diff --git a/README_Training.md b/README_Training.md index 86df63f..4564024 100644 --- a/README_Training.md +++ b/README_Training.md @@ -6,7 +6,7 @@ uv run src/build_dataset.py --output ../data/ ``` -## Train Model +## Train Model with RLVR ``` bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d @@ -24,4 +24,26 @@ bash scripts/run_async_training.sh \ -o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \ -d $DATA_PATH \ 2>&1 | tee training.log -``` \ No newline at end of file +``` + +## Train Model with On-Policy Distillation + +``` +DATA_PATH= +bash scripts/run_distillation.sh \ + -m Qwen/Qwen3-4B \ # Student model (model to be trained) + -r Qwen/Qwen3-8B \ # Teacher model (model to distill from) + -d $DATA_PATH \ + 2>&1 | tee distillation.log +``` + +``` +DATA_PATH= +bash scripts/run_distillation.sh \ + -m Qwen/Qwen3-4B \ # Student model (model to be trained) + -r Qwen/Qwen3-8B \ # Teacher model (model to distill from) + -o "+generator.exp_config=configs/skyrl-experiments/read-only.yaml" \ + -d $DATA_PATH \ + 2>&1 | tee distillation.log +``` + diff --git a/scripts/distill.sh b/scripts/distill.sh new file mode 100644 index 0000000..c085f45 --- /dev/null +++ b/scripts/distill.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Loop over 10 +for i in $(seq 1 10) +do + echo "Run number: $i" + # Kill any process using port 8080 after 4 hours + ( sleep 14400 && fuser -k 8080/tcp ) & \ + bash scripts/run_distillation.sh "$@" +done diff --git a/scripts/run_distillation.sh b/scripts/run_distillation.sh new file mode 100644 index 0000000..43b7c0f --- /dev/null +++ b/scripts/run_distillation.sh @@ -0,0 +1,130 @@ +#!/bin/bash +# +# Usage: bash scripts/run_distillation.sh \ +# -m Qwen/Qwen3-4B \ # Student model (model to be trained) +# -r Qwen/Qwen3-32B \ # Reference/Teacher model (model to distill from) +# -d data/swe_gym \ # Data path +# [-s ckpt_path] [-n n_rollouts] [-i num_inference_engines] [-t num_training_engines] +# + +. .env 2>/dev/null || true + +while getopts ":m:r:n:d:s:o:i:t:b:" opt; do + case ${opt} in + m ) STUDENT_MODEL=$OPTARG;; # -m: Student model (model to be trained) + r ) TEACHER_MODEL=$OPTARG;; # -r: Reference/Teacher model (model to distill from) + n ) N_ROLLOUTS=$OPTARG;; + d ) DATA_PATH=$OPTARG;; + s ) CKPT_PATH=$OPTARG;; + o ) OTHER_OPTION=$OPTARG;; + i ) NUM_INFERENCE_ENGINES=$OPTARG;; + t ) NUM_TRAINING_ENGINES=$OPTARG;; + b ) MICRO_BATCH_SIZE=$OPTARG;; + \? ) echo "Usage: $0 -m -r [-d data_path] [-s ckpt_path] [-n n_rollouts] [-i num_inference_engines] [-t num_training_engines] [-b micro_batch_size] [-o other_options]"; exit 1;; + esac +done + +# Validate required parameters +if [ -z "$STUDENT_MODEL" ]; then + echo "Error: Student model (-m) is required" + echo "Usage: $0 -m -r -d " + exit 1 +fi +if [ -z "$TEACHER_MODEL" ]; then + echo "Error: Teacher model (-r) is required" + echo "Usage: $0 -m -r -d " + exit 1 +fi + +STUDENT_MODEL_ALIAS=$(echo $STUDENT_MODEL | sed 's/\//-/g') +TEACHER_MODEL_ALIAS=$(echo $TEACHER_MODEL | sed 's/\//-/g') +# Get number of GPUs available +NUM_GPUS=$(nvidia-smi -L | wc -l) +N_ROLLOUTS="${N_ROLLOUTS:-8}" +BATCH_SIZE=16 # Must be <= num_parallel_generation_workers (set to 16 below) +MAX_LENGTH=8192 +RUN_NAME="code_search_distillation_${STUDENT_MODEL_ALIAS}_${TEACHER_MODEL_ALIAS}" +set -x + +DATA_PATH="${DATA_PATH:-data/swe_smith}" +CKPT_PATH="${CKPT_PATH:-$(pwd)/ckpts/${STUDENT_MODEL_ALIAS}}" +mkdir -p $CKPT_PATH + +HALF_NUM_GPUS=$((NUM_GPUS / 2)) +NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-$NUM_GPUS}" +NUM_TRAINING_ENGINES="${NUM_TRAINING_ENGINES:-$NUM_GPUS}" + +export VLLM_FLASH_ATTN_VERSION=2 +export CUDA_LAUNCH_BLOCKING=1 +export TORCH_USE_CUDA_DSA=1 + + +uv run python -m src.train \ + +run_async_trainer=false \ + +use_distillation=true \ + data.train_data="['$DATA_PATH/train.parquet']" \ + data.val_data="['$DATA_PATH/validation.parquet']" \ + trainer.algorithm.advantage_estimator="no_op" \ + trainer.algorithm.policy_loss_type="importance_sampling" \ + trainer.algorithm.use_kl_in_reward=true \ + trainer.algorithm.use_kl_loss=false \ + trainer.policy.model.path=${STUDENT_MODEL} \ + trainer.ref.model.path=${TEACHER_MODEL} \ + trainer.placement.colocate_all=true \ + trainer.placement.colocate_policy_ref=true \ + trainer.strategy=fsdp2 \ + trainer.policy.fsdp_config.cpu_offload=true \ + trainer.policy.fsdp_config.reshard_after_forward=true \ + trainer.policy.fsdp_config.fsdp_size=-1 \ + trainer.fully_async.num_parallel_generation_workers=16 \ + trainer.placement.policy_num_gpus_per_node=${NUM_TRAINING_ENGINES} \ + trainer.placement.ref_num_gpus_per_node=${NUM_TRAINING_ENGINES} \ + trainer.placement.policy_num_nodes=1 \ + trainer.placement.ref_num_nodes=1 \ + trainer.policy.sequence_parallel_size=1 \ + generator.num_inference_engines=${NUM_INFERENCE_ENGINES} \ + generator.inference_engine_tensor_parallel_size=1 \ + +generator.traj_dir=${CKPT_PATH}trajectories/ \ + +generator.engine_init_kwargs.enable_auto_tool_choice=true \ + +generator.engine_init_kwargs.tool_call_parser=hermes \ + +generator.engine_init_kwargs.reasoning_parser=qwen3 \ + trainer.epochs=20 \ + trainer.eval_batch_size=100 \ + trainer.eval_before_train=false \ + trainer.eval_interval=100 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=${BATCH_SIZE} \ + trainer.policy_mini_batch_size=${BATCH_SIZE} \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=${MICRO_BATCH_SIZE:-1} \ + trainer.dump_data_batch=true \ + trainer.export_path="${CKPT_PATH}exported_model/" \ + trainer.hf_save_interval=5 \ + trainer.ckpt_interval=5 \ + trainer.max_prompt_length=4096 \ + generator.sampling_params.max_generate_length=${MAX_LENGTH} \ + generator.sampling_params.temperature=1.0 \ + generator.max_input_length=32768 \ + generator.max_num_batched_tokens=131072 \ + generator.max_turns=4 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=False \ + generator.backend=vllm \ + generator.run_engines_locally=True \ + generator.enable_http_endpoint=True \ + generator.http_endpoint_host='0.0.0.0' \ + generator.http_endpoint_port=8080 \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.n_samples_per_prompt=${N_ROLLOUTS} \ + generator.gpu_memory_utilization=0.75 \ + generator.enforce_eager=false \ + trainer.step_wise_training=true \ + trainer.logger="wandb" \ + trainer.project_name="code_search" \ + trainer.run_name=${RUN_NAME} \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$CKPT_PATH" \ + trainer.max_ckpts_to_keep=3 \ + $OTHER_OPTION diff --git a/src/distiller.py b/src/distiller.py new file mode 100644 index 0000000..e2a0c0a --- /dev/null +++ b/src/distiller.py @@ -0,0 +1,64 @@ +import torch +import ray +from omegaconf import DictConfig +from skyrl_train.entrypoints.main_base import BasePPOExp +import hydra +from skyrl_train.trainer import RayPPOTrainer +from skyrl_train.utils import initialize_ray +from skyrl_train.entrypoints.main_base import config_dir, validate_cfg +from skyrl_train.utils.ppo_utils import ( + register_advantage_estimator, + register_policy_loss, + reduce_loss, +) +from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer + +def apply_reward_kl_penalty(data: TrainingInputBatch) -> TrainingInputBatch: + """Computes the KL penalty and sets the rewards to the KL penalty""" + loss_masks_all: torch.Tensor = data["loss_mask"] + teacher_action_log_probs: torch.Tensor = data["base_action_log_probs"] + action_log_probs: torch.Tensor = data["action_log_probs"] + rewards = -(action_log_probs - teacher_action_log_probs) * loss_masks_all + data["rewards"] = rewards + return data + +class OnPolicyDistillationTrainer(RayPPOTrainer): + """ + Custom trainer for On Policy Distillation. + + Overrides the apply_reward_kl_penalty method to set the rewards just to the kl penalty + """ + + def apply_reward_kl_penalty( + self, + data: TrainingInputBatch, + ) -> TrainingInputBatch: + """Computes the KL penalty and sets the rewards to the KL penalty""" + return apply_reward_kl_penalty(data) + +class FullyAsyncOnPolicyDistillationTrainer(FullyAsyncRayPPOTrainer): + def apply_reward_kl_penalty( + self, + data: TrainingInputBatch, + ) -> TrainingInputBatch: + return apply_reward_kl_penalty(data) + + +# Using the decorator +@register_advantage_estimator("no_op") +def compute_no_op_advantage(token_level_rewards: torch.Tensor, **kwargs): + # just pass through the rewards + return token_level_rewards, token_level_rewards + + +@register_policy_loss("importance_sampling") +def compute_importance_sampling_policy_loss( + log_probs, old_log_probs, advantages, config, loss_mask=None, rollout_logprobs=None, **kwargs +): + # as defined here: https://tinker-docs.thinkingmachines.ai/losses#policy-gradient-importance_sampling + loss = -torch.exp(log_probs - old_log_probs) * advantages + + loss = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", config.max_seq_len) + # return loss and a dummy clip ratio value as we aren't clipping here + return loss, 0.0 diff --git a/src/generator/code_search_generator.py b/src/generator/code_search_generator.py index 0674634..81c9946 100644 --- a/src/generator/code_search_generator.py +++ b/src/generator/code_search_generator.py @@ -168,6 +168,15 @@ def init_and_run( messages = list(map(lambda event: event.model_dump(), conversation.state.events)) final_message = get_agent_final_response(conversation.state.events) + # remove the workspace dir + try: + if workspace.exists(): + os.system(f"rm -rf {str(workspace)}") + logger.info(f"Removed workspace {str(workspace)}") + except Exception as e: + logger.error(f"Error removing workspace {str(workspace)}: {e}", exc_info=True) + + conversation.close() logger.info("Conversation Finished") @@ -354,31 +363,44 @@ async def code_search_loop( current_response_ids = current_response_ids[len(current_prompt_ids):] max_response_len = max_train_len - len(current_prompt_ids) - + mask = [1]*len(token_messages[0]["response_token_ids"]) + for i in range(1, len(token_messages)): + mask += [0] * (len(token_messages[i]["prompt_token_ids"]) - len(token_messages[i-1]["prompt_token_ids"]) - len(token_messages[i-1]["response_token_ids"])) + mask += [1] * len(token_messages[i]["response_token_ids"]) # make mask of 0 for everything inside <|im_start|> # and assistant and 1 elsewhere start_token_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>") end_token_id = self.tokenizer.convert_tokens_to_ids("assistant") + end_of_turn_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>") mask = [] + found_role_switch = False inside = False - for token_id in current_response_ids: - if token_id == start_token_id: - inside = True - mask.append(0) - elif token_id == end_token_id: - inside = False - mask.append(0) + idx = 0 + while idx < len(current_response_ids): + token_id = current_response_ids[idx] + if not inside: + mask.append(1) + idx += 1 + if token_id == end_of_turn_token_id: + inside = True else: - if inside: + if token_id == start_token_id: + inside = True + mask.append(0) + idx += 1 + elif token_id == end_token_id and found_role_switch: + inside = False mask.append(0) + mask.append(0) + idx += 2 else: - mask.append(1) + mask.append(0) + idx += 1 - # mask zero out everything beyond max_response_len - # Don't truncate the response, just mask out the loss - if len(current_response_ids) > max_response_len: - for i in range(max_response_len, len(current_response_ids)): - mask[i] = 0 + if token_id == start_token_id: + found_role_switch = True + else: + found_role_switch = False rollout_list.append( ( @@ -509,7 +531,7 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: for step_id in range(len(step_outputs)): out_trajectory_id = copy.deepcopy(trajectory_ids[i]) out_trajectory_id.step = step_id - out_trajectory_ids.append(out_trajectory_id.instance_id) + out_trajectory_ids.append(out_trajectory_id) is_last_step.append(step_id == len(step_outputs) - 1) if not len(responses): diff --git a/src/prompts/prompt_builder.py b/src/prompts/prompt_builder.py index 23a70db..16d85c1 100644 --- a/src/prompts/prompt_builder.py +++ b/src/prompts/prompt_builder.py @@ -25,7 +25,7 @@ def get_instruction( "workspace_dir_name": workspace_dir_name, "working_dir": workspace_path, } - context["test_instructions"] = "" + # context["test_instructions"] = "" # Render the instruction instruction = template.render(context) diff --git a/src/prompts/templates/default.j2 b/src/prompts/templates/default.j2 index c9323f7..c291f10 100644 --- a/src/prompts/templates/default.j2 +++ b/src/prompts/templates/default.j2 @@ -1,4 +1,4 @@ -I have access to a python code repository in the directory {{ instance.repo_path }} . +I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description: diff --git a/src/prompts/templates/file_localization.j2 b/src/prompts/templates/file_localization.j2 index 987df10..556e019 100644 --- a/src/prompts/templates/file_localization.j2 +++ b/src/prompts/templates/file_localization.j2 @@ -1,4 +1,4 @@ -I have access to a python code repository in the directory {{ instance.repo_path }} . +I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description: diff --git a/src/prompts/templates/file_module.j2 b/src/prompts/templates/file_module.j2 index 8fe4007..a71a943 100644 --- a/src/prompts/templates/file_module.j2 +++ b/src/prompts/templates/file_module.j2 @@ -1,4 +1,4 @@ -I have access to a python code repository in the directory {{ instance.repo_path }} . Consider the following issue description: +I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description: {{ instance.problem_statement }} diff --git a/src/prompts/templates/file_module_parallel_tools.j2 b/src/prompts/templates/file_module_parallel_tools.j2 index 0bc83c0..9a779bb 100644 --- a/src/prompts/templates/file_module_parallel_tools.j2 +++ b/src/prompts/templates/file_module_parallel_tools.j2 @@ -1,4 +1,4 @@ -I have access to a python code repository in the directory {{ instance.repo_path }} . Consider the following issue description: +I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description: {{ instance.problem_statement }} diff --git a/src/rewards/file_localization/file_localization.py b/src/rewards/file_localization/file_localization.py index 73f7d1a..9051e26 100644 --- a/src/rewards/file_localization/file_localization.py +++ b/src/rewards/file_localization/file_localization.py @@ -6,11 +6,12 @@ def compute_file_f1_score(predicted_files, true_files): pred, true = set(predicted_files), set(true_files) + if not true: + return 0.0 # return 0 reward if ground truth is empty tp = len(pred & true) precision = tp / len(pred) if pred else 0.0 recall = tp / len(true) if true else 0.0 - if not pred and not true: - return 1.0 + return 0.0 if precision + recall == 0 else 2 * precision * recall / (precision + recall) # def file_localization_f1_reward(final_message, instance): diff --git a/src/train.py b/src/train.py index 52b1e1a..23f0cda 100644 --- a/src/train.py +++ b/src/train.py @@ -9,6 +9,7 @@ # from src.tools import tool_exists from src.generator.code_search_generator import CodeSearchGenerator from src.async_trainer import CustomFullyAsyncRayPPOTrainer as FullyAsyncRayPPOTrainer +from src.distiller import FullyAsyncOnPolicyDistillationTrainer, OnPolicyDistillationTrainer # from skyrl_train.fully_async_trainer import FullyAsyncRayPPOTrainer @@ -53,15 +54,34 @@ def run(self): asyncio.run(trainer.train()) +class CodeSearchOnPolicyDistillationExp(CodeSearchPPOExp): + def get_trainer(self, *args, **kwargs): + return OnPolicyDistillationTrainer(*args, **kwargs) + +class AsyncCodeSearchOnPolicyDistillationExp(AsyncCodeSearchPPOExp): + def get_trainer(self, *args, **kwargs): + return FullyAsyncOnPolicyDistillationTrainer(*args, **kwargs) + @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: DictConfig): # make sure that the training loop is not run on the head node. - if cfg.get("run_async_trainer", False): - print("Running async trainer") - exp = AsyncCodeSearchPPOExp(cfg) - else: - print("Running sync trainer") - exp = CodeSearchPPOExp(cfg) + run_async = cfg.get("run_async_trainer", False) + use_distillation = cfg.get("use_distillation", False) + + match (run_async, use_distillation): + case (True, True): + print("Running async distillation trainer") + exp = AsyncCodeSearchOnPolicyDistillationExp(cfg) + case (True, False): + print("Running async trainer") + exp = AsyncCodeSearchPPOExp(cfg) + case (False, True): + print("Running sync distillation trainer") + exp = CodeSearchOnPolicyDistillationExp(cfg) + case (False, False): + print("Running sync trainer") + exp = CodeSearchPPOExp(cfg) + exp.run()