From 61b9161f96d77a8afbdc39494e0c2dbb42803eb9 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 15:50:23 +0100 Subject: [PATCH 1/7] feat: add multimodal embeddings with gemini-embedding-2-preview (#214) Support text, image, and video embedding in a unified vector space. Adds domain routing, namespace methods, Vertex AI guard, and demo notebook. --- notebooks/multimodal-embeddings.ipynb | 226 ++++++++++++++++++ src/celeste/core.py | 2 + src/celeste/modalities/embeddings/client.py | 38 ++- src/celeste/modalities/embeddings/io.py | 20 +- .../embeddings/providers/google/client.py | 70 +++++- .../embeddings/providers/google/models.py | 9 + src/celeste/namespaces/domains.py | 124 +++++++++- .../providers/google/embeddings/client.py | 15 ++ .../integration_tests/embeddings/conftest.py | 26 ++ .../embeddings/test_embed_image.py | 109 +++++++++ .../embeddings/test_embed_video.py | 58 +++++ tests/unit_tests/test_embeddings_input.py | 71 ++++++ .../unit_tests/test_embeddings_multimodal.py | 113 +++++++++ 13 files changed, 856 insertions(+), 25 deletions(-) create mode 100644 notebooks/multimodal-embeddings.ipynb create mode 100644 tests/integration_tests/embeddings/conftest.py create mode 100644 tests/integration_tests/embeddings/test_embed_image.py create mode 100644 tests/integration_tests/embeddings/test_embed_video.py create mode 100644 tests/unit_tests/test_embeddings_input.py create mode 100644 tests/unit_tests/test_embeddings_multimodal.py diff --git a/notebooks/multimodal-embeddings.ipynb b/notebooks/multimodal-embeddings.ipynb new file mode 100644 index 00000000..98f9fb6e --- /dev/null +++ b/notebooks/multimodal-embeddings.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Celeste AI - Multimodal Embeddings\n", + "\n", + "Embed **text**, **images**, and **video** into a unified vector space with `gemini-embedding-2-preview`.\n", + "\n", + "Star on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T13:50:41.286419Z", + "start_time": "2026-03-27T13:50:40.643376Z" + } + }, + "outputs": [], + "source": [ + "import celeste\n", + "import numpy as np\n", + "from IPython.display import Image, display" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Text Embedding\n", + "\n", + "Embed text using the `celeste.text` domain namespace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T13:50:42.413433Z", + "start_time": "2026-03-27T13:50:41.845251Z" + } + }, + "outputs": [], + "source": [ + "text_result = await celeste.text.embed(\n", + " \"A happy golden retriever\", model=\"gemini-embedding-2-preview\"\n", + ")\n", + "\n", + "print(f\"Dimensions: {len(text_result.content)}\")\n", + "print(f\"First 5 values: {text_result.content[:5]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Image Embedding\n", + "\n", + "Generate a dog image, then embed it using the `celeste.images` domain namespace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T13:52:51.673322Z", + "start_time": "2026-03-27T13:52:42.237801Z" + } + }, + "outputs": [], + "source": [ + "img_result = await celeste.images.generate(\n", + " \"A golden retriever dog\", model=\"gemini-2.5-flash-image\"\n", + ")\n", + "display(Image(data=img_result.content.data))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T13:53:11.915963Z", + "start_time": "2026-03-27T13:53:10.531407Z" + } + }, + "outputs": [], + "source": [ + "img_emb = await celeste.images.embed(\n", + " img_result.content, model=\"gemini-embedding-2-preview\"\n", + ")\n", + "print(f\"Dimensions: {len(img_emb.content)}\")\n", + "print(f\"First 5 values: {img_emb.content[:5]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T13:57:05.136610Z", + "start_time": "2026-03-27T13:57:01.124944Z" + } + }, + "outputs": [], + "source": [ + "import httpx\n", + "from celeste.artifacts import VideoArtifact\n", + "from celeste.mime_types import VideoMimeType\n", + "\n", + "video_bytes = httpx.get(\"https://download.samplelib.com/mp4/sample-5s.mp4\").content\n", + "video = VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4)\n", + "\n", + "vid_emb = await celeste.videos.embed(video, model=\"gemini-embedding-2-preview\")\n", + "print(f\"Dimensions: {len(vid_emb.content)}\")\n", + "print(f\"First 5 values: {vid_emb.content[:5]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Multimodal Search Test\n", + "\n", + "Query: dog image. Compare similarity to text \"dog\" vs text \"chair\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T14:21:14.584469Z", + "start_time": "2026-03-27T14:21:14.362443Z" + } + }, + "outputs": [], + "source": [ + "!uv pip install matplotlib seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2026-03-27T14:22:17.809811Z", + "start_time": "2026-03-27T14:22:16.624079Z" + } + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import pandas as pd\n", + "\n", + "\n", + "def cosine_sim(a, b):\n", + " a, b = np.array(a), np.array(b)\n", + " return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))\n", + "\n", + "\n", + "texts = [\"dog\", \"labrador\", \"chair\"]\n", + "embeddings = {\n", + " t: await celeste.text.embed(t, model=\"gemini-embedding-2-preview\") for t in texts\n", + "}\n", + "scores = {t: cosine_sim(img_emb.content, emb.content) for t, emb in embeddings.items()}\n", + "\n", + "df = pd.DataFrame({\"text\": scores.keys(), \"similarity\": scores.values()}).sort_values(\n", + " \"similarity\", ascending=False\n", + ")\n", + "sns.barplot(data=df, x=\"similarity\", y=\"text\").set(\n", + " xlim=(0, 1), title=\"Image(dog) similarity to:\"\n", + ")\n", + "\n", + "assert scores[\"dog\"] > scores[\"chair\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "Star on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/celeste/core.py b/src/celeste/core.py index e4126a38..a31e9365 100644 --- a/src/celeste/core.py +++ b/src/celeste/core.py @@ -149,10 +149,12 @@ class Domain(StrEnum): (Domain.IMAGES, Operation.GENERATE): Modality.IMAGES, (Domain.IMAGES, Operation.EDIT): Modality.IMAGES, (Domain.IMAGES, Operation.ANALYZE): Modality.TEXT, + (Domain.IMAGES, Operation.EMBED): Modality.EMBEDDINGS, (Domain.AUDIO, Operation.SPEAK): Modality.AUDIO, (Domain.AUDIO, Operation.ANALYZE): Modality.TEXT, (Domain.VIDEOS, Operation.GENERATE): Modality.VIDEOS, (Domain.VIDEOS, Operation.ANALYZE): Modality.TEXT, + (Domain.VIDEOS, Operation.EMBED): Modality.EMBEDDINGS, } diff --git a/src/celeste/modalities/embeddings/client.py b/src/celeste/modalities/embeddings/client.py index a98c1dab..6eea9ede 100644 --- a/src/celeste/modalities/embeddings/client.py +++ b/src/celeste/modalities/embeddings/client.py @@ -6,7 +6,7 @@ from celeste.client import ModalityClient from celeste.core import Modality -from celeste.types import EmbeddingsContent +from celeste.types import EmbeddingsContent, ImageContent, VideoContent from .io import ( EmbeddingsChunk, @@ -40,33 +40,42 @@ def _output_class(cls) -> type[EmbeddingsOutput]: async def embed( self, - text: str | list[str], + text: str | list[str] | None = None, *, + images: ImageContent | None = None, + videos: VideoContent | None = None, extra_body: dict[str, Any] | None = None, extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: - """Generate embeddings from text. + """Generate embeddings from text, images, or video. Args: text: Text to embed. Single string or list of strings. + images: Image(s) to embed. Single ImageArtifact or list. + videos: Video(s) to embed. Single VideoArtifact or list. extra_body: Additional provider-specific fields to merge into request. extra_headers: Additional HTTP headers to include in the request. **parameters: Embedding parameters (e.g., dimensions). Returns: - EmbeddingsOutput with content as: - - list[float] if text was a string - - list[list[float]] if text was a list + EmbeddingsOutput with content as EmbeddingsContent: + - Single vector for single inputs (str, ImageArtifact, VideoArtifact) + - List of vectors for batch inputs (list[str], list[ImageArtifact], etc.) """ - inputs = EmbeddingsInput(text=text) + inputs = EmbeddingsInput(text=text, images=images, videos=videos) output = await self._predict( inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) - # If single text input, unwrap from batch format to single embedding + # Unwrap single-item results from batch format + is_batch = ( + isinstance(text, list) + or isinstance(images, list) + or isinstance(videos, list) + ) if ( - isinstance(text, str) + not is_batch and isinstance(output.content, list) and output.content and isinstance(output.content[0], list) @@ -89,15 +98,22 @@ def __init__(self, client: EmbeddingsClient) -> None: def embed( self, - text: str | list[str], + text: str | list[str] | None = None, *, + images: ImageContent | None = None, + videos: VideoContent | None = None, extra_body: dict[str, Any] | None = None, extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: """Blocking embeddings generation.""" return async_to_sync(self._client.embed)( - text, extra_body=extra_body, extra_headers=extra_headers, **parameters + text, + images=images, + videos=videos, + extra_body=extra_body, + extra_headers=extra_headers, + **parameters, ) diff --git a/src/celeste/modalities/embeddings/io.py b/src/celeste/modalities/embeddings/io.py index 7ce6c711..9be49db6 100644 --- a/src/celeste/modalities/embeddings/io.py +++ b/src/celeste/modalities/embeddings/io.py @@ -1,15 +1,29 @@ """IO types for embeddings modality.""" -from pydantic import Field +from pydantic import Field, model_validator from celeste.io import Chunk, FinishReason, Input, Output, Usage -from celeste.types import EmbeddingsContent +from celeste.types import EmbeddingsContent, ImageContent, VideoContent class EmbeddingsInput(Input): """Input for embeddings operations.""" - text: str | list[str] + text: str | list[str] | None = None + images: ImageContent | None = None + videos: VideoContent | None = None + + @model_validator(mode="after") + def _validate_inputs(self) -> "EmbeddingsInput": + if self.text is None and self.images is None and self.videos is None: + msg = "At least one of text, images, or videos must be provided" + raise ValueError(msg) + if isinstance(self.text, list) and ( + self.images is not None or self.videos is not None + ): + msg = "Batch text (list[str]) cannot be combined with images or videos" + raise ValueError(msg) + return self class EmbeddingsFinishReason(FinishReason): diff --git a/src/celeste/modalities/embeddings/providers/google/client.py b/src/celeste/modalities/embeddings/providers/google/client.py index 683f10be..8496aca3 100644 --- a/src/celeste/modalities/embeddings/providers/google/client.py +++ b/src/celeste/modalities/embeddings/providers/google/client.py @@ -1,12 +1,15 @@ """Google embeddings client.""" +import base64 from typing import Any +from celeste.artifacts import ImageArtifact, VideoArtifact from celeste.parameters import ParameterMapper from celeste.providers.google.embeddings.client import ( GoogleEmbeddingsClient as GoogleEmbeddingsMixin, ) from celeste.types import EmbeddingsContent +from celeste.utils import detect_mime_type from ...client import EmbeddingsClient from ...io import EmbeddingsInput @@ -21,23 +24,78 @@ def parameter_mappers(cls) -> list[ParameterMapper[EmbeddingsContent]]: """Return parameter mappers for Google embeddings.""" return GOOGLE_PARAMETER_MAPPERS + def _build_image_part(self, image: ImageArtifact) -> dict[str, Any]: + """Build a Gemini part from an ImageArtifact.""" + if image.url: + return {"file_data": {"file_uri": image.url}} + image_bytes = image.get_bytes() + b64 = base64.b64encode(image_bytes).decode("utf-8") + mime = image.mime_type or detect_mime_type(image_bytes) + mime_str = mime.value if mime else None + return {"inline_data": {"mime_type": mime_str, "data": b64}} + + def _build_video_part(self, video: VideoArtifact) -> dict[str, Any]: + """Build a Gemini part from a VideoArtifact.""" + if video.url: + return {"file_data": {"file_uri": video.url}} + video_bytes = video.get_bytes() + b64 = base64.b64encode(video_bytes).decode("utf-8") + mime = video.mime_type or detect_mime_type(video_bytes) + mime_str = mime.value if mime else None + return {"inline_data": {"mime_type": mime_str, "data": b64}} + def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: """Build Google embeddings request from inputs.""" - texts = inputs.text if isinstance(inputs.text, list) else [inputs.text] + # Batch images → separate embeddings via batchEmbedContents + if isinstance(inputs.images, list): + return { + "requests": [ + { + "model": f"models/{self.model.id}", + "content": {"parts": [self._build_image_part(img)]}, + } + for img in inputs.images + ] + } - if len(texts) == 1: - return {"content": {"parts": [{"text": texts[0]}]}} - else: + # Batch videos → separate embeddings via batchEmbedContents + if isinstance(inputs.videos, list): return { "requests": [ { "model": f"models/{self.model.id}", - "content": {"parts": [{"text": text}]}, + "content": {"parts": [self._build_video_part(vid)]}, } - for text in texts + for vid in inputs.videos ] } + # Single/combined multimodal → one aggregated embedding + if inputs.images is not None or inputs.videos is not None: + parts: list[dict[str, Any]] = [] + if inputs.text is not None: + parts.append({"text": inputs.text}) + if inputs.images is not None: + parts.append(self._build_image_part(inputs.images)) + if inputs.videos is not None: + parts.append(self._build_video_part(inputs.videos)) + return {"content": {"parts": parts}} + + # Text-only (existing behavior) + assert inputs.text is not None + texts = inputs.text if isinstance(inputs.text, list) else [inputs.text] + if len(texts) == 1: + return {"content": {"parts": [{"text": texts[0]}]}} + return { + "requests": [ + { + "model": f"models/{self.model.id}", + "content": {"parts": [{"text": text}]}, + } + for text in texts + ] + } + def _parse_content( self, response_data: dict[str, Any], diff --git a/src/celeste/modalities/embeddings/providers/google/models.py b/src/celeste/modalities/embeddings/providers/google/models.py index 5441eac1..cdbfa3d3 100644 --- a/src/celeste/modalities/embeddings/providers/google/models.py +++ b/src/celeste/modalities/embeddings/providers/google/models.py @@ -16,4 +16,13 @@ EmbeddingsParameter.DIMENSIONS: Choice(options=[768, 1536, 3072]), }, ), + Model( + id="gemini-embedding-2-preview", + provider=Provider.GOOGLE, + display_name="Gemini Embedding 2 Preview", + operations={Modality.EMBEDDINGS: {Operation.EMBED}}, + parameter_constraints={ + EmbeddingsParameter.DIMENSIONS: Choice(options=[768, 1536, 3072]), + }, + ), ] diff --git a/src/celeste/namespaces/domains.py b/src/celeste/namespaces/domains.py index d6f894e8..3dfcc2f0 100644 --- a/src/celeste/namespaces/domains.py +++ b/src/celeste/namespaces/domains.py @@ -136,8 +136,10 @@ def generate( def embed( self, - text: str | list[str], + text: str | list[str] | None = None, *, + images: ImageContent | None = None, + videos: VideoContent | None = None, model: str, provider: Provider | None = None, api_key: str | SecretStr | None = None, @@ -153,7 +155,7 @@ def embed( api_key=api_key, auth=auth, ) - return client.sync.embed(text, **params) + return client.sync.embed(text, images=images, videos=videos, **params) @property def stream(self) -> SyncStreamTextNamespace: @@ -214,18 +216,22 @@ async def generate( async def embed( self, - text: str | list[str], + text: str | list[str] | None = None, *, + images: ImageContent | None = None, + videos: VideoContent | None = None, model: str, provider: Provider | None = None, api_key: str | SecretStr | None = None, auth: Authentication | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: - """Generate embeddings from text. + """Generate embeddings from text, images, or video. Args: text: Text to embed. Single string or list of strings. + images: Image(s) to embed. Single ImageArtifact or list. + videos: Video(s) to embed. Single VideoArtifact or list. model: Model ID to use (required). provider: Optional provider override. api_key: Optional API key override. @@ -243,7 +249,7 @@ async def embed( api_key=api_key, auth=auth, ) - return await client.embed(text, **parameters) + return await client.embed(text, images=images, videos=videos, **parameters) @property def sync(self) -> SyncTextNamespace: @@ -467,6 +473,27 @@ def analyze( ) return client.sync.analyze(prompt, messages=messages, image=image, **params) + def embed( + self, + images: ImageContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **params: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Blocking image embeddings generation.""" + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return client.sync.embed(images=images, **params) + @property def stream(self) -> SyncStreamImagesNamespace: """Access sync streaming image operations.""" @@ -586,6 +613,39 @@ async def analyze( prompt, messages=messages, image=image, **parameters ) + async def embed( + self, + images: ImageContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **parameters: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Generate embeddings from images. + + Args: + images: Image or list of images to embed. + model: Model ID to use (required). + provider: Optional provider override. + api_key: Optional API key override. + auth: Optional Authentication object (e.g., GoogleADC for Vertex AI). + **parameters: Additional model parameters. + + Returns: + EmbeddingsOutput with embedding vectors. + """ + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return await client.embed(images=images, **parameters) + @property def sync(self) -> SyncImagesNamespace: """Access synchronous image operations.""" @@ -940,6 +1000,27 @@ def analyze( ) return client.sync.analyze(prompt, messages=messages, video=video, **params) + def embed( + self, + videos: VideoContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **params: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Blocking video embeddings generation.""" + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return client.sync.embed(videos=videos, **params) + @property def stream(self) -> SyncStreamVideosNamespace: """Access sync streaming video operations.""" @@ -1024,6 +1105,39 @@ async def analyze( prompt, messages=messages, video=video, **parameters ) + async def embed( + self, + videos: VideoContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **parameters: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Generate embeddings from videos. + + Args: + videos: Video or list of videos to embed. + model: Model ID to use (required). + provider: Optional provider override. + api_key: Optional API key override. + auth: Optional Authentication object (e.g., GoogleADC for Vertex AI). + **parameters: Additional model parameters. + + Returns: + EmbeddingsOutput with embedding vectors. + """ + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return await client.embed(videos=videos, **parameters) + @property def sync(self) -> SyncVideosNamespace: """Access synchronous video operations.""" diff --git a/src/celeste/providers/google/embeddings/client.py b/src/celeste/providers/google/embeddings/client.py index a3cb7a5e..39e20442 100644 --- a/src/celeste/providers/google/embeddings/client.py +++ b/src/celeste/providers/google/embeddings/client.py @@ -79,6 +79,21 @@ async def _make_request( """Make HTTP request to embeddings endpoint.""" # Vertex :predict expects {"instances": [{"content": "..."}]} format if isinstance(self.auth, GoogleADC): + # Check for multimodal parts (inline_data / file_data) + parts_to_check: list[dict[str, Any]] = [] + if "requests" in request_body: + for req in request_body["requests"]: + parts_to_check.extend(req["content"]["parts"]) + else: + parts_to_check = request_body["content"]["parts"] + + if any("inline_data" in p or "file_data" in p for p in parts_to_check): + msg = ( + "Multimodal embeddings (images/videos) are not yet supported " + "via Vertex AI (GoogleADC). Use a Gemini API key instead." + ) + raise ValueError(msg) + if "requests" in request_body: texts = [ req["content"]["parts"][0]["text"] diff --git a/tests/integration_tests/embeddings/conftest.py b/tests/integration_tests/embeddings/conftest.py new file mode 100644 index 00000000..f8f76b01 --- /dev/null +++ b/tests/integration_tests/embeddings/conftest.py @@ -0,0 +1,26 @@ +"""Embeddings modality integration test fixtures.""" + +from pathlib import Path + +import pytest + +from celeste.artifacts import ImageArtifact, VideoArtifact +from celeste.mime_types import ImageMimeType, VideoMimeType + +ASSETS_DIR = Path(__file__).parent.parent / "text" / "assets" + + +@pytest.fixture +def square_image() -> ImageArtifact: + """Provide a square shape test image.""" + return ImageArtifact( + path=str(ASSETS_DIR / "square.png"), mime_type=ImageMimeType.PNG + ) + + +@pytest.fixture +def test_video() -> VideoArtifact: + """Provide a minimal test video (2s blue screen, 160x120).""" + return VideoArtifact( + path=str(ASSETS_DIR / "test_video.mp4"), mime_type=VideoMimeType.MP4 + ) diff --git a/tests/integration_tests/embeddings/test_embed_image.py b/tests/integration_tests/embeddings/test_embed_image.py new file mode 100644 index 00000000..865fc51b --- /dev/null +++ b/tests/integration_tests/embeddings/test_embed_image.py @@ -0,0 +1,109 @@ +"""Integration tests for embeddings embed operation - image inputs.""" + +import warnings + +# Suppress deprecation warnings from legacy capability packages +warnings.filterwarnings( + "ignore", + message=".*capability parameter is deprecated.*", + category=DeprecationWarning, +) + +import pytest # noqa: E402 + +from celeste import ( # noqa: E402 + Modality, + create_client, +) +from celeste.artifacts import ImageArtifact # noqa: E402 +from celeste.modalities.embeddings import ( # noqa: E402 + EmbeddingsOutput, + EmbeddingsUsage, +) +from celeste.providers.google.auth import GoogleADC # noqa: E402 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_embed_image(square_image: ImageArtifact) -> None: + """Test embedding a single image.""" + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = await client.embed(images=square_image) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None + assert isinstance(response.content, list) + assert len(response.content) > 0 + assert isinstance(response.content[0], float) + assert isinstance(response.usage, EmbeddingsUsage) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_embed_batch_images(square_image: ImageArtifact) -> None: + """Test embedding multiple images returns separate vectors.""" + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = await client.embed(images=[square_image, square_image]) + + assert isinstance(response, EmbeddingsOutput) + assert isinstance(response.content, list) + assert len(response.content) == 2 + assert isinstance(response.content[0], list) + assert isinstance(response.content[0][0], float) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_embed_text_and_image(square_image: ImageArtifact) -> None: + """Test aggregated text+image embedding produces a single vector.""" + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = await client.embed(text="a square shape", images=square_image) + + assert isinstance(response, EmbeddingsOutput) + assert isinstance(response.content, list) + assert len(response.content) > 0 + assert isinstance(response.content[0], float) + + +@pytest.mark.integration +def test_sync_embed_image(square_image: ImageArtifact) -> None: + """Test sync wrapper works correctly. + + Single model smoke test - sync is just async_to_sync wrapper. + """ + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = client.sync.embed(images=square_image) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_vertex_embed_image_raises(square_image: ImageArtifact) -> None: + """Test that multimodal + Vertex AI raises a clear error.""" + client = create_client( + modality=Modality.EMBEDDINGS, + provider="google", + model="gemini-embedding-2-preview", + auth=GoogleADC(), + ) + + with pytest.raises(ValueError, match="not yet supported via Vertex AI"): + await client.embed(images=square_image) diff --git a/tests/integration_tests/embeddings/test_embed_video.py b/tests/integration_tests/embeddings/test_embed_video.py new file mode 100644 index 00000000..430caf26 --- /dev/null +++ b/tests/integration_tests/embeddings/test_embed_video.py @@ -0,0 +1,58 @@ +"""Integration tests for embeddings embed operation - video inputs.""" + +import warnings + +# Suppress deprecation warnings from legacy capability packages +warnings.filterwarnings( + "ignore", + message=".*capability parameter is deprecated.*", + category=DeprecationWarning, +) + +import pytest # noqa: E402 + +from celeste import ( # noqa: E402 + Modality, + create_client, +) +from celeste.artifacts import VideoArtifact # noqa: E402 +from celeste.modalities.embeddings import ( # noqa: E402 + EmbeddingsOutput, + EmbeddingsUsage, +) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_embed_video(test_video: VideoArtifact) -> None: + """Test embedding a single video.""" + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = await client.embed(videos=test_video) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None + assert isinstance(response.content, list) + assert len(response.content) > 0 + assert isinstance(response.content[0], float) + assert isinstance(response.usage, EmbeddingsUsage) + + +@pytest.mark.integration +def test_sync_embed_video(test_video: VideoArtifact) -> None: + """Test sync wrapper works correctly. + + Single model smoke test - sync is just async_to_sync wrapper. + """ + client = create_client( + modality=Modality.EMBEDDINGS, + model="gemini-embedding-2-preview", + ) + + response = client.sync.embed(videos=test_video) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None diff --git a/tests/unit_tests/test_embeddings_input.py b/tests/unit_tests/test_embeddings_input.py new file mode 100644 index 00000000..5addb8e4 --- /dev/null +++ b/tests/unit_tests/test_embeddings_input.py @@ -0,0 +1,71 @@ +"""Unit tests for EmbeddingsInput validation.""" + +import pytest + +from celeste.artifacts import ImageArtifact, VideoArtifact +from celeste.mime_types import ImageMimeType, VideoMimeType +from celeste.modalities.embeddings.io import EmbeddingsInput + +_TEST_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00" + b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00" + b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +def test_text_only() -> None: + inp = EmbeddingsInput(text="hello") + assert inp.text == "hello" + assert inp.images is None + assert inp.videos is None + + +def test_batch_text() -> None: + inp = EmbeddingsInput(text=["a", "b"]) + assert isinstance(inp.text, list) + + +def test_image_only() -> None: + img = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + inp = EmbeddingsInput(images=img) + assert inp.images is not None + assert inp.text is None + + +def test_image_list() -> None: + img1 = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + img2 = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + inp = EmbeddingsInput(images=[img1, img2]) + assert isinstance(inp.images, list) + assert len(inp.images) == 2 + + +def test_video_only() -> None: + vid = VideoArtifact(data=b"\x00" * 10, mime_type=VideoMimeType.MP4) + inp = EmbeddingsInput(videos=vid) + assert inp.videos is not None + + +def test_text_and_image() -> None: + img = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + inp = EmbeddingsInput(text="a cat", images=img) + assert inp.text == "a cat" + assert inp.images is not None + + +def test_no_input_raises() -> None: + with pytest.raises(Exception, match="At least one"): + EmbeddingsInput() + + +def test_batch_text_with_image_raises() -> None: + img = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + with pytest.raises(Exception, match="Batch text"): + EmbeddingsInput(text=["a", "b"], images=img) + + +def test_batch_text_with_video_raises() -> None: + vid = VideoArtifact(data=b"\x00" * 10, mime_type=VideoMimeType.MP4) + with pytest.raises(Exception, match="Batch text"): + EmbeddingsInput(text=["a", "b"], videos=vid) diff --git a/tests/unit_tests/test_embeddings_multimodal.py b/tests/unit_tests/test_embeddings_multimodal.py new file mode 100644 index 00000000..3d6173bd --- /dev/null +++ b/tests/unit_tests/test_embeddings_multimodal.py @@ -0,0 +1,113 @@ +"""Unit tests for Google embeddings multimodal request building (no network).""" + +from pydantic import SecretStr + +from celeste import Model +from celeste.artifacts import ImageArtifact, VideoArtifact +from celeste.auth import AuthHeader +from celeste.core import Modality, Operation, Provider +from celeste.mime_types import ImageMimeType, VideoMimeType +from celeste.modalities.embeddings.io import EmbeddingsInput +from celeste.modalities.embeddings.providers.google.client import GoogleEmbeddingsClient + +_MODEL = Model( + id="gemini-embedding-2-preview", + provider=Provider.GOOGLE, + display_name="Gemini Embedding 2 Preview", + operations={Modality.EMBEDDINGS: {Operation.EMBED}}, +) + +_AUTH = AuthHeader(secret=SecretStr("test"), header="x-goog-api-key", prefix="") + +_TEST_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + b"\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00" + b"\x00\x00\x0cIDATx\x9cc\xf8\x0f\x00\x00\x01\x01\x00" + b"\x05\x18\xd8N\x00\x00\x00\x00IEND\xaeB`\x82" +) + + +def _make_client() -> GoogleEmbeddingsClient: + return GoogleEmbeddingsClient( + model=_MODEL, + provider=Provider.GOOGLE, + auth=_AUTH, + ) + + +def test_text_only_single() -> None: + client = _make_client() + request = client._init_request(EmbeddingsInput(text="hello")) + + assert "content" in request + assert request["content"]["parts"] == [{"text": "hello"}] + assert "requests" not in request + + +def test_text_only_batch() -> None: + client = _make_client() + request = client._init_request(EmbeddingsInput(text=["a", "b"])) + + assert "requests" in request + assert len(request["requests"]) == 2 + assert request["requests"][0]["content"]["parts"] == [{"text": "a"}] + assert request["requests"][1]["content"]["parts"] == [{"text": "b"}] + assert request["requests"][0]["model"] == "models/gemini-embedding-2-preview" + + +def test_single_image() -> None: + client = _make_client() + img = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + request = client._init_request(EmbeddingsInput(images=img)) + + assert "content" in request + parts = request["content"]["parts"] + assert len(parts) == 1 + assert "inline_data" in parts[0] + assert parts[0]["inline_data"]["mime_type"] == "image/png" + + +def test_batch_images() -> None: + client = _make_client() + img1 = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + img2 = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + request = client._init_request(EmbeddingsInput(images=[img1, img2])) + + assert "requests" in request + assert len(request["requests"]) == 2 + assert "inline_data" in request["requests"][0]["content"]["parts"][0] + assert "inline_data" in request["requests"][1]["content"]["parts"][0] + + +def test_single_video() -> None: + client = _make_client() + vid = VideoArtifact(data=b"\x00" * 10, mime_type=VideoMimeType.MP4) + request = client._init_request(EmbeddingsInput(videos=vid)) + + assert "content" in request + parts = request["content"]["parts"] + assert len(parts) == 1 + assert "inline_data" in parts[0] + assert parts[0]["inline_data"]["mime_type"] == "video/mp4" + + +def test_text_and_image_combined() -> None: + client = _make_client() + img = ImageArtifact(data=_TEST_PNG_BYTES, mime_type=ImageMimeType.PNG) + request = client._init_request(EmbeddingsInput(text="a cat", images=img)) + + assert "content" in request + parts = request["content"]["parts"] + assert len(parts) == 2 + assert parts[0] == {"text": "a cat"} + assert "inline_data" in parts[1] + + +def test_image_with_url_uses_file_data() -> None: + client = _make_client() + img = ImageArtifact(url="https://example.com/image.png") + request = client._init_request(EmbeddingsInput(images=img)) + + parts = request["content"]["parts"] + assert "file_data" in parts[0] + assert parts[0]["file_data"]["file_uri"] == "https://example.com/image.png" From 9995c88665fae9cd4fa60df7f533d8c21b485895 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 16:12:35 +0100 Subject: [PATCH 2/7] fix: add optional_input_types to embedding models and parametrize tests Register ImagesConstraint and VideosConstraint on gemini-embedding-2-preview so optional_input_types auto-populates. Parametrize image/video integration tests with list_models() + InputType filtering instead of hardcoded model IDs. --- .../modalities/embeddings/parameters.py | 2 + .../embeddings/providers/google/models.py | 4 +- .../embeddings/test_embed_image.py | 52 ++++++++++++++++--- .../embeddings/test_embed_video.py | 26 ++++++++-- 4 files changed, 73 insertions(+), 11 deletions(-) diff --git a/src/celeste/modalities/embeddings/parameters.py b/src/celeste/modalities/embeddings/parameters.py index db954905..87abd8c8 100644 --- a/src/celeste/modalities/embeddings/parameters.py +++ b/src/celeste/modalities/embeddings/parameters.py @@ -9,6 +9,8 @@ class EmbeddingsParameter(StrEnum): """Parameter names for embeddings.""" DIMENSIONS = "dimensions" + IMAGE = "image" + VIDEO = "video" class EmbeddingsParameters(Parameters): diff --git a/src/celeste/modalities/embeddings/providers/google/models.py b/src/celeste/modalities/embeddings/providers/google/models.py index cdbfa3d3..ae2adfbb 100644 --- a/src/celeste/modalities/embeddings/providers/google/models.py +++ b/src/celeste/modalities/embeddings/providers/google/models.py @@ -1,6 +1,6 @@ """Google models for embeddings modality.""" -from celeste.constraints import Choice +from celeste.constraints import Choice, ImagesConstraint, VideosConstraint from celeste.core import Modality, Operation, Provider from celeste.models import Model @@ -23,6 +23,8 @@ operations={Modality.EMBEDDINGS: {Operation.EMBED}}, parameter_constraints={ EmbeddingsParameter.DIMENSIONS: Choice(options=[768, 1536, 3072]), + EmbeddingsParameter.IMAGE: ImagesConstraint(), + EmbeddingsParameter.VIDEO: VideosConstraint(), }, ), ] diff --git a/tests/integration_tests/embeddings/test_embed_image.py b/tests/integration_tests/embeddings/test_embed_image.py index 865fc51b..28c385fe 100644 --- a/tests/integration_tests/embeddings/test_embed_image.py +++ b/tests/integration_tests/embeddings/test_embed_image.py @@ -13,9 +13,13 @@ from celeste import ( # noqa: E402 Modality, + Model, + Operation, create_client, + list_models, ) from celeste.artifacts import ImageArtifact # noqa: E402 +from celeste.core import InputType # noqa: E402 from celeste.modalities.embeddings import ( # noqa: E402 EmbeddingsOutput, EmbeddingsUsage, @@ -23,13 +27,22 @@ from celeste.providers.google.auth import GoogleADC # noqa: E402 +@pytest.mark.parametrize( + "model", + [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.IMAGE in m.optional_input_types + ], + ids=lambda m: f"{m.provider}-{m.id}", +) @pytest.mark.integration @pytest.mark.asyncio -async def test_embed_image(square_image: ImageArtifact) -> None: +async def test_embed_image(model: Model, square_image: ImageArtifact) -> None: """Test embedding a single image.""" client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = await client.embed(images=square_image) @@ -42,13 +55,22 @@ async def test_embed_image(square_image: ImageArtifact) -> None: assert isinstance(response.usage, EmbeddingsUsage) +@pytest.mark.parametrize( + "model", + [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.IMAGE in m.optional_input_types + ], + ids=lambda m: f"{m.provider}-{m.id}", +) @pytest.mark.integration @pytest.mark.asyncio -async def test_embed_batch_images(square_image: ImageArtifact) -> None: +async def test_embed_batch_images(model: Model, square_image: ImageArtifact) -> None: """Test embedding multiple images returns separate vectors.""" client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = await client.embed(images=[square_image, square_image]) @@ -60,13 +82,22 @@ async def test_embed_batch_images(square_image: ImageArtifact) -> None: assert isinstance(response.content[0][0], float) +@pytest.mark.parametrize( + "model", + [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.IMAGE in m.optional_input_types + ], + ids=lambda m: f"{m.provider}-{m.id}", +) @pytest.mark.integration @pytest.mark.asyncio -async def test_embed_text_and_image(square_image: ImageArtifact) -> None: +async def test_embed_text_and_image(model: Model, square_image: ImageArtifact) -> None: """Test aggregated text+image embedding produces a single vector.""" client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = await client.embed(text="a square shape", images=square_image) @@ -83,9 +114,16 @@ def test_sync_embed_image(square_image: ImageArtifact) -> None: Single model smoke test - sync is just async_to_sync wrapper. """ + models = [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.IMAGE in m.optional_input_types + ] + model = models[0] + client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = client.sync.embed(images=square_image) diff --git a/tests/integration_tests/embeddings/test_embed_video.py b/tests/integration_tests/embeddings/test_embed_video.py index 430caf26..9b80f766 100644 --- a/tests/integration_tests/embeddings/test_embed_video.py +++ b/tests/integration_tests/embeddings/test_embed_video.py @@ -13,22 +13,35 @@ from celeste import ( # noqa: E402 Modality, + Model, + Operation, create_client, + list_models, ) from celeste.artifacts import VideoArtifact # noqa: E402 +from celeste.core import InputType # noqa: E402 from celeste.modalities.embeddings import ( # noqa: E402 EmbeddingsOutput, EmbeddingsUsage, ) +@pytest.mark.parametrize( + "model", + [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.VIDEO in m.optional_input_types + ], + ids=lambda m: f"{m.provider}-{m.id}", +) @pytest.mark.integration @pytest.mark.asyncio -async def test_embed_video(test_video: VideoArtifact) -> None: +async def test_embed_video(model: Model, test_video: VideoArtifact) -> None: """Test embedding a single video.""" client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = await client.embed(videos=test_video) @@ -47,9 +60,16 @@ def test_sync_embed_video(test_video: VideoArtifact) -> None: Single model smoke test - sync is just async_to_sync wrapper. """ + models = [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.VIDEO in m.optional_input_types + ] + model = models[0] + client = create_client( modality=Modality.EMBEDDINGS, - model="gemini-embedding-2-preview", + model=model, ) response = client.sync.embed(videos=test_video) From f8218fcbcfeeeee495072e3197bad8dc401cf6f6 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 16:29:44 +0100 Subject: [PATCH 3/7] feat: add audio embedding support Wire AudioContent through EmbeddingsInput, Google provider, domain routing, and namespace methods. Add AudiosConstraint to gemini-embedding-2-preview. --- src/celeste/core.py | 1 + src/celeste/modalities/embeddings/client.py | 13 +++- src/celeste/modalities/embeddings/io.py | 19 +++-- .../modalities/embeddings/parameters.py | 1 + .../embeddings/providers/google/client.py | 32 +++++++- .../embeddings/providers/google/models.py | 8 +- src/celeste/namespaces/domains.py | 54 +++++++++++++ .../integration_tests/embeddings/conftest.py | 12 ++- .../embeddings/test_embed_audio.py | 78 +++++++++++++++++++ tests/unit_tests/test_embeddings_input.py | 17 +++- .../unit_tests/test_embeddings_multimodal.py | 16 +++- 11 files changed, 233 insertions(+), 18 deletions(-) create mode 100644 tests/integration_tests/embeddings/test_embed_audio.py diff --git a/src/celeste/core.py b/src/celeste/core.py index a31e9365..f7d99136 100644 --- a/src/celeste/core.py +++ b/src/celeste/core.py @@ -150,6 +150,7 @@ class Domain(StrEnum): (Domain.IMAGES, Operation.EDIT): Modality.IMAGES, (Domain.IMAGES, Operation.ANALYZE): Modality.TEXT, (Domain.IMAGES, Operation.EMBED): Modality.EMBEDDINGS, + (Domain.AUDIO, Operation.EMBED): Modality.EMBEDDINGS, (Domain.AUDIO, Operation.SPEAK): Modality.AUDIO, (Domain.AUDIO, Operation.ANALYZE): Modality.TEXT, (Domain.VIDEOS, Operation.GENERATE): Modality.VIDEOS, diff --git a/src/celeste/modalities/embeddings/client.py b/src/celeste/modalities/embeddings/client.py index 6eea9ede..ab2f2b99 100644 --- a/src/celeste/modalities/embeddings/client.py +++ b/src/celeste/modalities/embeddings/client.py @@ -6,7 +6,7 @@ from celeste.client import ModalityClient from celeste.core import Modality -from celeste.types import EmbeddingsContent, ImageContent, VideoContent +from celeste.types import AudioContent, EmbeddingsContent, ImageContent, VideoContent from .io import ( EmbeddingsChunk, @@ -44,26 +44,28 @@ async def embed( *, images: ImageContent | None = None, videos: VideoContent | None = None, + audio: AudioContent | None = None, extra_body: dict[str, Any] | None = None, extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: - """Generate embeddings from text, images, or video. + """Generate embeddings from text, images, video, or audio. Args: text: Text to embed. Single string or list of strings. images: Image(s) to embed. Single ImageArtifact or list. videos: Video(s) to embed. Single VideoArtifact or list. + audio: Audio file(s) to embed. Single AudioArtifact or list. extra_body: Additional provider-specific fields to merge into request. extra_headers: Additional HTTP headers to include in the request. **parameters: Embedding parameters (e.g., dimensions). Returns: EmbeddingsOutput with content as EmbeddingsContent: - - Single vector for single inputs (str, ImageArtifact, VideoArtifact) + - Single vector for single inputs (str, ImageArtifact, etc.) - List of vectors for batch inputs (list[str], list[ImageArtifact], etc.) """ - inputs = EmbeddingsInput(text=text, images=images, videos=videos) + inputs = EmbeddingsInput(text=text, images=images, videos=videos, audio=audio) output = await self._predict( inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) @@ -73,6 +75,7 @@ async def embed( isinstance(text, list) or isinstance(images, list) or isinstance(videos, list) + or isinstance(audio, list) ) if ( not is_batch @@ -102,6 +105,7 @@ def embed( *, images: ImageContent | None = None, videos: VideoContent | None = None, + audio: AudioContent | None = None, extra_body: dict[str, Any] | None = None, extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], @@ -111,6 +115,7 @@ def embed( text, images=images, videos=videos, + audio=audio, extra_body=extra_body, extra_headers=extra_headers, **parameters, diff --git a/src/celeste/modalities/embeddings/io.py b/src/celeste/modalities/embeddings/io.py index 9be49db6..c99bb1b4 100644 --- a/src/celeste/modalities/embeddings/io.py +++ b/src/celeste/modalities/embeddings/io.py @@ -3,7 +3,7 @@ from pydantic import Field, model_validator from celeste.io import Chunk, FinishReason, Input, Output, Usage -from celeste.types import EmbeddingsContent, ImageContent, VideoContent +from celeste.types import AudioContent, EmbeddingsContent, ImageContent, VideoContent class EmbeddingsInput(Input): @@ -12,16 +12,25 @@ class EmbeddingsInput(Input): text: str | list[str] | None = None images: ImageContent | None = None videos: VideoContent | None = None + audio: AudioContent | None = None @model_validator(mode="after") def _validate_inputs(self) -> "EmbeddingsInput": - if self.text is None and self.images is None and self.videos is None: - msg = "At least one of text, images, or videos must be provided" + if ( + self.text is None + and self.images is None + and self.videos is None + and self.audio is None + ): + msg = "At least one of text, images, videos, or audio must be provided" raise ValueError(msg) if isinstance(self.text, list) and ( - self.images is not None or self.videos is not None + self.images is not None or self.videos is not None or self.audio is not None ): - msg = "Batch text (list[str]) cannot be combined with images or videos" + msg = ( + "Batch text (list[str]) cannot be combined with images, videos," + " or audio" + ) raise ValueError(msg) return self diff --git a/src/celeste/modalities/embeddings/parameters.py b/src/celeste/modalities/embeddings/parameters.py index 87abd8c8..86da04ef 100644 --- a/src/celeste/modalities/embeddings/parameters.py +++ b/src/celeste/modalities/embeddings/parameters.py @@ -11,6 +11,7 @@ class EmbeddingsParameter(StrEnum): DIMENSIONS = "dimensions" IMAGE = "image" VIDEO = "video" + AUDIO = "audio" class EmbeddingsParameters(Parameters): diff --git a/src/celeste/modalities/embeddings/providers/google/client.py b/src/celeste/modalities/embeddings/providers/google/client.py index 8496aca3..ee5f3ace 100644 --- a/src/celeste/modalities/embeddings/providers/google/client.py +++ b/src/celeste/modalities/embeddings/providers/google/client.py @@ -3,7 +3,7 @@ import base64 from typing import Any -from celeste.artifacts import ImageArtifact, VideoArtifact +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.parameters import ParameterMapper from celeste.providers.google.embeddings.client import ( GoogleEmbeddingsClient as GoogleEmbeddingsMixin, @@ -44,6 +44,16 @@ def _build_video_part(self, video: VideoArtifact) -> dict[str, Any]: mime_str = mime.value if mime else None return {"inline_data": {"mime_type": mime_str, "data": b64}} + def _build_audio_part(self, audio: AudioArtifact) -> dict[str, Any]: + """Build a Gemini part from an AudioArtifact.""" + if audio.url: + return {"file_data": {"file_uri": audio.url}} + audio_bytes = audio.get_bytes() + b64 = base64.b64encode(audio_bytes).decode("utf-8") + mime = audio.mime_type or detect_mime_type(audio_bytes) + mime_str = mime.value if mime else None + return {"inline_data": {"mime_type": mime_str, "data": b64}} + def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: """Build Google embeddings request from inputs.""" # Batch images → separate embeddings via batchEmbedContents @@ -70,8 +80,24 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: ] } + # Batch audio → separate embeddings via batchEmbedContents + if isinstance(inputs.audio, list): + return { + "requests": [ + { + "model": f"models/{self.model.id}", + "content": {"parts": [self._build_audio_part(aud)]}, + } + for aud in inputs.audio + ] + } + # Single/combined multimodal → one aggregated embedding - if inputs.images is not None or inputs.videos is not None: + if ( + inputs.images is not None + or inputs.videos is not None + or inputs.audio is not None + ): parts: list[dict[str, Any]] = [] if inputs.text is not None: parts.append({"text": inputs.text}) @@ -79,6 +105,8 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: parts.append(self._build_image_part(inputs.images)) if inputs.videos is not None: parts.append(self._build_video_part(inputs.videos)) + if inputs.audio is not None: + parts.append(self._build_audio_part(inputs.audio)) return {"content": {"parts": parts}} # Text-only (existing behavior) diff --git a/src/celeste/modalities/embeddings/providers/google/models.py b/src/celeste/modalities/embeddings/providers/google/models.py index ae2adfbb..952bf43a 100644 --- a/src/celeste/modalities/embeddings/providers/google/models.py +++ b/src/celeste/modalities/embeddings/providers/google/models.py @@ -1,6 +1,11 @@ """Google models for embeddings modality.""" -from celeste.constraints import Choice, ImagesConstraint, VideosConstraint +from celeste.constraints import ( + AudiosConstraint, + Choice, + ImagesConstraint, + VideosConstraint, +) from celeste.core import Modality, Operation, Provider from celeste.models import Model @@ -25,6 +30,7 @@ EmbeddingsParameter.DIMENSIONS: Choice(options=[768, 1536, 3072]), EmbeddingsParameter.IMAGE: ImagesConstraint(), EmbeddingsParameter.VIDEO: VideosConstraint(), + EmbeddingsParameter.AUDIO: AudiosConstraint(), }, ), ] diff --git a/src/celeste/namespaces/domains.py b/src/celeste/namespaces/domains.py index 3dfcc2f0..dcb51e23 100644 --- a/src/celeste/namespaces/domains.py +++ b/src/celeste/namespaces/domains.py @@ -802,6 +802,27 @@ def analyze( ) return client.sync.analyze(prompt, messages=messages, audio=audio, **params) + def embed( + self, + audio: AudioContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **params: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Blocking audio embeddings generation.""" + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return client.sync.embed(audio=audio, **params) + @property def stream(self) -> SyncStreamAudioNamespace: """Access sync streaming audio operations.""" @@ -886,6 +907,39 @@ async def analyze( prompt, messages=messages, audio=audio, **parameters ) + async def embed( + self, + audio: AudioContent, + *, + model: str, + provider: Provider | None = None, + api_key: str | SecretStr | None = None, + auth: Authentication | None = None, + **parameters: Unpack[EmbeddingsParameters], + ) -> EmbeddingsOutput: + """Generate embeddings from audio. + + Args: + audio: Audio or list of audio files to embed. + model: Model ID to use (required). + provider: Optional provider override. + api_key: Optional API key override. + auth: Optional Authentication object (e.g., GoogleADC for Vertex AI). + **parameters: Additional model parameters. + + Returns: + EmbeddingsOutput with embedding vectors. + """ + client = create_client( + modality=Modality.EMBEDDINGS, + operation=Operation.EMBED, + model=model, + provider=provider, + api_key=api_key, + auth=auth, + ) + return await client.embed(audio=audio, **parameters) + @property def sync(self) -> SyncAudioNamespace: """Access synchronous audio operations.""" diff --git a/tests/integration_tests/embeddings/conftest.py b/tests/integration_tests/embeddings/conftest.py index f8f76b01..67378970 100644 --- a/tests/integration_tests/embeddings/conftest.py +++ b/tests/integration_tests/embeddings/conftest.py @@ -4,8 +4,8 @@ import pytest -from celeste.artifacts import ImageArtifact, VideoArtifact -from celeste.mime_types import ImageMimeType, VideoMimeType +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact +from celeste.mime_types import AudioMimeType, ImageMimeType, VideoMimeType ASSETS_DIR = Path(__file__).parent.parent / "text" / "assets" @@ -24,3 +24,11 @@ def test_video() -> VideoArtifact: return VideoArtifact( path=str(ASSETS_DIR / "test_video.mp4"), mime_type=VideoMimeType.MP4 ) + + +@pytest.fixture +def test_audio() -> AudioArtifact: + """Provide a minimal test audio (2s 440Hz sine wave).""" + return AudioArtifact( + path=str(ASSETS_DIR / "test_audio.mp3"), mime_type=AudioMimeType.MP3 + ) diff --git a/tests/integration_tests/embeddings/test_embed_audio.py b/tests/integration_tests/embeddings/test_embed_audio.py new file mode 100644 index 00000000..54b85557 --- /dev/null +++ b/tests/integration_tests/embeddings/test_embed_audio.py @@ -0,0 +1,78 @@ +"""Integration tests for embeddings embed operation - audio inputs.""" + +import warnings + +# Suppress deprecation warnings from legacy capability packages +warnings.filterwarnings( + "ignore", + message=".*capability parameter is deprecated.*", + category=DeprecationWarning, +) + +import pytest # noqa: E402 + +from celeste import ( # noqa: E402 + Modality, + Model, + Operation, + create_client, + list_models, +) +from celeste.artifacts import AudioArtifact # noqa: E402 +from celeste.core import InputType # noqa: E402 +from celeste.modalities.embeddings import ( # noqa: E402 + EmbeddingsOutput, + EmbeddingsUsage, +) + + +@pytest.mark.parametrize( + "model", + [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.AUDIO in m.optional_input_types + ], + ids=lambda m: f"{m.provider}-{m.id}", +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_embed_audio(model: Model, test_audio: AudioArtifact) -> None: + """Test embedding a single audio file.""" + client = create_client( + modality=Modality.EMBEDDINGS, + model=model, + ) + + response = await client.embed(audio=test_audio) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None + assert isinstance(response.content, list) + assert len(response.content) > 0 + assert isinstance(response.content[0], float) + assert isinstance(response.usage, EmbeddingsUsage) + + +@pytest.mark.integration +def test_sync_embed_audio(test_audio: AudioArtifact) -> None: + """Test sync wrapper works correctly. + + Single model smoke test - sync is just async_to_sync wrapper. + """ + models = [ + m + for m in list_models(modality=Modality.EMBEDDINGS, operation=Operation.EMBED) + if InputType.AUDIO in m.optional_input_types + ] + model = models[0] + + client = create_client( + modality=Modality.EMBEDDINGS, + model=model, + ) + + response = client.sync.embed(audio=test_audio) + + assert isinstance(response, EmbeddingsOutput) + assert response.content is not None diff --git a/tests/unit_tests/test_embeddings_input.py b/tests/unit_tests/test_embeddings_input.py index 5addb8e4..5f3c9303 100644 --- a/tests/unit_tests/test_embeddings_input.py +++ b/tests/unit_tests/test_embeddings_input.py @@ -2,8 +2,8 @@ import pytest -from celeste.artifacts import ImageArtifact, VideoArtifact -from celeste.mime_types import ImageMimeType, VideoMimeType +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact +from celeste.mime_types import AudioMimeType, ImageMimeType, VideoMimeType from celeste.modalities.embeddings.io import EmbeddingsInput _TEST_PNG_BYTES = ( @@ -69,3 +69,16 @@ def test_batch_text_with_video_raises() -> None: vid = VideoArtifact(data=b"\x00" * 10, mime_type=VideoMimeType.MP4) with pytest.raises(Exception, match="Batch text"): EmbeddingsInput(text=["a", "b"], videos=vid) + + +def test_audio_only() -> None: + aud = AudioArtifact(data=b"\x00" * 10, mime_type=AudioMimeType.MP3) + inp = EmbeddingsInput(audio=aud) + assert inp.audio is not None + assert inp.text is None + + +def test_batch_text_with_audio_raises() -> None: + aud = AudioArtifact(data=b"\x00" * 10, mime_type=AudioMimeType.MP3) + with pytest.raises(Exception, match="Batch text"): + EmbeddingsInput(text=["a", "b"], audio=aud) diff --git a/tests/unit_tests/test_embeddings_multimodal.py b/tests/unit_tests/test_embeddings_multimodal.py index 3d6173bd..9d9144bc 100644 --- a/tests/unit_tests/test_embeddings_multimodal.py +++ b/tests/unit_tests/test_embeddings_multimodal.py @@ -3,10 +3,10 @@ from pydantic import SecretStr from celeste import Model -from celeste.artifacts import ImageArtifact, VideoArtifact +from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.auth import AuthHeader from celeste.core import Modality, Operation, Provider -from celeste.mime_types import ImageMimeType, VideoMimeType +from celeste.mime_types import AudioMimeType, ImageMimeType, VideoMimeType from celeste.modalities.embeddings.io import EmbeddingsInput from celeste.modalities.embeddings.providers.google.client import GoogleEmbeddingsClient @@ -111,3 +111,15 @@ def test_image_with_url_uses_file_data() -> None: parts = request["content"]["parts"] assert "file_data" in parts[0] assert parts[0]["file_data"]["file_uri"] == "https://example.com/image.png" + + +def test_single_audio() -> None: + client = _make_client() + aud = AudioArtifact(data=b"\x00" * 10, mime_type=AudioMimeType.MP3) + request = client._init_request(EmbeddingsInput(audio=aud)) + + assert "content" in request + parts = request["content"]["parts"] + assert len(parts) == 1 + assert "inline_data" in parts[0] + assert parts[0]["inline_data"]["mime_type"] == "audio/mpeg" From 33c966e6286f8988aed5938ad6e0fd6a94b06acb Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 16:37:56 +0100 Subject: [PATCH 4/7] fix: add audio param to text.embed() namespace and update notebook Pass audio through TextNamespace.embed() and SyncTextNamespace.embed(). Add audio embedding cell to multimodal-embeddings notebook. --- notebooks/multimodal-embeddings.ipynb | 25 ++++++++++++++++++------- src/celeste/namespaces/domains.py | 13 ++++++++++--- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/notebooks/multimodal-embeddings.ipynb b/notebooks/multimodal-embeddings.ipynb index 98f9fb6e..dd17bb10 100644 --- a/notebooks/multimodal-embeddings.ipynb +++ b/notebooks/multimodal-embeddings.ipynb @@ -3,13 +3,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "# Celeste AI - Multimodal Embeddings\n", - "\n", - "Embed **text**, **images**, and **video** into a unified vector space with `gemini-embedding-2-preview`.\n", - "\n", - "Star on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" - ] + "source": "# Celeste AI - Multimodal Embeddings\n\nEmbed **text**, **images**, **video**, and **audio** into a unified vector space with `gemini-embedding-2-preview`.\n\nStar on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" }, { "cell_type": "markdown", @@ -110,6 +104,11 @@ "print(f\"First 5 values: {img_emb.content[:5]}\")" ] }, + { + "cell_type": "markdown", + "source": "---\n\n## Video Embedding\n\nDownload a short sample video and embed it using the `celeste.videos` domain namespace.", + "metadata": {} + }, { "cell_type": "code", "execution_count": null, @@ -133,6 +132,18 @@ "print(f\"First 5 values: {vid_emb.content[:5]}\")" ] }, + { + "cell_type": "markdown", + "source": "---\n\n## Audio Embedding\n\nDownload a short sample audio and embed it using the `celeste.audio` domain namespace.", + "metadata": {} + }, + { + "cell_type": "code", + "source": "from celeste.artifacts import AudioArtifact\nfrom celeste.mime_types import AudioMimeType\n\naudio_bytes = httpx.get(\"https://download.samplelib.com/mp3/sample-3s.mp3\").content\naudio = AudioArtifact(data=audio_bytes, mime_type=AudioMimeType.MP3)\n\naud_emb = await celeste.audio.embed(audio, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(aud_emb.content)}\")\nprint(f\"First 5 values: {aud_emb.content[:5]}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/celeste/namespaces/domains.py b/src/celeste/namespaces/domains.py index dcb51e23..ec03dfa2 100644 --- a/src/celeste/namespaces/domains.py +++ b/src/celeste/namespaces/domains.py @@ -140,6 +140,7 @@ def embed( *, images: ImageContent | None = None, videos: VideoContent | None = None, + audio: AudioContent | None = None, model: str, provider: Provider | None = None, api_key: str | SecretStr | None = None, @@ -155,7 +156,9 @@ def embed( api_key=api_key, auth=auth, ) - return client.sync.embed(text, images=images, videos=videos, **params) + return client.sync.embed( + text, images=images, videos=videos, audio=audio, **params + ) @property def stream(self) -> SyncStreamTextNamespace: @@ -220,18 +223,20 @@ async def embed( *, images: ImageContent | None = None, videos: VideoContent | None = None, + audio: AudioContent | None = None, model: str, provider: Provider | None = None, api_key: str | SecretStr | None = None, auth: Authentication | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: - """Generate embeddings from text, images, or video. + """Generate embeddings from text, images, video, or audio. Args: text: Text to embed. Single string or list of strings. images: Image(s) to embed. Single ImageArtifact or list. videos: Video(s) to embed. Single VideoArtifact or list. + audio: Audio file(s) to embed. Single AudioArtifact or list. model: Model ID to use (required). provider: Optional provider override. api_key: Optional API key override. @@ -249,7 +254,9 @@ async def embed( api_key=api_key, auth=auth, ) - return await client.embed(text, images=images, videos=videos, **parameters) + return await client.embed( + text, images=images, videos=videos, audio=audio, **parameters + ) @property def sync(self) -> SyncTextNamespace: From 6a2ac1c6408983bfa46832ea6947be97c4f67350 Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 16:47:34 +0100 Subject: [PATCH 5/7] fix: display video and audio in notebook for user validation --- notebooks/multimodal-embeddings.ipynb | 120 +++++++++++--------------- 1 file changed, 50 insertions(+), 70 deletions(-) diff --git a/notebooks/multimodal-embeddings.ipynb b/notebooks/multimodal-embeddings.ipynb index dd17bb10..6402552e 100644 --- a/notebooks/multimodal-embeddings.ipynb +++ b/notebooks/multimodal-embeddings.ipynb @@ -14,19 +14,19 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2026-03-27T13:50:41.286419Z", - "start_time": "2026-03-27T13:50:40.643376Z" + "end_time": "2026-03-27T15:38:28.216364Z", + "start_time": "2026-03-27T15:38:27.886367Z" } }, - "outputs": [], "source": [ "import celeste\n", "import numpy as np\n", "from IPython.display import Image, display" - ] + ], + "outputs": [], + "execution_count": 1 }, { "cell_type": "markdown", @@ -41,14 +41,12 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": { "ExecuteTime": { - "end_time": "2026-03-27T13:50:42.413433Z", - "start_time": "2026-03-27T13:50:41.845251Z" + "end_time": "2026-03-27T15:38:28.346425Z", + "start_time": "2026-03-27T15:38:28.217400Z" } }, - "outputs": [], "source": [ "text_result = await celeste.text.embed(\n", " \"A happy golden retriever\", model=\"gemini-embedding-2-preview\"\n", @@ -56,7 +54,25 @@ "\n", "print(f\"Dimensions: {len(text_result.content)}\")\n", "print(f\"First 5 values: {text_result.content[:5]}\")" - ] + ], + "outputs": [ + { + "ename": "MissingCredentialsError", + "evalue": "Provider google has no credentials configured. Set the appropriate environment variable or pass api_key parameter.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mMissingCredentialsError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m text_result = \u001b[38;5;28;01mawait\u001b[39;00m celeste.text.embed(\n\u001b[32m 2\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mA happy golden retriever\u001b[39m\u001b[33m\"\u001b[39m, model=\u001b[33m\"\u001b[39m\u001b[33mgemini-embedding-2-preview\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 3\u001b[39m )\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDimensions: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text_result.content)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 6\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFirst 5 values: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext_result.content[:\u001b[32m5\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/namespaces/domains.py:244\u001b[39m, in \u001b[36mTextNamespace.embed\u001b[39m\u001b[34m(self, text, images, videos, model, provider, api_key, auth, **parameters)\u001b[39m\n\u001b[32m 217\u001b[39m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34membed\u001b[39m(\n\u001b[32m 218\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 219\u001b[39m text: \u001b[38;5;28mstr\u001b[39m | \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m] | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m (...)\u001b[39m\u001b[32m 227\u001b[39m **parameters: Unpack[EmbeddingsParameters],\n\u001b[32m 228\u001b[39m ) -> EmbeddingsOutput:\n\u001b[32m 229\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Generate embeddings from text, images, or video.\u001b[39;00m\n\u001b[32m 230\u001b[39m \n\u001b[32m 231\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 242\u001b[39m \u001b[33;03m EmbeddingsOutput with embedding vectors.\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m client = \u001b[43mcreate_client\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[43m=\u001b[49m\u001b[43mModality\u001b[49m\u001b[43m.\u001b[49m\u001b[43mEMBEDDINGS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 246\u001b[39m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m=\u001b[49m\u001b[43mOperation\u001b[49m\u001b[43m.\u001b[49m\u001b[43mEMBED\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 247\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 248\u001b[39m \u001b[43m \u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[43m \u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m=\u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 250\u001b[39m \u001b[43m \u001b[49m\u001b[43mauth\u001b[49m\u001b[43m=\u001b[49m\u001b[43mauth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 251\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 252\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m client.embed(text, images=images, videos=videos, **parameters)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/__init__.py:252\u001b[39m, in \u001b[36mcreate_client\u001b[39m\u001b[34m(capability, modality, operation, provider, model, api_key, auth, protocol, base_url)\u001b[39m\n\u001b[32m 250\u001b[39m resolved_auth = NoAuth()\n\u001b[32m 251\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m252\u001b[39m resolved_auth = \u001b[43mcredentials\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_auth\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 253\u001b[39m \u001b[43m \u001b[49m\u001b[43mresolved_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[arg-type] # always Provider in this branch\u001b[39;49;00m\n\u001b[32m 254\u001b[39m \u001b[43m \u001b[49m\u001b[43moverride_auth\u001b[49m\u001b[43m=\u001b[49m\u001b[43mauth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 255\u001b[39m \u001b[43m \u001b[49m\u001b[43moverride_key\u001b[49m\u001b[43m=\u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 258\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m modality_client_class(\n\u001b[32m 259\u001b[39m modality=resolved_modality,\n\u001b[32m 260\u001b[39m model=resolved_model,\n\u001b[32m (...)\u001b[39m\u001b[32m 264\u001b[39m base_url=base_url,\n\u001b[32m 265\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/credentials.py:190\u001b[39m, in \u001b[36mCredentials.get_auth\u001b[39m\u001b[34m(self, provider, override_auth, override_key)\u001b[39m\n\u001b[32m 188\u001b[39m \u001b[38;5;66;03m# API key config tuple → AuthHeader\u001b[39;00m\n\u001b[32m 189\u001b[39m _secret_name, header, prefix = registered\n\u001b[32m--> \u001b[39m\u001b[32m190\u001b[39m api_key = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_credentials\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverride_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 191\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m AuthHeader(secret=api_key, header=header, prefix=prefix)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/credentials.py:127\u001b[39m, in \u001b[36mCredentials.get_credentials\u001b[39m\u001b[34m(self, provider, override_key)\u001b[39m\n\u001b[32m 125\u001b[39m credential: SecretStr | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, field_name, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m 126\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m credential \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m credential.get_secret_value().strip():\n\u001b[32m--> \u001b[39m\u001b[32m127\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m MissingCredentialsError(provider=provider)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m credential\n", + "\u001b[31mMissingCredentialsError\u001b[39m: Provider google has no credentials configured. Set the appropriate environment variable or pass api_key parameter." + ] + } + ], + "execution_count": 2 }, { "cell_type": "markdown", @@ -71,38 +87,28 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T13:52:51.673322Z", - "start_time": "2026-03-27T13:52:42.237801Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "img_result = await celeste.images.generate(\n", " \"A golden retriever dog\", model=\"gemini-2.5-flash-image\"\n", ")\n", "display(Image(data=img_result.content.data))" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T13:53:11.915963Z", - "start_time": "2026-03-27T13:53:10.531407Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "img_emb = await celeste.images.embed(\n", " img_result.content, model=\"gemini-embedding-2-preview\"\n", ")\n", "print(f\"Dimensions: {len(img_emb.content)}\")\n", "print(f\"First 5 values: {img_emb.content[:5]}\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -111,26 +117,10 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T13:57:05.136610Z", - "start_time": "2026-03-27T13:57:01.124944Z" - } - }, + "metadata": {}, + "source": "import httpx\nfrom celeste.artifacts import VideoArtifact\nfrom celeste.mime_types import VideoMimeType\nfrom IPython.display import Video, Audio\n\nvideo_bytes = httpx.get(\"https://download.samplelib.com/mp4/sample-5s.mp4\").content\nvideo = VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4)\ndisplay(Video(data=video_bytes, embed=True, mimetype=\"video/mp4\"))\n\nvid_emb = await celeste.videos.embed(video, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(vid_emb.content)}\")\nprint(f\"First 5 values: {vid_emb.content[:5]}\")", "outputs": [], - "source": [ - "import httpx\n", - "from celeste.artifacts import VideoArtifact\n", - "from celeste.mime_types import VideoMimeType\n", - "\n", - "video_bytes = httpx.get(\"https://download.samplelib.com/mp4/sample-5s.mp4\").content\n", - "video = VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4)\n", - "\n", - "vid_emb = await celeste.videos.embed(video, model=\"gemini-embedding-2-preview\")\n", - "print(f\"Dimensions: {len(vid_emb.content)}\")\n", - "print(f\"First 5 values: {vid_emb.content[:5]}\")" - ] + "execution_count": null }, { "cell_type": "markdown", @@ -139,10 +129,10 @@ }, { "cell_type": "code", - "source": "from celeste.artifacts import AudioArtifact\nfrom celeste.mime_types import AudioMimeType\n\naudio_bytes = httpx.get(\"https://download.samplelib.com/mp3/sample-3s.mp3\").content\naudio = AudioArtifact(data=audio_bytes, mime_type=AudioMimeType.MP3)\n\naud_emb = await celeste.audio.embed(audio, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(aud_emb.content)}\")\nprint(f\"First 5 values: {aud_emb.content[:5]}\")", + "source": "from celeste.artifacts import AudioArtifact\nfrom celeste.mime_types import AudioMimeType\n\naudio_bytes = httpx.get(\"https://download.samplelib.com/mp3/sample-3s.mp3\").content\naudio = AudioArtifact(data=audio_bytes, mime_type=AudioMimeType.MP3)\ndisplay(Audio(data=audio_bytes, autoplay=False))\n\naud_emb = await celeste.audio.embed(audio, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(aud_emb.content)}\")\nprint(f\"First 5 values: {aud_emb.content[:5]}\")", "metadata": {}, - "execution_count": null, - "outputs": [] + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -157,28 +147,16 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T14:21:14.584469Z", - "start_time": "2026-03-27T14:21:14.362443Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "!uv pip install matplotlib seaborn" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T14:22:17.809811Z", - "start_time": "2026-03-27T14:22:16.624079Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "import seaborn as sns\n", "import pandas as pd\n", @@ -203,7 +181,9 @@ ")\n", "\n", "assert scores[\"dog\"] > scores[\"chair\"]" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -215,10 +195,10 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, + "source": [], "outputs": [], - "source": [] + "execution_count": null } ], "metadata": { From 106ca474cec633d4564c9879a1333465a05e5f5c Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 17:58:29 +0100 Subject: [PATCH 6/7] refactor: extract shared build_media_part utility Closes #247 Replace duplicated _build_image/video/audio_part methods across text, images, embeddings, and generate_content Google clients with a single shared build_media_part function in celeste.providers.google.utils. --- .../embeddings/providers/google/client.py | 100 ++++-------------- .../images/providers/google/gemini.py | 27 +---- .../text/providers/google/client.py | 46 +------- .../google/generate_content/parameters.py | 26 +---- src/celeste/providers/google/utils.py | 18 ++++ 5 files changed, 47 insertions(+), 170 deletions(-) create mode 100644 src/celeste/providers/google/utils.py diff --git a/src/celeste/modalities/embeddings/providers/google/client.py b/src/celeste/modalities/embeddings/providers/google/client.py index ee5f3ace..2acc323c 100644 --- a/src/celeste/modalities/embeddings/providers/google/client.py +++ b/src/celeste/modalities/embeddings/providers/google/client.py @@ -1,15 +1,13 @@ """Google embeddings client.""" -import base64 from typing import Any -from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.parameters import ParameterMapper from celeste.providers.google.embeddings.client import ( GoogleEmbeddingsClient as GoogleEmbeddingsMixin, ) +from celeste.providers.google.utils import build_media_part from celeste.types import EmbeddingsContent -from celeste.utils import detect_mime_type from ...client import EmbeddingsClient from ...io import EmbeddingsInput @@ -24,89 +22,33 @@ def parameter_mappers(cls) -> list[ParameterMapper[EmbeddingsContent]]: """Return parameter mappers for Google embeddings.""" return GOOGLE_PARAMETER_MAPPERS - def _build_image_part(self, image: ImageArtifact) -> dict[str, Any]: - """Build a Gemini part from an ImageArtifact.""" - if image.url: - return {"file_data": {"file_uri": image.url}} - image_bytes = image.get_bytes() - b64 = base64.b64encode(image_bytes).decode("utf-8") - mime = image.mime_type or detect_mime_type(image_bytes) - mime_str = mime.value if mime else None - return {"inline_data": {"mime_type": mime_str, "data": b64}} - - def _build_video_part(self, video: VideoArtifact) -> dict[str, Any]: - """Build a Gemini part from a VideoArtifact.""" - if video.url: - return {"file_data": {"file_uri": video.url}} - video_bytes = video.get_bytes() - b64 = base64.b64encode(video_bytes).decode("utf-8") - mime = video.mime_type or detect_mime_type(video_bytes) - mime_str = mime.value if mime else None - return {"inline_data": {"mime_type": mime_str, "data": b64}} - - def _build_audio_part(self, audio: AudioArtifact) -> dict[str, Any]: - """Build a Gemini part from an AudioArtifact.""" - if audio.url: - return {"file_data": {"file_uri": audio.url}} - audio_bytes = audio.get_bytes() - b64 = base64.b64encode(audio_bytes).decode("utf-8") - mime = audio.mime_type or detect_mime_type(audio_bytes) - mime_str = mime.value if mime else None - return {"inline_data": {"mime_type": mime_str, "data": b64}} - def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]: """Build Google embeddings request from inputs.""" - # Batch images → separate embeddings via batchEmbedContents - if isinstance(inputs.images, list): - return { - "requests": [ - { - "model": f"models/{self.model.id}", - "content": {"parts": [self._build_image_part(img)]}, - } - for img in inputs.images - ] - } - - # Batch videos → separate embeddings via batchEmbedContents - if isinstance(inputs.videos, list): - return { - "requests": [ - { - "model": f"models/{self.model.id}", - "content": {"parts": [self._build_video_part(vid)]}, - } - for vid in inputs.videos - ] - } - - # Batch audio → separate embeddings via batchEmbedContents - if isinstance(inputs.audio, list): - return { - "requests": [ - { - "model": f"models/{self.model.id}", - "content": {"parts": [self._build_audio_part(aud)]}, - } - for aud in inputs.audio - ] - } + # Batch media → separate embeddings via batchEmbedContents + for field in (inputs.images, inputs.videos, inputs.audio): + if isinstance(field, list): + return { + "requests": [ + { + "model": f"models/{self.model.id}", + "content": {"parts": [build_media_part(item)]}, + } + for item in field + ] + } # Single/combined multimodal → one aggregated embedding - if ( - inputs.images is not None - or inputs.videos is not None - or inputs.audio is not None - ): + media = [ + f + for f in (inputs.images, inputs.videos, inputs.audio) + if f is not None and not isinstance(f, list) + ] + if media: parts: list[dict[str, Any]] = [] if inputs.text is not None: parts.append({"text": inputs.text}) - if inputs.images is not None: - parts.append(self._build_image_part(inputs.images)) - if inputs.videos is not None: - parts.append(self._build_video_part(inputs.videos)) - if inputs.audio is not None: - parts.append(self._build_audio_part(inputs.audio)) + for artifact in media: + parts.append(build_media_part(artifact)) return {"content": {"parts": parts}} # Text-only (existing behavior) diff --git a/src/celeste/modalities/images/providers/google/gemini.py b/src/celeste/modalities/images/providers/google/gemini.py index 472c87d7..9b7973d8 100644 --- a/src/celeste/modalities/images/providers/google/gemini.py +++ b/src/celeste/modalities/images/providers/google/gemini.py @@ -1,6 +1,5 @@ """Gemini client for Google images modality.""" -import base64 from typing import Any, Unpack from celeste.artifacts import ImageArtifact @@ -9,6 +8,7 @@ from celeste.parameters import ParameterMapper from celeste.providers.google.generate_content import config as google_config from celeste.providers.google.generate_content.client import GoogleGenerateContentClient +from celeste.providers.google.utils import build_media_part from celeste.types import ImageContent from ...client import ImagesClient @@ -17,29 +17,6 @@ from .parameters import GEMINI_PARAMETER_MAPPERS -def _build_image_part(image: ImageArtifact) -> dict[str, Any]: - """Build a Gemini image part from an ImageArtifact (snake_case, provider-style).""" - if image.url: - return {"file_data": {"file_uri": image.url}} - - if image.data is not None: - image_bytes = image.data - elif image.path: - with open(image.path, "rb") as f: - image_bytes = f.read() - else: - msg = "ImageArtifact must have url, data, or path" - raise ValueError(msg) - - base64_data = base64.b64encode(image_bytes).decode("utf-8") - return { - "inline_data": { - "mime_type": image.mime_type, - "data": base64_data, - } - } - - class GeminiImagesClient(GoogleGenerateContentClient, ImagesClient): """Google Gemini client for images modality (generate + edit).""" @@ -78,7 +55,7 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]: # Edit uses an input image (generation omits it) if inputs.image is not None: - parts.append(_build_image_part(inputs.image)) + parts.append(build_media_part(inputs.image)) parts.append({"text": inputs.prompt}) diff --git a/src/celeste/modalities/text/providers/google/client.py b/src/celeste/modalities/text/providers/google/client.py index 555fe3dc..26ac77dd 100644 --- a/src/celeste/modalities/text/providers/google/client.py +++ b/src/celeste/modalities/text/providers/google/client.py @@ -1,19 +1,17 @@ """Google text client (modality).""" -import base64 from typing import Any, Unpack from uuid import uuid4 -from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.parameters import ParameterMapper from celeste.providers.google.generate_content import config as google_config from celeste.providers.google.generate_content.client import GoogleGenerateContentClient from celeste.providers.google.generate_content.streaming import ( GoogleGenerateContentStream as _GoogleGenerateContentStream, ) +from celeste.providers.google.utils import build_media_part from celeste.tools import ToolCall, ToolResult from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent -from celeste.utils import detect_mime_type from ...client import TextClient from ...io import ( @@ -166,58 +164,22 @@ def content_to_parts(content: Any) -> list[dict[str, Any]]: if inputs.image is not None: images = inputs.image if isinstance(inputs.image, list) else [inputs.image] for img in images: - parts.append(self._build_image_part(img)) + parts.append(build_media_part(img)) if inputs.video is not None: videos = inputs.video if isinstance(inputs.video, list) else [inputs.video] for vid in videos: - parts.append(self._build_video_part(vid)) + parts.append(build_media_part(vid)) if inputs.audio is not None: audios = inputs.audio if isinstance(inputs.audio, list) else [inputs.audio] for aud in audios: - parts.append(self._build_audio_part(aud)) + parts.append(build_media_part(aud)) parts.append({"text": inputs.prompt or ""}) return {"contents": [{"role": "user", "parts": parts}]} - def _build_image_part(self, image: ImageArtifact) -> dict[str, Any]: - """Build a Gemini part from an ImageArtifact.""" - if image.url: - return {"file_data": {"file_uri": image.url}} - - image_bytes = image.get_bytes() - b64 = base64.b64encode(image_bytes).decode("utf-8") - mime = image.mime_type or detect_mime_type(image_bytes) - mime_str = mime.value if mime else None - - return {"inline_data": {"mime_type": mime_str, "data": b64}} - - def _build_video_part(self, video: VideoArtifact) -> dict[str, Any]: - """Build a Gemini part from a VideoArtifact.""" - if video.url: - return {"file_data": {"file_uri": video.url}} - - video_bytes = video.get_bytes() - b64 = base64.b64encode(video_bytes).decode("utf-8") - mime = video.mime_type or detect_mime_type(video_bytes) - mime_str = mime.value if mime else None - - return {"inline_data": {"mime_type": mime_str, "data": b64}} - - def _build_audio_part(self, audio: AudioArtifact) -> dict[str, Any]: - """Build a Gemini part from an AudioArtifact.""" - if audio.url: - return {"file_data": {"file_uri": audio.url}} - - audio_bytes = audio.get_bytes() - b64 = base64.b64encode(audio_bytes).decode("utf-8") - mime = audio.mime_type or detect_mime_type(audio_bytes) - mime_str = mime.value if mime else None - - return {"inline_data": {"mime_type": mime_str, "data": b64}} - def _parse_content( self, response_data: dict[str, Any], diff --git a/src/celeste/providers/google/generate_content/parameters.py b/src/celeste/providers/google/generate_content/parameters.py index 0f44828a..0b9daba4 100644 --- a/src/celeste/providers/google/generate_content/parameters.py +++ b/src/celeste/providers/google/generate_content/parameters.py @@ -1,16 +1,15 @@ """Google GenerateContent API parameter mappers.""" -import base64 import json from typing import Any, get_args, get_origin from pydantic import BaseModel, TypeAdapter -from celeste.artifacts import ImageArtifact from celeste.exceptions import InvalidToolError from celeste.mime_types import ApplicationMimeType from celeste.models import Model from celeste.parameters import ParameterMapper +from celeste.providers.google.utils import build_media_part from celeste.tools import Tool from celeste.types import TextContent @@ -136,27 +135,6 @@ def map( class MediaContentMapper[Content](ParameterMapper[Content]): """Map reference_images to Google contents.parts field.""" - def _build_image_part(self, image: ImageArtifact) -> dict[str, Any]: - """Build a Gemini part from an ImageArtifact.""" - if image.url: - return {"file_data": {"file_uri": image.url}} - - if image.data: - b64 = ( - base64.b64encode(image.data).decode("utf-8") - if isinstance(image.data, bytes) - else image.data - ) - return {"inline_data": {"mime_type": str(image.mime_type), "data": b64}} - - if image.path: - with open(image.path, "rb") as f: - b64 = base64.b64encode(f.read()).decode("utf-8") - return {"inline_data": {"mime_type": str(image.mime_type), "data": b64}} - - msg = "ImageArtifact must have url, data, or path" - raise ValueError(msg) - def map( self, request: dict[str, Any], @@ -169,7 +147,7 @@ def map( return request # Convert list of ImageArtifact to list of image parts - image_parts = [self._build_image_part(img) for img in validated_value] + image_parts = [build_media_part(img) for img in validated_value] # Add image parts before text in contents[0].parts if "contents" in request and len(request["contents"]) > 0: diff --git a/src/celeste/providers/google/utils.py b/src/celeste/providers/google/utils.py new file mode 100644 index 00000000..97b7c472 --- /dev/null +++ b/src/celeste/providers/google/utils.py @@ -0,0 +1,18 @@ +"""Shared utilities for Google/Gemini API providers.""" + +import base64 +from typing import Any + +from celeste.artifacts import Artifact +from celeste.utils import detect_mime_type + + +def build_media_part(artifact: Artifact) -> dict[str, Any]: + """Convert any media artifact to a Gemini inline_data/file_data part.""" + if artifact.url: + return {"file_data": {"file_uri": artifact.url}} + media_bytes = artifact.get_bytes() + b64 = base64.b64encode(media_bytes).decode("utf-8") + mime = artifact.mime_type or detect_mime_type(media_bytes) + mime_str = mime.value if mime else None + return {"inline_data": {"mime_type": mime_str, "data": b64}} From fa20780f74524365b3c4174cdc700206a4cb2b3f Mon Sep 17 00:00:00 2001 From: kamilbenkirane Date: Fri, 27 Mar 2026 18:02:29 +0100 Subject: [PATCH 7/7] chore: clear notebook outputs --- notebooks/multimodal-embeddings.ipynb | 141 +++++++++++++++----------- 1 file changed, 81 insertions(+), 60 deletions(-) diff --git a/notebooks/multimodal-embeddings.ipynb b/notebooks/multimodal-embeddings.ipynb index 6402552e..54dfc00b 100644 --- a/notebooks/multimodal-embeddings.ipynb +++ b/notebooks/multimodal-embeddings.ipynb @@ -3,7 +3,13 @@ { "cell_type": "markdown", "metadata": {}, - "source": "# Celeste AI - Multimodal Embeddings\n\nEmbed **text**, **images**, **video**, and **audio** into a unified vector space with `gemini-embedding-2-preview`.\n\nStar on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" + "source": [ + "# Celeste AI - Multimodal Embeddings\n", + "\n", + "Embed **text**, **images**, **video**, and **audio** into a unified vector space with `gemini-embedding-2-preview`.\n", + "\n", + "Star on GitHub 👉 [withceleste/celeste-python](https://github.com/withceleste/celeste-python)" + ] }, { "cell_type": "markdown", @@ -14,19 +20,16 @@ }, { "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T15:38:28.216364Z", - "start_time": "2026-03-27T15:38:27.886367Z" - } - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ + "import os\n", + "\n", "import celeste\n", "import numpy as np\n", "from IPython.display import Image, display" - ], - "outputs": [], - "execution_count": 1 + ] }, { "cell_type": "markdown", @@ -41,38 +44,20 @@ }, { "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2026-03-27T15:38:28.346425Z", - "start_time": "2026-03-27T15:38:28.217400Z" - } - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ + "from dotenv import load_dotenv\n", + "load_dotenv(\"/Users/kamilbenkirane/Desktop/Projects/withceleste/celeste-python/.env\")\n", + "\n", "text_result = await celeste.text.embed(\n", - " \"A happy golden retriever\", model=\"gemini-embedding-2-preview\"\n", + " \"A happy golden retriever\", model=\"gemini-embedding-2-preview\", api_key=os.getenv(\"GOOGLE_API_KEY\")\n", ")\n", "\n", "print(f\"Dimensions: {len(text_result.content)}\")\n", "print(f\"First 5 values: {text_result.content[:5]}\")" - ], - "outputs": [ - { - "ename": "MissingCredentialsError", - "evalue": "Provider google has no credentials configured. Set the appropriate environment variable or pass api_key parameter.", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mMissingCredentialsError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m text_result = \u001b[38;5;28;01mawait\u001b[39;00m celeste.text.embed(\n\u001b[32m 2\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mA happy golden retriever\u001b[39m\u001b[33m\"\u001b[39m, model=\u001b[33m\"\u001b[39m\u001b[33mgemini-embedding-2-preview\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 3\u001b[39m )\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mDimensions: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text_result.content)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 6\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFirst 5 values: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtext_result.content[:\u001b[32m5\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/namespaces/domains.py:244\u001b[39m, in \u001b[36mTextNamespace.embed\u001b[39m\u001b[34m(self, text, images, videos, model, provider, api_key, auth, **parameters)\u001b[39m\n\u001b[32m 217\u001b[39m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34membed\u001b[39m(\n\u001b[32m 218\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 219\u001b[39m text: \u001b[38;5;28mstr\u001b[39m | \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mstr\u001b[39m] | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m (...)\u001b[39m\u001b[32m 227\u001b[39m **parameters: Unpack[EmbeddingsParameters],\n\u001b[32m 228\u001b[39m ) -> EmbeddingsOutput:\n\u001b[32m 229\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Generate embeddings from text, images, or video.\u001b[39;00m\n\u001b[32m 230\u001b[39m \n\u001b[32m 231\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 242\u001b[39m \u001b[33;03m EmbeddingsOutput with embedding vectors.\u001b[39;00m\n\u001b[32m 243\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m244\u001b[39m client = \u001b[43mcreate_client\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 245\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[43m=\u001b[49m\u001b[43mModality\u001b[49m\u001b[43m.\u001b[49m\u001b[43mEMBEDDINGS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 246\u001b[39m \u001b[43m \u001b[49m\u001b[43moperation\u001b[49m\u001b[43m=\u001b[49m\u001b[43mOperation\u001b[49m\u001b[43m.\u001b[49m\u001b[43mEMBED\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 247\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 248\u001b[39m \u001b[43m \u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 249\u001b[39m \u001b[43m \u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m=\u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 250\u001b[39m \u001b[43m \u001b[49m\u001b[43mauth\u001b[49m\u001b[43m=\u001b[49m\u001b[43mauth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 251\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 252\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m client.embed(text, images=images, videos=videos, **parameters)\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/__init__.py:252\u001b[39m, in \u001b[36mcreate_client\u001b[39m\u001b[34m(capability, modality, operation, provider, model, api_key, auth, protocol, base_url)\u001b[39m\n\u001b[32m 250\u001b[39m resolved_auth = NoAuth()\n\u001b[32m 251\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m252\u001b[39m resolved_auth = \u001b[43mcredentials\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_auth\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 253\u001b[39m \u001b[43m \u001b[49m\u001b[43mresolved_model\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[arg-type] # always Provider in this branch\u001b[39;49;00m\n\u001b[32m 254\u001b[39m \u001b[43m \u001b[49m\u001b[43moverride_auth\u001b[49m\u001b[43m=\u001b[49m\u001b[43mauth\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 255\u001b[39m \u001b[43m \u001b[49m\u001b[43moverride_key\u001b[49m\u001b[43m=\u001b[49m\u001b[43mapi_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 256\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 258\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m modality_client_class(\n\u001b[32m 259\u001b[39m modality=resolved_modality,\n\u001b[32m 260\u001b[39m model=resolved_model,\n\u001b[32m (...)\u001b[39m\u001b[32m 264\u001b[39m base_url=base_url,\n\u001b[32m 265\u001b[39m )\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/credentials.py:190\u001b[39m, in \u001b[36mCredentials.get_auth\u001b[39m\u001b[34m(self, provider, override_auth, override_key)\u001b[39m\n\u001b[32m 188\u001b[39m \u001b[38;5;66;03m# API key config tuple → AuthHeader\u001b[39;00m\n\u001b[32m 189\u001b[39m _secret_name, header, prefix = registered\n\u001b[32m--> \u001b[39m\u001b[32m190\u001b[39m api_key = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_credentials\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprovider\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverride_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 191\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m AuthHeader(secret=api_key, header=header, prefix=prefix)\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/Projects/withceleste/celeste-python/src/celeste/credentials.py:127\u001b[39m, in \u001b[36mCredentials.get_credentials\u001b[39m\u001b[34m(self, provider, override_key)\u001b[39m\n\u001b[32m 125\u001b[39m credential: SecretStr | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, field_name, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m 126\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m credential \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m credential.get_secret_value().strip():\n\u001b[32m--> \u001b[39m\u001b[32m127\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m MissingCredentialsError(provider=provider)\n\u001b[32m 129\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m credential\n", - "\u001b[31mMissingCredentialsError\u001b[39m: Provider google has no credentials configured. Set the appropriate environment variable or pass api_key parameter." - ] - } - ], - "execution_count": 2 + ] }, { "cell_type": "markdown", @@ -87,52 +72,88 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "img_result = await celeste.images.generate(\n", - " \"A golden retriever dog\", model=\"gemini-2.5-flash-image\"\n", + " \"A golden retriever dog\", model=\"gemini-2.5-flash-image\", api_key=os.getenv(\"GOOGLE_API_KEY\")\n", ")\n", "display(Image(data=img_result.content.data))" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "img_emb = await celeste.images.embed(\n", - " img_result.content, model=\"gemini-embedding-2-preview\"\n", + " img_result.content, model=\"gemini-embedding-2-preview\", api_key=os.getenv(\"GOOGLE_API_KEY\")\n", ")\n", "print(f\"Dimensions: {len(img_emb.content)}\")\n", "print(f\"First 5 values: {img_emb.content[:5]}\")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", - "source": "---\n\n## Video Embedding\n\nDownload a short sample video and embed it using the `celeste.videos` domain namespace.", - "metadata": {} + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Video Embedding\n", + "\n", + "Download a short sample video and embed it using the `celeste.videos` domain namespace." + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": "import httpx\nfrom celeste.artifacts import VideoArtifact\nfrom celeste.mime_types import VideoMimeType\nfrom IPython.display import Video, Audio\n\nvideo_bytes = httpx.get(\"https://download.samplelib.com/mp4/sample-5s.mp4\").content\nvideo = VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4)\ndisplay(Video(data=video_bytes, embed=True, mimetype=\"video/mp4\"))\n\nvid_emb = await celeste.videos.embed(video, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(vid_emb.content)}\")\nprint(f\"First 5 values: {vid_emb.content[:5]}\")", "outputs": [], - "execution_count": null + "source": [ + "import httpx\n", + "from celeste.artifacts import VideoArtifact\n", + "from celeste.mime_types import VideoMimeType\n", + "from IPython.display import Video, Audio\n", + "\n", + "video_bytes = httpx.get(\"https://download.samplelib.com/mp4/sample-5s.mp4\").content\n", + "video = VideoArtifact(data=video_bytes, mime_type=VideoMimeType.MP4)\n", + "display(Video(data=video_bytes, embed=True, mimetype=\"video/mp4\"))\n", + "\n", + "vid_emb = await celeste.videos.embed(video, model=\"gemini-embedding-2-preview\", api_key=os.getenv(\"GOOGLE_API_KEY\"))\n", + "print(f\"Dimensions: {len(vid_emb.content)}\")\n", + "print(f\"First 5 values: {vid_emb.content[:5]}\")" + ] }, { "cell_type": "markdown", - "source": "---\n\n## Audio Embedding\n\nDownload a short sample audio and embed it using the `celeste.audio` domain namespace.", - "metadata": {} + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Audio Embedding\n", + "\n", + "Download a short sample audio and embed it using the `celeste.audio` domain namespace." + ] }, { "cell_type": "code", - "source": "from celeste.artifacts import AudioArtifact\nfrom celeste.mime_types import AudioMimeType\n\naudio_bytes = httpx.get(\"https://download.samplelib.com/mp3/sample-3s.mp3\").content\naudio = AudioArtifact(data=audio_bytes, mime_type=AudioMimeType.MP3)\ndisplay(Audio(data=audio_bytes, autoplay=False))\n\naud_emb = await celeste.audio.embed(audio, model=\"gemini-embedding-2-preview\")\nprint(f\"Dimensions: {len(aud_emb.content)}\")\nprint(f\"First 5 values: {aud_emb.content[:5]}\")", + "execution_count": null, "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "from celeste.artifacts import AudioArtifact\n", + "from celeste.mime_types import AudioMimeType\n", + "\n", + "audio_bytes = httpx.get(\"https://download.samplelib.com/mp3/sample-3s.mp3\").content\n", + "audio = AudioArtifact(data=audio_bytes, mime_type=AudioMimeType.MP3)\n", + "display(Audio(data=audio_bytes, autoplay=False))\n", + "\n", + "aud_emb = await celeste.audio.embed(audio, model=\"gemini-embedding-2-preview\", api_key=os.getenv(\"GOOGLE_API_KEY\"))\n", + "print(f\"Dimensions: {len(aud_emb.content)}\")\n", + "print(f\"First 5 values: {aud_emb.content[:5]}\")" + ] }, { "cell_type": "markdown", @@ -147,16 +168,18 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "!uv pip install matplotlib seaborn" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ "import seaborn as sns\n", "import pandas as pd\n", @@ -181,9 +204,7 @@ ")\n", "\n", "assert scores[\"dog\"] > scores[\"chair\"]" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -195,10 +216,10 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [], "outputs": [], - "execution_count": null + "source": [] } ], "metadata": {