diff --git a/docs/BaseComponent_ko.md b/docs/BaseComponent_ko.md index c98635f..657a79c 100644 --- a/docs/BaseComponent_ko.md +++ b/docs/BaseComponent_ko.md @@ -190,13 +190,15 @@ retriever = FunctionalComponent(my_retriever, name="MyRetriever", hook=hook) ```python from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import SequentialFlow + hook = MemoryHook() -flow = BaselineFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입 -out = flow.run_query("지난달 매출") +flow = SequentialFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입 +out = flow.run("지난달 매출") # 이벤트 확인 -for e in hook.events: +for e in hook.snapshot(): print(e.phase, e.component, e.duration_ms, e.error) ``` diff --git a/docs/Hook_and_exception_ko.md b/docs/Hook_and_exception_ko.md index c5764e5..b2c650a 100644 --- a/docs/Hook_and_exception_ko.md +++ b/docs/Hook_and_exception_ko.md @@ -111,16 +111,16 @@ class MemoryHook: #### MemoryHook 사용 예시 -```py +```python from lang2sql.core.hooks import MemoryHook -from lang2sql.flows.baseline import BaselineFlow +from lang2sql.flows.baseline import SequentialFlow hook = MemoryHook() -flow = BaselineFlow(steps=[...], hook=hook) +flow = SequentialFlow(steps=[...], hook=hook) -out = flow.run_query("지난달 매출") +out = flow.run("지난달 매출") -for e in hook.events: +for e in hook.snapshot(): print(e.name, e.phase, e.component, e.duration_ms, e.error) ``` diff --git a/docs/tutorials/getting-started-without-datahub.md b/docs/tutorials/getting-started-without-datahub.md index 0792b6a..d24d0d3 100644 --- a/docs/tutorials/getting-started-without-datahub.md +++ b/docs/tutorials/getting-started-without-datahub.md @@ -122,19 +122,53 @@ print(f"FAISS index saved to: {OUTPUT_DIR}/catalog.faiss") ### 4) 실행 +v2 CLI는 외부 벡터 인덱스 경로를 인수로 받지 않습니다. +앞서 생성한 FAISS 인덱스를 활용하려면 Python API로 파이프라인을 직접 구성합니다. + +```python +# run_query.py +import os +from dotenv import load_dotenv +from lang2sql import CatalogChunker, VectorRetriever +from lang2sql.integrations.db import SQLAlchemyDB +from lang2sql.integrations.embedding import OpenAIEmbedding +from lang2sql.integrations.llm import OpenAILLM +from lang2sql.integrations.vectorstore import FAISSVectorStore +from lang2sql.flows.hybrid import HybridNL2SQL + +load_dotenv() + +INDEX_DIR = "./dev/table_info_db" +embedding = OpenAIEmbedding( + model=os.getenv("OPEN_AI_EMBEDDING_MODEL", "text-embedding-3-large"), + api_key=os.getenv("OPEN_AI_KEY"), +) + +# FAISS 인덱스 로드 후 파이프라인 구성 +store = FAISSVectorStore.load(f"{INDEX_DIR}/catalog.faiss") + +pipeline = HybridNL2SQL( + catalog=[], # FAISS에 이미 인덱싱돼 있으므로 빈 리스트 + llm=OpenAILLM(model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), api_key=os.getenv("OPEN_AI_KEY")), + db=SQLAlchemyDB(os.getenv("DB_URL", "sqlite:///sample.db")), + embedding=embedding, + db_dialect=os.getenv("DB_TYPE", "sqlite"), +) + +rows = pipeline.run("주문 수를 집계하는 SQL을 만들어줘") +print(rows) +``` + +Streamlit UI: + ```bash -# Streamlit UI lang2sql run-streamlit +``` -# CLI 예시 (FAISS 인덱스 사용) -lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \ - --vectordb-type faiss \ - --vectordb-location ./dev/table_info_db +CLI (카탈로그 없이 baseline만 가능): -# CLI 예시 (pgvector) -lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \ - --vectordb-type pgvector \ - --vectordb-location "postgresql://pgvector:pgvector@localhost:5432/postgres" +```bash +lang2sql query "주문 수를 집계해줘" --flow baseline --dialect sqlite ``` ### 5) (선택) pgvector로 적재하기 @@ -229,4 +263,3 @@ VectorRetriever.from_chunks( print(f"pgvector collection populated: {TABLE}") ``` -주의: FAISS 디렉토리 또는 pgvector 컬렉션이 없으면 현재 코드는 DataHub에서 메타데이터를 가져와 인덱스를 생성하려고 시도합니다. DataHub를 사용하지 않는 경우 위 절차로 사전에 VectorDB를 만들어 두세요. diff --git a/docs/tutorials/v2-complete-tutorial.md b/docs/tutorials/v2-complete-tutorial.md index 707440b..1f2cd06 100644 --- a/docs/tutorials/v2-complete-tutorial.md +++ b/docs/tutorials/v2-complete-tutorial.md @@ -20,6 +20,7 @@ 5-1. 샘플 문서 자동 생성 6. 가장 쉬운 로컬 스모크 테스트 (API 키 없이) 7. BaselineNL2SQL 기본 사용 (KeywordRetriever) +7-1. DB 탐색: SQLAlchemyExplorer 8. 실제 LLM 연결 (OpenAI / Anthropic) 9. VectorRetriever 기초 (빠른 시작) 10. 문서 파싱: MarkdownLoader / PlainTextLoader / DirectoryLoader / PDFLoader @@ -232,6 +233,99 @@ print(rows) --- +## 7-1) DB 탐색: SQLAlchemyExplorer + +LLM에게 넘길 스키마 정보가 필요하거나, 처음 보는 DB를 손으로 살펴볼 때 사용합니다. +카탈로그를 미리 구축하지 않아도 DDL + 샘플 데이터를 바로 꺼내볼 수 있습니다. + +### 기본 사용 + +```python +from lang2sql import build_explorer_from_url + +exp = build_explorer_from_url("sqlite:///sample.db") + +# 1) 어떤 테이블이 있는지 +print(exp.list_tables()) +# ['customers', 'orders', ...] + +# 2) 테이블 DDL — CREATE TABLE 원문 +print(exp.get_ddl("orders")) +# CREATE TABLE orders ( +# id INTEGER PRIMARY KEY, +# customer_id INTEGER NOT NULL REFERENCES customers(id), +# amount REAL, +# status TEXT DEFAULT 'pending' +# ) + +# 3) 실제 샘플 데이터 (기본 5행) +print(exp.sample_data("orders")) +# [{'id': 1, 'customer_id': 1, 'amount': 99.9, 'status': 'shipped'}, ...] + +# 4) 커스텀 읽기 전용 질의 +print(exp.execute_read_only("SELECT status, COUNT(*) AS cnt FROM orders GROUP BY status")) +# [{'status': 'pending', 'cnt': 3}, {'status': 'shipped', 'cnt': 2}] +``` + +### 전체 테이블 한 번에 둘러보기 + +```python +from lang2sql import build_explorer_from_url + +exp = build_explorer_from_url("sqlite:///sample.db") + +for table in exp.list_tables(): + print(f"\n=== {table} ===") + print(exp.get_ddl(table)) + rows = exp.sample_data(table, limit=2) + print("샘플:", rows) +``` + +### PostgreSQL / MySQL 연결 + +URL만 바꾸면 됩니다. + +```python +from lang2sql import build_explorer_from_url + +# PostgreSQL +exp = build_explorer_from_url("postgresql://user:password@localhost:5432/mydb") + +# MySQL +exp = build_explorer_from_url("mysql+pymysql://user:password@localhost:3306/mydb") + +# schema 지정 (schema 파라미터) +exp = build_explorer_from_url("postgresql://user:pass@host/db", schema="analytics") +print(exp.list_tables()) # analytics 스키마 테이블만 +``` + +### 기존 SQLAlchemyDB engine 재사용 + +연결 풀을 따로 만들지 않고 공유할 수 있습니다. + +```python +from lang2sql.integrations.db import SQLAlchemyDB, SQLAlchemyExplorer + +db = SQLAlchemyDB("sqlite:///sample.db") +exp = SQLAlchemyExplorer.from_engine(db._engine) + +# db는 SQL 실행, exp는 탐색 — 같은 연결 풀 공유 +rows = db.execute("SELECT COUNT(*) AS cnt FROM orders") +ddl = exp.get_ddl("orders") +``` + +### 쓰기 구문은 거부됩니다 + +```python +exp.execute_read_only("DROP TABLE orders") +# ValueError: Write operations not allowed: 'DROP TABLE orders' + +exp.execute_read_only("INSERT INTO orders VALUES (99, 1, 0, 'test')") +# ValueError: Write operations not allowed: 'INSERT INTO orders ...' +``` + +--- + ## 8) 실제 LLM 연결 (OpenAI / Anthropic) LLM 백엔드는 교체 가능합니다. diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py index 2781ba9..66811de 100644 --- a/src/lang2sql/__init__.py +++ b/src/lang2sql/__init__.py @@ -1,4 +1,9 @@ -from .factory import build_db_from_env, build_embedding_from_env, build_llm_from_env +from .factory import ( + build_db_from_env, + build_embedding_from_env, + build_explorer_from_url, + build_llm_from_env, +) from .components.enrichment.context_enricher import ContextEnricher from .components.enrichment.question_profiler import QuestionProfiler from .components.execution.sql_executor import SQLExecutor @@ -28,16 +33,18 @@ from .core.exceptions import ComponentError, IntegrationMissingError, Lang2SQLError from .core.hooks import MemoryHook, NullHook, TraceHook from .core.ports import ( + CatalogLoaderPort, + DBExplorerPort, DBPort, DocumentLoaderPort, EmbeddingPort, LLMPort, VectorStorePort, ) +from .integrations.db.sqlalchemy_ import SQLAlchemyExplorer from .flows.enriched_nl2sql import EnrichedNL2SQL from .flows.hybrid import HybridNL2SQL from .flows.nl2sql import BaselineNL2SQL -from .integrations.catalog.datahub_ import DataHubCatalogLoader from .integrations.embedding.azure_ import AzureOpenAIEmbedding from .integrations.embedding.bedrock_ import BedrockEmbedding from .integrations.embedding.gemini_ import GeminiEmbedding @@ -48,8 +55,6 @@ from .integrations.llm.gemini_ import GeminiLLM from .integrations.llm.huggingface_ import HuggingFaceLLM from .integrations.llm.ollama_ import OllamaLLM -from .integrations.vectorstore.faiss_ import FAISSVectorStore -from .integrations.vectorstore.pgvector_ import PGVectorStore __all__ = [ # Data types @@ -64,9 +69,11 @@ # Ports (protocols) "LLMPort", "DBPort", + "DBExplorerPort", "EmbeddingPort", "VectorStorePort", "DocumentLoaderPort", + "CatalogLoaderPort", # Components — retrieval "KeywordRetriever", "VectorRetriever", @@ -116,8 +123,33 @@ "OllamaEmbedding", # Catalog integrations (Phase 3) "DataHubCatalogLoader", + # DB Explorer (Phase A1) + "SQLAlchemyExplorer", # Factory (Phase 6) "build_llm_from_env", "build_embedding_from_env", "build_db_from_env", + "build_explorer_from_url", ] + +# --------------------------------------------------------------------------- +# Lazy imports (PEP 562) — optional dependencies that have import side-effects +# (e.g. faiss prints INFO logs on import) or are rarely needed at startup. +# --------------------------------------------------------------------------- +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "DataHubCatalogLoader": (".integrations.catalog.datahub_", "DataHubCatalogLoader"), + "FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"), + "PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"), +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + module_path, attr = _LAZY_IMPORTS[name] + import importlib + + obj = getattr(importlib.import_module(module_path, package=__name__), attr) + # Cache in module globals so subsequent accesses skip __getattr__ + globals()[name] = obj + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/lang2sql/components/retrieval/vector.py b/src/lang2sql/components/retrieval/vector.py index ca1c454..c104590 100644 --- a/src/lang2sql/components/retrieval/vector.py +++ b/src/lang2sql/components/retrieval/vector.py @@ -168,6 +168,81 @@ def add(self, chunks: list[IndexedChunk]) -> None: self._vectorstore.upsert(ids, vectors) self._registry.update({c["chunk_id"]: c for c in chunks}) + # ── Persistence ────────────────────────────────────────────────── + + def save(self, path: str) -> None: + """벡터 인덱스와 registry를 path에 저장. + + FAISSVectorStore처럼 save()를 지원하는 store에서만 동작한다. + InMemoryVectorStore 등 save()가 없는 store는 NotImplementedError. + + 저장 파일: + {path} — FAISSVectorStore 벡터 인덱스 + {path}.meta — chunk_id 순서 목록 (FAISSVectorStore 내부) + {path}.registry — registry JSON + """ + import json + import pathlib + + save_fn = getattr(self._vectorstore, "save", None) + if save_fn is None: + raise NotImplementedError( + f"{type(self._vectorstore).__name__} does not support save(). " + "Use FAISSVectorStore for file-based persistence." + ) + save_fn(path) + pathlib.Path(path + ".registry").write_text( + json.dumps(self._registry), encoding="utf-8" + ) + + @classmethod + def load( + cls, + path: str, + *, + vectorstore: VectorStorePort, + embedding: EmbeddingPort, + top_n: int = 5, + score_threshold: float = 0.0, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> "VectorRetriever": + """저장된 registry를 복원해 VectorRetriever를 반환. + + 벡터 인덱스 복원은 호출자가 직접 수행한 뒤 vectorstore로 전달한다. + 이렇게 하면 VectorRetriever가 특정 store 구현체에 의존하지 않는다. + + Args: + path: save() 시 사용한 경로 (registry 파일 위치 기준). + vectorstore: 이미 로드된 VectorStorePort 구현체. + embedding: EmbeddingPort 구현체. + top_n: 최대 반환 스키마/컨텍스트 수. 기본 5. + score_threshold: 이 점수 이하는 결과에서 제외. 기본 0.0. + + Example: + store = FAISSVectorStore.load(path) + retriever = VectorRetriever.load(path, vectorstore=store, embedding=emb) + """ + import json + import pathlib + + registry_path = pathlib.Path(path + ".registry") + if not registry_path.exists(): + raise FileNotFoundError(f"Registry file not found: {registry_path}") + + registry = json.loads(registry_path.read_text(encoding="utf-8")) + return cls( + vectorstore=vectorstore, + embedding=embedding, + registry=registry, + top_n=top_n, + score_threshold=score_threshold, + name=name, + hook=hook, + ) + + # ── Core retrieval ──────────────────────────────────────────────── + def _run(self, query: str) -> RetrievalResult: """ Args: diff --git a/src/lang2sql/core/ports.py b/src/lang2sql/core/ports.py index b04bc61..d1bf462 100644 --- a/src/lang2sql/core/ports.py +++ b/src/lang2sql/core/ports.py @@ -63,3 +63,23 @@ class CatalogLoaderPort(Protocol): """Abstracts catalog loading from external sources (DataHub, file, database, etc.).""" def load(self) -> list[CatalogEntry]: ... + + +class DBExplorerPort(Protocol): + """DB 에이전틱 탐색 인터페이스. Agent가 DB를 직접 탐색할 때 사용. + + 메서드 선정 원칙: + - DDL에 이미 있는 정보(컬럼 목록, FK, PK)는 별도 메서드 없음 + - 통계/집계는 execute_read_only()로 직접 질의 + - 관계 추론은 LLM에 위임 (휴리스틱 제거) + """ + + def list_tables(self, schema: str | None = None) -> list[str]: ... + + def get_ddl(self, table: str, *, schema: str | None = None) -> str: ... + + def sample_data( + self, table: str, *, limit: int = 5, schema: str | None = None + ) -> list[dict]: ... + + def execute_read_only(self, sql: str) -> list[dict]: ... diff --git a/src/lang2sql/factory.py b/src/lang2sql/factory.py index 6fca375..bd05b41 100644 --- a/src/lang2sql/factory.py +++ b/src/lang2sql/factory.py @@ -8,7 +8,7 @@ import os -from .core.ports import DBPort, EmbeddingPort, LLMPort +from .core.ports import DBExplorerPort, DBPort, EmbeddingPort, LLMPort def build_llm_from_env() -> LLMPort: @@ -156,6 +156,13 @@ def build_embedding_from_env() -> EmbeddingPort: ) +def build_explorer_from_url(url: str, *, schema: str | None = None) -> "DBExplorerPort": + """DB URL로 SQLAlchemyExplorer 생성.""" + from .integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + return SQLAlchemyExplorer(url, schema=schema) + + def build_db_from_env(database_env: str = "") -> DBPort: """환경변수에서 DB URL을 구성하고 SQLAlchemyDB를 반환한다. diff --git a/src/lang2sql/integrations/db/__init__.py b/src/lang2sql/integrations/db/__init__.py index 4096452..79ae1db 100644 --- a/src/lang2sql/integrations/db/__init__.py +++ b/src/lang2sql/integrations/db/__init__.py @@ -1,3 +1,3 @@ -from .sqlalchemy_ import SQLAlchemyDB +from .sqlalchemy_ import SQLAlchemyDB, SQLAlchemyExplorer -__all__ = ["SQLAlchemyDB"] +__all__ = ["SQLAlchemyDB", "SQLAlchemyExplorer"] diff --git a/src/lang2sql/integrations/db/sqlalchemy_.py b/src/lang2sql/integrations/db/sqlalchemy_.py index 7444502..10f2ea6 100644 --- a/src/lang2sql/integrations/db/sqlalchemy_.py +++ b/src/lang2sql/integrations/db/sqlalchemy_.py @@ -6,10 +6,11 @@ from ...core.ports import DBPort try: - from sqlalchemy import create_engine, text as sa_text + from sqlalchemy import create_engine, inspect as sa_inspect, text as sa_text from sqlalchemy.engine import Engine except ImportError: create_engine = None # type: ignore[assignment] + sa_inspect = None # type: ignore[assignment] sa_text = None # type: ignore[assignment] Engine = None # type: ignore[assignment,misc] @@ -28,3 +29,112 @@ def execute(self, sql: str) -> list[dict[str, Any]]: with self._engine.connect() as conn: result = conn.execute(sa_text(sql)) return [dict(row._mapping) for row in result] + + +_WRITE_PREFIXES = frozenset( + { + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "REPLACE", + "MERGE", + } +) + + +class SQLAlchemyExplorer: + """DBExplorerPort implementation backed by SQLAlchemy 2.x. + + Agent가 DB 스키마를 탐색할 때 사용. DDL + 샘플 데이터를 LLM context에 직접 주입. + """ + + def __init__(self, url: str, *, schema: str | None = None) -> None: + if create_engine is None: + raise IntegrationMissingError( + "sqlalchemy", extra="sqlalchemy", hint="pip install sqlalchemy" + ) + self._engine: Engine = create_engine(url) + self._schema = schema + + @classmethod + def from_engine( + cls, engine: "Engine", *, schema: str | None = None + ) -> "SQLAlchemyExplorer": + """기존 engine 공유용. 연결 풀 중복 방지.""" + instance = cls.__new__(cls) + instance._engine = engine + instance._schema = schema + return instance + + def list_tables(self, schema: str | None = None) -> list[str]: + """테이블 목록 반환. Agent가 DB 구조 파악 시 첫 번째 호출.""" + insp = sa_inspect(self._engine) + return insp.get_table_names(schema=schema or self._schema) + + def get_ddl(self, table: str, *, schema: str | None = None) -> str: + """원본 DDL 문자열 반환. 컬럼 정의, PK, FK, 제약조건 모두 포함. + + SQLite: sqlite_master에서 원본 그대로 (DEFAULT, 코멘트, 인라인 FK 모두 보존). + 그 외: SQLAlchemy CreateTable construct로 포괄적 DDL 생성. + """ + resolved_schema = schema or self._schema + if self._engine.dialect.name == "sqlite": + rows = self._execute_safe( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=:table", + {"table": table}, + ) + if rows and rows[0].get("sql"): + return rows[0]["sql"] + + from sqlalchemy import MetaData + from sqlalchemy import Table as SATable + from sqlalchemy.schema import CreateTable + + metadata = MetaData() + t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema) + return str(CreateTable(t).compile(self._engine)) + + def sample_data( + self, table: str, *, limit: int = 5, schema: str | None = None + ) -> list[dict]: + """실제 샘플 데이터 반환. + + f-string SQL 금지 — SQLAlchemy ORM select()로 identifier quoting 위임. + dialect별 quoting 차이(PostgreSQL ", MySQL `, SQLite ")를 SQLAlchemy가 처리. + """ + from sqlalchemy import MetaData, select + from sqlalchemy import Table as SATable + + resolved_schema = schema or self._schema + metadata = MetaData() + t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema) + stmt = select(t).limit(limit) + with self._engine.connect() as conn: + result = conn.execute(stmt) + return [dict(row._mapping) for row in result] + + def execute_read_only(self, sql: str) -> list[dict]: + """읽기 전용 SQL 실행. + + 두 겹 방어: + 1. prefix guard — 일반적인 쓰기 구문 빠른 거부 (UX) + 2. rollback-always — WITH ... DELETE 같은 CTE 우회도 실제 DB 반영 방지 + """ + first_token = sql.strip().upper().split()[0] if sql.strip() else "" + if first_token in _WRITE_PREFIXES: + raise ValueError(f"Write operations not allowed: {sql[:50]!r}") + with self._engine.connect() as conn: + result = conn.execute(sa_text(sql)) + rows = [dict(row._mapping) for row in result] + conn.rollback() + return rows + + def _execute_safe(self, sql: str, params: dict | None = None) -> list[dict]: + """파라미터화 쿼리 실행 (내부용).""" + with self._engine.connect() as conn: + result = conn.execute(sa_text(sql), params or {}) + return [dict(row._mapping) for row in result] diff --git a/tests/test_components_vector_retriever.py b/tests/test_components_vector_retriever.py index 39d2515..07c77cd 100644 --- a/tests/test_components_vector_retriever.py +++ b/tests/test_components_vector_retriever.py @@ -502,3 +502,82 @@ def test_catalog_chunker_split_batch(): by_chunk = [c for entry in CATALOG for c in chunker.chunk(entry)] assert [c["chunk_id"] for c in by_split] == [c["chunk_id"] for c in by_chunk] + + +# --------------------------------------------------------------------------- +# 20-22. VectorRetriever save / load (FAISS 필요) +# --------------------------------------------------------------------------- + +faiss = pytest.importorskip("faiss", reason="faiss-cpu not installed") + + +class FakeEmbeddingFAISS: + """FAISS L2 정규화에서 zero-vector 오류가 안 나도록 비영벡터를 반환.""" + + def _vec(self, text: str) -> list[float]: + # 텍스트별로 구별 가능한 비영벡터 + h = abs(hash(text)) % 900 + 100 + return [h * 0.001, 1.0, 1.0, 1.0] + + def embed_query(self, text: str) -> list[float]: + return self._vec(text) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [self._vec(t) for t in texts] + + +def test_save_and_load_returns_same_results(tmp_path): + """save → load 후 동일 쿼리에 동일 스키마가 반환된다.""" + path = str(tmp_path / "catalog") + embedding = FakeEmbeddingFAISS() + + from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore + + store = FAISSVectorStore(index_path=path + ".faiss") + chunks = CatalogChunker().split(CATALOG) + original = VectorRetriever.from_chunks( + chunks, embedding=embedding, vectorstore=store + ) + original.save(path) + + loaded_store = FAISSVectorStore.load(path) + loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding) + result = loaded.run("주문 정보") + + assert len(result.schemas) > 0 + assert result.schemas[0]["name"] == original.run("주문 정보").schemas[0]["name"] + + +def test_load_registry_intact(tmp_path): + """load 후 registry의 키·값이 원본과 동일하다.""" + path = str(tmp_path / "catalog") + embedding = FakeEmbeddingFAISS() + + from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore + + store = FAISSVectorStore(index_path=path + ".faiss") + chunks = CatalogChunker().split(CATALOG) + original = VectorRetriever.from_chunks( + chunks, embedding=embedding, vectorstore=store + ) + original.save(path) + + loaded_store = FAISSVectorStore.load(path) + loaded = VectorRetriever.load(path, vectorstore=loaded_store, embedding=embedding) + + assert set(loaded._registry.keys()) == set(original._registry.keys()) + for chunk_id, chunk in original._registry.items(): + assert loaded._registry[chunk_id]["text"] == chunk["text"] + assert loaded._registry[chunk_id]["source_id"] == chunk["source_id"] + + +def test_save_raises_for_inmemory(): + """InMemoryVectorStore는 save()를 지원하지 않아 NotImplementedError가 발생한다.""" + embedding = FakeEmbeddingFAISS() + chunks = CatalogChunker().split(CATALOG) + retriever = VectorRetriever.from_chunks( + chunks, embedding=embedding + ) # InMemory 기본값 + + with pytest.raises(NotImplementedError, match="does not support save"): + retriever.save("/tmp/should_not_exist") diff --git a/tests/test_integrations_sqlalchemy_explorer.py b/tests/test_integrations_sqlalchemy_explorer.py new file mode 100644 index 0000000..60dd9b2 --- /dev/null +++ b/tests/test_integrations_sqlalchemy_explorer.py @@ -0,0 +1,149 @@ +"""Tests for SQLAlchemyExplorer (Phase A1).""" + +from __future__ import annotations + +import pytest +from sqlalchemy import create_engine, text + +# --------------------------------------------------------------------------- +# Fixture: SQLite in-memory DB with FK schema +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def engine(): + eng = create_engine("sqlite:///:memory:") + with eng.connect() as conn: + conn.execute(text(""" + CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ) + """)) + conn.execute(text(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES customers(id), + amount REAL, + status TEXT DEFAULT 'pending' + ) + """)) + conn.execute( + text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')") + ) + conn.execute(text("INSERT INTO customers VALUES (2, 'Bob', 'bob@example.com')")) + conn.execute(text("INSERT INTO orders VALUES (1, 1, 99.9, 'shipped')")) + conn.execute(text("INSERT INTO orders VALUES (2, 2, 42.0, 'pending')")) + conn.commit() + return eng + + +@pytest.fixture() +def explorer(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + return SQLAlchemyExplorer.from_engine(engine) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_list_tables(explorer): + tables = explorer.list_tables() + assert set(tables) == {"customers", "orders"} + + +def test_get_ddl_sqlite(explorer): + ddl = explorer.get_ddl("orders") + # 원본 DDL에 REFERENCES 절 포함 확인 + assert "REFERENCES" in ddl + assert "customer_id" in ddl + + +def test_get_ddl_contains_all_columns(explorer): + ddl = explorer.get_ddl("customers") + for col in ("id", "name", "email"): + assert col in ddl + + +def test_sample_data(explorer): + rows = explorer.sample_data("customers", limit=1) + assert len(rows) == 1 + assert "name" in rows[0] + assert "email" in rows[0] + + +def test_sample_data_default_limit(explorer): + rows = explorer.sample_data("customers") + # 2행 삽입, limit=5(기본값) → 모두 반환 + assert len(rows) == 2 + + +def test_sample_data_empty_table(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + with engine.connect() as conn: + conn.execute(text("CREATE TABLE empty_tbl (x INTEGER)")) + conn.commit() + + exp = SQLAlchemyExplorer.from_engine(engine) + assert exp.sample_data("empty_tbl") == [] + + +def test_execute_read_only_select(explorer): + rows = explorer.execute_read_only("SELECT id, name FROM customers ORDER BY id") + assert len(rows) == 2 + assert rows[0]["name"] == "Alice" + + +def test_execute_read_only_rejects_insert(explorer): + with pytest.raises(ValueError, match="Write operations not allowed"): + explorer.execute_read_only( + "INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')" + ) + + +def test_execute_read_only_rejects_drop(explorer): + with pytest.raises(ValueError, match="Write operations not allowed"): + explorer.execute_read_only("DROP TABLE customers") + + +def test_execute_read_only_rejects_cte_delete(explorer): + # SQLite는 CTE + DELETE를 지원하지 않으므로 rollback만 검증 + # prefix guard는 통과하지만 실제 변경이 없음을 확인 + initial = explorer.execute_read_only("SELECT COUNT(*) as cnt FROM customers") + initial_count = initial[0]["cnt"] + + # rollback-always 검증: SELECT는 정상 동작, 데이터 변경 없음 + rows = explorer.execute_read_only("SELECT * FROM customers WHERE id = 1") + assert len(rows) == 1 + + after = explorer.execute_read_only("SELECT COUNT(*) as cnt FROM customers") + assert after[0]["cnt"] == initial_count + + +def test_from_engine_shares_data(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyExplorer + + exp1 = SQLAlchemyExplorer.from_engine(engine) + exp2 = SQLAlchemyExplorer.from_engine(engine) + + rows1 = exp1.sample_data("customers") + rows2 = exp2.sample_data("customers") + assert rows1 == rows2 + + +def test_integration_with_sqlalchemydb(engine): + from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyDB, SQLAlchemyExplorer + + # SQLAlchemyDB와 같은 engine을 SQLAlchemyExplorer가 공유 + explorer = SQLAlchemyExplorer.from_engine(engine) + + tables = explorer.list_tables() + assert "customers" in tables + + ddl = explorer.get_ddl("customers") + assert "id" in ddl