Skip to content
16 changes: 14 additions & 2 deletions envs/coding_env/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Coding environment server components."""
"""Coding environment server components.

from .python_codeact_env import PythonCodeActEnv
Keep imports lazy so utility modules (for example transforms) remain importable
without pulling optional runtime dependencies like smolagents.
"""

from typing import Any

__all__ = ["PythonCodeActEnv"]


def __getattr__(name: str) -> Any:
if name == "PythonCodeActEnv":
from .python_codeact_env import PythonCodeActEnv

return PythonCodeActEnv
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
3 changes: 2 additions & 1 deletion envs/coding_env/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
python -m envs.coding_env.server.app
"""

from openenv.core.env_server import create_app

from coding_env.models import CodeAction, CodeObservation
from coding_env.server.python_codeact_env import PythonCodeActEnv
from openenv.core.env_server import create_app

# Create the app with web interface and README integration
# Pass the class (factory) instead of an instance for WebSocket session support
Expand Down
27 changes: 24 additions & 3 deletions envs/coding_env/server/python_codeact_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import uuid
from typing import Any, Optional

from openenv.core.env_server.interfaces import Action, Environment, Observation

Expand Down Expand Up @@ -50,15 +51,32 @@ def __init__(
self._executor = PyExecutor()
self._state = CodeState()

def reset(self) -> Observation:
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> Observation:
"""
Reset environment and start fresh execution session.

Args:
seed: Accepted for API compatibility. This deterministic executor
has no random state to seed.
episode_id: Optional episode identifier override.
**kwargs: Forward-compatible reset parameters accepted by the base
Environment API but unused by this environment.

Returns:
Initial observation with empty stdout/stderr and exit_code=0
"""
del seed, kwargs

# Initialize fresh state
self._state = CodeState(episode_id=str(uuid.uuid4()), step_count=0)
self._state = CodeState(
episode_id=episode_id if episode_id is not None else str(uuid.uuid4()),
step_count=0,
)
# Add last_exit_code to state
self._state.last_exit_code = 0

Expand All @@ -77,7 +95,10 @@ def reset(self) -> Observation:

return self._apply_transform(observation)

def step(self, action: Action) -> Observation:
def step(
self,
action: Action,
) -> Observation:
"""
Execute code action and return observation.

Expand Down
69 changes: 50 additions & 19 deletions envs/coding_env/server/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""Transforms specific to coding environments."""

import ast
import re

from openenv.core.env_server.base_transforms import CompositeTransform
from openenv.core.env_server.interfaces import Transform
Expand All @@ -17,33 +16,65 @@


class CodeSafetyTransform(Transform):
"""Evaluates code safety and assigns penalties for dangerous patterns."""
"""
Assign penalties for obviously unsafe coding patterns.

This is a reward heuristic, not a security sandbox. Container isolation is
the security boundary; this transform only shapes rewards for common cases.
"""

def __init__(self, penalty: float = -1.0):
self.penalty = penalty
self.dangerous_patterns = [
r"import\s+os",
r"import\s+subprocess",
r"eval\(",
r"exec\(",
r"__import__",
r"open\(",
]

def _detect_violation(self, code: str) -> str | None:
"""
Detect dangerous operations using AST analysis.

AST-based detection avoids false positives from harmless string literals
(e.g. ``print("import os")``) or similarly named user functions
(e.g. ``myopen()``).
"""
try:
tree = ast.parse(code)
except (SyntaxError, RecursionError, ValueError):
# Intentional trade-off: once the code is syntactically invalid or
# pathologically nested, this AST-only safety pass cannot reliably
# inspect partial code. CodeQualityTransform applies the syntax
# penalty instead.
return None
Comment thread
abhinavgautam01 marked this conversation as resolved.

for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
top_level_module = alias.name.split(".", 1)[0]
if top_level_module in {"os", "subprocess"}:
return f"import {top_level_module}"

if isinstance(node, ast.ImportFrom) and node.module:
top_level_module = node.module.split(".", 1)[0]
if top_level_module in {"os", "subprocess"}:
return f"import {top_level_module}"

if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
called_name = node.func.id
if called_name in {"eval", "exec", "open", "__import__"}:
return called_name

return None

def __call__(self, observation: Observation) -> Observation:
if not isinstance(observation, CodeObservation):
return observation

if "last_code" in observation.metadata:
code = observation.metadata["last_code"]
for pattern in self.dangerous_patterns:
if re.search(pattern, code):
observation.reward = self.penalty
observation.metadata["safety_violation"] = pattern
break
else:
if observation.reward is None:
observation.reward = 0.0
violation = self._detect_violation(code)
if violation is not None:
observation.reward = self.penalty
observation.metadata["safety_violation"] = violation
elif observation.reward is None:
observation.reward = 0.0

return observation

Expand Down Expand Up @@ -77,7 +108,7 @@ def __call__(self, observation: Observation) -> Observation:
# Check syntax (redundant but useful for quality assessment)
try:
ast.parse(code)
except SyntaxError:
except (SyntaxError, RecursionError, ValueError):
quality_score += self.syntax_penalty

# Add to existing reward
Expand Down
119 changes: 119 additions & 0 deletions tests/envs/test_coding_safety_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for coding_env safety transform false-positive handling."""

from coding_env.models import CodeObservation
from coding_env.server.transforms import CodeQualityTransform, CodeSafetyTransform


def _apply_safety_transform(code: str) -> CodeObservation:
transform = CodeSafetyTransform()
observation = CodeObservation(
stdout="",
stderr="",
exit_code=0,
metadata={"last_code": code},
)
transformed = transform(observation)
assert isinstance(transformed, CodeObservation)
return transformed


def test_blocks_real_dangerous_import():
observation = _apply_safety_transform("import os\nprint('x')")
assert observation.reward == -1.0
assert "safety_violation" in observation.metadata


def test_blocks_subprocess_import():
observation = _apply_safety_transform("import subprocess")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "import subprocess"


def test_blocks_from_subprocess_import():
observation = _apply_safety_transform("from subprocess import run")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "import subprocess"


def test_blocks_from_os_path_import():
observation = _apply_safety_transform("from os.path import join")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "import os"


def test_blocks_builtin_open_call():
observation = _apply_safety_transform(
"with open('f.txt') as f:\n data = f.read()"
)
assert observation.reward == -1.0
assert "safety_violation" in observation.metadata


def test_blocks_builtin_eval_call():
observation = _apply_safety_transform("result = eval('1 + 1')")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "eval"


def test_blocks_builtin_exec_call():
observation = _apply_safety_transform("exec('x = 1')")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "exec"


def test_blocks_builtin_import_call():
observation = _apply_safety_transform("__import__('os')")
assert observation.reward == -1.0
assert observation.metadata["safety_violation"] == "__import__"


def test_does_not_flag_string_literal_with_dangerous_text():
observation = _apply_safety_transform("print('import os')")
assert observation.reward == 0.0
assert "safety_violation" not in observation.metadata


def test_does_not_flag_user_defined_myopen_function():
observation = _apply_safety_transform(
"def myopen():\n return 1\nresult = myopen()"
)
assert observation.reward == 0.0
assert "safety_violation" not in observation.metadata


def test_does_not_flag_attribute_method_named_exec():
observation = _apply_safety_transform(
"class DB:\n"
" def exec(self, sql):\n"
" return sql\n"
"db = DB()\n"
"result = db.exec('SELECT 1')"
)
assert observation.reward == 0.0
assert "safety_violation" not in observation.metadata


def test_quality_transform_handles_ast_recursion_error(monkeypatch):
def raise_recursion_error(_code: str):
raise RecursionError("pathologically nested code")

monkeypatch.setattr("coding_env.server.transforms.ast.parse", raise_recursion_error)

transform = CodeQualityTransform(concise_bonus=0.0, syntax_penalty=-0.2)
observation = CodeObservation(
stdout="",
stderr="",
exit_code=0,
metadata={"last_code": "x = 1"},
)

transformed = transform(observation)

assert isinstance(transformed, CodeObservation)
assert transformed.reward == -0.2
20 changes: 20 additions & 0 deletions tests/envs/test_python_codeact_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,23 @@ def test_reset_changes_episode_id():

# Episode IDs should be different
assert episode_id_1 != episode_id_2


def test_reset_accepts_episode_id_override():
"""Test that reset() accepts an explicit episode_id."""
env = PythonCodeActEnv()

env.reset(episode_id="episode-123")

assert env.state.episode_id == "episode-123"
assert env.state.step_count == 0


def test_reset_preserves_empty_episode_id_override():
"""Test that reset() preserves any explicit non-None episode_id."""
env = PythonCodeActEnv()

env.reset(episode_id="")

assert env.state.episode_id == ""
assert env.state.step_count == 0