From e14437f2d1b8ef9e67b2739635161302dfe0c8b4 Mon Sep 17 00:00:00 2001 From: wensi123 <1763506289@qq.com> Date: Sun, 29 Jun 2025 17:03:46 +0800 Subject: [PATCH] [Model] NASA --- examples/nasa/README.md | 21 ++++ examples/nasa/nasa_gcn_trainer.py | 153 +++++++++++++++++++++++++++++ gammagl/models/__init__.py | 2 + gammagl/models/nasa_gcn.py | 87 ++++++++++++++++ gammagl/transforms/__init__.py | 5 +- gammagl/transforms/nr_augmentor.py | 94 ++++++++++++++++++ gammagl/utils/__init__.py | 7 +- gammagl/utils/nasa_utils.py | 87 ++++++++++++++++ 8 files changed, 453 insertions(+), 3 deletions(-) create mode 100644 examples/nasa/README.md create mode 100644 examples/nasa/nasa_gcn_trainer.py create mode 100644 gammagl/models/nasa_gcn.py create mode 100644 gammagl/transforms/nr_augmentor.py create mode 100644 gammagl/utils/nasa_utils.py diff --git a/examples/nasa/README.md b/examples/nasa/README.md new file mode 100644 index 000000000..31e625d1b --- /dev/null +++ b/examples/nasa/README.md @@ -0,0 +1,21 @@ +# Regularizing GNNs via Consistency-Diversity Graph Augmentations (NASA) + +This example implements the model from the paper: [Regularizing Graph Neural Networks via Consistency-Diversity Graph Augmentations](https://arxiv.org/abs/2110.07627) (AAAI 2022). + +The implementation includes: +- `NR_Augmentor`: A graph transformation that implements the "Neighbor Replacement" (NR) augmentation strategy. +- `NASA_GCN`: A GCN-based model that incorporates the neighbor-constrained regularization loss (L_CR). + +## How to Run + +You can run the training script from the root directory of the GammaGL repository: + +```bash +# Run on Cora dataset +python examples/nasa/nasa_gcn_trainer.py --dataset Cora + +# Run on Citeseer dataset +python examples/nasa/nasa_gcn_trainer.py --dataset Citeseer + +# Run on PubMed dataset +python examples/nasa/nasa_gcn_trainer.py --dataset PubMed \ No newline at end of file diff --git a/examples/nasa/nasa_gcn_trainer.py b/examples/nasa/nasa_gcn_trainer.py new file mode 100644 index 000000000..99f0eb2da --- /dev/null +++ b/examples/nasa/nasa_gcn_trainer.py @@ -0,0 +1,153 @@ +import tensorlayerx as tlx +import argparse +import time + +from gammagl.data import Graph +from gammagl.utils import mask_to_index +from gammagl.datasets import Planetoid +from gammagl.models import NASA_GCN +from gammagl.transforms import NR_Augmentor +from gammagl.utils import accuracy_tlx, compute_gcn_norm + +def main(args): + + try: + tlx.set_device(device='GPU', id=args.gpu_id) + print(f"Using GPU: {args.gpu_id}") + except: + tlx.set_device(device='CPU') + print("GPU not available, using CPU.") + + try: + dataset = Planetoid(root=args.dataset_path, name=args.dataset) + except Exception as e: + print(f"Error loading dataset {args.dataset}: {e}") + print("Please ensure the dataset name is correct (Cora, Citeseer, PubMed) and it's downloadable.") + return + + graph = dataset[0] + num_nodes = graph.num_nodes + + graph_x = tlx.convert_to_tensor(graph.x, dtype=tlx.float32) + graph_y = tlx.convert_to_tensor(graph.y, dtype=tlx.int64) + graph_train_mask = tlx.convert_to_tensor(graph.train_mask, dtype=tlx.bool) + graph_val_mask = tlx.convert_to_tensor(graph.val_mask, dtype=tlx.bool) + graph_test_mask = tlx.convert_to_tensor(graph.test_mask, dtype=tlx.bool) + + eval_edge_index, eval_edge_weight = compute_gcn_norm( + graph.edge_index, num_nodes, dtype=graph_x.dtype, add_self_loops_flag=True + ) + eval_edge_index = tlx.convert_to_tensor(eval_edge_index) + eval_edge_weight = tlx.convert_to_tensor(eval_edge_weight) + + # initialize + augmentor = NR_Augmentor(probability=args.nr_prob) + + model = NASA_GCN( + feature_dim=dataset.num_node_features, + hidden_dim=args.hidden_dim, + num_classes=dataset.num_classes, + dropout_rate=args.dropout, + temp=args.temp, + alpha=args.alpha + ) + + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay) + train_weights = model.trainable_weights + + best_val_acc = 0 + best_test_acc = 0 + best_epoch = 0 + + print("Starting training...") + for epoch in range(args.epochs): + epoch_start_time = time.time() + model.set_train() + + # 1. Dynamic Augmentation + temp_original_graph_for_aug = Graph(x=graph_x, edge_index=graph.edge_index, num_nodes=num_nodes) + augmented_graph = augmentor.augment(temp_original_graph_for_aug) + + # 2. Preprocess augmented graph for GCN + aug_edge_index, aug_edge_weight = compute_gcn_norm( + augmented_graph.edge_index, + tlx.get_tensor_shape(augmented_graph.x)[0], + dtype=augmented_graph.x.dtype, + add_self_loops_flag=True + ) + aug_edge_index = tlx.convert_to_tensor(aug_edge_index) + aug_edge_weight = tlx.convert_to_tensor(aug_edge_weight) + + import tensorflow as tf + with tf.GradientTape() as tape: + output_logits_aug = model( + augmented_graph.x, + aug_edge_index, + edge_weight=aug_edge_weight, + num_nodes=tlx.get_tensor_shape(augmented_graph.x)[0] + ) + + loss = model.compute_nasa_loss( + output_logits_aug, + augmented_graph, + graph_y, + graph_train_mask + ) + + gradients = tape.gradient(loss, train_weights) + optimizer.apply_gradients(zip(gradients, train_weights)) + + # evaluation + model.set_eval() + + eval_logits = model.predict( + graph_x, + eval_edge_index, + edge_weight=eval_edge_weight, + num_nodes=num_nodes + ) + eval_pred_softmax = tlx.softmax(eval_logits, axis=-1) + + train_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_train_mask)), + tlx.gather(graph_y, mask_to_index(graph_train_mask))) + val_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_val_mask)), + tlx.gather(graph_y, mask_to_index(graph_val_mask))) + test_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_test_mask)), + tlx.gather(graph_y, mask_to_index(graph_test_mask))) + + epoch_duration = time.time() - epoch_start_time + print(f"Epoch {epoch+1:03d}/{args.epochs} | Loss: {loss.item():.4f} | " + f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f} | " + f"Time: {epoch_duration:.2f}s") + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_test_acc = test_acc + best_epoch = epoch +1 + # save the best model + # model.save_weights(f"nasa_gcn_{args.dataset}_best.npz") + # print(f"New best validation accuracy: {best_val_acc:.4f}, saving model.") + + print("Training finished.") + print(f"Best Epoch: {best_epoch}, Best Val Acc: {best_val_acc:.4f}, Corresponding Test Acc: {best_test_acc:.4f}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="NASA GCN training with GammaGL") + parser.add_argument('--dataset', type=str, default='Cora', help='Dataset name (Cora, Citeseer, PubMed)') + parser.add_argument('--dataset_path', type=str, default='./data', help='Path to store/load datasets') + parser.add_argument('--epochs', type=int, default=500, help='Number of training epochs') + parser.add_argument('--lr', type=float, default=0.01, help='Learning rate') + parser.add_argument('--weight_decay', type=float, default=1e-3, help='Weight decay for Adam optimizer') + parser.add_argument('--hidden_dim', type=int, default=32, help='Number of hidden units in GCN') + parser.add_argument('--dropout', type=float, default=0.7, help='Dropout rate for GCN layers and input features') + # NASA specific hyperparameters + parser.add_argument('--alpha', type=float, default=1.0, help='Weight for L_CR loss component') + parser.add_argument('--temp', type=float, default=0.5, help='Temperature for sharpening pseudo-labels') + parser.add_argument('--nr_prob', type=float, default=0.5, help='Probability for Neighbor Replacement') + + parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID to use, -1 for CPU') + + cli_args = parser.parse_args() + print("Arguments:", cli_args) + main(cli_args) \ No newline at end of file diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 0d0668893..e9357cd7c 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -68,6 +68,7 @@ from .adagad import PreModel, ReModel from .dyfss import MoeSSL,VGAE,Discriminator,InnerProductDecoder from .egt import EGTModel +from .nasa_gcn import NASA_GCN __all__ = [ 'HeCo', @@ -149,6 +150,7 @@ 'Discriminator', 'InnerProductDecoder', 'EGTModel', + 'NASA_GCN', ] classes = __all__ diff --git a/gammagl/models/nasa_gcn.py b/gammagl/models/nasa_gcn.py new file mode 100644 index 000000000..87b785de0 --- /dev/null +++ b/gammagl/models/nasa_gcn.py @@ -0,0 +1,87 @@ +import tensorlayerx as tlx +from tensorlayerx.nn import Module as Model +from gammagl.layers.conv import GCNConv +from gammagl.utils import mask_to_index +# from gammagl.mpops import unsorted_segment_mean + +class NASA_GCN(Model): + def __init__(self, feature_dim, hidden_dim, num_classes, dropout_rate, temp, alpha, name=None): + super().__init__(name=name) + self.conv1 = GCNConv(in_channels=feature_dim, out_channels=hidden_dim) + self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=num_classes) + self.dropout = tlx.layers.Dropout(p=dropout_rate) + self.temp = temp + self.alpha = alpha + self.elu = tlx.layers.ELU() + + def _compute_L_CR(self, aug_graph_edge_index, aug_pred_softmax, num_nodes): + src, dst = aug_graph_edge_index[0], aug_graph_edge_index[1] + + # 1. Compute average of neighbor predictions (ลท_i) + aug_pred_softmax_src = tlx.gather(aug_pred_softmax, src) + #avg_pred = unsorted_segment_mean(data=aug_pred_softmax_src, segment_ids=dst, num_segments=num_nodes) + avg_pred = tlx.ops.unsorted_segment_mean(aug_pred_softmax_src, dst, num_segments=num_nodes) + + # Handle nodes with no incoming messages if unsorted_segment_mean results in NaN/Inf + #avg_pred = tlx.where(tlx.is_finite(avg_pred), avg_pred, tlx.zeros_like(avg_pred)) + is_inf_avg_pred = tlx.is_inf(avg_pred) + is_nan_avg_pred = tlx.is_nan(avg_pred) + avg_pred_finite_mask = tlx.logical_and( + tlx.logical_not(is_inf_avg_pred), + tlx.logical_not(is_nan_avg_pred) + ) + avg_pred = tlx.where(avg_pred_finite_mask, avg_pred, tlx.zeros_like(avg_pred)) + + # 2. Sharpening (p_i) + avg_pred_eps = avg_pred + 1e-12 + pow_avg_pred = tlx.pow(avg_pred_eps, 1.0 / self.temp) + sharp_pseudo_labels = pow_avg_pred / (tlx.reduce_sum(pow_avg_pred, axis=1, keepdims=True) + 1e-12) + sharp_pseudo_labels_detached = tlx.ops.stop_gradient(sharp_pseudo_labels) + + # 3. Compute KL Divergence Loss + p_dst_detached = tlx.gather(sharp_pseudo_labels_detached, dst) + q_src_softmax = tlx.gather(aug_pred_softmax, src) + + log_p_dst_detached = tlx.log(p_dst_detached + 1e-12) + log_q_src_softmax = tlx.log(q_src_softmax + 1e-12) + + kl_div_elements = p_dst_detached * (log_p_dst_detached - log_q_src_softmax) + kl_div_per_edge = tlx.reduce_sum(kl_div_elements, axis=1) + + if tlx.get_tensor_shape(kl_div_per_edge)[0] > 0: + loss_cr = tlx.reduce_mean(kl_div_per_edge) + else: + loss_cr = tlx.convert_to_tensor(0.0, dtype=tlx.float32) + + return loss_cr + + def forward(self, x, edge_index, edge_weight=None, num_nodes=None): + h = self.dropout(x) + h = self.conv1(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) + h = self.elu(h) + h = self.dropout(h) + logits = self.conv2(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) + return logits + + def compute_nasa_loss(self, output_logits_aug, augmented_graph, original_graph_labels, original_graph_train_mask): + train_indices = mask_to_index(original_graph_train_mask) + gathered_logits_aug = tlx.gather(output_logits_aug, train_indices) + gathered_labels = tlx.gather(original_graph_labels, train_indices) + + loss_ce = tlx.losses.softmax_cross_entropy_with_logits(gathered_logits_aug, gathered_labels) + + aug_pred_softmax = tlx.softmax(output_logits_aug, axis=-1) + loss_cr = self._compute_L_CR( + augmented_graph.edge_index, + aug_pred_softmax, + tlx.get_tensor_shape(augmented_graph.x)[0] + ) + + total_loss = loss_ce + self.alpha * loss_cr + return total_loss + + def predict(self, x, edge_index, edge_weight=None, num_nodes=None): + h = self.conv1(x, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) + h = self.elu(h) + logits = self.conv2(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes) + return logits \ No newline at end of file diff --git a/gammagl/transforms/__init__.py b/gammagl/transforms/__init__.py index b7ff5dfbd..c03fa8dc9 100644 --- a/gammagl/transforms/__init__.py +++ b/gammagl/transforms/__init__.py @@ -7,6 +7,7 @@ from .random_link_split import RandomLinkSplit from .vgae_pre import mask_test_edges, sparse_to_tuple from .svd_feature_reduction import SVDFeatureReduction +from .nr_augmentor import NR_Augmentor __all__ = [ 'BaseTransform', @@ -18,8 +19,8 @@ 'RandomLinkSplit', 'mask_test_edges', 'sparse_to_tuple', - 'SVDFeatureReduction' - + 'SVDFeatureReduction', + 'NR_Augmentor' ] classes = __all__ diff --git a/gammagl/transforms/nr_augmentor.py b/gammagl/transforms/nr_augmentor.py new file mode 100644 index 000000000..a43bc500f --- /dev/null +++ b/gammagl/transforms/nr_augmentor.py @@ -0,0 +1,94 @@ +# gammagl/transforms/nr_augmentor.py + +import tensorlayerx as tlx +import numpy as np +import random +from gammagl.data import Graph +from gammagl.utils import to_undirected, coalesce + +class NR_Augmentor: + def __init__(self, probability=0.5): + """ + Neighbor Replacement Augmentor for GammaGL. + This augmentor implements the NeighborReplace (NR) strategy from the paper + "Regularizing Graph Neural Networks via Consistency-Diversity Graph Augmentations". + + Args: + probability (float): Probability of replacing a 1-hop neighbor with a 2-hop neighbor. + """ + self.probability = probability + + def _get_1hop_neighbors_dict(self, edge_index, num_nodes): + """ + Creates an adjacency list dictionary {node_idx: [neighbor1, neighbor2,...]} + from edge_index, containing unique 1-hop neighbors (excluding self-loops). + """ + adj = {i: set() for i in range(num_nodes)} + src_nodes, dst_nodes = tlx.convert_to_numpy(edge_index[0]), tlx.convert_to_numpy(edge_index[1]) + for i in range(len(src_nodes)): + u, v = src_nodes[i], dst_nodes[i] + if u != v: + adj[u].add(v) + return {node_idx: list(neighbors) for node_idx, neighbors in adj.items()} + + def __call__(self, original_graph: Graph) -> Graph: + """ + Applies Neighbor Replacement augmentation to the input graph. + Making the class callable is a common pattern for transforms. + + Args: + original_graph (gammagl.data.Graph): The original graph object. + + Returns: + gammagl.data.Graph: The augmented graph object. + """ + num_nodes = original_graph.num_nodes + + adj_1hop_dict = self._get_1hop_neighbors_dict(original_graph.edge_index, num_nodes) + + new_edge_src_list = [] + new_edge_dst_list = [] + + for u_node_idx in range(num_nodes): + current_1hop_neighbors_of_u = adj_1hop_dict[u_node_idx] + if not current_1hop_neighbors_of_u: + continue + + for v_neighbor_idx in current_1hop_neighbors_of_u: + if random.random() < self.probability: + potential_2hop_neighbors_of_u_via_v = [ + vv_node for vv_node in adj_1hop_dict.get(v_neighbor_idx, []) + if vv_node != u_node_idx and vv_node != v_neighbor_idx + ] + + if potential_2hop_neighbors_of_u_via_v: + vv_chosen = random.choice(potential_2hop_neighbors_of_u_via_v) + new_edge_src_list.append(u_node_idx) + new_edge_dst_list.append(vv_chosen) + else: + new_edge_src_list.append(u_node_idx) + new_edge_dst_list.append(v_neighbor_idx) + else: + new_edge_src_list.append(u_node_idx) + new_edge_dst_list.append(v_neighbor_idx) + + if not new_edge_src_list: + aug_edge_index = tlx.convert_to_tensor(np.array([[],[]]), dtype=tlx.int64) + else: + temp_edge_index = tlx.stack([ + tlx.convert_to_tensor(new_edge_src_list, dtype=tlx.int64), + tlx.convert_to_tensor(new_edge_dst_list, dtype=tlx.int64) + ]) + + undirected_temp_edge_index = to_undirected(temp_edge_index, num_nodes=num_nodes) + aug_edge_index = coalesce(undirected_temp_edge_index, num_nodes=num_nodes) + + aug_graph = Graph(x=original_graph.x, edge_index=aug_edge_index, num_nodes=num_nodes) + + # Copy other essential attributes if they exist + if hasattr(original_graph, 'y'): aug_graph.y = original_graph.y + if hasattr(original_graph, 'train_mask'): aug_graph.train_mask = original_graph.train_mask + if hasattr(original_graph, 'val_mask'): aug_graph.val_mask = original_graph.val_mask + if hasattr(original_graph, 'test_mask'): aug_graph.test_mask = original_graph.test_mask + + return aug_graph \ No newline at end of file diff --git a/gammagl/utils/__init__.py b/gammagl/utils/__init__.py index dcaf483a9..cca07fc83 100644 --- a/gammagl/utils/__init__.py +++ b/gammagl/utils/__init__.py @@ -22,6 +22,9 @@ from .get_laplacian import get_laplacian from .simple_path import find_all_simple_paths from .dotdict import HDict +from .nasa_utils import compute_gcn_norm +from .nasa_utils import accuracy_tlx + __all__ = [ 'calc_A_norm_hat', 'calc_gcn_norm', @@ -50,7 +53,9 @@ 'get_laplacian', 'find_all_simple_paths', 'edge_index_to_adj_matrix', - 'HDict' + 'HDict', + 'accuracy_tlx', + 'compute_gcn_norm', ] classes = __all__ diff --git a/gammagl/utils/nasa_utils.py b/gammagl/utils/nasa_utils.py new file mode 100644 index 000000000..fd2781fd1 --- /dev/null +++ b/gammagl/utils/nasa_utils.py @@ -0,0 +1,87 @@ +import tensorlayerx as tlx +from gammagl.utils import add_self_loops +from gammagl.mpops import unsorted_segment_sum + +def accuracy_tlx(logits, labels): + if tlx.BACKEND == 'tensorflow': + predicted_labels = tlx.argmax(logits, axis=-1) + correct = tlx.reduce_sum(tlx.cast(tlx.equal(predicted_labels, labels), dtype=tlx.float32)) + return correct / tlx.cast(tlx.get_tensor_shape(labels)[0], dtype=tlx.float32) + elif tlx.BACKEND == 'torch': + _, indices = tlx.ops.max(logits, dim=1) + correct = tlx.reduce_sum(tlx.cast(indices == labels, dtype=tlx.float32)) + return correct / float(len(labels)) + elif tlx.BACKEND == 'paddle': + predicted_labels = tlx.argmax(logits, axis=-1) + correct = tlx.reduce_sum(tlx.cast(tlx.equal(predicted_labels, labels), dtype=tlx.float32)) + return correct / float(tlx.get_tensor_shape(labels)[0]) + elif tlx.BACKEND == 'mindspore': + predicted_labels = tlx.argmax(logits, axis=-1) + correct = tlx.reduce_sum(tlx.cast(tlx.equal(predicted_labels, labels), dtype=tlx.float32)) + return correct / float(tlx.get_tensor_shape(labels)[0]) + else: + raise NotImplementedError(f"Accuracy not implemented for backend: {tlx.BACKEND}") + + +def compute_gcn_norm(edge_index, num_nodes, dtype, add_self_loops_flag=True): + """ + Computes GCN normalization for edges (D^-0.5 * A * D^-0.5). + Adapted from PyG's GCNConv normalization. + + Args: + edge_index: Edge index tensor. + num_nodes: Number of nodes in the graph. + dtype: Dtype for the edge weights. + add_self_loops_flag: Whether to add self-loops before normalization. + + Returns: + Tuple[Tensor, Tensor]: Normalized edge_index and edge_weight. + """ + edge_src, edge_dst = edge_index[0], edge_index[1] + + if add_self_loops_flag: + edge_index_sl, _ = add_self_loops(edge_index, num_nodes=num_nodes) + edge_src_sl, edge_dst_sl = edge_index_sl[0], edge_index_sl[1] + else: + edge_index_sl = edge_index + edge_src_sl, edge_dst_sl = edge_src, edge_dst + + edge_weight_ones = tlx.ones(shape=(tlx.get_tensor_shape(edge_src_sl)[0],), dtype=dtype) + + # Calculate D_out (degree of source nodes of the SL-augmented graph) + deg_out = unsorted_segment_sum(edge_weight_ones, edge_src_sl, num_segments=num_nodes) + # Calculate D_in (degree of destination nodes of the SL-augmented graph) + deg_in = unsorted_segment_sum(edge_weight_ones, edge_dst_sl, num_segments=num_nodes) + + # Inverse square root, handling degree 0 + deg_out_inv_sqrt = tlx.pow(deg_out, -0.5) + # deg_out_inv_sqrt = tlx.where(tlx.is_finite(deg_out_inv_sqrt), deg_out_inv_sqrt, tlx.zeros_like(deg_out_inv_sqrt)) + deg_in_inv_sqrt = tlx.pow(deg_in, -0.5) + # deg_in_inv_sqrt = tlx.where(tlx.is_finite(deg_in_inv_sqrt), deg_in_inv_sqrt, tlx.zeros_like(deg_in_inv_sqrt)) + + # For deg_out_inv_sqrt + is_inf_deg_out = tlx.is_inf(deg_out_inv_sqrt) + is_nan_deg_out = tlx.is_nan(deg_out_inv_sqrt) + + deg_out_inv_sqrt_finite_mask_tlx = tlx.logical_and( + tlx.logical_not(is_inf_deg_out), + tlx.logical_not(is_nan_deg_out) + ) + + # For deg_in_inv_sqrt + is_inf_deg_in = tlx.is_inf(deg_in_inv_sqrt) + is_nan_deg_in = tlx.is_nan(deg_in_inv_sqrt) + deg_in_inv_sqrt_finite_mask_tlx = tlx.logical_and( + tlx.logical_not(is_inf_deg_in), + tlx.logical_not(is_nan_deg_in) + ) + + deg_out_inv_sqrt = tlx.where(deg_out_inv_sqrt_finite_mask_tlx, deg_out_inv_sqrt, tlx.zeros_like(deg_out_inv_sqrt)) + deg_in_inv_sqrt = tlx.where(deg_in_inv_sqrt_finite_mask_tlx, deg_in_inv_sqrt, tlx.zeros_like(deg_in_inv_sqrt)) + + # Normalized edge weights: deg_out_inv_sqrt[src] * deg_in_inv_sqrt[dst] + # (For A_ij, it's D_ii^-0.5 * A_ij * D_jj^-0.5) + # If edge is (j,i), then weight is deg(j)^-0.5 * deg(i)^-0.5 + norm_edge_weight = tlx.gather(deg_out_inv_sqrt, edge_src_sl) * tlx.gather(deg_in_inv_sqrt, edge_dst_sl) + + return edge_index_sl, norm_edge_weight \ No newline at end of file