diff --git a/pyproject.toml b/pyproject.toml index cce1f28..1f7c964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,8 @@ explicit = true flash-attn = ["torch"] [tool.uv.sources] -skyrl-train = { git = "https://github.com/adityasoni9998/SkyRL.git", rev = "81e5a97c7430503c0c4e6508497cc5aa01a0c624", subdirectory = "skyrl-train" } +# skyrl-train = { git = "https://github.com/NovaSky-AI/SkyRL", rev = "69ca4d9", subdirectory = "skyrl-train" } +skyrl-train = { path = "/project/flame/lsutawik/cso/SkyRL/skyrl-train", editable = true } flash-attn = {url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp313-cp313-linux_x86_64.whl"} openhands-sdk = { workspace = true } openhands-tools = { workspace = true } @@ -95,4 +96,4 @@ members = [ "software-agent-sdk/openhands-tools", "software-agent-sdk/openhands-workspace", "software-agent-sdk/openhands-agent-server", -] +] \ No newline at end of file diff --git a/src/build_dataset.py b/src/build_dataset.py index b28ae50..6dda444 100644 --- a/src/build_dataset.py +++ b/src/build_dataset.py @@ -3,12 +3,12 @@ from datasets import load_dataset -# from src.utils.dataset import extract_functions_from_patch +from src.utils.dataset import extract_functions_from_patch def main(): parser = argparse.ArgumentParser(description="Build dataset from patches") - parser.add_argument("--dataset", default="adityasoni17/SWE-smith-py-code-search", help="Input dataset path") + parser.add_argument("--dataset", default="SWE-Gym/SWE-Gym", help="Input dataset path") parser.add_argument("--split", default="train", help="Dataset split to use") parser.add_argument("--output", required=True, help="Output file path for processed dataset") parser.add_argument("--use_patch", action="store_true", help="Whether to use patches to extract target functions") @@ -18,7 +18,7 @@ def main(): dataset = load_dataset(args.dataset, split=args.split).to_pandas() dataset["target"] = dataset.apply( - lambda row: row["file_changes"], axis=1 + lambda row: f"{extract_functions_from_patch(row['patch'])}", axis=1 ) # Remove rows with empty problem_statement @@ -30,7 +30,6 @@ def main(): if args.use_patch: dataset["use_patch"] = True - dataset["base_commit"] = None else: dataset["use_patch"] = False @@ -41,7 +40,7 @@ def main(): pass # shuffle dataset - dataset = dataset.sample(frac=1, random_state=42).reset_index(drop=True) + dataset = dataset.sample(frac=1).reset_index(drop=True) # train_size = int(0.975 * len(dataset)) train_dataset = dataset.iloc[:-100] @@ -59,4 +58,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/generator/code_search_generator.py b/src/generator/code_search_generator.py index 9431c9d..52d6e71 100644 --- a/src/generator/code_search_generator.py +++ b/src/generator/code_search_generator.py @@ -165,9 +165,6 @@ def init_and_run( system_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.system_prompt) user_prompt_path = os.path.join(prompts_base_dir, generator_cfg.prompts.user_prompt) - assert os.path.exists(system_prompt_path), f"System prompt file {system_prompt_path} does not exist" - assert os.path.exists(user_prompt_path), f"User prompt file {user_prompt_path} does not exist" - agent = CustomAgent( llm=LLM( usage_id="agent", @@ -178,14 +175,14 @@ def init_and_run( litellm_extra_body={ "return_token_ids": True, "include_stop_str_in_output": False, + "add_generation_prompt": True, "chat_template_kwargs": { - "add_generation_prompt": True, - "enable_thinking": False - } - } + "enable_thinking": False, + } + }, ), tools=tools, - # security_analyzer=None, + security_analyzer=None, system_prompt_filename=system_prompt_path ) @@ -196,49 +193,43 @@ def init_and_run( workspace=str(working_dir), ) input_message = get_instruction(instance, user_prompt_path, str(working_dir)) + conversation.send_message(input_message) + + logger.info("Conversation Starting") # Capture start time start_time = time.time() start_timestamp = datetime.now().isoformat() try: - conversation.send_message(input_message) - logger.info("Conversation Starting") conversation.run() - messages = list(map(lambda event: event.model_dump(), conversation.state.events)) - final_message = get_agent_final_response(conversation.state.events) - structured_locations = get_structured_locations(conversation.state.events) except Exception as e: - logger.error(f"Error during conversation: {str(e)}", exc_info=True) - try: - messages = list(map(lambda event: event.model_dump(), conversation.state.events)) - final_message = get_agent_final_response(conversation.state.events) - structured_locations = get_structured_locations(conversation.state.events) - except Exception as e: - logger.error(f"Error during final message extraction in err'ed rollout: {str(e)}", exc_info=True) - messages = [] - final_message = "" - finally: - # Capture end time - try: - if workspace.exists(): - os.system(f"rm -rf {str(workspace)}") - logger.info(f"Removed workspace {str(workspace)}") - conversation.close() - except Exception as _: - pass - logger.info("Conversation Finished") - end_time = time.time() - end_timestamp = datetime.now().isoformat() - wall_clock_duration = end_time - start_time - - additional_attr = { - "wall_clock_duration": wall_clock_duration, - "start_timestamp": start_timestamp, - "end_timestamp": end_timestamp - } + logger.error(f"Error during conversation run: {e}", exc_info=True) + + messages = list(map(lambda event: event.model_dump(), conversation.state.events)) + final_message = get_agent_final_response(conversation.state.events) + structured_locations = get_structured_locations(conversation.state.events) + try: + if workspace.exists(): + os.system(f"rm -rf {str(workspace)}") + logger.info(f"Removed workspace {str(workspace)}") + conversation.close() + except Exception as _: + pass + conversation.close() + logger.info("Conversation Finished") + + # Capture end time + end_time = time.time() + end_timestamp = datetime.now().isoformat() + wall_clock_duration = end_time - start_time + + additional_attr = { + "wall_clock_duration": wall_clock_duration, + "start_timestamp": start_timestamp, + "end_timestamp": end_timestamp + } - # NOTE: Hard-coded final message to ensure all rollouts that don't call the custom finish tool have reward == 0 return messages, final_message, structured_locations, additional_attr @@ -269,7 +260,7 @@ def __init__( self.tokenizer = tokenizer self.model_name = model_name # self.litellm_model_name = "openai/" + self.model_name - self.litellm_model_name = "openai/" + self.model_name + self.litellm_model_name = "litellm_proxy/" + self.model_name # if self.generator_cfg.chat_template.name_or_path is not None: # raise NotImplementedError( @@ -329,7 +320,7 @@ async def code_search_loop( batch_metadata.training_phase, ) except Exception as e: - logger.error(f"Critical Error in conversation: {str(e)}", exc_info=True) + logger.error(f"Error in starting conversation: {e}", exc_info=True) # TODO properly handle this error = str(e) + "\n" + traceback.format_exc() messages = [] @@ -415,7 +406,6 @@ async def code_search_loop( token_messages = [msg for msg in messages if msg["kind"] == "TokenEvent"] rollout_list = [] - num_steps = len(token_messages) if len(token_messages) > 0: if self.step_wise: for idx, message in enumerate(token_messages): @@ -500,11 +490,9 @@ async def code_search_loop( ) else: - # Ideally the code should not reach here - logger.info("IMPORTANT_ERROR: No TokenEvents found in the conversation. Saving an error rollout with minimal data.") response_ids = [151643] stop_reason = "error" - loss_mask = [0] # NOTE: Mask out loss completely + loss_mask = [1] initial_input_ids = [151643] trajectory_metrics = {} # Empty metrics for error case rollout_list.append( @@ -544,10 +532,7 @@ async def code_search_loop( os.makedirs(os.path.dirname(filename_path), exist_ok=True) # get everything between ```` with regex - try: - raw_final_message = json.dumps(structured_locations) if structured_locations is not None else final_message - except Exception as e: - raw_final_message = "" + raw_final_message = final_message matches = re.findall(r"```(.*?)```", final_message, re.DOTALL) parsed_final_message = matches[-1] if matches else final_message