Skip to content

PKU-ML/GRASP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1,896 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GRASP Logo

GRASP: Graph Reasoning via Agentic Solving and Probing of LLMs


Authors: Xiaojun Guo*, Mingxue Tian*, Chenheng Zhang, Xiaohan Wang, Jiajun Chai, Guojun Yin, Wei Lin, Yifei Wang, Yisen Wang
*Equal Contribution. Correspondence.

📊 Overview

Integrating graph knowledge into Large Language Models (LLMs) via passive representation faces critical bottlenecks: limited context windows, unreliable numerical computation, and structural hallucinations. To solve this, we propose GRASP (Graph Reasoning via Agentic Solving and Probing), shifting the paradigm from passive ingestion to proactive agentic exploration. By interleaving Neighbor Retrieval for on-demand probing with Code Interpreter as a deterministic solver, GRASP enables LLMs to autonomously navigate and compute over complex topologies. We employ a staged reinforcement learning strategy (GRPO) that transitions from visible tuning to a structure-blind environment, forcing the agent to develop genuine topological awareness. Evaluated on multi-domain graph reasoning benchmarks, our 4B model achieves a 53.06% average performance boost, surpassing SOTA baselines like DeepSeek-V3.2 and successfully generalizing to unseen tasks, with high potential for tackling sampling on million-node graphs and solving Hard-level LeetCode graph problems.

📌 Key Takeaways

1️⃣ Agentic Probing over Passive Ingestion

We propose GRASP (Graph Reasoning via AgenticSolving and Probing), shifting the paradigm from passive ingestion to proactive agentic exploration. By interleaving Neighbor Retrieval (Eyes 👀) for on-demand probing with Code Interpreter (Hands 🙌) as a deterministic solver, GRASP enables LLMs to autonomously navigate and compute over complex topologies.

2️⃣ Structure-Blind RL Training

We employ a staged reinforcement learning strategy (GRPO) that transitions from visible tuning to a structure-blind environment, forcing the agent to develop genuine topological awareness.

3️⃣ From Million-Node Graphs to Hard LeetCode

Evaluated on multi-domain graph reasoning benchmarks, our 4B model achieves a 53.06% average performance boost, surpassing SOTA baselines like DeepSeek-V3.2 and successfully generalizing to unseen tasks, with high potential for tackling sampling on million-node graphs and solving Hard-level LeetCode graph problems.

🔥 Open-source Collections

Our models and datasets are released in the huggingface collection PKU-ML/GRASP. In detail:

Dataset

We release our dataset RealErdős on the huggingface at PKU-ML/RealErdos, and the full training / evaluation data in PKU-ML/GRASP

Models

We release our models GRASP-series on the huggingface:

🌟 A Quick Start

We provide an example code for graph reasoning using our GRASP models. Run the following:

cd evaluation
python example.py

In this example, the task is to detemine the maximal independent set for a wireless network:

🙋‍♀️ PROMP: "A wireless network has nodes Base_Station_1, Transmitter_1, Cell_A, Channel_A, Transceiver_A, Frequency_A, Signal_A, Station_X, Band_A, Spectrum_1, Sector_1, Relay_1, Antenna_A, Tower_1 with interference Base_Station_1 creates conflict with Transmitter_1. Base_Station_1 and Cell_A cannot operate simultaneously. Base_Station_1 conflicts with Channel_A. Base_Station_1 conflicts with Station_X. Base_Station_1 creates conflict with Band_A. Base_Station_1 and Spectrum_1 cannot operate simultaneously. Base_Station_1 creates conflict with Relay_1. Base_Station_1 and Antenna_A cannot operate simultaneously. Base_Station_1 interferes with Tower_1. Transmitter_1 interferes with Transceiver_A. Transmitter_1 creates conflict with Frequency_A. Transmitter_1 interferes with Signal_A. There is interference between Transmitter_1 and Spectrum_1. Cell_A creates conflict with Channel_A. Cell_A creates conflict with Transceiver_A. There is interference between Cell_A and Frequency_A. There is interference between Cell_A and Signal_A. Cell_A and Station_X cannot operate simultaneously. Transceiver_A interferes with Relay_1. There is interference between Signal_A and Band_A. Signal_A interferes with Sector_1. Station_X conflicts with Tower_1. There is interference between Band_A and Antenna_A. Spectrum_1 interferes with Sector_1.. Find maximal independent set. Question: What is a maximal independent set of this graph? You need to format your answer as a list of node names, e.g., [node-1, node-2, ..., node-n]. If necessary, you can provide Python code as a string argument to the code execution tool."

GRASP-4B correctly calls the Code Interpreter tool:

🤖 GRASP: "Okay, so I need to find a maximal independent set (MIS) for this wireless network graph. Let me first recall what a maximal independent set is. An independent set is a set of nodes where no two nodes are adjacent (i.e., there's no interference or conflict between them). ... First, I need to model the graph. Since the problem is to find an MIS, which is a bit complex to do by hand for 14 nodes, maybe I can approach it systematically.\n\nBut since the user wants the answer as a list of node names, and given that the problem is a bit large, perhaps using a tool to compute the MIS would be better. ... I can provide Python code as a string argument to the code execution tool. So maybe I can generate a Python script that uses a library like NetworkX to compute the MIS. ... Let me outline the steps: ... So the next step is to call the execute_code tool with this code.\n\n <tool_call>\n{"name": "execute_code", "arguments": {"code_string": "import networkx as nx\n\nG = nx.Graph()\n\nnodes = [\n \"Base_Station_1\", \"Transmitter_1\", \"Cell_A\", \"Channel_A\", \"Transceiver_A\",\n \"Frequency_A\", \"Signal_A\", \"Station_X\", \"Band_A\", \"Spectrum_1\", \"Sector_1\",\n \"Relay_1\", \"Antenna_A\", \"Tower_1\"\n]\n\nedges = [\n (\"Base_Station_1\", \"Transmitter_1\"),\n (\"Base_Station_1\", \"Cell_A\"),\n (\"Base_Station_1\", \"Channel_A\"),\n (\"Base_Station_1\", \"Station_X\"),\n (\"Base_Station_1\", \"Band_A\"),\n (\"Base_Station_1\", \"Spectrum_1\"),\n (\"Base_Station_1\", \"Relay_1\"),\n (\"Base_Station_1\", \"Antenna_A\"),\n (\"Base_Station_1\", \"Tower_1\"),\n (\"Transmitter_1\", \"Transceiver_A\"),\n (\"Transmitter_1\", \"Frequency_A\"),\n (\"Transmitter_1\", \"Signal_A\"),\n (\"Transmitter_1\", \"Spectrum_1\"),\n (\"Cell_A\", \"Channel_A\"),\n (\"Cell_A\", \"Transceiver_A\"),\n (\"Cell_A\", \"Frequency_A\"),\n (\"Cell_A\", \"Signal_A\"),\n (\"Cell_A\", \"Station_X\"),\n (\"Transceiver_A\", \"Relay_1\"),\n (\"Signal_A\", \"Band_A\"),\n (\"Signal_A\", \"Sector_1\"),\n (\"Station_X\", \"Tower_1\"),\n (\"Band_A\", \"Antenna_A\"),\n (\"Spectrum_1\", \"Sector_1\")\n]\n\nG.add_nodes_from(nodes)\nG.add_edges_from(edges)\n\nmis = list(nx.maximal_independent_set(G))\nresult = mis"}}\n </tool_call>"

The outputs of the code interpreter are:

💻 Tool: <tool_response>['Antenna_A', 'Channel_A', 'Relay_1', 'Station_X', 'Transmitter_1', 'Sector_1']</tool_response>

Based on the tool response, GRASP-4B gives a correct final answer:

🤖 GRASP: "Okay, let me check the tool response. The code execution returned a maximal independent set: ['Antenna_A', 'Channel_A', 'Relay_1', 'Station_X', 'Transmitter_1', 'Sector_1'].\n\nWait, but I need to verify if this is indeed a maximal independent set. Let me think.\n\nFirst, a maximal independent set means that no more nodes can be added without violating the independence condition. So each node in the set shouldn't be adjacent to any other in the set.\n\nLet me check the nodes in the result: ... Therefore, the set is maximal.\n\nThe answer is the list as given by the tool response.\n\n <answer>["Antenna_A", "Channel_A", "Relay_1", "Station_X", "Transmitter_1", "Sector_1"] </answer>"

🚀 Environment Setups

Our implementation is based on the library verl (version: 0.7.0.dev0) developed by ByteDance Seed team.

1. Requirements:

  • Python: Version >= 3.9
  • CUDA: Version >= 12.1

verl supports various backends. Currently, the following configurations are available: FSDP and Megatron-LM (optional) for training. SGLang, vLLM and TGI for rollout generation. For more information, please check the documents of verl.

2. For installing the dependencies, we recommend to use a fresh new conda environment:

conda create -n verl python==3.12
conda activate verl

3. Then, execute the install.sh script that we provided in the repository:

cd scripts
# Make sure you have activated verl conda env
# If you need to run with megatron
bash install_vllm_sglang_mcore.sh
# Or if you simply need to run with FSDP
USE_MEGATRON=0 bash install_vllm_sglang_mcore.sh

If you encounter any issues during installation, please refer to the Installation Guide provided by Verl. If problems persist, don’t hesitate to report them to us.

4. Install torch-geometric packages following the official guidance of pyg. For example, for cuda12.6 and torch 2.8, please run:

pip install torch_geometric

# Optional dependencies:
pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.8.0+cu126.html

5. Install TAGLAS:

cd TAGLAS
pip install -e .
pip install rdkit
pip install torchmetrics walker munkres fast-tsp

NOTE: We provide all the packages needed and the corresponding version in requirements.txt for reference.

🎯 Dataset Preparation

Download from Huggingface (Recommand)

We have uploaded the training and evaluation datasets to Huggingface GRASP for direct usage.

Prepare from Scratch

If you want to build the datasets from scratch, please follow the steps below.

Datasets in Stage I

1. Change the directory to TAGLAS/build_benchmarks:

cd TAGLAS/build_benchmarks

2. For Node Classification tasks and Link Prediction tasks, run

python build_node_classification.py
# or
python build_link_prediction.py

You can edit the variants of dataset_name to decide the dataset to be generated. Take the cora dataset as an example. The program will automatically download the raw dataset, and save it to datasets/taglas by default. Then it will be processed to parquet format and be saved to datasets/GRASP_All by default. If the networks does not work, you can also download the raw datasets manually (See the urls in TAGLAS/download_datasets.py).

3. For TSGBench, first download the raw data from TSGBench and save it to the directory datasets/Scene_Graphs. Then, run

python build_scene_graph.py

4. For ExplaGraphs, first download the raw data from G-Retriever and save it to directory datasets/Explain_Graphs. Then, run

python build_expla_graph.py

5. For Graph Theoretic tasks, please directly download them from Erdős and RealErdős. Then, copy each file to the directory datasets/GRASP_All.

6. Finally, the following command will merge all the sub-tasks from datasets/GRASP_All and save the merged parquet files to datasets/GRASP.

python combine_parquets.py

Datasets in Stage II

For the blind setting in Stage II, we extract the graph description from the Erdős benchmark and save it to the disk with only the prue questions left in the prompt.

First, download the datasets from Erdős and save it to datasets/Erdos. Then, run

python build_harder_erdos.py

The program will transfer the original benchmarks and save the graph description files to datasets/ErdosGraph and the training/evaluation parquets to datasets/HardErdos.

🧩 Running Reinforcement Learning

Follow these steps to reproduce our GRASP implementation:

1. Preprocess the dataset for RL training

Run the preprocessing script to convert the dataset format:

python examples/data_preprocess/grasp.py --input_dir datasets/GRASP --local_save_dir our_datasets/GRASP
  • --input_dir: Input directory for the dataset to be processed.

  • --local_save_dir: Output directory for the processed dataset.

2. Download the backbone model

Download Qwen3-4B-Thinking-2507 and save it to models/Qwen3-4B-Thinking-2507 locally.

3. Launch RL Training in Stage I

Execute the training script:

bash run_grasp_qwen3_think_4B.sh
  • SAVE_DIR: Output directory for the trained model.
  • train_path and test_path: Paths to the processed dataset.
  • actor_rollout_ref.rollout.multi_turn.tool_config_path: The path to the tool config file.
  • Logging: Defaults to Tensorboard. To use Weights & Biases, set trainer.logger = ['console','wandb'].

GPU requirements:

Our paper used 14×H20 GPUs. For limited GPU resources, reduce these parameters (may affect performance):

    actor_rollout_ref.actor.ppo_mini_batch_size
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu

Set trainer.n_gpus_per_node and trainer.nnodes to your actual GPU count.

4.Transfer the checkpoint to hf model

Run the following command to transfer the checkpoints to hf models.

bash scripts/merge_models.sh

The model will be saved to our_models/GRASP-base-4B by default.

5. Launch RL Training in Stage II

Execute the training script:

bash run_grasp_hard_erdos_4B.sh

Then, save the final model

bash scripts/merge_models.sh

NOTE: Edit the target_dir and local_dir variables in the merge_models.sh.

🌊 Inference and Evaluation

We provide evaluation scripts for multiple benchmarks in the directory evaluation.

To get the main results in our paper, please run

cd evaluation
python python3 eval.py --model our_models/GRASP-4B --save_name GRASP-4B --gpu 1

This command will save the generation into eval_results/GRASP/GRASP-4B.json. Then run the correctness_check.py to get the final metrics.

python correctness_check.py --dataset GRASP --model GRASP-4B

🎨 Customization Guide

If you want to improve over our methods or create new tools, feel free to

1. Write the tool configuration in configs to register your tools. For example, in configs/grasp_tool_config.yaml:

tools:
  - class_name: "verl.tools.graph_query_tool_cached.GraphQueryToolCached"
    config: 
      type: native
    tool_schema:
      type: "function"
      function:
        name: "query_neighbor_information"
        description: "Retrieves information about the neighbors of a specified node within a graph structure. This function allows you to query up to 'k' hops away from the given node, using a specific dataset name."
        parameters:
          type: "object"
          properties:
            node_index:
              type: "integer"
              description: "The unique index identifier of the node whose neighbors you want to query."
            dataset_name:
              type: "string"
              description: "The dataset name or database from which the node's neighbor information will be retrieved."
            k:
              type: "integer"
              description: "The number of hops away from the node to include in the query. A '1-hop' query will return direct neighbors, a '2-hop' will include neighbors of neighbors, etc. NOTE: for data from arxiv, k should be less than 3 to avoid out-of-memory."
          required: ["node_index", "dataset_name", "k"]

This file defines the tool class verl.tools.graph_query_tool_cached.GraphQueryToolCached, with the function name (query_neighbor_information), description, and required parameters (["node_index", "dataset_name", "k"]).

In the training script run_grasp_qwen3_think_4B.sh, we write

actor_rollout_ref.rollout.multi_turn.tool_config_path="configs/grasp_tool_config.yaml"

Then, the tools will be registered to the ToolAgentLoop class.

The verl framework will parse each generation of LLMs to check whether where is a tool calling. For example, when the LLM agent generates the following content

...<tool_call>{"name": "query_neighbor_information", "arguments": {"node_index": 3930, "dataset_name": "pubmed", "k": 1}}</tool_call>..."

The verl will parse the "<tool_call></tool_call>" tag, and recognize the registered function query_neighbor_information. Then, it will call the verl.tools.graph_query_tool_cached.GraphQueryToolCached and pass the parameters {"node_index": 3930, "dataset_name": "pubmed", "k": 1}} to the class's execute() function, as we will introduce below.

NOTE: Because we use Qwen3 models as the backbone, their chat template explicitly defines the format of tool calling (See tokenizer_config.json):

"""...
# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:<tools></tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>{"name": <function-name>, "arguments": <args-json-object>}</tool_call> 
...
"""
...

Correspondingly, we use the defaulting format parser (i.e., the hermes <tool_call> format) of verl (See verl/experimental/agent_loop/tool_parser.py):

config.actor_rollout_ref.rollout.multi_turn.format="hermes"

Therefore, if you use another backbone model, take note of its chat template for tool calling and whether it matches the verl's tool parser format ("hermes" or "gpt-oss").


2. Define / Edit the tool classes in verl/tools.

The verl framework introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter (See Recipe: Fully Async Policy Trainer). The agentic RL training is implemented in the asynchronous way (See Agentic RL Training).

We recommand to read verl/tools/base_tool.py for a quick glance. To create a custom tool class, you need define two functions: async def create() and async def execute().

  • create(): The function creates a tool instance. It receives a parameter instance_id, and returns the instance id of the tool and the response of the tool when creating the instance.
  • execute(): The function execute the tool. It receives the instance_id and the parameters to be passed to the tool (the type is a dict[str, Any]). Then, the function returns a ToolResponse object, tool reward score and tool metrics.

For example, our GraphRetrieval Tool is implemented in execute() of verl/tools/graph_query_tool_cached.py:

@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
       
    # The parameters should be passed to the tool by LLM agents.
    node_index = parameters.get("node_index")
    dataset_name = parameters.get("dataset_name", "")
    k = parameters.get("k", 1)
        
    # ... omit some parameter correctness checking

    logger.info(f"Query {k}-hop neighbor information for node {node_index}")

    # Query neighbor information (uses cached dataset)
    neighbor_information = self._query_neighbor_information(node_index, dataset_name, k)
        
    # Store the result
    self._instance_dict[instance_id]["neighbor_information"] = neighbor_information

    return ToolResponse(text=neighbor_information), 0.0, {}

3. When constructing the train dataset for RL, you need tell the agent how to use these tools. For example, in examples/data_preprocess/grasp.py, we write system_instruction by combining the instruction and the tool schema.

4. When writing the RL training script, besides the general parameters, there are some tool-specific ones:

actor_rollout_ref.rollout.multi_turn.tool_config_path="configs/grasp_tool_config.yaml"
actor_rollout_ref.rollout.multi_turn.max_user_turns=10 \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=10 \
actor_rollout_ref.rollout.multi_turn.max_tool_response_length=2000 \

5. In our experience, the base model (Qwen3-4B-Thinking-2507) can understand these two tools well and improves through trials-and-errors during RL. However, if the base model works poorly on your tools, we recommand first revise the system prompt and tool configurations. If that does not work, it's maybe a good way to do supervised fine-tuning.

🔔 verl also provide a Multi-turn Rollout Support in the official document and a quick start.

Citation

If you find this work useful, please give us a free cite:

About

The public implementations of the Paper "GRASP: Graph Reasoning via Agentic Solving and Probing of LLMs"

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors