Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 45 additions & 69 deletions cli/commands/quary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
`query` CLI 명령어를 제공합니다.
"""

import os

import click

from cli.utils.logger import configure_logging
Expand All @@ -16,14 +14,10 @@
@click.command(name="query")
@click.argument("question", type=str)
@click.option(
"--database-env",
default="clickhouse",
help="사용할 데이터베이스 환경 (기본값: clickhouse)",
)
@click.option(
"--retriever-name",
default="기본",
help="테이블 검색기 이름 (기본값: 기본)",
"--flow",
type=click.Choice(["baseline", "enriched"]),
default="baseline",
help="사용할 플로우 (기본값: baseline)",
)
@click.option(
"--top-n",
Expand All @@ -32,81 +26,63 @@
help="검색된 상위 테이블 수 제한 (기본값: 5)",
)
@click.option(
"--device",
default="cpu",
help="LLM 실행에 사용할 디바이스 (기본값: cpu)",
"--dialect",
default=None,
help="SQL 방언 (예: sqlite, postgresql, mysql, bigquery, duckdb)",
)
@click.option(
"--use-enriched-graph",
"--no-gate",
is_flag=True,
help="확장된 그래프(프로파일 추출 + 컨텍스트 보강) 사용 여부",
)
@click.option(
"--vectordb-type",
type=click.Choice(["faiss", "pgvector"]),
default="faiss",
help="사용할 벡터 데이터베이스 타입 (기본값: faiss)",
)
@click.option(
"--vectordb-location",
help=(
"VectorDB 위치 설정\n"
"- FAISS: 디렉토리 경로 (예: ./my_vectordb)\n"
"- pgvector: 연결 문자열 (예: postgresql://user:pass@host:port/db)\n"
"기본값: FAISS는 './dev/table_info_db', pgvector는 환경변수 사용"
),
help="QuestionGate 비활성화 (enriched 플로우 전용)",
)
def query_command(
question: str,
database_env: str,
retriever_name: str,
flow: str,
top_n: int,
device: str,
use_enriched_graph: bool,
vectordb_type: str = "faiss",
vectordb_location: str = None,
dialect: str,
no_gate: bool,
) -> None:
"""자연어 질문을 SQL 쿼리로 변환하여 출력합니다.
"""자연어 질문을 SQL 쿼리로 변환하여 실행 결과를 출력합니다.

Args:
question (str): SQL로 변환할 자연어 질문
database_env (str): 사용할 데이터베이스 환경
retriever_name (str): 테이블 검색기 이름
top_n (int): 검색된 상위 테이블 수 제한
device (str): LLM 실행 디바이스
use_enriched_graph (bool): 확장된 그래프 사용 여부
vectordb_type (str): 벡터 데이터베이스 타입 ("faiss" 또는 "pgvector")
vectordb_location (Optional[str]): 벡터DB 경로 또는 연결 URL
환경변수(LLM_PROVIDER, EMBEDDING_PROVIDER, DB_TYPE 등)로 설정을 제어합니다.
"""
try:
from engine.query_executor import execute_query, extract_sql_from_result
from lang2sql.factory import (
build_db_from_env,
build_embedding_from_env,
build_llm_from_env,
)
from lang2sql.flows import BaselineNL2SQL, EnrichedNL2SQL

os.environ["VECTORDB_TYPE"] = vectordb_type
llm = build_llm_from_env()
db = build_db_from_env()

if vectordb_location:
os.environ["VECTORDB_LOCATION"] = vectordb_location
if flow == "baseline":
pipeline = BaselineNL2SQL(
catalog=[],
llm=llm,
db=db,
db_dialect=dialect,
)
else:
embedding = build_embedding_from_env()
pipeline = EnrichedNL2SQL(
catalog=[],
llm=llm,
db=db,
embedding=embedding,
db_dialect=dialect,
gate_enabled=not no_gate,
top_n=top_n,
)

res = execute_query(
query=question,
database_env=database_env,
retriever_name=retriever_name,
top_n=top_n,
device=device,
use_enriched_graph=use_enriched_graph,
)
rows = pipeline.run(question)
if rows:
import json

sql = extract_sql_from_result(res)
if sql:
print(sql)
print(json.dumps(rows, ensure_ascii=False, indent=2))
else:
generated_query = res.get("generated_query")
if generated_query:
query_text = (
generated_query.content
if hasattr(generated_query, "content")
else str(generated_query)
)
print(query_text)
print("(결과 없음)")

except Exception as e:
logger.error("쿼리 처리 중 오류 발생: %s", e)
Expand Down
Loading