Skip to content

EdgePrompt#247

Open
Yilong-sudo wants to merge 2 commits into
BUPT-GAMMA:mainfrom
Yilong-sudo:EdgePrompt
Open

EdgePrompt#247
Yilong-sudo wants to merge 2 commits into
BUPT-GAMMA:mainfrom
Yilong-sudo:EdgePrompt

Conversation

@Yilong-sudo

Copy link
Copy Markdown

Description

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

Changes

Copilot AI review requested due to automatic review settings March 23, 2026 15:39

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an EdgePrompt/EdgePrompt+ implementation to GammaGL, along with utility helpers and runnable example scripts for pretraining and few-shot downstream node classification.

Changes:

  • Introduces EdgePromptGCNModel and EdgePromptNodeClassifier with EdgePrompt/EdgePrompt+ prompt injection.
  • Adds node_subgraph (node-centered k-hop subgraph as a Graph) and get_few_shot_split (few-shot train/test sampling).
  • Adds end-to-end example scripts (edge pretraining + downstream few-shot finetuning) and accompanying README.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
gammagl/utils/subgraph.py Adds node_subgraph helper built on k_hop_subgraph.
gammagl/utils/get_split.py Adds get_few_shot_split for few-shot train/test index sampling.
gammagl/utils/init.py Exposes node_subgraph and get_few_shot_split via utils public API.
gammagl/models/edgeprompt.py New EdgePrompt/EdgePrompt+ modules and GCN backbone + node classifier wrapper.
gammagl/models/init.py Exposes new EdgePrompt models via models public API.
examples/edgeprompt/node_edgeprompt_pretrain.py New example script for masked-edge link prediction pretraining.
examples/edgeprompt/node_edgeprompt_finetune.py New example script for few-shot downstream node classification.
examples/edgeprompt/init.py Marks EdgePrompt examples as a package.
examples/edgeprompt/README.md Usage instructions and reported results for the new examples.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +60 to +81
def get_few_shot_split(labels, num_shots, test_ratio=0.2, random_state=0):
"""Sample a minimal few-shot train/test split for node classification.

This follows the original EdgePrompt node downstream protocol closely:
sample up to ``num_shots`` nodes per class for training, remove those nodes
from the candidate pool, then draw a random test subset from the remaining
nodes.
"""
if test_ratio <= 0 or test_ratio > 1:
raise ValueError('test_ratio must be in (0, 1].')

labels = tlx.reshape(labels, (-1,))
labels_np = tlx.convert_to_numpy(labels)
rng = np.random.RandomState(random_state)

train_indices = []
for cls in np.unique(labels_np):
cls_indices = np.where(labels_np == cls)[0]
if cls_indices.shape[0] <= num_shots:
train_indices.extend(cls_indices.tolist())
else:
train_indices.extend(rng.choice(cls_indices, size=num_shots, replace=False).tolist())

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_few_shot_split does not validate num_shots. For num_shots <= 0, the current code will sample zero elements per class (or behave unexpectedly), which is likely not intended for a “few-shot” split. Consider adding a check that num_shots is a positive integer and raising a ValueError otherwise.

Copilot uses AI. Check for mistakes.
Comment on lines +60 to +92
def get_few_shot_split(labels, num_shots, test_ratio=0.2, random_state=0):
"""Sample a minimal few-shot train/test split for node classification.

This follows the original EdgePrompt node downstream protocol closely:
sample up to ``num_shots`` nodes per class for training, remove those nodes
from the candidate pool, then draw a random test subset from the remaining
nodes.
"""
if test_ratio <= 0 or test_ratio > 1:
raise ValueError('test_ratio must be in (0, 1].')

labels = tlx.reshape(labels, (-1,))
labels_np = tlx.convert_to_numpy(labels)
rng = np.random.RandomState(random_state)

train_indices = []
for cls in np.unique(labels_np):
cls_indices = np.where(labels_np == cls)[0]
if cls_indices.shape[0] <= num_shots:
train_indices.extend(cls_indices.tolist())
else:
train_indices.extend(rng.choice(cls_indices, size=num_shots, replace=False).tolist())

train_set = set(train_indices)
remaining_indices = [idx for idx in rng.permutation(labels_np.shape[0]).tolist() if idx not in train_set]
num_test = max(1, int(test_ratio * labels_np.shape[0]))
num_test = min(num_test, len(remaining_indices))
test_indices = remaining_indices[:num_test]

return (
tlx.convert_to_tensor(np.asarray(train_indices, dtype=np.int64), dtype=tlx.int64),
tlx.convert_to_tensor(np.asarray(test_indices, dtype=np.int64), dtype=tlx.int64),
)

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is existing test coverage for get_train_val_test_split in tests/utils/test_get_split.py, but the newly added get_few_shot_split is untested. Adding a unit test that checks per-class shot counts, reproducibility via random_state, and that train/test indices are disjoint would help keep this utility stable.

Copilot uses AI. Check for mistakes.
- Paper link: [https://arxiv.org/abs/2503.00750](https://arxiv.org/abs/2503.00750)
- Author's code repo: [https://github.com/xbfu/EdgePrompt](https://github.com/xbfu/EdgePrompt)

## Dataset Statics

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling: “Dataset Statics” should be “Dataset Statistics”.

Suggested change
## Dataset Statics
## Dataset Statistics

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +64
# Add self-loops first so every node can aggregate its own feature.
num_nodes = int(tlx.get_tensor_shape(x)[0])
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EdgePromptGCNConv unconditionally adds self-loops, but edge_prompt is computed outside (and EdgePromptPlus.get_prompt also adds self-loops). This can make edge_prompt.shape[0] not match the post-add_self_loops edge count, causing incorrect broadcasting or runtime shape errors. Consider adding self-loops exactly once (e.g., in EdgePromptGCNModel.forward) and passing the same edge_index to both prompt.get_prompt(...) and conv(...), or have conv only add self-loops when edge_prompt is None/already aligned.

Suggested change
# Add self-loops first so every node can aggregate its own feature.
num_nodes = int(tlx.get_tensor_shape(x)[0])
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
# Compute number of nodes from input features.
num_nodes = int(tlx.get_tensor_shape(x)[0])
# Add self-loops only when no external edge prompt is provided.
# When edge_prompt is not None, we assume edge_index is already
# aligned with edge_prompt (including any self-loops).
if edge_prompt is None:
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +86
num_test = max(1, int(test_ratio * labels_np.shape[0]))
num_test = min(num_test, len(remaining_indices))

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When num_shots is large enough to consume nearly all nodes, remaining_indices can be empty; in that case num_test becomes 0 after min(...), returning an empty test split even if test_ratio > 0. Either raise a clear error when there are no remaining candidates for testing, or document that the test set may be empty and adjust the max(1, ...) logic accordingly.

Suggested change
num_test = max(1, int(test_ratio * labels_np.shape[0]))
num_test = min(num_test, len(remaining_indices))
if len(remaining_indices) == 0:
raise ValueError(
"No remaining nodes are available for the test set after sampling training shots. "
"Consider reducing 'num_shots' or 'test_ratio'."
)
num_test = min(max(1, int(test_ratio * labels_np.shape[0])), len(remaining_indices))

Copilot uses AI. Check for mistakes.
Comment on lines +160 to +163
edge_index, _ = add_self_loops(
edge_index,
num_nodes=int(tlx.get_tensor_shape(x)[0]),
)

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EdgePromptPlus.get_prompt calls add_self_loops, but EdgePromptGCNConv.forward also adds self-loops. With the current flow (EdgePromptGCNModel computes the prompt before calling conv), this double-modifies edge_index and will misalign the generated per-edge prompts with the edges actually used by the convolution. Suggest removing self-loop addition from either the prompt generator or the convolution, and ensuring prompts are generated for the exact edge_index used in message passing.

Suggested change
edge_index, _ = add_self_loops(
edge_index,
num_nodes=int(tlx.get_tensor_shape(x)[0]),
)

Copilot uses AI. Check for mistakes.
Comment on lines +173 to +253
class EdgePromptGCNModel(tlx.nn.Module):
r"""A stacked GCN backbone for node or graph representations with EdgePrompt.

Parameters
----------
feature_dim: int
Dimension of the input node features.
hidden_dim: int
Dimension of hidden representations.
num_layers: int, optional
Number of GCN layers.
drop_rate: float, optional
Dropout rate applied between hidden layers.
name: str, optional
Name of the module.
"""
def __init__(
self,
feature_dim: int,
hidden_dim: int,
num_layers: int = 2,
drop_rate: float = 0.5,
name: Optional[str] = None,
):
super().__init__(name=name)

if num_layers < 1:
raise ValueError("num_layers must be at least 1.")

self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.prompt_dims = [feature_dim] + [hidden_dim] * (num_layers - 1)

self.convs = nn.ModuleList()
in_dims = self.prompt_dims
out_dims = [hidden_dim] * num_layers
for layer, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
self.convs.append(
EdgePromptGCNConv(
in_dim,
out_dim,
name="edgeprompt_conv_{}".format(layer),
)
)

self.relu = tlx.ReLU()
self.dropout = nn.Dropout(p=drop_rate)

def forward(self, graph, prompt_type=None, prompt=None, pooling=None):
# When prompt is enabled, each layer obtains its own edge prompt tensor.
x, edge_index = graph.x, graph.edge_index
prompt_type = normalize_prompt_type(prompt_type)

for layer, conv in enumerate(self.convs):
edge_prompt = None
if prompt is not None and prompt_type in ("EdgePrompt", "EdgePromptplus"):
edge_prompt = prompt.get_prompt(x, edge_index, layer)

x = conv(x, edge_index, edge_prompt=edge_prompt)
if layer != self.num_layers - 1:
x = self.relu(x)
x = self.dropout(x)

if pooling == "mean":
# Graph-level mean pooling for batched graphs.
batch = getattr(graph, "batch", None)
if batch is None:
raise ValueError("Mean pooling requires batched graphs with `batch`.")
return global_mean_pool(x, batch)

if pooling == "target":
# Gather the designated target node from each sampled subgraph.
if not hasattr(graph, "ptr") or not hasattr(graph, "target_node"):
raise ValueError(
"Target pooling requires batched subgraphs with `ptr` and `target_node`."
)
target_index = graph.ptr[:-1] + tlx.reshape(graph.target_node, (-1,))
return tlx.gather(x, target_index)

return x

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New model code (EdgePromptGCNModel/EdgePromptNodeClassifier) is added but there are no corresponding unit tests under tests/models/ (the repo already has model tests, e.g. tests/models/test_mlp.py). Adding at least a smoke test that checks forward shapes (with/without prompts, and pooling="mean"/"target") would help prevent regressions.

Copilot uses AI. Check for mistakes.
Comment thread gammagl/utils/subgraph.py
Comment on lines +89 to +106
subset, edge_index, mapping, _ = k_hop_subgraph(
node_idx=node_idx,
num_hops=num_hops,
edge_index=graph.edge_index,
relabel_nodes=True,
num_nodes=graph.num_nodes,
)
if tlx.is_tensor(mapping):
mapping_np = tlx.convert_to_numpy(mapping)
else:
mapping_np = np.asarray(mapping)
target_node = int(np.asarray(mapping_np).reshape(-1)[0])
return Graph(
x=tlx.gather(graph.x, subset),
edge_index=edge_index,
target_node=tlx.convert_to_tensor([target_node], dtype=tlx.int64),
num_nodes=int(tlx.get_tensor_shape(subset)[0]),
)

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node_subgraph drops edge-level attributes from the original graph (e.g., edge_attr) even though k_hop_subgraph returns edge_mask specifically to support filtering aligned edge features. If callers use graphs with edge_attr, the returned Graph will be inconsistent with edge_index. Consider propagating edge_attr (and any other edge-level fields you want to support) using the edge mask.

Copilot uses AI. Check for mistakes.
Comment thread gammagl/utils/subgraph.py
mapping_np = tlx.convert_to_numpy(mapping)
else:
mapping_np = np.asarray(mapping)
target_node = int(np.asarray(mapping_np).reshape(-1)[0])

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node_subgraph is described as “node-centered”, but it will silently accept multi-node node_idx inputs because k_hop_subgraph supports lists/tuples; it then picks only the first entry from mapping as target_node. This can lead to incorrect labels/targets without an obvious error. Consider validating that node_idx refers to a single node (scalar or length-1 tensor/list) and raising a clear ValueError otherwise.

Suggested change
target_node = int(np.asarray(mapping_np).reshape(-1)[0])
mapping_np = np.asarray(mapping_np)
if mapping_np.size != 1:
raise ValueError(
f"node_subgraph expects a single central node, but got {mapping_np.size} "
"nodes from k_hop_subgraph. Ensure that `node_idx` is a scalar or a "
"length-1 tensor/list/tuple."
)
target_node = int(mapping_np.reshape(-1)[0])

Copilot uses AI. Check for mistakes.
Comment thread gammagl/utils/subgraph.py
Comment on lines +87 to +106
def node_subgraph(graph, node_idx, num_hops=2):
"""Return a node-centered k-hop subgraph as a ``Graph`` object."""
subset, edge_index, mapping, _ = k_hop_subgraph(
node_idx=node_idx,
num_hops=num_hops,
edge_index=graph.edge_index,
relabel_nodes=True,
num_nodes=graph.num_nodes,
)
if tlx.is_tensor(mapping):
mapping_np = tlx.convert_to_numpy(mapping)
else:
mapping_np = np.asarray(mapping)
target_node = int(np.asarray(mapping_np).reshape(-1)[0])
return Graph(
x=tlx.gather(graph.x, subset),
edge_index=edge_index,
target_node=tlx.convert_to_tensor([target_node], dtype=tlx.int64),
num_nodes=int(tlx.get_tensor_shape(subset)[0]),
)

Copilot AI Mar 23, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is existing test coverage for k_hop_subgraph in tests/utils/test_k_hop_subgraph.py, but the new node_subgraph helper is untested. Adding a small unit test (e.g., verify target_node is correct, edge_index is relabeled, and x matches the subset) would align with the repo’s existing utils test patterns.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants