-
Notifications
You must be signed in to change notification settings - Fork 3
add some unit tests for utils #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import random | ||
|
|
||
| import numpy as np | ||
|
|
||
|
|
||
| def fake_embedding(content: str, length: int) -> list[float]: | ||
| """Fake embedding for a given content. | ||
|
|
||
| This function is deterministic, but it does not have the property that | ||
| strings that are close in semantic distance are close in vector distance. | ||
|
|
||
| Returns a unit vector of the given length, computed deterministically based | ||
| on content. | ||
| """ | ||
| # Initialize a random number generator seeded with the content | ||
| # to ensure that the same content always generates the same vector | ||
| # | ||
| # This is not a CSPRNG, but that is fine for our purposes | ||
| rng = random.Random(content) | ||
|
|
||
| # Generate a vector of random floats, with each element in [0, 1) | ||
| vector = [rng.random() for _ in range(length)] | ||
|
|
||
| # Calculate the magnitude of the vector | ||
| magnitude = sum(x**2 for x in vector) ** 0.5 | ||
|
|
||
| # Normalize the vector to unit length | ||
| # | ||
| # This vector is not a uniform random unit vector, but that is fine for our | ||
| # purposes | ||
| return [x / magnitude for x in vector] | ||
|
|
||
|
|
||
| def fake_embedding_with_target_cosine_distance(orig_embedding: list[float], target_distance: float) -> list[float]: | ||
| orig = np.array(orig_embedding) | ||
| orig = orig / np.linalg.norm(orig) | ||
|
|
||
| # Create a random vector orthogonal to orig | ||
| rand = np.random.randn(*orig.shape) | ||
| rand -= np.dot(rand, orig) * orig # make orthogonal | ||
| rand /= np.linalg.norm(rand) | ||
|
|
||
| # Compute angle theta from cosine similarity | ||
| target_cosine = 1 - target_distance | ||
| theta = np.arccos(target_cosine) | ||
|
|
||
| # Combine original and orthogonal vector to get new vector | ||
| new: list[float] = (np.cos(theta) * orig + np.sin(theta) * rand).tolist() | ||
| return new |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from contextlib import asynccontextmanager | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from tests.helpers.embedding_fixtures import fake_embedding | ||
| from tlm.utils.scoring.consistency_scoring_utils import EMBEDDING_MODELS, compute_consistency_scores | ||
| from tlm.types import SimilarityMeasure | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_compute_scores_qa_jaccard() -> None: | ||
| reference_answers = ["Hello, world!", "Hello", "Hello, world!"] | ||
| comparison_answers = ["Hello, world", "Hello, universe", "Hello, universe!"] | ||
| similarity_measure = SimilarityMeasure.JACCARD | ||
| avg_scores, scores = await compute_consistency_scores(reference_answers, comparison_answers, similarity_measure) | ||
| assert np.allclose(scores, np.array([1, 1 / 3, 1 / 3, 0.5, 0.5, 0.5, 1, 1 / 3, 1 / 3])) | ||
| assert np.allclose(avg_scores, np.array([5 / 9, 0.5, 5 / 9])) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("similarity_measure", [SimilarityMeasure.EMBEDDING_SMALL, SimilarityMeasure.EMBEDDING_LARGE]) | ||
| @pytest.mark.asyncio | ||
| async def test_compute_scores_qa_embedding(similarity_measure: SimilarityMeasure) -> None: | ||
| reference_answers = ["Hello, world!", "Hello", "Hello, world!"] | ||
| comparison_answers = ["Hello, world!", "Hello, universe", "Hello, universe!"] | ||
|
|
||
| # Create a mock OpenAI client with embeddings.create method | ||
| mock_openai_client = MagicMock() | ||
| embedding_calls: list[tuple[str, str]] = [] | ||
|
|
||
| async def mock_embeddings_create(input: str, model: str, timeout: float) -> MagicMock: | ||
| embedding_calls.append((input, model)) | ||
| response = MagicMock() | ||
| response.data = [MagicMock(embedding=fake_embedding(input, 3))] | ||
| return response | ||
|
|
||
| mock_openai_client.embeddings.create = mock_embeddings_create | ||
|
|
||
| @asynccontextmanager | ||
| async def mock_get_openai_client(): | ||
| yield mock_openai_client | ||
|
|
||
| with patch( | ||
| "tlm.utils.scoring.consistency_scoring_utils.get_openai_client", | ||
| mock_get_openai_client, | ||
| ): | ||
| avg_scores, scores = await compute_consistency_scores(reference_answers, comparison_answers, similarity_measure) | ||
| assert scores[0] == 1 | ||
| assert scores[6] == 1 | ||
| assert np.all(scores >= 0) | ||
| assert np.all(scores <= 1) | ||
| assert np.all(avg_scores >= 0) | ||
| assert np.all(avg_scores <= 1) | ||
| assert len(embedding_calls) == len(reference_answers) + len(comparison_answers) | ||
| for text, model in embedding_calls: | ||
| assert text in reference_answers or text in comparison_answers | ||
| assert model == EMBEDDING_MODELS[similarity_measure] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| import numpy as np | ||
|
|
||
| from tlm.config.defaults import get_settings | ||
| from tlm.utils.explainability_utils import ( | ||
| HIGH_CONFIDENCE_MESSAGE, | ||
| FALLBACK_EXPLANATION_MESSAGE, | ||
| NO_SELF_REFLECTION_EXPLANATION_MESSAGE, | ||
| OBSERVED_CONSISTENCY_EXPLANATION_TEMPLATE, | ||
| _add_punctuation_if_necessary, | ||
| _get_lowest_scoring_reflection_explanation, | ||
| _get_observed_consistency_explanation, | ||
| get_explainability_message, | ||
| ) | ||
| from tlm.types import Completion, ExtractedResponseField | ||
|
|
||
| defaults = get_settings() | ||
|
|
||
|
|
||
| def test_get_explainability_message_no_confidence_score() -> None: | ||
| assert get_explainability_message(None, [], [], 0, np.array([]), 0, "test") == "" | ||
|
|
||
|
|
||
| def test_get_explainability_message_low_confidence_no_self_reflection_or_consistency() -> None: | ||
| assert ( | ||
| get_explainability_message(defaults.EXPLAINABILITY_THRESHOLD - 0.1, [], [], np.nan, np.array([]), 0, "test") | ||
| == FALLBACK_EXPLANATION_MESSAGE | ||
| ) | ||
|
|
||
|
|
||
| def test_get_explainability_message_low_confidence_with_self_reflection_explanation() -> None: | ||
| self_reflection_explanation = "Self reflection explanation" | ||
| self_reflection_completion = Completion( | ||
| message="test", | ||
| explanation=self_reflection_explanation, | ||
| response_fields={ExtractedResponseField.MAPPED_SCORE: defaults.SELF_REFLECTION_EXPLAINABILITY_THRESHOLD - 0.1}, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| assert self_reflection_explanation in get_explainability_message( | ||
| defaults.EXPLAINABILITY_THRESHOLD - 0.1, | ||
| [[self_reflection_completion]], | ||
| [], | ||
| np.nan, | ||
| np.array([]), | ||
| 0, | ||
| "test", | ||
| ) | ||
|
|
||
|
|
||
| def test_get_explainability_message_low_confidence_with_self_reflection_no_explanation() -> None: | ||
| self_reflection_completion = Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={ExtractedResponseField.MAPPED_SCORE: defaults.SELF_REFLECTION_EXPLAINABILITY_THRESHOLD - 0.1}, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| assert ( | ||
| get_explainability_message( | ||
| defaults.EXPLAINABILITY_THRESHOLD - 0.1, | ||
| [[self_reflection_completion]], | ||
| [], | ||
| np.nan, | ||
| np.array([]), | ||
| 0, | ||
| "test", | ||
| ) | ||
| == NO_SELF_REFLECTION_EXPLANATION_MESSAGE.strip() | ||
| ) | ||
|
|
||
|
|
||
| def test_get_explainability_message_high_confidence_score() -> None: | ||
| assert get_explainability_message(0.9, [], [], 0, np.array([]), 0, "test") == HIGH_CONFIDENCE_MESSAGE | ||
|
|
||
|
|
||
| def test_get_explainaibility_message_nan_confidence_score() -> None: | ||
| assert get_explainability_message(np.nan, [], [], 0, np.array([]), 0, "test") == HIGH_CONFIDENCE_MESSAGE | ||
|
|
||
|
|
||
| def test_get_explainability_message_low_consistency_score() -> None: | ||
| observed_consistency_answer = "incorrect answer" | ||
| observed_consistency_completion = Completion( | ||
| message="test", | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1, | ||
| ExtractedResponseField.ANSWER: observed_consistency_answer, | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| assert OBSERVED_CONSISTENCY_EXPLANATION_TEMPLATE.format( | ||
| observed_consistency_completion=observed_consistency_answer | ||
| ) in get_explainability_message( | ||
| defaults.EXPLAINABILITY_THRESHOLD - 0.1, | ||
| [], | ||
| [observed_consistency_completion], | ||
| defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1, | ||
| np.array([defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1]), | ||
| 0, | ||
| "test", | ||
| ) | ||
|
|
||
|
|
||
| def test_get_explainability_message_self_reflection_and_consistency_explanations() -> None: | ||
| self_reflection_explanation = "Self reflection explanation" | ||
| self_reflection_completion = Completion( | ||
| message="test", | ||
| explanation=self_reflection_explanation, | ||
| response_fields={ExtractedResponseField.MAPPED_SCORE: defaults.SELF_REFLECTION_EXPLAINABILITY_THRESHOLD - 0.1}, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| observed_consistency_answer = "incorrect answer" | ||
| observed_consistency_completion = Completion( | ||
| message="test", | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1, | ||
| ExtractedResponseField.ANSWER: observed_consistency_answer, | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| res = get_explainability_message( | ||
| defaults.EXPLAINABILITY_THRESHOLD - 0.1, | ||
| [[self_reflection_completion]], | ||
| [observed_consistency_completion], | ||
| defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1, | ||
| np.array([defaults.CONSISTENCY_EXPLAINABILITY_THRESHOLD - 0.1]), | ||
| 0, | ||
| "test", | ||
| ) | ||
|
|
||
| assert self_reflection_explanation in res | ||
| assert ( | ||
| OBSERVED_CONSISTENCY_EXPLANATION_TEMPLATE.format(observed_consistency_completion=observed_consistency_answer) | ||
| in res | ||
| ) | ||
|
|
||
|
|
||
| def test_get_lowest_scoring_reflection_explanation() -> None: | ||
| self_reflection_completions = [ | ||
| Completion( | ||
| message="test", | ||
| explanation="Self reflection explanation", | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: defaults.SELF_REFLECTION_EXPLAINABILITY_THRESHOLD - 0.1 | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ), | ||
| Completion( | ||
| message="test", | ||
| explanation="Self reflection explanation 2", | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: defaults.SELF_REFLECTION_EXPLAINABILITY_THRESHOLD - 0.2 | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ), | ||
| ] | ||
| assert _get_lowest_scoring_reflection_explanation(self_reflection_completions) == "Self reflection explanation 2" | ||
|
|
||
|
|
||
| def test_get_lowest_scoring_reflection_explanation_no_explanation() -> None: | ||
| self_reflection_completions = [ | ||
| Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={}, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| ] | ||
| assert _get_lowest_scoring_reflection_explanation(self_reflection_completions) is None | ||
|
|
||
|
|
||
| def test_add_punctuation_if_necessary() -> None: | ||
| assert _add_punctuation_if_necessary("Hello, world!") == " " | ||
| assert _add_punctuation_if_necessary("Hello, world!?") == " " | ||
| assert _add_punctuation_if_necessary("Hello, world.") == " " | ||
| assert _add_punctuation_if_necessary("Hello, world") == ". " | ||
| assert _add_punctuation_if_necessary("Hello, world;\n") == " " | ||
| assert _add_punctuation_if_necessary("Hello, world: ") == " " | ||
|
|
||
|
|
||
| def test_get_observed_consistency_explanation() -> None: | ||
| answer1 = "Answer 1" | ||
| answer2 = "Answer 2" | ||
| best_answer = "correct answer" | ||
| observed_consistency_completions = [ | ||
| Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: 0.3, | ||
| ExtractedResponseField.ANSWER: answer1, | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ), | ||
| Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={ | ||
| ExtractedResponseField.MAPPED_SCORE: 0.2, | ||
| ExtractedResponseField.ANSWER: answer2, | ||
| }, | ||
| original_response={}, | ||
| template=None, | ||
| ), | ||
| Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={ExtractedResponseField.MAPPED_SCORE: 0.1, ExtractedResponseField.ANSWER: best_answer}, | ||
| original_response={}, | ||
| template=None, | ||
| ), | ||
| ] | ||
| assert _get_observed_consistency_explanation( | ||
| observed_consistency_completions, np.array([0.3, 0.2, 0.1]), best_answer | ||
| ) == OBSERVED_CONSISTENCY_EXPLANATION_TEMPLATE.format(observed_consistency_completion=answer2) | ||
|
|
||
|
|
||
| def test_get_observed_consistency_explanation_no_explanation() -> None: | ||
| observed_consistency_completions = [ | ||
| Completion( | ||
| message="test", | ||
| explanation=None, | ||
| response_fields={}, | ||
| original_response={}, | ||
| template=None, | ||
| ) | ||
| ] | ||
| assert ( | ||
| _get_observed_consistency_explanation(observed_consistency_completions, np.array([]), "correct answer") is None | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import pytest | ||
|
|
||
| from tlm.utils.scoring.jaccard_utils import get_structured_output_keys, jaccard_similarity | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "answer, comparison, expected", | ||
| [ | ||
| ("Hello, world!", "Hello, world!", 1.0), | ||
| ("Hello, world!", "Hello, universe!", 1 / 3), | ||
| ("Hello, world!", "Hello, world!?", 1.0), | ||
| ( | ||
| "The quick brown fox jumps over the lazy dog.", | ||
| "The swift black cat hops around the sleepy frog.", | ||
| 1 / 8, | ||
| ), # is there a reason why we don't ignore capitalization? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I asked about this before, will defer to @huiwengoh, but this case sensitivity logic is ported over from SaaS. maybe because capitalization can indicated differences e.g. proper nouns? but I also think that this similarity method is a worst-case scenario fallback so it's not meant to be very sophisticated. other methods like LLM as a judge are used primarily. |
||
| ("Hello world", "Goodbye universe", 0.0), | ||
| ], | ||
| ) | ||
| def test_jaccard_similarity(answer: str, comparison: str, expected: float) -> None: | ||
| assert jaccard_similarity(answer, comparison) == expected | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "answer, comparison, expected", | ||
| [ | ||
| ("{'name': 'John', 'age': 30}", "{'name': 'John', 'age': 30}", 1.0), | ||
| ("{'name': 'John', 'age': 30}", "{'name': 'Jane', 'age': 25}", 0.0), | ||
| ("{'name': 'John', 'age': 30}", "{'name': 'John', 'age': 30, 'city': 'New York'}", 0.4), | ||
| ( | ||
| "{'people': [{'name': 'John', 'age': 30}, {'name': 'Jane', 'age': 25}], 'city': 'New York'}", | ||
| "{'people': [{'name': 'Mary', 'age': 30}, {'name': 'Jane', 'age': 27}], 'city': 'Los Angeles'}", | ||
| 0.2, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_jaccard_similarity_structured_outputs(answer: str, comparison: str, expected: float) -> None: | ||
| assert jaccard_similarity(answer, comparison, structured_outputs=True) == expected | ||
|
|
||
|
|
||
| def test_get_structured_output_keys_error() -> None: | ||
| assert get_structured_output_keys("-1}") == set() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this functionality be refactored into a patch function similar to the
patch_acompletioninlitellm_patches.pyso that it can be easily reused across tests?