diff --git a/lmms_eval/tasks/_task_utils/mcq_extract.py b/lmms_eval/tasks/_task_utils/mcq_extract.py new file mode 100644 index 000000000..92bee43af --- /dev/null +++ b/lmms_eval/tasks/_task_utils/mcq_extract.py @@ -0,0 +1,147 @@ +"""Robust multiple-choice answer extraction. + +Shared utility for benchmark tasks that need to extract a choice letter +(A/B/C/D/...) from free-form model output. Handles 10+ common answer +formats and uses a priority ranking to pick the best candidate. + +Usage:: + + from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer + + letter = extract_mcq_answer("The correct answer is (B).") # -> "B" +""" + +import re +from typing import List, Optional + +_DEFAULT_CHOICES = ["A", "B", "C", "D", "E", "F", "G", "H"] + +_ANSWER_PHRASES = [ + "the answer is", + "answer is", + "the correct answer is", + "correct answer is", + "the best answer is", + "best answer is", + "the correct option is", + "correct option is", + "the best option is", + "best option is", + "the choice is", + "choice is", + "the correct choice is", + "correct choice is", + "i choose", + "i select", + "i pick", + "my answer is", + "my choice is", +] + +# Higher = more confident that this is the intended answer. +_FORMAT_PRIORITY = { + "start": 10, + "end": 9, + "phrase": 7, + "parentheses": 6, + "period": 5, + "colon": 4, + "right_paren": 3, + "space": 2, + "fallback": 0, +} + + +def extract_mcq_answer(response: str, choices: Optional[List[str]] = None) -> str: + """Extract a multiple-choice answer letter from model output. + + Searches for choice letters in various common formats and returns the + best candidate using a priority ranking. When multiple candidates + match, prefers the **last** occurrence in the **highest-priority** + format — this naturally handles reasoning-style outputs where the + model discusses options before giving its final answer. + + Args: + response: Model output (should already have ```` tags + stripped by the postprocessing pipeline). + choices: Valid choice letters. Defaults to ``["A".."H"]``. + + Returns: + Uppercase choice letter, or ``""`` if none found. + """ + if not response or not response.strip(): + return "" + + all_choices = choices or _DEFAULT_CHOICES + + text = response.strip() + for char in [",", ".", "!", "?", ";", ":", "'", '"']: + text = text.strip(char) + # Pad with spaces for boundary matching. + text = " " + text + " " + + candidates: list = [] # (letter, position, format_name) + + # --- (A) --- + for ch in all_choices: + if f"({ch})" in text: + candidates.append((ch, text.rfind(f"({ch})"), "parentheses")) + + # --- A. --- + for ch in all_choices: + if f"{ch}." in text: + candidates.append((ch, text.rfind(f"{ch}."), "period")) + + # --- A: --- + for ch in all_choices: + if f"{ch}:" in text: + candidates.append((ch, text.rfind(f"{ch}:"), "colon")) + + # --- A) --- + for ch in all_choices: + if f"{ch})" in text: + candidates.append((ch, text.rfind(f"{ch})"), "right_paren")) + + # --- A followed by space --- + for ch in all_choices: + if f"{ch} " in text: + candidates.append((ch, text.rfind(f"{ch} "), "space")) + + # --- Common answer phrases ("the answer is A", etc.) --- + text_lower = text.lower() + for phrase in _ANSWER_PHRASES: + idx = text_lower.find(phrase) + if idx != -1: + after = idx + len(phrase) + for ch in all_choices: + ch_pos = text.find(ch, after) + if ch_pos != -1: + candidates.append((ch, ch_pos, "phrase")) + + # --- Starts with standalone choice letter (not part of a word) --- + stripped = text.strip() + for ch in all_choices: + if stripped.startswith(ch) and (len(stripped) == 1 or not stripped[1].isalpha()): + candidates.append((ch, 0, "start")) + + # --- Ends with standalone choice letter --- + for ch in all_choices: + if stripped.endswith(ch) and (len(stripped) == 1 or not stripped[-2].isalpha()): + candidates.append((ch, len(text) - 1, "end")) + + # --- Fallback: any occurrence (lowest priority) --- + if not candidates: + for ch in all_choices: + if ch in text: + candidates.append((ch, text.rfind(ch), "fallback")) + + if not candidates: + return "" + + # Sort by (priority DESC, position DESC) — highest-priority format + # wins; within the same format, later position (closer to end) wins. + candidates.sort( + key=lambda x: (_FORMAT_PRIORITY.get(x[2], 0), x[1]), + reverse=True, + ) + return candidates[0][0] diff --git a/lmms_eval/tasks/mmbench/en_utils.py b/lmms_eval/tasks/mmbench/en_utils.py index 67546f2e3..1055c075f 100755 --- a/lmms_eval/tasks/mmbench/en_utils.py +++ b/lmms_eval/tasks/mmbench/en_utils.py @@ -120,6 +120,21 @@ def mmbench_aggregate_dev_results_eval(results, args): return overall_acc * 100 +def mmbench_aggregate_dev_results_static(results, args): + """Static eval using regex/substring MCQ extraction — no OpenAI API needed.""" + print("============= MMBench-EN(Dev) Static Eval =============") + overall_acc, category_acc, l2_category_acc = mmbench_evaluator.eval_result(results, eval_method="static") + file = generate_submission_file("mmbench_en_dev_static_results.json", args) + details_info = { + "overall_acc": overall_acc, + "category_acc": category_acc, + "l2_category_acc": l2_category_acc, + } + with open(file, "w") as f: + json.dump(details_info, f) + return overall_acc * 100 + + def mmbench_aggregate_dev_results_submission(results, args): df = pd.DataFrame(results) excel_write_path = generate_submission_file("mmbench_en_dev_results.xlsx", args) diff --git a/lmms_eval/tasks/mmbench/mmbench_en_dev_static.yaml b/lmms_eval/tasks/mmbench/mmbench_en_dev_static.yaml new file mode 100644 index 000000000..1207a5de6 --- /dev/null +++ b/lmms_eval/tasks/mmbench/mmbench_en_dev_static.yaml @@ -0,0 +1,10 @@ +task: "mmbench_en_dev_static" +test_split: dev +include: _default_template_mmbench_en_yaml +metric_list: + - metric: gpt_eval_score + aggregation: !function en_utils.mmbench_aggregate_dev_results_static + higher_is_better: true + - metric: submission + aggregation: !function en_utils.mmbench_aggregate_dev_results_submission + higher_is_better: true diff --git a/lmms_eval/tasks/mmbench/mmbench_evals.py b/lmms_eval/tasks/mmbench/mmbench_evals.py index 83a067f68..13478cc9b 100755 --- a/lmms_eval/tasks/mmbench/mmbench_evals.py +++ b/lmms_eval/tasks/mmbench/mmbench_evals.py @@ -203,7 +203,7 @@ def extract_answer_from_item(self, item): return chars[tmp], "Failed to predict, thus randomly generate one. " # Extract answer from multiple rolling records - def eval_sub_data(self, sub_data, answer_map): + def eval_sub_data(self, sub_data, answer_map, static_only=False): lt = len(sub_data) GT, PRED = [], [] for i in range(lt): @@ -217,6 +217,15 @@ def eval_sub_data(self, sub_data, answer_map): for i in range(lt): if PRED[i]: continue + elif static_only: + # Use robust MCQ extraction instead of GPT API + from lmms_eval.tasks._task_utils.mcq_extract import extract_mcq_answer + + choices = self.build_choices(sub_data.iloc[i]) + choice_list = sorted(choices.keys()) + PRED[i] = extract_mcq_answer(sub_data.iloc[i]["prediction"], choices=choice_list) + if not PRED[i] or PRED[i] != GT[i]: + return 0 else: ret, _ = self.extract_answer_from_item(sub_data.iloc[i]) PRED[i] = ret @@ -242,7 +251,8 @@ def calculate_hit_rates(self, data): # Evaluate Results def eval_result(self, results, eval_method): rd.seed(2680) - assert eval_method == "openai" + static_only = eval_method == "static" + assert eval_method in ("openai", "static") # Set a large retry number to avoid failure # model = OpenAI('gpt-3.5-turbo-0613', retry=99) @@ -286,7 +296,7 @@ def eval_result(self, results, eval_method): continue sub_data = data[data["index"] % int(1e6) == idx] - ret = self.eval_sub_data(sub_data, answer_map) + ret = self.eval_sub_data(sub_data, answer_map, static_only=static_only) result[idx] = ret hit += ret tot += 1