Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/nasa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Regularizing GNNs via Consistency-Diversity Graph Augmentations (NASA)

This example implements the model from the paper: [Regularizing Graph Neural Networks via Consistency-Diversity Graph Augmentations](https://arxiv.org/abs/2110.07627) (AAAI 2022).

The implementation includes:
- `NR_Augmentor`: A graph transformation that implements the "Neighbor Replacement" (NR) augmentation strategy.
- `NASA_GCN`: A GCN-based model that incorporates the neighbor-constrained regularization loss (L_CR).

## How to Run

You can run the training script from the root directory of the GammaGL repository:

```bash
# Run on Cora dataset
python examples/nasa/nasa_gcn_trainer.py --dataset Cora

# Run on Citeseer dataset
python examples/nasa/nasa_gcn_trainer.py --dataset Citeseer

# Run on PubMed dataset
python examples/nasa/nasa_gcn_trainer.py --dataset PubMed
153 changes: 153 additions & 0 deletions examples/nasa/nasa_gcn_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import tensorlayerx as tlx
import argparse
import time

from gammagl.data import Graph
from gammagl.utils import mask_to_index
from gammagl.datasets import Planetoid
from gammagl.models import NASA_GCN
from gammagl.transforms import NR_Augmentor
from gammagl.utils import accuracy_tlx, compute_gcn_norm

def main(args):

try:
tlx.set_device(device='GPU', id=args.gpu_id)
print(f"Using GPU: {args.gpu_id}")
except:
tlx.set_device(device='CPU')
print("GPU not available, using CPU.")

try:
dataset = Planetoid(root=args.dataset_path, name=args.dataset)
except Exception as e:
print(f"Error loading dataset {args.dataset}: {e}")
print("Please ensure the dataset name is correct (Cora, Citeseer, PubMed) and it's downloadable.")
return

graph = dataset[0]
num_nodes = graph.num_nodes

graph_x = tlx.convert_to_tensor(graph.x, dtype=tlx.float32)
graph_y = tlx.convert_to_tensor(graph.y, dtype=tlx.int64)
graph_train_mask = tlx.convert_to_tensor(graph.train_mask, dtype=tlx.bool)
graph_val_mask = tlx.convert_to_tensor(graph.val_mask, dtype=tlx.bool)
graph_test_mask = tlx.convert_to_tensor(graph.test_mask, dtype=tlx.bool)

eval_edge_index, eval_edge_weight = compute_gcn_norm(
graph.edge_index, num_nodes, dtype=graph_x.dtype, add_self_loops_flag=True
)
eval_edge_index = tlx.convert_to_tensor(eval_edge_index)
eval_edge_weight = tlx.convert_to_tensor(eval_edge_weight)

# initialize
augmentor = NR_Augmentor(probability=args.nr_prob)

model = NASA_GCN(
feature_dim=dataset.num_node_features,
hidden_dim=args.hidden_dim,
num_classes=dataset.num_classes,
dropout_rate=args.dropout,
temp=args.temp,
alpha=args.alpha
)

optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.weight_decay)
train_weights = model.trainable_weights

best_val_acc = 0
best_test_acc = 0
best_epoch = 0

print("Starting training...")
for epoch in range(args.epochs):
epoch_start_time = time.time()
model.set_train()

# 1. Dynamic Augmentation
temp_original_graph_for_aug = Graph(x=graph_x, edge_index=graph.edge_index, num_nodes=num_nodes)
augmented_graph = augmentor.augment(temp_original_graph_for_aug)

# 2. Preprocess augmented graph for GCN
aug_edge_index, aug_edge_weight = compute_gcn_norm(
augmented_graph.edge_index,
tlx.get_tensor_shape(augmented_graph.x)[0],
dtype=augmented_graph.x.dtype,
add_self_loops_flag=True
)
aug_edge_index = tlx.convert_to_tensor(aug_edge_index)
aug_edge_weight = tlx.convert_to_tensor(aug_edge_weight)

import tensorflow as tf
with tf.GradientTape() as tape:
output_logits_aug = model(
augmented_graph.x,
aug_edge_index,
edge_weight=aug_edge_weight,
num_nodes=tlx.get_tensor_shape(augmented_graph.x)[0]
)

loss = model.compute_nasa_loss(
output_logits_aug,
augmented_graph,
graph_y,
graph_train_mask
)

gradients = tape.gradient(loss, train_weights)
optimizer.apply_gradients(zip(gradients, train_weights))

# evaluation
model.set_eval()

eval_logits = model.predict(
graph_x,
eval_edge_index,
edge_weight=eval_edge_weight,
num_nodes=num_nodes
)
eval_pred_softmax = tlx.softmax(eval_logits, axis=-1)

train_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_train_mask)),
tlx.gather(graph_y, mask_to_index(graph_train_mask)))
val_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_val_mask)),
tlx.gather(graph_y, mask_to_index(graph_val_mask)))
test_acc = accuracy_tlx(tlx.gather(eval_pred_softmax, mask_to_index(graph_test_mask)),
tlx.gather(graph_y, mask_to_index(graph_test_mask)))

epoch_duration = time.time() - epoch_start_time
print(f"Epoch {epoch+1:03d}/{args.epochs} | Loss: {loss.item():.4f} | "
f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f} | "
f"Time: {epoch_duration:.2f}s")

if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
best_epoch = epoch +1
# save the best model
# model.save_weights(f"nasa_gcn_{args.dataset}_best.npz")
# print(f"New best validation accuracy: {best_val_acc:.4f}, saving model.")

print("Training finished.")
print(f"Best Epoch: {best_epoch}, Best Val Acc: {best_val_acc:.4f}, Corresponding Test Acc: {best_test_acc:.4f}")


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="NASA GCN training with GammaGL")
parser.add_argument('--dataset', type=str, default='Cora', help='Dataset name (Cora, Citeseer, PubMed)')
parser.add_argument('--dataset_path', type=str, default='./data', help='Path to store/load datasets')
parser.add_argument('--epochs', type=int, default=500, help='Number of training epochs')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-3, help='Weight decay for Adam optimizer')
parser.add_argument('--hidden_dim', type=int, default=32, help='Number of hidden units in GCN')
parser.add_argument('--dropout', type=float, default=0.7, help='Dropout rate for GCN layers and input features')
# NASA specific hyperparameters
parser.add_argument('--alpha', type=float, default=1.0, help='Weight for L_CR loss component')
parser.add_argument('--temp', type=float, default=0.5, help='Temperature for sharpening pseudo-labels')
parser.add_argument('--nr_prob', type=float, default=0.5, help='Probability for Neighbor Replacement')

parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID to use, -1 for CPU')

cli_args = parser.parse_args()
print("Arguments:", cli_args)
main(cli_args)
2 changes: 2 additions & 0 deletions gammagl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from .adagad import PreModel, ReModel
from .dyfss import MoeSSL,VGAE,Discriminator,InnerProductDecoder
from .egt import EGTModel
from .nasa_gcn import NASA_GCN

__all__ = [
'HeCo',
Expand Down Expand Up @@ -149,6 +150,7 @@
'Discriminator',
'InnerProductDecoder',
'EGTModel',
'NASA_GCN',
]

classes = __all__
87 changes: 87 additions & 0 deletions gammagl/models/nasa_gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import tensorlayerx as tlx
from tensorlayerx.nn import Module as Model
from gammagl.layers.conv import GCNConv
from gammagl.utils import mask_to_index
# from gammagl.mpops import unsorted_segment_mean

class NASA_GCN(Model):
def __init__(self, feature_dim, hidden_dim, num_classes, dropout_rate, temp, alpha, name=None):
super().__init__(name=name)
self.conv1 = GCNConv(in_channels=feature_dim, out_channels=hidden_dim)
self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=num_classes)
self.dropout = tlx.layers.Dropout(p=dropout_rate)
self.temp = temp
self.alpha = alpha
self.elu = tlx.layers.ELU()

def _compute_L_CR(self, aug_graph_edge_index, aug_pred_softmax, num_nodes):
src, dst = aug_graph_edge_index[0], aug_graph_edge_index[1]

# 1. Compute average of neighbor predictions (ŷ_i)
aug_pred_softmax_src = tlx.gather(aug_pred_softmax, src)
#avg_pred = unsorted_segment_mean(data=aug_pred_softmax_src, segment_ids=dst, num_segments=num_nodes)
avg_pred = tlx.ops.unsorted_segment_mean(aug_pred_softmax_src, dst, num_segments=num_nodes)

# Handle nodes with no incoming messages if unsorted_segment_mean results in NaN/Inf
#avg_pred = tlx.where(tlx.is_finite(avg_pred), avg_pred, tlx.zeros_like(avg_pred))
is_inf_avg_pred = tlx.is_inf(avg_pred)
is_nan_avg_pred = tlx.is_nan(avg_pred)
avg_pred_finite_mask = tlx.logical_and(
tlx.logical_not(is_inf_avg_pred),
tlx.logical_not(is_nan_avg_pred)
)
avg_pred = tlx.where(avg_pred_finite_mask, avg_pred, tlx.zeros_like(avg_pred))

# 2. Sharpening (p_i)
avg_pred_eps = avg_pred + 1e-12
pow_avg_pred = tlx.pow(avg_pred_eps, 1.0 / self.temp)
sharp_pseudo_labels = pow_avg_pred / (tlx.reduce_sum(pow_avg_pred, axis=1, keepdims=True) + 1e-12)
sharp_pseudo_labels_detached = tlx.ops.stop_gradient(sharp_pseudo_labels)

# 3. Compute KL Divergence Loss
p_dst_detached = tlx.gather(sharp_pseudo_labels_detached, dst)
q_src_softmax = tlx.gather(aug_pred_softmax, src)

log_p_dst_detached = tlx.log(p_dst_detached + 1e-12)
log_q_src_softmax = tlx.log(q_src_softmax + 1e-12)

kl_div_elements = p_dst_detached * (log_p_dst_detached - log_q_src_softmax)
kl_div_per_edge = tlx.reduce_sum(kl_div_elements, axis=1)

if tlx.get_tensor_shape(kl_div_per_edge)[0] > 0:
loss_cr = tlx.reduce_mean(kl_div_per_edge)
else:
loss_cr = tlx.convert_to_tensor(0.0, dtype=tlx.float32)

return loss_cr

def forward(self, x, edge_index, edge_weight=None, num_nodes=None):
h = self.dropout(x)
h = self.conv1(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
h = self.elu(h)
h = self.dropout(h)
logits = self.conv2(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
return logits

def compute_nasa_loss(self, output_logits_aug, augmented_graph, original_graph_labels, original_graph_train_mask):
train_indices = mask_to_index(original_graph_train_mask)
gathered_logits_aug = tlx.gather(output_logits_aug, train_indices)
gathered_labels = tlx.gather(original_graph_labels, train_indices)

loss_ce = tlx.losses.softmax_cross_entropy_with_logits(gathered_logits_aug, gathered_labels)

aug_pred_softmax = tlx.softmax(output_logits_aug, axis=-1)
loss_cr = self._compute_L_CR(
augmented_graph.edge_index,
aug_pred_softmax,
tlx.get_tensor_shape(augmented_graph.x)[0]
)

total_loss = loss_ce + self.alpha * loss_cr
return total_loss

def predict(self, x, edge_index, edge_weight=None, num_nodes=None):
h = self.conv1(x, edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
h = self.elu(h)
logits = self.conv2(h, edge_index, edge_weight=edge_weight, num_nodes=num_nodes)
return logits
5 changes: 3 additions & 2 deletions gammagl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .random_link_split import RandomLinkSplit
from .vgae_pre import mask_test_edges, sparse_to_tuple
from .svd_feature_reduction import SVDFeatureReduction
from .nr_augmentor import NR_Augmentor

__all__ = [
'BaseTransform',
Expand All @@ -18,8 +19,8 @@
'RandomLinkSplit',
'mask_test_edges',
'sparse_to_tuple',
'SVDFeatureReduction'

'SVDFeatureReduction',
'NR_Augmentor'
]

classes = __all__
94 changes: 94 additions & 0 deletions gammagl/transforms/nr_augmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# gammagl/transforms/nr_augmentor.py

import tensorlayerx as tlx
import numpy as np
import random
from gammagl.data import Graph
from gammagl.utils import to_undirected, coalesce

class NR_Augmentor:
def __init__(self, probability=0.5):
"""
Neighbor Replacement Augmentor for GammaGL.
This augmentor implements the NeighborReplace (NR) strategy from the paper
"Regularizing Graph Neural Networks via Consistency-Diversity Graph Augmentations".

Args:
probability (float): Probability of replacing a 1-hop neighbor with a 2-hop neighbor.
"""
self.probability = probability

def _get_1hop_neighbors_dict(self, edge_index, num_nodes):
"""
Creates an adjacency list dictionary {node_idx: [neighbor1, neighbor2,...]}
from edge_index, containing unique 1-hop neighbors (excluding self-loops).
"""
adj = {i: set() for i in range(num_nodes)}
src_nodes, dst_nodes = tlx.convert_to_numpy(edge_index[0]), tlx.convert_to_numpy(edge_index[1])
for i in range(len(src_nodes)):
u, v = src_nodes[i], dst_nodes[i]
if u != v:
adj[u].add(v)
return {node_idx: list(neighbors) for node_idx, neighbors in adj.items()}

def __call__(self, original_graph: Graph) -> Graph:
"""
Applies Neighbor Replacement augmentation to the input graph.
Making the class callable is a common pattern for transforms.

Args:
original_graph (gammagl.data.Graph): The original graph object.

Returns:
gammagl.data.Graph: The augmented graph object.
"""
num_nodes = original_graph.num_nodes

adj_1hop_dict = self._get_1hop_neighbors_dict(original_graph.edge_index, num_nodes)

new_edge_src_list = []
new_edge_dst_list = []

for u_node_idx in range(num_nodes):
current_1hop_neighbors_of_u = adj_1hop_dict[u_node_idx]
if not current_1hop_neighbors_of_u:
continue

for v_neighbor_idx in current_1hop_neighbors_of_u:
if random.random() < self.probability:
potential_2hop_neighbors_of_u_via_v = [
vv_node for vv_node in adj_1hop_dict.get(v_neighbor_idx, [])
if vv_node != u_node_idx and vv_node != v_neighbor_idx
]

if potential_2hop_neighbors_of_u_via_v:
vv_chosen = random.choice(potential_2hop_neighbors_of_u_via_v)
new_edge_src_list.append(u_node_idx)
new_edge_dst_list.append(vv_chosen)
else:
new_edge_src_list.append(u_node_idx)
new_edge_dst_list.append(v_neighbor_idx)
else:
new_edge_src_list.append(u_node_idx)
new_edge_dst_list.append(v_neighbor_idx)

if not new_edge_src_list:
aug_edge_index = tlx.convert_to_tensor(np.array([[],[]]), dtype=tlx.int64)
else:
temp_edge_index = tlx.stack([
tlx.convert_to_tensor(new_edge_src_list, dtype=tlx.int64),
tlx.convert_to_tensor(new_edge_dst_list, dtype=tlx.int64)
])

undirected_temp_edge_index = to_undirected(temp_edge_index, num_nodes=num_nodes)
aug_edge_index = coalesce(undirected_temp_edge_index, num_nodes=num_nodes)

aug_graph = Graph(x=original_graph.x, edge_index=aug_edge_index, num_nodes=num_nodes)

# Copy other essential attributes if they exist
if hasattr(original_graph, 'y'): aug_graph.y = original_graph.y
if hasattr(original_graph, 'train_mask'): aug_graph.train_mask = original_graph.train_mask
if hasattr(original_graph, 'val_mask'): aug_graph.val_mask = original_graph.val_mask
if hasattr(original_graph, 'test_mask'): aug_graph.test_mask = original_graph.test_mask

return aug_graph
Loading