Skip to content

AQ-MedAI/TOOL_CURE

Repository files navigation

TOOL-CURE: Tool Selection via Curriculum-Enhanced Reinforcement Learning

Accepted Paper

TOOL-CURE is a training recipe for LLM-based tool-selection agents, built on top of the verl RL framework.It is designed for noisy, partially-correct, and dirty real-world data, and is based on two key ideas:

  1. Proficiency-Scaled Curriculum Learning (PSCL) – a two-stage curriculum that first trains on easier samples, then continues on harder ones.
  2. Online Policy Guarding via Sample Screening (OPGSS) – an online sample-filtering mechanism that masks low-quality samples during RL training.

This repository provides:

  • Minimal integration of TOOL-CURE.
  • A four-step script pipeline to run training → checkpoint merge → inference → evaluation.

Key Features

  • Robust to noisy data: Handles partially-correct and dirty samples gracefully.
  • Better generalization: PSCL encourages learning from easier samples before harder ones.
  • Stable RL training: OPGSS prevents collapse from conflicting reward signals.

PSCL & OPGSS Overview

PSCL (Proficiency-Scaled Curriculum Learning)

PSCL is implemented as data preprocessing + two-stage training:

  1. Run an initial policy on each training prompt, sample G rollouts, compute an aggregate score (reward_sum).
  2. Split into:
    • D_easy: samples with reward_sum >= threshold
    • D_hard: samples with reward_sum < threshold
  3. Stage 1 – train on D_easy for several epochs.
  4. Stage 2 – resume from Stage 1 checkpoints and train on D_hard.

This corresponds to Section 3.2 of the TOOL-CURE paper.

OPGSS (Online Policy Guarding via Sample Screening)

OPGSS is integrated directly into the PPO trainer:

  • After computing token-level rewards, aggregate them per sample into reward_sum.
  • Compare each sample’s reward_sum with trainer.reward_sum_threshold.
  • Build a gradient mask:
    • If reward_sum < trainer.reward_sum_thresholdgrad_mask = 0 (sample is masked)
    • Else grad_mask = 1
  • Attach grad_mask to the batch:
    • batch.batch["grad_mask"] = mask_tensor
    • Log training/masked_sample_ratio
  • The actor and critic workers apply grad_mask when computing policy and value losses.

This corresponds to Section 3.3 of the paper and is implemented in verl/trainer/ppo/ray_trainer.py.


Project Structure

WSDM-ToolCure/verl/                 # Integrated verl framework with TOOL-CURE
├── verl/                           # Core verl framework (upstream + small changes)
│   ├── utils/reward_score/
│   │   └── rlla_intent.py          # TOOL-CURE reward function
│   ├── workers/reward_manager/
│   │   └── CustomRewardManger.py   # TOOL-CURE reward manager
│   └── trainer/ppo/
│       └── ray_trainer.py          # PPO trainer with OPGSS logic
├── step1_train.sh                  # Step 1: RL training (GRPO + OPGSS)
├── step2_model_merger.sh           # Step 2: merge FSDP checkpoints to HF weights
├── step3_inference.sh              # Step 3: offline inference (multi-sample rollout)
├── step4_evaluate.py               # Step 4: evaluation on inference parquet
└── README.md                       # This document

Installation

Prerequisites

  • Python >= 3.10
  • CUDA >= 12.4
  • cuDNN >= 9.1.0
  • vLLM >= 0.8.5.post1

You can use the official verl Docker image as a base:

docker pull verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.2-te2.2-deepep

Then clone this repository inside the container, and follow the quick start below.


Quick Start

Data preparation (parquet)

We recommend using a single parquet file with at least:

  • prompt: list of chat messages (for tokenizer.apply_chat_template(...)).
  • data_source: dataset identifier (used to route reward functions).
  • reward_model.ground_truth: label information (e.g. ground_truth_intent).

Example row:

{
  "data_source": "rlla_func_calling",
  "ability": "tool_selection",
  "prompt": [
    {
      "role": "user",
      "content": "## Tool APIs
Tool: get_videos_by_channel
Description: Fetch the latest videos from a YouTube channel.
Tool: titles_with_changed_episodes
Description: Retrieve a listing of titles that have changes to their episodes.
Tool: search
Description: Search for games on the Epic Games Store.

## Task Logic
Given the dialogue, select one or more tools from the API list.
Output format:
  Action: [\"tool_name_1\"]
  Action: [\"tool_name_1\", \"tool_name_2\"]
  Action: [\"NONE\"]

## Input Element
Dialogue: **Dialogue Records History**
User: I want to know what new games are available on the Epic Games Store. Also, can you get the latest videos from channel 'CHANNEL_ID'?"
    }
  ],
  "reward_model": {
    "ground_truth": "{\"task\": \"tool_selection\", \"domain\": \"rlla_func_calling\", \"ground_truth_intent\": [\"get_videos_by_channel\", \"search\"]}",
    "style": "rule"
  },
  "extra_info": {
    "answer": "<answer>['get_videos_by_channel', 'search']</answer>",
    "index": 3334,
    "split": "test"
  }
}

Loader reference: verl/utils/dataset/rl_dataset.py::RLHFDataset (default data.prompt_key=prompt).

This repo wraps the official quickstart into four scripts:

  1. RL training – step1_train.sh
  2. Checkpoint merge – step2_model_merger.sh
  3. Inference – step3_inference.sh
  4. Evaluation – step4_evaluate.py

You only need to edit a few placeholder paths in each script.

Step 1 – RL training (step1_train.sh)

This script launches GRPO training with OPGSS enabled, by calling python -m verl.trainer.main_ppo under the hood.

Edit the placeholders at the top of step1_train.sh:

  • ROOT_DIR="/path/to/WSDM-ToolCure/workspace/verl"
  • TRAIN_DATA_PATH="$ROOT_DIR/your_train_data_path.parquet"
  • VAL_DATA_PATH="$ROOT_DIR/your_val_data_path.parquet"
  • MODEL_PATH="$ROOT_DIR/your_base_or_merged_model_path"
  • REWARD_SUM_LOG_DIR="$ROOT_DIR/your_reward_summary_dir"
  • REWARD_SUM_LOG_FILENAME="your_reward_summary_filename.jsonl"
  • TRAIN_LOG_PATH="$ROOT_DIR/your_train_log_filename.log"
  • TENSORBOARD_DIR="$ROOT_DIR/your_tensorboard_file_directory"
  • CKPT_OUTPUT_DIR="$ROOT_DIR/your_ckpt_output_dir"

Run:

cd /path/to/WSDM-ToolCure/workspace/verl
sh step1_train.sh

This produces RL checkpoints under CKPT_OUTPUT_DIR (for example: .../global_step_XX/actor).

Step 2 – model merge (step2_model_merger.sh)

After training, we merge the actor checkpoints into a standalone HF model. Edit in step2_model_merger.sh:

  • ROOT_DIR="/path/to/WSDM-ToolCure/workspace/verl"
  • --local_dir "$ROOT_DIR/your_ckpt_output_dir/global_step_xx/actor"
  • --target_dir "$ROOT_DIR/your_merged_model_output_dir"

Run:

cd /path/to/WSDM-ToolCure/workspace/verl
sh step2_model_merger.sh

Step 3 – inference (step3_inference.sh)

This script uses python -m verl.trainer.main_generation to run offline inference and store model outputs in a parquet file. Edit in step3_inference.sh:

  • ROOT_DIR="/path/to/WSDM-ToolCure/workspace/verl"
  • DATA_PATH="$ROOT_DIR/your_eval_input_data.parquet"
  • SAVE_PATH="$ROOT_DIR/your_eval_output_data.parquet"
  • MODEL_PATH="$ROOT_DIR/your_merged_HF_model_output_dir" (Merged HF model produced in Step 2)

Run:

cd /path/to/WSDM-ToolCure/workspace/verl
sh step3_inference.sh

Step 4 – evaluation (step4_evaluate.py)

Finally, we evaluate the inference results with a multi-label metric suite.

Edit in step4_evaluate.py:

if __name__ == "__main__":
    # Set this to the parquet produced by step3_inference.sh
    input_parquet = "/path/to/your_eval_output_data.parquet"
    ...

Run:

cd /path/to/WSDM-ToolCure/workspace/verl
sh step4_evaluate.sh

The script will:

  • Parse Action: lines from model responses to extract predicted tool intents.
  • Extract ground-truth intents from reward_model.ground_truth / extra_info.answer fields.
  • Compute:
    • Subset accuracy (exact match)
    • Micro / macro / weighted precision, recall, and F1
    • Per-label metrics and a classification report
  • Save detailed outputs under eval_results_<input_basename>/, including:
    • sample_level_results.csv, recall_metrics.json
    • error_samples.json, correct_samples.json, format_error_samples.json

License

This project is licensed under the Apache License 2.0 – see the LICENSE file for details.


Citation

If you use TOOL-CURE in your research, please cite:

@inproceedings{zhang2026toolcure,
  author= {Jie Zhang and Dongsheng Bi and Tao Sun and Minghui Yang and Jian Wang and Yiwei Wang},
  title= {TOOL-CURE: Tool Selection via Curriculum-Enhanced Reinforcement Learning with Sample Screening for LLMs},
  booktitle= {Proceedings of the 19th {ACM} International Conference on Web Search and Data Mining (WSDM 2026)},
  year= {2026}
}

Acknowledgments

  • Built on top of the verl framework.
  • Uses Qwen models as base policies in our experiments.

Note: This is the open-source release that accompanies the WSDM 2026 paper. For the full, upstream verl framework, please visit the verl repository.

About

A training recipe for LLM-based tool-selection RL agent, designed for noisy, partially-correct, and dirty real-world data

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors