Skip to content

PearAnne/cgtl-2

Repository files navigation

HC-GAE: The Hierarchical Cluster-based Graph Auto-Encoder for Graph Representation Learning

项目简介

HC-GAE (Hierarchical Cluster-based Graph Auto-Encoder) 是一个用于图表示学习的图神经网络模型。该模型通过分层聚类方法实现图的自编码器架构,能够学习图结构数据的低维表示。

主要特点:

  • 基于分层聚类的图池化机制
  • 自编码器架构用于无监督图表示学习
  • 支持多种图数据集(Cora, Citeseer等)
  • 通过节点分类任务评估学习到的表示质量

架构组成

  1. 编码器-解码器结构
  • 编码器 (local_gcn.py:294-315): 通过图卷积层逐层对图进行粗化(coarsening)
  • 分层池化 (local_gcn.py:207-209): 使用 SubGraphs 模块在多个层级上将节点划分为簇
  • 解码器 (local_gcn.py:317-340): 从粗化的表示重构原始图结构
  1. 损失函数 (local_gcn.py:343-404)
  • 二元交叉熵(BCE): 用于图重构
  • JS散度: 用于邻接矩阵分布匹配

主要文件

  • main.py: 训练流程、数据加载、模型实例化、SVM节点分类评估
  • local_gcn.py: HCGAE核心模型实现
  • subgraph.py: 子图构建模块和专用图卷积
  • model.py: 基线GNN模型(GCN, SAGE, GIN)
  • utils.py: 边遮蔽、边分割、评估指标等工具

文件说明

main.py

  • 主程序入口文件
  • 包含训练和评估流程
  • 处理数据加载、模型训练、测试等完整流程
  • 支持命令行参数配置

local_gcn.py

  • HCGAE模型核心实现
  • 实现了分层聚类图自编码器
  • 包含图卷积层、编码器-解码器架构
  • 定义了损失函数和前向传播逻辑

model.py

  • 各种基线GNN模型实现
  • 包含GCN、SAGE等经典图神经网络模型
  • 提供不同的图卷积层实现

subgraph.py

  • 子图构建和处理模块
  • 实现子图构造和分配功能
  • 包含子图GCN卷积层实现

utils.py

  • 工具函数集合
  • 数据处理和评估相关的辅助函数
  • 包含边分割、评估指标计算等功能

依赖环境

  • PyTorch
  • PyTorch Geometric
  • NumPy
  • Scikit-learn
  • OGB (Open Graph Benchmark)

运行方式

基本运行命令

python main.py --use_sage=HCGAE --dataset=Cora --epochs=100

主要参数说明

  • --use_sage: 模型类型 (HCGAE)
  • --dataset: 数据集 (Cora, Citeseer)
  • --epochs: 训练轮数
  • --mask_ratio: 训练时掩码边的比例
  • --hidden_channels: 隐藏层维度大小

示例命令

# 在Cora数据集上训练HCGAE模型
python main.py --use_sage=HCGAE --dataset=Cora --epochs=100
python main.py --use_sage=HCGAE --dataset=Cora --lr=1e-2 --epochs=50 --batch_size=1024
python main.py --use_sage=HCGAE --dataset=Citeseer --lr=1e-2 --epochs=50 --batch_size=1024

# 在Citeseer数据集上训练HCGAE模型
python main.py --use_sage=HCGAE --dataset=Citeseer --epochs=100

模型架构

HCGAE模型采用编码器-解码器架构:

  1. 编码器: 通过图卷积层逐步聚合节点信息,生成图的低维表示
  2. 分层聚类池化: 使用分层聚类方法对图进行池化操作,保留重要结构信息
  3. 解码器: 从低维表示重构原始图结构
  4. 损失函数: 通过重构损失和JS散度损失优化模型参数

评估方法

模型通过节点分类任务评估学习到的图表示质量:

  • 使用SVM分类器进行节点分类
  • 采用5折交叉验证
  • 评估指标包括准确率、F1分数等

About

Continuous-Graph-Transfer-Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors