diff --git a/envs/repl_env/rubrics.py b/envs/repl_env/rubrics.py index 7c677f6d4..f971bb50d 100644 --- a/envs/repl_env/rubrics.py +++ b/envs/repl_env/rubrics.py @@ -21,6 +21,12 @@ from typing import Any, Callable from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.components import ( + RewardComponent, + RewardComponentType, + aggregate_weighted_sum, + serialize_reward_components, +) class ExactMatchRubric(Rubric): @@ -173,16 +179,61 @@ def set_expected(self, expected: str | None) -> None: if hasattr(self.outcome, "set_expected"): self.outcome.set_expected(expected) - def forward(self, action: Any, observation: Any) -> float: + def evaluate_components(self, action: Any, observation: Any) -> list[RewardComponent]: + """Compute structured reward components for the current step.""" done = getattr(observation, "done", False) if done: final = getattr(observation, "metadata", {}).get("final_answer") if final is not None: - return self.outcome(action, observation) - # Done but no final answer (max iterations exhausted) - return self.failure_reward - # Non-terminal step: process reward only - return self.process(action, observation) + return [ + RewardComponent( + name="outcome_match", + type=RewardComponentType.BINARY, + value=float(self.outcome(action, observation)), + weight=1.0, + terminal_only=True, + ) + ] + return [ + RewardComponent( + name="max_iterations_failure", + type=RewardComponentType.PENALTY, + value=float(self.failure_reward), + weight=1.0, + terminal_only=True, + ) + ] + + return [ + RewardComponent( + name="code_execution_quality", + type=RewardComponentType.SHAPING, + value=float(self.process(action, observation)), + weight=1.0, + terminal_only=False, + ) + ] + + def _attach_reward_metadata( + self, + observation: Any, + components: list[RewardComponent], + total: float, + ) -> None: + """Attach component diagnostics to observation metadata if available.""" + if not hasattr(observation, "metadata"): + return + metadata = getattr(observation, "metadata", None) or {} + metadata["reward_components"] = serialize_reward_components(components) + metadata["reward_total"] = total + metadata["reward_aggregation"] = "weighted_sum_v1" + observation.metadata = metadata + + def forward(self, action: Any, observation: Any) -> float: + components = self.evaluate_components(action, observation) + total = aggregate_weighted_sum(components) + self._attach_reward_metadata(observation, components, total) + return total def reset(self) -> None: self.outcome.reset() diff --git a/src/openenv/core/rubrics/__init__.py b/src/openenv/core/rubrics/__init__.py index abe368494..d7a51849c 100644 --- a/src/openenv/core/rubrics/__init__.py +++ b/src/openenv/core/rubrics/__init__.py @@ -10,6 +10,12 @@ """ from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.components import ( + RewardComponent, + RewardComponentType, + aggregate_weighted_sum, + serialize_reward_components, +) from openenv.core.rubrics.containers import ( Gate, RubricDict, @@ -26,6 +32,11 @@ __all__ = [ # Base "Rubric", + # Components + "RewardComponent", + "RewardComponentType", + "aggregate_weighted_sum", + "serialize_reward_components", # Containers "Sequential", "Gate", diff --git a/src/openenv/core/rubrics/components.py b/src/openenv/core/rubrics/components.py new file mode 100644 index 000000000..23bab094e --- /dev/null +++ b/src/openenv/core/rubrics/components.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Reward component schema and helpers. + +This module provides a standard representation for decomposed reward signals +while preserving a scalar optimization target. +""" + +from enum import Enum +from typing import Any, Dict, List + +from pydantic import BaseModel, ConfigDict, Field + + +class RewardComponentType(str, Enum): + """Common reward component styles for training diagnostics.""" + + BINARY = "binary" + SPARSE = "sparse" + DENSE = "dense" + SHAPING = "shaping" + PENALTY = "penalty" + + +class RewardComponent(BaseModel): + """Structured reward component emitted by a rubric/environment.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., description="Stable component identifier") + type: RewardComponentType = Field(..., description="Reward component style") + value: float = Field(..., description="Raw component value before weighting") + weight: float = Field(default=1.0, description="Aggregation weight") + weighted_value: float | None = Field( + default=None, + description="Optional explicit weighted value (defaults to value * weight)", + ) + terminal_only: bool = Field( + default=False, + description="Whether this component is meaningful only at terminal steps", + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Additional component-specific diagnostics", + ) + + def effective_weighted_value(self) -> float: + """Return weighted component value, respecting explicit overrides.""" + if self.weighted_value is not None: + return self.weighted_value + return self.value * self.weight + + +def aggregate_weighted_sum(components: List[RewardComponent]) -> float: + """Aggregate reward components with weighted-sum semantics.""" + return float(sum(component.effective_weighted_value() for component in components)) + + +def serialize_reward_components(components: List[RewardComponent]) -> List[Dict[str, Any]]: + """Serialize reward components for observation metadata.""" + return [component.model_dump() for component in components] diff --git a/tests/core/test_rubrics/test_reward_components.py b/tests/core/test_rubrics/test_reward_components.py new file mode 100644 index 000000000..156d3fa16 --- /dev/null +++ b/tests/core/test_rubrics/test_reward_components.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for reward component schema and helpers.""" + +from openenv.core.rubrics.components import ( + RewardComponent, + RewardComponentType, + aggregate_weighted_sum, + serialize_reward_components, +) + + +class TestRewardComponent: + def test_weighted_value_defaults_to_value_times_weight(self): + component = RewardComponent( + name="progress", + type=RewardComponentType.DENSE, + value=0.6, + weight=0.5, + ) + assert component.effective_weighted_value() == 0.3 + + def test_explicit_weighted_value_takes_precedence(self): + component = RewardComponent( + name="safety_penalty", + type=RewardComponentType.PENALTY, + value=-1.0, + weight=0.2, + weighted_value=-0.5, + ) + assert component.effective_weighted_value() == -0.5 + + +class TestRewardComponentHelpers: + def test_aggregate_weighted_sum(self): + components = [ + RewardComponent( + name="success", + type=RewardComponentType.BINARY, + value=1.0, + weight=0.7, + ), + RewardComponent( + name="format", + type=RewardComponentType.SPARSE, + value=1.0, + weight=0.3, + ), + ] + assert aggregate_weighted_sum(components) == 1.0 + + def test_serialize_reward_components(self): + components = [ + RewardComponent( + name="step_quality", + type=RewardComponentType.SHAPING, + value=-0.05, + ) + ] + payload = serialize_reward_components(components) + assert len(payload) == 1 + assert payload[0]["name"] == "step_quality" + assert payload[0]["type"] == "shaping" + assert payload[0]["value"] == -0.05 diff --git a/tests/envs/test_repl_env.py b/tests/envs/test_repl_env.py index 4811119fb..855770665 100644 --- a/tests/envs/test_repl_env.py +++ b/tests/envs/test_repl_env.py @@ -289,6 +289,13 @@ def test_rubric_reward_on_success(self): obs = env.step(REPLAction(code="print('FINAL(done)')")) assert obs.done assert obs.reward == 1.0 + assert obs.metadata["reward_total"] == 1.0 + assert obs.metadata["reward_aggregation"] == "weighted_sum_v1" + assert len(obs.metadata["reward_components"]) == 1 + component = obs.metadata["reward_components"][0] + assert component["name"] == "outcome_match" + assert component["type"] == "binary" + assert component["terminal_only"] is True def test_rubric_reward_on_wrong_answer(self): """Test rubric reward when final answer does not match expected.""" @@ -299,6 +306,9 @@ def test_rubric_reward_on_wrong_answer(self): obs = env.step(REPLAction(code="print('FINAL(wrong)')")) assert obs.done assert obs.reward == 0.0 + assert obs.metadata["reward_total"] == 0.0 + assert obs.metadata["reward_components"][0]["name"] == "outcome_match" + assert obs.metadata["reward_components"][0]["type"] == "binary" def test_rubric_reward_on_error(self): """Test rubric process reward on code error.""" @@ -306,6 +316,10 @@ def test_rubric_reward_on_error(self): env.reset() obs = env.step(REPLAction(code="raise ValueError()")) assert obs.reward == -0.05 # default CodeExecutionRubric error_penalty + assert obs.metadata["reward_total"] == -0.05 + assert obs.metadata["reward_components"][0]["name"] == "code_execution_quality" + assert obs.metadata["reward_components"][0]["type"] == "shaping" + assert obs.metadata["reward_components"][0]["terminal_only"] is False def test_close(self): """Test close cleans up resources."""