diff --git a/examples/coed/coed_trainer.py b/examples/coed/coed_trainer.py new file mode 100644 index 00000000..bdbfd723 --- /dev/null +++ b/examples/coed/coed_trainer.py @@ -0,0 +1,391 @@ +# !/usr/bin/env python +# -*- encoding: utf-8 -*- +""" +@File : coed_trainer.py +@Time : 2024/12/30 15:30:00 +@Author : GammaGL +""" + +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +# os.environ['TL_BACKEND'] = 'torch' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR + +import argparse +import numpy as np +import tensorlayerx as tlx +from gammagl.datasets import WebKB, WikipediaNetwork +from gammagl.models import CoEDModel +from gammagl.utils import mask_to_index +from tensorlayerx.model import TrainOneStep, WithLoss +import gammagl.transforms as T + +from geom_planetoid import load_planetoid_with_geom_splits + + +class SemiSpvzLoss(WithLoss): + r"""Loss wrapper for semi-supervised node classification.""" + + def __init__(self, net, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, y): + logits = self.backbone_network(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + train_logits = tlx.gather(logits, data['train_idx']) + train_y = tlx.gather(data['y'], data['train_idx']) + loss = self._loss_fn(train_logits, train_y) + return loss + + +def calculate_acc(logits, y, metrics): + r"""Compute accuracy via the TLX metrics API.""" + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst + + +def get_edge_index_and_theta(edge_index): + r"""Build the fuzzy edge list and initial phase angles from an edge_index. + + Symmetric (undirected) edges are kept only once with theta = pi/4; + directed edges are kept as-is with theta = 0. + """ + src = tlx.convert_to_numpy(edge_index[0]).tolist() + dst = tlx.convert_to_numpy(edge_index[1]).tolist() + + edges = [(int(u), int(v)) for u, v in zip(src, dst) if u != v] + edge_set = set(edges) + + triu_symm_edges = [] + triu_dir_edges = [] + tril_dir_edges = [] + + for u, v in edges: + if u < v: + if (v, u) in edge_set: + triu_symm_edges.append((u, v)) + else: + triu_dir_edges.append((u, v)) + elif u > v and (v, u) not in edge_set: + tril_dir_edges.append((u, v)) + + triu_symm_edges = sorted(set(triu_symm_edges)) + triu_dir_edges = sorted(set(triu_dir_edges)) + tril_dir_edges = sorted(set(tril_dir_edges)) + + if triu_symm_edges: + if not triu_dir_edges and not tril_dir_edges: + processed_edges = triu_symm_edges + theta = [np.pi / 4.0] * len(triu_symm_edges) + else: + processed_edges = triu_dir_edges + tril_dir_edges + triu_symm_edges + theta = [0.0] * (len(triu_dir_edges) + len(tril_dir_edges)) + [np.pi / 4.0] * len(triu_symm_edges) + else: + processed_edges = triu_dir_edges + tril_dir_edges + theta = [0.0] * len(processed_edges) + + edge_index_fuzzy = tlx.convert_to_tensor(np.array(processed_edges, dtype=np.int64).T, dtype=tlx.int64) + theta = tlx.convert_to_tensor(np.array(theta, dtype=np.float32), dtype=tlx.float32) + return edge_index_fuzzy, theta + + +def get_fuzzy_laplacian(edge_index, theta, num_nodes, edge_weight=None, add_self_loop=False): + r"""Construct normalized directional edge weights for CoED message passing. + + This implements the fuzzy Laplacian normalization described in the paper. + For each edge (i, j) with phase angle theta_k, the directional weights are: + - src-to-dst: cos^2(theta_k) + - dst-to-src: sin^2(theta_k) + These are then symmetrically normalized by node degrees. + """ + from gammagl.mpops import unsorted_segment_sum + + senders = edge_index[0] + receivers = edge_index[1] + + if edge_weight is None: + edge_weight = tlx.ones((tlx.get_tensor_shape(senders)[0],), dtype=tlx.float32) + + theta = tlx.cast(theta, tlx.float32) + edge_weight = tlx.cast(edge_weight, tlx.float32) + cos_sq = tlx.cos(theta) ** 2 + sin_sq = tlx.sin(theta) ** 2 + + conv_senders = tlx.concat([senders, receivers], axis=0) + conv_receivers = tlx.concat([receivers, senders], axis=0) + out_weight = tlx.concat([cos_sq * edge_weight, sin_sq * edge_weight], axis=0) + in_weight = tlx.concat([sin_sq * edge_weight, cos_sq * edge_weight], axis=0) + + if add_self_loop: + self_loops = tlx.arange(start=0, limit=num_nodes, dtype=tlx.int64) + ones = tlx.ones((num_nodes,), dtype=tlx.float32) + conv_senders = tlx.concat([conv_senders, self_loops], axis=0) + conv_receivers = tlx.concat([conv_receivers, self_loops], axis=0) + out_weight = tlx.concat([out_weight, ones], axis=0) + in_weight = tlx.concat([in_weight, ones], axis=0) + + deg_senders = tlx.reshape( + unsorted_segment_sum(out_weight, conv_senders, num_segments=num_nodes), (-1,) + ) + 1e-12 + deg_receivers = tlx.reshape( + unsorted_segment_sum(in_weight, conv_senders, num_segments=num_nodes), (-1,) + ) + 1e-12 + + deg_inv_sqrt_senders = tlx.where( + deg_senders < 1e-11, tlx.zeros_like(deg_senders), tlx.pow(deg_senders, -0.5) + ) + deg_inv_sqrt_receivers = tlx.where( + deg_receivers < 1e-11, tlx.zeros_like(deg_receivers), tlx.pow(deg_receivers, -0.5) + ) + + ew_src_to_dst = ( + tlx.gather(deg_inv_sqrt_senders, conv_senders) + * out_weight + * tlx.gather(deg_inv_sqrt_receivers, conv_receivers) + ) + ew_dst_to_src = ( + tlx.gather(deg_inv_sqrt_receivers, conv_senders) + * in_weight + * tlx.gather(deg_inv_sqrt_senders, conv_receivers) + ) + + conv_edge_index = tlx.stack([conv_senders, conv_receivers], axis=0) + conv_edge_weight = (tlx.reshape(ew_src_to_dst, (-1, 1)), tlx.reshape(ew_dst_to_src, (-1, 1))) + return conv_edge_index, conv_edge_weight + + +def set_seed(seed): + r"""Set random seeds for reproducible runs.""" + np.random.seed(seed) + tlx.set_seed(seed) + + +def main(args): + # ------------------------------------------------------------------ + # 1. Load dataset + # ------------------------------------------------------------------ + dataset_name = str.lower(args.dataset) + + if dataset_name in ['cora', 'pubmed', 'citeseer']: + # Planetoid with Geom-GCN 10 fixed splits + dataset, graph = load_planetoid_with_geom_splits( + root=args.dataset_path, name=dataset_name, + num_splits=args.num_splits, transform=T.NormalizeFeatures(), + ) + elif dataset_name in ['texas', 'wisconsin', 'cornell']: + dataset = WebKB(args.dataset_path, dataset_name, transform=T.NormalizeFeatures()) + graph = dataset[0] + # WebKB masks are flat 1D: concatenation of 10 splits + n = graph.num_nodes + train_idx = mask_to_index(graph.train_mask[args.split_idx * n: (args.split_idx + 1) * n]) + val_idx = mask_to_index(graph.val_mask[args.split_idx * n: (args.split_idx + 1) * n]) + test_idx = mask_to_index(graph.test_mask[args.split_idx * n: (args.split_idx + 1) * n]) + elif dataset_name in ['chameleon', 'squirrel']: + dataset = WikipediaNetwork(args.dataset_path, dataset_name, geom_gcn_preprocess=True) + graph = dataset[0] + # WikipediaNetwork masks are flat 1D: concatenation of 10 splits + n = graph.num_nodes + train_idx = mask_to_index(graph.train_mask[args.split_idx * n: (args.split_idx + 1) * n]) + val_idx = mask_to_index(graph.val_mask[args.split_idx * n: (args.split_idx + 1) * n]) + test_idx = mask_to_index(graph.test_mask[args.split_idx * n: (args.split_idx + 1) * n]) + else: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + + # ------------------------------------------------------------------ + # 2. Build fuzzy edge structure (dataset-level, shared across splits) + # ------------------------------------------------------------------ + if args.remove_existing_self_loop: + # Remove self-loops from the original edge_index + src = tlx.convert_to_numpy(graph.edge_index[0]) + dst = tlx.convert_to_numpy(graph.edge_index[1]) + mask = src != dst + graph.edge_index = tlx.convert_to_tensor( + np.array([src[mask], dst[mask]], dtype=np.int64), dtype=tlx.int64 + ) + + edge_index, theta = get_edge_index_and_theta(graph.edge_index) + num_nodes = graph.num_nodes + + conv_edge_index, conv_edge_weight = get_fuzzy_laplacian( + edge_index=edge_index, + theta=theta, + num_nodes=num_nodes, + add_self_loop=args.self_loop, + ) + + # ------------------------------------------------------------------ + # 3. Run multi-split evaluation + # ------------------------------------------------------------------ + split_test_accs = [] + + for split_id in range(args.num_splits): + # Reload masks for this split + if dataset_name in ['cora', 'pubmed', 'citeseer']: + # Geom-GCN splits: 2D masks [num_nodes, num_splits] + train_idx = mask_to_index(graph.train_mask[:, split_id]) + val_idx = mask_to_index(graph.val_mask[:, split_id]) + test_idx = mask_to_index(graph.test_mask[:, split_id]) + else: + n = graph.num_nodes + train_idx = mask_to_index(graph.train_mask[split_id * n: (split_id + 1) * n]) + val_idx = mask_to_index(graph.val_mask[split_id * n: (split_id + 1) * n]) + test_idx = mask_to_index(graph.test_mask[split_id * n: (split_id + 1) * n]) + + data = { + "x": graph.x, + "y": graph.y, + "edge_index": conv_edge_index, + "edge_weight": conv_edge_weight, + "train_idx": train_idx, + "test_idx": test_idx, + "val_idx": val_idx, + "num_nodes": num_nodes, + } + + for run in range(args.runs): + set_seed(args.seed + split_id * 97 + run) + + # Instantiate model + jk = args.jumping_knowledge if args.jumping_knowledge != "None" else None + net = CoEDModel( + feature_dim=dataset.num_node_features, + hidden_dim=args.hidden_dim, + num_class=dataset.num_classes, + num_layers=args.num_layers, + alpha=args.alpha, + drop_rate=args.drop_rate, + normalize=args.normalize, + self_feature_transform=args.self_feature_transform, + jumping_knowledge=jk, + name="CoED", + ) + + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) + metrics = tlx.metrics.Accuracy() + train_weights = net.trainable_weights + + loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + best_val_acc = 0 + best_test_acc = 0 + bad_counter = 0 + + for epoch in range(1, args.n_epoch + 1): + net.set_train() + train_loss = train_one_step(data, graph.y) + + net.set_eval() + logits = net(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + + val_logits = tlx.gather(logits, data['val_idx']) + val_y = tlx.gather(data['y'], data['val_idx']) + val_acc = calculate_acc(val_logits, val_y, metrics) + + test_logits = tlx.gather(logits, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + test_acc = calculate_acc(test_logits, test_y, metrics) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + bad_counter = 0 + net.save_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + else: + bad_counter += 1 + + if epoch % args.print_freq == 0 or epoch == 1: + print( + "split {:02d} run {:02d} epoch {:04d} " + "loss {:.4f} val {:.4f} best_test {:.4f} patience {}/{}".format( + split_id, run, epoch, + float(train_loss.item()), + val_acc, best_test_acc, + bad_counter, args.patience, + ) + ) + + if bad_counter >= args.patience: + break + + # Restore best model for final evaluation + net.load_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + if tlx.BACKEND == 'torch': + net.to(data['x'].device) + net.set_eval() + logits = net(data['x'], data['edge_index'], data['edge_weight'], data['num_nodes']) + test_logits = tlx.gather(logits, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + best_test_acc = calculate_acc(test_logits, test_y, metrics) + split_test_accs.append(best_test_acc) + print("split {:02d} run {:02d} best test acc: {:.5f}".format(split_id, run, best_test_acc * 100.0)) + + mean_test = float(np.mean(split_test_accs) * 100.0) + std_test = float(np.std(split_test_accs) * 100.0) + print("test acc: {:.5f} +/- {:.5f}".format(mean_test, std_test)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="CoED-GNN node classification with GammaGL/TensorLayerX.") + + # Dataset + parser.add_argument('--dataset', type=str, default='cora', + choices=['cora', 'texas', 'wisconsin', 'chameleon', 'squirrel'], + help='Dataset name.') + parser.add_argument('--dataset_path', type=str, default=r'', help='Path to save/load dataset.') + parser.add_argument('--num_splits', type=int, default=10, help='Number of fixed splits to evaluate.') + parser.add_argument('--split_idx', type=int, default=0, help='Unused when num_splits > 0.') + parser.add_argument('--runs', type=int, default=1, help='Runs per split.') + parser.add_argument('--seed', type=int, default=42, help='Random seed.') + + # Model + parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension.') + parser.add_argument('--num_layers', type=int, default=2, help='Number of GNN layers.') + parser.add_argument('--alpha', type=float, default=0.5, help='Direction convex combination parameter.') + parser.add_argument('--drop_rate', type=float, default=0.0, help='Feature dropout rate.') + parser.add_argument('--normalize', dest='normalize', action='store_true', + help='L2-normalize hidden features at each layer.') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.add_argument('--self_feature_transform', dest='self_feature_transform', action='store_true', + help='Learn a separate self-feature transform branch.') + parser.add_argument('--no_self_feature_transform', dest='self_feature_transform', action='store_false') + parser.add_argument('--self_loop', dest='self_loop', action='store_true', + help='Mix self features into directional messages.') + parser.add_argument('--no_self_loop', dest='self_loop', action='store_false') + parser.add_argument('--jumping_knowledge', type=str, default='None', + choices=['None', 'cat', 'max', 'lstm'], + help='Jumping-knowledge aggregation type.') + parser.add_argument('--remove_existing_self_loop', dest='remove_existing_self_loop', + action='store_true', + help='Remove existing self-loops from the graph before processing.') + parser.add_argument('--no_remove_existing_self_loop', dest='remove_existing_self_loop', + action='store_false') + + # Training + parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.') + parser.add_argument('--l2_coef', type=float, default=0.0, help='Weight decay (L2 regularization).') + parser.add_argument('--n_epoch', type=int, default=5000, help='Max training epochs.') + parser.add_argument('--patience', type=int, default=100, help='Early stopping patience.') + parser.add_argument('--print_freq', type=int, default=50, help='Print frequency (epochs).') + + # System + parser.add_argument('--best_model_path', type=str, default=r'./', help='Path to save best model.') + parser.add_argument('--gpu', type=int, default=0, help='GPU index, -1 for CPU.') + + parser.set_defaults( + normalize=False, + self_feature_transform=False, + self_loop=True, + remove_existing_self_loop=False, + ) + + args = parser.parse_args() + + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + main(args) diff --git a/examples/coed/geom_planetoid.py b/examples/coed/geom_planetoid.py new file mode 100644 index 00000000..ec8c8e1a --- /dev/null +++ b/examples/coed/geom_planetoid.py @@ -0,0 +1,53 @@ +"""Helpers for applying Geom-GCN 10-split evaluation to GammaGL Planetoid.""" + +import os +import os.path as osp + +import numpy as np +import tensorlayerx as tlx + +from gammagl.data import download_url +from gammagl.datasets import Planetoid + + +GEOM_GCN_URL = "https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master/splits" + + +def _geom_raw_dir(root, name): + return osp.join(root, name.lower(), "geom-gcn", "raw") + + +def _split_file(name, split_id): + return "{}_split_0.6_0.2_{}.npz".format(name.lower(), split_id) + + +def ensure_geom_gcn_splits(root, name, num_splits=10): + """Ensure Geom-GCN split files exist under the GammaGL dataset directory.""" + raw_dir = _geom_raw_dir(root, name) + os.makedirs(raw_dir, exist_ok=True) + for split_id in range(num_splits): + filename = _split_file(name, split_id) + path = osp.join(raw_dir, filename) + if not osp.exists(path): + download_url("{}/{}".format(GEOM_GCN_URL, filename), raw_dir) + return raw_dir + + +def load_planetoid_with_geom_splits(root, name, num_splits=10): + """Load Planetoid data and replace masks with Geom-GCN fixed splits.""" + dataset = Planetoid(root=root, name=name) + graph = dataset[0] + raw_dir = ensure_geom_gcn_splits(root, name, num_splits=num_splits) + + train_masks, val_masks, test_masks = [], [], [] + for split_id in range(num_splits): + split_path = osp.join(raw_dir, _split_file(name, split_id)) + split_data = np.load(split_path) + train_masks.append(split_data["train_mask"]) + val_masks.append(split_data["val_mask"]) + test_masks.append(split_data["test_mask"]) + + graph.train_mask = tlx.convert_to_tensor(np.stack(train_masks, axis=1), dtype=tlx.bool) + graph.val_mask = tlx.convert_to_tensor(np.stack(val_masks, axis=1), dtype=tlx.bool) + graph.test_mask = tlx.convert_to_tensor(np.stack(test_masks, axis=1), dtype=tlx.bool) + return dataset, graph diff --git a/examples/coed/readme.md b/examples/coed/readme.md new file mode 100644 index 00000000..c5b6fc26 --- /dev/null +++ b/examples/coed/readme.md @@ -0,0 +1,80 @@ +# CoED-GNN Node Classification + +- Paper link: [https://arxiv.org/abs/2410.14109](https://arxiv.org/abs/2410.14109) +- Author's code repo: [https://github.com/hormoz-lab/coed-gnn](https://github.com/hormoz-lab/coed-gnn) + +## Dataset Statics + +| Dataset | # Nodes | # Edges | # Classes | +|------------|---------|---------|-----------| +| Cora | 2,708 | 10,556 | 7 | +| Texas | 183 | 309 | 5 | +| Wisconsin | 251 | 515 | 5 | +| Chameleon | 2,277 | 36,101 | 5 | +| Squirrel | 5,201 | 217,073 | 5 | + +All datasets use the `Geom-GCN` 10 fixed splits for evaluation. + +## Files + +- `examples/coed/coed_trainer.py`: Multi-dataset training and evaluation entry +- `gammagl/models/coed.py`: CoED-GNN backbone model +- `gammagl/layers/conv/coed_conv.py`: CoED directional convolution layer + +## Results + +### Cora + +```bash +TL_BACKEND="torch" python examples/coed/coed_trainer.py --dataset cora +``` + +| Metric | Paper | Our(torch) | +|------------|------------|----------------------| +| Test Acc | 86.42 | 87.00 +/- 1.44 | + +### Texas + +```bash +TL_BACKEND="torch" python examples/coed/coed_trainer.py --dataset texas +``` + +| Metric | Paper | Our(torch) | +|------------|------------|----------------------| +| Test Acc | | | + +### Wisconsin + +```bash +TL_BACKEND="torch" python examples/coed/coed_trainer.py --dataset wisconsin +``` + +| Metric | Paper | Our(torch) | +|------------|------------|----------------------| +| Test Acc | | | + +### Chameleon + +```bash +TL_BACKEND="torch" python examples/coed/coed_trainer.py --dataset chameleon +``` + +| Metric | Paper | Our(torch) | +|------------|------------|----------------------| +| Test Acc | | | + +### Squirrel + +```bash +TL_BACKEND="torch" python examples/coed/coed_trainer.py --dataset squirrel +``` + +| Metric | Paper | Our(torch) | +|------------|------------|----------------------| +| Test Acc | | | + +## Notes + +- The default setup uses `hidden_dim=64`, `num_layers=2`, `lr=1e-3`, `l2_coef=0.0`, `alpha=0.5`, `self_loop=True`, `normalize=False`, `self_feature_transform=False`, `patience=100`, `n_epoch=5000`. +- The implementation evaluates all 10 Geom-GCN fixed splits and reports mean +/- std test accuracy. +- The model and convolution layers are registered in `gammagl/models/__init__.py` and `gammagl/layers/conv/__init__.py` and can be imported via standard GammaGL paths. diff --git a/examples/coed/reproduce_cora.sh b/examples/coed/reproduce_cora.sh new file mode 100755 index 00000000..81324351 --- /dev/null +++ b/examples/coed/reproduce_cora.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +set -e + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +source /home/mr/venv/gammagl-py311-cpu/bin/activate +export TL_BACKEND=torch +export PYTHONPATH="${ROOT_DIR}:${PYTHONPATH}" + +python "${ROOT_DIR}/examples/coed/coed_trainer.py" "$@" diff --git a/examples/coed/run_coed_cora.py b/examples/coed/run_coed_cora.py new file mode 100644 index 00000000..2d76cd0c --- /dev/null +++ b/examples/coed/run_coed_cora.py @@ -0,0 +1,19 @@ +"""Launcher for the CoED-GNN Cora reproduction.""" + +import os +import subprocess +import sys + + +def main(): + root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + env = os.environ.copy() + env.setdefault("TL_BACKEND", "torch") + env["PYTHONPATH"] = root + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") + + cmd = [sys.executable, os.path.join(os.path.dirname(__file__), "coed_trainer.py")] + sys.argv[1:] + raise SystemExit(subprocess.call(cmd, env=env, cwd=root)) + + +if __name__ == "__main__": + main() diff --git a/gammagl/layers/conv/__init__.py b/gammagl/layers/conv/__init__.py index 442db13a..b5782cf0 100644 --- a/gammagl/layers/conv/__init__.py +++ b/gammagl/layers/conv/__init__.py @@ -36,6 +36,7 @@ from .dhn_conv import DHNConv from .dna_conv import DNAConv from .rohehan_conv import RoheHANConv +from .coed_conv import CoEDConv __all__ = [ 'MessagePassing', @@ -75,7 +76,8 @@ 'HEATlayer', 'DHNConv', 'DNAConv', - 'RoheHANConv' + 'RoheHANConv', + 'CoEDConv' ] classes = __all__ diff --git a/gammagl/layers/conv/coed_conv.py b/gammagl/layers/conv/coed_conv.py new file mode 100644 index 00000000..9e670b00 --- /dev/null +++ b/gammagl/layers/conv/coed_conv.py @@ -0,0 +1,120 @@ +"""CoED directional convolution layer. + +This module implements the directional message passing operator used in +`"Co-Embedding of Edges and Directions for Graph Neural Networks" +`_. +""" + +import tensorlayerx as tlx +from tensorlayerx.nn import Linear + +from gammagl.layers.conv import MessagePassing + + +class CoEDConv(MessagePassing): + r"""The directional convolution operator used by CoED-GNN. + + The layer separately aggregates messages for two directional channels and + optionally applies an additional self-feature transformation. + + Parameters + ---------- + in_channels: int + Size of each input sample. + out_channels: int + Size of each output sample. + self_feature_transform: bool, optional + If set to :obj:`True`, adds an extra linear transform on the input node + features and combines it with directional messages. + bias: bool, optional + If set to :obj:`False`, the layer will not learn additive bias terms. + + """ + + def __init__(self, in_channels, out_channels, self_feature_transform=True, bias=True): + super().__init__() + self.self_feature_transform = self_feature_transform + + self.lin_src_to_dst = Linear( + in_features=in_channels, + out_features=out_channels, + W_init="xavier_uniform", + b_init=None, + ) + self.lin_dst_to_src = Linear( + in_features=in_channels, + out_features=out_channels, + W_init="xavier_uniform", + b_init=None, + ) + + if self_feature_transform: + self.lin_self = Linear( + in_features=in_channels, + out_features=out_channels, + W_init="xavier_uniform", + b_init=None, + ) + else: + self.lin_self = None + + if bias: + zeros = tlx.initializers.Zeros() + self.bias_src_to_dst = self._get_weights("bias_src_to_dst", shape=(out_channels,), init=zeros) + self.bias_dst_to_src = self._get_weights("bias_dst_to_src", shape=(out_channels,), init=zeros) + self.bias_self = ( + self._get_weights("bias_self", shape=(out_channels,), init=zeros) + if self_feature_transform + else None + ) + else: + self.bias_src_to_dst = None + self.bias_dst_to_src = None + self.bias_self = None + + def forward(self, x, edge_index, edge_weight=None, num_nodes=None): + """Compute directional node representations.""" + if num_nodes is None: + num_nodes = tlx.get_tensor_shape(x)[0] + + if isinstance(edge_weight, (tuple, list)): + edge_weight_src_to_dst, edge_weight_dst_to_src = edge_weight + else: + edge_weight_src_to_dst = edge_weight + edge_weight_dst_to_src = edge_weight + + x_src_to_dst = self.propagate( + x=x, + edge_index=edge_index, + edge_weight=edge_weight_src_to_dst, + num_nodes=num_nodes, + ) + x_dst_to_src = self.propagate( + x=x, + edge_index=edge_index, + edge_weight=edge_weight_dst_to_src, + num_nodes=num_nodes, + ) + + x_src_to_dst = self.lin_src_to_dst.forward(x_src_to_dst) + x_dst_to_src = self.lin_dst_to_src.forward(x_dst_to_src) + + if self.bias_src_to_dst is not None: + x_src_to_dst = x_src_to_dst + self.bias_src_to_dst + if self.bias_dst_to_src is not None: + x_dst_to_src = x_dst_to_src + self.bias_dst_to_src + + if self.self_feature_transform: + x_self = self.lin_self.forward(x) + if self.bias_self is not None: + x_self = x_self + self.bias_self + return x_src_to_dst, x_dst_to_src, x_self + + return x_src_to_dst, x_dst_to_src + + def message(self, x, edge_index, edge_weight=None): + """Construct messages on each edge.""" + msg = tlx.gather(x, edge_index[0, :]) + if edge_weight is None: + return msg + return msg * tlx.reshape(edge_weight, (-1, 1)) diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 062ee67e..af583985 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -67,6 +67,7 @@ from .sgformer import SGFormerModel from .adagad import PreModel, ReModel from .nodeid import NodeIDGNN +from .coed import CoEDModel __all__ = [ 'HeCo', @@ -142,7 +143,8 @@ 'sgformer', 'PreModel', 'ReModel' - , 'NodeIDGNN' + , 'NodeIDGNN', + 'CoEDModel' ] classes = __all__ diff --git a/gammagl/models/coed.py b/gammagl/models/coed.py new file mode 100644 index 00000000..abf1b980 --- /dev/null +++ b/gammagl/models/coed.py @@ -0,0 +1,132 @@ +"""CoED-GNN backbone model. + +This module implements the node classification backbone described in +`"Co-Embedding of Edges and Directions for Graph Neural Networks" +`_. +""" + +import tensorlayerx as tlx +from tensorlayerx.nn import Dropout, Linear, Module, ReLU + +from gammagl.layers.conv import CoEDConv, JumpingKnowledge + + +class CoEDModel(Module): + r"""CoED-GNN model for node classification. + + Parameters + ---------- + feature_dim: int + Input feature dimension. + hidden_dim: int + Hidden feature dimension. + num_class: int + Number of output classes. + num_layers: int, optional + Number of directional convolution layers. + alpha: float, optional + Mixture coefficient for combining the two directional channels. + drop_rate: float, optional + Dropout rate applied between hidden layers. + normalize: bool, optional + If set to :obj:`True`, applies L2 normalization to hidden features. + self_feature_transform: bool, optional + If set to :obj:`True`, each CoED layer also learns a self-feature + transform branch. + jumping_knowledge: str, optional + Type of jumping-knowledge aggregation (:obj:`"cat"`, :obj:`"max"`, + :obj:`"lstm"`, or :obj:`None`). When set, intermediate layer + outputs are aggregated and projected through an additional linear + layer. + name: str, optional + Model name. + + """ + + def __init__( + self, + feature_dim, + hidden_dim, + num_class, + num_layers=2, + alpha=0.0, + drop_rate=0.5, + normalize=False, + self_feature_transform=False, + jumping_knowledge=None, + name=None, + ): + super().__init__(name=name) + self.alpha = alpha + self.num_layers = num_layers + self.normalize = normalize + self.jumping_knowledge = jumping_knowledge + + self.convs = [] + in_channels = feature_dim + for layer_idx in range(num_layers): + conv = CoEDConv( + in_channels=in_channels, + out_channels=hidden_dim, + self_feature_transform=self_feature_transform, + ) + self.convs.append(conv) + self.add_module("conv{}".format(layer_idx + 1), conv) + in_channels = hidden_dim + + if jumping_knowledge is not None: + self.jump = JumpingKnowledge(jumping_knowledge, hidden_dim, num_layers) + if jumping_knowledge == "cat": + jk_dim = hidden_dim * num_layers + else: + jk_dim = hidden_dim + self.lin = Linear( + in_features=jk_dim, + out_features=num_class, + W_init="xavier_uniform", + b_init=tlx.initializers.Zeros(), + ) + self.readout = None + else: + self.jump = None + self.lin = None + self.readout = Linear( + in_features=hidden_dim, + out_features=num_class, + W_init="xavier_uniform", + b_init=tlx.initializers.Zeros(), + ) + + self.relu = ReLU() + self.dropout = Dropout(p=drop_rate) + + def combine(self, xs): + """Combine directional features with the optional self-feature branch.""" + if len(xs) == 3: + x_src_to_dst, x_dst_to_src, x_self = xs + return self.alpha * x_src_to_dst + (1.0 - self.alpha) * x_dst_to_src + x_self + + x_src_to_dst, x_dst_to_src = xs + return self.alpha * x_src_to_dst + (1.0 - self.alpha) * x_dst_to_src + + def forward(self, x, edge_index, edge_weight=None, num_nodes=None): + """Compute node logits.""" + x_intermediate = [] + + for layer_idx, conv in enumerate(self.convs): + x = self.combine(conv.forward(x, edge_index, edge_weight=edge_weight, num_nodes=num_nodes)) + + if layer_idx != self.num_layers - 1 or self.jump is not None: + x = self.relu.forward(x) + x = self.dropout.forward(x) + if self.normalize: + x = tlx.l2_normalize(x, axis=1) + x_intermediate.append(x) + + if self.jump is not None: + x = self.jump(x_intermediate) + x = self.lin.forward(x) + else: + x = self.readout.forward(x) + + return x