diff --git a/examples/hgat/hgat_trainer.py b/examples/hgat/hgat_trainer.py new file mode 100644 index 000000000..7e4f351b8 --- /dev/null +++ b/examples/hgat/hgat_trainer.py @@ -0,0 +1,183 @@ +# !/usr/bin/env python3 +# -*- coding:utf-8 -*- + +# @Time : 2022/04/16 25:16 +# @Author : Jingyu Huang +# @FileName: hgat_trainer.py +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '2' +os.environ['TL_BACKEND'] = 'tensorflow' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR +import numpy as np +import argparse +import tensorlayerx as tlx +import gammagl.transforms as T +from gammagl.datasets import AGNews,IMDB, OHSUMED, Twitter + + +from gammagl.models import HGATModel; +from gammagl.utils import mask_to_index, set_device +from tensorlayerx.model import TrainOneStep, WithLoss + +class SemiSpvzLoss(WithLoss): + def __init__(self, net, loss_fn): + super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, y, node_tpye): + logits = self.backbone_network(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict']) + train_logits = tlx.gather(logits[node_tpye], data['train_idx']) + train_y = tlx.gather(data['y'], data['train_idx']) + loss = self._loss_fn(train_logits, train_y) + return loss + + +def calculate_acc(logits, y, metrics): + """ + Args: + logits: node logits + y: node labels + metrics: tensorlayerx.metrics + + Returns: + rst + """ + + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst + + +def main(args): + # NOTE: ONLY IMDB DATASET + # If you want to execute HAN on other dataset (e.g. ACM), + # you will be needed to init `metepaths` + # and set `movie` string with proper values. + # path = osp.join(osp.dirname(osp.realpath(__file__)), '../IMDB') + if(args.dataset=="IMDB"): + dataset = IMDB(args.dataset_path) + graph = dataset[0] + print(graph) + y = graph['movie'].y + node_type = 'movie' + + + + if(args.dataset=="agnews"): + dataset = AGNews(args.dataset_path) + graph = dataset[0] + print(graph) + y = graph['text'].y + node_type = 'text' + + + if(args.dataset=="ohsumed"): + dataset = OHSUMED(args.dataset_path) + graph = dataset[0] + print(graph) + y = graph['documents'].y + node_type = 'documents' + + + + if(args.dataset=="twitter"): + dataset = Twitter(args.dataset_path) + graph = dataset[0] + print(graph) + y = graph['twitter'].y + node_type = 'twitter' + + + + # for mindspore, it should be passed into node indices + train_idx = mask_to_index(graph[node_type].train_mask) + test_idx = mask_to_index(graph[node_type].test_mask) + val_idx = mask_to_index(graph[node_type].val_mask) + node_type_list = graph.metadata()[0] + in_channel = {} + num_nodes_dict = {} + for node_type in node_type_list: + in_channel[node_type]=graph.x_dict[node_type].shape[1] + num_nodes_dict[node_type]=graph.x_dict[node_type].shape[0] + + + net = HGATModel( + in_channels=in_channel, + out_channels=len(np.unique(graph.y.cpu())), # graph.num_classes, + metadata=graph.metadata(), + drop_rate=args.drop_rate, + hidden_channels=args.hidden_dim, + name = 'hgat', + ) + + + optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef) + metrics = tlx.metrics.Accuracy() + train_weights = net.trainable_weights + + loss_func = tlx.losses.softmax_cross_entropy_with_logits + semi_spvz_loss = SemiSpvzLoss(net, loss_func) + train_one_step = TrainOneStep(semi_spvz_loss, optimizer, train_weights) + + data = { + "x_dict": graph.x_dict, + "y": y, + "edge_index_dict": graph.edge_index_dict, + "train_idx": train_idx, + "test_idx": test_idx, + "val_idx": val_idx, + "num_nodes_dict": num_nodes_dict, + } + print(np.unique(y.cpu())) + best_val_acc = 0 + + for epoch in range(args.n_epoch): + net.set_train() + train_loss = train_one_step(data, y, node_type) + net.set_eval() + + logits = net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict']) + val_logits = tlx.gather(logits[node_type], data['val_idx']) + val_y = tlx.gather(data['y'], data['val_idx']) + val_acc = calculate_acc(val_logits, val_y, metrics) + + print("Epoch [{:0>3d}] ".format(epoch + 1) + + " train_loss: {:.4f}".format(train_loss.item()) + # + " train_acc: {:.4f}".format(train_acc) + + " val_acc: {:.4f}".format(val_acc)) + + # save best model on evaluation set + if val_acc > best_val_acc: + best_val_acc = val_acc + net.save_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + + net.load_weights(args.best_model_path + net.name + ".npz", format='npz_dict') + net.set_eval() + logits = net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict']) + test_logits = tlx.gather(logits[node_type], data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + test_acc = calculate_acc(test_logits, test_y, metrics) + print("Test acc: {:.4f}".format(test_acc)) + + +if __name__ == '__main__': + # parameters setting + parser = argparse.ArgumentParser() + parser.add_argument("--lr", type=float, default=0.005, help="learnin rate") + parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch") + parser.add_argument("--hidden_dim", type=int, default=64, help="dimention of hidden layers") + parser.add_argument("--l2_coef", type=float, default=1e-3, help="l2 loss coeficient") + parser.add_argument("--heads", type=int, default=8, help="number of heads for stablization") + parser.add_argument("--drop_rate", type=float, default=0.6, help="drop_rate") + parser.add_argument("--gpu", type=int, default=0, help="gpu id") + parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset") + parser.add_argument('--dataset', type=str, default='IMDB', help='dataset') + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") + + args = parser.parse_args() + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + + main(args) diff --git a/examples/hgat/readme.md b/examples/hgat/readme.md new file mode 100644 index 000000000..9044b357e --- /dev/null +++ b/examples/hgat/readme.md @@ -0,0 +1,29 @@ +# Heterogeneous Graph Attention Network (HGAT) + +This is an implementation of `HAN` for heterogeneous graphs. + +- Paper link: [https://aclanthology.org/D19-1488/](https://aclanthology.org/D19-1488/) +- Author's code repo: [https://github.com/BUPT-GAMMA/HGAT](https://github.com/BUPT-GAMMA/HGAT). Note that the original code is + implemented with Tensorflow for the paper. + +## Usage + +`python hgat_trainer.py` for reproducing HGAT's work on IMDB. + + + +## Performance + + + +| Dataset |Paper(80% training) | Our(tf) | Our(th) | Our(pd) | +| ------- | ------------------ | ------- | ------- |-------- | +| AGNews | 72.10 | 63.80 | | | +| Ohsumed | 42.68 | 25.82 | | | +| Twitter | 63.21 | 61.06 | | | +| IMDB | | 57.71 | | | + +```bash +TL_BACKEND="tensorflow" python3 hgat_trainer.py --n_epoch 100 --lr 0.01 --l2_coef 0.0001 --drop_rate 0.8 + +``` \ No newline at end of file diff --git a/gammagl/datasets/__init__.py b/gammagl/datasets/__init__.py index 2dc598bba..bd64fd92d 100644 --- a/gammagl/datasets/__init__.py +++ b/gammagl/datasets/__init__.py @@ -22,6 +22,9 @@ from .molecule_net import MoleculeNet from .acm4heco import ACM4HeCo from .yelp import Yelp +from .agnews import AGNews +from .ohsumed import OHSUMED +from .twitter import Twitter __all__ = [ 'ACM4HeCo', @@ -46,7 +49,10 @@ 'WikiCS', 'MoleculeNet', 'NGSIM_US_101', - 'Yelp' + 'Yelp', + 'AGNews', + 'OHSUMED', + 'Twitter' ] classes = __all__ diff --git a/gammagl/datasets/agnews.py b/gammagl/datasets/agnews.py new file mode 100644 index 000000000..bbd969385 --- /dev/null +++ b/gammagl/datasets/agnews.py @@ -0,0 +1,81 @@ +import os +import os.path as osp +from itertools import product +from typing import Callable, List, Optional + +import numpy as np +import scipy.sparse as sp +import tensorlayerx as tlx + +from gammagl.data import (HeteroGraph, InMemoryDataset, download_url, + extract_zip) + +class AGNews(InMemoryDataset): + r"""AGNews dataset processed for use in GNN models.""" + + url = 'https://www.dropbox.com/scl/fi/m809k1xdqzf0rhdmb83jf/agnews.zip?rlkey=wrz4by7f4tvtsdte2scuiec5k&st=s3ty36oi&dl=1' + + def __init__(self, root: str = None, transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload=force_reload) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return [ + 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz', + 'labels.npy', 'train_val_test_idx.npz' + ] + + @property + def processed_file_names(self) -> str: + return tlx.BACKEND + 'data.pt' + + def download(self): + path = download_url(self.url, self.raw_dir) + extract_zip(path, self.raw_dir) + os.remove(path) + + def process(self): + data = HeteroGraph() + + node_types = ['text', 'topic', 'entity'] + for i, node_type in enumerate(node_types): + x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) + data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32) + y = np.load(osp.join(self.raw_dir, 'labels.npy')) + y = np.argmax(y,axis=1) + data['text'].y = tlx.convert_to_tensor(y, dtype=tlx.int64) + + split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) + for name in ['train', 'val', 'test']: + idx = split[f'{name}_idx'] + mask = np.zeros(data['text'].num_nodes, dtype=np.bool_) + mask[idx] = True + data['text'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool) + + + s = {} + N_m = data['text'].num_nodes + N_d = data['topic'].num_nodes + N_a = data['entity'].num_nodes + s['text'] = (0, N_m) + s['topic'] = (N_m, N_m + N_d) + s['entity'] = (N_m + N_d, N_m + N_d + N_a) + + A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr() + for src, dst in product(node_types, node_types): + A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() + if A_sub.nnz > 0: + row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64) + col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64) + data[src, dst].edge_index = tlx.stack([row, col], axis=0) + print(src+"____"+dst) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + self.save_data(self.collate([data]), self.processed_paths[0]) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/gammagl/datasets/ohsumed.py b/gammagl/datasets/ohsumed.py new file mode 100644 index 000000000..ed90e211f --- /dev/null +++ b/gammagl/datasets/ohsumed.py @@ -0,0 +1,82 @@ +import os +import os.path as osp +from itertools import product +from typing import Callable, List, Optional + +import numpy as np +import scipy.sparse as sp +import tensorlayerx as tlx + +from gammagl.data import (HeteroGraph, InMemoryDataset, download_url, + extract_zip) + +class OHSUMED(InMemoryDataset): + r"""AGNews dataset processed for use in GNN models.""" + + url = 'https://www.dropbox.com/scl/fi/di4u2apxat4v6oibq8j0q/ohsumed.zip?rlkey=hkbleedkz8bqw9p40y5zws425&st=yxj4kyzr&dl=1' + + def __init__(self, root: str = None, transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload=force_reload) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return [ + 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz', + 'labels.npy', 'train_val_test_idx.npz' + ] + + @property + def processed_file_names(self) -> str: + return tlx.BACKEND + 'data.pt' + + def download(self): + path = download_url(self.url, self.raw_dir) + extract_zip(path, self.raw_dir) + os.remove(path) + + def process(self): + data = HeteroGraph() + + node_types = ['documents', 'topics', 'words'] + for i, node_type in enumerate(node_types): + x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) + data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32) + + y = np.load(osp.join(self.raw_dir, 'labels.npy')) + y = np.argmax(y,axis=1) + data['documents'].y = tlx.convert_to_tensor(y, dtype=tlx.int64) + + split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) + for name in ['train', 'val', 'test']: + idx = split[f'{name}_idx'] + mask = np.zeros(data['documents'].num_nodes, dtype=np.bool_) + mask[idx] = True + data['documents'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool) + + + s = {} + N_m = data['documents'].num_nodes + N_d = data['topics'].num_nodes + N_a = data['words'].num_nodes + s['documents'] = (0, N_m) + s['topics'] = (N_m, N_m + N_d) + s['words'] = (N_m + N_d, N_m + N_d + N_a) + + A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr() + for src, dst in product(node_types, node_types): + A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() + if A_sub.nnz > 0: + row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64) + col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64) + data[src, dst].edge_index = tlx.stack([row, col], axis=0) + print(src+"____"+dst) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + self.save_data(self.collate([data]), self.processed_paths[0]) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/gammagl/datasets/twitter.py b/gammagl/datasets/twitter.py new file mode 100644 index 000000000..704d5d800 --- /dev/null +++ b/gammagl/datasets/twitter.py @@ -0,0 +1,82 @@ +import os +import os.path as osp +from itertools import product +from typing import Callable, List, Optional + +import numpy as np +import scipy.sparse as sp +import tensorlayerx as tlx + +from gammagl.data import (HeteroGraph, InMemoryDataset, download_url, + extract_zip) + +class Twitter(InMemoryDataset): + r"""AGNews dataset processed for use in GNN models.""" + + url = 'https://www.dropbox.com/scl/fi/uqiglprqpz6rytaoc4k5p/twitter.zip?rlkey=b8f9cltuus36hr0haqt0988o8&st=0cqjc9qx&dl=1' + + def __init__(self, root: str = None, transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, force_reload: bool = False): + super().__init__(root, transform, pre_transform, force_reload=force_reload) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return [ + 'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz', + 'labels.npy', 'train_val_test_idx.npz' + ] + + @property + def processed_file_names(self) -> str: + return tlx.BACKEND + 'data.pt' + + def download(self): + path = download_url(self.url, self.raw_dir) + extract_zip(path, self.raw_dir) + os.remove(path) + + def process(self): + data = HeteroGraph() + + node_types = ['twitter', 'topics', 'entity'] + for i, node_type in enumerate(node_types): + x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) + data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32) + + y = np.load(osp.join(self.raw_dir, 'labels.npy')) + y = np.argmax(y,axis=1) + data['twitter'].y = tlx.convert_to_tensor(y, dtype=tlx.int64) + + split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) + for name in ['train', 'val', 'test']: + idx = split[f'{name}_idx'] + mask = np.zeros(data['twitter'].num_nodes, dtype=np.bool_) + mask[idx] = True + data['twitter'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool) + + + s = {} + N_m = data['twitter'].num_nodes + N_d = data['topics'].num_nodes + N_a = data['entity'].num_nodes + s['twitter'] = (0, N_m) + s['topics'] = (N_m, N_m + N_d) + s['entity'] = (N_m + N_d, N_m + N_d + N_a) + + A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr() + for src, dst in product(node_types, node_types): + A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() + if A_sub.nnz > 0: + row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64) + col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64) + data[src, dst].edge_index = tlx.stack([row, col], axis=0) + print(src+"____"+dst) + + if self.pre_transform is not None: + data = self.pre_transform(data) + + self.save_data(self.collate([data]), self.processed_paths[0]) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/gammagl/layers/conv/__init__.py b/gammagl/layers/conv/__init__.py index 0162a6cb9..4efb5167e 100644 --- a/gammagl/layers/conv/__init__.py +++ b/gammagl/layers/conv/__init__.py @@ -33,6 +33,7 @@ from .magcl_conv import MAGCLConv from .fusedgat_conv import FusedGATConv from .hid_conv import Hid_conv +from .hgat_conv import HGATConv __all__ = [ 'MessagePassing', 'GCNConv', @@ -68,7 +69,8 @@ 'MAGCLConv', 'FusedGATConv', 'Hid_conv', - 'HEATlayer' + 'HEATlayer', + 'HGAT_conv' ] classes = __all__ diff --git a/gammagl/layers/conv/hgat_conv.py b/gammagl/layers/conv/hgat_conv.py new file mode 100644 index 000000000..e89bd946e --- /dev/null +++ b/gammagl/layers/conv/hgat_conv.py @@ -0,0 +1,121 @@ +import tensorlayerx as tlx +from tensorlayerx.nn import ModuleDict +from gammagl.layers.conv import MessagePassing +from gammagl.utils import segment_softmax +from gammagl.mpops import unsorted_segment_sum +class HGATConv(MessagePassing): + def __init__(self, + in_channels, + out_channels, + metadata, + negative_slope=0.2, + drop_rate=0.5): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.metadata = metadata + self.negetive_slop = negative_slope + self.dropout = tlx.layers.Dropout(drop_rate) + + for node_type in self.metadata[0]: + self.heterlinear[node_type] = tlx.layers.Linear(out_features=out_channels, + in_features=in_channels[node_type], + W_init='xavier_uniform') + + + self.Linear_dict_l = ModuleDict({}) + self.Linear_dict_r = ModuleDict({}) + + l_list = [] + r_list = [] + + for edge_type in metadata[1]: + src_type,_, dst_type = edge_type + if src_type not in r_list: + r_list.append(src_type) + if dst_type not in l_list: + l_list.append(dst_type) + + for dst_type in l_list: + print(dst_type) + self.Linear_dict_l[dst_type] = tlx.layers.Linear(out_features=1, + in_features=in_channels, + W_init='xavier_uniform') + for src_type in r_list: + print(src_type) + self.Linear_dict_r[src_type]= tlx.layers.Linear(out_features=1, + in_features=in_channels, + W_init='xavier_uniform') + + + self.nodeAttention = tlx.layers.Linear(out_features=1, + in_features=in_channels*2, + W_init='xavier_uniform') + + + self.leakyReLu = tlx.LeakyReLU(negative_slope=negative_slope) + + + def forward(self, x_dict, edge_index_dict, num_nodes_dict): + + for node_type, x_node in x_dict.items(): + x_dict[node_type]= self.dropout(self.heterlinear[node_type](x_node)) + + edge_pattern_dict={} + for node_type, value in x_dict.items(): + edge_pattern_dict[node_type]={} + for edge_type, edge_index in edge_index_dict.items(): + src_type, _, dst_type = edge_type + if(dst_type==node_type): + edge_pattern_dict[node_type][edge_type]=edge_index + + + + # There are several edge_type in each pattern, and their scr_type are the same + alpha_pattern_dict={} + beta_pattern_dict = {} + for node_type, pattern_dict in edge_pattern_dict.items(): + alpha_pattern_dict[node_type]={} + beta_pattern_dict[node_type]={} + Attention_value_dict = {} + for edge_type, edge_index in pattern_dict.items(): + src_type, _, dst_type = edge_type + src = edge_index[0,:] + dst = edge_index[1,:] + + message = unsorted_segment_sum(tlx.gather(x_dict[src_type],src),dst,x_dict[dst_type].shape[0]) + + h_l = self.Linear_dict_l[dst_type](x_dict[dst_type]) + h_r = self.Linear_dict_r[dst_type](message) + Type_Attention_Value = h_l + h_r # N values, N equals the number of edges + Type_Attention_Value = self.leakyReLu(Type_Attention_Value) + + Type_Attention_Value = tlx.exp(Type_Attention_Value) + Attention_value_dict[edge_type]=Type_Attention_Value + Attention_value_list = [value for value in Attention_value_dict.values()] + + Summation = tlx.reduce_sum(tlx.stack(Attention_value_list,axis=0),axis=0) + + + for edge_type, edge_index in pattern_dict.items(): + alpha_pattern_dict[node_type][edge_type] = Attention_value_dict[edge_type]/Summation # N values, N equals the number of edges + + out_dict={} + for node_type, pattern_dict in edge_pattern_dict.items(): + message_list= [] + for edge_type, edge_index in pattern_dict.items(): + alpha = alpha_pattern_dict[node_type][edge_type] + src_type, _, dst_type = edge_type + src = edge_index[0,:] + dst = edge_index[1,:] + # Use the broadcast mechanism alpha(N,1), followed by (N,hidden_dim*2), where N denotes the number of edges. + value = self.nodeAttention(tlx.gather(alpha,dst)*tlx.concat([tlx.gather(x_dict[dst_type],dst),tlx.gather(x_dict[src_type],src)],axis=1)) + value = self.leakyReLu(value) + value = segment_softmax(value,dst,num_segments=None) + beta_pattern_dict[node_type][edge_type] = value + message_list.append(unsorted_segment_sum(value*tlx.gather(x_dict[dst_type],dst),dst,x_dict[dst_type].shape[0])) + out_dict[node_type]=tlx.reduce_sum(tlx.stack(message_list,axis=0),axis=0) + + return out_dict + diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index c08b85c01..49037dd12 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -57,6 +57,7 @@ from .fusedgat import FusedGATModel from .hid_net import Hid_net from .gnnlfhf import GNNLFHFModel +from .hgat import HGATModel __all__ = [ 'HeCo', @@ -117,7 +118,8 @@ 'FusedGATModel', 'hid_net', 'HEAT', - 'GNNLFHFModel' + 'GNNLFHFModel', + 'HGATModel' ] classes = __all__ diff --git a/gammagl/models/hgat.py b/gammagl/models/hgat.py new file mode 100644 index 000000000..77d8cfb90 --- /dev/null +++ b/gammagl/models/hgat.py @@ -0,0 +1,50 @@ +import tensorlayerx as tlx +from tensorlayerx.nn import ModuleDict +from gammagl.layers.conv import MessagePassing, GCNConv, HGATConv,GATConv +from gammagl.utils import segment_softmax, to_homograph, to_heterograph +from gammagl.mpops import unsorted_segment_sum + +class HGATModel(tlx.nn.Module): + def __init__(self, + in_channels, + out_channels, + metadata, + drop_rate, + hidden_channels=128, + name=None): + + + super().__init__(name=name) + + self.hgat_conv = HGATConv(in_channels=hidden_channels, + out_channels=hidden_channels, + metadata=metadata, + drop_rate=drop_rate) + + # This two layers is equal to another HGAT Layer, as I followed the code of the paper. + # It seems like after the first HGAT layer, the node can be considered as a same type + # If we use another HGATConv again, there will be error, indicating: Grad is none + # This may because the design of the HGAT layer, if it is placed in the last layer, some of the parameter won't influence the output + self.gcn_conv = GCNConv(hidden_channels, + hidden_channels) + self.gat_conv = GATConv(in_channels=hidden_channels, + out_channels=hidden_channels, + dropout_rate=drop_rate) + self.linear = tlx.nn.Linear(out_features=out_channels, + in_features=hidden_channels) + self.softmax = tlx.nn.activation.Softmax() + + def forward(self, x_dict, edge_index_dict, num_nodes_dict): + out = self.hgat_conv(x_dict, edge_index_dict, num_nodes_dict) + + out, edge_index, edge_value = to_homograph(out,edge_index_dict,num_nodes_dict,None) + out = self.gcn_conv(out, edge_index) + + out = self.gat_conv(out, edge_index) + + out_dict = to_heterograph(out,edge_index,num_nodes_dict) + for node_type, _ in x_dict.items(): + out_dict[node_type] = self.softmax(self.linear(out_dict[node_type])) + + return out_dict + \ No newline at end of file diff --git a/gammagl/utils/__init__.py b/gammagl/utils/__init__.py index f57d8d09b..251f4e5d0 100644 --- a/gammagl/utils/__init__.py +++ b/gammagl/utils/__init__.py @@ -20,6 +20,8 @@ from .shortest_path import shortest_path_distance, batched_shortest_path_distance from .get_split import get_train_val_test_split from .get_laplacian import get_laplacian +from .homo_heter_mutual_convert import to_homograph, to_heterograph, add_num + __all__ = [ 'calc_A_norm_hat', @@ -46,7 +48,10 @@ 'shortest_path_distance', 'batched_shortest_path_distance', 'get_train_val_test_split', - 'get_laplacian' + 'get_laplacian', + 'to_homograph', + 'to_heterograph', + 'add_num' ] diff --git a/gammagl/utils/homo_heter_mutual_convert.py b/gammagl/utils/homo_heter_mutual_convert.py new file mode 100644 index 000000000..348d24ad4 --- /dev/null +++ b/gammagl/utils/homo_heter_mutual_convert.py @@ -0,0 +1,78 @@ +import tensorlayerx as tlx + +def to_homograph(x_dict ,edge_index_dict, num_nodes_dict, edge_value_dict): + + node_type_num = len(num_nodes_dict) + x_list = list(x_dict) + node_type_list = list(num_nodes_dict.keys()) + node_num_list = list(num_nodes_dict.values()) + add_num_list = add_num(node_num_list) + x_feature = x_dict[node_type_list[0]] + if(tlx.backend=="tensorflow"): + x_feature = tlx.stack(x_feature) + for i in node_type_list[1:]: + if(tlx.backend=="tensorflow"): + x_feature = tlx.concat((x_feature,tlx.stack(x_dict[i])), axis=0) + else: + x_feature = tlx.concat((x_feature,x_dict[i]), axis=0) + edge_type_list = list(edge_index_dict.keys()) + edge_index = edge_index_dict[edge_type_list[0]] + _,num = edge_index.shape + node_src = tlx.slice(edge_index,[0,0],[1,num]) + node_dst = tlx.slice(edge_index,[1,0],[1,num]) + node_src = tlx.add(node_src,add_num_list[node_type_list.index(edge_type_list[0][0])]) + node_dst = tlx.add(node_dst,add_num_list[node_type_list.index(edge_type_list[0][2])]) + if(tlx.backend=="tensorflow"): + node_src = tlx.stack(node_src) + node_dst = tlx.stack(node_dst) + edge_index = tlx.concat((node_src,node_dst),axis=0) + if(edge_value_dict!=None): + edge_value = edge_value_dict[edge_type_list[0]] + if(tlx.backend=="tensorflow"): + edge_value[0]=tlx.stack(edge_value[0]) + edge_value = edge_value[0] + for i in edge_type_list[1:]: + edge_index_tem = edge_index_dict[i] + _,num = edge_index_tem.shape + node_src = tlx.slice(edge_index_tem,[0,0],[1,num]) + node_dst = tlx.slice(edge_index_tem,[1,0],[1,num]) + node_src = tlx.add(node_src,add_num_list[node_type_list.index(i[0])]) + node_dst = tlx.add(node_dst,add_num_list[node_type_list.index(i[2])]) + if(tlx.backend=="tensorflow"): + node_src = tlx.stack(node_src) + node_dst = tlx.stack(node_dst) + edge_index_tem = tlx.concat((node_src,node_dst),axis=0) + if(tlx.backend=="tensorflow"): + edge_index_tem = tlx.stack(edge_index_tem) + edge_index = tlx.concat((edge_index,edge_index_tem),axis=1) + if(edge_value_dict!= None): + if(tlx.backend=="tensorflow"): + edge_value = tlx.concat((edge_value,tlx.stack(edge_value_dict[i][0])),axis=0) + else: + edge_value = tlx.concat((edge_value,edge_value_dict[i][0]),axis=0) + else: + edge_value = None + return [x_feature, edge_index, edge_value] + +def to_heterograph(x_value, edge_index, num_node_dict): + + x_dict = {} + node_type_list = list(num_node_dict.keys()) + node_num_list = list(num_node_dict.values()) + node_index_list = add_num(node_num_list) + type_num = len(node_type_list) + for node_type in node_type_list: + index = node_type_list.index(node_type) + if(index == type_num-1): + x_dict[node_type] = x_value[node_index_list[index]:,:] + else: + x_dict[node_type] = x_value[node_index_list[index]:node_index_list[index+1],:] + return x_dict + +def add_num(int_list): + out = [] + num = len(int_list) + out.append(0) + for i in range(num-1): + out.append(out[-1]+int_list[i]) + return out \ No newline at end of file