Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 212 additions & 56 deletions backend/app/rag/vision.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,129 @@
"""Image captioning / vision helpers for RAG pipeline.

Provides a simple, pluggable interface to generate textual descriptions
for images extracted from PDFs. By default it uses local OCR (pytesseract)
when available as a robust fallback. An external VLM provider (OpenAI)
can be integrated by setting `VISION_PROVIDER` and appropriate API keys
in settings; the provider hook is intentionally small and optional.
Caption resolution order for each image chunk:
1. Bounding-box proximity β€” nearest text block below/above the image in the PDF
(rich, zero-cost, works offline).
2. OCR (pytesseract) β€” when proximity yields nothing and tesseract is installed.
3. Placeholder β€” "Figure on page N (WxH px)" as a guaranteed non-empty fallback.

An optional OpenAI GPT-4o-mini vision hook is provided for deployments that set
VISION_PROVIDER=openai and OPENAI_API_KEY in settings.
"""
import base64
import logging
from typing import List, Dict, Any
from io import BytesIO
from typing import Any, Dict, List, Optional

import fitz # PyMuPDF

from app.config import get_settings

logger = logging.getLogger(__name__)
settings = get_settings()

# Minimum image area (pxΒ²) β€” smaller images are decorative and skipped.
_MIN_IMAGE_AREA = 1_000


# ── 1. Proximity-based caption extraction ────────────────────────────────────

def _find_caption_near_image(
page: fitz.Page,
img_bbox: fitz.Rect,
search_margin: float = 60.0,
) -> str:
"""Return the closest text block directly below (or above) an image rect."""
page_dict = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)
blocks = page_dict.get("blocks", [])

def _closest(region: fitz.Rect) -> str:
candidates = []
for block in blocks:
if block.get("type") != 0: # 0 == text block
continue
bx0, by0, bx1, by1 = block["bbox"]
if fitz.Rect(bx0, by0, bx1, by1).intersects(region):
text = " ".join(
span["text"]
for line in block.get("lines", [])
for span in line.get("spans", [])
).strip()
if text:
candidates.append((abs(by0 - img_bbox.y1), text))
if candidates:
return min(candidates, key=lambda t: t[0])[1]
return ""

# Search below first, fall back to above
below = fitz.Rect(img_bbox.x0, img_bbox.y1, img_bbox.x1, img_bbox.y1 + search_margin)
caption = _closest(below)
if caption:
return caption

above = fitz.Rect(img_bbox.x0, img_bbox.y0 - search_margin, img_bbox.x1, img_bbox.y0)
return _closest(above)


def extract_captions_from_pdf(filepath: str) -> List[Dict[str, Any]]:
"""Extract proximity-based image captions from a PDF.

Returns a list of dicts ordered by (page, figure_index):
{
"page": int, # 1-based
"figure_index": int, # 0-based within the page
"caption": str, # may be empty string
"bbox": list[float], # [x0, y0, x1, y1] normalised to [0, 1]
}
"""
results: List[Dict[str, Any]] = []
doc = fitz.open(filepath)

try:
for page_num, page in enumerate(doc):
W, H = float(page.rect.width), float(page.rect.height)
figure_index = 0

for img_info in page.get_images(full=True):
xref = img_info[0]
try:
rects = page.get_image_rects(xref)
if not rects:
continue
img_rect = rects[0]

if img_rect.width * img_rect.height < _MIN_IMAGE_AREA:
continue # skip decorative images

caption = _find_caption_near_image(page, img_rect)
results.append(
{
"page": page_num + 1,
"figure_index": figure_index,
"caption": caption,
"bbox": [
round(img_rect.x0 / W, 4),
round(img_rect.y0 / H, 4),
round(img_rect.x1 / W, 4),
round(img_rect.y1 / H, 4),
],
}
)
figure_index += 1

except Exception as exc:
logger.warning(
"Skipping image xref=%s on page %s: %s", xref, page_num + 1, exc
)
finally:
doc.close()

return results


# ── 2. OCR fallback ──────────────────────────────────────────────────────────

def _ocr_caption(image_bytes: bytes) -> str:
"""Try to produce a caption using pytesseract OCR; returns empty string if not available."""
"""Attempt OCR via pytesseract; returns empty string if unavailable."""
try:
from PIL import Image
import pytesseract
Expand All @@ -26,14 +132,66 @@ def _ocr_caption(image_bytes: bytes) -> str:

try:
img = Image.open(BytesIO(image_bytes)).convert("RGB")
text = pytesseract.image_to_string(img)
text = text.strip()
return text
except Exception as e:
logger.debug(f"OCR failed: {e}")
text = pytesseract.image_to_string(img).strip()
return (text[:500] + "...") if len(text) > 500 else text
except Exception as exc:
logger.debug("OCR failed: %s", exc)
return ""


# ── 3. Optional OpenAI GPT-4o-mini vision hook ───────────────────────────────

def _openai_caption(image_bytes: bytes) -> str:
"""Call OpenAI Chat Completions vision API; returns empty string on any failure."""
api_key = getattr(settings, "OPENAI_API_KEY", None)
if not api_key:
return ""

try:
from openai import OpenAI

client = OpenAI(api_key=api_key)
b64 = base64.b64encode(image_bytes).decode("utf-8")

response = client.chat.completions.create(
model="gpt-4o-mini",
max_tokens=120,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{b64}",
"detail": "low",
},
},
{
"type": "text",
"text": (
"Describe this figure or diagram in one concise sentence "
"suitable for use as a search index caption."
),
},
],
}
],
)
return response.choices[0].message.content.strip()

except Exception as exc:
logger.debug("OpenAI vision caption failed: %s", exc)
return ""


# ── Public API ───────────────────────────────────────────────────────────────

def caption_image(image_bytes: bytes, page: Optional[int] = None) -> str:
"""Generate a caption for a single image (bytes).

Resolution order: OpenAI (if configured) β†’ OCR β†’ placeholder.
"""
def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None = None) -> str | List[str]:
"""Generate a caption for a single image or a batch of images.

Expand All @@ -49,56 +207,54 @@ def caption_image(image_bytes: bytes | List[bytes], page: int | List[int] | None

# Placeholder for provider-based captioning (e.g., OpenAI / LLaVA hooks)
provider = getattr(settings, "VISION_PROVIDER", None)

if provider == "openai":
try:
import openai
# Minimal integration: attempt a text-only caption via responses if available.
# This is a best-effort hook; users should adapt to their provider's API.
api_key = getattr(settings, "OPENAI_API_KEY", None)
if api_key:
openai.api_key = api_key
# Use a generic prompt: "Describe the following image"
# Note: concrete multimodal API usage may vary across SDK versions.
resp = openai.Image.create(
prompt="Describe this image in one concise sentence.",
n=1,
# We do not re-upload image bytes here; this is a placeholder to show
# where provider code would be invoked. For production, follow
# provider docs for sending image data.
)
# openai.Image.create returns generated images, not captions β€” so skip.
except Exception:
# If provider integration fails, fall back to OCR below
logger.debug("OpenAI vision provider failed, falling back to OCR")

# Try OCR caption
caption = _openai_caption(image_bytes)
if caption:
return caption

ocr = _ocr_caption(image_bytes)
if ocr:
# Keep it short if very long
return (ocr[:500] + "...") if len(ocr) > 500 else ocr
return ocr

# Last-resort caption
if page:
return f"Image on page {page}."
return "Image."
# Derive dimensions for the placeholder
try:
pix = fitz.Pixmap(image_bytes)
dims = f"{pix.width}x{pix.height} px"
except Exception:
dims = "unknown size"

return f"Figure on page {page} ({dims})." if page else f"Figure ({dims})."


def generate_captions_for_chunks(chunks: List[Dict[str, Any]]) -> None:
"""Mutate chunks in-place: for any chunk containing `image_bytes` but empty `text`,
generate a caption and set `text`.
"""Mutate image chunks in-place: fill empty ``text`` with a caption.

Called by vectorstore.store_chunks() before embedding.
Proximity-based captions should already be written into chunk["image_caption"]
by document_ingestion.ingest_document() before this point.
This function handles the OCR / placeholder fallback for any remaining gaps.
"""
for chunk in chunks:
if chunk.get("image_bytes") and not chunk.get("text"):
try:
caption = caption_image(chunk["image_bytes"], page=chunk.get("page"))
chunk["text"] = caption
# Remove raw bytes to avoid accidentally serializing them later
chunk.pop("image_bytes", None)
chunk["is_image"] = True
chunk["image_caption"] = caption
except Exception as e:
logger.debug(f"Failed to caption image chunk: {e}")
# ensure we still mark it as image to avoid losing it
chunk.pop("image_bytes", None)
chunk["is_image"] = True
chunk.setdefault("text", f"Image on page {chunk.get('page')}")
if not chunk.get("image_bytes"):
continue
if chunk.get("text", "").strip():
continue # already captioned by proximity pass

try:
# Use pre-extracted proximity caption if available
caption = chunk.get("image_caption") or caption_image(
chunk["image_bytes"], page=chunk.get("page")
)
chunk["text"] = caption
chunk["is_image"] = True
chunk["image_caption"] = caption
except Exception as exc:
logger.debug("Failed to caption image chunk: %s", exc)
chunk["is_image"] = True
fallback = f"Image on page {chunk.get('page', '?')}"
chunk.setdefault("text", fallback)
chunk["image_caption"] = chunk["text"]
finally:
# Always strip raw bytes β€” never serialise them into ChromaDB
chunk.pop("image_bytes", None)
32 changes: 32 additions & 0 deletions backend/app/services/document_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,38 @@ def ingest_document(document_id: str, filepath: str, original_name: str, user_id
except TypeError:
chunks = chunk_document(filepath)

# ── Proximity caption pass (PDF only) ────────────────────────────────
# Write bounding-box-derived captions into image chunks BEFORE store_chunks()
# so generate_captions_for_chunks() in vectorstore.py only needs to handle
# the OCR / placeholder fallback for any images without adjacent text.
ext = filepath.rsplit(".", 1)[-1].lower()
if ext == "pdf":
try:
from app.rag.vision import extract_captions_from_pdf

pdf_captions = extract_captions_from_pdf(filepath)
# Build lookup: page -> [captions in figure_index order]
caption_map: dict = {}
for cap in pdf_captions:
caption_map.setdefault(cap["page"], []).append(cap)

fig_counters: dict = {}
for chunk in chunks:
if not chunk.get("image_bytes"):
continue
page = chunk.get("page", 1)
idx = fig_counters.get(page, 0)
page_caps = caption_map.get(page, [])
if idx < len(page_caps) and page_caps[idx]["caption"]:
chunk["image_caption"] = page_caps[idx]["caption"]
chunk["bbox"] = str(page_caps[idx]["bbox"])
fig_counters[page] = idx + 1
except Exception as exc:
logger.warning(
"Proximity caption extraction failed for %s: %s", document_id, exc
)
# ── End proximity caption pass ────────────────────────────────────────

if not chunks:
doc.status = "failed"
doc.processing_progress = 0
Expand Down
Loading