Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions lmms_eval/tasks/_task_utils/mcq_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Robust multiple-choice answer extraction.

Shared utility for benchmark tasks that need to extract a choice letter
(A/B/C/D/...) from free-form model output. Handles 10+ common answer
formats and uses a priority ranking to pick the best candidate.

Usage::

from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer

letter = extract_mcq_answer("The correct answer is (B).") # -> "B"
"""

import re
from typing import List, Optional

_DEFAULT_CHOICES = ["A", "B", "C", "D", "E", "F", "G", "H"]

_ANSWER_PHRASES = [
"the answer is",
"answer is",
"the correct answer is",
"correct answer is",
"the best answer is",
"best answer is",
"the correct option is",
"correct option is",
"the best option is",
"best option is",
"the choice is",
"choice is",
"the correct choice is",
"correct choice is",
"i choose",
"i select",
"i pick",
"my answer is",
"my choice is",
]

# Higher = more confident that this is the intended answer.
_FORMAT_PRIORITY = {
"start": 10,
"end": 9,
"phrase": 7,
"parentheses": 6,
"period": 5,
"colon": 4,
"right_paren": 3,
"space": 2,
"fallback": 0,
}


def extract_mcq_answer(response: str, choices: Optional[List[str]] = None) -> str:
"""Extract a multiple-choice answer letter from model output.

Searches for choice letters in various common formats and returns the
best candidate using a priority ranking. When multiple candidates
match, prefers the **last** occurrence in the **highest-priority**
format — this naturally handles reasoning-style outputs where the
model discusses options before giving its final answer.

Args:
response: Model output (should already have ``<think>`` tags
stripped by the postprocessing pipeline).
choices: Valid choice letters. Defaults to ``["A".."H"]``.

Returns:
Uppercase choice letter, or ``""`` if none found.
"""
if not response or not response.strip():
return ""

all_choices = choices or _DEFAULT_CHOICES

text = response.strip()
for char in [",", ".", "!", "?", ";", ":", "'", '"']:
text = text.strip(char)
# Pad with spaces for boundary matching.
text = " " + text + " "

candidates: list = [] # (letter, position, format_name)

# --- (A) ---
for ch in all_choices:
if f"({ch})" in text:
candidates.append((ch, text.rfind(f"({ch})"), "parentheses"))

# --- A. ---
for ch in all_choices:
if f"{ch}." in text:
candidates.append((ch, text.rfind(f"{ch}."), "period"))

# --- A: ---
for ch in all_choices:
if f"{ch}:" in text:
candidates.append((ch, text.rfind(f"{ch}:"), "colon"))

# --- A) ---
for ch in all_choices:
if f"{ch})" in text:
candidates.append((ch, text.rfind(f"{ch})"), "right_paren"))

# --- A followed by space ---
for ch in all_choices:
if f"{ch} " in text:
candidates.append((ch, text.rfind(f"{ch} "), "space"))

# --- Common answer phrases ("the answer is A", etc.) ---
text_lower = text.lower()
for phrase in _ANSWER_PHRASES:
idx = text_lower.find(phrase)
if idx != -1:
after = idx + len(phrase)
for ch in all_choices:
ch_pos = text.find(ch, after)
if ch_pos != -1:
candidates.append((ch, ch_pos, "phrase"))

# --- Starts with standalone choice letter (not part of a word) ---
stripped = text.strip()
for ch in all_choices:
if stripped.startswith(ch) and (len(stripped) == 1 or not stripped[1].isalpha()):
candidates.append((ch, 0, "start"))

# --- Ends with standalone choice letter ---
for ch in all_choices:
if stripped.endswith(ch) and (len(stripped) == 1 or not stripped[-2].isalpha()):
candidates.append((ch, len(text) - 1, "end"))

# --- Fallback: any occurrence (lowest priority) ---
if not candidates:
for ch in all_choices:
if ch in text:
candidates.append((ch, text.rfind(ch), "fallback"))

if not candidates:
return ""

# Sort by (priority DESC, position DESC) — highest-priority format
# wins; within the same format, later position (closer to end) wins.
candidates.sort(
key=lambda x: (_FORMAT_PRIORITY.get(x[2], 0), x[1]),
reverse=True,
)
return candidates[0][0]
15 changes: 15 additions & 0 deletions lmms_eval/tasks/mmbench/en_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ def mmbench_aggregate_dev_results_eval(results, args):
return overall_acc * 100


def mmbench_aggregate_dev_results_static(results, args):
"""Static eval using regex/substring MCQ extraction — no OpenAI API needed."""
print("============= MMBench-EN(Dev) Static Eval =============")
overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="static")
file = generate_submission_file("mmbench_en_dev_static_results.json", args)
details_info = {
"overall_acc": overall_acc,
"category_acc": category_acc,
"l2_category_acc": l2_category_acc,
}
with open(file, "w") as f:
json.dump(details_info, f)
return overall_acc * 100


def mmbench_aggregate_dev_results_submission(results, args):
df = pd.DataFrame(results)
excel_write_path = generate_submission_file("mmbench_en_dev_results.xlsx", args)
Expand Down
10 changes: 10 additions & 0 deletions lmms_eval/tasks/mmbench/mmbench_en_dev_static.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
task: "mmbench_en_dev_static"
test_split: dev
include: _default_template_mmbench_en_yaml
metric_list:
- metric: gpt_eval_score
aggregation: !function en_utils.mmbench_aggregate_dev_results_static
higher_is_better: true
- metric: submission
aggregation: !function en_utils.mmbench_aggregate_dev_results_submission
higher_is_better: true
16 changes: 13 additions & 3 deletions lmms_eval/tasks/mmbench/mmbench_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def extract_answer_from_item(self, item):
return chars[tmp], "Failed to predict, thus randomly generate one. "

# Extract answer from multiple rolling records
def eval_sub_data(self, sub_data, answer_map):
def eval_sub_data(self, sub_data, answer_map, static_only=False):
lt = len(sub_data)
GT, PRED = [], []
for i in range(lt):
Expand All @@ -217,6 +217,15 @@ def eval_sub_data(self, sub_data, answer_map):
for i in range(lt):
if PRED[i]:
continue
elif static_only:
# Use robust MCQ extraction instead of GPT API
from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer

choices = self.build_choices(sub_data.iloc[i])
choice_list = sorted(choices.keys())
PRED[i] = extract_mcq_answer(sub_data.iloc[i]["prediction"], choices=choice_list)
if not PRED[i] or PRED[i] != GT[i]:
return 0
else:
ret, _ = self.extract_answer_from_item(sub_data.iloc[i])
PRED[i] = ret
Expand All @@ -242,7 +251,8 @@ def calculate_hit_rates(self, data):
# Evaluate Results
def eval_result(self, results, eval_method):
rd.seed(2680)
assert eval_method == "openai"
static_only = eval_method == "static"
assert eval_method in ("openai", "static")
# Set a large retry number to avoid failure
# model = OpenAI('gpt-3.5-turbo-0613', retry=99)

Expand Down Expand Up @@ -286,7 +296,7 @@ def eval_result(self, results, eval_method):
continue

sub_data = data[data["index"] % int(1e6) == idx]
ret = self.eval_sub_data(sub_data, answer_map)
ret = self.eval_sub_data(sub_data, answer_map, static_only=static_only)
result[idx] = ret
hit += ret
tot += 1
Expand Down
Loading