diff --git a/docs/tutorials/v2-complete-tutorial.md b/docs/tutorials/v2-complete-tutorial.md new file mode 100644 index 0000000..707440b --- /dev/null +++ b/docs/tutorials/v2-complete-tutorial.md @@ -0,0 +1,993 @@ +# lang2sql v2 Complete Tutorial + +이 문서는 `src/lang2sql` 기반 v2만 대상으로 합니다. +아래 순서대로 따라가면 초급에서 고급까지 모든 지원 경로를 직접 테스트할 수 있습니다. + +- 난이도 상승 순서: 스크롤할수록 어려워집니다. +- 코드 예제는 현재 레포 구현 기준으로 작성되었습니다. +- 범위 외 기능(예: v2 내장 FAISS/PGVector)은 "커스텀 어댑터" 방식으로만 설명합니다. + +--- + +## 목차 + +1. 목표와 범위 +1-1. Why lang2sql +2. 사전 준비 +3. 설치 +4. API 키 설정 +5. 샘플 DB 준비 +5-1. 샘플 문서 자동 생성 +6. 가장 쉬운 로컬 스모크 테스트 (API 키 없이) +7. BaselineNL2SQL 기본 사용 (KeywordRetriever) +8. 실제 LLM 연결 (OpenAI / Anthropic) +9. VectorRetriever 기초 (빠른 시작) +10. 문서 파싱: MarkdownLoader / PlainTextLoader / DirectoryLoader / PDFLoader +11. 명시적 파이프라인: from_chunks() 패턴 +12. 청킹 전략 교체: Recursive vs Semantic +13. HybridRetriever / HybridNL2SQL +14. 임베딩 교체 테스트 (v2 내장 + 사용자 구현) +15. 벡터 스토어 교체 테스트 (v2 내장 + 사용자 구현) +16. 완전 수동 Advanced Flow 조합 +17. 관측성(Tracing)과 디버깅 +18. Best Practices 체크리스트 +19. 트러블슈팅 + +--- + +## 1) 목표와 범위 + +이 튜토리얼의 목표: + +- v2 코어 사용법을 처음 설치부터 끝까지 실습 +- 기본 플로우, 벡터 인덱싱, 문서 로딩, 하이브리드 검색까지 검증 +- 고급 사용자용 확장 포인트(Embedding/VectorStore/Chunker)를 직접 갈아끼워 테스트 + +중요 범위: + +- 이 문서에서 "v2 공식 내장"은 아래만 의미합니다. + - Embedding: `OpenAIEmbedding` + - Vector store: `InMemoryVectorStore` +- 그 외는 Protocol 기반 "사용자 구현 어댑터" 방식으로 테스트합니다. + +--- + +## 1-1) Why lang2sql + +다른 라이브러리와 비교했을 때, v2에서 강조하는 포인트는 아래입니다. + +- **운영 친화 기본선**: `Retriever -> Generator -> Executor` 경로가 짧고 실패 지점이 명확합니다. +- **명시적 인덱싱 파이프라인**: `chunker.split(docs)` → `VectorRetriever.from_chunks(chunks)` 패턴으로 split/embed/store 각 단계가 코드에 보입니다. +- **확장 포인트 분리**: 코어는 Protocol 기반이라 임베딩/벡터스토어/청커를 교체해도 플로우 코드는 유지됩니다. +- **관측성 내장**: Hook 이벤트(`start/end/error`, duration)를 컴포넌트 단위로 수집할 수 있습니다. + +주의: +- v2는 "모든 기능을 직접 구현한 거대 프레임워크"가 목적이 아닙니다. +- 코어 오케스트레이션과 운영 안정성에 집중하고, 고급 백엔드는 교체 가능한 어댑터로 다룹니다. + +--- + +## 2) 사전 준비 + +권장 환경: + +- Python 3.11+ +- `uv` 또는 `pip` +- (선택) OpenAI API 키, Anthropic API 키 + +--- + +## 3) 설치 + +### 옵션 A: pip +```bash +pip install lang2sql +``` + +### 옵션 B: 소스 기준 개발 설치 +```bash +uv venv --python 3.11 +source .venv/bin/activate +uv pip install -e . +``` + +--- + +## 4) API 키 설정 + +OpenAI/Anthropic SDK는 환경변수를 기본으로 읽습니다. + +```bash +export OPENAI_API_KEY="sk-..." +export ANTHROPIC_API_KEY="sk-ant-..." +``` + +--- + +## 5) 샘플 DB 준비 + +튜토리얼 전체를 재현하려면 샘플 DB를 먼저 만듭니다. + +```bash +python scripts/setup_sample_db.py +``` + +완료되면 프로젝트 루트에 `sample.db`가 생성됩니다. + +--- + +## 5-1) 샘플 문서 자동 생성 + +문서 로더/청킹/벡터 인덱싱 실습용 파일을 자동으로 생성합니다. + +```bash +python scripts/setup_sample_docs.py +``` + +생성 위치(기본): +- `docs/business/revenue.md` +- `docs/business/order_status_policy.md` +- `docs/business/rules.txt` + +기존 파일이 있을 때 덮어쓰려면: + +```bash +python scripts/setup_sample_docs.py --force +``` + +--- + +## 6) 가장 쉬운 로컬 스모크 테스트 (API 키 없이) + +먼저 외부 의존 없이 파이프라인 구조가 동작하는지 확인합니다. + +```python +from lang2sql import BaselineNL2SQL + +# 1) LLM을 흉내 내는 테스트 더블 +class FakeLLM: + def invoke(self, messages): + # SQLGenerator는 ```sql ... ``` 블록을 기대합니다. + return "```sql\nSELECT 1 AS ok\n```" + +# 2) DB를 흉내 내는 테스트 더블 +class FakeDB: + def execute(self, sql): + # SQLExecutor가 실행한 SQL을 받아 고정 결과를 반환 + return [{"ok": 1, "sql_received": sql}] + +catalog = [ + { + "name": "orders", + "description": "주문 테이블", + "columns": {"order_id": "주문 ID", "amount": "주문 금액"}, + } +] + +pipeline = BaselineNL2SQL( + catalog=catalog, + llm=FakeLLM(), # 외부 API 없이 테스트 + db=FakeDB(), # 실제 DB 없이 테스트 + db_dialect="sqlite", +) + +rows = pipeline.run("주문 건수 알려줘") +print(rows) +``` + +이 단계의 목적: + +- 설치/임포트 문제 없는지 확인 +- `Retriever -> Generator -> Executor` 기본 경로 확인 + +--- + +## 7) BaselineNL2SQL 기본 사용 (KeywordRetriever) + +이제 실제 DB에 연결합니다. + +```python +from lang2sql import BaselineNL2SQL +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.llm import OpenAILLM + +catalog = [ + { + "name": "orders", + "description": "고객 주문 정보", + "columns": { + "order_id": "주문 고유 ID", + "customer_id": "고객 ID", + "order_date": "주문 일시", + "amount": "주문 금액", + "status": "주문 상태", + }, + }, + { + "name": "customers", + "description": "고객 마스터", + "columns": { + "customer_id": "고객 ID", + "name": "고객명", + "grade": "고객 등급", + }, + }, +] + +pipeline = BaselineNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + db_dialect="sqlite", +) + +rows = pipeline.run("지난달 주문 건수") +print(rows) +``` + +주의: + +- 현재 `BaselineNL2SQL`은 키워드 기반 리트리버를 내부에서 사용합니다. +- 벡터 검색 기반 플로우는 아래 `HybridNL2SQL` 또는 수동 조합을 사용하세요. + +--- + +## 8) 실제 LLM 연결 (OpenAI / Anthropic) + +LLM 백엔드는 교체 가능합니다. + +### OpenAI LLM +```python +from lang2sql.integrations.llm import OpenAILLM +llm = OpenAILLM(model="gpt-4o-mini") +``` + +### Anthropic LLM +```python +from lang2sql.integrations.llm import AnthropicLLM +llm = AnthropicLLM(model="claude-sonnet-4-6") +``` + +둘 다 `LLMPort.invoke(messages)` 계약을 따르므로 플로우 코드는 동일합니다. + +--- + +## 9) VectorRetriever 기초 + +두 가지 생성 패턴을 제공합니다. 상황에 맞게 선택하세요. + +### 9-1. from_sources() — 원터치 (빠른 시작) + +`VectorRetriever.from_sources()`는 split/embed/store를 한 번에 처리합니다. + +```python +from lang2sql import VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding + +catalog = [ + { + "name": "orders", + "description": "주문 정보 테이블", + "columns": { + "order_id": "주문 ID", + "amount": "주문 금액", + "discount_amount": "할인 금액", + "order_date": "주문 날짜", + }, + } +] + +docs = [ + { + "id": "biz_rules", + "title": "매출 정의", + "content": "매출은 반품 제외 순매출이다. 할인 금액은 discount_amount 컬럼을 사용한다.", + "source": "docs/biz_rules.md", + } +] + +retriever = VectorRetriever.from_sources( + catalog=catalog, + documents=docs, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + top_n=5, + score_threshold=0.0, +) + +result = retriever.run("지난달 할인 매출") +print("schemas:", [s["name"] for s in result.schemas]) +print("context:", result.context) +``` + +내부에서 일어나는 일: +1. catalog/docs를 각각 `CatalogChunker`, `RecursiveCharacterChunker`로 split +2. `from_chunks()`를 호출해 embed + store +3. 검색 가능한 `VectorRetriever` 반환 + +### 9-2. from_chunks() — 명시적 파이프라인 (LangChain 스타일) + +split 단계를 직접 제어하고 싶을 때 사용합니다. + +```python +from lang2sql import CatalogChunker, RecursiveCharacterChunker, VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding + +# split 단계가 코드에 보임 +catalog_chunks = CatalogChunker().split(catalog) +doc_chunks = RecursiveCharacterChunker(chunk_size=800, chunk_overlap=80).split(docs) + +# chunks를 자유롭게 조합 +retriever = VectorRetriever.from_chunks( + catalog_chunks + doc_chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + top_n=5, +) + +result = retriever.run("지난달 할인 매출") +print("schemas:", [s["name"] for s in result.schemas]) +print("context:", result.context) +``` + +`from_chunks()`의 장점: +- catalog/doc 외의 소스도 `IndexedChunk`를 직접 생성해 자유롭게 합칠 수 있음 +- 커스텀 chunker와 조합하기 쉬움 +- 증분 추가도 동일 패턴: `retriever.add(chunker.split(new_docs))` + +--- + +## 10) 문서 파싱: MarkdownLoader / PlainTextLoader / DirectoryLoader / PDFLoader + +문서를 수동으로 리스트 작성하지 않고 파일에서 읽어올 수 있습니다. + +### 10-1. MarkdownLoader +```python +from lang2sql import MarkdownLoader + +docs = MarkdownLoader().load("docs/business/revenue.md") +print(docs[0]["id"], docs[0]["title"], docs[0]["source"]) +``` + +### 10-2. PlainTextLoader +```python +from lang2sql import PlainTextLoader + +docs = PlainTextLoader().load("docs/business/rules.txt") +print(docs[0]["id"], docs[0]["title"], docs[0]["source"]) +``` + +### 10-3. DirectoryLoader (권장) +```python +from lang2sql import DirectoryLoader + +# 기본 매핑: +# .md -> MarkdownLoader +# .txt -> PlainTextLoader +docs = DirectoryLoader("docs/business").load() +print("loaded docs:", len(docs)) +for d in docs[:3]: + print(d["id"], d["source"]) +``` + +### 10-4. 로더 결과를 벡터 인덱싱에 연결 +```python +from lang2sql import VectorRetriever, DirectoryLoader +from lang2sql.integrations.embedding import OpenAIEmbedding + +docs = DirectoryLoader("docs/business").load() + +retriever = VectorRetriever.from_sources( + catalog=catalog, + documents=docs, + embedding=OpenAIEmbedding(), +) +``` + +### 10-5. Loader → split → from_chunks 플로우를 코드로 명시 + +```python +from lang2sql import ( + CatalogChunker, + DirectoryLoader, + RecursiveCharacterChunker, + VectorRetriever, +) +from lang2sql.integrations.embedding import OpenAIEmbedding + +catalog = [ + { + "name": "orders", + "description": "주문 정보", + "columns": { + "order_id": "주문 ID", + "order_date": "주문 일시", + "amount": "결제 금액", + "discount_amount": "할인 금액", + }, + } +] + +# 1) document loader +docs = DirectoryLoader("docs/business").load() + +# 2) 각 소스를 명시적으로 split +catalog_chunks = CatalogChunker().split(catalog) +doc_chunks = RecursiveCharacterChunker(chunk_size=800, chunk_overlap=80).split(docs) + +# 3) from_chunks: embed + store를 한 번에 +retriever = VectorRetriever.from_chunks( + catalog_chunks + doc_chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + top_n=5, +) + +result = retriever.run("지난달 순매출 계산 규칙") +print("total chunks:", len(catalog_chunks) + len(doc_chunks)) +print("schemas:", [s["name"] for s in result.schemas]) +print("context sample:", result.context[:2]) +``` + +정리: +- `DirectoryLoader`가 `TextDocument`를 만든다. +- `chunker.split(docs)`가 `list[IndexedChunk]`를 반환한다. +- `from_chunks()`가 embed + upsert + registry를 처리한다. +- `VectorRetriever`는 쿼리 시 검색만 수행한다. + +### 10-6. 완전 수동 플로우 (내부 구조 직접 확인) + +`chunk → embed → vectorstore.upsert`를 눈으로 확인하려면 아래처럼 직접 실행하면 됩니다. + +```python +from lang2sql import CatalogChunker, RecursiveCharacterChunker, VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.vectorstore import InMemoryVectorStore + +# 1) chunk — .split() 배치 호출 +catalog_chunks = CatalogChunker(max_columns_per_chunk=20).split(catalog) +doc_chunks = RecursiveCharacterChunker(chunk_size=800, chunk_overlap=80).split(docs) +chunks = catalog_chunks + doc_chunks + +# 2) embed +embedding = OpenAIEmbedding(model="text-embedding-3-small") +texts = [c["text"] for c in chunks] +vectors = embedding.embed_texts(texts) + +# 3) vector store 저장(upsert) +store = InMemoryVectorStore() +ids = [c["chunk_id"] for c in chunks] +store.upsert(ids, vectors) + +# 4) registry 구성 +registry = {c["chunk_id"]: c for c in chunks} + +# 5) retrieval 검증 +retriever = VectorRetriever( + vectorstore=store, + embedding=embedding, + registry=registry, + top_n=5, +) +result = retriever.run("지난달 순매출 계산 규칙") +print("schemas:", [s["name"] for s in result.schemas]) +print("context:", result.context[:2]) +``` + +### 10-7. PDFLoader — PDF 파일 인덱싱 + +PDF는 `integrations.loaders`에서 opt-in으로 제공합니다 (`pip install pymupdf` 필요). + +```python +from lang2sql import CatalogChunker, DirectoryLoader, MarkdownLoader, VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.loaders import PDFLoader + +# PDFLoader를 DirectoryLoader에 추가 등록 +docs = DirectoryLoader( + "docs/", + loaders={ + ".md": MarkdownLoader(), + ".pdf": PDFLoader(), + }, +).load() + +# 이후 일반 from_chunks 패턴과 동일 +from lang2sql import RecursiveCharacterChunker + +chunks = ( + CatalogChunker().split(catalog) + + RecursiveCharacterChunker().split(docs) +) +retriever = VectorRetriever.from_chunks( + chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), +) +``` + +PDFLoader는 페이지 단위로 `TextDocument`를 생성합니다: +- `id`: `"{filename}__p{page_number}"` (1-indexed) +- `title`: `"{filename} page {page_number}"` +- `content`: 해당 페이지 추출 텍스트 + +--- + +## 11) 명시적 파이프라인: from_chunks() 패턴 + +고급 사용자는 split/embed/store 각 단계를 코드에서 명시적으로 제어합니다. + +### 11-1. 기본 from_chunks() 패턴 + +```python +from lang2sql import CatalogChunker, RecursiveCharacterChunker, VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding + +# 1) 각 소스를 명시적으로 split +catalog_chunks = CatalogChunker().split(catalog) +doc_chunks = RecursiveCharacterChunker().split(docs) + +# 2) from_chunks: embed + store + registry 자동 처리 +retriever = VectorRetriever.from_chunks( + catalog_chunks + doc_chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + top_n=5, +) + +result = retriever.run("할인 매출") +print(result.schemas) +print(result.context) +``` + +### 11-2. 커스텀 VectorStore와 함께 사용 + +```python +from lang2sql.integrations.vectorstore import InMemoryVectorStore + +store = InMemoryVectorStore() + +retriever = VectorRetriever.from_chunks( + catalog_chunks + doc_chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + vectorstore=store, # 커스텀 store 주입 + top_n=5, + score_threshold=0.2, +) +``` + +### 11-3. 증분 추가 (add) + +`add()`는 pre-split된 `list[IndexedChunk]`만 받습니다. 추가 전 반드시 split이 필요합니다. + +```python +# 카탈로그/문서 초기 인덱싱 +retriever = VectorRetriever.from_chunks( + CatalogChunker().split(catalog), + embedding=OpenAIEmbedding(model="text-embedding-3-small"), +) + +# 나중에 문서 증분 추가 +new_docs = DirectoryLoader("docs/new").load() +retriever.add(RecursiveCharacterChunker().split(new_docs)) + +result = retriever.run("할인 매출") +print(result.schemas) +``` + +Best practice: + +- `from_chunks()`는 embed + upsert를 내부에서 처리 — store/registry 직접 관리 불필요 +- catalog와 doc chunks는 Python list `+` 로 자유롭게 합칠 수 있음 +- `add()`에는 반드시 `chunker.split(docs)` 결과를 전달 + +--- + +## 12) 청킹 전략 교체: Recursive vs Semantic + +### 12-1. 기본 청커 (RecursiveCharacterChunker) + +`from_sources()` — 원터치 패턴에서는 `splitter` 파라미터로 전달합니다. + +```python +from lang2sql import VectorRetriever, RecursiveCharacterChunker +from lang2sql.integrations.embedding import OpenAIEmbedding + +chunker = RecursiveCharacterChunker( + chunk_size=1000, + chunk_overlap=100, # 반드시 chunk_size보다 작아야 함 +) + +retriever = VectorRetriever.from_sources( + catalog=catalog, + documents=docs, + embedding=OpenAIEmbedding(), + splitter=chunker, # document_chunker 대신 splitter +) +``` + +`from_chunks()` — 명시적 패턴에서는 `.split()`을 직접 호출합니다. + +```python +doc_chunks = RecursiveCharacterChunker(chunk_size=1000, chunk_overlap=100).split(docs) +retriever = VectorRetriever.from_chunks( + CatalogChunker().split(catalog) + doc_chunks, + embedding=OpenAIEmbedding(), +) +``` + +### 12-2. 의미 기반 청커 (SemanticChunker, opt-in) + +```python +from lang2sql import CatalogChunker, VectorRetriever +from lang2sql.integrations.chunking import SemanticChunker +from lang2sql.integrations.embedding import OpenAIEmbedding + +embedding = OpenAIEmbedding(model="text-embedding-3-small") + +semantic_chunker = SemanticChunker( + embedding=embedding, # 청킹 단계에서도 임베딩 호출됨 + breakpoint_threshold=0.3, + min_chunk_size=100, +) + +# from_chunks 패턴: 청커를 직접 split에 사용 +doc_chunks = semantic_chunker.split(docs) +retriever = VectorRetriever.from_chunks( + CatalogChunker().split(catalog) + doc_chunks, + embedding=embedding, +) + +# 또는 from_sources 패턴: splitter 파라미터로 전달 +retriever = VectorRetriever.from_sources( + catalog=catalog, + documents=docs, + embedding=embedding, + splitter=semantic_chunker, +) +``` + +주의: + +- SemanticChunker는 인덱싱 비용/시간이 증가합니다. +- sentence split은 punctuation/newline 기반이라 문서 형식에 따라 튜닝이 필요합니다. + +--- + +## 13) HybridRetriever / HybridNL2SQL + +`HybridRetriever`는 BM25 + Vector를 RRF로 합쳐 안정적인 검색 결과를 제공합니다. + +### 13-1. Retriever 단독 사용 +```python +from lang2sql import HybridRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding + +retriever = HybridRetriever( + catalog=catalog, + embedding=OpenAIEmbedding(), + documents=docs, + top_n=5, + rrf_k=60, + score_threshold=0.0, +) + +result = retriever.run("지난달 할인 매출") +print("schemas:", [s["name"] for s in result.schemas]) +print("context:", result.context) +``` + +### 13-2. Flow로 바로 사용 (추천) +```python +from lang2sql import HybridNL2SQL +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM + +pipeline = HybridNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(), + documents=docs, + db_dialect="sqlite", + top_n=5, +) + +rows = pipeline.run("지난달 할인 매출") +print(rows) +``` + +--- + +## 14) 임베딩 교체 테스트 (v2 내장 + 사용자 구현) + +v2 내장 임베딩은 `OpenAIEmbedding` 1개입니다. +하지만 `EmbeddingPort`를 만족하는 클래스를 구현하면 다른 임베딩도 바로 테스트할 수 있습니다. + +### 14-1. 내장 OpenAIEmbedding +```python +from lang2sql.integrations.embedding import OpenAIEmbedding +embedding = OpenAIEmbedding(model="text-embedding-3-small") +``` + +### 14-2. API 키 없이 테스트용 FakeEmbedding +```python +class FakeEmbedding: + # 문자열 길이/토큰 카운트 기반 간단 임베딩 (테스트용) + def _vec(self, text: str) -> list[float]: + return [ + float(len(text)), + float(text.count("매출")), + float(text.count("주문")), + float(text.count("고객")), + ] + + def embed_query(self, text: str) -> list[float]: + return self._vec(text) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [self._vec(t) for t in texts] +``` + +### 14-3. 외부 임베딩 어댑터 예시 (선택) +```python +class SentenceTransformerEmbedding: + def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): + from sentence_transformers import SentenceTransformer + self._model = SentenceTransformer(model_name) + + def embed_query(self, text: str) -> list[float]: + return self._model.encode([text], normalize_embeddings=True)[0].tolist() + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return self._model.encode(texts, normalize_embeddings=True).tolist() +``` + +--- + +## 15) 벡터 스토어 교체 테스트 (v2 내장 + 사용자 구현) + +v2 내장 VectorStore는 `InMemoryVectorStore` 1개입니다. +하지만 `VectorStorePort`를 만족하면 어떤 백엔드든 연결할 수 있습니다. + +### 15-1. 내장 InMemoryVectorStore +```python +from lang2sql.integrations.vectorstore import InMemoryVectorStore +store = InMemoryVectorStore() +``` + +### 15-2. 사용자 구현 VectorStore 어댑터 (테스트용) +아래 코드는 "교체가 실제로 가능한지"를 검증하기 위한 최소 구현입니다. + +```python +class TinyVectorStore: + """ + 학습/테스트용 최소 VectorStore 구현. + 메모리에 id->vector를 저장하고 cosine brute-force 검색을 수행합니다. + """ + + def __init__(self): + self._rows = {} + + def upsert(self, ids: list[str], vectors: list[list[float]]) -> None: + for i, v in zip(ids, vectors): + self._rows[i] = v + + def search(self, vector: list[float], k: int) -> list[tuple[str, float]]: + import math + + def cosine(a, b): + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + 1e-8 + nb = math.sqrt(sum(y * y for y in b)) + 1e-8 + return dot / (na * nb) + + ranked = sorted( + ((i, cosine(v, vector)) for i, v in self._rows.items()), + key=lambda x: x[1], + reverse=True, + ) + return ranked[:k] +``` + +### 15-3. 같은 코드에서 store만 갈아끼우기 +```python +from lang2sql import VectorRetriever + +# A) 내장 store +store_a = InMemoryVectorStore() + +# B) 사용자 구현 store +store_b = TinyVectorStore() + +# 나머지 코드(from_chunks/VectorRetriever)는 동일 +``` + +이게 의미하는 바: + +- 검색 정책(lang2sql 코어)은 유지 +- 저장소 구현체만 교체 + +--- + +## 16) 완전 수동 Advanced Flow 조합 + +아래는 고급 사용자가 실제로 많이 쓰는 패턴입니다. + +```python +from lang2sql import ( + CatalogChunker, + DirectoryLoader, + RecursiveCharacterChunker, + SQLExecutor, + SQLGenerator, + VectorRetriever, +) +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM + +# 1) 문서 로드 +docs = DirectoryLoader("docs/business").load() + +# 2) 명시적 파이프라인: split → from_chunks +embedding = OpenAIEmbedding(model="text-embedding-3-small") + +chunks = ( + CatalogChunker().split(catalog) + + RecursiveCharacterChunker().split(docs) +) + +retriever = VectorRetriever.from_chunks( + chunks, + embedding=embedding, + top_n=5, + score_threshold=0.2, +) + +# 3) 생성 / 실행 컴포넌트 개별 구성 +generator = SQLGenerator( + llm=OpenAILLM(model="gpt-4o-mini"), + db_dialect="sqlite", +) +executor = SQLExecutor(db=SQLAlchemyDB("sqlite:///sample.db")) + +# 4) 플로우 수동 실행 +query = "지난달 할인 반영 순매출" +retrieval = retriever.run(query) +sql = generator.run(query, retrieval.schemas, context=retrieval.context) +rows = executor.run(sql) + +print("SQL:", sql) +print("Rows:", rows) +``` + +이 패턴 장점: + +- split 단계가 코드에 보여 청킹 파라미터 튜닝이 직관적 +- 각 단계 결과를 모두 관측 가능 +- 임계값/청킹/임베딩/저장소를 독립 튜닝 가능 +- 실패 지점 분리 디버깅 쉬움 + +--- + +## 17) 관측성(Tracing)과 디버깅 + +`MemoryHook`으로 컴포넌트/플로우 이벤트를 추적할 수 있습니다. + +```python +from lang2sql import HybridNL2SQL, MemoryHook +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM + +hook = MemoryHook() + +pipeline = HybridNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(), + documents=docs, + db_dialect="sqlite", + top_n=5, + hook=hook, +) + +pipeline.run("지난달 주문 건수") + +for e in hook.snapshot(): + print(e.name, e.component, e.phase, e.duration_ms) +``` + +운영 관점 권장: + +- `duration_ms`를 컴포넌트별로 기록해 병목 확인 +- `error` 이벤트를 수집해 장애 패턴 분석 + +--- + +## 18) Best Practices 체크리스트 + +### 검색/인덱싱 +- `catalog`는 최소 `name`, `description`, `columns`를 충실히 작성 +- 문서는 한 파일에 너무 많은 주제를 넣지 말고 주제별 분리 +- `top_n`은 3~8 범위에서 시작해 실험 +- `score_threshold`는 0.0으로 시작 후 점진 상향 + +### 청킹 +- 기본은 `RecursiveCharacterChunker` +- 문서 품질이 중요하고 비용 허용 시 `SemanticChunker` 검토 +- `chunk_overlap`은 `chunk_size`보다 반드시 작게 설정 + +### 플로우 선택 +- 빠른 시작: `BaselineNL2SQL` +- 검색 품질 우선: `HybridNL2SQL` +- 완전 제어: 수동 컴포넌트 조합 + +### 운영 +- Hook 이벤트를 저장하고 p95 지표를 모니터링 +- 회귀 테스트를 정기 실행 + +```bash +pytest tests/test_components_vector_retriever.py -q +pytest tests/test_components_hybrid_retriever.py -q +pytest tests/test_components_loaders.py -q +``` + +--- + +## 19) 트러블슈팅 + +### Q1. `IntegrationMissingError: openai` +- 원인: `openai` 패키지 미설치 +- 해결: +```bash +pip install openai +``` + +### Q2. `chunk_overlap must be less than chunk_size` +- 원인: `RecursiveCharacterChunker` 파라미터 설정 오류 +- 해결: `chunk_overlap < chunk_size`로 수정 + +### Q3. VectorRetriever 결과가 비어 있음 +- 확인 순서: +1. `from_chunks(chunks, ...)` 또는 `from_sources(catalog=..., ...)` 가 실제로 호출되었는지 +2. `len(retriever._registry) > 0`인지 확인 +3. `score_threshold`가 너무 높지 않은지 (0.0으로 낮춰서 테스트) + +### Q4. `retriever.add()` 호출 시 타입 에러 +- 원인: `add()`는 `list[IndexedChunk]`만 받습니다. `TextDocument`를 직접 전달하면 에러가 발생합니다. +- 해결: 추가 전 반드시 `chunker.split(docs)`로 변환하세요: +```python +# ❌ 동작 안 함 +retriever.add(docs) + +# ✅ 올바른 방법 +retriever.add(RecursiveCharacterChunker().split(docs)) +``` + +### Q5. `IntegrationMissingError: pymupdf` +- 원인: `PDFLoader` 사용 시 `pymupdf` 미설치 +- 해결: +```bash +pip install pymupdf +``` + +--- + +## 마무리 + +이 문서의 순서대로 진행하면 아래 모든 경로를 실제로 검증할 수 있습니다. + +- Baseline keyword 플로우 +- VectorRetriever + 문서 인덱싱 +- HybridRetriever / HybridNL2SQL +- Loader/Chunker/Embedding/VectorStore 교체 +- 수동 Advanced Flow 및 tracing + +빠르게 시작하려면: + +1. 6단계(로컬 스모크 테스트) +2. 7단계(Baseline) +3. 13단계(HybridNL2SQL) + +고급 운영 튜닝까지 가려면: + +4. 11~16단계(from_chunks/어댑터/수동조합)까지 진행하세요. diff --git a/docs/tutorials/v2-usage-guide.md b/docs/tutorials/v2-usage-guide.md new file mode 100644 index 0000000..341041a --- /dev/null +++ b/docs/tutorials/v2-usage-guide.md @@ -0,0 +1,263 @@ +# lang2sql v2 Usage Guide + +이 문서는 `src/lang2sql` 기준의 새로운 v2 API만 다룹니다. +기존 `engine/`, `interface/`, `utils/llm/` 경로는 범위에서 제외합니다. + +자세한 단계별 실습은 [v2-complete-tutorial.md](./v2-complete-tutorial.md) 를 참고하세요. + +## 0) Why lang2sql + +- **운영 친화적인 기본 경로**: `Retriever -> Generator -> Executor`가 단순하고 디버깅 포인트가 명확합니다. +- **명시적 인덱싱 파이프라인**: `chunker.split(docs)` → `VectorRetriever.from_chunks(chunks)` 패턴으로 split/embed/store 각 단계가 코드에 보입니다. +- **프레임워크 락인 최소화**: 코어가 Protocol(`EmbeddingPort`, `VectorStorePort`, `DocumentChunkerPort`) 기반이라 구현체를 교체하기 쉽습니다. +- **관측성 내장**: Hook(`TraceHook`, `MemoryHook`)으로 컴포넌트 단위 실행 이벤트를 바로 수집할 수 있습니다. + +## 0-1) 튜토리얼 데이터 자동 준비 + +```bash +python scripts/setup_sample_db.py +python scripts/setup_sample_docs.py +``` + +문서 생성 후 `docs/business` 아래 파일을 로더 예제에서 그대로 사용합니다. + +## 1) v2에서 실제로 지원되는 기능 + +### Flows +- `BaselineNL2SQL`: BM25 `KeywordRetriever` 기반 기본 파이프라인 +- `HybridNL2SQL`: BM25 + Vector `HybridRetriever` 기반 파이프라인 + +### Retrievers +- `KeywordRetriever` +- `VectorRetriever` +- `HybridRetriever` + +### Vector / Embedding (v2 내장) +- Embedding: `OpenAIEmbedding` (내장 1개) +- Vector store: `InMemoryVectorStore` (내장 1개) + +### Chunking / Loading +- Chunkers: `CatalogChunker`, `RecursiveCharacterChunker`, `SemanticChunker` + - 모두 `.split(list)` 메서드 제공 — LangChain 스타일 batch 입력/출력 +- Loaders: `MarkdownLoader`, `PlainTextLoader`, `DirectoryLoader` + - `PDFLoader` (optional, `pip install pymupdf`) + +### Extensibility (Protocol) +- `EmbeddingPort`, `VectorStorePort`, `DocumentChunkerPort`, `DocumentLoaderPort` +- 즉, 내장 구현 외에도 사용자 어댑터를 연결할 수 있습니다. + +## 2) 빠른 선택 가이드 + +### 가장 쉬운 시작 +- 목적: 설치 후 바로 NL2SQL 확인 +- 선택: `BaselineNL2SQL` +- 특징: 벡터 인덱싱 없이 즉시 사용 + +### 검색 품질을 빠르게 올리고 싶을 때 +- 목적: 키워드 매칭 한계를 보완 +- 선택: `HybridNL2SQL` + `OpenAIEmbedding` +- 특징: BM25 + Vector RRF 결합으로 안정적인 검색 품질 + +### 고급 제어가 필요할 때 +- 목적: 청킹/임베딩/인덱싱/검색 파이프라인 세밀 제어 +- 선택: `chunker.split()` + `VectorRetriever.from_chunks()` + 수동 컴포넌트 조합 +- 특징: 증분 인덱싱, 커스텀 Chunker/VectorStore/Embedding 연동 가능 + +## 3) 최소 예제 + +### A. BaselineNL2SQL (키워드 기반) +```python +from lang2sql import BaselineNL2SQL +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.llm import OpenAILLM + +catalog = [ + { + "name": "orders", + "description": "order table", + "columns": {"order_id": "pk", "amount": "order amount"}, + } +] + +pipeline = BaselineNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + db_dialect="sqlite", +) + +rows = pipeline.run("지난달 주문 건수") +print(rows) +``` + +### B. HybridNL2SQL (키워드 + 벡터) +```python +from lang2sql import HybridNL2SQL +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM + +catalog = [ + { + "name": "orders", + "description": "order table", + "columns": {"order_id": "pk", "amount": "order amount"}, + } +] + +docs = [ + { + "id": "biz_rules", + "title": "매출 정의", + "content": "매출은 반품 제외 순매출이다.", + "source": "docs/biz_rules.md", + } +] + +pipeline = HybridNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + documents=docs, + db_dialect="sqlite", + top_n=5, +) + +rows = pipeline.run("지난달 순매출") +print(rows) +``` + +### C. 명시적 파이프라인: split → from_chunks (LangChain 스타일) +```python +from lang2sql import ( + CatalogChunker, + DirectoryLoader, + RecursiveCharacterChunker, + VectorRetriever, +) +from lang2sql.integrations.embedding import OpenAIEmbedding + +catalog = [ + { + "name": "orders", + "description": "order table", + "columns": {"order_id": "pk", "amount": "order amount"}, + } +] + +# 1) 문서 로딩 +docs = DirectoryLoader("docs/business").load() + +# 2) 각 소스를 명시적으로 split +catalog_chunks = CatalogChunker().split(catalog) +doc_chunks = RecursiveCharacterChunker(chunk_size=800, chunk_overlap=80).split(docs) + +# 3) chunks를 합쳐서 retriever 생성 (embed + store 자동) +retriever = VectorRetriever.from_chunks( + catalog_chunks + doc_chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + top_n=5, +) + +result = retriever.run("순매출 계산 기준") +print("schemas:", [s["name"] for s in result.schemas]) +print("context:", result.context[:2]) +``` + +명시적 플로우의 장점: + +1. split 단계가 코드에 보임 — `chunker.split(docs)`가 명시적 +2. catalog chunks + doc chunks를 Python list로 자유롭게 조합 가능 +3. `registry = {}` 같은 내부 상태를 사용자가 직접 관리할 필요 없음 + +증분 추가 시에는 chunks를 미리 split한 뒤 전달합니다: + +```python +new_docs = DirectoryLoader("docs/new").load() +retriever.add(RecursiveCharacterChunker().split(new_docs)) +``` + +### D. DirectoryLoader → HybridNL2SQL 직결 + +문서를 로드한 뒤 바로 HybridNL2SQL에 전달하는 가장 간결한 패턴입니다. + +```python +from lang2sql import DirectoryLoader, HybridNL2SQL +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM + +docs = DirectoryLoader("docs/business").load() + +pipeline = HybridNL2SQL( + catalog=catalog, + llm=OpenAILLM(model="gpt-4o-mini"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + documents=docs, + db_dialect="sqlite", +) + +rows = pipeline.run("지난달 순매출") +print(rows) +``` + +### E. PDFLoader — PDF 파일 인덱싱 + +PDF 파일은 `PDFLoader`로 로드합니다 (`pip install pymupdf` 필요). + +```python +from lang2sql import DirectoryLoader, MarkdownLoader, VectorRetriever +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.loaders import PDFLoader + +# PDFLoader를 DirectoryLoader에 등록 +docs = DirectoryLoader( + "docs/", + loaders={ + ".md": MarkdownLoader(), + ".pdf": PDFLoader(), + }, +).load() + +# 이후 from_chunks 패턴으로 인덱싱 +from lang2sql import CatalogChunker, RecursiveCharacterChunker + +chunks = ( + CatalogChunker().split(catalog) + + RecursiveCharacterChunker().split(docs) +) +retriever = VectorRetriever.from_chunks( + chunks, + embedding=OpenAIEmbedding(model="text-embedding-3-small"), +) +``` + +PDF는 페이지 단위로 `TextDocument`를 생성합니다: +- `id`: `"{filename}__p{page_number}"` +- `title`: `"{filename} page {page_number}"` + +## 4) 중요한 현재 제약 + +- v2 내장 VectorStore는 현재 `InMemoryVectorStore`만 공식 제공됩니다. +- `BaselineNL2SQL`은 현재 `retriever` 주입 파라미터를 받지 않습니다. + - 벡터 기반 파이프라인은 `HybridNL2SQL` 또는 수동 조합을 사용하세요. +- `VectorRetriever` 결과의 `context`는 현재 `list[str]`입니다. + - 문서 출처 구조화가 필요하면 `metadata`를 별도 조회하거나 커스텀 래퍼를 두세요. +- `retriever.add()`는 **`list[IndexedChunk]`만 받습니다** — `TextDocument` 직접 전달 불가. + - 추가 전 반드시 `chunker.split(docs)`로 split한 결과를 전달하세요: + ```python + # ❌ 동작 안 함 + retriever.add(docs) + + # ✅ 올바른 방법 + retriever.add(RecursiveCharacterChunker().split(docs)) + ``` + +## 5) 추천 실습 순서 + +1. [v2-complete-tutorial.md](./v2-complete-tutorial.md) 1~4단계로 로컬 스모크 테스트 +2. 동일 문서 5~8단계로 실제 DB/LLM 연결 +3. 동일 문서 9~13단계로 벡터 인덱싱/문서 파싱/청킹 튜닝 +4. 동일 문서 14~18단계로 고급 조합과 커스텀 어댑터 테스트 diff --git a/docs/tutorials/vector-retriever.md b/docs/tutorials/vector-retriever.md new file mode 100644 index 0000000..bef59ac --- /dev/null +++ b/docs/tutorials/vector-retriever.md @@ -0,0 +1,587 @@ +# VectorRetriever 튜토리얼 — 벡터 유사도 검색으로 NL2SQL 정확도 높이기 + +이 튜토리얼은 `VectorRetriever`를 처음 사용하는 분을 위한 단계별 가이드입니다. +`KeywordRetriever`(BM25 키워드 검색)와 다른 점, 설정 방법, 파이프라인에 연결하는 방법을 설명합니다. + +--- + +## 목차 + +1. [KeywordRetriever vs VectorRetriever — 언제 무엇을 쓸까?](#1-keywordretriever-vs-vectorretriever--언제-무엇을-쓸까) +2. [설치 — 임베딩 패키지 추가하기](#2-설치--임베딩-패키지-추가하기) +3. [가장 빠른 시작 — from_sources()](#3-가장-빠른-시작--from_sources) +4. [비즈니스 문서를 컨텍스트로 추가하기](#4-비즈니스-문서를-컨텍스트로-추가하기) +5. [파이프라인에 연결하기 — BaselineNL2SQL](#5-파이프라인에-연결하기--baselinenl2sql) +6. [인덱스 점진적으로 추가하기 — add()](#6-인덱스-점진적으로-추가하기--add) +7. [고급 — IndexBuilder 직접 사용하기](#7-고급--indexbuilder-직접-사용하기) +8. [고급 — 청커 교체하기](#8-고급--청커-교체하기) +9. [점수 임계값과 top_n 조정](#9-점수-임계값과-top_n-조정) +10. [전체 체크리스트 — API 키 없이 실행](#10-전체-체크리스트--api-키-없이-실행) + +--- + +## 1. KeywordRetriever vs VectorRetriever — 언제 무엇을 쓸까? + +| | `KeywordRetriever` | `VectorRetriever` | +|---|---|---| +| **검색 방식** | BM25 키워드 매칭 | 벡터 코사인 유사도 | +| **강점** | 빠름, 외부 의존성 없음 | 동의어·의미 유사 쿼리에 강함 | +| **약점** | 질문과 컬럼명이 다를 때 누락 | 임베딩 API 또는 모델 필요 | +| **적합한 상황** | 카탈로그 규모가 작고 컬럼명이 명확할 때 | 카탈로그가 크거나, 비즈니스 용어가 컬럼명과 다를 때 | +| **비즈니스 문서 지원** | 없음 | 있음 (`context` 필드로 LLM에 전달) | + +> **판단 기준**: `"매출"` 이라고 물었을 때 `amount` 컬럼이 검색되지 않으면 VectorRetriever로 교체하세요. + +--- + +## 2. 설치 + +```bash +pip install lang2sql +``` + +`openai`는 lang2sql의 기본 의존성에 포함되어 있어 별도 설치가 필요 없습니다. + +> 임베딩 API 없이 테스트하고 싶다면 **섹션 10**의 `FakeEmbedding`을 먼저 실행해 보세요. + +--- + +## 3. 가장 빠른 시작 — from_sources() + +`VectorRetriever.from_sources()` 한 줄로 인덱스를 만들고 즉시 검색할 수 있습니다. + +```python +from lang2sql import VectorRetriever, CatalogEntry +from lang2sql.integrations.embedding import OpenAIEmbedding + +CATALOG: list[CatalogEntry] = [ + { + "name": "orders", + "description": "고객 주문 정보 테이블. 주문 건수, 매출, 날짜 조회에 사용.", + "columns": { + "order_id": "주문 고유 ID (PK)", + "customer_id": "주문한 고객 ID (FK → customers)", + "order_date": "주문 일시 (TIMESTAMP)", + "amount": "주문 금액 (DECIMAL)", + "status": "주문 상태: pending / confirmed / shipped / cancelled", + }, + }, + { + "name": "customers", + "description": "고객 마스터 데이터. 고객 이름, 가입일, 등급 조회에 사용.", + "columns": { + "customer_id": "고객 고유 ID (PK)", + "name": "고객 이름", + "grade": "고객 등급: bronze / silver / gold", + }, + }, +] + +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=OpenAIEmbedding(), # OPENAI_API_KEY 환경변수 필요 +) + +result = retriever("매출 상위 고객 목록") + +print(result.schemas) +# [{'name': 'orders', ...}, {'name': 'customers', ...}] + +print(result.context) +# [] — 문서를 추가하지 않았으므로 빈 리스트 +``` + +`from_sources()`는 내부적으로 다음을 자동으로 처리합니다: +- `InMemoryVectorStore` 생성 (외부 DB 불필요) +- `IndexBuilder`로 카탈로그 청킹 → 임베딩 → 저장 +- 검색 준비 완료된 `VectorRetriever` 반환 + +--- + +## 4. 비즈니스 문서를 컨텍스트로 추가하기 + +"매출"의 정의, KPI 계산 방식 같은 비즈니스 규칙을 문서로 등록하면 +LLM이 SQL 생성 시 해당 내용을 참고합니다. + +```python +from lang2sql import TextDocument + +DOCS: list[TextDocument] = [ + { + "id": "revenue_def", + "title": "매출 정의", + "content": "매출은 반품을 제외한 순매출(net sales)을 기준으로 한다. " + "취소(cancelled) 상태의 주문은 매출에서 제외한다.", + "source": "docs/revenue_definition.md", + }, + { + "id": "grade_policy", + "title": "고객 등급 정책", + "content": "gold 등급: 최근 3개월 누적 구매액 50만원 이상. " + "silver 등급: 20만원 이상. bronze: 그 외.", + "source": "docs/customer_grade.md", + }, +] + +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + documents=DOCS, # ← 문서 동시 인덱싱 + embedding=OpenAIEmbedding(), +) + +result = retriever("이번 달 매출을 집계해줘") + +print(result.schemas) # 관련 테이블 목록 +print(result.context) # 관련 문서 텍스트 — LLM 프롬프트에 포함됨 +# ['매출 정의: 매출은 반품을 제외한 순매출...'] +``` + +> `result.context`의 내용은 `SQLGenerator`가 프롬프트에 "Business Context" 섹션으로 자동 삽입합니다. + +--- + +## 5. 파이프라인에 연결하기 — BaselineNL2SQL + +`BaselineNL2SQL`의 `retriever=` 파라미터로 `VectorRetriever`를 주입합니다. +기본 `KeywordRetriever`를 대체합니다. + +```python +from lang2sql import BaselineNL2SQL, VectorRetriever +from lang2sql.integrations.llm import AnthropicLLM +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding + +# 1. VectorRetriever 준비 +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + documents=DOCS, + embedding=OpenAIEmbedding(), +) + +# 2. 파이프라인에 주입 +pipeline = BaselineNL2SQL( + catalog=CATALOG, # KeywordRetriever 기본값용 (retriever 주입 시 무시됨) + llm=AnthropicLLM(model="claude-sonnet-4-6"), + db=SQLAlchemyDB("sqlite:///sample.db"), + db_dialect="sqlite", + retriever=retriever, # ← VectorRetriever 주입 +) + +result = pipeline.run("취소 제외한 이번 달 매출 합계") +print(result) +``` + +> `retriever=`가 주어지면 `catalog=`는 내부적으로 사용되지 않습니다. +> 하지만 API 일관성을 위해 `catalog=`를 함께 전달하는 것을 권장합니다. + +--- + +## 6. 인덱스 점진적으로 추가하기 — add() + +파이프라인이 실행 중일 때 새 문서를 동적으로 추가할 수 있습니다. +기존 카탈로그 인덱스는 그대로 유지됩니다. + +```python +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=OpenAIEmbedding(), +) + +# 나중에 문서가 생겼을 때 추가 +NEW_DOCS: list[TextDocument] = [ + { + "id": "discount_policy", + "title": "할인 정책", + "content": "VIP 고객(gold 등급)에게는 정가의 10% 할인을 적용한다.", + "source": "docs/discount.md", + }, +] + +retriever.add(NEW_DOCS) # 기존 카탈로그 인덱스 유지 + 새 문서 추가 + +result = retriever("VIP 고객 할인 금액 계산") +print(result.context) +# ['할인 정책: VIP 고객(gold 등급)에게는...'] +``` + +> `add()`는 `from_sources()`로 만든 retriever에서만 사용할 수 있습니다. +> 직접 생성한 경우엔 `IndexBuilder.run()`을 호출하세요 (섹션 7 참고). + +--- + +## 7. 고급 — IndexBuilder 직접 사용하기 + +여러 소스를 단계별로 인덱싱하거나, 커스텀 벡터 저장소를 쓰고 싶을 때 +`IndexBuilder`를 직접 조작합니다. + +```python +from lang2sql import VectorRetriever, IndexBuilder +from lang2sql.integrations.vectorstore import InMemoryVectorStore +from lang2sql.integrations.embedding import OpenAIEmbedding + +embedding = OpenAIEmbedding() +store = InMemoryVectorStore() +registry: dict = {} # IndexBuilder와 VectorRetriever가 공유하는 저장소 + +builder = IndexBuilder( + embedding=embedding, + vectorstore=store, + registry=registry, +) + +retriever = VectorRetriever( + vectorstore=store, + embedding=embedding, + registry=registry, # 같은 registry 공유 +) + +# 단계별 인덱싱 — 기존 데이터 유지됨 +builder.run(CATALOG) # 카탈로그 인덱싱 +builder.run(DOCS) # 문서 인덱싱 (카탈로그 유지) +builder.run(NEW_DOCS) # 추가 문서 (기존 모두 유지) + +result = retriever("매출 정의") +``` + +`from_sources()` 대비 직접 제어가 필요한 경우: +- 벡터 저장소를 외부 DB(FAISS 파일, pgvector)로 교체할 때 +- 인덱스를 디스크에 저장하고 재사용할 때 +- 카탈로그와 문서를 따로 스케줄링할 때 + +--- + +## 8. 고급 — 청커 교체하기 + +### 기본 청커 비교 + +| 청커 | 위치 | 특징 | +|------|------|------| +| `CatalogChunker` | `components/retrieval/chunker.py` | 테이블 헤더 + 컬럼 그룹으로 분할. 스키마 검색 전용. | +| `RecursiveCharacterChunker` | `components/retrieval/chunker.py` | 문단→줄→문장 순 재귀 분할. 외부 의존성 없음. | +| `SemanticChunker` | `integrations/chunking/semantic_.py` | 임베딩 기반 의미 단위 분할. 품질 우선 시 사용. | + +### SemanticChunker 사용하기 (opt-in) + +```bash +pip install sentence-transformers # 또는 openai 패키지 +``` + +```python +from lang2sql import VectorRetriever +from lang2sql.integrations.chunking import SemanticChunker +from lang2sql.integrations.embedding import OpenAIEmbedding + +embedding = OpenAIEmbedding() + +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + documents=DOCS, + embedding=embedding, + document_chunker=SemanticChunker(embedding=embedding), # ← 의미 기반 청킹 +) +``` + +### LangChain 청커 어댑터 (외부 라이브러리 연결) + +```python +from langchain_text_splitters import RecursiveCharacterTextSplitter +from lang2sql import IndexedChunk, TextDocument + +class LangChainChunkerAdapter: + """LangChain 텍스트 스플리터를 lang2sql DocumentChunkerPort에 맞게 감쌉니다.""" + + def __init__(self, splitter): + self._splitter = splitter + + def chunk(self, doc: TextDocument) -> list[IndexedChunk]: + texts = self._splitter.split_text(doc["content"]) + title = doc.get("title", "") + return [ + IndexedChunk( + chunk_id=f"{doc['id']}__{i}", + text=f"{title}: {text}" if title else text, + source_type="document", + source_id=doc["id"], + chunk_index=i, + metadata={"title": title, "source": doc.get("source", "")}, + ) + for i, text in enumerate(texts) + ] + + +lc_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) + +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + documents=DOCS, + embedding=OpenAIEmbedding(), + document_chunker=LangChainChunkerAdapter(lc_splitter), +) +``` + +--- + +## 9. 점수 임계값과 top_n 조정 + +```python +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=OpenAIEmbedding(), + top_n=3, # 반환할 최대 스키마/문서 수 (기본값: 5) + score_threshold=0.5, # 유사도가 이 값보다 낮은 결과는 제외 (기본값: 0.0) +) +``` + +| 파라미터 | 기본값 | 설명 | +|----------|--------|------| +| `top_n` | 5 | 반환하는 스키마(schemas)와 문서(context) 각각의 최대 수 | +| `score_threshold` | 0.0 | 이 값 **이하**의 유사도 점수는 결과에서 제외. 낮은 관련성 결과를 걸러낼 때 사용 | + +> 관련 없는 테이블이 자꾸 검색된다면 `score_threshold`를 0.3~0.5 사이로 높여보세요. + +--- + +## 10. 전체 체크리스트 — API 키 없이 실행 + +아래 코드는 실제 임베딩 API 없이 `FakeEmbedding`으로 모든 기능을 확인합니다. + +```python +""" +VectorRetriever 전체 체크리스트 +API 키 없이 FakeEmbedding으로 실행 가능합니다. +""" + +# ── 0. FakeEmbedding 정의 ───────────────────────────────────────────────────── + +class FakeEmbedding: + """테스트용 고정 벡터 임베딩. 실제 유사도 계산은 하지 않습니다.""" + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3, 0.4] + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3, 0.4]] * len(texts) + + +# ── 1. 카탈로그와 문서 준비 ────────────────────────────────────────────────── + +from lang2sql import CatalogEntry, TextDocument + +CATALOG: list[CatalogEntry] = [ + { + "name": "orders", + "description": "고객 주문 정보 테이블", + "columns": {"order_id": "PK", "amount": "금액", "status": "상태"}, + }, + { + "name": "customers", + "description": "고객 마스터 데이터", + "columns": {"customer_id": "PK", "name": "이름", "grade": "등급"}, + }, +] + +DOCS: list[TextDocument] = [ + { + "id": "revenue_def", + "title": "매출 정의", + "content": "매출은 반품 제외 순매출이며 cancelled 주문은 제외한다.", + "source": "docs/revenue.md", + }, +] + + +# ── 2. from_sources() — 카탈로그만 ─────────────────────────────────────────── + +from lang2sql import VectorRetriever + +retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=FakeEmbedding(), +) + +result = retriever("주문 건수") +print("✓ from_sources() — 카탈로그만") +print(f" schemas: {[s['name'] for s in result.schemas]}") +print(f" context: {result.context}") +assert isinstance(result.schemas, list) +assert result.context == [] + + +# ── 3. from_sources() — 문서 포함 ──────────────────────────────────────────── + +retriever2 = VectorRetriever.from_sources( + catalog=CATALOG, + documents=DOCS, + embedding=FakeEmbedding(), +) + +result2 = retriever2("매출 정의") +print("\n✓ from_sources() — 문서 포함") +print(f" schemas: {[s['name'] for s in result2.schemas]}") +print(f" context: {result2.context}") +assert len(result2.context) >= 1 + + +# ── 4. add() — 점진적 인덱싱 ───────────────────────────────────────────────── + +initial_count = len(retriever._registry) + +NEW_DOC: list[TextDocument] = [ + { + "id": "grade_policy", + "title": "등급 정책", + "content": "gold 등급은 최근 3개월 50만원 이상 구매 고객이다.", + "source": "docs/grade.md", + }, +] + +retriever.add(NEW_DOC) + +print("\n✓ add() — 점진적 인덱싱") +print(f" registry 크기: {initial_count} → {len(retriever._registry)}") +assert len(retriever._registry) > initial_count + + +# ── 5. score_threshold 필터링 ───────────────────────────────────────────────── + +from lang2sql.integrations.vectorstore import InMemoryVectorStore + +store = InMemoryVectorStore() +registry: dict = {} + +from lang2sql import IndexBuilder + +builder = IndexBuilder( + embedding=FakeEmbedding(), + vectorstore=store, + registry=registry, +) +builder.run(CATALOG) + +strict_retriever = VectorRetriever( + vectorstore=store, + embedding=FakeEmbedding(), + registry=registry, + # FakeEmbedding은 항상 동일 벡터 반환 → 코사인 유사도 = 1.0 + # threshold=1.0 이면 1.0 <= 1.0 조건 충족 → 전부 필터링됨 + score_threshold=1.0, +) + +result3 = strict_retriever("주문") +print("\n✓ score_threshold=1.0 — 결과 필터링") +print(f" schemas: {result3.schemas} (빈 리스트 예상)") +assert result3.schemas == [] + + +# ── 6. IndexBuilder 직접 사용 ───────────────────────────────────────────────── + +store2 = InMemoryVectorStore() +registry2: dict = {} +builder2 = IndexBuilder( + embedding=FakeEmbedding(), + vectorstore=store2, + registry=registry2, +) +retriever3 = VectorRetriever( + vectorstore=store2, + embedding=FakeEmbedding(), + registry=registry2, +) + +builder2.run(CATALOG) +catalog_ids = set(registry2.keys()) + +builder2.run(DOCS) # 카탈로그가 유지되는지 확인 +for chunk_id in catalog_ids: + assert chunk_id in registry2, f"카탈로그 청크 '{chunk_id}' 유실!" + +print("\n✓ IndexBuilder — 카탈로그 유지 확인") +print(f" 카탈로그 청크 수: {len(catalog_ids)} (모두 유지됨)") + + +# ── 7. BaselineNL2SQL 파이프라인 주입 ──────────────────────────────────────── + +class FakeLLM: + def invoke(self, messages): + return "```sql\nSELECT COUNT(*) FROM orders\n```" + +class FakeDB: + def execute(self, sql): + return [{"cnt": 44}] + +from lang2sql import BaselineNL2SQL + +pipeline = BaselineNL2SQL( + catalog=CATALOG, + llm=FakeLLM(), + db=FakeDB(), + retriever=VectorRetriever.from_sources( + catalog=CATALOG, + embedding=FakeEmbedding(), + ), +) + +result4 = pipeline.run("주문 건수") +print("\n✓ BaselineNL2SQL — VectorRetriever 주입") +print(f" 결과: {result4}") +assert result4 == [{"cnt": 44}] + + +# ── 8. public import 확인 ──────────────────────────────────────────────────── + +from lang2sql import ( + VectorRetriever, + IndexBuilder, + CatalogChunker, + RecursiveCharacterChunker, + DocumentChunkerPort, + RetrievalResult, + TextDocument, + IndexedChunk, + EmbeddingPort, + VectorStorePort, +) +print("\n✓ 모든 VectorRetriever 관련 import 성공") + +print("\n" + "=" * 50) +print("모든 체크리스트 통과! VectorRetriever 사용 준비 완료.") +print("=" * 50) +``` + +--- + +## 참고: 아키텍처 한눈에 보기 + +``` +[CATALOG / DOCS] + │ + ▼ + IndexBuilder.run() + ├── CatalogChunker — 테이블 헤더 + 컬럼 그룹 분할 + └── RecursiveCharacterChunker / SemanticChunker — 문서 분할 + │ + ▼ embed_texts() + EmbeddingPort — OpenAIEmbedding 등 + │ + ▼ upsert() + VectorStorePort — InMemoryVectorStore (기본) + + registry dict 공유 + │ + ▼ + VectorRetriever.__call__(query) + ├── embed_query(query) + ├── vectorstore.search(vector, k) + └── RetrievalResult + ├── .schemas — 관련 CatalogEntry 목록 (중복 제거됨) + └── .context — 관련 문서 텍스트 목록 + │ + ▼ + SQLGenerator — "Business Context" 섹션으로 프롬프트에 포함 +``` + +**확장 포인트:** + +| 인터페이스 | 구현할 메서드 | 용도 | +|-----------|------------|------| +| `EmbeddingPort` | `embed_query()`, `embed_texts()` | 임베딩 백엔드 교체 | +| `VectorStorePort` | `search()`, `upsert()` | 벡터 저장소 교체 (FAISS, pgvector 등) | +| `DocumentChunkerPort` | `chunk(doc)` | 청킹 전략 교체 | diff --git a/scripts/setup_sample_docs.py b/scripts/setup_sample_docs.py new file mode 100644 index 0000000..fb22870 --- /dev/null +++ b/scripts/setup_sample_docs.py @@ -0,0 +1,88 @@ +""" +v2 튜토리얼용 샘플 문서 생성 스크립트. + +사용법 +------ +python scripts/setup_sample_docs.py +python scripts/setup_sample_docs.py --out-dir docs/business --force +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +_SAMPLE_FILES: dict[str, str] = { + "revenue.md": """ +# 순매출 정의 + +순매출은 주문 금액(`orders.amount`)에서 취소 주문을 제외한 값으로 계산한다. +주문 상태가 `cancelled` 인 레코드는 집계에서 제외한다. + +월간 매출 지표는 `orders.order_date` 기준으로 월 단위 그룹핑한다. +""".strip(), + "order_status_policy.md": """ +# 주문 상태 정책 + +주문 상태 컬럼은 `orders.status` 이다. + +- pending: 결제 대기 +- confirmed: 결제 완료 +- shipped: 배송 완료 +- cancelled: 취소 + +운영 리포트에서 "완료 주문"은 `confirmed`, `shipped`를 포함한다. +""".strip(), + "rules.txt": """ +고객 등급은 customers.grade 컬럼을 사용한다. +gold, silver, bronze 세 등급이 존재한다. +고객별 주문 빈도 분석은 orders.customer_id 와 customers.customer_id 를 조인해서 수행한다. +""".strip(), +} + + +def setup_sample_docs(out_dir: Path, force: bool = False) -> tuple[int, int]: + out_dir.mkdir(parents=True, exist_ok=True) + + created_or_updated = 0 + skipped = 0 + + for filename, content in _SAMPLE_FILES.items(): + path = out_dir / filename + if path.exists() and not force: + skipped += 1 + continue + path.write_text(content + "\n", encoding="utf-8") + created_or_updated += 1 + + return created_or_updated, skipped + + +def main() -> None: + parser = argparse.ArgumentParser(description="v2 튜토리얼 샘플 문서 생성") + parser.add_argument( + "--out-dir", + default="docs/business", + help="샘플 문서를 생성할 디렉터리 (기본값: docs/business)", + ) + parser.add_argument( + "--force", + action="store_true", + help="기존 파일이 있어도 덮어쓰기", + ) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + created_or_updated, skipped = setup_sample_docs(out_dir=out_dir, force=args.force) + + print(f"문서 출력 경로: {out_dir}") + print(f"생성/갱신: {created_or_updated}개") + print(f"건너뜀: {skipped}개") + print() + print("생성 파일:") + for filename in sorted(_SAMPLE_FILES): + print(f" - {out_dir / filename}") + + +if __name__ == "__main__": + main() diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py index cef1f74..9ab711e 100644 --- a/src/lang2sql/__init__.py +++ b/src/lang2sql/__init__.py @@ -1,24 +1,56 @@ from .components.execution.sql_executor import SQLExecutor from .components.generation.sql_generator import SQLGenerator +from .components.loaders.directory_ import DirectoryLoader +from .components.loaders.markdown_ import MarkdownLoader +from .components.loaders.plaintext_ import PlainTextLoader +from .components.retrieval.chunker import ( + CatalogChunker, + DocumentChunkerPort, + RecursiveCharacterChunker, +) +from .components.retrieval.hybrid import HybridRetriever from .components.retrieval.keyword import KeywordRetriever -from .core.catalog import CatalogEntry +from .components.retrieval.vector import VectorRetriever +from .core.catalog import CatalogEntry, IndexedChunk, RetrievalResult, TextDocument from .core.exceptions import ComponentError, IntegrationMissingError, Lang2SQLError from .core.hooks import MemoryHook, NullHook, TraceHook -from .core.ports import DBPort, LLMPort +from .core.ports import ( + DBPort, + DocumentLoaderPort, + EmbeddingPort, + LLMPort, + VectorStorePort, +) +from .flows.hybrid import HybridNL2SQL from .flows.nl2sql import BaselineNL2SQL __all__ = [ # Data types "CatalogEntry", + "TextDocument", + "IndexedChunk", + "RetrievalResult", # Ports (protocols) "LLMPort", "DBPort", + "EmbeddingPort", + "VectorStorePort", + "DocumentLoaderPort", # Components "KeywordRetriever", + "VectorRetriever", + "HybridRetriever", + "DocumentChunkerPort", + "CatalogChunker", + "RecursiveCharacterChunker", "SQLGenerator", "SQLExecutor", + "MarkdownLoader", + "PlainTextLoader", + "DirectoryLoader", # Flows "BaselineNL2SQL", + "HybridNL2SQL", # Hooks "TraceHook", "MemoryHook", diff --git a/src/lang2sql/components/generation/sql_generator.py b/src/lang2sql/components/generation/sql_generator.py index eba4841..efc7bb3 100644 --- a/src/lang2sql/components/generation/sql_generator.py +++ b/src/lang2sql/components/generation/sql_generator.py @@ -53,11 +53,21 @@ def __init__( else: self._system_prompt = _load_prompt("default") - def _run(self, query: str, schemas: list[CatalogEntry]) -> str: - context = self._build_context(schemas) + def _run( + self, + query: str, + schemas: list[CatalogEntry], + context: Optional[list[str]] = None, + ) -> str: + schema_text = self._build_context(schemas) + user_parts: list[str] = [] + if context: + user_parts.append("Business Context:\n" + "\n\n".join(context)) + user_parts.append(f"Schemas:\n{schema_text}\n\nQuestion: {query}") + user_content = "\n\n".join(user_parts) messages = [ {"role": "system", "content": self._system_prompt}, - {"role": "user", "content": f"Schemas:\n{context}\n\nQuestion: {query}"}, + {"role": "user", "content": user_content}, ] response = self._llm.invoke(messages) sql = self._extract_sql(response) diff --git a/src/lang2sql/components/loaders/__init__.py b/src/lang2sql/components/loaders/__init__.py new file mode 100644 index 0000000..80ae315 --- /dev/null +++ b/src/lang2sql/components/loaders/__init__.py @@ -0,0 +1,9 @@ +from .directory_ import DirectoryLoader +from .markdown_ import MarkdownLoader +from .plaintext_ import PlainTextLoader + +__all__ = [ + "MarkdownLoader", + "PlainTextLoader", + "DirectoryLoader", +] diff --git a/src/lang2sql/components/loaders/directory_.py b/src/lang2sql/components/loaders/directory_.py new file mode 100644 index 0000000..ccee0b9 --- /dev/null +++ b/src/lang2sql/components/loaders/directory_.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + +from ...core.catalog import TextDocument +from ...core.ports import DocumentLoaderPort +from .markdown_ import MarkdownLoader +from .plaintext_ import PlainTextLoader + + +class DirectoryLoader: + """ + Recursively loads a directory by dispatching each file to the loader + registered for its extension. + + Default mapping:: + + .md → MarkdownLoader + .txt → PlainTextLoader + + Custom loaders can be added or override defaults:: + + from lang2sql.integrations.loaders import PDFLoader + + docs = DirectoryLoader( + "docs/", + loaders={".md": MarkdownLoader(), ".pdf": PDFLoader()}, + ).load() + + Args: + path: Directory path to load. + loaders: Mapping of lowercase extension → DocumentLoaderPort. + Defaults to ``{".md": MarkdownLoader(), ".txt": PlainTextLoader()}``. + """ + + def __init__( + self, + path: str, + loaders: dict[str, DocumentLoaderPort] | None = None, + ) -> None: + self._path = Path(path) + self._loaders: dict[str, DocumentLoaderPort] = loaders or { + ".md": MarkdownLoader(), + ".txt": PlainTextLoader(), + } + + def load(self) -> list[TextDocument]: + """Recursively walk the directory and load all files with a registered extension.""" + docs: list[TextDocument] = [] + for file in sorted(self._path.rglob("*")): + if not file.is_file(): + continue + loader = self._loaders.get(file.suffix.lower()) + if loader is None: + continue + docs.extend(loader.load(str(file))) + return docs diff --git a/src/lang2sql/components/loaders/markdown_.py b/src/lang2sql/components/loaders/markdown_.py new file mode 100644 index 0000000..e5dc008 --- /dev/null +++ b/src/lang2sql/components/loaders/markdown_.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from pathlib import Path + +from ...core.catalog import TextDocument + + +class MarkdownLoader: + """ + Markdown file(s) (.md) → list[TextDocument]. + + Standard library only. No external dependencies. + + - Single file: ``load("docs/revenue.md")`` → [TextDocument] + - Directory: ``load("docs/")`` → [TextDocument, ...] (recursive) + + The first ``# heading`` becomes ``title``; the full file text becomes ``content``. + Falls back to the filename stem when no heading is found. + """ + + def load(self, path: str) -> list[TextDocument]: + p = Path(path) + if p.is_dir(): + return [doc for f in sorted(p.rglob("*.md")) for doc in self._load_file(f)] + return self._load_file(p) + + def _load_file(self, path: Path) -> list[TextDocument]: + content = path.read_text(encoding="utf-8") + title = "" + for line in content.splitlines(): + if line.startswith("# "): + title = line[2:].strip() + break + return [ + TextDocument( + id=path.stem, + title=title or path.stem, + content=content, + source=str(path), + ) + ] diff --git a/src/lang2sql/components/loaders/plaintext_.py b/src/lang2sql/components/loaders/plaintext_.py new file mode 100644 index 0000000..1f4ad5a --- /dev/null +++ b/src/lang2sql/components/loaders/plaintext_.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from pathlib import Path + +from ...core.catalog import TextDocument + + +class PlainTextLoader: + """ + Plain text file(s) (.txt, etc.) → list[TextDocument]. + + Standard library only. No external dependencies. + + - Single file: ``load("notes.txt")`` → [TextDocument] + - Directory: ``load("data/")`` → [TextDocument, ...] (recursive) + + ``title`` = filename stem, ``content`` = full file text. + + Args: + extensions: File extensions to load. Default ``[".txt"]``. + """ + + def __init__(self, extensions: list[str] | None = None) -> None: + self._extensions = extensions or [".txt"] + + def load(self, path: str) -> list[TextDocument]: + p = Path(path) + if p.is_dir(): + docs: list[TextDocument] = [] + for ext in self._extensions: + for f in sorted(p.rglob(f"*{ext}")): + docs.extend(self._load_file(f)) + return docs + return self._load_file(p) + + def _load_file(self, path: Path) -> list[TextDocument]: + content = path.read_text(encoding="utf-8") + return [ + TextDocument( + id=path.stem, + title=path.stem, + content=content, + source=str(path), + ) + ] diff --git a/src/lang2sql/components/retrieval/__init__.py b/src/lang2sql/components/retrieval/__init__.py index d912694..997cc7c 100644 --- a/src/lang2sql/components/retrieval/__init__.py +++ b/src/lang2sql/components/retrieval/__init__.py @@ -1,4 +1,18 @@ +from .chunker import CatalogChunker, DocumentChunkerPort, RecursiveCharacterChunker +from .hybrid import HybridRetriever from .keyword import KeywordRetriever -from ...core.catalog import CatalogEntry +from .vector import VectorRetriever +from ...core.catalog import CatalogEntry, IndexedChunk, RetrievalResult, TextDocument -__all__ = ["KeywordRetriever", "CatalogEntry"] +__all__ = [ + "KeywordRetriever", + "VectorRetriever", + "HybridRetriever", + "DocumentChunkerPort", + "CatalogChunker", + "RecursiveCharacterChunker", + "CatalogEntry", + "TextDocument", + "IndexedChunk", + "RetrievalResult", +] diff --git a/src/lang2sql/components/retrieval/chunker.py b/src/lang2sql/components/retrieval/chunker.py new file mode 100644 index 0000000..20ac9e5 --- /dev/null +++ b/src/lang2sql/components/retrieval/chunker.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from ...core.catalog import CatalogEntry, IndexedChunk, TextDocument + + +@runtime_checkable +class DocumentChunkerPort(Protocol): + """ + Interface for TextDocument → list[IndexedChunk] conversion. + + Default implementation: RecursiveCharacterChunker + Advanced implementation: SemanticChunker (integrations/chunking/semantic_.py) + Custom implementation: any class satisfying this Protocol can be passed as splitter. + + Example (wrapping LangChain):: + + class LangChainChunkerAdapter: + def __init__(self, lc_splitter): + self._splitter = lc_splitter + + def chunk(self, doc: TextDocument) -> list[IndexedChunk]: + texts = self._splitter.split_text(doc["content"]) + return [ + IndexedChunk( + chunk_id=f"{doc['id']}__{i}", text=t, + source_type="document", source_id=doc["id"], + chunk_index=i, metadata={"title": doc.get("title", "")}, + ) + for i, t in enumerate(texts) + ] + + retriever = VectorRetriever.from_sources(..., splitter=LangChainChunkerAdapter(...)) + """ + + def chunk(self, doc: TextDocument) -> list[IndexedChunk]: ... + + +class CatalogChunker: + """ + Converts a CatalogEntry into a list of IndexedChunks. + + Solves the problem where a table with 100+ columns loses column-level + semantics when represented as a single vector. + Chunk 0 is a table header summary; subsequent chunks are column groups. + Each chunk's metadata preserves the full CatalogEntry so VectorRetriever + can reconstruct it on retrieval. + + Args: + max_columns_per_chunk: Maximum columns per column-group chunk. Default 20. + """ + + def __init__(self, max_columns_per_chunk: int = 20) -> None: + self._max_cols = max_columns_per_chunk + + def split(self, catalog: list[CatalogEntry]) -> list[IndexedChunk]: + """LangChain-style batch split: list input → list output.""" + return [c for entry in catalog for c in self.chunk(entry)] + + def chunk(self, entry: CatalogEntry) -> list[IndexedChunk]: + name = entry.get("name", "") + description = entry.get("description", "") + columns = entry.get("columns", {}) + chunks: list[IndexedChunk] = [] + + # Chunk 0: table header + chunks.append( + IndexedChunk( + chunk_id=f"{name}__0", + text=f"{name}: {description}".strip(), + source_type="catalog", + source_id=name, + chunk_index=0, + metadata=dict(entry), # preserve full CatalogEntry for reconstruction + ) + ) + + # Chunk 1+: column groups + col_items = list(columns.items()) + for i, start in enumerate(range(0, len(col_items), self._max_cols)): + group = col_items[start : start + self._max_cols] + col_text = " ".join(f"{k} {v}" for k, v in group) + chunks.append( + IndexedChunk( + chunk_id=f"{name}__col_{i + 1}", + text=f"{name} columns: {col_text}", + source_type="catalog", + source_id=name, + chunk_index=i + 1, + metadata=dict( + entry + ), # preserve full CatalogEntry in every column chunk + ) + ) + + return chunks + + +class RecursiveCharacterChunker: + """ + Hierarchical separator-based document chunker. No external dependencies. + + Separator priority: ["\\n\\n", "\\n", ". ", " ", ""] + — tries paragraph → line → sentence → word boundaries in order. + Character-count-based so it works for both Korean and English + (unlike str.split() which assumes whitespace-delimited words). + + For higher chunking quality, use SemanticChunker (integrations/chunking/semantic_.py). + + Args: + chunk_size: Maximum characters per chunk. Default 1000. + chunk_overlap: Overlap characters between consecutive chunks. Default 100. + separators: Separator priority list. None uses the default list above. + """ + + _DEFAULT_SEPARATORS = ["\n\n", "\n", ". ", " ", ""] + + def __init__( + self, + chunk_size: int = 1000, + chunk_overlap: int = 100, + separators: list[str] | None = None, + ) -> None: + if chunk_overlap >= chunk_size: + raise ValueError( + f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})" + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._separators = separators or self._DEFAULT_SEPARATORS + + def split(self, docs: list[TextDocument]) -> list[IndexedChunk]: + """LangChain-style batch split: list input → list output.""" + return [c for doc in docs for c in self.chunk(doc)] + + def chunk(self, doc: TextDocument) -> list[IndexedChunk]: + content = doc.get("content", "") + if not content: + return [] + + raw_chunks = self._split(content, self._separators) + title = doc.get("title", "") + doc_id = doc.get("id", "") + + return [ + IndexedChunk( + chunk_id=f"{doc_id}__{i}", + text=f"{title}: {text}" if title else text, + source_type="document", + source_id=doc_id, + chunk_index=i, + metadata={ + "id": doc_id, + "title": title, + "source": doc.get("source", ""), + }, + ) + for i, text in enumerate(raw_chunks) + ] + + def _split(self, text: str, separators: list[str]) -> list[str]: + """Recursively try separators until all chunks fit within chunk_size.""" + chunks: list[str] = [] + separator = separators[-1] # fallback: character-level split + + for sep in separators: + if sep and sep in text: + separator = sep + break + + parts = text.split(separator) if separator else list(text) + current = "" + + for part in parts: + candidate = ( + (current + separator + part).lstrip(separator) if current else part + ) + if len(candidate) <= self._chunk_size: + current = candidate + else: + if current: + chunks.append(current) + # part itself exceeds chunk_size → recurse with finer separators + if len(part) > self._chunk_size and len(separators) > 1: + chunks.extend(self._split(part, separators[1:])) + current = "" + else: + current = part + + if current: + chunks.append(current) + + if self._chunk_overlap > 0 and len(chunks) > 1: + chunks = self._apply_overlap(chunks) + + return chunks + + def _apply_overlap(self, chunks: list[str]) -> list[str]: + overlapped = [chunks[0]] + for i in range(1, len(chunks)): + prev_tail = chunks[i - 1][-self._chunk_overlap :] + overlapped.append(prev_tail + chunks[i]) + return overlapped diff --git a/src/lang2sql/components/retrieval/hybrid.py b/src/lang2sql/components/retrieval/hybrid.py new file mode 100644 index 0000000..5a3df77 --- /dev/null +++ b/src/lang2sql/components/retrieval/hybrid.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import CatalogEntry, RetrievalResult, TextDocument +from ...core.hooks import TraceHook +from ...core.ports import EmbeddingPort +from .chunker import DocumentChunkerPort +from .keyword import KeywordRetriever +from .vector import VectorRetriever + + +class HybridRetriever(BaseComponent): + """ + BM25 + vector hybrid retriever. Merges results with Reciprocal Rank Fusion (RRF). + + RRF algorithm:: + + RRF_score(table) = Σ 1/(k + rank_i) for each ranker i + k = 60 (default, recommended by the RRF paper) + + Over-fetches ``top_n * 2`` candidates from each retriever, merges via RRF, + and returns the final ``top_n``. + Context is taken from VectorRetriever only (BM25 has no document context). + + Args: + catalog: List of CatalogEntry dicts. + embedding: EmbeddingPort implementation. + documents: Optional list of business documents to index. + splitter: Chunking strategy for documents (default: RecursiveCharacterChunker). + top_n: Maximum number of schemas to return. Default 5. + rrf_k: RRF smoothing constant. Default 60. + score_threshold: Minimum vector similarity score. Default 0.0. + name: Component name for tracing. + hook: TraceHook for observability. + + Usage:: + + retriever = HybridRetriever( + catalog=[{"name": "orders", "description": "...", "columns": {...}}], + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + ) + result = retriever("How many orders last month?") # RetrievalResult + """ + + def __init__( + self, + *, + catalog: list[CatalogEntry], + embedding: EmbeddingPort, + documents: Optional[list[TextDocument]] = None, + splitter: Optional[DocumentChunkerPort] = None, + top_n: int = 5, + rrf_k: int = 60, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "HybridRetriever", hook=hook) + fetch = top_n * 2 + self._keyword = KeywordRetriever(catalog=catalog, top_n=fetch) + self._vector = VectorRetriever.from_sources( + catalog=catalog, + embedding=embedding, + documents=documents, + splitter=splitter, + top_n=fetch, + score_threshold=score_threshold, + ) + self._top_n = top_n + self._rrf_k = rrf_k + + def _run(self, query: str) -> RetrievalResult: + """ + Args: + query: Natural language search query. + + Returns: + RetrievalResult: + .schemas — top_n schemas after RRF merge + .context — business document context from VectorRetriever + """ + keyword_schemas = self._keyword(query) # list[CatalogEntry] + vector_result = self._vector(query) # RetrievalResult + + merged = self._rrf_merge(keyword_schemas, vector_result.schemas) + return RetrievalResult(schemas=merged, context=vector_result.context) + + def _rrf_merge( + self, + keyword_schemas: list[CatalogEntry], + vector_schemas: list[CatalogEntry], + ) -> list[CatalogEntry]: + """Merge results from both retrievers via RRF and return top_n entries.""" + k = self._rrf_k + scores: dict[str, float] = {} + entries: dict[str, CatalogEntry] = {} + + for rank, entry in enumerate(keyword_schemas, start=1): + name = entry["name"] + scores[name] = scores.get(name, 0.0) + 1.0 / (k + rank) + entries[name] = entry + + for rank, entry in enumerate(vector_schemas, start=1): + name = entry["name"] + scores[name] = scores.get(name, 0.0) + 1.0 / (k + rank) + if name not in entries: + entries[name] = entry + + sorted_names = sorted(scores, key=lambda n: scores[n], reverse=True) + return [entries[n] for n in sorted_names[: self._top_n]] diff --git a/src/lang2sql/components/retrieval/vector.py b/src/lang2sql/components/retrieval/vector.py new file mode 100644 index 0000000..ca1c454 --- /dev/null +++ b/src/lang2sql/components/retrieval/vector.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import CatalogEntry, IndexedChunk, RetrievalResult, TextDocument +from ...core.hooks import TraceHook +from ...core.ports import EmbeddingPort, VectorStorePort +from .chunker import CatalogChunker, DocumentChunkerPort + + +class VectorRetriever(BaseComponent): + """ + Catalog + business document retrieval via vector similarity. + + RetrievalResult.schemas is deduplicated by source table — multiple chunks + from the same table produce only one CatalogEntry in the result. + + Two construction patterns: + + 1. One-touch factory (quick start): + retriever = VectorRetriever.from_sources(catalog=..., embedding=...) + retriever.add(RecursiveCharacterChunker().split(more_docs)) # incremental + + 2. Explicit pipeline (full control, LangChain-style): + chunks = ( + CatalogChunker().split(catalog) + + RecursiveCharacterChunker().split(docs) + ) + retriever = VectorRetriever.from_chunks(chunks, embedding=embedding, top_n=5) + retriever.add(RecursiveCharacterChunker().split(new_docs)) # incremental + + Args: + vectorstore: VectorStorePort implementation. + embedding: EmbeddingPort implementation. + registry: dict[chunk_id, IndexedChunk] mapping. + top_n: Maximum schemas and context items to return. Default 5. + score_threshold: Chunks with score <= this value are excluded. Default 0.0. + name: Component name for tracing. + hook: TraceHook for observability. + """ + + def __init__( + self, + *, + vectorstore: VectorStorePort, + embedding: EmbeddingPort, + registry: dict, + top_n: int = 5, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "VectorRetriever", hook=hook) + self._vectorstore = vectorstore + self._embedding = embedding + self._registry = registry + self._top_n = top_n + self._score_threshold = score_threshold + + @classmethod + def from_chunks( + cls, + chunks: list[IndexedChunk], + *, + embedding: EmbeddingPort, + vectorstore: Optional[VectorStorePort] = None, + top_n: int = 5, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> "VectorRetriever": + """ + LangChain-style factory: build from pre-split chunks. + + Embeds and stores the given chunks; no splitting is performed here. + Use chunker.split(docs) before calling this method. + + Args: + chunks: Pre-split list[IndexedChunk] (e.g. from CatalogChunker.split()). + embedding: EmbeddingPort implementation. + vectorstore: Defaults to InMemoryVectorStore. + top_n: Maximum schemas and context items to return. Default 5. + score_threshold: Score cutoff. Default 0.0. + """ + from ...integrations.vectorstore.inmemory_ import InMemoryVectorStore + + store = vectorstore or InMemoryVectorStore() + registry: dict = {} + if chunks: + ids = [c["chunk_id"] for c in chunks] + texts = [c["text"] for c in chunks] + vectors = embedding.embed_texts(texts) + store.upsert(ids, vectors) + registry.update({c["chunk_id"]: c for c in chunks}) + + return cls( + vectorstore=store, + embedding=embedding, + registry=registry, + top_n=top_n, + score_threshold=score_threshold, + name=name, + hook=hook, + ) + + @classmethod + def from_sources( + cls, + *, + catalog: list[CatalogEntry], + embedding: EmbeddingPort, + documents: Optional[list[TextDocument]] = None, + splitter: Optional[DocumentChunkerPort] = None, + vectorstore: Optional[VectorStorePort] = None, + top_n: int = 5, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> "VectorRetriever": + """ + One-touch factory: chunk, embed, and index in a single call. + + Internally calls from_chunks() after splitting catalog and documents. + For incremental addition after construction, use retriever.add(chunks). + + Args: + catalog: List of CatalogEntry dicts to index. + embedding: EmbeddingPort implementation. + documents: Optional list of TextDocument to index alongside catalog. + splitter: Chunker for documents. Defaults to RecursiveCharacterChunker. + Pass SemanticChunker(embedding=...) for higher quality. + vectorstore: Defaults to InMemoryVectorStore. + top_n: Maximum schemas and context items to return. Default 5. + score_threshold: Score cutoff. Default 0.0. + """ + from .chunker import RecursiveCharacterChunker + + _splitter = splitter or RecursiveCharacterChunker() + chunks = CatalogChunker().split(catalog) + if documents: + chunks = chunks + _splitter.split(documents) + + return cls.from_chunks( + chunks, + embedding=embedding, + vectorstore=vectorstore, + top_n=top_n, + score_threshold=score_threshold, + name=name, + hook=hook, + ) + + def add(self, chunks: list[IndexedChunk]) -> None: + """ + Add pre-split chunks to the index incrementally. + + Use chunker.split(docs) before calling this method. + + Args: + chunks: list[IndexedChunk] from chunker.split(). + """ + if not chunks: + return + ids = [c["chunk_id"] for c in chunks] + texts = [c["text"] for c in chunks] + vectors = self._embedding.embed_texts(texts) + self._vectorstore.upsert(ids, vectors) + self._registry.update({c["chunk_id"]: c for c in chunks}) + + def _run(self, query: str) -> RetrievalResult: + """ + Args: + query: Natural language search query. + + Returns: + RetrievalResult: + .schemas — relevant CatalogEntry list (deduplicated, at most top_n) + .context — relevant business document chunk texts (at most top_n) + """ + if not self._registry: + return RetrievalResult(schemas=[], context=[]) + + query_vector = self._embedding.embed_query(query) + # over-fetch by 3x so deduplication still yields top_n catalog entries + raw = self._vectorstore.search(query_vector, k=self._top_n * 3) + + seen_tables: dict[str, CatalogEntry] = {} # source_id → CatalogEntry (dedup) + context: list[str] = [] + + for chunk_id, score in raw: + if score <= self._score_threshold: + continue + chunk = self._registry.get(chunk_id) + if chunk is None: + continue + + if chunk["source_type"] == "catalog": + src = chunk["source_id"] + if src not in seen_tables: + seen_tables[src] = chunk["metadata"] # full CatalogEntry + elif chunk["source_type"] == "document": + context.append(chunk["text"]) + + return RetrievalResult( + schemas=list(seen_tables.values())[: self._top_n], + context=context[: self._top_n], + ) diff --git a/src/lang2sql/core/catalog.py b/src/lang2sql/core/catalog.py index 469a250..bda5bf9 100644 --- a/src/lang2sql/core/catalog.py +++ b/src/lang2sql/core/catalog.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass, field from typing import TypedDict @@ -7,3 +8,32 @@ class CatalogEntry(TypedDict, total=False): name: str description: str columns: dict[str, str] + + +class TextDocument(TypedDict, total=False): + """Business document — data dictionaries, business rules, FAQ, etc.""" + + id: str # Unique identifier (required) + title: str # Document title (required) + content: str # Full body text (required) + source: str # File path, URL, etc. (optional) + metadata: dict # Free-form additional info (optional) + + +class IndexedChunk(TypedDict): + """Minimum indexing unit stored in the vector store. Shared by catalog and document chunks.""" + + chunk_id: str # e.g. "orders__0", "orders__col_1", "bizrule__2" + text: str # Text to embed + source_type: str # "catalog" | "document" + source_id: str # Table name (catalog) or doc id (document) + chunk_index: int # Position within the source (0-based) + metadata: dict # catalog → full CatalogEntry / document → document meta + + +@dataclass +class RetrievalResult: + """Return value of VectorRetriever — schema list + domain context.""" + + schemas: list[CatalogEntry] = field(default_factory=list) + context: list[str] = field(default_factory=list) diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py index fa6b3e4..ff6bc6f 100644 --- a/src/lang2sql/core/ports.py +++ b/src/lang2sql/core/ports.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import Any, Protocol +from typing import Any, Protocol, runtime_checkable + +from .catalog import TextDocument class LLMPort(Protocol): @@ -16,13 +18,42 @@ def execute(self, sql: str) -> list[dict[str, Any]]: ... class EmbeddingPort(Protocol): - """ - Placeholder — will be implemented in OQ-2 (VectorRetriever). - - Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.) - so VectorRetriever is not coupled to any specific provider. - """ + """Abstracts embedding backends (OpenAI, Azure, Bedrock, etc.).""" def embed_query(self, text: str) -> list[float]: ... def embed_texts(self, texts: list[str]) -> list[list[float]]: ... + + +class VectorStorePort(Protocol): + """Abstracts vector store backends (InMemory, FAISS, pgvector, etc.).""" + + def search(self, vector: list[float], k: int) -> list[tuple[str, float]]: + """ + Return the k nearest vectors. + + Returns: + List of (chunk_id, score) sorted by score descending. + Score range: [-1, 1] (cosine similarity). + """ + ... + + def upsert(self, ids: list[str], vectors: list[list[float]]) -> None: + """ + Store or update vectors by chunk_id. + + Implementations must merge incoming entries into existing ones — + calling upsert twice must not lose entries from the first call. + + Args: + ids: List of chunk_ids. + vectors: Corresponding embedding vectors (len(ids) == len(vectors)). + """ + ... + + +@runtime_checkable +class DocumentLoaderPort(Protocol): + """Converts a file path or directory to list[TextDocument].""" + + def load(self, path: str) -> list[TextDocument]: ... diff --git a/src/lang2sql/flows/hybrid.py b/src/lang2sql/flows/hybrid.py new file mode 100644 index 0000000..b30575a --- /dev/null +++ b/src/lang2sql/flows/hybrid.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Optional + +from ..components.execution.sql_executor import SQLExecutor +from ..components.generation.sql_generator import SQLGenerator +from ..components.retrieval.hybrid import HybridRetriever +from ..core.base import BaseFlow +from ..core.catalog import TextDocument +from ..core.hooks import TraceHook +from ..core.ports import DBPort, EmbeddingPort, LLMPort + + +class HybridNL2SQL(BaseFlow): + """ + NL→SQL pipeline backed by BM25 + vector hybrid retrieval. + + Provides higher retrieval quality than ``BaselineNL2SQL`` with only the + addition of an ``embedding`` parameter. + + Pipeline: HybridRetriever → SQLGenerator → SQLExecutor + + Args: + catalog: List of CatalogEntry dicts. + llm: LLMPort implementation. + db: DBPort implementation. + embedding: EmbeddingPort implementation. + documents: Optional list of business documents to index. + db_dialect: SQL dialect. Supported values: ``"sqlite"``, ``"postgresql"``, + ``"mysql"``, ``"bigquery"``, ``"duckdb"``, ``"default"`` + (or ``None`` for default). + top_n: Maximum number of schemas to return. Default 5. + hook: TraceHook for observability. + + Usage:: + + pipeline = HybridNL2SQL( + catalog=[{"name": "orders", "description": "...", "columns": {...}}], + llm=AnthropicLLM(model="claude-sonnet-4-6"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + db_dialect="sqlite", + ) + rows = pipeline.run("How many orders last month?") + """ + + def __init__( + self, + *, + catalog: list[dict], + llm: LLMPort, + db: DBPort, + embedding: EmbeddingPort, + documents: Optional[list[TextDocument]] = None, + db_dialect: Optional[str] = None, + top_n: int = 5, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name="HybridNL2SQL", hook=hook) + self._retriever = HybridRetriever( + catalog=catalog, + embedding=embedding, + documents=documents, + top_n=top_n, + hook=hook, + ) + self._generator = SQLGenerator(llm=llm, db_dialect=db_dialect, hook=hook) + self._executor = SQLExecutor(db=db, hook=hook) + + def _run(self, query: str) -> list[dict]: + result = self._retriever(query) # RetrievalResult + sql = self._generator(query, result.schemas, context=result.context) + return self._executor(sql) diff --git a/src/lang2sql/flows/nl2sql.py b/src/lang2sql/flows/nl2sql.py index 26b480c..530798a 100644 --- a/src/lang2sql/flows/nl2sql.py +++ b/src/lang2sql/flows/nl2sql.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Optional from ..components.execution.sql_executor import SQLExecutor from ..components.generation.sql_generator import SQLGenerator @@ -24,7 +24,7 @@ class BaselineNL2SQL(BaseFlow): db=SQLAlchemyDB("sqlite:///sample.db"), db_dialect="sqlite", ) - rows = pipeline.run("지난달 주문 건수") + rows = pipeline.run("How many orders last month?") Supported ``db_dialect`` values: ``"sqlite"``, ``"postgresql"``, ``"mysql"``, ``"bigquery"``, ``"duckdb"``, ``"default"`` (or ``None`` for default). @@ -44,7 +44,7 @@ def __init__( self._generator = SQLGenerator(llm=llm, db_dialect=db_dialect, hook=hook) self._executor = SQLExecutor(db=db, hook=hook) - def _run(self, query: str) -> list[dict[str, Any]]: - schemas = self._retriever(query) + def _run(self, query: str) -> list[dict]: + schemas = self._retriever(query) # list[CatalogEntry] sql = self._generator(query, schemas) return self._executor(sql) diff --git a/src/lang2sql/integrations/chunking/__init__.py b/src/lang2sql/integrations/chunking/__init__.py new file mode 100644 index 0000000..f80fad9 --- /dev/null +++ b/src/lang2sql/integrations/chunking/__init__.py @@ -0,0 +1,3 @@ +from .semantic_ import SemanticChunker + +__all__ = ["SemanticChunker"] diff --git a/src/lang2sql/integrations/chunking/semantic_.py b/src/lang2sql/integrations/chunking/semantic_.py new file mode 100644 index 0000000..523b5a6 --- /dev/null +++ b/src/lang2sql/integrations/chunking/semantic_.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from ...core.catalog import IndexedChunk, TextDocument +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + + +class SemanticChunker: + """ + Embedding-based semantic chunker. Optional — explicit opt-in only. + + Splits text into sentences, computes cosine similarity between adjacent + sentences, and treats similarity drop-offs as chunk boundaries. + Produces more semantically coherent chunks than RecursiveCharacterChunker. + + Note: embedding API is called during indexing (not just at query time). + Consider the cost and latency before choosing this over RecursiveCharacterChunker. + + Limitation: sentence splitting uses punctuation and newlines. Languages without + whitespace word boundaries (e.g. Korean) rely on sentence-ending punctuation + (., !, ?, 。) or newlines for splits. Accuracy may vary for dense prose. + + Args: + embedding: EmbeddingPort implementation (can be shared with VectorRetriever). + breakpoint_threshold: Similarity drop threshold for boundary detection. Default 0.3. + Lower values produce more (smaller) chunks. + min_chunk_size: Minimum characters per chunk; short chunks are merged + into the previous one. Default 100. + + Example:: + + from lang2sql.integrations.embedding import OpenAIEmbedding + from lang2sql.integrations.chunking import SemanticChunker + + embedding = OpenAIEmbedding() + chunker = SemanticChunker(embedding=embedding) + + retriever = VectorRetriever.from_sources(..., splitter=chunker) + """ + + def __init__( + self, + *, + embedding: EmbeddingPort, + breakpoint_threshold: float = 0.3, + min_chunk_size: int = 100, + ) -> None: + try: + import numpy as _np # noqa: F401 + except ImportError: + raise IntegrationMissingError("numpy", hint="pip install numpy") + self._embedding = embedding + self._threshold = breakpoint_threshold + self._min_size = min_chunk_size + + def split(self, docs: list[TextDocument]) -> list[IndexedChunk]: + """LangChain-style batch split: list input → list output.""" + return [c for doc in docs for c in self.chunk(doc)] + + def chunk(self, doc: TextDocument) -> list[IndexedChunk]: + import numpy as np + + content = doc.get("content", "") + if not content: + return [] + + sentences = self._split_sentences(content) + if len(sentences) <= 1: + return self._make_chunks(doc, [content]) + + embeddings = self._embedding.embed_texts(sentences) + mat = np.array(embeddings, dtype=np.float32) + norms = np.linalg.norm(mat, axis=1, keepdims=True) + mat = mat / (norms + 1e-8) + + # cosine similarity between adjacent sentences — shape: (n-1,) + sims = (mat[:-1] * mat[1:]).sum(axis=1) + + # positions where similarity drops sharply are chunk boundaries + boundaries = [0] + for i, sim in enumerate(sims): + if sim < (1.0 - self._threshold): + boundaries.append(i + 1) + boundaries.append(len(sentences)) + + raw_chunks: list[str] = [] + for start, end in zip(boundaries, boundaries[1:]): + chunk_text = " ".join(sentences[start:end]) + if len(chunk_text) < self._min_size and raw_chunks: + raw_chunks[-1] += ( + " " + chunk_text + ) # merge short trailing chunk into previous + else: + raw_chunks.append(chunk_text) + + return self._make_chunks(doc, raw_chunks) + + def _split_sentences(self, text: str) -> list[str]: + """Split on sentence-ending punctuation or newlines. No external tokenizer needed.""" + import re + + parts = re.split(r"(?<=[.!?。])\s+|\n+", text.strip()) + return [p.strip() for p in parts if p.strip()] + + def _make_chunks(self, doc: TextDocument, texts: list[str]) -> list[IndexedChunk]: + title = doc.get("title", "") + doc_id = doc.get("id", "") + return [ + IndexedChunk( + chunk_id=f"{doc_id}__{i}", + text=f"{title}: {text}" if title else text, + source_type="document", + source_id=doc_id, + chunk_index=i, + metadata={ + "id": doc_id, + "title": title, + "source": doc.get("source", ""), + }, + ) + for i, text in enumerate(texts) + ] diff --git a/src/lang2sql/integrations/embedding/__init__.py b/src/lang2sql/integrations/embedding/__init__.py new file mode 100644 index 0000000..b427c1e --- /dev/null +++ b/src/lang2sql/integrations/embedding/__init__.py @@ -0,0 +1,3 @@ +from .openai_ import OpenAIEmbedding + +__all__ = ["OpenAIEmbedding"] diff --git a/src/lang2sql/integrations/embedding/openai_.py b/src/lang2sql/integrations/embedding/openai_.py new file mode 100644 index 0000000..902a764 --- /dev/null +++ b/src/lang2sql/integrations/embedding/openai_.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError + +try: + import openai as _openai +except ImportError: + _openai = None # type: ignore[assignment] + + +class OpenAIEmbedding: + """EmbeddingPort implementation backed by OpenAI Embeddings API.""" + + def __init__( + self, *, model: str = "text-embedding-3-small", api_key: str | None = None + ) -> None: + if _openai is None: + raise IntegrationMissingError("openai", hint="pip install openai") + self._client = _openai.OpenAI(api_key=api_key) + self._model = model + + def embed_query(self, text: str) -> list[float]: + return ( + self._client.embeddings.create(input=text, model=self._model) + .data[0] + .embedding + ) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + resp = self._client.embeddings.create(input=texts, model=self._model) + return [item.embedding for item in resp.data] diff --git a/src/lang2sql/integrations/loaders/__init__.py b/src/lang2sql/integrations/loaders/__init__.py new file mode 100644 index 0000000..c5d095f --- /dev/null +++ b/src/lang2sql/integrations/loaders/__init__.py @@ -0,0 +1,3 @@ +from .pdf_ import PDFLoader + +__all__ = ["PDFLoader"] diff --git a/src/lang2sql/integrations/loaders/pdf_.py b/src/lang2sql/integrations/loaders/pdf_.py new file mode 100644 index 0000000..1ddb901 --- /dev/null +++ b/src/lang2sql/integrations/loaders/pdf_.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + +from ...core.catalog import TextDocument +from ...core.exceptions import IntegrationMissingError + +try: + import fitz as _fitz +except ImportError: + _fitz = None # type: ignore[assignment] + + +class PDFLoader: + """ + PDF file → list[TextDocument]. + + Requires pymupdf (fitz) as an optional dependency. + Raises ``IntegrationMissingError`` if not installed. + + Produces one TextDocument per page:: + + id = "{filename}__p{page_number}" (1-indexed) + title = "{filename} page {page_number}" + content = extracted text of that page + source = file path + + Usage:: + + from lang2sql.integrations.loaders import PDFLoader + + docs = PDFLoader().load("report.pdf") + + Installation:: + + pip install pymupdf + """ + + def load(self, path: str) -> list[TextDocument]: + if _fitz is None: + raise IntegrationMissingError( + "pymupdf", + hint="pip install pymupdf", + ) + p = Path(path) + pdf = _fitz.open(str(p)) + docs: list[TextDocument] = [] + for i, page in enumerate(pdf, start=1): + docs.append( + TextDocument( + id=f"{p.stem}__p{i}", + title=f"{p.stem} page {i}", + content=page.get_text(), + source=str(p), + ) + ) + return docs diff --git a/src/lang2sql/integrations/vectorstore/__init__.py b/src/lang2sql/integrations/vectorstore/__init__.py new file mode 100644 index 0000000..bddace4 --- /dev/null +++ b/src/lang2sql/integrations/vectorstore/__init__.py @@ -0,0 +1,3 @@ +from .inmemory_ import InMemoryVectorStore + +__all__ = ["InMemoryVectorStore"] diff --git a/src/lang2sql/integrations/vectorstore/inmemory_.py b/src/lang2sql/integrations/vectorstore/inmemory_.py new file mode 100644 index 0000000..ad0169f --- /dev/null +++ b/src/lang2sql/integrations/vectorstore/inmemory_.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError + +try: + import numpy as _np +except ImportError: + _np = None # type: ignore[assignment] + + +class InMemoryVectorStore: + """ + Brute-force cosine similarity vector store backed by numpy. + + upsert merges new entries into the existing store (true upsert — no data loss + across multiple calls). Rebuilds the matrix from a dict on each search call, + so duplicate chunk_ids are overwritten rather than duplicated. + + Handles tens of thousands of vectors without issue. No faiss dependency required. + + For larger scale or persistence, use FAISSVectorStore / PGVectorStore (next PR). + For advanced vector stores (Chroma, Qdrant, etc.), implement VectorStorePort directly. + """ + + def __init__(self) -> None: + if _np is None: + raise IntegrationMissingError("numpy", hint="pip install numpy") + self._store: dict[str, list[float]] = {} # chunk_id → vector + + def upsert(self, ids: list[str], vectors: list[list[float]]) -> None: + # Merge into existing store — preserves vectors from previous upsert calls. + # Duplicate ids overwrite the existing entry (true upsert semantics). + for id_, vec in zip(ids, vectors): + self._store[id_] = vec + + def search(self, vector: list[float], k: int) -> list[tuple[str, float]]: + if not self._store: + return [] + + ids = list(self._store.keys()) + matrix = _np.array(list(self._store.values()), dtype=_np.float32) + q = _np.array(vector, dtype=_np.float32) + + norms = _np.linalg.norm(matrix, axis=1) + q_norm = _np.linalg.norm(q) + sims = matrix @ q / (norms * q_norm + 1e-8) + + k = min(k, len(ids)) + top_k = _np.argsort(sims)[::-1][:k] + return [(ids[int(i)], float(sims[i])) for i in top_k] diff --git a/tests/test_components_hybrid_retriever.py b/tests/test_components_hybrid_retriever.py new file mode 100644 index 0000000..5fb7f9f --- /dev/null +++ b/tests/test_components_hybrid_retriever.py @@ -0,0 +1,282 @@ +""" +Tests for HybridRetriever and HybridNL2SQL — 8 cases. + +SmartFakeEmbedding maps marker keywords in text to orthogonal unit vectors: + - "bothdim" → [0.7071, 0.0, 0.7071, 0.0] (found by both BM25 and vector) + - "kwcommon" → [0.0, 1.0, 0.0, 0.0] (positive BM25, zero cosine with query) + - "vdimonly" → [0.0, 0.0, 1.0, 0.0] (zero BM25, positive cosine with query) + - other → [0.0, 0.0, 0.0, 1.0] + +Query "BOTHDIM KWCOMMON" embeds to [0.7071, 0, 0.7071, 0] (has "bothdim"). + - both_table (bothdim): cosine 1.0 → in vector ✓ ; BM25 matches "bothdim"+"kwcommon" ✓ + - kwonly_table(kwcommon): cosine 0.0 → NOT in vector ✓ ; BM25 matches "kwcommon" ✓ + - veconly_table(vdimonly): cosine 0.707 → in vector ✓ ; BM25 score 0 → not returned ✓ +""" + +from __future__ import annotations + +import pytest + +from lang2sql.components.retrieval.hybrid import HybridRetriever +from lang2sql.core.catalog import RetrievalResult +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.hybrid import HybridNL2SQL + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class FakeEmbedding: + """Uniform embedding — all texts and queries map to the same unit vector.""" + + def embed_query(self, text: str) -> list[float]: + return [0.5, 0.5, 0.5, 0.5] + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [[0.5, 0.5, 0.5, 0.5]] * len(texts) + + +class SmartFakeEmbedding: + """ + Marker-based deterministic embedding for controlled retrieval testing. + + Marker priority (first match wins): + "bothdim" → [√½, 0, √½, 0] — high cosine with query on dims 0 and 2 + "vdimonly" → [0, 0, 1, 0] — matches query dim 2 only + "kwcommon" → [0, 1, 0, 0] — orthogonal to query → cosine 0.0 (excluded) + else → [0, 0, 0, 1] — orthogonal to query → cosine 0.0 (excluded) + + Query "BOTHDIM KWCOMMON" has "bothdim" → embeds to [√½, 0, √½, 0]. + """ + + _SQRT2_INV = 2.0**-0.5 # ≈ 0.7071 + + def _embed(self, text: str) -> list[float]: + t = text.lower() + if "bothdim" in t: + return [self._SQRT2_INV, 0.0, self._SQRT2_INV, 0.0] + if "vdimonly" in t: + return [0.0, 0.0, 1.0, 0.0] + if "kwcommon" in t: + return [0.0, 1.0, 0.0, 0.0] + return [0.0, 0.0, 0.0, 1.0] + + def embed_query(self, text: str) -> list[float]: + return self._embed(text) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [self._embed(t) for t in texts] + + +class FakeLLM: + def __init__(self, response: str = "```sql\nSELECT COUNT(*) FROM orders\n```"): + self._response = response + + def invoke(self, messages: list[dict]) -> str: + return self._response + + +class FakeDB: + def __init__(self, rows: list[dict] | None = None): + self._rows = rows if rows is not None else [{"count": 1}] + + def execute(self, sql: str) -> list[dict]: + return self._rows + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +CATALOG_SIMPLE = [ + {"name": "orders", "description": "order records", "columns": {"id": "pk"}}, + {"name": "customers", "description": "customer data", "columns": {"id": "pk"}}, +] + +# Three-table catalog designed for deterministic BM25 vs vector split: +# - both_table: found by BOTH BM25 ("bothdim"+"kwcommon" in description) AND vector (cosine 1.0) +# - kwonly_table: found by BM25 only ("kwcommon" matches query token) → vector cosine 0.0 +# - veconly_table:found by vector only (vdimonly → cosine 0.707) → BM25 score 0 +CATALOG_HYBRID = [ + { + "name": "both_table", + "description": "BOTHDIM KWCOMMON data", + "columns": {"id": "pk"}, + }, + { + "name": "kwonly_table", + "description": "KWCOMMON unique info", + "columns": {"id": "pk"}, + }, + { + "name": "veconly_table", + "description": "VDIMONLY specific", + "columns": {"id": "pk"}, + }, +] + +DOCS = [ + { + "id": "revenue_doc", + "title": "Revenue Rules", + "content": "Revenue is net sales minus returns.", + "source": "docs/revenue.md", + } +] + +# Query that: +# - BM25-matches "both_table" (has "bothdim" and "kwcommon") and "kwonly_table" (has "kwcommon") +# - Embeds to [√½, 0, √½, 0] via SmartFakeEmbedding (triggered by "bothdim") +# - cosine("both_table") = 1.0, cosine("kwonly_table") = 0.0, cosine("veconly_table") = 0.707 +HYBRID_QUERY = "BOTHDIM KWCOMMON" + + +# --------------------------------------------------------------------------- +# 1. Return type is RetrievalResult +# --------------------------------------------------------------------------- + + +def test_hybrid_returns_retrieval_result(): + retriever = HybridRetriever( + catalog=CATALOG_SIMPLE, + embedding=FakeEmbedding(), + ) + result = retriever("orders") + assert isinstance(result, RetrievalResult) + assert isinstance(result.schemas, list) + assert isinstance(result.context, list) + + +# --------------------------------------------------------------------------- +# 2. RRF combines both retrievers — all three match types present +# --------------------------------------------------------------------------- + + +def test_hybrid_rrf_combines_both_retrievers(): + """keyword-only, vector-only, and both-matched tables must all appear in results.""" + retriever = HybridRetriever( + catalog=CATALOG_HYBRID, + embedding=SmartFakeEmbedding(), + top_n=3, + ) + result = retriever(HYBRID_QUERY) + names = {s["name"] for s in result.schemas} + + assert "both_table" in names, "both_table (found by both) must be in results" + assert "kwonly_table" in names, "kwonly_table (BM25-only) must be in results" + assert "veconly_table" in names, "veconly_table (vector-only) must be in results" + + +# --------------------------------------------------------------------------- +# 3. Tables found by both retrievers rank higher than single-retriever tables +# --------------------------------------------------------------------------- + + +def test_hybrid_rrf_ranks_overlap_higher(): + """A table found by both retrievers must rank higher than any single-retriever table.""" + retriever = HybridRetriever( + catalog=CATALOG_HYBRID, + embedding=SmartFakeEmbedding(), + top_n=3, + ) + result = retriever(HYBRID_QUERY) + names = [s["name"] for s in result.schemas] + + assert names[0] == "both_table", "both_table (highest RRF score) must be rank 1" + assert names.index("both_table") < names.index("kwonly_table") + assert names.index("both_table") < names.index("veconly_table") + + +# --------------------------------------------------------------------------- +# 4. top_n limits schemas count +# --------------------------------------------------------------------------- + + +def test_hybrid_top_n_limits_schemas(): + retriever = HybridRetriever( + catalog=CATALOG_HYBRID, + embedding=SmartFakeEmbedding(), + top_n=2, + ) + result = retriever(HYBRID_QUERY) + assert len(result.schemas) <= 2 + + +# --------------------------------------------------------------------------- +# 5. context comes only from VectorRetriever +# --------------------------------------------------------------------------- + + +def test_hybrid_context_from_vector(): + """Context must come from VectorRetriever only (document chunk text).""" + retriever = HybridRetriever( + catalog=CATALOG_SIMPLE, + embedding=FakeEmbedding(), + documents=DOCS, + top_n=5, + ) + result = retriever("revenue rules") + + assert isinstance(result.context, list) + assert len(result.context) > 0, "document text must appear in context" + assert any("Revenue" in c for c in result.context) + + +# --------------------------------------------------------------------------- +# 6. Hook records start/end events for HybridRetriever itself +# --------------------------------------------------------------------------- + + +def test_hybrid_hook_events(): + hook = MemoryHook() + retriever = HybridRetriever( + catalog=CATALOG_SIMPLE, + embedding=FakeEmbedding(), + hook=hook, + ) + retriever("orders") + + hybrid_events = [e for e in hook.snapshot() if e.component == "HybridRetriever"] + assert any(e.phase == "start" for e in hybrid_events) + assert any(e.phase == "end" for e in hybrid_events) + end_event = next(e for e in hybrid_events if e.phase == "end") + assert end_event.duration_ms is not None + assert end_event.duration_ms >= 0.0 + + +# --------------------------------------------------------------------------- +# 7. HybridNL2SQL end-to-end pipeline +# --------------------------------------------------------------------------- + + +def test_hybrid_nl2sql_pipeline(): + """HybridNL2SQL end-to-end with FakeLLM + FakeDB.""" + rows = [{"count": 7}] + pipeline = HybridNL2SQL( + catalog=CATALOG_SIMPLE, + llm=FakeLLM(), + db=FakeDB(rows), + embedding=FakeEmbedding(), + ) + result = pipeline.run("How many orders last month?") + assert result == rows + + +# --------------------------------------------------------------------------- +# 8. _rrf_merge deduplication — same table in both → score combined, no duplicate +# --------------------------------------------------------------------------- + + +def test_rrf_merge_deduplication(): + """Same table in both retrievers → scores combined, no duplicate in results.""" + retriever = HybridRetriever( + catalog=CATALOG_SIMPLE, + embedding=FakeEmbedding(), + ) + entry = {"name": "orders", "description": "order data", "columns": {}} + + merged = retriever._rrf_merge([entry], [entry]) + + assert len(merged) == 1, "duplicate table must appear only once" + assert merged[0]["name"] == "orders" diff --git a/tests/test_components_loaders.py b/tests/test_components_loaders.py new file mode 100644 index 0000000..5b0e525 --- /dev/null +++ b/tests/test_components_loaders.py @@ -0,0 +1,141 @@ +""" +Tests for MarkdownLoader, PlainTextLoader, DirectoryLoader, DocumentLoaderPort — 8 cases. + +Uses pytest tmp_path fixture to create temporary files in isolation. +""" + +from __future__ import annotations + +import pytest + +from lang2sql.components.loaders import DirectoryLoader, MarkdownLoader, PlainTextLoader +from lang2sql.core.ports import DocumentLoaderPort + +# --------------------------------------------------------------------------- +# 1. MarkdownLoader — single file: TextDocument fields are correct +# --------------------------------------------------------------------------- + + +def test_markdown_loader_single_file(tmp_path): + md_file = tmp_path / "revenue.md" + md_file.write_text( + "# Revenue Definition\n\nRevenue is net sales.", encoding="utf-8" + ) + + docs = MarkdownLoader().load(str(md_file)) + + assert len(docs) == 1 + doc = docs[0] + assert doc["id"] == "revenue" + assert doc["title"] == "Revenue Definition" + assert "Revenue is net sales" in doc["content"] + assert doc["source"] == str(md_file) + + +# --------------------------------------------------------------------------- +# 2. MarkdownLoader — directory: returns one doc per .md file +# --------------------------------------------------------------------------- + + +def test_markdown_loader_directory(tmp_path): + (tmp_path / "a.md").write_text("# A\ncontent a", encoding="utf-8") + (tmp_path / "b.md").write_text("# B\ncontent b", encoding="utf-8") + (tmp_path / "notes.txt").write_text("plain text", encoding="utf-8") # ignored + + docs = MarkdownLoader().load(str(tmp_path)) + + assert len(docs) == 2 + ids = {d["id"] for d in docs} + assert ids == {"a", "b"} + + +# --------------------------------------------------------------------------- +# 3. MarkdownLoader — title extracted from first # heading +# --------------------------------------------------------------------------- + + +def test_markdown_loader_title_from_heading(tmp_path): + md_file = tmp_path / "doc.md" + md_file.write_text("Some intro\n# My Title\nBody text.", encoding="utf-8") + + docs = MarkdownLoader().load(str(md_file)) + + # The first # heading (even if not the first line) is used as title + assert docs[0]["title"] == "My Title" + + +# --------------------------------------------------------------------------- +# 4. MarkdownLoader — no heading → title falls back to filename stem +# --------------------------------------------------------------------------- + + +def test_markdown_loader_no_heading(tmp_path): + md_file = tmp_path / "quarterly_report.md" + md_file.write_text("Just some content without a heading.", encoding="utf-8") + + docs = MarkdownLoader().load(str(md_file)) + + assert docs[0]["title"] == "quarterly_report" + + +# --------------------------------------------------------------------------- +# 5. PlainTextLoader — single file: content is correct +# --------------------------------------------------------------------------- + + +def test_plaintext_loader_single_file(tmp_path): + txt_file = tmp_path / "notes.txt" + txt_file.write_text("line one\nline two\n", encoding="utf-8") + + docs = PlainTextLoader().load(str(txt_file)) + + assert len(docs) == 1 + doc = docs[0] + assert doc["id"] == "notes" + assert doc["title"] == "notes" + assert "line one" in doc["content"] + assert doc["source"] == str(txt_file) + + +# --------------------------------------------------------------------------- +# 6. DirectoryLoader — dispatches by extension +# --------------------------------------------------------------------------- + + +def test_directory_loader_dispatches_by_extension(tmp_path): + (tmp_path / "guide.md").write_text("# Guide\nMarkdown content.", encoding="utf-8") + (tmp_path / "data.txt").write_text("plain text content", encoding="utf-8") + + docs = DirectoryLoader(str(tmp_path)).load() + + ids = {d["id"] for d in docs} + assert "guide" in ids # loaded by MarkdownLoader + assert "data" in ids # loaded by PlainTextLoader + + +# --------------------------------------------------------------------------- +# 7. DirectoryLoader — skips unknown extensions +# --------------------------------------------------------------------------- + + +def test_directory_loader_skips_unknown_extension(tmp_path): + (tmp_path / "script.py").write_text("print('hello')", encoding="utf-8") + (tmp_path / "data.csv").write_text("a,b,c\n1,2,3\n", encoding="utf-8") + (tmp_path / "readme.md").write_text("# Readme\ncontent", encoding="utf-8") + + docs = DirectoryLoader(str(tmp_path)).load() + + ids = {d["id"] for d in docs} + assert "readme" in ids # .md is loaded + assert "script" not in ids # .py is skipped + assert "data" not in ids # .csv is skipped (not in default loaders) + + +# --------------------------------------------------------------------------- +# 8. Protocol check — MarkdownLoader and PlainTextLoader satisfy DocumentLoaderPort +# --------------------------------------------------------------------------- + + +def test_document_loader_port_protocol(): + assert isinstance(MarkdownLoader(), DocumentLoaderPort) + assert isinstance(PlainTextLoader(), DocumentLoaderPort) diff --git a/tests/test_components_vector_retriever.py b/tests/test_components_vector_retriever.py new file mode 100644 index 0000000..39d2515 --- /dev/null +++ b/tests/test_components_vector_retriever.py @@ -0,0 +1,504 @@ +""" +Tests for VectorRetriever, CatalogChunker, RecursiveCharacterChunker — 16 cases. + +Mock strategy: +- FakeVectorStore + FakeEmbedding for tests that control search results explicitly. +- InMemoryVectorStore + FakeEmbedding for tests that verify actual storage/merge behavior + (tests 10, from_chunks_add_incremental). FakeVectorStore.search() returns pre-configured + results, so it cannot catch real storage bugs. +""" + +import pytest + +from lang2sql.components.retrieval.chunker import ( + CatalogChunker, + RecursiveCharacterChunker, +) +from lang2sql.components.retrieval.vector import VectorRetriever +from lang2sql.core.catalog import RetrievalResult +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import SequentialFlow +from lang2sql.integrations.vectorstore import InMemoryVectorStore + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class FakeVectorStore: + """ + Controlled search results for unit tests. + search() returns whatever was passed to __init__(results=...). + upsert() implements merge semantics to match InMemoryVectorStore contract. + Do NOT use for tests that verify storage correctness — use InMemoryVectorStore instead. + """ + + def __init__(self, results=None): + self._results = results or [] + self.upserted: dict = {} + + def search(self, vector, k): + return self._results[:k] + + def upsert(self, ids, vectors): + # merge semantics — consistent with VectorStorePort contract + for id_, vec in zip(ids, vectors): + self.upserted[id_] = vec + + +class FakeEmbedding: + def embed_query(self, text): + return [0.0] * 4 + + def embed_texts(self, texts): + return [[0.0] * 4] * len(texts) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +CATALOG = [ + { + "name": "orders", + "description": "Order information table", + "columns": {"order_id": "Unique order ID", "amount": "Order amount"}, + } +] + +DOCS = [ + { + "id": "biz_rules", + "title": "Revenue Definition", + "content": "Revenue is defined as net sales excluding returns.", + "source": "docs/biz_rules.md", + } +] + + +def _make_catalog_registry(): + """Registry pre-populated with one catalog chunk.""" + return { + "orders__0": { + "chunk_id": "orders__0", + "text": "orders: Order information table", + "source_type": "catalog", + "source_id": "orders", + "chunk_index": 0, + "metadata": CATALOG[0], + } + } + + +def _make_doc_registry(): + """Registry pre-populated with one document chunk.""" + return { + "biz_rules__0": { + "chunk_id": "biz_rules__0", + "text": "Revenue Definition: Revenue is defined as net sales.", + "source_type": "document", + "source_id": "biz_rules", + "chunk_index": 0, + "metadata": { + "id": "biz_rules", + "title": "Revenue Definition", + "source": "", + }, + } + } + + +# --------------------------------------------------------------------------- +# 1. Catalog chunk deduplication +# --------------------------------------------------------------------------- + + +def test_catalog_chunk_dedup(): + """Multiple chunks from the same table → only 1 CatalogEntry returned.""" + registry = { + "orders__0": { + "chunk_id": "orders__0", + "text": "orders: Order table", + "source_type": "catalog", + "source_id": "orders", + "chunk_index": 0, + "metadata": CATALOG[0], + }, + "orders__col_1": { + "chunk_id": "orders__col_1", + "text": "orders columns: order_id amount", + "source_type": "catalog", + "source_id": "orders", + "chunk_index": 1, + "metadata": CATALOG[0], + }, + } + store = FakeVectorStore(results=[("orders__0", 0.9), ("orders__col_1", 0.8)]) + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry=registry + ) + result = retriever("order amount") + + assert len(result.schemas) == 1 + assert result.schemas[0]["name"] == "orders" + + +# --------------------------------------------------------------------------- +# 2. Document chunk in context +# --------------------------------------------------------------------------- + + +def test_document_chunk_in_context(): + """Document chunk appears in RetrievalResult.context.""" + registry = _make_doc_registry() + store = FakeVectorStore(results=[("biz_rules__0", 0.8)]) + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry=registry + ) + result = retriever("revenue definition") + + assert len(result.context) == 1 + assert "Revenue" in result.context[0] + + +# --------------------------------------------------------------------------- +# 3. Empty registry +# --------------------------------------------------------------------------- + + +def test_empty_registry_returns_empty_result(): + """Empty registry → empty RetrievalResult.""" + store = FakeVectorStore(results=[("orders__0", 0.9)]) + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry={} + ) + result = retriever("any query") + + assert result.schemas == [] + assert result.context == [] + + +# --------------------------------------------------------------------------- +# 4. Score threshold filtering +# --------------------------------------------------------------------------- + + +def test_score_threshold_filters_results(): + """Chunks at or below threshold are excluded.""" + registry = _make_catalog_registry() + store = FakeVectorStore(results=[("orders__0", 0.3)]) + retriever = VectorRetriever( + vectorstore=store, + embedding=FakeEmbedding(), + registry=registry, + score_threshold=0.3, # score must be strictly greater than threshold + ) + result = retriever("orders") + + assert result.schemas == [] + + +# --------------------------------------------------------------------------- +# 5. top_n limits schemas +# --------------------------------------------------------------------------- + + +def test_top_n_limits_schemas(): + """schemas capped at top_n.""" + registry = { + f"table_{i}__0": { + "chunk_id": f"table_{i}__0", + "text": f"table_{i}", + "source_type": "catalog", + "source_id": f"table_{i}", + "chunk_index": 0, + "metadata": {"name": f"table_{i}", "description": "", "columns": {}}, + } + for i in range(10) + } + store = FakeVectorStore( + results=[(f"table_{i}__0", 0.9 - i * 0.05) for i in range(10)] + ) + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry=registry, top_n=3 + ) + result = retriever("table") + + assert len(result.schemas) <= 3 + + +# --------------------------------------------------------------------------- +# 6. top_n limits context +# --------------------------------------------------------------------------- + + +def test_top_n_limits_context(): + """context capped at top_n.""" + registry = { + f"doc_{i}__0": { + "chunk_id": f"doc_{i}__0", + "text": f"doc chunk {i}", + "source_type": "document", + "source_id": f"doc_{i}", + "chunk_index": 0, + "metadata": {"id": f"doc_{i}", "title": "", "source": ""}, + } + for i in range(10) + } + store = FakeVectorStore( + results=[(f"doc_{i}__0", 0.9 - i * 0.05) for i in range(10)] + ) + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry=registry, top_n=3 + ) + result = retriever("doc") + + assert len(result.context) <= 3 + + +# --------------------------------------------------------------------------- +# 7. Hook events +# --------------------------------------------------------------------------- + + +def test_hook_start_end_events(): + """MemoryHook records start/end events + duration_ms.""" + hook = MemoryHook() + store = FakeVectorStore() + retriever = VectorRetriever( + vectorstore=store, embedding=FakeEmbedding(), registry={}, hook=hook + ) + retriever("test query") + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "end" + assert hook.events[1].duration_ms is not None + assert hook.events[1].duration_ms >= 0.0 + + +# --------------------------------------------------------------------------- +# 8. from_chunks() — catalog chunks populate registry +# --------------------------------------------------------------------------- + + +def test_from_chunks_catalog_populates_registry(): + """from_chunks(catalog_chunks) populates registry with catalog source_type.""" + chunks = CatalogChunker().split(CATALOG) + retriever = VectorRetriever.from_chunks(chunks, embedding=FakeEmbedding()) + + assert len(retriever._registry) > 0 + for chunk in retriever._registry.values(): + assert chunk["source_type"] == "catalog" + assert chunk["source_id"] == "orders" + + +# --------------------------------------------------------------------------- +# 9. from_chunks() — document chunks populate registry +# --------------------------------------------------------------------------- + + +def test_from_chunks_doc_populates_registry(): + """from_chunks(doc_chunks) populates registry with document source_type.""" + chunks = RecursiveCharacterChunker().split(DOCS) + retriever = VectorRetriever.from_chunks(chunks, embedding=FakeEmbedding()) + + assert len(retriever._registry) > 0 + for chunk in retriever._registry.values(): + assert chunk["source_type"] == "document" + assert chunk["source_id"] == "biz_rules" + + +# --------------------------------------------------------------------------- +# 10. InMemoryVectorStore merge — catalog survives after doc chunks added +# --------------------------------------------------------------------------- + + +def test_from_chunks_preserves_catalog_after_doc_run(): + """ + catalog vectors survive when doc chunks are combined. + Uses InMemoryVectorStore to verify real merge behavior. + """ + store = InMemoryVectorStore() + catalog_chunks = CatalogChunker().split(CATALOG) + retriever = VectorRetriever.from_chunks( + catalog_chunks, embedding=FakeEmbedding(), vectorstore=store + ) + catalog_chunk_ids = set(retriever._registry.keys()) + assert len(catalog_chunk_ids) > 0 + + doc_chunks = RecursiveCharacterChunker().split(DOCS) + retriever.add(doc_chunks) + + for chunk_id in catalog_chunk_ids: + assert chunk_id in store._store, f"catalog chunk '{chunk_id}' lost after add()" + + +# --------------------------------------------------------------------------- +# 11. CatalogChunker — column groups +# --------------------------------------------------------------------------- + + +def test_catalog_chunker_column_groups(): + """25 columns → CatalogChunker (max 20) produces at least 2 chunks beyond header.""" + entry = { + "name": "big_table", + "description": "Large table", + "columns": {f"col_{i}": f"column {i}" for i in range(25)}, + } + chunker = CatalogChunker(max_columns_per_chunk=20) + chunks = chunker.chunk(entry) + + # chunk 0 = header, chunk 1 = first 20 cols, chunk 2 = remaining 5 cols + assert len(chunks) >= 3 + assert chunks[0]["chunk_id"] == "big_table__0" + assert all(c["source_type"] == "catalog" for c in chunks) + assert all(c["metadata"]["name"] == "big_table" for c in chunks) + + +# --------------------------------------------------------------------------- +# 12. RecursiveCharacterChunker — respects chunk_size +# --------------------------------------------------------------------------- + + +def test_recursive_chunker_respects_chunk_size(): + """Text exceeding chunk_size is split into multiple chunks.""" + chunker = RecursiveCharacterChunker(chunk_size=50, chunk_overlap=0) + doc = { + "id": "doc1", + "title": "", + "content": "A" * 50 + "\n\n" + "B" * 50, + "source": "", + } + chunks = chunker.chunk(doc) + + assert len(chunks) >= 2 + for chunk in chunks: + # title prefix is empty so raw text length should respect chunk_size + assert len(chunk["text"]) <= 50 + 10 # small tolerance for separator + + +# --------------------------------------------------------------------------- +# 13. RecursiveCharacterChunker — overlap +# --------------------------------------------------------------------------- + + +def test_recursive_chunker_overlap(): + """Second chunk contains tail characters of the first chunk.""" + overlap = 20 + chunker = RecursiveCharacterChunker(chunk_size=40, chunk_overlap=overlap) + long_content = "Hello world this is a test. " * 10 + doc = {"id": "d1", "title": "", "content": long_content, "source": ""} + chunks = chunker.chunk(doc) + + if len(chunks) >= 2: + tail_of_first = chunks[0]["text"][-overlap:] + assert tail_of_first in chunks[1]["text"] + + +# --------------------------------------------------------------------------- +# 14. from_sources() — builds retriever with non-empty registry +# --------------------------------------------------------------------------- + + +def test_from_sources_builds_retriever(): + """from_sources() returns a VectorRetriever with non-empty registry.""" + retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=FakeEmbedding(), + ) + + assert isinstance(retriever, VectorRetriever) + assert len(retriever._registry) > 0 + + +# --------------------------------------------------------------------------- +# 15. from_sources() + add() — incremental indexing +# --------------------------------------------------------------------------- + + +def test_from_sources_add_incremental(): + """retriever.add(chunks) adds chunks without losing existing catalog chunks.""" + retriever = VectorRetriever.from_sources( + catalog=CATALOG, + embedding=FakeEmbedding(), + ) + initial_ids = set(retriever._registry.keys()) + assert len(initial_ids) > 0 + + doc_chunks = RecursiveCharacterChunker().split(DOCS) + retriever.add(doc_chunks) + + final_ids = set(retriever._registry.keys()) + # new doc chunks were added + assert len(final_ids) > len(initial_ids) + # original catalog chunks are still present + for chunk_id in initial_ids: + assert chunk_id in final_ids, f"catalog chunk '{chunk_id}' lost after add()" + + +# --------------------------------------------------------------------------- +# 16. from_chunks() — empty chunks → empty result +# --------------------------------------------------------------------------- + + +def test_from_chunks_empty(): + """from_chunks([]) → retriever with empty registry returns empty result.""" + retriever = VectorRetriever.from_chunks([], embedding=FakeEmbedding()) + + assert retriever._registry == {} + result = retriever("any query") + assert result.schemas == [] + assert result.context == [] + + +# --------------------------------------------------------------------------- +# 17. from_chunks() — mixed catalog + doc chunks +# --------------------------------------------------------------------------- + + +def test_from_chunks_mixed_catalog_and_docs(): + """from_chunks with catalog + doc chunks → both source_types in registry.""" + chunks = CatalogChunker().split(CATALOG) + RecursiveCharacterChunker().split(DOCS) + retriever = VectorRetriever.from_chunks(chunks, embedding=FakeEmbedding()) + + source_types = {c["source_type"] for c in retriever._registry.values()} + assert "catalog" in source_types + assert "document" in source_types + + +# --------------------------------------------------------------------------- +# 18. from_chunks() + add() — incremental after from_chunks +# --------------------------------------------------------------------------- + + +def test_from_chunks_add_incremental(): + """from_chunks() followed by add(more_chunks) preserves original chunks.""" + store = InMemoryVectorStore() + catalog_chunks = CatalogChunker().split(CATALOG) + retriever = VectorRetriever.from_chunks( + catalog_chunks, embedding=FakeEmbedding(), vectorstore=store + ) + initial_ids = set(retriever._registry.keys()) + + doc_chunks = RecursiveCharacterChunker().split(DOCS) + retriever.add(doc_chunks) + + final_ids = set(retriever._registry.keys()) + assert len(final_ids) > len(initial_ids) + for chunk_id in initial_ids: + assert chunk_id in final_ids, f"chunk '{chunk_id}' lost after add()" + + +# --------------------------------------------------------------------------- +# 19. CatalogChunker.split() — batch convenience method +# --------------------------------------------------------------------------- + + +def test_catalog_chunker_split_batch(): + """CatalogChunker.split(catalog) returns same chunks as calling chunk() per entry.""" + chunker = CatalogChunker() + by_split = chunker.split(CATALOG) + by_chunk = [c for entry in CATALOG for c in chunker.chunk(entry)] + + assert [c["chunk_id"] for c in by_split] == [c["chunk_id"] for c in by_chunk]