Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions examples/data_preprocess/asearcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
- agent_name is set to "search_agent" (matching agent_config.yaml)
- ground_truth is placed in tools_kwargs["reward"] for SearchRewardSpec
- system prompt instructs the model to use search + finish tools

If ``--input_json`` is omitted, the filtered ASearcher dataset
(``aidenjhwu/ASearcher_en_no-math_Qwen3-8B-reject-sample``) is downloaded from
HuggingFace automatically. That dataset is derived from the original
``inclusionAI/ASearcher-train-data`` with Chinese samples and math problems
removed and reject sampling applied, and it already ships in the nested
``extra_info`` schema this script expects. To preprocess the original
(unfiltered) ASearcher data instead, write a separate preprocessing script: its
records use a different, flat schema (top-level ``question``/``answer``) that
this script does not handle.
"""

import argparse
Expand All @@ -30,6 +40,10 @@
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Filtered ASearcher dataset used by default when no local --input_json is given.
HF_FILTERED_REPO = "aidenjhwu/ASearcher_en_no-math_Qwen3-8B-reject-sample"
HF_FILTERED_FILE = "ASearcher_en_nomath_rejectsample.json"

DEFAULT_SYSTEM_CONTENT = (
"You are an expert research assistant. Your goal is to answer the user's question by thoroughly "
"researching it. You must follow a structured process of reasoning and tool use.\n\n"
Expand Down Expand Up @@ -117,12 +131,22 @@ def _read_input_as_dataframe(json_path: str) -> pd.DataFrame:
return pd.read_json(json_path, lines=True)


def _load_filtered_dataset() -> pd.DataFrame:
"""Download the filtered ASearcher dataset from HuggingFace as a DataFrame."""
from datasets import load_dataset

logger.info(f"Downloading {HF_FILTERED_REPO}:{HF_FILTERED_FILE} from HuggingFace...")
dataset = load_dataset(HF_FILTERED_REPO, data_files=HF_FILTERED_FILE, split="train")
return dataset.to_pandas()
Comment on lines +134 to +140

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The datasets library is an optional dependency and might not be installed in all environments. If a user runs this script without --input_json and does not have datasets installed, they will get a generic ImportError. Wrapping the import in a try-except block with a clear, actionable error message will improve the user experience.

def _load_filtered_dataset() -> pd.DataFrame:
    """Download the filtered ASearcher dataset from HuggingFace as a DataFrame."""
    try:
        from datasets import load_dataset
    except ImportError as e:
        raise ImportError(
            "The 'datasets' library is required to download the default dataset from HuggingFace. "
            "Please install it using 'pip install datasets' or provide a local path via '--input_json'."
        ) from e

    logger.info(f"Downloading {HF_FILTERED_REPO}:{HF_FILTERED_FILE} from HuggingFace...")
    dataset = load_dataset(HF_FILTERED_REPO, data_files=HF_FILTERED_FILE, split="train")
    return dataset.to_pandas()



def main():
parser = argparse.ArgumentParser(description="Preprocess ASearcher JSON/JSONL dataset for uni-agent training.")
parser.add_argument(
"--input_json",
required=True,
help="Path to raw ASearcher JSON or JSONL file.",
default=None,
help="Path to raw ASearcher JSON or JSONL file. If omitted, the filtered dataset "
f"({HF_FILTERED_REPO}) is downloaded from HuggingFace.",
)
parser.add_argument(
"--local_save_dir",
Expand All @@ -143,12 +167,16 @@ def main():
)
args = parser.parse_args()

input_json_path = os.path.expanduser(args.input_json)
local_save_dir = os.path.expanduser(args.local_save_dir)
os.makedirs(local_save_dir, exist_ok=True)

df_raw = _read_input_as_dataframe(input_json_path)
logger.info(f"Loaded {len(df_raw)} records from {input_json_path}")
if args.input_json:
input_json_path = os.path.expanduser(args.input_json)
df_raw = _read_input_as_dataframe(input_json_path)
logger.info(f"Loaded {len(df_raw)} records from {input_json_path}")
else:
df_raw = _load_filtered_dataset()
logger.info(f"Loaded {len(df_raw)} records from {HF_FILTERED_REPO}")
Comment on lines +173 to +179

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the downloaded filtered dataset contains fewer than the default 8,292 rows (8,192 train + 100 test), the script will crash with a ValueError at line 182. To make the script robust, we should dynamically adjust the train/test split sizes if the loaded dataset is smaller than the requested total.

Suggested change
if args.input_json:
input_json_path = os.path.expanduser(args.input_json)
df_raw = _read_input_as_dataframe(input_json_path)
logger.info(f"Loaded {len(df_raw)} records from {input_json_path}")
else:
df_raw = _load_filtered_dataset()
logger.info(f"Loaded {len(df_raw)} records from {HF_FILTERED_REPO}")
if args.input_json:
input_json_path = os.path.expanduser(args.input_json)
df_raw = _read_input_as_dataframe(input_json_path)
logger.info(f"Loaded {len(df_raw)} records from {input_json_path}")
else:
df_raw = _load_filtered_dataset()
logger.info(f"Loaded {len(df_raw)} records from {HF_FILTERED_REPO}")
if len(df_raw) < args.train_rows + args.test_rows:
args.test_rows = max(1, len(df_raw) // 10) if len(df_raw) > 1 else 0
args.train_rows = len(df_raw) - args.test_rows


total_needed = args.train_rows + args.test_rows
if len(df_raw) < total_needed:
Expand Down
Loading