Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion backend/app/routes/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncio
import concurrent.futures
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, List
from pathlib import Path
import shutil
import socket
Expand All @@ -36,6 +36,8 @@
DocumentRename,
ChunkSettings,
UploadUrl,
BatchUploadResponse,
BatchUploadResult,
)
from app.auth import get_current_user
from app.config import get_settings
Expand Down Expand Up @@ -281,6 +283,116 @@ async def upload_document(

return DocumentResponse.model_validate(document).model_copy(update={"task_id": task_id})

@router.post("/upload/batch", response_model=BatchUploadResponse, status_code=status.HTTP_202_ACCEPTED)
async def upload_documents_batch(
files: List[UploadFile] = File(...),
background_tasks: BackgroundTasks = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Upload multiple documents simultaneously and enqueue parallel RAG processing.

Accepts up to 20 files in a single multipart/form-data request. Each file
is validated and saved independently; failures for individual files do not
abort the entire batch. A Celery task is enqueued for each successfully
saved file, allowing parallel ingestion without blocking the API response.

Args:
files: List of uploaded files, provided as multipart/form-data fields.
background_tasks: FastAPI BackgroundTasks instance for in-process fallback.
user: The currently authenticated user, injected by get_current_user.
db: Database session, injected by get_db.

Returns:
BatchUploadResponse: Per-file results with succeeded/failed counts.

Raises:
HTTPException 400: If no files are provided or batch exceeds 20 files.
"""
if not files:
raise ValidationException("No files provided")

MAX_BATCH_SIZE = 20
if len(files) > MAX_BATCH_SIZE:
raise ValidationException(f"Batch upload limited to {MAX_BATCH_SIZE} files at once")

user_dir = os.path.join(settings.UPLOAD_DIR, user.id)
os.makedirs(user_dir, exist_ok=True)

results: List[BatchUploadResult] = []

for file in files:
filename = file.filename or "unknown"
try:
# Validate extension before paying the disk I/O cost
if not filename:
raise ValidationException("No filename provided")

ext = filename.rsplit(".", 1)[-1].lower()
if ext not in settings.ALLOWED_EXTENSIONS:
raise ValidationException(
f"File type '.{ext}' not supported. Allowed: {', '.join(settings.ALLOWED_EXTENSIONS)}"
)

temp_path = await validate_upload(file)

stored_filename = f"{uuid.uuid4().hex}.{ext}"
filepath = os.path.join(user_dir, stored_filename)
shutil.move(temp_path, filepath)

file_size = Path(filepath).stat().st_size

document = Document(
user_id=user.id,
filename=stored_filename,
original_name=filename,
file_size=file_size,
status="pending",
)
db.add(document)
db.commit()
db.refresh(document)

task_id = None
try:
task = process_document.delay(
document_id=document.id,
filepath=filepath,
original_name=filename,
user_id=user.id,
)
task_id = task.id
except Exception as e:
logger.warning(f"Celery queue failed for {filename}, falling back: {e}")
if background_tasks:
background_tasks.add_task(
ingest_document,
document_id=document.id,
filepath=filepath,
original_name=filename,
user_id=user.id,
)
task_id = f"local_{uuid.uuid4().hex}"

doc_response = DocumentResponse.model_validate(document).model_copy(
update={"task_id": task_id}
)
results.append(BatchUploadResult(filename=filename, success=True, document=doc_response))

except Exception as exc:
logger.warning(f"Batch upload: file '{filename}' failed — {exc}")
results.append(BatchUploadResult(filename=filename, success=False, error=str(exc)))

succeeded = sum(1 for r in results if r.success)
return BatchUploadResponse(
results=results,
total=len(results),
succeeded=succeeded,
failed=len(results) - succeeded,
)


@router.post("/urlupload", status_code=status.HTTP_202_ACCEPTED)
async def upload_document_url(
payload: UploadUrl,
Expand Down
14 changes: 14 additions & 0 deletions backend/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ class DocumentListResponse(BaseModel):
pages: int


class BatchUploadResult(BaseModel):
filename: str
success: bool
document: Optional[DocumentResponse] = None
error: Optional[str] = None


class BatchUploadResponse(BaseModel):
results: List[BatchUploadResult]
total: int
succeeded: int
failed: int


# Admin

class DiskUsageResponse(BaseModel):
Expand Down
193 changes: 193 additions & 0 deletions backend/tests/test_batch_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""
Tests for POST /documents/upload/batch — issue #435.
"""
import io
import os
from unittest.mock import MagicMock, patch

import pytest


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_pdf_bytes() -> bytes:
"""Return the minimal bytes of a valid single-page PDF."""
return (
b"%PDF-1.4\n"
b"1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj\n"
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj\n"
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj\n"
b"xref\n0 4\n0000000000 65535 f \n"
b"0000000009 00000 n \n"
b"0000000058 00000 n \n"
b"0000000115 00000 n \n"
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n190\n%%EOF"
)


def _pdf_file(name: str = "test.pdf") -> tuple[str, tuple]:
"""Return a (field_name, (filename, bytes_io, mimetype)) tuple for requests."""
return ("files", (name, io.BytesIO(_make_pdf_bytes()), "application/pdf"))


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------

class TestBatchUpload:
URL = "/documents/upload/batch"

def test_no_auth_rejected(self, client):
response = client.post(self.URL, files=[_pdf_file()])
assert response.status_code == 401

def test_empty_file_list_rejected(self, client, auth_headers, monkeypatch, tmp_path):
monkeypatch.setenv("UPLOAD_DIR", str(tmp_path))
# FastAPI will return 422 when the required `files` field is missing
response = client.post(self.URL, headers=auth_headers)
assert response.status_code == 422

def test_too_many_files_rejected(self, client, auth_headers, monkeypatch, tmp_path):
monkeypatch.setenv("UPLOAD_DIR", str(tmp_path))
with (
patch("app.routes.documents.validate_upload", side_effect=Exception("mocked")),
):
files = [_pdf_file(f"file{i}.pdf") for i in range(21)]
response = client.post(self.URL, headers=auth_headers, files=files)
# Our ValidationException maps to 400
assert response.status_code == 400

def test_single_file_success(self, client, auth_headers, monkeypatch, tmp_path):
upload_dir = str(tmp_path)
monkeypatch.setenv("UPLOAD_DIR", upload_dir)

fake_temp = tmp_path / "fake_tmp.pdf"
fake_temp.write_bytes(_make_pdf_bytes())

with (
patch("app.routes.documents.settings") as mock_settings,
patch("app.routes.documents.validate_upload", return_value=str(fake_temp)),
patch("app.routes.documents.process_document") as mock_task,
patch("app.routes.documents.shutil.move"),
):
mock_settings.UPLOAD_DIR = upload_dir
mock_settings.ALLOWED_EXTENSIONS = {"pdf", "docx", "txt", "md"}

fake_celery_result = MagicMock()
fake_celery_result.id = "celery-task-id-1"
mock_task.delay.return_value = fake_celery_result

response = client.post(
self.URL,
headers=auth_headers,
files=[_pdf_file("doc1.pdf")],
)

assert response.status_code == 202
body = response.json()
assert body["total"] == 1
assert body["succeeded"] == 1
assert body["failed"] == 0
assert body["results"][0]["success"] is True
assert body["results"][0]["filename"] == "doc1.pdf"

def test_multi_file_partial_failure(self, client, auth_headers, monkeypatch, tmp_path):
"""One valid file + one file that fails validation → partial success."""
upload_dir = str(tmp_path)
monkeypatch.setenv("UPLOAD_DIR", upload_dir)

fake_temp = tmp_path / "fake_tmp.pdf"
fake_temp.write_bytes(_make_pdf_bytes())

call_count = {"n": 0}

async def fake_validate(file):
call_count["n"] += 1
if call_count["n"] == 1:
return str(fake_temp)
raise Exception("Corrupted or invalid file")

with (
patch("app.routes.documents.settings") as mock_settings,
patch("app.routes.documents.validate_upload", side_effect=fake_validate),
patch("app.routes.documents.process_document") as mock_task,
patch("app.routes.documents.shutil.move"),
):
mock_settings.UPLOAD_DIR = upload_dir
mock_settings.ALLOWED_EXTENSIONS = {"pdf", "docx", "txt", "md"}

fake_celery_result = MagicMock()
fake_celery_result.id = "celery-task-id-2"
mock_task.delay.return_value = fake_celery_result

response = client.post(
self.URL,
headers=auth_headers,
files=[_pdf_file("good.pdf"), _pdf_file("bad.pdf")],
)

assert response.status_code == 202
body = response.json()
assert body["total"] == 2
assert body["succeeded"] == 1
assert body["failed"] == 1

successes = [r for r in body["results"] if r["success"]]
failures = [r for r in body["results"] if not r["success"]]
assert len(successes) == 1
assert len(failures) == 1
assert failures[0]["error"] is not None

def test_celery_fallback_to_background_task(self, client, auth_headers, monkeypatch, tmp_path):
"""When Celery is unavailable the endpoint falls back gracefully."""
upload_dir = str(tmp_path)
monkeypatch.setenv("UPLOAD_DIR", upload_dir)

fake_temp = tmp_path / "fake_tmp.pdf"
fake_temp.write_bytes(_make_pdf_bytes())

with (
patch("app.routes.documents.settings") as mock_settings,
patch("app.routes.documents.validate_upload", return_value=str(fake_temp)),
patch("app.routes.documents.process_document") as mock_task,
patch("app.routes.documents.shutil.move"),
):
mock_settings.UPLOAD_DIR = upload_dir
mock_settings.ALLOWED_EXTENSIONS = {"pdf", "docx", "txt", "md"}
mock_task.delay.side_effect = Exception("Redis unavailable")

response = client.post(
self.URL,
headers=auth_headers,
files=[_pdf_file("celery_fail.pdf")],
)

assert response.status_code == 202
body = response.json()
assert body["succeeded"] == 1
# task_id should start with "local_" when falling back
assert body["results"][0]["document"]["task_id"].startswith("local_")

def test_unsupported_extension_counted_as_failure(self, client, auth_headers, monkeypatch, tmp_path):
upload_dir = str(tmp_path)
monkeypatch.setenv("UPLOAD_DIR", upload_dir)

with (
patch("app.routes.documents.settings") as mock_settings,
):
mock_settings.UPLOAD_DIR = upload_dir
mock_settings.ALLOWED_EXTENSIONS = {"pdf", "docx", "txt", "md"}

response = client.post(
self.URL,
headers=auth_headers,
files=[("files", ("malware.exe", io.BytesIO(b"MZ"), "application/octet-stream"))],
)

assert response.status_code == 202
body = response.json()
assert body["total"] == 1
assert body["failed"] == 1
assert body["succeeded"] == 0
Loading