Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/config/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

# Define agent-LLM mapping
AGENT_LLM_MAP: dict[str, LLMType] = {
"coordinator": "basic",
"planner": "basic",
"researcher": "basic",
"coder": "basic",
"reporter": "basic",
"podcast_script_writer": "basic",
"ppt_composer": "basic",
"prose_writer": "basic",
"prompt_enhancer": "basic",
"coordinator": "reasoning",
"planner": "reasoning",
"researcher": "reasoning",
"coder": "reasoning",
"reporter": "reasoning",
"podcast_script_writer": "reasoning",
"ppt_composer": "reasoning",
"prose_writer": "reasoning",
"prompt_enhancer": "reasoning",
}
15 changes: 4 additions & 11 deletions src/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ def planner_node(
if configurable.enable_deep_thinking:
llm = get_llm_by_type("reasoning")
elif AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type("basic").with_structured_output(
Plan,
method="json_mode",
)
llm = get_llm_by_type("basic")
else:
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])

Expand All @@ -116,13 +113,9 @@ def planner_node(
return Command(goto="reporter")

full_response = ""
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
response = llm.invoke(messages)
full_response = response.model_dump_json(indent=4, exclude_none=True)
else:
response = llm.stream(messages)
for chunk in response:
full_response += chunk.content
response = llm.stream(messages)
for chunk in response:
full_response += chunk.content
logger.debug(f"Current state messages: {state['messages']}")
logger.info(f"Planner response: {full_response}")

Expand Down
26 changes: 22 additions & 4 deletions src/podcast/graph/script_writer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,32 @@

def script_writer_node(state: PodcastState):
logger.info("Generating script for podcast...")
model = get_llm_by_type(
AGENT_LLM_MAP["podcast_script_writer"]
).with_structured_output(Script, method="json_mode")
script = model.invoke(
model = get_llm_by_type(AGENT_LLM_MAP["podcast_script_writer"])
response = model.invoke(
[
SystemMessage(content=get_prompt_template("podcast/podcast_script_writer")),
HumanMessage(content=state["input"]),
],
)

# 手动解析 JSON 响应
try:
import json
from src.utils.json_utils import repair_json_output

# 修复和解析 JSON 响应
content = repair_json_output(response.content)
script_data = json.loads(content)

# 创建 Script 对象
script = Script(**script_data)
except Exception as e:
logger.error(f"Failed to parse script JSON: {e}")
# 如果解析失败,创建一个默认的 Script 对象
script = Script(
locale="en",
lines=[]
)

print(script)
return {"script": script, "audio_chunks": []}
Loading