From b520e52a4072131bab84ae294db9f57abe20b6aa Mon Sep 17 00:00:00 2001 From: ConstFr Date: Sat, 21 Feb 2026 11:44:49 +0400 Subject: [PATCH 1/2] added hidden dynamics method --- .gitignore | 2 + .../base_processing_mmlu_reasoning.yaml | 6 + .../single_sequence_estimators.yaml | 34 ++++ .../mmlu_reasoning.py | 7 + examples/configs/model/llama3-8b.yaml | 10 + .../configs/polygraph_eval_mmlu_fewshot.yaml | 47 +++++ scripts/polygraph_eval | 48 ++--- .../register_default_stat_calculators.py | 1 + src/lm_polygraph/estimators/__init__.py | 1 + .../estimators/hidden_dynamics.py | 176 ++++++++++++++++++ src/lm_polygraph/stat_calculators/__init__.py | 1 + .../stat_calculators/all_hidden_states.py | 76 ++++++++ src/lm_polygraph/utils/factory_estimator.py | 1 + 13 files changed, 386 insertions(+), 24 deletions(-) create mode 100644 examples/configs/base_processing_mmlu_reasoning.yaml create mode 100644 examples/configs/estimators/single_sequence_estimators.yaml create mode 100644 examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py create mode 100644 examples/configs/model/llama3-8b.yaml create mode 100644 examples/configs/polygraph_eval_mmlu_fewshot.yaml create mode 100644 src/lm_polygraph/estimators/hidden_dynamics.py create mode 100644 src/lm_polygraph/stat_calculators/all_hidden_states.py diff --git a/.gitignore b/.gitignore index 238da6459..e1d3dba51 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,5 @@ cython_debug/ # Agent instructions **/AGENTS.md +*.sbatch +sbatch_logs/ diff --git a/examples/configs/base_processing_mmlu_reasoning.yaml b/examples/configs/base_processing_mmlu_reasoning.yaml new file mode 100644 index 000000000..ab402fa20 --- /dev/null +++ b/examples/configs/base_processing_mmlu_reasoning.yaml @@ -0,0 +1,6 @@ +process_output_fn: + path: instruct/output_processing_scripts/mmlu_reasoning.py + fn_name: process_output_mmlu_reasoning +process_target_fn: + path: instruct/output_processing_scripts/mmlu_reasoning.py + fn_name: process_target_mmlu_reasoning \ No newline at end of file diff --git a/examples/configs/estimators/single_sequence_estimators.yaml b/examples/configs/estimators/single_sequence_estimators.yaml new file mode 100644 index 000000000..62d7a3caf --- /dev/null +++ b/examples/configs/estimators/single_sequence_estimators.yaml @@ -0,0 +1,34 @@ +- name: MaximumSequenceProbability +- name: Perplexity +- name: MeanTokenEntropy +# - name: AttentionScore +# cfg: +# gen_only: False +# - name: RAUQ +# cfg: +# alpha: 0.2 +# use_entropy: False +# - name: RAUQ +# cfg: +# alpha: 0.8 +# use_entropy: True +- name: HiddenDynamics + cfg: + layer_idx: 19 + head_idx: -1 +- name: HiddenDynamics + cfg: + layer_idx: -2 + head_idx: 19 +# - name: HiddenDynamics +# cfg: +# layer_idx: 19 +# head_idx: 19 +# - name: HiddenDynamics +# cfg: +# layer_idx: 21 +# head_idx: -1 +# - name: HiddenDynamics +# cfg: +# layer_idx: 21 +# head_idx: 21 diff --git a/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py b/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py new file mode 100644 index 000000000..34f9a3b56 --- /dev/null +++ b/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py @@ -0,0 +1,7 @@ +def process_output_mmlu_reasoning(output: str) -> str: + if "Final Answer: " in output: + output = output.split("Final Answer:")[1].split("\n")[0] + return output.lower() + +def process_target_mmlu_reasoning(target: str) -> str: + return target.lower() diff --git a/examples/configs/model/llama3-8b.yaml b/examples/configs/model/llama3-8b.yaml new file mode 100644 index 000000000..49390ca3e --- /dev/null +++ b/examples/configs/model/llama3-8b.yaml @@ -0,0 +1,10 @@ +defaults: + - default + +path: meta-llama/Llama-3.1-8B +type: CausalLM +path_to_load_script: model/default_causal.py + +load_model_args: + device_map: auto +load_tokenizer_args: {} \ No newline at end of file diff --git a/examples/configs/polygraph_eval_mmlu_fewshot.yaml b/examples/configs/polygraph_eval_mmlu_fewshot.yaml new file mode 100644 index 000000000..b74089cab --- /dev/null +++ b/examples/configs/polygraph_eval_mmlu_fewshot.yaml @@ -0,0 +1,47 @@ +hydra: + run: + dir: ${cache_path}/${task}/${model}/${dataset}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +defaults: + - model: llama3-8b + - estimators: single_sequence_estimators + - stat_calculators: default_calculators + - base_processing_mmlu_reasoning + - _self_ + + +cache_path: ./workdir/output +save_path: '${hydra:run.dir}' +instruct: false +task: qa + +dataset: ['denis1699/mmlu_reasoning'] +text_column: question +label_column: answer +train_split: train +eval_split: test +few_shot_prompt: null +max_new_tokens: 512 +load_from_disk: false +trust_remote_code: false + + +subsample_eval_dataset: 1000 + +generation_metrics: null + +generation_params: + generate_until: + - "Q:" + - "Question:" + - "\n\n" + +ignore_exceptions: false + +batch_size: 1 + +stat_calculator: + batch_size: 1 + +seed: + - 1 diff --git a/scripts/polygraph_eval b/scripts/polygraph_eval index de8b80a48..b96103d2b 100755 --- a/scripts/polygraph_eval +++ b/scripts/polygraph_eval @@ -181,10 +181,10 @@ def main(args): def get_ue_metrics(args): ue_metrics = [ - PredictionRejectionArea(), + # PredictionRejectionArea(), PredictionRejectionArea(max_rejection=0.5), - IsotonicPCC(), - ECE(normalize=True), + # IsotonicPCC(), + # ECE(normalize=True), ] if getattr(args, "use_claim_ue", False): ue_metrics += [ @@ -255,33 +255,33 @@ def get_generation_metrics(args): ignore_regex = getattr(args, "source_ignore_regex", None) if not generation_metrics: result = [ - RougeMetric("rouge1"), - RougeMetric("rouge2"), - RougeMetric("rougeL"), - BLEUMetric(), - BertScoreMetric(), - SbertMetric(), + # RougeMetric("rouge1"), + # RougeMetric("rouge2"), + # RougeMetric("rougeL"), + # BLEUMetric(), + # BertScoreMetric(), + # SbertMetric(), AccuracyMetric( target_ignore_regex=getattr(args, "target_ignore_regex", None), output_ignore_regex=getattr(args, "output_ignore_regex", None), normalize=getattr(args, "normalize", False), ), ] - if args.task == "ats": - result += [AlignScore(target_is_claims=False, source_ignore_regex=ignore_regex, source_as_target=True)] - else: - result += [AlignScore(target_is_claims=True)] - if getattr(args.model, "type", "Whitebox") != "Blackbox": - if getattr(args, "use_claim_ue", False): - result += [ - OpenAIFactCheck( - cache_path=args.cache_path, - language=getattr(args, "language", "en"), - n_threads=getattr(args, "n_threads", 1), - ) - ] - if args.task == "nmt": - result += [Comet(source_ignore_regex=ignore_regex)] + # if args.task == "ats": + # result += [AlignScore(target_is_claims=False, source_ignore_regex=ignore_regex, source_as_target=True)] + # else: + # result += [AlignScore(target_is_claims=True)] + # if getattr(args.model, "type", "Whitebox") != "Blackbox": + # if getattr(args, "use_claim_ue", False): + # result += [ + # OpenAIFactCheck( + # cache_path=args.cache_path, + # language=getattr(args, "language", "en"), + # n_threads=getattr(args, "n_threads", 1), + # ) + # ] + # if args.task == "nmt": + # result += [Comet(source_ignore_regex=ignore_regex)] else: result = [] for metric in generation_metrics: diff --git a/src/lm_polygraph/defaults/register_default_stat_calculators.py b/src/lm_polygraph/defaults/register_default_stat_calculators.py index f3ddabd82..64d190e36 100644 --- a/src/lm_polygraph/defaults/register_default_stat_calculators.py +++ b/src/lm_polygraph/defaults/register_default_stat_calculators.py @@ -148,6 +148,7 @@ def _register( }, ) _register(AttentionForwardPassCalculator) + _register(AllHiddenStatesCalculator) elif model_type == "VisualLM": _register( GreedyProbsVisualCalculator, diff --git a/src/lm_polygraph/estimators/__init__.py b/src/lm_polygraph/estimators/__init__.py index ee93389a5..97bf99165 100644 --- a/src/lm_polygraph/estimators/__init__.py +++ b/src/lm_polygraph/estimators/__init__.py @@ -90,3 +90,4 @@ from .rauq import RAUQ from .csl import CSL from .semantic_density import SemanticDensity +from .hidden_dynamics import HiddenDynamics diff --git a/src/lm_polygraph/estimators/hidden_dynamics.py b/src/lm_polygraph/estimators/hidden_dynamics.py new file mode 100644 index 000000000..a42efc573 --- /dev/null +++ b/src/lm_polygraph/estimators/hidden_dynamics.py @@ -0,0 +1,176 @@ +import numpy as np +import logging +from typing import Dict, Optional, Tuple + +from .estimator import Estimator + +log = logging.getLogger(__name__) + + +def _try_savgol(y: np.ndarray, window_length: int, polyorder: int) -> np.ndarray: + """ + Savitzky–Golay smoothing if scipy is available, else a simple moving average fallback. + """ + if window_length <= 1: + return y.copy() + + try: + from scipy.signal import savgol_filter # type: ignore + if len(y) < window_length: + return y.copy() + return savgol_filter(y, window_length=window_length, polyorder=polyorder, mode="interp") + except Exception: + if len(y) < window_length: + return y.copy() + kernel = np.ones(window_length, dtype=np.float32) / float(window_length) + return np.convolve(y, kernel, mode="same") + + +class HiddenDynamics(Estimator): + """ + Hidden Dynamics UQ (Sampling-free uncertainty via hidden state dynamics). + + Returns: + - sequence-level score (shape [1]) when aggregation is enabled. + """ + + def __init__( + self, + alpha: float = 0.5, + beta: float = 0.5, + sg_window_halfwidth: int = 5, + sg_polyorder: int = 3, + delta_layer: int = 10, + curvature_eps: float = 1e-6, + layer_idx = -1, + head_idx = -1, + use_last_step_attention: bool = True, + ): + dependencies = ["all_hidden_states", "all_attentions", "greedy_tokens", "prompt_len"] + super().__init__(dependencies, "sequence") + + self.alpha = float(alpha) + self.beta = float(beta) + self.w = int(sg_window_halfwidth) + self.polyorder = int(sg_polyorder) + self.delta = int(delta_layer) + self.eps = float(curvature_eps) + self.layer_idx = int(layer_idx) + self.head_idx = int(head_idx) + self.use_last_step_attention = use_last_step_attention + + def __str__(self) -> str: + return f"HiddenDynamics_layer_{self.layer_idx}_head_{self.head_idx}" + + def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: + hs = stats["all_hidden_states"] # [L+1, B, T, D] + att = stats["all_attentions"] # [L, B, H, T, T] + prompt_len = int(stats["prompt_len"]) + + log.info(f"HiddenDynamics: all_hidden_states shape {hs.shape}, all_attentions shape {att.shape}, prompt_len {prompt_len}") + log.info(f"{self.layer_idx=}, {self.head_idx=}") + + if hs.ndim != 4: + raise ValueError(f"Expected all_hidden_states shape [L+1,B,T,D], got {hs.shape}") + if att.ndim != 5: + raise ValueError(f"Expected all_attentions shape [L,B,H,T,T], got {att.shape}") + + Lp1, B, T, D = hs.shape + L = Lp1 - 1 + + if B != 1: + raise NotImplementedError("Batch size > 1 not supported") + + # Answer span indices + ans_start = min(max(prompt_len, 0), T) + + # Drop embedding layer: hs_layers [L, T, D] + hs_layers = hs[1:, 0, :, :] # [L, T, D] + h_final = hs_layers[-1] # [T, D] + + # LSR score s[l, t] = cos(h_l(t), h_final(t)) -> [L, T] + # cosine similarity + num = np.sum(hs_layers * h_final[None, :, :], axis=-1) # [L, T] + den = (np.linalg.norm(hs_layers, axis=-1) * np.linalg.norm(h_final, axis=-1)[None, :]) # [L, T] + s_scores = num / np.clip(den, 1e-12, None) + + # Token UQ only for answer tokens + token_uq = np.zeros((T - ans_start,), dtype=np.float32) + + for i, t in enumerate(range(ans_start, T)): + traj = s_scores[:, t].astype(np.float64) # [L] + tau = self._compute_tau(traj) # 1-based + + sss = self._sss(traj, tau) + ccs = self._ccs(traj, tau) + + token_uq[i] = self.alpha * sss + self.beta * ccs + + if self.layer_idx == -2: + # If no attention head is provided, return mean over answer tokens + return np.array([token_uq.mean()]) + + useq = self._sequence_uq_from_attention( + token_uq=token_uq, + attentions=att, + ans_start=ans_start, + ) + return np.array([useq]) + + def _compute_tau(self, s_traj) -> int: + L = int(s_traj.shape[0]) + + win = 2 * self.w + 1 + s_hat = _try_savgol(s_traj, window_length=win, polyorder=self.polyorder) + + s1 = np.zeros_like(s_hat) + s2 = np.zeros_like(s_hat) + s1[:-1] = s_hat[1:] - s_hat[:-1] + s2[:-2] = s1[1:-1] - s1[:-2] + + kappa = np.abs(s2) / np.power(1.0 + (s1 * s1) + self.eps, 1.5) + + start_idx = max(self.delta - 1, 0) + if start_idx >= L: + start_idx = 0 + + tau_idx = np.argmax(kappa[start_idx:]) + start_idx # 0-based + return tau_idx + 1 # 1-based + + def _sss(self, s_traj, tau_1based) -> float: + tau = max(int(tau_1based), 1) + pre = s_traj[:tau] + mu = float(np.mean(pre)) + var = float(np.mean((pre - mu) ** 2)) + return float(1.0 - var) + + def _ccs(self, s_traj, tau_1based) -> float: + L = int(s_traj.shape[0]) + tau = min(max(int(tau_1based), 1), L) + if tau >= L: + return 0.0 + + diffs = s_traj[tau:] - s_traj[tau - 1 : L - 1] + return float(np.mean(diffs)) + + def _sequence_uq_from_attention( + self, + token_uq, # [T_answer] + attentions, # [L, B, H, T, T] + ans_start, + ) -> float: + if self.layer_idx == -2: + return float(token_uq.mean()) + A = attentions[self.layer_idx, 0, self.head_idx] # [T, T] + + if self.use_last_step_attention: + attn_vec = A[-1, :] # [T] + else: + attn_vec = A[ans_start:, :].mean(axis=0) # [T] + + attn_ans = attn_vec[ans_start:] # [T_answer] + + w = np.exp(attn_ans - np.max(attn_ans)) + w = w / np.clip(w.sum(), 1e-12, None) + + return float(np.sum(w * token_uq)) diff --git a/src/lm_polygraph/stat_calculators/__init__.py b/src/lm_polygraph/stat_calculators/__init__.py index b3dcc0852..ae9f5422f 100644 --- a/src/lm_polygraph/stat_calculators/__init__.py +++ b/src/lm_polygraph/stat_calculators/__init__.py @@ -61,3 +61,4 @@ from .infer_causal_lm_calculator import InferCausalLMCalculator from .semantic_classes import SemanticClassesCalculator from .attention_forward_pass_visual import AttentionForwardPassCalculatorVisual +from .all_hidden_states import AllHiddenStatesCalculator diff --git a/src/lm_polygraph/stat_calculators/all_hidden_states.py b/src/lm_polygraph/stat_calculators/all_hidden_states.py new file mode 100644 index 000000000..1894bcbb1 --- /dev/null +++ b/src/lm_polygraph/stat_calculators/all_hidden_states.py @@ -0,0 +1,76 @@ +import torch +import numpy as np + +from typing import Dict, List, Tuple + +from .stat_calculator import StatCalculator +from lm_polygraph.utils.model import WhiteboxModel + + +class AllHiddenStatesCalculator(StatCalculator): + @staticmethod + def meta_info() -> Tuple[List[str], List[str]]: + """ + Returns the statistics and dependencies for the calculator. + """ + # Depends on greedy generation result, and also needs the prompt text to compute prompt_len + return ["all_hidden_states", "all_attentions", "prompt_len"], ["greedy_texts", "greedy_tokens"] + + def __init__(self): + super().__init__() + + @torch.no_grad() + def __call__( + self, + dependencies: Dict[str, np.array], + texts: List[str], + model: WhiteboxModel, + max_new_tokens: int = 100, + ) -> Dict[str, np.ndarray]: + # doesn't support bs size > 1 + if len(texts) > 1: + raise NotImplementedError("Batch size > 1 not supported for AllHiddenStatesCalculator") + + device = model.device() + + prompt_text = texts[0] + prompt_ids = model.tokenizer( + prompt_text, + return_tensors="pt", + padding=True, + add_special_tokens=True, + )["input_ids"].to(device) # [1, T_prompt] + prompt_len = prompt_ids.shape[1] + + gen_ids = torch.tensor(dependencies["greedy_tokens"], dtype=torch.long).to(device) # [1, T_gen] + + # Full sequence = prompt + generated + full_ids = torch.cat([prompt_ids, gen_ids], dim=1) # [1, T_total] + + fw = model( + input_ids=full_ids, + output_hidden_states=True, + output_attentions=True, + use_cache=False, + return_dict=True, + ) + + # hidden_states: tuple length (L+1), each [B, T, D] -> stack to [L+1, B, T, D] + # hidden_states = torch.stack(list(fw.hidden_states), dim=0) + hidden_states = torch.stack( + [h.detach().to("cpu") for h in fw.hidden_states], + dim=0, + ) + + # attentions: tuple length L, each [B, H, T, T] -> stack to [L, B, H, T, T] + # attentions = torch.stack(list(fw.attentions), dim=0) + attentions = torch.stack( + [a.detach().to("cpu", dtype=torch.float16) for a in fw.attentions], + dim=0, + ) + + return { + "all_hidden_states": hidden_states.detach().cpu().numpy(), + "all_attentions": attentions.detach().cpu().numpy(), + "prompt_len": prompt_len, + } diff --git a/src/lm_polygraph/utils/factory_estimator.py b/src/lm_polygraph/utils/factory_estimator.py index 7a8a52bfb..13cf1ab25 100644 --- a/src/lm_polygraph/utils/factory_estimator.py +++ b/src/lm_polygraph/utils/factory_estimator.py @@ -60,6 +60,7 @@ def load_simple_estimators(name: str, config): CSL, SemanticDensity, BoostedProbSequence, + HiddenDynamics, ] try: From b2eae35a88e1711575d95bcd371d4b3d61606437 Mon Sep 17 00:00:00 2001 From: ConstFr Date: Sun, 22 Feb 2026 12:38:53 +0400 Subject: [PATCH 2/2] changed dataset due to the input token limit --- .../mmlu_reasoning.py | 19 +++++++++++++++---- .../configs/polygraph_eval_mmlu_fewshot.yaml | 7 +++---- .../register_default_stat_calculators.py | 2 +- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py b/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py index 34f9a3b56..a4c66840e 100644 --- a/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py +++ b/examples/configs/instruct/output_processing_scripts/mmlu_reasoning.py @@ -1,7 +1,18 @@ +import re + def process_output_mmlu_reasoning(output: str) -> str: - if "Final Answer: " in output: - output = output.split("Final Answer:")[1].split("\n")[0] - return output.lower() + """ + Extract the FIRST answer letter (a/b/c/d) + after '### Answer:' and ignore everything after. + """ + match = re.search( + r"###\s*answer:\s*([a-dA-D])\b", + output, + flags=re.IGNORECASE + ) + if match: + return match.group(1).lower() + return output.strip().lower() def process_target_mmlu_reasoning(target: str) -> str: - return target.lower() + return target.strip().lower() \ No newline at end of file diff --git a/examples/configs/polygraph_eval_mmlu_fewshot.yaml b/examples/configs/polygraph_eval_mmlu_fewshot.yaml index b74089cab..f11bdd8e4 100644 --- a/examples/configs/polygraph_eval_mmlu_fewshot.yaml +++ b/examples/configs/polygraph_eval_mmlu_fewshot.yaml @@ -15,11 +15,11 @@ save_path: '${hydra:run.dir}' instruct: false task: qa -dataset: ['denis1699/mmlu_reasoning'] +dataset: ['UGRIP-LM-Polygraph/mmlu-reasoning'] text_column: question label_column: answer -train_split: train -eval_split: test +train_split: test +eval_split: validation few_shot_prompt: null max_new_tokens: 512 load_from_disk: false @@ -34,7 +34,6 @@ generation_params: generate_until: - "Q:" - "Question:" - - "\n\n" ignore_exceptions: false diff --git a/src/lm_polygraph/defaults/register_default_stat_calculators.py b/src/lm_polygraph/defaults/register_default_stat_calculators.py index 64d190e36..fcba8e74d 100644 --- a/src/lm_polygraph/defaults/register_default_stat_calculators.py +++ b/src/lm_polygraph/defaults/register_default_stat_calculators.py @@ -53,7 +53,7 @@ def _register( "deberta_path": deberta_model_path, "hf_cache": hf_cache, "batch_size": deberta_batch_size, - "device": None, + "device": 1, } _register(InitialStateCalculator)