From 4d1cf76edfab19dd1a2757cba4ae00cca3e50511 Mon Sep 17 00:00:00 2001 From: "Yilong.Niu" <3384756954@qq.com> Date: Mon, 23 Mar 2026 23:30:28 +0800 Subject: [PATCH] EdgePrompt --- examples/edgeprompt/README.md | 46 +++ examples/edgeprompt/__init__.py | 1 + .../edgeprompt/node_edgeprompt_finetune.py | 185 +++++++++ .../edgeprompt/node_edgeprompt_pretrain.py | 379 ++++++++++++++++++ gammagl/models/__init__.py | 7 +- gammagl/models/edgeprompt.py | 316 +++++++++++++++ gammagl/utils/__init__.py | 6 +- gammagl/utils/get_split.py | 37 +- gammagl/utils/subgraph.py | 26 +- 9 files changed, 997 insertions(+), 6 deletions(-) create mode 100644 examples/edgeprompt/README.md create mode 100644 examples/edgeprompt/__init__.py create mode 100644 examples/edgeprompt/node_edgeprompt_finetune.py create mode 100644 examples/edgeprompt/node_edgeprompt_pretrain.py create mode 100644 gammagl/models/edgeprompt.py diff --git a/examples/edgeprompt/README.md b/examples/edgeprompt/README.md new file mode 100644 index 000000000..e3751fc7f --- /dev/null +++ b/examples/edgeprompt/README.md @@ -0,0 +1,46 @@ +# EdgePrompt / EdgePrompt+ + +- Paper link: [https://arxiv.org/abs/2503.00750](https://arxiv.org/abs/2503.00750) +- Author's code repo: [https://github.com/xbfu/EdgePrompt](https://github.com/xbfu/EdgePrompt) + +## Dataset Statics + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Cora | 2,708 | 10,556 | 7 | +| Citeseer | 3,327 | 9,228 | 6 | +| PubMed | 19,717 | 88,651 | 3 | + +Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). + +### Optional pretraining + +```bash +cd examples/edgeprompt +TL_BACKEND=torch python node_edgeprompt_pretrain.py --dataset Cora --epochs 100 --seed 0 +``` + +### Downstream EdgePrompt + +```bash +cd examples/edgeprompt +TL_BACKEND=torch python node_edgeprompt_finetune.py --dataset Cora --method edgeprompt --num_shots 5 --pretrained_path ./cora_ep_gppt_backbone.npz --epochs 100 --seed 0 +``` + +### Downstream EdgePrompt+ + +```bash +cd examples/edgeprompt +TL_BACKEND=torch python node_edgeprompt_finetune.py --dataset Cora --method edgeprompt_plus --num_shots 5 --pretrained_path ./cora_ep_gppt_backbone.npz --epochs 100 --seed 0 +``` + +## Results + +| Dataset | Method | Paper | Our(Torch) | +|----------|--------------|-------|------------| +| Cora | EdgePrompt | 37.26 | 73.20 | +| Cora | EdgePrompt+ | 56.41 | 73.94 | +| CiteSeer | EdgePrompt | 29.83 | 46.77 | +| CiteSeer | EdgePrompt+ | 43.49 | 47.37 | +| PubMed | EdgePrompt | 47.20 | 51.05 | +| PubMed | EdgePrompt+ | 61.51 | 55.62 | diff --git a/examples/edgeprompt/__init__.py b/examples/edgeprompt/__init__.py new file mode 100644 index 000000000..9a031a99c --- /dev/null +++ b/examples/edgeprompt/__init__.py @@ -0,0 +1 @@ +# EdgePrompt example package. diff --git a/examples/edgeprompt/node_edgeprompt_finetune.py b/examples/edgeprompt/node_edgeprompt_finetune.py new file mode 100644 index 000000000..4c666bd18 --- /dev/null +++ b/examples/edgeprompt/node_edgeprompt_finetune.py @@ -0,0 +1,185 @@ +import argparse +import os +import random + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import numpy as np +import tensorlayerx as tlx +from gammagl.datasets import Planetoid +from gammagl.loader import DataLoader +from gammagl.models import EdgePromptGCNModel, EdgePromptNodeClassifier +from gammagl.utils import get_few_shot_split, node_subgraph +from tensorlayerx.model import TrainOneStep, WithLoss + + +HIDDEN_DIM = 128 +NUM_LAYERS = 2 +DROP_RATE = 0.5 +NUM_ANCHORS = 10 +BATCH_SIZE = 32 +LR = 0.001 +WEIGHT_DECAY = 0.0 +NUM_HOPS = 2 +TEST_RATIO = 0.2 +DATA_ROOT = "Planetoid" + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + tlx.set_seed(seed) + + +class NodeClsLoss(WithLoss): + def __init__(self, net, loss_fn): + super().__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, y): + del y + logits = forward_subgraph_batch(self.backbone_network, data) + return self._loss_fn(logits, tlx.reshape(data.y, (-1,))) + + +def forward_subgraph_batch(model, graph_batch): + graph_emb = model.backbone( + graph_batch, + prompt_type=model.prompt_type, + prompt=model.prompt, + pooling="mean", + ) + return model.classifier(graph_emb) + + +def evaluate(model, loader): + total_loss = [] + pred_list = [] + label_list = [] + + for batch in loader: + logits = forward_subgraph_batch(model, batch) + labels = tlx.reshape(batch.y, (-1,)) + loss = tlx.losses.softmax_cross_entropy_with_logits(logits, labels) + total_loss.append(float(loss.item())) + pred_list.extend(tlx.convert_to_numpy(tlx.argmax(logits, axis=-1)).tolist()) + label_list.extend(tlx.convert_to_numpy(labels).tolist()) + + test_acc = float((np.asarray(pred_list) == np.asarray(label_list)).mean()) + test_loss = float(np.mean(total_loss)) + return test_loss, test_acc + + +def build_subgraphs(graph, labels, node_indices): + labels_np = tlx.convert_to_numpy(tlx.reshape(labels, (-1,))) + node_indices = tlx.convert_to_numpy(tlx.reshape(node_indices, (-1,))).astype(np.int64) + graph_list = [] + for node_idx in node_indices.tolist(): + subgraph = node_subgraph(graph, int(node_idx), num_hops=NUM_HOPS) + subgraph.y = tlx.convert_to_tensor( + np.asarray([labels_np[int(node_idx)]], dtype=np.int64), + dtype=tlx.int64, + ) + graph_list.append(subgraph) + return graph_list + + +def main(args): + if not os.path.exists(args.pretrained_path): + raise FileNotFoundError("Pretrained checkpoint not found: {}".format(args.pretrained_path)) + + set_random_seed(args.seed) + dataset = Planetoid(DATA_ROOT, args.dataset) + graph = dataset[0] + labels = tlx.reshape(graph.y, (-1,)) + train_idx, test_idx = get_few_shot_split( + labels, + num_shots=args.num_shots, + test_ratio=TEST_RATIO, + random_state=args.seed, + ) + train_graphs = build_subgraphs(graph, labels, train_idx) + test_graphs = build_subgraphs(graph, labels, test_idx) + + print( + "Few-shot subgraphs on {}: train={}, test={} ({}-hop, test_ratio={:.2f})".format( + args.dataset, + len(train_graphs), + len(test_graphs), + NUM_HOPS, + TEST_RATIO, + ) + ) + + train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True) + test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False) + + backbone = EdgePromptGCNModel( + feature_dim=dataset.num_node_features, + hidden_dim=HIDDEN_DIM, + num_layers=NUM_LAYERS, + drop_rate=DROP_RATE, + name="EdgePromptGCN", + ) + backbone.load_weights(args.pretrained_path, format="npz_dict") + + model = EdgePromptNodeClassifier( + backbone=backbone, + num_classes=dataset.num_classes, + prompt_type=args.method, + num_prompts=NUM_ANCHORS, + name="EdgePromptNodeClassifier", + ) + + optimizer = tlx.optimizers.Adam(lr=LR, weight_decay=WEIGHT_DECAY) + loss_func = NodeClsLoss(model, tlx.losses.softmax_cross_entropy_with_logits) + train_one_step = TrainOneStep(loss_func, optimizer, model.tuning_weights()) + + last_test_acc = 0.0 + for epoch in range(1, args.epochs + 1): + model.set_train() + total_train_loss = [] + for batch in train_loader: + train_loss = train_one_step(batch, batch.y) + total_train_loss.append(float(train_loss.item())) + train_loss = float(np.mean(total_train_loss)) + + model.set_eval() + test_loss, test_acc = evaluate(model, test_loader) + last_test_acc = test_acc + print( + "Epoch [{:0>3d}] train loss: {:.4f} test loss: {:.4f} test acc: {:.4f}".format( + epoch, + train_loss, + test_loss, + test_acc, + ) + ) + + print("Final test acc: {:.4f}".format(last_test_acc)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Downstream task: node classification") + parser.add_argument("--dataset", type=str, default="Cora", choices=["Cora", "CiteSeer", "PubMed"]) + parser.add_argument("--method", type=str, default="edgeprompt", choices=["edgeprompt", "edgeprompt_plus"]) + parser.add_argument("--num_shots", type=int, default=5) + parser.add_argument("--pretrained_path", type=str, required=True) + parser.add_argument("--epochs", type=int, default=200) + parser.add_argument("--seed", type=int, default=0) + + args = parser.parse_args() + if tlx.BACKEND == "torch": + try: + import torch + if torch.cuda.is_available(): + tlx.set_device("GPU", 0) + else: + tlx.set_device("CPU") + except ImportError: + tlx.set_device("CPU") + else: + tlx.set_device("CPU") + + main(args) + + diff --git a/examples/edgeprompt/node_edgeprompt_pretrain.py b/examples/edgeprompt/node_edgeprompt_pretrain.py new file mode 100644 index 000000000..29a9273f3 --- /dev/null +++ b/examples/edgeprompt/node_edgeprompt_pretrain.py @@ -0,0 +1,379 @@ +import argparse +import json +import os +import random + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import numpy as np +import tensorlayerx as tlx +from gammagl.data import Graph +from gammagl.datasets import Planetoid +from gammagl.models import EdgePromptGCNModel +from sklearn.metrics import average_precision_score, roc_auc_score +from tensorlayerx.model import TrainOneStep, WithLoss + + +DATASET_NAME_MAP = { + "cora": "Cora", + "citeseer": "CiteSeer", + "pubmed": "PubMed", +} + + +def canonical_dataset_name(dataset): + dataset_key = str(dataset).lower() + if dataset_key not in DATASET_NAME_MAP: + raise ValueError("Unknown dataset: {}".format(dataset)) + return DATASET_NAME_MAP[dataset_key] + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + tlx.set_seed(seed) + + +def ensure_dir(path): + if not path: + return + directory = path + if os.path.splitext(path)[1]: + directory = os.path.dirname(path) + if directory: + os.makedirs(directory, exist_ok=True) + + +def build_checkpoint_path(save_dir, best_model_path, filename): + if best_model_path: + if os.path.splitext(best_model_path)[1]: + ensure_dir(best_model_path) + return best_model_path + ensure_dir(best_model_path) + return os.path.join(best_model_path, filename) + ensure_dir(save_dir) + return os.path.join(save_dir, filename) + + +def load_node_dataset(dataset, dataset_path): + dataset_name = canonical_dataset_name(dataset) + dataset_obj = Planetoid(os.path.join(dataset_path, "Planetoid"), dataset_name) + return dataset_name, dataset_obj, dataset_obj[0] + + +def empty_edge_index(): + return tlx.convert_to_tensor(np.zeros((2, 0), dtype=np.int64), dtype=tlx.int64) + + +def gather_edge_pairs(edge_pairs, indices): + if len(indices) == 0: + return empty_edge_index() + return tlx.gather( + edge_pairs, + tlx.convert_to_tensor(indices, dtype=tlx.int64), + axis=1, + ) + + +def unique_undirected_edge_pairs(edge_index): + edge_np = np.asarray(tlx.convert_to_numpy(edge_index), dtype=np.int64) + if edge_np.size == 0: + return empty_edge_index() + + row = edge_np[0] + col = edge_np[1] + mask = row != col + row = row[mask] + col = col[mask] + if row.size == 0: + return empty_edge_index() + + pairs = np.stack([np.minimum(row, col), np.maximum(row, col)], axis=1) + pairs = np.unique(pairs, axis=0) + return tlx.convert_to_tensor(pairs.T, dtype=tlx.int64) + + +def to_bidirectional_edge_index(edge_pairs): + if int(edge_pairs.shape[1]) == 0: + return empty_edge_index() + reverse_pairs = tlx.stack([edge_pairs[1], edge_pairs[0]], axis=0) + return tlx.concat([edge_pairs, reverse_pairs], axis=1) + + +def build_graph_from_edge_pairs(graph, edge_pairs): + return Graph( + x=graph.x, + edge_index=to_bidirectional_edge_index(edge_pairs), + y=graph.y, + num_nodes=graph.num_nodes, + ) + + +def split_edge_pairs(edge_pairs, val_ratio, test_ratio, seed): + num_edges = int(edge_pairs.shape[1]) + if num_edges == 0: + raise ValueError("No edges are available for pretraining.") + + num_val = max(1, int(num_edges * val_ratio)) + num_test = max(1, int(num_edges * test_ratio)) + if num_val + num_test >= num_edges: + raise ValueError("val_ratio + test_ratio leaves no training edges.") + + rng = np.random.default_rng(seed) + perm = rng.permutation(num_edges) + val_idx = perm[:num_val] + test_idx = perm[num_val:num_val + num_test] + train_idx = perm[num_val + num_test:] + return ( + gather_edge_pairs(edge_pairs, train_idx), + gather_edge_pairs(edge_pairs, val_idx), + gather_edge_pairs(edge_pairs, test_idx), + ) + + +def sample_negative_edge_pairs(num_nodes, positive_edge_pairs, num_samples, seed): + if num_samples <= 0: + return empty_edge_index() + + positive_np = tlx.convert_to_numpy(positive_edge_pairs).T + positive_set = {(int(u), int(v)) for u, v in positive_np} + negative_pairs = [] + negative_set = set() + rng = np.random.default_rng(seed) + max_trials = max(num_samples * 50, 10000) + trials = 0 + while len(negative_pairs) < num_samples and trials < max_trials: + u = int(rng.integers(0, num_nodes)) + v = int(rng.integers(0, num_nodes)) + trials += 1 + if u == v: + continue + pair = (u, v) if u < v else (v, u) + if pair in positive_set or pair in negative_set: + continue + negative_set.add(pair) + negative_pairs.append(pair) + + if len(negative_pairs) < num_samples: + raise ValueError("Unable to sample enough negative node pairs.") + + return tlx.convert_to_tensor(np.asarray(negative_pairs, dtype=np.int64).T, dtype=tlx.int64) + + +def prepare_edge_prediction_splits(graph, val_ratio, test_ratio, neg_ratio, seed): + full_edge_pairs = unique_undirected_edge_pairs(graph.edge_index) + train_edge_pairs, val_edge_pairs, test_edge_pairs = split_edge_pairs( + full_edge_pairs, + val_ratio=val_ratio, + test_ratio=test_ratio, + seed=seed, + ) + + train_message_graph = build_graph_from_edge_pairs(graph, train_edge_pairs) + num_val_neg = max(1, int(val_edge_pairs.shape[1] * neg_ratio)) + num_test_neg = max(1, int(test_edge_pairs.shape[1] * neg_ratio)) + val_neg_pairs = sample_negative_edge_pairs(graph.num_nodes, full_edge_pairs, num_val_neg, seed + 1) + test_neg_pairs = sample_negative_edge_pairs(graph.num_nodes, full_edge_pairs, num_test_neg, seed + 2) + + return { + "full_edge_pairs": full_edge_pairs, + "train_edge_pairs": train_edge_pairs, + "val": { + "message_graph": train_message_graph, + "pos_edge_index": val_edge_pairs, + "neg_edge_index": val_neg_pairs, + }, + "test": { + "message_graph": train_message_graph, + "pos_edge_index": test_edge_pairs, + "neg_edge_index": test_neg_pairs, + }, + } + + +def sample_masked_edge_prediction_task(graph, train_edge_pairs, full_edge_pairs, mask_ratio, neg_ratio, seed): + num_edges = int(train_edge_pairs.shape[1]) + if num_edges == 0: + raise ValueError("No training edges are available after the pretraining split.") + + rng = np.random.default_rng(seed) + perm = rng.permutation(num_edges) + num_masked = max(1, int(num_edges * mask_ratio)) + pos_idx = perm[:num_masked] + keep_idx = perm[num_masked:] + + pos_edge_pairs = gather_edge_pairs(train_edge_pairs, pos_idx) + message_edge_pairs = gather_edge_pairs(train_edge_pairs, keep_idx) + num_neg_edges = max(1, int(pos_edge_pairs.shape[1] * neg_ratio)) + neg_edge_pairs = sample_negative_edge_pairs(graph.num_nodes, full_edge_pairs, num_neg_edges, seed) + + return { + "message_graph": build_graph_from_edge_pairs(graph, message_edge_pairs), + "pos_edge_index": pos_edge_pairs, + "neg_edge_index": neg_edge_pairs, + } + + +def edge_scores(node_emb, edge_index): + src_x = tlx.gather(node_emb, edge_index[0]) + dst_x = tlx.gather(node_emb, edge_index[1]) + return tlx.reduce_sum(src_x * dst_x, axis=-1) + + +def link_prediction_loss(pos_logits, neg_logits): + pos_loss = tlx.losses.sigmoid_cross_entropy(pos_logits, tlx.ones_like(pos_logits)) + neg_loss = tlx.losses.sigmoid_cross_entropy(neg_logits, tlx.zeros_like(neg_logits)) + return pos_loss + neg_loss + + +class EdgePredLoss(WithLoss): + def __init__(self, net): + super().__init__(backbone=net, loss_fn=None) + + def forward(self, data, label): + del label + node_emb = self.backbone_network(data["message_graph"]) + pos_logits = edge_scores(node_emb, data["pos_edge_index"]) + neg_logits = edge_scores(node_emb, data["neg_edge_index"]) + return link_prediction_loss(pos_logits, neg_logits) + + +def evaluate(model, data): + node_emb = model(data["message_graph"]) + pos_logits = edge_scores(node_emb, data["pos_edge_index"]) + neg_logits = edge_scores(node_emb, data["neg_edge_index"]) + loss = link_prediction_loss(pos_logits, neg_logits) + + pos_scores = tlx.convert_to_numpy(tlx.sigmoid(pos_logits)) + neg_scores = tlx.convert_to_numpy(tlx.sigmoid(neg_logits)) + scores = np.concatenate([pos_scores, neg_scores], axis=0) + labels = np.concatenate([ + np.ones(pos_scores.shape[0], dtype=np.int64), + np.zeros(neg_scores.shape[0], dtype=np.int64), + ], axis=0) + return float(loss.item()), roc_auc_score(labels, scores), average_precision_score(labels, scores) + + +def save_metadata(checkpoint_path, args, dataset_name): + metadata_path = os.path.splitext(checkpoint_path)[0] + ".json" + metadata = { + "dataset": dataset_name, + "hidden_dim": args.hidden_dim, + "num_layers": args.num_layers, + "drop_rate": args.drop_rate, + "mask_ratio": args.mask_ratio, + "neg_ratio": args.neg_ratio, + "val_ratio": args.val_ratio, + "test_ratio": args.test_ratio, + "lr": args.lr, + "weight_decay": args.weight_decay, + "epochs": args.epochs, + "checkpoint_contains": "EdgePromptGCNModel backbone weights only", + "note": "Minimal EP-GPPT-style engineering pretraining, not an official standalone script.", + } + with open(metadata_path, "w", encoding="utf-8") as file: + json.dump(metadata, file, indent=2) + + +def main(args): + set_random_seed(args.seed) + dataset_name, dataset, graph = load_node_dataset(args.dataset, args.dataset_path) + splits = prepare_edge_prediction_splits( + graph, + val_ratio=args.val_ratio, + test_ratio=args.test_ratio, + neg_ratio=args.neg_ratio, + seed=args.seed, + ) + + print( + "Edge pretrain split on {}: train_message_edges={}, val_pos={}, test_pos={}".format( + dataset_name, + int(splits["train_edge_pairs"].shape[1]), + int(splits["val"]["pos_edge_index"].shape[1]), + int(splits["test"]["pos_edge_index"].shape[1]), + ) + ) + + backbone = EdgePromptGCNModel( + feature_dim=dataset.num_node_features, + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + drop_rate=args.drop_rate, + name="EdgePromptGCN", + ) + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay) + loss_func = EdgePredLoss(backbone) + train_one_step = TrainOneStep(loss_func, optimizer, backbone.trainable_weights) + + checkpoint_path = build_checkpoint_path( + args.save_dir, + args.best_model_path, + "{}_ep_gppt_backbone.npz".format(dataset_name.lower()), + ) + + best_val_ap = -1.0 + for epoch in range(args.epochs): + backbone.set_train() + train_task = sample_masked_edge_prediction_task( + graph, + splits["train_edge_pairs"], + splits["full_edge_pairs"], + mask_ratio=args.mask_ratio, + neg_ratio=args.neg_ratio, + seed=args.seed + epoch, + ) + train_loss = train_one_step(train_task, tlx.convert_to_tensor([1])) + + backbone.set_eval() + val_loss, val_auc, val_ap = evaluate(backbone, splits["val"]) + print( + "Epoch [{:0>3d}] train loss: {:.4f} val loss: {:.4f} val auc: {:.4f} val ap: {:.4f}".format( + epoch + 1, + float(train_loss.item()), + val_loss, + val_auc, + val_ap, + ) + ) + + if val_ap > best_val_ap: + best_val_ap = val_ap + backbone.save_weights(checkpoint_path, format="npz_dict") + + backbone.load_weights(checkpoint_path, format="npz_dict") + backbone.set_eval() + test_loss, test_auc, test_ap = evaluate(backbone, splits["test"]) + save_metadata(checkpoint_path, args, dataset_name) + print("Best checkpoint: {}".format(checkpoint_path)) + print("Test loss: {:.4f}".format(test_loss)) + print("Test auc: {:.4f}".format(test_auc)) + print("Test ap: {:.4f}".format(test_ap)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, default="Cora", help="dataset") + parser.add_argument("--dataset_path", type=str, default=r"", help="path to save dataset") + parser.add_argument("--hidden_dim", type=int, default=128, help="hidden dimension") + parser.add_argument("--num_layers", type=int, default=2, help="number of GCN layers") + parser.add_argument("--drop_rate", type=float, default=0.5, help="dropout rate") + parser.add_argument("--mask_ratio", type=float, default=0.1, help="masked edge ratio") + parser.add_argument("--neg_ratio", type=float, default=1.0, help="negative sampling ratio") + parser.add_argument("--val_ratio", type=float, default=0.05, help="validation edge ratio") + parser.add_argument("--test_ratio", type=float, default=0.1, help="test edge ratio") + parser.add_argument("--lr", type=float, default=0.001, help="learning rate") + parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay") + parser.add_argument("--epochs", type=int, default=200, help="training epochs") + parser.add_argument("--save_dir", type=str, default=r"./", help="path to save checkpoints") + parser.add_argument("--best_model_path", type=str, default="", help="optional checkpoint file path or checkpoint directory") + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument("--gpu", type=int, default=0) + + args = parser.parse_args() + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + main(args) diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 2f5f1c196..05ef10f73 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -66,6 +66,7 @@ from .gcil import GCILModel, LogReg from .sgformer import SGFormerModel from .adagad import PreModel, ReModel +from .edgeprompt import EdgePromptGCNModel, EdgePromptNodeClassifier __all__ = [ 'HeCo', @@ -111,7 +112,7 @@ 'HPN', 'GMMModel', 'HERec', - 'MetaPath2Vec' + 'MetaPath2Vec', 'ieHGCNModel', 'TADWModel', 'MGNNI_m_MLP', @@ -140,7 +141,9 @@ 'LogReg', 'sgformer', 'PreModel', - 'ReModel' + 'ReModel', + 'EdgePromptGCNModel', + 'EdgePromptNodeClassifier' ] classes = __all__ diff --git a/gammagl/models/edgeprompt.py b/gammagl/models/edgeprompt.py new file mode 100644 index 000000000..256a63a74 --- /dev/null +++ b/gammagl/models/edgeprompt.py @@ -0,0 +1,316 @@ +from typing import List, Optional + +import tensorlayerx as tlx +import tensorlayerx.nn as nn + +from gammagl.layers.pool import global_mean_pool +from gammagl.utils import add_self_loops, degree + + +def normalize_prompt_type(prompt_type: Optional[str]) -> Optional[str]: + r"""Normalize the prompt type string to the internal canonical form.""" + if prompt_type is None: + return None + + prompt_key = str(prompt_type).lower() + if prompt_key == "edgeprompt": + return "EdgePrompt" + if prompt_key in ("edgepromptplus", "edgeprompt_plus", "edgeprompt+"): + return "EdgePromptplus" + if prompt_key in ("none", "no_prompt"): + return None + + raise ValueError("Unsupported prompt type: {}".format(prompt_type)) + + +class EdgePromptGCNConv(tlx.nn.Module): + r"""A GCN-style convolution layer with optional edge prompt injection. + + Parameters + ---------- + in_channels: int + Dimension of the input node features. + out_channels: int + Dimension of the output node features. + name: str, optional + Name of the module. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.in_channels = in_channels + self.out_channels = out_channels + self.linear = nn.Linear( + in_features=in_channels, + out_features=out_channels, + W_init="xavier_uniform", + b_init=None, + ) + self.bias = self._get_weights( + "bias", + shape=(1, out_channels), + init=tlx.initializers.zeros(), + ) + + def forward(self, x, edge_index, edge_prompt=None): + # Add self-loops first so every node can aggregate its own feature. + num_nodes = int(tlx.get_tensor_shape(x)[0]) + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + + row, col = edge_index[0], edge_index[1] + # Standard symmetric normalization used in GCN. + deg = degree(col, num_nodes=num_nodes, dtype=tlx.float32) + deg_inv_sqrt = tlx.pow(deg, -0.5) + deg_inv_sqrt = tlx.where( + tlx.is_inf(deg_inv_sqrt), + tlx.zeros_like(deg_inv_sqrt), + deg_inv_sqrt, + ) + norm = tlx.gather(deg_inv_sqrt, row) * tlx.gather(deg_inv_sqrt, col) + + src_x = tlx.gather(x, row) + if edge_prompt is not None: + # Inject prompt information into source node features on each edge. + src_x = src_x + edge_prompt + + messages = self.linear(src_x) + messages = messages * tlx.expand_dims(norm, axis=-1) + out = tlx.unsorted_segment_sum(messages, col, num_nodes) + + return out + self.bias + + +class EdgePrompt(tlx.nn.Module): + r"""The basic EdgePrompt module with one learnable prompt per layer. + + Parameters + ---------- + dim_list: list + Feature dimensions used by each prompt layer. + name: str, optional + Name of the module. + """ + def __init__(self, dim_list: List[int], name: Optional[str] = None): + super().__init__(name=name) + + self.dim_list = dim_list + self.global_prompt = [] + for layer, dim in enumerate(dim_list): + prompt = self._get_weights( + "global_prompt_{}".format(layer), + shape=(1, dim), + init=tlx.initializers.xavier_uniform(), + ) + self.global_prompt.append(prompt) + + def get_prompt(self, x, edge_index, layer): + # The vanilla EdgePrompt ignores graph structure and returns a global prompt. + del x, edge_index + return self.global_prompt[layer] + + +class EdgePromptPlus(tlx.nn.Module): + r"""The EdgePrompt+ module that generates edge-aware prompts dynamically. + + Parameters + ---------- + dim_list: list + Feature dimensions used by each prompt layer. + num_anchors: int + Number of anchor prompts maintained for each layer. + name: str, optional + Name of the module. + """ + def __init__( + self, + dim_list: List[int], + num_anchors: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.dim_list = dim_list + self.num_anchors = num_anchors + self.anchor_prompt = [] + self.projectors = nn.ModuleList() + + for layer, dim in enumerate(dim_list): + anchor = self._get_weights( + "anchor_prompt_{}".format(layer), + shape=(num_anchors, dim), + init=tlx.initializers.xavier_uniform(), + ) + self.anchor_prompt.append(anchor) + self.projectors.append( + nn.Linear( + in_features=2 * dim, + out_features=num_anchors, + W_init="xavier_uniform", + ) + ) + + def get_prompt(self, x, edge_index, layer): + # Build an edge representation from source and destination node features, + # then use it to mix anchor prompts adaptively. + edge_index, _ = add_self_loops( + edge_index, + num_nodes=int(tlx.get_tensor_shape(x)[0]), + ) + src_x = tlx.gather(x, edge_index[0]) + dst_x = tlx.gather(x, edge_index[1]) + edge_feat = tlx.concat([src_x, dst_x], axis=-1) + coeff = self.projectors[layer](edge_feat) + coeff = tlx.leaky_relu(coeff, negative_slope=0.2) + coeff = tlx.softmax(coeff, axis=-1) + return tlx.matmul(coeff, self.anchor_prompt[layer]) + + +class EdgePromptGCNModel(tlx.nn.Module): + r"""A stacked GCN backbone for node or graph representations with EdgePrompt. + + Parameters + ---------- + feature_dim: int + Dimension of the input node features. + hidden_dim: int + Dimension of hidden representations. + num_layers: int, optional + Number of GCN layers. + drop_rate: float, optional + Dropout rate applied between hidden layers. + name: str, optional + Name of the module. + """ + def __init__( + self, + feature_dim: int, + hidden_dim: int, + num_layers: int = 2, + drop_rate: float = 0.5, + name: Optional[str] = None, + ): + super().__init__(name=name) + + if num_layers < 1: + raise ValueError("num_layers must be at least 1.") + + self.feature_dim = feature_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.prompt_dims = [feature_dim] + [hidden_dim] * (num_layers - 1) + + self.convs = nn.ModuleList() + in_dims = self.prompt_dims + out_dims = [hidden_dim] * num_layers + for layer, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + self.convs.append( + EdgePromptGCNConv( + in_dim, + out_dim, + name="edgeprompt_conv_{}".format(layer), + ) + ) + + self.relu = tlx.ReLU() + self.dropout = nn.Dropout(p=drop_rate) + + def forward(self, graph, prompt_type=None, prompt=None, pooling=None): + # When prompt is enabled, each layer obtains its own edge prompt tensor. + x, edge_index = graph.x, graph.edge_index + prompt_type = normalize_prompt_type(prompt_type) + + for layer, conv in enumerate(self.convs): + edge_prompt = None + if prompt is not None and prompt_type in ("EdgePrompt", "EdgePromptplus"): + edge_prompt = prompt.get_prompt(x, edge_index, layer) + + x = conv(x, edge_index, edge_prompt=edge_prompt) + if layer != self.num_layers - 1: + x = self.relu(x) + x = self.dropout(x) + + if pooling == "mean": + # Graph-level mean pooling for batched graphs. + batch = getattr(graph, "batch", None) + if batch is None: + raise ValueError("Mean pooling requires batched graphs with `batch`.") + return global_mean_pool(x, batch) + + if pooling == "target": + # Gather the designated target node from each sampled subgraph. + if not hasattr(graph, "ptr") or not hasattr(graph, "target_node"): + raise ValueError( + "Target pooling requires batched subgraphs with `ptr` and `target_node`." + ) + target_index = graph.ptr[:-1] + tlx.reshape(graph.target_node, (-1,)) + return tlx.gather(x, target_index) + + return x + + +class EdgePromptNodeClassifier(tlx.nn.Module): + r"""A node classification wrapper built on top of the EdgePrompt GCN backbone. + + Parameters + ---------- + backbone: EdgePromptGCNModel + The GCN backbone used to produce node embeddings. + num_classes: int + Number of prediction classes. + prompt_type: str or None + Prompt strategy, supporting ``EdgePrompt``, ``EdgePromptplus`` or ``None``. + num_prompts: int, optional + Number of anchor prompts used by ``EdgePromptplus``. + name: str, optional + Name of the module. + """ + def __init__( + self, + backbone: EdgePromptGCNModel, + num_classes: int, + prompt_type: Optional[str], + num_prompts: int = 10, + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.backbone = backbone + self.prompt_type = normalize_prompt_type(prompt_type) + + if self.prompt_type == "EdgePrompt": + self.prompt = EdgePrompt(backbone.prompt_dims, name="EdgePrompt") + elif self.prompt_type == "EdgePromptplus": + self.prompt = EdgePromptPlus( + backbone.prompt_dims, + num_anchors=num_prompts, + name="EdgePromptPlus", + ) + else: + self.prompt = None + + self.classifier = nn.Linear( + in_features=backbone.hidden_dim, + out_features=num_classes, + W_init="xavier_uniform", + ) + + def forward(self, graph): + # The classifier head operates on the prompt-enhanced node embeddings. + node_emb = self.backbone( + graph, + prompt_type=self.prompt_type, + prompt=self.prompt, + ) + return self.classifier(node_emb) + + def tuning_weights(self): + # During prompt tuning, only optimize the classifier and prompt parameters. + weights = list(self.classifier.trainable_weights) + if self.prompt is not None: + weights += list(self.prompt.trainable_weights) + return weights diff --git a/gammagl/utils/__init__.py b/gammagl/utils/__init__.py index a1ffb8bf2..09230165f 100644 --- a/gammagl/utils/__init__.py +++ b/gammagl/utils/__init__.py @@ -10,7 +10,7 @@ from .inspector import Inspector from .device import set_device from .to_dense_batch import to_dense_batch -from .subgraph import k_hop_subgraph +from .subgraph import k_hop_subgraph, node_subgraph from .negative_sampling import negative_sampling from .convert import to_scipy_sparse_matrix, edge_index_to_adj_matrix from .read_embeddings import read_embeddings @@ -18,7 +18,7 @@ from .to_dense_adj import to_dense_adj from .smiles import from_smiles from .shortest_path import shortest_path_distance, batched_shortest_path_distance -from .get_split import get_train_val_test_split +from .get_split import get_train_val_test_split, get_few_shot_split from .get_laplacian import get_laplacian from .simple_path import find_all_simple_paths @@ -38,6 +38,7 @@ 'set_device', 'to_dense_batch', 'k_hop_subgraph', + 'node_subgraph', 'negative_sampling', 'to_scipy_sparse_matrix', 'read_embeddings', @@ -47,6 +48,7 @@ 'shortest_path_distance', 'batched_shortest_path_distance', 'get_train_val_test_split', + 'get_few_shot_split', 'get_laplacian', 'find_all_simple_paths', 'edge_index_to_adj_matrix' diff --git a/gammagl/utils/get_split.py b/gammagl/utils/get_split.py index fc90494dc..746f6b63b 100644 --- a/gammagl/utils/get_split.py +++ b/gammagl/utils/get_split.py @@ -54,4 +54,39 @@ def generate_masks(num_nodes, train_indices, val_indices, test_indices): val_mask = tlx.ops.convert_to_tensor(np_val_mask, dtype=tlx.bool) test_mask = tlx.ops.convert_to_tensor(np_test_mask, dtype=tlx.bool) - return train_mask, val_mask, test_mask \ No newline at end of file + return train_mask, val_mask, test_mask + + +def get_few_shot_split(labels, num_shots, test_ratio=0.2, random_state=0): + """Sample a minimal few-shot train/test split for node classification. + + This follows the original EdgePrompt node downstream protocol closely: + sample up to ``num_shots`` nodes per class for training, remove those nodes + from the candidate pool, then draw a random test subset from the remaining + nodes. + """ + if test_ratio <= 0 or test_ratio > 1: + raise ValueError('test_ratio must be in (0, 1].') + + labels = tlx.reshape(labels, (-1,)) + labels_np = tlx.convert_to_numpy(labels) + rng = np.random.RandomState(random_state) + + train_indices = [] + for cls in np.unique(labels_np): + cls_indices = np.where(labels_np == cls)[0] + if cls_indices.shape[0] <= num_shots: + train_indices.extend(cls_indices.tolist()) + else: + train_indices.extend(rng.choice(cls_indices, size=num_shots, replace=False).tolist()) + + train_set = set(train_indices) + remaining_indices = [idx for idx in rng.permutation(labels_np.shape[0]).tolist() if idx not in train_set] + num_test = max(1, int(test_ratio * labels_np.shape[0])) + num_test = min(num_test, len(remaining_indices)) + test_indices = remaining_indices[:num_test] + + return ( + tlx.convert_to_tensor(np.asarray(train_indices, dtype=np.int64), dtype=tlx.int64), + tlx.convert_to_tensor(np.asarray(test_indices, dtype=np.int64), dtype=tlx.int64), + ) diff --git a/gammagl/utils/subgraph.py b/gammagl/utils/subgraph.py index 29f10532e..15dedb516 100644 --- a/gammagl/utils/subgraph.py +++ b/gammagl/utils/subgraph.py @@ -1,7 +1,9 @@ import tensorlayerx as tlx import numpy as np +from gammagl.data import Graph from .num_nodes import maybe_num_nodes + def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False, num_nodes=None, reverse=False): r""" Computes the induced subgraph of :obj:`edge_index` around all nodes in @@ -79,4 +81,26 @@ def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False, num_node edge_index = tlx.reshape(edge_index, (2, -1)) else: edge_index = tlx.gather(node_idx, edge_index) - return subset, edge_index, inv, edge_mask \ No newline at end of file + return subset, edge_index, inv, edge_mask + + +def node_subgraph(graph, node_idx, num_hops=2): + """Return a node-centered k-hop subgraph as a ``Graph`` object.""" + subset, edge_index, mapping, _ = k_hop_subgraph( + node_idx=node_idx, + num_hops=num_hops, + edge_index=graph.edge_index, + relabel_nodes=True, + num_nodes=graph.num_nodes, + ) + if tlx.is_tensor(mapping): + mapping_np = tlx.convert_to_numpy(mapping) + else: + mapping_np = np.asarray(mapping) + target_node = int(np.asarray(mapping_np).reshape(-1)[0]) + return Graph( + x=tlx.gather(graph.x, subset), + edge_index=edge_index, + target_node=tlx.convert_to_tensor([target_node], dtype=tlx.int64), + num_nodes=int(tlx.get_tensor_shape(subset)[0]), + )