EdgePrompt#247
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an EdgePrompt/EdgePrompt+ implementation to GammaGL, along with utility helpers and runnable example scripts for pretraining and few-shot downstream node classification.
Changes:
- Introduces
EdgePromptGCNModelandEdgePromptNodeClassifierwith EdgePrompt/EdgePrompt+ prompt injection. - Adds
node_subgraph(node-centered k-hop subgraph as aGraph) andget_few_shot_split(few-shot train/test sampling). - Adds end-to-end example scripts (edge pretraining + downstream few-shot finetuning) and accompanying README.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| gammagl/utils/subgraph.py | Adds node_subgraph helper built on k_hop_subgraph. |
| gammagl/utils/get_split.py | Adds get_few_shot_split for few-shot train/test index sampling. |
| gammagl/utils/init.py | Exposes node_subgraph and get_few_shot_split via utils public API. |
| gammagl/models/edgeprompt.py | New EdgePrompt/EdgePrompt+ modules and GCN backbone + node classifier wrapper. |
| gammagl/models/init.py | Exposes new EdgePrompt models via models public API. |
| examples/edgeprompt/node_edgeprompt_pretrain.py | New example script for masked-edge link prediction pretraining. |
| examples/edgeprompt/node_edgeprompt_finetune.py | New example script for few-shot downstream node classification. |
| examples/edgeprompt/init.py | Marks EdgePrompt examples as a package. |
| examples/edgeprompt/README.md | Usage instructions and reported results for the new examples. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_few_shot_split(labels, num_shots, test_ratio=0.2, random_state=0): | ||
| """Sample a minimal few-shot train/test split for node classification. | ||
|
|
||
| This follows the original EdgePrompt node downstream protocol closely: | ||
| sample up to ``num_shots`` nodes per class for training, remove those nodes | ||
| from the candidate pool, then draw a random test subset from the remaining | ||
| nodes. | ||
| """ | ||
| if test_ratio <= 0 or test_ratio > 1: | ||
| raise ValueError('test_ratio must be in (0, 1].') | ||
|
|
||
| labels = tlx.reshape(labels, (-1,)) | ||
| labels_np = tlx.convert_to_numpy(labels) | ||
| rng = np.random.RandomState(random_state) | ||
|
|
||
| train_indices = [] | ||
| for cls in np.unique(labels_np): | ||
| cls_indices = np.where(labels_np == cls)[0] | ||
| if cls_indices.shape[0] <= num_shots: | ||
| train_indices.extend(cls_indices.tolist()) | ||
| else: | ||
| train_indices.extend(rng.choice(cls_indices, size=num_shots, replace=False).tolist()) |
There was a problem hiding this comment.
get_few_shot_split does not validate num_shots. For num_shots <= 0, the current code will sample zero elements per class (or behave unexpectedly), which is likely not intended for a “few-shot” split. Consider adding a check that num_shots is a positive integer and raising a ValueError otherwise.
| def get_few_shot_split(labels, num_shots, test_ratio=0.2, random_state=0): | ||
| """Sample a minimal few-shot train/test split for node classification. | ||
|
|
||
| This follows the original EdgePrompt node downstream protocol closely: | ||
| sample up to ``num_shots`` nodes per class for training, remove those nodes | ||
| from the candidate pool, then draw a random test subset from the remaining | ||
| nodes. | ||
| """ | ||
| if test_ratio <= 0 or test_ratio > 1: | ||
| raise ValueError('test_ratio must be in (0, 1].') | ||
|
|
||
| labels = tlx.reshape(labels, (-1,)) | ||
| labels_np = tlx.convert_to_numpy(labels) | ||
| rng = np.random.RandomState(random_state) | ||
|
|
||
| train_indices = [] | ||
| for cls in np.unique(labels_np): | ||
| cls_indices = np.where(labels_np == cls)[0] | ||
| if cls_indices.shape[0] <= num_shots: | ||
| train_indices.extend(cls_indices.tolist()) | ||
| else: | ||
| train_indices.extend(rng.choice(cls_indices, size=num_shots, replace=False).tolist()) | ||
|
|
||
| train_set = set(train_indices) | ||
| remaining_indices = [idx for idx in rng.permutation(labels_np.shape[0]).tolist() if idx not in train_set] | ||
| num_test = max(1, int(test_ratio * labels_np.shape[0])) | ||
| num_test = min(num_test, len(remaining_indices)) | ||
| test_indices = remaining_indices[:num_test] | ||
|
|
||
| return ( | ||
| tlx.convert_to_tensor(np.asarray(train_indices, dtype=np.int64), dtype=tlx.int64), | ||
| tlx.convert_to_tensor(np.asarray(test_indices, dtype=np.int64), dtype=tlx.int64), | ||
| ) |
There was a problem hiding this comment.
There is existing test coverage for get_train_val_test_split in tests/utils/test_get_split.py, but the newly added get_few_shot_split is untested. Adding a unit test that checks per-class shot counts, reproducibility via random_state, and that train/test indices are disjoint would help keep this utility stable.
| - Paper link: [https://arxiv.org/abs/2503.00750](https://arxiv.org/abs/2503.00750) | ||
| - Author's code repo: [https://github.com/xbfu/EdgePrompt](https://github.com/xbfu/EdgePrompt) | ||
|
|
||
| ## Dataset Statics |
There was a problem hiding this comment.
Spelling: “Dataset Statics” should be “Dataset Statistics”.
| ## Dataset Statics | |
| ## Dataset Statistics |
| # Add self-loops first so every node can aggregate its own feature. | ||
| num_nodes = int(tlx.get_tensor_shape(x)[0]) | ||
| edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) | ||
|
|
There was a problem hiding this comment.
EdgePromptGCNConv unconditionally adds self-loops, but edge_prompt is computed outside (and EdgePromptPlus.get_prompt also adds self-loops). This can make edge_prompt.shape[0] not match the post-add_self_loops edge count, causing incorrect broadcasting or runtime shape errors. Consider adding self-loops exactly once (e.g., in EdgePromptGCNModel.forward) and passing the same edge_index to both prompt.get_prompt(...) and conv(...), or have conv only add self-loops when edge_prompt is None/already aligned.
| # Add self-loops first so every node can aggregate its own feature. | |
| num_nodes = int(tlx.get_tensor_shape(x)[0]) | |
| edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) | |
| # Compute number of nodes from input features. | |
| num_nodes = int(tlx.get_tensor_shape(x)[0]) | |
| # Add self-loops only when no external edge prompt is provided. | |
| # When edge_prompt is not None, we assume edge_index is already | |
| # aligned with edge_prompt (including any self-loops). | |
| if edge_prompt is None: | |
| edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) |
| num_test = max(1, int(test_ratio * labels_np.shape[0])) | ||
| num_test = min(num_test, len(remaining_indices)) |
There was a problem hiding this comment.
When num_shots is large enough to consume nearly all nodes, remaining_indices can be empty; in that case num_test becomes 0 after min(...), returning an empty test split even if test_ratio > 0. Either raise a clear error when there are no remaining candidates for testing, or document that the test set may be empty and adjust the max(1, ...) logic accordingly.
| num_test = max(1, int(test_ratio * labels_np.shape[0])) | |
| num_test = min(num_test, len(remaining_indices)) | |
| if len(remaining_indices) == 0: | |
| raise ValueError( | |
| "No remaining nodes are available for the test set after sampling training shots. " | |
| "Consider reducing 'num_shots' or 'test_ratio'." | |
| ) | |
| num_test = min(max(1, int(test_ratio * labels_np.shape[0])), len(remaining_indices)) |
| edge_index, _ = add_self_loops( | ||
| edge_index, | ||
| num_nodes=int(tlx.get_tensor_shape(x)[0]), | ||
| ) |
There was a problem hiding this comment.
EdgePromptPlus.get_prompt calls add_self_loops, but EdgePromptGCNConv.forward also adds self-loops. With the current flow (EdgePromptGCNModel computes the prompt before calling conv), this double-modifies edge_index and will misalign the generated per-edge prompts with the edges actually used by the convolution. Suggest removing self-loop addition from either the prompt generator or the convolution, and ensuring prompts are generated for the exact edge_index used in message passing.
| edge_index, _ = add_self_loops( | |
| edge_index, | |
| num_nodes=int(tlx.get_tensor_shape(x)[0]), | |
| ) |
| class EdgePromptGCNModel(tlx.nn.Module): | ||
| r"""A stacked GCN backbone for node or graph representations with EdgePrompt. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| feature_dim: int | ||
| Dimension of the input node features. | ||
| hidden_dim: int | ||
| Dimension of hidden representations. | ||
| num_layers: int, optional | ||
| Number of GCN layers. | ||
| drop_rate: float, optional | ||
| Dropout rate applied between hidden layers. | ||
| name: str, optional | ||
| Name of the module. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| feature_dim: int, | ||
| hidden_dim: int, | ||
| num_layers: int = 2, | ||
| drop_rate: float = 0.5, | ||
| name: Optional[str] = None, | ||
| ): | ||
| super().__init__(name=name) | ||
|
|
||
| if num_layers < 1: | ||
| raise ValueError("num_layers must be at least 1.") | ||
|
|
||
| self.feature_dim = feature_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.num_layers = num_layers | ||
| self.prompt_dims = [feature_dim] + [hidden_dim] * (num_layers - 1) | ||
|
|
||
| self.convs = nn.ModuleList() | ||
| in_dims = self.prompt_dims | ||
| out_dims = [hidden_dim] * num_layers | ||
| for layer, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): | ||
| self.convs.append( | ||
| EdgePromptGCNConv( | ||
| in_dim, | ||
| out_dim, | ||
| name="edgeprompt_conv_{}".format(layer), | ||
| ) | ||
| ) | ||
|
|
||
| self.relu = tlx.ReLU() | ||
| self.dropout = nn.Dropout(p=drop_rate) | ||
|
|
||
| def forward(self, graph, prompt_type=None, prompt=None, pooling=None): | ||
| # When prompt is enabled, each layer obtains its own edge prompt tensor. | ||
| x, edge_index = graph.x, graph.edge_index | ||
| prompt_type = normalize_prompt_type(prompt_type) | ||
|
|
||
| for layer, conv in enumerate(self.convs): | ||
| edge_prompt = None | ||
| if prompt is not None and prompt_type in ("EdgePrompt", "EdgePromptplus"): | ||
| edge_prompt = prompt.get_prompt(x, edge_index, layer) | ||
|
|
||
| x = conv(x, edge_index, edge_prompt=edge_prompt) | ||
| if layer != self.num_layers - 1: | ||
| x = self.relu(x) | ||
| x = self.dropout(x) | ||
|
|
||
| if pooling == "mean": | ||
| # Graph-level mean pooling for batched graphs. | ||
| batch = getattr(graph, "batch", None) | ||
| if batch is None: | ||
| raise ValueError("Mean pooling requires batched graphs with `batch`.") | ||
| return global_mean_pool(x, batch) | ||
|
|
||
| if pooling == "target": | ||
| # Gather the designated target node from each sampled subgraph. | ||
| if not hasattr(graph, "ptr") or not hasattr(graph, "target_node"): | ||
| raise ValueError( | ||
| "Target pooling requires batched subgraphs with `ptr` and `target_node`." | ||
| ) | ||
| target_index = graph.ptr[:-1] + tlx.reshape(graph.target_node, (-1,)) | ||
| return tlx.gather(x, target_index) | ||
|
|
||
| return x |
There was a problem hiding this comment.
New model code (EdgePromptGCNModel/EdgePromptNodeClassifier) is added but there are no corresponding unit tests under tests/models/ (the repo already has model tests, e.g. tests/models/test_mlp.py). Adding at least a smoke test that checks forward shapes (with/without prompts, and pooling="mean"/"target") would help prevent regressions.
| subset, edge_index, mapping, _ = k_hop_subgraph( | ||
| node_idx=node_idx, | ||
| num_hops=num_hops, | ||
| edge_index=graph.edge_index, | ||
| relabel_nodes=True, | ||
| num_nodes=graph.num_nodes, | ||
| ) | ||
| if tlx.is_tensor(mapping): | ||
| mapping_np = tlx.convert_to_numpy(mapping) | ||
| else: | ||
| mapping_np = np.asarray(mapping) | ||
| target_node = int(np.asarray(mapping_np).reshape(-1)[0]) | ||
| return Graph( | ||
| x=tlx.gather(graph.x, subset), | ||
| edge_index=edge_index, | ||
| target_node=tlx.convert_to_tensor([target_node], dtype=tlx.int64), | ||
| num_nodes=int(tlx.get_tensor_shape(subset)[0]), | ||
| ) |
There was a problem hiding this comment.
node_subgraph drops edge-level attributes from the original graph (e.g., edge_attr) even though k_hop_subgraph returns edge_mask specifically to support filtering aligned edge features. If callers use graphs with edge_attr, the returned Graph will be inconsistent with edge_index. Consider propagating edge_attr (and any other edge-level fields you want to support) using the edge mask.
| mapping_np = tlx.convert_to_numpy(mapping) | ||
| else: | ||
| mapping_np = np.asarray(mapping) | ||
| target_node = int(np.asarray(mapping_np).reshape(-1)[0]) |
There was a problem hiding this comment.
node_subgraph is described as “node-centered”, but it will silently accept multi-node node_idx inputs because k_hop_subgraph supports lists/tuples; it then picks only the first entry from mapping as target_node. This can lead to incorrect labels/targets without an obvious error. Consider validating that node_idx refers to a single node (scalar or length-1 tensor/list) and raising a clear ValueError otherwise.
| target_node = int(np.asarray(mapping_np).reshape(-1)[0]) | |
| mapping_np = np.asarray(mapping_np) | |
| if mapping_np.size != 1: | |
| raise ValueError( | |
| f"node_subgraph expects a single central node, but got {mapping_np.size} " | |
| "nodes from k_hop_subgraph. Ensure that `node_idx` is a scalar or a " | |
| "length-1 tensor/list/tuple." | |
| ) | |
| target_node = int(mapping_np.reshape(-1)[0]) |
| def node_subgraph(graph, node_idx, num_hops=2): | ||
| """Return a node-centered k-hop subgraph as a ``Graph`` object.""" | ||
| subset, edge_index, mapping, _ = k_hop_subgraph( | ||
| node_idx=node_idx, | ||
| num_hops=num_hops, | ||
| edge_index=graph.edge_index, | ||
| relabel_nodes=True, | ||
| num_nodes=graph.num_nodes, | ||
| ) | ||
| if tlx.is_tensor(mapping): | ||
| mapping_np = tlx.convert_to_numpy(mapping) | ||
| else: | ||
| mapping_np = np.asarray(mapping) | ||
| target_node = int(np.asarray(mapping_np).reshape(-1)[0]) | ||
| return Graph( | ||
| x=tlx.gather(graph.x, subset), | ||
| edge_index=edge_index, | ||
| target_node=tlx.convert_to_tensor([target_node], dtype=tlx.int64), | ||
| num_nodes=int(tlx.get_tensor_shape(subset)[0]), | ||
| ) |
There was a problem hiding this comment.
There is existing test coverage for k_hop_subgraph in tests/utils/test_k_hop_subgraph.py, but the new node_subgraph helper is untested. Adding a small unit test (e.g., verify target_node is correct, edge_index is relabeled, and x matches the subset) would align with the repo’s existing utils test patterns.
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes