Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions rag-service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ async def global_exception_handler(request: Request, exc: Exception):
SEMANTIC_CHUNK_SOFT_MAX = int(os.getenv("SEMANTIC_CHUNK_SOFT_MAX", "1200"))
SEMANTIC_CHUNK_MERGE_MIN = int(os.getenv("SEMANTIC_CHUNK_MERGE_MIN", "150"))
SEMANTIC_CHUNK_MERGE_MAX = int(os.getenv("SEMANTIC_CHUNK_MERGE_MAX", "1400"))
SEMANTIC_CHUNK_MAX_TINY_CHUNKS = int(
os.getenv("SEMANTIC_CHUNK_MAX_TINY_CHUNKS", "256")
)
SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES = int(
os.getenv("SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES", "384")
)
SEMANTIC_CHUNK_SIMILARITY_THRESHOLD = float(
os.getenv("SEMANTIC_CHUNK_SIMILARITY_THRESHOLD", "0.75")
)
Expand Down Expand Up @@ -2134,6 +2140,16 @@ def _split_pass2(
if not tiny_indices:
return list(raw_chunks) # fast-path: nothing to merge

# Keep semantic merge work bounded for adversarial inputs that fragment a page
# into a large number of tiny chunks. Normal PDFs stay on the merge path.
if len(tiny_indices) > SEMANTIC_CHUNK_MAX_TINY_CHUNKS:
logger.warning(
"Semantic merge skipped tiny_chunks=%s limit=%s",
len(tiny_indices),
SEMANTIC_CHUNK_MAX_TINY_CHUNKS,
)
return list(raw_chunks)
Comment on lines +2143 to +2151

# Collect tiny chunks + their immediate neighbours for batch embedding
neighbour_indices = set()
for idx in tiny_indices:
Expand All @@ -2144,6 +2160,16 @@ def _split_pass2(
neighbour_indices.add(idx + 1)

sorted_indices = sorted(neighbour_indices)

if len(sorted_indices) > SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES:
logger.warning(
"Semantic merge skipped candidates=%s tiny_chunks=%s limit=%s",
len(sorted_indices),
len(tiny_indices),
SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES,
)
return list(raw_chunks)
Comment on lines +2164 to +2171

texts_to_embed = [raw_chunks[i] for i in sorted_indices]

try:
Expand Down
124 changes: 124 additions & 0 deletions rag-service/tests/test_semantic_merge_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Regression tests for semantic merge resource limits."""

import os
import sys
import types
import unittest
from unittest.mock import MagicMock, patch


def _stub_heavy_deps():
for name in [
"torch",
"numpy",
"langchain_community",
"langchain_community.vectorstores",
"langchain_community.embeddings",
"transformers",
"rank_bm25",
"pdf_parse_worker",
]:
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)
Comment on lines +10 to +22

torch_stub = sys.modules["torch"]
torch_stub.no_grad = lambda: _NullCtx()
torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False)

tf = sys.modules["transformers"]
for attr in [
"AutoConfig",
"AutoTokenizer",
"AutoModelForSeq2SeqLM",
"AutoModelForCausalLM",
"TextIteratorStreamer",
]:
setattr(tf, attr, MagicMock())

sys.modules["rank_bm25"].BM25Okapi = MagicMock()
sys.modules["pdf_parse_worker"]._extract_pdf_text_worker = MagicMock()

lc_vs = sys.modules["langchain_community.vectorstores"]
lc_vs.FAISS = MagicMock()
lc_emb = sys.modules["langchain_community.embeddings"]
lc_emb.HuggingFaceEmbeddings = MagicMock()


class _NullCtx:
def __enter__(self):
return self

def __exit__(self, *_):
return False


os.environ.setdefault("JWT_SECRET", "test-secret-for-ci")
os.environ.setdefault("INTERNAL_RAG_TOKEN", "test-secret")

_stub_heavy_deps()
Comment on lines +55 to +58

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import main # noqa: E402
Comment on lines +60 to +61


class _FakeEmbeddingModel:
def __init__(self, embeddings):
self.embeddings = embeddings
self.calls = []

def embed_documents(self, texts):
self.calls.append(list(texts))
return self.embeddings


class TestSemanticMergeLimits(unittest.TestCase):
def test_split_pass2_merges_normal_small_input(self):
fake_model = _FakeEmbeddingModel([[1.0, 0.0], [1.0, 0.0]])

with patch.object(main, "get_embedding_model", return_value=fake_model):
result = main._split_pass2(
["alpha", "beta"],
threshold=0.5,
merge_min=32,
merge_max=128,
)

self.assertEqual(result, ["alpha beta"])
self.assertEqual(len(fake_model.calls), 1)

def test_split_pass2_skips_merge_when_tiny_chunk_limit_is_exceeded(self):
fake_model = _FakeEmbeddingModel([])
raw_chunks = [f"tiny-{idx}" for idx in range(5)]

with patch.object(main, "SEMANTIC_CHUNK_MAX_TINY_CHUNKS", 3), \
patch.object(main, "get_embedding_model", return_value=fake_model):
result = main._split_pass2(
raw_chunks,
threshold=0.5,
merge_min=64,
merge_max=256,
)

self.assertEqual(result, raw_chunks)
self.assertEqual(fake_model.calls, [])

def test_split_pass2_skips_merge_when_candidate_limit_is_exceeded(self):
fake_model = _FakeEmbeddingModel([])
raw_chunks = ["a", "long chunk", "b", "long chunk", "c"]

with patch.object(main, "SEMANTIC_CHUNK_MAX_TINY_CHUNKS", 10), \
patch.object(main, "SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES", 4), \
patch.object(main, "get_embedding_model", return_value=fake_model):
result = main._split_pass2(
raw_chunks,
threshold=0.5,
merge_min=32,
merge_max=256,
)

self.assertEqual(result, raw_chunks)
self.assertEqual(fake_model.calls, [])


if __name__ == "__main__":
unittest.main()
Loading