From c81b0d184f050d6a063ecaaffa7b2fa180a721de Mon Sep 17 00:00:00 2001 From: Mrguanglei Date: Tue, 1 Jul 2025 11:48:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E6=B7=B1=E5=BA=A6?= =?UTF-8?q?=E6=80=9D=E8=80=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/agents.py | 18 ++++++++--------- src/graph/nodes.py | 15 ++++---------- src/podcast/graph/script_writer_node.py | 26 +++++++++++++++++++++---- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/src/config/agents.py b/src/config/agents.py index 04c86e9..ccf63d3 100644 --- a/src/config/agents.py +++ b/src/config/agents.py @@ -8,13 +8,13 @@ # Define agent-LLM mapping AGENT_LLM_MAP: dict[str, LLMType] = { - "coordinator": "basic", - "planner": "basic", - "researcher": "basic", - "coder": "basic", - "reporter": "basic", - "podcast_script_writer": "basic", - "ppt_composer": "basic", - "prose_writer": "basic", - "prompt_enhancer": "basic", + "coordinator": "reasoning", + "planner": "reasoning", + "researcher": "reasoning", + "coder": "reasoning", + "reporter": "reasoning", + "podcast_script_writer": "reasoning", + "ppt_composer": "reasoning", + "prose_writer": "reasoning", + "prompt_enhancer": "reasoning", } diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 8370e21..11efaaf 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -104,10 +104,7 @@ def planner_node( if configurable.enable_deep_thinking: llm = get_llm_by_type("reasoning") elif AGENT_LLM_MAP["planner"] == "basic": - llm = get_llm_by_type("basic").with_structured_output( - Plan, - method="json_mode", - ) + llm = get_llm_by_type("basic") else: llm = get_llm_by_type(AGENT_LLM_MAP["planner"]) @@ -116,13 +113,9 @@ def planner_node( return Command(goto="reporter") full_response = "" - if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking: - response = llm.invoke(messages) - full_response = response.model_dump_json(indent=4, exclude_none=True) - else: - response = llm.stream(messages) - for chunk in response: - full_response += chunk.content + response = llm.stream(messages) + for chunk in response: + full_response += chunk.content logger.debug(f"Current state messages: {state['messages']}") logger.info(f"Planner response: {full_response}") diff --git a/src/podcast/graph/script_writer_node.py b/src/podcast/graph/script_writer_node.py index 2b3831b..cd720ee 100644 --- a/src/podcast/graph/script_writer_node.py +++ b/src/podcast/graph/script_writer_node.py @@ -17,14 +17,32 @@ def script_writer_node(state: PodcastState): logger.info("Generating script for podcast...") - model = get_llm_by_type( - AGENT_LLM_MAP["podcast_script_writer"] - ).with_structured_output(Script, method="json_mode") - script = model.invoke( + model = get_llm_by_type(AGENT_LLM_MAP["podcast_script_writer"]) + response = model.invoke( [ SystemMessage(content=get_prompt_template("podcast/podcast_script_writer")), HumanMessage(content=state["input"]), ], ) + + # 手动解析 JSON 响应 + try: + import json + from src.utils.json_utils import repair_json_output + + # 修复和解析 JSON 响应 + content = repair_json_output(response.content) + script_data = json.loads(content) + + # 创建 Script 对象 + script = Script(**script_data) + except Exception as e: + logger.error(f"Failed to parse script JSON: {e}") + # 如果解析失败,创建一个默认的 Script 对象 + script = Script( + locale="en", + lines=[] + ) + print(script) return {"script": script, "audio_chunks": []}