From b8e602b8a74ffad2b3da9f0c691aa44801c9989a Mon Sep 17 00:00:00 2001 From: Timur Khairulov Date: Fri, 13 Dec 2024 14:15:33 +0900 Subject: [PATCH] In this commit 2 changes have been introduced: 1. Due to generated sequence length limitation, errors occured during answer parsing stage. So, additional cheks were added to avoid unexpected parsing errors. 2. As mentioned in https://github.com/Yushi-Hu/tifa/issues/5 change has been made for correct dictionary handling. --- tifascore/question_gen_llama2.py | 16 ++++++++++------ tifascore/tifa_score.py | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tifascore/question_gen_llama2.py b/tifascore/question_gen_llama2.py index ff2b40f..fe0cdde 100644 --- a/tifascore/question_gen_llama2.py +++ b/tifascore/question_gen_llama2.py @@ -56,16 +56,20 @@ def parse_resp(resp): for line_number in range(6, len(resp)): line = resp[line_number] if line.startswith('About '): - whole_line = line[len('About '):-1] - this_entity = whole_line.split(' (')[0] - this_type = whole_line.split(' (')[1].split(')')[0] + if '(' in line and ')' in line: + whole_line = line[len('About '):-1] + this_entity = whole_line.split(' (')[0] + this_type = whole_line.split(' (')[1].split(')')[0] elif line.startswith('Q: '): - this_question = line[3:] + if len(line) > 3: + this_question = line[3:] elif line.startswith('Choices: '): - this_choices = line[9:].split(', ') + if len(line) > 9 and ',' in line: + this_choices = line[9:].split(', ') elif line.startswith('A: '): - this_answer = line[3:] + if len(line) > 3: + this_answer = line[3:] if this_entity and this_question and this_choices: question_instances.append( diff --git a/tifascore/tifa_score.py b/tifascore/tifa_score.py index 2109e24..da4f690 100644 --- a/tifascore/tifa_score.py +++ b/tifascore/tifa_score.py @@ -3,6 +3,7 @@ from tqdm import tqdm from .vqa_models import VQAModel from statistics import mean, stdev +import copy def tifa_score_benchmark(vqa_model_name, question_answer_path, id2img_path): @@ -85,7 +86,7 @@ def tifa_score_single(vqa_model, question_answer_pairs, img_path): # read the question, choices, and answers if question_answer_pair['question'] not in question_logs: - question_logs[question_answer_pair['question']] = question_answer_pair + question_logs[question_answer_pair['question']] = copy.deepcopy(question_answer_pair) choices=question_answer_pair['choices'] # get VQA answer