Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/skyrl-experiments/read-only.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ tools:
- glob
- grep
- terminal
- localization_finish

prompts:
system_prompt: "templates/system_prompt.j2"
Expand Down
1 change: 1 addition & 0 deletions configs/skyrl-experiments/terminal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ reward:

tools:
- terminal
- localization_finish

prompts:
system_prompt: "templates/system_prompt.j2"
Expand Down
42 changes: 39 additions & 3 deletions src/generator/code_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
LLMConvertibleEvent,
get_logger,
)
from openhands.sdk.event import ActionEvent
from src.tools.localization_finish import LocalizationFinishAction

from src.prompts.prompt_builder import get_instruction
from src.utils.instance import clone_instance
Expand All @@ -74,6 +76,35 @@

file_path = os.path.dirname(__file__)


def get_structured_locations(events: List[Event]) -> Optional[List[Dict[str, Any]]]:
"""Extract structured locations from LocalizationFinishAction in events.

Args:
events: List of conversation events to search through.

Returns:
List of location dicts with 'file', 'class', 'function' keys, or None if not found.
"""
# Find the last LocalizationFinishAction
for event in reversed(events):
if (
isinstance(event, ActionEvent)
and event.source == "agent"
and isinstance(event.action, LocalizationFinishAction)
):
# Extract structured locations from the action
locations = []
for loc in event.action.locations:
locations.append({
"file": loc.file,
"class": loc.class_name,
"function": loc.function_name,
})
return locations
return None


@ray.remote(num_cpus=0.01)
def init_and_run(
instance: dict,
Expand Down Expand Up @@ -156,6 +187,9 @@ def init_and_run(
messages = list(map(lambda event: event.model_dump(), conversation.state.events))
final_message = get_agent_final_response(conversation.state.events)

# Extract structured locations if available
structured_locations = get_structured_locations(conversation.state.events)

# remove the workspace dir
try:
if workspace.exists():
Expand All @@ -179,7 +213,7 @@ def init_and_run(
"end_timestamp": end_timestamp
}

return messages, final_message, additional_attr
return messages, final_message, structured_locations, additional_attr


class CodeSearchGenerator(SkyRLGymGenerator):
Expand Down Expand Up @@ -230,7 +264,7 @@ async def code_search_loop(
instance = env_extras
error = None
try:
messages, final_message, additional_attr = await init_and_run.remote(
messages, final_message, structured_locations, additional_attr = await init_and_run.remote(
instance,
self.litellm_model_name,
# sweagent_config,
Expand All @@ -249,6 +283,7 @@ async def code_search_loop(
error = str(e) + "\n" + traceback.format_exc()
messages = []
final_message = ""
structured_locations = None
additional_attr = {
"wall_clock_duration": 0.0,
"start_timestamp": None,
Expand All @@ -269,14 +304,15 @@ async def code_search_loop(
try:
input_args = {
"final_message": final_message,
"structured_locations": structured_locations,
"messages": messages,
"instance": instance,
}

reward_fn = get_reward_function(reward_fn_args["fn"])

input_args = {
**input_args,
**input_args,
**reward_fn_args.get("args", {})
}

Expand Down
96 changes: 75 additions & 21 deletions src/prompts/templates/system_prompt.j2
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,52 @@ You are given access to the codebase in a linux file system.
* Read targeted line ranges around matches using `sed -n 'START,ENDp'`
* Only read additional chunks if the initial sections are relevant

### Final Answer Format (REQUIRED)
- You MUST return your final answer in backticks ``` ... ```
- Format: ```\nfull_path1/file1.py\nclass: MyClass1\nfunction: my_function1\n\nfull_path2/file2.py\nfunction: MyClass2.my_function2\n\nfull_path3/file3.py\nfunction: my_function3\n```
- List one file path per line
- Use relative paths as they appear in the repository
- DO NOT include any other text inside the backticks
### Submitting Your Answer (REQUIRED)

When you have identified all relevant locations, use the `localization_finish` tool to submit your results.

**Format Requirements:**
Submit a structured list of locations. Each location is a JSON object with:
- `file`: Path to the file (REQUIRED)
- `class_name`: Class name (OPTIONAL - omit for file-level or standalone functions)
- `function_name`: Function/method name (OPTIONAL - omit for file-level or class-level only)

**When to include what:**
- File-level changes (imports, globals, new top-level classes): Just `file`
- Class-level changes (new methods, attributes, entire class): `file` + `class_name`
- Standalone function (top-level function): `file` + `function_name`
- Method in a class: `file` + `class_name` + `function_name`

**Example formats:**

1. File-only (imports, globals, new class):
```json
{"file": "path/to/file1.py"}
```

2. File + Class (class-level changes):
```json
{"file": "path/to/file2.py", "class_name": "MyClass"}
```

3. File + Function (standalone function):
```json
{"file": "path/to/file3.py", "function_name": "my_function"}
```

4. File + Class + Function (method):
```json
{"file": "path/to/file4.py", "class_name": "MyClass", "function_name": "my_method"}
```

5. Multiple locations:
```json
[
{"file": "src/parser.py", "class_name": "DataParser", "function_name": "parse_json"},
{"file": "src/models/user.py", "class_name": "User"},
{"file": "src/config.py"}
]
```

## SEARCH STRATEGY

Expand All @@ -61,32 +101,46 @@ You are given access to the codebase in a linux file system.
3. **Final Verification**: Confirm your file list
- Verify each candidate file is truly relevant
- Ensure you haven't missed related files
- Return your answer in backticks ``` ... ```
- Use the `localization_finish` tool to submit your answer

## CRITICAL RULES
- NEVER exceed 5 parallel bash tool calls in a single turn
- NEVER respond without wrapping your file list in backticks ```
- ALWAYS use the `localization_finish` tool when you're done
- ALWAYS use bash tool to search (do not guess file locations)
- NEVER read entire large files - always read in chunks (100-line ranges)
- Check file size with `wc -l` before reading
- Read file contents in chunks to verify relevance before including them
- Return file paths as they appear in the repository. Do not begin the path with "./"
- Aim for high precision (all files relevant) and high recall (no relevant files missed)
- Class and function names are OPTIONAL - only include when changes are at that level

## EXAMPLE OUTPUT
## EXAMPLE SUBMISSION

After exploring the codebase, return your answer like this:
When ready, call the `localization_finish` tool with your findings:

Your final output should list the locations requiring modification, wrapped with triple backticks ```
Each location should include the file path, class name (if applicable), and function name. Here is an example Output:
```json
[
{
"file": "src/utils/parser.py",
"class_name": "DataParser",
"function_name": "parse_json"
},
{
"file": "src/models/user.py",
"class_name": "User"
},
{
"file": "src/config.py"
},
{
"file": "src/api/endpoints.py",
"function_name": "handle_request"
}
]
```
full_path1/file1.py
class: MyClass1
function: my_function1

full_path2/file2.py
function: MyClass2.my_function2

full_path3/file3.py
function: my_function3
```
**Note:** In this example:
- `parser.py` has a specific method change (file + class + function)
- `user.py` has a class-level change (file + class only)
- `config.py` has file-level changes (file only)
- `endpoints.py` has a standalone function change (file + function only)
15 changes: 13 additions & 2 deletions src/rewards/file_localization/file_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ def file_localization_f1_reward(
final_message: str,
instance: dict,
file_level_weight: float=1.0,
structured_locations=None,
**kwargs
):
all_found_files, all_found_modules, all_found_entities = get_simple_results_from_raw_outputs(final_message)
# Use structured locations if available, otherwise parse final_message
if structured_locations is not None:
all_found_files, all_found_modules, all_found_entities = get_simple_results_from_raw_outputs(structured_locations)
else:
all_found_files, all_found_modules, all_found_entities = get_simple_results_from_raw_outputs(final_message)

true_files = set(x[0] for x in ast.literal_eval(instance["target"]))
file_level_score = compute_file_f1_score(all_found_files, true_files)
weighted_file_score = file_level_weight * file_level_score
Expand All @@ -42,6 +48,7 @@ def multilevel_localization_f1_reward(
file_level_weight: float=1.0,
module_level_weight: float=1.0,
entity_level_weight: float=1.0,
structured_locations=None,
**kwargs
):

Expand All @@ -67,7 +74,11 @@ def multilevel_localization_f1_reward(
gt_modules = set(gt_modules)
gt_entities = set(gt_entities)

predicted_files, predicted_modules, predicted_entities = get_simple_results_from_raw_outputs(final_message)
# Use structured locations if available, otherwise parse final_message
if structured_locations is not None:
predicted_files, predicted_modules, predicted_entities = get_simple_results_from_raw_outputs(structured_locations)
else:
predicted_files, predicted_modules, predicted_entities = get_simple_results_from_raw_outputs(final_message)

file_f1_score = compute_file_f1_score(predicted_files, gt_files)
module_f1_score = compute_file_f1_score(predicted_modules, gt_modules)
Expand Down
30 changes: 25 additions & 5 deletions src/rewards/file_localization/module_rewards.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

def parse_simple_output(raw_output: str) -> List[Dict[str, str]]:
def parse_simple_output(raw_output: Union[str, List[Dict[str, str]]]) -> List[Dict[str, str]]:
"""
Parse simplified agent output containing filename, optional class, and function.

Args:
raw_output: Raw text output from the agent
raw_output: Either a raw text string OR a list of location dicts (for structured input)

Returns:
List of dictionaries with keys: 'file', 'class' (optional), 'function'

Example input format:
Example string input format:
```
path/to/file1.py
class: MyClass
Expand All @@ -21,12 +21,32 @@ def parse_simple_output(raw_output: str) -> List[Dict[str, str]]:
function: standalone_function
```

Example output:
Example structured input format:
[
{'file': 'path/to/file1.py', 'class': 'MyClass', 'function': 'my_method'},
{'file': 'path/to/file2.py', 'class': None, 'function': 'standalone_function'}
]

Example output (same for both):
[
{'file': 'path/to/file1.py', 'class': 'MyClass', 'function': 'my_method'},
{'file': 'path/to/file2.py', 'class': None, 'function': 'standalone_function'}
]
"""
# Handle structured input (list of dicts)
if isinstance(raw_output, list):
# Already in the correct format (or close to it)
# Normalize field names: class_name -> class, function_name -> function
normalized = []
for loc in raw_output:
normalized.append({
'file': loc.get('file', ''),
'class': loc.get('class') or loc.get('class_name'),
'function': loc.get('function') or loc.get('function_name'),
})
return normalized

# Handle string input (legacy format)
# Remove triple backticks and whitespace
raw_output = raw_output.strip("` \n")

Expand Down
Loading