diff --git a/envs/mini_swe_env/grading.py b/envs/mini_swe_env/grading.py
index d09f46a4c..d0c6c120d 100644
--- a/envs/mini_swe_env/grading.py
+++ b/envs/mini_swe_env/grading.py
@@ -4,16 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-"""Binary grading for SWE-Gym tasks.
+"""Grading helpers for SWE-Gym tasks.
-This module is intentionally SWE-Gym-native and does not depend on
-external repo/version parser maps.
+Supports both the canonical binary SWE-Gym reward and an optional
+case-fraction shaping mode for RL experiments:
-The grading contract matches SWE-Gym semantics:
-
-- reward = 1.0 iff every FAIL_TO_PASS test passes AND every PASS_TO_PASS
- test still passes.
-- reward = 0.0 otherwise.
+- ``binary``: reward = 1.0 iff every FAIL_TO_PASS test passes AND every
+ PASS_TO_PASS test still passes, else 0.0.
+- ``case_fraction``: reward = passed_cases / total_cases, while
+ ``resolved`` still requires every FAIL_TO_PASS and PASS_TO_PASS case to
+ pass.
"""
from __future__ import annotations
@@ -32,6 +32,7 @@ class GradeResult:
__slots__ = (
"reward",
+ "case_fraction",
"resolved",
"patch_applied",
"tests_status",
@@ -42,12 +43,14 @@ def __init__(
self,
*,
reward: float,
+ case_fraction: float,
resolved: bool,
patch_applied: bool,
tests_status: dict[str, Any] | None = None,
instance_id: str = "",
):
self.reward = reward
+ self.case_fraction = case_fraction
self.resolved = resolved
self.patch_applied = patch_applied
self.tests_status = tests_status
@@ -55,7 +58,8 @@ def __init__(
def __repr__(self) -> str:
return (
- f"GradeResult(reward={self.reward}, resolved={self.resolved}, "
+ f"GradeResult(reward={self.reward}, case_fraction={self.case_fraction}, "
+ f"resolved={self.resolved}, "
f"patch_applied={self.patch_applied}, instance_id={self.instance_id!r})"
)
@@ -65,6 +69,7 @@ def grade_from_case_results(
case_results: dict[str, bool],
*,
patch_applied: bool = True,
+ reward_mode: str = "binary",
) -> GradeResult:
"""Grade directly from per-test-case outcomes.
@@ -72,14 +77,24 @@ def grade_from_case_results(
task: SWE-Gym task with FAIL_TO_PASS and PASS_TO_PASS lists.
case_results: Mapping ``test_case -> passed``.
patch_applied: Whether test patch was successfully applied.
+ reward_mode: ``"binary"`` or ``"case_fraction"``.
"""
if not isinstance(case_results, dict):
raise GradingError("case_results must be a dict[str, bool]")
+ if reward_mode not in {"binary", "case_fraction"}:
+ raise GradingError(
+ f"Unknown reward_mode {reward_mode!r}; expected 'binary' or 'case_fraction'"
+ )
def _passed(case: str) -> bool:
return bool(case_results.get(case, False))
+ all_cases = list(dict.fromkeys([*task.FAIL_TO_PASS, *task.PASS_TO_PASS]))
+ passed_cases = sum(1 for case in all_cases if _passed(case))
+ total_cases = len(all_cases)
+ case_fraction = (passed_cases / total_cases) if total_cases else 0.0
+
resolved = all(_passed(case) for case in task.FAIL_TO_PASS) and all(
_passed(case) for case in task.PASS_TO_PASS
)
@@ -95,8 +110,13 @@ def _passed(case: str) -> bool:
},
}
+ reward = 1.0 if resolved else 0.0
+ if reward_mode == "case_fraction":
+ reward = case_fraction
+
return GradeResult(
- reward=1.0 if resolved else 0.0,
+ reward=reward,
+ case_fraction=case_fraction,
resolved=resolved,
patch_applied=patch_applied,
tests_status=tests_status,
diff --git a/envs/mini_swe_env/harness.py b/envs/mini_swe_env/harness.py
index e00255d1f..17ec2f815 100644
--- a/envs/mini_swe_env/harness.py
+++ b/envs/mini_swe_env/harness.py
@@ -50,11 +50,12 @@
import json
import queue as _queue_mod
import logging
+import os
import shlex
import time
import uuid
-from dataclasses import dataclass, field
-from typing import Any, Literal
+from dataclasses import dataclass, field, replace
+from typing import Any, Callable, Literal
from openenv.core.harness import Message, ResourceSessionFactory, VerifyResult
from openenv.core.harness.agents import get_agent_spec
@@ -79,6 +80,32 @@
VERIFY_TIMEOUT_S = 300
SETUP_TIMEOUT_S = 600
+DEFAULT_GIT_CHECKOUT_TIMEOUT_S = 300
+
+
+def _git_checkout_timeout_s() -> int:
+ for name in ("SWE_GIT_CHECKOUT_TIMEOUT_S", "SWE_LOCAL_GIT_CHECKOUT_TIMEOUT_S"):
+ value = os.environ.get(name)
+ if value is None:
+ continue
+ try:
+ timeout_s = int(value)
+ except ValueError:
+ _log.warning(
+ "%s must be an integer; using %ss",
+ name,
+ DEFAULT_GIT_CHECKOUT_TIMEOUT_S,
+ )
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
+ if timeout_s > 0:
+ return timeout_s
+ _log.warning(
+ "%s must be positive; using %ss",
+ name,
+ DEFAULT_GIT_CHECKOUT_TIMEOUT_S,
+ )
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
_ANSWER_TOOL_DEFINITION = {
"type": "function",
@@ -101,43 +128,103 @@
{problem_statement}
+{hints_block}
+
# Task Instructions
## Overview
-You're a software engineer working on a codebase at /testbed.
+You're a software engineer working on a codebase at {workdir}.
Your task is to fix the issue described in the PR description above
by making changes to the source code (non-test files).
## Important Boundaries
-- MODIFY: Regular source code files in /testbed
+- MODIFY: Regular source code files in {workdir}
- DO NOT MODIFY: Tests, configuration files (pyproject.toml, setup.cfg, etc.)
+- Test edits do not help. Grading may restore evaluation tests before running.
+- Plain text does not execute anything. To inspect files, run commands, edit
+ code, or submit a fix, you must use your available tools.
+- The repo environment is already bootstrapped. Do not create a new virtualenv
+ or reinstall the project unless a command clearly shows it is necessary.
+- Do not use `git commit`, `git branch`, or `git push`. Grading only checks the
+ final working tree state.
+- Each bash tool call runs in a fresh shell. If commands depend on shell state,
+ combine them into one bash invocation; `source` and shell variables do not
+ persist across separate tool calls.
+- Do not claim a file was changed or a test passed unless you actually used a
+ tool and observed that result.
## Recommended Workflow
-1. Analyze the codebase by finding and reading relevant files
-2. Create a script to reproduce the issue
-3. Edit the source code to resolve the issue
-4. Verify your fix works by running your script again
-5. Test edge cases to ensure your fix is robust
-
-## Submitting Your Answer
-When you've completed your work and verified your fix, call the `answer`
-tool to submit your solution for grading. This runs the test suite and
-returns whether the issue is resolved.
-
-You cannot continue working after submitting — make sure your fix is
-tested before calling `answer`.
+1. Start with the maintainer hints or issue description. If they point to a
+ likely source file or function, inspect that before broad repo-wide searches.
+ If the issue mentions an identifier or exact error text, grep for that first.
+2. Reproduce the issue with focused commands instead of exhaustive greps over
+ unrelated tests.
+3. Edit the source code to resolve the issue.
+4. Verify the fix with targeted commands, then check for obvious regressions.
+5. If the `answer` tool is available, call it only after you have actually
+ changed source code and verified the result.
+
+{submission_block}
"""
-def _wrap_instruction(problem_statement: str) -> str:
+def _bool_env(name: str, default: bool) -> bool:
+ raw = os.environ.get(name)
+ if raw is None:
+ return default
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
+
+
+def _answer_tool_enabled() -> bool:
+ return _bool_env("SWE_ENABLE_ANSWER_TOOL", True)
+
+
+def _wrap_instruction(
+ problem_statement: str,
+ *,
+ hints_text: str = "",
+ workdir: str = TESTBED,
+ answer_tool_enabled: bool = True,
+) -> str:
"""Wrap a problem statement with SWE-Gym-style task instructions.
Tells the agent about the workflow, boundaries, and crucially
about the ``answer`` tool for submission.
"""
+ hints = (hints_text or "").strip()
+ hints_block = ""
+ if hints:
+ hints_block = (
+ "\n"
+ "Additional context from issue triage or maintainers:\n"
+ f"{hints}\n"
+ ""
+ )
+ if answer_tool_enabled:
+ submission_block = (
+ "## Submitting Your Answer\n"
+ "When you've completed your work and verified your fix, call the "
+ "`answer`\n"
+ "tool to submit your solution for grading. This runs the test suite and\n"
+ "returns whether the issue is resolved.\n\n"
+ "You cannot continue working after submitting — make sure your fix is\n"
+ "tested before calling `answer`."
+ )
+ else:
+ submission_block = (
+ "## Ending The Run\n"
+ "There is no `answer` tool in this run.\n"
+ "Keep working until you have made and checked the best source-code "
+ "fix you can.\n"
+ "Your final repo state will be graded automatically when the session "
+ "ends."
+ )
return _SWE_INSTRUCTION_TEMPLATE.format(
problem_statement=problem_statement,
+ hints_block=hints_block,
+ workdir=workdir,
+ submission_block=submission_block,
)
@@ -197,6 +284,9 @@ def __init__(
self._answer_reward_source: str | None = None
self._answer_called = False
self._answer_bridged = False
+ self._fallback_grader: (
+ Callable[..., tuple[float, bool]] | None
+ ) = None
@property
def swe_task(self) -> SWETask:
@@ -294,7 +384,9 @@ def verify(
t0 = time.time()
try:
r = self.sandbox.exec(
- cmd, cwd=TESTBED, timeout=self._verify_timeout_s
+ cmd,
+ cwd=self.config.workdir,
+ timeout=self._verify_timeout_s,
)
detail = {
"cmd": cmd,
@@ -332,7 +424,37 @@ def verify(
},
)
- # 4. No reward source — agent didn't call answer, no verify cmds.
+ # 4. Final-state fallback: grade the sandbox even if the agent forgot
+ # to call answer(). This preserves valid reward signal for training.
+ if self._fallback_grader is not None:
+ try:
+ reward, resolved = self._fallback_grader(
+ self.sandbox,
+ self._swe_task,
+ home=self.config.sandbox_home,
+ workdir=self.config.workdir,
+ )
+ return VerifyResult(
+ env_reward=float(reward),
+ done=True,
+ metrics={
+ "instance_id": self._swe_task.instance_id,
+ "reward_source": "host_verify_fallback",
+ "resolved": bool(resolved),
+ "answer_called": False,
+ "answer_bridged": False,
+ },
+ artifacts={
+ "task_id": self._swe_task.task_id,
+ },
+ )
+ except Exception:
+ _log.exception(
+ "fallback grading failed for %s",
+ self._swe_task.instance_id,
+ )
+
+ # 5. No reward source — agent didn't call answer, no verify cmds.
return VerifyResult(
env_reward=0.0,
done=True,
@@ -476,6 +598,10 @@ def create(
swe_task = coerce_swe_task(task)
validate_swe_task(swe_task)
+ backend_supports_images = bool(
+ getattr(self._backend, "supports_images", True)
+ )
+ requested_image = swe_task.sandbox_image if backend_supports_images else None
sandbox_timeout = int(self._config.agent_timeout_s) + 600
sandbox = self._backend.create(
timeout_s=sandbox_timeout,
@@ -484,17 +610,31 @@ def create(
if episode_id
else {"instance_id": swe_task.instance_id}
),
- image=swe_task.sandbox_image,
+ image=requested_image,
+ )
+
+ session_config = replace(
+ self._config,
+ sandbox_home=self._resolve_sandbox_home(sandbox),
+ workdir=self._resolve_workdir(sandbox),
)
try:
- if not swe_task.sandbox_image:
- self._prepare_repo(sandbox, swe_task)
+ if not requested_image:
+ self._prepare_repo(sandbox, swe_task, workdir=session_config.workdir)
+ self._bootstrap_local_repo_env(
+ sandbox,
+ swe_task,
+ config=session_config,
+ )
- self._run_setup(sandbox, swe_task)
+ self._run_setup(sandbox, swe_task, workdir=session_config.workdir)
- agent_task = self._build_agent_task(swe_task)
- self._driver._bootstrap_sandbox(sandbox, agent_task, self._config)
+ agent_task = self._build_agent_task(
+ swe_task,
+ workdir=session_config.workdir,
+ )
+ self._driver._bootstrap_sandbox(sandbox, agent_task, session_config)
except Exception as exc:
_log.error("SWESessionFactory.create: bootstrap failed: %r", exc)
@@ -516,9 +656,15 @@ def create(
rollout_id,
)
- agent_task = self._build_agent_task(swe_task)
+ agent_task = self._build_agent_task(
+ swe_task,
+ workdir=session_config.workdir,
+ )
agent_bg = self._driver._start_agent(
- sandbox, agent_task, self._config, base_url_override=base_url_override
+ sandbox,
+ agent_task,
+ session_config,
+ base_url_override=base_url_override,
)
session = SWESession(
@@ -527,7 +673,7 @@ def create(
spec=self._spec,
sandbox=sandbox,
task=agent_task,
- config=self._config,
+ config=session_config,
base_url_override=base_url_override,
agent_bg_job=agent_bg,
interception_server=self._interception_server,
@@ -535,19 +681,38 @@ def create(
interception_queue=interception_queue,
)
- if self._mode == "interception_gate":
+ if self._mode == "interception_gate" and _answer_tool_enabled():
self._register_answer_tool(session)
+ session._fallback_grader = self._grade_answer_submission
return session
# ── Bootstrap helpers ──────────────────────────────────────────────────
- def _prepare_repo(self, sandbox: SandboxHandle, task: SWETask) -> None:
+ def _resolve_sandbox_home(self, sandbox: SandboxHandle) -> str:
+ home = getattr(sandbox, "sandbox_home", None)
+ if isinstance(home, str) and home.strip():
+ return home
+ return self._config.sandbox_home
+
+ def _resolve_workdir(self, sandbox: SandboxHandle) -> str:
+ workdir = getattr(sandbox, "workdir", None)
+ if isinstance(workdir, str) and workdir.strip():
+ return workdir
+ return self._config.workdir
+
+ def _prepare_repo(
+ self,
+ sandbox: SandboxHandle,
+ task: SWETask,
+ *,
+ workdir: str,
+ ) -> None:
"""Clone the repo and reset to base_commit."""
- sandbox.exec(f"mkdir -p {TESTBED}", timeout=10)
+ sandbox.exec(f"mkdir -p {shlex.quote(workdir)}", timeout=10)
clone_url = f"https://github.com/{task.repo}.git"
r = sandbox.exec(
- f"git clone --quiet {clone_url} {TESTBED}",
+ f"git clone --quiet {clone_url} {shlex.quote(workdir)}",
timeout=SETUP_TIMEOUT_S,
)
if r.exit_code != 0:
@@ -556,32 +721,94 @@ def _prepare_repo(self, sandbox: SandboxHandle, task: SWETask) -> None:
)
r = sandbox.exec(
f"git checkout --quiet {task.base_commit}",
- cwd=TESTBED,
- timeout=60,
+ cwd=workdir,
+ timeout=_git_checkout_timeout_s(),
)
if r.exit_code != 0:
raise RuntimeError(
f"git checkout failed (exit {r.exit_code}): {r.stderr[:500]}"
)
- def _run_setup(self, sandbox: SandboxHandle, task: SWETask) -> None:
+ def _run_setup(
+ self,
+ sandbox: SandboxHandle,
+ task: SWETask,
+ *,
+ workdir: str,
+ ) -> None:
"""Run task setup commands in the workspace."""
for cmd in task.setup:
- r = sandbox.exec(cmd, cwd=TESTBED, timeout=SETUP_TIMEOUT_S)
+ r = sandbox.exec(cmd, cwd=workdir, timeout=SETUP_TIMEOUT_S)
if r.exit_code != 0:
raise RuntimeError(
f"Setup command failed (exit {r.exit_code}): "
f"{cmd[:120]}\nstderr: {(r.stderr or '')[:500]}"
)
- def _build_agent_task(self, swe_task: SWETask) -> _SWEAgentTask:
+ def _bootstrap_local_repo_env(
+ self,
+ sandbox: SandboxHandle,
+ swe_task: SWETask,
+ *,
+ config: SWEAgentConfig,
+ ) -> None:
+ """Install repo/runtime deps when the backend cannot provide task images.
+
+ SWE-Gym tasks usually rely on prebuilt per-task images. For rootless
+ local sandboxes we recreate just enough of that environment to run the
+ repeated-task pilot by installing the repo editable plus common test
+ dependencies inside the sandbox-local virtualenv.
+ """
+ del swe_task
+ workdir_q = shlex.quote(config.workdir)
+ commands = [
+ "python -m pip install -U pip setuptools wheel",
+ "python -m pip install pytest",
+ (
+ f"cd {workdir_q} && ("
+ "python -m pip install -e .[all] || "
+ "python -m pip install -e .[tests] || "
+ "python -m pip install -e .[test] || "
+ "python -m pip install -e . || "
+ "python -m pip install .)"
+ ),
+ ]
+ if sandbox.exists(f"{config.workdir}/requirements-tests.txt"):
+ commands.append(
+ f"cd {workdir_q} && python -m pip install -r requirements-tests.txt"
+ )
+ elif sandbox.exists(f"{config.workdir}/requirements-test.txt"):
+ commands.append(
+ f"cd {workdir_q} && python -m pip install -r requirements-test.txt"
+ )
+
+ for cmd in commands:
+ result = sandbox.exec(cmd, cwd=config.workdir, timeout=SETUP_TIMEOUT_S)
+ if result.exit_code != 0:
+ raise RuntimeError(
+ "local sandbox repo bootstrap failed "
+ f"(exit {result.exit_code}): {(result.stderr or result.stdout)[-500:]}"
+ )
+
+ def _build_agent_task(
+ self,
+ swe_task: SWETask,
+ *,
+ workdir: str,
+ ) -> _SWEAgentTask:
"""Convert SWETask into the shape CLIAgentDriver expects.
Wraps the raw problem statement with SWE-Gym-style instructions
that tell the agent about the ``answer`` tool.
"""
+ answer_tool_enabled = _answer_tool_enabled()
return _SWEAgentTask(
- instruction=_wrap_instruction(swe_task.instruction),
+ instruction=_wrap_instruction(
+ swe_task.instruction,
+ hints_text=str((swe_task.metadata or {}).get("hints_text", "") or ""),
+ workdir=workdir,
+ answer_tool_enabled=answer_tool_enabled,
+ ),
setup_shell=None,
metadata={
"task_id": swe_task.task_id,
@@ -614,6 +841,8 @@ async def _answer_handler(arguments: dict[str, Any]) -> dict[str, Any]:
self._grade_answer_submission,
session.sandbox,
session.swe_task,
+ home=session.config.sandbox_home,
+ workdir=session.config.workdir,
)
session.set_answer_reward(reward, source="host_answer_tool")
return {
@@ -635,14 +864,26 @@ def _grade_answer_submission(
self,
sandbox: SandboxHandle,
swe_task: SWETask,
+ *,
+ home: str,
+ workdir: str,
) -> tuple[float, bool]:
"""Compute answer-tool reward on host and return ``(reward, resolved)``."""
try:
metadata = swe_task.metadata or {}
required = {"version", "patch", "test_patch", "FAIL_TO_PASS"}
if required.issubset(metadata):
- return self._grade_with_swegym_metadata(sandbox, swe_task)
- return self._grade_with_verify_commands(sandbox, swe_task)
+ return self._grade_with_swegym_metadata(
+ sandbox,
+ swe_task,
+ home=home,
+ workdir=workdir,
+ )
+ return self._grade_with_verify_commands(
+ sandbox,
+ swe_task,
+ workdir=workdir,
+ )
except Exception:
_log.exception("answer-tool grading failed for %s", swe_task.instance_id)
return 0.0, False
@@ -651,6 +892,9 @@ def _grade_with_swegym_metadata(
self,
sandbox: SandboxHandle,
swe_task: SWETask,
+ *,
+ home: str,
+ workdir: str,
) -> tuple[float, bool]:
"""Grade SWE-Gym tasks directly from FAIL/PASS test-case outcomes."""
metadata = swe_task.metadata
@@ -672,33 +916,56 @@ def _grade_with_swegym_metadata(
)
touched_files = self._extract_paths_from_test_patch(gym_task.test_patch)
+ changed_test_like_files = self._list_changed_test_paths(
+ sandbox,
+ workdir=workdir,
+ )
+ files_to_restore = sorted(set(touched_files) | set(changed_test_like_files))
self._revert_test_files(
sandbox,
base_commit=swe_task.base_commit,
- paths=touched_files,
+ paths=files_to_restore,
strict=True,
+ workdir=workdir,
)
- self._apply_test_patch(sandbox, gym_task.test_patch)
- case_results = self._run_swegym_case_tests(sandbox, gym_task)
- grade = grade_from_case_results(gym_task, case_results)
+ self._apply_test_patch(sandbox, gym_task.test_patch, home=home, workdir=workdir)
+ case_results = self._run_swegym_case_tests(
+ sandbox,
+ gym_task,
+ workdir=workdir,
+ )
+ grade = grade_from_case_results(
+ gym_task,
+ case_results,
+ reward_mode=os.environ.get("SWE_REWARD_MODE", "binary").strip().lower()
+ or "binary",
+ )
# Best-effort cleanup in case grading was interrupted.
self._revert_test_files(
sandbox,
base_commit=swe_task.base_commit,
- paths=touched_files,
+ paths=files_to_restore,
strict=False,
+ workdir=workdir,
)
return float(grade.reward), bool(grade.resolved)
- def _apply_test_patch(self, sandbox: SandboxHandle, test_patch: str) -> None:
- patch_path = f"{HOME}/.openenv_swe_test_patch.diff"
+ def _apply_test_patch(
+ self,
+ sandbox: SandboxHandle,
+ test_patch: str,
+ *,
+ home: str,
+ workdir: str,
+ ) -> None:
+ patch_path = f"{home}/.openenv_swe_test_patch.diff"
sandbox.write_text(patch_path, test_patch)
result = sandbox.exec(
f"git apply --whitespace=nowarn {shlex.quote(patch_path)}",
- cwd=TESTBED,
+ cwd=workdir,
timeout=30,
)
if result.exit_code != 0:
@@ -711,6 +978,8 @@ def _run_swegym_case_tests(
self,
sandbox: SandboxHandle,
gym_task: SWEGymTask,
+ *,
+ workdir: str,
) -> dict[str, bool]:
cases: list[str] = []
seen: set[str] = set()
@@ -723,7 +992,7 @@ def _run_swegym_case_tests(
results: dict[str, bool] = {}
for case in cases:
cmd = f"python -m pytest -q --maxfail=1 {shlex.quote(case)}"
- run = sandbox.exec(cmd, cwd=TESTBED, timeout=self._verify_timeout_s)
+ run = sandbox.exec(cmd, cwd=workdir, timeout=self._verify_timeout_s)
results[case] = run.exit_code == 0
return results
@@ -731,13 +1000,15 @@ def _grade_with_verify_commands(
self,
sandbox: SandboxHandle,
swe_task: SWETask,
+ *,
+ workdir: str,
) -> tuple[float, bool]:
"""Legacy fallback for non-SWE-Gym tasks."""
if not swe_task.verify:
return 0.0, False
passed = 0
for cmd in swe_task.verify:
- r = sandbox.exec(cmd, cwd=TESTBED, timeout=self._verify_timeout_s)
+ r = sandbox.exec(cmd, cwd=workdir, timeout=self._verify_timeout_s)
if r.exit_code == 0:
passed += 1
reward = passed / len(swe_task.verify)
@@ -755,6 +1026,42 @@ def _extract_paths_from_test_patch(test_patch: str) -> list[str]:
paths.append(path)
return sorted(set(paths))
+ @staticmethod
+ def _is_test_like_path(path: str) -> bool:
+ text = (path or "").strip().strip('"')
+ if not text:
+ return False
+ normalized = text.replace("\\", "/")
+ basename = normalized.rsplit("/", 1)[-1]
+ if basename == "conftest.py":
+ return True
+ if basename.startswith("test_") or basename.endswith("_test.py"):
+ return True
+ parts = normalized.split("/")
+ return any(part in {"test", "tests", "testing"} for part in parts[:-1])
+
+ @classmethod
+ def _list_changed_test_paths(
+ cls,
+ sandbox: SandboxHandle,
+ *,
+ workdir: str,
+ ) -> list[str]:
+ commands = (
+ "git diff --name-only HEAD --",
+ "git ls-files --others --exclude-standard",
+ )
+ paths: set[str] = set()
+ for cmd in commands:
+ result = sandbox.exec(cmd, cwd=workdir, timeout=10)
+ if result.exit_code != 0:
+ continue
+ for line in (result.stdout or "").splitlines():
+ path = line.strip()
+ if cls._is_test_like_path(path):
+ paths.add(path)
+ return sorted(paths)
+
def _revert_test_files(
self,
sandbox: SandboxHandle,
@@ -762,6 +1069,7 @@ def _revert_test_files(
base_commit: str,
paths: list[str],
strict: bool,
+ workdir: str,
) -> None:
if not paths:
return
@@ -770,7 +1078,7 @@ def _revert_test_files(
for path in paths:
has_file = sandbox.exec(
f"git cat-file -e {shlex.quote(f'{base_commit}:{path}')}",
- cwd=TESTBED,
+ cwd=workdir,
timeout=10,
)
if has_file.exit_code == 0:
@@ -781,7 +1089,7 @@ def _revert_test_files(
else:
cmd = f"rm -f -- {shlex.quote(path)}"
- result = sandbox.exec(cmd, cwd=TESTBED, timeout=20)
+ result = sandbox.exec(cmd, cwd=workdir, timeout=20)
if result.exit_code != 0:
failures.append(
f"{path}: {(result.stderr or result.stdout or '').strip()}"
diff --git a/envs/mini_swe_env/server/swe_environment.py b/envs/mini_swe_env/server/swe_environment.py
index 916d6f938..c6b21341f 100644
--- a/envs/mini_swe_env/server/swe_environment.py
+++ b/envs/mini_swe_env/server/swe_environment.py
@@ -29,6 +29,7 @@
import json
import logging
+import os
import shlex
import time
from pathlib import Path
@@ -83,6 +84,32 @@
MCP_PORT = 8765
VERIFY_TIMEOUT_S = 300
SETUP_TIMEOUT_S = 600
+DEFAULT_GIT_CHECKOUT_TIMEOUT_S = 300
+
+
+def _git_checkout_timeout_s() -> int:
+ for name in ("SWE_GIT_CHECKOUT_TIMEOUT_S", "SWE_LOCAL_GIT_CHECKOUT_TIMEOUT_S"):
+ value = os.environ.get(name)
+ if value is None:
+ continue
+ try:
+ timeout_s = int(value)
+ except ValueError:
+ _log.warning(
+ "%s must be an integer; using %ss",
+ name,
+ DEFAULT_GIT_CHECKOUT_TIMEOUT_S,
+ )
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
+ if timeout_s > 0:
+ return timeout_s
+ _log.warning(
+ "%s must be positive; using %ss",
+ name,
+ DEFAULT_GIT_CHECKOUT_TIMEOUT_S,
+ )
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
+ return DEFAULT_GIT_CHECKOUT_TIMEOUT_S
# Path to the sandbox_mcp_server.py source alongside this module.
_SANDBOX_MCP_SERVER_SOURCE = Path(__file__).parent / "sandbox_mcp_server.py"
@@ -509,7 +536,14 @@ def _grade_submission(
)
self._apply_test_patch(sandbox, gym_task.test_patch)
case_results = self._run_swegym_case_tests(sandbox, gym_task)
- return grade_from_case_results(gym_task, case_results)
+ return grade_from_case_results(
+ gym_task,
+ case_results,
+ reward_mode=(
+ os.environ.get("SWE_REWARD_MODE", "binary").strip().lower()
+ or "binary"
+ ),
+ )
except Exception as exc:
_log.warning("SWE-Gym grading failed, falling back: %s", exc)
return None
@@ -651,7 +685,7 @@ def _prepare_repo(self, sandbox: Any, task: SWETask) -> None:
r = sandbox.exec(
f"git checkout --quiet {task.base_commit}",
cwd=TESTBED,
- timeout=60,
+ timeout=_git_checkout_timeout_s(),
)
if r.exit_code != 0:
raise RuntimeError(
diff --git a/examples/mini_swe_env/async_grpo/rollout_worker.py b/examples/mini_swe_env/async_grpo/rollout_worker.py
index f46cd4df8..545824b00 100644
--- a/examples/mini_swe_env/async_grpo/rollout_worker.py
+++ b/examples/mini_swe_env/async_grpo/rollout_worker.py
@@ -3,23 +3,21 @@
Implements ``RolloutWorkerProtocol`` from TRL's ``AsyncGRPOTrainer``.
Architecture:
- Pi (sandbox) → InterceptionServer → this worker → vLLM /v1/completions
+ Pi (sandbox) → InterceptionServer → this worker → vLLM /v1/chat/completions
← chat response back to Pi
Pi drives the generation loop inside the sandbox. This worker:
1. Dequeues each intercepted LLM request from Pi.
-2. Tokenizes the messages with ``apply_chat_template``.
-3. Calls vLLM ``/v1/completions`` with ``prompt=token_ids``,
- ``return_token_ids=True``, ``logprobs=0`` — same as TRL's own
- ``AsyncRolloutWorker._generate_one_turn``.
-4. Gets exact ``completion_ids`` and ``completion_logprobs`` from vLLM.
-5. Wraps the completion text as a chat response and delivers it back to Pi.
-6. Tracks multi-turn token sequences matching TRL's pattern:
+2. Forwards the intercepted request to vLLM ``/v1/chat/completions``.
+3. Requests exact ``prompt_token_ids`` / ``completion_ids`` and per-token
+ logprobs from vLLM.
+4. Delivers the OpenAI-compatible chat response back to Pi.
+5. Tracks multi-turn token sequences matching TRL's pattern:
``input_ids = initial_prompt_ids + [turn_ids + suffix_ids]*N``
``completion_mask = [0]*prompt + [1]*turn + [0]*suffix + ...``
-7. On ``answer()``, bridges to host-side grading.
-8. Assembles the final ``RolloutSample`` and pushes to ``rollout_buffer``.
+6. On ``answer()``, bridges to host-side grading.
+7. Assembles the final ``RolloutSample`` and pushes to ``rollout_buffer``.
"""
from __future__ import annotations
@@ -27,12 +25,15 @@
import asyncio
import json
import logging
+import math
+import os
import queue
+import re
import threading
import time
import uuid
from dataclasses import dataclass, field
-from typing import Any, Iterator, Sequence, cast
+from typing import Any, Callable, Iterator, Sequence, cast
import requests
@@ -95,17 +96,36 @@ class RolloutSample:
metrics: dict[str, Any] = field(default_factory=dict)
+@dataclass
+class PendingRollout:
+ """One rollout before group-relative advantage normalization."""
+
+ input_ids: list[int]
+ completion_mask: list[int]
+ old_log_probs: list[float]
+ reward: float
+ model_version: int
+ metrics: dict[str, Any] = field(default_factory=dict)
+
+
# ── Config ─────────────────────────────────────────────────────────────
@dataclass(frozen=True)
class WorkerConfig:
max_inflight: int = 2
+ max_rollout_attempts: int = 4
+ num_generations: int = 4
queue_maxsize: int = 64
request_timeout_s: float = 600.0
max_turns: int = 50
+ max_model_len: int = 4096
max_completion_tokens: int = 2048
temperature: float = 1.0
+ max_tool_message_chars: int = 6000
+ min_tool_message_chars: int = 256
+ max_assistant_message_chars: int = 4000
+ min_assistant_message_chars: int = 256
# After returning a terminal plain-text response (finish_reason=stop,
# no tool_calls), wait briefly for a follow-up request before treating
# the rollout as complete. This avoids 600s stalls when agent exit
@@ -113,6 +133,7 @@ class WorkerConfig:
post_response_grace_s: float = 10.0
stop_on_idle_terminal_response: bool = True
idle_backoff_s: float = 0.5
+ failure_backoff_s: float = 30.0
# ── Worker ─────────────────────────────────────────────────────────────
@@ -148,6 +169,12 @@ def __init__(
self._vllm_api_key = vllm_api_key
self._vllm_model = vllm_model
self._cfg = config or WorkerConfig()
+ if not self._tasks:
+ raise ValueError("SWERolloutWorker requires at least one SWE task")
+ if self._cfg.num_generations < 2:
+ raise ValueError(
+ "WorkerConfig.num_generations must be >= 2 for valid GRPO grouping"
+ )
self.rollout_buffer: queue.Queue[RolloutSample] = queue.Queue(
maxsize=self._cfg.queue_maxsize,
@@ -162,6 +189,12 @@ def __init__(
self._model_version = 0
self._started = False
self._model_update_group: Any | None = None
+ self._next_group_id = 0
+ self._current_group_task: SWETask | None = None
+ self._current_group_id: int | None = None
+ self._current_group_model_version: int | None = None
+ self._group_replica_idx = 0
+ self._pending_groups: dict[int, list[PendingRollout]] = {}
# Prefix to prepend to trainer parameter names so they match vLLM's
# model architecture. For VLM models like Qwen3_5ForConditionalGeneration
@@ -228,6 +261,7 @@ def send_weights(self, iterator: Iterator[tuple[str, Any]]) -> None:
)
return
+ items = [(self._vllm_weight_name(name), tensor) for name, tensor in items]
names = [name for name, _ in items]
# VLM models (e.g. Qwen3_5ForConditionalGeneration) have parameters at
@@ -302,6 +336,16 @@ def update_model_version(self, version: int) -> None:
with self._lock:
self._model_version = version
+ def _vllm_weight_name(self, name: str) -> str:
+ """Map trainer-side text model names to vLLM's served module names."""
+ model_id = self._vllm_model.lower().replace("_", "")
+ if "qwen3.5" in model_id or "qwen35" in model_id:
+ if name.startswith("model."):
+ return f"language_model.{name}"
+ if name.startswith("lm_head."):
+ return f"language_model.{name}"
+ return name
+
def _post_json(
self,
path: str,
@@ -324,6 +368,15 @@ def _post_json(
return response
def _init_weight_transfer(self) -> None:
+ if os.environ.get("SWE_DISABLE_WEIGHT_TRANSFER", "").strip().lower() in {
+ "1",
+ "true",
+ "yes",
+ "on",
+ }:
+ _log.warning("weight sync disabled by SWE_DISABLE_WEIGHT_TRANSFER")
+ return
+
if (
NCCLWeightTransferEngine is None
or NCCLTrainerSendWeightsArgs is None
@@ -403,16 +456,96 @@ def _destroy_model_update_group(self) -> None:
# ── Internal ───────────────────────────────────────────────────
- def _next_task(self) -> SWETask:
+ def _next_rollout_assignment(self) -> tuple[SWETask, int, int, int]:
with self._lock:
- t = self._tasks[self._task_idx % len(self._tasks)]
- self._task_idx += 1
- return t
+ if self._group_replica_idx == 0:
+ self._current_group_task = self._tasks[self._task_idx % len(self._tasks)]
+ self._task_idx += 1
+ self._current_group_id = self._next_group_id
+ self._current_group_model_version = self._model_version
+ self._next_group_id += 1
+
+ assert self._current_group_task is not None
+ assert self._current_group_id is not None
+ assert self._current_group_model_version is not None
+
+ task = self._current_group_task
+ group_id = self._current_group_id
+ model_version = self._current_group_model_version
+ replica_idx = self._group_replica_idx
+
+ self._group_replica_idx += 1
+ if self._group_replica_idx >= self._cfg.num_generations:
+ self._group_replica_idx = 0
+ self._current_group_task = None
+ self._current_group_id = None
+ self._current_group_model_version = None
+
+ return task, group_id, replica_idx, model_version
def _model_ver(self) -> int:
with self._lock:
return self._model_version
+ def _collect_group_sample(
+ self,
+ *,
+ group_id: int,
+ rollout: PendingRollout,
+ ) -> list[RolloutSample] | None:
+ with self._lock:
+ bucket = self._pending_groups.setdefault(group_id, [])
+ bucket.append(rollout)
+ if len(bucket) < self._cfg.num_generations:
+ return None
+ if len(bucket) > self._cfg.num_generations:
+ raise RuntimeError(
+ f"received too many rollouts for group {group_id}: "
+ f"{len(bucket)} > {self._cfg.num_generations}"
+ )
+ group_rollouts = list(bucket)
+ del self._pending_groups[group_id]
+
+ return self._finalize_group_rollouts(
+ group_id=group_id,
+ rollouts=group_rollouts,
+ )
+
+ def _finalize_group_rollouts(
+ self,
+ *,
+ group_id: int,
+ rollouts: Sequence[PendingRollout],
+ ) -> list[RolloutSample]:
+ rewards = [float(rollout.reward) for rollout in rollouts]
+ advantages, reward_mean, reward_std = _compute_group_advantages(rewards)
+ _log.info(
+ "rollout group complete: group_id=%d reward_mean=%.4f reward_std=%.4f rewards=%s",
+ group_id,
+ reward_mean,
+ reward_std,
+ [round(reward, 4) for reward in rewards],
+ )
+
+ samples: list[RolloutSample] = []
+ for rollout, advantage in zip(rollouts, advantages, strict=True):
+ metrics = dict(rollout.metrics)
+ metrics["reward"] = float(rollout.reward)
+ metrics["reward_mean"] = reward_mean
+ metrics["reward_std"] = reward_std
+ metrics["group_size"] = float(len(rollouts))
+ samples.append(
+ RolloutSample(
+ input_ids=rollout.input_ids,
+ completion_mask=rollout.completion_mask,
+ old_log_probs=rollout.old_log_probs,
+ advantage=advantage,
+ model_version=rollout.model_version,
+ metrics=metrics,
+ )
+ )
+ return samples
+
def _loop(self, idx: int) -> None:
while not self._stop.is_set():
while self._pause.is_set() and not self._stop.is_set():
@@ -420,27 +553,138 @@ def _loop(self, idx: int) -> None:
if self._stop.is_set():
return
- task = self._next_task()
- eid = f"swe-{idx}-{uuid.uuid4().hex[:8]}"
- try:
- sample = asyncio.run(self._rollout(task, eid))
- except Exception:
- _log.exception("rollout failed worker=%d id=%s", idx, task.instance_id)
- time.sleep(self._cfg.idle_backoff_s)
- continue
+ task, group_id, replica_idx, model_version = self._next_rollout_assignment()
+ eid = f"swe-{idx}-g{group_id}-r{replica_idx}-{uuid.uuid4().hex[:8]}"
+ rollout_t0 = time.time()
+ failed = False
+ sample: PendingRollout | None = None
+ last_exc: Exception | None = None
+
+ for attempt in range(1, self._cfg.max_rollout_attempts + 1):
+ try:
+ sample = asyncio.run(
+ self._rollout(task, eid, model_version=model_version)
+ )
+ last_exc = None
+ break
+ except Exception as exc:
+ last_exc = exc
+ retriable = (
+ attempt < self._cfg.max_rollout_attempts
+ and _is_retriable_rollout_error(exc)
+ )
+ if retriable:
+ _log.warning(
+ "rollout failed worker=%d id=%s attempt=%d/%d; retrying: %s",
+ idx,
+ task.instance_id,
+ attempt,
+ self._cfg.max_rollout_attempts,
+ exc,
+ )
+ time.sleep(self._cfg.failure_backoff_s * attempt)
+ continue
+
+ _log.exception(
+ "rollout failed worker=%d id=%s attempts=%d/%d",
+ idx,
+ task.instance_id,
+ attempt,
+ self._cfg.max_rollout_attempts,
+ )
+ sample = self._build_failed_rollout(
+ task=task,
+ elapsed_s=time.time() - rollout_t0,
+ exc=exc,
+ model_version=model_version,
+ )
+ failed = True
+ break
+
+ if sample is None:
+ if last_exc is None:
+ time.sleep(self._cfg.idle_backoff_s)
+ continue
+ raise RuntimeError(
+ f"rollout loop exhausted without sample for {task.instance_id}"
+ ) from last_exc
if sample is None:
time.sleep(self._cfg.idle_backoff_s)
continue
- try:
- self.rollout_buffer.put(sample, timeout=2.0)
- except queue.Full:
- _log.warning("queue full, dropping %s", task.instance_id)
+ group_samples = self._collect_group_sample(
+ group_id=group_id,
+ rollout=sample,
+ )
+ if group_samples is not None:
+ for group_sample in group_samples:
+ try:
+ self.rollout_buffer.put(group_sample, timeout=2.0)
+ except queue.Full:
+ _log.warning("queue full, dropping group %d", group_id)
+ break
+
+ if failed:
+ time.sleep(self._cfg.failure_backoff_s)
+
+ def _build_failed_rollout(
+ self,
+ *,
+ task: SWETask,
+ elapsed_s: float,
+ exc: Exception,
+ model_version: int,
+ ) -> PendingRollout:
+ """Return a zero-reward rollout when infrastructure fails before rollout.
+
+ Remote sandbox providers can transiently refuse new jobs. Producing a
+ neutral sample keeps distributed trainer ranks moving together instead
+ of letting nonzero ranks block in Accelerate's dataloader broadcast.
+ """
+ prompt_ids = self._render_prompt_ids(
+ [{"role": "user", "content": task.instruction}],
+ None,
+ )
+ completion_id = (
+ getattr(self._tokenizer, "eos_token_id", None)
+ or getattr(self._tokenizer, "pad_token_id", None)
+ or 0
+ )
+ input_ids = [*prompt_ids, int(completion_id)]
+ completion_mask = [0] * len(prompt_ids) + [1]
+ old_log_probs = [0.0] * len(input_ids)
+ exc_name = type(exc).__name__
+ return PendingRollout(
+ input_ids=input_ids,
+ completion_mask=completion_mask,
+ old_log_probs=old_log_probs,
+ reward=0.0,
+ model_version=model_version,
+ metrics={
+ "reward": 0.0,
+ "turns": 0.0,
+ "answer_called": 0.0,
+ "terminal_idle_stop": 0.0,
+ "wall_s": round(elapsed_s, 3),
+ "n_tokens": float(len(input_ids)),
+ "rollout_error": 1.0,
+ "sandbox_create_error": float(
+ "sandbox" in exc_name.lower()
+ or "sandbox" in str(exc).lower()
+ ),
+ },
+ )
# ── Single rollout ─────────────────────────────────────────────
- async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None:
+ async def _rollout(
+ self,
+ task: SWETask,
+ episode_id: str,
+ *,
+ model_version: int,
+ ) -> PendingRollout | None:
session = self._factory.create(task=task, episode_id=episode_id)
# Accumulate the full token sequence across turns, matching TRL's
@@ -467,9 +711,20 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
intercept = pending_intercept
pending_intercept = None
else:
- intercept = await session.next_request(
- timeout_s=self._cfg.request_timeout_s,
- )
+ try:
+ intercept = await session.next_request(
+ timeout_s=self._cfg.request_timeout_s,
+ )
+ except TimeoutError:
+ if turns == 0:
+ raise
+ rollout_stop_reason = "request_idle_timeout"
+ _log.info(
+ "rollout request idle-timeout: instance_id=%s turns=%d",
+ task.instance_id,
+ turns,
+ )
+ break
if intercept is None:
rollout_stop_reason = "agent_exit_detected"
break
@@ -477,7 +732,19 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
# ── Tokenize this turn's full prompt ──────────────
messages = _get_messages(intercept)
tools = _get_tools(intercept)
- current_prompt_ids = self._render_prompt_ids(messages, tools)
+ (
+ current_prompt_ids,
+ turn_ids,
+ turn_lps,
+ chat_resp,
+ _finish_reason,
+ ) = self._generate(
+ intercept=intercept,
+ messages=messages,
+ tools=tools,
+ )
+ if not current_prompt_ids:
+ current_prompt_ids = self._render_prompt_ids(messages, tools)
if initial_prompt_ids is None:
# First turn: the entire prompt is non-completion tokens.
@@ -498,11 +765,6 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
turns += 1
- # ── Generate via /v1/completions ──────────────────
- turn_ids, turn_lps, text, finish_reason = self._generate(
- current_prompt_ids
- )
-
all_ids.extend(turn_ids)
all_mask.extend([1] * len(turn_ids))
all_lps.extend(turn_lps)
@@ -510,18 +772,6 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
# For next turn's suffix computation:
prev_prompt_ids = current_prompt_ids + turn_ids
- # ── Build chat response for Pi ────────────────────
- assistant_message = _parse_assistant_message(
- tokenizer=self._tokenizer,
- completion_ids=turn_ids,
- fallback_text=text,
- )
- chat_resp = _make_chat_response(
- assistant_message,
- self._vllm_model,
- finish_reason=finish_reason,
- )
-
# ── Check for answer tool call ────────────────────
if _has_answer_call(chat_resp):
answer_called = True
@@ -571,12 +821,12 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
all_mask = [1]
all_lps = [0.0]
- return RolloutSample(
+ return PendingRollout(
input_ids=all_ids,
completion_mask=all_mask,
old_log_probs=all_lps,
- advantage=reward,
- model_version=self._model_ver(),
+ reward=reward,
+ model_version=model_version,
metrics={
"reward": reward,
"turns": float(turns),
@@ -588,6 +838,9 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
"agent_exit_after_terminal_stop",
}
),
+ "request_idle_timeout": float(
+ rollout_stop_reason == "request_idle_timeout"
+ ),
"wall_s": round(time.time() - t0, 3),
"n_tokens": float(len(all_ids)),
},
@@ -595,7 +848,7 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None
finally:
session.close()
- # ── vLLM call (matches TRL's _generate_one_turn exactly) ──────
+ # ── vLLM call ─────────────────────────────────────────────────
def _render_prompt_ids(
self,
@@ -618,42 +871,143 @@ def _render_prompt_ids(
def _generate(
self,
- prompt_ids: list[int],
- ) -> tuple[list[int], list[float], str, str | None]:
- """POST /v1/completions with token IDs.
-
- Returns: ``(token_ids, token_logprobs, text, finish_reason)``.
- """
- body = {
- "model": self._vllm_model,
- "prompt": prompt_ids,
- "max_tokens": self._cfg.max_completion_tokens,
- "temperature": self._cfg.temperature,
- "n": 1,
- "return_token_ids": True,
- "logprobs": 0,
- }
- resp = requests.post(
- f"{self._vllm_base_url}/v1/completions",
- headers={
- "Authorization": f"Bearer {self._vllm_api_key}",
- "Content-Type": "application/json",
- },
- json=body,
- timeout=self._cfg.request_timeout_s,
+ *,
+ intercept: dict[str, Any],
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None,
+ ) -> tuple[list[int], list[int], list[float], dict[str, Any], str | None]:
+ """POST /v1/chat/completions and return prompt/output tokens."""
+ (
+ messages,
+ rendered_prompt_ids,
+ ) = _fit_messages_to_context_window(
+ messages=messages,
+ tools=tools,
+ render_prompt_ids=self._render_prompt_ids,
+ requested_completion_tokens=self._cfg.max_completion_tokens,
+ max_model_len=self._cfg.max_model_len,
+ max_tool_message_chars=self._cfg.max_tool_message_chars,
+ min_tool_message_chars=self._cfg.min_tool_message_chars,
+ max_assistant_message_chars=self._cfg.max_assistant_message_chars,
+ min_assistant_message_chars=self._cfg.min_assistant_message_chars,
+ )
+ raw_body = intercept.get("body")
+ body = dict(raw_body) if isinstance(raw_body, dict) else {}
+ body["model"] = self._vllm_model
+ body["messages"] = messages
+ if tools is not None:
+ body["tools"] = tools
+ else:
+ body.pop("tools", None)
+ body.pop("stream", None)
+ body.pop("stream_options", None)
+ body.pop("max_tokens", None)
+ body["max_completion_tokens"] = _clamp_max_completion_tokens(
+ prompt_len=len(rendered_prompt_ids),
+ requested=self._cfg.max_completion_tokens,
+ max_model_len=self._cfg.max_model_len,
)
- if resp.status_code != 200:
+ body["temperature"] = self._cfg.temperature
+ body["n"] = 1
+ body["logprobs"] = True
+ body["top_logprobs"] = 0
+ body["return_token_ids"] = True
+
+ if os.environ.get("SWE_LOG_PROMPT_TOKENS", "").strip().lower() in {
+ "1",
+ "true",
+ "yes",
+ "on",
+ }:
+ _log.info(
+ "rollout prompt window: prompt_tokens=%d max_completion_tokens=%d tools=%d",
+ len(rendered_prompt_ids),
+ int(body["max_completion_tokens"]),
+ len(tools or []),
+ )
+
+ chat_template_kwargs = body.get("chat_template_kwargs")
+ if not isinstance(chat_template_kwargs, dict):
+ chat_template_kwargs = {}
+ model_id = self._vllm_model.lower().replace("_", "")
+ if (
+ ("qwen3" in model_id or "deepseek" in model_id)
+ and "enable_thinking" not in chat_template_kwargs
+ ):
+ chat_template_kwargs["enable_thinking"] = False
+ if chat_template_kwargs:
+ body["chat_template_kwargs"] = chat_template_kwargs
+
+ if tools and not body.get("tool_choice"):
+ body["tool_choice"] = "auto"
+
+ for attempt in range(2):
+ resp = requests.post(
+ f"{self._vllm_base_url}/v1/chat/completions",
+ headers={
+ "Authorization": f"Bearer {self._vllm_api_key}",
+ "Content-Type": "application/json",
+ },
+ json=body,
+ timeout=self._cfg.request_timeout_s,
+ )
+ if resp.status_code == 200:
+ break
+
+ retry_budget = _retry_completion_tokens_from_context_error(resp.text)
+ if attempt == 0 and retry_budget is not None:
+ requested = int(body.get("max_completion_tokens") or 0)
+ clamped_tokens, max_model_len, prompt_tokens = retry_budget
+ if clamped_tokens < requested:
+ _log.info(
+ "clamped max_completion_tokens from %d to %d "
+ "(prompt_tokens=%d max_model_len=%d)",
+ requested,
+ clamped_tokens,
+ prompt_tokens,
+ max_model_len,
+ )
+ body["max_completion_tokens"] = clamped_tokens
+ continue
+
raise RuntimeError(f"vllm {resp.status_code}: {resp.text[:400]}")
- choice = resp.json()["choices"][0]
+ payload = resp.json()
+ choice = payload["choices"][0]
+ turn_ids = _coerce_token_ids(choice.get("token_ids"))
+ choice["message"] = _normalize_chat_choice_message(
+ tokenizer=self._tokenizer,
+ choice=choice,
+ completion_ids=turn_ids,
+ )
+ chat_resp = dict(payload)
+ chat_resp["choices"] = [choice]
return (
- choice["token_ids"],
- choice["logprobs"]["token_logprobs"],
- choice.get("text", ""),
+ _extract_prompt_token_ids(payload) or rendered_prompt_ids,
+ turn_ids,
+ _extract_chat_choice_logprobs(choice, expected_len=len(turn_ids)),
+ chat_resp,
choice.get("finish_reason"),
)
+def _compute_group_advantages(
+ rewards: Sequence[float],
+ *,
+ eps: float = 1e-8,
+) -> tuple[list[float], float, float]:
+ """Return z-scored group advantages, mean reward, and reward stddev."""
+ if not rewards:
+ raise ValueError("rewards must not be empty")
+
+ reward_mean = sum(rewards) / len(rewards)
+ reward_var = sum((reward - reward_mean) ** 2 for reward in rewards) / len(rewards)
+ reward_std = math.sqrt(reward_var)
+ denom = reward_std + eps
+ advantages = [float((reward - reward_mean) / denom) for reward in rewards]
+ return advantages, float(reward_mean), float(reward_std)
+
+
# ── Helpers ────────────────────────────────────────────────────────────
@@ -679,6 +1033,286 @@ def _get_tools(intercept: dict[str, Any]) -> list[dict[str, Any]] | None:
return None
+def _coerce_token_ids(raw: Any) -> list[int]:
+ if not isinstance(raw, list):
+ return []
+ token_ids: list[int] = []
+ for token_id in raw:
+ try:
+ token_ids.append(int(token_id))
+ except (TypeError, ValueError):
+ return []
+ return token_ids
+
+
+def _extract_prompt_token_ids(payload: dict[str, Any]) -> list[int]:
+ return _coerce_token_ids(payload.get("prompt_token_ids"))
+
+
+def _extract_chat_choice_logprobs(
+ choice: dict[str, Any],
+ *,
+ expected_len: int,
+) -> list[float]:
+ content = (choice.get("logprobs") or {}).get("content") or []
+ values: list[float] = []
+ if isinstance(content, list):
+ for item in content:
+ if not isinstance(item, dict):
+ values.append(0.0)
+ continue
+ raw = item.get("logprob")
+ values.append(float(raw) if isinstance(raw, (int, float)) else 0.0)
+
+ if len(values) < expected_len:
+ values.extend([0.0] * (expected_len - len(values)))
+ return values[:expected_len]
+
+
+def _clamp_max_completion_tokens(
+ *,
+ prompt_len: int,
+ requested: int,
+ max_model_len: int,
+ safety_margin: int = 16,
+) -> int:
+ """Clamp completion tokens so prompt + generation fit in vLLM context."""
+ available = max_model_len - prompt_len - safety_margin
+ return max(1, min(int(requested), int(available)))
+
+
+_TRUNCATION_MARKER = "\n...[truncated]...\n"
+_OMITTED_TOOL_OUTPUT_MARKER = "[tool output omitted]"
+_OMITTED_ASSISTANT_TEXT_MARKER = "[omitted]"
+
+
+def _truncate_text_middle(
+ text: str,
+ *,
+ max_chars: int,
+ marker: str = _TRUNCATION_MARKER,
+) -> tuple[str, bool]:
+ if max_chars <= 0 or len(text) <= max_chars:
+ return text, False
+
+ usable = max_chars - len(marker)
+ if usable <= 8:
+ return marker[: max(1, max_chars)], True
+
+ head = max(1, int(usable * 0.75))
+ tail = max(1, usable - head)
+ return text[:head].rstrip() + marker + text[-tail:].lstrip(), True
+
+
+def _truncate_messages_for_prompt_budget(
+ messages: Sequence[dict[str, Any]],
+ *,
+ max_tool_message_chars: int,
+ max_assistant_message_chars: int,
+) -> tuple[list[dict[str, Any]], int, int]:
+ truncated_messages: list[dict[str, Any]] = []
+ tool_truncations = 0
+ assistant_truncations = 0
+
+ for message in messages:
+ updated = dict(message)
+ content = updated.get("content")
+ if not isinstance(content, str):
+ truncated_messages.append(updated)
+ continue
+
+ role = str(updated.get("role") or "")
+ if role == "tool":
+ content, changed = _truncate_text_middle(
+ content,
+ max_chars=max_tool_message_chars,
+ )
+ if changed:
+ tool_truncations += 1
+ elif role == "assistant":
+ content, changed = _truncate_text_middle(
+ content,
+ max_chars=max_assistant_message_chars,
+ )
+ if changed:
+ assistant_truncations += 1
+
+ updated["content"] = content
+ truncated_messages.append(updated)
+
+ return truncated_messages, tool_truncations, assistant_truncations
+
+
+def _replace_oldest_message_content(
+ messages: Sequence[dict[str, Any]],
+ *,
+ role: str,
+ replacement: str,
+) -> tuple[list[dict[str, Any]], bool]:
+ updated_messages = [dict(message) for message in messages]
+ for idx, message in enumerate(updated_messages):
+ if message.get("role") != role:
+ continue
+ content = message.get("content")
+ if not isinstance(content, str) or content == replacement:
+ continue
+ if len(replacement) >= len(content):
+ continue
+ message["content"] = replacement
+ updated_messages[idx] = message
+ return updated_messages, True
+ return updated_messages, False
+
+
+def _fit_messages_to_context_window(
+ *,
+ messages: Sequence[dict[str, Any]],
+ tools: list[dict[str, Any]] | None,
+ render_prompt_ids: Callable[
+ [list[dict[str, Any]], list[dict[str, Any]] | None],
+ list[int],
+ ],
+ requested_completion_tokens: int,
+ max_model_len: int,
+ max_tool_message_chars: int,
+ min_tool_message_chars: int,
+ max_assistant_message_chars: int,
+ min_assistant_message_chars: int,
+ safety_margin: int = 16,
+) -> tuple[list[dict[str, Any]], list[int]]:
+ prompt_budget = max(
+ 1,
+ max_model_len - max(1, int(requested_completion_tokens)) - safety_margin,
+ )
+ tool_char_budget = max(1, int(max_tool_message_chars))
+ assistant_char_budget = max(1, int(max_assistant_message_chars))
+ min_tool_chars = max(1, int(min_tool_message_chars))
+ min_assistant_chars = max(1, int(min_assistant_message_chars))
+
+ base_messages = [dict(message) for message in messages]
+ prepared_messages = base_messages
+ prompt_ids = render_prompt_ids(prepared_messages, tools)
+ tool_truncations = 0
+ assistant_truncations = 0
+
+ while True:
+ prepared_messages, tool_truncations, assistant_truncations = (
+ _truncate_messages_for_prompt_budget(
+ base_messages,
+ max_tool_message_chars=tool_char_budget,
+ max_assistant_message_chars=assistant_char_budget,
+ )
+ )
+ prompt_ids = render_prompt_ids(prepared_messages, tools)
+ if len(prompt_ids) <= prompt_budget:
+ break
+ if tool_char_budget > min_tool_chars:
+ tool_char_budget = max(min_tool_chars, tool_char_budget // 2)
+ continue
+ if assistant_char_budget > min_assistant_chars:
+ assistant_char_budget = max(
+ min_assistant_chars,
+ assistant_char_budget // 2,
+ )
+ continue
+ break
+
+ omitted_tool_messages = 0
+ omitted_assistant_messages = 0
+ while len(prompt_ids) > prompt_budget:
+ prepared_messages, changed = _replace_oldest_message_content(
+ prepared_messages,
+ role="tool",
+ replacement=_OMITTED_TOOL_OUTPUT_MARKER,
+ )
+ if changed:
+ omitted_tool_messages += 1
+ prompt_ids = render_prompt_ids(prepared_messages, tools)
+ continue
+
+ prepared_messages, changed = _replace_oldest_message_content(
+ prepared_messages,
+ role="assistant",
+ replacement=_OMITTED_ASSISTANT_TEXT_MARKER,
+ )
+ if not changed:
+ break
+ omitted_assistant_messages += 1
+ prompt_ids = render_prompt_ids(prepared_messages, tools)
+
+ if (
+ tool_truncations
+ or assistant_truncations
+ or omitted_tool_messages
+ or omitted_assistant_messages
+ ):
+ _log.info(
+ "trimmed intercepted prompt: prompt_tokens=%d/%d tool_truncations=%d "
+ "assistant_truncations=%d omitted_tool_messages=%d "
+ "omitted_assistant_messages=%d",
+ len(prompt_ids),
+ prompt_budget,
+ tool_truncations,
+ assistant_truncations,
+ omitted_tool_messages,
+ omitted_assistant_messages,
+ )
+
+ return prepared_messages, prompt_ids
+
+
+_CONTEXT_LIMIT_RE = re.compile(
+ r"maximum context length is (?P[\d,]+) tokens.*?"
+ r"requested (?P[\d,]+) output tokens and your prompt contains "
+ r"at least (?P[\d,]+) input tokens",
+ flags=re.IGNORECASE | re.DOTALL,
+)
+
+
+def _retry_completion_tokens_from_context_error(
+ error_text: str,
+ *,
+ safety_margin: int = 16,
+) -> tuple[int, int, int] | None:
+ """Return a smaller completion budget after a vLLM context-window error."""
+ match = _CONTEXT_LIMIT_RE.search(error_text or "")
+ if match is None:
+ return None
+
+ max_model_len = int(match.group("max_model_len").replace(",", ""))
+ prompt_tokens = int(match.group("prompt_tokens").replace(",", ""))
+ clamped_tokens = max(1, max_model_len - prompt_tokens - safety_margin)
+ return clamped_tokens, max_model_len, prompt_tokens
+
+
+def _is_context_window_error(exc: Exception) -> bool:
+ text = str(exc or "").lower()
+ return (
+ "maximum context length" in text
+ or "parameter=input_tokens" in text
+ or "input tokens" in text
+ and "output tokens" in text
+ )
+
+
+def _is_retriable_rollout_error(exc: Exception) -> bool:
+ if _is_context_window_error(exc):
+ return False
+
+ text = str(exc or "").lower()
+ retriable_markers = (
+ "429",
+ "too many requests",
+ "timed out",
+ "timeout",
+ "connection reset",
+ "temporarily unavailable",
+ "tunnel failed",
+ "sandbox",
+ )
+ return any(marker in text for marker in retriable_markers)
+
+
def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
if not isinstance(raw_tool_calls, list):
@@ -714,6 +1348,80 @@ def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]:
return normalized
+_TOOL_CALL_XML_RE = re.compile(
+ r"\s*(.*?)\s*",
+ flags=re.DOTALL,
+)
+
+
+def _extract_xml_tool_calls(text: str) -> tuple[str, list[dict[str, Any]]]:
+ """Parse Qwen-style XML tool-call blocks from assistant text.
+
+ vLLM's Qwen3 XML parser can return tool calls only inside
+ ``message.content``:
+
+
+ {"name": "answer", "arguments": {}}
+
+
+ When that happens we need to recover structured ``tool_calls`` before
+ sending the response back to Pi, otherwise the harness treats the reply as
+ plain terminal text and the rollout dies after one turn.
+ """
+ if not text:
+ return "", []
+
+ tool_calls: list[dict[str, Any]] = []
+ cursor = 0
+ content_parts: list[str] = []
+
+ for match in _TOOL_CALL_XML_RE.finditer(text):
+ start, end = match.span()
+ if start > cursor:
+ content_parts.append(text[cursor:start])
+
+ raw_block = match.group(1).strip()
+ parsed_block: Any
+ try:
+ parsed_block = json.loads(raw_block)
+ except json.JSONDecodeError:
+ content_parts.append(text[start:end])
+ cursor = end
+ continue
+
+ blocks = parsed_block if isinstance(parsed_block, list) else [parsed_block]
+ parsed_any = False
+ for block in blocks:
+ if not isinstance(block, dict):
+ continue
+ name = block.get("name")
+ if not isinstance(name, str) or not name:
+ continue
+ parsed_any = True
+ tool_calls.append(
+ {
+ "id": f"call_{uuid.uuid4().hex[:8]}",
+ "type": "function",
+ "function": {
+ "name": name,
+ "arguments": json.dumps(
+ block.get("arguments") or {},
+ ensure_ascii=False,
+ ),
+ },
+ }
+ )
+
+ if not parsed_any:
+ content_parts.append(text[start:end])
+ cursor = end
+
+ if cursor < len(text):
+ content_parts.append(text[cursor:])
+
+ return "".join(content_parts).strip(), tool_calls
+
+
def _parse_assistant_message(
*,
tokenizer: Any,
@@ -748,6 +1456,44 @@ def _parse_assistant_message(
return message
+def _normalize_chat_choice_message(
+ *,
+ tokenizer: Any,
+ choice: dict[str, Any],
+ completion_ids: list[int],
+) -> dict[str, Any]:
+ raw_message = choice.get("message")
+ if isinstance(raw_message, dict):
+ message = dict(raw_message)
+ else:
+ message = {"role": "assistant", "content": ""}
+
+ fallback_text = message.get("content")
+ if not isinstance(fallback_text, str):
+ fallback_text = ""
+
+ parsed = _parse_assistant_message(
+ tokenizer=tokenizer,
+ completion_ids=completion_ids,
+ fallback_text=fallback_text,
+ )
+
+ if not isinstance(message.get("role"), str):
+ message["role"] = "assistant"
+ if not isinstance(message.get("content"), str):
+ message["content"] = parsed.get("content", "")
+ if not (message.get("tool_calls") or []):
+ tool_calls = parsed.get("tool_calls") or []
+ if not tool_calls and fallback_text:
+ text_content, xml_tool_calls = _extract_xml_tool_calls(fallback_text)
+ if xml_tool_calls:
+ tool_calls = xml_tool_calls
+ message["content"] = text_content
+ if tool_calls:
+ message["tool_calls"] = tool_calls
+ return message
+
+
def _make_chat_response(
assistant_message: dict[str, Any],
model: str,
diff --git a/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh b/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh
new file mode 100755
index 000000000..bc8027be8
--- /dev/null
+++ b/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh
@@ -0,0 +1,439 @@
+#!/usr/bin/env bash
+#SBATCH --job-name=mini-swe-async-grpo
+#SBATCH --partition=hopper-prod
+#SBATCH --nodes=2
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:h100:1
+#SBATCH --cpus-per-task=16
+#SBATCH --time=08:00:00
+
+set -euo pipefail
+
+SCRIPT_DIR=$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)
+DEFAULT_REPO_ROOT=${SLURM_SUBMIT_DIR:-}
+if [[ -z "$DEFAULT_REPO_ROOT" ]]; then
+ DEFAULT_REPO_ROOT=$(cd -- "$SCRIPT_DIR/../../.." && pwd)
+fi
+REPO_ROOT=${REPO_ROOT:-$DEFAULT_REPO_ROOT}
+RUNS_ROOT=${RUNS_ROOT:-$REPO_ROOT/runs/mini_swe_async_grpo}
+RUN_DIR=${RUN_DIR:-$RUNS_ROOT/${SLURM_JOB_ID:-manual}}
+CLOUDFLARED=${CLOUDFLARED:-/fsx/benjamin_burtenshaw/bin/cloudflared}
+
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+CPUS_PER_TASK=${CPUS_PER_TASK:-${SLURM_CPUS_PER_TASK:-16}}
+JOB_PORT_BASE=${JOB_PORT_BASE:-$((20000 + (${SLURM_JOB_ID:-0} % 10000)))}
+VLLM_PORT=${VLLM_PORT:-$JOB_PORT_BASE}
+INTERCEPTION_PORT=${INTERCEPTION_PORT:-$((JOB_PORT_BASE + 1000))}
+MASTER_PORT=${MASTER_PORT:-$((JOB_PORT_BASE + 2000))}
+
+SWE_MODEL=${SWE_MODEL:-Qwen/Qwen3-1.7B}
+SWE_SANDBOX_BACKEND=${SWE_SANDBOX_BACKEND:-hf}
+SWE_AGENT=${SWE_AGENT:-pi}
+MAX_MODEL_LEN=${MAX_MODEL_LEN:-}
+GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.80}
+MAX_TASKS=${MAX_TASKS:-4}
+MAX_STEPS=${MAX_STEPS:-3}
+MAX_TURNS=${MAX_TURNS:-30}
+SWE_ROLLOUT_MAX_INFLIGHT=${SWE_ROLLOUT_MAX_INFLIGHT:-}
+SWE_TRAIN_DTYPE=${SWE_TRAIN_DTYPE:-bf16}
+SWE_TORCH_EMPTY_CACHE_STEPS=${SWE_TORCH_EMPTY_CACHE_STEPS:-1}
+SWE_LORA=${SWE_LORA:-0}
+SWE_LORA_R=${SWE_LORA_R:-16}
+SWE_LORA_ALPHA=${SWE_LORA_ALPHA:-}
+SWE_LORA_DROPOUT=${SWE_LORA_DROPOUT:-0.0}
+SWE_LORA_TARGET_MODULES=${SWE_LORA_TARGET_MODULES:-}
+SWE_LORA_BIAS=${SWE_LORA_BIAS:-none}
+SWE_LORA_USE_RSLORA=${SWE_LORA_USE_RSLORA:-0}
+
+if [[ -z "$MAX_MODEL_LEN" ]]; then
+ MAX_MODEL_LEN=$(
+ SWE_MODEL="$SWE_MODEL" "$REPO_ROOT/.venv/bin/python" - <<'PY' 2>/dev/null || true
+from transformers import AutoConfig
+import os
+
+cfg = AutoConfig.from_pretrained(os.environ["SWE_MODEL"])
+for section in (cfg, getattr(cfg, "text_config", None), getattr(cfg, "llm_config", None)):
+ if section is None:
+ continue
+ for key in ("max_position_embeddings", "model_max_length", "max_seq_len", "seq_length"):
+ value = section.get(key) if isinstance(section, dict) else getattr(section, key, None)
+ if isinstance(value, int) and value > 0:
+ print(value)
+ raise SystemExit(0)
+PY
+ )
+fi
+MAX_MODEL_LEN=${MAX_MODEL_LEN:-16384}
+
+if [[ -z "$SWE_ROLLOUT_MAX_INFLIGHT" ]]; then
+ if [[ "$SWE_SANDBOX_BACKEND" == "hf" ]]; then
+ SWE_ROLLOUT_MAX_INFLIGHT=1
+ else
+ SWE_ROLLOUT_MAX_INFLIGHT=2
+ fi
+fi
+
+MODEL_LOWER=$(printf '%s' "$SWE_MODEL" | tr '[:upper:]' '[:lower:]')
+VLLM_TOOL_CALL_PARSER=${VLLM_TOOL_CALL_PARSER:-}
+if [[ -z "$VLLM_TOOL_CALL_PARSER" ]]; then
+ case "$MODEL_LOWER" in
+ *qwen3*coder*)
+ VLLM_TOOL_CALL_PARSER=qwen3_coder
+ ;;
+ *qwen3*)
+ VLLM_TOOL_CALL_PARSER=qwen3_xml
+ ;;
+ esac
+fi
+
+SWE_CUDA_HOME=${SWE_CUDA_HOME:-}
+if [[ -z "$SWE_CUDA_HOME" ]]; then
+ case "$MODEL_LOWER" in
+ *qwen3.5*|*qwen3_5*)
+ if [[ -d /usr/local/cuda-12.4 ]]; then
+ SWE_CUDA_HOME=/usr/local/cuda-12.4
+ fi
+ ;;
+ esac
+fi
+
+SWE_VLLM_GDN_PREFILL_BACKEND=${SWE_VLLM_GDN_PREFILL_BACKEND:-}
+if [[ -z "$SWE_VLLM_GDN_PREFILL_BACKEND" ]]; then
+ case "$MODEL_LOWER" in
+ *qwen3.5*|*qwen3_5*)
+ SWE_VLLM_GDN_PREFILL_BACKEND=triton
+ ;;
+ esac
+fi
+SWE_VLLM_EXTRA_ARGS=${SWE_VLLM_EXTRA_ARGS:-}
+
+mkdir -p "$RUN_DIR/home"
+cd "$REPO_ROOT"
+
+if [[ -z "${HF_TOKEN:-}" ]]; then
+ if [[ -s /admin/home/benjamin_burtenshaw/.cache/huggingface/token ]]; then
+ HF_TOKEN="$(< /admin/home/benjamin_burtenshaw/.cache/huggingface/token)"
+ elif [[ -s /fsx/benjamin_burtenshaw/.cache/huggingface/token ]]; then
+ HF_TOKEN="$(< /fsx/benjamin_burtenshaw/.cache/huggingface/token)"
+ fi
+fi
+if [[ "$SWE_SANDBOX_BACKEND" == "hf" && -z "${HF_TOKEN:-}" ]]; then
+ echo "ERROR: HF_TOKEN is required for HF sandbox creation" >&2
+ exit 2
+fi
+
+INTERCEPTION_AUTH_TOKEN=${INTERCEPTION_AUTH_TOKEN:-$(python3 - <<'PY'
+import secrets
+print(secrets.token_urlsafe(32))
+PY
+)}
+
+mapfile -t ALLOC_NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST")
+if (( ${#ALLOC_NODES[@]} < 2 )); then
+ echo "ERROR: expected at least 2 allocated nodes, got ${#ALLOC_NODES[@]}" >&2
+ exit 3
+fi
+
+VLLM_NODE=${ALLOC_NODES[0]}
+TRAINER_NODES=("${ALLOC_NODES[@]:1}")
+TRAINER_MASTER=${TRAINER_NODES[0]}
+TRAINER_NODE_COUNT=${#TRAINER_NODES[@]}
+TRAINER_NODELIST=$(IFS=,; echo "${TRAINER_NODES[*]}")
+TOTAL_TRAINER_PROCS=$((TRAINER_NODE_COUNT * GPUS_PER_NODE))
+VLLM_URL="http://${VLLM_NODE}:${VLLM_PORT}"
+GPU_IDS=$(seq -s, 0 $((GPUS_PER_NODE - 1)))
+
+export RUN_DIR REPO_ROOT GPUS_PER_NODE VLLM_PORT MASTER_PORT
+export MAX_MODEL_LEN GPU_MEMORY_UTILIZATION MAX_TASKS MAX_STEPS MAX_TURNS
+export TRAINER_NODE_COUNT TOTAL_TRAINER_PROCS TRAINER_MASTER VLLM_URL GPU_IDS
+export PYTHONPATH="$REPO_ROOT/src:$REPO_ROOT/envs"
+export PYTHONUNBUFFERED=1
+export TRL_EXPERIMENTAL_SILENCE=1
+export HF_HOME=/fsx/benjamin_burtenshaw/.cache/huggingface
+export HF_HUB_CACHE=/fsx/benjamin_burtenshaw/.cache/huggingface/hub
+export HF_DATASETS_CACHE=/fsx/benjamin_burtenshaw/.cache/huggingface/datasets
+export HF_HUB_ENABLE_HF_TRANSFER=1
+export VLLM_CACHE_ROOT=/fsx/benjamin_burtenshaw/.cache/vllm
+export XDG_CACHE_HOME=/fsx/benjamin_burtenshaw/.cache
+export TORCH_HOME=/fsx/benjamin_burtenshaw/.cache/torch
+export TRITON_CACHE_DIR=/fsx/benjamin_burtenshaw/.cache/triton
+export FLASHINFER_CACHE_DIR=/fsx/benjamin_burtenshaw/.cache/flashinfer
+export FLASHINFER_WORKSPACE_BASE=${FLASHINFER_WORKSPACE_BASE:-/tmp/${USER}/flashinfer-workspace-${SLURM_JOB_ID:-manual}}
+export PYTHONDONTWRITEBYTECODE=${PYTHONDONTWRITEBYTECODE:-1}
+export PYTHONPYCACHEPREFIX=${PYTHONPYCACHEPREFIX:-/tmp/${USER}/openenv-pycache}
+export VLLM_NO_USAGE_STATS=1
+export HF_TOKEN
+export HOME="$RUN_DIR/home"
+export SWE_MODEL
+export SWE_SANDBOX_BACKEND
+export SWE_AGENT
+export VLLM_API_KEY=${VLLM_API_KEY:-token}
+export INTERCEPTION_HOST=0.0.0.0
+export INTERCEPTION_PORT
+export INTERCEPTION_AUTH_TOKEN
+export SWE_TRACKIO_SPACE_ID=${SWE_TRACKIO_SPACE_ID:-${TRACKIO_SPACE_ID:-burtenshaw/swe-grpo-dashboard}}
+export SWE_TRACKIO_PROJECT=${SWE_TRACKIO_PROJECT:-swe-async-grpo}
+export SWE_CHECKPOINT_TO_HUB=${SWE_CHECKPOINT_TO_HUB:-0}
+export SWE_HUB_MODEL_ID=${SWE_HUB_MODEL_ID:-}
+export SWE_HUB_PRIVATE_REPO=${SWE_HUB_PRIVATE_REPO:-1}
+export SWE_CHECKPOINT_SAVE_STEPS=${SWE_CHECKPOINT_SAVE_STEPS:-2}
+export SWE_CHECKPOINT_SAVE_TOTAL_LIMIT=${SWE_CHECKPOINT_SAVE_TOTAL_LIMIT:-2}
+export SWE_RESUME_FROM_CHECKPOINT=${SWE_RESUME_FROM_CHECKPOINT:-none}
+export SWE_IGNORE_DATA_SKIP=${SWE_IGNORE_DATA_SKIP:-1}
+export SWE_ASYNC_MAX_STALENESS=${SWE_ASYNC_MAX_STALENESS:-16}
+export SWE_ASYNC_WEIGHT_SYNC_STEPS=${SWE_ASYNC_WEIGHT_SYNC_STEPS:-1}
+export SWE_ASYNC_MAX_INFLIGHT_TASKS=${SWE_ASYNC_MAX_INFLIGHT_TASKS:-$SWE_ROLLOUT_MAX_INFLIGHT}
+export SWE_ASYNC_QUEUE_MAXSIZE=${SWE_ASYNC_QUEUE_MAXSIZE:-64}
+export SWE_ROLLOUT_MAX_INFLIGHT
+export SWE_VLLM_MAX_MODEL_LEN=${SWE_VLLM_MAX_MODEL_LEN:-$MAX_MODEL_LEN}
+export SWE_TRAIN_DTYPE SWE_TORCH_EMPTY_CACHE_STEPS
+export SWE_LORA SWE_LORA_R SWE_LORA_ALPHA SWE_LORA_DROPOUT
+export SWE_LORA_TARGET_MODULES SWE_LORA_BIAS SWE_LORA_USE_RSLORA
+export SWE_HF_SANDBOX_CREATE_RETRIES=${SWE_HF_SANDBOX_CREATE_RETRIES:-6}
+export SWE_HF_SANDBOX_CREATE_BACKOFF_S=${SWE_HF_SANDBOX_CREATE_BACKOFF_S:-20}
+export OPENENV_HF_SANDBOX_URL_TIMEOUT_S=${OPENENV_HF_SANDBOX_URL_TIMEOUT_S:-120}
+export SWE_ROLLOUT_QUEUE_TIMEOUT_S=${SWE_ROLLOUT_QUEUE_TIMEOUT_S:-900}
+export SWE_ROLLOUT_REQUEST_TIMEOUT_S=${SWE_ROLLOUT_REQUEST_TIMEOUT_S:-120}
+export SWE_ROLLOUT_FAILURE_BACKOFF_S=${SWE_ROLLOUT_FAILURE_BACKOFF_S:-30}
+export SWE_ROLLOUT_MAX_ATTEMPTS=${SWE_ROLLOUT_MAX_ATTEMPTS:-4}
+export SWE_GIT_CHECKOUT_TIMEOUT_S=${SWE_GIT_CHECKOUT_TIMEOUT_S:-300}
+export SWE_DISABLE_WEIGHT_TRANSFER=${SWE_DISABLE_WEIGHT_TRANSFER:-0}
+export VLLM_TOOL_CALL_PARSER
+export SWE_CUDA_HOME SWE_VLLM_GDN_PREFILL_BACKEND SWE_VLLM_EXTRA_ARGS
+export NCCL_DEBUG=${NCCL_DEBUG:-WARN}
+export TORCH_NCCL_ASYNC_ERROR_HANDLING=${TORCH_NCCL_ASYNC_ERROR_HANDLING:-1}
+export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}
+export OMP_NUM_THREADS=${OMP_NUM_THREADS:-4}
+export OPENENV_LOCAL_SANDBOX_ROOT=${OPENENV_LOCAL_SANDBOX_ROOT:-/tmp/${USER}/openenv-local-sandboxes-${SLURM_JOB_ID:-manual}}
+export OPENENV_LOCAL_SANDBOX_PRESERVE=${OPENENV_LOCAL_SANDBOX_PRESERVE:-0}
+if [[ -n "$SWE_CUDA_HOME" ]]; then
+ export CUDA_HOME="$SWE_CUDA_HOME"
+ export CUDA_PATH="$SWE_CUDA_HOME"
+ export PATH="$SWE_CUDA_HOME/bin:$PATH"
+ export LD_LIBRARY_PATH="$SWE_CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
+fi
+mkdir -p "$OPENENV_LOCAL_SANDBOX_ROOT" "$FLASHINFER_WORKSPACE_BASE"
+
+VLLM_STEP_PID=
+CLOUDFLARED_PID=
+cleanup() {
+ local rc=$?
+ if [[ -n "${CLOUDFLARED_PID:-}" ]]; then
+ kill "$CLOUDFLARED_PID" >/dev/null 2>&1 || true
+ fi
+ if [[ -n "${VLLM_STEP_PID:-}" ]]; then
+ kill "$VLLM_STEP_PID" >/dev/null 2>&1 || true
+ wait "$VLLM_STEP_PID" >/dev/null 2>&1 || true
+ fi
+ exit "$rc"
+}
+trap cleanup EXIT
+
+echo "job_start=$(date -Is)"
+echo "run_dir=$RUN_DIR"
+echo "nodes=${ALLOC_NODES[*]}"
+echo "vllm_node=$VLLM_NODE trainer_nodes=$TRAINER_NODELIST trainer_master=$TRAINER_MASTER"
+echo "ports=vllm:$VLLM_PORT interception:$INTERCEPTION_PORT master:$MASTER_PORT"
+echo "model=$SWE_MODEL agent=$SWE_AGENT parser=${VLLM_TOOL_CALL_PARSER:-none} sandbox=$SWE_SANDBOX_BACKEND dtype=$SWE_TRAIN_DTYPE lora=$SWE_LORA max_model_len=$MAX_MODEL_LEN max_tasks=$MAX_TASKS max_steps=$MAX_STEPS max_turns=$MAX_TURNS inflight=$SWE_ROLLOUT_MAX_INFLIGHT staleness=$SWE_ASYNC_MAX_STALENESS preserve_local=$OPENENV_LOCAL_SANDBOX_PRESERVE cuda_home=${CUDA_HOME:-none} gdn_prefill_backend=${SWE_VLLM_GDN_PREFILL_BACKEND:-default}"
+
+: > "$RUN_DIR/vllm.log"
+: > "$RUN_DIR/trainer.log"
+: > "$RUN_DIR/cloudflared.log"
+
+echo "starting_vllm=$(date -Is)"
+srun \
+ --nodes=1 \
+ --ntasks=1 \
+ --nodelist="$VLLM_NODE" \
+ --cpus-per-task="$CPUS_PER_TASK" \
+ --gres="gpu:h100:${GPUS_PER_NODE}" \
+ --kill-on-bad-exit=1 \
+ bash -lc '
+ set -euo pipefail
+ cd "$REPO_ROOT"
+ mkdir -p "$PYTHONPYCACHEPREFIX"
+ export CUDA_VISIBLE_DEVICES="$GPU_IDS"
+ export VLLM_SERVER_DEV_MODE=1
+ TOOL_ARGS=()
+ if [[ -n "${VLLM_TOOL_CALL_PARSER:-}" ]]; then
+ TOOL_ARGS+=(--enable-auto-tool-choice --tool-call-parser "$VLLM_TOOL_CALL_PARSER")
+ fi
+ GDN_ARGS=()
+ if [[ -n "${SWE_VLLM_GDN_PREFILL_BACKEND:-}" ]]; then
+ GDN_ARGS+=(--gdn-prefill-backend "$SWE_VLLM_GDN_PREFILL_BACKEND")
+ fi
+ EXTRA_ARGS=()
+ if [[ -n "${SWE_VLLM_EXTRA_ARGS:-}" ]]; then
+ read -r -a EXTRA_ARGS <<< "$SWE_VLLM_EXTRA_ARGS"
+ fi
+ exec .venv/bin/vllm serve "$SWE_MODEL" \
+ --tensor-parallel-size "$GPUS_PER_NODE" \
+ --max-model-len "$MAX_MODEL_LEN" \
+ --host 0.0.0.0 \
+ --port "$VLLM_PORT" \
+ --api-key "$VLLM_API_KEY" \
+ --gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \
+ --logprobs-mode processed_logprobs \
+ --weight-transfer-config '\''{"backend":"nccl"}'\'' \
+ "${TOOL_ARGS[@]}" \
+ "${GDN_ARGS[@]}" \
+ "${EXTRA_ARGS[@]}"
+ ' > "$RUN_DIR/vllm.log" 2>&1 &
+VLLM_STEP_PID=$!
+
+VLLM_BIND_MARKER="http://0.0.0.0:${VLLM_PORT}"
+VLLM_READY_MARKER="Application startup complete."
+for _ in $(seq 1 900); do
+ if grep -Fq "Address already in use" "$RUN_DIR/vllm.log"; then
+ echo "ERROR: vLLM failed to bind on ${VLLM_PORT}" >&2
+ tail -100 "$RUN_DIR/vllm.log" >&2 || true
+ exit 4
+ fi
+ if ! kill -0 "$VLLM_STEP_PID" >/dev/null 2>&1; then
+ echo "ERROR: vLLM srun exited before readiness" >&2
+ tail -100 "$RUN_DIR/vllm.log" >&2 || true
+ exit 4
+ fi
+ if grep -Fq "$VLLM_BIND_MARKER" "$RUN_DIR/vllm.log" \
+ && grep -Fq "$VLLM_READY_MARKER" "$RUN_DIR/vllm.log" \
+ && curl -fsS "$VLLM_URL/health" >/dev/null 2>&1 \
+ && curl -fsS "$VLLM_URL/get_world_size" >/dev/null 2>&1; then
+ echo "vllm_ready=$(date -Is)"
+ break
+ fi
+ sleep 2
+done
+if ! grep -Fq "$VLLM_BIND_MARKER" "$RUN_DIR/vllm.log" \
+ || ! grep -Fq "$VLLM_READY_MARKER" "$RUN_DIR/vllm.log" \
+ || ! curl -fsS "$VLLM_URL/health" >/dev/null 2>&1 \
+ || ! curl -fsS "$VLLM_URL/get_world_size" >/dev/null 2>&1; then
+ echo "ERROR: vLLM did not become ready on ${VLLM_URL}" >&2
+ tail -100 "$RUN_DIR/vllm.log" >&2 || true
+ exit 5
+fi
+
+INTERCEPTION_BASE_URL=
+if [[ "$SWE_SANDBOX_BACKEND" == "docker" ]]; then
+ INTERCEPTION_BASE_URL="http://host.docker.internal:${INTERCEPTION_PORT}"
+elif [[ "$SWE_SANDBOX_BACKEND" == "local" ]]; then
+ INTERCEPTION_BASE_URL="http://127.0.0.1:${INTERCEPTION_PORT}"
+else
+ echo "starting_cloudflared=$(date -Is)"
+ "$CLOUDFLARED" tunnel \
+ --url "http://${TRAINER_MASTER}:${INTERCEPTION_PORT}" \
+ --no-autoupdate \
+ --protocol quic \
+ --ha-connections 1 \
+ > "$RUN_DIR/cloudflared.log" 2>&1 &
+ CLOUDFLARED_PID=$!
+ echo "$CLOUDFLARED_PID" > "$RUN_DIR/cloudflared.pid"
+
+ for _ in $(seq 1 120); do
+ INTERCEPTION_BASE_URL=$(grep -Eo 'https://[A-Za-z0-9.-]+\.trycloudflare\.com' "$RUN_DIR/cloudflared.log" | tail -n 1 || true)
+ if [[ -n "$INTERCEPTION_BASE_URL" ]]; then
+ break
+ fi
+ if ! kill -0 "$CLOUDFLARED_PID" >/dev/null 2>&1; then
+ echo "ERROR: cloudflared exited before URL creation" >&2
+ tail -100 "$RUN_DIR/cloudflared.log" >&2 || true
+ exit 6
+ fi
+ sleep 1
+ done
+ if [[ -z "$INTERCEPTION_BASE_URL" ]]; then
+ echo "ERROR: cloudflared did not publish a tunnel URL" >&2
+ tail -100 "$RUN_DIR/cloudflared.log" >&2 || true
+ exit 7
+ fi
+fi
+export INTERCEPTION_BASE_URL
+echo "$INTERCEPTION_BASE_URL" > "$RUN_DIR/interception_base_url.txt"
+
+{
+ echo "SLURM_JOB_ID=${SLURM_JOB_ID:-}"
+ echo "RUN_DIR=$RUN_DIR"
+ echo "SWE_MODEL=$SWE_MODEL"
+ echo "SWE_SANDBOX_BACKEND=$SWE_SANDBOX_BACKEND"
+ echo "SWE_AGENT=$SWE_AGENT"
+ echo "VLLM_NODE=$VLLM_NODE"
+ echo "VLLM_URL=$VLLM_URL"
+ echo "VLLM_PORT=$VLLM_PORT"
+ echo "TRAINER_NODELIST=$TRAINER_NODELIST"
+ echo "TRAINER_MASTER=$TRAINER_MASTER"
+ echo "TOTAL_TRAINER_PROCS=$TOTAL_TRAINER_PROCS"
+ echo "INTERCEPTION_BASE_URL=$INTERCEPTION_BASE_URL"
+ echo "INTERCEPTION_PORT=$INTERCEPTION_PORT"
+ echo "MASTER_PORT=$MASTER_PORT"
+ echo "SWE_TRACKIO_SPACE_ID=$SWE_TRACKIO_SPACE_ID"
+ echo "SWE_TRACKIO_PROJECT=$SWE_TRACKIO_PROJECT"
+ echo "SWE_CHECKPOINT_TO_HUB=$SWE_CHECKPOINT_TO_HUB"
+ echo "SWE_HUB_MODEL_ID=$SWE_HUB_MODEL_ID"
+ echo "SWE_HUB_PRIVATE_REPO=$SWE_HUB_PRIVATE_REPO"
+ echo "SWE_CHECKPOINT_SAVE_STEPS=$SWE_CHECKPOINT_SAVE_STEPS"
+ echo "SWE_CHECKPOINT_SAVE_TOTAL_LIMIT=$SWE_CHECKPOINT_SAVE_TOTAL_LIMIT"
+ echo "SWE_RESUME_FROM_CHECKPOINT=$SWE_RESUME_FROM_CHECKPOINT"
+ echo "SWE_IGNORE_DATA_SKIP=$SWE_IGNORE_DATA_SKIP"
+ echo "MAX_TASKS=$MAX_TASKS"
+ echo "MAX_STEPS=$MAX_STEPS"
+ echo "MAX_TURNS=$MAX_TURNS"
+ echo "SWE_TASK_INDICES=${SWE_TASK_INDICES:-}"
+ echo "SWE_REPEAT_TASKS=${SWE_REPEAT_TASKS:-}"
+ echo "SWE_NUM_GENERATIONS=${SWE_NUM_GENERATIONS:-}"
+ echo "SWE_TEMPERATURE=${SWE_TEMPERATURE:-}"
+ echo "SWE_LEARNING_RATE=${SWE_LEARNING_RATE:-}"
+ echo "SWE_REWARD_MODE=${SWE_REWARD_MODE:-}"
+ echo "SWE_ENABLE_ANSWER_TOOL=${SWE_ENABLE_ANSWER_TOOL:-}"
+ echo "SWE_ROLLOUT_MAX_INFLIGHT=$SWE_ROLLOUT_MAX_INFLIGHT"
+ echo "SWE_ROLLOUT_MAX_ATTEMPTS=$SWE_ROLLOUT_MAX_ATTEMPTS"
+ echo "SWE_ROLLOUT_REQUEST_TIMEOUT_S=$SWE_ROLLOUT_REQUEST_TIMEOUT_S"
+ echo "SWE_GIT_CHECKOUT_TIMEOUT_S=$SWE_GIT_CHECKOUT_TIMEOUT_S"
+ echo "SWE_VLLM_MAX_MODEL_LEN=$SWE_VLLM_MAX_MODEL_LEN"
+ echo "SWE_TRAIN_DTYPE=$SWE_TRAIN_DTYPE"
+ echo "SWE_LORA=$SWE_LORA"
+ echo "SWE_LORA_R=$SWE_LORA_R"
+ echo "SWE_LORA_ALPHA=$SWE_LORA_ALPHA"
+ echo "SWE_LORA_DROPOUT=$SWE_LORA_DROPOUT"
+ echo "SWE_LORA_TARGET_MODULES=$SWE_LORA_TARGET_MODULES"
+ echo "SWE_LORA_BIAS=$SWE_LORA_BIAS"
+ echo "SWE_LORA_USE_RSLORA=$SWE_LORA_USE_RSLORA"
+ echo "SWE_OPTIM=${SWE_OPTIM:-}"
+ echo "SWE_TORCH_EMPTY_CACHE_STEPS=$SWE_TORCH_EMPTY_CACHE_STEPS"
+ echo "VLLM_TOOL_CALL_PARSER=${VLLM_TOOL_CALL_PARSER:-}"
+ echo "SWE_DISABLE_WEIGHT_TRANSFER=$SWE_DISABLE_WEIGHT_TRANSFER"
+ echo "OPENENV_LOCAL_SANDBOX_ROOT=$OPENENV_LOCAL_SANDBOX_ROOT"
+} > "$RUN_DIR/run.env"
+
+echo "starting_trainer=$(date -Is)"
+srun \
+ --nodes="$TRAINER_NODE_COUNT" \
+ --ntasks="$TRAINER_NODE_COUNT" \
+ --ntasks-per-node=1 \
+ --nodelist="$TRAINER_NODELIST" \
+ --cpus-per-task="$CPUS_PER_TASK" \
+ --gres="gpu:h100:${GPUS_PER_NODE}" \
+ --kill-on-bad-exit=1 \
+ bash -lc '
+ set -euo pipefail
+ cd "$REPO_ROOT"
+ mkdir -p "$PYTHONPYCACHEPREFIX"
+ export CUDA_VISIBLE_DEVICES="$GPU_IDS"
+ .venv/bin/python .venv/bin/accelerate launch \
+ --num_processes "$TOTAL_TRAINER_PROCS" \
+ --num_machines "$TRAINER_NODE_COUNT" \
+ --machine_rank "$SLURM_PROCID" \
+ --main_process_ip "$TRAINER_MASTER" \
+ --main_process_port "$MASTER_PORT" \
+ --mixed_precision no \
+ --num_cpu_threads_per_process "$OMP_NUM_THREADS" \
+ examples/mini_swe_env/train_swe_async_grpo.py \
+ --sandbox-backend "$SWE_SANDBOX_BACKEND" \
+ --agent "$SWE_AGENT" \
+ --vllm-url "$VLLM_URL" \
+ --task-variant lite \
+ --max-tasks "$MAX_TASKS" \
+ --max-steps "$MAX_STEPS" \
+ --max-turns "$MAX_TURNS"
+ ' > "$RUN_DIR/trainer.log" 2>&1
+
+echo "trainer_done=$(date -Is)"
+echo "job_done=$(date -Is)"
diff --git a/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh b/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh
new file mode 100755
index 000000000..b78121a65
--- /dev/null
+++ b/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh
@@ -0,0 +1,84 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+SCRIPT_DIR=$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)
+REPO_ROOT=$(cd -- "$SCRIPT_DIR/../../.." && pwd)
+SBATCH_TEMPLATE="$SCRIPT_DIR/sbatch_multinode_async_grpo.sh"
+RUNS_ROOT=${RUNS_ROOT:-$REPO_ROOT/runs/mini_swe_async_grpo}
+
+NODES=${NODES:-2}
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+CPUS_PER_TASK=${CPUS_PER_TASK:-16}
+PARTITION=${PARTITION:-hopper-prod}
+TIME_LIMIT=${TIME_LIMIT:-08:00:00}
+JOB_NAME=${JOB_NAME:-mini-swe-async-grpo}
+RUN_ID=${RUN_ID:-$(date -u +%Y%m%dT%H%M%SZ)-r$RANDOM}
+RUN_DIR=${RUN_DIR:-$RUNS_ROOT/$RUN_ID}
+
+if (( NODES < 2 || NODES > 4 )); then
+ echo "ERROR: NODES must be between 2 and 4, got ${NODES}" >&2
+ exit 2
+fi
+if (( GPUS_PER_NODE < 1 )); then
+ echo "ERROR: GPUS_PER_NODE must be >= 1, got ${GPUS_PER_NODE}" >&2
+ exit 2
+fi
+if (( CPUS_PER_TASK < 1 )); then
+ echo "ERROR: CPUS_PER_TASK must be >= 1, got ${CPUS_PER_TASK}" >&2
+ exit 2
+fi
+
+mkdir -p "$RUN_DIR/home"
+
+PORT_SEED=${PORT_SEED:-$(( (RANDOM + $(date -u +%S)) % 1000 ))}
+VLLM_PORT=${VLLM_PORT:-$((31000 + PORT_SEED))}
+INTERCEPTION_PORT=${INTERCEPTION_PORT:-$((32000 + PORT_SEED))}
+MASTER_PORT=${MASTER_PORT:-$((33000 + PORT_SEED))}
+
+export REPO_ROOT RUNS_ROOT RUN_DIR
+export GPUS_PER_NODE CPUS_PER_TASK
+export VLLM_PORT INTERCEPTION_PORT MASTER_PORT
+
+{
+ echo "RUN_ID=$RUN_ID"
+ echo "RUN_DIR=$RUN_DIR"
+ echo "NODES=$NODES"
+ echo "GPUS_PER_NODE=$GPUS_PER_NODE"
+ echo "CPUS_PER_TASK=$CPUS_PER_TASK"
+ echo "PARTITION=$PARTITION"
+ echo "TIME_LIMIT=$TIME_LIMIT"
+ echo "JOB_NAME=$JOB_NAME"
+ echo "VLLM_PORT=$VLLM_PORT"
+ echo "INTERCEPTION_PORT=$INTERCEPTION_PORT"
+ echo "MASTER_PORT=$MASTER_PORT"
+} > "$RUN_DIR/submission.env"
+
+SBATCH_ARGS=(
+ --job-name="$JOB_NAME"
+ --partition="$PARTITION"
+ --nodes="$NODES"
+ --ntasks-per-node=1
+ --gres="gpu:h100:${GPUS_PER_NODE}"
+ --cpus-per-task="$CPUS_PER_TASK"
+ --time="$TIME_LIMIT"
+ --output="$RUN_DIR/slurm-%j.out"
+ --error="$RUN_DIR/slurm-%j.err"
+)
+
+if [[ -n "${SBATCH_NODELIST:-}" ]]; then
+ SBATCH_ARGS+=(--nodelist="$SBATCH_NODELIST")
+fi
+if [[ -n "${SBATCH_EXCLUDE:-}" ]]; then
+ SBATCH_ARGS+=(--exclude="$SBATCH_EXCLUDE")
+fi
+
+submit_output=$(sbatch "${SBATCH_ARGS[@]}" "$SBATCH_TEMPLATE")
+printf '%s\n' "$submit_output"
+job_id=$(awk '{print $4}' <<<"$submit_output")
+if [[ -n "$job_id" ]]; then
+ printf '%s\n' "$job_id" > "$RUN_DIR/job_id.txt"
+fi
+
+echo "run_dir=$RUN_DIR"
+echo "ports=vllm:${VLLM_PORT} interception:${INTERCEPTION_PORT} master:${MASTER_PORT}"
diff --git a/examples/mini_swe_env/train_swe_async_grpo.py b/examples/mini_swe_env/train_swe_async_grpo.py
index 6b6a204a3..7adebba8a 100644
--- a/examples/mini_swe_env/train_swe_async_grpo.py
+++ b/examples/mini_swe_env/train_swe_async_grpo.py
@@ -2,7 +2,7 @@
"""Train SWE with AsyncGRPOTrainer + Pi agent + InterceptionServer.
Architecture:
- Pi (HF Sandbox) → InterceptionServer → SWERolloutWorker → vLLM /v1/completions
+ Pi (HF Sandbox) → InterceptionServer → SWERolloutWorker → vLLM /v1/chat/completions
← chat response back to Pi
vLLM runs on a separate GPU. The trainer, interception server, and rollout
@@ -27,22 +27,26 @@
import argparse
import asyncio
+import contextlib
import inspect
import logging
import os
import sys
import threading
+import types
from pathlib import Path
from typing import Any
+import torch
_root = Path(__file__).resolve().parent.parent.parent
for _p in (_root, _root / "src", _root / "envs"):
if str(_p) not in sys.path:
sys.path.insert(0, str(_p))
from datasets import Dataset # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer # noqa: E402
from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer # noqa: E402
+from trl.trainer.utils import _ChunkedLogProbFunction # noqa: E402
from examples.mini_swe_env.async_grpo.control_plane import ( # noqa: E402
SWEAsyncControlPlane,
@@ -113,11 +117,59 @@ def _args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--task-variant", default="lite", choices=["lite", "full"])
p.add_argument("--max-tasks", type=int, default=5)
+ p.add_argument(
+ "--task-offset",
+ type=int,
+ default=int(os.environ.get("SWE_TASK_OFFSET", "0")),
+ )
+ p.add_argument(
+ "--task-stride",
+ type=int,
+ default=int(os.environ.get("SWE_TASK_STRIDE", "1")),
+ )
+ p.add_argument("--task-indices", default=os.environ.get("SWE_TASK_INDICES", ""))
+ p.add_argument(
+ "--repeat-tasks",
+ type=int,
+ default=int(os.environ.get("SWE_REPEAT_TASKS", "1")),
+ )
p.add_argument("--max-steps", type=int, default=10)
p.add_argument("--max-turns", type=int, default=30)
- p.add_argument("--sandbox-backend", default="hf", choices=["docker", "e2b", "hf"])
+ p.add_argument(
+ "--num-generations",
+ type=int,
+ default=int(os.environ.get("SWE_NUM_GENERATIONS", "4")),
+ )
+ p.add_argument(
+ "--temperature",
+ type=float,
+ default=float(os.environ.get("SWE_TEMPERATURE", "1.0")),
+ )
+ p.add_argument(
+ "--learning-rate",
+ type=float,
+ default=float(os.environ.get("SWE_LEARNING_RATE", "1e-6")),
+ )
+ p.add_argument(
+ "--max-completion-tokens",
+ type=int,
+ default=int(os.environ.get("SWE_MAX_COMPLETION_TOKENS", "2048")),
+ )
+ p.add_argument(
+ "--sandbox-backend",
+ default=os.environ.get("SWE_SANDBOX_BACKEND", "hf"),
+ choices=["docker", "e2b", "hf", "local"],
+ )
p.add_argument("--vllm-url", default="http://localhost:8000")
- p.add_argument("--agent", default="pi", choices=["pi", "opencode"])
+ p.add_argument(
+ "--agent",
+ default=os.environ.get("SWE_AGENT", "pi"),
+ choices=["pi", "opencode"],
+ )
+ p.add_argument(
+ "--agent-thinking",
+ default=os.environ.get("SWE_AGENT_THINKING", "off"),
+ )
return p.parse_args()
@@ -145,6 +197,72 @@ def _int_env(name: str, default: int, *, min_value: int = 1) -> int:
return value
+def _float_env(name: str, default: float, *, min_value: float = 0.0) -> float:
+ raw = os.environ.get(name)
+ if raw is None or not raw.strip():
+ return default
+ value = float(raw)
+ if value < min_value:
+ raise ValueError(f"{name} must be >= {min_value}, got {value}")
+ return value
+
+
+def _parse_task_indices(raw: str) -> list[int]:
+ text = raw.strip()
+ if not text:
+ return []
+ indices: list[int] = []
+ for piece in text.split(","):
+ piece = piece.strip()
+ if not piece:
+ continue
+ indices.append(int(piece))
+ if not indices:
+ raise ValueError("task-indices must contain at least one integer")
+ return indices
+
+
+def _select_task_indices(
+ *,
+ total_tasks: int,
+ max_tasks: int,
+ task_offset: int,
+ task_stride: int,
+ task_indices_raw: str,
+ repeat_tasks: int,
+) -> list[int]:
+ if total_tasks <= 0:
+ raise ValueError("SWE-Gym task source is empty")
+ if max_tasks < 1:
+ raise ValueError(f"max_tasks must be >= 1, got {max_tasks}")
+ if task_offset < 0:
+ raise ValueError(f"task_offset must be >= 0, got {task_offset}")
+ if task_stride < 1:
+ raise ValueError(f"task_stride must be >= 1, got {task_stride}")
+ if repeat_tasks < 1:
+ raise ValueError(f"repeat_tasks must be >= 1, got {repeat_tasks}")
+
+ base_indices = _parse_task_indices(task_indices_raw)
+ if not base_indices:
+ base_indices = list(range(task_offset, total_tasks, task_stride))[:max_tasks]
+
+ if not base_indices:
+ raise ValueError(
+ "task selection produced no tasks; adjust max-tasks/task-offset/task-stride"
+ )
+
+ for idx in base_indices:
+ if idx < 0 or idx >= total_tasks:
+ raise IndexError(
+ f"task index {idx} out of range [0, {total_tasks - 1}]"
+ )
+
+ selected_indices: list[int] = []
+ for _ in range(repeat_tasks):
+ selected_indices.extend(base_indices)
+ return selected_indices
+
+
def _derive_checkpoint_repo_id() -> str | None:
explicit = os.environ.get("SWE_HUB_MODEL_ID", "").strip()
if explicit:
@@ -223,6 +341,379 @@ def _is_missing_checkpoint_error(exc: Exception) -> bool:
return any(hint in msg for hint in hints)
+def _is_main_process() -> bool:
+ """Return true for the single rank that owns rollout infrastructure."""
+ rank = os.environ.get("RANK")
+ if rank is not None and rank.strip():
+ return rank.strip() == "0"
+ local_rank = os.environ.get("LOCAL_RANK")
+ if local_rank is not None and local_rank.strip():
+ return local_rank.strip() == "0"
+ return True
+
+
+def _train_dtype_from_env() -> torch.dtype:
+ raw = os.environ.get("SWE_TRAIN_DTYPE", "bf16").strip().lower()
+ mapping = {
+ "bf16": torch.bfloat16,
+ "bfloat16": torch.bfloat16,
+ "fp16": torch.float16,
+ "float16": torch.float16,
+ "fp32": torch.float32,
+ "float32": torch.float32,
+ }
+ dtype = mapping.get(raw)
+ if dtype is None:
+ raise ValueError(
+ "SWE_TRAIN_DTYPE must be one of "
+ f"{sorted(mapping)}, got {raw!r}"
+ )
+ return dtype
+
+
+def _optional_env(name: str) -> str | None:
+ value = os.environ.get(name, "").strip()
+ return value or None
+
+
+def _csv_env(name: str) -> tuple[str, ...]:
+ raw = os.environ.get(name, "").strip()
+ if not raw:
+ return ()
+ values = tuple(piece.strip() for piece in raw.split(",") if piece.strip())
+ if not values:
+ raise ValueError(f"{name} must contain at least one non-empty value")
+ return values
+
+
+def _count_trainable_parameters(model: torch.nn.Module) -> int:
+ return sum(param.numel() for param in model.parameters() if param.requires_grad)
+
+
+def _single_gpu_trainable_param_limit() -> int:
+ raw = os.environ.get("SWE_SINGLE_GPU_TRAINABLE_PARAM_LIMIT", "").strip()
+ if not raw:
+ return 9_000_000_000
+ value = int(raw)
+ if value < 1:
+ raise ValueError(
+ f"SWE_SINGLE_GPU_TRAINABLE_PARAM_LIMIT must be >= 1, got {value}"
+ )
+ return value
+
+
+def _model_context_limit(model_name: str) -> int:
+ cfg = AutoConfig.from_pretrained(model_name)
+ cfg_sections = [
+ cfg,
+ getattr(cfg, "text_config", None),
+ getattr(cfg, "llm_config", None),
+ ]
+ for section in cfg_sections:
+ if section is None:
+ continue
+ for key in ("max_position_embeddings", "model_max_length", "max_seq_len", "seq_length"):
+ if isinstance(section, dict):
+ candidate = section.get(key)
+ else:
+ candidate = getattr(section, key, None)
+ if isinstance(candidate, int) and candidate > 0:
+ return candidate
+ raise ValueError(
+ f"Could not infer a positive context limit for model {model_name!r}"
+ )
+
+
+def _default_lora_target_modules(model_name: str) -> tuple[str, ...]:
+ model_id = model_name.lower().replace("_", "")
+ if any(
+ token in model_id
+ for token in ("qwen", "llama", "mistral", "mixtral", "deepseek")
+ ):
+ return (
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ )
+ return ("q_proj", "k_proj", "v_proj", "o_proj")
+
+
+def _sanitize_lora_merged_weight_name(name: str) -> str | None:
+ cleaned = name.removeprefix("module.").removeprefix("base_model.model.")
+ if (
+ ".lora_" in cleaned
+ or ".modules_to_save." in cleaned
+ or ".original_module." in cleaned
+ ):
+ return None
+ return cleaned.replace(".base_layer.", ".")
+
+
+def _lora_config_from_env(model_name: str) -> tuple[Any | None, str]:
+ if not _bool_env("SWE_LORA", False):
+ return None, "full"
+
+ try:
+ from peft import LoraConfig, TaskType
+ except ImportError as exc: # pragma: no cover - exercised in real runtime
+ raise RuntimeError(
+ "SWE_LORA=1 requires `peft` to be installed in the active environment."
+ ) from exc
+
+ rank = _int_env("SWE_LORA_R", 16)
+ alpha = _int_env("SWE_LORA_ALPHA", rank * 2)
+ dropout = _float_env("SWE_LORA_DROPOUT", 0.0)
+ target_modules = list(
+ _csv_env("SWE_LORA_TARGET_MODULES")
+ or _default_lora_target_modules(model_name)
+ )
+ modules_to_save = list(_csv_env("SWE_LORA_MODULES_TO_SAVE")) or None
+ bias = os.environ.get("SWE_LORA_BIAS", "none").strip() or "none"
+
+ config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=rank,
+ lora_alpha=alpha,
+ lora_dropout=dropout,
+ target_modules=target_modules,
+ modules_to_save=modules_to_save,
+ bias=bias,
+ use_rslora=_bool_env("SWE_LORA_USE_RSLORA", False),
+ )
+ summary = (
+ "lora("
+ f"r={rank},alpha={alpha},dropout={dropout:g},"
+ f"targets={','.join(target_modules)}"
+ ")"
+ )
+ return config, summary
+
+
+@contextlib.contextmanager
+def _patched_async_grpo_model_loader(
+ *,
+ dtype: torch.dtype,
+ attn_implementation: str | None,
+ lora_config: Any | None = None,
+):
+ """Force AsyncGRPOTrainer to load the policy in the requested dtype.
+
+ TRL's experimental AsyncGRPOTrainer currently hardcodes
+ ``AutoModelForCausalLM.from_pretrained(..., dtype=torch.float32)``,
+ which makes 8B full-finetuning overflow 80GB H100s on the first
+ optimizer step. Patch the loader just around trainer construction so
+ the example can run with bf16/fp16 policy weights.
+ """
+
+ original = AutoModelForCausalLM.__dict__["from_pretrained"]
+
+ def _patched(
+ cls,
+ pretrained_model_name_or_path: str,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ requested_dtype = kwargs.get("dtype")
+ if requested_dtype in {None, torch.float32, "float32", "fp32"}:
+ kwargs["dtype"] = dtype
+ if attn_implementation and not kwargs.get("attn_implementation"):
+ kwargs["attn_implementation"] = attn_implementation
+ kwargs.setdefault("low_cpu_mem_usage", True)
+ model = original.__get__(None, cls)(
+ pretrained_model_name_or_path,
+ *args,
+ **kwargs,
+ )
+ if lora_config is not None:
+ from peft import get_peft_model
+
+ model = get_peft_model(model, lora_config)
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ if hasattr(model, "config"):
+ model.config.use_cache = False
+ return model
+
+ AutoModelForCausalLM.from_pretrained = classmethod(_patched)
+ try:
+ yield
+ finally:
+ AutoModelForCausalLM.from_pretrained = original
+
+
+def _patch_lora_weight_streaming(trainer: AsyncGRPOTrainer) -> None:
+ def _streaming_iter_with_merged_lora(self: AsyncGRPOTrainer):
+ device = self.accelerator.device
+ unwrap_model = getattr(self.accelerator, "unwrap_model", None)
+ model = (
+ unwrap_model(self.model)
+ if callable(unwrap_model)
+ else self.model
+ )
+ merge_adapter = getattr(model, "merge_adapter", None)
+ unmerge_adapter = getattr(model, "unmerge_adapter", None)
+ get_base_model = getattr(model, "get_base_model", None)
+ if not callable(merge_adapter) or not callable(unmerge_adapter):
+ raise RuntimeError(
+ "SWE_LORA requires a PEFT model with merge_adapter()/unmerge_adapter() "
+ "support so vLLM can receive merged weights."
+ )
+ if not callable(get_base_model):
+ raise RuntimeError(
+ "SWE_LORA weight sync requires get_base_model() on the PEFT wrapper."
+ )
+
+ merge_adapter()
+ try:
+ base_model = get_base_model()
+ for name, param in base_model.named_parameters():
+ mapped_name = _sanitize_lora_merged_weight_name(name)
+ if mapped_name is None:
+ continue
+ full = param.full_tensor() if hasattr(param, "full_tensor") else param.detach()
+ if full.device != device:
+ full = full.to(device)
+ yield mapped_name, full
+ finally:
+ unmerge_adapter()
+
+ trainer._streaming_iter = types.MethodType( # type: ignore[method-assign]
+ _streaming_iter_with_merged_lora,
+ trainer,
+ )
+
+
+def _patch_peft_lora_resume_tp_sharding_for_ddp(model: torch.nn.Module) -> None:
+ """Avoid PEFT's TP-only resume hook for ordinary DDP LoRA training."""
+ try:
+ import peft.utils.save_and_load as peft_save_and_load
+ except ImportError:
+ return
+
+ original = getattr(peft_save_and_load, "_maybe_shard_state_dict_for_tp", None)
+ if not callable(original) or getattr(original, "_swe_ddp_guard", False):
+ return
+
+ def _has_hf_tensor_parallel_plan(model_arg: torch.nn.Module) -> bool:
+ for module in model_arg.modules():
+ get_base_layer = getattr(module, "get_base_layer", None)
+ base_layer = get_base_layer() if callable(get_base_layer) else module
+ if (
+ getattr(base_layer, "_hf_tp_plan", None) is not None
+ and getattr(base_layer, "_hf_device_mesh", None) is not None
+ ):
+ return True
+ return False
+
+ def _guarded_maybe_shard_state_dict_for_tp(
+ model_arg: torch.nn.Module,
+ state_dict: dict[str, torch.Tensor],
+ adapter_name: str,
+ ) -> None:
+ if _has_hf_tensor_parallel_plan(model_arg):
+ return original(model_arg, state_dict, adapter_name)
+ return None
+
+ _guarded_maybe_shard_state_dict_for_tp._swe_ddp_guard = True # type: ignore[attr-defined]
+ _guarded_maybe_shard_state_dict_for_tp._swe_original = original # type: ignore[attr-defined]
+ peft_save_and_load._maybe_shard_state_dict_for_tp = ( # type: ignore[method-assign]
+ _guarded_maybe_shard_state_dict_for_tp
+ )
+
+
+def _chunked_logprob_backbone(model: torch.nn.Module) -> torch.nn.Module:
+ backbone = getattr(model, "model", model)
+ if hasattr(backbone, "lm_head") and hasattr(backbone, "model"):
+ return backbone.model
+ return backbone
+
+
+def _patch_chunked_lm_head_for_wrapped_causal_lm(
+ model: torch.nn.Module,
+ *,
+ temperature: float,
+ chunk_size: int = 8192,
+) -> None:
+ def _chunked_forward(
+ self: torch.nn.Module,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ labels: torch.Tensor | None = None,
+ completion_mask: torch.Tensor | None = None,
+ use_cache: bool = False,
+ **kwargs: Any,
+ ) -> dict[str, torch.Tensor]:
+ assert labels is not None, "requires labels to not be None for logprob computation"
+
+ outputs = _chunked_logprob_backbone(self)(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ logit_scale = getattr(self.config, "logit_scale", 1.0)
+ hidden_states = getattr(outputs, "last_hidden_state", None)
+ if hidden_states is None:
+ all_hidden_states = getattr(outputs, "hidden_states", None)
+ if not all_hidden_states:
+ raise AttributeError(
+ "Chunked GRPO forward requires `last_hidden_state` or "
+ "`hidden_states` on the model outputs."
+ )
+ hidden_states = all_hidden_states[-1]
+
+ hidden_states = hidden_states[:, :-1, :]
+ labels = labels[:, 1:]
+
+ b, s, h = hidden_states.shape
+ hidden_flat = hidden_states.reshape(b * s, h).contiguous()
+ targets_flat = labels.reshape(b * s).contiguous()
+
+ valid_mask = None
+ if completion_mask is not None:
+ completion_mask = completion_mask[:, 1:]
+ valid_mask = completion_mask.bool().reshape(b * s)
+ hidden_flat = hidden_flat[valid_mask]
+ targets_flat = targets_flat[valid_mask]
+
+ logprobs_valid, entropy_valid = _ChunkedLogProbFunction.apply(
+ hidden_flat,
+ self.lm_head.weight,
+ targets_flat,
+ temperature,
+ chunk_size,
+ logit_scale,
+ )
+
+ if valid_mask is not None:
+ logprobs = torch.zeros(
+ b * s,
+ device=logprobs_valid.device,
+ dtype=logprobs_valid.dtype,
+ )
+ entropy = torch.zeros(
+ b * s,
+ device=entropy_valid.device,
+ dtype=entropy_valid.dtype,
+ )
+ logprobs[valid_mask] = logprobs_valid
+ entropy[valid_mask] = entropy_valid
+ else:
+ logprobs = logprobs_valid
+ entropy = entropy_valid
+
+ return {
+ "log_probs": logprobs.reshape(b, s),
+ "entropy": entropy.reshape(b, s),
+ }
+
+ model.forward = types.MethodType(_chunked_forward, model)
+
+
def main() -> int:
logging.basicConfig(
level=logging.INFO,
@@ -230,14 +721,45 @@ def main() -> int:
datefmt="%H:%M:%S",
)
args = _args()
+ if args.num_generations < 2:
+ raise ValueError(
+ f"--num-generations must be >= 2 for GRPO, got {args.num_generations}"
+ )
model = _env("SWE_MODEL")
vllm_url = args.vllm_url
vllm_key = os.environ.get("VLLM_API_KEY", "token").strip()
+ train_dtype = _train_dtype_from_env()
+ train_attn_implementation = _optional_env("SWE_TRAIN_ATTN_IMPLEMENTATION")
+ lora_config, adapter_summary = _lora_config_from_env(model)
+ if lora_config is not None and _bool_env("SWE_DISABLE_WEIGHT_TRANSFER", False):
+ raise RuntimeError(
+ "SWE_LORA requires SWE_DISABLE_WEIGHT_TRANSFER=0 so rollout generation "
+ "tracks the trained policy."
+ )
# ── Load tasks ────────────────────────────────────────────────
- gym_tasks = load_swegym_tasks(args.task_variant)[: args.max_tasks]
+ all_gym_tasks = load_swegym_tasks(args.task_variant)
+ selected_indices = _select_task_indices(
+ total_tasks=len(all_gym_tasks),
+ max_tasks=args.max_tasks,
+ task_offset=args.task_offset,
+ task_stride=args.task_stride,
+ task_indices_raw=args.task_indices,
+ repeat_tasks=args.repeat_tasks,
+ )
+ gym_tasks = [all_gym_tasks[idx] for idx in selected_indices]
swe_tasks = [t.to_swe_task() for t in gym_tasks]
- _log.info("loaded %d tasks", len(swe_tasks))
+ unique_indices = sorted(set(selected_indices))
+ _log.info(
+ "loaded %d task slots (%d unique tasks) indices=%s",
+ len(swe_tasks),
+ len(unique_indices),
+ selected_indices,
+ )
+ _log.info(
+ "task instances=%s",
+ [all_gym_tasks[idx].instance_id for idx in unique_indices],
+ )
# ── Dataset (prompt per task) ─────────────────────────────────
dataset = Dataset.from_list(
@@ -254,43 +776,93 @@ def main() -> int:
tokenizer = AutoTokenizer.from_pretrained(model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
+ model_context_limit = _model_context_limit(model)
# ── Interception control plane (background thread) ────────────
- control_cfg = SWEAsyncControlPlaneConfig.from_env()
- control_plane = SWEAsyncControlPlane(config=control_cfg)
- server_loop, server_thread = start_interception_server(control_plane)
- _log.info("InterceptionServer running in background thread")
+ is_main_process = _is_main_process()
+ control_plane: SWEAsyncControlPlane | None = None
+ server_loop: asyncio.AbstractEventLoop | None = None
+ server_thread: threading.Thread | None = None
+ if is_main_process:
+ control_cfg = SWEAsyncControlPlaneConfig.from_env()
+ control_plane = SWEAsyncControlPlane(config=control_cfg)
+ server_loop, server_thread = start_interception_server(control_plane)
+ _log.info("InterceptionServer running in background thread")
try:
# ── Session factory (Pi in sandbox) ───────────────────────
- backend = create_sandbox_backend(args.sandbox_backend)
- session_factory = SWESessionFactory(
- agent=args.agent,
- config=SWEAgentConfig(
- base_url=control_plane.interception_base_url,
- api_key=control_plane.auth_token,
- model=model,
- agent_timeout_s=1800.0,
- ),
- sandbox_backend=backend,
- mode="interception_gate",
- interception_server=control_plane.server,
- interception_base_url=control_plane.interception_base_url,
- )
+ worker: SWERolloutWorker | None = None
+ if is_main_process:
+ assert control_plane is not None
+ backend_kwargs: dict[str, Any] = {}
+ if args.sandbox_backend == "hf":
+ backend_kwargs = {
+ "create_retries": _int_env("SWE_HF_SANDBOX_CREATE_RETRIES", 6),
+ "create_backoff_s": _float_env(
+ "SWE_HF_SANDBOX_CREATE_BACKOFF_S", 20.0
+ ),
+ }
+ elif args.sandbox_backend == "local":
+ backend_kwargs = {
+ "root_dir": os.environ.get("OPENENV_LOCAL_SANDBOX_ROOT", "").strip()
+ or None,
+ "preserve_root": _bool_env(
+ "OPENENV_LOCAL_SANDBOX_PRESERVE",
+ False,
+ ),
+ }
+ default_rollout_inflight = min(
+ args.num_generations,
+ 1 if args.sandbox_backend == "hf" else max(2, args.num_generations),
+ )
+ backend = create_sandbox_backend(args.sandbox_backend, **backend_kwargs)
+ session_factory = SWESessionFactory(
+ agent=args.agent,
+ config=SWEAgentConfig(
+ base_url=control_plane.interception_base_url,
+ api_key=control_plane.auth_token,
+ model=model,
+ agent_timeout_s=_float_env("SWE_AGENT_TIMEOUT_S", 1800.0),
+ thinking=args.agent_thinking,
+ ),
+ sandbox_backend=backend,
+ mode="interception_gate",
+ interception_server=control_plane.server,
+ interception_base_url=control_plane.interception_base_url,
+ )
- # ── Rollout worker ────────────────────────────────────────
- worker = SWERolloutWorker(
- session_factory=session_factory,
- tasks=swe_tasks,
- tokenizer=tokenizer,
- vllm_base_url=vllm_url,
- vllm_api_key=vllm_key,
- vllm_model=model,
- config=WorkerConfig(
- max_inflight=2,
- max_turns=args.max_turns,
- ),
- )
+ # ── Rollout worker ────────────────────────────────────────
+ worker = SWERolloutWorker(
+ session_factory=session_factory,
+ tasks=swe_tasks,
+ tokenizer=tokenizer,
+ vllm_base_url=vllm_url,
+ vllm_api_key=vllm_key,
+ vllm_model=model,
+ config=WorkerConfig(
+ max_inflight=_int_env(
+ "SWE_ROLLOUT_MAX_INFLIGHT",
+ default_rollout_inflight,
+ ),
+ max_rollout_attempts=_int_env(
+ "SWE_ROLLOUT_MAX_ATTEMPTS",
+ 4,
+ ),
+ num_generations=args.num_generations,
+ request_timeout_s=_float_env("SWE_ROLLOUT_REQUEST_TIMEOUT_S", 600.0),
+ max_turns=args.max_turns,
+ max_model_len=_int_env(
+ "SWE_VLLM_MAX_MODEL_LEN",
+ model_context_limit,
+ ),
+ max_completion_tokens=args.max_completion_tokens,
+ temperature=args.temperature,
+ failure_backoff_s=_float_env(
+ "SWE_ROLLOUT_FAILURE_BACKOFF_S",
+ 30.0,
+ ),
+ ),
+ )
# ── Trainer ───────────────────────────────────────────────
def _noop_reward(**kwargs: Any) -> list[float]:
@@ -301,30 +873,74 @@ def _noop_reward(**kwargs: Any) -> list[float]:
checkpoint_args, resume_from_checkpoint, checkpoint_requested = (
_build_checkpoint_args()
)
+ train_optim = os.environ.get("SWE_OPTIM", "").strip() or "paged_adamw_8bit"
async_grpo_args: dict[str, Any] = {
"output_dir": os.path.join(
os.environ.get("HOME", "/tmp"), "outputs/swe_async_grpo"
),
"vllm_server_base_url": vllm_url,
- "vllm_server_timeout": 2400.0,
- "max_completion_length": 2048,
+ "vllm_server_timeout": _float_env("SWE_ROLLOUT_QUEUE_TIMEOUT_S", 900.0),
+ "max_completion_length": args.max_completion_tokens,
"max_steps": args.max_steps,
+ # Online rollout samples are newly generated after resume; do not
+ # discard live GRPO groups to match the previous dataloader offset.
+ "ignore_data_skip": _bool_env("SWE_IGNORE_DATA_SKIP", True),
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 1,
- "num_generations": 1,
- "learning_rate": 1e-6,
- "temperature": 1.0,
- "optim": "adamw_bnb_8bit",
- "bf16": True,
+ "num_generations": args.num_generations,
+ "learning_rate": args.learning_rate,
+ "temperature": args.temperature,
+ "optim": train_optim,
+ "bf16": train_dtype == torch.bfloat16,
+ "fp16": train_dtype == torch.float16,
"gradient_checkpointing": True,
- "max_staleness": 4,
- "weight_sync_steps": 1,
- "max_inflight_tasks": 2,
+ "torch_empty_cache_steps": _int_env(
+ "SWE_TORCH_EMPTY_CACHE_STEPS",
+ 1,
+ ),
+ "max_staleness": _int_env(
+ "SWE_ASYNC_MAX_STALENESS",
+ max(16, args.num_generations * 4),
+ ),
+ "weight_sync_steps": _int_env("SWE_ASYNC_WEIGHT_SYNC_STEPS", 1),
+ "max_inflight_tasks": _int_env(
+ "SWE_ASYNC_MAX_INFLIGHT_TASKS",
+ args.num_generations,
+ ),
+ "queue_maxsize": _int_env("SWE_ASYNC_QUEUE_MAXSIZE", 64),
"logging_steps": 1,
"report_to": "trackio",
"run_name": f"swe-grpo-{model.split('/')[-1]}",
- "trackio_space_id": os.environ.get("TRACKIO_SPACE_ID", "").strip() or None,
+ "project": (
+ os.environ.get("SWE_TRACKIO_PROJECT", "").strip()
+ or os.environ.get("TRACKIO_PROJECT", "").strip()
+ or "huggingface"
+ ),
+ "trackio_space_id": (
+ os.environ.get("SWE_TRACKIO_SPACE_ID", "").strip()
+ or os.environ.get("TRACKIO_SPACE_ID", "").strip()
+ or None
+ ),
+ "log_completions": _bool_env("SWE_LOG_COMPLETIONS", False),
+ "num_completions_to_print": _int_env(
+ "SWE_LOG_COMPLETIONS_LIMIT",
+ 3,
+ ),
+ "accelerator_config": {
+ "split_batches": True,
+ "dispatch_batches": True,
+ },
}
+ async_grpo_args.update(
+ _filter_async_grpo_kwargs(
+ {
+ # PEFT LoRA + DDP can mark the same adapter parameter ready
+ # twice with reentrant checkpointing during GRPO backward.
+ "gradient_checkpointing_kwargs": {"use_reentrant": False},
+ "ddp_find_unused_parameters": False,
+ }
+ )
+ )
filtered_checkpoint_args = _filter_async_grpo_kwargs(checkpoint_args)
async_grpo_args.update(filtered_checkpoint_args)
@@ -336,21 +952,55 @@ def _noop_reward(**kwargs: Any) -> list[float]:
)
resume_from_checkpoint = None
- trainer = AsyncGRPOTrainer(
- model=model,
- reward_funcs=_noop_reward,
- train_dataset=dataset,
- processing_class=tokenizer,
- rollout_worker=worker,
- args=AsyncGRPOConfig(**async_grpo_args),
- )
+ trainer_args = AsyncGRPOConfig(**async_grpo_args)
+ with _patched_async_grpo_model_loader(
+ dtype=train_dtype,
+ attn_implementation=train_attn_implementation,
+ lora_config=lora_config,
+ ):
+ trainer = AsyncGRPOTrainer(
+ model=model,
+ reward_funcs=_noop_reward,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ rollout_worker=worker,
+ args=trainer_args,
+ )
+
+ trainable_params = _count_trainable_parameters(trainer.model)
+ trainer_world_size = trainer.accelerator.num_processes
+ if lora_config is not None:
+ _patch_lora_weight_streaming(trainer)
+ _patch_peft_lora_resume_tp_sharding_for_ddp(trainer.model)
+ _patch_chunked_lm_head_for_wrapped_causal_lm(
+ trainer.model,
+ temperature=args.temperature,
+ )
+ single_gpu_param_limit = _single_gpu_trainable_param_limit()
+ if trainer_world_size == 1 and trainable_params > single_gpu_param_limit:
+ raise RuntimeError(
+ "Unsupported full-parameter Async GRPO config for a single training GPU: "
+ f"model={model} trainable_params={trainable_params:,} "
+ f"limit={single_gpu_param_limit:,}. "
+ "Use a smaller base model, add sharded training, or reduce trainable "
+ "state before launching this example."
+ )
_log.info(
- "starting training: model=%s tasks=%d checkpointing=%s resume=%s",
+ "starting training: model=%s tasks=%d generations=%d temp=%.2f lr=%g optim=%s adapter=%s dtype=%s attn_impl=%s checkpointing=%s resume=%s trainable_params=%s world_size=%d",
model,
len(swe_tasks),
+ args.num_generations,
+ args.temperature,
+ args.learning_rate,
+ train_optim,
+ adapter_summary,
+ str(train_dtype).split(".")[-1],
+ train_attn_implementation or "auto",
checkpoint_enabled,
resume_from_checkpoint or "none",
+ f"{trainable_params:,}",
+ trainer_world_size,
)
if resume_from_checkpoint is None:
trainer.train()
@@ -369,7 +1019,11 @@ def _noop_reward(**kwargs: Any) -> list[float]:
else:
raise
- if checkpoint_enabled and hasattr(trainer, "push_to_hub"):
+ if (
+ checkpoint_enabled
+ and hasattr(trainer, "push_to_hub")
+ and trainer.is_world_process_zero()
+ ):
trainer.push_to_hub(
commit_message=f"Final checkpoint at step {getattr(trainer.state, 'global_step', '?')}"
)
@@ -377,7 +1031,12 @@ def _noop_reward(**kwargs: Any) -> list[float]:
_log.info("done: step=%s", getattr(trainer.state, "global_step", "?"))
return 0
finally:
- stop_interception_server(control_plane, server_loop, server_thread)
+ if (
+ control_plane is not None
+ and server_loop is not None
+ and server_thread is not None
+ ):
+ stop_interception_server(control_plane, server_loop, server_thread)
if __name__ == "__main__":
diff --git a/src/openenv/core/harness/sandbox/__init__.py b/src/openenv/core/harness/sandbox/__init__.py
index 208fe54d5..e11d9c812 100644
--- a/src/openenv/core/harness/sandbox/__init__.py
+++ b/src/openenv/core/harness/sandbox/__init__.py
@@ -18,6 +18,7 @@
from .base import BgJob, ExecResult, SandboxBackend, SandboxHandle
from .docker_backend import DockerBgJob, DockerSandboxBackend, DockerSandboxHandle
+from .local_backend import LocalBgJob, LocalSandboxBackend, LocalSandboxHandle
__all__ = [
"BgJob",
@@ -25,6 +26,9 @@
"DockerSandboxBackend",
"DockerSandboxHandle",
"ExecResult",
+ "LocalBgJob",
+ "LocalSandboxBackend",
+ "LocalSandboxHandle",
"SandboxBackend",
"SandboxHandle",
"create_sandbox_backend",
@@ -46,7 +50,7 @@
def create_sandbox_backend(
- backend: Literal["e2b", "docker", "hf"] = "e2b",
+ backend: Literal["e2b", "docker", "hf", "local"] = "e2b",
**kwargs: Any,
) -> SandboxBackend:
"""Create a sandbox backend by name.
@@ -57,6 +61,8 @@ def create_sandbox_backend(
For ``"docker"``: local Docker, no external dependencies.
For ``"hf"``: Hugging Face Jobs via ``hf-sandbox``.
+
+ For ``"local"``: isolated temp directories and subprocesses on the host.
"""
if backend == "e2b":
from .e2b_backend import E2BSandboxBackend
@@ -68,6 +74,8 @@ def create_sandbox_backend(
from .hf_backend import HFSandboxBackend
return HFSandboxBackend(**kwargs)
+ elif backend == "local":
+ return LocalSandboxBackend(**kwargs)
raise ValueError(
- f"Unknown sandbox backend: {backend!r}. Use 'e2b', 'docker', or 'hf'."
+ f"Unknown sandbox backend: {backend!r}. Use 'e2b', 'docker', 'hf', or 'local'."
)
diff --git a/src/openenv/core/harness/sandbox/local_backend.py b/src/openenv/core/harness/sandbox/local_backend.py
new file mode 100644
index 000000000..62725e3a6
--- /dev/null
+++ b/src/openenv/core/harness/sandbox/local_backend.py
@@ -0,0 +1,294 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Local subprocess implementation of :class:`SandboxBackend`.
+
+Each sandbox gets an isolated temp root on the host filesystem plus an
+optional per-sandbox virtualenv. Commands execute directly on the trainer
+host, which makes this backend suitable for rootless cluster environments
+where Docker/HF Jobs are unavailable.
+"""
+
+from __future__ import annotations
+
+import os
+import shutil
+import signal
+import subprocess
+import sys
+import tempfile
+import time
+import uuid
+from pathlib import Path
+
+from openenv.core.harness.sandbox.base import BgJob, ExecResult, SandboxHandle
+
+_CANONICAL_HOME = "/home/user"
+_CANONICAL_WORKDIR = "/testbed"
+
+
+class LocalBgJob:
+ """Handle to a background subprocess launched in a local sandbox."""
+
+ def __init__(self, proc: subprocess.Popen[str]) -> None:
+ self._proc = proc
+
+ @property
+ def pid(self) -> int:
+ return int(self._proc.pid)
+
+ def wait(self, timeout: float | None = None) -> int:
+ try:
+ return int(self._proc.wait(timeout=timeout))
+ except subprocess.TimeoutExpired as exc:
+ raise TimeoutError(
+ f"Background command (pid={self._proc.pid}) did not exit within {timeout}s"
+ ) from exc
+
+ def kill(self) -> None:
+ if self._proc.poll() is not None:
+ return
+ try:
+ os.killpg(self._proc.pid, signal.SIGTERM)
+ except ProcessLookupError:
+ return
+ except Exception:
+ self._proc.terminate()
+ return
+
+ deadline = time.monotonic() + 5.0
+ while self._proc.poll() is None and time.monotonic() < deadline:
+ time.sleep(0.1)
+ if self._proc.poll() is None:
+ try:
+ os.killpg(self._proc.pid, signal.SIGKILL)
+ except Exception:
+ self._proc.kill()
+
+
+class LocalSandboxHandle:
+ """Host-backed sandbox handle with an isolated temp root."""
+
+ supports_images = False
+
+ def __init__(
+ self,
+ *,
+ root_dir: str,
+ home_dir: str,
+ workdir: str,
+ tmp_dir: str,
+ default_envs: dict[str, str] | None = None,
+ preserve_root: bool = False,
+ ) -> None:
+ self._root_dir = root_dir
+ self._home_dir = home_dir
+ self._workdir = workdir
+ self._tmp_dir = tmp_dir
+ self._default_envs = dict(default_envs or {})
+ self._preserve_root = preserve_root
+ self._bg_jobs: list[LocalBgJob] = []
+ self._sandbox_id = f"local-{uuid.uuid4().hex[:12]}"
+
+ @property
+ def sandbox_id(self) -> str:
+ return self._sandbox_id
+
+ @property
+ def sandbox_home(self) -> str:
+ return self._home_dir
+
+ @property
+ def workdir(self) -> str:
+ return self._workdir
+
+ @property
+ def tmp_dir(self) -> str:
+ return self._tmp_dir
+
+ def exec(
+ self,
+ cmd: str,
+ *,
+ envs: dict[str, str] | None = None,
+ cwd: str | None = None,
+ timeout: float | None = 60,
+ ) -> ExecResult:
+ run_env = self._build_env(envs)
+ resolved_cwd = self._resolve_cwd(cwd)
+ try:
+ result = subprocess.run(
+ ["bash", "-lc", cmd],
+ cwd=resolved_cwd,
+ env=run_env,
+ capture_output=True,
+ text=True,
+ timeout=timeout,
+ )
+ return ExecResult(
+ exit_code=int(result.returncode),
+ stdout=result.stdout,
+ stderr=result.stderr,
+ )
+ except subprocess.TimeoutExpired:
+ return ExecResult(
+ exit_code=-1,
+ stdout="",
+ stderr=f"Command timed out after {timeout}s",
+ )
+ except Exception as exc:
+ return ExecResult(exit_code=-1, stdout="", stderr=str(exc))
+
+ def start_bg(
+ self,
+ cmd: str,
+ *,
+ envs: dict[str, str] | None = None,
+ cwd: str | None = None,
+ ) -> BgJob:
+ proc = subprocess.Popen(
+ ["bash", "-lc", cmd],
+ cwd=self._resolve_cwd(cwd),
+ env=self._build_env(envs),
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ text=True,
+ preexec_fn=os.setsid,
+ )
+ job = LocalBgJob(proc)
+ self._bg_jobs.append(job)
+ return job
+
+ def write_text(self, path: str, content: str) -> None:
+ resolved = Path(self._resolve_path(path))
+ resolved.parent.mkdir(parents=True, exist_ok=True)
+ resolved.write_text(content)
+
+ def read_text(self, path: str) -> str:
+ return Path(self._resolve_path(path)).read_text()
+
+ def exists(self, path: str) -> bool:
+ return Path(self._resolve_path(path)).exists()
+
+ def kill(self) -> None:
+ for job in self._bg_jobs:
+ try:
+ job.kill()
+ except Exception:
+ pass
+ self._bg_jobs.clear()
+ if not self._preserve_root:
+ shutil.rmtree(self._root_dir, ignore_errors=True)
+
+ def _resolve_cwd(self, cwd: str | None) -> str:
+ candidate = self._resolve_path(cwd) if cwd else self._workdir
+ Path(candidate).mkdir(parents=True, exist_ok=True)
+ return candidate
+
+ def _resolve_path(self, path: str | None) -> str:
+ if not path:
+ return self._workdir
+ if path == _CANONICAL_HOME or path.startswith(f"{_CANONICAL_HOME}/"):
+ suffix = path[len(_CANONICAL_HOME) :].lstrip("/")
+ return str(Path(self._home_dir) / suffix) if suffix else self._home_dir
+ if path == _CANONICAL_WORKDIR or path.startswith(f"{_CANONICAL_WORKDIR}/"):
+ suffix = path[len(_CANONICAL_WORKDIR) :].lstrip("/")
+ return str(Path(self._workdir) / suffix) if suffix else self._workdir
+ return path
+
+ def _build_env(self, envs: dict[str, str] | None) -> dict[str, str]:
+ merged = os.environ.copy()
+ merged.update(self._default_envs)
+ merged.update(envs or {})
+ merged.setdefault("HOME", self._home_dir)
+ merged.setdefault("TMPDIR", self._tmp_dir)
+ merged.setdefault("PIP_CACHE_DIR", str(Path(self._home_dir) / ".cache" / "pip"))
+ merged.setdefault("XDG_CACHE_HOME", str(Path(self._home_dir) / ".cache"))
+ merged.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")
+ return merged
+
+
+class LocalSandboxBackend:
+ """Create host-local sandboxes rooted in unique temp directories."""
+
+ supports_images = False
+
+ def __init__(
+ self,
+ *,
+ root_dir: str | None = None,
+ create_virtualenv: bool = True,
+ python_executable: str | None = None,
+ preserve_root: bool = False,
+ ) -> None:
+ self._root_dir = root_dir
+ self._create_virtualenv = create_virtualenv
+ self._python_executable = python_executable or sys.executable
+ self._preserve_root = preserve_root
+
+ def create(
+ self,
+ *,
+ timeout_s: int = 900,
+ envs: dict[str, str] | None = None,
+ metadata: dict[str, str] | None = None,
+ image: str | None = None,
+ ) -> SandboxHandle:
+ del timeout_s, metadata, image
+
+ if self._root_dir:
+ Path(self._root_dir).mkdir(parents=True, exist_ok=True)
+ root = tempfile.mkdtemp(prefix="openenv_local_", dir=self._root_dir)
+ home = str(Path(root) / "home")
+ workdir = str(Path(root) / "testbed")
+ tmp_dir = str(Path(root) / "tmp")
+ Path(home).mkdir(parents=True, exist_ok=True)
+ Path(workdir).mkdir(parents=True, exist_ok=True)
+ Path(tmp_dir).mkdir(parents=True, exist_ok=True)
+
+ default_envs = dict(envs or {})
+ default_envs.setdefault("HOME", home)
+ default_envs.setdefault("TMPDIR", tmp_dir)
+ default_envs.setdefault(
+ "PIP_CACHE_DIR",
+ os.environ.get("PIP_CACHE_DIR", str(Path(home) / ".cache" / "pip")),
+ )
+ default_envs.setdefault(
+ "XDG_CACHE_HOME",
+ os.environ.get("XDG_CACHE_HOME", str(Path(home) / ".cache")),
+ )
+ default_envs.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")
+
+ venv_dir = Path(home) / "venv"
+ if self._create_virtualenv:
+ subprocess.run(
+ [self._python_executable, "-m", "venv", str(venv_dir)],
+ check=True,
+ capture_output=True,
+ text=True,
+ timeout=120,
+ )
+ path_prefix = str(venv_dir / "bin")
+ default_envs["VIRTUAL_ENV"] = str(venv_dir)
+ default_envs["PATH"] = (
+ f"{path_prefix}:{os.environ.get('PATH', '')}".rstrip(":")
+ )
+
+ return LocalSandboxHandle(
+ root_dir=root,
+ home_dir=home,
+ workdir=workdir,
+ tmp_dir=tmp_dir,
+ default_envs=default_envs,
+ preserve_root=self._preserve_root,
+ )
+
+
+__all__ = [
+ "LocalBgJob",
+ "LocalSandboxBackend",
+ "LocalSandboxHandle",
+]
diff --git a/tests/core/test_local_sandbox_backend.py b/tests/core/test_local_sandbox_backend.py
new file mode 100644
index 000000000..eed2ef1d8
--- /dev/null
+++ b/tests/core/test_local_sandbox_backend.py
@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from openenv.core.harness.sandbox.local_backend import LocalSandboxBackend
+
+
+def test_local_sandbox_backend_basic_lifecycle(tmp_path):
+ backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=False)
+ sandbox = backend.create(envs={"OPENENV_TEST": "1"})
+
+ root = Path(sandbox.sandbox_home).parent
+ try:
+ result = sandbox.exec(
+ 'printf "%s|%s|%s" "$HOME" "$TMPDIR" "$OPENENV_TEST"',
+ cwd="/testbed",
+ )
+ assert result.exit_code == 0
+ assert result.stdout == f"{sandbox.sandbox_home}|{sandbox.tmp_dir}|1"
+
+ sandbox.write_text("/testbed/hello.txt", "hello\n")
+ assert sandbox.read_text("/testbed/hello.txt") == "hello\n"
+ assert sandbox.exists("/testbed/hello.txt")
+
+ job = sandbox.start_bg("sleep 0.1", cwd="/testbed")
+ assert job.wait(timeout=2.0) == 0
+ finally:
+ sandbox.kill()
+
+ assert not root.exists()
+
+
+def test_local_sandbox_backend_creates_virtualenv(tmp_path):
+ backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=True)
+ sandbox = backend.create()
+
+ try:
+ result = sandbox.exec(
+ 'python -c "import os,sys; print(sys.prefix); print(os.environ.get(\'VIRTUAL_ENV\', \'\'))"'
+ )
+ assert result.exit_code == 0
+ lines = result.stdout.strip().splitlines()
+ assert len(lines) == 2
+ assert lines[0] == f"{sandbox.sandbox_home}/venv"
+ assert lines[1] == f"{sandbox.sandbox_home}/venv"
+ finally:
+ sandbox.kill()
+
+
+def test_local_sandbox_backend_creates_missing_root_dir(tmp_path):
+ root_dir = tmp_path / "nested" / "sandboxes"
+ backend = LocalSandboxBackend(root_dir=str(root_dir), create_virtualenv=False)
+ sandbox = backend.create()
+
+ try:
+ assert root_dir.exists()
+ assert Path(sandbox.sandbox_home).parent.parent == root_dir
+ finally:
+ sandbox.kill()
+
+
+def test_local_sandbox_backend_inherits_host_cache_envs(tmp_path, monkeypatch):
+ monkeypatch.setenv("PIP_CACHE_DIR", "/shared/pip-cache")
+ monkeypatch.setenv("XDG_CACHE_HOME", "/shared/xdg-cache")
+
+ backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=False)
+ sandbox = backend.create()
+
+ try:
+ result = sandbox.exec(
+ 'printf "%s|%s" "$PIP_CACHE_DIR" "$XDG_CACHE_HOME"',
+ cwd="/testbed",
+ )
+ assert result.exit_code == 0
+ assert result.stdout == "/shared/pip-cache|/shared/xdg-cache"
+ finally:
+ sandbox.kill()
diff --git a/tests/envs/test_swe_async_rollout_worker.py b/tests/envs/test_swe_async_rollout_worker.py
index a8bb0e04c..24c26c15f 100644
--- a/tests/envs/test_swe_async_rollout_worker.py
+++ b/tests/envs/test_swe_async_rollout_worker.py
@@ -1,19 +1,128 @@
+import asyncio
import sys
from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
_ROOT = Path(__file__).resolve().parents[2]
-if str(_ROOT) not in sys.path:
- sys.path.insert(0, str(_ROOT))
+for _p in (_ROOT, _ROOT / "src", _ROOT / "envs"):
+ if str(_p) not in sys.path:
+ sys.path.insert(0, str(_p))
from examples.mini_swe_env.async_grpo.rollout_worker import (
+ _OMITTED_TOOL_OUTPUT_MARKER,
+ SWERolloutWorker,
+ WorkerConfig,
+ _clamp_max_completion_tokens,
+ _compute_group_advantages,
+ _coerce_token_ids,
+ _extract_xml_tool_calls,
+ _extract_chat_choice_logprobs,
+ _extract_prompt_token_ids,
+ _fit_messages_to_context_window,
_get_tools,
_has_answer_call,
+ _is_context_window_error,
+ _is_retriable_rollout_error,
_is_terminal_non_tool_response,
_make_chat_response,
+ _normalize_chat_choice_message,
_normalize_tool_calls,
+ _retry_completion_tokens_from_context_error,
+ _truncate_messages_for_prompt_budget,
+ _truncate_text_middle,
)
+def test_rollout_grades_partial_work_on_request_idle_timeout(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setattr(SWERolloutWorker, "_init_weight_transfer", lambda self: None)
+
+ class _Session:
+ answer_called = False
+
+ def __init__(self) -> None:
+ self.calls = 0
+ self.delivered = []
+ self.closed = False
+
+ async def next_request(self, timeout_s: float) -> dict[str, object] | None:
+ self.calls += 1
+ if self.calls == 1:
+ return {
+ "messages": [{"role": "user", "content": "fix it"}],
+ "body": {},
+ }
+ raise TimeoutError("no request within timeout")
+
+ async def deliver(
+ self,
+ intercept: dict[str, object],
+ response: dict[str, object],
+ ) -> None:
+ self.delivered.append((intercept, response))
+
+ def verify(self, transcript: list[object]) -> object:
+ return SimpleNamespace(env_reward=0.5)
+
+ def close(self) -> None:
+ self.closed = True
+
+ session = _Session()
+
+ class _Factory:
+ def create(self, *, task: object, episode_id: str) -> _Session:
+ return session
+
+ class _Tokenizer:
+ pad_token_id = 0
+
+ def apply_chat_template(self, messages: object, **kwargs: object) -> list[int]:
+ return [10, 11]
+
+ worker = SWERolloutWorker(
+ session_factory=_Factory(),
+ tasks=[SimpleNamespace(instance_id="repo__issue-1", instruction="fix it")],
+ tokenizer=_Tokenizer(),
+ vllm_base_url="http://vllm",
+ vllm_api_key="token",
+ vllm_model="model",
+ config=WorkerConfig(request_timeout_s=0.01),
+ )
+ monkeypatch.setattr(
+ worker,
+ "_generate",
+ lambda **kwargs: (
+ [10, 11],
+ [12, 13],
+ [-0.2, -0.3],
+ _make_chat_response(
+ {"role": "assistant", "content": "ran a command"},
+ model="model",
+ finish_reason="tool_calls",
+ ),
+ "tool_calls",
+ ),
+ )
+
+ sample = asyncio.run(
+ worker._rollout( # type: ignore[attr-defined]
+ worker._tasks[0],
+ "episode-1",
+ model_version=3,
+ )
+ )
+
+ assert sample is not None
+ assert sample.reward == pytest.approx(0.5)
+ assert sample.model_version == 3
+ assert sample.metrics["turns"] == 1.0
+ assert sample.metrics["request_idle_timeout"] == 1.0
+ assert session.closed is True
+
+
def test_normalize_tool_calls_serializes_arguments_to_json_string() -> None:
calls = _normalize_tool_calls(
[
@@ -68,3 +177,167 @@ def test_get_tools_from_intercept_or_body() -> None:
assert _get_tools({"tools": [tool_schema]}) == [tool_schema]
assert _get_tools({"body": {"tools": [tool_schema]}}) == [tool_schema]
assert _get_tools({}) is None
+
+
+def test_coerce_token_ids_and_prompt_token_ids_are_int_lists() -> None:
+ assert _coerce_token_ids([1, "2", 3]) == [1, 2, 3]
+ assert _coerce_token_ids(["bad"]) == []
+ assert _extract_prompt_token_ids({"prompt_token_ids": [4, "5"]}) == [4, 5]
+
+
+def test_extract_chat_choice_logprobs_pads_missing_values() -> None:
+ choice = {
+ "logprobs": {
+ "content": [
+ {"logprob": -0.1},
+ {},
+ {"logprob": -0.3},
+ ]
+ }
+ }
+ assert _extract_chat_choice_logprobs(choice, expected_len=4) == pytest.approx(
+ [-0.1, 0.0, -0.3, 0.0]
+ )
+
+
+def test_extract_xml_tool_calls_recovers_qwen_style_blocks() -> None:
+ content, tool_calls = _extract_xml_tool_calls(
+ 'Working...\n\n{"name": "answer", "arguments": {}}\n\nDone.'
+ )
+ assert content == "Working...\n\nDone."
+ assert len(tool_calls) == 1
+ assert tool_calls[0]["function"]["name"] == "answer"
+ assert tool_calls[0]["function"]["arguments"] == "{}"
+
+
+def test_normalize_chat_choice_message_parses_xml_tool_calls_from_content() -> None:
+ message = _normalize_chat_choice_message(
+ tokenizer=object(),
+ choice={
+ "message": {
+ "role": "assistant",
+ "content": '\n{"name": "answer", "arguments": {}}\n',
+ "tool_calls": [],
+ }
+ },
+ completion_ids=[],
+ )
+ assert message["content"] == ""
+ assert message["tool_calls"][0]["function"]["name"] == "answer"
+ assert message["tool_calls"][0]["function"]["arguments"] == "{}"
+
+
+def test_clamp_max_completion_tokens_reserves_context_margin() -> None:
+ assert _clamp_max_completion_tokens(
+ prompt_len=2561,
+ requested=1536,
+ max_model_len=4096,
+ ) == 1519
+
+
+def test_retry_completion_tokens_from_context_error_parses_vllm_message() -> None:
+ error_text = (
+ "This model's maximum context length is 4096 tokens. However, you "
+ "requested 1536 output tokens and your prompt contains at least 2561 "
+ "input tokens, for a total of at least 4097 tokens."
+ )
+ assert _retry_completion_tokens_from_context_error(error_text) == (
+ 1519,
+ 4096,
+ 2561,
+ )
+
+
+def test_truncate_text_middle_preserves_edges() -> None:
+ text = "A" * 80 + "B" * 80
+ truncated, changed = _truncate_text_middle(text, max_chars=64)
+ assert changed is True
+ assert len(truncated) <= 64
+ assert truncated.startswith("A" * 10)
+ assert truncated.endswith("B" * 8)
+
+
+def test_truncate_messages_for_prompt_budget_only_trims_long_tool_output() -> None:
+ messages, tool_truncations, assistant_truncations = (
+ _truncate_messages_for_prompt_budget(
+ [
+ {"role": "user", "content": "short"},
+ {"role": "tool", "content": "x" * 200},
+ {"role": "assistant", "content": "ok"},
+ ],
+ max_tool_message_chars=64,
+ max_assistant_message_chars=64,
+ )
+ )
+ assert tool_truncations == 1
+ assert assistant_truncations == 0
+ assert messages[0]["content"] == "short"
+ assert len(messages[1]["content"]) <= 64
+ assert messages[2]["content"] == "ok"
+
+
+def test_fit_messages_to_context_window_omits_oldest_tool_output_when_needed() -> None:
+ def render_prompt_ids(
+ messages: list[dict[str, object]],
+ tools: list[dict[str, object]] | None,
+ ) -> list[int]:
+ total = 0
+ for message in messages:
+ content = message.get("content")
+ if isinstance(content, str):
+ total += len(content)
+ return list(range(total))
+
+ messages, prompt_ids = _fit_messages_to_context_window(
+ messages=[
+ {"role": "user", "content": "task"},
+ {"role": "tool", "content": "x" * 220},
+ {"role": "assistant", "content": "thinking"},
+ {"role": "tool", "content": "y" * 220},
+ ],
+ tools=None,
+ render_prompt_ids=render_prompt_ids,
+ requested_completion_tokens=40,
+ max_model_len=116,
+ max_tool_message_chars=64,
+ min_tool_message_chars=32,
+ max_assistant_message_chars=32,
+ min_assistant_message_chars=16,
+ )
+ assert len(prompt_ids) <= 60
+ assert messages[1]["content"] == _OMITTED_TOOL_OUTPUT_MARKER
+ assert messages[3]["content"] == _OMITTED_TOOL_OUTPUT_MARKER
+
+
+def test_context_window_errors_are_not_marked_retriable() -> None:
+ exc = RuntimeError(
+ "vllm 400: This model's maximum context length is 8192 tokens. "
+ "However, you requested 1 output tokens and your prompt contains "
+ "at least 8192 input tokens."
+ )
+ assert _is_context_window_error(exc) is True
+ assert _is_retriable_rollout_error(exc) is False
+
+
+def test_transient_sandbox_errors_are_marked_retriable() -> None:
+ exc = RuntimeError("HF sandbox tunnel failed: 429 Too Many Requests")
+ assert _is_context_window_error(exc) is False
+ assert _is_retriable_rollout_error(exc) is True
+
+
+def test_compute_group_advantages_zscores_rewards() -> None:
+ advantages, reward_mean, reward_std = _compute_group_advantages(
+ [0.0, 0.0, 1.0, 1.0]
+ )
+ assert reward_mean == pytest.approx(0.5)
+ assert reward_std == pytest.approx(0.5)
+ assert advantages == pytest.approx([-1.0, -1.0, 1.0, 1.0], abs=1e-6)
+
+
+def test_compute_group_advantages_returns_zero_for_constant_rewards() -> None:
+ advantages, reward_mean, reward_std = _compute_group_advantages(
+ [1.0, 1.0, 1.0, 1.0]
+ )
+ assert reward_mean == pytest.approx(1.0)
+ assert reward_std == pytest.approx(0.0)
+ assert advantages == pytest.approx([0.0, 0.0, 0.0, 0.0], abs=1e-6)
diff --git a/tests/envs/test_swe_grading.py b/tests/envs/test_swe_grading.py
new file mode 100644
index 000000000..7c0ca198b
--- /dev/null
+++ b/tests/envs/test_swe_grading.py
@@ -0,0 +1,182 @@
+import sys
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+_ROOT = Path(__file__).resolve().parents[2]
+if str(_ROOT) not in sys.path:
+ sys.path.insert(0, str(_ROOT))
+_ENVS = _ROOT / "envs"
+if str(_ENVS) not in sys.path:
+ sys.path.insert(0, str(_ENVS))
+
+from mini_swe_env.grading import GradingError, grade_from_case_results
+from mini_swe_env.harness import (
+ HOME,
+ SWESession,
+ SWEAgentConfig,
+ SWESessionFactory,
+ _wrap_instruction,
+)
+from mini_swe_env.models import SWEGymTask, SWETask
+from openenv.core.harness.sandbox.base import ExecResult
+
+
+def _task() -> SWEGymTask:
+ return SWEGymTask(
+ instance_id="demo__task-1",
+ repo="demo/repo",
+ base_commit="deadbeef",
+ problem_statement="Fix the bug.",
+ version="1.0",
+ patch="",
+ test_patch="",
+ FAIL_TO_PASS=["tests/test_a.py::test_fix"],
+ PASS_TO_PASS=["tests/test_b.py::test_regression"],
+ )
+
+
+def test_grade_from_case_results_binary_mode_stays_sparse() -> None:
+ grade = grade_from_case_results(
+ _task(),
+ {
+ "tests/test_a.py::test_fix": True,
+ "tests/test_b.py::test_regression": False,
+ },
+ )
+ assert grade.case_fraction == pytest.approx(0.5)
+ assert grade.reward == pytest.approx(0.0)
+ assert grade.resolved is False
+
+
+def test_grade_from_case_results_case_fraction_mode_is_dense() -> None:
+ grade = grade_from_case_results(
+ _task(),
+ {
+ "tests/test_a.py::test_fix": True,
+ "tests/test_b.py::test_regression": False,
+ },
+ reward_mode="case_fraction",
+ )
+ assert grade.case_fraction == pytest.approx(0.5)
+ assert grade.reward == pytest.approx(0.5)
+ assert grade.resolved is False
+
+
+def test_grade_from_case_results_rejects_unknown_reward_mode() -> None:
+ with pytest.raises(GradingError):
+ grade_from_case_results(_task(), {}, reward_mode="unknown")
+
+
+def test_wrap_instruction_includes_optional_hints_block() -> None:
+ wrapped = _wrap_instruction(
+ "Fix the bug.",
+ hints_text="Look at module.py around line 10.",
+ workdir="/testbed",
+ )
+ assert "" in wrapped
+ assert "Look at module.py around line 10." in wrapped
+ assert "You cannot continue working after submitting" in wrapped
+ assert "Each bash tool call runs in a fresh shell." in wrapped
+ assert "Test edits do not help." in wrapped
+ assert "Do not use `git commit`, `git branch`, or `git push`." in wrapped
+ assert "Start with the maintainer hints or issue description." in wrapped
+
+
+def test_wrap_instruction_supports_fallback_only_grading_mode() -> None:
+ wrapped = _wrap_instruction(
+ "Fix the bug.",
+ workdir="/testbed",
+ answer_tool_enabled=False,
+ )
+ assert "There is no `answer` tool in this run." in wrapped
+ assert "graded automatically when the session ends." in wrapped
+ assert "If the `answer` tool is available" in wrapped
+
+
+def test_swe_session_verify_uses_host_fallback_grader() -> None:
+ task = SWETask(
+ task_id="task-1",
+ source="unit",
+ instance_id="demo__task-1",
+ repo="demo/repo",
+ base_commit="deadbeef",
+ instruction="Fix the bug.",
+ setup=[],
+ verify=[],
+ )
+ session = SWESession(
+ swe_task=task,
+ spec=SimpleNamespace(name="opencode"),
+ sandbox=SimpleNamespace(),
+ task=SimpleNamespace(instruction="Fix the bug."),
+ config=SWEAgentConfig(sandbox_home=HOME, workdir="/testbed"),
+ )
+
+ def _grader(_sandbox, swe_task, *, home: str, workdir: str) -> tuple[float, bool]:
+ assert swe_task.instance_id == "demo__task-1"
+ assert home == HOME
+ assert workdir == "/testbed"
+ return 0.5, False
+
+ session._fallback_grader = _grader
+ result = session.verify(transcript=[])
+
+ assert result.env_reward == pytest.approx(0.5)
+ assert result.metrics["reward_source"] == "host_verify_fallback"
+ assert result.metrics["resolved"] is False
+
+
+def test_list_changed_test_paths_filters_to_test_like_files() -> None:
+ class _Sandbox:
+ def exec(self, cmd: str, *, cwd=None, timeout=None): # type: ignore[no-untyped-def]
+ if cmd == "git diff --name-only HEAD --":
+ return ExecResult(
+ exit_code=0,
+ stdout=(
+ "moto/ssm/models.py\n"
+ "tests/test_ssm/test_ssm_boto3.py\n"
+ "pkg/widget_test.py\n"
+ ),
+ stderr="",
+ )
+ if cmd == "git ls-files --others --exclude-standard":
+ return ExecResult(
+ exit_code=0,
+ stdout="notes.txt\ntesting/helpers/new_case.py\n",
+ stderr="",
+ )
+ raise AssertionError(f"unexpected command: {cmd}")
+
+ paths = SWESessionFactory._list_changed_test_paths(_Sandbox(), workdir="/testbed")
+ assert paths == [
+ "pkg/widget_test.py",
+ "testing/helpers/new_case.py",
+ "tests/test_ssm/test_ssm_boto3.py",
+ ]
+
+
+def test_prepare_repo_uses_configurable_checkout_timeout(monkeypatch) -> None:
+ monkeypatch.setenv("SWE_GIT_CHECKOUT_TIMEOUT_S", "240")
+ calls = []
+
+ class _Sandbox:
+ def exec(self, cmd: str, *, cwd=None, timeout=None): # type: ignore[no-untyped-def]
+ calls.append((cmd, cwd, timeout))
+ return ExecResult(exit_code=0, stdout="", stderr="")
+
+ task = SimpleNamespace(repo="demo/repo", base_commit="deadbeef")
+
+ SWESessionFactory._prepare_repo(
+ object(),
+ _Sandbox(),
+ task,
+ workdir="/testbed",
+ )
+
+ assert (
+ "git checkout --quiet deadbeef",
+ "/testbed",
+ 240,
+ ) in calls
diff --git a/tests/envs/test_train_swe_async_grpo.py b/tests/envs/test_train_swe_async_grpo.py
new file mode 100644
index 000000000..3eacfcb60
--- /dev/null
+++ b/tests/envs/test_train_swe_async_grpo.py
@@ -0,0 +1,340 @@
+import sys
+from pathlib import Path
+
+import pytest
+import torch
+from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
+from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
+
+from peft import LoraConfig, TaskType, get_peft_model
+
+_ROOT = Path(__file__).resolve().parents[2]
+if str(_ROOT) not in sys.path:
+ sys.path.insert(0, str(_ROOT))
+
+import examples.mini_swe_env.train_swe_async_grpo as train_mod
+from examples.mini_swe_env.train_swe_async_grpo import (
+ _chunked_logprob_backbone,
+ _default_lora_target_modules,
+ _filter_async_grpo_kwargs,
+ _lora_config_from_env,
+ _model_context_limit,
+ _patch_chunked_lm_head_for_wrapped_causal_lm,
+ _patch_lora_weight_streaming,
+ _patch_peft_lora_resume_tp_sharding_for_ddp,
+ _select_task_indices,
+ _sanitize_lora_merged_weight_name,
+)
+
+
+def test_select_task_indices_uses_offset_stride_and_repeat() -> None:
+ selected = _select_task_indices(
+ total_tasks=20,
+ max_tasks=3,
+ task_offset=2,
+ task_stride=4,
+ task_indices_raw="",
+ repeat_tasks=2,
+ )
+ assert selected == [2, 6, 10, 2, 6, 10]
+
+
+def test_select_task_indices_prefers_explicit_indices() -> None:
+ selected = _select_task_indices(
+ total_tasks=20,
+ max_tasks=5,
+ task_offset=0,
+ task_stride=1,
+ task_indices_raw="16,3,16",
+ repeat_tasks=1,
+ )
+ assert selected == [16, 3, 16]
+
+
+def test_select_task_indices_rejects_out_of_range_index() -> None:
+ with pytest.raises(IndexError):
+ _select_task_indices(
+ total_tasks=5,
+ max_tasks=2,
+ task_offset=0,
+ task_stride=1,
+ task_indices_raw="0,5",
+ repeat_tasks=1,
+ )
+
+
+def test_model_context_limit_prefers_first_positive_candidate(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ class _Cfg:
+ max_position_embeddings = 40960
+ model_max_length = 32768
+
+ monkeypatch.setattr(
+ train_mod.AutoConfig,
+ "from_pretrained",
+ lambda model_name: _Cfg(),
+ )
+
+ assert _model_context_limit("Qwen/Qwen3-8B") == 40960
+
+
+def test_model_context_limit_reads_nested_text_config(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ class _TextCfg:
+ max_position_embeddings = 262144
+
+ class _Cfg:
+ max_position_embeddings = None
+ model_max_length = None
+ max_seq_len = None
+ seq_length = None
+ text_config = _TextCfg()
+
+ monkeypatch.setattr(
+ train_mod.AutoConfig,
+ "from_pretrained",
+ lambda model_name: _Cfg(),
+ )
+
+ assert _model_context_limit("Qwen/Qwen3.5-4B") == 262144
+
+
+def test_model_context_limit_rejects_missing_values(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ class _Cfg:
+ max_position_embeddings = None
+ model_max_length = None
+ max_seq_len = None
+ seq_length = None
+
+ monkeypatch.setattr(
+ train_mod.AutoConfig,
+ "from_pretrained",
+ lambda model_name: _Cfg(),
+ )
+
+ with pytest.raises(ValueError):
+ _model_context_limit("missing-context-model")
+
+
+def test_default_lora_target_modules_for_qwen() -> None:
+ assert _default_lora_target_modules("Qwen/Qwen3-14B") == (
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ )
+
+
+def test_filter_async_grpo_kwargs_accepts_ddp_lora_checkpointing_controls() -> None:
+ filtered = _filter_async_grpo_kwargs(
+ {
+ "gradient_checkpointing_kwargs": {"use_reentrant": False},
+ "ddp_find_unused_parameters": False,
+ "not_a_training_arg": True,
+ }
+ )
+
+ assert filtered["gradient_checkpointing_kwargs"] == {"use_reentrant": False}
+ assert filtered["ddp_find_unused_parameters"] is False
+ assert "not_a_training_arg" not in filtered
+
+
+def test_sanitize_lora_merged_weight_name_strips_wrapper_paths() -> None:
+ assert (
+ _sanitize_lora_merged_weight_name(
+ "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight"
+ )
+ == "model.layers.0.self_attn.q_proj.weight"
+ )
+ assert (
+ _sanitize_lora_merged_weight_name(
+ "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
+ )
+ is None
+ )
+
+
+def test_lora_config_from_env_returns_full_when_disabled(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.delenv("SWE_LORA", raising=False)
+ config, summary = _lora_config_from_env("Qwen/Qwen3-8B")
+ assert config is None
+ assert summary == "full"
+
+
+def test_chunked_logprob_backbone_uses_decoder_for_wrapped_causal_lm() -> None:
+ cfg = Qwen3Config(
+ vocab_size=128,
+ hidden_size=64,
+ intermediate_size=128,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ )
+ base_model = Qwen3ForCausalLM(cfg)
+ peft_model = get_peft_model(
+ base_model,
+ LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=4,
+ lora_alpha=8,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ ),
+ )
+
+ assert _chunked_logprob_backbone(base_model) is base_model.model
+ assert _chunked_logprob_backbone(peft_model) is peft_model.model.model
+
+
+def test_patch_chunked_lm_head_for_wrapped_causal_lm_handles_peft_model() -> None:
+ cfg = Qwen3Config(
+ vocab_size=128,
+ hidden_size=64,
+ intermediate_size=128,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ )
+ peft_model = get_peft_model(
+ Qwen3ForCausalLM(cfg),
+ LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=4,
+ lora_alpha=8,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ ),
+ )
+ _patch_chunked_lm_head_for_wrapped_causal_lm(
+ peft_model,
+ temperature=1.0,
+ chunk_size=32,
+ )
+
+ input_ids = torch.randint(0, cfg.vocab_size, (2, 8))
+ attention_mask = torch.ones_like(input_ids)
+ outputs = peft_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ labels=input_ids,
+ completion_mask=attention_mask,
+ use_cache=False,
+ )
+
+ assert outputs["log_probs"].shape == (2, 7)
+ assert outputs["entropy"].shape == (2, 7)
+
+
+def test_patch_lora_weight_streaming_unwraps_distributed_model() -> None:
+ cfg = Qwen3Config(
+ vocab_size=128,
+ hidden_size=64,
+ intermediate_size=128,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ )
+ peft_model = get_peft_model(
+ Qwen3ForCausalLM(cfg),
+ LoraConfig(
+ task_type=TaskType.CAUSAL_LM,
+ r=4,
+ lora_alpha=8,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ ),
+ )
+
+ class _Wrapped(torch.nn.Module):
+ def __init__(self, module: torch.nn.Module) -> None:
+ super().__init__()
+ self.module = module
+
+ class _Accelerator:
+ device = torch.device("cpu")
+
+ @staticmethod
+ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ return model.module
+
+ class _Trainer:
+ model = _Wrapped(peft_model)
+ accelerator = _Accelerator()
+
+ trainer = _Trainer()
+ _patch_lora_weight_streaming(trainer) # type: ignore[arg-type]
+ streamed = dict(trainer._streaming_iter()) # type: ignore[attr-defined]
+
+ assert any(name.endswith("self_attn.q_proj.weight") for name in streamed)
+ assert all(".lora_" not in name for name in streamed)
+
+
+def test_patch_peft_lora_resume_tp_sharding_skips_plain_ddp_model(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ import peft.utils.save_and_load as save_and_load
+
+ def _raise_if_called(
+ model: torch.nn.Module,
+ state_dict: dict[str, torch.Tensor],
+ adapter_name: str,
+ ) -> None:
+ raise AssertionError("TP sharding should not run for plain DDP models")
+
+ monkeypatch.setattr(
+ save_and_load,
+ "_maybe_shard_state_dict_for_tp",
+ _raise_if_called,
+ )
+
+ model = torch.nn.Linear(2, 2)
+ _patch_peft_lora_resume_tp_sharding_for_ddp(model)
+ save_and_load._maybe_shard_state_dict_for_tp(model, {}, "default")
+
+
+def test_patch_peft_lora_resume_tp_sharding_preserves_tp_model(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ import peft.utils.save_and_load as save_and_load
+
+ called = False
+
+ def _mark_called(
+ model: torch.nn.Module,
+ state_dict: dict[str, torch.Tensor],
+ adapter_name: str,
+ ) -> None:
+ nonlocal called
+ called = True
+
+ monkeypatch.setattr(
+ save_and_load,
+ "_maybe_shard_state_dict_for_tp",
+ _mark_called,
+ )
+
+ class _Base(torch.nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self._hf_tp_plan = "colwise"
+ self._hf_device_mesh = object()
+
+ class _LoraLike(torch.nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.base = _Base()
+
+ def get_base_layer(self) -> torch.nn.Module:
+ return self.base
+
+ model = _LoraLike()
+ _patch_peft_lora_resume_tp_sharding_for_ddp(model)
+ save_and_load._maybe_shard_state_dict_for_tp(model, {}, "default")
+
+ assert called