diff --git a/examples/cobformer/Data/data_utils.py b/examples/cobformer/Data/data_utils.py new file mode 100644 index 00000000..b9a4cbe9 --- /dev/null +++ b/examples/cobformer/Data/data_utils.py @@ -0,0 +1,167 @@ +import torch +import torch.nn.functional as F +import numpy as np +import networkx as nx +import metis +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_geometric.utils import add_remaining_self_loops +from torch_geometric.utils import scatter + + +def rand_train_test_idx(label, train_prop=.5, valid_prop=.25, ignore_negative=True): + """ randomly splits label into train/valid/test splits """ + if ignore_negative: + labeled_nodes = torch.where(label != -1)[0] + else: + labeled_nodes = label + + n = labeled_nodes.shape[0] + train_num = int(n * train_prop) + valid_num = int(n * valid_prop) + + perm = torch.as_tensor(np.random.permutation(n)) + + train_indices = perm[:train_num] + val_indices = perm[train_num:train_num + valid_num] + test_indices = perm[train_num + valid_num:] + + if not ignore_negative: + return train_indices, val_indices, test_indices + + train_idx = labeled_nodes[train_indices] + valid_idx = labeled_nodes[val_indices] + test_idx = labeled_nodes[test_indices] + train_mask = torch.zeros_like(label, dtype=torch.bool) + train_mask[train_idx] = True + valid_mask = torch.zeros_like(label, dtype=torch.bool) + valid_mask[valid_idx] = True + test_mask = torch.zeros_like(label, dtype=torch.bool) + test_mask[test_idx] = True + + return train_mask, valid_mask, test_mask + + +def load_fixed_splits(data_dir, dataset, name, protocol): + splits_lst = [] + if name in ['Cora', 'CiteSeer', 'PubMed', 'ogbn-arxiv', 'ogbn-products'] and protocol == 'semi': + splits = {} + splits['train'] = torch.as_tensor(dataset.train_mask) + splits['valid'] = torch.as_tensor(dataset.valid_mask) + splits['test'] = torch.as_tensor(dataset.test_mask) + splits['train'] = F.pad(splits['train'], [0, 1]) + splits['valid'] = F.pad(splits['valid'], [0, 1]) + splits['test'] = F.pad(splits['test'], [0, 1]) + splits_lst.append(splits) + elif name in ['film', 'deezer']: + for i in range(10): + splits_file_path = '{}/{}'.format(data_dir, name) + '_split_50_25_' + str(i) + '.npz' + splits = {} + with np.load(splits_file_path) as splits_file: + splits['train'] = torch.BoolTensor(splits_file['train_mask']) + splits['valid'] = torch.BoolTensor(splits_file['val_mask']) + splits['test'] = torch.BoolTensor(splits_file['test_mask']) + splits['train'] = F.pad(splits['train'], [0, 1]) + splits['valid'] = F.pad(splits['valid'], [0, 1]) + splits['test'] = F.pad(splits['test'], [0, 1]) + splits_lst.append(splits) + else: + raise NotImplementedError + + return splits_lst + + +def class_rand_splits(label, label_num_per_class): + train_idx, non_train_idx = [], [] + idx = torch.arange(label.shape[0]) + class_list = label.squeeze().unique() + valid_num, test_num = 500, 1000 + for i in range(class_list.shape[0]): + c_i = class_list[i] + idx_i = idx[label.squeeze() == c_i] + n_i = idx_i.shape[0] + rand_idx = idx_i[torch.randperm(n_i)] + train_idx += rand_idx[:label_num_per_class].tolist() + non_train_idx += rand_idx[label_num_per_class:].tolist() + train_idx = torch.as_tensor(train_idx) + non_train_idx = torch.as_tensor(non_train_idx) + non_train_idx = non_train_idx[torch.randperm(non_train_idx.shape[0])] + valid_idx, test_idx = non_train_idx[:valid_num], non_train_idx[valid_num:valid_num + test_num] + train_mask = torch.zeros_like(label, dtype=torch.bool) + train_mask[train_idx] = True + valid_mask = torch.zeros_like(label, dtype=torch.bool) + valid_mask[valid_idx] = True + test_mask = torch.zeros_like(label, dtype=torch.bool) + test_mask[test_idx] = True + + return train_mask, valid_mask, test_mask + + +def metis_partition(g, n_patches=50): + if g['num_nodes'] < n_patches: + membership = torch.randperm(n_patches) + else: + # data augmentation + adjlist = g['edge_index'].t() + G = nx.Graph() + G.add_nodes_from(np.arange(g['num_nodes'])) + G.add_edges_from(adjlist.tolist()) + # metis partition + cuts, membership = metis.part_graph(G, n_patches, recursive=True) + + assert len(membership) >= g['num_nodes'] + membership = torch.tensor(membership[:g['num_nodes']]) + + + patch = [] + max_patch_size = -1 + for i in range(n_patches): + patch.append(list()) + patch[-1] = torch.where(membership == i)[0].tolist() + max_patch_size = max(max_patch_size, len(patch[-1])) + + for i in range(len(patch)): + l = len(patch[i]) + if l < max_patch_size: + patch[i] += [g['num_nodes']] * (max_patch_size - l) + + patch = torch.tensor(patch) + + return patch + + +def patch2batch(g, node_mask): + patches = node_mask.shape[0] + max_patch_size = node_mask.sum(dim=1).max() + all_nodes = torch.tensor(range(g['num_nodes'])) + batch_node_list = list() + for i in range(patches): + patch_nodes = all_nodes[node_mask[i, :]].tolist() + l = len(patch_nodes) + if l < max_patch_size: + patch_nodes += [g['num_nodes']] * (max_patch_size - l) + batch_node_list.append(patch_nodes) + + batch = torch.tensor(batch_node_list) + return batch + + +def norm(edge_index, num_nodes=None, edge_weight=None): + num_nodes = maybe_num_nodes(edge_index, num_nodes) + fill_value = 1. + + edge_index, edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + if edge_weight is None: + edge_weight = torch.ones((edge_index.size(1),), dtype=torch.float, + device=edge_index.device) + + row, col = edge_index[0], edge_index[1] + idx = col + deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum') + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + # return torch.sparse_coo_tensor(edge_index, edge_weight) + return edge_index, edge_weight diff --git a/examples/cobformer/Data/get_batch_data.py b/examples/cobformer/Data/get_batch_data.py new file mode 100644 index 00000000..de2565ec --- /dev/null +++ b/examples/cobformer/Data/get_batch_data.py @@ -0,0 +1,124 @@ +import torch +import torch.nn.functional as F +import numpy as np +import torch_geometric +import networkx as nx +import metis +from torch_geometric.datasets import Planetoid +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_geometric.utils import add_remaining_self_loops, to_undirected, remove_self_loops, add_self_loops +from torch_geometric.utils import scatter +import torch_geometric.transforms as T +from infomap import Infomap +from Data.data_utils import * +from ogb.nodeproppred import NodePropPredDataset +import scipy +import scipy.io +import scipy.sparse as sp +import os + + +class NCDataset(object): + def __init__(self, name): + """ + based off of ogb NodePropPredDataset + https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py + Gives torch tensors instead of numpy arrays + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + + Usage after construction: + + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + graph, label = dataset[0] + + Where the graph is a dictionary of the following form: + dataset.graph = {'edge_index': edge_index, + 'edge_feat': None, + 'node_feat': node_feat, + 'num_nodes': num_nodes} + For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/ + + """ + + self.name = name # original name, e.g., ogbn-proteins + self.graph = {} + self.label = None + + def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25, label_num_per_class=20): + """ + split_type: 'random' for random splitting, 'class' for splitting with equal node num per class + train_prop: The proportion of dataset for train split. Between 0 and 1. + valid_prop: The proportion of dataset for validation split. Between 0 and 1. + label_num_per_class: num of nodes per class + """ + + if split_type == 'random': + ignore_negative = False if self.name == 'ogbn-proteins' else True + train_mask, valid_mask, test_mask = rand_train_test_idx( + self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative) + train_mask = F.pad(train_mask, [0, 1]) + valid_mask = F.pad(valid_mask, [0, 1]) + test_mask = F.pad(test_mask, [0, 1]) + split_idx = {'train': train_mask, + 'valid': valid_mask, + 'test': test_mask} + elif split_type == 'class': + train_mask, valid_mask, test_mask = class_rand_splits(self.label, label_num_per_class=label_num_per_class) + train_mask = F.pad(train_mask, [0, 1]) + valid_mask = F.pad(valid_mask, [0, 1]) + test_mask = F.pad(test_mask, [0, 1]) + split_idx = {'train': train_mask, + 'valid': valid_mask, + 'test': test_mask} + return split_idx + + def partition_patch(self, n_patches): + node_mask = metis_partition(g=self.graph, n_patches=n_patches) + patch = patch2batch(self.graph, node_mask) + self.graph['num_nodes'] += 1 + self.graph['node_feat'] = F.pad(self.graph['node_feat'], [0, 0, 0, 1]) + self.label = F.pad(self.label, [0, 1]) + return patch + + def __getitem__(self, idx): + assert idx == 0, 'This dataset has only one graph' + return self.graph, self.label + + def __len__(self): + return 1 + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, len(self)) + + +def get_data_batch(path, name, batch_size=100000): + if name in ('ogbn-products'): + dataset = load_ogb_dataset(path, name, batch_size) + + return dataset + + +def load_ogb_dataset(data_dir, name, batch_size): + dataset = NCDataset(name) + ogb_dataset = NodePropPredDataset(name=name, root=f'{data_dir}/ogb') + graph = ogb_dataset.graph + graph['edge_index'] = torch.as_tensor(graph['edge_index']) + graph['node_feat'] = torch.as_tensor(graph['node_feat']) + + label = torch.as_tensor(ogb_dataset.labels).squeeze(-1) + + split_idx = ogb_dataset.get_idx_split() + train_mask = torch.zeros_like(label, dtype=torch.bool) + train_mask[split_idx['train']] = True + valid_mask = torch.zeros_like(label, dtype=torch.bool) + valid_mask[split_idx['valid']] = True + test_mask = torch.zeros_like(label, dtype=torch.bool) + test_mask[split_idx['test']] = True + + graph['edge_index'] = to_undirected(graph['edge_index']) + + return dataset \ No newline at end of file diff --git a/examples/cobformer/Data/get_data.py b/examples/cobformer/Data/get_data.py new file mode 100644 index 00000000..aa53bb19 --- /dev/null +++ b/examples/cobformer/Data/get_data.py @@ -0,0 +1,270 @@ +import torch +import torch.nn.functional as F +import numpy as np +import torch_geometric +import networkx as nx +import metis +from torch_geometric.datasets import Planetoid +from torch_geometric.utils.num_nodes import maybe_num_nodes +from torch_geometric.utils import add_remaining_self_loops, to_undirected, remove_self_loops, add_self_loops +from torch_geometric.utils import scatter +import torch_geometric.transforms as T +from infomap import Infomap +from Data.data_utils import * +from ogb.nodeproppred import NodePropPredDataset +import scipy +import scipy.io +import scipy.sparse as sp +import os + + +class NCDataset(object): + def __init__(self, name): + """ + based off of ogb NodePropPredDataset + https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py + Gives torch tensors instead of numpy arrays + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + + Usage after construction: + + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + graph, label = dataset[0] + + Where the graph is a dictionary of the following form: + dataset.graph = {'edge_index': edge_index, + 'edge_feat': None, + 'node_feat': node_feat, + 'num_nodes': num_nodes} + For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/ + + """ + + self.name = name # original name, e.g., ogbn-proteins + self.graph = {} + self.label = None + + def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25, label_num_per_class=20): + """ + split_type: 'random' for random splitting, 'class' for splitting with equal node num per class + train_prop: The proportion of dataset for train split. Between 0 and 1. + valid_prop: The proportion of dataset for validation split. Between 0 and 1. + label_num_per_class: num of nodes per class + """ + + if split_type == 'random': + ignore_negative = False if self.name == 'ogbn-proteins' else True + train_mask, valid_mask, test_mask = rand_train_test_idx( + self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative) + train_mask = F.pad(train_mask, [0, 1]) + valid_mask = F.pad(valid_mask, [0, 1]) + test_mask = F.pad(test_mask, [0, 1]) + split_idx = {'train': train_mask, + 'valid': valid_mask, + 'test': test_mask} + elif split_type == 'class': + train_mask, valid_mask, test_mask = class_rand_splits(self.label, label_num_per_class=label_num_per_class) + train_mask = F.pad(train_mask, [0, 1]) + valid_mask = F.pad(valid_mask, [0, 1]) + test_mask = F.pad(test_mask, [0, 1]) + split_idx = {'train': train_mask, + 'valid': valid_mask, + 'test': test_mask} + return split_idx + + def partition_patch(self, n_patches, load_path=None): + if load_path is not None: + patch = torch.load(load_path) + else: + if n_patches == 1: + patch = torch.tensor(range(self.graph['num_nodes'] + 1)).unsqueeze(dim=0) + else: + patch = metis_partition(g=self.graph, n_patches=n_patches) + print('metis done!!!') + print('patch done!!!') + self.graph['num_nodes'] += 1 + self.graph['node_feat'] = F.pad(self.graph['node_feat'], [0, 0, 0, 1]) + self.label = F.pad(self.label, [0, 1]) + return patch + + def __getitem__(self, idx): + assert idx == 0, 'This dataset has only one graph' + return self.graph, self.label + + def __len__(self): + return 1 + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, len(self)) + + +def get_data(path, name): + if name in ('Cora', 'CiteSeer', 'PubMed'): + dataset = load_planetoid_dataset(path, name) + elif name in ('ogbn-arxiv', 'ogbn-products'): + dataset = load_ogb_dataset(path, name) + elif name in ('film'): + dataset = load_geom_gcn_dataset(path, name) + elif name in ('deezer'): + dataset = load_deezer_dataset(path) + + return dataset + + +def load_planetoid_dataset(data_dir, name): + # transform = T.NormalizeFeatures() + torch_dataset = Planetoid(root=data_dir, + name=name) + data = torch_dataset[0] + + edge_index = data.edge_index + node_feat = data.x + label = data.y + num_nodes = data.num_nodes + + dataset = NCDataset(name) + + dataset.train_mask = data.train_mask + dataset.valid_mask = data.val_mask + dataset.test_mask = data.test_mask + + dataset.graph = {'edge_index': edge_index, + 'node_feat': node_feat, + 'edge_feat': None, + 'num_nodes': num_nodes} + dataset.label = label + + return dataset + + +def load_ogb_dataset(data_dir, name): + dataset = NCDataset(name) + ogb_dataset = NodePropPredDataset(name=name, root=f'{data_dir}/ogb') + dataset.graph = ogb_dataset.graph + dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index']) + dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat']) + + dataset.label = torch.as_tensor(ogb_dataset.labels).squeeze(-1) + + split_idx = ogb_dataset.get_idx_split() + dataset.train_mask = torch.zeros_like(dataset.label, dtype=torch.bool) + dataset.train_mask[split_idx['train']] = True + dataset.valid_mask = torch.zeros_like(dataset.label, dtype=torch.bool) + dataset.valid_mask[split_idx['valid']] = True + dataset.test_mask = torch.zeros_like(dataset.label, dtype=torch.bool) + dataset.test_mask[split_idx['test']] = True + if name == "ogbn-arxiv": + dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) + + return dataset + + +def load_geom_gcn_dataset(data_dir, name): + graph_adjacency_list_file_path = os.path.join(data_dir, 'out1_graph_edges.txt'.format(name)) + graph_node_features_and_labels_file_path = os.path.join(data_dir, 'out1_node_feature_label.txt'.format(name)) + + G = nx.DiGraph() + graph_node_features_dict = {} + graph_labels_dict = {} + + if name == 'film': + with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: + graph_node_features_and_labels_file.readline() + for line in graph_node_features_and_labels_file: + line = line.rstrip().split('\t') + assert (len(line) == 3) + assert (int(line[0]) not in graph_node_features_dict and int( + line[0]) not in graph_labels_dict) + feature_blank = np.zeros(932, dtype=np.uint8) + feature_blank[np.array( + line[1].split(','), dtype=np.uint16)] = 1 + graph_node_features_dict[int(line[0])] = feature_blank + graph_labels_dict[int(line[0])] = int(line[2]) + else: + with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: + graph_node_features_and_labels_file.readline() + for line in graph_node_features_and_labels_file: + line = line.rstrip().split('\t') + assert (len(line) == 3) + assert (int(line[0]) not in graph_node_features_dict and int( + line[0]) not in graph_labels_dict) + graph_node_features_dict[int(line[0])] = np.array( + line[1].split(','), dtype=np.uint8) + graph_labels_dict[int(line[0])] = int(line[2]) + + with open(graph_adjacency_list_file_path) as graph_adjacency_list_file: + graph_adjacency_list_file.readline() + for line in graph_adjacency_list_file: + line = line.rstrip().split('\t') + assert (len(line) == 2) + if int(line[0]) not in G: + G.add_node(int(line[0]), features=graph_node_features_dict[int(line[0])], + label=graph_labels_dict[int(line[0])]) + if int(line[1]) not in G: + G.add_node(int(line[1]), features=graph_node_features_dict[int(line[1])], + label=graph_labels_dict[int(line[1])]) + G.add_edge(int(line[0]), int(line[1])) + + adj = nx.adjacency_matrix(G, sorted(G.nodes())) + adj = sp.coo_matrix(adj) + adj = adj + sp.eye(adj.shape[0]) + adj = adj.tocoo().astype(np.float32) + features = np.array( + [features for _, features in sorted(G.nodes(data='features'), key=lambda x: x[0])]) + labels = np.array( + [label for _, label in sorted(G.nodes(data='label'), key=lambda x: x[0])]) + print(features.shape) + + def preprocess_features(feat): + """Row-normalize feature matrix and convert to tuple representation""" + rowsum = np.array(feat.sum(1)) + rowsum = (rowsum == 0) * 1 + rowsum + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + feat = r_mat_inv.dot(feat) + return feat + + features = preprocess_features(features) + + edge_index = torch.from_numpy( + np.vstack((adj.row, adj.col)).astype(np.int64)) + node_feat = torch.FloatTensor(features) + labels = torch.LongTensor(labels) + num_nodes = node_feat.shape[0] + print(f"Num nodes: {num_nodes}") + + dataset = NCDataset(name) + + dataset.graph = {'edge_index': edge_index, + 'node_feat': node_feat, + 'edge_feat': None, + 'num_nodes': num_nodes} + + dataset.label = labels + dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) + return dataset + + +def load_deezer_dataset(path): + filename = 'deezer-europe' + dataset = NCDataset(filename) + deezer = scipy.io.loadmat(f'{path}/deezer-europe.mat') + + A, label, features = deezer['A'], deezer['label'], deezer['features'] + edge_index = torch.tensor(A.nonzero(), dtype=torch.long) + node_feat = torch.tensor(features.todense(), dtype=torch.float) + label = torch.tensor(label, dtype=torch.long).squeeze() + num_nodes = label.shape[0] + + dataset.graph = {'edge_index': edge_index, + 'edge_feat': None, + 'node_feat': node_feat, + 'num_nodes': num_nodes} + # dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) + dataset.label = label + return dataset diff --git a/examples/cobformer/Data/partition/ogbn-products_partition_8192.pt b/examples/cobformer/Data/partition/ogbn-products_partition_8192.pt new file mode 100644 index 00000000..09ef1719 Binary files /dev/null and b/examples/cobformer/Data/partition/ogbn-products_partition_8192.pt differ diff --git a/examples/cobformer/Data/save_data.py b/examples/cobformer/Data/save_data.py new file mode 100644 index 00000000..89e014f1 --- /dev/null +++ b/examples/cobformer/Data/save_data.py @@ -0,0 +1,42 @@ +import numpy as np +import os.path as osp + +import scipy.sparse + +from get_data import get_data +import torch +from torch_sparse import SparseTensor +from torch_geometric.utils import degree + + +dataset = 'ogbn-products' +path = osp.join(osp.expanduser('~'), 'datasets', dataset) +data = get_data(path, dataset) +N = data.graph['num_nodes'] +edge_index = data.graph['edge_index'] +row, col = edge_index +d = degree(col, N).float() +d_norm_in = (1. / d[col]).sqrt() +d_norm_out = (1. / d[row]).sqrt() +value = torch.ones_like(row) * d_norm_in * d_norm_out +value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) +adj = scipy.sparse.coo_matrix((value, (row.numpy(), col.numpy())), shape=[N,N]) +feature = data.graph['node_feat'] +labels = data.label +split_list = [] +if dataset in ['film', 'deezer']: + for i in range(10): + splits_file_path = '{}/{}'.format(path, dataset) + '_split_50_25_' + str(i) + '.npz' + with np.load(splits_file_path) as splits_file: + idx_train = torch.BoolTensor(splits_file['train_mask']).nonzero().squeeze() + idx_val = torch.BoolTensor(splits_file['val_mask']).nonzero().squeeze() + idx_test = torch.BoolTensor(splits_file['test_mask']).nonzero().squeeze() + split_list.append([idx_train, idx_val, idx_test]) +else: + idx_train = data.train_mask.nonzero().squeeze() + idx_val = data.valid_mask.nonzero().squeeze() + idx_test = data.test_mask.nonzero().squeeze() + split_list.append([idx_train, idx_val, idx_test]) +data_list = [adj, feature, labels] + split_list +torch.save(data_list, f'data4NAG/{dataset}.pt') +print("save done") diff --git a/examples/cobformer/Data/split_save.py b/examples/cobformer/Data/split_save.py new file mode 100644 index 00000000..c2fda9a7 --- /dev/null +++ b/examples/cobformer/Data/split_save.py @@ -0,0 +1,21 @@ +import torch.nn.functional as F +import numpy as np +import os.path as osp +from yaml import SafeLoader +from get_data import get_data +from data_utils import rand_train_test_idx + +dataset = 'film' +path = osp.join(osp.expanduser('~'), 'datasets', dataset) + +train_prop = 0.5 +valid_prop = 0.25 + +data = get_data(path, dataset) +# get the splits for all runs +for i in range(10): + ignore_negative = False if dataset == 'ogbn-proteins' else True + train_mask, valid_mask, test_mask = rand_train_test_idx( + data.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative) + splits_file_path = '{}/{}'.format(path, dataset) + '_split_50_25_' + str(i) + np.savez(splits_file_path, train_mask=train_mask, val_mask=valid_mask, test_mask=test_mask) diff --git a/examples/cobformer/Train/train_test.py b/examples/cobformer/Train/train_test.py new file mode 100644 index 00000000..df5c4fba --- /dev/null +++ b/examples/cobformer/Train/train_test.py @@ -0,0 +1,103 @@ +import torch +import torch.nn.functional as F +from torcheval.metrics.functional import multiclass_f1_score + + +def eval_f1(pred, label, num_classes): + micro = multiclass_f1_score(pred, label, num_classes=num_classes, average='micro') + macro = multiclass_f1_score(pred, label, num_classes=num_classes, average='macro') + return micro.item(), macro.item() + + +def co_train(model, data, label, patch, split_index, optimizer): + model.train() + optimizer.zero_grad() + pred1, pred2 = model(data.graph['node_feat'], patch, data.graph['edge_index']) + loss = model.loss(pred1, pred2, label, split_index['train']) + loss.backward() + optimizer.step() + # eval + model.eval() + with torch.no_grad(): + pred1, pred2 = model(data.graph['node_feat'], patch, data.graph['edge_index']) + + # pred1 = F.log_softmax(pred1, dim=-1) + # pred2 = F.log_softmax(pred2, dim=-1) + + y = data.label.squeeze() + num_classes = y.max() + 1 + + y1_ = torch.argmax(pred1, dim=1).squeeze() + micro_val1, macro_val1 = eval_f1(y1_[split_index['valid']], y[split_index['valid']], num_classes) + micro_test1, macro_test1 = eval_f1(y1_[split_index['test']], y[split_index['test']], num_classes) + + y2_ = torch.argmax(pred2, dim=1).squeeze() + micro_val2, macro_val2 = eval_f1(y2_[split_index['valid']], y[split_index['valid']], num_classes) + micro_test2, macro_test2 = eval_f1(y2_[split_index['test']], y[split_index['test']], num_classes) + + return micro_val1, micro_test1, macro_val1, macro_test1, micro_val2, micro_test2, macro_val2, macro_test2 + + +def co_train_batch(model, node_feat_i, edge_index_i, label_i, patch_i, train_idx, optimizer): + model.train() + optimizer.zero_grad() + pred1, pred2 = model(node_feat_i, patch_i, edge_index_i) + loss = model.loss(pred1, pred2, label_i, train_idx) + loss.backward() + optimizer.step() + + +def co_test_batch(model, node_feat_i, edge_index_i, label_i, patch_i, valid_idx, test_idx): + model.eval() + with torch.no_grad(): + pred1, pred2 = model(node_feat_i, patch_i, edge_index_i) + + # pred1 = F.log_softmax(pred1, dim=-1) + # pred2 = F.log_softmax(pred2, dim=-1) + + y = data.label.squeeze() + num_classes = y.max() + 1 + + y1_ = torch.argmax(pred1, dim=1).squeeze() + micro_val1, macro_val1 = eval_f1(y1_, y, num_classes) + micro_test1, macro_test1 = eval_f1(y1_, y, num_classes) + + y2_ = torch.argmax(pred2, dim=1).squeeze() + micro_val2, macro_val2 = eval_f1(y2_, y, num_classes) + micro_test2, macro_test2 = eval_f1(y2_, y, num_classes) + + return micro_val1.item(), micro_test1.item(), macro_val1.item(), macro_test1.item(), micro_val2.item(), micro_test2.item(), macro_val2.item(), macro_test2.item() + + +def test(model, data, patch, split_index): + model.eval() + with torch.no_grad(): + pred = model(data.graph['node_feat'], patch, data.graph['edge_index']) + y_hat_val = torch.argmax(pred[split_index['valid']], dim=1) + acc_val = torch.mean(torch.eq(y_hat_val, data.label[split_index['valid']]).float()) + y_hat_test = torch.argmax(pred[split_index['test']], dim=1) + acc_test = torch.mean(torch.eq(y_hat_test, data.label[split_index['test']]).float()) + + return acc_val.item(), acc_test.item() + + +def co_test(model, data, patch, split_index): + model.eval() + with torch.no_grad(): + pred1, pred2 = model(data.graph['node_feat'], patch, data.graph['edge_index']) + + # pred1 = F.log_softmax(pred1, dim=-1) + # pred2 = F.log_softmax(pred2, dim=-1) + + y = data.label.squeeze() + num_classes = y.max() + 1 + + y1_ = torch.argmax(pred1, dim=1).squeeze() + micro_val1, macro_val1 = eval_f1(y1_[split_index['valid']], y[split_index['valid']], num_classes) + micro_test1, macro_test1 = eval_f1(y1_[split_index['test']], y[split_index['test']], num_classes) + + y2_ = torch.argmax(pred2, dim=1).squeeze() + micro_val2, macro_val2 = eval_f1(y2_[split_index['valid']], y[split_index['valid']], num_classes) + micro_test2, macro_test2 = eval_f1(y2_[split_index['test']], y[split_index['test']], num_classes) + + return micro_val1, micro_test1, macro_val1, macro_test1, micro_val2, micro_test2, macro_val2, macro_test2 diff --git a/examples/cobformer/cal_partition.py b/examples/cobformer/cal_partition.py new file mode 100644 index 00000000..0af5fa4c --- /dev/null +++ b/examples/cobformer/cal_partition.py @@ -0,0 +1,18 @@ +import torch + +from Data.get_data import * +import os.path as osp + +dataset = 'ogbn-products' +# dataset = 'Cora' +# n_patch = 16384 +n_patch = 8192 +# n_patch = 112 +path = osp.join(osp.expanduser('~'), 'datasets', dataset) +data = get_data(path, dataset) + +patch = metis_partition(data.graph, n_patch) +# node_mask = torch.load('partition/'+dataset+'_partition_{}.pt') +# patch = patch2batch(data.graph, node_mask) +torch.save(patch, 'Data/partition/'+dataset+f'_partition_{n_patch}.pt') +print('Done!!!') diff --git a/examples/cobformer/cobformer_trainer.py b/examples/cobformer/cobformer_trainer.py new file mode 100644 index 00000000..01a2b573 --- /dev/null +++ b/examples/cobformer/cobformer_trainer.py @@ -0,0 +1,124 @@ +import os +os.environ['TL_BACKEND'] = "torch" +import argparse +import torch +import torch.nn.functional as F +import numpy as np +import yaml +import os +import os.path as osp +from yaml import SafeLoader +from Data.get_data import get_data +from run import run +from run_batch import run_batch +from Data.data_utils import load_fixed_splits +import random +import warnings + + +def fix_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +if __name__ == '__main__': + warnings.filterwarnings('ignore') + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Cora') + parser.add_argument('--learning_rate', type=float, default=0.01) + parser.add_argument('--weight_decay', type=float, default=5e-4) + parser.add_argument('--gcn_wd', type=float, default=5e-4) + parser.add_argument('--gpu_id', type=int, default=6) + parser.add_argument('--config', type=str, default='config.yaml') + parser.add_argument('--gcn_use_bn', action='store_true', help='gcn use batch norm') + parser.add_argument('--use_patch_attn', action='store_true', help='transformer use patch attention') + parser.add_argument('--show_details', type=bool, default=True) + parser.add_argument('--gcn_type', type=int, default=1) + parser.add_argument('--gcn_layers', type=int, default=2) + parser.add_argument('--n_patch', type=int, default=112) + parser.add_argument('--batch_size', type=int, default=100000) + parser.add_argument('--rand_split', action='store_true', help='random split dataset') + parser.add_argument('--rand_split_class', action='store_true', help='random split dataset by class') + parser.add_argument('--protocol', type=str, default='semi') + parser.add_argument('--label_num_per_class', type=int, default=20) + parser.add_argument('--train_prop', type=float, default=.6) + parser.add_argument('--valid_prop', type=float, default=.2) + parser.add_argument('--alpha', type=float, default=.8) + parser.add_argument('--tau', type=float, default=.3) + + args = parser.parse_args() + + assert args.gpu_id in range(0, 8) + torch.cuda.set_device(args.gpu_id) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] + fix_seed(config['seed']) + + path = osp.join(osp.expanduser('~'), 'datasets', args.dataset) + results = dict() + n_patch = args.n_patch + alpha = args.alpha + tau = args.tau + load_path = None + if args.dataset in ['ogbn-products']: + load_path = f'Data/partition/{args.dataset}_partition_{n_patch}.pt' + + # postfix = f'{n_patch}' + postfix = "test" + runs = 5 + print("n_patch: ", n_patch) + + data = get_data(path, args.dataset) + # get the splits for all runs + if args.rand_split: + split_idx_lst = [data.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) + for _ in range(runs)] + elif args.rand_split_class: + split_idx_lst = [data.get_idx_split(split_type='class', label_num_per_class=args.label_num_per_class) + for _ in range(runs)] + else: + split_idx_lst = load_fixed_splits(path, data, name=args.dataset, protocol=args.protocol) + patch = data.partition_patch(n_patch, load_path) + + batch_size = args.batch_size + + results = [[], []] + for r in range(runs): + if args.dataset in ['Cora', 'CiteSeer', 'PubMed', 'ogbn-arxiv', 'ogbn-products'] and args.protocol == 'semi': + split_idx = split_idx_lst[0] + else: + split_idx = split_idx_lst[r] + + if args.dataset in ['ogbn-products']: + res_gnn, res_trans = run_batch(args, config, device, data, patch, batch_size, split_idx, alpha, tau, + postfix) + else: + res_gnn, res_trans = run(args, config, device, data, patch, split_idx, alpha, tau, postfix) + results[0].append(res_gnn) + results[1].append(res_trans) + + print(f"==== Final GNN====") + result = torch.tensor(results[0]) * 100. + print(result) + print(f"max: {torch.max(result, dim=0)[0]}") + print(f"min: {torch.min(result, dim=0)[0]}") + print(f"mean: {torch.mean(result, dim=0)}") + print(f"std: {torch.std(result, dim=0)}") + + print(f'GNN Micro: {torch.mean(result, dim=0)[1]:.2f} ± {torch.std(result, dim=0)[1]:.2f}') + print(f'GNN Macro: {torch.mean(result, dim=0)[3]:.2f} ± {torch.std(result, dim=0)[3]:.2f}') + + print(f"==== Final Trans====") + result = torch.tensor(results[1]) * 100. + print(result) + print(f"max: {torch.max(result, dim=0)[0]}") + print(f"min: {torch.min(result, dim=0)[0]}") + print(f"mean: {torch.mean(result, dim=0)}") + print(f"std: {torch.std(result, dim=0)}") + + print(f'Trans Micro: {torch.mean(result, dim=0)[1]:.2f} ± {torch.std(result, dim=0)[1]:.2f}') + print(f'Trans Macro: {torch.mean(result, dim=0)[3]:.2f} ± {torch.std(result, dim=0)[3]:.2f}') diff --git a/examples/cobformer/cobformer_trainer_batch.py b/examples/cobformer/cobformer_trainer_batch.py new file mode 100644 index 00000000..7162d309 --- /dev/null +++ b/examples/cobformer/cobformer_trainer_batch.py @@ -0,0 +1,106 @@ +import argparse +import torch +import torch.nn.functional as F +import numpy as np +import yaml +import os +import os.path as osp +from yaml import SafeLoader +from Data.get_data import get_data +from run import run +from run_batch import run_batch +from Data.data_utils import load_fixed_splits +import random + + +def fix_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Cora') + parser.add_argument('--gpu_id', type=int, default=6) + parser.add_argument('--config', type=str, default='config.yaml') + parser.add_argument('--gcn_use_bn', action='store_true', help='gcn use batch norm') + parser.add_argument('--show_details', type=bool, default=True) + parser.add_argument('--gcn_type', type=int, default=1) + parser.add_argument('--gcn_layers', type=int, default=2) + parser.add_argument('--n_patch', type=int, default=112) + parser.add_argument('--rand_split', action='store_true', help='random split dataset') + parser.add_argument('--rand_split_class', action='store_true', help='random split dataset by class') + parser.add_argument('--protocol', type=str, default='semi') + parser.add_argument('--label_num_per_class', type=int, default=20) + parser.add_argument('--train_prop', type=float, default=.6) + parser.add_argument('--valid_prop', type=float, default=.2) + parser.add_argument('--alpha', type=float, default=.8) + parser.add_argument('--tau', type=float, default=.3) + + args = parser.parse_args() + + assert args.gpu_id in range(0, 8) + torch.cuda.set_device(args.gpu_id) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset] + + path = osp.join(osp.expanduser('~'), 'datasets', args.dataset) + + results = dict() + n_patch = args.n_patch + alpha = args.alpha + tau = args.tau + # postfix = f'{n_patch}' + postfix = "test" + runs = 5 + + data = get_data(path, args.dataset) + # get the splits for all runs + if args.rand_split: + split_idx_lst = [data.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) + for _ in range(runs)] + elif args.rand_split_class: + split_idx_lst = [data.get_idx_split(split_type='class', label_num_per_class=args.label_num_per_class) + for _ in range(runs)] + else: + split_idx_lst = load_fixed_splits(path, data, name=args.dataset, protocol=args.protocol) + + batch_size = 100000 + + results = [[], []] + fix_seed(config['seed']) + for r in range(runs): + if args.dataset in ['Cora', 'CiteSeer', 'PubMed', 'ogbn-arxiv'] and args.protocol == 'semi': + split_idx = split_idx_lst[0] + else: + split_idx = split_idx_lst[r] + + res_gnn, res_trans = run_batch(args, config, device, data, patch, batch_size, split_idx, alpha, tau, postfix) + results[0].append(res_gnn) + results[1].append(res_trans) + + print(f"==== Final GNN====") + result = torch.tensor(results[0]) * 100. + print(result) + print(f"max: {torch.max(result, dim=0)[0]}") + print(f"min: {torch.min(result, dim=0)[0]}") + print(f"mean: {torch.mean(result, dim=0)}") + print(f"std: {torch.std(result, dim=0)}") + + print(f'GNN Micro: {torch.mean(result, dim=0)[1]:.2f} ± {torch.std(result, dim=0)[1]:.2f}') + print(f'GNN Macro: {torch.mean(result, dim=0)[3]:.2f} ± {torch.std(result, dim=0)[3]:.2f}') + + print(f"==== Final Trans====") + result = torch.tensor(results[1]) * 100. + print(result) + print(f"max: {torch.max(result, dim=0)[0]}") + print(f"min: {torch.min(result, dim=0)[0]}") + print(f"mean: {torch.mean(result, dim=0)}") + print(f"std: {torch.std(result, dim=0)}") + + print(f'Trans Micro: {torch.mean(result, dim=0)[1]:.2f} ± {torch.std(result, dim=0)[1]:.2f}') + print(f'Trans Macro: {torch.mean(result, dim=0)[3]:.2f} ± {torch.std(result, dim=0)[3]:.2f}') diff --git a/examples/cobformer/config.yaml b/examples/cobformer/config.yaml new file mode 100644 index 00000000..7667e795 --- /dev/null +++ b/examples/cobformer/config.yaml @@ -0,0 +1,67 @@ +Cora: + seed: 123 + learning_rate: 0.01 + num_hidden: 64 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 4 + num_epochs: 500 + weight_decay: 0.0005 +CiteSeer: + seed: 123 + learning_rate: 0.005 + num_hidden: 64 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 1 + num_epochs: 500 + weight_decay: 0.001 +PubMed: + seed: 123 + learning_rate: 0.005 + num_hidden: 64 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 1 + num_epochs: 500 + weight_decay: 0.001 +ogbn-arxiv: + seed: 123 + learning_rate: 0.005 + num_hidden: 256 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 4 + num_epochs: 3000 + weight_decay: 0. +ogbn-products: + seed: 123 + num_hidden: 256 + activation: 'relu' + base_model: 'GCNConv' + n_head: 4 + num_epochs: 1000 +film: + seed: 123 + learning_rate: 0.05 + num_hidden: 64 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 2 + num_epochs: 600 + weight_decay: 0.005 +deezer: + seed: 123 + learning_rate: 0.01 + num_hidden: 96 + activation: 'relu' + base_model: 'GCNConv' + num_layers: 1 + n_head: 1 + num_epochs: 200 + weight_decay: 0.0005 \ No newline at end of file diff --git a/examples/cobformer/generate_split.py b/examples/cobformer/generate_split.py new file mode 100644 index 00000000..bcb37867 --- /dev/null +++ b/examples/cobformer/generate_split.py @@ -0,0 +1,56 @@ +import os +import os.path as osp +import numpy as np + +def generate_splits(dataset, name, data_dir): + """ + dataset: 已加载的 NCDataset,必须含 dataset.graph['num_nodes'] + name: 'film' 或 'deezer' + data_dir: 存放 .npz 文件的目录 + """ + N = dataset.graph['num_nodes'] + os.makedirs(data_dir, exist_ok=True) + + for i in range(10): + # 每一折用不同的 seed 保证可复现 + rng = np.random.default_rng(seed=i) + perm = rng.permutation(N) + + n_train = int(N * 0.50) + n_val = int(N * 0.25) + # 剩下的就是 test + train_idx = perm[:n_train] + val_idx = perm[n_train:n_train + n_val] + test_idx = perm[n_train + n_val:] + + # 构造布尔 mask + train_mask = np.zeros(N, dtype=bool) + val_mask = np.zeros(N, dtype=bool) + test_mask = np.zeros(N, dtype=bool) + train_mask[train_idx] = True + val_mask[val_idx] = True + test_mask[test_idx] = True + + # 保存 npz + out_path = osp.join(data_dir, f"{name}_split_50_25_{i}.npz") + np.savez(out_path, + train_mask=train_mask, + val_mask=val_mask, + test_mask=test_mask) + print(f"Saved split {i} ➜ {out_path}") + +if __name__=="__main__": + from Data.get_data import load_geom_gcn_dataset, load_deezer_dataset + + path = osp.join(osp.expanduser('~'), 'datasets', 'film') + + dataset = load_geom_gcn_dataset(path, "film") + + generate_splits(dataset, name="film", data_dir=path) + + path = osp.join(osp.expanduser('~'), 'datasets', 'deezer') + + dataset = load_deezer_dataset(path) + + generate_splits(dataset, name="deezer", data_dir=path) + diff --git a/examples/cobformer/readme.md b/examples/cobformer/readme.md new file mode 100644 index 00000000..c459ba8a --- /dev/null +++ b/examples/cobformer/readme.md @@ -0,0 +1,64 @@ +# Less is More: on the Over-Globalizing Problem in Graph Transformers (CoBFormer) + +- Paper link: [http://www.shichuan.org/doc/177.pdf](http://www.shichuan.org/doc/177.pdf) +- Author's code repo: [https://github.com/BUPT-GAMMA/CoBFormer](https://github.com/BUPT-GAMMA/CoBFormer) + +# Dataset Statics + +| Dataset | # Nodes | # Edges | # Feats | Edge hom | # Classes | +|---------------|-----------|------------|---------|----------|-----------| +| Cora | 2,708 | 5,429 | 1,433 | 0.83 | 7 | +| CiteSeer | 3,327 | 4,732 | 3,703 | 0.72 | 6 | +| PubMed | 19,717 | 44,338 | 500 | 0.79 | 3 | +| Actor | 7,600 | 26,752 | 931 | 0.22 | 5 | +| Deezer | 28,281 | 92,752 | 31,241 | 0.52 | 2 | +| Ogbn-Arxiv | 169,343 | 1,166,343 | 128 | 0.63 | 40 | +| Ogbn-Products | 2,449,029 | 61,859,140 | 100 | 0.81 | 47 | + +Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). + +# Results + +- Available dataset: "Cora", "Citeseer", "Pubmed", "film", "deezer", "ogbn-arxiv", "ogbn-products" + +```bash +# available dataset: "cora", "citeseer", "pubmed" +python cobformer_trainer.py --dataset=Cora --learning_rate=0.01 --gcn_wd=1e-3 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=112 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=3 +python cobformer_trainer.py --dataset=CiteSeer --learning_rate=5e-3 --gcn_wd=1e-2 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=144 --use_patch_attn --alpha=0.8 --tau=0.7 --gpu_id=3 +python cobformer_trainer.py --dataset=PubMed --learning_rate=5e-3 --gcn_wd=1e-3 --weight_decay=1e-3 --gcn_type=1 --gcn_layers=2 --n_patch=224 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=3 +python cobformer_trainer.py --dataset=film --learning_rate=5e-2 --gcn_wd=1e-4 --weight_decay=1e-3 --gcn_type=1 --gcn_layers=2 --n_patch=112 --use_patch_attn --alpha=0.7 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --dataset=deezer --learning_rate=0.01 --gcn_wd=1e-3 --weight_decay=5e-4 --gcn_type=1 --gcn_layers=2 --n_patch=224 --use_patch_attn --alpha=0.8 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --dataset=ogbn-arxiv --learning_rate=1e-3 --weight_decay=0. --gcn_use_bn --gcn_type=2 --gcn_layers=3 --n_patch=2048 --use_patch_attn --alpha=0.9 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --dataset=ogbn-products --learning_rate=5e-4 --weight_decay=0. --gcn_type=2 --gcn_layers=3 --gcn_use_bn --n_patch=8192 --use_patch_attn --batch_size=150000 --alpha=0.9 --tau=0.7 --gpu_id=3 +``` + +- Or use `runs.sh` to run all experiments + +```bash +bash runs.sh +``` + +Paper: + +| Dataset | CoB-G Mi-F1 | CoB-T Mi-F1 | CoB-G Ma-F1 | CoB-T Ma-F1 | +|-------------------|----------------|----------------|----------------|----------------| +| **Cora** | 84.96 ± 0.34 % | 85.28 ± 0.16 % | 83.52 ± 0.15 % | 84.10 ± 0.28 % | +| **CiteSeer** | 74.68 ± 0.33 % | 74.52 ± 0.48 % | 69.73 ± 0.45 % | 69.82 ± 0.55 % | +| **PubMed** | 80.52 ± 0.25 % | 81.42 ± 0.53 % | 80.02 ± 0.28 % | 81.04 ± 0.49 % | +| **Actor** | 31.05 ± 1.02 % | 37.41 ± 0.36 % | 27.01 ± 1.77 % | 34.96 ± 0.68 % | +| **Deezer** | 63.76 ± 0.62 % | 66.96 ± 0.37 % | 62.32 ± 0.94 % | 65.63 ± 0.36 % | +| **Ogbn-Arxiv** | 73.17 ± 0.18 % | 72.76 ± 0.11 % | 52.31 ± 0.40 % | 51.64 ± 0.09 % | +| **Ogbn-Products** | 78.09 ± 0.16 % | 78.15 ± 0.07 % | 38.21 ± 0.22 % | 37.91 ± 0.44 % | + +Our: + +| Dataset | CoB-G Mi-F1 | CoB-T Mi-F1 | CoB-G Ma-F1 | CoB-T Ma-F1 | +|-------------------|----------------|----------------|----------------|----------------| +| **Cora** | 84.44 ± 0.60 % | 84.42 ± 0.79 % | 83.23 ± 0.93 % | 83.41 ± 0.45 % | +| **CiteSeer** | 74.64 ± 0.40 % | 74.40 ± 0.52 % | 69.76 ± 0.37 % | 69.93 ± 0.59 % | +| **PubMed** | 80.30 ± 0.32 % | 81.02 ± 0.47 % | 79.72 ± 0.35 % | 80.66 ± 0.52 % | +| **Actor** | 30.44 ± 1.16 % | 34.19 ± 4.05 % | 24.42 ± 3.25 % | 28.08 ± 8.36 % | +| **Deezer** | 64.05 ± 0.42 % | 66.62 ± 0.81 % | 63.12 ± 0.46 % | 65.03 ± 0.39 % | +| **Ogbn-Arxiv** | 73.09 ± 0.10 % | 72.71 ± 0.16 % | 52.37 ± 0.33 % | 51.39 ± 0.21 % | +| **Ogbn-Products** | 78.14 ± 0.09 % | 78.15 ± 0.08 % | 38.36 ± 0.19 % | 38.09 ± 0.35 % | + diff --git a/examples/cobformer/run.py b/examples/cobformer/run.py new file mode 100644 index 00000000..95da2b30 --- /dev/null +++ b/examples/cobformer/run.py @@ -0,0 +1,114 @@ +from time import perf_counter as t + +import torch +import torch.nn as nn +import random +from gammagl.models.CoBFormer import CoBFormer +from Train.train_test import * + +# max_val = -10000 +def co_early_stop_train(epochs, patience, model, data, label, patch, split_index, optimizer, show_details, + postfix, save_path=None): + best_epoch1 = 0 + best_epoch2 = 0 + acc_val1_max = 0. + acc_val2_max = 0. + logger = [] + max_val = -10000 + + for epoch in range(1, epochs + 1): + + micro_val1, micro_test1, macro_val1, macro_test1, micro_val2, micro_test2, macro_val2, macro_test2 = co_train( + model, data, label, patch, split_index, optimizer) + logger.append( + [micro_val1, micro_test1, macro_val1, macro_test1, micro_val2, micro_test2, macro_val2, macro_test2]) + + if show_details and epoch % 50 == 0: + print( + f'(T) | Epoch={epoch:03d}\n', + f'micro_val1={micro_val1:.4f}, micro_test1={micro_test1:.4f}, macro_val1={macro_val1:.4f}, macro_test1={macro_test1:.4f}\n', + f'micro_val2={micro_val2:.4f}, micro_test2={micro_test2:.4f}, macro_val2={macro_val2:.4f}, macro_test2={macro_test2:.4f}\n') + + logger = torch.tensor(logger) + ind = torch.argmax(logger, dim=0) + + res_gnn = [] + res_trans = [] + + res_gnn.append(logger[ind[0]][0]) + res_gnn.append(logger[ind[0]][1]) + res_gnn.append(logger[ind[2]][2]) + res_gnn.append(logger[ind[2]][3]) + res_gnn.append(logger[ind[1]][1]) + res_gnn.append(logger[ind[3]][3]) + + res_trans.append(logger[ind[4]][4]) + res_trans.append(logger[ind[4]][5]) + res_trans.append(logger[ind[6]][6]) + res_trans.append(logger[ind[6]][7]) + res_trans.append(logger[ind[5]][5]) + res_trans.append(logger[ind[7]][7]) + + return res_gnn, res_trans + + +def run(args, config, device, data, patch, split_idx, alpha, tau, postfix): + learning_rate = args.learning_rate + # learning_rate2 = args.learning_rate2 + + weight_decay = args.weight_decay + gcn_wd = args.gcn_wd + num_hidden = config['num_hidden'] + activation = ({'relu': F.relu, 'prelu': nn.PReLU()})[config['activation']] + num_layers = config['num_layers'] + n_head = config['n_head'] + num_epochs = config['num_epochs'] + gcn_type = args.gcn_type + gcn_layers = args.gcn_layers + gcn_use_bn = args.gcn_use_bn + use_patch_attn = args.use_patch_attn + show_details = args.show_details + patch = patch.to(device) + num_nodes = data.graph['num_nodes'] + num_classes = data.label.max() + 1 + num_features = data.graph['node_feat'].shape[-1] + data.graph['node_feat'] = data.graph['node_feat'].to(device) + data.graph['edge_index'] = data.graph['edge_index'].to(device) + data.label = data.label.to(device) + + split_idx['train'] = split_idx['train'].to(device) + split_idx['valid'] = split_idx['valid'].to(device) + split_idx['test'] = split_idx['test'].to(device) + + label = F.one_hot(data.label, num_classes).float() + + # model = Beyondformer(num_nodes, num_features, num_hidden, num_classes, activation, + # layers=num_layers, gnn_layers=gcn_layers, n_head=n_head, alpha=alpha, ratio=ratio).to(device) + model = CoBFormer(num_nodes, num_features, num_hidden, num_classes, activation, layers=num_layers, + gcn_layers=gcn_layers, gcn_type=gcn_type, n_head=n_head, alpha=alpha, tau=tau, + gcn_use_bn=gcn_use_bn, use_patch_attn=use_patch_attn).to(device) + # print(model) + + if args.dataset in ['film', 'CiteSeer', 'Cora', 'PubMed', "Deezer"]: + optimizer = torch.optim.Adam([ + {'params': model.bga.parameters(), 'weight_decay': weight_decay}, + {'params': model.gcn.parameters(), 'weight_decay': gcn_wd} + ], lr=learning_rate) + else: + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + + patience = num_epochs + + res_gnn, res_trans = co_early_stop_train( + num_epochs, patience, + model, data, label, + patch, split_idx, + optimizer, show_details, + postfix) + print("=== Train Final ===") + print( + f'micro_val1={res_gnn[0]:.4f}, micro_test1={res_gnn[1]:.4f}, macro_val1={res_gnn[2]:.4f}, macro_test1={res_gnn[3]:.4f}, micro_best1={res_gnn[4]:.4f}, macro_best1={res_gnn[5]:.4f},\n', + f'micro_val2={res_trans[0]:.4f}, micro_test2={res_trans[1]:.4f}, macro_val2={res_trans[2]:.4f}, macro_test2={res_trans[3]:.4f}, micro_best2={res_trans[4]:.4f}, macro_best2={res_trans[5]:.4f}\n') + + return res_gnn, res_trans diff --git a/examples/cobformer/run_batch.py b/examples/cobformer/run_batch.py new file mode 100644 index 00000000..27788ddc --- /dev/null +++ b/examples/cobformer/run_batch.py @@ -0,0 +1,163 @@ +from time import perf_counter as t + +import torch +import random +from gammagl.models.CoBFormer import * +from torch_geometric.utils import subgraph +from torch_geometric.utils.map import map_index +from Train.train_test import * +import torch.nn as nn + + +def co_early_stop_train_batch(epochs, patience, model, data, label, patch, batch_size, split_index, optimizer, + show_details, device, postfix): + best_epoch1 = 0 + best_epoch2 = 0 + acc_val1_max = 0. + acc_val2_max = 0. + logger = [] + + n_patch, patch_size = patch.shape + patch_per_batch = batch_size // patch_size + num_batch = n_patch // patch_per_batch + (n_patch % patch_per_batch > 0) + + for epoch in range(1, epochs + 1): + + idx = torch.randperm(n_patch) + for i in range(num_batch): + patch_idx = idx[i * patch_per_batch: (i + 1) * patch_per_batch] + patch_i = patch[patch_idx] + node_i = torch.unique(patch_i) + patch_i = map_index(patch_i, node_i)[0].view(patch_i.shape).to(device) + node_feat_i = data.graph['node_feat'][node_i].to(device) + edge_index_i, _ = subgraph(node_i, data.graph['edge_index'], num_nodes=data.graph['num_nodes'], + relabel_nodes=True) + edge_index_i = edge_index_i.to(device) + label_i = label[node_i].to(device) + train_idx = split_index['train'][node_i].to(device) + + co_train_batch(model, node_feat_i, edge_index_i, label_i, patch_i, train_idx, optimizer) + + if epoch % 5 == 0: + model.eval() + y1 = torch.zeros_like(data.label).to(device) + y2 = torch.zeros_like(data.label).to(device) + with torch.no_grad(): + idx = torch.randperm(n_patch) + for i in range(num_batch): + patch_idx = idx[i * patch_per_batch: (i + 1) * patch_per_batch] + patch_i = patch[patch_idx] + node_i = torch.unique(patch_i) + patch_i = map_index(patch_i, node_i)[0].view(patch_i.shape).to(device) + node_feat_i = data.graph['node_feat'][node_i].to(device) + edge_index_i, _ = subgraph(node_i, data.graph['edge_index'], num_nodes=data.graph['num_nodes'], + relabel_nodes=True) + edge_index_i = edge_index_i.to(device) + pred1, pred2 = model(node_feat_i, patch_i, edge_index_i) + pred1 = torch.argmax(pred1, dim=1).squeeze() + pred2 = torch.argmax(pred2, dim=1).squeeze() + y1[node_i] = pred1 + y2[node_i] = pred2 + y = torch.tensor(data.label).to(device) + num_classes = data.label.max() + 1 + micro_val1, macro_val1 = eval_f1(y1[split_index['valid']], y[split_index['valid']], num_classes) + micro_test1, macro_test1 = eval_f1(y1[split_index['test']], y[split_index['test']], num_classes) + micro_val2, macro_val2 = eval_f1(y2[split_index['valid']], y[split_index['valid']], num_classes) + micro_test2, macro_test2 = eval_f1(y2[split_index['test']], y[split_index['test']], num_classes) + acc1 = torch.eq(y1[split_index['test']], y[split_index['test']]).float().mean() + acc2 = torch.eq(y2[split_index['test']], y[split_index['test']]).float().mean() + + + logger.append( + [micro_val1, micro_test1, macro_val1, macro_test1, micro_val2, micro_test2, macro_val2, macro_test2]) + + if show_details and epoch % 5 == 0: + print( + f'(T) | Epoch={epoch:03d}\n', + f'micro_val1={micro_val1:.4f}, micro_test1={micro_test1:.4f}, acc1={acc1:.4f}, macro_val1={macro_val1:.4f}, macro_test1={macro_test1:.4f}\n', + f'micro_val2={micro_val2:.4f}, micro_test2={micro_test2:.4f}, acc2={acc2:.4f}, macro_val2={macro_val2:.4f}, macro_test2={macro_test2:.4f}\n') + # acc_val = (acc_val1 + acc_val2) /2. + # if acc_val > acc_val1_max: + # acc_val1_max = acc_val + # best_epoch1 = epoch + # torch.save(model.state_dict(), f"tem/weight_best_pretrain_{postfix}_1.pkl") + # if acc_val2 > acc_val2_max: + # acc_val2_max = acc_val2 + # best_epoch2 = epoch + # torch.save(model.state_dict(), f"tem/weight_best_pretrain_{postfix}_2.pkl") + + logger = torch.tensor(logger) + ind = torch.argmax(logger, dim=0) + + res_gnn = [] + res_trans = [] + + res_gnn.append(logger[ind[0]][0]) + res_gnn.append(logger[ind[0]][1]) + res_gnn.append(logger[ind[2]][2]) + res_gnn.append(logger[ind[2]][3]) + res_gnn.append(logger[ind[1]][1]) + res_gnn.append(logger[ind[3]][3]) + + res_trans.append(logger[ind[4]][4]) + res_trans.append(logger[ind[4]][5]) + res_trans.append(logger[ind[6]][6]) + res_trans.append(logger[ind[6]][7]) + res_trans.append(logger[ind[5]][5]) + res_trans.append(logger[ind[7]][7]) + + # model.load_state_dict(torch.load(f"tem/weight_best_pretrain_{postfix}_1.pkl")) + # acc_val1, acc_test1, acc_val2, acc_test2 = co_test(model, data, patch, split_index) + # model.load_state_dict(torch.load(f"tem/weight_best_pretrain_{postfix}_2.pkl")) + # _, _, acc_val2, acc_test2 = co_test(model, data, patch, split_index) + return res_gnn, res_trans + + +def run_batch(args, config, device, data, patch, batch_size, split_idx, alpha, tau, postfix): + learning_rate = args.learning_rate + weight_decay = args.weight_decay + num_hidden = config['num_hidden'] + activation = ({'relu': F.relu, 'prelu': nn.PReLU()})[config['activation']] + num_layers = 1 + n_head = config['n_head'] + num_epochs = config['num_epochs'] + gcn_type = args.gcn_type + gcn_layers = args.gcn_layers + gcn_use_bn = args.gcn_use_bn + show_details = args.show_details + patch = patch + num_nodes = data.graph['num_nodes'] + num_classes = data.label.max() + 1 + num_features = data.graph['node_feat'].shape[-1] + + label = F.one_hot(data.label, num_classes).float() + + # model = Beyondformer(num_nodes, num_features, num_hidden, num_classes, activation, + # layers=num_layers, gnn_layers=gcn_layers, n_head=n_head, alpha=alpha, ratio=ratio).to(device) + model = CoBFormer(num_nodes, num_features, num_hidden, num_classes, activation, layers=num_layers, + gcn_layers=gcn_layers, gcn_type=gcn_type, n_head=n_head, alpha=alpha, tau=tau, + gcn_use_bn=gcn_use_bn).to(device) + # print(model) + + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + patience = num_epochs + + # best_epoch, acc_val, acc_test = early_stop_train(num_epochs, patience, model, data, label, patch, patch_adj, + # optimizer, show_details, postfix) + # print("=== Train Final ===") + # print( + # f"best_epoch: {best_epoch}, acc_val: {acc_val}, acc_test: {acc_test}") + + res_gnn, res_trans = co_early_stop_train_batch( + num_epochs, patience, + model, data, label, + patch, batch_size, split_idx, + optimizer, show_details, + device, postfix) + print("=== Train Final ===") + print( + f'micro_val1={res_gnn[0]:.4f}, micro_test1={res_gnn[1]:.4f}, macro_val1={res_gnn[2]:.4f}, macro_test1={res_gnn[3]:.4f}, micro_best1={res_gnn[4]:.4f}, macro_best1={res_gnn[5]:.4f},\n', + f'micro_val2={res_trans[0]:.4f}, micro_test2={res_trans[1]:.4f}, macro_val2={res_trans[2]:.4f}, macro_test2={res_trans[3]:.4f}, micro_best2={res_trans[4]:.4f}, macro_best2={res_trans[5]:.4f}\n') + + return res_gnn, res_trans diff --git a/examples/cobformer/runs.sh b/examples/cobformer/runs.sh new file mode 100644 index 00000000..83be04b4 --- /dev/null +++ b/examples/cobformer/runs.sh @@ -0,0 +1,7 @@ +python cobformer_trainer.py --dataset=Cora --learning_rate=0.01 --gcn_wd=1e-3 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=112 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=3 +python cobformer_trainer.py --dataset=CiteSeer --learning_rate=5e-3 --gcn_wd=1e-2 --weight_decay=5e-5 --gcn_type=1 --gcn_layers=2 --n_patch=144 --use_patch_attn --alpha=0.8 --tau=0.7 --gpu_id=3 +python cobformer_trainer.py --dataset=PubMed --learning_rate=5e-3 --gcn_wd=1e-3 --weight_decay=1e-3 --gcn_type=1 --gcn_layers=2 --n_patch=224 --use_patch_attn --alpha=0.7 --tau=0.3 --gpu_id=3 +python cobformer_trainer.py --dataset=film --learning_rate=5e-2 --gcn_wd=1e-4 --weight_decay=1e-3 --gcn_type=1 --gcn_layers=2 --n_patch=112 --use_patch_attn --alpha=0.7 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --dataset=deezer --learning_rate=0.01 --gcn_wd=1e-3 --weight_decay=5e-4 --gcn_type=1 --gcn_layers=2 --n_patch=224 --use_patch_attn --alpha=0.8 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --learning_rate=1e-3 --weight_decay=0. --dataset=ogbn-arxiv --gcn_use_bn --gcn_type=2 --gcn_layers=3 --n_patch=2048 --use_patch_attn --alpha=0.9 --tau=0.9 --gpu_id=3 +python cobformer_trainer.py --dataset=ogbn-products --learning_rate=5e-4 --weight_decay=0. --gcn_type=2 --gcn_layers=3 --gcn_use_bn --n_patch=8192 --use_patch_attn --batch_size=150000 --alpha=0.9 --tau=0.7 --gpu_id=5 \ No newline at end of file diff --git a/gammagl/layers/attention/BGA.py b/gammagl/layers/attention/BGA.py new file mode 100644 index 00000000..83dd2cb9 --- /dev/null +++ b/gammagl/layers/attention/BGA.py @@ -0,0 +1,35 @@ +from gammagl.models.ffn import * +from gammagl.models.gcn import * +from gammagl.layers.attention.BGA_layer import BGALayer +from torch_geometric.nn import GCNConv +import torch.nn as nn + + +class BGA(torch.nn.Module): + def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, + layers: int, n_head: int, use_patch_attn=True, dropout1=0.5, dropout2=0.1, need_attn=False): + super(BGA, self).__init__() + self.layers = layers + self.n_head = n_head + self.num_nodes = num_nodes + self.dropout = nn.Dropout(dropout1) + self.attribute_encoder = FFN(in_channels, hidden_channels) + self.BGALayers = nn.ModuleList() + for _ in range(0, layers): + self.BGALayers.append( + BGALayer(n_head, hidden_channels, use_patch_attn, dropout=dropout2)) + self.classifier = nn.Linear(hidden_channels, out_channels) + self.attn=[] + + def forward(self, x: torch.Tensor, patch: torch.Tensor, need_attn=False): + patch_mask = (patch != self.num_nodes - 1).float().unsqueeze(-1) + attn_mask = torch.matmul(patch_mask, patch_mask.transpose(1, 2)).int() + + x = self.attribute_encoder(x) + for i in range(0, self.layers): + x = self.BGALayers[i](x, patch, attn_mask, need_attn) + if need_attn: + self.attn.append(self.BGALayers[i].attn) + x = self.dropout(x) + x = self.classifier(x) + return x diff --git a/gammagl/layers/attention/BGA_layer.py b/gammagl/layers/attention/BGA_layer.py new file mode 100644 index 00000000..2b8f41fc --- /dev/null +++ b/gammagl/layers/attention/BGA_layer.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + # self.label_same_matrix = torch.load('analysis/label_same_matrix_citeseer.pt').float() + + def forward(self, q, k, v, mask=None): + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9) + # self.label_same_matrix = self.label_same_matrix.to(attn.device) + # attn = attn * self.label_same_matrix * 2 + attn * (1-self.label_same_matrix) + attn = self.dropout(F.softmax(attn, dim=-1)) + # attn = self.dropout(attn) + + output = torch.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, channels, dropout=0.1): + super(MultiHeadAttention, self).__init__() + + self.n_head = n_head + self.channels = channels + d_q = d_k = d_v = channels // n_head + + self.w_qs = nn.Linear(channels, channels, bias=False) + self.w_ks = nn.Linear(channels, channels, bias=False) + self.w_vs = nn.Linear(channels, channels, bias=False) + self.fc = nn.Linear(channels, channels, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + + self.dropout = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + n_head = self.n_head + d_q = d_k = d_v = self.channels // n_head + B_q = q.size(0) + N_q = q.size(1) + B_k = k.size(0) + N_k = k.size(1) + B_v = v.size(0) + N_v = v.size(1) + + residual = q + # x = self.dropout(q) + + # Pass through the pre-attention projection: B * N x (h*dv) + # Separate different heads: B * N x h x dv + q = self.w_qs(q).view(B_q, N_q, n_head, d_q) + k = self.w_ks(k).view(B_k, N_k, n_head, d_k) + v = self.w_vs(v).view(B_v, N_v, n_head, d_v) + + # Transpose for attention dot product: B * h x N x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + # For head axis broadcasting. + if mask is not None: + mask = mask.unsqueeze(1) + + q, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: B x N x h x dv + # Combine the last two dimensions to concatenate all the heads together: B x N x (h*dv) + q = q.transpose(1, 2).contiguous().view(B_q, N_q, -1) + q = self.fc(q) + q = q + residual + + return q, attn + + +class FFN(nn.Module): + ''' A two-feed-forward-layer module ''' + + def __init__(self, channels, dropout=0.1): + super(FFN, self).__init__() + self.lin1 = nn.Linear(channels, channels) # position-wise + self.lin2 = nn.Linear(channels, channels) # position-wise + self.layer_norm = nn.LayerNorm(channels, eps=1e-6) + self.Dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.layer_norm(x) + x = self.Dropout(x) + x = F.relu(self.lin1(x)) + x = self.lin2(x) + residual + + return x + + +class BGALayer(nn.Module): + def __init__(self, n_head, channels, use_patch_attn=True, dropout=0.1): + super(BGALayer, self).__init__() + self.node_norm = nn.LayerNorm(channels) + self.node_transformer = MultiHeadAttention(n_head, channels, dropout) + self.patch_norm = nn.LayerNorm(channels) + self.patch_transformer = MultiHeadAttention(n_head, channels, dropout) + self.node_ffn = FFN(channels, dropout) + self.patch_ffn = FFN(channels, dropout) + self.fuse_lin = nn.Linear(2 * channels, channels) + self.use_patch_attn = use_patch_attn + self.attn = None + + def forward(self, x, patch, attn_mask=None, need_attn=False): + x = self.node_norm(x) + patch_x = x[patch] + patch_x, attn = self.node_transformer(patch_x, patch_x, patch_x, attn_mask) + patch_x = self.node_ffn(patch_x) + if need_attn: + self.attn = torch.zeros((x.shape[0], x.shape[0])) + for i in tqdm(range(patch.shape[0])): + p = patch[i].tolist() + row = torch.tensor([p] * len(p)).T.flatten() + col = torch.tensor(p * len(p)) + a = attn[i].mean(0).flatten().cpu() + self.attn = self.attn.index_put((row, col), a) + + self.attn = self.attn[:-1][:, :-1].detach().cpu() + + if self.use_patch_attn: + p = self.patch_norm(patch_x.mean(dim=1, keepdim=False)).unsqueeze(0) + p, _ = self.patch_transformer(p, p, p) + p = self.patch_ffn(p).permute(1, 0, 2) + # + p = p.repeat(1, patch.shape[1], 1) + z = torch.cat([patch_x, p], dim=2) + patch_x = F.relu(self.fuse_lin(z)) + patch_x + + x[patch] = patch_x + + return x diff --git a/gammagl/models/CoBFormer.py b/gammagl/models/CoBFormer.py new file mode 100644 index 00000000..65c1ad81 --- /dev/null +++ b/gammagl/models/CoBFormer.py @@ -0,0 +1,45 @@ +import torch + +from gammagl.models.ffn import * +from gammagl.models.gcn import * +from gammagl.layers.attention.BGA import BGA + + +class CoBFormer(torch.nn.Module): + def __init__(self, num_nodes: int, in_channels: int, hidden_channels: int, out_channels: int, + activation, gcn_layers: int, gcn_type: int, layers: int, n_head: int, dropout1=0.5, dropout2=0.1, + alpha=0.8, tau=0.5, gcn_use_bn=False, use_patch_attn=True): + super(CoBFormer, self).__init__() + self.alpha = alpha + self.tau = tau + self.layers = layers + self.n_head = n_head + self.num_nodes = num_nodes + self.activation = activation + self.dropout = nn.Dropout(dropout1) + if gcn_type == 1: + self.gcn = GCN(in_channels, hidden_channels, out_channels, activation, k=gcn_layers, use_bn=gcn_use_bn) + else: + self.gcn = GraphConv(in_channels, hidden_channels, out_channels, num_layers=gcn_layers, use_bn=gcn_use_bn) + # self.gat = GAT(in_channels, hidden_channels, out_channels, activation, k=gcn_layers, use_bn=gcn_use_bn) + self.bga = BGA(num_nodes, in_channels, hidden_channels, out_channels, layers, n_head, + use_patch_attn, dropout1, dropout2) + self.attn = None + + def forward(self, x: torch.Tensor, patch: torch.Tensor, edge_index: torch.Tensor, need_attn=False): + z1 = self.gcn(x, edge_index) + z2 = self.bga(x, patch, need_attn) + if need_attn: + self.attn = self.beyondformer.attn + + return z1, z2 + + def loss(self, pred1, pred2, label, mask): + l1 = F.cross_entropy(pred1[mask], label[mask]) + l2 = F.cross_entropy(pred2[mask], label[mask]) + pred1 *= self.tau + pred2 *= self.tau + l3 = F.cross_entropy(pred1[~mask], F.softmax(pred2, dim=1)[~mask]) + l4 = F.cross_entropy(pred2[~mask], F.softmax(pred1, dim=1)[~mask]) + loss = self.alpha * (l1 + l2) + (1 - self.alpha) * (l3 + l4) + return loss \ No newline at end of file diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 0d066889..b137a652 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -68,6 +68,7 @@ from .adagad import PreModel, ReModel from .dyfss import MoeSSL,VGAE,Discriminator,InnerProductDecoder from .egt import EGTModel +from .CoBFormer import CoBFormer __all__ = [ 'HeCo', @@ -149,6 +150,7 @@ 'Discriminator', 'InnerProductDecoder', 'EGTModel', + 'CoBFormer', ] classes = __all__ diff --git a/gammagl/models/ffn.py b/gammagl/models/ffn.py new file mode 100644 index 00000000..b5219c3e --- /dev/null +++ b/gammagl/models/ffn.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def sparse_dropout(x: torch.Tensor, p: float, training: bool): + x = x.coalesce() + return torch.sparse_coo_tensor(x.indices(), F.dropout(x.values(), p=p, training=training), + size=x.size()) + + +class FFN(torch.nn.Module): + def __init__(self, in_channels: int, hidden_channels: int): + super(FFN, self).__init__() + self.lin1 = nn.Linear(in_channels, hidden_channels) + self.lin2 = nn.Linear(hidden_channels, hidden_channels) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + # x = self.dropout(x) + x = F.relu(self.lin1(x)) + return x + # return self.lin2(x) + + +class SparseFFN(torch.nn.Module): + def __init__(self, in_channels: int, hidden_channels: int): + super(SparseFFN, self).__init__() + self.lin1 = nn.Linear(in_channels, hidden_channels) + self.lin2 = nn.Linear(hidden_channels, hidden_channels) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = sparse_dropout(x, 0.5, self.training) + x = F.relu(self.lin1(x)) + return self.lin2(x) + x diff --git a/gammagl/models/gcn.py b/gammagl/models/gcn.py index cbed577f..05a6dbd4 100644 --- a/gammagl/models/gcn.py +++ b/gammagl/models/gcn.py @@ -1,6 +1,13 @@ import tensorlayerx as tlx import tensorlayerx.nn as nn from gammagl.layers.conv import GCNConv +import torch +from torch_geometric.nn import GCNConv as tGCNConv, GATConv +import torch.nn as tnn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +from torch_geometric.utils import degree + class GCNModel(tlx.nn.Module): @@ -63,3 +70,150 @@ def forward(self, x, edge_index, edge_weight, num_nodes): # x = self.conv2(x, edge_index, edge_weight, num_nodes) return x +class GCN(torch.nn.Module): + def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, + activation, k: int = 2, use_bn=False): + super(GCN, self).__init__() + assert k > 1, "k must > 1 !!" + self.use_bn = use_bn + self.k = k + self.conv = tnn.ModuleList([tGCNConv(in_channels, hidden_channels)]) + self.bns = tnn.ModuleList([tnn.BatchNorm1d(hidden_channels)]) + for _ in range(1, k - 1): + self.conv.append(tGCNConv(hidden_channels, hidden_channels)) + self.bns.append(tnn.BatchNorm1d(hidden_channels)) + self.conv.append(tGCNConv(hidden_channels, out_channels)) + # self.conv.append(tGCNConv(hidden_channels, hidden_channels)) + if activation is None: + self.activation = F.relu + else: + self.activation = activation + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor): + for i in range(self.k - 1): + # x = F.dropout(x, p=0.5, training=self.training) + x = self.conv[i](x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=0.5, training=self.training) + return self.conv[-1](x, edge_index) + + +class GraphConvLayer(tnn.Module): + def __init__(self, in_channels, out_channels, use_weight=True, use_init=False): + super(GraphConvLayer, self).__init__() + + self.use_init = use_init + self.use_weight = use_weight + if self.use_init: + in_channels_ = 2 * in_channels + else: + in_channels_ = in_channels + self.W = tnn.Linear(in_channels_, out_channels) + + def reset_parameters(self): + self.W.reset_parameters() + + def forward(self, x, edge_index, x0): + N = x.shape[0] + row, col = edge_index + d = degree(col, N).float() + d_norm_in = (1. / d[col]).sqrt() + d_norm_out = (1. / d[row]).sqrt() + value = torch.ones_like(row) * d_norm_in * d_norm_out + value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0) + adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N)) + x = matmul(adj, x) # [N, D] + + if self.use_init: + x = torch.cat([x, x0], 1) + x = self.W(x) + elif self.use_weight: + x = self.W(x) + + return x + + +class GraphConv(tnn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.5, use_bn=True, + use_residual=True, use_weight=True, use_init=False, use_act=True): + super(GraphConv, self).__init__() + + self.convs = tnn.ModuleList() + self.fcs = tnn.ModuleList() + self.fcs.append(tnn.Linear(in_channels, hidden_channels)) + + self.bns = tnn.ModuleList() + self.bns.append(tnn.BatchNorm1d(hidden_channels)) + for _ in range(num_layers): + self.convs.append( + GraphConvLayer(hidden_channels, hidden_channels, use_weight, use_init)) + self.bns.append(tnn.BatchNorm1d(hidden_channels)) + self.classifier = tnn.Linear(hidden_channels, out_channels) + self.dropout = dropout + self.activation = F.relu + self.use_bn = use_bn + self.use_residual = use_residual + self.use_act = use_act + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + for fc in self.fcs: + fc.reset_parameters() + + def forward(self, x, edge_index): + layer_ = [] + + x = self.fcs[0](x) + if self.use_bn: + x = self.bns[0](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + layer_.append(x) + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index, layer_[0]) + if self.use_bn: + x = self.bns[i + 1](x) + if self.use_act: + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + if self.use_residual: + x = x + layer_[-1] + # layer_.append(x) + return self.classifier(x) + # return x + + +class GAT(torch.nn.Module): + def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, + activation, n_heads=8, k: int = 2, use_bn=False): + super(GAT, self).__init__() + assert k > 1, "k must > 1 !!" + self.use_bn = use_bn + self.k = k + self.conv = tnn.ModuleList([GATConv(in_channels, hidden_channels// n_heads, heads=n_heads, dropout=0.6)]) + self.bns = tnn.ModuleList([tnn.BatchNorm1d(hidden_channels)]) + for _ in range(1, k - 1): + self.conv.append(GATConv(hidden_channels, hidden_channels // n_heads, heads=n_heads, dropout=0.6)) + self.bns.append(tnn.BatchNorm1d(hidden_channels)) + self.conv.append(GATConv(hidden_channels, out_channels, heads=8, concat=False, dropout=0.6)) + # self.conv.append(tGCNConv(hidden_channels, hidden_channels)) + self.activation = F.relu + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor): + for i in range(self.k - 1): + x = F.dropout(x, p=0.6, training=self.training) + x = self.conv[i](x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=0.6, training=self.training) + return self.conv[-1](x, edge_index) + +