From 1eb1a19017d7b749d13b949904811c7a21e7eb7b Mon Sep 17 00:00:00 2001 From: lzd <1372248298@qq.com> Date: Thu, 4 Sep 2025 02:36:26 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=B0=86CoBFormer=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E7=BB=A7=E6=89=BF=E5=88=B0GammaGL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/cobformer/cobformer_trainer.py | 421 ++++++++++++++++++++++++ examples/cobformer/readme.md | 29 ++ gammagl/models/cobformer.py | 306 +++++++++++++++++ 3 files changed, 756 insertions(+) create mode 100644 examples/cobformer/cobformer_trainer.py create mode 100644 examples/cobformer/readme.md create mode 100644 gammagl/models/cobformer.py diff --git a/examples/cobformer/cobformer_trainer.py b/examples/cobformer/cobformer_trainer.py new file mode 100644 index 00000000..e6b1b8b9 --- /dev/null +++ b/examples/cobformer/cobformer_trainer.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +""" +@File : cobformer_trainer.py +@Time : 2024/09/04 2:33:00 +@Author : lzd +""" + +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR + +import argparse +import random +import numpy as np +import tensorlayerx as tlx +from gammagl.datasets import Planetoid +from gammagl.models import CoBFormerModel +from gammagl.utils import add_self_loops, mask_to_index +from tensorlayerx.model import TrainOneStep, WithLoss + + +def set_seed(seed=123): + """Set random seed for reproducibility""" + random.seed(seed) + np.random.seed(seed) + # Set seed for tensorlayerx backend + tlx.set_seed(seed) + + +class CoBFormerLoss(WithLoss): + def __init__(self, net, loss_fn): + super(CoBFormerLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, y): + pred1, pred2 = self.backbone_network(data['x'], data['patch'], data['edge_index']) + # For training, we use the combined loss function from the model + loss = self.backbone_network.loss(pred1, pred2, data['y'], data['train_mask']) + return loss + + +def calculate_f1(pred, label, num_classes, average='micro'): + """ + Calculate F1 score + Args: + pred: predicted labels + label: true labels + num_classes: number of classes + average: 'micro' or 'macro' + + Returns: + f1 score + """ + # Convert logits to predicted classes if needed + if len(tlx.get_tensor_shape(pred)) > 1 and tlx.get_tensor_shape(pred)[-1] > 1: + pred = tlx.argmax(pred, axis=-1) + + # Ensure pred and label are the same dtype + if pred.dtype != label.dtype: + label = tlx.cast(label, dtype=pred.dtype) + + # For micro F1, we need to flatten predictions and labels + if average == 'micro': + # Calculate micro F1 directly using tensor operations that avoid conversion issues + eq = (pred == label) + # Convert boolean tensor to float directly + correct = tlx.reduce_sum(tlx.cast(tlx.convert_to_tensor(eq), tlx.float32)) + total = tlx.get_tensor_shape(pred)[0] + # Add numerical stability with epsilon + epsilon = 1e-8 + return correct / (float(total) + epsilon) + else: + # For macro F1, calculate per-class F1 and average + # This is a simplified version - in practice you might want to use a more complete implementation + f1_scores = [] + for i in range(num_classes): + # True positives, false positives, false negatives for class i + # Use direct comparison that returns tensor operations + eq_pred = (pred == i) + eq_label = (label == i) + ne_label = (label != i) + ne_pred = (pred != i) + + # Use element-wise operations that work with tensors + tp = tlx.reduce_sum(tlx.cast(tlx.convert_to_tensor(eq_pred) & tlx.convert_to_tensor(eq_label), tlx.float32)) + fp = tlx.reduce_sum(tlx.cast(tlx.convert_to_tensor(eq_pred) & tlx.convert_to_tensor(ne_label), tlx.float32)) + fn = tlx.reduce_sum(tlx.cast(tlx.convert_to_tensor(ne_pred) & tlx.convert_to_tensor(eq_label), tlx.float32)) + + # Precision and recall for class i with improved numerical stability + epsilon = 1e-8 + precision = tp / (tp + fp + epsilon) + recall = tp / (tp + fn + epsilon) + + # F1 score for class i + f1 = 2 * precision * recall / (precision + recall + epsilon) + f1_scores.append(f1) + + # Return average F1 score + return tlx.reduce_mean(tlx.stack(f1_scores)) + +def calculate_acc(logits, y, metrics): + """ + Args: + logits: node logits + y: node labels + metrics: tensorlayerx.metrics + + Returns: + rst + """ + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst + + +def metis_partition(graph, n_patches): + """ + METIS partitioning that matches the original CoBFormer implementation. + """ + import numpy as np + try: + import networkx as nx + import metis + + num_nodes = graph.num_nodes + + # Handle case where n_patches is greater than num_nodes + if num_nodes < n_patches: + # Create a random partition when we have more patches than nodes + membership = np.random.randint(0, n_patches, num_nodes) + else: + # Create NetworkX graph from edge_index + # Move tensor to CPU first if it's on GPU + if hasattr(graph.edge_index, 'cpu'): + edge_index = graph.edge_index.cpu().numpy() + else: + edge_index = graph.edge_index.numpy() + + # Create NetworkX graph + G = nx.Graph() + G.add_nodes_from(range(num_nodes)) + + # Add edges to the graph + for i in range(edge_index.shape[1]): + G.add_edge(edge_index[0, i], edge_index[1, i]) + + # Use METIS to partition the graph + cuts, membership = metis.part_graph(G, n_patches, recursive=True) + + # Convert membership to patch format like the original implementation + patch = [] + max_patch_size = 0 + for i in range(n_patches): + patch_nodes = np.where(membership == i)[0].tolist() + patch.append(patch_nodes) + max_patch_size = max(max_patch_size, len(patch_nodes)) + + # Handle case where max_patch_size is 0 (all patches are empty) + if max_patch_size == 0: + # Create a simple partition where each node is in its own patch + max_patch_size = 1 + for i in range(min(n_patches, num_nodes)): + patch[i] = [i] + # Fill remaining patches with the last node + for i in range(num_nodes, n_patches): + patch[i] = [num_nodes - 1] + + # Pad each patch to max_patch_size + for i in range(len(patch)): + if len(patch[i]) == 0: + # If patch is empty, fill it with the last node + patch[i] = [num_nodes - 1] * max_patch_size + else: + patch[i] += [num_nodes - 1] * (max_patch_size - len(patch[i])) + + # Convert to tensor format + patch_tensors = [] + for p in patch: + patch_tensor = tlx.convert_to_tensor(p, dtype=tlx.int32) + patch_tensors.append(tlx.expand_dims(patch_tensor, axis=0)) + + # Concatenate all patches + if patch_tensors: + patch_result = tlx.concat(patch_tensors, axis=0) + else: + # Create a default patch if something went wrong + patch_result = tlx.convert_to_tensor([[num_nodes - 1]], dtype=tlx.int32) + + print(f"METIS patch tensor shape: {tlx.get_tensor_shape(patch_result)}") + print(f"METIS patch tensor dtype: {patch_result.dtype}") + return patch_result + + except ImportError: + # Fallback to simple partitioning if METIS or NetworkX is not available + print("METIS or NetworkX not available, using simple partitioning") + import numpy as np + + num_nodes = graph.num_nodes + # Simple partitioning for demonstration - in practice use actual METIS + patch_size = num_nodes // n_patches + patch = [] + + for i in range(n_patches): + start = i * patch_size + end = min((i + 1) * patch_size, num_nodes) + patch.append(list(range(start, end))) + + # Add remaining nodes to last patch + if num_nodes % n_patches != 0: + remaining = list(range(n_patches * patch_size, num_nodes)) + patch[-1].extend(remaining) + + # Convert to tensor format using tlx operations to avoid torch.tensor warnings + max_patch_size = max(len(p) for p in patch) if patch else 0 + + # Handle case where max_patch_size is 0 (all patches are empty) + if max_patch_size == 0: + # Create a simple partition where each node is in its own patch + max_patch_size = 1 + for i in range(min(n_patches, num_nodes)): + if i < len(patch): + patch[i] = [i] + else: + patch.append([i]) + # Fill remaining patches with the last node + for i in range(len(patch), n_patches): + patch.append([num_nodes - 1]) + max_patch_size = 1 + + # Pad each patch to max_patch_size + for i in range(len(patch)): + if len(patch[i]) == 0: + # If patch is empty, fill it with the last node + patch[i] = [num_nodes - 1] * max_patch_size + else: + patch[i] += [num_nodes - 1] * (max_patch_size - len(patch[i])) + + # Create patch tensor using tlx operations + patch_tensors = [] + for p in patch: + # Convert each patch to tensor + patch_tensor = tlx.convert_to_tensor(p, dtype=tlx.int32) + patch_tensors.append(tlx.expand_dims(patch_tensor, axis=0)) + + # Concatenate all patches + if patch_tensors: + patch_result = tlx.concat(patch_tensors, axis=0) + else: + # Create a default patch if something went wrong + patch_result = tlx.convert_to_tensor([[num_nodes - 1]], dtype=tlx.int32) + + print(f"Fallback patch tensor shape: {tlx.get_tensor_shape(patch_result)}") + print(f"Fallback patch tensor dtype: {patch_result.dtype}") + return patch_result + + +def main(args): + # load datasets + if str.lower(args.dataset) not in ['cora', 'pubmed', 'citeseer']: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + dataset = Planetoid(args.dataset_path, args.dataset) + graph = dataset[0] + edge_index, _ = add_self_loops(graph.edge_index, num_nodes=graph.num_nodes, n_loops=args.self_loops) + + # for mindspore, it should be passed into node indices + train_idx = mask_to_index(graph.train_mask) + test_idx = mask_to_index(graph.test_mask) + val_idx = mask_to_index(graph.val_mask) + + # Create patches for CoBFormer + print("Creating patches...") + patch = metis_partition(graph, args.n_patches) + print(f"Patch created with shape: {tlx.get_tensor_shape(patch)}") + + # build model + net = CoBFormerModel(num_nodes=graph.num_nodes, + in_channels=dataset.num_node_features, + hidden_channels=args.hidden_dim, + out_channels=dataset.num_classes, + layers=args.layers, + n_head=args.heads, + gcn_layers=args.gcn_layers, + alpha=args.alpha, + tau=args.tau, + dropout1=args.drop_rate1, + dropout2=args.drop_rate2, + gcn_use_bn=args.gcn_use_bn, + use_patch_attn=args.use_patch_attn, + name="CoBFormer") + + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) + metrics = tlx.metrics.Accuracy() + train_weights = net.trainable_weights + + loss_func = CoBFormerLoss(net, None) # Loss is computed within the model + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + + data = { + "x": graph.x, + "y": graph.y, + "edge_index": edge_index, + "patch": patch, + "train_idx": train_idx, + "test_idx": test_idx, + "val_idx": val_idx, + "num_nodes": graph.num_nodes, + "train_mask": graph.train_mask, + "test_mask": graph.test_mask, + "val_mask": graph.val_mask, + } + + best_val_acc = 0 + for epoch in range(args.n_epoch): + net.set_train() + train_loss = train_one_step(data, graph.y) + net.set_eval() + + # Get predictions from both models + pred1, pred2 = net(data['x'], data['patch'], data['edge_index']) + + # Evaluate first model (GCN) + val_logits1 = tlx.gather(pred1, data['val_idx']) + val_y = tlx.gather(data['y'], data['val_idx']) + num_classes = tlx.get_tensor_shape(data['y'])[-1] if len(tlx.get_tensor_shape(data['y'])) > 1 else int(tlx.reduce_max(data['y'])) + 1 + # Add numerical stability check + try: + val_micro_f1_1 = calculate_f1(val_logits1, val_y, num_classes, 'micro') + val_macro_f1_1 = calculate_f1(val_logits1, val_y, num_classes, 'macro') + except Exception as e: + print(f"Warning: Error calculating F1 for GCN model: {e}") + val_micro_f1_1 = 0.0 + val_macro_f1_1 = 0.0 + + # Evaluate second model (Transformer) + val_logits2 = tlx.gather(pred2, data['val_idx']) + try: + val_micro_f1_2 = calculate_f1(val_logits2, val_y, num_classes, 'micro') + val_macro_f1_2 = calculate_f1(val_logits2, val_y, num_classes, 'macro') + except Exception as e: + print(f"Warning: Error calculating F1 for Transformer model: {e}") + val_micro_f1_2 = 0.0 + val_macro_f1_2 = 0.0 + + print("Epoch [{:0>3d}] ".format(epoch + 1) + " train loss: {:.4f}".format(train_loss.item()) + " val micro_f1_1: {:.4f}".format(val_micro_f1_1) + " val macro_f1_1: {:.4f}".format(val_macro_f1_1) + " val micro_f1_2: {:.4f}".format(val_micro_f1_2) + " val macro_f1_2: {:.4f}".format(val_macro_f1_2)) + + # save best model on evaluation set (using the better of the two models) + val_acc = max(val_micro_f1_1, val_micro_f1_2) + if val_acc > best_val_acc: + best_val_acc = val_acc + net.save_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + + net.load_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + net.set_eval() + + # Get final predictions + pred1, pred2 = net(data['x'], data['patch'], data['edge_index']) + + # Test first model (GCN) + test_logits1 = tlx.gather(pred1, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + num_classes = tlx.get_tensor_shape(data['y'])[-1] if len(tlx.get_tensor_shape(data['y'])) > 1 else int(tlx.reduce_max(data['y'])) + 1 + try: + test_micro_f1_1 = calculate_f1(test_logits1, test_y, num_classes, 'micro') + test_macro_f1_1 = calculate_f1(test_logits1, test_y, num_classes, 'macro') + except Exception as e: + print(f"Warning: Error calculating test F1 for GCN model: {e}") + test_micro_f1_1 = 0.0 + test_macro_f1_1 = 0.0 + + # Test second model (Transformer) + test_logits2 = tlx.gather(pred2, data['test_idx']) + try: + test_micro_f1_2 = calculate_f1(test_logits2, test_y, num_classes, 'micro') + test_macro_f1_2 = calculate_f1(test_logits2, test_y, num_classes, 'macro') + except Exception as e: + print(f"Warning: Error calculating test F1 for Transformer model: {e}") + test_micro_f1_2 = 0.0 + test_macro_f1_2 = 0.0 + + print("Test Micro-F1 (GCN): {:.4f}".format(test_micro_f1_1)) + print("Test Macro-F1 (GCN): {:.4f}".format(test_macro_f1_1)) + print("Test Micro-F1 (Transformer): {:.4f}".format(test_micro_f1_2)) + print("Test Macro-F1 (Transformer): {:.4f}".format(test_macro_f1_2)) + + +if __name__ == "__main__": + # parameters setting + parser = argparse.ArgumentParser() + parser.add_argument("--lr", type=float, default=0.005, help="learning rate") + parser.add_argument("--n_epoch", type=int, default=500, help="number of epoch") + parser.add_argument("--hidden_dim", type=int, default=64, help="dimension of hidden layers") + parser.add_argument("--layers", type=int, default=2, help="number of transformer layers") + parser.add_argument("--heads", type=int, default=1, help="number of attention heads") + parser.add_argument("--gcn_layers", type=int, default=2, help="number of GCN layers") + parser.add_argument("--alpha", type=float, default=0.7, help="loss balancing parameter") + parser.add_argument("--tau", type=float, default=0.3, help="temperature parameter") + parser.add_argument("--drop_rate1", type=float, default=0.5, help="dropout rate 1") + parser.add_argument("--drop_rate2", type=float, default=0.1, help="dropout rate 2") + parser.add_argument("--l2_coef", type=float, default=1e-3, help="l2 loss coefficient") + parser.add_argument('--dataset', type=str, default='Pubmed', help='dataset') + parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset") + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") + parser.add_argument("--self_loops", type=int, default=1, help="number of graph self-loop") + parser.add_argument("--n_patches", type=int, default=224, help="number of patches for partitioning") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument('--gcn_use_bn', action='store_true', help='gcn use batch norm') + parser.add_argument('--use_patch_attn', action='store_true', help='transformer use patch attention') + parser.add_argument("--seed", type=int, default=123, help="random seed for reproducibility") + + args = parser.parse_args() + + # Set random seed for reproducibility + set_seed(args.seed) + + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + main(args) \ No newline at end of file diff --git a/examples/cobformer/readme.md b/examples/cobformer/readme.md new file mode 100644 index 00000000..b65a4909 --- /dev/null +++ b/examples/cobformer/readme.md @@ -0,0 +1,29 @@ +# CoBFormer: Less is More - On the Over-Globalizing Problem in Graph Transformers + +- Paper link: [https://arxiv.org/abs/2405.14786](https://arxiv.org/abs/2405.14786) +- Author's code repo: [https://github.com/Graph-COM/CoBFormer](https://github.com/Graph-COM/CoBFormer) Note that the original code is + implemented with PyTorch for the paper. + +## Dataset Statics + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Pubmed | 19,717 | 88,651 | 3 | + +Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). + +## How to run examples +```bash +TL_BACKEND="torch" python cobformer_trainer.py --dataset Pubmed --hidden_dim 64 --layers 1 --heads 1 --gcn_layers 2 --lr 0.005 --l2_coef 0.001 --drop_rate1 0.5 --drop_rate2 0.1 --alpha 0.7 --tau 0.3 --gpu 1 --n_epoch 200 --seed 42 +``` + +## Performance + +| Dataset | Metrics | Author's Code (CoB_G) | Author's Code (CoB_T) | GAMMAGL's Code (CoB_G) | GAMMAGL's Code (CoB_T) | +|:-------:|:-------:|:---------------------:|:---------------------:|:----------------------:|:----------------------:| +| PubMed | Mi-F1 | 80.52 | 81.42 | 86.20 | 64.30 | +| PubMed | Ma-F1 | 80.02 | 81.04 | 85.24 | 62.07 | + + + + diff --git a/gammagl/models/cobformer.py b/gammagl/models/cobformer.py new file mode 100644 index 00000000..d7bb730e --- /dev/null +++ b/gammagl/models/cobformer.py @@ -0,0 +1,306 @@ +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.layers.conv import GCNConv +from gammagl.utils import mask_to_index + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + attn = tlx.matmul(q / self.temperature, tlx.transpose(k, (0, 1, 3, 2))) + + if mask is not None: + # Convert mask to appropriate type and apply + # Use -1e9 for masked positions to ensure they get near-zero attention weights + mask = tlx.cast(mask, dtype=attn.dtype) + attn = tlx.where(mask == 0, tlx.ones_like(attn) * -1e9, attn) + + attn = self.dropout(tlx.softmax(attn, axis=-1)) + output = tlx.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, channels, dropout=0.1): + super(MultiHeadAttention, self).__init__() + + self.n_head = n_head + self.channels = channels + d_q = d_k = d_v = channels // n_head + + self.w_qs = nn.Linear(in_features=channels, out_features=channels, b_init=None) + self.w_ks = nn.Linear(in_features=channels, out_features=channels, b_init=None) + self.w_vs = nn.Linear(in_features=channels, out_features=channels, b_init=None) + self.fc = nn.Linear(in_features=channels, out_features=channels, b_init=None) + + self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + n_head = self.n_head + d_q = d_k = d_v = self.channels // n_head + B_q = tlx.get_tensor_shape(q)[0] + N_q = tlx.get_tensor_shape(q)[1] + B_k = tlx.get_tensor_shape(k)[0] + N_k = tlx.get_tensor_shape(k)[1] + B_v = tlx.get_tensor_shape(v)[0] + N_v = tlx.get_tensor_shape(v)[1] + + residual = q + + # Pass through the pre-attention projection: B * N x (h*dv) + # Separate different heads: B * N x h x dv + q = tlx.reshape(self.w_qs(q), (B_q, N_q, n_head, d_q)) + k = tlx.reshape(self.w_ks(k), (B_k, N_k, n_head, d_k)) + v = tlx.reshape(self.w_vs(v), (B_v, N_v, n_head, d_v)) + + # Transpose for attention dot product: B * h x N x dv + q, k, v = tlx.transpose(q, (0, 2, 1, 3)), tlx.transpose(k, (0, 2, 1, 3)), tlx.transpose(v, (0, 2, 1, 3)) + + # For head axis broadcasting. + if mask is not None: + mask = tlx.expand_dims(mask, 1) + + q, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: B x N x h x dv + # Combine the last two dimensions to concatenate all the heads together: B x N x (h*dv) + q = tlx.reshape(tlx.transpose(q, (0, 2, 1, 3)), (B_q, N_q, -1)) + q = self.fc(q) + q = q + residual + + return q, attn + + +class FFN(nn.Module): + ''' A two-feed-forward-layer module ''' + + def __init__(self, channels, dropout=0.1): + super(FFN, self).__init__() + self.lin1 = nn.Linear(in_features=channels, out_features=channels) + self.lin2 = nn.Linear(in_features=channels, out_features=channels) + self.layer_norm = nn.LayerNorm(normalized_shape=channels, epsilon=1e-6) + self.Dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.layer_norm(x) + x = self.Dropout(x) + x = tlx.relu(self.lin1(x)) + x = self.lin2(x) + residual + + return x + + +class SimpleFFN(nn.Module): + ''' A simple feed-forward module without LayerNorm and residual connection ''' + + def __init__(self, in_channels, hidden_channels): + super(SimpleFFN, self).__init__() + self.lin1 = nn.Linear(in_features=in_channels, out_features=hidden_channels) + self.lin2 = nn.Linear(in_features=hidden_channels, out_features=hidden_channels) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.dropout(x) + x = tlx.relu(self.lin1(x)) + # Return the output of the first linear layer as in the original implementation + return x + + +class BGALayer(nn.Module): + def __init__(self, n_head, channels, num_nodes, use_patch_attn=True, dropout=0.1): + super(BGALayer, self).__init__() + self.num_nodes = num_nodes + self.node_norm = nn.LayerNorm(normalized_shape=channels) + self.node_transformer = MultiHeadAttention(n_head, channels, dropout) + self.patch_norm = nn.LayerNorm(normalized_shape=channels) + self.patch_transformer = MultiHeadAttention(n_head, channels, dropout) + self.node_ffn = FFN(channels, dropout) + self.patch_ffn = FFN(channels, dropout) + self.fuse_lin = nn.Linear(in_features=2 * channels, out_features=channels) + self.use_patch_attn = use_patch_attn + self.attn = None + + def forward(self, x, patch, attn_mask=None, need_attn=False): + x = self.node_norm(x) + + # More efficient patch processing using gather and scatter operations + patch_shape = tlx.get_tensor_shape(patch) + + # Flatten patch indices for efficient gathering + patch_flat = tlx.reshape(patch, [-1]) + + # Gather all patch features at once using the correct gather interface + patch_features_flat = tlx.gather(x, patch_flat) + + # Reshape back to [num_patches, max_patch_size, feature_dim] + patch_x = tlx.reshape(patch_features_flat, [patch_shape[0], patch_shape[1], -1]) + + patch_x, attn = self.node_transformer(patch_x, patch_x, patch_x, attn_mask) + patch_x = self.node_ffn(patch_x) + + if need_attn: + # Note: This is a simplified version without the complex attention tracking + self.attn = attn + + if self.use_patch_attn: + # Mean pooling across patch dimension + patch_mean = tlx.reduce_mean(patch_x, axis=1, keepdims=False) + p = tlx.expand_dims(self.patch_norm(patch_mean), axis=0) + p, _ = self.patch_transformer(p, p, p) + p = tlx.transpose(self.patch_ffn(p), (1, 0, 2)) + + # Repeat to match patch dimensions + p = tlx.tile(p, (1, tlx.get_tensor_shape(patch_x)[1], 1)) + z = tlx.concat([patch_x, p], axis=2) + patch_x = tlx.relu(self.fuse_lin(z)) + patch_x + + # More efficient scatter back to original tensor + # Flatten patch_x for scattering + patch_x_flat = tlx.reshape(patch_x, [-1, tlx.get_tensor_shape(patch_x)[-1]]) + + # Flatten patch indices + patch_flat = tlx.reshape(patch, [-1]) + + # Update x at the patch node positions using tensor_scatter_nd_update + x = tlx.tensor_scatter_nd_update(x, tlx.expand_dims(patch_flat, axis=-1), patch_x_flat) + + return x + + +class BGA(nn.Module): + def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, + layers: int, n_head: int, use_patch_attn=True, dropout1=0.5, dropout2=0.1): + super(BGA, self).__init__() + self.layers = layers + self.n_head = n_head + self.num_nodes = num_nodes + self.dropout = nn.Dropout(dropout1) + self.attribute_encoder = SimpleFFN(in_channels, hidden_channels) + self.BGALayers = nn.ModuleList() + for _ in range(0, layers): + self.BGALayers.append( + BGALayer(n_head, hidden_channels, num_nodes, use_patch_attn, dropout=dropout2)) + self.classifier = nn.Linear(in_features=hidden_channels, out_features=out_channels) + + def forward(self, x, patch, need_attn=False): + # Create attention mask more efficiently + # Use boolean mask first, then convert to float for matmul + patch_mask = patch != self.num_nodes - 1 # Boolean mask + patch_mask = tlx.cast(patch_mask, dtype=tlx.float32) # Convert to float + patch_mask = tlx.expand_dims(patch_mask, axis=-1) # Add dimension for matmul + attn_mask = tlx.matmul(patch_mask, tlx.transpose(patch_mask, (0, 2, 1))) # Create attention mask + # Convert to int mask to match original implementation + attn_mask = tlx.cast(attn_mask, dtype=tlx.int32) + + x = self.attribute_encoder(x) + for i in range(0, self.layers): + x = self.BGALayers[i](x, patch, attn_mask, need_attn) + x = self.dropout(x) + x = self.classifier(x) + return x + + +class GCN(nn.Module): + def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, + activation=None, k: int = 2, use_bn=False): + super(GCN, self).__init__() + assert k > 1, "k must > 1 !!" + self.use_bn = use_bn + self.k = k + self.conv = nn.ModuleList([GCNConv(in_channels, hidden_channels)]) + self.bns = nn.ModuleList([nn.BatchNorm1d(num_features=hidden_channels)]) + for _ in range(1, k - 1): + self.conv.append(GCNConv(hidden_channels, hidden_channels)) + self.bns.append(nn.BatchNorm1d(num_features=hidden_channels)) + self.conv.append(GCNConv(hidden_channels, out_channels)) + self.dropout = nn.Dropout(p=0.5) + + if activation is None: + self.activation = tlx.relu + else: + self.activation = activation + + def forward(self, x, edge_index): + for i in range(self.k - 1): + x = self.conv[i](x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = self.dropout(x) + return self.conv[-1](x, edge_index) + + +class CoBFormerModel(nn.Module): + def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, + activation=None, gcn_layers: int = 2, gcn_type: int = 1, layers: int = 1, + n_head: int = 4, dropout1=0.5, dropout2=0.1, + alpha=0.8, tau=0.5, gcn_use_bn=False, use_patch_attn=True, name=None): + super(CoBFormerModel, self).__init__(name=name) + self.alpha = alpha + self.tau = tau + self.layers = layers + self.n_head = n_head + self.num_nodes = num_nodes + self.activation = activation + self.dropout = nn.Dropout(dropout1) + + if gcn_type == 1: + self.gcn = GCN(in_channels, hidden_channels, out_channels, activation, k=gcn_layers, use_bn=gcn_use_bn) + else: + # For simplicity, we'll use the same GCN for both types + self.gcn = GCN(in_channels, hidden_channels, out_channels, activation, k=gcn_layers, use_bn=gcn_use_bn) + + self.bga = BGA(num_nodes, in_channels, hidden_channels, out_channels, layers, n_head, + use_patch_attn, dropout1, dropout2) + + def forward(self, x, patch, edge_index, need_attn=False): + z1 = self.gcn(x, edge_index) + z2 = self.bga(x, patch, need_attn) + return z1, z2 + + def loss(self, pred1, pred2, label, mask): + # Convert one-hot label to indices if needed + if len(tlx.get_tensor_shape(label)) > 1 and tlx.get_tensor_shape(label)[-1] > 1: + label = tlx.argmax(label, axis=-1) + + # Convert mask to the same dtype as predictions to avoid type conversion issues + mask = tlx.cast(mask, dtype=tlx.float32) + + # For labeled nodes - use proper tensor handling with masking + l1 = tlx.losses.softmax_cross_entropy_with_logits(pred1, label) + l1_masked = l1 * mask + l1 = tlx.reduce_sum(l1_masked) / (tlx.reduce_sum(mask) + 1e-8) + + l2 = tlx.losses.softmax_cross_entropy_with_logits(pred2, label) + l2_masked = l2 * mask + l2 = tlx.reduce_sum(l2_masked) / (tlx.reduce_sum(mask) + 1e-8) + + pred1_scaled = pred1 * self.tau + pred2_scaled = pred2 * self.tau + + # For unlabeled nodes, use softmax of the other prediction + pred2_softmax = tlx.softmax(pred2_scaled, axis=-1) + pred1_softmax = tlx.softmax(pred1_scaled, axis=-1) + + l3 = tlx.losses.softmax_cross_entropy_with_logits(pred1_scaled, pred2_softmax) + not_mask = 1.0 - mask + l3_masked = l3 * not_mask + l3 = tlx.reduce_sum(l3_masked) / (tlx.reduce_sum(not_mask) + 1e-8) + + l4 = tlx.losses.softmax_cross_entropy_with_logits(pred2_scaled, pred1_softmax) + l4_masked = l4 * not_mask + l4 = tlx.reduce_sum(l4_masked) / (tlx.reduce_sum(not_mask) + 1e-8) + + loss = self.alpha * (l1 + l2) + (1 - self.alpha) * (l3 + l4) + return loss \ No newline at end of file From 6f5fbe06081961a4375fa8a7f8e0bee86d9fa4dc Mon Sep 17 00:00:00 2001 From: lzd <1372248298@qq.com> Date: Thu, 4 Sep 2025 02:41:12 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=B0=86CoBFormer=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E7=BB=A7=E6=89=BF=E5=88=B0GammaGL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/cobformer/cobformer_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cobformer/cobformer_trainer.py b/examples/cobformer/cobformer_trainer.py index e6b1b8b9..1ae76fa1 100644 --- a/examples/cobformer/cobformer_trainer.py +++ b/examples/cobformer/cobformer_trainer.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- """ @File : cobformer_trainer.py -@Time : 2024/09/04 2:33:00 +@Time : 2025/09/04 2:33:00 @Author : lzd """