diff --git a/examples/basic_example.ipynb b/examples/basic_example.ipynb index 66d746c5d..d3caf67de 100644 --- a/examples/basic_example.ipynb +++ b/examples/basic_example.ipynb @@ -180,9 +180,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python [conda env:.mlspace-focus_new]", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-env-.mlspace-focus_new-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -194,7 +194,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/examples/configs/base_processing_hotpot.yaml b/examples/configs/base_processing_hotpot.yaml new file mode 100644 index 000000000..489adc232 --- /dev/null +++ b/examples/configs/base_processing_hotpot.yaml @@ -0,0 +1,6 @@ +process_output_fn: + path: instruct/output_processing_scripts/hotpot.py + fn_name: process_output_cot_hotpot +process_target_fn: + path: instruct/output_processing_scripts/hotpot.py + fn_name: process_target_cot_hotpot \ No newline at end of file diff --git a/examples/configs/estimators/cot_estimators.yaml b/examples/configs/estimators/cot_estimators.yaml new file mode 100644 index 000000000..41aa129ef --- /dev/null +++ b/examples/configs/estimators/cot_estimators.yaml @@ -0,0 +1,13 @@ +- name: MaximumSequenceProbability +- name: Perplexity +- name: MeanTokenEntropy +- name: MeanPointwiseMutualInformation +- name: MeanConditionalPointwiseMutualInformation +- name: PTrue +- name: PTrueSampling +- name: MonteCarloSequenceEntropy +- name: MonteCarloNormalizedSequenceEntropy +- name: EigenScore +- name: RenyiNeg +- name: FisherRao +- name: ProbasMeanWithCoT diff --git a/examples/configs/estimators/default_estimators.yaml b/examples/configs/estimators/default_estimators.yaml index 41a40e079..477da0631 100644 --- a/examples/configs/estimators/default_estimators.yaml +++ b/examples/configs/estimators/default_estimators.yaml @@ -82,4 +82,5 @@ trust_remote_code: True idf_seed: 42 idf_dataset_size: -1 - spacy_path: "en_core_web_sm" \ No newline at end of file + spacy_path: "en_core_web_sm" +- name: ProbasMeanWithCoT diff --git a/examples/configs/instruct/output_processing_scripts/hotpot.py b/examples/configs/instruct/output_processing_scripts/hotpot.py new file mode 100644 index 000000000..a1bcd9c9c --- /dev/null +++ b/examples/configs/instruct/output_processing_scripts/hotpot.py @@ -0,0 +1,15 @@ +import re +import string + +CoT_OUTPUT_IGNORE_REGEX = re.compile(r"(?s).*Final Answer:") + +def process_output_cot_hotpot(output: str) -> str: + output = CoT_OUTPUT_IGNORE_REGEX.sub("", output).lower().strip() + output = output.translate(str.maketrans("", "", string.punctuation)) + return output + +def process_target_cot_hotpot(target: str) -> str: + target = target.lower().strip() + target = target.translate(str.maketrans("", "", string.punctuation)) + + return target diff --git a/examples/configs/polygraph_eval_cot_hotpot.yaml b/examples/configs/polygraph_eval_cot_hotpot.yaml new file mode 100644 index 000000000..6d874f8a1 --- /dev/null +++ b/examples/configs/polygraph_eval_cot_hotpot.yaml @@ -0,0 +1,38 @@ +hydra: + run: + dir: ${cache_path}/${task}/${model}/${dataset}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +defaults: + - model: bloomz-560m + - estimators: cot_estimators + - stat_calculators: default_calculators + - base_processing_hotpot + - _self_ + +cache_path: ./workdir/output +save_path: '${hydra:run.dir}' +instruct: true +task: qa + +dataset: ['denis1699/hotpot_cot'] +text_column: question +label_column: answer +train_split: train +eval_split: validation +few_shot_prompt: null +max_new_tokens: 384 +load_from_disk: false +trust_remote_code: false +size: 100 + + +subsample_eval_dataset: 20 + +generation_metrics: null + +ignore_exceptions: false + +batch_size: 1 + +seed: + - 1 diff --git a/examples/reasoning_example.ipynb b/examples/reasoning_example.ipynb new file mode 100644 index 000000000..1ae3d7c39 --- /dev/null +++ b/examples/reasoning_example.ipynb @@ -0,0 +1,631 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "6958a441", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" # see issue #152\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", + "\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", + "from lm_polygraph.estimators import *\n", + "from lm_polygraph.utils.model import WhiteboxModel\n", + "from lm_polygraph.utils.dataset import Dataset\n", + "from lm_polygraph.utils.processor import Logger\n", + "from lm_polygraph.utils.manager import UEManager\n", + "from lm_polygraph.ue_metrics import PredictionRejectionArea\n", + "from lm_polygraph.generation_metrics import RougeMetric, BartScoreSeqMetric, ModelScoreSeqMetric, ModelScoreTokenwiseMetric, AggregatedMetric\n", + "from lm_polygraph.utils.builder_enviroment_stat_calculator import (\n", + " BuilderEnvironmentStatCalculator\n", + ")\n", + "from lm_polygraph.defaults.register_default_stat_calculators import (\n", + " register_default_stat_calculators,\n", + ")\n", + "from lm_polygraph.utils.factory_stat_calculator import StatCalculatorContainer\n", + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "markdown", + "id": "5025e26e-fd7f-44b6-88d7-5876439a5ab0", + "metadata": {}, + "source": [ + "# Specify HyperParameters" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7111f938-bc8c-4b82-82a1-fce490bc8e4a", + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"bigscience/bloomz-560m\"\n", + "model_path = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "device = \"cuda\"\n", + "model_type = \"Whitebox\"\n", + "dataset_name = \"denis1699/hotpot_cot\"\n", + "batch_size = 1\n", + "seed = 42" + ] + }, + { + "cell_type": "markdown", + "id": "757a3862-77d1-4bb4-8423-1f86f3a58b54", + "metadata": {}, + "source": [ + "# Initialize Model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e7a7afe", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8b41e2f8f6334c8785ffa023bd7c474b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00 instead of \n", " 'selector': 'td:hover',\n", " 'props': [('background-color', '#ffffb3')]\n", @@ -104,18 +104,594 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "31c03154", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Will measure variance using 1 seeds\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 HotpotQA, Llama3.2-3b
 AccuracyBLEURouge_rouge1Rouge_rouge2Rouge_rougeL
 prrprr_0.5prr_0.5_normalizedprr_normalizedprrprr_0.5prr_0.5_normalizedprr_normalizedprrprr_0.5prr_0.5_normalizedprr_normalizedprrprr_0.5prr_0.5_normalizedprr_normalizedprrprr_0.5prr_0.5_normalizedprr_normalized
MaximumSequenceProbability29.89 ± 0.0036.33 ± 0.00-26.36 ± 0.00-27.21 ± 0.0030.36 ± 0.0037.28 ± 0.00-28.28 ± 0.00-29.70 ± 0.0030.58 ± 0.0037.71 ± 0.00-29.10 ± 0.00-30.83 ± 0.0022.75 ± 0.0031.47 ± 0.009.68 ± 0.00-22.90 ± 0.0030.58 ± 0.0037.71 ± 0.00-29.10 ± 0.00-30.83 ± 0.00
Perplexity32.30 ± 0.0032.06 ± 0.00-57.82 ± 0.00-20.45 ± 0.0032.44 ± 0.0032.34 ± 0.00-63.40 ± 0.00-23.85 ± 0.0032.50 ± 0.0032.47 ± 0.00-65.81 ± 0.00-25.41 ± 0.0020.61 ± 0.0026.43 ± 0.00-49.48 ± 0.00-29.20 ± 0.0032.50 ± 0.0032.47 ± 0.00-65.81 ± 0.00-25.41 ± 0.00
MeanTokenEntropy28.05 ± 0.0030.57 ± 0.00-68.74 ± 0.00-32.35 ± 0.0028.20 ± 0.0030.85 ± 0.00-73.94 ± 0.00-35.80 ± 0.0028.26 ± 0.0030.98 ± 0.00-76.18 ± 0.00-37.37 ± 0.0018.79 ± 0.0022.50 ± 0.00-95.73 ± 0.00-34.54 ± 0.0028.26 ± 0.0030.98 ± 0.00-76.18 ± 0.00-37.37 ± 0.00
MeanPointwiseMutualInformation48.88 ± 0.0034.89 ± 0.00-36.93 ± 0.0026.12 ± 0.0049.36 ± 0.0035.85 ± 0.00-38.47 ± 0.0023.81 ± 0.0049.57 ± 0.0036.28 ± 0.00-39.14 ± 0.0022.76 ± 0.0032.91 ± 0.0028.10 ± 0.00-29.89 ± 0.007.02 ± 0.0049.57 ± 0.0036.28 ± 0.00-39.14 ± 0.0022.76 ± 0.00
MeanConditionalPointwiseMutualInformation49.75 ± 0.0040.65 ± 0.005.44 ± 0.0028.55 ± 0.0050.80 ± 0.0042.49 ± 0.008.73 ± 0.0027.88 ± 0.0051.28 ± 0.0043.33 ± 0.0010.14 ± 0.0027.58 ± 0.0023.43 ± 0.0029.61 ± 0.00-12.08 ± 0.00-20.89 ± 0.0051.28 ± 0.0043.33 ± 0.0010.14 ± 0.0027.58 ± 0.00
PTrue51.65 ± 0.0048.98 ± 0.0066.76 ± 0.0033.89 ± 0.0052.22 ± 0.0050.13 ± 0.0063.00 ± 0.0031.88 ± 0.0052.48 ± 0.0050.65 ± 0.0061.38 ± 0.0030.96 ± 0.0039.36 ± 0.0032.03 ± 0.0016.36 ± 0.0025.99 ± 0.0052.48 ± 0.0050.65 ± 0.0061.38 ± 0.0030.96 ± 0.00
PTrueSampling33.60 ± 0.0047.49 ± 0.0055.74 ± 0.00-16.78 ± 0.0033.67 ± 0.0047.62 ± 0.0045.18 ± 0.00-20.38 ± 0.0033.70 ± 0.0047.69 ± 0.0040.63 ± 0.00-22.02 ± 0.0030.19 ± 0.0033.55 ± 0.0034.16 ± 0.00-0.99 ± 0.0033.70 ± 0.0047.69 ± 0.0040.63 ± 0.00-22.02 ± 0.00
MonteCarloSequenceEntropy38.71 ± 0.0042.34 ± 0.0017.84 ± 0.00-2.44 ± 0.0040.52 ± 0.0044.18 ± 0.0020.70 ± 0.00-1.09 ± 0.0041.34 ± 0.0045.01 ± 0.0021.93 ± 0.00-0.48 ± 0.0024.48 ± 0.0035.22 ± 0.0053.75 ± 0.00-17.81 ± 0.0041.34 ± 0.0045.01 ± 0.0021.93 ± 0.00-0.48 ± 0.00
MonteCarloNormalizedSequenceEntropy52.69 ± 0.0041.50 ± 0.0011.71 ± 0.0036.80 ± 0.0054.77 ± 0.0043.34 ± 0.0014.78 ± 0.0039.06 ± 0.0055.72 ± 0.0044.18 ± 0.0016.10 ± 0.0040.08 ± 0.0037.30 ± 0.0035.22 ± 0.0053.75 ± 0.0019.91 ± 0.0055.72 ± 0.0044.18 ± 0.0016.10 ± 0.0040.08 ± 0.00
EigenScore69.41 ± 0.0048.40 ± 0.0062.43 ± 0.0083.77 ± 0.0070.09 ± 0.0049.76 ± 0.0060.33 ± 0.0082.22 ± 0.0070.40 ± 0.0050.37 ± 0.0059.42 ± 0.0081.52 ± 0.0056.09 ± 0.0033.55 ± 0.0034.16 ± 0.0075.20 ± 0.0070.40 ± 0.0050.37 ± 0.0059.42 ± 0.0081.52 ± 0.00
RenyiNeg47.63 ± 0.0045.76 ± 0.0043.02 ± 0.0022.59 ± 0.0052.57 ± 0.0047.60 ± 0.0045.00 ± 0.0032.87 ± 0.0054.82 ± 0.0048.43 ± 0.0045.85 ± 0.0037.56 ± 0.0040.08 ± 0.0039.15 ± 0.00100.00 ± 0.0028.11 ± 0.0054.82 ± 0.0048.43 ± 0.0045.85 ± 0.0037.56 ± 0.00
FisherRao49.58 ± 0.0045.17 ± 0.0038.69 ± 0.0028.09 ± 0.0054.53 ± 0.0047.01 ± 0.0040.82 ± 0.0038.38 ± 0.0056.78 ± 0.0047.85 ± 0.0041.74 ± 0.0043.08 ± 0.0045.67 ± 0.0039.15 ± 0.00100.00 ± 0.0044.55 ± 0.0056.78 ± 0.0047.85 ± 0.0041.74 ± 0.0043.08 ± 0.00
ProbasMeanWithCoT51.14 ± 0.0044.09 ± 0.0030.73 ± 0.0032.45 ± 0.0053.22 ± 0.0045.93 ± 0.0033.13 ± 0.0034.69 ± 0.0054.17 ± 0.0046.76 ± 0.0034.17 ± 0.0035.71 ± 0.0047.84 ± 0.0037.07 ± 0.0075.52 ± 0.0050.94 ± 0.0054.17 ± 0.0046.76 ± 0.0034.17 ± 0.0035.71 ± 0.00
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# visualize results in a table\n", "pretty_plot(\n", - " 'TriviaQA, Dolly3b',\n", + " 'HotpotQA, Llama3.2-3b',\n", " # outputs generated by scripts/polygraph_eval benchmark\n", " # provide several seeds to calculate variance\n", - " ['./workdir/output_seed' + str(x)\n", - " for x in range(1, 10)])" + " [\"../workdir/output/qa/{'path': 'meta-llama/Llama-3.2-3B-Instruct', 'ensemble': False, 'mc': False, 'mc_seeds': None, 'dropout_rate': None, 'type': 'CausalLM', 'path_to_load_script': 'model/default_causal.py', 'load_model_args': {'device_map': 'auto'}, 'load_tokenizer_args': {}}/['denis1699/hotpot_cot']/2025-05-06/09-26-59/ue_manager_seed1\"])" ] }, { @@ -143,7 +719,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/scripts/polygraph_eval b/scripts/polygraph_eval index 057bf3b2b..408865ca5 100755 --- a/scripts/polygraph_eval +++ b/scripts/polygraph_eval @@ -224,14 +224,14 @@ def get_generation_metrics(args): RougeMetric("rouge2"), RougeMetric("rougeL"), BLEUMetric(), - BertScoreMetric("rh"), - SbertMetric(), + # BertScoreMetric("rh"), + # SbertMetric(), AccuracyMetric( target_ignore_regex=getattr(args, "target_ignore_regex", None), output_ignore_regex=getattr(args, "output_ignore_regex", None), normalize=getattr(args, "normalize", False), ), - AlignScore(target_is_claims=False if args.task == "ats" else True), + # AlignScore(target_is_claims=False if args.task == "ats" else True), ] if getattr(args.model, "type", "Whitebox") != "Blackbox": if getattr(args, "use_claim_ue", False): @@ -374,7 +374,9 @@ def get_vllm_model(args): load_model_args = {'model_path': args.model.path, 'max_new_tokens': args.max_new_tokens, - 'logprobs': args.model.logprobs} + 'logprobs': args.model.logprobs, + 'max_model_len': 8192, + } load_model_args.update(args.model.load_model_args) base_model, sampling_params = load_module.load_model(**load_model_args) diff --git a/src/lm_polygraph/defaults/register_default_stat_calculators.py b/src/lm_polygraph/defaults/register_default_stat_calculators.py index 28af538d4..a79e17de1 100644 --- a/src/lm_polygraph/defaults/register_default_stat_calculators.py +++ b/src/lm_polygraph/defaults/register_default_stat_calculators.py @@ -47,18 +47,18 @@ def _register( deberta_model_path = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" _register(InitialStateCalculator) - _register( - SemanticMatrixCalculator, - "lm_polygraph.defaults.stat_calculator_builders.default_SemanticMatrixCalculator", - { - "nli_model": { - "deberta_path": deberta_model_path, - "hf_cache": hf_cache, - "batch_size": 10, - "device": None, - } - }, - ) + # _register( + # SemanticMatrixCalculator, + # "lm_polygraph.defaults.stat_calculator_builders.default_SemanticMatrixCalculator", + # { + # "nli_model": { + # "deberta_path": deberta_model_path, + # "hf_cache": hf_cache, + # "batch_size": 10, + # "device": None, + # } + # }, + # ) _register(SemanticClassesCalculator) if model_type == "Blackbox": @@ -99,42 +99,47 @@ def _register( _register(PromptCalculator) _register(SamplingPromptCalculator) _register(ClaimPromptCalculator) + # _register( + # CrossEncoderSimilarityMatrixCalculator, + # "lm_polygraph.defaults.stat_calculator_builders.default_CrossEncoderSimilarityMatrixCalculator", + # { + # "batch_size": 10, + # "cross_encoder_name": "cross-encoder/stsb-roberta-large", + # }, + # ) + # _register( + # GreedyAlternativesNLICalculator, + # "lm_polygraph.defaults.stat_calculator_builders.default_GreedyAlternativesNLICalculator", + # { + # "nli_model": { + # "deberta_path": deberta_model_path, + # "hf_cache": hf_cache, + # "batch_size": 10, + # "device": None, + # } + # }, + # ) + # _register( + # GreedyAlternativesFactPrefNLICalculator, + # "lm_polygraph.defaults.stat_calculator_builders.default_GreedyAlternativesFactPrefNLICalculator", + # { + # "nli_model": { + # "deberta_path": deberta_model_path, + # "hf_cache": hf_cache, + # "batch_size": 10, + # "device": None, + # } + # }, + # ) + # _register( + # ClaimsExtractor, + # "lm_polygraph.defaults.stat_calculator_builders.default_ClaimsExtractor", + # {"openai_model": "gpt-4o", "cache_path": "~/.cache", "language": language}, + # ) _register( - CrossEncoderSimilarityMatrixCalculator, - "lm_polygraph.defaults.stat_calculator_builders.default_CrossEncoderSimilarityMatrixCalculator", - { - "batch_size": 10, - "cross_encoder_name": "cross-encoder/stsb-roberta-large", - }, - ) - _register( - GreedyAlternativesNLICalculator, - "lm_polygraph.defaults.stat_calculator_builders.default_GreedyAlternativesNLICalculator", - { - "nli_model": { - "deberta_path": deberta_model_path, - "hf_cache": hf_cache, - "batch_size": 10, - "device": None, - } - }, - ) - _register( - GreedyAlternativesFactPrefNLICalculator, - "lm_polygraph.defaults.stat_calculator_builders.default_GreedyAlternativesFactPrefNLICalculator", - { - "nli_model": { - "deberta_path": deberta_model_path, - "hf_cache": hf_cache, - "batch_size": 10, - "device": None, - } - }, - ) - _register( - ClaimsExtractor, - "lm_polygraph.defaults.stat_calculator_builders.default_ClaimsExtractor", - {"openai_model": "gpt-4o", "cache_path": "~/.cache", "language": language}, + ReasoningKeywordsProbs, + "lm_polygraph.defaults.stat_calculator_builders.default_ReasoningKeywordsProbs", + {"max_retries": 5, "max_length_cot": 128, "temperature": 1.0} ) else: diff --git a/src/lm_polygraph/defaults/stat_calculator_builders/default_ReasoningKeywordsProbs.py b/src/lm_polygraph/defaults/stat_calculator_builders/default_ReasoningKeywordsProbs.py new file mode 100644 index 000000000..38820560e --- /dev/null +++ b/src/lm_polygraph/defaults/stat_calculator_builders/default_ReasoningKeywordsProbs.py @@ -0,0 +1,9 @@ +from lm_polygraph.stat_calculators.reasoning_keywords_probs import ( + ReasoningKeywordsProbs, +) + + +def load_stat_calculator(config, builder): + return ReasoningKeywordsProbs( + config.max_retries, config.max_length_cot, config.temperature + ) diff --git a/src/lm_polygraph/estimators/__init__.py b/src/lm_polygraph/estimators/__init__.py index 8162f6380..fd06e1232 100644 --- a/src/lm_polygraph/estimators/__init__.py +++ b/src/lm_polygraph/estimators/__init__.py @@ -77,3 +77,4 @@ from .kernel_language_entropy import KernelLanguageEntropy from .luq import LUQ from .eigenscore import EigenScore +from .chain_of_thought_uq import ProbasMeanWithCoT diff --git a/src/lm_polygraph/estimators/chain_of_thought_uq.py b/src/lm_polygraph/estimators/chain_of_thought_uq.py new file mode 100644 index 000000000..c51bce82a --- /dev/null +++ b/src/lm_polygraph/estimators/chain_of_thought_uq.py @@ -0,0 +1,137 @@ +import numpy as np +import math + +from typing import Dict, List, Tuple + +from .estimator import Estimator + + +def aggregate_probas_mean( + keyword_token_probability: Dict[str, Dict[str, List[int]]], contribution_scores: Dict[str, Dict[str, int]] = None +) -> Tuple[Dict[str, List[float]], Dict[str, List[float]]]: + """ + Aggregates token probabilities + + Parameters: + keyword_token_probability (Dict[str, Dict[str, List[int]]]): token probs for keywords + (example { + "step1": { + "keyword1": [0.7, 0.8], + "keyword2": [0.9, 0.6, 0.5], + }, + "step2": { + "keyword1": [0.5, 0.8], + "keyword3": [0.5, 0.9, 0.9], + }, + ... + } + ), + contribution_scores (Dict[str, Dict[str, int]]): contribution scores for keywords. + Returns: + Tuple[Dict[str, List[float]], Dict[str, List[float]]]: agg. keyword probs, agg. keyword contributions. + (example { + "keyword1": [(0.7 + 0.8) / 2, (0.5 + 0.8) / 2], + "keyword2": [(0.9 + 0.6 + 0.5) / 3], + "keyword3": [(0.5 + 0.9 + 0.9) / 3], + ... + } + ), + """ + return_keyword_dict = {} + return_contribution_dict = {} + for step, inner_dict in keyword_token_probability.items(): + for key, values in inner_dict.items(): + if len(values) == 0: + continue + # it is strange that min(values) was in original implementation for probas mean agg. strategy + # value_to_add = min(values) + value_to_add = np.mean(values) + if key in return_keyword_dict: + return_keyword_dict[key].append(value_to_add) + return_contribution_dict[key].append(contribution_scores[step][key]) + else: + return_keyword_dict[key] = [value_to_add] + return_contribution_dict[key] = [contribution_scores[step][key]] + return return_keyword_dict, return_contribution_dict + + +def weighted_sum(values: List[float]) -> float: + """ + Computes a softmin weighted sum of the input values. + + Parameters: + values (List[float]): values to be summed + Returns: + float: a softmin weighted sum + """ + if len(values) == 1: + return values[0] + weights = [math.exp(-c) for c in values] + sum_weights = sum(weights) + normalized_weights = [w / sum_weights for w in weights] + result = sum(w * c for w, c in zip(normalized_weights, values)) + return result + + +class ProbasMeanWithCoT(Estimator): + """ + Enhances Probas-Mean aggregated probabilities strategy with reasoning steps. + Only usabe for instruct-finetuned models with chat template support. + Adapted from the original implementation in the paper https://arxiv.org/pdf/2502.17214 + """ + + def __init__( + self, + name_postfix="", + ): + self.postfix = name_postfix + super().__init__( + [ + "input_texts", + "greedy_texts", + "reasoning_answer", + "reasoning_keywords_probabilities", + "reasoning_keywords_contributions", + ], + "sequence", + ) + + def __str__(self): + return f"ProbasMeanWithCoT{self.postfix}" + + def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray: + prompts = stats["input_texts"] + ues = [] + for i, question in enumerate(prompts): + reasoning_answer = stats["reasoning_answer"][i] + if reasoning_answer == "": + ues.append(0.5) + continue + + keyword_token_probability = stats["reasoning_keywords_probabilities"][i] + if keyword_token_probability is None or keyword_token_probability == {}: + ues.append(0.5) + continue + contribution_scores = stats["reasoning_keywords_contributions"][i] + if contribution_scores is None or contribution_scores == {}: + ues.append(0.5) + continue + + probabilities, contribution_dict = aggregate_probas_mean(keyword_token_probability, contribution_scores) + + # softmin weighted sum of keywords probs + probabilities = {key: weighted_sum(value) for key, value in probabilities.items()} + # average of keywords contributions + contributions = {key: sum(value) / len(value) for key, value in contribution_dict.items()} + + # CoT-UQ + total_sum = sum(probabilities[key] * contributions[key] for key in probabilities) + total_weight = sum(contributions[key] for key in contributions) + if total_weight == 0: + p_list = [v for v in probabilities.values()] + confidence = sum(p_list) / len(p_list) + else: + confidence = total_sum / total_weight + ues.append(1 - confidence) + + return np.array(ues) diff --git a/src/lm_polygraph/stat_calculators/__init__.py b/src/lm_polygraph/stat_calculators/__init__.py index 354026271..99a0ec4ad 100644 --- a/src/lm_polygraph/stat_calculators/__init__.py +++ b/src/lm_polygraph/stat_calculators/__init__.py @@ -29,3 +29,4 @@ from .extract_claims import ClaimsExtractor from .infer_causal_lm_calculator import InferCausalLMCalculator from .semantic_classes import SemanticClassesCalculator +from .reasoning_keywords_probs import ReasoningKeywordsProbs diff --git a/src/lm_polygraph/stat_calculators/reasoning_keywords_probs.py b/src/lm_polygraph/stat_calculators/reasoning_keywords_probs.py new file mode 100644 index 000000000..de695e6dc --- /dev/null +++ b/src/lm_polygraph/stat_calculators/reasoning_keywords_probs.py @@ -0,0 +1,500 @@ +import re +import torch +import warnings +import numpy as np +from collections import defaultdict + +from typing import Dict, List, Tuple, Optional + +from .stat_calculator import StatCalculator +from lm_polygraph.utils.model import WhiteboxModel + +import logging + +log = logging.getLogger("lm_polygraph") +logging.getLogger("httpx").setLevel(logging.WARNING) + + +cot_instruction = """ +Please reason the following question step by step. Label each reasoning step as "Step i:", where "i" is the step number. +You need to ensure that each step builds on the previous one and contributes meaningfully toward reaching the final answer. +Once you finish all steps, put your final answer on a separate line after the reasoning steps, starting with "Final Answer:" (do not label it as a step). + +Question: +Response: Let's think step by step. +""" + +keywords_extraction_instruction = ''' +You will be provided with a question and a multi-step response containing reasoning steps. +For each long reasoning step labeled "Step i:", extract the keywords, only the relevant tokens for that specific reasoning step. +The keywords should be relevant to question and final answer. +If you find more than one keyword in a specific step, separate them with “;”. +For example: + +###### + +Q: Which band has more members, "We Are the Ocean" or "The Dream Academy"? + +Reasoning steps: +Step 1: The question is asking which band has more members. +Step 2: "We Are the Ocean" has five members. +Step 3: "The Dream Academy" has three members. +Step 4: 5 is greater than 3. +Step 5: Therefore, "We Are the Ocean" has more members. +Final Answer: We Are the Ocean + +Keywords for each reasoning step: +Step 1: band +Step 2: We Are the Ocean; five +Step 3: The Dream Academy; three +Step 4: greater +Step 5: We Are the Ocean + +###### + +The following is your task: +Q: + +Reasoning steps: + + +Keywords for each reasoning step: +''' + + +def is_effectively_empty(obj): + if obj is None: + return True + + if isinstance(obj, (int, float)) and obj == 0: + return True + + if obj == "": + return True + + if isinstance(obj, list): + return all(is_effectively_empty(item) for item in obj) + + if isinstance(obj, dict): + if len(obj) == 0: + return True + return all(is_effectively_empty(value) for value in obj.values()) + return False + + +def parse_response_to_dict(response: str) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: + """ + Parse model reasoning output to highlight: reasoning answer, reasoning steps, reasoning output without answer. + + Parameters: + response (str): reasoning output. + Returns: + Tuple[Optional[str], Dict[str, str], Optional[str]]: + - final answer (str or None), + - dictionary of steps (e.g., {"Step 1": "Step 1: ..."}), + - response before final answer (str or None) + """ + steps: Dict[str, str] = {} + final_answer: Optional[str] = None + + # Match Final Answer + match = re.search(r"Final Answer:\s*(.+?)\s*(?=(\n|$))", response, re.DOTALL) + if match: + final_answer = match.group(1).strip() + response_after_final_answer = response[:match.end()].strip() + # response_before_final_answer = response[:match.start()].strip() + else: + return None, {}, None + + # Match Steps + matches = list(re.finditer(r"(Step \d+):", response_after_final_answer)) + for i, match in enumerate(matches): + start = match.start() + end = matches[i + 1].start() if i + 1 < len(matches) else len(response_after_final_answer) + segment = response[start:end].strip() + steps[match.group(1)] = segment + + return_response = response_after_final_answer + return final_answer, steps, return_response + + +def match_final_answer_token_ids(tokenizer, original_tokens, response_tokens, generated_ids): + # caution + final_answer_tokens = tokenizer.tokenize("Final Answer:") + + end_index = None + end_index_original = None + + for i in range(len(response_tokens) - len(final_answer_tokens) + 1): + if response_tokens[i : i + len(final_answer_tokens)] == final_answer_tokens: + end_index = i + len(final_answer_tokens) + break + + if end_index is None or end_index == len(response_tokens): + return None, None + + for i in range(len(original_tokens) - len(final_answer_tokens) + 1): + if original_tokens[i : i + len(final_answer_tokens)] == final_answer_tokens: + end_index_original = i + len(final_answer_tokens) + break + + if end_index_original is None: + return None, None + + if response_tokens[end_index] in ["▁", "Ġ", tokenizer.tokenize(" ")]: + end_index += 1 + end_index_original += 1 + + target_tokens = response_tokens[end_index:] + + final_answer_token_ids = generated_ids[end_index_original : end_index_original + len(target_tokens)] + + return end_index_original, final_answer_token_ids + + +def predict(prompt, model, tokenizer, max_length_cot, temperature): + inputs = tokenizer(prompt, return_tensors="pt").to('cuda') + generate_ids = model.generate( + **inputs, + max_new_tokens=max_length_cot, + temperature=temperature, + pad_token_id=tokenizer.eos_token_id, + output_scores=True, + return_dict_in_generate=True, + ) + infer_res = tokenizer.decode(generate_ids.sequences[0][len(inputs["input_ids"][0]):-1]) + return infer_res + + +# def step_exacts_2_list(response): +# # Split response into lines and filter out empty lines +# lines = response.splitlines() +# lines = [line for line in lines if line.strip()] + +# keywords_by_step = [] +# contributions_by_step = [] +# valid_response_text = [] + +# for line in lines: +# # Match lines starting with "Step X:" +# match = re.search(r"Step \d+: (.+)", line) +# if match: +# # Extract keywords with contributions +# keywords_w_contribution = match.group(1).split("; ") + +# # Check for valid format and skip invalid lines +# if any("(/" not in key_w_c or "/)" not in key_w_c for key_w_c in keywords_w_contribution): +# continue + +# try: +# # Extract keywords and contributions +# keywords = [key_w_c.split("(/")[0].strip() for key_w_c in keywords_w_contribution] +# contributions = [int(key_w_c.split("(/")[1].split("/)")[0].strip()) for key_w_c in keywords_w_contribution] +# except ValueError: +# return False # Return False if contributions cannot be converted to int + +# for i in contributions: +# if i > 10: +# return False + +# keywords_by_step.append(keywords) +# contributions_by_step.append(contributions) +# valid_response_text.append(line) # Add valid lines from the original response + +# # If no valid lines are found, return False +# if not valid_response_text: +# return False + +# return "\n".join(valid_response_text), keywords_by_step, contributions_by_step + + +def step_exacts_2_list(response): + # Split response into lines and filter out empty lines + lines = response.splitlines() + lines = [line for line in lines if line.strip()] + + keywords_by_step = [] + contributions_by_step = [] + valid_response_text = [] + + for line in lines: + # Match lines starting with "Step X:" + match = re.search(r"Step \d+: (.+)", line) + if match: + # Extract keywords + keywords = match.group(1).split("; ") + + contributions = [10]*len(keywords) + + keywords_by_step.append(keywords) + contributions_by_step.append(contributions) + valid_response_text.append(line) # Add valid lines from the original response + + return "\n".join(valid_response_text), keywords_by_step, contributions_by_step + + +def find_subsequence_position(sub_sequence, long_sequence): + len_long = len(long_sequence) + len_sub = len(sub_sequence) + + for i in range(len_long - len_sub + 1): + if long_sequence[i:i + len_sub] == sub_sequence: + return i + return -1 + + +def clean_words(word): + # TODO forward space token + return word.replace(" ", "").replace(".", "").replace("\"", "").replace("\n", "").replace("_", "").replace("Ġ", "").lower() + + +def find_token_indices(tokens, word): + word_len = len(word.replace(" ", "")) + + for start_index in range(len(tokens)): + combined_text = "" + end_index = start_index + while end_index < len(tokens) and len(combined_text) < word_len: + combined_text += tokens[end_index] + if clean_words(combined_text) == clean_words(word): + return start_index, end_index + end_index += 1 + return -1, -1 + + +def is_word_in_sentence(sentence, word): + pattern = re.escape(word) + match = re.search(pattern, sentence, re.IGNORECASE) + return True if match else False + + +class ReasoningKeywordsProbs(StatCalculator): + """ + For Whitebox model (lm_polygraph.WhiteboxModel), at input texts batch calculates: + * model output for reasoning enhanced input, + * model answer for reasoning enhanced input, + * token probabilities for `reasoning_answer`, + * keywords from `reasoning_output`, + * probabilities for `reasoning_keywords`, + * contributions for `reasoning_keywords`, + * step-wise token indices for `reasoning_keywords`, + * token indices for `reasoning_keywords`. + """ + + @staticmethod + def meta_info() -> Tuple[List[str], List[str]]: + """ + Returns the statistics and dependencies for the calculator. + """ + return [ + "reasoning_output", + "reasoning_answer", + "reasoning_answer_tokens_probs", + "reasoning_keywords", + "reasoning_keywords_probabilities", + "reasoning_keywords_contributions", + "reasoning_keywords_token_ids", + "reasoning_answer_token_ids", + ], ["input_texts", "greedy_texts", "greedy_tokens", "greedy_log_probs"] + + def __init__(self, max_retries=5, max_length_cot=256, temperature=1): + super().__init__() + self.max_retries = max_retries + self.max_length_cot = max_length_cot + self.temperature = temperature + + def __call__( + self, + dependencies: Dict[str, np.array], + texts: List[str], + model: WhiteboxModel, + max_new_tokens: int = 100, + ) -> Dict[str, np.ndarray]: + """ + Calculates the statistics of reasoning enhanced process. + + Parameters: + dependencies (Dict[str, np.ndarray]): input statistics, can be empty (not used). + texts (List[str]): Input texts batch used for model generation. + model (Model): Model used for generation. + max_new_tokens (int): Maximum number of new tokens at model generation. Default: 100. + Returns: + Dict[str, np.ndarray]: dictionary with the following items: + - 'reasoning_output' (List[str]): model output for reasoning enhanced input, + - 'reasoning_answer' (List[str]): model answer for reasoning enhanced input, + - 'reasoning_answer_tokens_probs' (List[str]): token probabilities for `reasoning_answer`, + - 'reasoning_keywords' (List[str]): keywords from `reasoning_output`, + - 'reasoning_keywords_probabilities' (List[Dict[str, Dict[str, List[int]]]]): probabilities for `reasoning_keywords`, + - 'reasoning_keywords_contributions' (List[Dict[str, Dict[str, int]]]): contributions for `reasoning_keywords`, + - 'reasoning_keywords_token_ids' (List[Dict[str, Dict[str, List[int]]]]): step-wise token indices for `reasoning_keywords`, + - 'reasoning_answer_token_ids' (List[Dict[str, List[int]]]): token indices for `reasoning_keywords`. + """ + result_dict = defaultdict(list) + batch_input_texts = dependencies['input_texts'] + batch_generated_texts = dependencies['greedy_texts'] + batch_generated_tokens = dependencies['greedy_tokens'] + batch_generated_log_probs = dependencies['greedy_log_probs'] + for input_text, generated_text, generated_tokens, generated_log_probs in zip(batch_input_texts, batch_generated_texts, batch_generated_tokens, batch_generated_log_probs): + question = re.search(r'Question:\s*(.*?)\s*Response:', input_text, re.DOTALL).group(1).strip() + # log.info(f"Input texts: {question}") + # log.info(f"Generated text: {generated_text}") + n_of_retries = 0 + while n_of_retries < self.max_retries: + # generated token ids for the question enchanced with CoT. + generated_ids = generated_tokens + # generated text for the question enchaced with CoT + to_parse = model.tokenizer.decode(generated_ids, skip_special_tokens=True) + + llm_answer, steps_dict, response = parse_response_to_dict(to_parse) + + if len(generated_ids) == 0: + log.info(f'New Reasoning Tokens Are Null, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + if llm_answer is None or llm_answer in ["", " "]: + log.info(f'New Reasoning Tokens Are None, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + + # reasoning tokens + response_tokens = model.tokenizer.tokenize(response) + # full reasoning tokens + original_tokens = model.tokenizer.convert_ids_to_tokens(generated_ids) + probabilities = [ + {i: p for i, p in enumerate(prob) if p > 0} + for prob in [torch.softmax(torch.from_numpy(score), dim=0).tolist() for score in generated_log_probs] + ] + + final_answer_probabilities = {} + final_answer_token_ids = {} + answer_start_indice, answer_token_ids = match_final_answer_token_ids( + model.tokenizer, + original_tokens, + response_tokens, + generated_ids, + ) + if answer_start_indice is None: + log.info(f'Cannot locate the Final Answer, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + answer_probs = [] + flag = False + for j, token_id in enumerate(answer_token_ids): + idxx = j + answer_start_indice + if token_id not in probabilities[idxx].keys(): + flag = True + break + answer_probs.append(probabilities[idxx][token_id]) + if flag: + # log.debug(f'Cannot locate the Final Answer Token Probability, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + final_answer_probabilities[llm_answer] = answer_probs + final_answer_token_ids[llm_answer] = answer_token_ids + + # exacts_prompt = get_step_exact_tokens(args, q, response) + keywords_extraction_prompt = keywords_extraction_instruction.replace('', question).replace('', response) + chat = [{"role": "user", "content": keywords_extraction_prompt},] + keywords_extraction_prompt = model.tokenizer.apply_chat_template(chat, tokenize=False) + + keywords_extraction_prompt_output = predict(keywords_extraction_prompt, model, model.tokenizer, self.max_length_cot, self.temperature) + + parsed_keywords_output = step_exacts_2_list(keywords_extraction_prompt_output) + if not parsed_keywords_output: + log.info(f'Exact Tokens Have no contribution scores, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + extracted_keywords, keywords_list, contributions_list = parsed_keywords_output + if len(keywords_list) == 0: + log.info(f'Cannot Exract Effective Keywords, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + + if len(steps_dict) > len(keywords_list): + log.info(f'Len of keywords list doesn\'t match the len of step dict, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + + keywords_probabilities = {} + keywords_contributions = {} + keywords_token_ids = {} + for step_idx, (step_name, step_text) in enumerate(steps_dict.items()): + # # Skip the Final Answer + keywords = keywords_list[step_idx] + contributions = contributions_list[step_idx] + if len(keywords) == 1 and keywords[0] == "NO ANSWER": + log.info("NO answer") + continue + step_tokens = model.tokenizer.tokenize(step_text) + space_token = model.tokenizer.tokenize(" ")[0] + processed_step_tokens = [ + (token[1:] if token.startswith(space_token) else token) + for token in step_tokens + ] + step_token_ids = model.tokenizer.convert_tokens_to_ids(step_tokens) + start_position = find_subsequence_position(step_token_ids[1:-2], generated_ids) - 1 + step_token_ids = generated_ids[start_position : start_position + len(step_tokens)] + keywords_probabilities_dict = {} + keywords_contributions_dict = {} + keywords_token_ids_dict = {} + for keyword_idx, keyword in enumerate(keywords): + keyword_probs = [] + keyword_token_ids = [] + if is_word_in_sentence(step_text, keyword) is not True: + log.info(f"\n{step_name}-Keyword-{keyword_idx} Does not appear in the Step Text") + continue + keyword_token_start_idx, keyword_token_end_idx = find_token_indices( + processed_step_tokens, keyword + ) + keyword_token_ids = generated_ids[ + start_position + keyword_token_start_idx : start_position + keyword_token_end_idx + 1 + ] + + for j, token_id in enumerate(keyword_token_ids): + idxx = start_position + keyword_token_start_idx + j + keyword_probs.append(probabilities[idxx][token_id]) + keywords_probabilities_dict[keyword] = keyword_probs + keywords_contributions_dict[keyword] = int(contributions[keyword_idx]) + keywords_token_ids_dict[keyword] = keyword_token_ids + + keywords_probabilities[step_name] = keywords_probabilities_dict + keywords_contributions[step_name] = keywords_contributions_dict + keywords_token_ids[step_name] = keywords_token_ids_dict + + if is_effectively_empty(keywords_probabilities): + log.info(f'Token Probability from All Steps are All None, Current try is {n_of_retries + 1}') + n_of_retries += 1 + continue + + # Dict[str, np.ndarray]: dictionary with the following items: + # - 'reasoning_output' (List[str]): model output for reasoning enhanced input, + # - 'reasoning_answer' (List[str]): model answer for reasoning enhanced input, + # - 'reasoning_answer_tokens_probs' (List[str]): token probabilities for `reasoning_answer`, + # - 'reasoning_keywords' (List[str]): keywords from `reasoning_output`, + # - 'reasoning_keywords_probabilities' (List[Dict[str, Dict[str, List[int]]]]): probabilities for `reasoning_keywords`, + # - 'reasoning_keywords_contributions' (List[Dict[str, Dict[str, int]]]): contributions for `reasoning_keywords`, + # - 'reasoning_keywords_token_ids' (List[Dict[str, Dict[str, List[int]]]]): step-wise token indices for `reasoning_keywords`, + # - 'reasoning_answer_token_ids' (List[Dict[str, List[int]]]): token indices for `reasoning_keywords`. + + result_dict["reasoning_output"].append(response) + result_dict["reasoning_answer"].append(llm_answer) + result_dict["reasoning_answer_tokens_probs"].append(final_answer_probabilities) + result_dict["reasoning_keywords"].append(extracted_keywords) + result_dict["reasoning_keywords_probabilities"].append(keywords_probabilities) + result_dict["reasoning_keywords_contributions"].append(keywords_contributions) + result_dict["reasoning_keywords_token_ids"].append(keywords_token_ids) + result_dict["reasoning_answer_token_ids"].append(final_answer_token_ids) + break + + if n_of_retries >= self.max_retries: + # log.debug(f'#####The Following Question:#####\n{q}\nHas no Meaningful Answer & Explanations, Record and Skip') + result_dict["reasoning_output"].append(response) + result_dict["reasoning_answer"].append(llm_answer) + result_dict["reasoning_answer_tokens_probs"].append(None) + result_dict["reasoning_keywords"].append(None) + result_dict["reasoning_keywords_probabilities"].append(None) + result_dict["reasoning_keywords_contributions"].append(None) + result_dict["reasoning_keywords_token_ids"].append(None) + result_dict["reasoning_answer_token_ids"].append(None) + + return result_dict diff --git a/src/lm_polygraph/stat_calculators/stat_calculator.py b/src/lm_polygraph/stat_calculators/stat_calculator.py index e6e6655c4..031e4f163 100644 --- a/src/lm_polygraph/stat_calculators/stat_calculator.py +++ b/src/lm_polygraph/stat_calculators/stat_calculator.py @@ -18,7 +18,7 @@ class StatCalculator(ABC): UEManager at lm_polygraph.utils.manager will order all the needed calculators and estimators to be called in the correct order. Any cycle dependencies among calculators will be spotted by UEManager and end with an exception. - Each new StatCalculator needs to be registered at lm_polygraph/stat_calculators/__init__.py to be seen be UEManager. + Each new StatCalculator needs to be registered at lm_polygraph/stat_calculators/__init__.py to be seen by UEManager. """ @staticmethod diff --git a/src/lm_polygraph/utils/factory_estimator.py b/src/lm_polygraph/utils/factory_estimator.py index c1e13b5b0..24c859edd 100644 --- a/src/lm_polygraph/utils/factory_estimator.py +++ b/src/lm_polygraph/utils/factory_estimator.py @@ -46,6 +46,7 @@ def load_simple_estimators(name: str, config): ClaimConditionedProbabilityClaim, RandomBaselineClaim, FocusClaim, + ProbasMeanWithCoT, ] try: diff --git a/src/lm_polygraph/utils/manager.py b/src/lm_polygraph/utils/manager.py index c6416b4bd..131c9cfd1 100644 --- a/src/lm_polygraph/utils/manager.py +++ b/src/lm_polygraph/utils/manager.py @@ -58,6 +58,24 @@ def _delete_nans(ue, metric): return clipped_ue, new_metric +def _recombine_data(ue, gen_metric, inputs): + ue = np.array(ue) + gen_metric = np.array(gen_metric) + + # np.unique() with return_counts=True? + recombined_inputs = defaultdict(list) + for i, input_text in enumerate(inputs): + recombined_inputs[input_text].append(i) + + recombined_ue, recombined_gen_metric = [], [] + for input_text, ids in recombined_inputs.items(): + recombined_ue.append(ue[ids].mean()) + # Assumes that metric is bigger for better generations! + recombined_gen_metric.append(gen_metric[ids].max()) + + return recombined_ue, recombined_gen_metric + + def order_calculators( stats: List[str], stat_calculators: Dict[str, StatCalculator], diff --git a/test/test_estimators.py b/test/test_estimators.py index 50dcd260a..48d3faa8d 100644 --- a/test/test_estimators.py +++ b/test/test_estimators.py @@ -244,3 +244,9 @@ def test_eigenscore(model): estimator = EigenScore() ue = estimate_uncertainty(model, estimator, INPUT) assert isinstance(ue.uncertainty, float) + +def test_probas_mean_with_cot(model): + estimator = ProbasMeanWithCoT() + ue = estimate_uncertainty(model, estimator, INPUT) + assert isinstance(ue.uncertainty, float) +