diff --git a/examples/gnrf/args.json b/examples/gnrf/args.json new file mode 100644 index 00000000..4f49266c --- /dev/null +++ b/examples/gnrf/args.json @@ -0,0 +1,218 @@ +{ + "wisconsin": { + "dataset": "wisconsin", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.0014577890711895912, + "weight_decay": 0.0008857926282823821, + "dropout": 0.27372220290104154, + "num_hid": 128, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": true, + "use_bn_out": false, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 0.8483774748255035, + "tol_scale": 2.118818198337877, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 1703, + "num_class": 5, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "cornell": { + "dataset": "cornell", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 800, + "lr": 0.002832941670402573, + "weight_decay": 4.315845302779911e-05, + "dropout": 0.20115854835043973, + "num_hid": 192, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": false, + "use_bn_out": false, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 0.10918163696818216, + "tol_scale": 4.348962999711223, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 1703, + "num_class": 5, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "texas": { + "dataset": "texas", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.009465689674847282, + "weight_decay": 3.318139009838993e-05, + "dropout": 0.02318890583746977, + "num_hid": 256, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": false, + "use_bn_out": false, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 0.664120788452176, + "tol_scale": 2.2510097955315866, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 1703, + "num_class": 5, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "roman-empire": { + "dataset": "Roman-empire", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.0017770294770524057, + "weight_decay": 1.206367283375403e-06, + "dropout": 0.4076238151758252, + "num_hid": 256, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": true, + "use_bn_out": true, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 7.1808984693838624, + "tol_scale": 14766.304417175616, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 300, + "num_class": 18, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "cora_full": { + "dataset": "Cora_Full", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.00027656068839244653, + "weight_decay": 9.76581378785642e-05, + "dropout": 0.6837586052894707, + "num_hid": 256, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": true, + "use_bn_out": true, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 4.092580732743444, + "tol_scale": 5.831748380763541, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 8710, + "num_class": 70, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "pubmed": { + "dataset": "pubmed", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.000842290192251508, + "weight_decay": 6.724802070706501e-06, + "dropout": 0.3068111160234254, + "num_hid": 64, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": true, + "use_bn_out": true, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 1.562395740177495, + "tol_scale": 76.17766226880025, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 500, + "num_class": 3, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "tolokers": { + "dataset": "Tolokers", + "device": "cuda:0", + "trial": 10, + "verbose": false, + "epoch": 500, + "lr": 0.009902825826309957, + "weight_decay": 0.0006860252495670429, + "dropout": 0.0920985414289007, + "num_hid": 64, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": true, + "use_bn_out": true, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 3.02086710208908, + "tol_scale": 172.61973364435173, + "solver": "implicit_adams", + "adjoint": false, + "channel_curv": true, + "num_feat": 10, + "num_class": 2, + "rewiring": null, + "damping": true, + "edgenet": true + }, + "ogbn-arxiv": { + "dataset": "ogbn-arxiv", + "device": "cuda:2", + "trial": 10, + "verbose": false, + "epoch": 1500, + "lr": 0.0012487994084971797, + "weight_decay": 4.8251083637061214e-05, + "dropout":0.1127683161361892, + "num_hid": 64, + "use_mlp_in": false, + "use_mlp_out": false, + "use_bn_in": false, + "use_bn_out": true, + "ode": "GNRF", + "t_start": 1e-05, + "t_end": 3.206093833264813, + "tol_scale": 7.211796294690889, + "solver": "implicit_adams", + "adjoint": true, + "channel_curv": true, + "num_feat": 128, + "num_class": 40, + "rewiring": null, + "damping": true, + "edgenet": true + } +} \ No newline at end of file diff --git a/examples/gnrf/dataset.py b/examples/gnrf/dataset.py new file mode 100644 index 00000000..32a4b21e --- /dev/null +++ b/examples/gnrf/dataset.py @@ -0,0 +1,64 @@ +import os +import numpy as np +import tensorlayerx as tlx +from gammagl.datasets import Planetoid, WebKB +from gammagl.utils import to_undirected, remove_self_loops +from gammagl.datasets.custom_datasets import CustomDataset + +script_dir = os.path.dirname(os.path.abspath(__file__)) +DATASET_ROOT = script_dir + "/Datasets" + +class NodeDataset: + def __init__(self, dataset_name: str): + self.dataset_name = dataset_name.lower() + self.file_root = DATASET_ROOT + + if self.dataset_name in ['pubmed']: + dataset = Planetoid(root=self.file_root, name=self.dataset_name, force_reload=True) + data = dataset[0] + elif self.dataset_name in ['cornell', 'texas', 'wisconsin']: + dataset = WebKB(root=self.file_root, name=self.dataset_name, force_reload=True) + data = dataset[0] + else: + data = self._load_custom_dataset() + + self.x = tlx.convert_to_tensor(data.x, dtype=tlx.float32) + self.edge_index = tlx.convert_to_tensor(data.edge_index, dtype=tlx.int64) + self.y = tlx.convert_to_tensor(data.y, dtype=tlx.int64) + + if self.dataset_name == 'ogbn-arxiv' and len(self.y.shape) > 1: + self.y = tlx.reshape(self.y, (-1,)) + + self.nfeat = self.x.shape[1] + self.nclass = int(tlx.reduce_max(self.y)) + 1 + self.nnode = self.x.shape[0] + + self.edge_index = remove_self_loops(self.edge_index)[0] + self.edge_index = to_undirected(self.edge_index) + + self._normalize_features() + + def _load_custom_dataset(self): + if self.dataset_name not in ['roman-empire', 'tolokers', 'cora_full', 'ogbn-arxiv']: + raise ValueError(f"Unsupported dataset: {self.dataset_name}") + dataset = CustomDataset(root=self.file_root, name=self.dataset_name) + return dataset[0] + + def _normalize_features(self): + self.x = tlx.where(tlx.is_nan(self.x), tlx.zeros_like(self.x), self.x) + rowsum = tlx.reduce_sum(self.x, axis=1, keepdims=True) + rowsum = tlx.where(rowsum == 0., tlx.ones_like(rowsum), rowsum) + self.x = self.x / rowsum + + def random_split(self, seed: int, p_train: float = 0.6, p_val: float = 0.2): + tlx.set_seed(seed) + np.random.seed(seed) + n_train = int(self.nnode * p_train) + n_val = int(self.nnode * p_val) + full_idx = np.random.permutation(self.nnode) + train_idx = full_idx[:n_train] + val_idx = full_idx[n_train+1:n_train + n_val] + test_idx = full_idx[n_train + n_val+1:] + return train_idx, val_idx, test_idx + + \ No newline at end of file diff --git a/examples/gnrf/readme.md b/examples/gnrf/readme.md new file mode 100644 index 00000000..8b8d4b03 --- /dev/null +++ b/examples/gnrf/readme.md @@ -0,0 +1,47 @@ +# Graph Neural Ricci Flow: Evolving Feature From A Curvature Perspective +--- + +- Paper link: https://proceedings.iclr.cc/paper_files/paper/2025/file/4d3ac0eee841e6df6e08e51932943266-Paper-Conference.pdf +- Author's code repo (in PyTorch): https://github.com/GalenChen320/GNRF_new + +## Datasets and Performances +--- + +|Datasets|Cornell|Wisconsin|Texas|Roman-Empire|Tolokers|Cora_Full|Pubmed| +|---|---|---|---|---|---|---|---| +|Hom.level|0.1227|0.1778|0.0609|0.0000|0.6344|0.5670|0.8024| +|#Node|183|251|183|22,662|11,758|19,793|19,717| +|Paper|87.28(±3.12)|88.00(±2.00)|87.39(±4.13)|86.25(±0.46)|83.96(±0.39)|72.12(±0.50)|90.37(±0.69)| +|Ours|79.46(±5.57)|87.60(±2.33)|84.86(±6.64)|85.01(±1.04)|81.14(±0.98)|68.62(±0.59)|88.85(±0.39)| + +|Datasets|Ogbn-Arxiv| +|---|---| +|depth|3| +|num-hid|64| +|Paper|69.33| +|Ours|60.01| + +## Notes +--- + +- On the Cornell dataset, under the source code and environment described in the paper, the performance on an RTX 3090 is mean: 79.46, std: 7.77. Therefore, the GammaGL version retains mean: 79.46, std: 5.57. +- For the Ogbn-arxiv dataset, with depth=3, num-hid=64, no standard deviation data was reported in the paper. In the original paper's source code environment on an RTX 3090, the results are mean: 66.64, std: 0.62, while the GammaGL version retains mean: 60.01, std: 1.91. +- When using PaddlePaddle or MindSpore as backends, due to the lack of mature and unified Neural ODE solving ecosystems, the odeint module is manually implemented, with only correctness testing performed and no performance guarantees. +- All the data presented in the above tables are obtained only with the PyTorch backend. + +## How To Run +--- + +Execute in the current directory. + + +```python +python train.py --dataset wisconsin +``` + +The dataset defaults to GPU mode; under CPU, the command is as follows. Note that the specification of CPU for different backends is case-sensitive. + + +```python +python train.py --dataset wisconsin --device cpu +``` diff --git a/examples/gnrf/train.py b/examples/gnrf/train.py new file mode 100644 index 00000000..0137b283 --- /dev/null +++ b/examples/gnrf/train.py @@ -0,0 +1,148 @@ +import os +os.environ['HOME'] = os.path.dirname(os.path.abspath(__file__)) +os.environ['TL_BACKEND'] = 'torch' +import argparse +import numpy as np +import tensorlayerx as tlx +from tensorlayerx.optimizers import Adam +from tensorlayerx.model import TrainOneStep, WithLoss +from gammagl.utils import to_undirected, remove_self_loops +from dataset import NodeDataset +from gammagl.models.gnrf import GNN +import gc + + +def load_best_args(args): + import json + dataset = args.dataset.lower() + with open("args.json", 'r') as f: + json_data = json.load(f) + if dataset in json_data: + for key, value in json_data[dataset].items(): + if not hasattr(args, key): + setattr(args, key, value) + return args + +class SemiSpvzLoss(WithLoss): + def __init__(self, backbone, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=backbone, loss_fn=loss_fn) + + def forward(self, data, y): + logits = self.backbone_network(data['x'], data['edge_index']) + train_logits = tlx.gather(logits, data['train_idx']) + train_y = tlx.gather(y, data['train_idx']) + loss = self._loss_fn(train_logits, train_y) + return loss + +def evaluate_loss(model, data, y, mask_idx, loss_fn): + model.set_eval() + logits = model(data['x'], data['edge_index']) + target_logits = tlx.gather(logits, mask_idx) + target_y = tlx.gather(y, mask_idx) + loss = loss_fn(target_logits, target_y) + if hasattr(loss, 'item'): + loss_val = loss.item() + elif hasattr(loss, 'numpy'): + loss_val = float(loss.numpy()) + else: + loss_val = float(loss) + return loss_val + +def evaluate_acc(model, data, y, mask_idx): + model.set_eval() + logits = model(data['x'], data['edge_index']) + target_logits = tlx.gather(logits, mask_idx) + target_y = tlx.gather(y, mask_idx) + pred_class = tlx.argmax(target_logits, axis=1) + pred_np = tlx.convert_to_numpy(pred_class) + true_np = tlx.convert_to_numpy(target_y) + correct = (pred_np == true_np) + acc = correct.mean() + return acc + +def main(args): + tlx.set_device(args.device) + data = NodeDataset(args.dataset) + if (args.dataset in ["cornell", "wisconsin", "texas"]) and args.rewiring is not None: + rewired = np.load("Datasets/rewiring.npz") + edge_index = rewired[f"{args.rewiring}_{args.dataset}"] + edge_index = tlx.convert_to_tensor(edge_index, dtype=tlx.int64) + edge_index = to_undirected(edge_index) + edge_index = remove_self_loops(edge_index)[0] + edge_index = tlx.to_device(edge_index, args.device) + else: + edge_index = tlx.to_device(data.edge_index, args.device) + x = tlx.to_device(data.x, args.device) + y = tlx.to_device(data.y, args.device) + args.num_feat = data.nfeat + args.num_class = data.nclass + + results = [] + for trial in range(int(args.trial)): + tlx.set_seed(trial) + np.random.seed(trial) + train_idx, val_idx, test_idx = data.random_split(seed=trial, p_train=0.6, p_val=0.2) + train_idx = tlx.convert_to_tensor(train_idx, dtype=tlx.int64) + val_idx = tlx.convert_to_tensor(val_idx, dtype=tlx.int64) + test_idx = tlx.convert_to_tensor(test_idx, dtype=tlx.int64) + train_idx = tlx.to_device(train_idx, args.device) + val_idx = tlx.to_device(val_idx, args.device) + test_idx = tlx.to_device(test_idx, args.device) + + data_dict = {'x': x,'edge_index': edge_index,'train_idx': train_idx} + model = GNN(args) + if hasattr(model, 'to_device'): + model.to_device(args.device) + elif hasattr(model, 'to'): + model.to(args.device) + optimizer = Adam(lr=args.lr, weight_decay=args.weight_decay) + train_weights = model.trainable_weights + loss_func = SemiSpvzLoss(model, tlx.losses.softmax_cross_entropy_with_logits) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + best_val_loss = float('inf') + final_test_acc = 0.0 + for epoch in range(int(args.epoch)): + model.set_train() + train_loss = train_one_step(data_dict, y) + if hasattr(train_loss, 'item'): + train_loss = train_loss.item() + elif hasattr(train_loss, 'numpy'): + train_loss = float(train_loss.numpy()) + else: + train_loss = float(train_loss) + val_loss = evaluate_loss(model, data_dict, y, val_idx, tlx.losses.softmax_cross_entropy_with_logits) + test_acc = evaluate_acc(model, data_dict, y, test_idx) + + if args.verbose: + print(f"[Epoch {epoch:3d}] Train Loss: {train_loss:.4f}", \ + f"Valid Loss: {val_loss:.4f}, Test Acc: {test_acc:.4f}.") + if val_loss < best_val_loss: + best_val_loss = val_loss + final_test_acc = test_acc + + gc.collect() + + if args.verbose: + print(f"Best Test Acc: {final_test_acc:.4f}") + results.append(final_test_acc) + + results_np = np.array(results) + mean, std = np.mean(results_np), np.std(results_np) + print(f"Mean: {mean:.4f}, Std: {std:.4f}", flush=True) + return mean + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + args, unknown_args = parser.parse_known_args() + unknown_args_dict = {} + for i in range(0, len(unknown_args), 2): + key = unknown_args[i].lstrip('-') + value = unknown_args[i+1] if i+1 < len(unknown_args) else None + unknown_args_dict[key] = value + for key, value in unknown_args_dict.items(): + setattr(args, key, value) + args = load_best_args(args) + main(args) \ No newline at end of file diff --git a/gammagl/datasets/custom_datasets.py b/gammagl/datasets/custom_datasets.py new file mode 100644 index 00000000..1afb8d06 --- /dev/null +++ b/gammagl/datasets/custom_datasets.py @@ -0,0 +1,190 @@ +import os +import os.path as osp +import zipfile +from typing import Optional, Callable, List +import numpy as np +import pandas as pd +import tensorlayerx as tlx +from gammagl.data import InMemoryDataset, download_url, Graph +from gammagl.utils import coalesce + + +class CustomDataset(InMemoryDataset): + r"""统一加载四个额外数据集: roman-empire, tolokers, cora_full, ogbn-arxiv。 + 继承自 GammaGL 的 InMemoryDataset,完全兼容 GammaGL 数据流水线。 + + 参数 + ---------- + root : str, optional + 数据集存储的根目录。默认为 './data'。 + name : str + 数据集名称,支持: 'roman-empire', 'tolokers', 'cora_full', 'ogbn-arxiv'。 + transform : callable, optional + 应用于每个 Graph 对象的变换函数(在访问时应用)。 + pre_transform : callable, optional + 在保存到磁盘前应用于 Graph 对象的变换函数。 + force_reload : bool, optional + 是否强制重新处理数据集。默认 False。 + + Uniformly load four additional datasets: roman-empire, tolokers, cora_full, ogbn-arxiv. + Inherits from GammaGL's InMemoryDataset, fully compatible with GammaGL data pipeline. + + Parameters + ---------- + root : str, optional + Root directory where the dataset is stored. Defaults to './data'. + name : str + Name of the dataset, supports: 'roman-empire', 'tolokers', 'cora_full', 'ogbn-arxiv'. + transform : callable, optional + A transform function to be applied to each Graph object (applied at access time). + pre_transform : callable, optional + A transform function to be applied to the Graph object before saving to disk. + force_reload : bool, optional + Whether to force reprocessing of the dataset. Defaults to False. + """ + + urls = { + 'roman-empire': 'https://github.com/yandex-research/heterophilous-graphs/raw/main/data/roman_empire.npz', + 'tolokers': 'https://github.com/yandex-research/heterophilous-graphs/raw/main/data/tolokers.npz', + 'cora_full': 'https://github.com/abojchevski/graph2gauss/raw/master/data/cora.npz', + 'ogbn-arxiv': 'http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip', + } + + def __init__(self, root: Optional[str] = None, name: str = 'roman-empire', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + force_reload: bool = False): + self.name = name.lower() + assert self.name in self.urls, f"Unsupported dataset '{self.name}'. Choose from {list(self.urls.keys())}" + + if root is None: + root = osp.join(osp.dirname(osp.abspath(__file__)), '..', 'data') + self.root = root + + super().__init__(root, transform, pre_transform, force_reload=force_reload) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + @property + def raw_dir(self) -> str: + return osp.join(self.root, self.name, 'raw') + + @property + def processed_dir(self) -> str: + return osp.join(self.root, self.name, 'processed') + + @property + def raw_file_names(self) -> List[str]: + if self.name == 'ogbn-arxiv': + return ['arxiv_loaded'] + else: + return [f'{self.name}.npz'] + + @property + def processed_file_names(self) -> str: + return tlx.BACKEND + '_data.pt' + + def download(self): + url = self.urls[self.name] + if self.name == 'ogbn-arxiv': + zip_path = download_url(url, self.raw_dir, filename='arxiv.zip') + with zipfile.ZipFile(zip_path, 'r') as zf: + zf.extractall(self.raw_dir) + os.remove(zip_path) + open(osp.join(self.raw_dir, 'arxiv_loaded'), 'a').close() + else: + download_url(url, self.raw_dir, filename=self.raw_file_names[0]) + + def process(self): + if self.name in ['roman-empire', 'tolokers']: + data = self._process_hetero_npz() + elif self.name == 'cora_full': + data = self._process_cora_full() + elif self.name == 'ogbn-arxiv': + data = self._process_ogbn_arxiv() + else: + raise NotImplementedError(f"Processing for {self.name} not implemented.") + if self.pre_transform is not None: + data = self.pre_transform(data) + self.save_data(self.collate([data]), self.processed_paths[0]) + + def _process_hetero_npz(self) -> Graph: + """处理 roman-empire 和 tolokers 数据集。""" + path = osp.join(self.raw_dir, self.raw_file_names[0]) + with np.load(path, allow_pickle=False) as z: + x = np.asarray(z['node_features'], dtype=np.float32) + y = np.asarray(z['node_labels'], dtype=np.int64).reshape(-1) + edges = np.asarray(z['edges'], dtype=np.int64) + if edges.ndim == 2 and edges.shape[0] == 2: + edge_index = edges + elif edges.ndim == 2 and edges.shape[1] == 2: + edge_index = edges.T + else: + raise ValueError(f'Unexpected edges shape in {self.name}.npz: {edges.shape}') + edge_index = coalesce(edge_index) + return Graph(x=x, edge_index=edge_index, y=y) + + def _process_cora_full(self) -> Graph: + try: + import scipy.sparse as sp + except ImportError: + raise ImportError("Loading 'cora_full' requires scipy. Install via: pip install scipy") + path = osp.join(self.raw_dir, self.raw_file_names[0]) + with np.load(path, allow_pickle=False) as f: + x = sp.csr_matrix( + (f['attr_data'], f['attr_indices'], f['attr_indptr']), + shape=tuple(f['attr_shape']) + ).todense() + x = np.asarray(x, dtype=np.float32) + x[x > 0.0] = 1.0 + adj = sp.csr_matrix( + (f['adj_data'], f['adj_indices'], f['adj_indptr']), + shape=tuple(f['adj_shape']) + ).tocoo() + edge_index = np.stack([adj.row, adj.col], axis=0).astype(np.int64) + y = np.asarray(f['labels'], dtype=np.int64).reshape(-1) + edge_index = coalesce(edge_index) + return Graph(x=x, edge_index=edge_index, y=y) + + def _process_ogbn_arxiv(self) -> Graph: + arxiv_raw_dir = osp.join(self.raw_dir, 'arxiv', 'raw') + split_dir = osp.join(self.raw_dir, 'arxiv', 'split', 'time') + + feat_path = osp.join(arxiv_raw_dir, 'node-feat.csv.gz') + x = pd.read_csv(feat_path, compression='gzip', header=None).values.astype(np.float32) + + edge_path = osp.join(arxiv_raw_dir, 'edge.csv.gz') + edge_index = pd.read_csv(edge_path, compression='gzip', header=None).values.T.astype(np.int64) + + label_path = osp.join(arxiv_raw_dir, 'node-label.csv.gz') + y = pd.read_csv(label_path, compression='gzip', header=None).values.reshape(-1).astype(np.int64) + + train_idx = pd.read_csv(osp.join(split_dir, 'train.csv.gz'), + compression='gzip', header=None).values.reshape(-1) + val_idx = pd.read_csv(osp.join(split_dir, 'valid.csv.gz'), + compression='gzip', header=None).values.reshape(-1) + test_idx = pd.read_csv(osp.join(split_dir, 'test.csv.gz'), + compression='gzip', header=None).values.reshape(-1) + + num_nodes = x.shape[0] + train_mask = np.zeros(num_nodes, dtype=bool) + val_mask = np.zeros(num_nodes, dtype=bool) + test_mask = np.zeros(num_nodes, dtype=bool) + train_mask[train_idx] = True + val_mask[val_idx] = True + test_mask[test_idx] = True + + edge_index = coalesce(edge_index) + graph = Graph(x=x, edge_index=edge_index, y=y) + graph.train_mask = train_mask + graph.val_mask = val_mask + graph.test_mask = test_mask + return graph + + + def get_idx_split(self): + if self.name != 'ogbn-arxiv': + raise NotImplementedError("Official split is only available for 'ogbn-arxiv'.") + return {'train': self.train_idx, 'valid': self.val_idx, 'test': self.test_idx} + + def __repr__(self) -> str: + return f"{self.name.capitalize()}()" \ No newline at end of file diff --git a/gammagl/models/gnrf.py b/gammagl/models/gnrf.py new file mode 100644 index 00000000..361a7ae8 --- /dev/null +++ b/gammagl/models/gnrf.py @@ -0,0 +1,372 @@ +import tensorlayerx as tlx +from tensorlayerx import nn +import os +from gammagl.mpops import unsorted_segment_sum, unsorted_segment_mean +TL_BACKEND = os.environ.get('TL_BACKEND', 'torch') +''' +ODEIntAdapter 是一个抽象类,用于适配不同的 ODE 求解器。 +子类需要实现 odeint 方法,该方法接收一个函数 func、初始值 y0、时间步 t、方法 method、相对误差 rtol、绝对误差 atol 和选项 options。 +返回值是一个包含解的张量,形状为 (t.shape[0], *y0.shape)。 +torch:method支持implicit_adams,dopri5,dopri8,rk4,bosh3,adaptive_heun,explicit_adams等 +tensorflow:由于 TensorFlow 图模式与变量跟踪机制,部分未参与计算图的网络层可能导致训练或梯度问题,当前适配器未实现 adjoint 模式,method支持dopri5,rk4,euler,midpoint,adaptive_heun等 +由于paddlepaddle和mindspore缺少成熟统一的 Neural ODE 求解生态,本模块中为手动实现odeint方法,method 参数当前仅作接口兼容保留 +paddlepaddle:rk4求解,固定步长积分,由t_paddle指定,method/rtol/atol/adjoint 参数当前未实际生效 +mindspore:euler求解,固定步长积分,由t_ms指定,method/rtol/atol/adjoint 参数当前未实际生效 + +ODEIntAdapter is an abstract class for adapting different ODE solvers. +Subclasses need to implement the odeint method, which receives a function func, initial value y0, time steps t, method method, relative tolerance rtol, absolute tolerance atol, and options options. +The return value is a tensor containing the solution, with shape (t.shape[0], *y0.shape). +torch: method supports implicit_adams, dopri5, dopri8, rk4, bosh3, adaptive_heun, explicit_adams, etc. +tensorflow: Due to TensorFlow's graph mode and variable tracking mechanism, some network layers not involved in the computation graph may cause training or gradient issues. The current adapter does not implement adjoint mode. method supports dopri5, rk4, euler, midpoint, adaptive_heun, etc. +Since paddlepaddle and mindspore lack a mature and unified Neural ODE solving ecosystem, the odeint method is manually implemented in this module. The method parameter is currently retained only for interface compatibility. +paddlepaddle: RK4 solver with fixed step size integration, specified by t_paddle. The method/rtol/atol/adjoint parameters are not currently effective. +mindspore: Euler solver with fixed step size integration, specified by t_ms. The method/rtol/atol/adjoint parameters are not currently effective. +''' + +class ODEIntAdapter: + @staticmethod + def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None): + raise NotImplementedError + +class TorchODEIntAdapter(ODEIntAdapter): + @staticmethod + def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None, adjoint=False): + try: + if adjoint: + from torchdiffeq import odeint_adjoint as torch_odeint + else: + from torchdiffeq import odeint as torch_odeint + except ImportError: + raise ImportError("使用 PyTorch 后端需要安装 torchdiffeq: pip install torchdiffeq") + + import torch + + y0_torch = y0._tensor if hasattr(y0, '_tensor') else torch.as_tensor(y0) + t_torch = t._tensor if hasattr(t, '_tensor') else torch.as_tensor(t) + + class TorchFuncWrapper(torch.nn.Module): + def __init__(self, original_func): + super().__init__() + self.original_func = original_func + params = [] + if hasattr(original_func, 'trainable_weights'): + for w in original_func.trainable_weights: + wt = w._tensor if hasattr(w, '_tensor') else w + if not isinstance(wt, torch.nn.Parameter): + wt = torch.nn.Parameter(wt) + params.append(wt) + elif hasattr(original_func, 'parameters'): + for w in original_func.parameters(): + wt = w._tensor if hasattr(w, '_tensor') else w + if not isinstance(wt, torch.nn.Parameter): + wt = torch.nn.Parameter(wt) + params.append(wt) + self.params = torch.nn.ParameterList(params) + + def forward(self, t_val, y_val): + res = self.original_func(t_val, y_val) + if hasattr(res, '_tensor'): + return res._tensor + return res + + torch_func = TorchFuncWrapper(func) + + kwargs = {} + if adjoint: + kwargs['adjoint_params'] = tuple(torch_func.parameters()) + + res_torch = torch_odeint( + torch_func, + y0_torch, + t_torch, + method=method, + rtol=rtol, + atol=atol, + options=options, + **kwargs + ) + return tlx.convert_to_tensor(res_torch) + + +class TensorFlowODEIntAdapter(ODEIntAdapter): + @staticmethod + def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None, adjoint=False): + try: + from tfdiffeq import odeint as tf_odeint + except ImportError: + raise ImportError("使用 TensorFlow 后端需要安装 tfdiffeq: pip install tfdiffeq") + + import tensorflow as tf + + y0_tf = y0._tensor if hasattr(y0, '_tensor') else tf.convert_to_tensor(y0) + t_tf = t._tensor if hasattr(t, '_tensor') else tf.convert_to_tensor(t) + + def tf_func(t_val, y_val): + res = func(t_val, y_val) + if hasattr(res, '_tensor'): + return res._tensor + return res + + ode_options = options or {} + ode_options.update({'rtol': rtol, 'atol': atol}) + + res_tf = tf_odeint(tf_func, y0_tf, t_tf, method=method, **ode_options) + return tlx.convert_to_tensor(res_tf) + + +class PaddleODEIntAdapter(ODEIntAdapter): + @staticmethod + def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None, adjoint=False): + try: + import paddle + except ImportError: + raise ImportError("使用 PaddlePaddle 后端需要安装 paddlepaddle") + + y0_paddle = y0._tensor if hasattr(y0, '_tensor') else paddle.to_tensor(y0) + t_paddle = t._tensor if hasattr(t, '_tensor') else paddle.to_tensor(t) + + def rk4_step(func, t_val, y_val, dt): + k1 = func(t_val, y_val) + k2 = func(t_val + dt/2, y_val + dt*k1/2) + k3 = func(t_val + dt/2, y_val + dt*k2/2) + k4 = func(t_val + dt, y_val + dt*k3) + return y_val + dt * (k1 + 2*k2 + 2*k3 + k4) / 6 + + def paddle_func(t_val, y_val): + res = func(t_val, y_val) + if hasattr(res, '_tensor'): + return res._tensor + return res + + ys = [y0_paddle] + for i in range(1, len(t_paddle)): + dt = t_paddle[i] - t_paddle[i-1] + y_next = rk4_step(paddle_func, t_paddle[i-1], ys[-1], dt) + ys.append(y_next) + + res_paddle = paddle.stack(ys, axis=0) + return tlx.convert_to_tensor(res_paddle) + + +class MindSporeODEIntAdapter(ODEIntAdapter): + @staticmethod + def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None, adjoint=False): + try: + import mindspore as ms + import mindspore.ops as ops + except ImportError: + raise ImportError("使用 MindSpore 后端需要安装 mindspore") + + y0_ms = y0._tensor if hasattr(y0, '_tensor') else ms.Tensor(y0) + t_ms = t._tensor if hasattr(t, '_tensor') else ms.Tensor(t) + + ys = [y0_ms] + for i in range(1, len(t_ms)): + dt = t_ms[i] - t_ms[i-1] + dy = func(t_ms[i-1], ys[-1]) + if hasattr(dy, '_tensor'): + dy_ms = dy._tensor + else: + dy_ms = dy + y_next = ys[-1] + dt * dy_ms + ys.append(y_next) + + res_ms = ops.stack(ys, axis=0) + return tlx.convert_to_tensor(res_ms) + + +_ADAPTERS = { + 'torch': TorchODEIntAdapter, + 'pytorch': TorchODEIntAdapter, + 'tensorflow': TensorFlowODEIntAdapter, + 'tf': TensorFlowODEIntAdapter, + 'paddle': PaddleODEIntAdapter, + 'mindspore': MindSporeODEIntAdapter, + 'ms': MindSporeODEIntAdapter, +} + +def get_adapter(backend=None): + backend = backend or TL_BACKEND + adapter = _ADAPTERS.get(backend.lower()) + if adapter is None: + raise ValueError(f"不支持的后端: {backend}. 请从 {list(_ADAPTERS.keys())} 中选择。") + return adapter + + +def odeint(func, y0, t, method='dopri5', rtol=1e-3, atol=1e-6, options=None, backend=None, adjoint=False): + adapter = get_adapter(backend) + return adapter.odeint(func, y0, t, method=method, rtol=rtol, atol=atol, options=options, adjoint=adjoint) + + +class SimpleMLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.0): + super().__init__() + self.lins = nn.ModuleList() + self.dropout_layer = tlx.nn.Dropout(p=dropout) + in_dims = [in_channels] + [hidden_channels] * (num_layers - 1) + out_dims = [hidden_channels] * (num_layers - 1) + [out_channels] + for i, o in zip(in_dims, out_dims): + self.lins.append(tlx.nn.Linear(in_features=i, out_features=o)) + + def forward(self, x): + for i, lin in enumerate(self.lins): + x = self.dropout_layer(x) + x = lin(x) + if i < len(self.lins) - 1: + x = tlx.nn.ReLU()(x) + return x + +class GNRF(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.edge_index = None + self.damping = args.damping + self.edgenet = args.edgenet + if self.edgenet: + self.mlp_1 = SimpleMLP( + in_channels=2 * args.num_hid, + hidden_channels=args.num_hid, + out_channels=args.num_hid, + num_layers=2, + dropout=args.dropout + ) + if args.channel_curv: + self.mlp_2 = SimpleMLP( + in_channels=2 * args.num_hid, + hidden_channels=args.num_hid, + out_channels=args.num_hid, + num_layers=2, + dropout=args.dropout + ) + else: + self.mlp_2 = SimpleMLP( + in_channels=2 * args.num_hid, + hidden_channels=args.num_hid, + out_channels=1, + num_layers=2, + dropout=args.dropout + ) + else: + self.a = tlx.Variable(initial_value=tlx.convert_to_tensor(0.5, dtype=tlx.float32),trainable=True) + + def set_edges(self, edge_index): + self.edge_index = edge_index + + def curvature(self, H_i, H_j): + curv = tlx.concat([H_i, H_j], axis=1) + curv = tlx.nn.ReLU()(self.mlp_1(curv)) + num_nodes = int(tlx.reduce_max(self.edge_index).item() + 1) if hasattr(tlx.reduce_max(self.edge_index), 'item') else int(tlx.reduce_max(self.edge_index)) + 1 + curv = unsorted_segment_sum(curv, self.edge_index[0], num_segments=num_nodes) + curv = tlx.concat([tlx.gather(curv, self.edge_index[0]),tlx.gather(curv, self.edge_index[1])], axis=1) + curv = self.mlp_2(curv) + return curv + + def forward(self, t, H): + eps = 1e-8 + if self.damping: + norm = tlx.sqrt(tlx.reduce_sum(tlx.square(H), axis=1, keepdims=True) + eps) + H = H / norm + H_i = tlx.gather(H, self.edge_index[0]) + H_j = tlx.gather(H, self.edge_index[1]) + if self.edgenet: + curv = self.curvature(H_i, H_j) + else: + curv = tlx.clip_by_value(self.a, eps, 1.0) + curv = tlx.ones((H_i.shape[0], 1)) * curv + if self.damping: + cos = tlx.reduce_sum(H_i * H_j, axis=1, keepdims=True) + H_edge = curv * (H_j - cos * H_i) + else: + H_edge = curv * (H_j - H_i) + dH = unsorted_segment_mean(H_edge, self.edge_index[0], num_segments=H.shape[0]) + if self.damping: + if hasattr(dH, '_tensor'): dH_t2 = dH._tensor + else: dH_t2 = dH + norm_dH = tlx.sqrt(tlx.reduce_sum(tlx.square(dH_t2), axis=1, keepdims=True) + eps) + dH = dH_t2 / norm_dH + return dH + + + +class GNN(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + if getattr(args, 'use_bn_in', False): + self.bn_in = tlx.nn.BatchNorm1d(num_features=args.num_hid, momentum=0.9) + if getattr(args, 'use_mlp_in', False): + self.mlp_in = SimpleMLP( + in_channels=args.num_hid, + hidden_channels=args.num_hid, + out_channels=args.num_hid, + num_layers=2, + dropout=args.dropout + ) + self.lin_in = tlx.nn.Linear(in_features=args.num_feat, out_features=args.num_hid) + + if getattr(args, 'use_bn_out', False): + self.bn_out = tlx.nn.BatchNorm1d(num_features=args.num_hid, momentum=0.9) + if getattr(args, 'use_mlp_out', False): + self.mlp_out = SimpleMLP( + in_channels=args.num_hid, + hidden_channels=args.num_hid, + out_channels=args.num_hid, + num_layers=2, + dropout=args.dropout + ) + self.lin_out = tlx.nn.Linear(in_features=args.num_hid, out_features=args.num_class) + + self.ODE_block = GNRF(args) + self.t = tlx.convert_to_tensor([args.t_start, args.t_end], dtype=tlx.float32) + self.t = tlx.to_device(self.t, args.device) + self.solver = args.solver + self.adjoint = args.adjoint + self.tol_scale = args.tol_scale + + self.dropout = tlx.nn.Dropout(p=args.dropout) + + def pre_transform(self, x, edge_index): + x = self.dropout(x) + x = self.lin_in(x) + x = tlx.nn.ReLU()(x) + if getattr(self.args, 'use_mlp_in', False): + x = self.mlp_in(x) + x = tlx.nn.ReLU()(x) + if getattr(self.args, 'use_bn_in', False): + x = self.bn_in(x) + return x + + def solve_ODE(self, x_0, edge_index): + self.ODE_block.set_edges(edge_index) + rtol = self.tol_scale * 1e-9 + atol = self.tol_scale * 1e-7 + + trajectory = odeint( + func=self.ODE_block, + y0=x_0, + t=self.t, + method=self.solver, + rtol=rtol, + atol=atol, + adjoint=self.adjoint + ) + end_state = trajectory[-1] + return end_state + + def post_transform(self, x, edge_index): + x = tlx.nn.ReLU()(x) + if getattr(self.args, 'use_bn_out', False): + x = self.bn_out(x) + if getattr(self.args, 'use_mlp_out', False): + x = self.mlp_out(x) + x = tlx.nn.ReLU()(x) + x = self.dropout(x) + x = self.lin_out(x) + return x + + def forward(self, x, edge_index): + x = self.pre_transform(x, edge_index) + x = self.solve_ODE(x, edge_index) + x = self.post_transform(x, edge_index) + return x