From 415cd99d089bae1f9d2341ef31296136e46115d4 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:31:03 +0900 Subject: [PATCH 01/10] feat(integrations/llm): add Azure, Bedrock, Gemini, Ollama, HuggingFace LLM providers --- src/lang2sql/integrations/llm/__init__.py | 15 +++++- src/lang2sql/integrations/llm/azure_.py | 42 +++++++++++++++ src/lang2sql/integrations/llm/bedrock_.py | 49 +++++++++++++++++ src/lang2sql/integrations/llm/gemini_.py | 46 ++++++++++++++++ src/lang2sql/integrations/llm/huggingface_.py | 35 ++++++++++++ src/lang2sql/integrations/llm/ollama_.py | 33 ++++++++++++ tests/test_integrations_llm_azure.py | 53 +++++++++++++++++++ tests/test_integrations_llm_ollama.py | 26 +++++++++ 8 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 src/lang2sql/integrations/llm/azure_.py create mode 100644 src/lang2sql/integrations/llm/bedrock_.py create mode 100644 src/lang2sql/integrations/llm/gemini_.py create mode 100644 src/lang2sql/integrations/llm/huggingface_.py create mode 100644 src/lang2sql/integrations/llm/ollama_.py create mode 100644 tests/test_integrations_llm_azure.py create mode 100644 tests/test_integrations_llm_ollama.py diff --git a/src/lang2sql/integrations/llm/__init__.py b/src/lang2sql/integrations/llm/__init__.py index 528072c..09ecf03 100644 --- a/src/lang2sql/integrations/llm/__init__.py +++ b/src/lang2sql/integrations/llm/__init__.py @@ -1,4 +1,17 @@ from .anthropic_ import AnthropicLLM +from .azure_ import AzureOpenAILLM +from .bedrock_ import BedrockLLM +from .gemini_ import GeminiLLM +from .huggingface_ import HuggingFaceLLM +from .ollama_ import OllamaLLM from .openai_ import OpenAILLM -__all__ = ["AnthropicLLM", "OpenAILLM"] +__all__ = [ + "AnthropicLLM", + "AzureOpenAILLM", + "BedrockLLM", + "GeminiLLM", + "HuggingFaceLLM", + "OllamaLLM", + "OpenAILLM", +] diff --git a/src/lang2sql/integrations/llm/azure_.py b/src/lang2sql/integrations/llm/azure_.py new file mode 100644 index 0000000..cc32b9d --- /dev/null +++ b/src/lang2sql/integrations/llm/azure_.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import LLMPort + +try: + import openai as _openai +except ImportError: + _openai = None # type: ignore[assignment] + + +class AzureOpenAILLM(LLMPort): + """LLMPort implementation backed by the Azure OpenAI Chat Completions API.""" + + def __init__( + self, + *, + azure_deployment: str, + azure_endpoint: str, + api_version: str = "2023-07-01-preview", + api_key: str | None = None, + max_tokens: int = 4096, + ) -> None: + if _openai is None: + raise IntegrationMissingError( + "openai", hint="pip install openai # or: uv sync" + ) + self._client = _openai.AzureOpenAI( + api_key=api_key, + azure_endpoint=azure_endpoint, + api_version=api_version, + ) + self._deployment = azure_deployment + self._max_tokens = max_tokens + + def invoke(self, messages: list[dict[str, str]]) -> str: + resp = self._client.chat.completions.create( + model=self._deployment, + messages=messages, # type: ignore[arg-type] + max_tokens=self._max_tokens, + ) + return resp.choices[0].message.content or "" diff --git a/src/lang2sql/integrations/llm/bedrock_.py b/src/lang2sql/integrations/llm/bedrock_.py new file mode 100644 index 0000000..428e5d4 --- /dev/null +++ b/src/lang2sql/integrations/llm/bedrock_.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import LLMPort + +try: + import boto3 as _boto3 # type: ignore[import] +except ImportError: + _boto3 = None # type: ignore[assignment] + + +class BedrockLLM(LLMPort): + """LLMPort implementation backed by the AWS Bedrock Converse API.""" + + def __init__( + self, + *, + model: str, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + region_name: str = "us-east-1", + ) -> None: + if _boto3 is None: + raise IntegrationMissingError( + "boto3", hint="pip install boto3" + ) + self._model = model + self._client = _boto3.client( + "bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + ) + + def invoke(self, messages: list[dict[str, str]]) -> str: + system_parts = [m["content"] for m in messages if m["role"] == "system"] + user_msgs = [m for m in messages if m["role"] != "system"] + + converse_messages = [ + {"role": m["role"], "content": [{"text": m["content"]}]} + for m in user_msgs + ] + + kwargs: dict = {"modelId": self._model, "messages": converse_messages} + if system_parts: + kwargs["system"] = [{"text": system_parts[0]}] + + resp = self._client.converse(**kwargs) + return resp["output"]["message"]["content"][0]["text"] diff --git a/src/lang2sql/integrations/llm/gemini_.py b/src/lang2sql/integrations/llm/gemini_.py new file mode 100644 index 0000000..6c70799 --- /dev/null +++ b/src/lang2sql/integrations/llm/gemini_.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import LLMPort + +try: + import google.generativeai as _genai # type: ignore[import] +except ImportError: + _genai = None # type: ignore[assignment] + + +class GeminiLLM(LLMPort): + """LLMPort implementation backed by the Google Gemini Generative AI API.""" + + def __init__( + self, + *, + model: str, + api_key: str | None = None, + ) -> None: + if _genai is None: + raise IntegrationMissingError( + "google-generativeai", + hint="pip install google-generativeai", + ) + if api_key: + _genai.configure(api_key=api_key) + self._model_name = model + + def invoke(self, messages: list[dict[str, str]]) -> str: + system_parts = [m["content"] for m in messages if m["role"] == "system"] + system_instruction = system_parts[0] if system_parts else None + + contents = [] + for m in messages: + if m["role"] == "system": + continue + role = "model" if m["role"] == "assistant" else "user" + contents.append({"role": role, "parts": [m["content"]]}) + + model = _genai.GenerativeModel( + model_name=self._model_name, + system_instruction=system_instruction, + ) + response = model.generate_content(contents) + return response.text diff --git a/src/lang2sql/integrations/llm/huggingface_.py b/src/lang2sql/integrations/llm/huggingface_.py new file mode 100644 index 0000000..08cca81 --- /dev/null +++ b/src/lang2sql/integrations/llm/huggingface_.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import LLMPort + +try: + from huggingface_hub import InferenceClient as _InferenceClient # type: ignore[import] +except ImportError: + _InferenceClient = None # type: ignore[assignment] + + +class HuggingFaceLLM(LLMPort): + """LLMPort implementation backed by the HuggingFace Inference API.""" + + def __init__( + self, + *, + repo_id: str | None = None, + endpoint_url: str | None = None, + api_token: str | None = None, + ) -> None: + if _InferenceClient is None: + raise IntegrationMissingError( + "huggingface_hub", hint="pip install huggingface_hub" + ) + if repo_id is None and endpoint_url is None: + raise ValueError("Either repo_id or endpoint_url must be provided.") + self._client = _InferenceClient( + model=endpoint_url or repo_id, + token=api_token, + ) + + def invoke(self, messages: list[dict[str, str]]) -> str: + resp = self._client.chat_completion(messages=messages) # type: ignore[arg-type] + return resp.choices[0].message.content or "" diff --git a/src/lang2sql/integrations/llm/ollama_.py b/src/lang2sql/integrations/llm/ollama_.py new file mode 100644 index 0000000..2c182cf --- /dev/null +++ b/src/lang2sql/integrations/llm/ollama_.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import LLMPort + +try: + import ollama as _ollama # type: ignore[import] +except ImportError: + _ollama = None # type: ignore[assignment] + + +class OllamaLLM(LLMPort): + """LLMPort implementation backed by the Ollama chat API.""" + + def __init__( + self, + *, + model: str, + base_url: str = "http://localhost:11434", + ) -> None: + if _ollama is None: + raise IntegrationMissingError( + "ollama", hint="pip install ollama" + ) + self._model = model + self._client = _ollama.Client(host=base_url) + + def invoke(self, messages: list[dict[str, str]]) -> str: + resp = self._client.chat( + model=self._model, + messages=messages, # type: ignore[arg-type] + ) + return resp.message.content diff --git a/tests/test_integrations_llm_azure.py b/tests/test_integrations_llm_azure.py new file mode 100644 index 0000000..8c1592e --- /dev/null +++ b/tests/test_integrations_llm_azure.py @@ -0,0 +1,53 @@ +"""Tests for AzureOpenAILLM integration.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +openai = pytest.importorskip("openai", reason="openai not installed") + +from lang2sql.integrations.llm.azure_ import AzureOpenAILLM + + +def _make_llm() -> AzureOpenAILLM: + return AzureOpenAILLM( + azure_deployment="gpt-4o", + azure_endpoint="https://test.openai.azure.com/", + api_key="test-key", + ) + + +def test_azure_llm_invoke_returns_string(): + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "SELECT 1" + + with patch("openai.AzureOpenAI") as MockClient: + instance = MockClient.return_value + instance.chat.completions.create.return_value = mock_resp + + llm = _make_llm() + llm._client = instance + result = llm.invoke([{"role": "user", "content": "hello"}]) + + assert result == "SELECT 1" + + +def test_azure_llm_missing_dependency_raises(): + import sys + + with patch.dict(sys.modules, {"openai": None}): + # Re-import to trigger the ImportError guard + import importlib + + import lang2sql.integrations.llm.azure_ as mod + + importlib.reload(mod) + with pytest.raises(Exception): + mod.AzureOpenAILLM( + azure_deployment="x", + azure_endpoint="https://x.openai.azure.com/", + ) + # Reload back + importlib.reload(mod) diff --git a/tests/test_integrations_llm_ollama.py b/tests/test_integrations_llm_ollama.py new file mode 100644 index 0000000..95adf17 --- /dev/null +++ b/tests/test_integrations_llm_ollama.py @@ -0,0 +1,26 @@ +"""Tests for OllamaLLM integration.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +ollama = pytest.importorskip("ollama", reason="ollama not installed") + +from lang2sql.integrations.llm.ollama_ import OllamaLLM + + +def test_ollama_llm_invoke_returns_string(): + mock_resp = MagicMock() + mock_resp.message.content = "SELECT 1" + + with patch("ollama.Client") as MockClient: + instance = MockClient.return_value + instance.chat.return_value = mock_resp + + llm = OllamaLLM(model="llama3", base_url="http://localhost:11434") + llm._client = instance + result = llm.invoke([{"role": "user", "content": "hello"}]) + + assert result == "SELECT 1" From e368e5f75b7ee2fc304969af7a59ea620a1560b3 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:31:18 +0900 Subject: [PATCH 02/10] feat(integrations/embedding): add Azure, Bedrock, Gemini, Ollama, HuggingFace embedding providers --- .../integrations/embedding/__init__.py | 14 ++++- src/lang2sql/integrations/embedding/azure_.py | 43 +++++++++++++++ .../integrations/embedding/bedrock_.py | 55 +++++++++++++++++++ .../integrations/embedding/gemini_.py | 46 ++++++++++++++++ .../integrations/embedding/huggingface_.py | 27 +++++++++ .../integrations/embedding/ollama_.py | 34 ++++++++++++ tests/test_integrations_embedding_azure.py | 51 +++++++++++++++++ 7 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 src/lang2sql/integrations/embedding/azure_.py create mode 100644 src/lang2sql/integrations/embedding/bedrock_.py create mode 100644 src/lang2sql/integrations/embedding/gemini_.py create mode 100644 src/lang2sql/integrations/embedding/huggingface_.py create mode 100644 src/lang2sql/integrations/embedding/ollama_.py create mode 100644 tests/test_integrations_embedding_azure.py diff --git a/src/lang2sql/integrations/embedding/__init__.py b/src/lang2sql/integrations/embedding/__init__.py index b427c1e..3c4d54b 100644 --- a/src/lang2sql/integrations/embedding/__init__.py +++ b/src/lang2sql/integrations/embedding/__init__.py @@ -1,3 +1,15 @@ +from .azure_ import AzureOpenAIEmbedding +from .bedrock_ import BedrockEmbedding +from .gemini_ import GeminiEmbedding +from .huggingface_ import HuggingFaceEmbedding +from .ollama_ import OllamaEmbedding from .openai_ import OpenAIEmbedding -__all__ = ["OpenAIEmbedding"] +__all__ = [ + "AzureOpenAIEmbedding", + "BedrockEmbedding", + "GeminiEmbedding", + "HuggingFaceEmbedding", + "OllamaEmbedding", + "OpenAIEmbedding", +] diff --git a/src/lang2sql/integrations/embedding/azure_.py b/src/lang2sql/integrations/embedding/azure_.py new file mode 100644 index 0000000..824203e --- /dev/null +++ b/src/lang2sql/integrations/embedding/azure_.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + +try: + import openai as _openai +except ImportError: + _openai = None # type: ignore[assignment] + + +class AzureOpenAIEmbedding(EmbeddingPort): + """EmbeddingPort implementation backed by the Azure OpenAI Embeddings API.""" + + def __init__( + self, + *, + azure_deployment: str, + azure_endpoint: str, + api_version: str = "2023-07-01-preview", + api_key: str | None = None, + ) -> None: + if _openai is None: + raise IntegrationMissingError( + "openai", hint="pip install openai # or: uv sync" + ) + self._client = _openai.AzureOpenAI( + api_key=api_key, + azure_endpoint=azure_endpoint, + api_version=api_version, + ) + self._deployment = azure_deployment + + def embed_query(self, text: str) -> list[float]: + return ( + self._client.embeddings.create(input=text, model=self._deployment) + .data[0] + .embedding + ) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + resp = self._client.embeddings.create(input=texts, model=self._deployment) + return [item.embedding for item in resp.data] diff --git a/src/lang2sql/integrations/embedding/bedrock_.py b/src/lang2sql/integrations/embedding/bedrock_.py new file mode 100644 index 0000000..74d5e1e --- /dev/null +++ b/src/lang2sql/integrations/embedding/bedrock_.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import json + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + +try: + import boto3 as _boto3 # type: ignore[import] +except ImportError: + _boto3 = None # type: ignore[assignment] + + +class BedrockEmbedding(EmbeddingPort): + """EmbeddingPort implementation backed by AWS Bedrock Embeddings API. + + Supports Amazon Titan embedding models (e.g., amazon.titan-embed-text-v1). + """ + + def __init__( + self, + *, + model_id: str, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + region_name: str = "us-east-1", + ) -> None: + if _boto3 is None: + raise IntegrationMissingError( + "boto3", hint="pip install boto3" + ) + self._model_id = model_id + self._client = _boto3.client( + "bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + ) + + def _embed_single(self, text: str) -> list[float]: + body = json.dumps({"inputText": text}) + resp = self._client.invoke_model( + modelId=self._model_id, + body=body, + contentType="application/json", + accept="application/json", + ) + result = json.loads(resp["body"].read()) + return result["embedding"] + + def embed_query(self, text: str) -> list[float]: + return self._embed_single(text) + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [self._embed_single(t) for t in texts] diff --git a/src/lang2sql/integrations/embedding/gemini_.py b/src/lang2sql/integrations/embedding/gemini_.py new file mode 100644 index 0000000..ee1004f --- /dev/null +++ b/src/lang2sql/integrations/embedding/gemini_.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + +try: + import google.generativeai as _genai # type: ignore[import] +except ImportError: + _genai = None # type: ignore[assignment] + + +class GeminiEmbedding(EmbeddingPort): + """EmbeddingPort implementation backed by the Google Gemini Embeddings API.""" + + def __init__( + self, + *, + model: str = "models/embedding-001", + api_key: str | None = None, + ) -> None: + if _genai is None: + raise IntegrationMissingError( + "google-generativeai", + hint="pip install google-generativeai", + ) + if api_key: + _genai.configure(api_key=api_key) + self._model = model + + def embed_query(self, text: str) -> list[float]: + result = _genai.embed_content( + model=self._model, + content=text, + task_type="retrieval_query", + ) + return result["embedding"] + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [ + _genai.embed_content( + model=self._model, + content=t, + task_type="retrieval_document", + )["embedding"] + for t in texts + ] diff --git a/src/lang2sql/integrations/embedding/huggingface_.py b/src/lang2sql/integrations/embedding/huggingface_.py new file mode 100644 index 0000000..b0a584e --- /dev/null +++ b/src/lang2sql/integrations/embedding/huggingface_.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + +try: + from sentence_transformers import SentenceTransformer as _SentenceTransformer # type: ignore[import] +except ImportError: + _SentenceTransformer = None # type: ignore[assignment] + + +class HuggingFaceEmbedding(EmbeddingPort): + """EmbeddingPort implementation backed by sentence-transformers.""" + + def __init__(self, *, model: str) -> None: + if _SentenceTransformer is None: + raise IntegrationMissingError( + "sentence-transformers", + hint="pip install sentence-transformers", + ) + self._model = _SentenceTransformer(model) + + def embed_query(self, text: str) -> list[float]: + return self._model.encode(text, convert_to_numpy=True).tolist() + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return self._model.encode(texts, convert_to_numpy=True).tolist() diff --git a/src/lang2sql/integrations/embedding/ollama_.py b/src/lang2sql/integrations/embedding/ollama_.py new file mode 100644 index 0000000..cc84989 --- /dev/null +++ b/src/lang2sql/integrations/embedding/ollama_.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from ...core.exceptions import IntegrationMissingError +from ...core.ports import EmbeddingPort + +try: + import ollama as _ollama # type: ignore[import] +except ImportError: + _ollama = None # type: ignore[assignment] + + +class OllamaEmbedding(EmbeddingPort): + """EmbeddingPort implementation backed by the Ollama Embeddings API.""" + + def __init__( + self, + *, + model: str, + base_url: str = "http://localhost:11434", + ) -> None: + if _ollama is None: + raise IntegrationMissingError( + "ollama", hint="pip install ollama" + ) + self._model = model + self._client = _ollama.Client(host=base_url) + + def embed_query(self, text: str) -> list[float]: + resp = self._client.embed(model=self._model, input=text) + return resp.embeddings[0] + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + resp = self._client.embed(model=self._model, input=texts) + return resp.embeddings diff --git a/tests/test_integrations_embedding_azure.py b/tests/test_integrations_embedding_azure.py new file mode 100644 index 0000000..7e099ad --- /dev/null +++ b/tests/test_integrations_embedding_azure.py @@ -0,0 +1,51 @@ +"""Tests for AzureOpenAIEmbedding integration.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +openai = pytest.importorskip("openai", reason="openai not installed") + +from lang2sql.integrations.embedding.azure_ import AzureOpenAIEmbedding + + +def _make_embedding() -> AzureOpenAIEmbedding: + return AzureOpenAIEmbedding( + azure_deployment="text-embedding-ada-002", + azure_endpoint="https://test.openai.azure.com/", + api_key="test-key", + ) + + +def _mock_embedding_response(vectors: list[list[float]]): + resp = MagicMock() + resp.data = [MagicMock(embedding=v) for v in vectors] + return resp + + +def test_embed_query_returns_vector(): + vec = [0.1, 0.2, 0.3] + with patch("openai.AzureOpenAI") as MockClient: + instance = MockClient.return_value + instance.embeddings.create.return_value = _mock_embedding_response([vec]) + + emb = _make_embedding() + emb._client = instance + result = emb.embed_query("hello") + + assert result == vec + + +def test_embed_texts_returns_multiple_vectors(): + vecs = [[0.1, 0.2], [0.3, 0.4]] + with patch("openai.AzureOpenAI") as MockClient: + instance = MockClient.return_value + instance.embeddings.create.return_value = _mock_embedding_response(vecs) + + emb = _make_embedding() + emb._client = instance + result = emb.embed_texts(["hello", "world"]) + + assert result == vecs From 636e4f36640f484a6e0bf72e434067c602a1bb7f Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:31:35 +0900 Subject: [PATCH 03/10] feat(integrations/catalog): add DataHubCatalogLoader bridging legacy fetcher to CatalogEntry --- src/lang2sql/integrations/catalog/__init__.py | 3 + src/lang2sql/integrations/catalog/datahub_.py | 61 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 src/lang2sql/integrations/catalog/__init__.py create mode 100644 src/lang2sql/integrations/catalog/datahub_.py diff --git a/src/lang2sql/integrations/catalog/__init__.py b/src/lang2sql/integrations/catalog/__init__.py new file mode 100644 index 0000000..eaa993e --- /dev/null +++ b/src/lang2sql/integrations/catalog/__init__.py @@ -0,0 +1,3 @@ +from .datahub_ import DataHubCatalogLoader + +__all__ = ["DataHubCatalogLoader"] diff --git a/src/lang2sql/integrations/catalog/datahub_.py b/src/lang2sql/integrations/catalog/datahub_.py new file mode 100644 index 0000000..cb3154f --- /dev/null +++ b/src/lang2sql/integrations/catalog/datahub_.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from ...core.catalog import CatalogEntry +from ...core.exceptions import IntegrationMissingError + +try: + import datahub as _datahub # type: ignore[import] +except ImportError: + _datahub = None # type: ignore[assignment] + + +class DataHubCatalogLoader: + """DataHub URN → list[CatalogEntry] 변환. + + DataHub GMS 서버에서 테이블 메타데이터를 조회하여 + v2 아키텍처의 CatalogEntry 포맷으로 변환한다. + """ + + def __init__( + self, + gms_server: str = "http://localhost:8080", + extra_headers: dict | None = None, + ) -> None: + if _datahub is None: + raise IntegrationMissingError( + "acryl-datahub", + hint="pip install acryl-datahub", + ) + # 레거시 DatahubMetadataFetcher를 내부에서 wrapping + from utils.data.datahub_source import DatahubMetadataFetcher # type: ignore[import] + + self._fetcher = DatahubMetadataFetcher( + gms_server=gms_server, + extra_headers=extra_headers or {}, + ) + + def load(self, urns: list[str] | None = None) -> list[CatalogEntry]: + """DataHub에서 CatalogEntry 목록을 로드한다. + + Args: + urns: 조회할 URN 목록. None이면 전체 URN을 조회한다. + + Returns: + CatalogEntry 목록 + """ + if urns is None: + urns = list(self._fetcher.get_urns()) + + entries: list[CatalogEntry] = [] + for urn in urns: + name = self._fetcher.get_table_name(urn) or "" + description = self._fetcher.get_table_description(urn) or "" + raw_cols = self._fetcher.get_column_names_and_descriptions(urn) or [] + columns: dict[str, str] = { + col["column_name"]: col.get("column_description") or "" + for col in raw_cols + } + entries.append( + CatalogEntry(name=name, description=description, columns=columns) + ) + return entries From 1a9c1f86f76996ffdd82b6183d852064a131fe57 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:31:51 +0900 Subject: [PATCH 04/10] feat(components): add gate and enrichment components with domain type --- .../components/enrichment/__init__.py | 4 + .../components/enrichment/context_enricher.py | 58 +++++++++ .../enrichment/prompts/context_enricher.md | 22 ++++ .../enrichment/prompts/question_profiler.md | 28 +++++ .../enrichment/question_profiler.py | 61 ++++++++++ src/lang2sql/components/gate/__init__.py | 4 + .../components/gate/prompts/question_gate.md | 25 ++++ .../gate/prompts/table_suitability.md | 47 ++++++++ src/lang2sql/components/gate/question_gate.py | 55 +++++++++ .../components/gate/table_suitability.py | 80 +++++++++++++ src/lang2sql/core/catalog.py | 34 ++++++ tests/test_components_context_enricher.py | 76 ++++++++++++ tests/test_components_question_gate.py | 89 ++++++++++++++ tests/test_components_question_profiler.py | 75 ++++++++++++ tests/test_components_table_suitability.py | 111 ++++++++++++++++++ 15 files changed, 769 insertions(+) create mode 100644 src/lang2sql/components/enrichment/__init__.py create mode 100644 src/lang2sql/components/enrichment/context_enricher.py create mode 100644 src/lang2sql/components/enrichment/prompts/context_enricher.md create mode 100644 src/lang2sql/components/enrichment/prompts/question_profiler.md create mode 100644 src/lang2sql/components/enrichment/question_profiler.py create mode 100644 src/lang2sql/components/gate/__init__.py create mode 100644 src/lang2sql/components/gate/prompts/question_gate.md create mode 100644 src/lang2sql/components/gate/prompts/table_suitability.md create mode 100644 src/lang2sql/components/gate/question_gate.py create mode 100644 src/lang2sql/components/gate/table_suitability.py create mode 100644 tests/test_components_context_enricher.py create mode 100644 tests/test_components_question_gate.py create mode 100644 tests/test_components_question_profiler.py create mode 100644 tests/test_components_table_suitability.py diff --git a/src/lang2sql/components/enrichment/__init__.py b/src/lang2sql/components/enrichment/__init__.py new file mode 100644 index 0000000..3abd0d1 --- /dev/null +++ b/src/lang2sql/components/enrichment/__init__.py @@ -0,0 +1,4 @@ +from .context_enricher import ContextEnricher +from .question_profiler import QuestionProfiler + +__all__ = ["ContextEnricher", "QuestionProfiler"] diff --git a/src/lang2sql/components/enrichment/context_enricher.py b/src/lang2sql/components/enrichment/context_enricher.py new file mode 100644 index 0000000..7d4d231 --- /dev/null +++ b/src/lang2sql/components/enrichment/context_enricher.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import dataclasses +import json +from pathlib import Path +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import CatalogEntry, QuestionProfile +from ...core.hooks import TraceHook +from ...core.ports import LLMPort + +_PROMPT_PATH = Path(__file__).parent / "prompts" / "context_enricher.md" + + +def _load_prompt() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8").strip() + + +class ContextEnricher(BaseComponent): + """질문 프로파일 + 스키마 메타데이터로 질문을 보강한다.""" + + def __init__( + self, + *, + llm: LLMPort, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "ContextEnricher", hook=hook) + self._llm = llm + self._system_prompt = _load_prompt() + + def _run( + self, + query: str, + schemas: list[CatalogEntry], + profile: QuestionProfile, + ) -> str: + profiles_json = json.dumps(dataclasses.asdict(profile), ensure_ascii=False) + + tables_map: dict[str, dict] = { + entry.get("name", ""): { + "description": entry.get("description", ""), + "columns": entry.get("columns", {}), + } + for entry in schemas + } + tables_json = json.dumps(tables_map, ensure_ascii=False) + + user_content = ( + self._system_prompt + .replace("{profiles}", profiles_json) + .replace("{related_tables}", tables_json) + .replace("{refined_question}", query) + ) + messages = [{"role": "user", "content": user_content}] + return self._llm.invoke(messages).strip() diff --git a/src/lang2sql/components/enrichment/prompts/context_enricher.md b/src/lang2sql/components/enrichment/prompts/context_enricher.md new file mode 100644 index 0000000..a148dd9 --- /dev/null +++ b/src/lang2sql/components/enrichment/prompts/context_enricher.md @@ -0,0 +1,22 @@ +# Role + +You are a smart assistant that takes a user question and enriches it using: +1. Question profiles: {profiles} +2. Table metadata (names, columns, descriptions): + {related_tables} + +# Tasks + +- Correct any wrong terms by matching them to actual column names. +- If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). +- If needed, map natural language terms to actual column values (e.g., '미국' → 'USA' for country_code). +- Output the enriched question only. + +# Input + +Refined question: +{refined_question} + +# Notes + +Using the refined version for enrichment, but keep the original intent in mind. diff --git a/src/lang2sql/components/enrichment/prompts/question_profiler.md b/src/lang2sql/components/enrichment/prompts/question_profiler.md new file mode 100644 index 0000000..806cedc --- /dev/null +++ b/src/lang2sql/components/enrichment/prompts/question_profiler.md @@ -0,0 +1,28 @@ +# Role + +You are an assistant that analyzes a user question and extracts the following profiles as JSON: +- is_timeseries (boolean) +- is_aggregation (boolean) +- has_filter (boolean) +- is_grouped (boolean) +- has_ranking (boolean) +- has_temporal_comparison (boolean) +- intent_type (one of: trend, lookup, comparison, distribution) + +# Input + +Question: +{question} + +# Output + +The output must be a valid JSON matching the schema below (no extra keys): +{{ + "is_timeseries": boolean, + "is_aggregation": boolean, + "has_filter": boolean, + "is_grouped": boolean, + "has_ranking": boolean, + "has_temporal_comparison": boolean, + "intent_type": string +}} diff --git a/src/lang2sql/components/enrichment/question_profiler.py b/src/lang2sql/components/enrichment/question_profiler.py new file mode 100644 index 0000000..f0dc5c8 --- /dev/null +++ b/src/lang2sql/components/enrichment/question_profiler.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import QuestionProfile +from ...core.hooks import TraceHook +from ...core.ports import LLMPort + +_PROMPT_PATH = Path(__file__).parent / "prompts" / "question_profiler.md" + +_VALID_INTENT_TYPES = {"trend", "lookup", "comparison", "distribution"} + + +def _load_prompt() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8").strip() + + +def _parse_json(text: str) -> dict: + match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE) + if match: + text = match.group(1).strip() + return json.loads(text) + + +class QuestionProfiler(BaseComponent): + """질문에서 구조화된 특성(시계열, 집계, 필터 등)을 추출한다.""" + + def __init__( + self, + *, + llm: LLMPort, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "QuestionProfiler", hook=hook) + self._llm = llm + self._system_prompt = _load_prompt() + + def _run(self, query: str) -> QuestionProfile: + user_content = self._system_prompt.replace("{question}", query) + messages = [{"role": "user", "content": user_content}] + response = self._llm.invoke(messages) + data = _parse_json(response) + + intent_type = str(data.get("intent_type", "lookup")) + if intent_type not in _VALID_INTENT_TYPES: + intent_type = "lookup" + + return QuestionProfile( + is_timeseries=bool(data.get("is_timeseries", False)), + is_aggregation=bool(data.get("is_aggregation", False)), + has_filter=bool(data.get("has_filter", False)), + is_grouped=bool(data.get("is_grouped", False)), + has_ranking=bool(data.get("has_ranking", False)), + has_temporal_comparison=bool(data.get("has_temporal_comparison", False)), + intent_type=intent_type, + ) diff --git a/src/lang2sql/components/gate/__init__.py b/src/lang2sql/components/gate/__init__.py new file mode 100644 index 0000000..0fc026c --- /dev/null +++ b/src/lang2sql/components/gate/__init__.py @@ -0,0 +1,4 @@ +from .question_gate import QuestionGate +from .table_suitability import TableSuitabilityEvaluator + +__all__ = ["QuestionGate", "TableSuitabilityEvaluator"] diff --git a/src/lang2sql/components/gate/prompts/question_gate.md b/src/lang2sql/components/gate/prompts/question_gate.md new file mode 100644 index 0000000..2e08150 --- /dev/null +++ b/src/lang2sql/components/gate/prompts/question_gate.md @@ -0,0 +1,25 @@ +당신은 데이터 분석 도우미입니다. 아래 사용자 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하세요. + +요건: +- suitable: 질문이 SQL로 답변 가능한지 여부(Boolean) +- reason: 한 줄 설명(어떤 보완이 필요한지 요약) +- missing_entities: 기간, 대상 엔터티, 측정값 등 누락된 핵심 요소 리스트(없으면 빈 리스트) +- requires_data_science: 통계/ML 분석이 필요한지 여부(Boolean) + +언어/출력 형식: +- 모든 텍스트 값은 한국어로 작성하세요. (reason는 한국어 문장, missing_entities 항목은 한국어 명사구) +- Boolean 값은 JSON의 true/false로 표기하세요. + +주의: +- 데이터 분석 맥락에서 SQL 집계/필터/조인으로 해결 가능한지 판단합니다. +- 정책/운영/가이드/설치/권한/오류 해결 등은 SQL 부적합으로 간주합니다. + +입력: {question} + +출력은 반드시 아래 JSON 스키마로만 반환하세요: +{{ + "suitable": boolean, + "reason": string, + "missing_entities": string[], + "requires_data_science": boolean +}} diff --git a/src/lang2sql/components/gate/prompts/table_suitability.md b/src/lang2sql/components/gate/prompts/table_suitability.md new file mode 100644 index 0000000..23b0e53 --- /dev/null +++ b/src/lang2sql/components/gate/prompts/table_suitability.md @@ -0,0 +1,47 @@ +## 문서 적합성 평가 프롬프트 (Table Search 재랭킹) + +당신은 데이터 카탈로그 평가자입니다. 주어진 사용자 질문과 검색 결과(테이블 → 칼럼 설명 맵)를 바탕으로, 각 테이블이 질문에 얼마나 적합한지 0~1 사이의 실수 점수로 평가하세요. + +### 입력 +- **question**: {question} +- **tables**: {tables} + +### 과업 +1. **핵심 신호 추출**: 질문에서 엔터티/지표/시간/필터/그룹화 단서를 추출합니다. +2. **테이블별 점수화**: 각 테이블의 칼럼·설명과의 연관성으로 적합도를 점수화합니다(0~1, 소수 셋째 자리 반올림). +3. **근거와 보완점 제시**: 매칭된 칼럼과 부족한 요소(엔터티/지표/기간 등)를 한국어로 설명합니다. +4. **정렬**: 결과를 점수 내림차순으로 정렬해 반환합니다. + +### 평가 규칙(가이드) +- **0.90~1.00**: 필요한 엔터티, 기간/시간 컬럼, 핵심 지표/측정 칼럼이 모두 존재. 직접 조회/집계만으로 답 가능. +- **0.60~0.89**: 주요 신호 매칭, 일부 보완(기간/그룹 키/보조 칼럼) 필요. 조인 없이 근사 가능. +- **0.30~0.59**: 일부만 매칭. 외부 컨텍스트나 조인 없이는 부정확/제한적. +- **0.00~0.29**: 연관성 낮음. 스키마/도메인 불일치 또는 정책/운영성 테이블. + +### 주의 +- 칼럼 이름/설명에 실제로 존재하지 않는 항목을 매칭하지 마세요(환각 금지). +- 시간 요구(특정 날짜/기간)가 있으면 timestamp/date/created_at 등 시간 계열 키를 중시하세요. +- 엔티티 키(예: id, user_id, product_id)의 존재 여부를 가산점으로 반영하세요. +- 키 이름은 정확히 입력 맵의 키만 사용하세요(자유 추측 금지). + +### 언어/출력 형식 +- 모든 텍스트 값은 한국어로 작성하세요. +- 결과는 반드시 아래 JSON 스키마로만 반환하세요(추가/누락 키 금지). + +### 출력(JSON 스키마) +{{ + "results": [ + {{ + "table_name": string, + "score": number, + "reason": string, + "matched_columns": string[], + "missing_entities": string[] + }} + ] +}} + +### 검증 규칙 +- score는 [0, 1] 범위로 클램핑하고 소수 셋째 자리까지 반올림하세요. +- matched_columns는 해당 테이블 객체의 실제 키만 포함하세요(단, table_description 제외). +- reason 및 missing_entities는 한국어로 작성하세요. diff --git a/src/lang2sql/components/gate/question_gate.py b/src/lang2sql/components/gate/question_gate.py new file mode 100644 index 0000000..6d559eb --- /dev/null +++ b/src/lang2sql/components/gate/question_gate.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import GateResult +from ...core.hooks import TraceHook +from ...core.ports import LLMPort + +_PROMPT_PATH = Path(__file__).parent / "prompts" / "question_gate.md" + + +def _load_prompt() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8").strip() + + +def _parse_json(text: str) -> dict: + """LLM 응답에서 JSON을 추출한다. 마크다운 코드블록을 자동으로 제거한다.""" + # ```json ... ``` 또는 ``` ... ``` 블록 제거 + match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE) + if match: + text = match.group(1).strip() + return json.loads(text) + + +class QuestionGate(BaseComponent): + """질문이 SQL로 답변 가능한지 판별한다.""" + + def __init__( + self, + *, + llm: LLMPort, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "QuestionGate", hook=hook) + self._llm = llm + self._system_prompt = _load_prompt() + + def _run(self, query: str) -> GateResult: + user_content = self._system_prompt.replace("{question}", query) + messages = [ + {"role": "user", "content": user_content}, + ] + response = self._llm.invoke(messages) + data = _parse_json(response) + return GateResult( + suitable=bool(data.get("suitable", True)), + reason=str(data.get("reason", "")), + missing_entities=list(data.get("missing_entities", [])), + requires_data_science=bool(data.get("requires_data_science", False)), + ) diff --git a/src/lang2sql/components/gate/table_suitability.py b/src/lang2sql/components/gate/table_suitability.py new file mode 100644 index 0000000..091bbb3 --- /dev/null +++ b/src/lang2sql/components/gate/table_suitability.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Optional + +from ...core.base import BaseComponent +from ...core.catalog import CatalogEntry, TableScore +from ...core.hooks import TraceHook +from ...core.ports import LLMPort + +_PROMPT_PATH = Path(__file__).parent / "prompts" / "table_suitability.md" + + +def _load_prompt() -> str: + return _PROMPT_PATH.read_text(encoding="utf-8").strip() + + +def _parse_json(text: str) -> dict: + match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE) + if match: + text = match.group(1).strip() + return json.loads(text) + + +class TableSuitabilityEvaluator(BaseComponent): + """검색된 테이블을 질문 관련도순으로 필터링한다.""" + + def __init__( + self, + *, + llm: LLMPort, + threshold: float = 0.3, + name: Optional[str] = None, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name=name or "TableSuitabilityEvaluator", hook=hook) + self._llm = llm + self._threshold = threshold + self._system_prompt = _load_prompt() + + def _run(self, query: str, schemas: list[CatalogEntry]) -> list[CatalogEntry]: + # 테이블을 {table_name: {col: desc, ...}} 구조로 직렬화 + tables_map: dict[str, dict] = {} + for entry in schemas: + name = entry.get("name", "") + cols = entry.get("columns", {}) + desc = entry.get("description", "") + tables_map[name] = {"table_description": desc, **cols} + + tables_json = json.dumps(tables_map, ensure_ascii=False) + user_content = ( + self._system_prompt + .replace("{question}", query) + .replace("{tables}", tables_json) + ) + messages = [{"role": "user", "content": user_content}] + response = self._llm.invoke(messages) + data = _parse_json(response) + + results: list[TableScore] = [ + TableScore( + table_name=r["table_name"], + score=float(r.get("score", 0.0)), + reason=str(r.get("reason", "")), + matched_columns=list(r.get("matched_columns", [])), + missing_entities=list(r.get("missing_entities", [])), + ) + for r in data.get("results", []) + ] + + # threshold 이상인 테이블만 score 내림차순으로 필터 + passing = {r.table_name for r in results if r.score >= self._threshold} + filtered = [e for e in schemas if e.get("name", "") in passing] + + # score 내림차순 정렬 + score_map = {r.table_name: r.score for r in results} + filtered.sort(key=lambda e: score_map.get(e.get("name", ""), 0.0), reverse=True) + return filtered diff --git a/src/lang2sql/core/catalog.py b/src/lang2sql/core/catalog.py index bda5bf9..56413d0 100644 --- a/src/lang2sql/core/catalog.py +++ b/src/lang2sql/core/catalog.py @@ -37,3 +37,37 @@ class RetrievalResult: schemas: list[CatalogEntry] = field(default_factory=list) context: list[str] = field(default_factory=list) + + +@dataclass +class GateResult: + """QuestionGate 컴포넌트의 반환 타입.""" + + suitable: bool + reason: str + missing_entities: list[str] = field(default_factory=list) + requires_data_science: bool = False + + +@dataclass +class QuestionProfile: + """질문 특성 프로파일.""" + + is_timeseries: bool = False + is_aggregation: bool = False + has_filter: bool = False + is_grouped: bool = False + has_ranking: bool = False + has_temporal_comparison: bool = False + intent_type: str = "lookup" # trend | lookup | comparison | distribution + + +@dataclass +class TableScore: + """개별 테이블의 적합도 평가 결과.""" + + table_name: str + score: float + reason: str + matched_columns: list[str] = field(default_factory=list) + missing_entities: list[str] = field(default_factory=list) diff --git a/tests/test_components_context_enricher.py b/tests/test_components_context_enricher.py new file mode 100644 index 0000000..3d4e927 --- /dev/null +++ b/tests/test_components_context_enricher.py @@ -0,0 +1,76 @@ +"""Tests for ContextEnricher component.""" + +from __future__ import annotations + +import pytest + +from lang2sql.components.enrichment.context_enricher import ContextEnricher +from lang2sql.core.catalog import CatalogEntry, QuestionProfile +from lang2sql.core.hooks import MemoryHook + + +class FakeLLM: + def __init__(self, response: str = "지난달(2024-03) 주문 건수를 COUNT하는 쿼리"): + self._response = response + + def invoke(self, messages: list[dict]) -> str: + return self._response + + +def _catalog() -> list[CatalogEntry]: + return [ + { + "name": "orders", + "description": "주문 테이블", + "columns": {"order_id": "주문 ID", "amount": "주문 금액", "created_at": "생성일"}, + } + ] + + +def _profile(is_aggregation: bool = True) -> QuestionProfile: + return QuestionProfile( + is_aggregation=is_aggregation, + has_filter=True, + intent_type="lookup", + ) + + +def test_context_enricher_returns_string(): + llm = FakeLLM("enriched question text") + enricher = ContextEnricher(llm=llm) + result = enricher.run("주문 수", _catalog(), _profile()) + assert isinstance(result, str) + assert result == "enriched question text" + + +def test_context_enricher_trims_whitespace(): + llm = FakeLLM(" enriched ") + enricher = ContextEnricher(llm=llm) + result = enricher.run("test", _catalog(), _profile()) + assert result == "enriched" + + +def test_context_enricher_emits_hook_events(): + hook = MemoryHook() + llm = FakeLLM("enriched") + enricher = ContextEnricher(llm=llm, hook=hook) + enricher.run("test", _catalog(), _profile()) + phases = [e.phase for e in hook.events] + assert "start" in phases + assert "end" in phases + + +def test_context_enricher_passes_profile_to_llm(): + received_messages = [] + + class CaptureLLM: + def invoke(self, messages: list[dict]) -> str: + received_messages.extend(messages) + return "ok" + + profiler = QuestionProfile(is_timeseries=True, intent_type="trend") + enricher = ContextEnricher(llm=CaptureLLM()) + enricher.run("월별 추이", _catalog(), profiler) + assert received_messages + # profile JSON should appear in the message content + assert "is_timeseries" in received_messages[0]["content"] diff --git a/tests/test_components_question_gate.py b/tests/test_components_question_gate.py new file mode 100644 index 0000000..21097ec --- /dev/null +++ b/tests/test_components_question_gate.py @@ -0,0 +1,89 @@ +"""Tests for QuestionGate component.""" + +from __future__ import annotations + +import json + +import pytest + +from lang2sql.components.gate.question_gate import QuestionGate +from lang2sql.core.catalog import GateResult +from lang2sql.core.hooks import MemoryHook + + +class FakeLLM: + def __init__(self, response: str): + self._response = response + + def invoke(self, messages: list[dict]) -> str: + return self._response + + +def _gate_json( + suitable: bool = True, + reason: str = "SQL로 답변 가능합니다.", + missing_entities: list | None = None, + requires_data_science: bool = False, +) -> str: + return json.dumps( + { + "suitable": suitable, + "reason": reason, + "missing_entities": missing_entities or [], + "requires_data_science": requires_data_science, + }, + ensure_ascii=False, + ) + + +def test_question_gate_suitable_true(): + llm = FakeLLM(_gate_json(suitable=True, reason="OK")) + gate = QuestionGate(llm=llm) + result = gate.run("지난달 주문 수는?") + assert isinstance(result, GateResult) + assert result.suitable is True + assert result.reason == "OK" + + +def test_question_gate_suitable_false(): + llm = FakeLLM( + _gate_json( + suitable=False, + reason="통계 분석이 필요합니다.", + requires_data_science=True, + ) + ) + gate = QuestionGate(llm=llm) + result = gate.run("이상 탐지 모델을 만들어줘") + assert result.suitable is False + assert result.requires_data_science is True + + +def test_question_gate_with_missing_entities(): + llm = FakeLLM( + _gate_json( + suitable=False, + missing_entities=["기간", "대상 엔터티"], + ) + ) + gate = QuestionGate(llm=llm) + result = gate.run("매출을 보여줘") + assert "기간" in result.missing_entities + + +def test_question_gate_strips_markdown_json(): + raw = "```json\n" + _gate_json(suitable=True) + "\n```" + llm = FakeLLM(raw) + gate = QuestionGate(llm=llm) + result = gate.run("test") + assert result.suitable is True + + +def test_question_gate_emits_hook_events(): + hook = MemoryHook() + llm = FakeLLM(_gate_json()) + gate = QuestionGate(llm=llm, hook=hook) + gate.run("test query") + phases = [e.phase for e in hook.events] + assert "start" in phases + assert "end" in phases diff --git a/tests/test_components_question_profiler.py b/tests/test_components_question_profiler.py new file mode 100644 index 0000000..deb5063 --- /dev/null +++ b/tests/test_components_question_profiler.py @@ -0,0 +1,75 @@ +"""Tests for QuestionProfiler component.""" + +from __future__ import annotations + +import json + +import pytest + +from lang2sql.components.enrichment.question_profiler import QuestionProfiler +from lang2sql.core.catalog import QuestionProfile +from lang2sql.core.hooks import MemoryHook + + +class FakeLLM: + def __init__(self, response: str): + self._response = response + + def invoke(self, messages: list[dict]) -> str: + return self._response + + +def _profile_json(**kwargs) -> str: + defaults = { + "is_timeseries": False, + "is_aggregation": True, + "has_filter": True, + "is_grouped": False, + "has_ranking": False, + "has_temporal_comparison": False, + "intent_type": "lookup", + } + defaults.update(kwargs) + return json.dumps(defaults) + + +def test_question_profiler_returns_profile(): + llm = FakeLLM(_profile_json(is_aggregation=True, has_filter=True)) + profiler = QuestionProfiler(llm=llm) + result = profiler.run("지난달 주문 수") + assert isinstance(result, QuestionProfile) + assert result.is_aggregation is True + assert result.has_filter is True + + +def test_question_profiler_timeseries(): + llm = FakeLLM(_profile_json(is_timeseries=True, intent_type="trend")) + profiler = QuestionProfiler(llm=llm) + result = profiler.run("월별 매출 추이") + assert result.is_timeseries is True + assert result.intent_type == "trend" + + +def test_question_profiler_invalid_intent_type_falls_back_to_lookup(): + llm = FakeLLM(_profile_json(intent_type="invalid_type")) + profiler = QuestionProfiler(llm=llm) + result = profiler.run("test") + assert result.intent_type == "lookup" + + +def test_question_profiler_strips_markdown_json(): + raw = "```json\n" + _profile_json() + "\n```" + llm = FakeLLM(raw) + profiler = QuestionProfiler(llm=llm) + result = profiler.run("test") + assert isinstance(result, QuestionProfile) + + +def test_question_profiler_emits_hook_events(): + hook = MemoryHook() + llm = FakeLLM(_profile_json()) + profiler = QuestionProfiler(llm=llm, hook=hook) + profiler.run("test") + phases = [e.phase for e in hook.events] + assert "start" in phases + assert "end" in phases diff --git a/tests/test_components_table_suitability.py b/tests/test_components_table_suitability.py new file mode 100644 index 0000000..e1c1b7d --- /dev/null +++ b/tests/test_components_table_suitability.py @@ -0,0 +1,111 @@ +"""Tests for TableSuitabilityEvaluator component.""" + +from __future__ import annotations + +import json + +import pytest + +from lang2sql.components.gate.table_suitability import TableSuitabilityEvaluator +from lang2sql.core.catalog import CatalogEntry +from lang2sql.core.hooks import MemoryHook + + +class FakeLLM: + def __init__(self, response: str): + self._response = response + + def invoke(self, messages: list[dict]) -> str: + return self._response + + +def _catalog() -> list[CatalogEntry]: + return [ + { + "name": "orders", + "description": "주문 테이블", + "columns": {"order_id": "주문 ID", "amount": "주문 금액", "created_at": "생성일"}, + }, + { + "name": "users", + "description": "사용자 테이블", + "columns": {"user_id": "사용자 ID", "name": "이름"}, + }, + ] + + +def _suitability_json(results: list[dict]) -> str: + return json.dumps({"results": results}, ensure_ascii=False) + + +def test_table_suitability_filters_below_threshold(): + resp = _suitability_json( + [ + { + "table_name": "orders", + "score": 0.9, + "reason": "핵심 지표 포함", + "matched_columns": ["amount"], + "missing_entities": [], + }, + { + "table_name": "users", + "score": 0.1, + "reason": "관련 없음", + "matched_columns": [], + "missing_entities": ["주문 정보"], + }, + ] + ) + evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), threshold=0.3) + result = evaluator.run("지난달 주문 금액 합계", _catalog()) + assert len(result) == 1 + assert result[0]["name"] == "orders" + + +def test_table_suitability_sorted_by_score(): + resp = _suitability_json( + [ + { + "table_name": "users", + "score": 0.5, + "reason": "부분 매칭", + "matched_columns": [], + "missing_entities": [], + }, + { + "table_name": "orders", + "score": 0.9, + "reason": "완전 매칭", + "matched_columns": ["amount"], + "missing_entities": [], + }, + ] + ) + evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), threshold=0.3) + result = evaluator.run("주문 금액", _catalog()) + assert result[0]["name"] == "orders" + assert result[1]["name"] == "users" + + +def test_table_suitability_empty_result_when_all_below_threshold(): + resp = _suitability_json( + [ + {"table_name": "orders", "score": 0.1, "reason": "낮은 관련성", "matched_columns": [], "missing_entities": []}, + ] + ) + evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), threshold=0.3) + result = evaluator.run("관련 없는 질문", _catalog()) + assert result == [] + + +def test_table_suitability_emits_hook_events(): + hook = MemoryHook() + resp = _suitability_json( + [{"table_name": "orders", "score": 0.8, "reason": "ok", "matched_columns": [], "missing_entities": []}] + ) + evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), hook=hook) + evaluator.run("test", _catalog()) + phases = [e.phase for e in hook.events] + assert "start" in phases + assert "end" in phases From 014fa857e16b5c4360e80990a5fe32b19c1c77e9 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:32:08 +0900 Subject: [PATCH 05/10] =?UTF-8?q?feat(flows):=20add=20EnrichedNL2SQL=20?= =?UTF-8?q?=E2=80=94=207-step=20NL=E2=86=92SQL=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/lang2sql/__init__.py | 64 ++++++++- src/lang2sql/flows/__init__.py | 4 +- src/lang2sql/flows/enriched_nl2sql.py | 101 +++++++++++++ tests/test_flows_enriched_nl2sql.py | 196 ++++++++++++++++++++++++++ 4 files changed, 360 insertions(+), 5 deletions(-) create mode 100644 src/lang2sql/flows/enriched_nl2sql.py create mode 100644 tests/test_flows_enriched_nl2sql.py diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py index 2dcdc49..2781ba9 100644 --- a/src/lang2sql/__init__.py +++ b/src/lang2sql/__init__.py @@ -1,6 +1,9 @@ -from .integrations.vectorstore.faiss_ import FAISSVectorStore -from .integrations.vectorstore.pgvector_ import PGVectorStore +from .factory import build_db_from_env, build_embedding_from_env, build_llm_from_env +from .components.enrichment.context_enricher import ContextEnricher +from .components.enrichment.question_profiler import QuestionProfiler from .components.execution.sql_executor import SQLExecutor +from .components.gate.question_gate import QuestionGate +from .components.gate.table_suitability import TableSuitabilityEvaluator from .components.generation.sql_generator import SQLGenerator from .components.loaders.directory_ import DirectoryLoader from .components.loaders.markdown_ import MarkdownLoader @@ -13,7 +16,15 @@ from .components.retrieval.hybrid import HybridRetriever from .components.retrieval.keyword import KeywordRetriever from .components.retrieval.vector import VectorRetriever -from .core.catalog import CatalogEntry, IndexedChunk, RetrievalResult, TextDocument +from .core.catalog import ( + CatalogEntry, + GateResult, + IndexedChunk, + QuestionProfile, + RetrievalResult, + TableScore, + TextDocument, +) from .core.exceptions import ComponentError, IntegrationMissingError, Lang2SQLError from .core.hooks import MemoryHook, NullHook, TraceHook from .core.ports import ( @@ -23,8 +34,22 @@ LLMPort, VectorStorePort, ) +from .flows.enriched_nl2sql import EnrichedNL2SQL from .flows.hybrid import HybridNL2SQL from .flows.nl2sql import BaselineNL2SQL +from .integrations.catalog.datahub_ import DataHubCatalogLoader +from .integrations.embedding.azure_ import AzureOpenAIEmbedding +from .integrations.embedding.bedrock_ import BedrockEmbedding +from .integrations.embedding.gemini_ import GeminiEmbedding +from .integrations.embedding.huggingface_ import HuggingFaceEmbedding +from .integrations.embedding.ollama_ import OllamaEmbedding +from .integrations.llm.azure_ import AzureOpenAILLM +from .integrations.llm.bedrock_ import BedrockLLM +from .integrations.llm.gemini_ import GeminiLLM +from .integrations.llm.huggingface_ import HuggingFaceLLM +from .integrations.llm.ollama_ import OllamaLLM +from .integrations.vectorstore.faiss_ import FAISSVectorStore +from .integrations.vectorstore.pgvector_ import PGVectorStore __all__ = [ # Data types @@ -32,27 +57,40 @@ "TextDocument", "IndexedChunk", "RetrievalResult", + # Domain types (Phase 4) + "GateResult", + "QuestionProfile", + "TableScore", # Ports (protocols) "LLMPort", "DBPort", "EmbeddingPort", "VectorStorePort", "DocumentLoaderPort", - # Components + # Components — retrieval "KeywordRetriever", "VectorRetriever", "HybridRetriever", "DocumentChunkerPort", "CatalogChunker", "RecursiveCharacterChunker", + # Components — generation & execution "SQLGenerator", "SQLExecutor", + # Components — gate (Phase 4) + "QuestionGate", + "TableSuitabilityEvaluator", + # Components — enrichment (Phase 4) + "QuestionProfiler", + "ContextEnricher", + # Components — loaders "MarkdownLoader", "PlainTextLoader", "DirectoryLoader", # Flows "BaselineNL2SQL", "HybridNL2SQL", + "EnrichedNL2SQL", # Hooks "TraceHook", "MemoryHook", @@ -64,4 +102,22 @@ # Vector store backends "FAISSVectorStore", "PGVectorStore", + # LLM integrations (Phase 1) + "AzureOpenAILLM", + "BedrockLLM", + "GeminiLLM", + "HuggingFaceLLM", + "OllamaLLM", + # Embedding integrations (Phase 2) + "AzureOpenAIEmbedding", + "BedrockEmbedding", + "GeminiEmbedding", + "HuggingFaceEmbedding", + "OllamaEmbedding", + # Catalog integrations (Phase 3) + "DataHubCatalogLoader", + # Factory (Phase 6) + "build_llm_from_env", + "build_embedding_from_env", + "build_db_from_env", ] diff --git a/src/lang2sql/flows/__init__.py b/src/lang2sql/flows/__init__.py index f50eb0e..b6aa9ac 100644 --- a/src/lang2sql/flows/__init__.py +++ b/src/lang2sql/flows/__init__.py @@ -1,3 +1,5 @@ +from .enriched_nl2sql import EnrichedNL2SQL +from .hybrid import HybridNL2SQL from .nl2sql import BaselineNL2SQL -__all__ = ["BaselineNL2SQL"] +__all__ = ["BaselineNL2SQL", "EnrichedNL2SQL", "HybridNL2SQL"] diff --git a/src/lang2sql/flows/enriched_nl2sql.py b/src/lang2sql/flows/enriched_nl2sql.py new file mode 100644 index 0000000..e01ea4b --- /dev/null +++ b/src/lang2sql/flows/enriched_nl2sql.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import Optional + +from ..components.enrichment.context_enricher import ContextEnricher +from ..components.enrichment.question_profiler import QuestionProfiler +from ..components.execution.sql_executor import SQLExecutor +from ..components.gate.question_gate import QuestionGate +from ..components.gate.table_suitability import TableSuitabilityEvaluator +from ..components.generation.sql_generator import SQLGenerator +from ..components.retrieval.hybrid import HybridRetriever +from ..core.base import BaseFlow +from ..core.catalog import TextDocument +from ..core.exceptions import ContractError +from ..core.hooks import TraceHook +from ..core.ports import DBPort, EmbeddingPort, LLMPort + + +class EnrichedNL2SQL(BaseFlow): + """ + 풀 파이프라인 NL→SQL: + QuestionGate → HybridRetriever → TableSuitabilityEvaluator + → QuestionProfiler → ContextEnricher → SQLGenerator → SQLExecutor + + 레거시 LangGraph 기반 engine/query_executor.py + graph_utils/enriched_graph.py를 대체한다. + + Args: + catalog: list[CatalogEntry] — 검색 대상 테이블 메타데이터. + llm: LLMPort — 질문 평가, SQL 생성 등에 사용. + db: DBPort — SQL 실행 대상 데이터베이스. + embedding: EmbeddingPort — 벡터 검색용 임베딩 모델. + documents: Optional list of business documents to index. + db_dialect: SQL 방언. "sqlite", "postgresql", "mysql", "bigquery", "duckdb", "default". + gate_enabled: QuestionGate를 활성화할지 여부. Default True. + top_n: HybridRetriever가 반환할 최대 스키마 수. Default 5. + hook: TraceHook for observability. + + Usage:: + + pipeline = EnrichedNL2SQL( + catalog=[{"name": "orders", "description": "...", "columns": {...}}], + llm=AnthropicLLM(model="claude-sonnet-4-6"), + db=SQLAlchemyDB("sqlite:///sample.db"), + embedding=OpenAIEmbedding(model="text-embedding-3-small"), + db_dialect="sqlite", + ) + rows = pipeline.run("지난달 주문 건수를 알려줘") + """ + + def __init__( + self, + *, + catalog: list[dict], + llm: LLMPort, + db: DBPort, + embedding: EmbeddingPort, + documents: Optional[list[TextDocument]] = None, + db_dialect: Optional[str] = None, + gate_enabled: bool = True, + top_n: int = 5, + hook: Optional[TraceHook] = None, + ) -> None: + super().__init__(name="EnrichedNL2SQL", hook=hook) + self._gate = QuestionGate(llm=llm, hook=hook) if gate_enabled else None + self._retriever = HybridRetriever( + catalog=catalog, + embedding=embedding, + documents=documents, + top_n=top_n, + hook=hook, + ) + self._table_eval = TableSuitabilityEvaluator(llm=llm, hook=hook) + self._profiler = QuestionProfiler(llm=llm, hook=hook) + self._enricher = ContextEnricher(llm=llm, hook=hook) + self._generator = SQLGenerator(llm=llm, db_dialect=db_dialect, hook=hook) + self._executor = SQLExecutor(db=db, hook=hook) + + def _run(self, query: str) -> list[dict]: + # 1. Gate (선택적): 질문이 SQL 답변 불가능하면 ContractError 발생 + if self._gate is not None: + gate = self._gate(query) + if not gate.suitable: + raise ContractError(f"Query not suitable for SQL: {gate.reason}") + + # 2. Retrieval: HybridRetriever → RetrievalResult + result = self._retriever(query) + + # 3. Table filtering: 관련도 낮은 테이블 제거 + schemas = self._table_eval(query, result.schemas) + + # 4. Profiling: 질문 특성 추출 + profile = self._profiler(query) + + # 5. Context enrichment: 보강된 질문 텍스트 생성 + enriched = self._enricher(query, schemas, profile) + + # 6. SQL generation: 보강된 컨텍스트 + 도메인 문서를 함께 전달 + sql = self._generator(query, schemas, context=[enriched] + result.context) + + # 7. Execution + return self._executor(sql) diff --git a/tests/test_flows_enriched_nl2sql.py b/tests/test_flows_enriched_nl2sql.py new file mode 100644 index 0000000..f69692a --- /dev/null +++ b/tests/test_flows_enriched_nl2sql.py @@ -0,0 +1,196 @@ +"""Tests for EnrichedNL2SQL flow.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + +from lang2sql.core.catalog import CatalogEntry +from lang2sql.core.exceptions import ContractError +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.enriched_nl2sql import EnrichedNL2SQL + + +# --------------------------------------------------------------------------- +# Fakes +# --------------------------------------------------------------------------- + + +class FakeLLM: + """Configurable fake LLM that cycles through responses.""" + + def __init__(self, responses: list[str]): + self._responses = list(responses) + self._idx = 0 + + def invoke(self, messages: list[dict]) -> str: + resp = self._responses[self._idx % len(self._responses)] + self._idx += 1 + return resp + + +class FakeDB: + def execute(self, sql: str) -> list[dict]: + return [{"count": 42}] + + +class FakeEmbedding: + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3] for _ in texts] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _gate_json(suitable: bool = True) -> str: + return json.dumps( + { + "suitable": suitable, + "reason": "ok" if suitable else "not suitable", + "missing_entities": [], + "requires_data_science": False, + } + ) + + +def _suitability_json(table_names: list[str], score: float = 0.9) -> str: + return json.dumps( + { + "results": [ + { + "table_name": name, + "score": score, + "reason": "ok", + "matched_columns": [], + "missing_entities": [], + } + for name in table_names + ] + } + ) + + +def _profile_json() -> str: + return json.dumps( + { + "is_timeseries": False, + "is_aggregation": True, + "has_filter": False, + "is_grouped": False, + "has_ranking": False, + "has_temporal_comparison": False, + "intent_type": "lookup", + } + ) + + +def _catalog() -> list[CatalogEntry]: + return [ + { + "name": "orders", + "description": "주문 테이블", + "columns": {"order_id": "ID", "amount": "금액", "created_at": "생성일"}, + } + ] + + +def _make_pipeline(gate_enabled: bool = True) -> EnrichedNL2SQL: + # LLM response order: + # 1. QuestionGate → gate JSON + # 2. TableSuitabilityEvaluator → suitability JSON + # 3. QuestionProfiler → profile JSON + # 4. ContextEnricher → enriched string + # 5. SQLGenerator → sql block + llm = FakeLLM( + [ + _gate_json(suitable=True), + _suitability_json(["orders"]), + _profile_json(), + "지난달 주문 건수를 구합니다", + "```sql\nSELECT COUNT(*) FROM orders\n```", + ] + ) + return EnrichedNL2SQL( + catalog=_catalog(), + llm=llm, + db=FakeDB(), + embedding=FakeEmbedding(), + db_dialect="sqlite", + gate_enabled=gate_enabled, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_enriched_nl2sql_full_pipeline_returns_rows(): + pipeline = _make_pipeline() + rows = pipeline.run("지난달 주문 건수") + assert rows == [{"count": 42}] + + +def test_enriched_nl2sql_gate_disabled_skips_gate(): + # With gate disabled, LLM responses shift by one + llm = FakeLLM( + [ + _suitability_json(["orders"]), + _profile_json(), + "enriched query", + "```sql\nSELECT COUNT(*) FROM orders\n```", + ] + ) + pipeline = EnrichedNL2SQL( + catalog=_catalog(), + llm=llm, + db=FakeDB(), + embedding=FakeEmbedding(), + gate_enabled=False, + ) + rows = pipeline.run("주문 건수") + assert rows == [{"count": 42}] + + +def test_enriched_nl2sql_gate_raises_contract_error_when_not_suitable(): + llm = FakeLLM([_gate_json(suitable=False)]) + pipeline = EnrichedNL2SQL( + catalog=_catalog(), + llm=llm, + db=FakeDB(), + embedding=FakeEmbedding(), + gate_enabled=True, + ) + with pytest.raises(ContractError): + pipeline.run("통계 모델을 만들어줘") + + +def test_enriched_nl2sql_emits_hook_events(): + hook = MemoryHook() + llm = FakeLLM( + [ + _gate_json(suitable=True), + _suitability_json(["orders"]), + _profile_json(), + "enriched", + "```sql\nSELECT COUNT(*) FROM orders\n```", + ] + ) + pipeline = EnrichedNL2SQL( + catalog=_catalog(), + llm=llm, + db=FakeDB(), + embedding=FakeEmbedding(), + gate_enabled=True, + hook=hook, + ) + pipeline.run("주문 건수") + # At minimum, the flow itself emits start/end + assert len(hook.events) > 0 From 83129c87a51616de73032e531ef39010b1571a44 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:32:22 +0900 Subject: [PATCH 06/10] feat(factory,cli): replace LangChain factory with env-based build_*_from_env() --- cli/commands/quary.py | 115 +++++++++------------ src/lang2sql/factory.py | 218 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+), 70 deletions(-) create mode 100644 src/lang2sql/factory.py diff --git a/cli/commands/quary.py b/cli/commands/quary.py index 0db9b9c..da1a05a 100644 --- a/cli/commands/quary.py +++ b/cli/commands/quary.py @@ -4,8 +4,6 @@ `query` CLI 명령어를 제공합니다. """ -import os - import click from cli.utils.logger import configure_logging @@ -16,14 +14,10 @@ @click.command(name="query") @click.argument("question", type=str) @click.option( - "--database-env", - default="clickhouse", - help="사용할 데이터베이스 환경 (기본값: clickhouse)", -) -@click.option( - "--retriever-name", - default="기본", - help="테이블 검색기 이름 (기본값: 기본)", + "--flow", + type=click.Choice(["baseline", "enriched"]), + default="baseline", + help="사용할 플로우 (기본값: baseline)", ) @click.option( "--top-n", @@ -32,81 +26,62 @@ help="검색된 상위 테이블 수 제한 (기본값: 5)", ) @click.option( - "--device", - default="cpu", - help="LLM 실행에 사용할 디바이스 (기본값: cpu)", + "--dialect", + default=None, + help="SQL 방언 (예: sqlite, postgresql, mysql, bigquery, duckdb)", ) @click.option( - "--use-enriched-graph", + "--no-gate", is_flag=True, - help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부", -) -@click.option( - "--vectordb-type", - type=click.Choice(["faiss", "pgvector"]), - default="faiss", - help="사용할 벡터 데이터베이스 타입 (기본값: faiss)", -) -@click.option( - "--vectordb-location", - help=( - "VectorDB 위치 설정\n" - "- FAISS: 디렉토리 경로 (예: ./my_vectordb)\n" - "- pgvector: 연결 문자열 (예: postgresql://user:pass@host:port/db)\n" - "기본값: FAISS는 './dev/table_info_db', pgvector는 환경변수 사용" - ), + help="QuestionGate 비활성화 (enriched 플로우 전용)", ) def query_command( question: str, - database_env: str, - retriever_name: str, + flow: str, top_n: int, - device: str, - use_enriched_graph: bool, - vectordb_type: str = "faiss", - vectordb_location: str = None, + dialect: str, + no_gate: bool, ) -> None: - """자연어 질문을 SQL 쿼리로 변환하여 출력합니다. + """자연어 질문을 SQL 쿼리로 변환하여 실행 결과를 출력합니다. - Args: - question (str): SQL로 변환할 자연어 질문 - database_env (str): 사용할 데이터베이스 환경 - retriever_name (str): 테이블 검색기 이름 - top_n (int): 검색된 상위 테이블 수 제한 - device (str): LLM 실행 디바이스 - use_enriched_graph (bool): 확장된 그래프 사용 여부 - vectordb_type (str): 벡터 데이터베이스 타입 ("faiss" 또는 "pgvector") - vectordb_location (Optional[str]): 벡터DB 경로 또는 연결 URL + 환경변수(LLM_PROVIDER, EMBEDDING_PROVIDER, DB_TYPE 등)로 설정을 제어합니다. """ try: - from engine.query_executor import execute_query, extract_sql_from_result - - os.environ["VECTORDB_TYPE"] = vectordb_type + from lang2sql.factory import ( + build_db_from_env, + build_embedding_from_env, + build_llm_from_env, + ) + from lang2sql.flows import BaselineNL2SQL, EnrichedNL2SQL - if vectordb_location: - os.environ["VECTORDB_LOCATION"] = vectordb_location + llm = build_llm_from_env() + db = build_db_from_env() - res = execute_query( - query=question, - database_env=database_env, - retriever_name=retriever_name, - top_n=top_n, - device=device, - use_enriched_graph=use_enriched_graph, - ) + if flow == "baseline": + pipeline = BaselineNL2SQL( + catalog=[], + llm=llm, + db=db, + db_dialect=dialect, + ) + else: + embedding = build_embedding_from_env() + pipeline = EnrichedNL2SQL( + catalog=[], + llm=llm, + db=db, + embedding=embedding, + db_dialect=dialect, + gate_enabled=not no_gate, + top_n=top_n, + ) - sql = extract_sql_from_result(res) - if sql: - print(sql) + rows = pipeline.run(question) + if rows: + import json + print(json.dumps(rows, ensure_ascii=False, indent=2)) else: - generated_query = res.get("generated_query") - if generated_query: - query_text = ( - generated_query.content - if hasattr(generated_query, "content") - else str(generated_query) - ) - print(query_text) + print("(결과 없음)") except Exception as e: logger.error("쿼리 처리 중 오류 발생: %s", e) diff --git a/src/lang2sql/factory.py b/src/lang2sql/factory.py new file mode 100644 index 0000000..69539a6 --- /dev/null +++ b/src/lang2sql/factory.py @@ -0,0 +1,218 @@ +"""환경변수 기반 LLM/Embedding/DB 인스턴스 팩토리. + +레거시 utils/llm/core/factory.py를 LangChain 없이 재구현한 것. +CLI와 Streamlit UI 양쪽에서 사용한다. +""" +from __future__ import annotations + +import os + +from .core.ports import DBPort, EmbeddingPort, LLMPort + + +def build_llm_from_env() -> LLMPort: + """환경변수 LLM_PROVIDER에 따라 적절한 LLMPort 인스턴스를 생성한다.""" + provider = os.getenv("LLM_PROVIDER", "openai").lower() + + if provider == "openai": + from .integrations.llm.openai_ import OpenAILLM + + return OpenAILLM( + model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), + api_key=os.getenv("OPEN_AI_KEY"), + ) + + if provider == "anthropic": + from .integrations.llm.anthropic_ import AnthropicLLM + + return AnthropicLLM( + model=os.getenv("ANTHROPIC_LLM_MODEL", "claude-sonnet-4-6"), + api_key=os.getenv("ANTHROPIC_API_KEY"), + ) + + if provider == "azure": + from .integrations.llm.azure_ import AzureOpenAILLM + + return AzureOpenAILLM( + azure_deployment=os.environ["AZURE_OPENAI_LLM_MODEL"], + azure_endpoint=os.environ["AZURE_OPENAI_LLM_ENDPOINT"], + api_version=os.getenv("AZURE_OPENAI_LLM_API_VERSION", "2023-07-01-preview"), + api_key=os.getenv("AZURE_OPENAI_LLM_KEY"), + ) + + if provider == "gemini": + from .integrations.llm.gemini_ import GeminiLLM + + return GeminiLLM( + model=os.getenv("GEMINI_LLM_MODEL", "gemini-2.0-flash-lite"), + api_key=os.getenv("GEMINI_API_KEY"), + ) + + if provider == "bedrock": + from .integrations.llm.bedrock_ import BedrockLLM + + return BedrockLLM( + model=os.environ["AWS_BEDROCK_LLM_MODEL"], + aws_access_key_id=os.getenv("AWS_BEDROCK_LLM_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_BEDROCK_LLM_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_BEDROCK_LLM_REGION", "us-east-1"), + ) + + if provider == "ollama": + from .integrations.llm.ollama_ import OllamaLLM + + return OllamaLLM( + model=os.environ["OLLAMA_LLM_MODEL"], + base_url=os.getenv("OLLAMA_LLM_BASE_URL", "http://localhost:11434"), + ) + + if provider == "huggingface": + from .integrations.llm.huggingface_ import HuggingFaceLLM + + return HuggingFaceLLM( + repo_id=os.getenv("HUGGING_FACE_LLM_REPO_ID"), + endpoint_url=os.getenv("HUGGING_FACE_LLM_ENDPOINT"), + api_token=os.getenv("HUGGING_FACE_LLM_API_TOKEN"), + ) + + raise ValueError( + f"Unknown LLM_PROVIDER: {provider!r}. " + "Valid values: openai, anthropic, azure, gemini, bedrock, ollama, huggingface" + ) + + +def build_embedding_from_env() -> EmbeddingPort: + """환경변수 EMBEDDING_PROVIDER에 따라 EmbeddingPort 인스턴스를 생성한다.""" + provider = os.getenv("EMBEDDING_PROVIDER", "openai").lower().strip("'\"") + + if provider == "openai": + from .integrations.embedding.openai_ import OpenAIEmbedding + + return OpenAIEmbedding( + model=os.getenv("OPEN_AI_EMBEDDING_MODEL", "text-embedding-3-small"), + api_key=os.getenv("OPEN_AI_KEY"), + ) + + if provider == "azure": + from .integrations.embedding.azure_ import AzureOpenAIEmbedding + + return AzureOpenAIEmbedding( + azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"], + azure_endpoint=os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"], + api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION", "2023-09-15-preview"), + api_key=os.getenv("AZURE_OPENAI_EMBEDDING_KEY"), + ) + + if provider == "ollama": + from .integrations.embedding.ollama_ import OllamaEmbedding + + return OllamaEmbedding( + model=os.getenv("EMBEDDING_MODEL", os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")), + base_url=os.getenv("EMBEDDING_BASE_PATH", os.getenv("OLLAMA_EMBEDDING_BASE_URL", "http://localhost:11434")), + ) + + if provider == "bedrock": + from .integrations.embedding.bedrock_ import BedrockEmbedding + + return BedrockEmbedding( + model_id=os.getenv("AWS_BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0"), + aws_access_key_id=os.getenv("AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_BEDROCK_EMBEDDING_REGION", "us-east-1"), + ) + + if provider == "gemini": + from .integrations.embedding.gemini_ import GeminiEmbedding + + return GeminiEmbedding( + model=os.getenv("EMBEDDING_MODEL", "models/embedding-001"), + api_key=os.getenv("GEMINI_EMBEDDING_API_KEY"), + ) + + if provider == "huggingface": + from .integrations.embedding.huggingface_ import HuggingFaceEmbedding + + return HuggingFaceEmbedding( + model=os.getenv( + "HUGGING_FACE_EMBEDDING_MODEL", + os.getenv("HUGGING_FACE_EMBEDDING_REPO_ID", ""), + ) + ) + + raise ValueError( + f"Unknown EMBEDDING_PROVIDER: {provider!r}. " + "Valid values: openai, azure, ollama, bedrock, gemini, huggingface" + ) + + +def build_db_from_env(database_env: str = "") -> DBPort: + """환경변수에서 DB URL을 구성하고 SQLAlchemyDB를 반환한다. + + DB_TYPE 환경변수에 따라 적절한 SQLAlchemy 연결 URL을 구성한다. + """ + from .integrations.db.sqlalchemy_ import SQLAlchemyDB + + db_type = os.getenv("DB_TYPE", "sqlite").lower() + url = _build_db_url(db_type) + return SQLAlchemyDB(url) + + +def _build_db_url(db_type: str) -> str: + if db_type == "sqlite": + path = os.getenv("SQLITE_PATH", "./data/sqlite.db") + return f"sqlite:///{path}" + + if db_type == "postgresql": + host = os.getenv("POSTGRESQL_HOST", "localhost") + port = os.getenv("POSTGRESQL_PORT", "5432") + user = os.getenv("POSTGRESQL_USER", "postgres") + password = os.getenv("POSTGRESQL_PASSWORD", "") + database = os.getenv("POSTGRESQL_DATABASE", "postgres") + return f"postgresql://{user}:{password}@{host}:{port}/{database}" + + if db_type == "mysql": + host = os.getenv("MYSQL_HOST", "localhost") + port = os.getenv("MYSQL_PORT", "3306") + user = os.getenv("MYSQL_USER", "root") + password = os.getenv("MYSQL_PASSWORD", "") + database = os.getenv("MYSQL_DATABASE", "") + return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}" + + if db_type == "mariadb": + host = os.getenv("MARIADB_HOST", "localhost") + port = os.getenv("MARIADB_PORT", "3306") + user = os.getenv("MARIADB_USER", "root") + password = os.getenv("MARIADB_PASSWORD", "") + database = os.getenv("MARIADB_DATABASE", "") + return f"mariadb+pymysql://{user}:{password}@{host}:{port}/{database}" + + if db_type == "duckdb": + path = os.getenv("DUCKDB_PATH", "./data/duckdb.db") + return f"duckdb:///{path}" + + if db_type == "clickhouse": + host = os.getenv("CLICKHOUSE_HOST", "localhost") + port = os.getenv("CLICKHOUSE_PORT", "9001") + user = os.getenv("CLICKHOUSE_USER", "default") + password = os.getenv("CLICKHOUSE_PASSWORD", "") + database = os.getenv("CLICKHOUSE_DATABASE", "default") + return f"clickhouse+native://{user}:{password}@{host}:{port}/{database}" + + if db_type == "snowflake": + user = os.environ["SNOWFLAKE_USER"] + password = os.environ["SNOWFLAKE_PASSWORD"] + account = os.environ["SNOWFLAKE_ACCOUNT"] + return f"snowflake://{user}:{password}@{account}" + + if db_type == "oracle": + host = os.getenv("ORACLE_HOST", "localhost") + port = os.getenv("ORACLE_PORT", "1521") + user = os.getenv("ORACLE_USER", "") + password = os.getenv("ORACLE_PASSWORD", "") + service = os.getenv("ORACLE_SERVICE_NAME", os.getenv("ORACLE_DATABASE", "")) + return f"oracle+cx_oracle://{user}:{password}@{host}:{port}/?service_name={service}" + + raise ValueError( + f"Unknown DB_TYPE: {db_type!r}. " + "Valid values: sqlite, postgresql, mysql, mariadb, duckdb, clickhouse, snowflake, oracle" + ) From 8c4db4668ab841d5e516ae4a54dcf370aceadeab Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:32:38 +0900 Subject: [PATCH 07/10] feat(interface): migrate Streamlit UI to v2 flows --- interface/app_pages/lang2sql.py | 140 ++++++++--------------------- interface/core/lang2sql_runner.py | 82 +++++++++++------ interface/core/provider_factory.py | 140 +++++++++++++++++++++++++++++ interface/core/session_utils.py | 34 ++----- interface/pages_config.py | 1 - 5 files changed, 236 insertions(+), 161 deletions(-) create mode 100644 interface/core/provider_factory.py diff --git a/interface/app_pages/lang2sql.py b/interface/app_pages/lang2sql.py index 22f49aa..11f5830 100644 --- a/interface/app_pages/lang2sql.py +++ b/interface/app_pages/lang2sql.py @@ -2,45 +2,30 @@ Lang2SQL Streamlit 애플리케이션. 자연어 질의를 SQL 쿼리로 변환하고 실행 결과를 시각화하는 인터페이스를 제공합니다. -사용자는 데이터베이스 다이얼렉트 선택 및 편집, 검색기(retriever) 방식 지정, 토큰 사용량/결과 설명/시각화 등 다양한 출력 옵션을 설정할 수 있습니다. 주요 기능: - 사용자 질의를 SQL 쿼리로 변환 후 실행 - - DB 다이얼렉트(PRESET_DIALECTS) 선택 및 편집 지원 - - 검색기 유형 및 Top-N 테이블 검색 개수 설정 - - 쿼리 실행 결과를 표와 차트로 시각화 - - 토큰 사용량, 문서 적합성 평가, AI 재해석 질의 등 추가 정보 표시 + - SQL 방언(dialect) 선택 지원 + - 쿼리 실행 결과를 표로 시각화 + - Baseline / Enriched 워크플로우 선택 """ -from copy import deepcopy - +import pandas as pd import streamlit as st -from interface.core.dialects import PRESET_DIALECTS, DialectOption -from interface.core.lang2sql_runner import run_lang2sql -from interface.core.result_renderer import display_result -from interface.core.session_utils import init_graph from interface.core.config import load_config +from interface.core.lang2sql_runner import run_lang2sql from interface.app_pages.sidebar_components import ( render_sidebar_data_source_selector, - render_sidebar_llm_selector, - render_sidebar_embedding_selector, render_sidebar_db_selector, + render_sidebar_embedding_selector, + render_sidebar_llm_selector, ) TITLE = "Lang2SQL" DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" -SIDEBAR_OPTIONS = { - "show_token_usage": "Show Token Usage", - "show_result_description": "Show Result Description", - "show_sql": "Show SQL", - "show_question_reinterpreted_by_ai": "Show User Question Reinterpreted by AI", - "show_referenced_tables": "Show List of Referenced Tables", - "show_question_gate_result": "Show Question Gate Result", - "show_document_suitability": "Show Document Suitability", - "show_table": "Show Table", - "show_chart": "Show Chart", -} + +DIALECT_OPTIONS = ["default", "sqlite", "postgresql", "mysql", "bigquery", "duckdb"] st.title(TITLE) @@ -55,92 +40,39 @@ render_sidebar_db_selector() st.sidebar.divider() -st.sidebar.title("Output Settings") -for key, label in SIDEBAR_OPTIONS.items(): - st.sidebar.checkbox(label, value=True, key=key) - st.sidebar.markdown("### 워크플로우 선택") use_enriched = st.sidebar.checkbox( - "프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False + "프로파일 추출 & 컨텍스트 보강 워크플로우 사용 (Enriched)", value=False ) -if ( - "graph" not in st.session_state - or st.session_state.get("use_enriched") != use_enriched -): - GRAPH_TYPE = init_graph(use_enriched) - st.info(f"Lang2SQL 시작됨. ({GRAPH_TYPE} 워크플로우)") - -if st.sidebar.button("Lang2SQL 새로고침"): - GRAPH_TYPE = init_graph(st.session_state.get("use_enriched", False)) - st.sidebar.success( - f"Lang2SQL이 성공적으로 새로고침되었습니다. ({GRAPH_TYPE} 워크플로우)" - ) - -## moved to component: render_sidebar_llm_selector() - +# 쿼리 입력 user_query = st.text_area("쿼리를 입력하세요:", value=DEFAULT_QUERY) -if "dialects" not in st.session_state: - st.session_state["dialects"] = {k: v.to_dict() for k, v in PRESET_DIALECTS.items()} - -st.markdown("### DB 선택 및 관리") -cols = st.columns(2) -dialects = st.session_state["dialects"] -keys = list(dialects.keys()) -active = st.session_state.get("active_dialect", keys[0]) - -with cols[0]: - user_database_env = st.selectbox( - "사용할 DB를 선택하세요:", options=keys, index=keys.index(active) +# 설정 +col1, col2 = st.columns(2) +with col1: + user_dialect = st.selectbox( + "SQL 방언(Dialect):", options=DIALECT_OPTIONS, index=0 ) - st.session_state["active_dialect"] = user_database_env - st.session_state["selected_dialect_option"] = dialects[user_database_env] - -with cols[1]: - st.caption("선택된 DB 설정을 편집하거나 새로 추가할 수 있습니다.") - -with st.expander("DB 편집"): - edit_key = st.selectbox( - "편집할 DB를 선택하세요:", - options=keys, - index=keys.index(active), - key="dialect_edit_selector", - ) - current = deepcopy(dialects[edit_key]) - _supports_ilike = st.checkbox( - "ILIKE 지원", value=bool(current.get("supports_ilike", False)) - ) - _hints_text = st.text_area( - "hints (쉼표로 구분)", value=", ".join(current.get("hints", [])) - ) - if st.button("변경사항 저장", key="btn_save_dialect_edit"): - st.session_state["dialects"][edit_key] = DialectOption( - name=edit_key, - supports_ilike=_supports_ilike, - hints=[s.strip() for s in _hints_text.split(",") if s.strip()], - ).to_dict() - st.success(f"{edit_key} DB가 업데이트되었습니다.") - -device = st.selectbox("모델 실행 장치", options=["cpu", "cuda"], index=0) -retriever_options = { - "기본": "벡터 검색 (기본)", - "Reranker": "Reranker 검색 (정확도 향상)", -} -user_retriever = st.selectbox( - "검색기 유형을 선택하세요:", - options=list(retriever_options.keys()), - format_func=lambda x: retriever_options[x], -) -user_top_n = st.slider("검색할 테이블 정보 개수:", min_value=1, max_value=20, value=5) +with col2: + user_top_n = st.slider("검색할 테이블 정보 개수:", min_value=1, max_value=20, value=5) if st.button("쿼리 실행"): - res = run_lang2sql( - query=user_query, - database_env=user_database_env, - retriever_name=user_retriever, - top_n=user_top_n, - device=device, - use_enriched=use_enriched, - ) - display_result(res=res) + with st.spinner("쿼리 실행 중..."): + res = run_lang2sql( + query=user_query, + db_dialect=user_dialect if user_dialect != "default" else None, + top_n=user_top_n, + use_enriched=use_enriched, + ) + + if res.get("error"): + st.error(f"오류 발생: {res['error']}") + else: + rows = res.get("rows", []) + if rows: + st.success(f"{len(rows)}개 행 반환됨.") + st.markdown("**쿼리 실행 결과:**") + st.dataframe(pd.DataFrame(rows)) + else: + st.info("쿼리 실행 결과가 없습니다.") diff --git a/interface/core/lang2sql_runner.py b/interface/core/lang2sql_runner.py index f37bd4a..503f230 100644 --- a/interface/core/lang2sql_runner.py +++ b/interface/core/lang2sql_runner.py @@ -3,43 +3,69 @@ 이 모듈은 자연어로 입력된 질문을 SQL 쿼리로 변환하고, 지정된 데이터베이스 환경에서 실행하는 함수(`run_lang2sql`)를 제공합니다. -내부적으로 `engine.query_executor.execute_query`를 호출하여 -Lang2SQL 전체 파이프라인을 간단히 실행할 수 있도록 합니다. +내부적으로 v2 플로우(BaselineNL2SQL / EnrichedNL2SQL)를 사용한다. """ +from __future__ import annotations -from engine.query_executor import execute_query as execute_query_common +from typing import Any def run_lang2sql( - query, - database_env, - retriever_name, - top_n, - device, - use_enriched, -): - """ - Lang2SQL 실행 함수. + query: str, + db_dialect: str | None = None, + top_n: int = 5, + use_enriched: bool = False, + catalog: list | None = None, +) -> dict[str, Any]: + """Lang2SQL 실행 함수. - 주어진 자연어 질문을 SQL 쿼리로 변환하고 지정된 데이터베이스 환경에서 실행합니다. - 내부적으로 `engine.query_executor.execute_query`를 호출합니다. + 주어진 자연어 질문을 SQL 쿼리로 변환하고 데이터베이스에서 실행한다. + LLM, Embedding, DB는 환경변수(LLM_PROVIDER, EMBEDDING_PROVIDER, DB_TYPE 등)로 + 자동 설정된다. Args: - query (str): 사용자 입력 자연어 질문. - database_env (str): 사용할 데이터베이스 환경 이름. - retriever_name (str): 검색기(retriever) 유형 이름. - top_n (int): 검색할 테이블 정보 개수. - device (str): 모델 실행 장치 ("cpu" 또는 "cuda"). + query: 사용자 입력 자연어 질문. + db_dialect: SQL 방언 힌트 (None이면 default 프롬프트 사용). + top_n: 검색할 상위 테이블 수. + use_enriched: True이면 EnrichedNL2SQL 플로우 사용. + catalog: CatalogEntry 목록. None이면 빈 카탈로그로 실행. Returns: - dict: Lang2SQL 실행 결과를 담은 딕셔너리. + dict: {"rows": list[dict], "sql": str, "error": str | None} """ - - return execute_query_common( - query=query, - database_env=database_env, - retriever_name=retriever_name, - top_n=top_n, - device=device, - use_enriched_graph=use_enriched, + from lang2sql.factory import ( + build_db_from_env, + build_embedding_from_env, + build_llm_from_env, ) + from lang2sql.flows import BaselineNL2SQL, EnrichedNL2SQL + + catalog = catalog or [] + + try: + llm = build_llm_from_env() + db = build_db_from_env() + + if use_enriched: + embedding = build_embedding_from_env() + pipeline = EnrichedNL2SQL( + catalog=catalog, + llm=llm, + db=db, + embedding=embedding, + db_dialect=db_dialect, + top_n=top_n, + ) + else: + pipeline = BaselineNL2SQL( + catalog=catalog, + llm=llm, + db=db, + db_dialect=db_dialect, + ) + + rows = pipeline.run(query) + return {"rows": rows, "error": None} + + except Exception as exc: + return {"rows": [], "error": str(exc)} diff --git a/interface/core/provider_factory.py b/interface/core/provider_factory.py new file mode 100644 index 0000000..0c616f8 --- /dev/null +++ b/interface/core/provider_factory.py @@ -0,0 +1,140 @@ +"""Settings UI에서 선택된 프로파일을 LLMPort/EmbeddingPort로 변환하는 팩토리. + +LLMProfile과 EmbeddingProfile(config/models.py)을 받아 +lang2sql.integrations의 구현체를 반환한다. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lang2sql.core.ports import EmbeddingPort, LLMPort + from interface.core.config.models import EmbeddingProfile, LLMProfile + + +def build_llm(profile: "LLMProfile") -> "LLMPort": + """LLMProfile → LLMPort 변환. Settings UI에서 호출한다.""" + f = profile.fields + provider = profile.provider.lower() + + if provider == "openai": + from lang2sql.integrations.llm.openai_ import OpenAILLM + + return OpenAILLM( + model=f.get("model", "gpt-4o"), + api_key=f.get("api_key"), + ) + + if provider == "anthropic": + from lang2sql.integrations.llm.anthropic_ import AnthropicLLM + + return AnthropicLLM( + model=f.get("model", "claude-sonnet-4-6"), + api_key=f.get("api_key"), + ) + + if provider == "azure": + from lang2sql.integrations.llm.azure_ import AzureOpenAILLM + + return AzureOpenAILLM( + azure_deployment=f["azure_deployment"], + azure_endpoint=f["azure_endpoint"], + api_version=f.get("api_version", "2023-07-01-preview"), + api_key=f.get("api_key"), + ) + + if provider == "gemini": + from lang2sql.integrations.llm.gemini_ import GeminiLLM + + return GeminiLLM( + model=f.get("model", "gemini-2.0-flash-lite"), + api_key=f.get("api_key"), + ) + + if provider == "bedrock": + from lang2sql.integrations.llm.bedrock_ import BedrockLLM + + return BedrockLLM( + model=f["model"], + aws_access_key_id=f.get("aws_access_key_id"), + aws_secret_access_key=f.get("aws_secret_access_key"), + region_name=f.get("region_name", "us-east-1"), + ) + + if provider == "ollama": + from lang2sql.integrations.llm.ollama_ import OllamaLLM + + return OllamaLLM( + model=f["model"], + base_url=f.get("base_url", "http://localhost:11434"), + ) + + if provider == "huggingface": + from lang2sql.integrations.llm.huggingface_ import HuggingFaceLLM + + return HuggingFaceLLM( + repo_id=f.get("repo_id"), + endpoint_url=f.get("endpoint_url"), + api_token=f.get("api_token"), + ) + + raise ValueError(f"Unknown LLM provider: {provider!r}") + + +def build_embedding(profile: "EmbeddingProfile") -> "EmbeddingPort": + """EmbeddingProfile → EmbeddingPort 변환. Settings UI에서 호출한다.""" + f = profile.fields + provider = profile.provider.lower() + + if provider == "openai": + from lang2sql.integrations.embedding.openai_ import OpenAIEmbedding + + return OpenAIEmbedding( + model=f.get("model", "text-embedding-3-small"), + api_key=f.get("api_key"), + ) + + if provider == "azure": + from lang2sql.integrations.embedding.azure_ import AzureOpenAIEmbedding + + return AzureOpenAIEmbedding( + azure_deployment=f["azure_deployment"], + azure_endpoint=f["azure_endpoint"], + api_version=f.get("api_version", "2023-09-15-preview"), + api_key=f.get("api_key"), + ) + + if provider == "ollama": + from lang2sql.integrations.embedding.ollama_ import OllamaEmbedding + + return OllamaEmbedding( + model=f.get("model", "nomic-embed-text"), + base_url=f.get("base_url", "http://localhost:11434"), + ) + + if provider == "bedrock": + from lang2sql.integrations.embedding.bedrock_ import BedrockEmbedding + + return BedrockEmbedding( + model_id=f.get("model_id", "amazon.titan-embed-text-v2:0"), + aws_access_key_id=f.get("aws_access_key_id"), + aws_secret_access_key=f.get("aws_secret_access_key"), + region_name=f.get("region_name", "us-east-1"), + ) + + if provider == "gemini": + from lang2sql.integrations.embedding.gemini_ import GeminiEmbedding + + return GeminiEmbedding( + model=f.get("model", "models/embedding-001"), + api_key=f.get("api_key"), + ) + + if provider == "huggingface": + from lang2sql.integrations.embedding.huggingface_ import HuggingFaceEmbedding + + return HuggingFaceEmbedding( + model=f.get("model", f.get("repo_id", "")), + ) + + raise ValueError(f"Unknown embedding provider: {provider!r}") diff --git a/interface/core/session_utils.py b/interface/core/session_utils.py index 5071fbd..fbabb2b 100644 --- a/interface/core/session_utils.py +++ b/interface/core/session_utils.py @@ -1,38 +1,16 @@ -""" -Streamlit 세션 상태에서 그래프 빌더를 초기화하는 모듈. +"""Streamlit 세션 상태 유틸리티 모듈.""" -이 모듈은 Lang2SQL 애플리케이션의 그래프 실행 파이프라인을 준비하기 위해 -기본 또는 확장(enriched) 그래프 빌더를 선택적으로 로드하고, -세션 상태에 초기화된 그래프 객체를 저장합니다. -Functions: - init_graph(use_enriched: bool) -> str: - 그래프 빌더를 초기화하고 세션 상태를 갱신합니다. -""" - -import streamlit as st - - -def init_graph(use_enriched: bool) -> str: - """그래프 빌더를 초기화하고 세션 상태를 갱신합니다. +def init_pipeline(use_enriched: bool) -> str: + """파이프라인 타입을 세션 상태에 기록한다. Args: - use_enriched (bool): 확장(enriched) 그래프 빌더를 사용할지 여부. + use_enriched: True이면 EnrichedNL2SQL, False이면 BaselineNL2SQL. Returns: - str: 초기화된 그래프 유형. "확장된" 또는 "기본". + str: "확장된" 또는 "기본". """ + import streamlit as st - builder_module = ( - "utils.llm.graph_utils.enriched_graph" - if use_enriched - else "utils.llm.graph_utils.basic_graph" - ) - - builder = __import__(builder_module, fromlist=["builder"]).builder - - st.session_state.setdefault("graph", builder.compile()) - st.session_state["graph"] = builder.compile() st.session_state["use_enriched"] = use_enriched - return "확장된" if use_enriched else "기본" diff --git a/interface/pages_config.py b/interface/pages_config.py index 8963f03..362f855 100644 --- a/interface/pages_config.py +++ b/interface/pages_config.py @@ -17,7 +17,6 @@ PAGES = [ st.Page("app_pages/home.py", title="🏠 홈"), st.Page("app_pages/lang2sql.py", title="🔍 Lang2SQL"), - st.Page("app_pages/graph_builder.py", title="📊 그래프 빌더"), st.Page("app_pages/chatbot.py", title="🤖 ChatBot"), st.Page("app_pages/settings.py", title="⚙️ 설정"), ] From f5e1774151e83c32f35ac22b95245cc1d4aa5d7f Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:32:55 +0900 Subject: [PATCH 08/10] chore(legacy): remove replaced modules and fix broken imports --- engine/README.md | 215 --------------- engine/__init__.py | 1 - engine/query_executor.py | 136 --------- interface/app_pages/graph_builder.py | 239 ---------------- prompt/document_suitability_prompt.md | 47 ---- prompt/profile_extraction_prompt.md | 19 -- prompt/query_enrichment_prompt.md | 22 -- prompt/query_maker_prompt.md | 47 ---- prompt/question_gate_prompt.md | 19 -- utils/llm/chains.py | 145 ---------- utils/llm/core/__init__.py | 34 +-- utils/llm/core/factory.py | 181 ------------ utils/llm/graph_utils/README.md | 185 ------------- utils/llm/graph_utils/__init__.py | 37 --- utils/llm/graph_utils/base.py | 258 ------------------ utils/llm/graph_utils/basic_graph.py | 49 ---- utils/llm/graph_utils/enriched_graph.py | 57 ---- utils/llm/graph_utils/profile_utils.py | 17 -- utils/llm/output_schema/README.md | 114 -------- .../llm/output_schema/document_suitability.py | 36 --- .../llm/output_schema/question_suitability.py | 23 -- utils/llm/retrieval.py | 104 ------- utils/llm/tools/chatbot_tool.py | 23 +- utils/llm/vectordb/README.md | 223 --------------- utils/llm/vectordb/__init__.py | 7 - utils/llm/vectordb/factory.py | 38 --- utils/llm/vectordb/faiss_db.py | 33 --- utils/llm/vectordb/pgvector_db.py | 81 ------ 28 files changed, 21 insertions(+), 2369 deletions(-) delete mode 100644 engine/README.md delete mode 100644 engine/__init__.py delete mode 100644 engine/query_executor.py delete mode 100644 interface/app_pages/graph_builder.py delete mode 100644 prompt/document_suitability_prompt.md delete mode 100644 prompt/profile_extraction_prompt.md delete mode 100644 prompt/query_enrichment_prompt.md delete mode 100644 prompt/query_maker_prompt.md delete mode 100644 prompt/question_gate_prompt.md delete mode 100644 utils/llm/chains.py delete mode 100644 utils/llm/core/factory.py delete mode 100644 utils/llm/graph_utils/README.md delete mode 100644 utils/llm/graph_utils/__init__.py delete mode 100644 utils/llm/graph_utils/base.py delete mode 100644 utils/llm/graph_utils/basic_graph.py delete mode 100644 utils/llm/graph_utils/enriched_graph.py delete mode 100644 utils/llm/graph_utils/profile_utils.py delete mode 100644 utils/llm/output_schema/README.md delete mode 100644 utils/llm/output_schema/document_suitability.py delete mode 100644 utils/llm/output_schema/question_suitability.py delete mode 100644 utils/llm/retrieval.py delete mode 100644 utils/llm/vectordb/README.md delete mode 100644 utils/llm/vectordb/__init__.py delete mode 100644 utils/llm/vectordb/factory.py delete mode 100644 utils/llm/vectordb/faiss_db.py delete mode 100644 utils/llm/vectordb/pgvector_db.py diff --git a/engine/README.md b/engine/README.md deleted file mode 100644 index b8b4445..0000000 --- a/engine/README.md +++ /dev/null @@ -1,215 +0,0 @@ -# engine 모듈 - -Lang2SQL 쿼리 실행을 위한 공용 모듈입니다. - -이 모듈은 CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있는 쿼리 실행 함수를 제공합니다. - -## 디렉토리 구조 - -``` -engine/ -├── __init__.py # 패키지 초기화 모듈 -├── query_executor.py # 쿼리 실행 공용 함수 -└── README.md # 이 파일 -``` - -## 모듈 개요 - -### `__init__.py` -- **위치**: `engine/__init__.py` -- **설명**: Lang2SQL Data Processing 진입점 패키지 -- **내용**: 패키지 초기화 모듈 - -### `query_executor.py` -- **위치**: `engine/query_executor.py` -- **설명**: Lang2SQL 쿼리 실행을 위한 공용 모듈 -- **주요 기능**: - - `execute_query()`: 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환 - - `extract_sql_from_result()`: Lang2SQL 실행 결과에서 SQL 쿼리 추출 - -#### 주요 함수 - -##### `execute_query()` -자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 공용 함수입니다. - -**매개변수:** -- `query` (str): 사용자가 입력한 자연어 기반 질문 -- `database_env` (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod") -- `retriever_name` (str, optional): 테이블 검색기 이름. 기본값은 "기본" -- `top_n` (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5 -- `device` (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu" -- `use_enriched_graph` (bool, optional): 확장된 그래프 사용 여부. 기본값은 False -- `session_state` (Optional[Union[Dict[str, Any], Any]], optional): Streamlit 세션 상태 (Streamlit에서만 사용) - -**반환값:** -- `Dict[str, Any]`: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: - - `"generated_query"`: 생성된 SQL 쿼리 (`AIMessage`) - - `"messages"`: 전체 LLM 응답 메시지 목록 - - `"searched_tables"`: 참조된 테이블 목록 등 추가 정보 - -**동작 방식:** -1. 사용자가 지정한 옵션에 따라 기본 그래프 또는 확장 그래프를 선택 -2. Streamlit 환경에서는 세션 상태에서 그래프 재사용, CLI 환경에서는 매번 새로운 그래프 컴파일 -3. 선택된 그래프를 컴파일하고 invoke하여 결과 반환 - -**사용 예제:** -```python -from engine.query_executor import execute_query - -# CLI 환경에서 사용 -result = execute_query( - query="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", - database_env="clickhouse", - retriever_name="기본", - top_n=5, - device="cpu", - use_enriched_graph=False -) - -# Streamlit 환경에서 사용 -result = execute_query( - query="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", - database_env="clickhouse", - retriever_name="기본", - top_n=5, - device="cpu", - use_enriched_graph=False, - session_state=st.session_state # Streamlit 세션 상태 전달 -) -``` - -##### `extract_sql_from_result()` -Lang2SQL 실행 결과에서 SQL 쿼리를 추출합니다. - -**매개변수:** -- `res` (Dict[str, Any]): `execute_query()` 함수의 반환 결과 - -**반환값:** -- `Optional[str]`: 추출된 SQL 쿼리 문자열. 추출 실패 시 None - -**동작 방식:** -1. `generated_query` 필드에서 쿼리 메시지 추출 -2. `LLMResponseParser.extract_sql()`을 사용하여 SQL 쿼리 문자열 추출 -3. 추출 실패 시 None 반환 - -**사용 예제:** -```python -from engine.query_executor import execute_query, extract_sql_from_result - -result = execute_query( - query="고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리", - database_env="clickhouse" -) - -sql = extract_sql_from_result(result) -if sql: - print(sql) -``` - -## 의존성 - -### 내부 모듈 -- `utils.llm.graph_utils.basic_graph.builder`: 기본 그래프 빌더 -- `utils.llm.graph_utils.enriched_graph.builder`: 확장 그래프 빌더 -- `utils.llm.llm_response_parser.LLMResponseParser`: LLM 응답 파서 - -### 외부 라이브러리 -- `langchain_core.messages.HumanMessage`: LangChain 메시지 클래스 - -## 사용 위치 - -### 1. CLI 명령어 (`cli/commands/quary.py`) -CLI 환경에서 `query` 명령어 실행 시 사용됩니다. - -```python -from engine.query_executor import execute_query, extract_sql_from_result - -# CLI 명령어에서 사용 -res = execute_query( - query=question, - database_env=database_env, - retriever_name=retriever_name, - top_n=top_n, - device=device, - use_enriched_graph=use_enriched_graph, -) - -sql = extract_sql_from_result(res) -``` - -### 2. Streamlit 인터페이스 (`interface/core/lang2sql_runner.py`) -Streamlit 인터페이스에서 Lang2SQL 실행을 위해 사용됩니다. - -```python -from engine.query_executor import execute_query as execute_query_common - -# Streamlit 러너에서 사용 -def run_lang2sql(query, database_env, retriever_name, top_n, device): - return execute_query_common( - query=query, - database_env=database_env, - retriever_name=retriever_name, - top_n=top_n, - device=device, - ) -``` - -### 3. Streamlit 메인 페이지 (`interface/app_pages/lang2sql.py`) -Streamlit 메인 페이지에서 `lang2sql_runner.run_lang2sql()`을 호출하여 사용됩니다. - -```python -from interface.core.lang2sql_runner import run_lang2sql - -# 메인 페이지에서 사용 -if st.button("쿼리 실행"): - res = run_lang2sql( - query=user_query, - database_env=user_database_env, - retriever_name=user_retriever, - top_n=user_top_n, - device=device, - ) - display_result(res=res) -``` - -## 워크플로우 - -### 기본 워크플로우 -1. 사용자가 자연어 질문 입력 -2. `execute_query()` 호출 -3. 기본 그래프 빌더 선택 및 컴파일 -4. 그래프 실행하여 SQL 쿼리 생성 -5. 결과 딕셔너리 반환 - -### 확장 워크플로우 (프로파일 추출 + 컨텍스트 보강) -1. 사용자가 자연어 질문 입력 -2. `execute_query(use_enriched_graph=True)` 호출 -3. 확장 그래프 빌더 선택 및 컴파일 -4. 그래프 실행하여 SQL 쿼리 생성 -5. 결과 딕셔너리 반환 - -## 환경별 동작 - -### CLI 환경 -- 세션 상태 없이 매번 새로운 그래프 컴파일 -- 중간 결과 저장/재사용 불가 - -### Streamlit 환경 -- 세션 상태를 통해 그래프 재사용 가능 -- 중간 결과 저장/재사용 가능 -- 다이얼렉트 정보 주입 지원 - -## 로깅 - -이 모듈은 `logging` 모듈을 사용하여 로그를 기록합니다: -- 처리 중인 쿼리 로그 -- 사용 중인 그래프 유형 로그 -- SQL 추출 실패 시 에러 로그 - -## 주의사항 - -1. `session_state` 파라미터는 Streamlit 환경에서만 유효합니다 -2. `use_enriched_graph=True`로 설정하면 더 많은 리소스가 소모될 수 있습니다 -3. `database_env`는 유효한 데이터베이스 환경 이름이어야 합니다 -4. 그래프 컴파일은 처음 실행 시 시간이 걸릴 수 있습니다 - diff --git a/engine/__init__.py b/engine/__init__.py deleted file mode 100644 index 1c8cdec..0000000 --- a/engine/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Lang2SQL Data Processing 진입점 패키지""" diff --git a/engine/query_executor.py b/engine/query_executor.py deleted file mode 100644 index 01f5b7b..0000000 --- a/engine/query_executor.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Lang2SQL 쿼리 실행을 위한 공용 모듈입니다. - -이 모듈은 CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있는 -쿼리 실행 함수를 제공합니다. -""" - -import logging -from typing import Any, Dict, Optional, Union - -from langchain_core.messages import HumanMessage - -from utils.llm.graph_utils.basic_graph import builder as basic_builder -from utils.llm.graph_utils.enriched_graph import builder as enriched_builder -from utils.llm.llm_response_parser import LLMResponseParser - -logger = logging.getLogger(__name__) - - -def execute_query( - *, - query: str, - database_env: str, - retriever_name: str = "기본", - top_n: int = 5, - device: str = "cpu", - use_enriched_graph: bool = False, - session_state: Optional[Union[Dict[str, Any], Any]] = None, -) -> Dict[str, Any]: - """ - 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 공용 함수입니다. - - 이 함수는 Lang2SQL 파이프라인(graph)을 사용하여 사용자의 자연어 질문을 - SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. - CLI와 Streamlit 인터페이스에서 공통으로 사용할 수 있습니다. - - Args: - query (str): 사용자가 입력한 자연어 기반 질문. - database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod"). - retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본". - top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5. - device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu". - use_enriched_graph (bool, optional): 확장된 그래프 사용 여부. 기본값은 False. - session_state (Optional[Union[Dict[str, Any], Any]], optional): Streamlit 세션 상태 (Streamlit에서만 사용). - - Returns: - Dict[str, Any]: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: - - "generated_query": 생성된 SQL 쿼리 (`AIMessage`) - - "messages": 전체 LLM 응답 메시지 목록 - - "searched_tables": 참조된 테이블 목록 등 추가 정보 - """ - logger.info("Processing query: %s", query) - - # 그래프 선택 - if use_enriched_graph: - graph_type = "enriched" - graph_builder = enriched_builder - else: - graph_type = "basic" - graph_builder = basic_builder - - logger.info("Using %s graph", graph_type) - - # 그래프 선택 및 컴파일 - if session_state is not None: - # Streamlit 환경: 세션 상태에서 그래프 재사용 - graph = session_state.get("graph") - if graph is None: - graph = graph_builder.compile() - session_state["graph"] = graph - else: - # CLI 환경: 매번 새로운 그래프 컴파일 - graph = graph_builder.compile() - - # 그래프 실행 - res = graph.invoke( - input={ - "messages": [HumanMessage(content=query)], - "user_database_env": database_env, - "best_practice_query": "", - "retriever_name": retriever_name, - "top_n": top_n, - "device": device, - # 다이얼렉트 정보 주입 (있다면 세션에서, 없으면 기본값) - "dialect_name": ( - session_state.get("selected_dialect_option", {}).get("name") - if session_state is not None - else database_env - ), - "supports_ilike": ( - bool( - session_state.get("selected_dialect_option", {}).get( - "supports_ilike", False - ) - ) - if session_state is not None - else False - ), - "dialect_hints": ( - session_state.get("selected_dialect_option", {}).get("hints", []) - if session_state is not None - else [] - ), - } - ) - - return res - - -def extract_sql_from_result(res: Dict[str, Any]) -> Optional[str]: - """ - Lang2SQL 실행 결과에서 SQL 쿼리를 추출합니다. - - Args: - res (Dict[str, Any]): execute_query 함수의 반환 결과 - - Returns: - Optional[str]: 추출된 SQL 쿼리 문자열. 추출 실패 시 None - """ - generated_query = res.get("generated_query") - if not generated_query: - logger.error("생성된 쿼리가 없습니다.") - return None - - query_text = ( - generated_query.content - if hasattr(generated_query, "content") - else str(generated_query) - ) - - try: - sql = LLMResponseParser.extract_sql(query_text) - return sql - except ValueError: - logger.error("SQL을 추출할 수 없습니다.") - return None diff --git a/interface/app_pages/graph_builder.py b/interface/app_pages/graph_builder.py deleted file mode 100644 index 4792992..0000000 --- a/interface/app_pages/graph_builder.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -LangGraph 워크플로우를 Streamlit에서 구성하고 세션에 적용하는 페이지. - -기능 개요: -- 프리셋(기본/확장) 또는 커스텀 토글로 노드 시퀀스를 구성 -- QUERY_MAKER 포함 여부를 토글하여 마지막 노드를 제어 -- 선택이 바뀌면 즉시 컴파일된 그래프를 세션 상태에 반영 -- 현재 적용된 그래프 설정을 확인 가능 -""" - -from typing import List - -import streamlit as st -from langgraph.graph import END, StateGraph - -from utils.llm.graph_utils.base import ( - CONTEXT_ENRICHMENT, - GET_TABLE_INFO, - PROFILE_EXTRACTION, - QUERY_MAKER, - QueryMakerState, - context_enrichment_node, - get_table_info_node, - profile_extraction_node, - query_maker_node, -) - - -def build_selected_sequence( - preset: str, use_profile: bool, use_context: bool -) -> List[str]: - """ - 프리셋과 커스텀 토글에 따라 실행할 노드 시퀀스를 생성합니다. - - Args: - preset (str): "기본" | "확장" | "커스텀" 중 하나 - use_profile (bool): 커스텀에서 PROFILE_EXTRACTION 포함 여부 - use_context (bool): 커스텀에서 CONTEXT_ENRICHMENT 포함 여부 - - Returns: - List[str]: 노드 식별자들의 실행 순서 - """ - sequence: List[str] = [GET_TABLE_INFO] - - if preset == "기본": - sequence += [QUERY_MAKER] - elif preset == "확장": - sequence += [PROFILE_EXTRACTION, CONTEXT_ENRICHMENT, QUERY_MAKER] - else: - if use_profile: - sequence.append(PROFILE_EXTRACTION) - if use_context: - sequence.append(CONTEXT_ENRICHMENT) - sequence.append(QUERY_MAKER) - - return sequence - - -def build_state_graph(sequence: List[str]) -> StateGraph: - """ - 주어진 시퀀스대로 노드를 추가하고, 인접 노드 간 엣지를 연결한 그래프 빌더를 반환합니다. - - 마지막 노드는 항상 END로 연결합니다. - - Args: - sequence (List[str]): 실행 순서에 따른 노드 식별자 목록 - - Returns: - StateGraph: 컴파일 전 그래프 빌더 객체 - """ - builder = StateGraph(QueryMakerState) - builder.set_entry_point(GET_TABLE_INFO) - - # 노드 등록 - for node_id in sequence: - if node_id == GET_TABLE_INFO: - builder.add_node(GET_TABLE_INFO, get_table_info_node) - elif node_id == PROFILE_EXTRACTION: - builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) - elif node_id == CONTEXT_ENRICHMENT: - builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) - elif node_id == QUERY_MAKER: - builder.add_node(QUERY_MAKER, query_maker_node) - - # 엣지 연결 - for i in range(len(sequence) - 1): - builder.add_edge(sequence[i], sequence[i + 1]) - - # 종료 연결: 마지막 노드가 무엇이든 END로 연결 - if len(sequence) > 0: - builder.add_edge(sequence[-1], END) - - return builder - - -def render_sequence(sequence: List[str]) -> str: - """ - 노드 시퀀스를 사람이 읽기 쉬운 문자열로 변환합니다. - - Args: - sequence (List[str]): 실행 순서에 따른 노드 식별자 목록 - - Returns: - str: 예) "GET_TABLE_INFO → PROFILE_EXTRACTION → ..." - """ - label_map = { - GET_TABLE_INFO: "GET_TABLE_INFO", - PROFILE_EXTRACTION: "PROFILE_EXTRACTION", - CONTEXT_ENRICHMENT: "CONTEXT_ENRICHMENT", - QUERY_MAKER: "QUERY_MAKER", - } - return " → ".join(label_map[s] for s in sequence) - - -st.title("LangGraph 구성 UI") -st.caption("기본/확장/커스텀으로 StateGraph를 구성하고 세션에 적용합니다.") - -preset = st.radio("프리셋 선택", ("기본", "확장", "커스텀"), horizontal=True) - -use_profile = False -use_context = False -if preset == "커스텀": - st.subheader("커스텀 옵션") - use_profile = st.checkbox("PROFILE_EXTRACTION 포함", value=True) - use_context = st.checkbox("CONTEXT_ENRICHMENT 포함", value=True) - use_query_maker = st.checkbox("QUERY_MAKER 포함", value=True) -else: - # 프리셋에서는 QUERY_MAKER 자동 포함 - use_query_maker = True - -# GET_TABLE_INFO 설정 -st.subheader("GET_TABLE_INFO 설정") -_prev_cfg = st.session_state.get("graph_config", {}) - -_retriever_options = { - "기본": "벡터 검색 (기본)", - "Reranker": "Reranker 검색 (정확도 향상)", -} -_retriever_keys = list(_retriever_options.keys()) -_retriever_default = _prev_cfg.get("retriever_name", "기본") -_retriever_index = ( - _retriever_keys.index(_retriever_default) - if _retriever_default in _retriever_keys - else 0 -) - -retriever_name = st.selectbox( - "테이블 검색기", - options=_retriever_keys, - format_func=lambda x: _retriever_options[x], - index=_retriever_index, -) - -top_n = st.slider( - "검색할 테이블 정보 개수", - min_value=1, - max_value=20, - value=int(_prev_cfg.get("top_n", 5)), - step=1, -) - -_device_options = ["cpu", "cuda"] -_device_default = _prev_cfg.get("device", "cpu") -_device_index = ( - _device_options.index(_device_default) if _device_default in _device_options else 0 -) -device = st.selectbox( - "모델 실행 장치", - options=_device_options, - index=_device_index, -) - - -def build_sequence_with_qm( - preset: str, use_profile: bool, use_context: bool, use_qm: bool -) -> List[str]: - """ - QUERY_MAKER 포함 여부를 반영하여 시퀀스를 생성합니다. - - - use_qm=False면 마지막 노드는 반드시 GET_TABLE_INFO입니다. - - use_qm=True면 프리셋/커스텀 로직에 따라 마지막 노드는 QUERY_MAKER가 됩니다. - - Args: - preset (str): "기본" | "확장" | "커스텀" 중 하나 - use_profile (bool): PROFILE_EXTRACTION 포함 여부(커스텀 전용) - use_context (bool): CONTEXT_ENRICHMENT 포함 여부(커스텀 전용) - use_qm (bool): QUERY_MAKER 포함 여부 - - Returns: - List[str]: 노드 식별자들의 실행 순서 - """ - # QUERY_MAKER가 비활성화되면 마지막 노드는 반드시 GET_TABLE_INFO - if not use_qm: - return [GET_TABLE_INFO] - # 활성화된 경우 프리셋/커스텀 구성에 따라 마지막 노드는 QUERY_MAKER - base_seq = build_selected_sequence(preset, use_profile, use_context) - return base_seq - - -sequence = build_sequence_with_qm(preset, use_profile, use_context, use_query_maker) - -st.subheader("실행 순서") -st.write(render_sequence(sequence)) - -st.subheader("그래프 생성") -config = { - "preset": preset, - "use_profile": use_profile, - "use_context": use_context, - "use_query_maker": use_query_maker, - "retriever_name": retriever_name, - "top_n": top_n, - "device": device, -} - -# 선택이 바뀌면 자동으로 세션 그래프 갱신 -prev_config = st.session_state.get("graph_config") -if ("graph" not in st.session_state) or (prev_config != config): - _builder = build_state_graph(sequence) - st.session_state["graph"] = _builder.compile() - st.session_state["graph_config"] = config - # Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달 - st.session_state["default_retriever_name"] = retriever_name - st.session_state["default_top_n"] = top_n - st.session_state["default_device"] = device - st.info("그래프가 세션에 적용되었습니다.") - -# 수동 새로고침 버튼 -if st.button("세션 그래프 새로고침"): - _builder = build_state_graph(sequence) - st.session_state["graph"] = _builder.compile() - st.session_state["graph_config"] = config - st.session_state["default_retriever_name"] = retriever_name - st.session_state["default_top_n"] = top_n - st.session_state["default_device"] = device - st.success("세션 그래프가 새로고침되었습니다.") - -with st.expander("현재 세션 그래프 설정"): - st.json(st.session_state.get("graph_config", {})) diff --git a/prompt/document_suitability_prompt.md b/prompt/document_suitability_prompt.md deleted file mode 100644 index e8bcbd1..0000000 --- a/prompt/document_suitability_prompt.md +++ /dev/null @@ -1,47 +0,0 @@ -## 문서 적합성 평가 프롬프트 (Table Search 재랭킹) - -당신은 데이터 카탈로그 평가자입니다. 주어진 사용자 질문과 검색 결과(테이블 → 칼럼 설명 맵)를 바탕으로, 각 테이블이 질문에 얼마나 적합한지 0~1 사이의 실수 점수로 평가하세요. - -### 입력 -- **question**: {question} -- **tables**: {tables} - -### 과업 -1. **핵심 신호 추출**: 질문에서 엔터티/지표/시간/필터/그룹화 단서를 추출합니다. -2. **테이블별 점수화**: 각 테이블의 칼럼·설명과의 연관성으로 적합도를 점수화합니다(0~1, 소수 셋째 자리 반올림). -3. **근거와 보완점 제시**: 매칭된 칼럼과 부족한 요소(엔터티/지표/기간 등)를 한국어로 설명합니다. -4. **정렬**: 결과를 점수 내림차순으로 정렬해 반환합니다. - -### 평가 규칙(가이드) -- **0.90~1.00**: 필요한 엔터티, 기간/시간 컬럼, 핵심 지표/측정 칼럼이 모두 존재. 직접 조회/집계만으로 답 가능. -- **0.60~0.89**: 주요 신호 매칭, 일부 보완(기간/그룹 키/보조 칼럼) 필요. 조인 없이 근사 가능. -- **0.30~0.59**: 일부만 매칭. 외부 컨텍스트나 조인 없이는 부정확/제한적. -- **0.00~0.29**: 연관성 낮음. 스키마/도메인 불일치 또는 정책/운영성 테이블. - -### 주의 -- 칼럼 이름/설명에 실제로 존재하지 않는 항목을 매칭하지 마세요(환각 금지). -- 시간 요구(특정 날짜/기간)가 있으면 timestamp/date/created_at 등 시간 계열 키를 중시하세요. -- 엔티티 키(예: id, user_id, product_id)의 존재 여부를 가산점으로 반영하세요. -- 키 이름은 정확히 입력 맵의 키만 사용하세요(자유 추측 금지). - -### 언어/출력 형식 -- 모든 텍스트 값은 한국어로 작성하세요. -- 결과는 반드시 아래 JSON 스키마로만 반환하세요(추가/누락 키 금지). - -### 출력(JSON 스키마) -{{ - "results": [ - {{ - "table_name": string, - "score": number, // 0.0~1.0, 소수 셋째 자리 반올림 - "reason": string, // 한국어 한두 문장 근거 - "matched_columns": string[], - "missing_entities": string[] - }} - ] -}} - -### 검증 규칙 -- score는 [0, 1] 범위로 클램핑하고 소수 셋째 자리까지 반올림하세요. -- matched_columns는 해당 테이블 객체의 실제 키만 포함하세요(단, table_description 제외). -- reason 및 missing_entities는 한국어로 작성하세요. \ No newline at end of file diff --git a/prompt/profile_extraction_prompt.md b/prompt/profile_extraction_prompt.md deleted file mode 100644 index 606e037..0000000 --- a/prompt/profile_extraction_prompt.md +++ /dev/null @@ -1,19 +0,0 @@ -# Role - -You are an assistant that analyzes a user question and extracts the following profiles as JSON: -- is_timeseries (boolean) -- is_aggregation (boolean) -- has_filter (boolean) -- is_grouped (boolean) -- has_ranking (boolean) -- has_temporal_comparison (boolean) -- intent_type (one of: trend, lookup, comparison, distribution) - -# Input - -Question: -{question} - -# Output Example - -The output must be a valid JSON matching the QuestionProfile schema. diff --git a/prompt/query_enrichment_prompt.md b/prompt/query_enrichment_prompt.md deleted file mode 100644 index 98fbb6f..0000000 --- a/prompt/query_enrichment_prompt.md +++ /dev/null @@ -1,22 +0,0 @@ -# Role - -You are a smart assistant that takes a user question and enriches it using: -1. Question profiles: {profiles} -2. Table metadata (names, columns, descriptions): - {related_tables} - -# Tasks - -- Correct any wrong terms by matching them to actual column names. -- If the question is time-series or aggregation, add explicit hints (e.g., "over the last 30 days"). -- If needed, map natural language terms to actual column values (e.g., ‘미국’ → ‘USA’ for country_code). -- Output the enriched question only. - -# Input - -Refined question: -{refined_question} - -# Notes - -Using the refined version for enrichment, but keep the original intent in mind. diff --git a/prompt/query_maker_prompt.md b/prompt/query_maker_prompt.md deleted file mode 100644 index af3ee14..0000000 --- a/prompt/query_maker_prompt.md +++ /dev/null @@ -1,47 +0,0 @@ -# Role - -당신은 데이터 분석 전문가(데이터 분석가 페르소나)입니다. -사용자의 질문을 기반으로, 주어진 테이블과 컬럼 정보를 활용하여 적절한 SQL 쿼리를 생성하세요. - -# 주의사항 -- 사용자의 질문이 다소 모호하더라도, 주어진 데이터를 참고하여 합리적인 가정을 통해 SQL 쿼리를 완성하세요. -- 불필요한 재질문 없이, 가능한 가장 명확한 분석 쿼리를 만들어 주세요. -- 반드시 입력된 다이얼렉트 변수들을 준수하여 문법을 선택하세요. -- 최종 출력 형식은 반드시 아래와 같아야 합니다. - -# Output Example -최종 형태 예시: - -```sql - SELECT COUNT(DISTINCT user_id) - FROM stg_users -``` - -<해석> -```plaintext (max_length_per_line=100) - 이 쿼리는 stg_users 테이블에서 고유한 사용자의 수를 계산합니다. - 사용자는 유니크한 user_id를 가지고 있으며 - 중복을 제거하기 위해 COUNT(DISTINCT user_id)를 사용했습니다. -``` - -# Input - -- 사용자 질문: -{user_input} - -- DB 환경: -{user_database_env} - -- 관련 테이블 및 컬럼 정보: -{searched_tables} - -- 다이얼렉트 정보: - - dialect_name: {dialect_name} - - supports_ilike: {supports_ilike} - - dialect_hints: {dialect_hints} - -# Notes - -- 위 입력을 바탕으로 최적의 SQL을 생성하세요. -- {dialect_hints}를 참고하여 엔진에 맞는 함수/연산자를 우선 사용하세요. -- 출력은 위 '최종 형태 예시'와 동일한 구조로만 작성하세요. \ No newline at end of file diff --git a/prompt/question_gate_prompt.md b/prompt/question_gate_prompt.md deleted file mode 100644 index aea3865..0000000 --- a/prompt/question_gate_prompt.md +++ /dev/null @@ -1,19 +0,0 @@ -당신은 데이터 분석 도우미입니다. 아래 사용자 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하세요. - -요건: -- reason: 한 줄 설명(어떤 보완이 필요한지 요약) -- missing_entities: 기간, 대상 엔터티, 측정값 등 누락된 핵심 요소 리스트(없으면 빈 리스트) -- requires_data_science: 통계/ML 분석이 필요한지 여부(Boolean) - -언어/출력 형식: -- 모든 텍스트 값은 한국어로 작성하세요. (reason는 한국어 문장, missing_entities 항목은 한국어 명사구) -- Boolean 값은 JSON의 true/false로 표기하세요. - -주의: -- 데이터 분석 맥락에서 SQL 집계/필터/조인으로 해결 가능한지 판단합니다. -- 정책/운영/가이드/설치/권한/오류 해결 등은 SQL 부적합으로 간주합니다. - -입력: {question} - -출력은 반드시 지정된 스키마의 JSON으로만 반환하세요. - diff --git a/utils/llm/chains.py b/utils/llm/chains.py deleted file mode 100644 index 6e66f8b..0000000 --- a/utils/llm/chains.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -LLM 체인 생성 모듈. - -이 모듈은 Lang2SQL에서 사용하는 다양한 LangChain 기반 체인을 정의합니다. -- Query Maker -- Query Enrichment -- Profile Extraction -- Question Gate (SQL 적합성 분류) -""" - -from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate -from pydantic import BaseModel, Field - -from prompt.template_loader import get_prompt_template -from utils.llm.core import get_llm -from utils.llm.output_schema.document_suitability import DocumentSuitabilityList -from utils.llm.output_schema.question_suitability import QuestionSuitability - -llm = get_llm() - - -class QuestionProfile(BaseModel): - """ - 자연어 질문의 특징을 구조화해 표현하는 프로파일 모델. - - 이 프로파일은 이후 컨텍스트 보강 및 SQL 생성 시 힌트로 사용됩니다. - """ - - is_timeseries: bool = Field(description="시계열 분석 필요 여부") - is_aggregation: bool = Field(description="집계 함수 필요 여부") - has_filter: bool = Field(description="조건 필터 필요 여부") - is_grouped: bool = Field(description="그룹화 필요 여부") - has_ranking: bool = Field(description="정렬/순위 필요 여부") - has_temporal_comparison: bool = Field(description="기간 비교 포함 여부") - intent_type: str = Field(description="질문의 주요 의도 유형") - - -# QueryMakerChain -def create_query_maker_chain(llm): - """ - SQL 쿼리 생성을 위한 체인을 생성합니다. - - Args: - llm: LangChain 호환 LLM 인스턴스 - - Returns: - Runnable: 입력 프롬프트를 받아 SQL을 생성하는 체인 - """ - prompt = get_prompt_template("query_maker_prompt") - query_maker_prompt = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(prompt), - ] - ) - return query_maker_prompt | llm - - -def create_query_enrichment_chain(llm): - """ - 사용자 질문을 메타데이터로 보강하기 위한 체인을 생성합니다. - - Args: - llm: LangChain 호환 LLM 인스턴스 - - Returns: - Runnable: 보강된 질문 텍스트를 반환하는 체인 - """ - prompt = get_prompt_template("query_enrichment_prompt") - - enrichment_prompt = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(prompt), - ] - ) - - chain = enrichment_prompt | llm - return chain - - -def create_profile_extraction_chain(llm): - """ - 질문으로부터 `QuestionProfile`을 추출하는 체인을 생성합니다. - - Args: - llm: LangChain 호환 LLM 인스턴스 - - Returns: - Runnable: `QuestionProfile` 구조화 출력을 반환하는 체인 - """ - prompt = get_prompt_template("profile_extraction_prompt") - - profile_prompt = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(prompt), - ] - ) - - chain = profile_prompt | llm.with_structured_output(QuestionProfile) - return chain - - -def create_question_gate_chain(llm): - """ - 질문 적합성(Question Gate) 체인을 생성합니다. - - ChatPromptTemplate(SystemMessage) + LLM 구조화 출력으로 - `QuestionSuitability`를 반환합니다. - - Args: - llm: LangChain 호환 LLM 인스턴스 - - Returns: - Runnable: invoke({"question": str}) -> QuestionSuitability - """ - - prompt = get_prompt_template("question_gate_prompt") - gate_prompt = ChatPromptTemplate.from_messages( - [SystemMessagePromptTemplate.from_template(prompt)] - ) - return gate_prompt | llm.with_structured_output(QuestionSuitability) - - -def create_document_suitability_chain(llm): - """ - 문서 적합성 평가 체인을 생성합니다. - - 질문(question)과 검색 결과(tables)를 입력으로 받아 - 테이블별 적합도 점수를 포함한 JSON 딕셔너리를 반환합니다. - - Returns: - Runnable: invoke({"question": str, "tables": dict}) -> {"results": DocumentSuitability[]} - """ - - prompt = get_prompt_template("document_suitability_prompt") - doc_prompt = ChatPromptTemplate.from_messages( - [SystemMessagePromptTemplate.from_template(prompt)] - ) - return doc_prompt | llm.with_structured_output(DocumentSuitabilityList) - - -query_maker_chain = create_query_maker_chain(llm) -profile_extraction_chain = create_profile_extraction_chain(llm) -query_enrichment_chain = create_query_enrichment_chain(llm) -question_gate_chain = create_question_gate_chain(llm) -document_suitability_chain = create_document_suitability_chain(llm) diff --git a/utils/llm/core/__init__.py b/utils/llm/core/__init__.py index a842cfb..76e54fe 100644 --- a/utils/llm/core/__init__.py +++ b/utils/llm/core/__init__.py @@ -1,33 +1,3 @@ -from utils.llm.core.factory import ( - get_embeddings, - get_embeddings_azure, - get_embeddings_bedrock, - get_embeddings_gemini, - get_embeddings_huggingface, - get_embeddings_ollama, - get_embeddings_openai, - get_llm, - get_llm_azure, - get_llm_bedrock, - get_llm_gemini, - get_llm_huggingface, - get_llm_ollama, - get_llm_openai, -) +# Legacy LLM core module — factory removed (replaced by src/lang2sql/factory.py) -__all__ = [ - "get_llm", - "get_llm_openai", - "get_llm_azure", - "get_llm_bedrock", - "get_llm_gemini", - "get_llm_ollama", - "get_llm_huggingface", - "get_embeddings", - "get_embeddings_openai", - "get_embeddings_azure", - "get_embeddings_bedrock", - "get_embeddings_gemini", - "get_embeddings_ollama", - "get_embeddings_huggingface", -] +__all__ = [] diff --git a/utils/llm/core/factory.py b/utils/llm/core/factory.py deleted file mode 100644 index 3164220..0000000 --- a/utils/llm/core/factory.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -from typing import Optional - -from langchain.llms.base import BaseLanguageModel -from langchain_aws import BedrockEmbeddings, ChatBedrockConverse -from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings -from langchain_huggingface import ( - ChatHuggingFace, - HuggingFaceEndpoint, - HuggingFaceEndpointEmbeddings, -) -from langchain_ollama import ChatOllama, OllamaEmbeddings -from langchain_openai import ( - AzureChatOpenAI, - AzureOpenAIEmbeddings, - ChatOpenAI, - OpenAIEmbeddings, -) - - -def get_llm(**kwargs) -> BaseLanguageModel: - """ - return chat model interface - """ - provider = os.getenv("LLM_PROVIDER") - print(os.environ["LLM_PROVIDER"]) - - if provider is None: - raise ValueError("LLM_PROVIDER environment variable is not set.") - - if provider == "openai": - return get_llm_openai(**kwargs) - - elif provider == "azure": - return get_llm_azure(**kwargs) - - elif provider == "bedrock": - return get_llm_bedrock(**kwargs) - - elif provider == "gemini": - return get_llm_gemini(**kwargs) - - elif provider == "ollama": - return get_llm_ollama(**kwargs) - - elif provider == "huggingface": - return get_llm_huggingface(**kwargs) - - else: - raise ValueError(f"Invalid LLM API Provider: {provider}") - - -def get_llm_openai(**kwargs) -> BaseLanguageModel: - return ChatOpenAI( - model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), - api_key=os.getenv("OPEN_AI_KEY"), - **kwargs, - ) - - -def get_llm_azure(**kwargs) -> BaseLanguageModel: - return AzureChatOpenAI( - api_key=os.getenv("AZURE_OPENAI_LLM_KEY"), - azure_endpoint=os.getenv("AZURE_OPENAI_LLM_ENDPOINT"), - azure_deployment=os.getenv("AZURE_OPENAI_LLM_MODEL"), # Deployment name - api_version=os.getenv("AZURE_OPENAI_LLM_API_VERSION", "2023-07-01-preview"), - **kwargs, - ) - - -def get_llm_bedrock(**kwargs) -> BaseLanguageModel: - return ChatBedrockConverse( - model=os.getenv("AWS_BEDROCK_LLM_MODEL"), - aws_access_key_id=os.getenv("AWS_BEDROCK_LLM_ACCESS_KEY_ID"), - aws_secret_access_key=os.getenv("AWS_BEDROCK_LLM_SECRET_ACCESS_KEY"), - region_name=os.getenv("AWS_BEDROCK_LLM_REGION", "us-east-1"), - **kwargs, - ) - - -def get_llm_gemini(**kwargs) -> BaseLanguageModel: - return ChatGoogleGenerativeAI(model=os.getenv("GEMINI_LLM_MODEL"), **kwargs) - - -def get_llm_ollama(**kwargs) -> BaseLanguageModel: - base_url = os.getenv("OLLAMA_LLM_BASE_URL") - if base_url: - return ChatOllama( - base_url=base_url, model=os.getenv("OLLAMA_LLM_MODEL"), **kwargs - ) - else: - return ChatOllama(model=os.getenv("OLLAMA_LLM_MODEL"), **kwargs) - - -def get_llm_huggingface(**kwargs) -> BaseLanguageModel: - return ChatHuggingFace( - llm=HuggingFaceEndpoint( - model=os.getenv("HUGGING_FACE_LLM_MODEL"), - repo_id=os.getenv("HUGGING_FACE_LLM_REPO_ID"), - task="text-generation", - endpoint_url=os.getenv("HUGGING_FACE_LLM_ENDPOINT"), - huggingfacehub_api_token=os.getenv("HUGGING_FACE_LLM_API_TOKEN"), - **kwargs, - ) - ) - - -def get_embeddings() -> Optional[BaseLanguageModel]: - """ - return embedding model interface - """ - provider = os.getenv("EMBEDDING_PROVIDER") - print(provider) - - if provider is None: - raise ValueError("EMBEDDING_PROVIDER environment variable is not set.") - - if provider == "openai": - return get_embeddings_openai() - - elif provider == "bedrock": - return get_embeddings_bedrock() - - elif provider == "azure": - return get_embeddings_azure() - - elif provider == "gemini": - return get_embeddings_gemini() - - elif provider == "ollama": - return get_embeddings_ollama() - - else: - raise ValueError(f"Invalid Embedding API Provider: {provider}") - - -def get_embeddings_openai() -> BaseLanguageModel: - return OpenAIEmbeddings( - model=os.getenv("OPEN_AI_EMBEDDING_MODEL"), - openai_api_key=os.getenv("OPEN_AI_KEY"), - ) - - -def get_embeddings_azure() -> BaseLanguageModel: - return AzureOpenAIEmbeddings( - api_key=os.getenv("AZURE_OPENAI_EMBEDDING_KEY"), - azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT"), - azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"), - api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), - ) - - -def get_embeddings_bedrock() -> BaseLanguageModel: - return BedrockEmbeddings( - model_id=os.getenv("AWS_BEDROCK_EMBEDDING_MODEL"), - aws_access_key_id=os.getenv("AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID"), - aws_secret_access_key=os.getenv("AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY"), - region_name=os.getenv("AWS_BEDROCK_EMBEDDING_REGION", "us-east-1"), - ) - - -def get_embeddings_gemini() -> BaseLanguageModel: - return GoogleGenerativeAIEmbeddings( - model=os.getenv("GEMINI_EMBEDDING_MODEL"), - api_key=os.getenv("GEMINI_EMBEDDING_KEY"), - ) - - -def get_embeddings_ollama() -> BaseLanguageModel: - return OllamaEmbeddings( - model=os.getenv("OLLAMA_EMBEDDING_MODEL"), - base_url=os.getenv("OLLAMA_EMBEDDING_BASE_URL"), - ) - - -def get_embeddings_huggingface() -> BaseLanguageModel: - return HuggingFaceEndpointEmbeddings( - model=os.getenv("HUGGING_FACE_EMBEDDING_MODEL"), - repo_id=os.getenv("HUGGING_FACE_EMBEDDING_REPO_ID"), - huggingfacehub_api_token=os.getenv("HUGGING_FACE_EMBEDDING_API_TOKEN"), - ) diff --git a/utils/llm/graph_utils/README.md b/utils/llm/graph_utils/README.md deleted file mode 100644 index 595323a..0000000 --- a/utils/llm/graph_utils/README.md +++ /dev/null @@ -1,185 +0,0 @@ -# graph_utils - -이 모듈은 **LangGraph workflow**를 위한 그래프 유틸리티들을 제공합니다. Lang2SQL 프로젝트에서 자연어 질문을 SQL 쿼리로 변환하는 워크플로우를 LangGraph를 사용하여 구성합니다. - -## 디렉토리 구조 - -``` -graph_utils/ -├── __init__.py -├── base.py -├── basic_graph.py -├── enriched_graph.py -├── profile_utils.py -└── README.md -``` - -## 파일 설명 - -### `__init__.py` -그래프 관련 유틸리티 모듈의 공개 인터페이스를 정의합니다. - -**주요 사용:** -- **상태 및 노드 식별자:** - - `QueryMakerState`: 그래프의 상태 타입 정의 - - `GET_TABLE_INFO`, `QUERY_MAKER`, `PROFILE_EXTRACTION`, `CONTEXT_ENRICHMENT`: 노드 식별자 상수 - -- **노드 함수들:** - - `get_table_info_node`: 테이블 정보 검색 노드 - - `query_maker_node`: SQL 쿼리 생성 노드 - - `profile_extraction_node`: 질문 프로파일 추출 노드 - - `context_enrichment_node`: 컨텍스트 보강 노드 - -- **그래프 빌더들:** - - `basic_builder`: 기본 워크플로우 그래프 빌더 - - `enriched_builder`: 확장된 워크플로우 그래프 빌더 - -### `base.py` -LangGraph 워크플로우의 핵심 노드 함수들과 상태 정의를 포함합니다. - -**주요 내용:** -- **상태 타입 (`QueryMakerState`):** TypedDict를 사용하여 그래프 상태 구조를 정의 - - `messages`: LLM 메시지 리스트 - - `user_database_env`: 사용자 데이터베이스 환경 - - `searched_tables`: 검색된 테이블 정보 - - `question_profile`: 질문 프로파일 정보 - - `generated_query`: 생성된 SQL 쿼리 - - 기타 워크플로우에 필요한 상태 정보 - -- **노드 식별자 상수:** - - `QUESTION_GATE`, `EVALUATE_DOCUMENT_SUITABILITY`, `GET_TABLE_INFO`, `TOOL`, `TABLE_FILTER`, `QUERY_MAKER`, `PROFILE_EXTRACTION`, `CONTEXT_ENRICHMENT` - -- **노드 함수들:** - - `question_gate_node`: 사용자 질문이 SQL로 답변 가능한지 판별하는 게이트 노드 - - `get_table_info_node`: 벡터 검색을 통해 관련 테이블 정보를 가져오는 노드 - - `document_suitability_node`: 검색된 테이블들의 문서 적합성 점수를 계산하는 노드 - - `profile_extraction_node`: 자연어 쿼리로부터 질문 유형(시계열, 집계, 필터 등)을 추출하는 노드 - - `context_enrichment_node`: 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드 - - `query_maker_node`: 최종 SQL 쿼리를 생성하는 노드 - -### `basic_graph.py` -기본 워크플로우를 위한 StateGraph 구성을 정의합니다. - -**워크플로우 순서:** -``` -QUESTION_GATE → GET_TABLE_INFO → EVALUATE_DOCUMENT_SUITABILITY → QUERY_MAKER → END -``` - -**주요 내용:** -- `StateGraph`를 사용하여 기본 워크플로우 그래프 생성 -- `builder` 객체를 export하여 다른 모듈에서 사용 가능 -- 조건부 라우팅(`add_conditional_edges`)을 통해 게이트 노드 이후 흐름 제어 - -### `enriched_graph.py` -기본 워크플로우에 프로파일 추출과 컨텍스트 보강 단계를 추가한 확장된 그래프입니다. - -**워크플로우 순서:** -``` -QUESTION_GATE → GET_TABLE_INFO → EVALUATE_DOCUMENT_SUITABILITY → -PROFILE_EXTRACTION → CONTEXT_ENRICHMENT → QUERY_MAKER → END -``` - -**주요 내용:** -- `basic_graph`와 동일한 구조이지만 `PROFILE_EXTRACTION`과 `CONTEXT_ENRICHMENT` 노드가 추가됨 -- 더 정교한 질문 분석과 컨텍스트 보강을 통해 더 나은 SQL 쿼리 생성이 가능 - -### `profile_utils.py` -질문 프로파일 객체를 텍스트로 변환하는 유틸리티 함수를 제공합니다. - -**주요 함수:** -- `profile_to_text(profile_obj) -> str`: 질문 프로파일 객체를 읽기 쉬운 텍스트 형태로 변환 - - 시계열 분석 필요 여부 - - 집계 함수 필요 여부 - - WHERE 조건 필요 여부 - - GROUP BY 필요 여부 - - 정렬/순위 필요 여부 - - 기간 비교 필요 여부 - - 의도 유형 정보 - -## 사용 방법 - -### 1. `engine/query_executor.py`에서의 사용 - -기본 또는 확장된 그래프 빌더를 선택하여 쿼리를 실행합니다: - -```python -from utils.llm.graph_utils.basic_graph import builder as basic_builder -from utils.llm.graph_utils.enriched_graph import builder as enriched_builder - -# 그래프 선택 -if use_enriched_graph: - graph_builder = enriched_builder -else: - graph_builder = basic_builder - -# 그래프 컴파일 및 실행 -graph = graph_builder.compile() -result = graph.invoke({ - "messages": [HumanMessage(content=query)], - "user_database_env": database_env, - # ... 기타 상태 정보 -}) -``` - -**사용 위치:** `/home/dwlee/Lang2SQL/engine/query_executor.py`의 `execute_query()` 함수 - -### 2. `interface/core/session_utils.py`에서의 사용 - -Streamlit 세션 상태에서 그래프 빌더를 동적으로 초기화합니다: - -```python -def init_graph(use_enriched: bool) -> str: - builder_module = ( - "utils.llm.graph_utils.enriched_graph" - if use_enriched - else "utils.llm.graph_utils.basic_graph" - ) - builder = __import__(builder_module, fromlist=["builder"]).builder - st.session_state["graph"] = builder.compile() - return "확장된" if use_enriched else "기본" -``` - -**사용 위치:** `/home/dwlee/Lang2SQL/interface/core/session_utils.py`의 `init_graph()` 함수 - -### 3. `interface/app_pages/graph_builder.py`에서의 사용 - -Streamlit 인터페이스에서 커스텀 그래프를 구성할 때 개별 노드 함수들을 사용합니다: - -```python -from utils.llm.graph_utils.base import ( - CONTEXT_ENRICHMENT, - GET_TABLE_INFO, - PROFILE_EXTRACTION, - QUERY_MAKER, - QueryMakerState, - context_enrichment_node, - get_table_info_node, - profile_extraction_node, - query_maker_node, -) - -# 커스텀 시퀀스에 따라 노드 등록 -builder = StateGraph(QueryMakerState) -for node_id in sequence: - if node_id == GET_TABLE_INFO: - builder.add_node(GET_TABLE_INFO, get_table_info_node) - elif node_id == PROFILE_EXTRACTION: - builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) - # ... 기타 노드들 -``` - -**사용 위치:** `/home/dwlee/Lang2SQL/interface/app_pages/graph_builder.py`의 `build_state_graph()` 함수 - -## 워크플로우 개요 - -이 모듈은 **LangGraph**를 사용하여 자연어 질문을 SQL 쿼리로 변환하는 워크플로우를 구현합니다: - -1. **QUESTION_GATE**: 질문이 SQL로 답변 가능한지 판별 -2. **GET_TABLE_INFO**: 벡터 검색을 통해 관련 테이블 정보 검색 -3. **EVALUATE_DOCUMENT_SUITABILITY**: 검색된 테이블들의 적합성 평가 -4. **PROFILE_EXTRACTION** (확장 그래프만): 질문의 특성 추출 (시계열, 집계 등) -5. **CONTEXT_ENRICHMENT** (확장 그래프만): 질문을 컨텍스트 정보로 보강 -6. **QUERY_MAKER**: 최종 SQL 쿼리 생성 - -각 노드는 `QueryMakerState`를 입력으로 받아 상태를 업데이트하고 반환합니다. - diff --git a/utils/llm/graph_utils/__init__.py b/utils/llm/graph_utils/__init__.py deleted file mode 100644 index c2012ad..0000000 --- a/utils/llm/graph_utils/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -그래프 관련 유틸리티 모듈입니다. - -이 패키지는 Lang2SQL의 워크플로우 그래프 구성과 관련된 모듈들을 포함합니다. -""" - -from utils.llm.graph_utils.base import ( - CONTEXT_ENRICHMENT, - GET_TABLE_INFO, - PROFILE_EXTRACTION, - QUERY_MAKER, - QueryMakerState, - context_enrichment_node, - get_table_info_node, - profile_extraction_node, - query_maker_node, -) - -from .basic_graph import builder as basic_builder -from .enriched_graph import builder as enriched_builder - -__all__ = [ - # 상태 및 노드 식별자 - "QueryMakerState", - "GET_TABLE_INFO", - "QUERY_MAKER", - "PROFILE_EXTRACTION", - "CONTEXT_ENRICHMENT", - # 노드 함수들 - "get_table_info_node", - "query_maker_node", - "profile_extraction_node", - "context_enrichment_node", - # 그래프 빌더들 - "basic_builder", - "enriched_builder", -] diff --git a/utils/llm/graph_utils/base.py b/utils/llm/graph_utils/base.py deleted file mode 100644 index 41a4d49..0000000 --- a/utils/llm/graph_utils/base.py +++ /dev/null @@ -1,258 +0,0 @@ -import json - -from langgraph.graph.message import add_messages -from typing_extensions import Annotated, TypedDict - -from utils.llm.chains import ( - document_suitability_chain, - profile_extraction_chain, - query_enrichment_chain, - query_maker_chain, - question_gate_chain, -) -from utils.llm.retrieval import search_tables - -# 노드 식별자 정의 -QUESTION_GATE = "question_gate" -EVALUATE_DOCUMENT_SUITABILITY = "evaluate_document_suitability" -GET_TABLE_INFO = "get_table_info" -TOOL = "tool" -TABLE_FILTER = "table_filter" -QUERY_MAKER = "query_maker" -PROFILE_EXTRACTION = "profile_extraction" -CONTEXT_ENRICHMENT = "context_enrichment" - - -# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함) -class QueryMakerState(TypedDict): - messages: Annotated[list, add_messages] - user_database_env: str - searched_tables: dict[str, dict[str, str]] - document_suitability: dict - best_practice_query: str - question_profile: dict - generated_query: str - retriever_name: str - top_n: int - device: str - question_gate_result: dict - # 다이얼렉트 정보 - dialect_name: str - supports_ilike: bool - dialect_hints: list[str] - - -# 노드 함수: QUESTION_GATE 노드 -def question_gate_node(state: QueryMakerState): - """ - 사용자의 질문이 SQL로 답변 가능한지 판별하고, 구조화된 결과를 반환하는 게이트 노드입니다. - - - question_gate_chain 으로 적합성을 판정하여 - `question_gate_result`를 설정합니다. - - Args: - state (QueryMakerState): 그래프 상태 - - Returns: - QueryMakerState: 게이트 판정 결과가 반영된 상태 - """ - - question_text = state["messages"][0].content - suitability = question_gate_chain.invoke({"question": question_text}) - state["question_gate_result"] = { - "reason": getattr(suitability, "reason", ""), - "missing_entities": getattr(suitability, "missing_entities", []), - "requires_data_science": getattr(suitability, "requires_data_science", False), - } - return state - - -# 노드 함수: PROFILE_EXTRACTION 노드 -def profile_extraction_node(state: QueryMakerState): - """ - 자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다. - - 이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부, - 그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다. - - 추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다: - - `is_timeseries`: 시계열 분석 필요 여부 - - `is_aggregation`: 집계 함수 필요 여부 - - `has_filter`: 조건 필터 필요 여부 - - `is_grouped`: 그룹화 필요 여부 - - `has_ranking`: 정렬/순위 필요 여부 - - `has_temporal_comparison`: 기간 비교 포함 여부 - - `intent_type`: 질문의 주요 의도 유형 - - """ - result = profile_extraction_chain.invoke({"question": state["messages"][0].content}) - - state["question_profile"] = result - print("profile_extraction_node : ", result) - return state - - -# 노드 함수: CONTEXT_ENRICHMENT 노드 -def context_enrichment_node(state: QueryMakerState): - """ - 주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다. - - 이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다. - 보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다. - - 주요 작업: - - 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다. - - 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안"). - - 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: '미국' → 'USA'). - - 보강된 질문을 출력합니다. - - Args: - state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체. - 상태 객체는 `messages`, `question_profile`, `searched_tables` 등의 정보를 포함합니다. - - Returns: - QueryMakerState: 보강된 질문이 포함된 상태 객체. - - Example: - Given the refined question "What are the total sales in the last month?", - the function would enrich it with additional information such as: - - Ensuring the time period is specified correctly. - - Correcting any column names if necessary. - - Returning the enriched version of the question. - """ - - searched_tables = state["searched_tables"] - searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2) - - # question_profile이 BaseModel인 경우 model_dump() 사용, dict인 경우 그대로 사용 - if hasattr(state["question_profile"], "model_dump"): - question_profile = state["question_profile"].model_dump() - else: - question_profile = state["question_profile"] - question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2) - - # 초기 사용자 입력 사용 - refined_question = state["messages"][0].content - - enriched_text = query_enrichment_chain.invoke( - input={ - "refined_question": refined_question, - "profiles": question_profile_json, - "related_tables": searched_tables_json, - } - ) - - state["messages"].append(enriched_text) - print("After context enrichment : ", enriched_text.content) - - return state - - -def get_table_info_node(state: QueryMakerState): - # retriever_name과 top_n을 이용하여 검색 수행 - documents_dict = search_tables( - query=state["messages"][0].content, - retriever_name=state["retriever_name"], - top_n=state["top_n"], - device=state["device"], - ) - state["searched_tables"] = documents_dict - - return state - - -# 노드 함수: DOCUMENT_SUITABILITY 노드 -def document_suitability_node(state: QueryMakerState): - """ - GET_TABLE_INFO에서 수집된 테이블 후보들에 대해 문서 적합성 점수를 계산하는 노드입니다. - - 질문(`messages[0].content`)과 `searched_tables`(테이블→칼럼 설명 맵)를 입력으로 - 프롬프트 체인(`document_suitability_chain`)을 호출하고, 결과 딕셔너리를 - `document_suitability` 상태 키에 저장합니다. - - Returns: - QueryMakerState: 문서 적합성 평가 결과가 포함된 상태 - """ - - # 관련 테이블이 없으면 즉시 반환 - if not state.get("searched_tables"): - state["document_suitability"] = {} - return state - - res = document_suitability_chain.invoke( - { - "question": state["messages"][0].content, - "tables": state["searched_tables"], - } - ) - - items = ( - res.get("results", []) - if isinstance(res, dict) - else getattr(res, "results", None) - or (res.model_dump().get("results", []) if hasattr(res, "model_dump") else []) - ) - - normalized = {} - for x in items: - d = ( - x.model_dump() - if hasattr(x, "model_dump") - else ( - x - if isinstance(x, dict) - else { - "table_name": getattr(x, "table_name", ""), - "score": getattr(x, "score", 0), - "reason": getattr(x, "reason", ""), - "matched_columns": getattr(x, "matched_columns", []), - "missing_entities": getattr(x, "missing_entities", []), - } - ) - ) - t = d.get("table_name") - if not t: - continue - normalized[t] = { - "score": float(d.get("score", 0)), - "reason": d.get("reason", ""), - "matched_columns": d.get("matched_columns", []), - "missing_entities": d.get("missing_entities", []), - } - - state["document_suitability"] = normalized - - return state - - -# 노드 함수: QUERY_MAKER 노드 -def query_maker_node(state: QueryMakerState): - # 사용자 원 질문 + (있다면) 컨텍스트 보강 결과를 하나의 문자열로 결합 - parts = [state["messages"][0].content] - if len(state["messages"]) > 1: - last_msg = state["messages"][-1] - last_content = ( - last_msg.content if hasattr(last_msg, "content") else str(last_msg) - ) - if isinstance(last_content, str) and last_content.strip(): - parts.append(last_content) - - combined_input = "\n\n---\n\n".join(parts) - searched_tables_json = json.dumps( - state["searched_tables"], ensure_ascii=False, indent=2 - ) - - res = query_maker_chain.invoke( - input={ - "user_input": combined_input, - "user_database_env": state["user_database_env"], - "searched_tables": searched_tables_json, - # 다이얼렉트 변수 전달 - "dialect_name": state.get("dialect_name", ""), - "supports_ilike": state.get("supports_ilike", False), - "dialect_hints": ", ".join(state.get("dialect_hints", [])), - } - ) - state["generated_query"] = res - state["messages"].append(res) - return state diff --git a/utils/llm/graph_utils/basic_graph.py b/utils/llm/graph_utils/basic_graph.py deleted file mode 100644 index e7c7ade..0000000 --- a/utils/llm/graph_utils/basic_graph.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -기본 워크플로우를 위한 StateGraph 구성입니다. -GET_TABLE_INFO -> QUERY_MAKER 순서로 실행됩니다. -""" - -from langgraph.graph import END, StateGraph - -from utils.llm.graph_utils.base import ( - EVALUATE_DOCUMENT_SUITABILITY, - GET_TABLE_INFO, - QUERY_MAKER, - QUESTION_GATE, - QueryMakerState, - document_suitability_node, - get_table_info_node, - query_maker_node, - question_gate_node, -) - -# StateGraph 생성 및 구성 -builder = StateGraph(QueryMakerState) -builder.set_entry_point(QUESTION_GATE) - -# 노드 추가 -builder.add_node(QUESTION_GATE, question_gate_node) -builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(EVALUATE_DOCUMENT_SUITABILITY, document_suitability_node) -builder.add_node(QUERY_MAKER, query_maker_node) - - -def _route_after_gate(state: QueryMakerState): - return GET_TABLE_INFO - - -builder.add_conditional_edges( - QUESTION_GATE, - _route_after_gate, - { - GET_TABLE_INFO: GET_TABLE_INFO, - END: END, - }, -) - -# 기본 엣지 설정 -builder.add_edge(GET_TABLE_INFO, EVALUATE_DOCUMENT_SUITABILITY) -builder.add_edge(EVALUATE_DOCUMENT_SUITABILITY, QUERY_MAKER) - -# QUERY_MAKER 노드 후 종료 -builder.add_edge(QUERY_MAKER, END) diff --git a/utils/llm/graph_utils/enriched_graph.py b/utils/llm/graph_utils/enriched_graph.py deleted file mode 100644 index 703726d..0000000 --- a/utils/llm/graph_utils/enriched_graph.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -기본 워크플로우에 '프로파일 추출(PROFILE_EXTRACTION)'과 '컨텍스트 보강(CONTEXT_ENRICHMENT)'를 -추가한 확장된 그래프입니다. -""" - -from langgraph.graph import END, StateGraph - -from utils.llm.graph_utils.base import ( - CONTEXT_ENRICHMENT, - EVALUATE_DOCUMENT_SUITABILITY, - GET_TABLE_INFO, - PROFILE_EXTRACTION, - QUERY_MAKER, - QUESTION_GATE, - QueryMakerState, - context_enrichment_node, - document_suitability_node, - get_table_info_node, - profile_extraction_node, - query_maker_node, - question_gate_node, -) - -# StateGraph 생성 및 구성 -builder = StateGraph(QueryMakerState) -builder.set_entry_point(QUESTION_GATE) - -# 노드 추가 -builder.add_node(QUESTION_GATE, question_gate_node) -builder.add_node(GET_TABLE_INFO, get_table_info_node) -builder.add_node(EVALUATE_DOCUMENT_SUITABILITY, document_suitability_node) -builder.add_node(PROFILE_EXTRACTION, profile_extraction_node) -builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node) -builder.add_node(QUERY_MAKER, query_maker_node) - - -def _route_after_gate(state: QueryMakerState): - return GET_TABLE_INFO - - -builder.add_conditional_edges( - QUESTION_GATE, - _route_after_gate, - { - GET_TABLE_INFO: GET_TABLE_INFO, - END: END, - }, -) - -# 기본 엣지 설정 -builder.add_edge(GET_TABLE_INFO, EVALUATE_DOCUMENT_SUITABILITY) -builder.add_edge(EVALUATE_DOCUMENT_SUITABILITY, PROFILE_EXTRACTION) -builder.add_edge(PROFILE_EXTRACTION, CONTEXT_ENRICHMENT) -builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER) - -# QUERY_MAKER 노드 후 종료 -builder.add_edge(QUERY_MAKER, END) diff --git a/utils/llm/graph_utils/profile_utils.py b/utils/llm/graph_utils/profile_utils.py deleted file mode 100644 index 2057b5c..0000000 --- a/utils/llm/graph_utils/profile_utils.py +++ /dev/null @@ -1,17 +0,0 @@ -def profile_to_text(profile_obj) -> str: - mapping = { - "is_timeseries": "• 시계열 분석 필요", - "is_aggregation": "• 집계 함수 필요", - "has_filter": "• WHERE 조건 필요", - "is_grouped": "• GROUP BY 필요", - "has_ranking": "• 정렬/순위 필요", - "has_temporal_comparison": "• 기간 비교 필요", - } - bullets = [ - text for field, text in mapping.items() if getattr(profile_obj, field, False) - ] - intent = getattr(profile_obj, "intent_type", None) - if intent: - bullets.append(f"• 의도 유형 → {intent}") - - return "\n".join(bullets) diff --git a/utils/llm/output_schema/README.md b/utils/llm/output_schema/README.md deleted file mode 100644 index 2e220c2..0000000 --- a/utils/llm/output_schema/README.md +++ /dev/null @@ -1,114 +0,0 @@ -# output_schema 모듈 - -LLM 구조화 출력을 위한 Pydantic 모델 정의 모듈입니다. - -## 디렉토리 구조 - -``` -output_schema/ -├── __pycache__/ -├── document_suitability.py -└── question_suitability.py -``` - -## 파일 목록 및 설명 - -### document_suitability.py - -**목적**: LLM 구조화 출력으로부터 테이블별 적합성 평가 결과를 표현하는 Pydantic 모델을 정의합니다. - -**주요 클래스**: - -- `DocumentSuitability`: 단일 테이블에 대한 적합성 평가 결과를 표현하는 모델 - - `table_name` (str): 테이블명 - - `score` (float): 0.0~1.0 사이의 적합도 점수 - - `reason` (str): 한국어 한두 문장 근거 - - `matched_columns` (List[str]): 질문과 직접 연관된 컬럼명 목록 - - `missing_entities` (List[str]): 부족한 엔티티/지표/기간 등 - -- `DocumentSuitabilityList`: 문서 적합성 평가 결과 리스트 래퍼 - - `results` (List[DocumentSuitability]): 평가 결과 목록 - - OpenAI Structured Outputs 호환을 위해 명시적 최상위 키(`results`)를 제공 - -### question_suitability.py - -**목적**: LLM 구조화 출력으로부터 SQL 적합성 판단 결과를 표현하는 Pydantic 모델을 정의합니다. - -**주요 클래스**: - -- `QuestionSuitability`: SQL 생성 적합성 결과 모델 - - `reason` (str): 보완/설명 사유 요약 - - `missing_entities` (list[str]): 질문에서 누락된 핵심 엔터티/기간 등 - - `requires_data_science` (bool): SQL을 넘어 ML/통계 분석이 필요한지 여부 - -## 사용 방법 - -### Import 및 사용 위치 - -이 모듈의 클래스들은 `utils/llm/chains.py`에서 import되어 사용됩니다: - -```python -from utils.llm.output_schema.document_suitability import DocumentSuitabilityList -from utils.llm.output_schema.question_suitability import QuestionSuitability -``` - -### 사용 예시 - -#### 1. QuestionSuitability 사용 - -`create_question_gate_chain()` 함수에서 질문 적합성을 판단하는 체인을 생성할 때 사용됩니다: - -```python -def create_question_gate_chain(llm): - """ - 질문 적합성(Question Gate) 체인을 생성합니다. - - Returns: - Runnable: invoke({"question": str}) -> QuestionSuitability - """ - prompt = get_prompt_template("question_gate_prompt") - gate_prompt = ChatPromptTemplate.from_messages( - [SystemMessagePromptTemplate.from_template(prompt)] - ) - return gate_prompt | llm.with_structured_output(QuestionSuitability) -``` - -**사용 흐름**: -1. 사용자 질문을 입력으로 받음 -2. LLM이 구조화된 출력으로 `QuestionSuitability` 객체를 반환 -3. SQL 생성이 적합한지 여부와 필요 보완 사항을 판단 - -#### 2. DocumentSuitabilityList 사용 - -`create_document_suitability_chain()` 함수에서 문서(테이블) 적합성을 평가하는 체인을 생성할 때 사용됩니다: - -```python -def create_document_suitability_chain(llm): - """ - 문서 적합성 평가 체인을 생성합니다. - - Returns: - Runnable: invoke({"question": str, "tables": dict}) -> {"results": DocumentSuitability[]} - """ - prompt = get_prompt_template("document_suitability_prompt") - doc_prompt = ChatPromptTemplate.from_messages( - [SystemMessagePromptTemplate.from_template(prompt)] - ) - return doc_prompt | llm.with_structured_output(DocumentSuitabilityList) -``` - -**사용 흐름**: -1. 사용자 질문과 검색된 테이블 메타데이터를 입력으로 받음 -2. LLM이 각 테이블에 대한 적합도 점수와 평가 결과를 포함한 `DocumentSuitabilityList` 객체를 반환 -3. 가장 적합한 테이블을 선택하거나 적합도가 낮은 경우 사용자에게 알림 - -### 구조화 출력 활용 - -두 모델 모두 LangChain의 `with_structured_output()` 메서드와 함께 사용되어 LLM의 출력을 자동으로 Pydantic 모델로 변환합니다. 이를 통해: - -- 타입 안전성 보장 -- 자동 검증 및 직렬화 -- 명확한 API 계약 - -을 제공합니다. - diff --git a/utils/llm/output_schema/document_suitability.py b/utils/llm/output_schema/document_suitability.py deleted file mode 100644 index 7b4c11a..0000000 --- a/utils/llm/output_schema/document_suitability.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -DocumentSuitability 출력 모델. - -LLM 구조화 출력으로부터 테이블별 적합성 평가 결과를 표현하는 Pydantic 모델입니다. -최상위는 테이블명(string) -> 평가 객체 매핑을 담는 Root 모델입니다. -""" - -from typing import List - -from pydantic import BaseModel, Field - - -class DocumentSuitability(BaseModel): - """ - 단일 테이블에 대한 적합성 평가 결과. - """ - - table_name: str = Field(description="테이블명") - score: float = Field(description="0.0~1.0 사이의 적합도 점수") - reason: str = Field(description="한국어 한두 문장 근거") - matched_columns: List[str] = Field( - default_factory=list, description="질문과 직접 연관된 컬럼명 목록" - ) - missing_entities: List[str] = Field( - default_factory=list, description="부족한 엔티티/지표/기간 등" - ) - - -class DocumentSuitabilityList(BaseModel): - """ - 문서 적합성 평가 결과 리스트 래퍼. - - OpenAI Structured Outputs 호환을 위해 명시적 최상위 키(`results`)를 둡니다. - """ - - results: List[DocumentSuitability] = Field(description="평가 결과 목록") diff --git a/utils/llm/output_schema/question_suitability.py b/utils/llm/output_schema/question_suitability.py deleted file mode 100644 index 210a307..0000000 --- a/utils/llm/output_schema/question_suitability.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -QuestionSuitability 출력 모델. - -LLM 구조화 출력으로부터 SQL 적합성 판단 결과를 표현하는 Pydantic 모델입니다. -""" - -from pydantic import BaseModel, Field - - -class QuestionSuitability(BaseModel): - """ - SQL 생성 적합성 결과 모델. - - LLM 구조화 출력으로 직렬화 가능한 필드를 정의합니다. - """ - - reason: str = Field(description="보완/설명 사유 요약") - missing_entities: list[str] = Field( - default_factory=list, description="질문에서 누락된 핵심 엔터티/기간 등" - ) - requires_data_science: bool = Field( - default=False, description="SQL을 넘어 ML/통계 분석이 필요한지 여부" - ) diff --git a/utils/llm/retrieval.py b/utils/llm/retrieval.py deleted file mode 100644 index 0b5d916..0000000 --- a/utils/llm/retrieval.py +++ /dev/null @@ -1,104 +0,0 @@ -import os - -from langchain.retrievers import ContextualCompressionRetriever -from langchain.retrievers.document_compressors import CrossEncoderReranker -from langchain_community.cross_encoders import HuggingFaceCrossEncoder -from transformers import AutoModelForSequenceClassification, AutoTokenizer - -from utils.llm.vectordb import get_vector_db - - -def load_reranker_model(device: str = "cpu"): - """한국어 reranker 모델을 로드하거나 다운로드합니다.""" - local_model_path = os.path.join(os.getcwd(), "ko_reranker_local") - - # 로컬에 저장된 모델이 있으면 불러오고, 없으면 다운로드 후 저장 - if os.path.exists(local_model_path) and os.path.isdir(local_model_path): - print("🔄 ko-reranker 모델 로컬에서 로드 중...") - else: - print("⬇️ ko-reranker 모델 다운로드 및 저장 중...") - model = AutoModelForSequenceClassification.from_pretrained( - "Dongjin-kr/ko-reranker" - ) - tokenizer = AutoTokenizer.from_pretrained("Dongjin-kr/ko-reranker") - model.save_pretrained(local_model_path) - tokenizer.save_pretrained(local_model_path) - - return HuggingFaceCrossEncoder( - model_name=local_model_path, - model_kwargs={"device": device}, - ) - - -def get_retriever(retriever_name: str = "기본", top_n: int = 5, device: str = "cpu"): - """검색기 타입에 따라 적절한 검색기를 생성합니다. - - Args: - retriever_name: 사용할 검색기 이름 ("기본", "재순위", 등) - top_n: 반환할 상위 결과 개수 - """ - print(device) - retrievers = { - "기본": lambda: get_vector_db().as_retriever(search_kwargs={"k": top_n}), - "Reranker": lambda: ContextualCompressionRetriever( - base_compressor=CrossEncoderReranker( - model=load_reranker_model(device), top_n=top_n - ), - base_retriever=get_vector_db().as_retriever(search_kwargs={"k": top_n}), - ), - } - - if retriever_name not in retrievers: - print( - f"경고: '{retriever_name}' 검색기를 찾을 수 없습니다. 기본 검색기를 사용합니다." - ) - retriever_name = "기본" - - return retrievers[retriever_name]() - - -def search_tables( - query: str, retriever_name: str = "기본", top_n: int = 5, device: str = "cpu" -): - """쿼리에 맞는 테이블 정보를 검색합니다.""" - if retriever_name == "기본": - db = get_vector_db() - doc_res = db.similarity_search(query, k=top_n) - else: - retriever = get_retriever( - retriever_name=retriever_name, top_n=top_n, device=device - ) - doc_res = retriever.invoke(query) - - # 결과를 사전 형태로 변환 - documents_dict = {} - for doc in doc_res: - lines = doc.page_content.split("\n") - - # 테이블명 및 설명 추출 - table_name, table_desc = lines[0].split(": ", 1) - - # 섹션별로 정보 추출 (테이블/컬럼만 사용) - columns = {} - current_section = None - - for i, line in enumerate(lines[1:], 1): - line = line.strip() - - # 섹션 헤더 확인 - if line == "Columns:": - current_section = "columns" - continue - - # 각 섹션의 내용 파싱 - if current_section == "columns" and ": " in line: - col_name, col_desc = line.split(": ", 1) - columns[col_name.strip()] = col_desc.strip() - - # 딕셔너리 저장 - documents_dict[table_name] = { - "table_description": table_desc.strip(), - **columns, # 컬럼 정보 추가 - } - - return documents_dict diff --git a/utils/llm/tools/chatbot_tool.py b/utils/llm/tools/chatbot_tool.py index 9c496f0..1c3aad8 100644 --- a/utils/llm/tools/chatbot_tool.py +++ b/utils/llm/tools/chatbot_tool.py @@ -58,11 +58,26 @@ def search_database_tables( - "사용 가능한 컬럼을 보여줘" - SQL 쿼리를 생성하기 전에 스키마 정보가 필요할 때 """ - from utils.llm.retrieval import search_tables + try: + import os + + from lang2sql.components.retrieval.keyword_retriever import KeywordRetriever + from lang2sql.integrations.catalog.datahub_ import DataHubCatalogLoader - return search_tables( - query=query, retriever_name=retriever_name, top_n=top_n, device=device - ) + gms_server = os.getenv("DATAHUB_SERVER", "http://localhost:8080") + loader = DataHubCatalogLoader(gms_server=gms_server) + catalog = loader.load() + retriever = KeywordRetriever(catalog=catalog) + results = retriever.run(query, top_k=top_n) + return { + entry["name"]: { + "table_description": entry.get("description", ""), + **entry.get("columns", {}), + } + for entry in results + } + except Exception as e: + return {"error": True, "message": f"테이블 검색 중 오류 발생: {str(e)}"} def _simplify_glossary_data(glossary_data): diff --git a/utils/llm/vectordb/README.md b/utils/llm/vectordb/README.md deleted file mode 100644 index 356ffcd..0000000 --- a/utils/llm/vectordb/README.md +++ /dev/null @@ -1,223 +0,0 @@ -## utils.llm.vectordb 개요 - -Lang2SQL 파이프라인에서 테이블 메타데이터를 벡터화하여 저장하고 검색하기 위한 벡터 데이터베이스 모듈입니다. FAISS와 pgvector 두 가지 백엔드를 지원하며, 환경변수를 통해 선택이 가능합니다. - -### 파일 구조 - -``` -utils/llm/vectordb/ -├── __init__.py # 모듈 진입점, get_vector_db 함수 export -├── factory.py # VectorDB 팩토리 - 타입에 따라 적절한 인스턴스 생성 -├── faiss_db.py # FAISS 벡터 데이터베이스 구현 -└── pgvector_db.py # pgvector (PostgreSQL) 벡터 데이터베이스 구현 -``` - -### 각 파일 상세 설명 - -#### __init__.py - -**목적**: 벡터DB 모듈의 공개 인터페이스 정의 - -**Export 함수**: -- `get_vector_db`: 환경변수 기반으로 적절한 벡터DB 인스턴스 반환 - -#### factory.py - -**목적**: VectorDB 타입과 위치에 따라 적절한 VectorDB 인스턴스를 생성하는 팩토리 - -**주요 함수**: - -1. **`get_vector_db(vectordb_type=None, vectordb_location=None)`** - - 환경변수 또는 파라미터로 VectorDB 타입과 위치를 받아 적절한 인스턴스 반환 - - `vectordb_type`: "faiss" 또는 "pgvector" (기본: 환경변수 `VECTORDB_TYPE`, fallback: "faiss") - - `vectordb_location`: - - FAISS: 디렉토리 경로 - - pgvector: PostgreSQL 연결 문자열 - - 기본: 환경변수 `VECTORDB_LOCATION` - - 반환: FAISS 또는 PGVector 인스턴스 - - 에러: 지원하지 않는 타입 시 ValueError 발생 - -**의존성**: -- `utils.llm.vectordb.faiss_db.get_faiss_vector_db`: FAISS 인스턴스 생성 -- `utils.llm.vectordb.pgvector_db.get_pgvector_db`: PGVector 인스턴스 생성 - -#### faiss_db.py - -**목적**: FAISS 벡터 데이터베이스 구현 (로컬 디스크 기반) - -**주요 함수**: - -1. **`get_faiss_vector_db(vectordb_path=None)`** - - FAISS 벡터 데이터베이스를 로드하거나 새로 생성 - - `vectordb_path`: 저장 경로 (기본: `dev/table_info_db`) - - 동작 방식: - - 기존 DB가 있으면 `FAISS.load_local()`로 로드 - - 없으면 `get_info_from_db()`로 문서 수집 후 `FAISS.from_documents()` 생성 및 저장 - - 반환: FAISS 벡터스토어 인스턴스 - -**의존성**: -- `langchain_community.vectorstores.FAISS`: LangChain FAISS 래퍼 -- `utils.llm.core.get_embeddings`: 임베딩 모델 로드 -- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집 - -**특징**: -- 로컬 디스크에 저장되어 네트워크 연결 불필요 -- 빠른 검색 성능 -- 싱글 머신 환경에 최적화 - -#### pgvector_db.py - -**목적**: pgvector를 활용한 PostgreSQL 벡터 데이터베이스 구현 - -**주요 함수**: - -1. **`get_pgvector_db(connection_string=None, collection_name=None)`** - - pgvector 벡터 데이터베이스를 로드하거나 새로 생성 - - `connection_string`: PostgreSQL 연결 문자열 (기본: 환경변수 조합) - - `collection_name`: 컬렉션 이름 (기본: `lang2sql_table_info_db`) - - 환경변수 (기본값): - - `PGVECTOR_HOST`: "localhost" - - `PGVECTOR_PORT`: "5432" - - `PGVECTOR_USER`: "postgres" - - `PGVECTOR_PASSWORD`: "postgres" - - `PGVECTOR_DATABASE`: "postgres" - - `PGVECTOR_COLLECTION`: "lang2sql_table_info_db" - - 동작 방식: - - 기존 컬렉션이 있고 비어있지 않으면 로드 - - 없거나 비어있으면 `get_info_from_db()`로 문서 수집 후 `PGVector.from_documents()` 생성 - - 반환: PGVector 벡터스토어 인스턴스 - -2. **`_check_collection_exists(connection_string, collection_name)`** - - PostgreSQL에서 컬렉션 존재 여부 확인 - - `langchain_pg_embedding` 테이블에서 collection_name 조회 - - 반환: bool (존재 여부) - -**의존성**: -- `langchain_postgres.vectorstores.PGVector`: LangChain pgvector 래퍼 -- `psycopg2`: PostgreSQL 연결 -- `utils.llm.core.get_embeddings`: 임베딩 모델 로드 -- `utils.llm.tools.get_info_from_db`: DataHub에서 테이블 메타데이터 수집 - -**특징**: -- PostgreSQL 데이터베이스에 저장되어 다중 서버 환경에 적합 -- ACID 트랜잭션 지원 -- 확장 가능한 인프라 -- 네트워크 연결 필요 - -### 사용 방법 - -#### 1. 기본 사용법 (retrieval.py에서 실제 사용) - -```python -from utils.llm.vectordb import get_vector_db - -# 환경변수 기반으로 적절한 벡터DB 로드 -db = get_vector_db() - -# 유사도 검색 -documents = db.similarity_search("고객 테이블", k=5) - -# Retriever 인터페이스 사용 -retriever = db.as_retriever(search_kwargs={"k": 5}) -results = retriever.invoke("매출 관련 테이블") -``` - -#### 2. FAISS 명시적 사용 - -```python -from utils.llm.vectordb.factory import get_vector_db - -# FAISS 타입 지정 -db = get_vector_db(vectordb_type="faiss", vectordb_location="./my_faiss_db") - -# 검색 -results = db.similarity_search("사용자 정보", k=10) -``` - -#### 3. pgvector 명시적 사용 - -```python -from utils.llm.vectordb.factory import get_vector_db - -# pgvector 타입 지정 -connection_string = "postgresql://user:password@localhost:5432/mydb" -db = get_vector_db(vectordb_type="pgvector", vectordb_location=connection_string) - -# 검색 -results = db.similarity_search("주문 테이블", k=5) -``` - -#### 4. 통합 흐름 (Lang2SQL 파이프라인 내) - -`utils/llm/retrieval.py`의 `search_tables()` 함수에서 사용: - -1. `get_vector_db()`로 벡터DB 로드 (환경변수 기반) -2. `similarity_search()` 또는 `retriever.invoke()`로 유사도 기반 검색 -3. 결과를 테이블/컬럼 정보 딕셔너리로 파싱 및 반환 - -**경로**: `utils/llm/retrieval.py` (60-104번째 줄) - -#### 5. CLI 환경변수 설정 - -```bash -# FAISS 사용 -export VECTORDB_TYPE=faiss -export VECTORDB_LOCATION=./dev/table_info_db # 선택사항 - -# pgvector 사용 -export VECTORDB_TYPE=pgvector -export PGVECTOR_HOST=localhost -export PGVECTOR_PORT=5432 -export PGVECTOR_USER=postgres -export PGVECTOR_PASSWORD=postgres -export PGVECTOR_DATABASE=postgres -export PGVECTOR_COLLECTION=lang2sql_table_info_db -``` - -### import 관계 - -**import하는 파일**: -- `utils/llm/retrieval.py`: `from utils.llm.vectordb import get_vector_db` - -**내부 의존성**: -- `utils/llm/core/factory.py`: `get_embeddings()` - 임베딩 모델 로드 -- `utils/llm/tools/datahub.py`: `get_info_from_db()` - DataHub 메타데이터 수집 - -**외부 의존성**: -- `langchain_community.vectorstores.FAISS`: FAISS 벡터스토어 -- `langchain_postgres.vectorstores.PGVector`: pgvector 벡터스토어 -- `psycopg2`: PostgreSQL 연결 (pgvector 전용) - -### 환경 변수 요약 - -#### VectorDB 타입 선택 -- **`VECTORDB_TYPE`**: "faiss" 또는 "pgvector" (기본: "faiss") - -#### FAISS 환경변수 -- **`VECTORDB_LOCATION`**: FAISS 저장 디렉토리 경로 (기본: `./dev/table_info_db`) - -#### pgvector 환경변수 -- **`PGVECTOR_HOST`**: PostgreSQL 호스트 (기본: "localhost") -- **`PGVECTOR_PORT`**: PostgreSQL 포트 (기본: "5432") -- **`PGVECTOR_USER`**: PostgreSQL 사용자 (기본: "postgres") -- **`PGVECTOR_PASSWORD`**: PostgreSQL 비밀번호 (기본: "postgres") -- **`PGVECTOR_DATABASE`**: PostgreSQL 데이터베이스 (기본: "postgres") -- **`PGVECTOR_COLLECTION`**: 컬렉션 이름 (기본: "lang2sql_table_info_db") -- **`EMBEDDING_PROVIDER`**: 임베딩 모델 공급자 (필수, 모든 타입 공통) - -### 주요 특징 - -1. **이중 백엔드 지원**: FAISS(로컬) 및 pgvector(PostgreSQL) 자유 선택 -2. **자동 초기화**: 벡터DB가 없으면 DataHub에서 자동으로 생성 -3. **환경변수 기반 설정**: 코드 수정 없이 실행 시점에 선택 가능 -4. **LangChain 통합**: 표준 VectorStore 인터페이스 제공 -5. **유사도 검색**: 테이블 메타데이터에 대한 의미 기반 검색 - -### 개선 가능 영역 - -- 다른 벡터DB 지원 (Qdrant 등) -- 증분 인덱싱 지원 -- 벡터DB 버전 관리 -- 성능 최적화 (인덱스 튜닝) -- 모니터링 및 로깅 강화 - diff --git a/utils/llm/vectordb/__init__.py b/utils/llm/vectordb/__init__.py deleted file mode 100644 index 6265b0f..0000000 --- a/utils/llm/vectordb/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -VectorDB 모듈 - FAISS와 pgvector를 지원하는 벡터 데이터베이스 추상화 -""" - -from utils.llm.vectordb.factory import get_vector_db - -__all__ = ["get_vector_db"] diff --git a/utils/llm/vectordb/factory.py b/utils/llm/vectordb/factory.py deleted file mode 100644 index 942a443..0000000 --- a/utils/llm/vectordb/factory.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -VectorDB 팩토리 모듈 - 환경 변수에 따라 적절한 VectorDB 인스턴스를 생성 -""" - -import os -from typing import Optional - -from utils.llm.vectordb.faiss_db import get_faiss_vector_db -from utils.llm.vectordb.pgvector_db import get_pgvector_db - - -def get_vector_db( - vectordb_type: Optional[str] = None, vectordb_location: Optional[str] = None -): - """ - VectorDB 타입과 위치에 따라 적절한 VectorDB 인스턴스를 반환합니다. - - Args: - vectordb_type: VectorDB 타입 ("faiss" 또는 "pgvector"). None인 경우 환경 변수에서 읽음. - vectordb_location: VectorDB 위치 (FAISS: 디렉토리 경로, pgvector: 연결 문자열). None인 경우 환경 변수에서 읽음. - - Returns: - VectorDB 인스턴스 (FAISS 또는 PGVector) - """ - if vectordb_type is None: - vectordb_type = os.getenv("VECTORDB_TYPE", "faiss").lower() - - if vectordb_location is None: - vectordb_location = os.getenv("VECTORDB_LOCATION") - - if vectordb_type == "faiss": - return get_faiss_vector_db(vectordb_location) - elif vectordb_type == "pgvector": - return get_pgvector_db(vectordb_location) - else: - raise ValueError( - f"지원하지 않는 VectorDB 타입: {vectordb_type}. 'faiss' 또는 'pgvector'를 사용하세요." - ) diff --git a/utils/llm/vectordb/faiss_db.py b/utils/llm/vectordb/faiss_db.py deleted file mode 100644 index d4754a5..0000000 --- a/utils/llm/vectordb/faiss_db.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -FAISS VectorDB 구현 -""" - -import os -from typing import Optional - -from langchain_community.vectorstores import FAISS - -from utils.llm.core import get_embeddings -from utils.llm.tools import get_info_from_db - - -def get_faiss_vector_db(vectordb_path: Optional[str] = None): - """FAISS 벡터 데이터베이스를 로드하거나 생성합니다.""" - embeddings = get_embeddings() - - # 기본 경로 설정 - if vectordb_path is None: - vectordb_path = os.path.join(os.getcwd(), "dev/table_info_db") - - try: - db = FAISS.load_local( - vectordb_path, - embeddings, - allow_dangerous_deserialization=True, - ) - except: - documents = get_info_from_db() - db = FAISS.from_documents(documents, embeddings) - db.save_local(vectordb_path) - print(f"VectorDB를 새로 생성했습니다: {vectordb_path}") - return db diff --git a/utils/llm/vectordb/pgvector_db.py b/utils/llm/vectordb/pgvector_db.py deleted file mode 100644 index d03f034..0000000 --- a/utils/llm/vectordb/pgvector_db.py +++ /dev/null @@ -1,81 +0,0 @@ -""" -pgvector VectorDB 구현 -""" - -import os -from typing import Optional - -import psycopg2 -from langchain_postgres.vectorstores import PGVector - -from utils.llm.core import get_embeddings -from utils.llm.tools import get_info_from_db - - -def _check_collection_exists(connection_string: str, collection_name: str) -> bool: - """PostgreSQL에서 collection이 존재하는지 확인합니다.""" - try: - # 연결 문자열에서 연결 정보 추출 - conn = psycopg2.connect(connection_string) - cursor = conn.cursor() - - # langchain_pg_embedding 테이블에서 collection_name이 존재하는지 확인 - cursor.execute( - "SELECT COUNT(*) FROM langchain_pg_embedding WHERE collection_name = %s", - (collection_name,), - ) - result = cursor.fetchone() - count = result[0] if result else 0 - - cursor.close() - conn.close() - - return count > 0 - except Exception as e: - print(f"Collection 존재 여부 확인 중 오류: {e}") - return False - - -def get_pgvector_db( - connection_string: Optional[str] = None, collection_name: Optional[str] = None -): - """pgvector 벡터 데이터베이스를 로드하거나 생성합니다.""" - embeddings = get_embeddings() - - if connection_string is None: - # 환경 변수에서 연결 정보 읽기 (기존 방식) - host = os.getenv("PGVECTOR_HOST", "localhost") - port = os.getenv("PGVECTOR_PORT", "5432") - user = os.getenv("PGVECTOR_USER", "postgres") - password = os.getenv("PGVECTOR_PASSWORD", "postgres") - database = os.getenv("PGVECTOR_DATABASE", "postgres") - connection_string = f"postgresql://{user}:{password}@{host}:{port}/{database}" - - if collection_name is None: - collection_name = os.getenv("PGVECTOR_COLLECTION", "lang2sql_table_info_db") - try: - vector_store = PGVector( - embeddings=embeddings, - collection_name=collection_name, - connection=connection_string, - ) - - results = vector_store.similarity_search("test", k=1) - if not results: - raise RuntimeError(f"Collection '{collection_name}' is empty") - - # 컬렉션이 존재하면 실제 검색도 진행해 볼 수 있습니다. - vector_store.similarity_search("test", k=1) - return vector_store - - except Exception as e: - print(f"exception: {e}") - # 컬렉션이 없거나 불러오기에 실패한 경우, 문서를 다시 인덱싱 - documents = get_info_from_db() - vector_store = PGVector.from_documents( - documents=documents, - embedding=embeddings, - connection=connection_string, - collection_name=collection_name, - ) - return vector_store From 8099c5a3f0ab0ec0251b299fc1b14c3c08ace2af Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:33:12 +0900 Subject: [PATCH 09/10] docs: add production test guide for v2 migration --- docs/production_test_guide.md | 1003 +++++++++++++++++++++++++++++++++ 1 file changed, 1003 insertions(+) create mode 100644 docs/production_test_guide.md diff --git a/docs/production_test_guide.md b/docs/production_test_guide.md new file mode 100644 index 0000000..e40bb2a --- /dev/null +++ b/docs/production_test_guide.md @@ -0,0 +1,1003 @@ +# Production Test Guide — v2 Migration + +이 문서는 v2 마이그레이션에서 수행된 모든 변경 사항을 **실제 API 키와 실제 DB**를 사용해 프로덕션 수준에서 검증하는 가이드입니다. + +--- + +## 전제 조건 + +```bash +# 의존성 설치 +uv sync --group dev + +# .env 설정 (아래 각 섹션에서 사용할 프로바이더 항목을 활성화) +cp .env.example .env +``` + +모든 Python 스니펫은 프로젝트 루트에서 실행합니다: + +```bash +cd /path/to/lang2sql +``` + +--- + +## 1. LLM 통합 — 7개 프로바이더 + +각 프로바이더는 `.env`에서 해당 항목을 설정하고 독립적으로 검증합니다. + +### 1-A. Anthropic + +``` +# .env +LLM_PROVIDER=anthropic +ANTHROPIC_API_KEY=sk-ant-... +ANTHROPIC_LLM_MODEL=claude-sonnet-4-6 +``` + +```python +from lang2sql.integrations.llm.anthropic_ import AnthropicLLM +import os + +llm = AnthropicLLM(model="claude-sonnet-4-6", api_key=os.getenv("ANTHROPIC_API_KEY")) +resp = llm.invoke([{"role": "user", "content": "Respond with just 'OK'"}]) +assert resp.strip() == "OK", f"Unexpected: {resp}" +print("Anthropic LLM ✓") +``` + +**확인 포인트** +- `invoke()` 반환값이 `str` 타입 +- system 메시지가 `role: system`으로 분리되어 Anthropic Messages API에 전달됨 + +--- + +### 1-B. OpenAI + +``` +# .env +LLM_PROVIDER=openai +OPEN_AI_KEY=sk-proj-... +OPEN_AI_LLM_MODEL=gpt-4o +``` + +```python +from lang2sql.integrations.llm.openai_ import OpenAILLM +import os + +llm = OpenAILLM(model="gpt-4o", api_key=os.getenv("OPEN_AI_KEY")) +resp = llm.invoke([{"role": "user", "content": "Respond with just 'OK'"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("OpenAI LLM ✓") +``` + +--- + +### 1-C. Azure OpenAI + +``` +# .env +LLM_PROVIDER=azure +AZURE_OPENAI_LLM_ENDPOINT=https://RESOURCE.openai.azure.com/ +AZURE_OPENAI_LLM_KEY=... +AZURE_OPENAI_LLM_MODEL=gpt4o # Azure deployment name +AZURE_OPENAI_LLM_API_VERSION=2024-07-01-preview +``` + +```python +from lang2sql.integrations.llm.azure_ import AzureOpenAILLM +import os + +llm = AzureOpenAILLM( + azure_deployment=os.environ["AZURE_OPENAI_LLM_MODEL"], + azure_endpoint=os.environ["AZURE_OPENAI_LLM_ENDPOINT"], + api_version=os.getenv("AZURE_OPENAI_LLM_API_VERSION", "2024-07-01-preview"), + api_key=os.getenv("AZURE_OPENAI_LLM_KEY"), +) +resp = llm.invoke([{"role": "user", "content": "Respond with just 'OK'"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("Azure OpenAI LLM ✓") +``` + +--- + +### 1-D. Google Gemini + +``` +# .env +LLM_PROVIDER=gemini +GEMINI_API_KEY=AIza... +GEMINI_LLM_MODEL=gemini-2.0-flash-lite +``` + +```python +from lang2sql.integrations.llm.gemini_ import GeminiLLM +import os + +llm = GeminiLLM(model="gemini-2.0-flash-lite", api_key=os.getenv("GEMINI_API_KEY")) +resp = llm.invoke([{"role": "user", "content": "Respond with just 'OK'"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("Gemini LLM ✓") +``` + +--- + +### 1-E. AWS Bedrock + +``` +# .env +LLM_PROVIDER=bedrock +AWS_BEDROCK_LLM_ACCESS_KEY_ID=AKI... +AWS_BEDROCK_LLM_SECRET_ACCESS_KEY=... +AWS_BEDROCK_LLM_REGION=us-east-1 +AWS_BEDROCK_LLM_MODEL=anthropic.claude-3-5-sonnet-20241022-v2:0 +``` + +```python +from lang2sql.integrations.llm.bedrock_ import BedrockLLM +import os + +llm = BedrockLLM( + model=os.environ["AWS_BEDROCK_LLM_MODEL"], + aws_access_key_id=os.getenv("AWS_BEDROCK_LLM_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_BEDROCK_LLM_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_BEDROCK_LLM_REGION", "us-east-1"), +) +resp = llm.invoke([{"role": "user", "content": "Respond with just 'OK'"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("Bedrock LLM ✓") +``` + +**확인 포인트**: Bedrock Converse API 포맷 — `role: system`이 `system` 블록으로 분리되는지 확인 + +```python +# system 메시지 분리 확인 +resp = llm.invoke([ + {"role": "system", "content": "Always respond in one word."}, + {"role": "user", "content": "Say hello"}, +]) +assert len(resp.split()) <= 3, f"System prompt not applied: {resp}" +print("Bedrock system message separation ✓") +``` + +--- + +### 1-F. Ollama (로컬) + +``` +# Ollama 서버 실행 필요 +# brew install ollama && ollama serve +# ollama pull llama3.2 + +# .env +LLM_PROVIDER=ollama +OLLAMA_LLM_BASE_URL=http://localhost:11434 +OLLAMA_LLM_MODEL=llama3.2 +``` + +```python +from lang2sql.integrations.llm.ollama_ import OllamaLLM +import os + +llm = OllamaLLM( + model=os.environ["OLLAMA_LLM_MODEL"], + base_url=os.getenv("OLLAMA_LLM_BASE_URL", "http://localhost:11434"), +) +resp = llm.invoke([{"role": "user", "content": "Say hello in one word"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("Ollama LLM ✓") +``` + +--- + +### 1-G. HuggingFace Inference API + +``` +# .env +LLM_PROVIDER=huggingface +HUGGING_FACE_LLM_REPO_ID=mistralai/Mistral-7B-Instruct-v0.3 +HUGGING_FACE_LLM_API_TOKEN=hf_... +# HUGGING_FACE_LLM_ENDPOINT=https://... (Dedicated Endpoint 사용 시) +``` + +```python +from lang2sql.integrations.llm.huggingface_ import HuggingFaceLLM +import os + +llm = HuggingFaceLLM( + repo_id=os.getenv("HUGGING_FACE_LLM_REPO_ID"), + api_token=os.getenv("HUGGING_FACE_LLM_API_TOKEN"), +) +resp = llm.invoke([{"role": "user", "content": "Say hello"}]) +assert isinstance(resp, str) and len(resp) > 0 +print("HuggingFace LLM ✓") +``` + +--- + +## 2. Embedding 통합 — 6개 프로바이더 + +### 2-A. OpenAI Embedding + +```python +from lang2sql.integrations.embedding.openai_ import OpenAIEmbedding +import os + +emb = OpenAIEmbedding( + model="text-embedding-3-small", + api_key=os.getenv("OPEN_AI_KEY"), +) +vec = emb.embed_query("주문 테이블의 주문 ID") +assert isinstance(vec, list) and len(vec) == 1536 +print(f"OpenAI Embedding ✓ (dim={len(vec)})") + +vecs = emb.embed_texts(["orders", "customers"]) +assert len(vecs) == 2 and len(vecs[0]) == 1536 +print("OpenAI batch embed ✓") +``` + +--- + +### 2-B. Azure OpenAI Embedding + +``` +# .env +EMBEDDING_PROVIDER=azure +AZURE_OPENAI_EMBEDDING_ENDPOINT=https://RESOURCE.openai.azure.com/ +AZURE_OPENAI_EMBEDDING_KEY=... +AZURE_OPENAI_EMBEDDING_MODEL=textembeddingada002 +AZURE_OPENAI_EMBEDDING_API_VERSION=2023-09-15-preview +``` + +```python +from lang2sql.integrations.embedding.azure_ import AzureOpenAIEmbedding +import os + +emb = AzureOpenAIEmbedding( + azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"], + azure_endpoint=os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"], + api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"), + api_key=os.getenv("AZURE_OPENAI_EMBEDDING_KEY"), +) +vec = emb.embed_query("주문 데이터") +assert isinstance(vec, list) and len(vec) > 0 +print(f"Azure Embedding ✓ (dim={len(vec)})") +``` + +--- + +### 2-C. Ollama Embedding + +``` +# .env +EMBEDDING_PROVIDER=ollama +OLLAMA_EMBEDDING_MODEL=nomic-embed-text +OLLAMA_EMBEDDING_BASE_URL=http://localhost:11434 +``` + +```python +# ollama pull nomic-embed-text 먼저 실행 필요 +from lang2sql.integrations.embedding.ollama_ import OllamaEmbedding +import os + +emb = OllamaEmbedding( + model=os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text"), + base_url=os.getenv("OLLAMA_EMBEDDING_BASE_URL", "http://localhost:11434"), +) +vec = emb.embed_query("test") +assert isinstance(vec, list) and len(vec) > 0 +print(f"Ollama Embedding ✓ (dim={len(vec)})") +``` + +--- + +### 2-D. AWS Bedrock Embedding + +``` +# .env +EMBEDDING_PROVIDER=bedrock +AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID=... +AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY=... +AWS_BEDROCK_EMBEDDING_REGION=us-east-1 +AWS_BEDROCK_EMBEDDING_MODEL=amazon.titan-embed-text-v2:0 +``` + +```python +from lang2sql.integrations.embedding.bedrock_ import BedrockEmbedding +import os + +emb = BedrockEmbedding( + model_id=os.getenv("AWS_BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0"), + aws_access_key_id=os.getenv("AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_BEDROCK_EMBEDDING_REGION", "us-east-1"), +) +vec = emb.embed_query("주문 데이터") +assert isinstance(vec, list) and len(vec) == 1024 # Titan v2 기본 차원 +print(f"Bedrock Embedding ✓ (dim={len(vec)})") +``` + +--- + +### 2-E. Google Gemini Embedding + +``` +# .env +EMBEDDING_PROVIDER=gemini +GEMINI_EMBEDDING_API_KEY=AIza... +EMBEDDING_MODEL=models/embedding-001 +``` + +```python +from lang2sql.integrations.embedding.gemini_ import GeminiEmbedding +import os + +emb = GeminiEmbedding( + model=os.getenv("EMBEDDING_MODEL", "models/embedding-001"), + api_key=os.getenv("GEMINI_EMBEDDING_API_KEY"), +) +vec = emb.embed_query("주문 데이터") +assert isinstance(vec, list) and len(vec) == 768 +print(f"Gemini Embedding ✓ (dim={len(vec)})") +``` + +--- + +### 2-F. HuggingFace Embedding (로컬 모델) + +``` +# .env +EMBEDDING_PROVIDER=huggingface +HUGGING_FACE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 +``` + +```python +# pip install sentence-transformers 필요 +from lang2sql.integrations.embedding.huggingface_ import HuggingFaceEmbedding +import os + +emb = HuggingFaceEmbedding( + model=os.getenv("HUGGING_FACE_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") +) +vec = emb.embed_query("주문 데이터") +assert isinstance(vec, list) and len(vec) == 384 # all-MiniLM-L6-v2 차원 +print(f"HuggingFace Embedding ✓ (dim={len(vec)})") +``` + +--- + +## 3. 환경변수 기반 Factory (`build_*_from_env`) + +`.env`에 원하는 프로바이더 설정을 넣고 아래를 실행합니다. + +```python +from dotenv import load_dotenv +load_dotenv() + +from lang2sql.factory import build_llm_from_env, build_embedding_from_env, build_db_from_env + +# LLM +llm = build_llm_from_env() +resp = llm.invoke([{"role": "user", "content": "Say 'ready'"}]) +assert isinstance(resp, str) +print(f"build_llm_from_env ✓ → {resp[:40]}") + +# Embedding +emb = build_embedding_from_env() +vec = emb.embed_query("test") +assert isinstance(vec, list) and len(vec) > 0 +print(f"build_embedding_from_env ✓ dim={len(vec)}") + +# DB +db = build_db_from_env() +# DB_TYPE=sqlite 인 경우 간단한 쿼리 실행 확인 +rows = db.execute("SELECT 1 AS val") +assert rows[0]["val"] == 1 +print("build_db_from_env ✓") +``` + +--- + +## 4. 고급 컴포넌트 — 실제 LLM 호출 + +아래 예제는 Anthropic LLM으로 작성됐으나 어떤 프로바이더든 사용 가능합니다. + +### 4-A. QuestionGate + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_llm_from_env +from lang2sql.components.gate.question_gate import QuestionGate +from lang2sql.core.catalog import GateResult + +llm = build_llm_from_env() +gate = QuestionGate(llm=llm) + +# 정상 쿼리 → suitable=True +result: GateResult = gate("지난달 주문 건수를 알려줘") +assert result.suitable is True, f"suitable False: {result.reason}" +print(f"QuestionGate (suitable) ✓ — reason: {result.reason}") + +# 비적합 쿼리 → suitable=False +result2: GateResult = gate("회사 전략 보고서를 통계 모델로 분석해줘") +assert result2.suitable is False, "Expected unsuitable for data-science request" +print(f"QuestionGate (not suitable) ✓ — reason: {result2.reason}") +``` + +--- + +### 4-B. TableSuitabilityEvaluator + +```python +from lang2sql.factory import build_llm_from_env +from lang2sql.components.gate.table_suitability import TableSuitabilityEvaluator + +llm = build_llm_from_env() +evaluator = TableSuitabilityEvaluator(llm=llm) + +catalog = [ + {"name": "orders", "description": "주문 정보 테이블", "columns": {"order_id": "주문 ID", "amount": "금액", "created_at": "생성일"}}, + {"name": "users", "description": "사용자 정보 테이블", "columns": {"user_id": "유저 ID", "name": "이름"}}, + {"name": "products", "description": "상품 정보 테이블", "columns": {"product_id": "상품 ID", "price": "가격"}}, +] + +filtered = evaluator("지난달 주문 건수", catalog) +# orders 테이블은 반드시 포함되어야 함 +names = [t["name"] for t in filtered] +assert "orders" in names, f"orders not found in {names}" +print(f"TableSuitabilityEvaluator ✓ → {names}") +``` + +--- + +### 4-C. QuestionProfiler + +```python +from lang2sql.factory import build_llm_from_env +from lang2sql.components.enrichment.question_profiler import QuestionProfiler +from lang2sql.core.catalog import QuestionProfile + +llm = build_llm_from_env() +profiler = QuestionProfiler(llm=llm) + +profile: QuestionProfile = profiler("월별 주문 금액 추이") +assert hasattr(profile, "is_timeseries") +assert hasattr(profile, "intent_type") +print(f"QuestionProfiler ✓ — is_timeseries={profile.is_timeseries}, intent={profile.intent_type}") + +profile2: QuestionProfile = profiler("상위 10개 고객 목록") +assert hasattr(profile2, "has_ranking") +print(f"QuestionProfiler ✓ — has_ranking={profile2.has_ranking}") +``` + +--- + +### 4-D. ContextEnricher + +```python +from lang2sql.factory import build_llm_from_env +from lang2sql.components.enrichment.context_enricher import ContextEnricher +from lang2sql.core.catalog import QuestionProfile + +llm = build_llm_from_env() +enricher = ContextEnricher(llm=llm) + +catalog = [ + {"name": "orders", "description": "주문 정보", "columns": {"order_id": "주문 ID", "amount": "금액", "created_at": "생성일"}}, +] +profile = QuestionProfile(is_aggregation=True, has_filter=True, intent_type="lookup") +enriched = enricher("지난달 주문 건수", catalog, profile) + +assert isinstance(enriched, str) and len(enriched) > 0 +print(f"ContextEnricher ✓ — enriched: {enriched[:100]}") +``` + +--- + +## 5. HybridRetriever (BM25 + Vector RRF) + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_embedding_from_env +from lang2sql.components.retrieval.hybrid import HybridRetriever + +emb = build_embedding_from_env() + +catalog = [ + {"name": "orders", "description": "주문 정보 테이블", "columns": {"order_id": "주문 ID", "amount": "금액", "created_at": "생성일"}}, + {"name": "customers", "description": "고객 정보 테이블", "columns": {"customer_id": "고객 ID", "name": "이름", "email": "이메일"}}, + {"name": "products", "description": "상품 정보 테이블", "columns": {"product_id": "상품 ID", "price": "가격"}}, + {"name": "inventory", "description": "재고 테이블", "columns": {"product_id": "상품 ID", "stock": "재고 수량"}}, +] + +retriever = HybridRetriever(catalog=catalog, embedding=emb, top_n=2) +result = retriever("지난달 주문 건수") + +assert len(result.schemas) <= 2 +names = [s["name"] for s in result.schemas] +assert "orders" in names, f"orders missing from {names}" +print(f"HybridRetriever ✓ → schemas={names}") + +# 비즈니스 문서 context 테스트 +from lang2sql.core.catalog import TextDocument +docs = [TextDocument(id="doc1", content="주문은 created_at 컬럼 기준으로 집계합니다.")] +retriever2 = HybridRetriever(catalog=catalog, embedding=emb, documents=docs, top_n=2) +result2 = retriever2("주문 날짜 기준 집계") +print(f"HybridRetriever with docs ✓ — context={result2.context}") +``` + +--- + +## 6. BaselineNL2SQL — End-to-End + +SQLite 예제 (가장 빠르게 검증 가능) + +```bash +# 테스트 DB 준비 +python - <<'EOF' +import sqlite3 +conn = sqlite3.connect("test_e2e.db") +conn.execute("CREATE TABLE IF NOT EXISTS orders (order_id INTEGER PRIMARY KEY, amount REAL, created_at TEXT)") +conn.execute("INSERT OR IGNORE INTO orders VALUES (1, 10000, '2024-01-15')") +conn.execute("INSERT OR IGNORE INTO orders VALUES (2, 20000, '2024-01-20')") +conn.execute("INSERT OR IGNORE INTO orders VALUES (3, 15000, '2024-02-05')") +conn.commit() +conn.close() +print("test_e2e.db 생성 완료") +EOF +``` + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_llm_from_env +from lang2sql.flows import BaselineNL2SQL +from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyDB + +catalog = [ + { + "name": "orders", + "description": "주문 정보 테이블", + "columns": {"order_id": "주문 ID", "amount": "주문 금액(원)", "created_at": "주문 생성일(YYYY-MM-DD)"}, + } +] + +llm = build_llm_from_env() +db = SQLAlchemyDB("sqlite:///test_e2e.db") + +pipeline = BaselineNL2SQL(catalog=catalog, llm=llm, db=db, db_dialect="sqlite") +rows = pipeline.run("전체 주문 건수") + +assert isinstance(rows, list) and len(rows) > 0 +print(f"BaselineNL2SQL ✓ — rows={rows}") +``` + +--- + +## 7. EnrichedNL2SQL — End-to-End (Full 7-Step Pipeline) + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_llm_from_env, build_embedding_from_env +from lang2sql.flows import EnrichedNL2SQL +from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyDB +from lang2sql.core.hooks import MemoryHook + +catalog = [ + { + "name": "orders", + "description": "주문 정보 테이블. 고객이 결제한 주문 기록.", + "columns": {"order_id": "주문 ID", "amount": "주문 금액(원)", "created_at": "주문 생성일(YYYY-MM-DD)"}, + } +] + +llm = build_llm_from_env() +emb = build_embedding_from_env() +db = SQLAlchemyDB("sqlite:///test_e2e.db") +hook = MemoryHook() + +pipeline = EnrichedNL2SQL( + catalog=catalog, + llm=llm, + db=db, + embedding=emb, + db_dialect="sqlite", + gate_enabled=True, + top_n=3, + hook=hook, +) + +rows = pipeline.run("전체 주문 건수를 알려줘") +assert isinstance(rows, list) and len(rows) > 0 +print(f"EnrichedNL2SQL ✓ — rows={rows}") + +# Hook 이벤트 확인 (QuestionGate ~ SQLExecutor까지 7단계 이벤트 발생 확인) +components = {e.component for e in hook.events} +print(f" → 실행된 컴포넌트: {components}") +assert "QuestionGate" in components +assert "HybridRetriever" in components +assert "SQLGenerator" in components +assert "SQLExecutor" in components +print(" → Hook 이벤트 ✓") +``` + +### 7-A. QuestionGate — ContractError 발생 확인 + +```python +from lang2sql.core.exceptions import ContractError +import pytest + +try: + pipeline.run("우리 회사 마케팅 전략을 ML 모델로 예측해줘") + print("WARNING: ContractError가 발생해야 합니다") +except ContractError as e: + print(f"ContractError ✓ — {e}") +``` + +### 7-B. gate_enabled=False — Gate 비활성화 확인 + +```python +pipeline_no_gate = EnrichedNL2SQL( + catalog=catalog, llm=llm, db=db, embedding=emb, + db_dialect="sqlite", gate_enabled=False, +) +rows2 = pipeline_no_gate.run("전체 주문 금액 합계") +assert isinstance(rows2, list) +print(f"EnrichedNL2SQL (no gate) ✓ — rows={rows2}") +``` + +--- + +## 8. CLI 명령어 + +`.env`가 올바르게 설정된 상태에서 실행합니다. + +### 8-A. Baseline 플로우 + +```bash +lang2sql query "전체 주문 건수" \ + --flow baseline \ + --dialect sqlite +``` + +**예상 출력**: JSON 배열 (결과 행) 또는 `(결과 없음)` + +--- + +### 8-B. Enriched 플로우 + +```bash +lang2sql query "지난 1월 주문 금액 합계" \ + --flow enriched \ + --dialect sqlite \ + --top-n 3 +``` + +--- + +### 8-C. Gate 비활성화 + +```bash +lang2sql query "전체 주문 건수" \ + --flow enriched \ + --no-gate \ + --dialect sqlite +``` + +--- + +### 8-D. 에러 케이스 확인 + +```bash +# LLM_PROVIDER를 잘못된 값으로 설정한 경우 +LLM_PROVIDER=unknown lang2sql query "test" +# 예상: ValueError: Unknown LLM_PROVIDER: 'unknown' +``` + +--- + +## 9. DataHub 카탈로그 브릿지 + +> DataHub GMS 서버가 실행 중이어야 합니다. + +``` +# .env +DATAHUB_SERVER=http://localhost:8080 +``` + +```python +import os +from dotenv import load_dotenv; load_dotenv() +from lang2sql.integrations.catalog.datahub_ import DataHubCatalogLoader + +loader = DataHubCatalogLoader(gms_server=os.getenv("DATAHUB_SERVER", "http://localhost:8080")) +catalog = loader.load() + +assert isinstance(catalog, list) +assert len(catalog) > 0, "DataHub에 테이블이 하나 이상 존재해야 합니다" + +first = catalog[0] +assert "name" in first and "description" in first and "columns" in first +print(f"DataHubCatalogLoader ✓ — {len(catalog)}개 테이블 로드") +print(f" 첫 번째: name={first['name']}, columns={list(first['columns'].keys())[:5]}") +``` + +### DataHub Catalog → EnrichedNL2SQL 연동 + +```python +from lang2sql.factory import build_llm_from_env, build_embedding_from_env, build_db_from_env + +llm = build_llm_from_env() +emb = build_embedding_from_env() +db = build_db_from_env() +pipeline = EnrichedNL2SQL( + catalog=catalog, # DataHub에서 로드한 catalog 사용 + llm=llm, db=db, embedding=emb, + gate_enabled=True, +) +rows = pipeline.run("유니크한 유저 수를 카운트해줘") +print(f"DataHub catalog + EnrichedNL2SQL ✓ — {rows}") +``` + +--- + +## 10. FAISSVectorStore (v2 벡터 스토어) + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_embedding_from_env +from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore + +emb = build_embedding_from_env() + +# 문서 임베딩 및 저장 +texts = ["주문 테이블: 고객 주문 정보를 저장합니다", "고객 테이블: 회원 정보를 저장합니다"] +vectors = emb.embed_texts(texts) + +store = FAISSVectorStore(index_path="/tmp/test_faiss.idx") +store.upsert(ids=["doc0", "doc1"], vectors=vectors) + +# 검색 +query_vec = emb.embed_query("주문 정보") +results = store.search(query_vec, k=2) + +assert len(results) > 0 +assert results[0][0] in ["doc0", "doc1"] +print(f"FAISSVectorStore ✓ — top result: {results[0]}") + +# 저장/로드 +store.save() +loaded = FAISSVectorStore.load("/tmp/test_faiss.idx") +results2 = loaded.search(query_vec, k=1) +assert results2[0][0] == results[0][0] +print("FAISSVectorStore save/load ✓") +``` + +--- + +## 11. Streamlit UI 수동 검증 + +```bash +lang2sql run-streamlit +# 또는 +streamlit run interface/streamlit_app.py +``` + +### 체크리스트 + +| 항목 | 확인 방법 | 통과 조건 | +|------|-----------|-----------| +| 홈 페이지 | `http://localhost:8501` 접속 | 에러 없이 로드 | +| Lang2SQL — Baseline | 워크플로우 체크박스 해제 → "쿼리 실행" | 결과 테이블 렌더링 | +| Lang2SQL — Enriched | 체크박스 선택 → "쿼리 실행" | 결과 테이블 렌더링 | +| Dialect 선택 | `sqlite` → `postgresql` 전환 | 드롭다운 변경 반영 | +| 오류 표시 | 연결 불가 DB 설정 후 실행 | `st.error()` 에러 박스 | +| ChatBot 페이지 | `🤖 ChatBot` 탭 클릭 | 에러 없이 로드 | +| 설정 페이지 | `⚙️ 설정` 탭 클릭 | 에러 없이 로드 | +| Graph Builder 페이지 없음 | 네비게이션 탭 확인 | 탭 목록에 없어야 함 | + +--- + +## 12. ChatBot — LangGraph + 수정된 `search_database_tables` + +> `DATAHUB_SERVER`가 설정되어 있어야 합니다. DataHub 없이 검색 시 에러 응답(`{"error": True, ...}`)을 반환합니다. + +```python +import os +from dotenv import load_dotenv; load_dotenv() + +# 12-A. 모듈 임포트 무결성 확인 (핵심: retrieval.py 삭제 이후 임포트 성공 확인) +from utils.llm.tools import search_database_tables, get_glossary_terms, get_query_examples +print("utils.llm.tools import ✓") + +from utils.llm.chatbot import ChatBot +print("utils.llm.chatbot import ✓") + +# 12-B. ChatBot 인스턴스 생성 +bot = ChatBot( + openai_api_key=os.getenv("OPEN_AI_KEY"), + model_name="gpt-4o-mini", + gms_server=os.getenv("DATAHUB_SERVER", "http://localhost:8080"), +) +print("ChatBot instance ✓") + +# 12-C. 기본 대화 테스트 +result = bot.chat("안녕하세요", thread_id="test-001") +last_msg = result["messages"][-1] +assert hasattr(last_msg, "content") and len(last_msg.content) > 0 +print(f"ChatBot.chat() ✓ — 응답: {last_msg.content[:60]}") +``` + +### 12-D. `search_database_tables` 직접 호출 (DataHub 연결 시) + +```python +# DataHub가 연결된 환경에서만 성공적인 결과 반환 +result = search_database_tables.invoke({ + "query": "주문 테이블", + "top_n": 3 +}) +# DataHub 연결 성공 시: {"orders": {"table_description": "...", ...}, ...} +# DataHub 연결 실패 시: {"error": True, "message": "..."} +print(f"search_database_tables ✓ — result keys: {list(result.keys())}") +``` + +--- + +## 13. 레거시 정리 (삭제 확인) + +아래 모듈들은 마이그레이션에서 삭제되었습니다. **임포트 시 에러 발생이 정상**입니다. + +```python +import importlib, sys + +deleted_modules = [ + "engine", + "engine.query_executor", + "utils.llm.core.factory", + "utils.llm.chains", + "utils.llm.retrieval", + "utils.llm.vectordb", + "utils.llm.graph_utils", + "utils.llm.output_schema", +] + +for mod in deleted_modules: + try: + importlib.import_module(mod) + print(f"WARNING: {mod} — 삭제되었어야 하지만 여전히 존재합니다") + except (ImportError, ModuleNotFoundError): + print(f"✓ {mod} 삭제 확인") +``` + +--- + +## 14. 전체 회귀 테스트 + +```bash +# 유닛 테스트 전체 실행 (145 passed, 6 skipped 예상) +pytest tests/ -v --tb=short + +# 커버리지 포함 +pytest tests/ --cov=src/lang2sql --cov-report=term-missing +``` + +**예상 결과**: 145 passed, 6 skipped (pgvector 관련 — 실제 PostgreSQL 없이는 skip) + +--- + +## 15. DB 커넥터 검증 + +사용하는 DB에 맞게 `.env`를 설정하고 아래를 실행합니다. + +```python +from dotenv import load_dotenv; load_dotenv() +from lang2sql.factory import build_db_from_env + +db = build_db_from_env() + +# 실제 테이블에서 데이터 조회 +rows = db.execute("SELECT COUNT(*) AS cnt FROM 실제_테이블명") +assert isinstance(rows, list) and "cnt" in rows[0] +print(f"DB 연결 ✓ — count={rows[0]['cnt']}") +``` + +### 지원 DB 목록 및 `.env` 키 + +| DB | `DB_TYPE` | 필수 환경변수 | +|----|-----------|---------------| +| SQLite | `sqlite` | `SQLITE_PATH` | +| PostgreSQL | `postgresql` | `POSTGRESQL_HOST/PORT/USER/PASSWORD/DATABASE` | +| MySQL | `mysql` | `MYSQL_HOST/PORT/USER/PASSWORD/DATABASE` | +| MariaDB | `mariadb` | `MARIADB_HOST/PORT/USER/PASSWORD/DATABASE` | +| DuckDB | `duckdb` | `DUCKDB_PATH` | +| ClickHouse | `clickhouse` | `CLICKHOUSE_HOST/PORT/USER/PASSWORD/DATABASE` | +| Snowflake | `snowflake` | `SNOWFLAKE_USER/PASSWORD/ACCOUNT` | +| Oracle | `oracle` | `ORACLE_HOST/PORT/USER/PASSWORD/SERVICE_NAME` | + +--- + +## 빠른 스모크 테스트 스크립트 + +아래 스크립트를 `smoke_test.py`로 저장 후 실행하면 가장 중요한 경로를 빠르게 확인할 수 있습니다. + +```python +""" +smoke_test.py — 핵심 경로 빠른 검증 (Anthropic + SQLite 기준) +실행: python smoke_test.py +""" + +import os +import sqlite3 + +from dotenv import load_dotenv +load_dotenv() + +print("=" * 50) +print("Lang2SQL v2 Smoke Test") +print("=" * 50) + +# 1. 테스트 DB +print("\n[1] SQLite DB 준비") +conn = sqlite3.connect("/tmp/smoke.db") +conn.execute("CREATE TABLE IF NOT EXISTS orders (id INT, amount REAL, created_at TEXT)") +conn.execute("DELETE FROM orders") +conn.executemany("INSERT INTO orders VALUES (?,?,?)", [ + (1, 10000, "2024-01-10"), (2, 20000, "2024-01-20"), (3, 15000, "2024-02-01") +]) +conn.commit(); conn.close() +print(" ✓ /tmp/smoke.db") + +# 2. Factory +print("\n[2] Factory 인스턴스 생성") +from lang2sql.factory import build_llm_from_env, build_embedding_from_env, build_db_from_env +llm = build_llm_from_env() +emb = build_embedding_from_env() +db = build_db_from_env() if os.getenv("DB_TYPE") else None +print(f" ✓ LLM={llm.__class__.__name__}, Embedding={emb.__class__.__name__}") + +# 3. LLM 통신 +print("\n[3] LLM 호출") +resp = llm.invoke([{"role": "user", "content": "Respond with OK"}]) +assert isinstance(resp, str) and len(resp) > 0 +print(f" ✓ response={resp[:30]}") + +# 4. BaselineNL2SQL +print("\n[4] BaselineNL2SQL") +from lang2sql.flows import BaselineNL2SQL +from lang2sql.integrations.db.sqlalchemy_ import SQLAlchemyDB + +catalog = [{"name": "orders", "description": "주문 테이블", "columns": {"id": "주문 ID", "amount": "금액", "created_at": "생성일"}}] +pipe_base = BaselineNL2SQL(catalog=catalog, llm=llm, db=SQLAlchemyDB("sqlite:////tmp/smoke.db"), db_dialect="sqlite") +rows = pipe_base.run("전체 주문 건수") +assert isinstance(rows, list) and len(rows) > 0 +print(f" ✓ rows={rows}") + +# 5. EnrichedNL2SQL +print("\n[5] EnrichedNL2SQL") +from lang2sql.flows import EnrichedNL2SQL + +pipe_rich = EnrichedNL2SQL( + catalog=catalog, llm=llm, embedding=emb, + db=SQLAlchemyDB("sqlite:////tmp/smoke.db"), + db_dialect="sqlite", gate_enabled=True, +) +rows2 = pipe_rich.run("주문 총 건수") +assert isinstance(rows2, list) and len(rows2) > 0 +print(f" ✓ rows={rows2}") + +# 6. 삭제 확인 +print("\n[6] 삭제된 레거시 모듈 확인") +import importlib +for m in ["utils.llm.retrieval", "utils.llm.vectordb", "utils.llm.chains"]: + try: + importlib.import_module(m) + print(f" WARNING: {m} 존재 — 삭제 필요") + except (ImportError, ModuleNotFoundError): + print(f" ✓ {m} 삭제됨") + +print("\n" + "=" * 50) +print("Smoke Test 완료") +print("=" * 50) +``` + +```bash +python smoke_test.py +``` From ca1d73b948fdf11d90f4ca252235407835ee3346 Mon Sep 17 00:00:00 2001 From: seyeong Date: Sat, 28 Feb 2026 15:43:34 +0900 Subject: [PATCH 10/10] style: apply pre-commit formatting --- cli/commands/quary.py | 1 + interface/app_pages/lang2sql.py | 8 +++---- interface/core/lang2sql_runner.py | 1 + interface/core/provider_factory.py | 1 + .../components/enrichment/context_enricher.py | 3 +-- .../components/gate/table_suitability.py | 6 ++--- src/lang2sql/factory.py | 19 +++++++++++---- .../integrations/embedding/bedrock_.py | 4 +--- .../integrations/embedding/ollama_.py | 4 +--- src/lang2sql/integrations/llm/bedrock_.py | 7 ++---- src/lang2sql/integrations/llm/ollama_.py | 4 +--- .../integrations/vectorstore/faiss_.py | 8 ++----- tests/test_components_context_enricher.py | 6 ++++- tests/test_components_table_suitability.py | 24 ++++++++++++++++--- tests/test_flows_enriched_nl2sql.py | 1 - tests/test_integrations_faiss_vectorstore.py | 2 +- .../test_integrations_pgvector_vectorstore.py | 2 +- 17 files changed, 60 insertions(+), 41 deletions(-) diff --git a/cli/commands/quary.py b/cli/commands/quary.py index da1a05a..5a6e652 100644 --- a/cli/commands/quary.py +++ b/cli/commands/quary.py @@ -79,6 +79,7 @@ def query_command( rows = pipeline.run(question) if rows: import json + print(json.dumps(rows, ensure_ascii=False, indent=2)) else: print("(결과 없음)") diff --git a/interface/app_pages/lang2sql.py b/interface/app_pages/lang2sql.py index 11f5830..214e480 100644 --- a/interface/app_pages/lang2sql.py +++ b/interface/app_pages/lang2sql.py @@ -51,11 +51,11 @@ # 설정 col1, col2 = st.columns(2) with col1: - user_dialect = st.selectbox( - "SQL 방언(Dialect):", options=DIALECT_OPTIONS, index=0 - ) + user_dialect = st.selectbox("SQL 방언(Dialect):", options=DIALECT_OPTIONS, index=0) with col2: - user_top_n = st.slider("검색할 테이블 정보 개수:", min_value=1, max_value=20, value=5) + user_top_n = st.slider( + "검색할 테이블 정보 개수:", min_value=1, max_value=20, value=5 + ) if st.button("쿼리 실행"): with st.spinner("쿼리 실행 중..."): diff --git a/interface/core/lang2sql_runner.py b/interface/core/lang2sql_runner.py index 503f230..08ecdc0 100644 --- a/interface/core/lang2sql_runner.py +++ b/interface/core/lang2sql_runner.py @@ -5,6 +5,7 @@ 지정된 데이터베이스 환경에서 실행하는 함수(`run_lang2sql`)를 제공합니다. 내부적으로 v2 플로우(BaselineNL2SQL / EnrichedNL2SQL)를 사용한다. """ + from __future__ import annotations from typing import Any diff --git a/interface/core/provider_factory.py b/interface/core/provider_factory.py index 0c616f8..72da835 100644 --- a/interface/core/provider_factory.py +++ b/interface/core/provider_factory.py @@ -3,6 +3,7 @@ LLMProfile과 EmbeddingProfile(config/models.py)을 받아 lang2sql.integrations의 구현체를 반환한다. """ + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/src/lang2sql/components/enrichment/context_enricher.py b/src/lang2sql/components/enrichment/context_enricher.py index 7d4d231..47012f9 100644 --- a/src/lang2sql/components/enrichment/context_enricher.py +++ b/src/lang2sql/components/enrichment/context_enricher.py @@ -49,8 +49,7 @@ def _run( tables_json = json.dumps(tables_map, ensure_ascii=False) user_content = ( - self._system_prompt - .replace("{profiles}", profiles_json) + self._system_prompt.replace("{profiles}", profiles_json) .replace("{related_tables}", tables_json) .replace("{refined_question}", query) ) diff --git a/src/lang2sql/components/gate/table_suitability.py b/src/lang2sql/components/gate/table_suitability.py index 091bbb3..4837805 100644 --- a/src/lang2sql/components/gate/table_suitability.py +++ b/src/lang2sql/components/gate/table_suitability.py @@ -50,10 +50,8 @@ def _run(self, query: str, schemas: list[CatalogEntry]) -> list[CatalogEntry]: tables_map[name] = {"table_description": desc, **cols} tables_json = json.dumps(tables_map, ensure_ascii=False) - user_content = ( - self._system_prompt - .replace("{question}", query) - .replace("{tables}", tables_json) + user_content = self._system_prompt.replace("{question}", query).replace( + "{tables}", tables_json ) messages = [{"role": "user", "content": user_content}] response = self._llm.invoke(messages) diff --git a/src/lang2sql/factory.py b/src/lang2sql/factory.py index 69539a6..6fca375 100644 --- a/src/lang2sql/factory.py +++ b/src/lang2sql/factory.py @@ -3,6 +3,7 @@ 레거시 utils/llm/core/factory.py를 LangChain 없이 재구현한 것. CLI와 Streamlit UI 양쪽에서 사용한다. """ + from __future__ import annotations import os @@ -99,7 +100,9 @@ def build_embedding_from_env() -> EmbeddingPort: return AzureOpenAIEmbedding( azure_deployment=os.environ["AZURE_OPENAI_EMBEDDING_MODEL"], azure_endpoint=os.environ["AZURE_OPENAI_EMBEDDING_ENDPOINT"], - api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION", "2023-09-15-preview"), + api_version=os.getenv( + "AZURE_OPENAI_EMBEDDING_API_VERSION", "2023-09-15-preview" + ), api_key=os.getenv("AZURE_OPENAI_EMBEDDING_KEY"), ) @@ -107,15 +110,23 @@ def build_embedding_from_env() -> EmbeddingPort: from .integrations.embedding.ollama_ import OllamaEmbedding return OllamaEmbedding( - model=os.getenv("EMBEDDING_MODEL", os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text")), - base_url=os.getenv("EMBEDDING_BASE_PATH", os.getenv("OLLAMA_EMBEDDING_BASE_URL", "http://localhost:11434")), + model=os.getenv( + "EMBEDDING_MODEL", + os.getenv("OLLAMA_EMBEDDING_MODEL", "nomic-embed-text"), + ), + base_url=os.getenv( + "EMBEDDING_BASE_PATH", + os.getenv("OLLAMA_EMBEDDING_BASE_URL", "http://localhost:11434"), + ), ) if provider == "bedrock": from .integrations.embedding.bedrock_ import BedrockEmbedding return BedrockEmbedding( - model_id=os.getenv("AWS_BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0"), + model_id=os.getenv( + "AWS_BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0" + ), aws_access_key_id=os.getenv("AWS_BEDROCK_EMBEDDING_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("AWS_BEDROCK_EMBEDDING_SECRET_ACCESS_KEY"), region_name=os.getenv("AWS_BEDROCK_EMBEDDING_REGION", "us-east-1"), diff --git a/src/lang2sql/integrations/embedding/bedrock_.py b/src/lang2sql/integrations/embedding/bedrock_.py index 74d5e1e..c0cc69a 100644 --- a/src/lang2sql/integrations/embedding/bedrock_.py +++ b/src/lang2sql/integrations/embedding/bedrock_.py @@ -26,9 +26,7 @@ def __init__( region_name: str = "us-east-1", ) -> None: if _boto3 is None: - raise IntegrationMissingError( - "boto3", hint="pip install boto3" - ) + raise IntegrationMissingError("boto3", hint="pip install boto3") self._model_id = model_id self._client = _boto3.client( "bedrock-runtime", diff --git a/src/lang2sql/integrations/embedding/ollama_.py b/src/lang2sql/integrations/embedding/ollama_.py index cc84989..d377c1b 100644 --- a/src/lang2sql/integrations/embedding/ollama_.py +++ b/src/lang2sql/integrations/embedding/ollama_.py @@ -19,9 +19,7 @@ def __init__( base_url: str = "http://localhost:11434", ) -> None: if _ollama is None: - raise IntegrationMissingError( - "ollama", hint="pip install ollama" - ) + raise IntegrationMissingError("ollama", hint="pip install ollama") self._model = model self._client = _ollama.Client(host=base_url) diff --git a/src/lang2sql/integrations/llm/bedrock_.py b/src/lang2sql/integrations/llm/bedrock_.py index 428e5d4..29fadfb 100644 --- a/src/lang2sql/integrations/llm/bedrock_.py +++ b/src/lang2sql/integrations/llm/bedrock_.py @@ -21,9 +21,7 @@ def __init__( region_name: str = "us-east-1", ) -> None: if _boto3 is None: - raise IntegrationMissingError( - "boto3", hint="pip install boto3" - ) + raise IntegrationMissingError("boto3", hint="pip install boto3") self._model = model self._client = _boto3.client( "bedrock-runtime", @@ -37,8 +35,7 @@ def invoke(self, messages: list[dict[str, str]]) -> str: user_msgs = [m for m in messages if m["role"] != "system"] converse_messages = [ - {"role": m["role"], "content": [{"text": m["content"]}]} - for m in user_msgs + {"role": m["role"], "content": [{"text": m["content"]}]} for m in user_msgs ] kwargs: dict = {"modelId": self._model, "messages": converse_messages} diff --git a/src/lang2sql/integrations/llm/ollama_.py b/src/lang2sql/integrations/llm/ollama_.py index 2c182cf..4172df4 100644 --- a/src/lang2sql/integrations/llm/ollama_.py +++ b/src/lang2sql/integrations/llm/ollama_.py @@ -19,9 +19,7 @@ def __init__( base_url: str = "http://localhost:11434", ) -> None: if _ollama is None: - raise IntegrationMissingError( - "ollama", hint="pip install ollama" - ) + raise IntegrationMissingError("ollama", hint="pip install ollama") self._model = model self._client = _ollama.Client(host=base_url) diff --git a/src/lang2sql/integrations/vectorstore/faiss_.py b/src/lang2sql/integrations/vectorstore/faiss_.py index 3252c3e..0e6a12f 100644 --- a/src/lang2sql/integrations/vectorstore/faiss_.py +++ b/src/lang2sql/integrations/vectorstore/faiss_.py @@ -85,9 +85,7 @@ def save(self, path: str | None = None) -> None: raise RuntimeError("Cannot save before any upsert() call.") pathlib.Path(path).parent.mkdir(parents=True, exist_ok=True) _faiss.write_index(self._index, path) - pathlib.Path(path + ".meta").write_text( - json.dumps(self._ids), encoding="utf-8" - ) + pathlib.Path(path + ".meta").write_text(json.dumps(self._ids), encoding="utf-8") @classmethod def load(cls, path: str) -> "FAISSVectorStore": @@ -99,9 +97,7 @@ def load(cls, path: str) -> "FAISSVectorStore": raise IntegrationMissingError("faiss", hint="pip install faiss-cpu") meta_path = pathlib.Path(path + ".meta") if not pathlib.Path(path).exists() or not meta_path.exists(): - raise FileNotFoundError( - f"Index files not found: {path}, {path}.meta" - ) + raise FileNotFoundError(f"Index files not found: {path}, {path}.meta") store = cls(index_path=path) store._index = _faiss.read_index(path) store._ids = json.loads(meta_path.read_text(encoding="utf-8")) diff --git a/tests/test_components_context_enricher.py b/tests/test_components_context_enricher.py index 3d4e927..20fd83c 100644 --- a/tests/test_components_context_enricher.py +++ b/tests/test_components_context_enricher.py @@ -22,7 +22,11 @@ def _catalog() -> list[CatalogEntry]: { "name": "orders", "description": "주문 테이블", - "columns": {"order_id": "주문 ID", "amount": "주문 금액", "created_at": "생성일"}, + "columns": { + "order_id": "주문 ID", + "amount": "주문 금액", + "created_at": "생성일", + }, } ] diff --git a/tests/test_components_table_suitability.py b/tests/test_components_table_suitability.py index e1c1b7d..0a37835 100644 --- a/tests/test_components_table_suitability.py +++ b/tests/test_components_table_suitability.py @@ -24,7 +24,11 @@ def _catalog() -> list[CatalogEntry]: { "name": "orders", "description": "주문 테이블", - "columns": {"order_id": "주문 ID", "amount": "주문 금액", "created_at": "생성일"}, + "columns": { + "order_id": "주문 ID", + "amount": "주문 금액", + "created_at": "생성일", + }, }, { "name": "users", @@ -91,7 +95,13 @@ def test_table_suitability_sorted_by_score(): def test_table_suitability_empty_result_when_all_below_threshold(): resp = _suitability_json( [ - {"table_name": "orders", "score": 0.1, "reason": "낮은 관련성", "matched_columns": [], "missing_entities": []}, + { + "table_name": "orders", + "score": 0.1, + "reason": "낮은 관련성", + "matched_columns": [], + "missing_entities": [], + }, ] ) evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), threshold=0.3) @@ -102,7 +112,15 @@ def test_table_suitability_empty_result_when_all_below_threshold(): def test_table_suitability_emits_hook_events(): hook = MemoryHook() resp = _suitability_json( - [{"table_name": "orders", "score": 0.8, "reason": "ok", "matched_columns": [], "missing_entities": []}] + [ + { + "table_name": "orders", + "score": 0.8, + "reason": "ok", + "matched_columns": [], + "missing_entities": [], + } + ] ) evaluator = TableSuitabilityEvaluator(llm=FakeLLM(resp), hook=hook) evaluator.run("test", _catalog()) diff --git a/tests/test_flows_enriched_nl2sql.py b/tests/test_flows_enriched_nl2sql.py index f69692a..9683c37 100644 --- a/tests/test_flows_enriched_nl2sql.py +++ b/tests/test_flows_enriched_nl2sql.py @@ -12,7 +12,6 @@ from lang2sql.core.hooks import MemoryHook from lang2sql.flows.enriched_nl2sql import EnrichedNL2SQL - # --------------------------------------------------------------------------- # Fakes # --------------------------------------------------------------------------- diff --git a/tests/test_integrations_faiss_vectorstore.py b/tests/test_integrations_faiss_vectorstore.py index 2b4e1d6..0b97fcb 100644 --- a/tests/test_integrations_faiss_vectorstore.py +++ b/tests/test_integrations_faiss_vectorstore.py @@ -3,6 +3,7 @@ All tests are auto-skipped when faiss-cpu is not installed. """ + import pytest faiss = pytest.importorskip("faiss") # skip entire module if not installed @@ -12,7 +13,6 @@ from lang2sql.integrations.vectorstore.faiss_ import FAISSVectorStore - # ── helpers ────────────────────────────────────────────────────────────────── diff --git a/tests/test_integrations_pgvector_vectorstore.py b/tests/test_integrations_pgvector_vectorstore.py index c2245e2..789872b 100644 --- a/tests/test_integrations_pgvector_vectorstore.py +++ b/tests/test_integrations_pgvector_vectorstore.py @@ -8,6 +8,7 @@ TEST_POSTGRES_URL="postgresql://postgres:postgres@localhost:5432/test" \\ pytest tests/test_integrations_pgvector_vectorstore.py -v """ + import os import pytest from uuid import uuid4 @@ -19,7 +20,6 @@ from lang2sql.integrations.vectorstore.pgvector_ import PGVectorStore - # ── helpers ──────────────────────────────────────────────────────────────────