diff --git a/requirements.txt b/requirements.txt index 892ea35..92d159d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -80,6 +80,7 @@ pyzmq==27.0.0 regex==2024.11.6 requests==2.32.4 requests-toolbelt==1.0.0 +setuptools==80.9.0 six==1.17.0 sniffio==1.3.1 SQLAlchemy==2.0.41 diff --git a/src/agents/sql_agent_graph.py b/src/agents/sql_agent_graph.py index 52c86ac..6e39284 100644 --- a/src/agents/sql_agent_graph.py +++ b/src/agents/sql_agent_graph.py @@ -6,6 +6,7 @@ from langchain_core.messages import BaseMessage from langgraph.graph import StateGraph, END from langchain.output_parsers.pydantic import PydanticOutputParser +from langchain_core.output_parsers import StrOutputParser from langchain.prompts import load_prompt from schemas.sql_schemas import SqlQuery from core.db_manager import db_instance @@ -27,6 +28,7 @@ def resource_path(relative_path): PROMPT_DIR = os.path.join("prompts", PROMPT_VERSION, "sql_agent") # --- 프롬프트 로드 --- +INTENT_CLASSIFIER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "intent_classifier.yaml"))) SQL_GENERATOR_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "sql_generator.yaml"))) RESPONSE_SYNTHESIZER_PROMPT = load_prompt(resource_path(os.path.join(PROMPT_DIR, "response_synthesizer.yaml"))) @@ -35,6 +37,7 @@ class SqlAgentState(TypedDict): question: str chat_history: List[BaseMessage] db_schema: str + intent: str sql_query: str validation_error: Optional[str] validation_error_count: int @@ -43,6 +46,18 @@ class SqlAgentState(TypedDict): final_response: str # --- 노드 함수 정의 --- +def intent_classifier_node(state: SqlAgentState): + print("--- 0. 의도 분류 중 ---") + chain = INTENT_CLASSIFIER_PROMPT | llm_instance | StrOutputParser() + intent = chain.invoke({"question": state['question']}) + state['intent'] = intent + return state + +def unsupported_question_node(state: SqlAgentState): + print("--- SQL 관련 없는 질문 ---") + state['final_response'] = "죄송합니다, 해당 질문에는 답변할 수 없습니다. 데이터베이스 관련 질문만 가능합니다." + return state + def sql_generator_node(state: SqlAgentState): print("--- 1. SQL 생성 중 ---") parser = PydanticOutputParser(pydantic_object=SqlQuery) @@ -73,7 +88,7 @@ def sql_generator_node(state: SqlAgentState): ) response = llm_instance.invoke(prompt) - parsed_query = parser.invoke(response) + parsed_query = parser.invoke(response.content) state['sql_query'] = parsed_query.query state['validation_error'] = None state['execution_result'] = None @@ -145,6 +160,13 @@ def response_synthesizer_node(state: SqlAgentState): return state # --- 엣지 함수 정의 --- +def route_after_intent_classification(state: SqlAgentState): + if state['intent'] == "SQL": + print("--- 의도: SQL 관련 질문 ---") + return "sql_generator" + print("--- 의도: SQL과 관련 없는 질문 ---") + return "unsupported_question" + def should_execute_sql(state: SqlAgentState): if state.get("validation_error_count", 0) >= MAX_ERROR_COUNT: print(f"--- 검증 실패 {MAX_ERROR_COUNT}회 초과: 답변 생성으로 이동 ---") @@ -168,12 +190,26 @@ def should_retry_or_respond(state: SqlAgentState): # --- 그래프 구성 --- def create_sql_agent_graph() -> StateGraph: graph = StateGraph(SqlAgentState) + + graph.add_node("intent_classifier", intent_classifier_node) + graph.add_node("unsupported_question", unsupported_question_node) graph.add_node("sql_generator", sql_generator_node) graph.add_node("sql_validator", sql_validator_node) graph.add_node("sql_executor", sql_executor_node) graph.add_node("response_synthesizer", response_synthesizer_node) - graph.set_entry_point("sql_generator") + graph.set_entry_point("intent_classifier") + + graph.add_conditional_edges( + "intent_classifier", + route_after_intent_classification, + { + "sql_generator": "sql_generator", + "unsupported_question": "unsupported_question" + } + ) + graph.add_edge("unsupported_question", END) + graph.add_edge("sql_generator", "sql_validator") graph.add_conditional_edges("sql_validator", should_execute_sql, { diff --git a/src/prompts/v1/sql_agent/intent_classifier.yaml b/src/prompts/v1/sql_agent/intent_classifier.yaml new file mode 100644 index 0000000..1c2be08 --- /dev/null +++ b/src/prompts/v1/sql_agent/intent_classifier.yaml @@ -0,0 +1,30 @@ + +_type: prompt +input_variables: + - question +template: | + You are an intelligent assistant responsible for classifying user questions. + Your task is to determine whether a user's question is related to retrieving information from a database using SQL. + + - If the question can be answered with a SQL query, respond with "SQL". + - If the question is a simple greeting, a question about your identity, or anything that does not require database access, respond with "non-SQL". + + Example 1: + Question: "Show me the list of users who signed up last month." + Classification: SQL + + Example 2: + Question: "What is the total revenue for the last quarter?" + Classification: SQL + + Example 3: + Question: "Hello, who are you?" + Classification: non-SQL + + Example 4: + Question: "What is the weather like today?" + Classification: non-SQL + + Now, classify the following question: + Question: {question} + Classification: