-
Notifications
You must be signed in to change notification settings - Fork 59
[examples] feat: add filtered ASearcher dataset source to the preprocess script #74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,6 +19,16 @@ | |||||||||||||||||||||||||||||||||||
| - agent_name is set to "search_agent" (matching agent_config.yaml) | ||||||||||||||||||||||||||||||||||||
| - ground_truth is placed in tools_kwargs["reward"] for SearchRewardSpec | ||||||||||||||||||||||||||||||||||||
| - system prompt instructs the model to use search + finish tools | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| If ``--input_json`` is omitted, the filtered ASearcher dataset | ||||||||||||||||||||||||||||||||||||
| (``aidenjhwu/ASearcher_en_no-math_Qwen3-8B-reject-sample``) is downloaded from | ||||||||||||||||||||||||||||||||||||
| HuggingFace automatically. That dataset is derived from the original | ||||||||||||||||||||||||||||||||||||
| ``inclusionAI/ASearcher-train-data`` with Chinese samples and math problems | ||||||||||||||||||||||||||||||||||||
| removed and reject sampling applied, and it already ships in the nested | ||||||||||||||||||||||||||||||||||||
| ``extra_info`` schema this script expects. To preprocess the original | ||||||||||||||||||||||||||||||||||||
| (unfiltered) ASearcher data instead, write a separate preprocessing script: its | ||||||||||||||||||||||||||||||||||||
| records use a different, flat schema (top-level ``question``/``answer``) that | ||||||||||||||||||||||||||||||||||||
| this script does not handle. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||
|
|
@@ -30,6 +40,10 @@ | |||||||||||||||||||||||||||||||||||
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | ||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Filtered ASearcher dataset used by default when no local --input_json is given. | ||||||||||||||||||||||||||||||||||||
| HF_FILTERED_REPO = "aidenjhwu/ASearcher_en_no-math_Qwen3-8B-reject-sample" | ||||||||||||||||||||||||||||||||||||
| HF_FILTERED_FILE = "ASearcher_en_nomath_rejectsample.json" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| DEFAULT_SYSTEM_CONTENT = ( | ||||||||||||||||||||||||||||||||||||
| "You are an expert research assistant. Your goal is to answer the user's question by thoroughly " | ||||||||||||||||||||||||||||||||||||
| "researching it. You must follow a structured process of reasoning and tool use.\n\n" | ||||||||||||||||||||||||||||||||||||
|
|
@@ -117,12 +131,22 @@ def _read_input_as_dataframe(json_path: str) -> pd.DataFrame: | |||||||||||||||||||||||||||||||||||
| return pd.read_json(json_path, lines=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _load_filtered_dataset() -> pd.DataFrame: | ||||||||||||||||||||||||||||||||||||
| """Download the filtered ASearcher dataset from HuggingFace as a DataFrame.""" | ||||||||||||||||||||||||||||||||||||
| from datasets import load_dataset | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger.info(f"Downloading {HF_FILTERED_REPO}:{HF_FILTERED_FILE} from HuggingFace...") | ||||||||||||||||||||||||||||||||||||
| dataset = load_dataset(HF_FILTERED_REPO, data_files=HF_FILTERED_FILE, split="train") | ||||||||||||||||||||||||||||||||||||
| return dataset.to_pandas() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser(description="Preprocess ASearcher JSON/JSONL dataset for uni-agent training.") | ||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||
| "--input_json", | ||||||||||||||||||||||||||||||||||||
| required=True, | ||||||||||||||||||||||||||||||||||||
| help="Path to raw ASearcher JSON or JSONL file.", | ||||||||||||||||||||||||||||||||||||
| default=None, | ||||||||||||||||||||||||||||||||||||
| help="Path to raw ASearcher JSON or JSONL file. If omitted, the filtered dataset " | ||||||||||||||||||||||||||||||||||||
| f"({HF_FILTERED_REPO}) is downloaded from HuggingFace.", | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||||||||||||||||||
| "--local_save_dir", | ||||||||||||||||||||||||||||||||||||
|
|
@@ -143,12 +167,16 @@ def main(): | |||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| input_json_path = os.path.expanduser(args.input_json) | ||||||||||||||||||||||||||||||||||||
| local_save_dir = os.path.expanduser(args.local_save_dir) | ||||||||||||||||||||||||||||||||||||
| os.makedirs(local_save_dir, exist_ok=True) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| df_raw = _read_input_as_dataframe(input_json_path) | ||||||||||||||||||||||||||||||||||||
| logger.info(f"Loaded {len(df_raw)} records from {input_json_path}") | ||||||||||||||||||||||||||||||||||||
| if args.input_json: | ||||||||||||||||||||||||||||||||||||
| input_json_path = os.path.expanduser(args.input_json) | ||||||||||||||||||||||||||||||||||||
| df_raw = _read_input_as_dataframe(input_json_path) | ||||||||||||||||||||||||||||||||||||
| logger.info(f"Loaded {len(df_raw)} records from {input_json_path}") | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| df_raw = _load_filtered_dataset() | ||||||||||||||||||||||||||||||||||||
| logger.info(f"Loaded {len(df_raw)} records from {HF_FILTERED_REPO}") | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+173
to
+179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the downloaded filtered dataset contains fewer than the default 8,292 rows (8,192 train + 100 test), the script will crash with a
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| total_needed = args.train_rows + args.test_rows | ||||||||||||||||||||||||||||||||||||
| if len(df_raw) < total_needed: | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
datasetslibrary is an optional dependency and might not be installed in all environments. If a user runs this script without--input_jsonand does not havedatasetsinstalled, they will get a genericImportError. Wrapping the import in atry-exceptblock with a clear, actionable error message will improve the user experience.