diff --git a/sql_agent_workflow.png b/sql_agent_workflow.png new file mode 100644 index 0000000..cecf785 Binary files /dev/null and b/sql_agent_workflow.png differ diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 0000000..f677fbe --- /dev/null +++ b/src/agents/__init__.py @@ -0,0 +1,6 @@ +""" +에이전트 루트 패키지 +""" + + + diff --git a/src/agents/sql_agent/__init__.py b/src/agents/sql_agent/__init__.py new file mode 100644 index 0000000..cd71571 --- /dev/null +++ b/src/agents/sql_agent/__init__.py @@ -0,0 +1,27 @@ +# src/agents/sql_agent/__init__.py + +from .state import SqlAgentState +from .nodes import SqlAgentNodes +from .edges import SqlAgentEdges +from .graph import SqlAgentGraph +from .exceptions import ( + SqlAgentException, + ValidationException, + ExecutionException, + DatabaseConnectionException, + LLMProviderException, + MaxRetryExceededException +) + +__all__ = [ + 'SqlAgentState', + 'SqlAgentNodes', + 'SqlAgentEdges', + 'SqlAgentGraph', + 'SqlAgentException', + 'ValidationException', + 'ExecutionException', + 'DatabaseConnectionException', + 'LLMProviderException', + 'MaxRetryExceededException' +] diff --git a/src/agents/sql_agent/edges.py b/src/agents/sql_agent/edges.py new file mode 100644 index 0000000..9a24fe3 --- /dev/null +++ b/src/agents/sql_agent/edges.py @@ -0,0 +1,51 @@ +# src/agents/sql_agent/edges.py + +from .state import SqlAgentState + +# 상수 정의 +MAX_ERROR_COUNT = 3 + +class SqlAgentEdges: + """SQL Agent의 모든 엣지 로직을 담당하는 클래스""" + + @staticmethod + def route_after_intent_classification(state: SqlAgentState) -> str: + """의도 분류 결과에 따라 라우팅을 결정합니다.""" + if state['intent'] == "SQL": + print("--- 의도: SQL 관련 질문 ---") + return "db_classifier" + print("--- 의도: SQL과 관련 없는 질문 ---") + return "unsupported_question" + + @staticmethod + def should_execute_sql(state: SqlAgentState) -> str: + """SQL 검증 결과에 따라 다음 단계를 결정합니다.""" + validation_error_count = state.get("validation_error_count", 0) + + if validation_error_count >= MAX_ERROR_COUNT: + print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") + return "synthesize_failure" + + if state.get("validation_error"): + print(f"--- 검증 실패 {validation_error_count}회: SQL 재생성 ---") + return "regenerate" + + print("--- 검증 성공: SQL 실행 ---") + return "execute" + + @staticmethod + def should_retry_or_respond(state: SqlAgentState) -> str: + """SQL 실행 결과에 따라 다음 단계를 결정합니다.""" + execution_error_count = state.get("execution_error_count", 0) + execution_result = state.get("execution_result", "") + + if execution_error_count >= MAX_ERROR_COUNT: + print(f"--- 실행 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") + return "synthesize_failure" + + if "오류" in execution_result: + print(f"--- 실행 실패 {execution_error_count}회: SQL 재생성 ---") + return "regenerate" + + print("--- 실행 성공: 최종 답변 생성 ---") + return "synthesize_success" diff --git a/src/agents/sql_agent/exceptions.py b/src/agents/sql_agent/exceptions.py new file mode 100644 index 0000000..0e0171a --- /dev/null +++ b/src/agents/sql_agent/exceptions.py @@ -0,0 +1,31 @@ +# src/agents/sql_agent/exceptions.py + +class SqlAgentException(Exception): + """SQL Agent 관련 기본 예외 클래스""" + pass + +class ValidationException(SqlAgentException): + """SQL 검증 실패 예외""" + def __init__(self, message: str, error_count: int = 0): + super().__init__(message) + self.error_count = error_count + +class ExecutionException(SqlAgentException): + """SQL 실행 실패 예외""" + def __init__(self, message: str, error_count: int = 0): + super().__init__(message) + self.error_count = error_count + +class DatabaseConnectionException(SqlAgentException): + """데이터베이스 연결 실패 예외""" + pass + +class LLMProviderException(SqlAgentException): + """LLM 제공자 관련 예외""" + pass + +class MaxRetryExceededException(SqlAgentException): + """최대 재시도 횟수 초과 예외""" + def __init__(self, message: str, max_retries: int): + super().__init__(f"{message} (최대 재시도 {max_retries}회 초과)") + self.max_retries = max_retries diff --git a/src/agents/sql_agent/graph.py b/src/agents/sql_agent/graph.py new file mode 100644 index 0000000..cf48421 --- /dev/null +++ b/src/agents/sql_agent/graph.py @@ -0,0 +1,129 @@ +# src/agents/sql_agent/graph.py + +from langgraph.graph import StateGraph, END +from core.providers.llm_provider import LLMProvider +from services.database.database_service import DatabaseService +from .state import SqlAgentState +from .nodes import SqlAgentNodes +from .edges import SqlAgentEdges + +class SqlAgentGraph: + """SQL Agent 그래프를 구성하고 관리하는 클래스""" + + def __init__(self, llm_provider: LLMProvider, database_service: DatabaseService): + self.llm_provider = llm_provider + self.database_service = database_service + self.nodes = SqlAgentNodes(llm_provider, database_service) + self.edges = SqlAgentEdges() + self._graph = None + + def create_graph(self) -> StateGraph: + """SQL Agent 그래프를 생성하고 구성합니다.""" + if self._graph is not None: + return self._graph + + graph = StateGraph(SqlAgentState) + + # 노드 추가 + self._add_nodes(graph) + + # 엣지 추가 + self._add_edges(graph) + + # 진입점 설정 + graph.set_entry_point("intent_classifier") + + # 그래프 컴파일 + self._graph = graph.compile() + return self._graph + + def _add_nodes(self, graph: StateGraph): + """그래프에 모든 노드를 추가합니다.""" + graph.add_node("intent_classifier", self.nodes.intent_classifier_node) + graph.add_node("db_classifier", self.nodes.db_classifier_node) + graph.add_node("unsupported_question", self.nodes.unsupported_question_node) + graph.add_node("sql_generator", self.nodes.sql_generator_node) + graph.add_node("sql_validator", self.nodes.sql_validator_node) + graph.add_node("sql_executor", self.nodes.sql_executor_node) + graph.add_node("response_synthesizer", self.nodes.response_synthesizer_node) + + def _add_edges(self, graph: StateGraph): + """그래프에 모든 엣지를 추가합니다.""" + # 의도 분류 후 조건부 라우팅 + graph.add_conditional_edges( + "intent_classifier", + self.edges.route_after_intent_classification, + { + "db_classifier": "db_classifier", + "unsupported_question": "unsupported_question" + } + ) + + # 지원되지 않는 질문 처리 후 종료 + graph.add_edge("unsupported_question", END) + + # DB 분류 후 SQL 생성으로 이동 + graph.add_edge("db_classifier", "sql_generator") + + # SQL 생성 후 검증으로 이동 + graph.add_edge("sql_generator", "sql_validator") + + # SQL 검증 후 조건부 라우팅 + graph.add_conditional_edges( + "sql_validator", + self.edges.should_execute_sql, + { + "regenerate": "sql_generator", + "execute": "sql_executor", + "synthesize_failure": "response_synthesizer" + } + ) + + # SQL 실행 후 조건부 라우팅 + graph.add_conditional_edges( + "sql_executor", + self.edges.should_retry_or_respond, + { + "regenerate": "sql_generator", + "synthesize_success": "response_synthesizer", + "synthesize_failure": "response_synthesizer" + } + ) + + # 응답 생성 후 종료 + graph.add_edge("response_synthesizer", END) + + async def run(self, initial_state: dict) -> dict: + """그래프를 실행하고 결과를 반환합니다.""" + try: + if self._graph is None: + self.create_graph() + + result = await self._graph.ainvoke(initial_state) + return result + + except Exception as e: + print(f"그래프 실행 중 오류 발생: {e}") + # 에러 발생 시 예외를 다시 발생시켜 상위 레벨에서 HTTP 에러로 처리되도록 함 + raise e + + def save_graph_visualization(self, file_path: str = "sql_agent_graph.png") -> bool: + """그래프 시각화를 파일로 저장합니다.""" + try: + if self._graph is None: + self.create_graph() + + # PNG 이미지 생성 + png_data = self._graph.get_graph(xray=True).draw_mermaid_png() + + # 파일로 저장 + with open(file_path, "wb") as f: + f.write(png_data) + + print(f"그래프 시각화가 {file_path}에 저장되었습니다.") + return True + + except Exception as e: + print(f"그래프 시각화 저장 실패: {e}") + return False + \ No newline at end of file diff --git a/src/agents/sql_agent/nodes.py b/src/agents/sql_agent/nodes.py new file mode 100644 index 0000000..b381c17 --- /dev/null +++ b/src/agents/sql_agent/nodes.py @@ -0,0 +1,319 @@ +# src/agents/sql_agent/nodes.py + +import os +import sys +import asyncio +from typing import List, Optional +from langchain.output_parsers.pydantic import PydanticOutputParser +from langchain_core.output_parsers import StrOutputParser +from langchain.prompts import load_prompt + +from schemas.agent.sql_schemas import SqlQuery +from services.database.database_service import DatabaseService +from core.providers.llm_provider import LLMProvider +from .state import SqlAgentState +from .exceptions import ( + ValidationException, + ExecutionException, + DatabaseConnectionException, + MaxRetryExceededException +) + +# 상수 정의 +MAX_ERROR_COUNT = 3 +PROMPT_VERSION = "v1" +PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent") + +def resource_path(relative_path): + """PyInstaller 경로 해결 함수""" + try: + base_path = sys._MEIPASS + except Exception: + base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + return os.path.join(base_path, relative_path) + +class SqlAgentNodes: + """SQL Agent의 모든 노드 로직을 담당하는 클래스""" + + def __init__(self, llm_provider: LLMProvider, database_service: DatabaseService): + self.llm_provider = llm_provider + self.database_service = database_service + + # 프롬프트 로드 + self._load_prompts() + + def _load_prompts(self): + """프롬프트 파일들을 로드합니다.""" + try: + self.intent_classifier_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml")) + ) + self.db_classifier_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "db_classifier.yaml")) + ) + self.sql_generator_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml")) + ) + self.response_synthesizer_prompt = load_prompt( + resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml")) + ) + except Exception as e: + raise FileNotFoundError(f"프롬프트 파일 로드 실패: {e}") + + async def intent_classifier_node(self, state: SqlAgentState) -> SqlAgentState: + """사용자 질문의 의도를 분류하는 노드""" + print("--- 0. 의도 분류 중 ---") + + try: + llm = await self.llm_provider.get_llm() + + # 채팅 내역을 활용하여 의도 분류 + input_data = { + "question": state['question'], + "chat_history": state.get('chat_history', []) + } + + chain = self.intent_classifier_prompt | llm | StrOutputParser() + intent = await chain.ainvoke(input_data) + state['intent'] = intent.strip() + + print(f"의도 분류 결과: {state['intent']}") + return state + + except Exception as e: + print(f"의도 분류 실패: {e}") + # 기본값으로 SQL 처리 + state['intent'] = "SQL" + return state + + async def unsupported_question_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL과 관련 없는 질문을 처리하는 노드""" + print("--- SQL 관련 없는 질문 ---") + + state['final_response'] = """죄송합니다, 해당 질문에는 답변할 수 없습니다. +저는 데이터베이스 관련 질문만 처리할 수 있습니다. +SQL 쿼리나 데이터 분석과 관련된 질문을 해주세요.""" + + return state + + async def db_classifier_node(self, state: SqlAgentState) -> SqlAgentState: + """데이터베이스를 분류하고 스키마를 가져오는 노드""" + print("--- 0.5. DB 분류 중 ---") + + try: + # 데이터베이스 목록 가져오기 + available_dbs = await self.database_service.get_available_databases() + + if not available_dbs: + raise DatabaseConnectionException("사용 가능한 데이터베이스가 없습니다.") + + # 데이터베이스 옵션 생성 + db_options = "\n".join([ + f"- {db.database_name}: {db.description}" + for db in available_dbs + ]) + + # LLM을 사용하여 적절한 데이터베이스 선택 + llm = await self.llm_provider.get_llm() + chain = self.db_classifier_prompt | llm | StrOutputParser() + selected_db_name = await chain.ainvoke({ + "db_options": db_options, + "chat_history": state['chat_history'], + "question": state['question'] + }) + + selected_db_name = selected_db_name.strip() + state['selected_db'] = selected_db_name + + print(f'--- 선택된 DB: {selected_db_name} ---') + + # 선택된 데이터베이스의 스키마 정보 가져오기 + db_schema = await self.database_service.get_schema_for_db(selected_db_name) + state['db_schema'] = db_schema + + return state + + except Exception as e: + print(f"데이터베이스 분류 실패: {e}") + print(f"에러 타입: {type(e).__name__}") + print(f"에러 상세: {str(e)}") + + # 폴백 없이 에러를 다시 발생시킴 + raise e + + async def sql_generator_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리를 생성하는 노드""" + print("--- 1. SQL 생성 중 ---") + + try: + parser = PydanticOutputParser(pydantic_object=SqlQuery) + + # 에러 피드백 컨텍스트 생성 + error_feedback = self._build_error_feedback(state) + + prompt = self.sql_generator_prompt.format( + format_instructions=parser.get_format_instructions(), + db_schema=state['db_schema'], + chat_history=state['chat_history'], + question=state['question'], + error_feedback=error_feedback + ) + + llm = await self.llm_provider.get_llm() + response = await llm.ainvoke(prompt) + parsed_query = parser.invoke(response.content) + + state['sql_query'] = parsed_query.query + state['validation_error'] = None + state['execution_result'] = None + + return state + + except Exception as e: + raise ExecutionException(f"SQL 생성 실패: {e}") + + def _build_error_feedback(self, state: SqlAgentState) -> str: + """에러 피드백 컨텍스트를 생성합니다.""" + error_feedback = "" + + # 검증 오류가 있었을 경우 + if state.get("validation_error") and state.get("validation_error_count", 0) > 0: + error_feedback = f""" + Your previous query was rejected for the following reason: {state['validation_error']} + Please generate a new, safe query that does not contain forbidden keywords. + """ + # 실행 오류가 있었을 경우 + elif (state.get("execution_result") and + "오류" in state.get("execution_result", "") and + state.get("execution_error_count", 0) > 0): + error_feedback = f""" + Your previously generated SQL query failed with the following database error: + FAILED SQL: {state['sql_query']} + DATABASE ERROR: {state['execution_result']} + Please correct the SQL query based on the error. + """ + + return error_feedback + + async def sql_validator_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리의 안전성을 검증하는 노드""" + print("--- 2. SQL 검증 중 ---") + + try: + query_words = state['sql_query'].lower().split() + dangerous_keywords = [ + "drop", "delete", "update", "insert", "truncate", + "alter", "create", "grant", "revoke" + ] + found_keywords = [keyword for keyword in dangerous_keywords if keyword in query_words] + + if found_keywords: + keyword_str = ', '.join(f"'{k}'" for k in found_keywords) + error_msg = f'위험한 키워드 {keyword_str}가 포함되어 있습니다.' + state['validation_error'] = error_msg + state['validation_error_count'] = state.get('validation_error_count', 0) + 1 + + if state['validation_error_count'] >= MAX_ERROR_COUNT: + raise MaxRetryExceededException( + f"SQL 검증 실패가 {MAX_ERROR_COUNT}회 반복됨", MAX_ERROR_COUNT + ) + else: + state['validation_error'] = None + state['validation_error_count'] = 0 + + return state + + except MaxRetryExceededException: + raise + except Exception as e: + raise ValidationException(f"SQL 검증 중 오류 발생: {e}") + + async def sql_executor_node(self, state: SqlAgentState) -> SqlAgentState: + """SQL 쿼리를 실행하는 노드""" + print("--- 3. SQL 실행 중 ---") + + try: + selected_db = state.get('selected_db', 'default') + user_db_id = state.get('user_db_id', 'TEST-USER-DB-12345') + + result = await self.database_service.execute_query( + state['sql_query'], + database_name=selected_db, + user_db_id=user_db_id + ) + + state['execution_result'] = result + state['validation_error_count'] = 0 + state['execution_error_count'] = 0 + + return state + + except Exception as e: + error_msg = f"실행 오류: {e}" + state['execution_result'] = error_msg + state['validation_error_count'] = 0 + state['execution_error_count'] = state.get('execution_error_count', 0) + 1 + + print(f"⚠️ SQL 실행 실패 ({state['execution_error_count']}/{MAX_ERROR_COUNT}): {error_msg}") + + if state['execution_error_count'] >= MAX_ERROR_COUNT: + print(f"🚫 SQL 실행 실패 {MAX_ERROR_COUNT}회 도달, 재시도 중단") + print(f"최종 에러: {error_msg}") + + # 최종 실패 시 기본 응답 설정 + state['final_response'] = f"죄송합니다. SQL 쿼리 실행에 실패했습니다. 오류: {error_msg}" + + raise MaxRetryExceededException( + f"SQL 실행 실패가 {MAX_ERROR_COUNT}회 반복됨", MAX_ERROR_COUNT + ) + + return state + + async def response_synthesizer_node(self, state: SqlAgentState) -> SqlAgentState: + """최종 답변을 생성하는 노드""" + print("--- 4. 최종 답변 생성 중 ---") + + try: + is_failure = (state.get('validation_error_count', 0) >= MAX_ERROR_COUNT or + state.get('execution_error_count', 0) >= MAX_ERROR_COUNT) + + if is_failure: + context_message = self._build_failure_context(state) + else: + context_message = f""" + Successfully executed the SQL query to answer the user's question. + SQL Query: {state['sql_query']} + SQL Result: {state['execution_result']} + """ + + prompt = self.response_synthesizer_prompt.format( + question=state['question'], + chat_history=state['chat_history'], + context_message=context_message + ) + + llm = await self.llm_provider.get_llm() + response = await llm.ainvoke(prompt) + state['final_response'] = response.content + + return state + + except Exception as e: + # 최종 답변 생성 실패 시 기본 메시지 제공 + state['final_response'] = f"죄송합니다. 답변 생성 중 오류가 발생했습니다: {e}" + return state + + def _build_failure_context(self, state: SqlAgentState) -> str: + """실패 상황에 대한 컨텍스트 메시지를 생성합니다.""" + if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT: + error_type = "SQL 검증" + error_details = state.get('validation_error') + else: + error_type = "SQL 실행" + error_details = state.get('execution_result') + + return f""" + An attempt to answer the user's question failed after multiple retries. + Failure Type: {error_type} + Last Error: {error_details} + """ diff --git a/src/agents/sql_agent/state.py b/src/agents/sql_agent/state.py new file mode 100644 index 0000000..ab6aa1c --- /dev/null +++ b/src/agents/sql_agent/state.py @@ -0,0 +1,30 @@ +# src/agents/sql_agent/state.py + +from typing import List, TypedDict, Optional +from langchain_core.messages import BaseMessage + +class SqlAgentState(TypedDict): + """SQL Agent의 상태를 정의하는 TypedDict""" + + # 입력 정보 + question: str + chat_history: List[BaseMessage] + + # 데이터베이스 관련 + selected_db: Optional[str] + db_schema: str + + # 의도 분류 결과 + intent: str + + # SQL 생성 및 검증 + sql_query: str + validation_error: Optional[str] + validation_error_count: int + + # SQL 실행 결과 + execution_result: Optional[str] + execution_error_count: int + + # 최종 응답 + final_response: str diff --git a/src/agents/sql_agent_graph.py b/src/agents/sql_agent_graph.py deleted file mode 100644 index 71d59b0..0000000 --- a/src/agents/sql_agent_graph.py +++ /dev/null @@ -1,285 +0,0 @@ -# src/agents/sql_agent_graph.py - -import os -import sys -from typing import List, TypedDict, Optional -from langchain_core.messages import BaseMessage -from langgraph.graph import StateGraph, END -from langchain.output_parsers.pydantic import PydanticOutputParser -from langchain_core.output_parsers import StrOutputParser -from langchain.prompts import load_prompt -from schemas.sql_schemas import SqlQuery -from core.db_manager import db_instance -from core.llm_provider import llm_instance - -# --- PyInstaller 경로 해결 함수 --- -def resource_path(relative_path): - try: - # PyInstaller creates a temp folder and stores path in _MEIPASS - base_path = sys._MEIPASS - except Exception: - # 개발 환경에서는 src 폴더를 기준으로 경로 설정 - base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - return os.path.join(base_path, relative_path) - -# --- 상수 정의 --- -MAX_ERROR_COUNT = 3 -PROMPT_VERSION = "v1" -PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent") - -# --- 프롬프트 로드 --- -INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml"))) -DB_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "db_classifier.yaml"))) -SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml"))) -RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml"))) - -# Agent 상태 정의 -class SqlAgentState(TypedDict): - question: str - chat_history: List[BaseMessage] - db_schema: str - intent: str - sql_query: str - validation_error: Optional[str] - validation_error_count: int - execution_result: Optional[str] - execution_error_count: int - final_response: str - -# --- 노드 함수 정의 --- -def intent_classifier_node(state: SqlAgentState): - print("--- 0. 의도 분류 중 ---") - chain = INTENT_CLASSIFIER_PROMPT | llm_instance | StrOutputParser() - intent = chain.invoke({"question": state['question']}) - state['intent'] = intent - return state - -def unsupported_question_node(state: SqlAgentState): - print("--- SQL 관련 없는 질문 ---") - state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다." - return state - -def db_classifier_node(state: SqlAgentState): - print("--- 0.5. DB 분류 중 ---") - - # TODO: BE API 호출로 대체 필요 - available_dbs = [ - { - "connection_name": "local_mysql", - "database_name": "sakila", - "description": "DVD 대여점 비즈니스 모델을 다루는 샘플 데이터베이스로, 영화, 배우, 고객, 대여 기록 등의 정보를 포함합니다." - }, - { - "connection_name": "local_mysql", - "database_name": "ecom_prod", - "description": "온라인 쇼핑몰의 운영 데이터베이스로, 상품 카탈로그, 고객 주문, 재고 및 배송 정보를 관리합니다." - }, - { - "connection_name": "local_mysql", - "database_name": "hr_analytics", - "description": "회사의 인사 관리 데이터베이스로, 직원 정보, 급여, 부서, 성과 평가 기록을 포함합니다." - }, - { - "connection_name": "local_mysql", - "database_name": "web_logs", - "description": "웹사이트 트래픽 분석을 위한 로그 데이터베이스로, 사용자 방문 기록, 페이지 뷰, 에러 로그 등을 저장합니다." - } - ] - - db_options = "\n".join([f"- {db['database_name']}: {db['description']}" for db in available_dbs]) - - chain = DB_CLASSIFIER_PROMPT | llm_instance | StrOutputParser() - selected_db_name = chain.invoke({ - "db_options": db_options, - "chat_history": state['chat_history'], - "question": state['question'] - }) - - state['selected_db'] = selected_db_name.strip() - - # 선택된 DB의 스키마 정보를 가져와서 상태에 업데이트합니다. - print(f'--- 선택된 DB: {selected_db_name} ---') - - # TODO: get_schema_for_db - state['db_schema'] = db_instance.get_schema_for_db(db_name=selected_db_name) - - return state - -def sql_generator_node(state: SqlAgentState): - print("--- 1. SQL 생성 중 ---") - parser = PydanticOutputParser(pydantic_object=SqlQuery) - - # --- 에러 피드백 컨텍스트 생성 --- - error_feedback = "" - # 1. 검증 오류가 있었을 경우 - if state.get("validation_error") and state.get("validation_error_count", 0) > 0: - error_feedback = f""" - Your previous query was rejected for the following reason: {state['validation_error']} - Please generate a new, safe query that does not contain forbidden keywords. - """ - # 2. 실행 오류가 있었을 경우 - elif state.get("execution_result") and "오류" in state.get("execution_result") and state.get("execution_error_count", 0) > 0: - error_feedback = f""" - Your previously generated SQL query failed with the following database error: - FAILED SQL: {state['sql_query']} - DATABASE ERROR: {state['execution_result']} - Please correct the SQL query based on the error. - """ - - prompt = SQL_GENERATOR_PROMPT.format( - format_instructions=parser.get_format_instructions(), - db_schema=state['db_schema'], - chat_history=state['chat_history'], - question=state['question'], - error_feedback=error_feedback - ) - - response = llm_instance.invoke(prompt) - parsed_query = parser.invoke(response.content) - state['sql_query'] = parsed_query.query - state['validation_error'] = None - state['execution_result'] = None - return state - -def sql_validator_node(state: SqlAgentState): - print("--- 2. SQL 검증 중 ---") - query_words = state['sql_query'].lower().split() - dangerous_keywords = [ - "drop", "delete", "update", "insert", "truncate", - "alter", "create", "grant", "revoke" - ] - found_keywords = [keyword for keyword in dangerous_keywords if keyword in query_words] - - if found_keywords: - keyword_str = ', '.join(f"'{k}'" for k in found_keywords) - state['validation_error'] = f'위험한 키워드 {keyword_str}가 포함되어 있습니다.' - state['validation_error_count'] += 1 # sql 검증 횟수 추가 - else: - state['validation_error'] = None - state['validation_error_count'] = 0 # sql 검증 횟수 초기화 - return state - -def sql_executor_node(state: SqlAgentState): - print("--- 3. SQL 실행 중 ---") - try: - result = db_instance.run(state['sql_query']) - state['execution_result'] = str(result) - state['validation_error_count'] = 0 # sql 검증 횟수 초기화 - state['execution_error_count'] = 0 # sql 실행 횟수 초기화 - except Exception as e: - state['execution_result'] = f"실행 오류: {e}" - state['validation_error_count'] = 0 # sql 검증 횟수 초기화 - state['execution_error_count'] += 1 # sql 실행 횟수 추가 - return state - -def response_synthesizer_node(state: SqlAgentState): - print("--- 4. 최종 답변 생성 중 ---") - - is_failure = state.get('validation_error_count', 0) >= MAX_ERROR_COUNT or \ - state.get('execution_error_count', 0) >= MAX_ERROR_COUNT - - if is_failure: - if state.get('validation_error_count', 0) >= MAX_ERROR_COUNT: - error_type = "SQL 검증" - error_details = state.get('validation_error') - else: - error_type = "SQL 실행" - error_details = state.get('execution_result') - - context_message = f""" - An attempt to answer the user's question failed after multiple retries. - Failure Type: {error_type} - Last Error: {error_details} - """ - else: - context_message = f""" - Successfully executed the SQL query to answer the user's question. - SQL Query: {state['sql_query']} - SQL Result: {state['execution_result']} - """ - - prompt = RESPONSE_SYNTHESIZER_PROMPT.format( - question=state['question'], - chat_history=state['chat_history'], - context_message=context_message - ) - response = llm_instance.invoke(prompt) - state['final_response'] = response.content - return state - -# --- 엣지 함수 정의 --- -def route_after_intent_classification(state: SqlAgentState): - if state['intent'] == "SQL": - print("--- 의도: SQL 관련 질문 ---") - return "db_classifier" - print("--- 의도: SQL과 관련 없는 질문 ---") - return "unsupported_question" - -def should_execute_sql(state: SqlAgentState): - if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT: - print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") - return "synthesize_failure" - if state.get("validation_error"): - print(f"--- 검증 실패 {state['validation_error_count']}회: SQL 재생성 ---") - return "regenerate" - print("--- 검증 성공: SQL 실행 ---") - return "execute" - -def should_retry_or_respond(state: SqlAgentState): - if state.get("execution_error_count", 0) >= MAX_ERROR_COUNT: - print(f"--- 실행 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") - return "synthesize_failure" - if "오류" in (state.get("execution_result") or ""): - print(f"--- 실행 실패 {state['execution_error_count']}회: SQL 재생성 ---") - return "regenerate" - print("--- 실행 성공: 최종 답변 생성 ---") - return "synthesize_success" - -# --- 그래프 구성 --- -def create_sql_agent_graph() -> StateGraph: - graph = StateGraph(SqlAgentState) - - graph.add_node("intent_classifier", intent_classifier_node) - graph.add_node("db_classifier", db_classifier_node) - graph.add_node("unsupported_question", unsupported_question_node) - graph.add_node("sql_generator", sql_generator_node) - graph.add_node("sql_validator", sql_validator_node) - graph.add_node("sql_executor", sql_executor_node) - graph.add_node("response_synthesizer", response_synthesizer_node) - - graph.set_entry_point("intent_classifier") - - graph.add_conditional_edges( - "intent_classifier", - route_after_intent_classification, - { - "db_classifier": "db_classifier", - "unsupported_question": "unsupported_question" - } - ) - graph.add_edge("unsupported_question", END) - - graph.add_edge("db_classifier", "sql_generator") - - graph.add_edge("sql_generator", "sql_validator") - - graph.add_conditional_edges("sql_validator", should_execute_sql, { - "regenerate": "sql_generator", - "execute": "sql_executor", - "synthesize_failure": "response_synthesizer" - }) - graph.add_conditional_edges("sql_executor", should_retry_or_respond, { - "regenerate": "sql_generator", - "synthesize_success": "response_synthesizer", - "synthesize_failure": "response_synthesizer" - }) - graph.add_edge("response_synthesizer", END) - - return graph.compile() - -sql_agent_app = create_sql_agent_graph() - -# 워크 플로우 그림 작성 -# graph_image_bytes = sql_agent_app.get_graph(xray=True).draw_mermaid_png() -# with open("workflow_graph.png", "wb") as f: -# f.write(graph_image_bytes) \ No newline at end of file diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..365cd08 --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,6 @@ +""" +API 패키지 루트 +""" + + + diff --git a/src/api/v1/__init__.py b/src/api/v1/__init__.py new file mode 100644 index 0000000..cd0ff49 --- /dev/null +++ b/src/api/v1/__init__.py @@ -0,0 +1,6 @@ +""" +API v1 패키지 +""" + + + diff --git a/src/api/v1/endpoints/annotator.py b/src/api/v1/endpoints/annotator.py deleted file mode 100644 index 508fb94..0000000 --- a/src/api/v1/endpoints/annotator.py +++ /dev/null @@ -1,24 +0,0 @@ - -from fastapi import APIRouter, Depends -from core.llm_provider import llm_instance -from services.annotation_service import AnnotationService -from api.v1.schemas.annotator_schemas import AnnotationRequest, AnnotationResponse - -router = APIRouter() - -# AnnotationService 인스턴스를 싱글턴으로 관리 -annotation_service_instance = AnnotationService(llm=llm_instance) - -def get_annotation_service(): - """의존성 주입을 통해 AnnotationService 인스턴스를 제공합니다.""" - return annotation_service_instance - -@router.post("/annotator", response_model=AnnotationResponse) -async def create_annotations( - request: AnnotationRequest, - service: AnnotationService = Depends(get_annotation_service) -): - """ - DB 스키마 정보를 받아 각 요소에 대한 설명을 비동기적으로 생성하여 반환합니다. - """ - return await service.generate_for_schema(request) diff --git a/src/api/v1/endpoints/chat.py b/src/api/v1/endpoints/chat.py deleted file mode 100644 index 69ee393..0000000 --- a/src/api/v1/endpoints/chat.py +++ /dev/null @@ -1,28 +0,0 @@ -# src/api/v1/endpoints/chat.py - -from fastapi import APIRouter, Depends -from api.v1.schemas.chatbot_schemas import ChatRequest, ChatResponse -from services.chatbot_service import ChatbotService - -router = APIRouter() - -def get_chatbot_service(): - return ChatbotService() - -@router.post("/chat", response_model=ChatResponse) -def handle_chat_request( - request: ChatRequest, - service: ChatbotService = Depends(get_chatbot_service) -): - """ - 사용자의 채팅 요청을 받아 챗봇의 답변을 반환합니다. - Args: - request: 챗봇 요청 - service: 챗봇 서비스 로직 - - Returns: - ChatRespone: 챗봇 응답 - """ - final_answer = service.handle_request(request.question, request.chat_history) - - return ChatResponse(answer=final_answer) \ No newline at end of file diff --git a/src/api/v1/routers/annotator.py b/src/api/v1/routers/annotator.py new file mode 100644 index 0000000..e918ce8 --- /dev/null +++ b/src/api/v1/routers/annotator.py @@ -0,0 +1,67 @@ +# src/api/v1/routers/annotator.py + +from fastapi import APIRouter, HTTPException, Depends +from typing import Dict, Any + +from schemas.api.annotator_schemas import AnnotationRequest, AnnotationResponse +from services.annotation.annotation_service import AnnotationService, get_annotation_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/annotator", response_model=AnnotationResponse) +async def create_annotations( + request: AnnotationRequest, + service: AnnotationService = Depends(get_annotation_service) +) -> AnnotationResponse: + """ + DB 스키마 정보를 받아 각 요소에 대한 설명을 비동기적으로 생성하여 반환합니다. + + Args: + request: 어노테이션 요청 (DB 스키마 정보) + service: 어노테이션 서비스 로직 + + Returns: + AnnotationResponse: 어노테이션이 추가된 스키마 정보 + + Raises: + HTTPException: 요청 처리 실패 시 + """ + try: + logger.info(f"Received annotation request for {len(request.databases)} databases") + + response = await service.generate_for_schema(request) + + logger.info("Annotation request processed successfully") + + return response + + except Exception as e: + logger.error(f"Annotation request failed: {e}") + raise HTTPException( + status_code=500, + detail=f"어노테이션 생성 중 오류가 발생했습니다: {e}" + ) + +@router.get("/annotator/health") +async def annotator_health_check( + service: AnnotationService = Depends(get_annotation_service) +) -> Dict[str, Any]: + """ + 어노테이션 서비스의 상태를 확인합니다. + + Returns: + Dict: 서비스 상태 정보 + """ + try: + health_status = await service.health_check() + return health_status + + except Exception as e: + logger.error(f"Annotator health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } diff --git a/src/api/v1/routers/chat.py b/src/api/v1/routers/chat.py new file mode 100644 index 0000000..113196b --- /dev/null +++ b/src/api/v1/routers/chat.py @@ -0,0 +1,91 @@ +# src/api/v1/routers/chat.py + +from fastapi import APIRouter, HTTPException, Depends +from typing import Dict, Any, List + +from schemas.api.chat_schemas import ChatRequest, ChatResponse +from services.chat.chatbot_service import ChatbotService, get_chatbot_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.post("/chat", response_model=ChatResponse) +async def handle_chat_request( + request: ChatRequest, + service: ChatbotService = Depends(get_chatbot_service) +) -> ChatResponse: + """ + 사용자의 채팅 요청을 받아 챗봇의 답변을 반환합니다. + + Args: + request: 챗봇 요청 (질문과 채팅 히스토리) + service: 챗봇 서비스 로직 + + Returns: + ChatResponse: 챗봇 응답 + + Raises: + HTTPException: 요청 처리 실패 시 + """ + try: + logger.info(f"Received chat request: {request.question[:100]}...") + + final_answer = await service.handle_request( + user_question=request.question, + chat_history=request.chat_history + ) + + logger.info("Chat request processed successfully") + + return ChatResponse(answer=final_answer) + + except Exception as e: + logger.error(f"Chat request failed: {e}") + raise HTTPException( + status_code=500, + detail=f"채팅 요청 처리 중 오류가 발생했습니다: {e}" + ) + +@router.get("/chat/health") +async def chat_health_check( + service: ChatbotService = Depends(get_chatbot_service) +) -> Dict[str, Any]: + """ + 챗봇 서비스의 상태를 확인합니다. + + Returns: + Dict: 서비스 상태 정보 + """ + try: + health_status = await service.health_check() + return health_status + + except Exception as e: + logger.error(f"Chat health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + +@router.get("/chat/databases") +async def get_available_databases( + service: ChatbotService = Depends(get_chatbot_service) +) -> Dict[str, List[Dict[str, str]]]: + """ + 사용 가능한 데이터베이스 목록을 반환합니다. + + Returns: + Dict: 데이터베이스 목록 + """ + try: + databases = await service.get_available_databases() + return {"databases": databases} + + except Exception as e: + logger.error(f"Failed to get databases: {e}") + raise HTTPException( + status_code=500, + detail=f"데이터베이스 목록 조회 중 오류가 발생했습니다: {e}" + ) diff --git a/src/api/v1/routers/health.py b/src/api/v1/routers/health.py new file mode 100644 index 0000000..1ea2ce0 --- /dev/null +++ b/src/api/v1/routers/health.py @@ -0,0 +1,77 @@ +# src/api/v1/routers/health.py + +from fastapi import APIRouter, Depends +from typing import Dict, Any + +from services.chat.chatbot_service import ChatbotService, get_chatbot_service +from services.annotation.annotation_service import AnnotationService, get_annotation_service +from services.database.database_service import DatabaseService, get_database_service +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + +@router.get("/health") +async def root_health_check() -> Dict[str, str]: + """ + 루트 헬스체크 엔드포인트, 서버 상태가 정상이면 'ok' 반환합니다. + + Returns: + Dict: 기본 상태 정보 + """ + return { + "status": "ok", + "message": "Welcome to the QGenie Chatbot AI!", + "version": "2.0.0" + } + +@router.get("/health/detailed") +async def detailed_health_check( + chatbot_service: ChatbotService = Depends(get_chatbot_service), + annotation_service: AnnotationService = Depends(get_annotation_service), + database_service: DatabaseService = Depends(get_database_service) +) -> Dict[str, Any]: + """ + 전체 시스템의 상세 헬스체크를 수행합니다. + + Returns: + Dict: 상세 상태 정보 + """ + try: + # 모든 서비스의 헬스체크를 병렬로 실행 + import asyncio + + chatbot_health, annotation_health, database_health = await asyncio.gather( + chatbot_service.health_check(), + annotation_service.health_check(), + database_service.health_check(), + return_exceptions=True + ) + + # 각 서비스 상태 처리 + services_status = { + "chatbot": chatbot_health if not isinstance(chatbot_health, Exception) else {"status": "unhealthy", "error": str(chatbot_health)}, + "annotation": annotation_health if not isinstance(annotation_health, Exception) else {"status": "unhealthy", "error": str(annotation_health)}, + "database": {"status": "healthy" if database_health and not isinstance(database_health, Exception) else "unhealthy"} + } + + # 전체 상태 결정 + all_healthy = all( + service.get("status") == "healthy" + for service in services_status.values() + ) + + return { + "status": "healthy" if all_healthy else "partial", + "services": services_status, + "timestamp": __import__("datetime").datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"Detailed health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e), + "timestamp": __import__("datetime").datetime.now().isoformat() + } diff --git a/src/core/__init__.py b/src/core/__init__.py index e69de29..9c22b0b 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -0,0 +1,15 @@ +# src/core/__init__.py + +""" +코어 모듈 - 기본 인프라스트럭처 구성 요소들 +""" + +from .providers.llm_provider import LLMProvider, get_llm_provider +from .clients.api_client import APIClient, get_api_client + +__all__ = [ + 'LLMProvider', + 'get_llm_provider', + 'APIClient', + 'get_api_client' +] diff --git a/src/core/clients/api_client.py b/src/core/clients/api_client.py new file mode 100644 index 0000000..1c2c9a4 --- /dev/null +++ b/src/core/clients/api_client.py @@ -0,0 +1,261 @@ +# src/core/clients/api_client.py + +import httpx +import asyncio +from typing import List, Dict, Any, Optional, Union +from pydantic import BaseModel +import logging + +# 로깅 설정 +logger = logging.getLogger(__name__) + +class DatabaseInfo(BaseModel): + """데이터베이스 정보 모델""" + connection_name: str + database_name: str + description: str + +class QueryExecutionRequest(BaseModel): + """쿼리 실행 요청 모델""" + user_db_id: str + database: str + query_text: str + +class QueryResultData(BaseModel): + """쿼리 실행 결과 데이터 모델""" + columns: List[str] + data: List[Dict[str, Any]] + +class QueryExecutionResponse(BaseModel): + """쿼리 실행 응답 모델""" + code: str + message: str + data: Union[QueryResultData, str, bool] # 결과 데이터, 에러 메시지 + +class APIClient: + """백엔드 API와 통신하는 클라이언트 클래스""" + + def __init__(self, base_url: str = "http://localhost:39722"): + self.base_url = base_url + self.timeout = httpx.Timeout(30.0) + self.headers = { + "Content-Type": "application/json" + } + self._client: Optional[httpx.AsyncClient] = None + + async def _get_client(self) -> httpx.AsyncClient: + """재사용 가능한 HTTP 클라이언트를 반환합니다.""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=self.timeout) + return self._client + + async def close(self): + """HTTP 클라이언트 연결을 닫습니다.""" + if self._client and not self._client.is_closed: + await self._client.aclose() + # TODO: DB 어노테이션 조회 + async def get_available_databases(self) -> List[DatabaseInfo]: + """사용 가능한 데이터베이스 목록을 가져옵니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/v1/databases", + headers=self.headers + ) + response.raise_for_status() + + data = response.json() + databases = [DatabaseInfo(**db) for db in data.get("databases", [])] + logger.info(f"Successfully fetched {len(databases)} databases") + return databases + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + # TODO: DB 스키마 조회 API 필요 + async def get_database_schema(self, database_name: str) -> str: + """특정 데이터베이스의 스키마 정보를 가져옵니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/api/v1/databases/{database_name}/schema", + headers=self.headers + ) + response.raise_for_status() + + data = response.json() + schema = data.get("schema", "") + logger.info(f"Successfully fetched schema for database: {database_name}") + return schema + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error: {e}") + raise + + async def execute_query( + self, + sql_query: str, + database_name: str, + user_db_id: str = None + ) -> QueryExecutionResponse: + """SQL 쿼리를 Backend 서버에 전송하여 실행하고 결과를 받아옵니다.""" + try: + logger.info(f"Sending SQL query to backend: {sql_query}") + + request_data = QueryExecutionRequest( + user_db_id=user_db_id, + database=database_name, + query_text=sql_query + ) + + client = await self._get_client() + response = await client.post( + f"{self.base_url}/api/query/execute/test", + json=request_data.model_dump(), + headers=self.headers, + timeout=httpx.Timeout(35.0) # 고정 타임아웃 + ) + + response.raise_for_status() # HTTP 에러 시 예외 발생 + + response_data = response.json() + + # data 필드 타입에 따라 처리 + raw_data = response_data.get("data") + parsed_data = raw_data + + # data가 객체 형태(쿼리 결과)인지 확인 + if isinstance(raw_data, dict) and "columns" in raw_data and "data" in raw_data: + try: + parsed_data = QueryResultData(**raw_data) + except Exception as e: + logger.warning(f"Failed to parse query result data: {e}, using raw data") + parsed_data = raw_data + + result = QueryExecutionResponse( + code=response_data.get("code"), + message=response_data.get("message"), + data=parsed_data + ) + + if result.code == "2400": + logger.info(f"Query executed successfully: {result.message}") + else: + logger.warning(f"Query execution returned non-success code: {result.code} - {result.message}") + + return result + + except httpx.TimeoutException: + logger.error("Backend API 요청 시간 초과") + raise + except httpx.ConnectError: + logger.error("Backend 서버 연결 실패") + raise + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"Unexpected error during query execution: {e}") + raise + + async def health_check(self) -> bool: + """API 서버 상태를 확인합니다.""" + try: + client = await self._get_client() + response = await client.get( + f"{self.base_url}/health", + timeout=httpx.Timeout(5.0) + ) + return response.status_code == 200 + except Exception as e: + logger.error(f"Health check failed: {e}") + return False + + async def get_openai_api_key(self) -> str: + """백엔드에서 OpenAI API 키를 가져옵니다.""" + try: + client = await self._get_client() + + # 1단계: 암호화된 API 키 조회 + response = await client.get( + f"{self.base_url}/api/keys/result", + headers=self.headers, + timeout=httpx.Timeout(10.0) + ) + response.raise_for_status() + + data = response.json() + + # data 배열에서 OpenAI 서비스 찾기 + api_keys = data.get("data", []) + openai_key = None + + # 가장 첫번째 OpenAI 키 사용 + for key_info in api_keys: + if key_info.get("service_name") == "OpenAI": + openai_key = key_info.get("id") + break + + if not openai_key: + raise ValueError("백엔드에서 OpenAI API 키를 찾을 수 없습니다.") + + # 2단계: 복호화된 실제 API 키 조회 + decrypt_response = await client.get( + f"{self.base_url}/api/keys/find/decrypted/OpenAI", + headers=self.headers, + timeout=httpx.Timeout(10.0) + ) + decrypt_response.raise_for_status() + + decrypt_data = decrypt_response.json() + + # 복호화된 키 데이터에서 실제 API 키 추출 + data_field = decrypt_data.get("data", {}) + + if isinstance(data_field, dict) and "api_key" in data_field: + actual_api_key = data_field["api_key"] + else: + raise ValueError("백엔드 응답에서 API 키를 찾을 수 없습니다.") + + if not actual_api_key: + raise ValueError("백엔드에서 복호화된 OpenAI API 키를 가져올 수 없습니다.") + + logger.info("Successfully fetched decrypted OpenAI API key from backend") + return actual_api_key + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error occurred while fetching API key: {e.response.status_code} - {e.response.text}") + raise + except httpx.RequestError as e: + logger.error(f"Request error occurred while fetching API key: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error while fetching API key: {e}") + raise + + async def __aenter__(self): + """비동기 컨텍스트 매니저 진입""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """비동기 컨텍스트 매니저 종료""" + await self.close() + +# 싱글톤 인스턴스 +_api_client = APIClient() + +async def get_api_client() -> APIClient: + """API Client 인스턴스를 반환합니다.""" + return _api_client diff --git a/src/core/db_manager.py b/src/core/db_manager.py deleted file mode 100644 index c06ee1a..0000000 --- a/src/core/db_manager.py +++ /dev/null @@ -1,40 +0,0 @@ -# src/core/db_manager.py - -from langchain_community.utilities import SQLDatabase -import os -from dotenv import load_dotenv - -load_dotenv() - -def get_db_connection() -> SQLDatabase: - """SQLDatabase 객체를 생성하고 반환합니다. - - Args: - - Returns: - SQLDatabase: DB와 연결된 SQLDatabase 객체 - """ - db_uri = os.getenv("MYSQL_URI") - if not db_uri: - raise ValueError("DATABASE_URI 환경 변수가 .env 파일에 설정되지 않았습니다.") - return SQLDatabase.from_uri(db_uri) - -def load_predefined_schema(db: SQLDatabase) -> str: - """SQLDatabase 객체를 사용하여 모든 테이블의 스키마 정보를 반환합니다. - - Args: - db: DB와 연결된 SQLDatabase 객체 - - Returns: - str: DB에 포함된 모든 Table의 schema - - """ - try: - all_table_names = db.get_usable_table_names() - return db.get_table_info(table_names=all_table_names) - except Exception as e: - return f"스키마 조회 중 오류 발생: {e}" - -# 앱 전체에서 동일한 객체를 참조(싱글턴 패턴) -db_instance = get_db_connection() -schema_instance = load_predefined_schema(db_instance) diff --git a/src/core/llm_provider.py b/src/core/llm_provider.py deleted file mode 100644 index 3a5a94c..0000000 --- a/src/core/llm_provider.py +++ /dev/null @@ -1,30 +0,0 @@ -# src/core/llm_provider.py - -import os -from langchain_openai import ChatOpenAI -from dotenv import load_dotenv - -load_dotenv() - -def get_llm() -> ChatOpenAI: - """ 사전 설정된 ChatOpenAI 인스턴스를 생성하고 반환합니다. - Prams: - - Returns: - llm: 생성된 ChatOpenAI 객체 - """ - # 환경 변수에서 OpenAI API 키를 가져옵니다. - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError("OPENAI_API_KEY 환경 변수가 설정되지 않았습니다.") - - # 기본값으로 gpt-4o-mini 모델 사용 - llm = ChatOpenAI( - model="gpt-4o-mini", - temperature=0, - api_key=api_key - ) - - return llm - -llm_instance = get_llm() \ No newline at end of file diff --git a/src/core/providers/llm_provider.py b/src/core/providers/llm_provider.py new file mode 100644 index 0000000..80e4c60 --- /dev/null +++ b/src/core/providers/llm_provider.py @@ -0,0 +1,88 @@ +# src/core/providers/llm_provider.py + +import os +import asyncio +import logging +from typing import Optional +from langchain_openai import ChatOpenAI +from core.clients.api_client import get_api_client + +logger = logging.getLogger(__name__) + +class LLMProvider: + """LLM 제공자를 관리하는 클래스""" + + def __init__(self, model_name: str = "gpt-4o-mini", temperature: float = 0): + self.model_name = model_name + self.temperature = temperature + self._llm: Optional[ChatOpenAI] = None + self._api_key: Optional[str] = None + self._api_client = None + + async def _load_api_key(self) -> str: + """백엔드에서 OpenAI API 키를 로드합니다.""" + try: + if self._api_key is None: + if self._api_client is None: + self._api_client = await get_api_client() + + self._api_key = await self._api_client.get_openai_api_key() + return self._api_key + + except Exception as e: + logger.error(f"Failed to fetch API key from backend: {e}") + raise ValueError("백엔드에서 OpenAI API 키를 가져올 수 없습니다. 백엔드 서버를 확인해주세요.") + + async def get_llm(self) -> ChatOpenAI: + """LLM 인스턴스를 비동기적으로 반환합니다.""" + if self._llm is None: + self._llm = await self._create_llm() + return self._llm + + async def _create_llm(self) -> ChatOpenAI: + """ChatOpenAI 인스턴스를 생성합니다.""" + try: + # API 키를 비동기적으로 로드 + api_key = await self._load_api_key() + logger.info("✅ 백엔드에서 OpenAI API 키를 성공적으로 가져왔습니다") + + llm = ChatOpenAI( + model=self.model_name, + temperature=self.temperature, + api_key=api_key + ) + return llm + + except Exception as e: + raise RuntimeError(f"LLM 인스턴스 생성 실패: {e}") + + def update_model(self, model_name: str, temperature: float = None): + """모델 설정을 업데이트하고 인스턴스를 재생성합니다.""" + self.model_name = model_name + if temperature is not None: + self.temperature = temperature + self._llm = None # 다음 호출 시 재생성되도록 함 + + async def refresh_api_key(self): + """API 키를 새로고침합니다.""" + self._api_key = None + self._llm = None # LLM 인스턴스도 재생성 + logger.info("API key refreshed") + + async def test_connection(self) -> bool: + """LLM 연결을 테스트합니다.""" + try: + llm = await self.get_llm() + test_response = await llm.ainvoke("테스트") + return test_response is not None + + except Exception as e: + print(f"LLM 연결 테스트 실패: {e}") + return False + +# 싱글톤 인스턴스 +_llm_provider = LLMProvider() + +async def get_llm_provider() -> LLMProvider: + """LLM Provider 인스턴스를 반환합니다.""" + return _llm_provider diff --git a/src/health_check/__init__.py b/src/health_check/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/health_check/router.py b/src/health_check/router.py deleted file mode 100644 index 8fbbb2e..0000000 --- a/src/health_check/router.py +++ /dev/null @@ -1,9 +0,0 @@ -# src/health_check/router.py -from flask import Flask, jsonify - -app = Flask(__name__) - -@app.route("/health") -def health_check(): - """헬스체크 엔드포인트, 서버 상태가 정상이면 'ok'를 반환합니다.""" - return jsonify(status="ok"), 200 \ No newline at end of file diff --git a/src/main.py b/src/main.py index b25548d..d023e72 100644 --- a/src/main.py +++ b/src/main.py @@ -1,50 +1,100 @@ # src/main.py -import socket -from contextlib import closing -import uvicorn +import logging +from contextlib import asynccontextmanager from fastapi import FastAPI -from api.v1.endpoints import chat, annotator -def find_free_port(): - """사용 가능한 비어있는 포트를 찾는 함수""" - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(('', 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] +from api.v1.routers import chat, annotator, health + +# 로깅 설정 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """애플리케이션 라이프사이클 관리""" + logger.info("QGenie AI Chatbot 시작 중...") + + # 시작 시 초기화 작업 + try: + # 필요한 경우 여기에 초기화 로직 추가 + logger.info("애플리케이션 초기화 완료") + yield + finally: + # 종료 시 정리 작업 + logger.info("애플리케이션 종료 중...") + + # API 클라이언트 정리 + try: + from core.clients.api_client import get_api_client + api_client = await get_api_client() + await api_client.close() + logger.info("API 클라이언트 정리 완료") + except Exception as e: + logger.error(f"API 클라이언트 정리 실패: {e}") + + logger.info("애플리케이션 종료 완료") # FastAPI 앱 인스턴스 생성 app = FastAPI( - title="Qgenie - Agentic SQL Chatbot", - description="LangGraph로 구현된 사전 스키마를 지원하는 SQL 챗봇", - version="1.0.0" + title="QGenie - Agentic SQL Chatbot", + description="LangGraph로 구현된 사전 스키마를 지원하는 SQL 챗봇 (리팩터링 버전)", + version="2.0.0", + lifespan=lifespan +) + +# 라우터 등록 +app.include_router( + health.router, + prefix="/api/v1", + tags=["Health"] ) -# '/api/v1' 경로에 chat 라우터 포함 app.include_router( chat.router, prefix="/api/v1", tags=["Chatbot"] ) -# '/api/v1' 경로에 annotator 라우터 포함 app.include_router( annotator.router, prefix="/api/v1", tags=["Annotator"] ) +# 루트 엔드포인트 @app.get("/") -def health_check(): - """헬스체크 엔드포인트, 서버 상태가 정상이면 'ok' 반환합니다.""" - return {"status": "ok", "message": "Welcome to the QGenie Chatbot AI!"} +async def root(): + """루트 엔드포인트 - 기본 상태 확인""" + return { + "status": "ok", + "message": "Welcome to the QGenie Chatbot AI! (Refactored)", + "version": "2.0.0", + "endpoints": { + "chat": "/api/v1/chat", + "annotator": "/api/v1/annotator", + "health": "/api/v1/health", + "detailed_health": "/api/v1/health/detailed" + } + } if __name__ == "__main__": - # 1. 비어있는 포트 동적 할당 - free_port = find_free_port() - - # 2. 할당된 포트 번호를 콘솔에 특정 형식으로 출력 + import uvicorn + + # 포트 번호 고정 (기존 설정 유지) + free_port = 35816 + + # 할당된 포트 번호를 콘솔에 특정 형식으로 출력 (Electron 연동을 위해) print(f"PYTHON_SERVER_PORT:{free_port}") - - # 3. 할당된 포트로 FastAPI 서버 실행 - uvicorn.run(app, host="127.0.0.1", port=free_port, reload=False) \ No newline at end of file + + # FastAPI 서버 실행 + uvicorn.run( + app, + host="127.0.0.1", + port=free_port, + reload=False, + log_level="info" + ) \ No newline at end of file diff --git a/src/prompts/v1/sql_agent/intent_classifier.yaml b/src/prompts/v1/sql_agent/intent_classifier.yaml index 1c2be08..9aa3a34 100644 --- a/src/prompts/v1/sql_agent/intent_classifier.yaml +++ b/src/prompts/v1/sql_agent/intent_classifier.yaml @@ -2,6 +2,7 @@ _type: prompt input_variables: - question + - chat_history template: | You are an intelligent assistant responsible for classifying user questions. Your task is to determine whether a user's question is related to retrieving information from a database using SQL. @@ -9,6 +10,9 @@ template: | - If the question can be answered with a SQL query, respond with "SQL". - If the question is a simple greeting, a question about your identity, or anything that does not require database access, respond with "non-SQL". + Consider the chat history context when classifying the current question. + If the current question is a follow-up or continuation of a previous SQL-related conversation, classify it as "SQL". + Example 1: Question: "Show me the list of users who signed up last month." Classification: SQL @@ -25,6 +29,11 @@ template: | Question: "What is the weather like today?" Classification: non-SQL - Now, classify the following question: - Question: {question} + Example 5 (Follow-up): + Previous: "Show me sales data for January" + Current: "How about February?" + Classification: SQL (continuation of data query) + + Chat History: {chat_history} + Current Question: {question} Classification: diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..2aec2f4 --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,6 @@ +""" +스키마 루트 패키지 +""" + + + diff --git a/src/schemas/sql_schemas.py b/src/schemas/agent/sql_schemas.py similarity index 55% rename from src/schemas/sql_schemas.py rename to src/schemas/agent/sql_schemas.py index 2201103..0fcbc2b 100644 --- a/src/schemas/sql_schemas.py +++ b/src/schemas/agent/sql_schemas.py @@ -1,6 +1,7 @@ -# src/schemas/sql_schemas.py +# src/schemas/agent/sql_schemas.py + from pydantic import BaseModel, Field class SqlQuery(BaseModel): """SQL 쿼리를 나타내는 Pydantic 모델""" - query: str = Field(description="생성된 SQL 쿼리") \ No newline at end of file + query: str = Field(description="생성된 SQL 쿼리") diff --git a/src/schemas/api/__init__.py b/src/schemas/api/__init__.py new file mode 100644 index 0000000..93577a6 --- /dev/null +++ b/src/schemas/api/__init__.py @@ -0,0 +1,6 @@ +""" +API 스키마 패키지 +""" + + + diff --git a/src/api/v1/schemas/annotator_schemas.py b/src/schemas/api/annotator_schemas.py similarity index 63% rename from src/api/v1/schemas/annotator_schemas.py rename to src/schemas/api/annotator_schemas.py index 7db3174..6dda6d7 100644 --- a/src/api/v1/schemas/annotator_schemas.py +++ b/src/schemas/api/annotator_schemas.py @@ -1,47 +1,60 @@ -# src/api/v1/schemas/annotator_schemas.py +# src/schemas/api/annotator_schemas.py from pydantic import BaseModel, Field from typing import List, Dict, Any class Column(BaseModel): + """데이터베이스 컬럼 모델""" column_name: str data_type: str class Table(BaseModel): + """데이터베이스 테이블 모델""" table_name: str columns: List[Column] sample_rows: List[Dict[str, Any]] class Relationship(BaseModel): + """테이블 관계 모델""" from_table: str from_columns: List[str] to_table: str to_columns: List[str] class Database(BaseModel): + """데이터베이스 모델""" database_name: str tables: List[Table] relationships: List[Relationship] class AnnotationRequest(BaseModel): + """어노테이션 요청 모델""" dbms_type: str databases: List[Database] -class AnnotatedColumn(Column): +class AnnotatedColumn(BaseModel): + """어노테이션이 추가된 컬럼 모델""" + column_name: str description: str = Field(..., description="AI가 생성한 컬럼 설명") -class AnnotatedTable(Table): +class AnnotatedTable(BaseModel): + """어노테이션이 추가된 테이블 모델""" + table_name: str description: str = Field(..., description="AI가 생성한 테이블 설명") columns: List[AnnotatedColumn] class AnnotatedRelationship(Relationship): + """어노테이션이 추가된 관계 모델""" description: str = Field(..., description="AI가 생성한 관계 설명") -class AnnotatedDatabase(Database): +class AnnotatedDatabase(BaseModel): + """어노테이션이 추가된 데이터베이스 모델""" + database_name: str description: str = Field(..., description="AI가 생성한 데이터베이스 설명") tables: List[AnnotatedTable] relationships: List[AnnotatedRelationship] class AnnotationResponse(BaseModel): + """어노테이션 응답 모델""" dbms_type: str databases: List[AnnotatedDatabase] diff --git a/src/api/v1/schemas/chatbot_schemas.py b/src/schemas/api/chat_schemas.py similarity index 79% rename from src/api/v1/schemas/chatbot_schemas.py rename to src/schemas/api/chat_schemas.py index f3ae457..c18d956 100644 --- a/src/api/v1/schemas/chatbot_schemas.py +++ b/src/schemas/api/chat_schemas.py @@ -1,4 +1,4 @@ -# src/api/v1/schemas/chatbot_schemas.py +# src/schemas/api/chat_schemas.py from pydantic import BaseModel from typing import List, Optional @@ -9,8 +9,10 @@ class ChatMessage(BaseModel): content: str class ChatRequest(BaseModel): + """채팅 요청 모델""" question: str chat_history: Optional[List[ChatMessage]] = None class ChatResponse(BaseModel): + """채팅 응답 모델""" answer: str diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..5d88207 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,18 @@ +# src/services/__init__.py + +""" +서비스 계층 - 비즈니스 로직 구성 요소들 +""" + +from .chat.chatbot_service import ChatbotService, get_chatbot_service +from .annotation.annotation_service import AnnotationService, get_annotation_service +from .database.database_service import DatabaseService, get_database_service + +__all__ = [ + 'ChatbotService', + 'get_chatbot_service', + 'AnnotationService', + 'get_annotation_service', + 'DatabaseService', + 'get_database_service' +] diff --git a/src/services/annotation/__init__.py b/src/services/annotation/__init__.py new file mode 100644 index 0000000..85f7091 --- /dev/null +++ b/src/services/annotation/__init__.py @@ -0,0 +1,6 @@ +""" +어노테이션 서비스 패키지 +""" + + + diff --git a/src/services/annotation/annotation_service.py b/src/services/annotation/annotation_service.py new file mode 100644 index 0000000..356e3f9 --- /dev/null +++ b/src/services/annotation/annotation_service.py @@ -0,0 +1,260 @@ +# src/services/annotation/annotation_service.py + +import asyncio +from typing import List, Dict, Any +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +from schemas.api.annotator_schemas import ( + AnnotationRequest, AnnotationResponse, + Database, Table, Column, Relationship, + AnnotatedDatabase, AnnotatedTable, AnnotatedColumn, AnnotatedRelationship +) +from core.providers.llm_provider import LLMProvider, get_llm_provider +import logging + +logger = logging.getLogger(__name__) + +class AnnotationService: + """어노테이션 생성과 관련된 모든 비즈니스 로직을 담당하는 서비스 클래스""" + + def __init__(self, llm_provider: LLMProvider = None): + self.llm_provider = llm_provider + + async def _initialize_dependencies(self): + """필요한 의존성들을 초기화합니다.""" + if self.llm_provider is None: + self.llm_provider = await get_llm_provider() + + async def _generate_description(self, template: str, **kwargs) -> str: + """LLM을 비동기적으로 호출하여 설명을 생성하는 헬퍼 함수""" + try: + await self._initialize_dependencies() + + prompt = ChatPromptTemplate.from_template(template) + llm = await self.llm_provider.get_llm() + chain = prompt | llm | StrOutputParser() + + result = await chain.ainvoke(kwargs) + return result.strip() + + except Exception as e: + logger.error(f"Failed to generate description: {e}") + return f"설명 생성 실패: {e}" + + async def _annotate_column( + self, + table_name: str, + sample_rows: str, + column: Column + ) -> AnnotatedColumn: + """단일 컬럼을 비동기적으로 어노테이트합니다.""" + try: + column_desc = await self._generate_description( + """ + 테이블 '{table_name}'의 컬럼 '{column_name}'(타입: {data_type})의 역할을 한국어로 간결하게 설명해줘. + 샘플 데이터: {sample_rows} + """, + table_name=table_name, + column_name=column.column_name, + data_type=column.data_type, + sample_rows=sample_rows + ) + + return AnnotatedColumn( + **column.model_dump(), + description=column_desc + ) + + except Exception as e: + logger.error(f"Failed to annotate column {column.column_name}: {e}") + return AnnotatedColumn( + **column.model_dump(), + description=f"설명 생성 실패: {e}" + ) + + async def _annotate_table(self, db_name: str, table: Table) -> AnnotatedTable: + """단일 테이블과 그 컬럼들을 비동기적으로 어노테이트합니다.""" + try: + sample_rows_str = str(table.sample_rows[:3]) + + # 테이블 설명 생성과 모든 컬럼 설명을 동시에 병렬로 처리 + table_desc_task = self._generate_description( + "데이터베이스 '{db_name}'에 속한 테이블 '{table_name}'의 역할을 한국어로 간결하게 설명해줘.", + db_name=db_name, + table_name=table.table_name + ) + + column_tasks = [ + self._annotate_column(table.table_name, sample_rows_str, col) + for col in table.columns + ] + + # 모든 작업을 병렬 실행 + results = await asyncio.gather( + table_desc_task, + *column_tasks, + return_exceptions=True + ) + + # 결과 처리 + table_desc = results[0] if not isinstance(results[0], Exception) else "테이블 설명 생성 실패" + annotated_columns = [ + result for result in results[1:] + if not isinstance(result, Exception) + ] + + return AnnotatedTable( + **table.model_dump(exclude={'columns'}), + description=table_desc, + columns=annotated_columns + ) + + except Exception as e: + logger.error(f"Failed to annotate table {table.table_name}: {e}") + # 실패 시 기본 어노테이션 반환 + annotated_columns = [ + AnnotatedColumn(**col.model_dump(), description="설명 생성 실패") + for col in table.columns + ] + return AnnotatedTable( + **table.model_dump(exclude={'columns'}), + description=f"테이블 설명 생성 실패: {e}", + columns=annotated_columns + ) + + async def _annotate_relationship(self, relationship: Relationship) -> AnnotatedRelationship: + """단일 관계를 비동기적으로 어노테이트합니다.""" + try: + rel_desc = await self._generate_description( + """ + 테이블 '{from_table}'이(가) 테이블 '{to_table}'을(를) 참조하고 있습니다. + 이 관계를 한국어 문장으로 설명해줘. + """, + from_table=relationship.from_table, + to_table=relationship.to_table + ) + + return AnnotatedRelationship( + **relationship.model_dump(), + description=rel_desc + ) + + except Exception as e: + logger.error(f"Failed to annotate relationship: {e}") + return AnnotatedRelationship( + **relationship.model_dump(), + description=f"관계 설명 생성 실패: {e}" + ) + + async def generate_for_schema(self, request: AnnotationRequest) -> AnnotationResponse: + """입력된 스키마 전체에 대한 어노테이션을 비동기적으로 생성합니다.""" + try: + logger.info(f"Starting annotation generation for {len(request.databases)} databases") + + annotated_databases = [] + + for db in request.databases: + try: + # DB 설명, 모든 테이블, 모든 관계 설명을 동시에 병렬로 처리 + db_desc_task = self._generate_description( + "데이터베이스 '{db_name}'의 역할을 한국어로 간결하게 설명해줘.", + db_name=db.database_name + ) + + table_tasks = [ + self._annotate_table(db.database_name, table) + for table in db.tables + ] + + relationship_tasks = [ + self._annotate_relationship(rel) + for rel in db.relationships + ] + + # 모든 작업을 병렬 실행 + all_results = await asyncio.gather( + db_desc_task, + *table_tasks, + *relationship_tasks, + return_exceptions=True + ) + + # 결과 분리 + db_desc = all_results[0] if not isinstance(all_results[0], Exception) else "DB 설명 생성 실패" + + num_tables = len(table_tasks) + annotated_tables = [ + result for result in all_results[1:1+num_tables] + if not isinstance(result, Exception) + ] + + annotated_relationships = [ + result for result in all_results[1+num_tables:] + if not isinstance(result, Exception) + ] + + annotated_databases.append( + AnnotatedDatabase( + database_name=db.database_name, + description=db_desc, + tables=annotated_tables, + relationships=annotated_relationships + ) + ) + + logger.info(f"Completed annotation for database: {db.database_name}") + + except Exception as e: + logger.error(f"Failed to annotate database {db.database_name}: {e}") + # 실패한 데이터베이스도 기본값으로 포함 + annotated_databases.append( + AnnotatedDatabase( + database_name=db.database_name, + description=f"데이터베이스 어노테이션 생성 실패: {e}", + tables=[], + relationships=[] + ) + ) + + logger.info("Annotation generation completed successfully") + + return AnnotationResponse( + dbms_type=request.dbms_type, + databases=annotated_databases + ) + + except Exception as e: + logger.error(f"Failed to generate annotations: {e}") + # 전체 실패 시 기본 응답 반환 + return AnnotationResponse( + dbms_type=request.dbms_type, + databases=[] + ) + + async def health_check(self) -> Dict[str, Any]: + """어노테이션 서비스의 상태를 확인합니다.""" + try: + await self._initialize_dependencies() + + # LLM 연결 테스트 + llm_status = await self.llm_provider.test_connection() + + return { + "status": "healthy" if llm_status else "unhealthy", + "llm_provider": "connected" if llm_status else "disconnected" + } + + except Exception as e: + logger.error(f"Annotation service health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + +# 싱글톤 인스턴스 +_annotation_service = AnnotationService() + +async def get_annotation_service() -> AnnotationService: + """Annotation Service 인스턴스를 반환합니다.""" + return _annotation_service diff --git a/src/services/annotation_service.py b/src/services/annotation_service.py deleted file mode 100644 index 4e0e823..0000000 --- a/src/services/annotation_service.py +++ /dev/null @@ -1,100 +0,0 @@ -# src/services/annotation_service.py - -import asyncio -from langchain_openai import ChatOpenAI -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import StrOutputParser -from api.v1.schemas.annotator_schemas import ( - AnnotationRequest, AnnotationResponse, - Database, Table, Column, Relationship, - AnnotatedDatabase, AnnotatedTable, AnnotatedColumn, AnnotatedRelationship -) - -class AnnotationService(): - """ - 어노테이션 생성과 관련된 모든 비즈니스 로직을 담당하는 서비스 클래스. - LLM 호출을 비동기적으로 처리하여 성능을 최적화합니다. - """ - def __init__(self, llm: ChatOpenAI): - self.llm = llm - - async def _generate_description(self, template: str, **kwargs) -> str: - """LLM을 비동기적으로 호출하여 설명을 생성하는 헬퍼 함수""" - prompt = ChatPromptTemplate.from_template(template) - chain = prompt | self.llm | StrOutputParser() - return await chain.ainvoke(kwargs) - - async def _annotate_column(self, table_name: str, sample_rows: str, column: Column) -> AnnotatedColumn: - """단일 컬럼을 비동기적으로 어노테이트합니다.""" - column_desc = await self._generate_description( - """ - 테이블 '{table_name}'의 컬럼 '{column_name}'(타입: {data_type})의 역할을 한국어로 간결하게 설명해줘. - 샘플 데이터: {sample_rows} - """, - table_name=table_name, - column_name=column.column_name, - data_type=column.data_type, - sample_rows=sample_rows - ) - return AnnotatedColumn(**column.model_dump(), description=column_desc.strip()) - - async def _annotate_table(self, db_name: str, table: Table) -> AnnotatedTable: - """단일 테이블과 그 컬럼들을 비동기적으로 어노테이트합니다.""" - sample_rows_str = str(table.sample_rows[:3]) - - # 테이블 설명 생성과 모든 컬럼 설명을 동시에 병렬로 처리 - table_desc_task = self._generate_description( - "데이터베이스 '{db_name}'에 속한 테이블 '{table_name}'의 역할을 한국어로 간결하게 설명해줘.", - db_name=db_name, table_name=table.table_name - ) - column_tasks = [self._annotate_column(table.table_name, sample_rows_str, col) for col in table.columns] - - results = await asyncio.gather(table_desc_task, *column_tasks) - - table_desc = results[0].strip() - annotated_columns = results[1:] - - return AnnotatedTable(**table.model_dump(exclude={'columns'}), description=table_desc, columns=annotated_columns) - - async def _annotate_relationship(self, relationship: Relationship) -> AnnotatedRelationship: - """단일 관계를 비동기적으로 어노테이트합니다.""" - rel_desc = await self._generate_description( - """ - 테이블 '{from_table}'이(가) 테이블 '{to_table}'을(를) 참조하고 있습니다. - 이 관계를 한국어 문장으로 설명해줘. - """, - from_table=relationship.from_table, to_table=relationship.to_table - ) - return AnnotatedRelationship(**relationship.model_dump(), description=rel_desc.strip()) - - async def generate_for_schema(self, request: AnnotationRequest) -> AnnotationResponse: - """ - 입력된 스키마 전체에 대한 어노테이션을 비동기적으로 생성합니다. - """ - annotated_databases = [] - for db in request.databases: - # DB 설명, 모든 테이블, 모든 관계 설명을 동시에 병렬로 처리 - db_desc_task = self._generate_description( - "데이터베이스 '{db_name}'의 역할을 한국어로 간결하게 설명해줘.", - db_name=db.database_name - ) - - table_tasks = [self._annotate_table(db.database_name, table) for table in db.tables] - relationship_tasks = [self._annotate_relationship(rel) for rel in db.relationships] - - db_desc_result, *other_results = await asyncio.gather(db_desc_task, *table_tasks, *relationship_tasks) - - num_tables = len(table_tasks) - annotated_tables = other_results[:num_tables] - annotated_relationships = other_results[num_tables:] - - annotated_databases.append( - AnnotatedDatabase( - database_name=db.database_name, - description=db_desc_result.strip(), - tables=annotated_tables, - relationships=annotated_relationships - ) - ) - - return AnnotationResponse(dbms_type=request.dbms_type, databases=annotated_databases) diff --git a/src/services/chat/__init__.py b/src/services/chat/__init__.py new file mode 100644 index 0000000..3a90e18 --- /dev/null +++ b/src/services/chat/__init__.py @@ -0,0 +1,6 @@ +""" +챗 서비스 패키지 +""" + + + diff --git a/src/services/chat/chatbot_service.py b/src/services/chat/chatbot_service.py new file mode 100644 index 0000000..5f2d3ff --- /dev/null +++ b/src/services/chat/chatbot_service.py @@ -0,0 +1,142 @@ +# src/services/chat/chatbot_service.py + +import asyncio +from typing import List, Optional, Dict, Any +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage + +from schemas.api.chat_schemas import ChatMessage +from agents.sql_agent.graph import SqlAgentGraph +from core.providers.llm_provider import LLMProvider, get_llm_provider +from services.database.database_service import DatabaseService, get_database_service +import logging + +logger = logging.getLogger(__name__) + +class ChatbotService: + """챗봇 관련 비즈니스 로직을 담당하는 서비스 클래스""" + + def __init__( + self, + llm_provider: LLMProvider = None, + database_service: DatabaseService = None + ): + self.llm_provider = llm_provider + self.database_service = database_service + self._sql_agent_graph: Optional[SqlAgentGraph] = None + + async def _initialize_dependencies(self): + """필요한 의존성들을 초기화합니다.""" + if self.llm_provider is None: + self.llm_provider = await get_llm_provider() + + if self.database_service is None: + self.database_service = await get_database_service() + + if self._sql_agent_graph is None: + self._sql_agent_graph = SqlAgentGraph( + self.llm_provider, + self.database_service + ) + + async def handle_request( + self, + user_question: str, + chat_history: Optional[List[ChatMessage]] = None + ) -> str: + """채팅 요청을 처리하고 응답을 반환합니다.""" + try: + # 의존성 초기화 + await self._initialize_dependencies() + + # 채팅 히스토리를 LangChain 메시지로 변환 + langchain_messages = await self._convert_chat_history(chat_history) + + # 초기 상태 구성 + initial_state = { + "question": user_question, + "chat_history": langchain_messages, + "validation_error_count": 0, + "execution_error_count": 0 + } + + # SQL Agent 그래프 실행 + final_state = await self._sql_agent_graph.run(initial_state) + + return final_state.get('final_response', "죄송합니다. 응답을 생성할 수 없습니다.") + + except Exception as e: + logger.error(f"Chat request handling failed: {e}") + # 에러 상황에서는 예외를 다시 발생시켜 라우터에서 HTTP 에러로 처리되도록 함 + raise e + + async def _convert_chat_history( + self, + chat_history: Optional[List[ChatMessage]] + ) -> List[BaseMessage]: + """채팅 히스토리를 LangChain 메시지 형식으로 변환합니다.""" + langchain_messages: List[BaseMessage] = [] + + if chat_history: + for message in chat_history: + try: + if message.role == 'u': + langchain_messages.append(HumanMessage(content=message.content)) + elif message.role == 'a': + langchain_messages.append(AIMessage(content=message.content)) + except Exception as e: + logger.warning(f"Failed to convert message: {e}") + continue + + return langchain_messages + + async def health_check(self) -> Dict[str, Any]: + """챗봇 서비스의 상태를 확인합니다.""" + try: + await self._initialize_dependencies() + + # LLM 연결 테스트 + llm_status = await self.llm_provider.test_connection() + + # 데이터베이스 서비스 상태 확인 + db_status = await self.database_service.health_check() + + overall_status = llm_status and db_status + + return { + "status": "healthy" if overall_status else "unhealthy", + "llm_provider": "connected" if llm_status else "disconnected", + "database_service": "connected" if db_status else "disconnected" + } + + except Exception as e: + logger.error(f"Health check failed: {e}") + return { + "status": "unhealthy", + "error": str(e) + } + + async def get_available_databases(self) -> List[Dict[str, str]]: + """사용 가능한 데이터베이스 목록을 반환합니다.""" + try: + await self._initialize_dependencies() + databases = await self.database_service.get_available_databases() + + return [ + { + "name": db.database_name, + "description": db.description, + "connection": db.connection_name + } + for db in databases + ] + + except Exception as e: + logger.error(f"Failed to get available databases: {e}") + return [] + +# 싱글톤 인스턴스 +_chatbot_service = ChatbotService() + +async def get_chatbot_service() -> ChatbotService: + """Chatbot Service 인스턴스를 반환합니다.""" + return _chatbot_service diff --git a/src/services/chatbot_service.py b/src/services/chatbot_service.py deleted file mode 100644 index 3a17cfc..0000000 --- a/src/services/chatbot_service.py +++ /dev/null @@ -1,34 +0,0 @@ -# src/services/chatbot_service.py - -from agents.sql_agent_graph import sql_agent_app -from api.v1.schemas.chatbot_schemas import ChatMessage # --- 추가된 부분 --- -from langchain_core.messages import HumanMessage, AIMessage, BaseMessage # --- 추가된 부분 --- -from typing import List, Optional # --- 추가된 부분 --- -#from core.db_manager import schema_instance - -class ChatbotService(): - # TODO: schema API 요청 - # def __init__(self): - # self.db_schema = schema_instance - - def handle_request(self, user_question: str, chat_history: Optional[List[ChatMessage]] = None) -> dict: - - langchain_messages: List[BaseMessage] = [] - if chat_history: - for message in chat_history: - if message.role == 'user': - langchain_messages.append(HumanMessage(content=message.content)) - elif message.role == 'assistant': - langchain_messages.append(AIMessage(content=message.content)) - - initial_state = { - "question": user_question, - "chat_history": langchain_messages, - # "db_schema": self.db_schema, - "validation_error_count": 0, - "execution_error_count": 0 - } - - final_state = sql_agent_app.invoke(initial_state) - - return final_state['final_response'] \ No newline at end of file diff --git a/src/services/database/__init__.py b/src/services/database/__init__.py new file mode 100644 index 0000000..435b095 --- /dev/null +++ b/src/services/database/__init__.py @@ -0,0 +1,6 @@ +""" +데이터베이스 서비스 패키지 +""" + + + diff --git a/src/services/database/database_service.py b/src/services/database/database_service.py new file mode 100644 index 0000000..40d3f14 --- /dev/null +++ b/src/services/database/database_service.py @@ -0,0 +1,157 @@ +# src/services/database/database_service.py + +import asyncio +from typing import List, Optional, Dict +from core.clients.api_client import APIClient, DatabaseInfo, get_api_client +import logging + +logger = logging.getLogger(__name__) + +class DatabaseService: + """데이터베이스 관련 비즈니스 로직을 담당하는 서비스 클래스""" + + def __init__(self, api_client: APIClient = None): + self.api_client = api_client + self._cached_databases: Optional[List[DatabaseInfo]] = None + self._cached_schemas: Dict[str, str] = {} + + async def _get_api_client(self) -> APIClient: + """API 클라이언트를 가져옵니다.""" + if self.api_client is None: + self.api_client = await get_api_client() + return self.api_client + + async def get_available_databases(self) -> List[DatabaseInfo]: + """사용 가능한 데이터베이스 목록을 가져옵니다.""" + try: + if self._cached_databases is None: + api_client = await self._get_api_client() + self._cached_databases = await api_client.get_available_databases() + logger.info(f"Cached {len(self._cached_databases)} databases") + + return self._cached_databases + + except Exception as e: + logger.error(f"Failed to fetch databases: {e}") + raise RuntimeError(f"데이터베이스 목록을 가져올 수 없습니다. 백엔드 서버를 확인해주세요: {e}") + + async def get_schema_for_db(self, db_name: str) -> str: + """특정 데이터베이스의 스키마를 가져옵니다.""" + try: + if db_name not in self._cached_schemas: + api_client = await self._get_api_client() + schema = await api_client.get_database_schema(db_name) + self._cached_schemas[db_name] = schema + logger.info(f"Cached schema for database: {db_name}") + + return self._cached_schemas[db_name] + + except Exception as e: + logger.error(f"Failed to fetch schema for {db_name}: {e}") + raise RuntimeError(f"데이터베이스 '{db_name}' 스키마를 가져올 수 없습니다. 백엔드 서버를 확인해주세요: {e}") + + async def execute_query(self, sql_query: str, database_name: str = None, user_db_id: str = None) -> str: + """SQL 쿼리를 실행하고 결과를 반환합니다.""" + try: + if not database_name: + logger.warning("Database name not provided, using default") + database_name = "default" + + logger.info(f"Executing SQL query on database '{database_name}': {sql_query}") + + api_client = await self._get_api_client() + response = await api_client.execute_query( + sql_query=sql_query, + database_name=database_name, + user_db_id=user_db_id + ) + + # 백엔드 응답 코드 확인 + if response.code == "2400": + logger.info(f"Query executed successfully: {response.message}") + + # 응답 데이터 형태에 따라 다른 메시지 반환 + if hasattr(response.data, 'columns') and hasattr(response.data, 'data'): + # 쿼리 결과 데이터가 있는 경우 + row_count = len(response.data.data) + col_count = len(response.data.columns) + return f"쿼리가 성공적으로 실행되었습니다. {row_count}개 행, {col_count}개 컬럼의 결과를 반환했습니다." + else: + # 일반적인 성공 메시지 + return "쿼리가 성공적으로 실행되었습니다." + else: + # data에 에러 메시지가 있는지 확인 + error_detail = "" + if isinstance(response.data, str): + error_detail = f" 상세: {response.data}" + + error_msg = f"쿼리 실행 실패: {response.message} (코드: {response.code}){error_detail}" + logger.error(error_msg) + return error_msg + + except Exception as e: + logger.error(f"Error during query execution: {e}") + return f"쿼리 실행 중 오류 발생: {e}" + + async def _get_fallback_databases(self) -> List[DatabaseInfo]: + """API 실패 시 사용할 폴백 데이터베이스 목록""" + return [ + DatabaseInfo( + connection_name="local_mysql", + database_name="sakila", + description="DVD 대여점 비즈니스 모델을 다루는 샘플 데이터베이스" + ), + DatabaseInfo( + connection_name="local_mysql", + database_name="ecom_prod", + description="온라인 쇼핑몰의 운영 데이터베이스" + ), + DatabaseInfo( + connection_name="local_mysql", + database_name="hr_analytics", + description="회사의 인사 관리 데이터베이스" + ), + DatabaseInfo( + connection_name="local_mysql", + database_name="web_logs", + description="웹사이트 트래픽 분석을 위한 로그 데이터베이스" + ) + ] + + async def get_fallback_schema(self, db_name: str) -> str: + """API 실패 시 사용할 폴백 스키마""" + fallback_schemas = { + "sakila": "CREATE TABLE actor (actor_id INT, first_name VARCHAR(45), last_name VARCHAR(45))", + "ecom_prod": "CREATE TABLE products (product_id INT, name VARCHAR(100), price DECIMAL(10,2))", + "hr_analytics": "CREATE TABLE employees (employee_id INT, name VARCHAR(100), department VARCHAR(50))", + "web_logs": "CREATE TABLE access_logs (log_id INT, timestamp DATETIME, ip_address VARCHAR(45))" + } + return fallback_schemas.get(db_name, "Schema information not available") + + async def refresh_cache(self): + """캐시를 새로고침합니다.""" + self._cached_databases = None + self._cached_schemas.clear() + logger.info("Database cache refreshed") + + async def clear_cache(self): + """캐시를 클리어합니다.""" + self._cached_databases = None + self._cached_schemas.clear() + logger.info("Database cache cleared") + + async def health_check(self) -> bool: + """데이터베이스 서비스 상태를 확인합니다.""" + try: + api_client = await self._get_api_client() + return await api_client.health_check() + except Exception as e: + logger.error(f"Database service health check failed: {e}") + return False + +# 싱글톤 인스턴스 +_database_service = DatabaseService() + +async def get_database_service() -> DatabaseService: + """Database Service 인스턴스를 반환합니다.""" + return _database_service diff --git a/test_services.py b/test_services.py new file mode 100644 index 0000000..5b66de8 --- /dev/null +++ b/test_services.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +""" +서비스 테스트 스크립트 +""" + +import asyncio +import sys +import os + +# src 디렉토리를 Python 경로에 추가 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +async def test_llm_provider(): + """LLM Provider 테스트""" + print("🔍 LLM Provider 테스트 중...") + try: + from core.providers.llm_provider import get_llm_provider + + provider = await get_llm_provider() + print(f"✅ LLM Provider 생성 성공: {provider.model_name}") + + # 연결 테스트 + is_connected = await provider.test_connection() + print(f"🔗 LLM 연결 상태: {'성공' if is_connected else '실패'}") + + # API 키 소스 확인 (로그에서 확인 가능) + print("💡 백엔드에서 API 키를 가져옵니다") + + except Exception as e: + print(f"❌ LLM Provider 테스트 실패: {e}") + +async def test_api_client(): + """API Client 테스트""" + print("\n🔍 API Client 테스트 중...") + try: + from core.clients.api_client import get_api_client + + client = await get_api_client() + print("✅ API Client 생성 성공") + + # OpenAI API 키 조회 테스트 + try: + api_key = await client.get_openai_api_key() + print(f"🔑 OpenAI API 키 조회 성공: {api_key[:20]}...") + except Exception as e: + print(f"⚠️ OpenAI API 키 조회 실패: {e}") + + # 헬스체크 테스트 + try: + is_healthy = await client.health_check() + print(f"🏥 백엔드 서버 상태: {'정상' if is_healthy else '비정상'}") + except Exception as e: + print(f"⚠️ 백엔드 서버 연결 실패: {e}") + + except Exception as e: + print(f"❌ API Client 테스트 실패: {e}") + +async def test_database_service(): + """Database Service 테스트""" + print("\n🔍 Database Service 테스트 중...") + try: + from services.database.database_service import get_database_service + + service = await get_database_service() + print("✅ Database Service 생성 성공") + + # 사용 가능한 데이터베이스 목록 조회 + try: + databases = await service.get_available_databases() + print(f"🗄️ 사용 가능한 데이터베이스: {len(databases)}개") + print("✅ 백엔드 API에서 데이터베이스 목록을 성공적으로 가져왔습니다") + + for db in databases[:3]: # 처음 3개만 출력 + print(f" - {db.database_name}: {db.description}") + except Exception as e: + print(f"⚠️ 데이터베이스 목록 조회 실패: {e}") + + except Exception as e: + print(f"❌ Database Service 테스트 실패: {e}") + +async def test_annotation_service(): + """Annotation Service 테스트""" + print("\n🔍 Annotation Service 테스트 중...") + try: + from services.annotation.annotation_service import get_annotation_service + + service = await get_annotation_service() + print("✅ Annotation Service 생성 성공") + + # 헬스체크 테스트 + try: + health = await service.health_check() + print(f"🏥 어노테이션 서비스 상태: {health}") + except Exception as e: + print(f"⚠️ 어노테이션 서비스 헬스체크 실패: {e}") + + except Exception as e: + print(f"❌ Annotation Service 테스트 실패: {e}") + +async def test_chatbot_service(): + """Chatbot Service 테스트""" + print("\n🔍 Chatbot Service 테스트 중...") + try: + from services.chat.chatbot_service import get_chatbot_service + + service = await get_chatbot_service() + print("✅ Chatbot Service 생성 성공") + + # 헬스체크 테스트 + try: + health = await service.health_check() + print(f"🏥 챗봇 서비스 상태: {health}") + except Exception as e: + print(f"⚠️ 챗봇 서비스 헬스체크 실패: {e}") + + except Exception as e: + print(f"❌ Chatbot Service 테스트 실패: {e}") + +async def test_sql_agent(): + """SQL Agent 테스트""" + print("\n🔍 SQL Agent 테스트 중...") + try: + from agents.sql_agent.graph import SqlAgentGraph + from core.providers.llm_provider import get_llm_provider + from services.database.database_service import get_database_service + + llm_provider = await get_llm_provider() + db_service = await get_database_service() + + agent = SqlAgentGraph(llm_provider, db_service) + print("✅ SQL Agent 생성 성공") + + # 그래프 시각화 PNG 저장 + try: + success = agent.save_graph_visualization("sql_agent_workflow.png") + if success: + print("📊 그래프 시각화 PNG 저장 성공: sql_agent_workflow.png") + else: + print("⚠️ 그래프 시각화 PNG 저장 실패") + except Exception as e: + print(f"⚠️ 그래프 시각화 생성 실패: {e}") + + except Exception as e: + print(f"❌ SQL Agent 테스트 실패: {e}") + +async def test_end_to_end_chat(): + """실제 채팅 요청 End-to-End 테스트""" + print("\n🔍 End-to-End 채팅 테스트 중...") + try: + from services.chat.chatbot_service import get_chatbot_service + import time + + service = await get_chatbot_service() + + # SQL 관련 질문으로 테스트 + test_questions = [ + "사용자 테이블에서 모든 데이터를 조회해주세요", + "가장 많이 주문한 고객을 찾아주세요", + ] + + for i, question in enumerate(test_questions, 1): + print(f"🤖 테스트 질문 {i}: {question}") + start_time = time.time() + + try: + response = await service.handle_request(user_question=question) + end_time = time.time() + response_time = round(end_time - start_time, 2) + + print(f"✅ 응답 시간: {response_time}초") + print(f"📝 응답: {response[:100]}{'...' if len(response) > 100 else ''}") + except Exception as e: + print(f"❌ 질문 {i} 실패: {e}") + + print("---") + + except Exception as e: + print(f"❌ End-to-End 테스트 실패: {e}") + +async def test_annotation_functionality(): + """어노테이션 기능 실제 사용 테스트""" + print("\n🔍 어노테이션 기능 테스트 중...") + try: + from services.annotation.annotation_service import get_annotation_service + from schemas.api.annotator_schemas import Database, Table, Column + + service = await get_annotation_service() + + # 샘플 데이터로 어노테이션 테스트 + sample_database = Database( + database_name="test_db", + tables=[ + Table( + table_name="users", + columns=[ + Column(column_name="id", data_type="int"), + Column(column_name="name", data_type="varchar"), + Column(column_name="email", data_type="varchar") + ], + sample_rows=["1, John Doe, john@example.com"] + ) + ], + relationships=[] + ) + + try: + result = await service.generate_annotations(sample_database) + print(f"✅ 어노테이션 생성 성공") + print(f"📝 생성된 테이블 수: {len(result.tables)}") + if result.tables: + print(f"📝 첫 번째 테이블 설명: {result.tables[0].description[:100]}...") + except Exception as e: + print(f"⚠️ 어노테이션 생성 실패: {e}") + + except Exception as e: + print(f"❌ 어노테이션 기능 테스트 실패: {e}") + +async def test_error_scenarios(): + """에러 시나리오 테스트""" + print("\n🔍 에러 시나리오 테스트 중...") + + # 잘못된 API 키로 LLM 테스트 + print("🧪 잘못된 API 키 시나리오...") + try: + from core.providers.llm_provider import LLMProvider + + # 일시적으로 잘못된 API 키 설정 테스트는 실제 환경에서는 위험하므로 스킵 + print("⚠️ 실제 환경에서는 API 키 에러 테스트 스킵") + + except Exception as e: + print(f"✅ 예상된 에러 발생: {e}") + + print("✅ 에러 시나리오 테스트 완료") + +async def main(): + """메인 테스트 함수""" + print("🚀 QGenie AI 서비스 테스트 시작\n") + + # 기본 서비스 테스트 + await test_llm_provider() + await test_api_client() + await test_database_service() + await test_annotation_service() + await test_chatbot_service() + await test_sql_agent() + + # 확장 테스트 (백엔드 연결이 가능한 경우에만) + try: + from core.clients.api_client import get_api_client + client = await get_api_client() + if await client.health_check(): + print("\n🧪 확장 테스트 시작 (백엔드 연결 확인됨)") + await test_end_to_end_chat() + await test_annotation_functionality() + await test_error_scenarios() + else: + print("\n⚠️ 백엔드 연결 불가 - 확장 테스트 스킵") + except Exception: + print("\n⚠️ 백엔드 연결 불가 - 확장 테스트 스킵") + + print("\n✨ 모든 테스트 완료!") + +if __name__ == "__main__": + asyncio.run(main())