From 8d1f79b1155aedd674a4e87ae526ccd2fd9c5fab Mon Sep 17 00:00:00 2001 From: Adam Kovacs Date: Mon, 5 May 2025 16:25:55 +0000 Subject: [PATCH 1/2] Optimized LLM inference, introduced cache --- .gitignore | 5 +- lettucedetect/models/inference.py | 271 +++++++++++------- lettucedetect/prompts/examples_de.json | 21 ++ .../prompts/hallucination_detection.txt | 29 ++ 4 files changed, 224 insertions(+), 102 deletions(-) create mode 100644 lettucedetect/prompts/examples_de.json create mode 100644 lettucedetect/prompts/hallucination_detection.txt diff --git a/.gitignore b/.gitignore index d04ce72..b81c6d0 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,7 @@ data/ # output/ output/ -temp/ \ No newline at end of file +temp/ + +# cache/ +lettucedetect/cache/ \ No newline at end of file diff --git a/lettucedetect/models/inference.py b/lettucedetect/models/inference.py index a0500fd..cba4bda 100644 --- a/lettucedetect/models/inference.py +++ b/lettucedetect/models/inference.py @@ -1,7 +1,10 @@ +import hashlib import json import os import re from abc import ABC, abstractmethod +from pathlib import Path +from string import Template import torch from openai import OpenAI @@ -11,6 +14,7 @@ HallucinationDataset, ) +# ==== For formatting user input to the right format ==== PROMPT_QA = """ Briefly answer the following question: {question} @@ -25,72 +29,10 @@ {text} output: """ - -PROMPT_LLM = """ - -You will act as an expert annotator to evaluate an answer against a provided source text. -The source text will be given within ... XML tags. -The answer will be given within ... XML tags. - -For each answer, follow these steps: -Step 1: Read and fully understand the answer in german. The answer is a text containing information related to the source text. -Step 2: Thoroughly analyze how the answer relates to the information in the source text. Determine whether the answer contains hallucinations. Hallucinations are sentences that contain one of the following information: - a. conflict: instances where the answer presents direct contraction or opposition to the original source. - b. baseless info: instances where the generated answer includes information which is not inferred from the original source. -Step 3: Determine whether the answer contains any hallucinations. If no hallucinations are found, return an empty list. -Step 4: Compile the labeled hallucinated spans found into a JSON dict, with a key "hallucination list" and its value is a list of -hallucinated spans. If there exist potential hallucinations, the output should be in the following JSON format: {{"hallucination -list": [hallucination span1, hallucination span2, ...]}}. In case of no hallucinations, please output an empty list : {{"hallucination -list": []}}. -Output only the JSON dict. - - - -Given below are three examples for you to comprehend the task. - - - -Source: Was ist die Hauptstadt von Frankreich? Wie hoch ist die Bevölkerung Frankreichs? Frankreich ist ein Land in Europa. Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 67 Millionen. -Answer: Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 69 Millionen. - -1.The answer states that Paris is capital of France. This matches the source and is correct. -2.The answer states that the population of France is 69 million. This condradicts the source that the population is actually 67 million. -Hallucination -> "Die Bevölkerung von Frankreich beträgt 69 Millionen." -Therefore, output only {{"hallucination list": ["Die Bevölkerung Frankreichs beträgt 69 Millionen." ]}} - - - -Source: Was ist die Hauptstadt von Frankreich? Wie hoch ist die Bevölkerung Frankreichs? Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung von Frankreich beträgt 67 Millionen. -Answer: Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung von Frankreich beträgt 67 Millionen, und die Amtssprache ist Spanisch. - -1.The answer states that Paris is capital of France. This matches the source and is correct. -2.The answer states that the population of France is 69 million. This matches the source and is correct. -3. The answer states that the language spoken in France is Spanish. This is incorrect and not supported by the source. -Hallucination -> "die Amtssprache ist Spanisch" -Therefore, output only {{"hallucination list": ["die Amtssprache ist Spanisch" ]}} - - - - -Source: Was ist die Hauptstadt von Österreich? Wie hoch ist die Bevölkerung Österreich? Österreich ist ein Land in Europa. Die Hauptstadt von Österreich ist Wien. Die Bevölkerung Österreichs beträgt 9.1 Millionen. -Answer: Die Hauptstadt von Österreich ist Wien. Die Bevölkerung Österreichs beträgt 9.1 Millionen. -1.The answer states that Vienna is capital of Austria. This matches the source and is correct. -2.The answer states that the population of Austria is 9.1 million. This matches the source and is correct. -Hallucination -> No hallucinations found -Therefore, output only {{"hallucination list": []}} - - -\n - -{context} - -\n - -{answer} - -)""" +# ===================================================== +# ==== Base class for all detectors ==== class BaseDetector(ABC): @abstractmethod def predict(self, context: str, answer: str, output_format: str = "tokens") -> list: @@ -103,6 +45,7 @@ def predict(self, context: str, answer: str, output_format: str = "tokens") -> l pass +# ==== Transformer-based detector ==== class TransformerDetector(BaseDetector): def __init__(self, model_path: str, max_length: int = 4096, device=None, **kwargs): """Initialize the TransformerDetector. @@ -271,17 +214,122 @@ def predict( return self._predict(prompt, answer, output_format) +# ==== LLM-based detector ==== +ANNOTATE_SCHEMA = [ + { + "type": "function", + "function": { + "name": "annotate", + "description": "Return hallucinated substrings from the answer relative to the source.", + "parameters": { + "type": "object", + "properties": { + "hallucination_list": { + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["hallucination_list"], + }, + }, + } +] + + class LLMDetector(BaseDetector): - def __init__(self, model: str = "gpt-4o", temperature: int = 0): + """LLM-powered hallucination detector using function calling and a prompt template.""" + + def __init__( + self, + model: str = "gpt-4o", + temperature: int = 0, + lang: str = "en", + fewshot_path: str | None = None, + prompt_path: str | None = None, + cache_file: str | None = None, + ): """Initialize the LLMDetector. :param model: OpenAI model. :param temperature: model temperature. + :param lang: language of the examples. + :param fewshot_path: path to the fewshot examples. + :param prompt_path: path to the prompt template. + :param cache_file: path to the cache file. """ self.model = model self.temperature = temperature - def _form_prompt(self, context: list[str], question: str | None) -> str: + self.lang = lang + + if fewshot_path is None: + print( + f"No fewshot path provided, using default path: {Path(__file__).parent.parent / 'prompts' / f'examples_{lang.lower()}.json'}" + ) + fewshot_path = ( + Path(__file__).parent.parent / "prompts" / f"examples_{lang.lower()}.json" + ) + + if not fewshot_path.exists(): + raise FileNotFoundError(f"Fewshot file not found at {fewshot_path}") + else: + fewshot_path = Path(fewshot_path) + + if prompt_path is None: + print( + f"No prompt path provided, using default path: {Path(__file__).parent.parent / 'prompts' / 'hallucination_detection.txt'}" + ) + template_path = Path(__file__).parent.parent / "prompts" / "hallucination_detection.txt" + else: + template_path = Path(prompt_path) + + self.template = Template(template_path.read_text(encoding="utf-8")) + self.fewshot = json.loads(fewshot_path.read_text(encoding="utf-8")) + self.cache_path = cache_file + + if cache_file is None: + self.cache_path = ( + Path(__file__).parent.parent / "cache" / f"cache_{self.model}_{self.lang}.json" + ) + self.cache_path.parent.mkdir(parents=True, exist_ok=True) + self.cache = {} + print(f"Cache file not provided, using default path: {self.cache_path}") + else: + self.cache_path = Path(cache_file) + if not self.cache_path.exists(): + raise FileNotFoundError(f"Cache file not found at {self.cache_path}") + self.cache = json.loads(self.cache_path.read_text(encoding="utf-8")) + + def _build_prompt( + self, + context: str, + answer: str, + ) -> str: + """Fill the template with runtime values, inserting zero or many few‑shot examples. + Uses `${placeholder}` tokens in the .txt file. + """ + fewshot_block = "" + if self.fewshot: + lines: list[str] = [] + for idx, ex in enumerate(self.fewshot, 1): + lines.append( + f""" +{ex["source"]} +{ex["answer"]} +{{"hallucination_list": {json.dumps(ex["hallucination_list"], ensure_ascii=False)} }} +""" + ) + fewshot_block = "\n".join(lines) + + filled = self.template.substitute( + lang=self.lang, + context=context, + answer=answer, + fewshot_block=fewshot_block, + ) + return filled + + def _form_context(self, context: list[str], question: str | None) -> str: """Form a prompt from the provided context and question. We use different prompts for summary and QA tasks. :param context: A list of context strings. :param question: The question string. @@ -297,23 +345,6 @@ def _form_prompt(self, context: list[str], question: str | None) -> str: question=question, num_passages=len(context), context=context_str ) - def _create_labels(self, llm_content: str, answer: str) -> list: - """Create hallucination labels for each answer.""" - labels = [] - match_dict = re.search(r"\{.*?\}", llm_content, re.DOTALL) - try: - hal_dict = match_dict.group(0) - hal_dict = json.loads(hal_dict) - except json.JSONDecodeError: - return labels - - for hal in hal_dict["hallucination list"]: - match = re.search(re.escape(hal), answer) - if match: - labels.append({"start": match.start(), "end": match.end(), "text": hal}) - - return labels - def _get_openai_client(self) -> OpenAI: """Get OpenAI client configured from environment variables. @@ -326,6 +357,48 @@ def _get_openai_client(self) -> OpenAI: api_key=api_key, ) + def _hash(self, prompt: str) -> str: + """Hash the prompt.""" + return hashlib.sha256(prompt.encode("utf-8")).hexdigest() + + def _call_openai(self, prompt: str) -> str: + """Call the OpenAI API. + + :param prompt: The prompt to call the OpenAI API with. + :return: The response from the OpenAI API. + """ + client = self._get_openai_client() + resp = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + tools=ANNOTATE_SCHEMA, + tool_choice={"type": "function", "function": {"name": "annotate"}}, + temperature=self.temperature, + ) + + return resp.choices[0].message.tool_calls[0].function.arguments + + def _save_cache(self): + """Save the cache to the cache file.""" + self.cache_path.write_text(json.dumps(self.cache, ensure_ascii=False), encoding="utf-8") + + def _to_spans(self, subs: list[str], answer: str) -> list[dict]: + """Convert a list of substrings to a list of spans. + + :param subs: A list of substrings. + :param answer: The answer string. + :return: A list of spans. + """ + spans = [] + for s in subs: + m = re.search(re.escape(s), answer) + if m: + spans.append({"start": m.start(), "end": m.end(), "text": s}) + return spans + def _predict(self, context: str, answer: str, output_format: str = "spans") -> list: """Prompts the ChatGPT model to predict hallucination spans from the provided context and answer. @@ -333,24 +406,20 @@ def _predict(self, context: str, answer: str, output_format: str = "spans") -> l :param answer: The answer string. :param output_format: works only for "spans" and returns grouped spans. """ - client = self._get_openai_client() - if output_format == "spans": - llm_prompt = PROMPT_LLM.format(context=context, answer=answer) - llm_response = client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - {"role": "user", "content": llm_prompt}, - ], - temperature=self.temperature, - ) - llm_content = llm_response.choices[0].message.content - predictions = self._create_labels(llm_content, answer) - return predictions + llm_prompt = self._build_prompt(context, answer) + + key = self._hash("||".join([llm_prompt, self.model, str(self.temperature)])) + + # Check if the response is cached + cached_response = self.cache.get(key) + if cached_response is None: + cached_response = self._call_openai(llm_prompt) + self.cache[key] = cached_response + self._save_cache() + + payload = json.loads(cached_response) + return self._to_spans(payload["hallucination_list"], answer) else: raise ValueError( "Invalid output_format. This model can only predict hallucination spans. Use spans." @@ -380,7 +449,7 @@ def predict( :param question: The question string. :param output_format: "spans" to return grouped spans. """ - prompt = self._form_prompt(context, question) + prompt = self._form_context(context, question) return self._predict(prompt, answer, output_format=output_format) diff --git a/lettucedetect/prompts/examples_de.json b/lettucedetect/prompts/examples_de.json new file mode 100644 index 0000000..43c0dd4 --- /dev/null +++ b/lettucedetect/prompts/examples_de.json @@ -0,0 +1,21 @@ +[ + { + "source": "Was ist die Hauptstadt von Frankreich? Wie hoch ist die Bevölkerung Frankreichs? Frankreich ist ein Land in Europa. Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 67 Millionen.", + "answer": "Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 69 Millionen.", + "hallucination_list": [ + "Die Bevölkerung Frankreichs beträgt 69 Millionen." + ] + }, + { + "source": "Was ist die Hauptstadt von Frankreich? Wie hoch ist die Bevölkerung Frankreichs? Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung von Frankreich beträgt 67 Millionen.", + "answer": "Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung von Frankreich beträgt 67 Millionen, und die Amtssprache ist Spanisch.", + "hallucination_list": [ + "die Amtssprache ist Spanisch" + ] + }, + { + "source": "Was ist die Hauptstadt von Österreich? Wie hoch ist die Bevölkerung Österreich? Österreich ist ein Land in Europa. Die Hauptstadt von Österreich ist Wien. Die Bevölkerung Österreichs beträgt 9.1 Millionen.", + "answer": "Die Hauptstadt von Österreich ist Wien. Die Bevölkerung Österreichs beträgt 9.1 Millionen.", + "hallucination_list": [] + } +] \ No newline at end of file diff --git a/lettucedetect/prompts/hallucination_detection.txt b/lettucedetect/prompts/hallucination_detection.txt new file mode 100644 index 0000000..e9b2f6e --- /dev/null +++ b/lettucedetect/prompts/hallucination_detection.txt @@ -0,0 +1,29 @@ + +You are an expert annotator who must identify hallucinated substrings in a generated **answer** with respect to a given **source**. + +## Language +- The source and answer are written in **${lang}**. +- Respond **in ${lang} only**. + +## Step‑by‑step instructions +1. **Read** the answer inside . +2. **Compare** each statement with the information in …. + - *Hallucination* = a substring that **(a)** contradicts the source **or** **(b)** introduces facts not supported by the source. +3. **Decide** whether the answer contains any hallucinations. +4. **Return** a JSON object following *exactly* this schema + (no extra keys, no markdown, no code‑block fences): + + `{"hallucination_list": ["substring1", "substring2", …]}` + + If none are found, return `{"hallucination_list": []}`. + + +${fewshot_block} + + +${context} + + + +${answer} + From b6bdf4bf9bd85877d73eda46a2ae1137868dd8fa Mon Sep 17 00:00:00 2001 From: Adam Kovacs Date: Tue, 6 May 2025 12:00:50 +0000 Subject: [PATCH 2/2] Reorganized LLM baseline --- docs/EVALUATION.md | 13 ++ lettucedetect/models/evaluator.py | 50 +----- lettucedetect/models/inference.py | 93 +++++++--- lettucedetect/prompts/examples_en.json | 35 ++++ lettucedetect/prompts/examples_es.json | 35 ++++ lettucedetect/prompts/examples_fr.json | 35 ++++ lettucedetect/prompts/examples_it.json | 35 ++++ lettucedetect/prompts/examples_pl.json | 35 ++++ .../prompts/hallucination_detection.txt | 4 +- lettucedetect/prompts/qa_prompt_de.txt | 6 + lettucedetect/prompts/qa_prompt_en.txt | 6 + lettucedetect/prompts/qa_prompt_es.txt | 6 + lettucedetect/prompts/qa_prompt_fr.txt | 6 + lettucedetect/prompts/qa_prompt_it.txt | 6 + lettucedetect/prompts/qa_prompt_pl.txt | 6 + lettucedetect/prompts/summary_prompt_de.txt | 3 + lettucedetect/prompts/summary_prompt_en.txt | 3 + lettucedetect/prompts/summary_prompt_es.txt | 3 + lettucedetect/prompts/summary_prompt_fr.txt | 3 + lettucedetect/prompts/summary_prompt_it.txt | 3 + lettucedetect/prompts/summary_prompt_pl.txt | 3 + scripts/evaluate.py | 159 +++--------------- scripts/evaluate_llm.py | 132 +++++++++++++++ 23 files changed, 476 insertions(+), 204 deletions(-) create mode 100644 docs/EVALUATION.md create mode 100644 lettucedetect/prompts/examples_en.json create mode 100644 lettucedetect/prompts/examples_es.json create mode 100644 lettucedetect/prompts/examples_fr.json create mode 100644 lettucedetect/prompts/examples_it.json create mode 100644 lettucedetect/prompts/examples_pl.json create mode 100644 lettucedetect/prompts/qa_prompt_de.txt create mode 100644 lettucedetect/prompts/qa_prompt_en.txt create mode 100644 lettucedetect/prompts/qa_prompt_es.txt create mode 100644 lettucedetect/prompts/qa_prompt_fr.txt create mode 100644 lettucedetect/prompts/qa_prompt_it.txt create mode 100644 lettucedetect/prompts/qa_prompt_pl.txt create mode 100644 lettucedetect/prompts/summary_prompt_de.txt create mode 100644 lettucedetect/prompts/summary_prompt_en.txt create mode 100644 lettucedetect/prompts/summary_prompt_es.txt create mode 100644 lettucedetect/prompts/summary_prompt_fr.txt create mode 100644 lettucedetect/prompts/summary_prompt_it.txt create mode 100644 lettucedetect/prompts/summary_prompt_pl.txt create mode 100644 scripts/evaluate_llm.py diff --git a/docs/EVALUATION.md b/docs/EVALUATION.md new file mode 100644 index 0000000..c6b8cdb --- /dev/null +++ b/docs/EVALUATION.md @@ -0,0 +1,13 @@ +# Evaluation + +## Use LLM baselines + +```bash +python scripts/evaluate_llm.py --model "gpt-4o-mini" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level" +``` + +## Use HallucinationDetector + +```bash +python scripts/evaluate.py --model_path "output/hallucination_detection_de_210m" --data_path "data/translated/ragtruth-de-translated-300sample.json" --evaluation_type "example_level" +``` diff --git a/lettucedetect/models/evaluator.py b/lettucedetect/models/evaluator.py index 476f806..2b6b451 100644 --- a/lettucedetect/models/evaluator.py +++ b/lettucedetect/models/evaluator.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from tqdm.auto import tqdm -from lettucedetect.datasets.hallucination_dataset import HallucinationData, HallucinationSample +from lettucedetect.datasets.hallucination_dataset import HallucinationSample from lettucedetect.models.inference import HallucinationDetector @@ -224,8 +224,6 @@ def create_sample_llm(sample, labels): def evaluate_detector_char_level( detector: HallucinationDetector, samples: list[HallucinationSample], - samples_llm: list[HallucinationSample] = None, - baseline_file_exists: bool = False, ) -> dict[str, float]: """Evaluate the HallucinationDetector at the character level. @@ -241,24 +239,17 @@ def evaluate_detector_char_level( :param detector: The detector to evaluate. :param samples: A list of samples to evaluate. - :param samples_llm : A list of samples containing LLM generated labels, is used if baseline file exists. - :param baseline_file_exists: Gives information if baseline file exists or should be created. :return: A dictionary with global metrics: {"char_precision": ..., "char_recall": ..., "char_f1": ...} """ total_overlap = 0 total_predicted = 0 total_gold = 0 - hallucination_data_llm = HallucinationData(samples=[]) - for i, sample in enumerate(tqdm(samples, desc="Evaluating", leave=False)): + for sample in tqdm(samples, desc="Evaluating", leave=False): prompt = sample.prompt answer = sample.answer gold_spans = sample.labels - predicted_spans = ( - samples_llm[i].labels - if baseline_file_exists - else detector.predict_prompt(prompt, answer, output_format="spans") - ) + predicted_spans = detector.predict_prompt(prompt, answer, output_format="spans") # Compute total predicted span length for this sample. sample_predicted_length = sum(pred["end"] - pred["start"] for pred in predicted_spans) @@ -278,22 +269,16 @@ def evaluate_detector_char_level( sample_overlap += overlap_end - overlap_start total_overlap += sample_overlap - if not baseline_file_exists: - sample_llm = create_sample_llm(sample, predicted_spans) - hallucination_data_llm.samples.append(sample_llm) - precision = total_overlap / total_predicted if total_predicted > 0 else 0 recall = total_overlap / total_gold if total_gold > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 - return {"precision": precision, "recall": recall, "f1": f1}, hallucination_data_llm + return {"precision": precision, "recall": recall, "f1": f1} def evaluate_detector_example_level( detector: HallucinationDetector, samples: list[HallucinationSample], - samples_llm: list[HallucinationSample] = None, - baseline_file_exists: bool = None, verbose: bool = True, ) -> dict[str, dict[str, float]]: """Evaluate the HallucinationDetector at the example level. @@ -319,27 +304,18 @@ def evaluate_detector_example_level( """ example_preds: list[int] = [] example_labels: list[int] = [] - hallucination_data_llm = HallucinationData(samples=[]) - for i, sample in enumerate(tqdm(samples, desc="Evaluating", leave=False)): + for sample in tqdm(samples, desc="Evaluating", leave=False): prompt = sample.prompt answer = sample.answer gold_spans = sample.labels - predicted_spans = ( - samples_llm.__getitem__(i).labels - if baseline_file_exists - else detector.predict_prompt(prompt, answer, output_format="spans") - ) + predicted_spans = detector.predict_prompt(prompt, answer, output_format="spans") true_example_label = 1 if gold_spans else 0 pred_example_label = 1 if predicted_spans else 0 example_labels.append(true_example_label) example_preds.append(pred_example_label) - if not baseline_file_exists: - sample_llm = create_sample_llm(sample, predicted_spans) - hallucination_data_llm.samples.append(sample_llm) - precision, recall, f1, _ = precision_recall_fscore_support( example_labels, example_preds, labels=[0, 1], average=None, zero_division=0 ) @@ -361,18 +337,6 @@ def evaluate_detector_example_level( fpr, tpr, _ = roc_curve(example_labels, example_preds) auroc = auc(fpr, tpr) - results: dict[str, dict[str, float]] = { - "supported": { # Class 0 - "precision": float(precision[0]), - "recall": float(recall[0]), - "f1": float(f1[0]), - }, - "hallucinated": { # Class 1 - "precision": float(precision[1]), - "recall": float(recall[1]), - "f1": float(f1[1]), - }, - } results["auroc"] = auroc if verbose: @@ -387,4 +351,4 @@ def evaluate_detector_example_level( print(report) results["classification_report"] = report - return results, hallucination_data_llm + return results diff --git a/lettucedetect/models/inference.py b/lettucedetect/models/inference.py index cba4bda..f638954 100644 --- a/lettucedetect/models/inference.py +++ b/lettucedetect/models/inference.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from pathlib import Path from string import Template +from typing import Literal import torch from openai import OpenAI @@ -14,22 +15,14 @@ HallucinationDataset, ) -# ==== For formatting user input to the right format ==== -PROMPT_QA = """ -Briefly answer the following question: -{question} -Bear in mind that your response should be strictly based on the following {num_passages} passages: -{context} -In case the passages do not contain the necessary information to answer the question, please reply with: "Unable to answer based on given passages." -output: -""" - -PROMPT_SUMMARY = """ -Summarize the following text: -{text} -output: -""" -# ===================================================== +LANG_TO_PASSAGE = { + "en": "passage", + "de": "Passage", + "fr": "passage", + "es": "pasaje", + "it": "brano", + "pl": "fragment", +} # ==== Base class for all detectors ==== @@ -47,13 +40,25 @@ def predict(self, context: str, answer: str, output_format: str = "tokens") -> l # ==== Transformer-based detector ==== class TransformerDetector(BaseDetector): - def __init__(self, model_path: str, max_length: int = 4096, device=None, **kwargs): + def __init__( + self, + model_path: str, + max_length: int = 4096, + device=None, + lang: Literal["en", "de", "fr", "es", "it", "pl"] = "en", + **kwargs, + ): """Initialize the TransformerDetector. :param model_path: The path to the model. :param max_length: The maximum length of the input sequence. :param device: The device to run the model on. + :param lang: The language of the model. """ + if lang not in LANG_TO_PASSAGE: + raise ValueError(f"Invalid language. Use one of: {', '.join(LANG_TO_PASSAGE.keys())}") + + self.lang = lang self.tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) self.model = AutoModelForTokenClassification.from_pretrained(model_path, **kwargs) self.max_length = max_length @@ -61,6 +66,13 @@ def __init__(self, model_path: str, max_length: int = 4096, device=None, **kwarg self.model.to(self.device) self.model.eval() + prompt_path = Path(__file__).parent.parent / "prompts" / f"qa_prompt_{lang.lower()}.txt" + self.prompt_qa = Template(prompt_path.read_text(encoding="utf-8")) + prompt_path = ( + Path(__file__).parent.parent / "prompts" / f"summary_prompt_{lang.lower()}.txt" + ) + self.prompt_summary = Template(prompt_path.read_text(encoding="utf-8")) + def _form_prompt(self, context: list[str], question: str | None) -> str: """Form a prompt from the provided context and question. We use different prompts for summary and QA tasks. @@ -69,12 +81,15 @@ def _form_prompt(self, context: list[str], question: str | None) -> str: :return: The formatted prompt. """ context_str = "\n".join( - [f"passage {i + 1}: {passage}" for i, passage in enumerate(context)] + [ + f"{LANG_TO_PASSAGE[self.lang]} {i + 1}: {passage}" + for i, passage in enumerate(context) + ] ) if question is None: - return PROMPT_SUMMARY.format(text=context_str) + return self.prompt_summary.substitute(text=context_str) else: - return PROMPT_QA.format( + return self.prompt_qa.substitute( question=question, num_passages=len(context), context=context_str ) @@ -243,7 +258,8 @@ def __init__( self, model: str = "gpt-4o", temperature: int = 0, - lang: str = "en", + lang: Literal["en", "de", "fr", "es", "it", "pl"] = "en", + zero_shot: bool = False, fewshot_path: str | None = None, prompt_path: str | None = None, cache_file: str | None = None, @@ -253,6 +269,7 @@ def __init__( :param model: OpenAI model. :param temperature: model temperature. :param lang: language of the examples. + :param zero_shot: whether to use zero-shot prompting. :param fewshot_path: path to the fewshot examples. :param prompt_path: path to the prompt template. :param cache_file: path to the cache file. @@ -260,8 +277,11 @@ def __init__( self.model = model self.temperature = temperature - self.lang = lang + if lang not in LANG_TO_PASSAGE: + raise ValueError(f"Invalid language. Use one of: {', '.join(LANG_TO_PASSAGE.keys())}") + self.lang = lang + self.zero_shot = zero_shot if fewshot_path is None: print( f"No fewshot path provided, using default path: {Path(__file__).parent.parent / 'prompts' / f'examples_{lang.lower()}.json'}" @@ -283,7 +303,15 @@ def __init__( else: template_path = Path(prompt_path) + prompt_qa_path = Path(__file__).parent.parent / "prompts" / f"qa_prompt_{lang.lower()}.txt" + prompt_summary_path = ( + Path(__file__).parent.parent / "prompts" / f"summary_prompt_{lang.lower()}.txt" + ) + self.template = Template(template_path.read_text(encoding="utf-8")) + self.prompt_qa = Template(prompt_qa_path.read_text(encoding="utf-8")) + self.prompt_summary = Template(prompt_summary_path.read_text(encoding="utf-8")) + self.fewshot = json.loads(fewshot_path.read_text(encoding="utf-8")) self.cache_path = cache_file @@ -292,7 +320,13 @@ def __init__( Path(__file__).parent.parent / "cache" / f"cache_{self.model}_{self.lang}.json" ) self.cache_path.parent.mkdir(parents=True, exist_ok=True) - self.cache = {} + + # Read in cache + if self.cache_path.exists(): + self.cache = json.loads(self.cache_path.read_text(encoding="utf-8")) + else: + self.cache = {} + print(f"Cache file not provided, using default path: {self.cache_path}") else: self.cache_path = Path(cache_file) @@ -309,7 +343,7 @@ def _build_prompt( Uses `${placeholder}` tokens in the .txt file. """ fewshot_block = "" - if self.fewshot: + if self.fewshot and not self.zero_shot: lines: list[str] = [] for idx, ex in enumerate(self.fewshot, 1): lines.append( @@ -339,9 +373,9 @@ def _form_context(self, context: list[str], question: str | None) -> str: [f"passage {i + 1}: {passage}" for i, passage in enumerate(context)] ) if question is None: - return PROMPT_SUMMARY.format(text=context_str) + return self.prompt_summary.substitute(text=context_str) else: - return PROMPT_QA.format( + return self.prompt_qa.substitute( question=question, num_passages=len(context), context=context_str ) @@ -352,9 +386,11 @@ def _get_openai_client(self) -> OpenAI: :raises ValueError: If API key is not set """ api_key = os.getenv("OPENAI_API_KEY") or "EMPTY" + api_base = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" return OpenAI( api_key=api_key, + base_url=api_base, ) def _hash(self, prompt: str) -> str: @@ -371,7 +407,10 @@ def _call_openai(self, prompt: str) -> str: resp = client.chat.completions.create( model=self.model, messages=[ - {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "system", + "content": "You are an expert in detecting hallucinations in LLM outputs.", + }, {"role": "user", "content": prompt}, ], tools=ANNOTATE_SCHEMA, diff --git a/lettucedetect/prompts/examples_en.json b/lettucedetect/prompts/examples_en.json new file mode 100644 index 0000000..86e2f4f --- /dev/null +++ b/lettucedetect/prompts/examples_en.json @@ -0,0 +1,35 @@ +[ + { + "source": "What is the capital of France? What is the population of France? France is a country in Europe. The capital of France is Paris. The population of France is 67 million.", + "answer": "The capital of France is Paris. The population of France is 69 million.", + "hallucination_list": [ + "The population of France is 69 million." + ] + }, + { + "source": "What is the capital of France? What is the population of France? The capital of France is Paris. The population of France is 67 million.", + "answer": "The capital of France is Paris. The population of France is 67 million, and the official language is Spanish.", + "hallucination_list": [ + "the official language is Spanish" + ] + }, + { + "source": "What is the capital of Austria? What is the population of Austria? Austria is a country in Europe. The capital of Austria is Vienna. The population of Austria is 9.1 million.", + "answer": "The capital of Austria is Vienna. The population of Austria is 9.1 million.", + "hallucination_list": [] + }, + { + "source": "Who was the first person to walk on the moon? When did this happen? Neil Armstrong was the first person to walk on the moon. This historic event took place on July 20, 1969, during the Apollo 11 mission.", + "answer": "Neil Armstrong was the first person to walk on the moon. This historic event took place on July 16, 1969, during the Apollo 11 mission.", + "hallucination_list": [ + "This historic event took place on July 16, 1969" + ] + }, + { + "source": "What is the tallest mountain in the world? Mount Everest is the tallest mountain in the world, with a height of 8,848 meters above sea level. It is located in the Mahalangur Himal sub-range of the Himalayas on the border between Nepal and Tibet.", + "answer": "Mount Everest is the tallest mountain in the world, with a height of 8,848 meters above sea level. It is located in the Mahalangur Himal sub-range of the Himalayas on the border between India and China.", + "hallucination_list": [ + "It is located in the Mahalangur Himal sub-range of the Himalayas on the border between India and China." + ] + } +] diff --git a/lettucedetect/prompts/examples_es.json b/lettucedetect/prompts/examples_es.json new file mode 100644 index 0000000..db3efaf --- /dev/null +++ b/lettucedetect/prompts/examples_es.json @@ -0,0 +1,35 @@ +[ + { + "source": "¿Cuál es la capital de Francia? ¿Cuál es la población de Francia? Francia es un país en Europa. La capital de Francia es París. La población de Francia es de 67 millones.", + "answer": "La capital de Francia es París. La población de Francia es de 69 millones.", + "hallucination_list": [ + "La población de Francia es de 69 millones." + ] + }, + { + "source": "¿Cuál es la capital de Francia? ¿Cuál es la población de Francia? La capital de Francia es París. La población de Francia es de 67 millones.", + "answer": "La capital de Francia es París. La población de Francia es de 67 millones, y el idioma oficial es el italiano.", + "hallucination_list": [ + "el idioma oficial es el italiano" + ] + }, + { + "source": "¿Cuál es la capital de Austria? ¿Cuál es la población de Austria? Austria es un país en Europa. La capital de Austria es Viena. La población de Austria es de 9,1 millones.", + "answer": "La capital de Austria es Viena. La población de Austria es de 9,1 millones.", + "hallucination_list": [] + }, + { + "source": "¿Quién fue la primera persona en caminar sobre la luna? ¿Cuándo sucedió esto? Neil Armstrong fue la primera persona en caminar sobre la luna. Este histórico evento tuvo lugar el 20 de julio de 1969, durante la misión Apolo 11.", + "answer": "Neil Armstrong fue la primera persona en caminar sobre la luna. Este histórico evento tuvo lugar el 16 de julio de 1969, durante la misión Apolo 11.", + "hallucination_list": [ + "Este histórico evento tuvo lugar el 16 de julio de 1969" + ] + }, + { + "source": "¿Cuál es la montaña más alta del mundo? El Monte Everest es la montaña más alta del mundo, con una altura de 8.848 metros sobre el nivel del mar. Está ubicado en la subcordillera Mahalangur Himal del Himalaya, en la frontera entre Nepal y Tíbet.", + "answer": "El Monte Everest es la montaña más alta del mundo, con una altura de 8.848 metros sobre el nivel del mar. Está ubicado en la subcordillera Mahalangur Himal del Himalaya, en la frontera entre India y China.", + "hallucination_list": [ + "Está ubicado en la subcordillera Mahalangur Himal del Himalaya, en la frontera entre India y China." + ] + } +] \ No newline at end of file diff --git a/lettucedetect/prompts/examples_fr.json b/lettucedetect/prompts/examples_fr.json new file mode 100644 index 0000000..55f42ae --- /dev/null +++ b/lettucedetect/prompts/examples_fr.json @@ -0,0 +1,35 @@ +[ + { + "source": "Quelle est la capitale de la France ? Quelle est la population de la France ? La France est un pays d'Europe. La capitale de la France est Paris. La population de la France est de 67 millions.", + "answer": "La capitale de la France est Paris. La population de la France est de 69 millions.", + "hallucination_list": [ + "La population de la France est de 69 millions." + ] + }, + { + "source": "Quelle est la capitale de la France ? Quelle est la population de la France ? La capitale de la France est Paris. La population de la France est de 67 millions.", + "answer": "La capitale de la France est Paris. La population de la France est de 67 millions, et la langue officielle est l'espagnol.", + "hallucination_list": [ + "la langue officielle est l'espagnol" + ] + }, + { + "source": "Quelle est la capitale de l'Autriche ? Quelle est la population de l'Autriche ? L'Autriche est un pays d'Europe. La capitale de l'Autriche est Vienne. La population de l'Autriche est de 9,1 millions.", + "answer": "La capitale de l'Autriche est Vienne. La population de l'Autriche est de 9,1 millions.", + "hallucination_list": [] + }, + { + "source": "Qui a été le premier homme à marcher sur la lune ? Quand cela s'est-il produit ? Neil Armstrong a été le premier homme à marcher sur la lune. Cet événement historique a eu lieu le 20 juillet 1969, lors de la mission Apollo 11.", + "answer": "Neil Armstrong a été le premier homme à marcher sur la lune. Cet événement historique a eu lieu le 16 juillet 1969, lors de la mission Apollo 11.", + "hallucination_list": [ + "Cet événement historique a eu lieu le 16 juillet 1969" + ] + }, + { + "source": "Quelle est la plus haute montagne du monde ? Le mont Everest est la plus haute montagne du monde, avec une altitude de 8 848 mètres au-dessus du niveau de la mer. Il est situé dans la sous-chaîne de Mahalangur Himal de l'Himalaya, à la frontière entre le Népal et le Tibet.", + "answer": "Le mont Everest est la plus haute montagne du monde, avec une altitude de 8 848 mètres au-dessus du niveau de la mer. Il est situé dans la sous-chaîne de Mahalangur Himal de l'Himalaya, à la frontière entre l'Inde et la Chine.", + "hallucination_list": [ + "Il est situé dans la sous-chaîne de Mahalangur Himal de l'Himalaya, à la frontière entre l'Inde et la Chine." + ] + } +] \ No newline at end of file diff --git a/lettucedetect/prompts/examples_it.json b/lettucedetect/prompts/examples_it.json new file mode 100644 index 0000000..c9b3fef --- /dev/null +++ b/lettucedetect/prompts/examples_it.json @@ -0,0 +1,35 @@ +[ + { + "source": "Qual è la capitale della Francia? Qual è la popolazione della Francia? La Francia è un paese in Europa. La capitale della Francia è Parigi. La popolazione della Francia è di 67 milioni.", + "answer": "La capitale della Francia è Parigi. La popolazione della Francia è di 69 milioni.", + "hallucination_list": [ + "La popolazione della Francia è di 69 milioni." + ] + }, + { + "source": "Qual è la capitale della Francia? Qual è la popolazione della Francia? La capitale della Francia è Parigi. La popolazione della Francia è di 67 milioni.", + "answer": "La capitale della Francia è Parigi. La popolazione della Francia è di 67 milioni, e la lingua ufficiale è lo spagnolo.", + "hallucination_list": [ + "la lingua ufficiale è lo spagnolo" + ] + }, + { + "source": "Qual è la capitale dell'Austria? Qual è la popolazione dell'Austria? L'Austria è un paese in Europa. La capitale dell'Austria è Vienna. La popolazione dell'Austria è di 9,1 milioni.", + "answer": "La capitale dell'Austria è Vienna. La popolazione dell'Austria è di 9,1 milioni.", + "hallucination_list": [] + }, + { + "source": "Chi è stato il primo uomo a camminare sulla luna? Quando è successo? Neil Armstrong è stato il primo uomo a camminare sulla luna. Questo evento storico è avvenuto il 20 luglio 1969, durante la missione Apollo 11.", + "answer": "Neil Armstrong è stato il primo uomo a camminare sulla luna. Questo evento storico è avvenuto il 16 luglio 1969, durante la missione Apollo 11.", + "hallucination_list": [ + "Questo evento storico è avvenuto il 16 luglio 1969" + ] + }, + { + "source": "Qual è la montagna più alta del mondo? Il Monte Everest è la montagna più alta del mondo, con un'altitudine di 8.848 metri sul livello del mare. Si trova nella sottocatena Mahalangur Himal dell'Himalaya, al confine tra Nepal e Tibet.", + "answer": "Il Monte Everest è la montagna più alta del mondo, con un'altitudine di 8.848 metri sul livello del mare. Si trova nella sottocatena Mahalangur Himal dell'Himalaya, al confine tra India e Cina.", + "hallucination_list": [ + "Si trova nella sottocatena Mahalangur Himal dell'Himalaya, al confine tra India e Cina." + ] + } +] \ No newline at end of file diff --git a/lettucedetect/prompts/examples_pl.json b/lettucedetect/prompts/examples_pl.json new file mode 100644 index 0000000..b1f9dfa --- /dev/null +++ b/lettucedetect/prompts/examples_pl.json @@ -0,0 +1,35 @@ +[ + { + "source": "Jaka jest stolica Francji? Jaka jest populacja Francji? Francja jest krajem w Europie. Stolicą Francji jest Paryż. Populacja Francji wynosi 67 milionów.", + "answer": "Stolicą Francji jest Paryż. Populacja Francji wynosi 69 milionów.", + "hallucination_list": [ + "Populacja Francji wynosi 69 milionów." + ] + }, + { + "source": "Jaka jest stolica Francji? Jaka jest populacja Francji? Stolicą Francji jest Paryż. Populacja Francji wynosi 67 milionów.", + "answer": "Stolicą Francji jest Paryż. Populacja Francji wynosi 67 milionów, a językiem urzędowym jest hiszpański.", + "hallucination_list": [ + "językiem urzędowym jest hiszpański" + ] + }, + { + "source": "Jaka jest stolica Austrii? Jaka jest populacja Austrii? Austria jest krajem w Europie. Stolicą Austrii jest Wiedeń. Populacja Austrii wynosi 9,1 miliona.", + "answer": "Stolicą Austrii jest Wiedeń. Populacja Austrii wynosi 9,1 miliona.", + "hallucination_list": [] + }, + { + "source": "Kto był pierwszym człowiekiem, który chodził po księżycu? Kiedy to się stało? Neil Armstrong był pierwszym człowiekiem, który chodził po księżycu. To historyczne wydarzenie miało miejsce 20 lipca 1969 roku, podczas misji Apollo 11.", + "answer": "Neil Armstrong był pierwszym człowiekiem, który chodził po księżycu. To historyczne wydarzenie miało miejsce 16 lipca 1969 roku, podczas misji Apollo 11.", + "hallucination_list": [ + "To historyczne wydarzenie miało miejsce 16 lipca 1969 roku" + ] + }, + { + "source": "Jaka jest najwyższa góra na świecie? Mount Everest jest najwyższą górą na świecie, z wysokością 8848 metrów nad poziomem morza. Znajduje się w paśmie Mahalangur Himal w Himalajach, na granicy między Nepalem a Tybetem.", + "answer": "Mount Everest jest najwyższą górą na świecie, z wysokością 8848 metrów nad poziomem morza. Znajduje się w paśmie Mahalangur Himal w Himalajach, na granicy między Indiami a Chinami.", + "hallucination_list": [ + "Znajduje się w paśmie Mahalangur Himal w Himalajach, na granicy między Indiami a Chinami." + ] + } +] \ No newline at end of file diff --git a/lettucedetect/prompts/hallucination_detection.txt b/lettucedetect/prompts/hallucination_detection.txt index e9b2f6e..cb439b6 100644 --- a/lettucedetect/prompts/hallucination_detection.txt +++ b/lettucedetect/prompts/hallucination_detection.txt @@ -9,7 +9,9 @@ You are an expert annotator who must identify hallucinated substrings in a gener 1. **Read** the answer inside . 2. **Compare** each statement with the information in …. - *Hallucination* = a substring that **(a)** contradicts the source **or** **(b)** introduces facts not supported by the source. -3. **Decide** whether the answer contains any hallucinations. + - *Not hallucination* = a substring that is consistent with the source. + - *Boileplate substring* = a substring that is not a hallucination but is not relevant to the question (e.g. introductory phrases, etc.) +3. **Decide** whether the answer contains any hallucinations, be precise, in your answer only include substrings that are hallucinations. 4. **Return** a JSON object following *exactly* this schema (no extra keys, no markdown, no code‑block fences): diff --git a/lettucedetect/prompts/qa_prompt_de.txt b/lettucedetect/prompts/qa_prompt_de.txt new file mode 100644 index 0000000..e8be970 --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_de.txt @@ -0,0 +1,6 @@ +Beantworte die folgende Frage kurz: +${question} +Beachte, dass deine Antwort nur auf den folgenden ${num_passages} Passagen basieren soll: +${context} +Falls die Passagen nicht die notwendigen Informationen zur Beantwortung der Frage enthalten, antworte bitte mit: "Kann aufgrund der gegebenen Passagen nicht beantwortet werden." +Ausgabe: \ No newline at end of file diff --git a/lettucedetect/prompts/qa_prompt_en.txt b/lettucedetect/prompts/qa_prompt_en.txt new file mode 100644 index 0000000..8ffd4a3 --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_en.txt @@ -0,0 +1,6 @@ +Briefly answer the following question: +${question} +Bear in mind that your response should be strictly based on the following ${num_passages} passages: +${context} +In case the passages do not contain the necessary information to answer the question, please reply with: "Unable to answer based on given passages." +output: \ No newline at end of file diff --git a/lettucedetect/prompts/qa_prompt_es.txt b/lettucedetect/prompts/qa_prompt_es.txt new file mode 100644 index 0000000..595d2ed --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_es.txt @@ -0,0 +1,6 @@ +Responde a la siguiente pregunta brevemente: +${question} +Ten en cuenta que tu respuesta debe basarse únicamente en los siguientes ${num_passages} pasajes: +${context} +Si los pasajes no contienen la información necesaria para responder a la pregunta, por favor responde con: "No se puede responder con base en los pasajes proporcionados." +Salida: \ No newline at end of file diff --git a/lettucedetect/prompts/qa_prompt_fr.txt b/lettucedetect/prompts/qa_prompt_fr.txt new file mode 100644 index 0000000..92c0b57 --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_fr.txt @@ -0,0 +1,6 @@ +Réponds brièvement à la question suivante : +${question} +Note que ta réponse doit être basée uniquement sur les ${num_passages} passages suivants : +${context} +Si les passages ne contiennent pas les informations nécessaires pour répondre à la question, réponds s'il te plaît par : "Impossible de répondre sur la base des passages fournis." +Sortie : \ No newline at end of file diff --git a/lettucedetect/prompts/qa_prompt_it.txt b/lettucedetect/prompts/qa_prompt_it.txt new file mode 100644 index 0000000..11550b4 --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_it.txt @@ -0,0 +1,6 @@ +Rispondi alla seguente domanda in modo breve: +${question} +Tieni presente che la tua risposta deve basarsi solo sui seguenti ${num_passages} passaggi: +${context} +Se i passaggi non contengono le informazioni necessarie per rispondere alla domanda, per favore rispondi con: "Non può essere risposto in base ai passaggi forniti." +Risposta: diff --git a/lettucedetect/prompts/qa_prompt_pl.txt b/lettucedetect/prompts/qa_prompt_pl.txt new file mode 100644 index 0000000..df7fe4c --- /dev/null +++ b/lettucedetect/prompts/qa_prompt_pl.txt @@ -0,0 +1,6 @@ +Odpowiedz krótko na następujące pytanie: +${question} +Pamiętaj, że Twoja odpowiedź powinna opierać się wyłącznie na następujących ${num_passages} fragmentach: +${context} +Jeśli fragmenty nie zawierają informacji niezbędnych do udzielenia odpowiedzi na pytanie, odpowiedz: "Nie można odpowiedzieć na podstawie podanych fragmentów." +Wynik: \ No newline at end of file diff --git a/lettucedetect/prompts/summary_prompt_de.txt b/lettucedetect/prompts/summary_prompt_de.txt new file mode 100644 index 0000000..a3ab347 --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_de.txt @@ -0,0 +1,3 @@ +Fasse den folgenden Text zusammen: +${text} +Ausgabe: diff --git a/lettucedetect/prompts/summary_prompt_en.txt b/lettucedetect/prompts/summary_prompt_en.txt new file mode 100644 index 0000000..67831a9 --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_en.txt @@ -0,0 +1,3 @@ +Summarize the following text: +${text} +output: \ No newline at end of file diff --git a/lettucedetect/prompts/summary_prompt_es.txt b/lettucedetect/prompts/summary_prompt_es.txt new file mode 100644 index 0000000..04ce79b --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_es.txt @@ -0,0 +1,3 @@ +Resume el siguiente texto: +${text} +Salida: \ No newline at end of file diff --git a/lettucedetect/prompts/summary_prompt_fr.txt b/lettucedetect/prompts/summary_prompt_fr.txt new file mode 100644 index 0000000..ba91cbc --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_fr.txt @@ -0,0 +1,3 @@ +Résume le texte suivant : +${text} +Sortie : \ No newline at end of file diff --git a/lettucedetect/prompts/summary_prompt_it.txt b/lettucedetect/prompts/summary_prompt_it.txt new file mode 100644 index 0000000..c6f9c86 --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_it.txt @@ -0,0 +1,3 @@ +Riassumi il seguente testo: +${text} +Risposta: \ No newline at end of file diff --git a/lettucedetect/prompts/summary_prompt_pl.txt b/lettucedetect/prompts/summary_prompt_pl.txt new file mode 100644 index 0000000..78042b5 --- /dev/null +++ b/lettucedetect/prompts/summary_prompt_pl.txt @@ -0,0 +1,3 @@ +Streść następujący tekst: +${text} +Wynik: \ No newline at end of file diff --git a/scripts/evaluate.py b/scripts/evaluate.py index f48eb0b..29cc73b 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -59,40 +59,13 @@ def evaluate_task_samples( else: # char_level print("\n---- Character-Level Span Evaluation ----") - metrics = evaluate_detector_char_level(detector, samples)[0] + metrics = evaluate_detector_char_level(detector, samples) print(f" Precision: {metrics['precision']:.4f}") print(f" Recall: {metrics['recall']:.4f}") print(f" F1: {metrics['f1']:.4f}") return metrics -def evaluate_task_samples_llm( - samples, evaluation_type, detector, samples_llm, baseline_file_exists -): - print(f"\nEvaluating model on {len(samples)} samples") - - if evaluation_type == "example_level": - print("\n---- Example-Level Span Evaluation ----") - metrics, hallucination_data_llm = evaluate_detector_example_level( - detector, samples, samples_llm, baseline_file_exists - ) - print_metrics(metrics) - return metrics, hallucination_data_llm - elif evaluation_type == "char_level": - print("\n---- Character-Level Span Evaluation ----") - metrics, hallucination_data_llm = evaluate_detector_char_level( - detector, samples, samples_llm, baseline_file_exists - ) - print(f" Precision: {metrics['precision']:.4f}") - print(f" Recall: {metrics['recall']:.4f}") - print(f" F1: {metrics['f1']:.4f}") - return metrics, hallucination_data_llm - else: - raise ValueError( - "This evaluation type is not available for this method. Use either 'example_level' or 'char_level'." - ) - - def load_data(data_path): data_path = Path(data_path) hallucination_data = HallucinationData.from_json(json.loads(data_path.read_text())) @@ -109,36 +82,8 @@ def load_data(data_path): return test_samples, task_type_map -def save_baseline_data(data_path_llm, hallucination_data_llm): - """This function saves the LLM baseline data into a file.""" - data_path_llm = Path(data_path_llm) - (data_path_llm).write_text(json.dumps(hallucination_data_llm.to_json(), indent=4)) - - -def exists_baseline_data(data_path, data_path_llm): - """This function checks whether there is already an existing file containing LLM labels.""" - data_path = Path(data_path) - data_path_llm = Path(data_path_llm) - - if data_path_llm.exists() and data_path_llm.is_file(): - hallucination_data = HallucinationData.from_json(json.loads(data_path.read_text())) - hallucination_data_llm = HallucinationData.from_json(json.loads(data_path_llm.read_text())) - if len(hallucination_data.samples) == len(hallucination_data_llm.samples): - return True - else: - return False - else: - return False - - def main(): parser = argparse.ArgumentParser(description="Evaluate a hallucination detection model") - parser.add_argument( - "--method", - type=str, - required=True, - help="Detector method. Choose either 'transformer' or 'llm'.", - ) parser.add_argument("--model_path", type=str, required=True, help="Path to the saved model") parser.add_argument( "--data_path", @@ -159,55 +104,29 @@ def main(): help="Batch size for evaluation", ) - parser.add_argument( - "--data_path_llm", - type=int, - default=None, - help="Path to LLM baseline data (JSON Format)", - ) - args = parser.parse_args() test_samples, task_type_map = load_data(args.data_path) - baseline_file_exists = ( - False - if args.data_path_llm is None - else exists_baseline_data(args.data_path, args.data_path_llm) - ) - print(f"\nEvaluating model on test samples: {len(test_samples)}") # Setup model/detector based on evaluation type - if args.method == "transformer": - if args.evaluation_type in {"token_level", "example_level"}: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = AutoModelForTokenClassification.from_pretrained( - args.model_path, trust_remote_code=True - ).to(device) - tokenizer = AutoTokenizer.from_pretrained(args.model_path) - detector = None - else: # char_level - model, tokenizer, device = None, None, None - detector = HallucinationDetector(method=args.method, model_path=args.model_path) - - # Evaluate each task type separately - for task_type, samples in task_type_map.items(): - print(f"\nTask type: {task_type}") - evaluate_task_samples( - samples, - args.evaluation_type, - model=model, - tokenizer=tokenizer, - detector=detector, - device=device, - batch_size=args.batch_size, - ) + if args.evaluation_type in {"token_level", "example_level"}: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = AutoModelForTokenClassification.from_pretrained( + args.model_path, trust_remote_code=True + ).to(device) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + detector = None + else: # char_level + model, tokenizer, device = None, None, None + detector = HallucinationDetector(method="transformer", model_path=args.model_path) - # Evaluate the whole dataset - print("\nTask type: whole dataset") + # Evaluate each task type separately + for task_type, samples in task_type_map.items(): + print(f"\nTask type: {task_type}") evaluate_task_samples( - test_samples, + samples, args.evaluation_type, model=model, tokenizer=tokenizer, @@ -216,43 +135,17 @@ def main(): batch_size=args.batch_size, ) - elif args.method == "llm": - if baseline_file_exists: - test_samples_llm, task_type_map_llm = load_data(args.data_path_llm) - else: - test_samples_llm, task_type_map_llm = None, None - model, tokenizer, device = None, None, None - detector = HallucinationDetector(method=args.method) - samples, samples_llm = (test_samples, test_samples_llm) - - # Evaluate the whole dataset - print("\nTask type: whole dataset") - metrics, hallucination_data_llm = evaluate_task_samples_llm( - samples, - args.evaluation_type, - detector=detector, - samples_llm=samples_llm, - baseline_file_exists=baseline_file_exists, - ) - - if not baseline_file_exists: - save_baseline_data(args.data_path_llm, hallucination_data_llm) - - test_samples_llm, task_type_map_llm = load_data(args.data_path_llm) - - for task_type, samples in task_type_map.items(): - for task_type_llm, samples_llm in task_type_map_llm.items(): - print(f"\nTask type: {task_type}") - evaluate_task_samples_llm( - samples, - args.evaluation_type, - detector=detector, - samples_llm=samples_llm, - baseline_file_exists=True, - ) - - else: - raise ValueError("Unsupported method. Choose 'transformer' or 'llm'.") + # Evaluate the whole dataset + print("\nTask type: whole dataset") + evaluate_task_samples( + test_samples, + args.evaluation_type, + model=model, + tokenizer=tokenizer, + detector=detector, + device=device, + batch_size=args.batch_size, + ) if __name__ == "__main__": diff --git a/scripts/evaluate_llm.py b/scripts/evaluate_llm.py new file mode 100644 index 0000000..cba492f --- /dev/null +++ b/scripts/evaluate_llm.py @@ -0,0 +1,132 @@ +import argparse +import json +from pathlib import Path + +from lettucedetect.datasets.hallucination_dataset import ( + HallucinationData, + HallucinationSample, +) +from lettucedetect.models.evaluator import ( + evaluate_detector_char_level, + evaluate_detector_example_level, + print_metrics, +) +from lettucedetect.models.inference import HallucinationDetector + + +def evaluate_task_samples_llm( + samples: list[HallucinationSample], evaluation_type: str, detector: HallucinationDetector +): + """Evaluate the model on the samples. + + :param samples: list of samples to evaluate + :param evaluation_type: evaluation type (example_level or char_level) + :param detector: detector to use + :return: metrics and hallucination data + """ + print(f"\nEvaluating model on {len(samples)} samples") + + if evaluation_type == "example_level": + print("\n---- Example-Level Span Evaluation ----") + metrics = evaluate_detector_example_level(detector, samples) + print_metrics(metrics) + return metrics + elif evaluation_type == "char_level": + print("\n---- Character-Level Span Evaluation ----") + metrics = evaluate_detector_char_level(detector, samples) + print(f" Precision: {metrics['precision']:.4f}") + print(f" Recall: {metrics['recall']:.4f}") + print(f" F1: {metrics['f1']:.4f}") + return metrics + else: + raise ValueError( + "This evaluation type is not available for this method. Use either 'example_level' or 'char_level'." + ) + + +def load_data(data_path): + data_path = Path(data_path) + hallucination_data = HallucinationData.from_json(json.loads(data_path.read_text())) + + # Filter test samples from the data + test_samples = [sample for sample in hallucination_data.samples if sample.split == "test"] + + # group samples by task type + task_type_map = {} + for sample in test_samples: + if sample.task_type not in task_type_map: + task_type_map[sample.task_type] = [] + task_type_map[sample.task_type].append(sample) + return test_samples, task_type_map + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate a hallucination detection model based on LLM" + ) + parser.add_argument( + "--model", type=str, required=True, help="Model name to evaluate, e.g. 'gpt-4o'" + ) + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to the evaluation data (JSON format)", + ) + parser.add_argument( + "--evaluation_type", + type=str, + default="example_level", + help="Evaluation type (example_level or char_level)", + ) + parser.add_argument( + "--lang", + type=str, + default="en", + help="Language of the evaluation data", + ) + parser.add_argument( + "--zero_shot", + action="store_true", + help="Whether to use zero-shot prompting", + ) + parser.add_argument( + "--cache_path", + type=str, + default=None, + help="Path to the cache file", + ) + + args = parser.parse_args() + + test_samples, task_type_map = load_data(args.data_path) + + print(f"\nEvaluating model on test samples: {len(test_samples)}") + + detector = HallucinationDetector( + method="llm", + lang=args.lang, + cache_file=args.cache_path, + model=args.model, + zero_shot=args.zero_shot, + ) + + # Evaluate the whole dataset + print("\nTask type: whole dataset") + evaluate_task_samples_llm( + test_samples, + args.evaluation_type, + detector=detector, + ) + + for task_type, samples in task_type_map.items(): + print(f"\nTask type: {task_type}") + evaluate_task_samples_llm( + samples, + args.evaluation_type, + detector=detector, + ) + + +if __name__ == "__main__": + main()