diff --git a/rag-service/main.py b/rag-service/main.py index e508e6a..0ae9ca8 100644 --- a/rag-service/main.py +++ b/rag-service/main.py @@ -488,6 +488,12 @@ async def global_exception_handler(request: Request, exc: Exception): SEMANTIC_CHUNK_SOFT_MAX = int(os.getenv("SEMANTIC_CHUNK_SOFT_MAX", "1200")) SEMANTIC_CHUNK_MERGE_MIN = int(os.getenv("SEMANTIC_CHUNK_MERGE_MIN", "150")) SEMANTIC_CHUNK_MERGE_MAX = int(os.getenv("SEMANTIC_CHUNK_MERGE_MAX", "1400")) +SEMANTIC_CHUNK_MAX_TINY_CHUNKS = int( + os.getenv("SEMANTIC_CHUNK_MAX_TINY_CHUNKS", "256") +) +SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES = int( + os.getenv("SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES", "384") +) SEMANTIC_CHUNK_SIMILARITY_THRESHOLD = float( os.getenv("SEMANTIC_CHUNK_SIMILARITY_THRESHOLD", "0.75") ) @@ -2134,6 +2140,16 @@ def _split_pass2( if not tiny_indices: return list(raw_chunks) # fast-path: nothing to merge + # Keep semantic merge work bounded for adversarial inputs that fragment a page + # into a large number of tiny chunks. Normal PDFs stay on the merge path. + if len(tiny_indices) > SEMANTIC_CHUNK_MAX_TINY_CHUNKS: + logger.warning( + "Semantic merge skipped tiny_chunks=%s limit=%s", + len(tiny_indices), + SEMANTIC_CHUNK_MAX_TINY_CHUNKS, + ) + return list(raw_chunks) + # Collect tiny chunks + their immediate neighbours for batch embedding neighbour_indices = set() for idx in tiny_indices: @@ -2144,6 +2160,16 @@ def _split_pass2( neighbour_indices.add(idx + 1) sorted_indices = sorted(neighbour_indices) + + if len(sorted_indices) > SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES: + logger.warning( + "Semantic merge skipped candidates=%s tiny_chunks=%s limit=%s", + len(sorted_indices), + len(tiny_indices), + SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES, + ) + return list(raw_chunks) + texts_to_embed = [raw_chunks[i] for i in sorted_indices] try: diff --git a/rag-service/tests/test_semantic_merge_limits.py b/rag-service/tests/test_semantic_merge_limits.py new file mode 100644 index 0000000..37a7a19 --- /dev/null +++ b/rag-service/tests/test_semantic_merge_limits.py @@ -0,0 +1,124 @@ +"""Regression tests for semantic merge resource limits.""" + +import os +import sys +import types +import unittest +from unittest.mock import MagicMock, patch + + +def _stub_heavy_deps(): + for name in [ + "torch", + "numpy", + "langchain_community", + "langchain_community.vectorstores", + "langchain_community.embeddings", + "transformers", + "rank_bm25", + "pdf_parse_worker", + ]: + if name not in sys.modules: + sys.modules[name] = types.ModuleType(name) + + torch_stub = sys.modules["torch"] + torch_stub.no_grad = lambda: _NullCtx() + torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False) + + tf = sys.modules["transformers"] + for attr in [ + "AutoConfig", + "AutoTokenizer", + "AutoModelForSeq2SeqLM", + "AutoModelForCausalLM", + "TextIteratorStreamer", + ]: + setattr(tf, attr, MagicMock()) + + sys.modules["rank_bm25"].BM25Okapi = MagicMock() + sys.modules["pdf_parse_worker"]._extract_pdf_text_worker = MagicMock() + + lc_vs = sys.modules["langchain_community.vectorstores"] + lc_vs.FAISS = MagicMock() + lc_emb = sys.modules["langchain_community.embeddings"] + lc_emb.HuggingFaceEmbeddings = MagicMock() + + +class _NullCtx: + def __enter__(self): + return self + + def __exit__(self, *_): + return False + + +os.environ.setdefault("JWT_SECRET", "test-secret-for-ci") +os.environ.setdefault("INTERNAL_RAG_TOKEN", "test-secret") + +_stub_heavy_deps() + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +import main # noqa: E402 + + +class _FakeEmbeddingModel: + def __init__(self, embeddings): + self.embeddings = embeddings + self.calls = [] + + def embed_documents(self, texts): + self.calls.append(list(texts)) + return self.embeddings + + +class TestSemanticMergeLimits(unittest.TestCase): + def test_split_pass2_merges_normal_small_input(self): + fake_model = _FakeEmbeddingModel([[1.0, 0.0], [1.0, 0.0]]) + + with patch.object(main, "get_embedding_model", return_value=fake_model): + result = main._split_pass2( + ["alpha", "beta"], + threshold=0.5, + merge_min=32, + merge_max=128, + ) + + self.assertEqual(result, ["alpha beta"]) + self.assertEqual(len(fake_model.calls), 1) + + def test_split_pass2_skips_merge_when_tiny_chunk_limit_is_exceeded(self): + fake_model = _FakeEmbeddingModel([]) + raw_chunks = [f"tiny-{idx}" for idx in range(5)] + + with patch.object(main, "SEMANTIC_CHUNK_MAX_TINY_CHUNKS", 3), \ + patch.object(main, "get_embedding_model", return_value=fake_model): + result = main._split_pass2( + raw_chunks, + threshold=0.5, + merge_min=64, + merge_max=256, + ) + + self.assertEqual(result, raw_chunks) + self.assertEqual(fake_model.calls, []) + + def test_split_pass2_skips_merge_when_candidate_limit_is_exceeded(self): + fake_model = _FakeEmbeddingModel([]) + raw_chunks = ["a", "long chunk", "b", "long chunk", "c"] + + with patch.object(main, "SEMANTIC_CHUNK_MAX_TINY_CHUNKS", 10), \ + patch.object(main, "SEMANTIC_CHUNK_MAX_MERGE_CANDIDATES", 4), \ + patch.object(main, "get_embedding_model", return_value=fake_model): + result = main._split_pass2( + raw_chunks, + threshold=0.5, + merge_min=32, + merge_max=256, + ) + + self.assertEqual(result, raw_chunks) + self.assertEqual(fake_model.calls, []) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file