Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ explicit = true
flash-attn = ["torch"]

[tool.uv.sources]
skyrl-train = { git = "https://github.com/adityasoni9998/SkyRL.git", rev = "81e5a97c7430503c0c4e6508497cc5aa01a0c624", subdirectory = "skyrl-train" }
# skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", rev = "69ca4d9", subdirectory = "skyrl-train" }
skyrl-train = { path = "/project/flame/lsutawik/cso/SkyRL/skyrl-train", editable = true }
flash-attn = {url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"}
openhands-sdk = { workspace = true }
openhands-tools = { workspace = true }
Expand All @@ -95,4 +96,4 @@ members = [
"software-agent-sdk/openhands-tools",
"software-agent-sdk/openhands-workspace",
"software-agent-sdk/openhands-agent-server",
]
]
11 changes: 5 additions & 6 deletions src/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from datasets import load_dataset

# from src.utils.dataset import extract_functions_from_patch
from src.utils.dataset import extract_functions_from_patch


def main():
parser = argparse.ArgumentParser(description="Build dataset from patches")
parser.add_argument("--dataset", default="adityasoni17/SWE-smith-py-code-search", help="Input dataset path")
parser.add_argument("--dataset", default="SWE-Gym/SWE-Gym", help="Input dataset path")
parser.add_argument("--split", default="train", help="Dataset split to use")
parser.add_argument("--output", required=True, help="Output file path for processed dataset")
parser.add_argument("--use_patch", action="store_true", help="Whether to use patches to extract target functions")
Expand All @@ -18,7 +18,7 @@ def main():
dataset = load_dataset(args.dataset, split=args.split).to_pandas()

dataset["target"] = dataset.apply(
lambda row: row["file_changes"], axis=1
lambda row: f"{extract_functions_from_patch(row['patch'])}", axis=1
)

# Remove rows with empty problem_statement
Expand All @@ -30,7 +30,6 @@ def main():

if args.use_patch:
dataset["use_patch"] = True
dataset["base_commit"] = None
else:
dataset["use_patch"] = False

Expand All @@ -41,7 +40,7 @@ def main():
pass

# shuffle dataset
dataset = dataset.sample(frac=1, random_state=42).reset_index(drop=True)
dataset = dataset.sample(frac=1).reset_index(drop=True)

# train_size = int(0.975 * len(dataset))
train_dataset = dataset.iloc[:-100]
Expand All @@ -59,4 +58,4 @@ def main():


if __name__ == "__main__":
main()
main()
89 changes: 37 additions & 52 deletions src/generator/code_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ def init_and_run(
system_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.system_prompt)
user_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.user_prompt)

assert os.path.exists(system_prompt_path), f"System prompt file {system_prompt_path} does not exist"
assert os.path.exists(user_prompt_path), f"User prompt file {user_prompt_path} does not exist"

agent = CustomAgent(
llm=LLM(
usage_id="agent",
Expand All @@ -178,14 +175,14 @@ def init_and_run(
litellm_extra_body={
"return_token_ids": True,
"include_stop_str_in_output": False,
"add_generation_prompt": True,
"chat_template_kwargs": {
"add_generation_prompt": True,
"enable_thinking": False
}
}
"enable_thinking": False,
}
},
),
tools=tools,
# security_analyzer=None,
security_analyzer=None,
system_prompt_filename=system_prompt_path
)

Expand All @@ -196,49 +193,43 @@ def init_and_run(
workspace=str(working_dir),
)
input_message = get_instruction(instance, user_prompt_path, str(working_dir))
conversation.send_message(input_message)

logger.info("Conversation Starting")

# Capture start time
start_time = time.time()
start_timestamp = datetime.now().isoformat()

try:
conversation.send_message(input_message)
logger.info("Conversation Starting")
conversation.run()
messages = list(map(lambda event: event.model_dump(), conversation.state.events))
final_message = get_agent_final_response(conversation.state.events)
structured_locations = get_structured_locations(conversation.state.events)
except Exception as e:
logger.error(f"Error during conversation: {str(e)}", exc_info=True)
try:
messages = list(map(lambda event: event.model_dump(), conversation.state.events))
final_message = get_agent_final_response(conversation.state.events)
structured_locations = get_structured_locations(conversation.state.events)
except Exception as e:
logger.error(f"Error during final message extraction in err'ed rollout: {str(e)}", exc_info=True)
messages = []
final_message = ""
finally:
# Capture end time
try:
if workspace.exists():
os.system(f"rm -rf {str(workspace)}")
logger.info(f"Removed workspace {str(workspace)}")
conversation.close()
except Exception as _:
pass
logger.info("Conversation Finished")
end_time = time.time()
end_timestamp = datetime.now().isoformat()
wall_clock_duration = end_time - start_time

additional_attr = {
"wall_clock_duration": wall_clock_duration,
"start_timestamp": start_timestamp,
"end_timestamp": end_timestamp
}
logger.error(f"Error during conversation run: {e}", exc_info=True)

messages = list(map(lambda event: event.model_dump(), conversation.state.events))
final_message = get_agent_final_response(conversation.state.events)
structured_locations = get_structured_locations(conversation.state.events)
try:
if workspace.exists():
os.system(f"rm -rf {str(workspace)}")
logger.info(f"Removed workspace {str(workspace)}")
conversation.close()
except Exception as _:
pass
conversation.close()
logger.info("Conversation Finished")

# Capture end time
end_time = time.time()
end_timestamp = datetime.now().isoformat()
wall_clock_duration = end_time - start_time

additional_attr = {
"wall_clock_duration": wall_clock_duration,
"start_timestamp": start_timestamp,
"end_timestamp": end_timestamp
}

# NOTE: Hard-coded final message to ensure all rollouts that don't call the custom finish tool have reward == 0
return messages, final_message, structured_locations, additional_attr


Expand Down Expand Up @@ -269,7 +260,7 @@ def __init__(
self.tokenizer = tokenizer
self.model_name = model_name
# self.litellm_model_name = "openai/" + self.model_name
self.litellm_model_name = "openai/" + self.model_name
self.litellm_model_name = "litellm_proxy/" + self.model_name

# if self.generator_cfg.chat_template.name_or_path is not None:
# raise NotImplementedError(
Expand Down Expand Up @@ -329,7 +320,7 @@ async def code_search_loop(
batch_metadata.training_phase,
)
except Exception as e:
logger.error(f"Critical Error in conversation: {str(e)}", exc_info=True)
logger.error(f"Error in starting conversation: {e}", exc_info=True)
# TODO properly handle this
error = str(e) + "\n" + traceback.format_exc()
messages = []
Expand Down Expand Up @@ -415,7 +406,6 @@ async def code_search_loop(

token_messages = [msg for msg in messages if msg["kind"] == "TokenEvent"]
rollout_list = []
num_steps = len(token_messages)
if len(token_messages) > 0:
if self.step_wise:
for idx, message in enumerate(token_messages):
Expand Down Expand Up @@ -500,11 +490,9 @@ async def code_search_loop(
)

else:
# Ideally the code should not reach here
logger.info("IMPORTANT_ERROR: No TokenEvents found in the conversation. Saving an error rollout with minimal data.")
response_ids = [151643]
stop_reason = "error"
loss_mask = [0] # NOTE: Mask out loss completely
loss_mask = [1]
initial_input_ids = [151643]
trajectory_metrics = {} # Empty metrics for error case
rollout_list.append(
Expand Down Expand Up @@ -544,10 +532,7 @@ async def code_search_loop(
os.makedirs(os.path.dirname(filename_path), exist_ok=True)

# get everything between ```` with regex
try:
raw_final_message = json.dumps(structured_locations) if structured_locations is not None else final_message
except Exception as e:
raw_final_message = ""
raw_final_message = final_message
matches = re.findall(r"```(.*?)```", final_message, re.DOTALL)
parsed_final_message = matches[-1] if matches else final_message

Expand Down