Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/graphtranslator/Producer/Embeddings_GraphSAGE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import time
from ogb.nodeproppred import PygNodePropPredDataset
import torch

from gammagl.data import Graph
from gammagl.loader import LinkNeighborLoader
from gammagl.layers.conv import SAGEConv
import tensorlayerx.nn as nn
import tensorlayerx as tlx
from tensorlayerx.model import WithLoss, TrainOneStep
import os
os.environ['TL_BACKEND'] = "torch"

bert_node_embeddings = torch.load("../data/bert_node_embeddings.pt")

dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./arxiv/')

edge_index = dataset[0].edge_index
row, col = edge_index[0], edge_index[1]

src_node = tlx.concat((tlx.convert_to_tensor(row), tlx.convert_to_tensor(col)), axis=0)
dst_node = tlx.concat((tlx.convert_to_tensor(col), tlx.convert_to_tensor(row)), axis=0)
edge_index = tlx.stack((src_node, dst_node), axis=0)


split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

graph = Graph(x=bert_node_embeddings, edge_index=edge_index, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx, y=dataset[0].y.squeeze())
train_loader = LinkNeighborLoader(
graph,
batch_size=65536,
shuffle=True,
neg_sampling_ratio=1.0,
num_neighbors=[10, 10],
edge_label_index=graph.edge_index,
edge_label=None
)


class Net(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim):
super(Net, self).__init__()
self.conv1 = SAGEConv(in_channels=in_dim,
out_channels=hid_dim)
self.conv2 = SAGEConv(in_channels=hid_dim,
out_channels=out_dim)
self.act = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = self.act(x)
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = Net(768, 1024, 768).to(device)

class LinkPredictionLoss(WithLoss):
def __init__(self, net, loss_fn):
super(LinkPredictionLoss, self).__init__(backbone=net, loss_fn=loss_fn)

def forward(self, data, label):
h = self._backbone(data['x'], data['edge_index'])
h_src = tlx.gather(h, data['edge_label_index'][0])
h_dst = tlx.gather(h, data['edge_label_index'][1])
pred = tlx.reduce_sum(h_src * h_dst, axis=-1)
loss = self._loss_fn(output=pred, target=label)
return loss

def train():

optimizer = tlx.optimizers.Adam(0.01)
loss_func = LinkPredictionLoss(model, tlx.losses.sigmoid_cross_entropy)
train_one_step = TrainOneStep(loss_func, optimizer, model.trainable_weights)

total_loss = 0
total_num = 0
for batch in train_loader:
data = {'x': batch.x.to(device),
'edge_index': batch.edge_index.to(device),
'edge_label_index': batch.edge_label_index.to(device)
}
model.set_train()
loss = train_one_step(data, batch.edge_label.to(device))
total_loss += float(loss) * batch.edge_label.shape[0]
total_num += batch.edge_label.shape[0]

return total_loss/ total_num


best_acc = 0
for epoch in range(10):
start = time.time()
loss = train()
print("loss:", loss)

out = model(graph.x, graph.edge_index)
if os.environ['TL_BACKEND'] == "torch":
torch.save(out, "../../data/graphsage_node_embeddings.pt")

194 changes: 194 additions & 0 deletions examples/graphtranslator/Producer/producer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import random
import csv
import argparse
import time
import logging
import os
import sys
import pandas as pd
import numpy as np
import torch
import tensorlayerx.nn as nn
import tensorlayerx as tlx
from transformers import AutoTokenizer, AutoModel




def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--llm_checkpoint', type=str, default="../Translator/models/chatglm2-6b", required=False)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--distributed", action='store_const', default=False, const=True)
parser.add_argument('--random_seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--num_workers", default=1, type=int)

return parser.parse_args()

def init_seeds(distributed, seed=0):
tlx.set_seed(seed)
random.seed(seed)
np.random.seed(seed)


if seed == 0:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def setup_logging():
logging_formatter = logging.Formatter("%(asctime)s-%(levelname)s-%(message)s")
# Setup common logger
root = logging.getLogger()
root.setLevel(logging.INFO)

handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.INFO)
handler.setFormatter(logging_formatter)
root.addHandler(handler)

args = parse_args()

if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
device = torch.device('cuda:0')
else:
device = torch.device('cpu')

def read_arxiv_dataset():
# paperid to node的映射和node to paperid的映射
node2paperid = {}
paperid2node = {}
with open('../data/arxiv_nodeidx2paperid.csv', 'r') as file:
reader = csv.reader(file)
next(reader)
for row in reader:
nodeIdx = int(row[0])
paperId = int(row[1])
node2paperid[nodeIdx] = paperId
paperid2node[paperId] = nodeIdx

# 读取paperId到title和abstract映射的内容
paperId2titleAndabs = pd.read_csv("../data/titleabs.tsv", delimiter='\t', header=None)
paperId2titleAndabs = paperId2titleAndabs.rename(columns={0: "paper_id", 1: "title", 2: "abstract"})
paperId2titleAndabs['node_id'] = paperId2titleAndabs['paper_id'].map(paperid2node).fillna(-1).astype(int)
paperId2titleAndabs["title_abstract"] = "Title: " + paperId2titleAndabs["title"] + "\n" +"Abstract: " + paperId2titleAndabs["abstract"]
paperId2titleAndabs = paperId2titleAndabs[paperId2titleAndabs['node_id'] != -1]

paperId2titleAndabs = paperId2titleAndabs.replace('≤', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('≥', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('≠', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('≠', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('∫', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('∞', '', regex=True)
paperId2titleAndabs = paperId2titleAndabs.replace('√', '', regex=True)

sorted_paperId2titleAndabs = paperId2titleAndabs.sort_values(by='node_id')
sample_neighbor_df = pd.read_csv("../data/sample_neighbor_df.csv")

return sorted_paperId2titleAndabs, sample_neighbor_df


class LLM(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self._args = args
# tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self._args.llm_checkpoint, trust_remote_code=True)
# model
self.llm = AutoModel.from_pretrained(self._args.llm_checkpoint, trust_remote_code=True).half().to(device)

def inference_chatglm_arxiv(self, arxiv_data, sample_neighbor_df):
self.llm.eval()

node_title_and_abs = arxiv_data.set_index('node_id')['title_abstract'].to_dict()
src_to_dst_dict = sample_neighbor_df.groupby('src_node')['dst_node'].apply(list).to_dict()
node2title = arxiv_data.set_index('node_id')['title'].to_dict()

print(f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))} total paper count: {arxiv_data.shape[0]}")
summary = []
total = 0
for data in arxiv_data.iterrows():
node_id = data[1]['node_id']
title = data[1]['title']
src_prompt_pre = "The title and abstract of this paper are as follows: "
src_prompt = '\n please summarize this paper and list five key words of this paper. All answers are in English and No Chinese in your answer'
src_title_abstract = data[1]['title_abstract']
node_word_input = src_prompt_pre + src_title_abstract
if len(node_word_input[0]) > 3000- len(src_prompt):
node_word_input = node_word_input[:3000-len(src_prompt)]
node_word_input += src_prompt

dst_prompt_pre = '\n The paper title and abstract are provided as follows: '
dst_prompt = "\n Please summarize the topic and content of these papers. All answers are in English and No Chinese in your answer"
dst_title_abstract = ""
for neighbor_id in src_to_dst_dict[node_id]:
dst_title_abstract = dst_title_abstract + node_title_and_abs[neighbor_id] + '\n'

neighbor_word_input = dst_prompt_pre + dst_title_abstract
if len(neighbor_word_input[0]) > 3000-len(dst_prompt):
neighbor_word_input = neighbor_word_input[:3000-len(dst_prompt)]
neighbor_word_input += dst_prompt

try:
response_node, _ = self.llm.chat(self.tokenizer,
node_word_input ,
history=[])
response_neighbor, _ = self.llm.chat(self.tokenizer,
neighbor_word_input,
history=[])
summary.append({
'node_id': node_id,
'title': title,
'response_node': response_node,
'response_neighbor': response_neighbor
})
print(f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))} paper {node_id+1} title: \"{title}\"")
except RuntimeError as exception:
if "out of memory" in str(exception):
print("CUDA out of memory error detected, skipping this batch")
continue
else:
continue
total += 1
if total == 6480:
break

summary_df = pd.DataFrame(summary)
embeddings = torch.load("../data/graphsage_node_embeddings.pt").to('cpu')
new_data = []
for _, row in summary_df.iterrows():
node_id = int(row['node_id'])
embedding = np.array(embeddings[node_id].detach())
str_array = [str(num) for num in embedding]
str_representation = ", ".join(str_array)
title = node2title[row['node_id']]

new_data.append({
'node_id': node_id,
'embedding':str_representation ,
'paper_summary':row['response_node'],
'citepapers_summary':row['response_neighbor'],
'title':title
})
summary_embeddings = pd.DataFrame(new_data)
summary_embeddings.to_csv('../../data/summary_embeddings_0.csv',index=False)


def main():
setup_logging()
init_seeds(args.distributed, args.random_seed)

logging.info("Main arguments:")
for k, v in args.__dict__.items():
logging.info("{}={}".format(k, v))


# load model
model = LLM(args)
logging.info('start inference')
arxiv_data, sample_neighbor_df = read_arxiv_dataset()
model.inference_chatglm_arxiv(arxiv_data, sample_neighbor_df)


if __name__ == "__main__":
main()
Loading