diff --git a/tifascore/question_gen_llama2.py b/tifascore/question_gen_llama2.py index ff2b40f..fe0cdde 100644 --- a/tifascore/question_gen_llama2.py +++ b/tifascore/question_gen_llama2.py @@ -56,16 +56,20 @@ def parse_resp(resp): for line_number in range(6, len(resp)): line = resp[line_number] if line.startswith('About '): - whole_line = line[len('About '):-1] - this_entity = whole_line.split(' (')[0] - this_type = whole_line.split(' (')[1].split(')')[0] + if '(' in line and ')' in line: + whole_line = line[len('About '):-1] + this_entity = whole_line.split(' (')[0] + this_type = whole_line.split(' (')[1].split(')')[0] elif line.startswith('Q: '): - this_question = line[3:] + if len(line) > 3: + this_question = line[3:] elif line.startswith('Choices: '): - this_choices = line[9:].split(', ') + if len(line) > 9 and ',' in line: + this_choices = line[9:].split(', ') elif line.startswith('A: '): - this_answer = line[3:] + if len(line) > 3: + this_answer = line[3:] if this_entity and this_question and this_choices: question_instances.append( diff --git a/tifascore/tifa_score.py b/tifascore/tifa_score.py index 2109e24..da4f690 100644 --- a/tifascore/tifa_score.py +++ b/tifascore/tifa_score.py @@ -3,6 +3,7 @@ from tqdm import tqdm from .vqa_models import VQAModel from statistics import mean, stdev +import copy def tifa_score_benchmark(vqa_model_name, question_answer_path, id2img_path): @@ -85,7 +86,7 @@ def tifa_score_single(vqa_model, question_answer_pairs, img_path): # read the question, choices, and answers if question_answer_pair['question'] not in question_logs: - question_logs[question_answer_pair['question']] = question_answer_pair + question_logs[question_answer_pair['question']] = copy.deepcopy(question_answer_pair) choices=question_answer_pair['choices'] # get VQA answer