diff --git a/README.md b/README.md index 91379fd7..730d894b 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ RAGLite is a Python toolkit for Retrieval-Augmented Generation (RAG) with DuckDB - 🔌 A built-in [Model Context Protocol](https://modelcontextprotocol.io) (MCP) server that any MCP client like [Claude desktop](https://claude.ai/download) can connect with - 💬 Optional customizable ChatGPT-like frontend for [web](https://docs.chainlit.io/deploy/copilot), [Slack](https://docs.chainlit.io/deploy/slack), and [Teams](https://docs.chainlit.io/deploy/teams) with [Chainlit](https://github.com/Chainlit/chainlit) - ✍️ Optional conversion of any input document to Markdown with [Pandoc](https://github.com/jgm/pandoc) +- 🔎 Optional high-quality document processing with [Mistral OCR](https://docs.mistral.ai/capabilities/document/) for PDFs, images, DOCX, and PPTX with automatic image descriptions - ✅ Optional evaluation of retrieval and generation performance with [Ragas](https://github.com/explodinggradients/ragas) ## Installing @@ -69,6 +70,12 @@ To add support for filetypes other than PDF, use the `pandoc` extra: pip install raglite[pandoc] ``` +To add support for high-quality document processing with [Mistral OCR](https://docs.mistral.ai/capabilities/document/), use the `mistral-ocr` extra: + +```sh +pip install raglite[mistral-ocr] +``` + To add support for evaluation, use the `ragas` extra: ```sh @@ -152,6 +159,21 @@ my_config = RAGLiteConfig( > [!TIP] > ✍️ To insert documents other than PDF, install the `pandoc` extra with `pip install raglite[pandoc]`. +> [!TIP] +> 🔎 For higher-quality document processing with automatic image descriptions, install the `mistral-ocr` extra with `pip install raglite[mistral-ocr]` and configure it as follows: +> ```python +> from raglite import RAGLiteConfig, MistralOCRConfig +> +> my_config = RAGLiteConfig( +> document_processor=MistralOCRConfig( +> include_image_descriptions=True, # Describe images, charts, and diagrams as text +> image_types=frozenset({"chart", "diagram", "photo", "table", "logo", "icon"}), # Custom image categories +> exclude_image_types=frozenset({"logo", "icon"}), # Filter out specific types from the output +> ), +> ) +> ``` +> The `image_types` parameter defines the categories that Mistral classifies each image into — you can use the defaults or provide your own domain-specific types. Use `exclude_image_types` to filter out any classified types that are not useful for retrieval. + Next, insert some documents into the database. RAGLite will take care of the [conversion to Markdown](src/raglite/_markdown.py), [optimal level 4 semantic chunking](src/raglite/_split_chunks.py), and [multi-vector embedding with late chunking](src/raglite/_embed.py): ```python diff --git a/pyproject.toml b/pyproject.toml index cdde3628..85e1dc49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dev = [ "pytest (>=8.3.4)", "pytest-mock (>=3.14.0)", "pytest-xdist (>=3.6.1)", + "python-dotenv (>=1.0.0)", "ruff (>=0.10.0)", "typeguard (>=4.4.1)", ] @@ -80,6 +81,7 @@ chainlit = ["chainlit (>=2.0.0)"] # Large Language Models: llama-cpp-python = ["llama-cpp-python (>=0.3.9)"] # Markdown conversion: +mistral-ocr = ["mistralai (>=1.10.1)"] pandoc = ["pypandoc-binary (>=1.13)"] # Evaluation: ragas = ["pandas (>=2.1.1)", "ragas (>=0.3.3)"] diff --git a/src/raglite/__init__.py b/src/raglite/__init__.py index 52a9b598..5f568c68 100644 --- a/src/raglite/__init__.py +++ b/src/raglite/__init__.py @@ -1,11 +1,12 @@ """RAGLite.""" -from raglite._config import RAGLiteConfig +from raglite._config import MistralOCRConfig, RAGLiteConfig from raglite._database import Document from raglite._delete import delete_documents, delete_documents_by_metadata from raglite._eval import answer_evals, evaluate, insert_evals from raglite._extract import expand_document_metadata from raglite._insert import insert_documents +from raglite._mistral_ocr import MistralOCRError from raglite._query_adapter import update_query_adapter from raglite._rag import add_context, async_rag, rag, retrieve_context from raglite._search import ( @@ -22,6 +23,8 @@ __all__ = [ # Config "RAGLiteConfig", + "MistralOCRConfig", + "MistralOCRError", # Insert "Document", "insert_documents", diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index d6ff12ea..5bd8d939 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -75,7 +75,7 @@ async def handle_message(user_message: cl.Message) -> None: inline_attachments = [] for file in user_message.elements: if file.path: - doc_md = document_to_markdown(Path(file.path)) + doc_md = document_to_markdown(Path(file.path), config=config) if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3): # Document is small enough to attach to the context. inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}") @@ -83,7 +83,7 @@ async def handle_message(user_message: cl.Message) -> None: # Document is too large and must be inserted into the database. async with cl.Step(name="insert", type="run") as step: step.input = Path(file.path).name - document = Document.from_path(Path(file.path)) + document = Document.from_path(Path(file.path), config=config) await async_insert_documents([document], config=config) # Append any inline attachments to the user prompt. user_prompt = ( diff --git a/src/raglite/_config.py b/src/raglite/_config.py index de715c68..da0471d3 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -23,6 +23,26 @@ cache_path = Path(user_data_dir("raglite", ensure_exists=True)) +DEFAULT_IMAGE_TYPES = frozenset( + {"graph", "chart", "diagram", "table", "photo", "screenshot", "logo", "icon", "other"} +) + + +@dataclass(frozen=True) +class MistralOCRConfig: + """Configuration for MistralOCR document processor.""" + + # API key - falls back to MISTRAL_API_KEY env var if None. + api_key: str | None = None + # Whether to use vision to describe images in documents. + include_image_descriptions: bool = True + # Image types that Mistral classifies each image into. + image_types: frozenset[str] = DEFAULT_IMAGE_TYPES + # Image types to exclude from the output (e.g., {"logo", "icon"}). + exclude_image_types: frozenset[str] = frozenset() + model: str = "mistral-ocr-latest" + + # Lazily load the default search method to avoid circular imports. # TODO: Replace with search_and_rerank_chunk_spans after benchmarking. def _vector_search( @@ -65,6 +85,8 @@ class RAGLiteConfig: embedder_normalize: bool = True # Chunk config used to partition documents into chunks. chunk_max_size: int = 2048 # Max number of characters per chunk. + # Document processing config. None = default processor. + document_processor: MistralOCRConfig | None = None # Vector search config. vector_search_distance_metric: Literal["cosine", "dot", "l2"] = "cosine" vector_search_multivector: bool = True diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 2276f46b..7ea5cc62 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -108,6 +108,7 @@ def from_path( *, id: DocumentId | None = None, # noqa: A002 url: str | None = None, + config: RAGLiteConfig | None = None, **kwargs: Any, ) -> "Document": """Create a document from a file path. @@ -120,6 +121,8 @@ def from_path( The document id to use. If not provided, a hash of the document's content is used. url The URL of the document, if available. + config + The RAGLite configuration for document processing. kwargs Any additional metadata to store. @@ -145,7 +148,7 @@ def from_path( filename=doc_path.name, url=url, metadata_=metadata, - content=document_to_markdown(doc_path), + content=document_to_markdown(doc_path, config=config), ) @staticmethod diff --git a/src/raglite/_markdown.py b/src/raglite/_markdown.py index b5269fb3..d41a63b1 100644 --- a/src/raglite/_markdown.py +++ b/src/raglite/_markdown.py @@ -1,5 +1,6 @@ """Convert any document to Markdown.""" +import logging import re from copy import deepcopy from pathlib import Path @@ -9,6 +10,10 @@ from pdftext.extraction import dictionary_output from sklearn.cluster import KMeans +from raglite._config import MistralOCRConfig, RAGLiteConfig + +logger = logging.getLogger(__name__) + def parsed_pdf_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0915 """Convert a PDF parsed with pdftext to Markdown.""" @@ -194,8 +199,8 @@ def _merge_split_headings(match: re.Match[str]) -> str: return pages_md -def document_to_markdown(doc_path: Path) -> str: - """Convert any document to GitHub Flavored Markdown.""" +def _default_document_to_markdown(doc_path: Path) -> str: + """Convert any document to GitHub Flavored Markdown using pdftext/pandoc.""" # Convert the file's content to GitHub Flavored Markdown. if doc_path.suffix == ".pdf": # Parse the PDF with pdftext and convert it to Markdown. @@ -219,3 +224,34 @@ def document_to_markdown(doc_path: Path) -> str: # File format not supported, fall back to reading the text. doc = doc_path.read_text() return doc + + +def document_to_markdown(doc_path: Path, *, config: RAGLiteConfig | None = None) -> str: + """Convert any document to GitHub Flavored Markdown. + + Parameters + ---------- + doc_path + Path to the document file. + config + Optional RAGLite configuration. If document_processor is set to a + MistralOCRConfig, uses MistralOCR instead of the default processor. + + Returns + ------- + str + Document content as GitHub Flavored Markdown. + """ + config = config or RAGLiteConfig() + + if isinstance(config.document_processor, MistralOCRConfig): + # Lazy import to avoid requiring mistralai when not using MistralOCR. + from raglite._mistral_ocr import SUPPORTED_EXTENSIONS, mistral_ocr_to_markdown + + if doc_path.suffix.lower() in SUPPORTED_EXTENSIONS: + return mistral_ocr_to_markdown(doc_path, processor_config=config.document_processor) + logger.debug( + "Mistral does not support file type: %s\nFalling back to default processor.", doc_path + ) + + return _default_document_to_markdown(doc_path) diff --git a/src/raglite/_mistral_ocr.py b/src/raglite/_mistral_ocr.py new file mode 100644 index 00000000..84fccc78 --- /dev/null +++ b/src/raglite/_mistral_ocr.py @@ -0,0 +1,236 @@ +"""MistralOCR document processor for RAGLite.""" + +import base64 +import logging +import os +import re +from enum import Enum +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field + +from raglite._config import MistralOCRConfig + +logger = logging.getLogger(__name__) + +# Single source of truth for supported extensions and their MIME types. +_MIME_TYPES: dict[str, str] = { + ".pdf": "application/pdf", + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".avif": "image/avif", + ".webp": "image/webp", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", +} +SUPPORTED_EXTENSIONS = frozenset(_MIME_TYPES) +_IMAGE_EXTENSIONS = frozenset(ext for ext, mime in _MIME_TYPES.items() if mime.startswith("image/")) + + +class MistralOCRError(Exception): + """Error during MistralOCR processing.""" + + +def _build_image_annotation_model(image_types: frozenset[str]) -> type[BaseModel]: + """Build an ImageAnnotation Pydantic model with the given image types.""" + image_type_enum = Enum("ImageType", {t.upper(): t for t in sorted(image_types)}, type=str) # type: ignore[misc] + image_type_values = ", ".join(sorted(image_types)) + + class ImageAnnotation(BaseModel): + """Schema for vision-based image annotation.""" + + image_type: image_type_enum = Field( # type: ignore[valid-type] + ..., + description=f"The type of the image. Must be one of: {image_type_values}.", + ) + description: str = Field( + ..., + description=( + "A concise description of the image content. For diagrams and charts, " + "describe what is being illustrated. For tables, summarize the data. " + "For photos, describe the subject matter." + ), + ) + + return ImageAnnotation + + +def _get_api_key(processor_config: MistralOCRConfig) -> str: + """Get the Mistral API key from config or environment variable.""" + api_key = processor_config.api_key or os.environ.get("MISTRAL_API_KEY") + if not api_key: + error_msg = ( + "MISTRAL_API_KEY environment variable is not set and MistralOCRConfig.api_key is None." + ) + raise ValueError(error_msg) + return api_key + + +def _get_mistral_client(processor_config: MistralOCRConfig) -> Any: + """Get a Mistral client instance.""" + try: + from mistralai import Mistral + except ImportError as e: + error_msg = ( + "To use MistralOCR, please install the `mistral-ocr` extra: " + "`pip install raglite[mistral-ocr]` or `uv add raglite[mistral-ocr]`." + ) + raise ImportError(error_msg) from e + + api_key = _get_api_key(processor_config) + return Mistral(api_key=api_key) + + +def _get_response_format_converter() -> Any: + """Get the response_format_from_pydantic_model function from mistralai.""" + try: + from mistralai.extra import response_format_from_pydantic_model + except ImportError as e: + error_msg = ( + "To use MistralOCR, please install the `mistral-ocr` extra: " + "`uv add raglite[mistral-ocr]` or `pip install raglite[mistral-ocr]`." + ) + raise ImportError(error_msg) from e + return response_format_from_pydantic_model + + +def _encode_document_base64(doc_path: Path) -> tuple[str, str]: + """Encode a document as base64 with appropriate MIME type.""" + mime_type = _MIME_TYPES.get(doc_path.suffix.lower(), "application/octet-stream") + data = base64.standard_b64encode(doc_path.read_bytes()).decode("utf-8") + return data, mime_type + + +def _process_ocr_response( + ocr_response: Any, + *, + annotation_model: type[BaseModel], + include_image_descriptions: bool = True, + exclude_image_types: frozenset[str] | None = None, +) -> str: + """Convert MistralOCR response to markdown string. + + When include_image_descriptions is True and bbox_annotation_format was used, + image placeholders are replaced with their annotations. + + Parameters + ---------- + ocr_response + Response from Mistral OCR API. + annotation_model + The Pydantic model used to parse image annotations. + include_image_descriptions + Whether to replace image placeholders with annotations. + exclude_image_types + Set of image type strings to exclude from output. + + Returns + ------- + str + Document content as markdown. + """ + exclude_image_types = exclude_image_types or frozenset() + pages_md = [] + + for page in ocr_response.pages: + page_md = page.markdown + + if include_image_descriptions and page.images: + for img in page.images: + # Check if the image has an annotation (from bbox_annotation_format). + annotation = getattr(img, "image_annotation", None) + if annotation: + placeholder_pattern = rf"!\[[^\]]*\]\({re.escape(img.id)}\)" + # Parse annotation to check image type for filtering. + try: + parsed: Any = annotation_model.model_validate_json(annotation) + image_type = parsed.image_type.value + if image_type in exclude_image_types: + # Remove the image placeholder entirely. + page_md = re.sub(placeholder_pattern, "", page_md) + continue + replacement = f"[Image ({image_type}): {parsed.description}]" + except (ValueError, TypeError): + # If parsing fails, use raw annotation. + replacement = f"[Image: {annotation}]" + page_md = re.sub(placeholder_pattern, replacement, page_md) + + pages_md.append(page_md) + + return "\n\n".join(pages_md) + + +def mistral_ocr_to_markdown(doc_path: Path, *, processor_config: MistralOCRConfig) -> str: + """Convert a document to markdown using Mistral OCR with vision annotations. + + Uses Mistral's bbox_annotation_format to automatically describe images and + diagrams found in the document, making visual content searchable. + + Parameters + ---------- + doc_path + Path to the document file. + processor_config + MistralOCR processor configuration. + + Returns + ------- + str + Document content as GitHub Flavored Markdown with image descriptions. + + Raises + ------ + ImportError + If the mistralai package is not installed. + ValueError + If MISTRAL_API_KEY is not set and MistralOCRConfig.api_key is None. + MistralOCRError + If the OCR processing fails. + """ + data, mime_type = _encode_document_base64(doc_path) + + if doc_path.suffix.lower() in _IMAGE_EXTENSIONS: + document_payload = { + "type": "image_url", + "image_url": f"data:{mime_type};base64,{data}", + } + else: + # PDF, DOCX, PPTX. + document_payload = { + "type": "document_url", + "document_url": f"data:{mime_type};base64,{data}", + } + + # Build OCR request parameters. + ocr_params: dict[str, Any] = { + "model": processor_config.model, + "document": document_payload, + "include_image_base64": False, # We don't need base64, just annotations. + } + + annotation_model = _build_image_annotation_model(processor_config.image_types) + + try: + client = _get_mistral_client(processor_config) + # Add bbox annotation format if image descriptions are enabled. + if processor_config.include_image_descriptions: + response_format_from_pydantic_model = _get_response_format_converter() + ocr_params["bbox_annotation_format"] = response_format_from_pydantic_model( + annotation_model + ) + ocr_response = client.ocr.process(**ocr_params) + except (ImportError, ValueError): + raise + except Exception as e: + error_msg = f"MistralOCR failed to process {doc_path}: {e}" + raise MistralOCRError(error_msg) from e + + # Process response and replace image placeholders with annotations. + return _process_ocr_response( + ocr_response, + annotation_model=annotation_model, + include_image_descriptions=processor_config.include_image_descriptions, + exclude_image_types=processor_config.exclude_image_types, + ) diff --git a/tests/NVIDIA-report.pdf b/tests/NVIDIA-report.pdf new file mode 100644 index 00000000..59b1bd04 Binary files /dev/null and b/tests/NVIDIA-report.pdf differ diff --git a/tests/conftest.py b/tests/conftest.py index 94614e4f..a2d4b3e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,12 +5,16 @@ import tempfile from collections.abc import Generator from pathlib import Path +from typing import Any import pytest +from dotenv import load_dotenv from sqlalchemy import create_engine, text from raglite import Document, RAGLiteConfig, insert_documents +load_dotenv() + POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres" @@ -79,7 +83,7 @@ def database(request: pytest.FixtureRequest) -> str: params=[ pytest.param( ( - "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192", + "llama-cpp-python/unsloth/Qwen3-4B-GGUF/*Q4_K_M.gguf@8192", # mistralai/Ministral-3-3B-Instruct-2512 "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@512", # More context degrades performance. ), id="qwen3_4B-bge_m3", @@ -124,6 +128,6 @@ def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder) # Insert a document and update the index. doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. - metadata = {"type": "Paper", "topic": "Physics", "author": "Albert Einstein"} + metadata: dict[str, Any] = {"type": "Paper", "topic": "Physics", "author": "Albert Einstein"} insert_documents([Document.from_path(doc_path, **metadata)], config=db_config) return db_config diff --git a/tests/test_mistral_ocr.py b/tests/test_mistral_ocr.py new file mode 100644 index 00000000..0ec39e21 --- /dev/null +++ b/tests/test_mistral_ocr.py @@ -0,0 +1,70 @@ +"""Test MistralOCR document processing.""" + +import os +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from raglite import MistralOCRConfig +from raglite._mistral_ocr import ( + _build_image_annotation_model, + _process_ocr_response, + mistral_ocr_to_markdown, +) + + +def _mock_ocr_response(pages: list[tuple[str, list[tuple[str, str | None]]]]) -> SimpleNamespace: + """Create a mock OCR response.""" + return SimpleNamespace( + pages=[ + SimpleNamespace( + markdown=markdown, + images=[SimpleNamespace(id=img_id, image_annotation=ann) for img_id, ann in images], + ) + for markdown, images in pages + ] + ) + + +def test_process_ocr_response() -> None: + """Test OCR response processing: annotations, filtering, multi-page join.""" + diagram_ann = '{"image_type": "diagram", "description": "A flowchart"}' + logo_ann = '{"image_type": "logo", "description": "Company logo"}' + response = _mock_ocr_response( + [ + ( + "![](img-d.jpeg)\n\n![](img-l.jpeg)", + [ + ("img-d.jpeg", diagram_ann), + ("img-l.jpeg", logo_ann), + ], + ), + ("![](img-r.jpeg)", [("img-r.jpeg", "raw fallback text")]), # page 2 + ] + ) + annotation_model = _build_image_annotation_model(frozenset({"diagram", "logo"})) + result = _process_ocr_response( + response, + annotation_model=annotation_model, + include_image_descriptions=True, + exclude_image_types=frozenset({"logo"}), + ) + assert "[Image (diagram): A flowchart]" in result + assert "Company logo" not in result + assert "[Image: raw fallback text]" in result # fallback + page 2 joined + + +@pytest.mark.skipif(not os.environ.get("MISTRAL_API_KEY"), reason="MISTRAL_API_KEY not set") +@pytest.mark.slow +def test_real_pdf_conversion() -> None: + """Test Mistral OCR on NVIDIA report with tables, charts, and financial data.""" + doc_path = Path(__file__).parent / "NVIDIA-report.pdf" + result = mistral_ocr_to_markdown( + doc_path, + processor_config=MistralOCRConfig(include_image_descriptions=True), + ) + assert len(result) > 500 # noqa: PLR2004 # substantial multi-page content + assert "| " in result # tables rendered as markdown + assert "[Image (" in result # image descriptions with type classification + assert "$130.5 billion" in result # financial data from table cells