Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 258 additions & 0 deletions xpotato/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import json
import re
from re import I
from typing import Dict, List, Tuple
import spacy
import amrlib

import networkx as nx
import pandas as pd
from networkx.readwrite import json_graph
from tqdm import tqdm
from tuw_nlp.graph.utils import check_if_str_is_penman, graph_to_pn
from tuw_nlp.grammar.text_to_4lang import TextTo4lang
from tuw_nlp.grammar.text_to_ud import TextToUD
from tuw_nlp.graph.amr_graph import AMRGraph

from xpotato.dataset.sample import Sample
from xpotato.graph_extractor.extract import GraphExtractor
Expand Down Expand Up @@ -44,6 +50,258 @@ def save_dataframe(df: pd.DataFrame, path: str) -> None:
df["graph"] = graphs
df.to_csv(path, index=False, sep="\t")

@staticmethod
def generate_dataframe_ud_fl(df, sentence_colname="", label_id_colname="", sentence_id_colname=""):
"""
Submit a dataframe and compute a response dataframe with entity-marked graphs
Expects columns b1,e1 b2,e2 to provide begin/end character positions of the respective entities 1 and 2.
:param df: The dataframe containing the input sentence and all necessary conversion information
:param sentence_colname: which column is the sentence in
:param label_id_colname: which column is the label in
:param sentence_id_colname: which column is the unique sentence id in
:return: a new dataframe with the above information, converted and entity-tagged UD and FL graphs as well as a
report for how many and which entity tokens could be mapped to their respective FL nodes
"""
def is_allcaps(text):
return not re.search(r'[a-z]', text)

extractor = GraphExtractor(
lang="en", cache_dir=None, cache_fn=None
)

ud_parser = TextToUD(
lang=extractor.lang, nlp_cache=extractor.cache_fn, cache_dir=extractor.cache_dir
)

fl_parser = TextTo4lang(
lang=extractor.lang, nlp_cache=extractor.cache_fn, cache_dir=extractor.cache_dir
)

rows = df.iterrows()
sentences = [row[1][sentence_colname] for row in rows]
rows = df.iterrows()
label_ids = [row[1][label_id_colname] for row in rows]
rows = df.iterrows()
sentence_ids = [row[1][sentence_id_colname] for row in rows]

ud_graph_list = []
fl_graphs = []
reports = []
e1_found = []
e2_found = []

for i, sent in enumerate(sentences):

# For each graph keep track of which entity tokens are linked to nodes, and which couldn't be found
entity1_dict = {}
entity2_dict = {}
tokens_e1 = 0
tokens_e1_found = 0
tokens_e2 = 0
tokens_e2_found = 0
e1_not_found = []
e2_not_found = []

# Parse UD
ud_graphs = list(ud_parser(sent))
num_sents = len(ud_graphs)

# If there are more than one sentences, take the last one. This seems to discard only "et al." introductions, which are irrelevant
ud_graph = ud_graphs[num_sents - 1]
fl_graph = list(fl_parser(sent))[num_sents - 1]

# 4lang graph nodes have a different index but the same token IDs as the UD graph they are built from, so we need this mapping
node_ids = {}
for idx in fl_graph.G.nodes:
t_id = fl_graph.G.nodes[idx]['token_id']
node_ids[t_id] = idx

# Go through the tokens and check for each token if it is part of a relationship entity. If yes, mark the node.
for t in ud_graph.ud_graph.tokens:
# The node is an entity node if the associated start and end character positions fall within
# the entity bounds defined in the dataframe OR if one of them does and the word is ALLCAPS.
# This second condition is necessary because the CrowdTruth dataset gives us inaccurate indexes.

# Entity 1
if (t.start_char >= df.iloc[i].b1 and t.end_char <= df.iloc[i].e1) \
or (df.iloc[i].b1 <= t.start_char <= df.iloc[i].e1 and is_allcaps(t.text)) \
or (df.iloc[i].b1 <= t.end_char <= df.iloc[i].e1 and is_allcaps(t.text)):
# Tag UD
id = t.id[0]
ud_graph.G.nodes[id]["entity"] = 1

# If the FL graph has a node with the same ID, we can tag it as well
tokens_e1 += 1
if id in node_ids.keys():
node_id = node_ids[id]
fl_graph.G.nodes[node_id]["entity"] = 1
tokens_e1_found += 1
entity1_dict[t.text] = fl_graph.G.nodes[node_id]
else:
e1_not_found.append(t.text)

# Entity 2
elif (t.start_char >= df.iloc[i].b2 and t.end_char <= df.iloc[i].e2) or (
t.start_char >= df.iloc[i].b2 and t.start_char <= df.iloc[i].e2 and is_allcaps(t.text)) or (
t.end_char >= df.iloc[i].b2 and t.end_char <= df.iloc[i].e2 and is_allcaps(t.text)):
# Tag UD
id = t.id[0]
ud_graph.G.nodes[id]["entity"] = 2

# If the FL graph has a node with the same ID, we can tag it as well
tokens_e2 += 1
if id in node_ids.keys():
node_id = node_ids[id]
fl_graph.G.nodes[node_id]["entity"] = 2
tokens_e2_found += 1
entity2_dict[t.text] = fl_graph.G.nodes[node_id]
else:
e2_not_found.append(t.text)

ud_graph_list.append(ud_graph.G)
fl_graphs.append(fl_graph.G)

# Build a report of missing or correctly mapped tokens for the FL graph
reports.append(f"Entity 1: Found {tokens_e1_found} / {tokens_e1} token nodes\n" +
f"{entity1_dict}\n" +
f"Not found: {e1_not_found}\n\n" +
f"Entity 2: Found {tokens_e2_found} / {tokens_e2} token nodes\n" +
f"{entity2_dict}\n" +
f"Not found: {e2_not_found}\n\n")
e1_found.append(0.0 if tokens_e1 == 0 else tokens_e1_found / tokens_e1)
e2_found.append(0.0 if tokens_e2 == 0 else tokens_e2_found / tokens_e2)

df_parsed = pd.DataFrame(
{
"SID": sentence_ids,
"text": sentences,
"label_id": label_ids,
"ud": ud_graph_list,
"fl": fl_graphs,
"report_fl": reports,
"e1_found_fl": e1_found,
"e2_found_fl": e2_found
}
)

return df_parsed

@staticmethod
def generate_dataframe_amr(df, sentence_colname="", label_id_colname="", sentence_id_colname=""):
"""
Submit a dataframe and compute a response dataframe with entity-marked graphs
Expects columns b1,e1 b2,e2 to provide begin/end character positions of the respective entities 1 and 2.
:param df: The dataframe containing the input sentence and all necessary conversion information
:param sentence_colname: which column is the sentence in
:param label_id_colname: which column is the label in
:param sentence_id_colname: which column is the unique sentence id in
:return: a new dataframe with the above information, converted and entity-tagged AMR graphs as well as a
report for how many and which entity tokens could be mapped to their respective AMR nodes
"""
def is_allcaps(text):
return not re.search(r'[a-z]', text)

# For each graph keep track of which entity tokens are linked to nodes, and which couldn't be found
rows = df.iterrows()
sentences = [row[1][sentence_colname] for row in rows]
rows = df.iterrows()
label_ids = [row[1][label_id_colname] for row in rows]
rows = df.iterrows()
sentence_ids = [row[1][sentence_id_colname] for row in rows]

amr_graph_list = []
reports = []
e1_found = []
e2_found = []

# For my conversions I have been using the amrlib 0.8.0 xfm bart large model
amr_stog = amrlib.load_stog_model()
# Load the same spacy model that TUW NLP uses to build the graphs to get a comparable mapping
spacy_nlp = spacy.load('en_core_web_sm')
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we are loading/doing the parsing redundantly. I think if this information is not available from the AMR graph the TUW library provides, we should modify that code instead and just call it here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, then we would need to adapt the AMRGraph class itself. Currently, it does not provide a link from token to character position in the source text, and that is generally how entities are marked. That's why the spacy method is called again, so we can reproduce the source text to token mapping that is usually done within the AMR Graph conversion process.


for i, sent in enumerate(sentences):

# For each graph keep track of which entity tokens are linked to nodes, and which couldn't be found
entity1_dict = {}
entity2_dict = {}
tokens_e1 = 0
tokens_e1_found = 0
tokens_e2 = 0
tokens_e2_found = 0
e1_not_found = []
e2_not_found = []

# Parse UD
pn_graphs = amr_stog.parse_sents([sent])
amr_graph = AMRGraph(pn_graphs[0], sent)

# Map the nodes to their respective tokens
token_to_node = {}
for idx in amr_graph.G.nodes:
t_id = amr_graph.G.nodes[idx]['token_id']
if t_id is not None:
token_to_node[t_id] = idx

# Unlike UD we don't get a neat mapping between token ID and the start/end characters.
# So we redo the original spacy AMR conversion to get the token character info from there.
doc = spacy_nlp(sent)
indices = [(t.idx, t.idx + len(t)) for t in doc]

tokens = json.loads(amr_graph.tokens)

# Go through the tokens and check for each token if it is an entity. If yes, mark the node.
for token_num, character_idx in enumerate(indices):
start_char = character_idx[0]
end_char = character_idx[1]
tok = tokens[token_num]
if (start_char >= df.iloc[i].b1 and end_char <= df.iloc[i].e1) \
or (df.iloc[i].b1 <= start_char <= df.iloc[i].e1 and is_allcaps(tok)) \
or (df.iloc[i].b1 <= end_char <= df.iloc[i].e1 and is_allcaps(tok)):
tokens_e1 += 1
if token_num in token_to_node.keys():
node_id = token_to_node[token_num]
amr_graph.G.nodes[node_id]["entity"] = 1
tokens_e1_found += 1
entity1_dict[tokens[token_num]] = amr_graph.G.nodes[node_id]
else:
e1_not_found.append(tokens[token_num])
if (start_char >= df.iloc[i].b2 and end_char <= df.iloc[i].e2) \
or (df.iloc[i].b2 <= start_char <= df.iloc[i].e2 and is_allcaps(tok)) \
or (df.iloc[i].b2 <= end_char <= df.iloc[i].e2 and is_allcaps(tok)):
tokens_e2 += 1
if token_num in token_to_node.keys():
node_id = token_to_node[token_num]
amr_graph.G.nodes[node_id]["entity"] = 2
tokens_e2_found += 1
entity2_dict[tokens[token_num]] = amr_graph.G.nodes[node_id]
else:
e2_not_found.append(tokens[token_num])

amr_graph_list.append(amr_graph.G)

reports.append(f"Entity 1: Found {tokens_e1_found} / {tokens_e1} token nodes\n" +
f"{entity1_dict}\n" +
f"Not found: {e1_not_found}\n\n" +
f"Entity 2: Found {tokens_e2_found} / {tokens_e2} token nodes\n" +
f"{entity2_dict}\n" +
f"Not found: {e2_not_found}\n\n")
e1_found.append(0.0 if tokens_e1 == 0 else tokens_e1_found / tokens_e1)
e2_found.append(0.0 if tokens_e2 == 0 else tokens_e2_found / tokens_e2)

df_parsed = pd.DataFrame(
{
"SID": sentence_ids,
"text": sentences,
"label_id": label_ids,
"amr": amr_graph_list,
"report_amr": reports,
"e1_found_amr": e1_found,
"e2_found_amr": e2_found
}
)

return df_parsed

def prune_graphs(self, graphs: List[nx.DiGraph] = None) -> None:
graphs_str = []
for i, graph in enumerate(graphs):
Expand Down
10 changes: 9 additions & 1 deletion xpotato/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
from collections import defaultdict

import networkx as nx
Expand Down Expand Up @@ -43,8 +44,10 @@ def default_pn_to_graph(raw_dl, edge_attr="color"):
G = nx.DiGraph()

char_to_id = defaultdict(int)
char_to_entity = dict()
next_id = 0
for i, trip in enumerate(g.triples):
#print(f"Potato: {trip}")
if i == 0:
root_id = next_id
name = trip[2]
Expand All @@ -54,6 +57,9 @@ def default_pn_to_graph(raw_dl, edge_attr="color"):
char_to_id[trip[0]] = next_id
next_id += 1

elif re.match("entity", trip[0].split('_')[0]):
char_to_entity[trip[0]] = trip[2]

elif trip[1] == ":instance":
if trip[2]:
name = trip[2]
Expand All @@ -64,7 +70,9 @@ def default_pn_to_graph(raw_dl, edge_attr="color"):
next_id += 1

for trip in g.triples:
if trip[1] != ":instance":
if re.match(r":entity", trip[1]):
G.nodes[char_to_id[trip[0]]]["entity"] = int(char_to_entity[trip[2]])
elif trip[1] != ":instance":
edge = trip[1].split(":")[1]
src = trip[0]
tgt = trip[2]
Expand Down
1 change: 0 additions & 1 deletion xpotato/graph_extractor/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from xpotato.dataset.utils import amr_pn_to_graph, default_pn_to_graph, ud_to_graph


class GraphExtractor:
def __init__(self, cache_dir=None, cache_fn=None, lang=None):
if cache_dir is None:
Expand Down