diff --git a/envs/mini_swe_env/grading.py b/envs/mini_swe_env/grading.py index d09f46a4c..d0c6c120d 100644 --- a/envs/mini_swe_env/grading.py +++ b/envs/mini_swe_env/grading.py @@ -4,16 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Binary grading for SWE-Gym tasks. +"""Grading helpers for SWE-Gym tasks. -This module is intentionally SWE-Gym-native and does not depend on -external repo/version parser maps. +Supports both the canonical binary SWE-Gym reward and an optional +case-fraction shaping mode for RL experiments: -The grading contract matches SWE-Gym semantics: - -- reward = 1.0 iff every FAIL_TO_PASS test passes AND every PASS_TO_PASS - test still passes. -- reward = 0.0 otherwise. +- ``binary``: reward = 1.0 iff every FAIL_TO_PASS test passes AND every + PASS_TO_PASS test still passes, else 0.0. +- ``case_fraction``: reward = passed_cases / total_cases, while + ``resolved`` still requires every FAIL_TO_PASS and PASS_TO_PASS case to + pass. """ from __future__ import annotations @@ -32,6 +32,7 @@ class GradeResult: __slots__ = ( "reward", + "case_fraction", "resolved", "patch_applied", "tests_status", @@ -42,12 +43,14 @@ def __init__( self, *, reward: float, + case_fraction: float, resolved: bool, patch_applied: bool, tests_status: dict[str, Any] | None = None, instance_id: str = "", ): self.reward = reward + self.case_fraction = case_fraction self.resolved = resolved self.patch_applied = patch_applied self.tests_status = tests_status @@ -55,7 +58,8 @@ def __init__( def __repr__(self) -> str: return ( - f"GradeResult(reward={self.reward}, resolved={self.resolved}, " + f"GradeResult(reward={self.reward}, case_fraction={self.case_fraction}, " + f"resolved={self.resolved}, " f"patch_applied={self.patch_applied}, instance_id={self.instance_id!r})" ) @@ -65,6 +69,7 @@ def grade_from_case_results( case_results: dict[str, bool], *, patch_applied: bool = True, + reward_mode: str = "binary", ) -> GradeResult: """Grade directly from per-test-case outcomes. @@ -72,14 +77,24 @@ def grade_from_case_results( task: SWE-Gym task with FAIL_TO_PASS and PASS_TO_PASS lists. case_results: Mapping ``test_case -> passed``. patch_applied: Whether test patch was successfully applied. + reward_mode: ``"binary"`` or ``"case_fraction"``. """ if not isinstance(case_results, dict): raise GradingError("case_results must be a dict[str, bool]") + if reward_mode not in {"binary", "case_fraction"}: + raise GradingError( + f"Unknown reward_mode {reward_mode!r}; expected 'binary' or 'case_fraction'" + ) def _passed(case: str) -> bool: return bool(case_results.get(case, False)) + all_cases = list(dict.fromkeys([*task.FAIL_TO_PASS, *task.PASS_TO_PASS])) + passed_cases = sum(1 for case in all_cases if _passed(case)) + total_cases = len(all_cases) + case_fraction = (passed_cases / total_cases) if total_cases else 0.0 + resolved = all(_passed(case) for case in task.FAIL_TO_PASS) and all( _passed(case) for case in task.PASS_TO_PASS ) @@ -95,8 +110,13 @@ def _passed(case: str) -> bool: }, } + reward = 1.0 if resolved else 0.0 + if reward_mode == "case_fraction": + reward = case_fraction + return GradeResult( - reward=1.0 if resolved else 0.0, + reward=reward, + case_fraction=case_fraction, resolved=resolved, patch_applied=patch_applied, tests_status=tests_status, diff --git a/envs/mini_swe_env/harness.py b/envs/mini_swe_env/harness.py index e00255d1f..17ec2f815 100644 --- a/envs/mini_swe_env/harness.py +++ b/envs/mini_swe_env/harness.py @@ -50,11 +50,12 @@ import json import queue as _queue_mod import logging +import os import shlex import time import uuid -from dataclasses import dataclass, field -from typing import Any, Literal +from dataclasses import dataclass, field, replace +from typing import Any, Callable, Literal from openenv.core.harness import Message, ResourceSessionFactory, VerifyResult from openenv.core.harness.agents import get_agent_spec @@ -79,6 +80,32 @@ VERIFY_TIMEOUT_S = 300 SETUP_TIMEOUT_S = 600 +DEFAULT_GIT_CHECKOUT_TIMEOUT_S = 300 + + +def _git_checkout_timeout_s() -> int: + for name in ("SWE_GIT_CHECKOUT_TIMEOUT_S", "SWE_LOCAL_GIT_CHECKOUT_TIMEOUT_S"): + value = os.environ.get(name) + if value is None: + continue + try: + timeout_s = int(value) + except ValueError: + _log.warning( + "%s must be an integer; using %ss", + name, + DEFAULT_GIT_CHECKOUT_TIMEOUT_S, + ) + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S + if timeout_s > 0: + return timeout_s + _log.warning( + "%s must be positive; using %ss", + name, + DEFAULT_GIT_CHECKOUT_TIMEOUT_S, + ) + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S _ANSWER_TOOL_DEFINITION = { "type": "function", @@ -101,43 +128,103 @@ {problem_statement} +{hints_block} + # Task Instructions ## Overview -You're a software engineer working on a codebase at /testbed. +You're a software engineer working on a codebase at {workdir}. Your task is to fix the issue described in the PR description above by making changes to the source code (non-test files). ## Important Boundaries -- MODIFY: Regular source code files in /testbed +- MODIFY: Regular source code files in {workdir} - DO NOT MODIFY: Tests, configuration files (pyproject.toml, setup.cfg, etc.) +- Test edits do not help. Grading may restore evaluation tests before running. +- Plain text does not execute anything. To inspect files, run commands, edit + code, or submit a fix, you must use your available tools. +- The repo environment is already bootstrapped. Do not create a new virtualenv + or reinstall the project unless a command clearly shows it is necessary. +- Do not use `git commit`, `git branch`, or `git push`. Grading only checks the + final working tree state. +- Each bash tool call runs in a fresh shell. If commands depend on shell state, + combine them into one bash invocation; `source` and shell variables do not + persist across separate tool calls. +- Do not claim a file was changed or a test passed unless you actually used a + tool and observed that result. ## Recommended Workflow -1. Analyze the codebase by finding and reading relevant files -2. Create a script to reproduce the issue -3. Edit the source code to resolve the issue -4. Verify your fix works by running your script again -5. Test edge cases to ensure your fix is robust - -## Submitting Your Answer -When you've completed your work and verified your fix, call the `answer` -tool to submit your solution for grading. This runs the test suite and -returns whether the issue is resolved. - -You cannot continue working after submitting — make sure your fix is -tested before calling `answer`. +1. Start with the maintainer hints or issue description. If they point to a + likely source file or function, inspect that before broad repo-wide searches. + If the issue mentions an identifier or exact error text, grep for that first. +2. Reproduce the issue with focused commands instead of exhaustive greps over + unrelated tests. +3. Edit the source code to resolve the issue. +4. Verify the fix with targeted commands, then check for obvious regressions. +5. If the `answer` tool is available, call it only after you have actually + changed source code and verified the result. + +{submission_block} """ -def _wrap_instruction(problem_statement: str) -> str: +def _bool_env(name: str, default: bool) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _answer_tool_enabled() -> bool: + return _bool_env("SWE_ENABLE_ANSWER_TOOL", True) + + +def _wrap_instruction( + problem_statement: str, + *, + hints_text: str = "", + workdir: str = TESTBED, + answer_tool_enabled: bool = True, +) -> str: """Wrap a problem statement with SWE-Gym-style task instructions. Tells the agent about the workflow, boundaries, and crucially about the ``answer`` tool for submission. """ + hints = (hints_text or "").strip() + hints_block = "" + if hints: + hints_block = ( + "\n" + "Additional context from issue triage or maintainers:\n" + f"{hints}\n" + "" + ) + if answer_tool_enabled: + submission_block = ( + "## Submitting Your Answer\n" + "When you've completed your work and verified your fix, call the " + "`answer`\n" + "tool to submit your solution for grading. This runs the test suite and\n" + "returns whether the issue is resolved.\n\n" + "You cannot continue working after submitting — make sure your fix is\n" + "tested before calling `answer`." + ) + else: + submission_block = ( + "## Ending The Run\n" + "There is no `answer` tool in this run.\n" + "Keep working until you have made and checked the best source-code " + "fix you can.\n" + "Your final repo state will be graded automatically when the session " + "ends." + ) return _SWE_INSTRUCTION_TEMPLATE.format( problem_statement=problem_statement, + hints_block=hints_block, + workdir=workdir, + submission_block=submission_block, ) @@ -197,6 +284,9 @@ def __init__( self._answer_reward_source: str | None = None self._answer_called = False self._answer_bridged = False + self._fallback_grader: ( + Callable[..., tuple[float, bool]] | None + ) = None @property def swe_task(self) -> SWETask: @@ -294,7 +384,9 @@ def verify( t0 = time.time() try: r = self.sandbox.exec( - cmd, cwd=TESTBED, timeout=self._verify_timeout_s + cmd, + cwd=self.config.workdir, + timeout=self._verify_timeout_s, ) detail = { "cmd": cmd, @@ -332,7 +424,37 @@ def verify( }, ) - # 4. No reward source — agent didn't call answer, no verify cmds. + # 4. Final-state fallback: grade the sandbox even if the agent forgot + # to call answer(). This preserves valid reward signal for training. + if self._fallback_grader is not None: + try: + reward, resolved = self._fallback_grader( + self.sandbox, + self._swe_task, + home=self.config.sandbox_home, + workdir=self.config.workdir, + ) + return VerifyResult( + env_reward=float(reward), + done=True, + metrics={ + "instance_id": self._swe_task.instance_id, + "reward_source": "host_verify_fallback", + "resolved": bool(resolved), + "answer_called": False, + "answer_bridged": False, + }, + artifacts={ + "task_id": self._swe_task.task_id, + }, + ) + except Exception: + _log.exception( + "fallback grading failed for %s", + self._swe_task.instance_id, + ) + + # 5. No reward source — agent didn't call answer, no verify cmds. return VerifyResult( env_reward=0.0, done=True, @@ -476,6 +598,10 @@ def create( swe_task = coerce_swe_task(task) validate_swe_task(swe_task) + backend_supports_images = bool( + getattr(self._backend, "supports_images", True) + ) + requested_image = swe_task.sandbox_image if backend_supports_images else None sandbox_timeout = int(self._config.agent_timeout_s) + 600 sandbox = self._backend.create( timeout_s=sandbox_timeout, @@ -484,17 +610,31 @@ def create( if episode_id else {"instance_id": swe_task.instance_id} ), - image=swe_task.sandbox_image, + image=requested_image, + ) + + session_config = replace( + self._config, + sandbox_home=self._resolve_sandbox_home(sandbox), + workdir=self._resolve_workdir(sandbox), ) try: - if not swe_task.sandbox_image: - self._prepare_repo(sandbox, swe_task) + if not requested_image: + self._prepare_repo(sandbox, swe_task, workdir=session_config.workdir) + self._bootstrap_local_repo_env( + sandbox, + swe_task, + config=session_config, + ) - self._run_setup(sandbox, swe_task) + self._run_setup(sandbox, swe_task, workdir=session_config.workdir) - agent_task = self._build_agent_task(swe_task) - self._driver._bootstrap_sandbox(sandbox, agent_task, self._config) + agent_task = self._build_agent_task( + swe_task, + workdir=session_config.workdir, + ) + self._driver._bootstrap_sandbox(sandbox, agent_task, session_config) except Exception as exc: _log.error("SWESessionFactory.create: bootstrap failed: %r", exc) @@ -516,9 +656,15 @@ def create( rollout_id, ) - agent_task = self._build_agent_task(swe_task) + agent_task = self._build_agent_task( + swe_task, + workdir=session_config.workdir, + ) agent_bg = self._driver._start_agent( - sandbox, agent_task, self._config, base_url_override=base_url_override + sandbox, + agent_task, + session_config, + base_url_override=base_url_override, ) session = SWESession( @@ -527,7 +673,7 @@ def create( spec=self._spec, sandbox=sandbox, task=agent_task, - config=self._config, + config=session_config, base_url_override=base_url_override, agent_bg_job=agent_bg, interception_server=self._interception_server, @@ -535,19 +681,38 @@ def create( interception_queue=interception_queue, ) - if self._mode == "interception_gate": + if self._mode == "interception_gate" and _answer_tool_enabled(): self._register_answer_tool(session) + session._fallback_grader = self._grade_answer_submission return session # ── Bootstrap helpers ────────────────────────────────────────────────── - def _prepare_repo(self, sandbox: SandboxHandle, task: SWETask) -> None: + def _resolve_sandbox_home(self, sandbox: SandboxHandle) -> str: + home = getattr(sandbox, "sandbox_home", None) + if isinstance(home, str) and home.strip(): + return home + return self._config.sandbox_home + + def _resolve_workdir(self, sandbox: SandboxHandle) -> str: + workdir = getattr(sandbox, "workdir", None) + if isinstance(workdir, str) and workdir.strip(): + return workdir + return self._config.workdir + + def _prepare_repo( + self, + sandbox: SandboxHandle, + task: SWETask, + *, + workdir: str, + ) -> None: """Clone the repo and reset to base_commit.""" - sandbox.exec(f"mkdir -p {TESTBED}", timeout=10) + sandbox.exec(f"mkdir -p {shlex.quote(workdir)}", timeout=10) clone_url = f"https://github.com/{task.repo}.git" r = sandbox.exec( - f"git clone --quiet {clone_url} {TESTBED}", + f"git clone --quiet {clone_url} {shlex.quote(workdir)}", timeout=SETUP_TIMEOUT_S, ) if r.exit_code != 0: @@ -556,32 +721,94 @@ def _prepare_repo(self, sandbox: SandboxHandle, task: SWETask) -> None: ) r = sandbox.exec( f"git checkout --quiet {task.base_commit}", - cwd=TESTBED, - timeout=60, + cwd=workdir, + timeout=_git_checkout_timeout_s(), ) if r.exit_code != 0: raise RuntimeError( f"git checkout failed (exit {r.exit_code}): {r.stderr[:500]}" ) - def _run_setup(self, sandbox: SandboxHandle, task: SWETask) -> None: + def _run_setup( + self, + sandbox: SandboxHandle, + task: SWETask, + *, + workdir: str, + ) -> None: """Run task setup commands in the workspace.""" for cmd in task.setup: - r = sandbox.exec(cmd, cwd=TESTBED, timeout=SETUP_TIMEOUT_S) + r = sandbox.exec(cmd, cwd=workdir, timeout=SETUP_TIMEOUT_S) if r.exit_code != 0: raise RuntimeError( f"Setup command failed (exit {r.exit_code}): " f"{cmd[:120]}\nstderr: {(r.stderr or '')[:500]}" ) - def _build_agent_task(self, swe_task: SWETask) -> _SWEAgentTask: + def _bootstrap_local_repo_env( + self, + sandbox: SandboxHandle, + swe_task: SWETask, + *, + config: SWEAgentConfig, + ) -> None: + """Install repo/runtime deps when the backend cannot provide task images. + + SWE-Gym tasks usually rely on prebuilt per-task images. For rootless + local sandboxes we recreate just enough of that environment to run the + repeated-task pilot by installing the repo editable plus common test + dependencies inside the sandbox-local virtualenv. + """ + del swe_task + workdir_q = shlex.quote(config.workdir) + commands = [ + "python -m pip install -U pip setuptools wheel", + "python -m pip install pytest", + ( + f"cd {workdir_q} && (" + "python -m pip install -e .[all] || " + "python -m pip install -e .[tests] || " + "python -m pip install -e .[test] || " + "python -m pip install -e . || " + "python -m pip install .)" + ), + ] + if sandbox.exists(f"{config.workdir}/requirements-tests.txt"): + commands.append( + f"cd {workdir_q} && python -m pip install -r requirements-tests.txt" + ) + elif sandbox.exists(f"{config.workdir}/requirements-test.txt"): + commands.append( + f"cd {workdir_q} && python -m pip install -r requirements-test.txt" + ) + + for cmd in commands: + result = sandbox.exec(cmd, cwd=config.workdir, timeout=SETUP_TIMEOUT_S) + if result.exit_code != 0: + raise RuntimeError( + "local sandbox repo bootstrap failed " + f"(exit {result.exit_code}): {(result.stderr or result.stdout)[-500:]}" + ) + + def _build_agent_task( + self, + swe_task: SWETask, + *, + workdir: str, + ) -> _SWEAgentTask: """Convert SWETask into the shape CLIAgentDriver expects. Wraps the raw problem statement with SWE-Gym-style instructions that tell the agent about the ``answer`` tool. """ + answer_tool_enabled = _answer_tool_enabled() return _SWEAgentTask( - instruction=_wrap_instruction(swe_task.instruction), + instruction=_wrap_instruction( + swe_task.instruction, + hints_text=str((swe_task.metadata or {}).get("hints_text", "") or ""), + workdir=workdir, + answer_tool_enabled=answer_tool_enabled, + ), setup_shell=None, metadata={ "task_id": swe_task.task_id, @@ -614,6 +841,8 @@ async def _answer_handler(arguments: dict[str, Any]) -> dict[str, Any]: self._grade_answer_submission, session.sandbox, session.swe_task, + home=session.config.sandbox_home, + workdir=session.config.workdir, ) session.set_answer_reward(reward, source="host_answer_tool") return { @@ -635,14 +864,26 @@ def _grade_answer_submission( self, sandbox: SandboxHandle, swe_task: SWETask, + *, + home: str, + workdir: str, ) -> tuple[float, bool]: """Compute answer-tool reward on host and return ``(reward, resolved)``.""" try: metadata = swe_task.metadata or {} required = {"version", "patch", "test_patch", "FAIL_TO_PASS"} if required.issubset(metadata): - return self._grade_with_swegym_metadata(sandbox, swe_task) - return self._grade_with_verify_commands(sandbox, swe_task) + return self._grade_with_swegym_metadata( + sandbox, + swe_task, + home=home, + workdir=workdir, + ) + return self._grade_with_verify_commands( + sandbox, + swe_task, + workdir=workdir, + ) except Exception: _log.exception("answer-tool grading failed for %s", swe_task.instance_id) return 0.0, False @@ -651,6 +892,9 @@ def _grade_with_swegym_metadata( self, sandbox: SandboxHandle, swe_task: SWETask, + *, + home: str, + workdir: str, ) -> tuple[float, bool]: """Grade SWE-Gym tasks directly from FAIL/PASS test-case outcomes.""" metadata = swe_task.metadata @@ -672,33 +916,56 @@ def _grade_with_swegym_metadata( ) touched_files = self._extract_paths_from_test_patch(gym_task.test_patch) + changed_test_like_files = self._list_changed_test_paths( + sandbox, + workdir=workdir, + ) + files_to_restore = sorted(set(touched_files) | set(changed_test_like_files)) self._revert_test_files( sandbox, base_commit=swe_task.base_commit, - paths=touched_files, + paths=files_to_restore, strict=True, + workdir=workdir, ) - self._apply_test_patch(sandbox, gym_task.test_patch) - case_results = self._run_swegym_case_tests(sandbox, gym_task) - grade = grade_from_case_results(gym_task, case_results) + self._apply_test_patch(sandbox, gym_task.test_patch, home=home, workdir=workdir) + case_results = self._run_swegym_case_tests( + sandbox, + gym_task, + workdir=workdir, + ) + grade = grade_from_case_results( + gym_task, + case_results, + reward_mode=os.environ.get("SWE_REWARD_MODE", "binary").strip().lower() + or "binary", + ) # Best-effort cleanup in case grading was interrupted. self._revert_test_files( sandbox, base_commit=swe_task.base_commit, - paths=touched_files, + paths=files_to_restore, strict=False, + workdir=workdir, ) return float(grade.reward), bool(grade.resolved) - def _apply_test_patch(self, sandbox: SandboxHandle, test_patch: str) -> None: - patch_path = f"{HOME}/.openenv_swe_test_patch.diff" + def _apply_test_patch( + self, + sandbox: SandboxHandle, + test_patch: str, + *, + home: str, + workdir: str, + ) -> None: + patch_path = f"{home}/.openenv_swe_test_patch.diff" sandbox.write_text(patch_path, test_patch) result = sandbox.exec( f"git apply --whitespace=nowarn {shlex.quote(patch_path)}", - cwd=TESTBED, + cwd=workdir, timeout=30, ) if result.exit_code != 0: @@ -711,6 +978,8 @@ def _run_swegym_case_tests( self, sandbox: SandboxHandle, gym_task: SWEGymTask, + *, + workdir: str, ) -> dict[str, bool]: cases: list[str] = [] seen: set[str] = set() @@ -723,7 +992,7 @@ def _run_swegym_case_tests( results: dict[str, bool] = {} for case in cases: cmd = f"python -m pytest -q --maxfail=1 {shlex.quote(case)}" - run = sandbox.exec(cmd, cwd=TESTBED, timeout=self._verify_timeout_s) + run = sandbox.exec(cmd, cwd=workdir, timeout=self._verify_timeout_s) results[case] = run.exit_code == 0 return results @@ -731,13 +1000,15 @@ def _grade_with_verify_commands( self, sandbox: SandboxHandle, swe_task: SWETask, + *, + workdir: str, ) -> tuple[float, bool]: """Legacy fallback for non-SWE-Gym tasks.""" if not swe_task.verify: return 0.0, False passed = 0 for cmd in swe_task.verify: - r = sandbox.exec(cmd, cwd=TESTBED, timeout=self._verify_timeout_s) + r = sandbox.exec(cmd, cwd=workdir, timeout=self._verify_timeout_s) if r.exit_code == 0: passed += 1 reward = passed / len(swe_task.verify) @@ -755,6 +1026,42 @@ def _extract_paths_from_test_patch(test_patch: str) -> list[str]: paths.append(path) return sorted(set(paths)) + @staticmethod + def _is_test_like_path(path: str) -> bool: + text = (path or "").strip().strip('"') + if not text: + return False + normalized = text.replace("\\", "/") + basename = normalized.rsplit("/", 1)[-1] + if basename == "conftest.py": + return True + if basename.startswith("test_") or basename.endswith("_test.py"): + return True + parts = normalized.split("/") + return any(part in {"test", "tests", "testing"} for part in parts[:-1]) + + @classmethod + def _list_changed_test_paths( + cls, + sandbox: SandboxHandle, + *, + workdir: str, + ) -> list[str]: + commands = ( + "git diff --name-only HEAD --", + "git ls-files --others --exclude-standard", + ) + paths: set[str] = set() + for cmd in commands: + result = sandbox.exec(cmd, cwd=workdir, timeout=10) + if result.exit_code != 0: + continue + for line in (result.stdout or "").splitlines(): + path = line.strip() + if cls._is_test_like_path(path): + paths.add(path) + return sorted(paths) + def _revert_test_files( self, sandbox: SandboxHandle, @@ -762,6 +1069,7 @@ def _revert_test_files( base_commit: str, paths: list[str], strict: bool, + workdir: str, ) -> None: if not paths: return @@ -770,7 +1078,7 @@ def _revert_test_files( for path in paths: has_file = sandbox.exec( f"git cat-file -e {shlex.quote(f'{base_commit}:{path}')}", - cwd=TESTBED, + cwd=workdir, timeout=10, ) if has_file.exit_code == 0: @@ -781,7 +1089,7 @@ def _revert_test_files( else: cmd = f"rm -f -- {shlex.quote(path)}" - result = sandbox.exec(cmd, cwd=TESTBED, timeout=20) + result = sandbox.exec(cmd, cwd=workdir, timeout=20) if result.exit_code != 0: failures.append( f"{path}: {(result.stderr or result.stdout or '').strip()}" diff --git a/envs/mini_swe_env/server/swe_environment.py b/envs/mini_swe_env/server/swe_environment.py index 916d6f938..c6b21341f 100644 --- a/envs/mini_swe_env/server/swe_environment.py +++ b/envs/mini_swe_env/server/swe_environment.py @@ -29,6 +29,7 @@ import json import logging +import os import shlex import time from pathlib import Path @@ -83,6 +84,32 @@ MCP_PORT = 8765 VERIFY_TIMEOUT_S = 300 SETUP_TIMEOUT_S = 600 +DEFAULT_GIT_CHECKOUT_TIMEOUT_S = 300 + + +def _git_checkout_timeout_s() -> int: + for name in ("SWE_GIT_CHECKOUT_TIMEOUT_S", "SWE_LOCAL_GIT_CHECKOUT_TIMEOUT_S"): + value = os.environ.get(name) + if value is None: + continue + try: + timeout_s = int(value) + except ValueError: + _log.warning( + "%s must be an integer; using %ss", + name, + DEFAULT_GIT_CHECKOUT_TIMEOUT_S, + ) + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S + if timeout_s > 0: + return timeout_s + _log.warning( + "%s must be positive; using %ss", + name, + DEFAULT_GIT_CHECKOUT_TIMEOUT_S, + ) + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S + return DEFAULT_GIT_CHECKOUT_TIMEOUT_S # Path to the sandbox_mcp_server.py source alongside this module. _SANDBOX_MCP_SERVER_SOURCE = Path(__file__).parent / "sandbox_mcp_server.py" @@ -509,7 +536,14 @@ def _grade_submission( ) self._apply_test_patch(sandbox, gym_task.test_patch) case_results = self._run_swegym_case_tests(sandbox, gym_task) - return grade_from_case_results(gym_task, case_results) + return grade_from_case_results( + gym_task, + case_results, + reward_mode=( + os.environ.get("SWE_REWARD_MODE", "binary").strip().lower() + or "binary" + ), + ) except Exception as exc: _log.warning("SWE-Gym grading failed, falling back: %s", exc) return None @@ -651,7 +685,7 @@ def _prepare_repo(self, sandbox: Any, task: SWETask) -> None: r = sandbox.exec( f"git checkout --quiet {task.base_commit}", cwd=TESTBED, - timeout=60, + timeout=_git_checkout_timeout_s(), ) if r.exit_code != 0: raise RuntimeError( diff --git a/examples/mini_swe_env/async_grpo/rollout_worker.py b/examples/mini_swe_env/async_grpo/rollout_worker.py index f46cd4df8..545824b00 100644 --- a/examples/mini_swe_env/async_grpo/rollout_worker.py +++ b/examples/mini_swe_env/async_grpo/rollout_worker.py @@ -3,23 +3,21 @@ Implements ``RolloutWorkerProtocol`` from TRL's ``AsyncGRPOTrainer``. Architecture: - Pi (sandbox) → InterceptionServer → this worker → vLLM /v1/completions + Pi (sandbox) → InterceptionServer → this worker → vLLM /v1/chat/completions ← chat response back to Pi Pi drives the generation loop inside the sandbox. This worker: 1. Dequeues each intercepted LLM request from Pi. -2. Tokenizes the messages with ``apply_chat_template``. -3. Calls vLLM ``/v1/completions`` with ``prompt=token_ids``, - ``return_token_ids=True``, ``logprobs=0`` — same as TRL's own - ``AsyncRolloutWorker._generate_one_turn``. -4. Gets exact ``completion_ids`` and ``completion_logprobs`` from vLLM. -5. Wraps the completion text as a chat response and delivers it back to Pi. -6. Tracks multi-turn token sequences matching TRL's pattern: +2. Forwards the intercepted request to vLLM ``/v1/chat/completions``. +3. Requests exact ``prompt_token_ids`` / ``completion_ids`` and per-token + logprobs from vLLM. +4. Delivers the OpenAI-compatible chat response back to Pi. +5. Tracks multi-turn token sequences matching TRL's pattern: ``input_ids = initial_prompt_ids + [turn_ids + suffix_ids]*N`` ``completion_mask = [0]*prompt + [1]*turn + [0]*suffix + ...`` -7. On ``answer()``, bridges to host-side grading. -8. Assembles the final ``RolloutSample`` and pushes to ``rollout_buffer``. +6. On ``answer()``, bridges to host-side grading. +7. Assembles the final ``RolloutSample`` and pushes to ``rollout_buffer``. """ from __future__ import annotations @@ -27,12 +25,15 @@ import asyncio import json import logging +import math +import os import queue +import re import threading import time import uuid from dataclasses import dataclass, field -from typing import Any, Iterator, Sequence, cast +from typing import Any, Callable, Iterator, Sequence, cast import requests @@ -95,17 +96,36 @@ class RolloutSample: metrics: dict[str, Any] = field(default_factory=dict) +@dataclass +class PendingRollout: + """One rollout before group-relative advantage normalization.""" + + input_ids: list[int] + completion_mask: list[int] + old_log_probs: list[float] + reward: float + model_version: int + metrics: dict[str, Any] = field(default_factory=dict) + + # ── Config ───────────────────────────────────────────────────────────── @dataclass(frozen=True) class WorkerConfig: max_inflight: int = 2 + max_rollout_attempts: int = 4 + num_generations: int = 4 queue_maxsize: int = 64 request_timeout_s: float = 600.0 max_turns: int = 50 + max_model_len: int = 4096 max_completion_tokens: int = 2048 temperature: float = 1.0 + max_tool_message_chars: int = 6000 + min_tool_message_chars: int = 256 + max_assistant_message_chars: int = 4000 + min_assistant_message_chars: int = 256 # After returning a terminal plain-text response (finish_reason=stop, # no tool_calls), wait briefly for a follow-up request before treating # the rollout as complete. This avoids 600s stalls when agent exit @@ -113,6 +133,7 @@ class WorkerConfig: post_response_grace_s: float = 10.0 stop_on_idle_terminal_response: bool = True idle_backoff_s: float = 0.5 + failure_backoff_s: float = 30.0 # ── Worker ───────────────────────────────────────────────────────────── @@ -148,6 +169,12 @@ def __init__( self._vllm_api_key = vllm_api_key self._vllm_model = vllm_model self._cfg = config or WorkerConfig() + if not self._tasks: + raise ValueError("SWERolloutWorker requires at least one SWE task") + if self._cfg.num_generations < 2: + raise ValueError( + "WorkerConfig.num_generations must be >= 2 for valid GRPO grouping" + ) self.rollout_buffer: queue.Queue[RolloutSample] = queue.Queue( maxsize=self._cfg.queue_maxsize, @@ -162,6 +189,12 @@ def __init__( self._model_version = 0 self._started = False self._model_update_group: Any | None = None + self._next_group_id = 0 + self._current_group_task: SWETask | None = None + self._current_group_id: int | None = None + self._current_group_model_version: int | None = None + self._group_replica_idx = 0 + self._pending_groups: dict[int, list[PendingRollout]] = {} # Prefix to prepend to trainer parameter names so they match vLLM's # model architecture. For VLM models like Qwen3_5ForConditionalGeneration @@ -228,6 +261,7 @@ def send_weights(self, iterator: Iterator[tuple[str, Any]]) -> None: ) return + items = [(self._vllm_weight_name(name), tensor) for name, tensor in items] names = [name for name, _ in items] # VLM models (e.g. Qwen3_5ForConditionalGeneration) have parameters at @@ -302,6 +336,16 @@ def update_model_version(self, version: int) -> None: with self._lock: self._model_version = version + def _vllm_weight_name(self, name: str) -> str: + """Map trainer-side text model names to vLLM's served module names.""" + model_id = self._vllm_model.lower().replace("_", "") + if "qwen3.5" in model_id or "qwen35" in model_id: + if name.startswith("model."): + return f"language_model.{name}" + if name.startswith("lm_head."): + return f"language_model.{name}" + return name + def _post_json( self, path: str, @@ -324,6 +368,15 @@ def _post_json( return response def _init_weight_transfer(self) -> None: + if os.environ.get("SWE_DISABLE_WEIGHT_TRANSFER", "").strip().lower() in { + "1", + "true", + "yes", + "on", + }: + _log.warning("weight sync disabled by SWE_DISABLE_WEIGHT_TRANSFER") + return + if ( NCCLWeightTransferEngine is None or NCCLTrainerSendWeightsArgs is None @@ -403,16 +456,96 @@ def _destroy_model_update_group(self) -> None: # ── Internal ─────────────────────────────────────────────────── - def _next_task(self) -> SWETask: + def _next_rollout_assignment(self) -> tuple[SWETask, int, int, int]: with self._lock: - t = self._tasks[self._task_idx % len(self._tasks)] - self._task_idx += 1 - return t + if self._group_replica_idx == 0: + self._current_group_task = self._tasks[self._task_idx % len(self._tasks)] + self._task_idx += 1 + self._current_group_id = self._next_group_id + self._current_group_model_version = self._model_version + self._next_group_id += 1 + + assert self._current_group_task is not None + assert self._current_group_id is not None + assert self._current_group_model_version is not None + + task = self._current_group_task + group_id = self._current_group_id + model_version = self._current_group_model_version + replica_idx = self._group_replica_idx + + self._group_replica_idx += 1 + if self._group_replica_idx >= self._cfg.num_generations: + self._group_replica_idx = 0 + self._current_group_task = None + self._current_group_id = None + self._current_group_model_version = None + + return task, group_id, replica_idx, model_version def _model_ver(self) -> int: with self._lock: return self._model_version + def _collect_group_sample( + self, + *, + group_id: int, + rollout: PendingRollout, + ) -> list[RolloutSample] | None: + with self._lock: + bucket = self._pending_groups.setdefault(group_id, []) + bucket.append(rollout) + if len(bucket) < self._cfg.num_generations: + return None + if len(bucket) > self._cfg.num_generations: + raise RuntimeError( + f"received too many rollouts for group {group_id}: " + f"{len(bucket)} > {self._cfg.num_generations}" + ) + group_rollouts = list(bucket) + del self._pending_groups[group_id] + + return self._finalize_group_rollouts( + group_id=group_id, + rollouts=group_rollouts, + ) + + def _finalize_group_rollouts( + self, + *, + group_id: int, + rollouts: Sequence[PendingRollout], + ) -> list[RolloutSample]: + rewards = [float(rollout.reward) for rollout in rollouts] + advantages, reward_mean, reward_std = _compute_group_advantages(rewards) + _log.info( + "rollout group complete: group_id=%d reward_mean=%.4f reward_std=%.4f rewards=%s", + group_id, + reward_mean, + reward_std, + [round(reward, 4) for reward in rewards], + ) + + samples: list[RolloutSample] = [] + for rollout, advantage in zip(rollouts, advantages, strict=True): + metrics = dict(rollout.metrics) + metrics["reward"] = float(rollout.reward) + metrics["reward_mean"] = reward_mean + metrics["reward_std"] = reward_std + metrics["group_size"] = float(len(rollouts)) + samples.append( + RolloutSample( + input_ids=rollout.input_ids, + completion_mask=rollout.completion_mask, + old_log_probs=rollout.old_log_probs, + advantage=advantage, + model_version=rollout.model_version, + metrics=metrics, + ) + ) + return samples + def _loop(self, idx: int) -> None: while not self._stop.is_set(): while self._pause.is_set() and not self._stop.is_set(): @@ -420,27 +553,138 @@ def _loop(self, idx: int) -> None: if self._stop.is_set(): return - task = self._next_task() - eid = f"swe-{idx}-{uuid.uuid4().hex[:8]}" - try: - sample = asyncio.run(self._rollout(task, eid)) - except Exception: - _log.exception("rollout failed worker=%d id=%s", idx, task.instance_id) - time.sleep(self._cfg.idle_backoff_s) - continue + task, group_id, replica_idx, model_version = self._next_rollout_assignment() + eid = f"swe-{idx}-g{group_id}-r{replica_idx}-{uuid.uuid4().hex[:8]}" + rollout_t0 = time.time() + failed = False + sample: PendingRollout | None = None + last_exc: Exception | None = None + + for attempt in range(1, self._cfg.max_rollout_attempts + 1): + try: + sample = asyncio.run( + self._rollout(task, eid, model_version=model_version) + ) + last_exc = None + break + except Exception as exc: + last_exc = exc + retriable = ( + attempt < self._cfg.max_rollout_attempts + and _is_retriable_rollout_error(exc) + ) + if retriable: + _log.warning( + "rollout failed worker=%d id=%s attempt=%d/%d; retrying: %s", + idx, + task.instance_id, + attempt, + self._cfg.max_rollout_attempts, + exc, + ) + time.sleep(self._cfg.failure_backoff_s * attempt) + continue + + _log.exception( + "rollout failed worker=%d id=%s attempts=%d/%d", + idx, + task.instance_id, + attempt, + self._cfg.max_rollout_attempts, + ) + sample = self._build_failed_rollout( + task=task, + elapsed_s=time.time() - rollout_t0, + exc=exc, + model_version=model_version, + ) + failed = True + break + + if sample is None: + if last_exc is None: + time.sleep(self._cfg.idle_backoff_s) + continue + raise RuntimeError( + f"rollout loop exhausted without sample for {task.instance_id}" + ) from last_exc if sample is None: time.sleep(self._cfg.idle_backoff_s) continue - try: - self.rollout_buffer.put(sample, timeout=2.0) - except queue.Full: - _log.warning("queue full, dropping %s", task.instance_id) + group_samples = self._collect_group_sample( + group_id=group_id, + rollout=sample, + ) + if group_samples is not None: + for group_sample in group_samples: + try: + self.rollout_buffer.put(group_sample, timeout=2.0) + except queue.Full: + _log.warning("queue full, dropping group %d", group_id) + break + + if failed: + time.sleep(self._cfg.failure_backoff_s) + + def _build_failed_rollout( + self, + *, + task: SWETask, + elapsed_s: float, + exc: Exception, + model_version: int, + ) -> PendingRollout: + """Return a zero-reward rollout when infrastructure fails before rollout. + + Remote sandbox providers can transiently refuse new jobs. Producing a + neutral sample keeps distributed trainer ranks moving together instead + of letting nonzero ranks block in Accelerate's dataloader broadcast. + """ + prompt_ids = self._render_prompt_ids( + [{"role": "user", "content": task.instruction}], + None, + ) + completion_id = ( + getattr(self._tokenizer, "eos_token_id", None) + or getattr(self._tokenizer, "pad_token_id", None) + or 0 + ) + input_ids = [*prompt_ids, int(completion_id)] + completion_mask = [0] * len(prompt_ids) + [1] + old_log_probs = [0.0] * len(input_ids) + exc_name = type(exc).__name__ + return PendingRollout( + input_ids=input_ids, + completion_mask=completion_mask, + old_log_probs=old_log_probs, + reward=0.0, + model_version=model_version, + metrics={ + "reward": 0.0, + "turns": 0.0, + "answer_called": 0.0, + "terminal_idle_stop": 0.0, + "wall_s": round(elapsed_s, 3), + "n_tokens": float(len(input_ids)), + "rollout_error": 1.0, + "sandbox_create_error": float( + "sandbox" in exc_name.lower() + or "sandbox" in str(exc).lower() + ), + }, + ) # ── Single rollout ───────────────────────────────────────────── - async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None: + async def _rollout( + self, + task: SWETask, + episode_id: str, + *, + model_version: int, + ) -> PendingRollout | None: session = self._factory.create(task=task, episode_id=episode_id) # Accumulate the full token sequence across turns, matching TRL's @@ -467,9 +711,20 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None intercept = pending_intercept pending_intercept = None else: - intercept = await session.next_request( - timeout_s=self._cfg.request_timeout_s, - ) + try: + intercept = await session.next_request( + timeout_s=self._cfg.request_timeout_s, + ) + except TimeoutError: + if turns == 0: + raise + rollout_stop_reason = "request_idle_timeout" + _log.info( + "rollout request idle-timeout: instance_id=%s turns=%d", + task.instance_id, + turns, + ) + break if intercept is None: rollout_stop_reason = "agent_exit_detected" break @@ -477,7 +732,19 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None # ── Tokenize this turn's full prompt ────────────── messages = _get_messages(intercept) tools = _get_tools(intercept) - current_prompt_ids = self._render_prompt_ids(messages, tools) + ( + current_prompt_ids, + turn_ids, + turn_lps, + chat_resp, + _finish_reason, + ) = self._generate( + intercept=intercept, + messages=messages, + tools=tools, + ) + if not current_prompt_ids: + current_prompt_ids = self._render_prompt_ids(messages, tools) if initial_prompt_ids is None: # First turn: the entire prompt is non-completion tokens. @@ -498,11 +765,6 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None turns += 1 - # ── Generate via /v1/completions ────────────────── - turn_ids, turn_lps, text, finish_reason = self._generate( - current_prompt_ids - ) - all_ids.extend(turn_ids) all_mask.extend([1] * len(turn_ids)) all_lps.extend(turn_lps) @@ -510,18 +772,6 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None # For next turn's suffix computation: prev_prompt_ids = current_prompt_ids + turn_ids - # ── Build chat response for Pi ──────────────────── - assistant_message = _parse_assistant_message( - tokenizer=self._tokenizer, - completion_ids=turn_ids, - fallback_text=text, - ) - chat_resp = _make_chat_response( - assistant_message, - self._vllm_model, - finish_reason=finish_reason, - ) - # ── Check for answer tool call ──────────────────── if _has_answer_call(chat_resp): answer_called = True @@ -571,12 +821,12 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None all_mask = [1] all_lps = [0.0] - return RolloutSample( + return PendingRollout( input_ids=all_ids, completion_mask=all_mask, old_log_probs=all_lps, - advantage=reward, - model_version=self._model_ver(), + reward=reward, + model_version=model_version, metrics={ "reward": reward, "turns": float(turns), @@ -588,6 +838,9 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None "agent_exit_after_terminal_stop", } ), + "request_idle_timeout": float( + rollout_stop_reason == "request_idle_timeout" + ), "wall_s": round(time.time() - t0, 3), "n_tokens": float(len(all_ids)), }, @@ -595,7 +848,7 @@ async def _rollout(self, task: SWETask, episode_id: str) -> RolloutSample | None finally: session.close() - # ── vLLM call (matches TRL's _generate_one_turn exactly) ────── + # ── vLLM call ───────────────────────────────────────────────── def _render_prompt_ids( self, @@ -618,42 +871,143 @@ def _render_prompt_ids( def _generate( self, - prompt_ids: list[int], - ) -> tuple[list[int], list[float], str, str | None]: - """POST /v1/completions with token IDs. - - Returns: ``(token_ids, token_logprobs, text, finish_reason)``. - """ - body = { - "model": self._vllm_model, - "prompt": prompt_ids, - "max_tokens": self._cfg.max_completion_tokens, - "temperature": self._cfg.temperature, - "n": 1, - "return_token_ids": True, - "logprobs": 0, - } - resp = requests.post( - f"{self._vllm_base_url}/v1/completions", - headers={ - "Authorization": f"Bearer {self._vllm_api_key}", - "Content-Type": "application/json", - }, - json=body, - timeout=self._cfg.request_timeout_s, + *, + intercept: dict[str, Any], + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[list[int], list[int], list[float], dict[str, Any], str | None]: + """POST /v1/chat/completions and return prompt/output tokens.""" + ( + messages, + rendered_prompt_ids, + ) = _fit_messages_to_context_window( + messages=messages, + tools=tools, + render_prompt_ids=self._render_prompt_ids, + requested_completion_tokens=self._cfg.max_completion_tokens, + max_model_len=self._cfg.max_model_len, + max_tool_message_chars=self._cfg.max_tool_message_chars, + min_tool_message_chars=self._cfg.min_tool_message_chars, + max_assistant_message_chars=self._cfg.max_assistant_message_chars, + min_assistant_message_chars=self._cfg.min_assistant_message_chars, + ) + raw_body = intercept.get("body") + body = dict(raw_body) if isinstance(raw_body, dict) else {} + body["model"] = self._vllm_model + body["messages"] = messages + if tools is not None: + body["tools"] = tools + else: + body.pop("tools", None) + body.pop("stream", None) + body.pop("stream_options", None) + body.pop("max_tokens", None) + body["max_completion_tokens"] = _clamp_max_completion_tokens( + prompt_len=len(rendered_prompt_ids), + requested=self._cfg.max_completion_tokens, + max_model_len=self._cfg.max_model_len, ) - if resp.status_code != 200: + body["temperature"] = self._cfg.temperature + body["n"] = 1 + body["logprobs"] = True + body["top_logprobs"] = 0 + body["return_token_ids"] = True + + if os.environ.get("SWE_LOG_PROMPT_TOKENS", "").strip().lower() in { + "1", + "true", + "yes", + "on", + }: + _log.info( + "rollout prompt window: prompt_tokens=%d max_completion_tokens=%d tools=%d", + len(rendered_prompt_ids), + int(body["max_completion_tokens"]), + len(tools or []), + ) + + chat_template_kwargs = body.get("chat_template_kwargs") + if not isinstance(chat_template_kwargs, dict): + chat_template_kwargs = {} + model_id = self._vllm_model.lower().replace("_", "") + if ( + ("qwen3" in model_id or "deepseek" in model_id) + and "enable_thinking" not in chat_template_kwargs + ): + chat_template_kwargs["enable_thinking"] = False + if chat_template_kwargs: + body["chat_template_kwargs"] = chat_template_kwargs + + if tools and not body.get("tool_choice"): + body["tool_choice"] = "auto" + + for attempt in range(2): + resp = requests.post( + f"{self._vllm_base_url}/v1/chat/completions", + headers={ + "Authorization": f"Bearer {self._vllm_api_key}", + "Content-Type": "application/json", + }, + json=body, + timeout=self._cfg.request_timeout_s, + ) + if resp.status_code == 200: + break + + retry_budget = _retry_completion_tokens_from_context_error(resp.text) + if attempt == 0 and retry_budget is not None: + requested = int(body.get("max_completion_tokens") or 0) + clamped_tokens, max_model_len, prompt_tokens = retry_budget + if clamped_tokens < requested: + _log.info( + "clamped max_completion_tokens from %d to %d " + "(prompt_tokens=%d max_model_len=%d)", + requested, + clamped_tokens, + prompt_tokens, + max_model_len, + ) + body["max_completion_tokens"] = clamped_tokens + continue + raise RuntimeError(f"vllm {resp.status_code}: {resp.text[:400]}") - choice = resp.json()["choices"][0] + payload = resp.json() + choice = payload["choices"][0] + turn_ids = _coerce_token_ids(choice.get("token_ids")) + choice["message"] = _normalize_chat_choice_message( + tokenizer=self._tokenizer, + choice=choice, + completion_ids=turn_ids, + ) + chat_resp = dict(payload) + chat_resp["choices"] = [choice] return ( - choice["token_ids"], - choice["logprobs"]["token_logprobs"], - choice.get("text", ""), + _extract_prompt_token_ids(payload) or rendered_prompt_ids, + turn_ids, + _extract_chat_choice_logprobs(choice, expected_len=len(turn_ids)), + chat_resp, choice.get("finish_reason"), ) +def _compute_group_advantages( + rewards: Sequence[float], + *, + eps: float = 1e-8, +) -> tuple[list[float], float, float]: + """Return z-scored group advantages, mean reward, and reward stddev.""" + if not rewards: + raise ValueError("rewards must not be empty") + + reward_mean = sum(rewards) / len(rewards) + reward_var = sum((reward - reward_mean) ** 2 for reward in rewards) / len(rewards) + reward_std = math.sqrt(reward_var) + denom = reward_std + eps + advantages = [float((reward - reward_mean) / denom) for reward in rewards] + return advantages, float(reward_mean), float(reward_std) + + # ── Helpers ──────────────────────────────────────────────────────────── @@ -679,6 +1033,286 @@ def _get_tools(intercept: dict[str, Any]) -> list[dict[str, Any]] | None: return None +def _coerce_token_ids(raw: Any) -> list[int]: + if not isinstance(raw, list): + return [] + token_ids: list[int] = [] + for token_id in raw: + try: + token_ids.append(int(token_id)) + except (TypeError, ValueError): + return [] + return token_ids + + +def _extract_prompt_token_ids(payload: dict[str, Any]) -> list[int]: + return _coerce_token_ids(payload.get("prompt_token_ids")) + + +def _extract_chat_choice_logprobs( + choice: dict[str, Any], + *, + expected_len: int, +) -> list[float]: + content = (choice.get("logprobs") or {}).get("content") or [] + values: list[float] = [] + if isinstance(content, list): + for item in content: + if not isinstance(item, dict): + values.append(0.0) + continue + raw = item.get("logprob") + values.append(float(raw) if isinstance(raw, (int, float)) else 0.0) + + if len(values) < expected_len: + values.extend([0.0] * (expected_len - len(values))) + return values[:expected_len] + + +def _clamp_max_completion_tokens( + *, + prompt_len: int, + requested: int, + max_model_len: int, + safety_margin: int = 16, +) -> int: + """Clamp completion tokens so prompt + generation fit in vLLM context.""" + available = max_model_len - prompt_len - safety_margin + return max(1, min(int(requested), int(available))) + + +_TRUNCATION_MARKER = "\n...[truncated]...\n" +_OMITTED_TOOL_OUTPUT_MARKER = "[tool output omitted]" +_OMITTED_ASSISTANT_TEXT_MARKER = "[omitted]" + + +def _truncate_text_middle( + text: str, + *, + max_chars: int, + marker: str = _TRUNCATION_MARKER, +) -> tuple[str, bool]: + if max_chars <= 0 or len(text) <= max_chars: + return text, False + + usable = max_chars - len(marker) + if usable <= 8: + return marker[: max(1, max_chars)], True + + head = max(1, int(usable * 0.75)) + tail = max(1, usable - head) + return text[:head].rstrip() + marker + text[-tail:].lstrip(), True + + +def _truncate_messages_for_prompt_budget( + messages: Sequence[dict[str, Any]], + *, + max_tool_message_chars: int, + max_assistant_message_chars: int, +) -> tuple[list[dict[str, Any]], int, int]: + truncated_messages: list[dict[str, Any]] = [] + tool_truncations = 0 + assistant_truncations = 0 + + for message in messages: + updated = dict(message) + content = updated.get("content") + if not isinstance(content, str): + truncated_messages.append(updated) + continue + + role = str(updated.get("role") or "") + if role == "tool": + content, changed = _truncate_text_middle( + content, + max_chars=max_tool_message_chars, + ) + if changed: + tool_truncations += 1 + elif role == "assistant": + content, changed = _truncate_text_middle( + content, + max_chars=max_assistant_message_chars, + ) + if changed: + assistant_truncations += 1 + + updated["content"] = content + truncated_messages.append(updated) + + return truncated_messages, tool_truncations, assistant_truncations + + +def _replace_oldest_message_content( + messages: Sequence[dict[str, Any]], + *, + role: str, + replacement: str, +) -> tuple[list[dict[str, Any]], bool]: + updated_messages = [dict(message) for message in messages] + for idx, message in enumerate(updated_messages): + if message.get("role") != role: + continue + content = message.get("content") + if not isinstance(content, str) or content == replacement: + continue + if len(replacement) >= len(content): + continue + message["content"] = replacement + updated_messages[idx] = message + return updated_messages, True + return updated_messages, False + + +def _fit_messages_to_context_window( + *, + messages: Sequence[dict[str, Any]], + tools: list[dict[str, Any]] | None, + render_prompt_ids: Callable[ + [list[dict[str, Any]], list[dict[str, Any]] | None], + list[int], + ], + requested_completion_tokens: int, + max_model_len: int, + max_tool_message_chars: int, + min_tool_message_chars: int, + max_assistant_message_chars: int, + min_assistant_message_chars: int, + safety_margin: int = 16, +) -> tuple[list[dict[str, Any]], list[int]]: + prompt_budget = max( + 1, + max_model_len - max(1, int(requested_completion_tokens)) - safety_margin, + ) + tool_char_budget = max(1, int(max_tool_message_chars)) + assistant_char_budget = max(1, int(max_assistant_message_chars)) + min_tool_chars = max(1, int(min_tool_message_chars)) + min_assistant_chars = max(1, int(min_assistant_message_chars)) + + base_messages = [dict(message) for message in messages] + prepared_messages = base_messages + prompt_ids = render_prompt_ids(prepared_messages, tools) + tool_truncations = 0 + assistant_truncations = 0 + + while True: + prepared_messages, tool_truncations, assistant_truncations = ( + _truncate_messages_for_prompt_budget( + base_messages, + max_tool_message_chars=tool_char_budget, + max_assistant_message_chars=assistant_char_budget, + ) + ) + prompt_ids = render_prompt_ids(prepared_messages, tools) + if len(prompt_ids) <= prompt_budget: + break + if tool_char_budget > min_tool_chars: + tool_char_budget = max(min_tool_chars, tool_char_budget // 2) + continue + if assistant_char_budget > min_assistant_chars: + assistant_char_budget = max( + min_assistant_chars, + assistant_char_budget // 2, + ) + continue + break + + omitted_tool_messages = 0 + omitted_assistant_messages = 0 + while len(prompt_ids) > prompt_budget: + prepared_messages, changed = _replace_oldest_message_content( + prepared_messages, + role="tool", + replacement=_OMITTED_TOOL_OUTPUT_MARKER, + ) + if changed: + omitted_tool_messages += 1 + prompt_ids = render_prompt_ids(prepared_messages, tools) + continue + + prepared_messages, changed = _replace_oldest_message_content( + prepared_messages, + role="assistant", + replacement=_OMITTED_ASSISTANT_TEXT_MARKER, + ) + if not changed: + break + omitted_assistant_messages += 1 + prompt_ids = render_prompt_ids(prepared_messages, tools) + + if ( + tool_truncations + or assistant_truncations + or omitted_tool_messages + or omitted_assistant_messages + ): + _log.info( + "trimmed intercepted prompt: prompt_tokens=%d/%d tool_truncations=%d " + "assistant_truncations=%d omitted_tool_messages=%d " + "omitted_assistant_messages=%d", + len(prompt_ids), + prompt_budget, + tool_truncations, + assistant_truncations, + omitted_tool_messages, + omitted_assistant_messages, + ) + + return prepared_messages, prompt_ids + + +_CONTEXT_LIMIT_RE = re.compile( + r"maximum context length is (?P[\d,]+) tokens.*?" + r"requested (?P[\d,]+) output tokens and your prompt contains " + r"at least (?P[\d,]+) input tokens", + flags=re.IGNORECASE | re.DOTALL, +) + + +def _retry_completion_tokens_from_context_error( + error_text: str, + *, + safety_margin: int = 16, +) -> tuple[int, int, int] | None: + """Return a smaller completion budget after a vLLM context-window error.""" + match = _CONTEXT_LIMIT_RE.search(error_text or "") + if match is None: + return None + + max_model_len = int(match.group("max_model_len").replace(",", "")) + prompt_tokens = int(match.group("prompt_tokens").replace(",", "")) + clamped_tokens = max(1, max_model_len - prompt_tokens - safety_margin) + return clamped_tokens, max_model_len, prompt_tokens + + +def _is_context_window_error(exc: Exception) -> bool: + text = str(exc or "").lower() + return ( + "maximum context length" in text + or "parameter=input_tokens" in text + or "input tokens" in text + and "output tokens" in text + ) + + +def _is_retriable_rollout_error(exc: Exception) -> bool: + if _is_context_window_error(exc): + return False + + text = str(exc or "").lower() + retriable_markers = ( + "429", + "too many requests", + "timed out", + "timeout", + "connection reset", + "temporarily unavailable", + "tunnel failed", + "sandbox", + ) + return any(marker in text for marker in retriable_markers) + + def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] if not isinstance(raw_tool_calls, list): @@ -714,6 +1348,80 @@ def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]: return normalized +_TOOL_CALL_XML_RE = re.compile( + r"\s*(.*?)\s*", + flags=re.DOTALL, +) + + +def _extract_xml_tool_calls(text: str) -> tuple[str, list[dict[str, Any]]]: + """Parse Qwen-style XML tool-call blocks from assistant text. + + vLLM's Qwen3 XML parser can return tool calls only inside + ``message.content``: + + + {"name": "answer", "arguments": {}} + + + When that happens we need to recover structured ``tool_calls`` before + sending the response back to Pi, otherwise the harness treats the reply as + plain terminal text and the rollout dies after one turn. + """ + if not text: + return "", [] + + tool_calls: list[dict[str, Any]] = [] + cursor = 0 + content_parts: list[str] = [] + + for match in _TOOL_CALL_XML_RE.finditer(text): + start, end = match.span() + if start > cursor: + content_parts.append(text[cursor:start]) + + raw_block = match.group(1).strip() + parsed_block: Any + try: + parsed_block = json.loads(raw_block) + except json.JSONDecodeError: + content_parts.append(text[start:end]) + cursor = end + continue + + blocks = parsed_block if isinstance(parsed_block, list) else [parsed_block] + parsed_any = False + for block in blocks: + if not isinstance(block, dict): + continue + name = block.get("name") + if not isinstance(name, str) or not name: + continue + parsed_any = True + tool_calls.append( + { + "id": f"call_{uuid.uuid4().hex[:8]}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps( + block.get("arguments") or {}, + ensure_ascii=False, + ), + }, + } + ) + + if not parsed_any: + content_parts.append(text[start:end]) + cursor = end + + if cursor < len(text): + content_parts.append(text[cursor:]) + + return "".join(content_parts).strip(), tool_calls + + def _parse_assistant_message( *, tokenizer: Any, @@ -748,6 +1456,44 @@ def _parse_assistant_message( return message +def _normalize_chat_choice_message( + *, + tokenizer: Any, + choice: dict[str, Any], + completion_ids: list[int], +) -> dict[str, Any]: + raw_message = choice.get("message") + if isinstance(raw_message, dict): + message = dict(raw_message) + else: + message = {"role": "assistant", "content": ""} + + fallback_text = message.get("content") + if not isinstance(fallback_text, str): + fallback_text = "" + + parsed = _parse_assistant_message( + tokenizer=tokenizer, + completion_ids=completion_ids, + fallback_text=fallback_text, + ) + + if not isinstance(message.get("role"), str): + message["role"] = "assistant" + if not isinstance(message.get("content"), str): + message["content"] = parsed.get("content", "") + if not (message.get("tool_calls") or []): + tool_calls = parsed.get("tool_calls") or [] + if not tool_calls and fallback_text: + text_content, xml_tool_calls = _extract_xml_tool_calls(fallback_text) + if xml_tool_calls: + tool_calls = xml_tool_calls + message["content"] = text_content + if tool_calls: + message["tool_calls"] = tool_calls + return message + + def _make_chat_response( assistant_message: dict[str, Any], model: str, diff --git a/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh b/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh new file mode 100755 index 000000000..bc8027be8 --- /dev/null +++ b/examples/mini_swe_env/async_grpo/sbatch_multinode_async_grpo.sh @@ -0,0 +1,439 @@ +#!/usr/bin/env bash +#SBATCH --job-name=mini-swe-async-grpo +#SBATCH --partition=hopper-prod +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:h100:1 +#SBATCH --cpus-per-task=16 +#SBATCH --time=08:00:00 + +set -euo pipefail + +SCRIPT_DIR=$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd) +DEFAULT_REPO_ROOT=${SLURM_SUBMIT_DIR:-} +if [[ -z "$DEFAULT_REPO_ROOT" ]]; then + DEFAULT_REPO_ROOT=$(cd -- "$SCRIPT_DIR/../../.." && pwd) +fi +REPO_ROOT=${REPO_ROOT:-$DEFAULT_REPO_ROOT} +RUNS_ROOT=${RUNS_ROOT:-$REPO_ROOT/runs/mini_swe_async_grpo} +RUN_DIR=${RUN_DIR:-$RUNS_ROOT/${SLURM_JOB_ID:-manual}} +CLOUDFLARED=${CLOUDFLARED:-/fsx/benjamin_burtenshaw/bin/cloudflared} + +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +CPUS_PER_TASK=${CPUS_PER_TASK:-${SLURM_CPUS_PER_TASK:-16}} +JOB_PORT_BASE=${JOB_PORT_BASE:-$((20000 + (${SLURM_JOB_ID:-0} % 10000)))} +VLLM_PORT=${VLLM_PORT:-$JOB_PORT_BASE} +INTERCEPTION_PORT=${INTERCEPTION_PORT:-$((JOB_PORT_BASE + 1000))} +MASTER_PORT=${MASTER_PORT:-$((JOB_PORT_BASE + 2000))} + +SWE_MODEL=${SWE_MODEL:-Qwen/Qwen3-1.7B} +SWE_SANDBOX_BACKEND=${SWE_SANDBOX_BACKEND:-hf} +SWE_AGENT=${SWE_AGENT:-pi} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-} +GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.80} +MAX_TASKS=${MAX_TASKS:-4} +MAX_STEPS=${MAX_STEPS:-3} +MAX_TURNS=${MAX_TURNS:-30} +SWE_ROLLOUT_MAX_INFLIGHT=${SWE_ROLLOUT_MAX_INFLIGHT:-} +SWE_TRAIN_DTYPE=${SWE_TRAIN_DTYPE:-bf16} +SWE_TORCH_EMPTY_CACHE_STEPS=${SWE_TORCH_EMPTY_CACHE_STEPS:-1} +SWE_LORA=${SWE_LORA:-0} +SWE_LORA_R=${SWE_LORA_R:-16} +SWE_LORA_ALPHA=${SWE_LORA_ALPHA:-} +SWE_LORA_DROPOUT=${SWE_LORA_DROPOUT:-0.0} +SWE_LORA_TARGET_MODULES=${SWE_LORA_TARGET_MODULES:-} +SWE_LORA_BIAS=${SWE_LORA_BIAS:-none} +SWE_LORA_USE_RSLORA=${SWE_LORA_USE_RSLORA:-0} + +if [[ -z "$MAX_MODEL_LEN" ]]; then + MAX_MODEL_LEN=$( + SWE_MODEL="$SWE_MODEL" "$REPO_ROOT/.venv/bin/python" - <<'PY' 2>/dev/null || true +from transformers import AutoConfig +import os + +cfg = AutoConfig.from_pretrained(os.environ["SWE_MODEL"]) +for section in (cfg, getattr(cfg, "text_config", None), getattr(cfg, "llm_config", None)): + if section is None: + continue + for key in ("max_position_embeddings", "model_max_length", "max_seq_len", "seq_length"): + value = section.get(key) if isinstance(section, dict) else getattr(section, key, None) + if isinstance(value, int) and value > 0: + print(value) + raise SystemExit(0) +PY + ) +fi +MAX_MODEL_LEN=${MAX_MODEL_LEN:-16384} + +if [[ -z "$SWE_ROLLOUT_MAX_INFLIGHT" ]]; then + if [[ "$SWE_SANDBOX_BACKEND" == "hf" ]]; then + SWE_ROLLOUT_MAX_INFLIGHT=1 + else + SWE_ROLLOUT_MAX_INFLIGHT=2 + fi +fi + +MODEL_LOWER=$(printf '%s' "$SWE_MODEL" | tr '[:upper:]' '[:lower:]') +VLLM_TOOL_CALL_PARSER=${VLLM_TOOL_CALL_PARSER:-} +if [[ -z "$VLLM_TOOL_CALL_PARSER" ]]; then + case "$MODEL_LOWER" in + *qwen3*coder*) + VLLM_TOOL_CALL_PARSER=qwen3_coder + ;; + *qwen3*) + VLLM_TOOL_CALL_PARSER=qwen3_xml + ;; + esac +fi + +SWE_CUDA_HOME=${SWE_CUDA_HOME:-} +if [[ -z "$SWE_CUDA_HOME" ]]; then + case "$MODEL_LOWER" in + *qwen3.5*|*qwen3_5*) + if [[ -d /usr/local/cuda-12.4 ]]; then + SWE_CUDA_HOME=/usr/local/cuda-12.4 + fi + ;; + esac +fi + +SWE_VLLM_GDN_PREFILL_BACKEND=${SWE_VLLM_GDN_PREFILL_BACKEND:-} +if [[ -z "$SWE_VLLM_GDN_PREFILL_BACKEND" ]]; then + case "$MODEL_LOWER" in + *qwen3.5*|*qwen3_5*) + SWE_VLLM_GDN_PREFILL_BACKEND=triton + ;; + esac +fi +SWE_VLLM_EXTRA_ARGS=${SWE_VLLM_EXTRA_ARGS:-} + +mkdir -p "$RUN_DIR/home" +cd "$REPO_ROOT" + +if [[ -z "${HF_TOKEN:-}" ]]; then + if [[ -s /admin/home/benjamin_burtenshaw/.cache/huggingface/token ]]; then + HF_TOKEN="$(< /admin/home/benjamin_burtenshaw/.cache/huggingface/token)" + elif [[ -s /fsx/benjamin_burtenshaw/.cache/huggingface/token ]]; then + HF_TOKEN="$(< /fsx/benjamin_burtenshaw/.cache/huggingface/token)" + fi +fi +if [[ "$SWE_SANDBOX_BACKEND" == "hf" && -z "${HF_TOKEN:-}" ]]; then + echo "ERROR: HF_TOKEN is required for HF sandbox creation" >&2 + exit 2 +fi + +INTERCEPTION_AUTH_TOKEN=${INTERCEPTION_AUTH_TOKEN:-$(python3 - <<'PY' +import secrets +print(secrets.token_urlsafe(32)) +PY +)} + +mapfile -t ALLOC_NODES < <(scontrol show hostnames "$SLURM_JOB_NODELIST") +if (( ${#ALLOC_NODES[@]} < 2 )); then + echo "ERROR: expected at least 2 allocated nodes, got ${#ALLOC_NODES[@]}" >&2 + exit 3 +fi + +VLLM_NODE=${ALLOC_NODES[0]} +TRAINER_NODES=("${ALLOC_NODES[@]:1}") +TRAINER_MASTER=${TRAINER_NODES[0]} +TRAINER_NODE_COUNT=${#TRAINER_NODES[@]} +TRAINER_NODELIST=$(IFS=,; echo "${TRAINER_NODES[*]}") +TOTAL_TRAINER_PROCS=$((TRAINER_NODE_COUNT * GPUS_PER_NODE)) +VLLM_URL="http://${VLLM_NODE}:${VLLM_PORT}" +GPU_IDS=$(seq -s, 0 $((GPUS_PER_NODE - 1))) + +export RUN_DIR REPO_ROOT GPUS_PER_NODE VLLM_PORT MASTER_PORT +export MAX_MODEL_LEN GPU_MEMORY_UTILIZATION MAX_TASKS MAX_STEPS MAX_TURNS +export TRAINER_NODE_COUNT TOTAL_TRAINER_PROCS TRAINER_MASTER VLLM_URL GPU_IDS +export PYTHONPATH="$REPO_ROOT/src:$REPO_ROOT/envs" +export PYTHONUNBUFFERED=1 +export TRL_EXPERIMENTAL_SILENCE=1 +export HF_HOME=/fsx/benjamin_burtenshaw/.cache/huggingface +export HF_HUB_CACHE=/fsx/benjamin_burtenshaw/.cache/huggingface/hub +export HF_DATASETS_CACHE=/fsx/benjamin_burtenshaw/.cache/huggingface/datasets +export HF_HUB_ENABLE_HF_TRANSFER=1 +export VLLM_CACHE_ROOT=/fsx/benjamin_burtenshaw/.cache/vllm +export XDG_CACHE_HOME=/fsx/benjamin_burtenshaw/.cache +export TORCH_HOME=/fsx/benjamin_burtenshaw/.cache/torch +export TRITON_CACHE_DIR=/fsx/benjamin_burtenshaw/.cache/triton +export FLASHINFER_CACHE_DIR=/fsx/benjamin_burtenshaw/.cache/flashinfer +export FLASHINFER_WORKSPACE_BASE=${FLASHINFER_WORKSPACE_BASE:-/tmp/${USER}/flashinfer-workspace-${SLURM_JOB_ID:-manual}} +export PYTHONDONTWRITEBYTECODE=${PYTHONDONTWRITEBYTECODE:-1} +export PYTHONPYCACHEPREFIX=${PYTHONPYCACHEPREFIX:-/tmp/${USER}/openenv-pycache} +export VLLM_NO_USAGE_STATS=1 +export HF_TOKEN +export HOME="$RUN_DIR/home" +export SWE_MODEL +export SWE_SANDBOX_BACKEND +export SWE_AGENT +export VLLM_API_KEY=${VLLM_API_KEY:-token} +export INTERCEPTION_HOST=0.0.0.0 +export INTERCEPTION_PORT +export INTERCEPTION_AUTH_TOKEN +export SWE_TRACKIO_SPACE_ID=${SWE_TRACKIO_SPACE_ID:-${TRACKIO_SPACE_ID:-burtenshaw/swe-grpo-dashboard}} +export SWE_TRACKIO_PROJECT=${SWE_TRACKIO_PROJECT:-swe-async-grpo} +export SWE_CHECKPOINT_TO_HUB=${SWE_CHECKPOINT_TO_HUB:-0} +export SWE_HUB_MODEL_ID=${SWE_HUB_MODEL_ID:-} +export SWE_HUB_PRIVATE_REPO=${SWE_HUB_PRIVATE_REPO:-1} +export SWE_CHECKPOINT_SAVE_STEPS=${SWE_CHECKPOINT_SAVE_STEPS:-2} +export SWE_CHECKPOINT_SAVE_TOTAL_LIMIT=${SWE_CHECKPOINT_SAVE_TOTAL_LIMIT:-2} +export SWE_RESUME_FROM_CHECKPOINT=${SWE_RESUME_FROM_CHECKPOINT:-none} +export SWE_IGNORE_DATA_SKIP=${SWE_IGNORE_DATA_SKIP:-1} +export SWE_ASYNC_MAX_STALENESS=${SWE_ASYNC_MAX_STALENESS:-16} +export SWE_ASYNC_WEIGHT_SYNC_STEPS=${SWE_ASYNC_WEIGHT_SYNC_STEPS:-1} +export SWE_ASYNC_MAX_INFLIGHT_TASKS=${SWE_ASYNC_MAX_INFLIGHT_TASKS:-$SWE_ROLLOUT_MAX_INFLIGHT} +export SWE_ASYNC_QUEUE_MAXSIZE=${SWE_ASYNC_QUEUE_MAXSIZE:-64} +export SWE_ROLLOUT_MAX_INFLIGHT +export SWE_VLLM_MAX_MODEL_LEN=${SWE_VLLM_MAX_MODEL_LEN:-$MAX_MODEL_LEN} +export SWE_TRAIN_DTYPE SWE_TORCH_EMPTY_CACHE_STEPS +export SWE_LORA SWE_LORA_R SWE_LORA_ALPHA SWE_LORA_DROPOUT +export SWE_LORA_TARGET_MODULES SWE_LORA_BIAS SWE_LORA_USE_RSLORA +export SWE_HF_SANDBOX_CREATE_RETRIES=${SWE_HF_SANDBOX_CREATE_RETRIES:-6} +export SWE_HF_SANDBOX_CREATE_BACKOFF_S=${SWE_HF_SANDBOX_CREATE_BACKOFF_S:-20} +export OPENENV_HF_SANDBOX_URL_TIMEOUT_S=${OPENENV_HF_SANDBOX_URL_TIMEOUT_S:-120} +export SWE_ROLLOUT_QUEUE_TIMEOUT_S=${SWE_ROLLOUT_QUEUE_TIMEOUT_S:-900} +export SWE_ROLLOUT_REQUEST_TIMEOUT_S=${SWE_ROLLOUT_REQUEST_TIMEOUT_S:-120} +export SWE_ROLLOUT_FAILURE_BACKOFF_S=${SWE_ROLLOUT_FAILURE_BACKOFF_S:-30} +export SWE_ROLLOUT_MAX_ATTEMPTS=${SWE_ROLLOUT_MAX_ATTEMPTS:-4} +export SWE_GIT_CHECKOUT_TIMEOUT_S=${SWE_GIT_CHECKOUT_TIMEOUT_S:-300} +export SWE_DISABLE_WEIGHT_TRANSFER=${SWE_DISABLE_WEIGHT_TRANSFER:-0} +export VLLM_TOOL_CALL_PARSER +export SWE_CUDA_HOME SWE_VLLM_GDN_PREFILL_BACKEND SWE_VLLM_EXTRA_ARGS +export NCCL_DEBUG=${NCCL_DEBUG:-WARN} +export TORCH_NCCL_ASYNC_ERROR_HANDLING=${TORCH_NCCL_ASYNC_ERROR_HANDLING:-1} +export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True} +export OMP_NUM_THREADS=${OMP_NUM_THREADS:-4} +export OPENENV_LOCAL_SANDBOX_ROOT=${OPENENV_LOCAL_SANDBOX_ROOT:-/tmp/${USER}/openenv-local-sandboxes-${SLURM_JOB_ID:-manual}} +export OPENENV_LOCAL_SANDBOX_PRESERVE=${OPENENV_LOCAL_SANDBOX_PRESERVE:-0} +if [[ -n "$SWE_CUDA_HOME" ]]; then + export CUDA_HOME="$SWE_CUDA_HOME" + export CUDA_PATH="$SWE_CUDA_HOME" + export PATH="$SWE_CUDA_HOME/bin:$PATH" + export LD_LIBRARY_PATH="$SWE_CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}" +fi +mkdir -p "$OPENENV_LOCAL_SANDBOX_ROOT" "$FLASHINFER_WORKSPACE_BASE" + +VLLM_STEP_PID= +CLOUDFLARED_PID= +cleanup() { + local rc=$? + if [[ -n "${CLOUDFLARED_PID:-}" ]]; then + kill "$CLOUDFLARED_PID" >/dev/null 2>&1 || true + fi + if [[ -n "${VLLM_STEP_PID:-}" ]]; then + kill "$VLLM_STEP_PID" >/dev/null 2>&1 || true + wait "$VLLM_STEP_PID" >/dev/null 2>&1 || true + fi + exit "$rc" +} +trap cleanup EXIT + +echo "job_start=$(date -Is)" +echo "run_dir=$RUN_DIR" +echo "nodes=${ALLOC_NODES[*]}" +echo "vllm_node=$VLLM_NODE trainer_nodes=$TRAINER_NODELIST trainer_master=$TRAINER_MASTER" +echo "ports=vllm:$VLLM_PORT interception:$INTERCEPTION_PORT master:$MASTER_PORT" +echo "model=$SWE_MODEL agent=$SWE_AGENT parser=${VLLM_TOOL_CALL_PARSER:-none} sandbox=$SWE_SANDBOX_BACKEND dtype=$SWE_TRAIN_DTYPE lora=$SWE_LORA max_model_len=$MAX_MODEL_LEN max_tasks=$MAX_TASKS max_steps=$MAX_STEPS max_turns=$MAX_TURNS inflight=$SWE_ROLLOUT_MAX_INFLIGHT staleness=$SWE_ASYNC_MAX_STALENESS preserve_local=$OPENENV_LOCAL_SANDBOX_PRESERVE cuda_home=${CUDA_HOME:-none} gdn_prefill_backend=${SWE_VLLM_GDN_PREFILL_BACKEND:-default}" + +: > "$RUN_DIR/vllm.log" +: > "$RUN_DIR/trainer.log" +: > "$RUN_DIR/cloudflared.log" + +echo "starting_vllm=$(date -Is)" +srun \ + --nodes=1 \ + --ntasks=1 \ + --nodelist="$VLLM_NODE" \ + --cpus-per-task="$CPUS_PER_TASK" \ + --gres="gpu:h100:${GPUS_PER_NODE}" \ + --kill-on-bad-exit=1 \ + bash -lc ' + set -euo pipefail + cd "$REPO_ROOT" + mkdir -p "$PYTHONPYCACHEPREFIX" + export CUDA_VISIBLE_DEVICES="$GPU_IDS" + export VLLM_SERVER_DEV_MODE=1 + TOOL_ARGS=() + if [[ -n "${VLLM_TOOL_CALL_PARSER:-}" ]]; then + TOOL_ARGS+=(--enable-auto-tool-choice --tool-call-parser "$VLLM_TOOL_CALL_PARSER") + fi + GDN_ARGS=() + if [[ -n "${SWE_VLLM_GDN_PREFILL_BACKEND:-}" ]]; then + GDN_ARGS+=(--gdn-prefill-backend "$SWE_VLLM_GDN_PREFILL_BACKEND") + fi + EXTRA_ARGS=() + if [[ -n "${SWE_VLLM_EXTRA_ARGS:-}" ]]; then + read -r -a EXTRA_ARGS <<< "$SWE_VLLM_EXTRA_ARGS" + fi + exec .venv/bin/vllm serve "$SWE_MODEL" \ + --tensor-parallel-size "$GPUS_PER_NODE" \ + --max-model-len "$MAX_MODEL_LEN" \ + --host 0.0.0.0 \ + --port "$VLLM_PORT" \ + --api-key "$VLLM_API_KEY" \ + --gpu-memory-utilization "$GPU_MEMORY_UTILIZATION" \ + --logprobs-mode processed_logprobs \ + --weight-transfer-config '\''{"backend":"nccl"}'\'' \ + "${TOOL_ARGS[@]}" \ + "${GDN_ARGS[@]}" \ + "${EXTRA_ARGS[@]}" + ' > "$RUN_DIR/vllm.log" 2>&1 & +VLLM_STEP_PID=$! + +VLLM_BIND_MARKER="http://0.0.0.0:${VLLM_PORT}" +VLLM_READY_MARKER="Application startup complete." +for _ in $(seq 1 900); do + if grep -Fq "Address already in use" "$RUN_DIR/vllm.log"; then + echo "ERROR: vLLM failed to bind on ${VLLM_PORT}" >&2 + tail -100 "$RUN_DIR/vllm.log" >&2 || true + exit 4 + fi + if ! kill -0 "$VLLM_STEP_PID" >/dev/null 2>&1; then + echo "ERROR: vLLM srun exited before readiness" >&2 + tail -100 "$RUN_DIR/vllm.log" >&2 || true + exit 4 + fi + if grep -Fq "$VLLM_BIND_MARKER" "$RUN_DIR/vllm.log" \ + && grep -Fq "$VLLM_READY_MARKER" "$RUN_DIR/vllm.log" \ + && curl -fsS "$VLLM_URL/health" >/dev/null 2>&1 \ + && curl -fsS "$VLLM_URL/get_world_size" >/dev/null 2>&1; then + echo "vllm_ready=$(date -Is)" + break + fi + sleep 2 +done +if ! grep -Fq "$VLLM_BIND_MARKER" "$RUN_DIR/vllm.log" \ + || ! grep -Fq "$VLLM_READY_MARKER" "$RUN_DIR/vllm.log" \ + || ! curl -fsS "$VLLM_URL/health" >/dev/null 2>&1 \ + || ! curl -fsS "$VLLM_URL/get_world_size" >/dev/null 2>&1; then + echo "ERROR: vLLM did not become ready on ${VLLM_URL}" >&2 + tail -100 "$RUN_DIR/vllm.log" >&2 || true + exit 5 +fi + +INTERCEPTION_BASE_URL= +if [[ "$SWE_SANDBOX_BACKEND" == "docker" ]]; then + INTERCEPTION_BASE_URL="http://host.docker.internal:${INTERCEPTION_PORT}" +elif [[ "$SWE_SANDBOX_BACKEND" == "local" ]]; then + INTERCEPTION_BASE_URL="http://127.0.0.1:${INTERCEPTION_PORT}" +else + echo "starting_cloudflared=$(date -Is)" + "$CLOUDFLARED" tunnel \ + --url "http://${TRAINER_MASTER}:${INTERCEPTION_PORT}" \ + --no-autoupdate \ + --protocol quic \ + --ha-connections 1 \ + > "$RUN_DIR/cloudflared.log" 2>&1 & + CLOUDFLARED_PID=$! + echo "$CLOUDFLARED_PID" > "$RUN_DIR/cloudflared.pid" + + for _ in $(seq 1 120); do + INTERCEPTION_BASE_URL=$(grep -Eo 'https://[A-Za-z0-9.-]+\.trycloudflare\.com' "$RUN_DIR/cloudflared.log" | tail -n 1 || true) + if [[ -n "$INTERCEPTION_BASE_URL" ]]; then + break + fi + if ! kill -0 "$CLOUDFLARED_PID" >/dev/null 2>&1; then + echo "ERROR: cloudflared exited before URL creation" >&2 + tail -100 "$RUN_DIR/cloudflared.log" >&2 || true + exit 6 + fi + sleep 1 + done + if [[ -z "$INTERCEPTION_BASE_URL" ]]; then + echo "ERROR: cloudflared did not publish a tunnel URL" >&2 + tail -100 "$RUN_DIR/cloudflared.log" >&2 || true + exit 7 + fi +fi +export INTERCEPTION_BASE_URL +echo "$INTERCEPTION_BASE_URL" > "$RUN_DIR/interception_base_url.txt" + +{ + echo "SLURM_JOB_ID=${SLURM_JOB_ID:-}" + echo "RUN_DIR=$RUN_DIR" + echo "SWE_MODEL=$SWE_MODEL" + echo "SWE_SANDBOX_BACKEND=$SWE_SANDBOX_BACKEND" + echo "SWE_AGENT=$SWE_AGENT" + echo "VLLM_NODE=$VLLM_NODE" + echo "VLLM_URL=$VLLM_URL" + echo "VLLM_PORT=$VLLM_PORT" + echo "TRAINER_NODELIST=$TRAINER_NODELIST" + echo "TRAINER_MASTER=$TRAINER_MASTER" + echo "TOTAL_TRAINER_PROCS=$TOTAL_TRAINER_PROCS" + echo "INTERCEPTION_BASE_URL=$INTERCEPTION_BASE_URL" + echo "INTERCEPTION_PORT=$INTERCEPTION_PORT" + echo "MASTER_PORT=$MASTER_PORT" + echo "SWE_TRACKIO_SPACE_ID=$SWE_TRACKIO_SPACE_ID" + echo "SWE_TRACKIO_PROJECT=$SWE_TRACKIO_PROJECT" + echo "SWE_CHECKPOINT_TO_HUB=$SWE_CHECKPOINT_TO_HUB" + echo "SWE_HUB_MODEL_ID=$SWE_HUB_MODEL_ID" + echo "SWE_HUB_PRIVATE_REPO=$SWE_HUB_PRIVATE_REPO" + echo "SWE_CHECKPOINT_SAVE_STEPS=$SWE_CHECKPOINT_SAVE_STEPS" + echo "SWE_CHECKPOINT_SAVE_TOTAL_LIMIT=$SWE_CHECKPOINT_SAVE_TOTAL_LIMIT" + echo "SWE_RESUME_FROM_CHECKPOINT=$SWE_RESUME_FROM_CHECKPOINT" + echo "SWE_IGNORE_DATA_SKIP=$SWE_IGNORE_DATA_SKIP" + echo "MAX_TASKS=$MAX_TASKS" + echo "MAX_STEPS=$MAX_STEPS" + echo "MAX_TURNS=$MAX_TURNS" + echo "SWE_TASK_INDICES=${SWE_TASK_INDICES:-}" + echo "SWE_REPEAT_TASKS=${SWE_REPEAT_TASKS:-}" + echo "SWE_NUM_GENERATIONS=${SWE_NUM_GENERATIONS:-}" + echo "SWE_TEMPERATURE=${SWE_TEMPERATURE:-}" + echo "SWE_LEARNING_RATE=${SWE_LEARNING_RATE:-}" + echo "SWE_REWARD_MODE=${SWE_REWARD_MODE:-}" + echo "SWE_ENABLE_ANSWER_TOOL=${SWE_ENABLE_ANSWER_TOOL:-}" + echo "SWE_ROLLOUT_MAX_INFLIGHT=$SWE_ROLLOUT_MAX_INFLIGHT" + echo "SWE_ROLLOUT_MAX_ATTEMPTS=$SWE_ROLLOUT_MAX_ATTEMPTS" + echo "SWE_ROLLOUT_REQUEST_TIMEOUT_S=$SWE_ROLLOUT_REQUEST_TIMEOUT_S" + echo "SWE_GIT_CHECKOUT_TIMEOUT_S=$SWE_GIT_CHECKOUT_TIMEOUT_S" + echo "SWE_VLLM_MAX_MODEL_LEN=$SWE_VLLM_MAX_MODEL_LEN" + echo "SWE_TRAIN_DTYPE=$SWE_TRAIN_DTYPE" + echo "SWE_LORA=$SWE_LORA" + echo "SWE_LORA_R=$SWE_LORA_R" + echo "SWE_LORA_ALPHA=$SWE_LORA_ALPHA" + echo "SWE_LORA_DROPOUT=$SWE_LORA_DROPOUT" + echo "SWE_LORA_TARGET_MODULES=$SWE_LORA_TARGET_MODULES" + echo "SWE_LORA_BIAS=$SWE_LORA_BIAS" + echo "SWE_LORA_USE_RSLORA=$SWE_LORA_USE_RSLORA" + echo "SWE_OPTIM=${SWE_OPTIM:-}" + echo "SWE_TORCH_EMPTY_CACHE_STEPS=$SWE_TORCH_EMPTY_CACHE_STEPS" + echo "VLLM_TOOL_CALL_PARSER=${VLLM_TOOL_CALL_PARSER:-}" + echo "SWE_DISABLE_WEIGHT_TRANSFER=$SWE_DISABLE_WEIGHT_TRANSFER" + echo "OPENENV_LOCAL_SANDBOX_ROOT=$OPENENV_LOCAL_SANDBOX_ROOT" +} > "$RUN_DIR/run.env" + +echo "starting_trainer=$(date -Is)" +srun \ + --nodes="$TRAINER_NODE_COUNT" \ + --ntasks="$TRAINER_NODE_COUNT" \ + --ntasks-per-node=1 \ + --nodelist="$TRAINER_NODELIST" \ + --cpus-per-task="$CPUS_PER_TASK" \ + --gres="gpu:h100:${GPUS_PER_NODE}" \ + --kill-on-bad-exit=1 \ + bash -lc ' + set -euo pipefail + cd "$REPO_ROOT" + mkdir -p "$PYTHONPYCACHEPREFIX" + export CUDA_VISIBLE_DEVICES="$GPU_IDS" + .venv/bin/python .venv/bin/accelerate launch \ + --num_processes "$TOTAL_TRAINER_PROCS" \ + --num_machines "$TRAINER_NODE_COUNT" \ + --machine_rank "$SLURM_PROCID" \ + --main_process_ip "$TRAINER_MASTER" \ + --main_process_port "$MASTER_PORT" \ + --mixed_precision no \ + --num_cpu_threads_per_process "$OMP_NUM_THREADS" \ + examples/mini_swe_env/train_swe_async_grpo.py \ + --sandbox-backend "$SWE_SANDBOX_BACKEND" \ + --agent "$SWE_AGENT" \ + --vllm-url "$VLLM_URL" \ + --task-variant lite \ + --max-tasks "$MAX_TASKS" \ + --max-steps "$MAX_STEPS" \ + --max-turns "$MAX_TURNS" + ' > "$RUN_DIR/trainer.log" 2>&1 + +echo "trainer_done=$(date -Is)" +echo "job_done=$(date -Is)" diff --git a/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh b/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh new file mode 100755 index 000000000..b78121a65 --- /dev/null +++ b/examples/mini_swe_env/async_grpo/submit_multinode_async_grpo.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR=$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd -- "$SCRIPT_DIR/../../.." && pwd) +SBATCH_TEMPLATE="$SCRIPT_DIR/sbatch_multinode_async_grpo.sh" +RUNS_ROOT=${RUNS_ROOT:-$REPO_ROOT/runs/mini_swe_async_grpo} + +NODES=${NODES:-2} +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +CPUS_PER_TASK=${CPUS_PER_TASK:-16} +PARTITION=${PARTITION:-hopper-prod} +TIME_LIMIT=${TIME_LIMIT:-08:00:00} +JOB_NAME=${JOB_NAME:-mini-swe-async-grpo} +RUN_ID=${RUN_ID:-$(date -u +%Y%m%dT%H%M%SZ)-r$RANDOM} +RUN_DIR=${RUN_DIR:-$RUNS_ROOT/$RUN_ID} + +if (( NODES < 2 || NODES > 4 )); then + echo "ERROR: NODES must be between 2 and 4, got ${NODES}" >&2 + exit 2 +fi +if (( GPUS_PER_NODE < 1 )); then + echo "ERROR: GPUS_PER_NODE must be >= 1, got ${GPUS_PER_NODE}" >&2 + exit 2 +fi +if (( CPUS_PER_TASK < 1 )); then + echo "ERROR: CPUS_PER_TASK must be >= 1, got ${CPUS_PER_TASK}" >&2 + exit 2 +fi + +mkdir -p "$RUN_DIR/home" + +PORT_SEED=${PORT_SEED:-$(( (RANDOM + $(date -u +%S)) % 1000 ))} +VLLM_PORT=${VLLM_PORT:-$((31000 + PORT_SEED))} +INTERCEPTION_PORT=${INTERCEPTION_PORT:-$((32000 + PORT_SEED))} +MASTER_PORT=${MASTER_PORT:-$((33000 + PORT_SEED))} + +export REPO_ROOT RUNS_ROOT RUN_DIR +export GPUS_PER_NODE CPUS_PER_TASK +export VLLM_PORT INTERCEPTION_PORT MASTER_PORT + +{ + echo "RUN_ID=$RUN_ID" + echo "RUN_DIR=$RUN_DIR" + echo "NODES=$NODES" + echo "GPUS_PER_NODE=$GPUS_PER_NODE" + echo "CPUS_PER_TASK=$CPUS_PER_TASK" + echo "PARTITION=$PARTITION" + echo "TIME_LIMIT=$TIME_LIMIT" + echo "JOB_NAME=$JOB_NAME" + echo "VLLM_PORT=$VLLM_PORT" + echo "INTERCEPTION_PORT=$INTERCEPTION_PORT" + echo "MASTER_PORT=$MASTER_PORT" +} > "$RUN_DIR/submission.env" + +SBATCH_ARGS=( + --job-name="$JOB_NAME" + --partition="$PARTITION" + --nodes="$NODES" + --ntasks-per-node=1 + --gres="gpu:h100:${GPUS_PER_NODE}" + --cpus-per-task="$CPUS_PER_TASK" + --time="$TIME_LIMIT" + --output="$RUN_DIR/slurm-%j.out" + --error="$RUN_DIR/slurm-%j.err" +) + +if [[ -n "${SBATCH_NODELIST:-}" ]]; then + SBATCH_ARGS+=(--nodelist="$SBATCH_NODELIST") +fi +if [[ -n "${SBATCH_EXCLUDE:-}" ]]; then + SBATCH_ARGS+=(--exclude="$SBATCH_EXCLUDE") +fi + +submit_output=$(sbatch "${SBATCH_ARGS[@]}" "$SBATCH_TEMPLATE") +printf '%s\n' "$submit_output" +job_id=$(awk '{print $4}' <<<"$submit_output") +if [[ -n "$job_id" ]]; then + printf '%s\n' "$job_id" > "$RUN_DIR/job_id.txt" +fi + +echo "run_dir=$RUN_DIR" +echo "ports=vllm:${VLLM_PORT} interception:${INTERCEPTION_PORT} master:${MASTER_PORT}" diff --git a/examples/mini_swe_env/train_swe_async_grpo.py b/examples/mini_swe_env/train_swe_async_grpo.py index 6b6a204a3..7adebba8a 100644 --- a/examples/mini_swe_env/train_swe_async_grpo.py +++ b/examples/mini_swe_env/train_swe_async_grpo.py @@ -2,7 +2,7 @@ """Train SWE with AsyncGRPOTrainer + Pi agent + InterceptionServer. Architecture: - Pi (HF Sandbox) → InterceptionServer → SWERolloutWorker → vLLM /v1/completions + Pi (HF Sandbox) → InterceptionServer → SWERolloutWorker → vLLM /v1/chat/completions ← chat response back to Pi vLLM runs on a separate GPU. The trainer, interception server, and rollout @@ -27,22 +27,26 @@ import argparse import asyncio +import contextlib import inspect import logging import os import sys import threading +import types from pathlib import Path from typing import Any +import torch _root = Path(__file__).resolve().parent.parent.parent for _p in (_root, _root / "src", _root / "envs"): if str(_p) not in sys.path: sys.path.insert(0, str(_p)) from datasets import Dataset # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer # noqa: E402 from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer # noqa: E402 +from trl.trainer.utils import _ChunkedLogProbFunction # noqa: E402 from examples.mini_swe_env.async_grpo.control_plane import ( # noqa: E402 SWEAsyncControlPlane, @@ -113,11 +117,59 @@ def _args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--task-variant", default="lite", choices=["lite", "full"]) p.add_argument("--max-tasks", type=int, default=5) + p.add_argument( + "--task-offset", + type=int, + default=int(os.environ.get("SWE_TASK_OFFSET", "0")), + ) + p.add_argument( + "--task-stride", + type=int, + default=int(os.environ.get("SWE_TASK_STRIDE", "1")), + ) + p.add_argument("--task-indices", default=os.environ.get("SWE_TASK_INDICES", "")) + p.add_argument( + "--repeat-tasks", + type=int, + default=int(os.environ.get("SWE_REPEAT_TASKS", "1")), + ) p.add_argument("--max-steps", type=int, default=10) p.add_argument("--max-turns", type=int, default=30) - p.add_argument("--sandbox-backend", default="hf", choices=["docker", "e2b", "hf"]) + p.add_argument( + "--num-generations", + type=int, + default=int(os.environ.get("SWE_NUM_GENERATIONS", "4")), + ) + p.add_argument( + "--temperature", + type=float, + default=float(os.environ.get("SWE_TEMPERATURE", "1.0")), + ) + p.add_argument( + "--learning-rate", + type=float, + default=float(os.environ.get("SWE_LEARNING_RATE", "1e-6")), + ) + p.add_argument( + "--max-completion-tokens", + type=int, + default=int(os.environ.get("SWE_MAX_COMPLETION_TOKENS", "2048")), + ) + p.add_argument( + "--sandbox-backend", + default=os.environ.get("SWE_SANDBOX_BACKEND", "hf"), + choices=["docker", "e2b", "hf", "local"], + ) p.add_argument("--vllm-url", default="http://localhost:8000") - p.add_argument("--agent", default="pi", choices=["pi", "opencode"]) + p.add_argument( + "--agent", + default=os.environ.get("SWE_AGENT", "pi"), + choices=["pi", "opencode"], + ) + p.add_argument( + "--agent-thinking", + default=os.environ.get("SWE_AGENT_THINKING", "off"), + ) return p.parse_args() @@ -145,6 +197,72 @@ def _int_env(name: str, default: int, *, min_value: int = 1) -> int: return value +def _float_env(name: str, default: float, *, min_value: float = 0.0) -> float: + raw = os.environ.get(name) + if raw is None or not raw.strip(): + return default + value = float(raw) + if value < min_value: + raise ValueError(f"{name} must be >= {min_value}, got {value}") + return value + + +def _parse_task_indices(raw: str) -> list[int]: + text = raw.strip() + if not text: + return [] + indices: list[int] = [] + for piece in text.split(","): + piece = piece.strip() + if not piece: + continue + indices.append(int(piece)) + if not indices: + raise ValueError("task-indices must contain at least one integer") + return indices + + +def _select_task_indices( + *, + total_tasks: int, + max_tasks: int, + task_offset: int, + task_stride: int, + task_indices_raw: str, + repeat_tasks: int, +) -> list[int]: + if total_tasks <= 0: + raise ValueError("SWE-Gym task source is empty") + if max_tasks < 1: + raise ValueError(f"max_tasks must be >= 1, got {max_tasks}") + if task_offset < 0: + raise ValueError(f"task_offset must be >= 0, got {task_offset}") + if task_stride < 1: + raise ValueError(f"task_stride must be >= 1, got {task_stride}") + if repeat_tasks < 1: + raise ValueError(f"repeat_tasks must be >= 1, got {repeat_tasks}") + + base_indices = _parse_task_indices(task_indices_raw) + if not base_indices: + base_indices = list(range(task_offset, total_tasks, task_stride))[:max_tasks] + + if not base_indices: + raise ValueError( + "task selection produced no tasks; adjust max-tasks/task-offset/task-stride" + ) + + for idx in base_indices: + if idx < 0 or idx >= total_tasks: + raise IndexError( + f"task index {idx} out of range [0, {total_tasks - 1}]" + ) + + selected_indices: list[int] = [] + for _ in range(repeat_tasks): + selected_indices.extend(base_indices) + return selected_indices + + def _derive_checkpoint_repo_id() -> str | None: explicit = os.environ.get("SWE_HUB_MODEL_ID", "").strip() if explicit: @@ -223,6 +341,379 @@ def _is_missing_checkpoint_error(exc: Exception) -> bool: return any(hint in msg for hint in hints) +def _is_main_process() -> bool: + """Return true for the single rank that owns rollout infrastructure.""" + rank = os.environ.get("RANK") + if rank is not None and rank.strip(): + return rank.strip() == "0" + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None and local_rank.strip(): + return local_rank.strip() == "0" + return True + + +def _train_dtype_from_env() -> torch.dtype: + raw = os.environ.get("SWE_TRAIN_DTYPE", "bf16").strip().lower() + mapping = { + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp16": torch.float16, + "float16": torch.float16, + "fp32": torch.float32, + "float32": torch.float32, + } + dtype = mapping.get(raw) + if dtype is None: + raise ValueError( + "SWE_TRAIN_DTYPE must be one of " + f"{sorted(mapping)}, got {raw!r}" + ) + return dtype + + +def _optional_env(name: str) -> str | None: + value = os.environ.get(name, "").strip() + return value or None + + +def _csv_env(name: str) -> tuple[str, ...]: + raw = os.environ.get(name, "").strip() + if not raw: + return () + values = tuple(piece.strip() for piece in raw.split(",") if piece.strip()) + if not values: + raise ValueError(f"{name} must contain at least one non-empty value") + return values + + +def _count_trainable_parameters(model: torch.nn.Module) -> int: + return sum(param.numel() for param in model.parameters() if param.requires_grad) + + +def _single_gpu_trainable_param_limit() -> int: + raw = os.environ.get("SWE_SINGLE_GPU_TRAINABLE_PARAM_LIMIT", "").strip() + if not raw: + return 9_000_000_000 + value = int(raw) + if value < 1: + raise ValueError( + f"SWE_SINGLE_GPU_TRAINABLE_PARAM_LIMIT must be >= 1, got {value}" + ) + return value + + +def _model_context_limit(model_name: str) -> int: + cfg = AutoConfig.from_pretrained(model_name) + cfg_sections = [ + cfg, + getattr(cfg, "text_config", None), + getattr(cfg, "llm_config", None), + ] + for section in cfg_sections: + if section is None: + continue + for key in ("max_position_embeddings", "model_max_length", "max_seq_len", "seq_length"): + if isinstance(section, dict): + candidate = section.get(key) + else: + candidate = getattr(section, key, None) + if isinstance(candidate, int) and candidate > 0: + return candidate + raise ValueError( + f"Could not infer a positive context limit for model {model_name!r}" + ) + + +def _default_lora_target_modules(model_name: str) -> tuple[str, ...]: + model_id = model_name.lower().replace("_", "") + if any( + token in model_id + for token in ("qwen", "llama", "mistral", "mixtral", "deepseek") + ): + return ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ) + return ("q_proj", "k_proj", "v_proj", "o_proj") + + +def _sanitize_lora_merged_weight_name(name: str) -> str | None: + cleaned = name.removeprefix("module.").removeprefix("base_model.model.") + if ( + ".lora_" in cleaned + or ".modules_to_save." in cleaned + or ".original_module." in cleaned + ): + return None + return cleaned.replace(".base_layer.", ".") + + +def _lora_config_from_env(model_name: str) -> tuple[Any | None, str]: + if not _bool_env("SWE_LORA", False): + return None, "full" + + try: + from peft import LoraConfig, TaskType + except ImportError as exc: # pragma: no cover - exercised in real runtime + raise RuntimeError( + "SWE_LORA=1 requires `peft` to be installed in the active environment." + ) from exc + + rank = _int_env("SWE_LORA_R", 16) + alpha = _int_env("SWE_LORA_ALPHA", rank * 2) + dropout = _float_env("SWE_LORA_DROPOUT", 0.0) + target_modules = list( + _csv_env("SWE_LORA_TARGET_MODULES") + or _default_lora_target_modules(model_name) + ) + modules_to_save = list(_csv_env("SWE_LORA_MODULES_TO_SAVE")) or None + bias = os.environ.get("SWE_LORA_BIAS", "none").strip() or "none" + + config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=rank, + lora_alpha=alpha, + lora_dropout=dropout, + target_modules=target_modules, + modules_to_save=modules_to_save, + bias=bias, + use_rslora=_bool_env("SWE_LORA_USE_RSLORA", False), + ) + summary = ( + "lora(" + f"r={rank},alpha={alpha},dropout={dropout:g}," + f"targets={','.join(target_modules)}" + ")" + ) + return config, summary + + +@contextlib.contextmanager +def _patched_async_grpo_model_loader( + *, + dtype: torch.dtype, + attn_implementation: str | None, + lora_config: Any | None = None, +): + """Force AsyncGRPOTrainer to load the policy in the requested dtype. + + TRL's experimental AsyncGRPOTrainer currently hardcodes + ``AutoModelForCausalLM.from_pretrained(..., dtype=torch.float32)``, + which makes 8B full-finetuning overflow 80GB H100s on the first + optimizer step. Patch the loader just around trainer construction so + the example can run with bf16/fp16 policy weights. + """ + + original = AutoModelForCausalLM.__dict__["from_pretrained"] + + def _patched( + cls, + pretrained_model_name_or_path: str, + *args: Any, + **kwargs: Any, + ): + requested_dtype = kwargs.get("dtype") + if requested_dtype in {None, torch.float32, "float32", "fp32"}: + kwargs["dtype"] = dtype + if attn_implementation and not kwargs.get("attn_implementation"): + kwargs["attn_implementation"] = attn_implementation + kwargs.setdefault("low_cpu_mem_usage", True) + model = original.__get__(None, cls)( + pretrained_model_name_or_path, + *args, + **kwargs, + ) + if lora_config is not None: + from peft import get_peft_model + + model = get_peft_model(model, lora_config) + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + return model + + AutoModelForCausalLM.from_pretrained = classmethod(_patched) + try: + yield + finally: + AutoModelForCausalLM.from_pretrained = original + + +def _patch_lora_weight_streaming(trainer: AsyncGRPOTrainer) -> None: + def _streaming_iter_with_merged_lora(self: AsyncGRPOTrainer): + device = self.accelerator.device + unwrap_model = getattr(self.accelerator, "unwrap_model", None) + model = ( + unwrap_model(self.model) + if callable(unwrap_model) + else self.model + ) + merge_adapter = getattr(model, "merge_adapter", None) + unmerge_adapter = getattr(model, "unmerge_adapter", None) + get_base_model = getattr(model, "get_base_model", None) + if not callable(merge_adapter) or not callable(unmerge_adapter): + raise RuntimeError( + "SWE_LORA requires a PEFT model with merge_adapter()/unmerge_adapter() " + "support so vLLM can receive merged weights." + ) + if not callable(get_base_model): + raise RuntimeError( + "SWE_LORA weight sync requires get_base_model() on the PEFT wrapper." + ) + + merge_adapter() + try: + base_model = get_base_model() + for name, param in base_model.named_parameters(): + mapped_name = _sanitize_lora_merged_weight_name(name) + if mapped_name is None: + continue + full = param.full_tensor() if hasattr(param, "full_tensor") else param.detach() + if full.device != device: + full = full.to(device) + yield mapped_name, full + finally: + unmerge_adapter() + + trainer._streaming_iter = types.MethodType( # type: ignore[method-assign] + _streaming_iter_with_merged_lora, + trainer, + ) + + +def _patch_peft_lora_resume_tp_sharding_for_ddp(model: torch.nn.Module) -> None: + """Avoid PEFT's TP-only resume hook for ordinary DDP LoRA training.""" + try: + import peft.utils.save_and_load as peft_save_and_load + except ImportError: + return + + original = getattr(peft_save_and_load, "_maybe_shard_state_dict_for_tp", None) + if not callable(original) or getattr(original, "_swe_ddp_guard", False): + return + + def _has_hf_tensor_parallel_plan(model_arg: torch.nn.Module) -> bool: + for module in model_arg.modules(): + get_base_layer = getattr(module, "get_base_layer", None) + base_layer = get_base_layer() if callable(get_base_layer) else module + if ( + getattr(base_layer, "_hf_tp_plan", None) is not None + and getattr(base_layer, "_hf_device_mesh", None) is not None + ): + return True + return False + + def _guarded_maybe_shard_state_dict_for_tp( + model_arg: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + adapter_name: str, + ) -> None: + if _has_hf_tensor_parallel_plan(model_arg): + return original(model_arg, state_dict, adapter_name) + return None + + _guarded_maybe_shard_state_dict_for_tp._swe_ddp_guard = True # type: ignore[attr-defined] + _guarded_maybe_shard_state_dict_for_tp._swe_original = original # type: ignore[attr-defined] + peft_save_and_load._maybe_shard_state_dict_for_tp = ( # type: ignore[method-assign] + _guarded_maybe_shard_state_dict_for_tp + ) + + +def _chunked_logprob_backbone(model: torch.nn.Module) -> torch.nn.Module: + backbone = getattr(model, "model", model) + if hasattr(backbone, "lm_head") and hasattr(backbone, "model"): + return backbone.model + return backbone + + +def _patch_chunked_lm_head_for_wrapped_causal_lm( + model: torch.nn.Module, + *, + temperature: float, + chunk_size: int = 8192, +) -> None: + def _chunked_forward( + self: torch.nn.Module, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + completion_mask: torch.Tensor | None = None, + use_cache: bool = False, + **kwargs: Any, + ) -> dict[str, torch.Tensor]: + assert labels is not None, "requires labels to not be None for logprob computation" + + outputs = _chunked_logprob_backbone(self)( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + **kwargs, + ) + logit_scale = getattr(self.config, "logit_scale", 1.0) + hidden_states = getattr(outputs, "last_hidden_state", None) + if hidden_states is None: + all_hidden_states = getattr(outputs, "hidden_states", None) + if not all_hidden_states: + raise AttributeError( + "Chunked GRPO forward requires `last_hidden_state` or " + "`hidden_states` on the model outputs." + ) + hidden_states = all_hidden_states[-1] + + hidden_states = hidden_states[:, :-1, :] + labels = labels[:, 1:] + + b, s, h = hidden_states.shape + hidden_flat = hidden_states.reshape(b * s, h).contiguous() + targets_flat = labels.reshape(b * s).contiguous() + + valid_mask = None + if completion_mask is not None: + completion_mask = completion_mask[:, 1:] + valid_mask = completion_mask.bool().reshape(b * s) + hidden_flat = hidden_flat[valid_mask] + targets_flat = targets_flat[valid_mask] + + logprobs_valid, entropy_valid = _ChunkedLogProbFunction.apply( + hidden_flat, + self.lm_head.weight, + targets_flat, + temperature, + chunk_size, + logit_scale, + ) + + if valid_mask is not None: + logprobs = torch.zeros( + b * s, + device=logprobs_valid.device, + dtype=logprobs_valid.dtype, + ) + entropy = torch.zeros( + b * s, + device=entropy_valid.device, + dtype=entropy_valid.dtype, + ) + logprobs[valid_mask] = logprobs_valid + entropy[valid_mask] = entropy_valid + else: + logprobs = logprobs_valid + entropy = entropy_valid + + return { + "log_probs": logprobs.reshape(b, s), + "entropy": entropy.reshape(b, s), + } + + model.forward = types.MethodType(_chunked_forward, model) + + def main() -> int: logging.basicConfig( level=logging.INFO, @@ -230,14 +721,45 @@ def main() -> int: datefmt="%H:%M:%S", ) args = _args() + if args.num_generations < 2: + raise ValueError( + f"--num-generations must be >= 2 for GRPO, got {args.num_generations}" + ) model = _env("SWE_MODEL") vllm_url = args.vllm_url vllm_key = os.environ.get("VLLM_API_KEY", "token").strip() + train_dtype = _train_dtype_from_env() + train_attn_implementation = _optional_env("SWE_TRAIN_ATTN_IMPLEMENTATION") + lora_config, adapter_summary = _lora_config_from_env(model) + if lora_config is not None and _bool_env("SWE_DISABLE_WEIGHT_TRANSFER", False): + raise RuntimeError( + "SWE_LORA requires SWE_DISABLE_WEIGHT_TRANSFER=0 so rollout generation " + "tracks the trained policy." + ) # ── Load tasks ──────────────────────────────────────────────── - gym_tasks = load_swegym_tasks(args.task_variant)[: args.max_tasks] + all_gym_tasks = load_swegym_tasks(args.task_variant) + selected_indices = _select_task_indices( + total_tasks=len(all_gym_tasks), + max_tasks=args.max_tasks, + task_offset=args.task_offset, + task_stride=args.task_stride, + task_indices_raw=args.task_indices, + repeat_tasks=args.repeat_tasks, + ) + gym_tasks = [all_gym_tasks[idx] for idx in selected_indices] swe_tasks = [t.to_swe_task() for t in gym_tasks] - _log.info("loaded %d tasks", len(swe_tasks)) + unique_indices = sorted(set(selected_indices)) + _log.info( + "loaded %d task slots (%d unique tasks) indices=%s", + len(swe_tasks), + len(unique_indices), + selected_indices, + ) + _log.info( + "task instances=%s", + [all_gym_tasks[idx].instance_id for idx in unique_indices], + ) # ── Dataset (prompt per task) ───────────────────────────────── dataset = Dataset.from_list( @@ -254,43 +776,93 @@ def main() -> int: tokenizer = AutoTokenizer.from_pretrained(model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + model_context_limit = _model_context_limit(model) # ── Interception control plane (background thread) ──────────── - control_cfg = SWEAsyncControlPlaneConfig.from_env() - control_plane = SWEAsyncControlPlane(config=control_cfg) - server_loop, server_thread = start_interception_server(control_plane) - _log.info("InterceptionServer running in background thread") + is_main_process = _is_main_process() + control_plane: SWEAsyncControlPlane | None = None + server_loop: asyncio.AbstractEventLoop | None = None + server_thread: threading.Thread | None = None + if is_main_process: + control_cfg = SWEAsyncControlPlaneConfig.from_env() + control_plane = SWEAsyncControlPlane(config=control_cfg) + server_loop, server_thread = start_interception_server(control_plane) + _log.info("InterceptionServer running in background thread") try: # ── Session factory (Pi in sandbox) ─────────────────────── - backend = create_sandbox_backend(args.sandbox_backend) - session_factory = SWESessionFactory( - agent=args.agent, - config=SWEAgentConfig( - base_url=control_plane.interception_base_url, - api_key=control_plane.auth_token, - model=model, - agent_timeout_s=1800.0, - ), - sandbox_backend=backend, - mode="interception_gate", - interception_server=control_plane.server, - interception_base_url=control_plane.interception_base_url, - ) + worker: SWERolloutWorker | None = None + if is_main_process: + assert control_plane is not None + backend_kwargs: dict[str, Any] = {} + if args.sandbox_backend == "hf": + backend_kwargs = { + "create_retries": _int_env("SWE_HF_SANDBOX_CREATE_RETRIES", 6), + "create_backoff_s": _float_env( + "SWE_HF_SANDBOX_CREATE_BACKOFF_S", 20.0 + ), + } + elif args.sandbox_backend == "local": + backend_kwargs = { + "root_dir": os.environ.get("OPENENV_LOCAL_SANDBOX_ROOT", "").strip() + or None, + "preserve_root": _bool_env( + "OPENENV_LOCAL_SANDBOX_PRESERVE", + False, + ), + } + default_rollout_inflight = min( + args.num_generations, + 1 if args.sandbox_backend == "hf" else max(2, args.num_generations), + ) + backend = create_sandbox_backend(args.sandbox_backend, **backend_kwargs) + session_factory = SWESessionFactory( + agent=args.agent, + config=SWEAgentConfig( + base_url=control_plane.interception_base_url, + api_key=control_plane.auth_token, + model=model, + agent_timeout_s=_float_env("SWE_AGENT_TIMEOUT_S", 1800.0), + thinking=args.agent_thinking, + ), + sandbox_backend=backend, + mode="interception_gate", + interception_server=control_plane.server, + interception_base_url=control_plane.interception_base_url, + ) - # ── Rollout worker ──────────────────────────────────────── - worker = SWERolloutWorker( - session_factory=session_factory, - tasks=swe_tasks, - tokenizer=tokenizer, - vllm_base_url=vllm_url, - vllm_api_key=vllm_key, - vllm_model=model, - config=WorkerConfig( - max_inflight=2, - max_turns=args.max_turns, - ), - ) + # ── Rollout worker ──────────────────────────────────────── + worker = SWERolloutWorker( + session_factory=session_factory, + tasks=swe_tasks, + tokenizer=tokenizer, + vllm_base_url=vllm_url, + vllm_api_key=vllm_key, + vllm_model=model, + config=WorkerConfig( + max_inflight=_int_env( + "SWE_ROLLOUT_MAX_INFLIGHT", + default_rollout_inflight, + ), + max_rollout_attempts=_int_env( + "SWE_ROLLOUT_MAX_ATTEMPTS", + 4, + ), + num_generations=args.num_generations, + request_timeout_s=_float_env("SWE_ROLLOUT_REQUEST_TIMEOUT_S", 600.0), + max_turns=args.max_turns, + max_model_len=_int_env( + "SWE_VLLM_MAX_MODEL_LEN", + model_context_limit, + ), + max_completion_tokens=args.max_completion_tokens, + temperature=args.temperature, + failure_backoff_s=_float_env( + "SWE_ROLLOUT_FAILURE_BACKOFF_S", + 30.0, + ), + ), + ) # ── Trainer ─────────────────────────────────────────────── def _noop_reward(**kwargs: Any) -> list[float]: @@ -301,30 +873,74 @@ def _noop_reward(**kwargs: Any) -> list[float]: checkpoint_args, resume_from_checkpoint, checkpoint_requested = ( _build_checkpoint_args() ) + train_optim = os.environ.get("SWE_OPTIM", "").strip() or "paged_adamw_8bit" async_grpo_args: dict[str, Any] = { "output_dir": os.path.join( os.environ.get("HOME", "/tmp"), "outputs/swe_async_grpo" ), "vllm_server_base_url": vllm_url, - "vllm_server_timeout": 2400.0, - "max_completion_length": 2048, + "vllm_server_timeout": _float_env("SWE_ROLLOUT_QUEUE_TIMEOUT_S", 900.0), + "max_completion_length": args.max_completion_tokens, "max_steps": args.max_steps, + # Online rollout samples are newly generated after resume; do not + # discard live GRPO groups to match the previous dataloader offset. + "ignore_data_skip": _bool_env("SWE_IGNORE_DATA_SKIP", True), "per_device_train_batch_size": 1, "gradient_accumulation_steps": 1, - "num_generations": 1, - "learning_rate": 1e-6, - "temperature": 1.0, - "optim": "adamw_bnb_8bit", - "bf16": True, + "num_generations": args.num_generations, + "learning_rate": args.learning_rate, + "temperature": args.temperature, + "optim": train_optim, + "bf16": train_dtype == torch.bfloat16, + "fp16": train_dtype == torch.float16, "gradient_checkpointing": True, - "max_staleness": 4, - "weight_sync_steps": 1, - "max_inflight_tasks": 2, + "torch_empty_cache_steps": _int_env( + "SWE_TORCH_EMPTY_CACHE_STEPS", + 1, + ), + "max_staleness": _int_env( + "SWE_ASYNC_MAX_STALENESS", + max(16, args.num_generations * 4), + ), + "weight_sync_steps": _int_env("SWE_ASYNC_WEIGHT_SYNC_STEPS", 1), + "max_inflight_tasks": _int_env( + "SWE_ASYNC_MAX_INFLIGHT_TASKS", + args.num_generations, + ), + "queue_maxsize": _int_env("SWE_ASYNC_QUEUE_MAXSIZE", 64), "logging_steps": 1, "report_to": "trackio", "run_name": f"swe-grpo-{model.split('/')[-1]}", - "trackio_space_id": os.environ.get("TRACKIO_SPACE_ID", "").strip() or None, + "project": ( + os.environ.get("SWE_TRACKIO_PROJECT", "").strip() + or os.environ.get("TRACKIO_PROJECT", "").strip() + or "huggingface" + ), + "trackio_space_id": ( + os.environ.get("SWE_TRACKIO_SPACE_ID", "").strip() + or os.environ.get("TRACKIO_SPACE_ID", "").strip() + or None + ), + "log_completions": _bool_env("SWE_LOG_COMPLETIONS", False), + "num_completions_to_print": _int_env( + "SWE_LOG_COMPLETIONS_LIMIT", + 3, + ), + "accelerator_config": { + "split_batches": True, + "dispatch_batches": True, + }, } + async_grpo_args.update( + _filter_async_grpo_kwargs( + { + # PEFT LoRA + DDP can mark the same adapter parameter ready + # twice with reentrant checkpointing during GRPO backward. + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "ddp_find_unused_parameters": False, + } + ) + ) filtered_checkpoint_args = _filter_async_grpo_kwargs(checkpoint_args) async_grpo_args.update(filtered_checkpoint_args) @@ -336,21 +952,55 @@ def _noop_reward(**kwargs: Any) -> list[float]: ) resume_from_checkpoint = None - trainer = AsyncGRPOTrainer( - model=model, - reward_funcs=_noop_reward, - train_dataset=dataset, - processing_class=tokenizer, - rollout_worker=worker, - args=AsyncGRPOConfig(**async_grpo_args), - ) + trainer_args = AsyncGRPOConfig(**async_grpo_args) + with _patched_async_grpo_model_loader( + dtype=train_dtype, + attn_implementation=train_attn_implementation, + lora_config=lora_config, + ): + trainer = AsyncGRPOTrainer( + model=model, + reward_funcs=_noop_reward, + train_dataset=dataset, + processing_class=tokenizer, + rollout_worker=worker, + args=trainer_args, + ) + + trainable_params = _count_trainable_parameters(trainer.model) + trainer_world_size = trainer.accelerator.num_processes + if lora_config is not None: + _patch_lora_weight_streaming(trainer) + _patch_peft_lora_resume_tp_sharding_for_ddp(trainer.model) + _patch_chunked_lm_head_for_wrapped_causal_lm( + trainer.model, + temperature=args.temperature, + ) + single_gpu_param_limit = _single_gpu_trainable_param_limit() + if trainer_world_size == 1 and trainable_params > single_gpu_param_limit: + raise RuntimeError( + "Unsupported full-parameter Async GRPO config for a single training GPU: " + f"model={model} trainable_params={trainable_params:,} " + f"limit={single_gpu_param_limit:,}. " + "Use a smaller base model, add sharded training, or reduce trainable " + "state before launching this example." + ) _log.info( - "starting training: model=%s tasks=%d checkpointing=%s resume=%s", + "starting training: model=%s tasks=%d generations=%d temp=%.2f lr=%g optim=%s adapter=%s dtype=%s attn_impl=%s checkpointing=%s resume=%s trainable_params=%s world_size=%d", model, len(swe_tasks), + args.num_generations, + args.temperature, + args.learning_rate, + train_optim, + adapter_summary, + str(train_dtype).split(".")[-1], + train_attn_implementation or "auto", checkpoint_enabled, resume_from_checkpoint or "none", + f"{trainable_params:,}", + trainer_world_size, ) if resume_from_checkpoint is None: trainer.train() @@ -369,7 +1019,11 @@ def _noop_reward(**kwargs: Any) -> list[float]: else: raise - if checkpoint_enabled and hasattr(trainer, "push_to_hub"): + if ( + checkpoint_enabled + and hasattr(trainer, "push_to_hub") + and trainer.is_world_process_zero() + ): trainer.push_to_hub( commit_message=f"Final checkpoint at step {getattr(trainer.state, 'global_step', '?')}" ) @@ -377,7 +1031,12 @@ def _noop_reward(**kwargs: Any) -> list[float]: _log.info("done: step=%s", getattr(trainer.state, "global_step", "?")) return 0 finally: - stop_interception_server(control_plane, server_loop, server_thread) + if ( + control_plane is not None + and server_loop is not None + and server_thread is not None + ): + stop_interception_server(control_plane, server_loop, server_thread) if __name__ == "__main__": diff --git a/src/openenv/core/harness/sandbox/__init__.py b/src/openenv/core/harness/sandbox/__init__.py index 208fe54d5..e11d9c812 100644 --- a/src/openenv/core/harness/sandbox/__init__.py +++ b/src/openenv/core/harness/sandbox/__init__.py @@ -18,6 +18,7 @@ from .base import BgJob, ExecResult, SandboxBackend, SandboxHandle from .docker_backend import DockerBgJob, DockerSandboxBackend, DockerSandboxHandle +from .local_backend import LocalBgJob, LocalSandboxBackend, LocalSandboxHandle __all__ = [ "BgJob", @@ -25,6 +26,9 @@ "DockerSandboxBackend", "DockerSandboxHandle", "ExecResult", + "LocalBgJob", + "LocalSandboxBackend", + "LocalSandboxHandle", "SandboxBackend", "SandboxHandle", "create_sandbox_backend", @@ -46,7 +50,7 @@ def create_sandbox_backend( - backend: Literal["e2b", "docker", "hf"] = "e2b", + backend: Literal["e2b", "docker", "hf", "local"] = "e2b", **kwargs: Any, ) -> SandboxBackend: """Create a sandbox backend by name. @@ -57,6 +61,8 @@ def create_sandbox_backend( For ``"docker"``: local Docker, no external dependencies. For ``"hf"``: Hugging Face Jobs via ``hf-sandbox``. + + For ``"local"``: isolated temp directories and subprocesses on the host. """ if backend == "e2b": from .e2b_backend import E2BSandboxBackend @@ -68,6 +74,8 @@ def create_sandbox_backend( from .hf_backend import HFSandboxBackend return HFSandboxBackend(**kwargs) + elif backend == "local": + return LocalSandboxBackend(**kwargs) raise ValueError( - f"Unknown sandbox backend: {backend!r}. Use 'e2b', 'docker', or 'hf'." + f"Unknown sandbox backend: {backend!r}. Use 'e2b', 'docker', 'hf', or 'local'." ) diff --git a/src/openenv/core/harness/sandbox/local_backend.py b/src/openenv/core/harness/sandbox/local_backend.py new file mode 100644 index 000000000..62725e3a6 --- /dev/null +++ b/src/openenv/core/harness/sandbox/local_backend.py @@ -0,0 +1,294 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Local subprocess implementation of :class:`SandboxBackend`. + +Each sandbox gets an isolated temp root on the host filesystem plus an +optional per-sandbox virtualenv. Commands execute directly on the trainer +host, which makes this backend suitable for rootless cluster environments +where Docker/HF Jobs are unavailable. +""" + +from __future__ import annotations + +import os +import shutil +import signal +import subprocess +import sys +import tempfile +import time +import uuid +from pathlib import Path + +from openenv.core.harness.sandbox.base import BgJob, ExecResult, SandboxHandle + +_CANONICAL_HOME = "/home/user" +_CANONICAL_WORKDIR = "/testbed" + + +class LocalBgJob: + """Handle to a background subprocess launched in a local sandbox.""" + + def __init__(self, proc: subprocess.Popen[str]) -> None: + self._proc = proc + + @property + def pid(self) -> int: + return int(self._proc.pid) + + def wait(self, timeout: float | None = None) -> int: + try: + return int(self._proc.wait(timeout=timeout)) + except subprocess.TimeoutExpired as exc: + raise TimeoutError( + f"Background command (pid={self._proc.pid}) did not exit within {timeout}s" + ) from exc + + def kill(self) -> None: + if self._proc.poll() is not None: + return + try: + os.killpg(self._proc.pid, signal.SIGTERM) + except ProcessLookupError: + return + except Exception: + self._proc.terminate() + return + + deadline = time.monotonic() + 5.0 + while self._proc.poll() is None and time.monotonic() < deadline: + time.sleep(0.1) + if self._proc.poll() is None: + try: + os.killpg(self._proc.pid, signal.SIGKILL) + except Exception: + self._proc.kill() + + +class LocalSandboxHandle: + """Host-backed sandbox handle with an isolated temp root.""" + + supports_images = False + + def __init__( + self, + *, + root_dir: str, + home_dir: str, + workdir: str, + tmp_dir: str, + default_envs: dict[str, str] | None = None, + preserve_root: bool = False, + ) -> None: + self._root_dir = root_dir + self._home_dir = home_dir + self._workdir = workdir + self._tmp_dir = tmp_dir + self._default_envs = dict(default_envs or {}) + self._preserve_root = preserve_root + self._bg_jobs: list[LocalBgJob] = [] + self._sandbox_id = f"local-{uuid.uuid4().hex[:12]}" + + @property + def sandbox_id(self) -> str: + return self._sandbox_id + + @property + def sandbox_home(self) -> str: + return self._home_dir + + @property + def workdir(self) -> str: + return self._workdir + + @property + def tmp_dir(self) -> str: + return self._tmp_dir + + def exec( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + timeout: float | None = 60, + ) -> ExecResult: + run_env = self._build_env(envs) + resolved_cwd = self._resolve_cwd(cwd) + try: + result = subprocess.run( + ["bash", "-lc", cmd], + cwd=resolved_cwd, + env=run_env, + capture_output=True, + text=True, + timeout=timeout, + ) + return ExecResult( + exit_code=int(result.returncode), + stdout=result.stdout, + stderr=result.stderr, + ) + except subprocess.TimeoutExpired: + return ExecResult( + exit_code=-1, + stdout="", + stderr=f"Command timed out after {timeout}s", + ) + except Exception as exc: + return ExecResult(exit_code=-1, stdout="", stderr=str(exc)) + + def start_bg( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> BgJob: + proc = subprocess.Popen( + ["bash", "-lc", cmd], + cwd=self._resolve_cwd(cwd), + env=self._build_env(envs), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + preexec_fn=os.setsid, + ) + job = LocalBgJob(proc) + self._bg_jobs.append(job) + return job + + def write_text(self, path: str, content: str) -> None: + resolved = Path(self._resolve_path(path)) + resolved.parent.mkdir(parents=True, exist_ok=True) + resolved.write_text(content) + + def read_text(self, path: str) -> str: + return Path(self._resolve_path(path)).read_text() + + def exists(self, path: str) -> bool: + return Path(self._resolve_path(path)).exists() + + def kill(self) -> None: + for job in self._bg_jobs: + try: + job.kill() + except Exception: + pass + self._bg_jobs.clear() + if not self._preserve_root: + shutil.rmtree(self._root_dir, ignore_errors=True) + + def _resolve_cwd(self, cwd: str | None) -> str: + candidate = self._resolve_path(cwd) if cwd else self._workdir + Path(candidate).mkdir(parents=True, exist_ok=True) + return candidate + + def _resolve_path(self, path: str | None) -> str: + if not path: + return self._workdir + if path == _CANONICAL_HOME or path.startswith(f"{_CANONICAL_HOME}/"): + suffix = path[len(_CANONICAL_HOME) :].lstrip("/") + return str(Path(self._home_dir) / suffix) if suffix else self._home_dir + if path == _CANONICAL_WORKDIR or path.startswith(f"{_CANONICAL_WORKDIR}/"): + suffix = path[len(_CANONICAL_WORKDIR) :].lstrip("/") + return str(Path(self._workdir) / suffix) if suffix else self._workdir + return path + + def _build_env(self, envs: dict[str, str] | None) -> dict[str, str]: + merged = os.environ.copy() + merged.update(self._default_envs) + merged.update(envs or {}) + merged.setdefault("HOME", self._home_dir) + merged.setdefault("TMPDIR", self._tmp_dir) + merged.setdefault("PIP_CACHE_DIR", str(Path(self._home_dir) / ".cache" / "pip")) + merged.setdefault("XDG_CACHE_HOME", str(Path(self._home_dir) / ".cache")) + merged.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") + return merged + + +class LocalSandboxBackend: + """Create host-local sandboxes rooted in unique temp directories.""" + + supports_images = False + + def __init__( + self, + *, + root_dir: str | None = None, + create_virtualenv: bool = True, + python_executable: str | None = None, + preserve_root: bool = False, + ) -> None: + self._root_dir = root_dir + self._create_virtualenv = create_virtualenv + self._python_executable = python_executable or sys.executable + self._preserve_root = preserve_root + + def create( + self, + *, + timeout_s: int = 900, + envs: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + image: str | None = None, + ) -> SandboxHandle: + del timeout_s, metadata, image + + if self._root_dir: + Path(self._root_dir).mkdir(parents=True, exist_ok=True) + root = tempfile.mkdtemp(prefix="openenv_local_", dir=self._root_dir) + home = str(Path(root) / "home") + workdir = str(Path(root) / "testbed") + tmp_dir = str(Path(root) / "tmp") + Path(home).mkdir(parents=True, exist_ok=True) + Path(workdir).mkdir(parents=True, exist_ok=True) + Path(tmp_dir).mkdir(parents=True, exist_ok=True) + + default_envs = dict(envs or {}) + default_envs.setdefault("HOME", home) + default_envs.setdefault("TMPDIR", tmp_dir) + default_envs.setdefault( + "PIP_CACHE_DIR", + os.environ.get("PIP_CACHE_DIR", str(Path(home) / ".cache" / "pip")), + ) + default_envs.setdefault( + "XDG_CACHE_HOME", + os.environ.get("XDG_CACHE_HOME", str(Path(home) / ".cache")), + ) + default_envs.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") + + venv_dir = Path(home) / "venv" + if self._create_virtualenv: + subprocess.run( + [self._python_executable, "-m", "venv", str(venv_dir)], + check=True, + capture_output=True, + text=True, + timeout=120, + ) + path_prefix = str(venv_dir / "bin") + default_envs["VIRTUAL_ENV"] = str(venv_dir) + default_envs["PATH"] = ( + f"{path_prefix}:{os.environ.get('PATH', '')}".rstrip(":") + ) + + return LocalSandboxHandle( + root_dir=root, + home_dir=home, + workdir=workdir, + tmp_dir=tmp_dir, + default_envs=default_envs, + preserve_root=self._preserve_root, + ) + + +__all__ = [ + "LocalBgJob", + "LocalSandboxBackend", + "LocalSandboxHandle", +] diff --git a/tests/core/test_local_sandbox_backend.py b/tests/core/test_local_sandbox_backend.py new file mode 100644 index 000000000..eed2ef1d8 --- /dev/null +++ b/tests/core/test_local_sandbox_backend.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from pathlib import Path + +from openenv.core.harness.sandbox.local_backend import LocalSandboxBackend + + +def test_local_sandbox_backend_basic_lifecycle(tmp_path): + backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=False) + sandbox = backend.create(envs={"OPENENV_TEST": "1"}) + + root = Path(sandbox.sandbox_home).parent + try: + result = sandbox.exec( + 'printf "%s|%s|%s" "$HOME" "$TMPDIR" "$OPENENV_TEST"', + cwd="/testbed", + ) + assert result.exit_code == 0 + assert result.stdout == f"{sandbox.sandbox_home}|{sandbox.tmp_dir}|1" + + sandbox.write_text("/testbed/hello.txt", "hello\n") + assert sandbox.read_text("/testbed/hello.txt") == "hello\n" + assert sandbox.exists("/testbed/hello.txt") + + job = sandbox.start_bg("sleep 0.1", cwd="/testbed") + assert job.wait(timeout=2.0) == 0 + finally: + sandbox.kill() + + assert not root.exists() + + +def test_local_sandbox_backend_creates_virtualenv(tmp_path): + backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=True) + sandbox = backend.create() + + try: + result = sandbox.exec( + 'python -c "import os,sys; print(sys.prefix); print(os.environ.get(\'VIRTUAL_ENV\', \'\'))"' + ) + assert result.exit_code == 0 + lines = result.stdout.strip().splitlines() + assert len(lines) == 2 + assert lines[0] == f"{sandbox.sandbox_home}/venv" + assert lines[1] == f"{sandbox.sandbox_home}/venv" + finally: + sandbox.kill() + + +def test_local_sandbox_backend_creates_missing_root_dir(tmp_path): + root_dir = tmp_path / "nested" / "sandboxes" + backend = LocalSandboxBackend(root_dir=str(root_dir), create_virtualenv=False) + sandbox = backend.create() + + try: + assert root_dir.exists() + assert Path(sandbox.sandbox_home).parent.parent == root_dir + finally: + sandbox.kill() + + +def test_local_sandbox_backend_inherits_host_cache_envs(tmp_path, monkeypatch): + monkeypatch.setenv("PIP_CACHE_DIR", "/shared/pip-cache") + monkeypatch.setenv("XDG_CACHE_HOME", "/shared/xdg-cache") + + backend = LocalSandboxBackend(root_dir=str(tmp_path), create_virtualenv=False) + sandbox = backend.create() + + try: + result = sandbox.exec( + 'printf "%s|%s" "$PIP_CACHE_DIR" "$XDG_CACHE_HOME"', + cwd="/testbed", + ) + assert result.exit_code == 0 + assert result.stdout == "/shared/pip-cache|/shared/xdg-cache" + finally: + sandbox.kill() diff --git a/tests/envs/test_swe_async_rollout_worker.py b/tests/envs/test_swe_async_rollout_worker.py index a8bb0e04c..24c26c15f 100644 --- a/tests/envs/test_swe_async_rollout_worker.py +++ b/tests/envs/test_swe_async_rollout_worker.py @@ -1,19 +1,128 @@ +import asyncio import sys from pathlib import Path +from types import SimpleNamespace + +import pytest _ROOT = Path(__file__).resolve().parents[2] -if str(_ROOT) not in sys.path: - sys.path.insert(0, str(_ROOT)) +for _p in (_ROOT, _ROOT / "src", _ROOT / "envs"): + if str(_p) not in sys.path: + sys.path.insert(0, str(_p)) from examples.mini_swe_env.async_grpo.rollout_worker import ( + _OMITTED_TOOL_OUTPUT_MARKER, + SWERolloutWorker, + WorkerConfig, + _clamp_max_completion_tokens, + _compute_group_advantages, + _coerce_token_ids, + _extract_xml_tool_calls, + _extract_chat_choice_logprobs, + _extract_prompt_token_ids, + _fit_messages_to_context_window, _get_tools, _has_answer_call, + _is_context_window_error, + _is_retriable_rollout_error, _is_terminal_non_tool_response, _make_chat_response, + _normalize_chat_choice_message, _normalize_tool_calls, + _retry_completion_tokens_from_context_error, + _truncate_messages_for_prompt_budget, + _truncate_text_middle, ) +def test_rollout_grades_partial_work_on_request_idle_timeout( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(SWERolloutWorker, "_init_weight_transfer", lambda self: None) + + class _Session: + answer_called = False + + def __init__(self) -> None: + self.calls = 0 + self.delivered = [] + self.closed = False + + async def next_request(self, timeout_s: float) -> dict[str, object] | None: + self.calls += 1 + if self.calls == 1: + return { + "messages": [{"role": "user", "content": "fix it"}], + "body": {}, + } + raise TimeoutError("no request within timeout") + + async def deliver( + self, + intercept: dict[str, object], + response: dict[str, object], + ) -> None: + self.delivered.append((intercept, response)) + + def verify(self, transcript: list[object]) -> object: + return SimpleNamespace(env_reward=0.5) + + def close(self) -> None: + self.closed = True + + session = _Session() + + class _Factory: + def create(self, *, task: object, episode_id: str) -> _Session: + return session + + class _Tokenizer: + pad_token_id = 0 + + def apply_chat_template(self, messages: object, **kwargs: object) -> list[int]: + return [10, 11] + + worker = SWERolloutWorker( + session_factory=_Factory(), + tasks=[SimpleNamespace(instance_id="repo__issue-1", instruction="fix it")], + tokenizer=_Tokenizer(), + vllm_base_url="http://vllm", + vllm_api_key="token", + vllm_model="model", + config=WorkerConfig(request_timeout_s=0.01), + ) + monkeypatch.setattr( + worker, + "_generate", + lambda **kwargs: ( + [10, 11], + [12, 13], + [-0.2, -0.3], + _make_chat_response( + {"role": "assistant", "content": "ran a command"}, + model="model", + finish_reason="tool_calls", + ), + "tool_calls", + ), + ) + + sample = asyncio.run( + worker._rollout( # type: ignore[attr-defined] + worker._tasks[0], + "episode-1", + model_version=3, + ) + ) + + assert sample is not None + assert sample.reward == pytest.approx(0.5) + assert sample.model_version == 3 + assert sample.metrics["turns"] == 1.0 + assert sample.metrics["request_idle_timeout"] == 1.0 + assert session.closed is True + + def test_normalize_tool_calls_serializes_arguments_to_json_string() -> None: calls = _normalize_tool_calls( [ @@ -68,3 +177,167 @@ def test_get_tools_from_intercept_or_body() -> None: assert _get_tools({"tools": [tool_schema]}) == [tool_schema] assert _get_tools({"body": {"tools": [tool_schema]}}) == [tool_schema] assert _get_tools({}) is None + + +def test_coerce_token_ids_and_prompt_token_ids_are_int_lists() -> None: + assert _coerce_token_ids([1, "2", 3]) == [1, 2, 3] + assert _coerce_token_ids(["bad"]) == [] + assert _extract_prompt_token_ids({"prompt_token_ids": [4, "5"]}) == [4, 5] + + +def test_extract_chat_choice_logprobs_pads_missing_values() -> None: + choice = { + "logprobs": { + "content": [ + {"logprob": -0.1}, + {}, + {"logprob": -0.3}, + ] + } + } + assert _extract_chat_choice_logprobs(choice, expected_len=4) == pytest.approx( + [-0.1, 0.0, -0.3, 0.0] + ) + + +def test_extract_xml_tool_calls_recovers_qwen_style_blocks() -> None: + content, tool_calls = _extract_xml_tool_calls( + 'Working...\n\n{"name": "answer", "arguments": {}}\n\nDone.' + ) + assert content == "Working...\n\nDone." + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "answer" + assert tool_calls[0]["function"]["arguments"] == "{}" + + +def test_normalize_chat_choice_message_parses_xml_tool_calls_from_content() -> None: + message = _normalize_chat_choice_message( + tokenizer=object(), + choice={ + "message": { + "role": "assistant", + "content": '\n{"name": "answer", "arguments": {}}\n', + "tool_calls": [], + } + }, + completion_ids=[], + ) + assert message["content"] == "" + assert message["tool_calls"][0]["function"]["name"] == "answer" + assert message["tool_calls"][0]["function"]["arguments"] == "{}" + + +def test_clamp_max_completion_tokens_reserves_context_margin() -> None: + assert _clamp_max_completion_tokens( + prompt_len=2561, + requested=1536, + max_model_len=4096, + ) == 1519 + + +def test_retry_completion_tokens_from_context_error_parses_vllm_message() -> None: + error_text = ( + "This model's maximum context length is 4096 tokens. However, you " + "requested 1536 output tokens and your prompt contains at least 2561 " + "input tokens, for a total of at least 4097 tokens." + ) + assert _retry_completion_tokens_from_context_error(error_text) == ( + 1519, + 4096, + 2561, + ) + + +def test_truncate_text_middle_preserves_edges() -> None: + text = "A" * 80 + "B" * 80 + truncated, changed = _truncate_text_middle(text, max_chars=64) + assert changed is True + assert len(truncated) <= 64 + assert truncated.startswith("A" * 10) + assert truncated.endswith("B" * 8) + + +def test_truncate_messages_for_prompt_budget_only_trims_long_tool_output() -> None: + messages, tool_truncations, assistant_truncations = ( + _truncate_messages_for_prompt_budget( + [ + {"role": "user", "content": "short"}, + {"role": "tool", "content": "x" * 200}, + {"role": "assistant", "content": "ok"}, + ], + max_tool_message_chars=64, + max_assistant_message_chars=64, + ) + ) + assert tool_truncations == 1 + assert assistant_truncations == 0 + assert messages[0]["content"] == "short" + assert len(messages[1]["content"]) <= 64 + assert messages[2]["content"] == "ok" + + +def test_fit_messages_to_context_window_omits_oldest_tool_output_when_needed() -> None: + def render_prompt_ids( + messages: list[dict[str, object]], + tools: list[dict[str, object]] | None, + ) -> list[int]: + total = 0 + for message in messages: + content = message.get("content") + if isinstance(content, str): + total += len(content) + return list(range(total)) + + messages, prompt_ids = _fit_messages_to_context_window( + messages=[ + {"role": "user", "content": "task"}, + {"role": "tool", "content": "x" * 220}, + {"role": "assistant", "content": "thinking"}, + {"role": "tool", "content": "y" * 220}, + ], + tools=None, + render_prompt_ids=render_prompt_ids, + requested_completion_tokens=40, + max_model_len=116, + max_tool_message_chars=64, + min_tool_message_chars=32, + max_assistant_message_chars=32, + min_assistant_message_chars=16, + ) + assert len(prompt_ids) <= 60 + assert messages[1]["content"] == _OMITTED_TOOL_OUTPUT_MARKER + assert messages[3]["content"] == _OMITTED_TOOL_OUTPUT_MARKER + + +def test_context_window_errors_are_not_marked_retriable() -> None: + exc = RuntimeError( + "vllm 400: This model's maximum context length is 8192 tokens. " + "However, you requested 1 output tokens and your prompt contains " + "at least 8192 input tokens." + ) + assert _is_context_window_error(exc) is True + assert _is_retriable_rollout_error(exc) is False + + +def test_transient_sandbox_errors_are_marked_retriable() -> None: + exc = RuntimeError("HF sandbox tunnel failed: 429 Too Many Requests") + assert _is_context_window_error(exc) is False + assert _is_retriable_rollout_error(exc) is True + + +def test_compute_group_advantages_zscores_rewards() -> None: + advantages, reward_mean, reward_std = _compute_group_advantages( + [0.0, 0.0, 1.0, 1.0] + ) + assert reward_mean == pytest.approx(0.5) + assert reward_std == pytest.approx(0.5) + assert advantages == pytest.approx([-1.0, -1.0, 1.0, 1.0], abs=1e-6) + + +def test_compute_group_advantages_returns_zero_for_constant_rewards() -> None: + advantages, reward_mean, reward_std = _compute_group_advantages( + [1.0, 1.0, 1.0, 1.0] + ) + assert reward_mean == pytest.approx(1.0) + assert reward_std == pytest.approx(0.0) + assert advantages == pytest.approx([0.0, 0.0, 0.0, 0.0], abs=1e-6) diff --git a/tests/envs/test_swe_grading.py b/tests/envs/test_swe_grading.py new file mode 100644 index 000000000..7c0ca198b --- /dev/null +++ b/tests/envs/test_swe_grading.py @@ -0,0 +1,182 @@ +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest + +_ROOT = Path(__file__).resolve().parents[2] +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) +_ENVS = _ROOT / "envs" +if str(_ENVS) not in sys.path: + sys.path.insert(0, str(_ENVS)) + +from mini_swe_env.grading import GradingError, grade_from_case_results +from mini_swe_env.harness import ( + HOME, + SWESession, + SWEAgentConfig, + SWESessionFactory, + _wrap_instruction, +) +from mini_swe_env.models import SWEGymTask, SWETask +from openenv.core.harness.sandbox.base import ExecResult + + +def _task() -> SWEGymTask: + return SWEGymTask( + instance_id="demo__task-1", + repo="demo/repo", + base_commit="deadbeef", + problem_statement="Fix the bug.", + version="1.0", + patch="", + test_patch="", + FAIL_TO_PASS=["tests/test_a.py::test_fix"], + PASS_TO_PASS=["tests/test_b.py::test_regression"], + ) + + +def test_grade_from_case_results_binary_mode_stays_sparse() -> None: + grade = grade_from_case_results( + _task(), + { + "tests/test_a.py::test_fix": True, + "tests/test_b.py::test_regression": False, + }, + ) + assert grade.case_fraction == pytest.approx(0.5) + assert grade.reward == pytest.approx(0.0) + assert grade.resolved is False + + +def test_grade_from_case_results_case_fraction_mode_is_dense() -> None: + grade = grade_from_case_results( + _task(), + { + "tests/test_a.py::test_fix": True, + "tests/test_b.py::test_regression": False, + }, + reward_mode="case_fraction", + ) + assert grade.case_fraction == pytest.approx(0.5) + assert grade.reward == pytest.approx(0.5) + assert grade.resolved is False + + +def test_grade_from_case_results_rejects_unknown_reward_mode() -> None: + with pytest.raises(GradingError): + grade_from_case_results(_task(), {}, reward_mode="unknown") + + +def test_wrap_instruction_includes_optional_hints_block() -> None: + wrapped = _wrap_instruction( + "Fix the bug.", + hints_text="Look at module.py around line 10.", + workdir="/testbed", + ) + assert "" in wrapped + assert "Look at module.py around line 10." in wrapped + assert "You cannot continue working after submitting" in wrapped + assert "Each bash tool call runs in a fresh shell." in wrapped + assert "Test edits do not help." in wrapped + assert "Do not use `git commit`, `git branch`, or `git push`." in wrapped + assert "Start with the maintainer hints or issue description." in wrapped + + +def test_wrap_instruction_supports_fallback_only_grading_mode() -> None: + wrapped = _wrap_instruction( + "Fix the bug.", + workdir="/testbed", + answer_tool_enabled=False, + ) + assert "There is no `answer` tool in this run." in wrapped + assert "graded automatically when the session ends." in wrapped + assert "If the `answer` tool is available" in wrapped + + +def test_swe_session_verify_uses_host_fallback_grader() -> None: + task = SWETask( + task_id="task-1", + source="unit", + instance_id="demo__task-1", + repo="demo/repo", + base_commit="deadbeef", + instruction="Fix the bug.", + setup=[], + verify=[], + ) + session = SWESession( + swe_task=task, + spec=SimpleNamespace(name="opencode"), + sandbox=SimpleNamespace(), + task=SimpleNamespace(instruction="Fix the bug."), + config=SWEAgentConfig(sandbox_home=HOME, workdir="/testbed"), + ) + + def _grader(_sandbox, swe_task, *, home: str, workdir: str) -> tuple[float, bool]: + assert swe_task.instance_id == "demo__task-1" + assert home == HOME + assert workdir == "/testbed" + return 0.5, False + + session._fallback_grader = _grader + result = session.verify(transcript=[]) + + assert result.env_reward == pytest.approx(0.5) + assert result.metrics["reward_source"] == "host_verify_fallback" + assert result.metrics["resolved"] is False + + +def test_list_changed_test_paths_filters_to_test_like_files() -> None: + class _Sandbox: + def exec(self, cmd: str, *, cwd=None, timeout=None): # type: ignore[no-untyped-def] + if cmd == "git diff --name-only HEAD --": + return ExecResult( + exit_code=0, + stdout=( + "moto/ssm/models.py\n" + "tests/test_ssm/test_ssm_boto3.py\n" + "pkg/widget_test.py\n" + ), + stderr="", + ) + if cmd == "git ls-files --others --exclude-standard": + return ExecResult( + exit_code=0, + stdout="notes.txt\ntesting/helpers/new_case.py\n", + stderr="", + ) + raise AssertionError(f"unexpected command: {cmd}") + + paths = SWESessionFactory._list_changed_test_paths(_Sandbox(), workdir="/testbed") + assert paths == [ + "pkg/widget_test.py", + "testing/helpers/new_case.py", + "tests/test_ssm/test_ssm_boto3.py", + ] + + +def test_prepare_repo_uses_configurable_checkout_timeout(monkeypatch) -> None: + monkeypatch.setenv("SWE_GIT_CHECKOUT_TIMEOUT_S", "240") + calls = [] + + class _Sandbox: + def exec(self, cmd: str, *, cwd=None, timeout=None): # type: ignore[no-untyped-def] + calls.append((cmd, cwd, timeout)) + return ExecResult(exit_code=0, stdout="", stderr="") + + task = SimpleNamespace(repo="demo/repo", base_commit="deadbeef") + + SWESessionFactory._prepare_repo( + object(), + _Sandbox(), + task, + workdir="/testbed", + ) + + assert ( + "git checkout --quiet deadbeef", + "/testbed", + 240, + ) in calls diff --git a/tests/envs/test_train_swe_async_grpo.py b/tests/envs/test_train_swe_async_grpo.py new file mode 100644 index 000000000..3eacfcb60 --- /dev/null +++ b/tests/envs/test_train_swe_async_grpo.py @@ -0,0 +1,340 @@ +import sys +from pathlib import Path + +import pytest +import torch +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + +from peft import LoraConfig, TaskType, get_peft_model + +_ROOT = Path(__file__).resolve().parents[2] +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) + +import examples.mini_swe_env.train_swe_async_grpo as train_mod +from examples.mini_swe_env.train_swe_async_grpo import ( + _chunked_logprob_backbone, + _default_lora_target_modules, + _filter_async_grpo_kwargs, + _lora_config_from_env, + _model_context_limit, + _patch_chunked_lm_head_for_wrapped_causal_lm, + _patch_lora_weight_streaming, + _patch_peft_lora_resume_tp_sharding_for_ddp, + _select_task_indices, + _sanitize_lora_merged_weight_name, +) + + +def test_select_task_indices_uses_offset_stride_and_repeat() -> None: + selected = _select_task_indices( + total_tasks=20, + max_tasks=3, + task_offset=2, + task_stride=4, + task_indices_raw="", + repeat_tasks=2, + ) + assert selected == [2, 6, 10, 2, 6, 10] + + +def test_select_task_indices_prefers_explicit_indices() -> None: + selected = _select_task_indices( + total_tasks=20, + max_tasks=5, + task_offset=0, + task_stride=1, + task_indices_raw="16,3,16", + repeat_tasks=1, + ) + assert selected == [16, 3, 16] + + +def test_select_task_indices_rejects_out_of_range_index() -> None: + with pytest.raises(IndexError): + _select_task_indices( + total_tasks=5, + max_tasks=2, + task_offset=0, + task_stride=1, + task_indices_raw="0,5", + repeat_tasks=1, + ) + + +def test_model_context_limit_prefers_first_positive_candidate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _Cfg: + max_position_embeddings = 40960 + model_max_length = 32768 + + monkeypatch.setattr( + train_mod.AutoConfig, + "from_pretrained", + lambda model_name: _Cfg(), + ) + + assert _model_context_limit("Qwen/Qwen3-8B") == 40960 + + +def test_model_context_limit_reads_nested_text_config( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _TextCfg: + max_position_embeddings = 262144 + + class _Cfg: + max_position_embeddings = None + model_max_length = None + max_seq_len = None + seq_length = None + text_config = _TextCfg() + + monkeypatch.setattr( + train_mod.AutoConfig, + "from_pretrained", + lambda model_name: _Cfg(), + ) + + assert _model_context_limit("Qwen/Qwen3.5-4B") == 262144 + + +def test_model_context_limit_rejects_missing_values( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class _Cfg: + max_position_embeddings = None + model_max_length = None + max_seq_len = None + seq_length = None + + monkeypatch.setattr( + train_mod.AutoConfig, + "from_pretrained", + lambda model_name: _Cfg(), + ) + + with pytest.raises(ValueError): + _model_context_limit("missing-context-model") + + +def test_default_lora_target_modules_for_qwen() -> None: + assert _default_lora_target_modules("Qwen/Qwen3-14B") == ( + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ) + + +def test_filter_async_grpo_kwargs_accepts_ddp_lora_checkpointing_controls() -> None: + filtered = _filter_async_grpo_kwargs( + { + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "ddp_find_unused_parameters": False, + "not_a_training_arg": True, + } + ) + + assert filtered["gradient_checkpointing_kwargs"] == {"use_reentrant": False} + assert filtered["ddp_find_unused_parameters"] is False + assert "not_a_training_arg" not in filtered + + +def test_sanitize_lora_merged_weight_name_strips_wrapper_paths() -> None: + assert ( + _sanitize_lora_merged_weight_name( + "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight" + ) + == "model.layers.0.self_attn.q_proj.weight" + ) + assert ( + _sanitize_lora_merged_weight_name( + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" + ) + is None + ) + + +def test_lora_config_from_env_returns_full_when_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("SWE_LORA", raising=False) + config, summary = _lora_config_from_env("Qwen/Qwen3-8B") + assert config is None + assert summary == "full" + + +def test_chunked_logprob_backbone_uses_decoder_for_wrapped_causal_lm() -> None: + cfg = Qwen3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + ) + base_model = Qwen3ForCausalLM(cfg) + peft_model = get_peft_model( + base_model, + LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=4, + lora_alpha=8, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ), + ) + + assert _chunked_logprob_backbone(base_model) is base_model.model + assert _chunked_logprob_backbone(peft_model) is peft_model.model.model + + +def test_patch_chunked_lm_head_for_wrapped_causal_lm_handles_peft_model() -> None: + cfg = Qwen3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + ) + peft_model = get_peft_model( + Qwen3ForCausalLM(cfg), + LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=4, + lora_alpha=8, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ), + ) + _patch_chunked_lm_head_for_wrapped_causal_lm( + peft_model, + temperature=1.0, + chunk_size=32, + ) + + input_ids = torch.randint(0, cfg.vocab_size, (2, 8)) + attention_mask = torch.ones_like(input_ids) + outputs = peft_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + completion_mask=attention_mask, + use_cache=False, + ) + + assert outputs["log_probs"].shape == (2, 7) + assert outputs["entropy"].shape == (2, 7) + + +def test_patch_lora_weight_streaming_unwraps_distributed_model() -> None: + cfg = Qwen3Config( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + ) + peft_model = get_peft_model( + Qwen3ForCausalLM(cfg), + LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=4, + lora_alpha=8, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ), + ) + + class _Wrapped(torch.nn.Module): + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self.module = module + + class _Accelerator: + device = torch.device("cpu") + + @staticmethod + def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + return model.module + + class _Trainer: + model = _Wrapped(peft_model) + accelerator = _Accelerator() + + trainer = _Trainer() + _patch_lora_weight_streaming(trainer) # type: ignore[arg-type] + streamed = dict(trainer._streaming_iter()) # type: ignore[attr-defined] + + assert any(name.endswith("self_attn.q_proj.weight") for name in streamed) + assert all(".lora_" not in name for name in streamed) + + +def test_patch_peft_lora_resume_tp_sharding_skips_plain_ddp_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import peft.utils.save_and_load as save_and_load + + def _raise_if_called( + model: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + adapter_name: str, + ) -> None: + raise AssertionError("TP sharding should not run for plain DDP models") + + monkeypatch.setattr( + save_and_load, + "_maybe_shard_state_dict_for_tp", + _raise_if_called, + ) + + model = torch.nn.Linear(2, 2) + _patch_peft_lora_resume_tp_sharding_for_ddp(model) + save_and_load._maybe_shard_state_dict_for_tp(model, {}, "default") + + +def test_patch_peft_lora_resume_tp_sharding_preserves_tp_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import peft.utils.save_and_load as save_and_load + + called = False + + def _mark_called( + model: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + adapter_name: str, + ) -> None: + nonlocal called + called = True + + monkeypatch.setattr( + save_and_load, + "_maybe_shard_state_dict_for_tp", + _mark_called, + ) + + class _Base(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self._hf_tp_plan = "colwise" + self._hf_device_mesh = object() + + class _LoraLike(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.base = _Base() + + def get_base_layer(self) -> torch.nn.Module: + return self.base + + model = _LoraLike() + _patch_peft_lora_resume_tp_sharding_for_ddp(model) + save_and_load._maybe_shard_state_dict_for_tp(model, {}, "default") + + assert called