diff --git a/envs/coding_env/server/__init__.py b/envs/coding_env/server/__init__.py index dab6b748a..41d01bba7 100644 --- a/envs/coding_env/server/__init__.py +++ b/envs/coding_env/server/__init__.py @@ -4,8 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Coding environment server components.""" +"""Coding environment server components. -from .python_codeact_env import PythonCodeActEnv +Keep imports lazy so utility modules (for example transforms) remain importable +without pulling optional runtime dependencies like smolagents. +""" + +from typing import Any __all__ = ["PythonCodeActEnv"] + + +def __getattr__(name: str) -> Any: + if name == "PythonCodeActEnv": + from .python_codeact_env import PythonCodeActEnv + + return PythonCodeActEnv + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/envs/coding_env/server/app.py b/envs/coding_env/server/app.py index 4c712916b..2271b69de 100644 --- a/envs/coding_env/server/app.py +++ b/envs/coding_env/server/app.py @@ -21,9 +21,10 @@ python -m envs.coding_env.server.app """ +from openenv.core.env_server import create_app + from coding_env.models import CodeAction, CodeObservation from coding_env.server.python_codeact_env import PythonCodeActEnv -from openenv.core.env_server import create_app # Create the app with web interface and README integration # Pass the class (factory) instead of an instance for WebSocket session support diff --git a/envs/coding_env/server/python_codeact_env.py b/envs/coding_env/server/python_codeact_env.py index dbfc39e6a..043838096 100644 --- a/envs/coding_env/server/python_codeact_env.py +++ b/envs/coding_env/server/python_codeact_env.py @@ -12,6 +12,7 @@ """ import uuid +from typing import Any, Optional from openenv.core.env_server.interfaces import Action, Environment, Observation @@ -50,15 +51,32 @@ def __init__( self._executor = PyExecutor() self._state = CodeState() - def reset(self) -> Observation: + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> Observation: """ Reset environment and start fresh execution session. + Args: + seed: Accepted for API compatibility. This deterministic executor + has no random state to seed. + episode_id: Optional episode identifier override. + **kwargs: Forward-compatible reset parameters accepted by the base + Environment API but unused by this environment. + Returns: Initial observation with empty stdout/stderr and exit_code=0 """ + del seed, kwargs + # Initialize fresh state - self._state = CodeState(episode_id=str(uuid.uuid4()), step_count=0) + self._state = CodeState( + episode_id=episode_id if episode_id is not None else str(uuid.uuid4()), + step_count=0, + ) # Add last_exit_code to state self._state.last_exit_code = 0 @@ -77,7 +95,10 @@ def reset(self) -> Observation: return self._apply_transform(observation) - def step(self, action: Action) -> Observation: + def step( + self, + action: Action, + ) -> Observation: """ Execute code action and return observation. diff --git a/envs/coding_env/server/transforms.py b/envs/coding_env/server/transforms.py index fc92e89ba..5baed87ce 100644 --- a/envs/coding_env/server/transforms.py +++ b/envs/coding_env/server/transforms.py @@ -7,7 +7,6 @@ """Transforms specific to coding environments.""" import ast -import re from openenv.core.env_server.base_transforms import CompositeTransform from openenv.core.env_server.interfaces import Transform @@ -17,18 +16,52 @@ class CodeSafetyTransform(Transform): - """Evaluates code safety and assigns penalties for dangerous patterns.""" + """ + Assign penalties for obviously unsafe coding patterns. + + This is a reward heuristic, not a security sandbox. Container isolation is + the security boundary; this transform only shapes rewards for common cases. + """ def __init__(self, penalty: float = -1.0): self.penalty = penalty - self.dangerous_patterns = [ - r"import\s+os", - r"import\s+subprocess", - r"eval\(", - r"exec\(", - r"__import__", - r"open\(", - ] + + def _detect_violation(self, code: str) -> str | None: + """ + Detect dangerous operations using AST analysis. + + AST-based detection avoids false positives from harmless string literals + (e.g. ``print("import os")``) or similarly named user functions + (e.g. ``myopen()``). + """ + try: + tree = ast.parse(code) + except (SyntaxError, RecursionError, ValueError): + # Intentional trade-off: once the code is syntactically invalid or + # pathologically nested, this AST-only safety pass cannot reliably + # inspect partial code. CodeQualityTransform applies the syntax + # penalty instead. + return None + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + top_level_module = alias.name.split(".", 1)[0] + if top_level_module in {"os", "subprocess"}: + return f"import {top_level_module}" + + if isinstance(node, ast.ImportFrom) and node.module: + top_level_module = node.module.split(".", 1)[0] + if top_level_module in {"os", "subprocess"}: + return f"import {top_level_module}" + + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + called_name = node.func.id + if called_name in {"eval", "exec", "open", "__import__"}: + return called_name + + return None def __call__(self, observation: Observation) -> Observation: if not isinstance(observation, CodeObservation): @@ -36,14 +69,12 @@ def __call__(self, observation: Observation) -> Observation: if "last_code" in observation.metadata: code = observation.metadata["last_code"] - for pattern in self.dangerous_patterns: - if re.search(pattern, code): - observation.reward = self.penalty - observation.metadata["safety_violation"] = pattern - break - else: - if observation.reward is None: - observation.reward = 0.0 + violation = self._detect_violation(code) + if violation is not None: + observation.reward = self.penalty + observation.metadata["safety_violation"] = violation + elif observation.reward is None: + observation.reward = 0.0 return observation @@ -77,7 +108,7 @@ def __call__(self, observation: Observation) -> Observation: # Check syntax (redundant but useful for quality assessment) try: ast.parse(code) - except SyntaxError: + except (SyntaxError, RecursionError, ValueError): quality_score += self.syntax_penalty # Add to existing reward diff --git a/tests/envs/test_coding_safety_transform.py b/tests/envs/test_coding_safety_transform.py new file mode 100644 index 000000000..9a0986f35 --- /dev/null +++ b/tests/envs/test_coding_safety_transform.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for coding_env safety transform false-positive handling.""" + +from coding_env.models import CodeObservation +from coding_env.server.transforms import CodeQualityTransform, CodeSafetyTransform + + +def _apply_safety_transform(code: str) -> CodeObservation: + transform = CodeSafetyTransform() + observation = CodeObservation( + stdout="", + stderr="", + exit_code=0, + metadata={"last_code": code}, + ) + transformed = transform(observation) + assert isinstance(transformed, CodeObservation) + return transformed + + +def test_blocks_real_dangerous_import(): + observation = _apply_safety_transform("import os\nprint('x')") + assert observation.reward == -1.0 + assert "safety_violation" in observation.metadata + + +def test_blocks_subprocess_import(): + observation = _apply_safety_transform("import subprocess") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import subprocess" + + +def test_blocks_from_subprocess_import(): + observation = _apply_safety_transform("from subprocess import run") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import subprocess" + + +def test_blocks_from_os_path_import(): + observation = _apply_safety_transform("from os.path import join") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "import os" + + +def test_blocks_builtin_open_call(): + observation = _apply_safety_transform( + "with open('f.txt') as f:\n data = f.read()" + ) + assert observation.reward == -1.0 + assert "safety_violation" in observation.metadata + + +def test_blocks_builtin_eval_call(): + observation = _apply_safety_transform("result = eval('1 + 1')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "eval" + + +def test_blocks_builtin_exec_call(): + observation = _apply_safety_transform("exec('x = 1')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "exec" + + +def test_blocks_builtin_import_call(): + observation = _apply_safety_transform("__import__('os')") + assert observation.reward == -1.0 + assert observation.metadata["safety_violation"] == "__import__" + + +def test_does_not_flag_string_literal_with_dangerous_text(): + observation = _apply_safety_transform("print('import os')") + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_does_not_flag_user_defined_myopen_function(): + observation = _apply_safety_transform( + "def myopen():\n return 1\nresult = myopen()" + ) + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_does_not_flag_attribute_method_named_exec(): + observation = _apply_safety_transform( + "class DB:\n" + " def exec(self, sql):\n" + " return sql\n" + "db = DB()\n" + "result = db.exec('SELECT 1')" + ) + assert observation.reward == 0.0 + assert "safety_violation" not in observation.metadata + + +def test_quality_transform_handles_ast_recursion_error(monkeypatch): + def raise_recursion_error(_code: str): + raise RecursionError("pathologically nested code") + + monkeypatch.setattr("coding_env.server.transforms.ast.parse", raise_recursion_error) + + transform = CodeQualityTransform(concise_bonus=0.0, syntax_penalty=-0.2) + observation = CodeObservation( + stdout="", + stderr="", + exit_code=0, + metadata={"last_code": "x = 1"}, + ) + + transformed = transform(observation) + + assert isinstance(transformed, CodeObservation) + assert transformed.reward == -0.2 diff --git a/tests/envs/test_python_codeact_reset.py b/tests/envs/test_python_codeact_reset.py index b4d8b59f1..55bd9c03b 100644 --- a/tests/envs/test_python_codeact_reset.py +++ b/tests/envs/test_python_codeact_reset.py @@ -166,3 +166,23 @@ def test_reset_changes_episode_id(): # Episode IDs should be different assert episode_id_1 != episode_id_2 + + +def test_reset_accepts_episode_id_override(): + """Test that reset() accepts an explicit episode_id.""" + env = PythonCodeActEnv() + + env.reset(episode_id="episode-123") + + assert env.state.episode_id == "episode-123" + assert env.state.step_count == 0 + + +def test_reset_preserves_empty_episode_id_override(): + """Test that reset() preserves any explicit non-None episode_id.""" + env = PythonCodeActEnv() + + env.reset(episode_id="") + + assert env.state.episode_id == "" + assert env.state.step_count == 0