diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..0820263 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,11 @@ +# 全体的な命令 + +- 回答は指示されない限り、常に日本語で行ってください。 +- 関数にはdocstringを必ず記載してください。 + - docstringはGoogle Styleで記述してください。 +- プログラムには、適切な粒度でコメントを記載してください。 +- docstringやコメントは、日本語で記述してください。 + +## プログラム実装に関する命令 +- コードを新たに追加したり、変更した場合は、テストコードも必ず追加・修正してください。 +- 常にカバレッジ100%を目指してください。 \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..19d3358 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,61 @@ +name: CI + +on: + push: + branches: + - main + - develop + - feature/** + pull_request: + branches: + - main + - develop + +jobs: + backend-tests: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + cache: "pip" + cache-dependency-path: pyproject.toml + + - name: Install backend dependencies + run: | + python -m pip install --upgrade pip + pip install . + + - name: Run backend tests + run: pytest tests/backend + + frontend-build: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + cache: "npm" + cache-dependency-path: src/frontend/package-lock.json + + - name: Install frontend dependencies + working-directory: src/frontend + run: npm ci + + - name: Lint frontend + working-directory: src/frontend + run: npm run lint -- --max-warnings=0 + + - name: Build frontend + working-directory: src/frontend + run: npm run build diff --git a/.gitignore b/.gitignore index 505a3b1..0ee7014 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,11 @@ build/ dist/ wheels/ *.egg-info +.pytest_cache/ +.coverage +htmlcov/ # Virtual environments .venv + +graph.png \ No newline at end of file diff --git a/README.md b/README.md index e30cf81..73fae57 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,20 @@ DeepReSearch は、LangGraph と LangChain を用いて多段階のウェブリ export OPENROUTER_API_KEY="your-key" ``` +### テストの実行 + +```bash +python -m pytest +``` + +カバレッジと視覚的なレポートが必要な場合は次のように実行します。 + +```bash +python -m pytest --cov=src/backend --cov-report=term-missing --cov-report=html +``` + +HTML レポートは `htmlcov/index.html` に生成され、ブラウザや VS Code の Live Preview で確認できます。 + ### フロントエンド (Next.js) 1. 依存関係をインストールします。 @@ -80,14 +94,6 @@ npm run dev - 調査計画のレビュー・編集フォーム (interrupt 発生時) - 生成済みプランとレポートの閲覧 -### CLI クライアントから操作する - -```bash -python -m clients.research_client ws "人類の歴史" -``` - -CLI 上で中断が発生したら `y` / `n` で判断し、必要に応じて編集済み計画 JSON を指定して再開できます。 - ### API ドキュメント バックエンド API のエンドポイント一覧は から確認できます。 diff --git a/clients/research_client.py b/clients/research_client.py deleted file mode 100644 index 5599fcf..0000000 --- a/clients/research_client.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Deep Research API と対話するための簡易クライアント。""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import os -from typing import Any, Dict -from urllib.parse import urlparse, urlunparse - -try: - import websockets -except ImportError: # pragma: no cover - optional dependency - websockets = None - - -async def _ainput(prompt: str) -> str: - """非同期に標準入力から文字列を取得する。""" - - return await asyncio.to_thread(input, prompt) - - -def _print_event(prefix: str, payload: Dict[str, Any]) -> None: - """イベント情報を整形して出力する。""" - - print(f"[{prefix}] {json.dumps(payload, ensure_ascii=False)}") - - -_DEFAULT_BASE_URL = os.environ.get("DEEPRESEARCH_API_BASE", "http://127.0.0.1:8000") - - -async def _websocket_research(args: argparse.Namespace) -> None: - """WebSocket 経由でHITL付きリサーチを実行する。""" - - if websockets is None: - print( - "websockets パッケージが必要です。`pip install websockets` を実行してください。" - ) - return - - parsed = urlparse(args.base_url) - scheme = "wss" if parsed.scheme == "https" else "ws" - path = parsed.path.rstrip("/") - ws_url = urlunparse((scheme, parsed.netloc, f"{path}/ws/research", "", "", "")) - - async with websockets.connect(ws_url) as ws: # type: ignore[union-attr] - await ws.send(json.dumps({"query": args.query}, ensure_ascii=False)) - - async for raw in ws: - try: - message = json.loads(raw) - except json.JSONDecodeError: - print(f"[invalid] {raw}") - continue - - msg_type = message.get("type") - - if msg_type == "thread_started": - print(f"Thread started: {message.get('thread_id')}") - elif msg_type == "event": - _print_event("event", message.get("payload", {})) - elif msg_type == "interrupt": - interrupt = message.get("interrupt", {}) - _print_event("interrupt", interrupt) - - while True: - decision = (await _ainput("[y/n] > ")).strip().lower() - if decision in {"y", "n"}: - break - print("'y' か 'n' を入力してください。") - - plan_payload = None - if decision == "y": - plan_path = ( - await _ainput("計画JSONのパス(未入力でスキップ)> ") - ).strip() - if plan_path: - try: - with open(plan_path, "r", encoding="utf-8") as handle: - plan_payload = json.load(handle) - except OSError as exc: - print(f"計画ファイルを開けませんでした: {exc}") - plan_payload = None - except json.JSONDecodeError as exc: - print(f"計画JSONの解析に失敗しました: {exc}") - plan_payload = None - - await ws.send( - json.dumps( - {"decision": decision, "plan": plan_payload}, ensure_ascii=False - ) - ) - elif msg_type == "complete": - print("Workflow completed. State:") - print( - json.dumps(message.get("state", {}), ensure_ascii=False, indent=2) - ) - return - elif msg_type == "error": - print(f"[error] {message.get('message')}") - return - else: - print(f"[unknown] {message}") - - -async def _dispatch(args: argparse.Namespace) -> None: - """サブコマンドに応じて WebSocket 実行をディスパッチする。""" - - if args.command == "ws": - await _websocket_research(args) - else: # pragma: no cover - argparse が保証する - raise ValueError(f"unknown command: {args.command}") - - -def _build_parser() -> argparse.ArgumentParser: - """コマンドライン引数パーサーを構築する。""" - - parser = argparse.ArgumentParser(description="Deep Research API クライアント") - parser.add_argument( - "--base-url", - default=_DEFAULT_BASE_URL, - help="API のベース URL (既定: %(default)s)", - ) - - subparsers = parser.add_subparsers(dest="command", required=True) - - ws_parser = subparsers.add_parser("ws", help="WebSocket でHITL操作を行う") - ws_parser.add_argument("query", help="リサーチしたいテーマや質問文") - - return parser - - -def main(argv: list[str] | None = None) -> int: - """エントリーポイント。""" - - parser = _build_parser() - args = parser.parse_args(argv) - - try: - asyncio.run(_dispatch(args)) - except Exception as exc: # pragma: no cover - CLI 実行時の予期しない例外 - print(f"エラーが発生しました: {exc}") - return 1 - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index 6013df1..b896b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,9 @@ dependencies = [ "langchain-openai>=1.0.1", "langgraph>=1.0.1", "nest-asyncio>=1.6.0", + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pytest-cov>=7.0.0", "streamlit>=1.51.0", "uvicorn[standard]>=0.38.0", "websockets>=15.0.1", diff --git a/src/backend/__init__.py b/src/backend/__init__.py deleted file mode 100644 index be11415..0000000 --- a/src/backend/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Backend サブパッケージ。""" diff --git a/src/backend/agent.py b/src/backend/agent.py index bfe3143..7f6b07e 100644 --- a/src/backend/agent.py +++ b/src/backend/agent.py @@ -1,3 +1,4 @@ +import datetime from os import getenv from typing import Annotated @@ -23,7 +24,6 @@ from src.backend.ai.reflect.reflect_search_result import ReflectionResultSchema from src.backend.ai.schedule.plan_reserch import GeneratedObjectSchema, PlanResearchAI from src.backend.ai.search.prompt import DEEP_RESEARCH_SYSTEM_PROMPT -from src.backend.tools.get_current_date import get_current_date from src.backend.tools.search_reflect import reflect_on_results from src.backend.tools.web_research import web_research @@ -131,7 +131,7 @@ class OSSDeepResearchAgent: def __init__(self) -> None: """エージェントを初期化する。""" # 使用するツール - self.tools = [web_research, reflect_on_results, get_current_date] + self.tools = [web_research, reflect_on_results] self.llm = ChatOpenAI( model="tngtech/deepseek-r1t2-chimera:free", @@ -173,7 +173,7 @@ async def _node_generate_research_parameters( Args: state (State): 現在のステート。 - config (RunnableConfig): LangGraph 実行時の設定。 + config (RunnableConfig): LangGraph実行時の設定。 Returns: dict[str, ResearchParameters]: 生成した研究パラメータを含む差分ステート。 @@ -189,7 +189,7 @@ async def _node_make_research_plan( Args: state (State): 現在のステート。 - config (RunnableConfig): LangGraph 実行時の設定。 + config (RunnableConfig): LangGraph実行時の設定。 Returns: dict[str, GeneratedObjectSchema]: 研究計画を含む差分ステート。 @@ -203,12 +203,12 @@ async def _research_plan_human_judge(self, state: State, config: RunnableConfig) Args: state (State): 現在のステート。 - config (RunnableConfig): LangGraph 実行時の設定。 + config (RunnableConfig): LangGraph実行時の設定。 Returns: State: 判定結果を反映したステート。 """ - feedback = interrupt("編集しますか? y or n: ") + feedback = interrupt("調査計画を編集しますか?") if feedback == "y": state.research_plan_human_edit = True @@ -243,7 +243,7 @@ async def _node_deep_research(self, state: State, config: RunnableConfig): Args: state (State): 現在のステート。 - config (RunnableConfig): ランググラフ実行時の設定。 + config (RunnableConfig): LangGraph実行時の設定。 Returns: dict[str, list]: LLM 応答を追記したメッセージ差分。 @@ -296,13 +296,16 @@ def _node_prepare_research(self, state: State): params = state.research_parameters assert params + assert plan # 2. システムプロンプトをフォーマット - formatted_plan = str(plan) + formatted_plan = plan.model_dump() final_prompt_text = DEEP_RESEARCH_SYSTEM_PROMPT.format( - SEARCH_PLAN=formatted_plan, SEARCH_QUERIES_PER_SECTION=params.search_queries_per_section, + SEARCH_API="DuckDuckGo", SEARCH_ITERATIONS=params.search_iterations, + SEARCH_PLAN=formatted_plan, + CURRENT_DATE=datetime.date.today(), ) # 3. ReActエージェントへの初期メッセージを作成 @@ -398,8 +401,7 @@ def get_compiled_graph(self): compiled_graph = graph.compile(checkpointer=memory) # graph実行イメージ保存 - # graph_image = compiled_graph.get_graph().draw_mermaid_png() - # with open("./graph.png", "wb") as file: - # file.write(graph_image) - - return compiled_graph + graph_image = compiled_graph.get_graph().draw_mermaid_png() # pragma: no cover + with open("./graph.png", "wb") as file: # pragma: no cover + file.write(graph_image) # pragma: no cover + return compiled_graph # pragma: no cover diff --git a/src/backend/ai/reflect/reflect_search_result.py b/src/backend/ai/reflect/reflect_search_result.py index e7fc1f7..2dc588e 100644 --- a/src/backend/ai/reflect/reflect_search_result.py +++ b/src/backend/ai/reflect/reflect_search_result.py @@ -53,12 +53,12 @@ def __init__(self, llm): self.llm = llm self.structured_llm = llm.with_structured_output(ReflectionResultSchema) - def __call__(self, query, result) -> ReflectionResultSchema: + def __call__(self, query, results) -> ReflectionResultSchema: prompt = [ ( "system", SEARCH_RESULT_ANALYZE_AND_REFLECTION_SYSTEM_PROMPT.format( - query=query, result=result + query=query, results=results ), ) ] diff --git a/src/backend/ai/search/deep_research.py b/src/backend/ai/search/deep_research.py deleted file mode 100644 index 705ee05..0000000 --- a/src/backend/ai/search/deep_research.py +++ /dev/null @@ -1,28 +0,0 @@ -from ..schedule.plan_reserch import GeneratedObjectSchema -from .prompt import DEEP_RESEARCH_SYSTEM_PROMPT - - -class DeepResearchAI: - def __init__(self, llm): - self.llm = llm - - def __call__( - self, - search_queries_per_section: int, - search_iterations: int, - search_plan: GeneratedObjectSchema, - ): - prompt = [ - ( - "system", - DEEP_RESEARCH_SYSTEM_PROMPT.format( - SEARCH_QUERIES_PER_SECTION=search_queries_per_section, - SEARCH_API="DuckDuckGo", - SEARCH_ITERATIONS=search_iterations, - SEARCH_PLAN=search_plan, - ), - ) - ] - response = self.llm.invoke(prompt) - - return response diff --git a/src/backend/ai/search/prompt.py b/src/backend/ai/search/prompt.py index 546cb95..3182146 100644 --- a/src/backend/ai/search/prompt.py +++ b/src/backend/ai/search/prompt.py @@ -1,18 +1,16 @@ DEEP_RESEARCH_SYSTEM_PROMPT = """あなたは高度な調査エージェントです。ユーザーの質問に対して深く、包括的な調査を行います。 ## 調査プロセス -あなたは、調査計画と、調査(Research)のフェーズに従って調査を行います: +あなたは、調査計画と、調査(Research)の手順に従って調査を行います: ### 調査計画 {SEARCH_PLAN} ### 調査(Research) -0. あなたに、ユーザーのクエリと、従うべき調査計画がシステムプロンプトにより与えられます。 -1. get_current_dateツールを使用して現在の日付を取得します。 - - 取得した日付を使用して、情報の鮮度を評価します - - レポートに日付を明記し、いつの時点の情報かを明確にします +0. あなたに、従うべき調査計画、本日の日付がこのシステムプロンプトにより与えられます。 +1. ユーザーから調査対象のプロンプトが与えられます。 2. 各セクションについて、{SEARCH_QUERIES_PER_SECTION}個の検索クエリを作成します。 -3. 各クエリに対してDuckDuckGoを使用して情報を収集します。 +3. 各クエリに対して{SEARCH_API}を使用して情報を収集します。 - web_researchツールを使用してウェブ検索を実行します - 検索結果の日付を確認し、古い情報は注意して扱います 4. 検索結果を処理し、以下を行います: @@ -33,11 +31,13 @@ 5. リンクのテキストは文脈に自然に溶け込むようにし、URLそのものは表示しないようにします。 6. 主要な発見事項、結論、および参考文献を含めます。 7. 情報の日付を明記し、特に最新の情報(過去3ヶ月以内)は強調します。 -8. レポートの冒頭に調査日(get_current_dateで取得した日付)を明記します。 + - 本日の日付ではなく、参考にしている情報が公開または更新された日付を使用してください。 +8. レポートの冒頭に調査日(本日の日付)を明記します。 ## セクション出力形式 各セクションは以下の形式で出力してください: +```md # [セクションタイトル] [セクションの内容:事実、分析、洞察など。情報源へのリンクを含める。情報の日付を明記する] @@ -53,10 +53,15 @@ - [情報源1へのリンク] (日付: YYYY-MM-DD) - [情報源2へのリンク] (日付: YYYY-MM-DD) - [情報源3へのリンク] (日付: YYYY-MM-DD) +``` ## ツールの使用方法 -- get_current_date: 現在の日付を取得します - web_research: ウェブ検索を実行して情報を収集します - reflect_on_results: 検索結果を振り返り、次のクエリを改善します +## 本日の日付 +- {CURRENT_DATE} + - 本日の日付を使用して、情報の鮮度を評価します + - レポートに日付を明記し、いつの時点の情報かを明確にします + 常に批判的思考を用い、情報の信頼性を評価してください。複数の情報源を比較し、バランスの取れた見解を提供してください。""" diff --git a/src/backend/api/main.py b/src/backend/api/main.py index 93e9c43..810c2da 100644 --- a/src/backend/api/main.py +++ b/src/backend/api/main.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging import os from datetime import datetime, timezone from typing import Any, Dict @@ -11,8 +12,9 @@ from starlette.websockets import WebSocketState from .schemas import HealthResponse, InterruptPayload, StateResponse, ThreadListResponse -from .workflow import StateNotFoundError, workflow_service +from .workflow import StateNotFoundError, WorkflowService +workflow_service = WorkflowService() app = FastAPI(title="Deep Research API", version="1.0.0") _DEFAULT_ALLOWED_ORIGINS = { @@ -20,11 +22,21 @@ "http://127.0.0.1:3000", } +# Uvicorn の標準エラーロガー配下にぶら下げて、Docker コンソールへ確実に流す。 +logger = logging.getLogger("uvicorn.error").getChild("deep_research.api") +logger.setLevel(logging.INFO) +logger.propagate = True + def _resolve_allowed_origins() -> list[str]: - """CORS許可オリジンを環境変数から解決する。""" + """CORS許可オリジンを環境変数から解決する。 + + Returns: + list[str]: 許可するオリジンのリスト。環境変数が未設定の場合はデフォルト値を返す。 + """ raw_value = os.getenv("CORS_ALLOW_ORIGINS", "") + # 環境変数はカンマ区切りを想定、空要素を除外しつつ整理 candidates = [origin.strip() for origin in raw_value.split(",") if origin.strip()] if candidates: return candidates @@ -40,30 +52,50 @@ def _resolve_allowed_origins() -> list[str]: ) -async def _send_ws_events( - websocket: WebSocket, thread_id: str, events: list[Dict[str, Any]] -) -> None: - """WebSocketへ逐次イベントを送信する。""" - - for event in events: - await websocket.send_json( - { - "type": "event", - "thread_id": thread_id, - "payload": event, - } - ) +def _interrupt_from_raw(raw: Dict[str, Any] | None) -> InterruptPayload | None: + """未加工データから割り込みペイロードを生成する。 + Args: + raw (Dict[str, Any] | None): ワークフローから返された割り込みデータ。 -def _interrupt_from_raw(raw: Dict[str, Any] | None) -> InterruptPayload | None: + Returns: + InterruptPayload | None: バリデーションに成功した割り込みペイロード。入力が不正な場合は `None`。 + """ if not raw: return None return InterruptPayload.model_validate(raw) +def _extract_event_error_message(event: Dict[str, Any]) -> str: + """イベントペイロードから代表的なエラーメッセージを抽出する。 + + Args: + event (Dict[str, Any]): ワークフローから受け取ったイベントペイロード。 + + Returns: + str: 抽出したエラーメッセージ。候補が存在しない場合は汎用的なメッセージを返す。 + """ + message = event.get("message") + if isinstance(message, str) and message: + return message + + payload = event.get("data") or event.get("payload") + if isinstance(payload, dict): + for key in ("message", "error", "text", "details"): + value = payload.get(key) + if isinstance(value, str) and value: + return value + + return "処理中にエラーが発生しました。" + + @app.get("/healthz", response_model=HealthResponse, tags=["system"]) async def healthcheck() -> HealthResponse: - """システムの稼働状況を返すヘルスチェック。""" + """システムの稼働状況を返すヘルスチェック。 + + Returns: + HealthResponse: システムステータスと診断情報を含むレスポンス。 + """ diagnostics = workflow_service.diagnostics() return HealthResponse( @@ -75,7 +107,11 @@ async def healthcheck() -> HealthResponse: @app.get("/threads", response_model=ThreadListResponse, tags=["workflow"]) async def list_threads() -> ThreadListResponse: - """アクティブなスレッドおよび割り込み待ちスレッドの一覧を返す。""" + """アクティブなスレッドおよび割り込み待ちスレッドの一覧を返す。 + + Returns: + ThreadListResponse: スレッドIDと件数を含むレスポンスデータ。 + """ active = workflow_service.list_active_threads() pending = workflow_service.list_pending_interrupts() @@ -93,7 +129,17 @@ async def list_threads() -> ThreadListResponse: tags=["workflow"], ) async def get_thread_state(thread_id: str) -> StateResponse: - """スレッドの最新状態スナップショットを取得する。""" + """スレッドの最新状態スナップショットを取得する。 + + Args: + thread_id (str): 状態を取得したいスレッドID。 + + Returns: + StateResponse: スレッドのステータス、状態、割り込み情報を含むレスポンス。 + + Raises: + HTTPException: スレッドが見つからない場合に ``404`` を投げる。 + """ try: snapshot = workflow_service.get_state(thread_id) @@ -110,30 +156,92 @@ async def get_thread_state(thread_id: str) -> StateResponse: @app.websocket("/ws/research") async def websocket_research(websocket: WebSocket) -> None: - """WebSocket経由でHITL対応のリサーチ実行を提供する。""" + """WebSocket経由でHITL対応のリサーチ実行を提供する。 + + Args: + websocket (WebSocket): クライアント接続を表すWebSocketインスタンス。 + Returns: + None: このエンドポイントはレスポンスボディを返さない。 + + Raises: + WebSocketDisconnect: クライアント切断時に内部的に発生する。 + Exception: 想定外のエラーが発生した場合に送信し、接続を終了する。 + """ + + # WebSocket接続を確立し、スレッドIDを後続処理で共有する。 await websocket.accept() thread_id: str | None = None + close_reason = "initialized" try: + # クライアントから初期クエリを受信し、空文字列でないことを確認する。 initial_payload = await websocket.receive_json() query = (initial_payload.get("query") or "").strip() if not query: + logger.warning("WebSocket closing due to empty query payload") + close_reason = "invalid_query" await websocket.send_json({"type": "error", "message": "query が空です。"}) await websocket.close(code=4000) return + # 新しいスレッドを生成し、クライアントへ開始イベントを通知する。 thread_id = workflow_service.create_thread_id() + logger.info( + "WebSocket session started [thread_id=%s, query=%s]", + thread_id, + query[:200], + ) await websocket.send_json({"type": "thread_started", "thread_id": thread_id}) + async def forward_event(event: Dict[str, Any]) -> None: + # ワークフローからのイベントをそのままフロントに中継する。 + event_name = event.get("event") + logger.debug( + "Forwarding workflow event [thread_id=%s, event=%s]", + thread_id, + event_name, + ) + await websocket.send_json( + { + "type": "event", + "thread_id": thread_id, + "payload": event, + } + ) + if event.get("level") == "error": + error_message = _extract_event_error_message(event) + logger.error( + "Workflow error event forwarded [thread_id=%s, event=%s]: %s", + thread_id, + event.get("event"), + error_message, + ) + await websocket.send_json( + { + "type": "error", + "thread_id": thread_id, + "message": error_message, + } + ) + + # ワークフローを開始し、最初の割り込みまたは完了まで待機する。 outcome = await workflow_service.start_research( thread_id=thread_id, query=query, + event_consumer=forward_event, + ) + logger.info( + "Workflow start completed [thread_id=%s, status=%s, events=%d, interrupt=%s]", + thread_id, + outcome.status, + len(outcome.events), + bool(outcome.interrupt), ) - await _send_ws_events(websocket, thread_id, outcome.events) while True: if outcome.status == "completed": + logger.info("Thread completed [thread_id=%s]", thread_id) await websocket.send_json( { "type": "complete", @@ -141,20 +249,33 @@ async def websocket_research(websocket: WebSocket) -> None: "state": outcome.state, } ) + close_reason = "completed" await websocket.close(code=1000) return interrupt_payload = _interrupt_from_raw(outcome.interrupt) if not interrupt_payload: + logger.error( + "Missing interrupt payload [thread_id=%s, status=%s]", + thread_id, + outcome.status, + ) + close_reason = "missing_interrupt" await websocket.send_json( { "type": "error", + "thread_id": thread_id, "message": "割り込み情報が取得できませんでした。", } ) await websocket.close(code=1011) return + logger.info( + "Interrupt dispatched [thread_id=%s, interrupt_id=%s]", + thread_id, + interrupt_payload.id, + ) await websocket.send_json( { "type": "interrupt", @@ -163,31 +284,71 @@ async def websocket_research(websocket: WebSocket) -> None: } ) + # クライアント側の意思決定を待ち、承認または再計画を処理する。 resume_payload = await websocket.receive_json() decision = (resume_payload.get("decision") or "").lower() if decision not in {"y", "n"}: + logger.warning( + "Invalid decision received [thread_id=%s, decision=%s]", + thread_id, + resume_payload.get("decision"), + ) await websocket.send_json( { "type": "error", + "thread_id": thread_id, "message": "decision は 'y' または 'n' を指定してください。", } ) continue plan_update = resume_payload.get("plan") + logger.info( + "Resuming workflow [thread_id=%s, decision=%s, has_plan_update=%s]", + thread_id, + decision, + plan_update is not None, + ) outcome = await workflow_service.resume_research( thread_id=thread_id, decision=decision, plan_update=plan_update, + event_consumer=forward_event, + ) + logger.info( + "Workflow resumed [thread_id=%s, status=%s, events=%d, interrupt=%s]", + thread_id, + outcome.status, + len(outcome.events), + bool(outcome.interrupt), ) - await _send_ws_events(websocket, thread_id, outcome.events) - except WebSocketDisconnect: # pragma: no cover - 切断時 + except WebSocketDisconnect: + close_reason = "client_disconnect" + logger.info("WebSocket disconnected by client [thread_id=%s]", thread_id) return - except Exception as exc: # pragma: no cover - 想定外エラー + except Exception as exc: + close_reason = f"exception:{exc.__class__.__name__}" + logger.exception( + "Unhandled exception in websocket_research [thread_id=%s]", thread_id + ) if websocket.application_state == WebSocketState.CONNECTED: - await websocket.send_json({"type": "error", "message": str(exc)}) + error_payload = {"type": "error", "message": str(exc)} + if thread_id: + error_payload["thread_id"] = thread_id + await websocket.send_json(error_payload) await websocket.close(code=1011) finally: if websocket.application_state == WebSocketState.CONNECTED: + logger.info( + "Closing WebSocket session [thread_id=%s, reason=%s]", + thread_id, + close_reason, + ) await websocket.close(code=1000) + else: + logger.info( + "WebSocket session finalized [thread_id=%s, reason=%s]", + thread_id, + close_reason, + ) diff --git a/src/backend/api/workflow.py b/src/backend/api/workflow.py deleted file mode 100644 index abf155b..0000000 --- a/src/backend/api/workflow.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Deep Researchワークフローの実行管理ロジック。""" - -from __future__ import annotations - -import json -import os -import uuid -from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Dict, Optional - -from langchain_core.runnables import RunnableConfig -from langgraph.types import Command, Interrupt - -from ..agent import OSSDeepResearchAgent - -_STREAM_VERSION = "v1" -_DEFAULT_RECURSION_LIMIT = 100 - - -class WorkflowError(Exception): - """ワークフロー操作時に発生する例外の基底クラス。""" - - -class StateNotFoundError(WorkflowError): - """指定したスレッドの状態が見つからない場合に送出する。""" - - -class HitlNotEnabledError(WorkflowError): - """HITL モードが無効なスレッドに対して操作を行った場合に送出する。""" - - -class InterruptNotFoundError(WorkflowError): - """保留中割り込みが存在しない場合に送出する。""" - - -@dataclass -class RunOutcome: - """ワークフロー実行結果を表現するデータクラス。""" - - status: str - state: Dict[str, Any] - events: list[Dict[str, Any]] - interrupt: Dict[str, Any] | None - - -@dataclass -class StateSnapshot: - """スレッド状態取得結果を表現するデータクラス。""" - - status: str - state: Dict[str, Any] - pending_interrupt: Dict[str, Any] | None - - -class WorkflowService: - """Deep Researchワークフロー実行を統括するサービス。""" - - def __init__(self) -> None: - """サービスを初期化し、LangGraphエージェントを構築する。""" - - self._agent = OSSDeepResearchAgent() - self._graph = self._agent.get_compiled_graph() - self._pending_interrupts: dict[str, Interrupt] = {} - self._hitl_threads: set[str] = set() - self._recursion_limit = self._load_recursion_limit() - - # 公開API --------------------------------------------------------------- - - def create_thread_id(self) -> str: - """新しいスレッドIDを生成する。""" - - return str(uuid.uuid4()) - - async def start_research( - self, - *, - thread_id: str, - query: str, - ) -> RunOutcome: - """ワークフローを開始し、必ずHITL割り込みポイントまで実行する。""" - - self._register_hitl_thread(thread_id) - initial_payload: Dict[str, Any] = {"user_input": query} - events, pending, finished, snapshot = await self._run_until_pause( - initial_payload, - thread_id=thread_id, - auto_resume=False, - interrupt_predicate=self._is_plan_edit_interrupt, - ) - self._record_post_run(thread_id, pending, finished) - state = self._serialize_state(thread_id, snapshot) - status = "completed" if finished else "pending_human" - interrupt_dict = self._serialize_interrupt(pending) if pending else None - return RunOutcome( - status=status, state=state, events=events, interrupt=interrupt_dict - ) - - async def resume_research( - self, - *, - thread_id: str, - decision: str, - plan_update: Any | None, - ) -> RunOutcome: - """保留中割り込みへの回答を用いてワークフローを再開する。""" - - if not self._is_hitl_thread(thread_id): - raise HitlNotEnabledError("このスレッドはHITLモードで開始されていません。") - - pending = self._get_pending_interrupt(thread_id) - if pending is None: - raise InterruptNotFoundError("待機中の割り込みは見つかりません。") - - command_kwargs: Dict[str, Any] = {"resume": {pending.id: decision}} - if plan_update is not None: - command_kwargs["update"] = {"research_plan": plan_update} - - events, next_pending, finished, snapshot = await self._run_until_pause( - Command(**command_kwargs), - thread_id=thread_id, - auto_resume=False, - interrupt_predicate=self._is_plan_edit_interrupt, - ) - self._record_post_run(thread_id, next_pending, finished) - - state = self._serialize_state(thread_id, snapshot) - status = "completed" if finished else "pending_human" - interrupt_dict = ( - self._serialize_interrupt(next_pending) if next_pending else None - ) - return RunOutcome( - status=status, state=state, events=events, interrupt=interrupt_dict - ) - - def get_state(self, thread_id: str) -> StateSnapshot: - """スレッドIDに紐づく最新状態を取得する。""" - - snapshot = self._graph.get_state(self._graph_config(thread_id)) - if snapshot is None: - raise StateNotFoundError("指定したスレッドの状態が見つかりません。") - - pending = self._get_pending_interrupt(thread_id) - status = "pending_human" - if not pending: - status = "completed" if self._is_run_finished(snapshot) else "running" - - state = self._serialize_state(thread_id, snapshot) - interrupt_dict = self._serialize_interrupt(pending) if pending else None - return StateSnapshot( - status=status, state=state, pending_interrupt=interrupt_dict - ) - - def diagnostics(self) -> Dict[str, Any]: - """サービス全体の診断情報を取得する。""" - - return { - "active_threads": len(self._hitl_threads), - "pending_interrupts": len(self._pending_interrupts), - "recursion_limit": self._recursion_limit, - } - - def list_active_threads(self) -> list[str]: - """現在アクティブなスレッドID一覧を返す。""" - - return sorted(self._hitl_threads) - - def list_pending_interrupts(self) -> list[str]: - """割り込み回答待ちのスレッドID一覧を返す。""" - - return sorted(self._pending_interrupts.keys()) - - async def stream_events( - self, - *, - thread_id: str, - query: str, - auto_resume: bool = True, - interrupt_predicate: Callable[[Interrupt], bool] | None = None, - ) -> AsyncGenerator[str, None]: - """ワークフロー実行イベントをストリームとして返す。""" - - initial_payload: Dict[str, Any] = {"user_input": query} - async for event in self._astream( - initial_payload, - thread_id=thread_id, - auto_resume=auto_resume, - interrupt_predicate=interrupt_predicate, - ): - yield self._format_sse(event) - - final_state = self._serialize_state(thread_id) - yield self._format_sse( - { - "event": "state_snapshot", - "name": "final_state", - "data": {"thread_id": thread_id, "state": final_state}, - } - ) - - def render_event(self, event: Dict[str, Any]) -> str: - """任意のイベント辞書をSSEフレーム文字列へ変換する。""" - - return self._format_sse(event) - - # 内部ユーティリティ --------------------------------------------------- - - def _graph_config(self, thread_id: str) -> RunnableConfig: - return { - "configurable": {"thread_id": thread_id}, - "recursion_limit": self._recursion_limit, - } - - def _load_recursion_limit(self) -> int: - raw_value = os.getenv("GRAPH_RECURSION_LIMIT") - try: - limit = ( - int(raw_value) if raw_value is not None else _DEFAULT_RECURSION_LIMIT - ) - except (TypeError, ValueError): - limit = _DEFAULT_RECURSION_LIMIT - return max(limit, 1) - - def _register_hitl_thread(self, thread_id: str) -> None: - self._hitl_threads.add(thread_id) - self._pending_interrupts.pop(thread_id, None) - - def _record_post_run( - self, - thread_id: str, - pending: Optional[Interrupt], - finished: bool, - ) -> None: - if pending and not finished: - self._hitl_threads.add(thread_id) - self._pending_interrupts[thread_id] = pending - return - self._hitl_threads.discard(thread_id) - self._pending_interrupts.pop(thread_id, None) - - def _is_hitl_thread(self, thread_id: str) -> bool: - return thread_id in self._hitl_threads - - def _get_pending_interrupt(self, thread_id: str) -> Optional[Interrupt]: - if not self._is_hitl_thread(thread_id): - return None - return self._pending_interrupts.get(thread_id) - - def _serialize_interrupt( - self, interrupt: Interrupt | None - ) -> Dict[str, Any] | None: - if interrupt is None: - return None - return {"id": interrupt.id, "value": interrupt.value} - - def _extract_interrupt(self, event: Any) -> Optional[Interrupt]: - data = event.get("data") - if not isinstance(data, dict): - return None - - payload: Any | None = None - if event.get("event") == "on_chain_stream": - payload = data.get("chunk") - elif event.get("event") == "on_chain_end": - payload = data.get("output") - - if isinstance(payload, dict) and "__interrupt__" in payload: - interrupts = payload["__interrupt__"] - if isinstance(interrupts, (list, tuple)) and interrupts: - candidate = interrupts[-1] - if isinstance(candidate, Interrupt): - return candidate - return None - - def _sanitize_event(self, event: Any) -> Dict[str, Any]: - if isinstance(event, dict): - return {key: self._convert_model(value) for key, value in event.items()} - return {"event": "message", "data": self._convert_model(event)} - - def _convert_model(self, obj: Any) -> Any: - if hasattr(obj, "model_dump"): - return obj.model_dump() - if isinstance(obj, Interrupt): - return {"id": obj.id, "value": obj.value} - if isinstance(obj, dict): - return {k: self._convert_model(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [self._convert_model(v) for v in obj] - if isinstance(obj, (str, int, float, bool)) or obj is None: - return obj - return str(obj) - - def _serialize_state( - self, thread_id: str, snapshot: Any | None = None - ) -> Dict[str, Any]: - if snapshot is None: - snapshot = self._graph.get_state(self._graph_config(thread_id)) - if snapshot is None: - raise StateNotFoundError("指定したスレッドの状態が見つかりません。") - values = getattr(snapshot, "values", {}) - return {k: self._convert_model(v) for k, v in dict(values).items()} - - def _is_plan_edit_interrupt(self, interrupt: Interrupt) -> bool: - prompt = getattr(interrupt, "value", "") - if isinstance(prompt, str) and "編集しますか" in prompt: - return True - interrupt_id = getattr(interrupt, "id", "") - return ( - isinstance(interrupt_id, str) - and "_research_plan_human_judge" in interrupt_id - ) - - def _is_run_finished(self, snapshot: Any | None) -> bool: - if snapshot is None: - return False - return not getattr(snapshot, "next", None) - - async def _run_until_pause( - self, - payload: Any, - *, - thread_id: str, - auto_resume: bool, - interrupt_predicate: Callable[[Interrupt], bool] | None = None, - ) -> tuple[list[Dict[str, Any]], Optional[Interrupt], bool, Any | None]: - config = self._graph_config(thread_id) - current_payload: Any = payload - collected_events: list[Dict[str, Any]] = [] - - while True: - pending: Interrupt | None = None - async for event in self._graph.astream_events( - current_payload, config=config, version=_STREAM_VERSION - ): - collected_events.append(self._sanitize_event(event)) - pending = self._extract_interrupt(event) - if pending: - break - - snapshot = self._graph.get_state(config) - finished = self._is_run_finished(snapshot) - - if pending: - allowed = interrupt_predicate(pending) if interrupt_predicate else True - if auto_resume or not allowed: - auto_event = { - "event": "auto_resume", - "name": "human_judge", - "data": {"decision": "n", "thread_id": thread_id}, - } - collected_events.append(self._sanitize_event(auto_event)) - current_payload = Command(resume={pending.id: "n"}) - continue - return collected_events, pending, finished, snapshot - - return collected_events, None, finished, snapshot - - async def _astream( - self, - payload: Any, - *, - thread_id: str, - auto_resume: bool, - interrupt_predicate: Callable[[Interrupt], bool] | None = None, - ) -> AsyncGenerator[Any, None]: - config = self._graph_config(thread_id) - current_payload: Any = payload - - while True: - pending: Interrupt | None = None - async for event in self._graph.astream_events( - current_payload, config=config, version=_STREAM_VERSION - ): - yield self._sanitize_event(event) - pending = self._extract_interrupt(event) - if pending: - break - if not pending: - break - - allowed = interrupt_predicate(pending) if interrupt_predicate else True - if auto_resume or not allowed: - yield { - "event": "auto_resume", - "name": "human_judge", - "data": {"decision": "n", "thread_id": thread_id}, - } - current_payload = Command(resume={pending.id: "n"}) - continue - - yield { - "event": "interrupt", - "name": pending.id, - "data": self._serialize_interrupt(pending), - } - return - - def _format_sse(self, event: Dict[str, Any]) -> str: - payload = json.dumps(event, default=self._convert_model, ensure_ascii=False) - event_type = event.get("event", "message") - return f"event: {event_type}\ndata: {payload}\n\n" - - -workflow_service = WorkflowService() -"""アプリ全体で共有するワークフローサービスシングルトン。""" diff --git a/src/backend/api/workflow/__init__.py b/src/backend/api/workflow/__init__.py new file mode 100644 index 0000000..9380abb --- /dev/null +++ b/src/backend/api/workflow/__init__.py @@ -0,0 +1,25 @@ +"""Deep Researchワークフロー関連エントリポイント。""" + +from __future__ import annotations + +from .constants import DEFAULT_RECURSION_LIMIT, STREAM_VERSION +from .errors import ( + HitlNotEnabledError, + InterruptNotFoundError, + StateNotFoundError, + WorkflowError, +) +from .models import RunOutcome, StateSnapshot +from .service import WorkflowService + +__all__ = [ + "DEFAULT_RECURSION_LIMIT", + "STREAM_VERSION", + "WorkflowError", + "StateNotFoundError", + "HitlNotEnabledError", + "InterruptNotFoundError", + "RunOutcome", + "StateSnapshot", + "WorkflowService", +] diff --git a/src/backend/api/workflow/constants.py b/src/backend/api/workflow/constants.py new file mode 100644 index 0000000..259cf0b --- /dev/null +++ b/src/backend/api/workflow/constants.py @@ -0,0 +1,4 @@ +"""Workflowサービスで利用する定数群。""" + +STREAM_VERSION = "v1" +DEFAULT_RECURSION_LIMIT = 100 diff --git a/src/backend/api/workflow/errors.py b/src/backend/api/workflow/errors.py new file mode 100644 index 0000000..eea9d18 --- /dev/null +++ b/src/backend/api/workflow/errors.py @@ -0,0 +1,17 @@ +"""Workflowサービスで発生し得る例外クラス群。""" + + +class WorkflowError(Exception): + """ワークフロー操作時に発生する例外の基底クラス。""" + + +class StateNotFoundError(WorkflowError): + """指定したスレッドの状態が見つからない場合に送出する。""" + + +class HitlNotEnabledError(WorkflowError): + """HITL モードが無効なスレッドに対して操作を行った場合に送出する。""" + + +class InterruptNotFoundError(WorkflowError): + """保留中割り込みが存在しない場合に送出する。""" diff --git a/src/backend/api/workflow/models.py b/src/backend/api/workflow/models.py new file mode 100644 index 0000000..f788e53 --- /dev/null +++ b/src/backend/api/workflow/models.py @@ -0,0 +1,25 @@ +"""Workflowサービスで利用するデータモデル。""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + + +@dataclass +class RunOutcome: + """ワークフロー実行結果を表現するデータクラス。""" + + status: str + state: Dict[str, Any] + events: list[Dict[str, Any]] + interrupt: Dict[str, Any] | None + + +@dataclass +class StateSnapshot: + """スレッド状態取得結果を表現するデータクラス。""" + + status: str + state: Dict[str, Any] + pending_interrupt: Dict[str, Any] | None diff --git a/src/backend/api/workflow/service.py b/src/backend/api/workflow/service.py new file mode 100644 index 0000000..9e3bc46 --- /dev/null +++ b/src/backend/api/workflow/service.py @@ -0,0 +1,655 @@ +"""WorkflowService本体の実装。""" + +from __future__ import annotations + +import json +import logging +import os +import uuid +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional + +from langchain_core.runnables import RunnableConfig +from langgraph.types import Command, Interrupt + +from src.backend.agent import OSSDeepResearchAgent +from .constants import DEFAULT_RECURSION_LIMIT, STREAM_VERSION +from .errors import ( + HitlNotEnabledError, + InterruptNotFoundError, + StateNotFoundError, +) +from .models import RunOutcome, StateSnapshot + +logger = logging.getLogger(__name__) + + +class WorkflowService: + """Deep Researchワークフロー実行を統括するサービス。""" + + def __init__(self) -> None: + """サービスを初期化し、LangGraphエージェントを構築する。 + + Raises: + RuntimeError: エージェントの初期化に失敗した場合。 + """ + + self._agent = OSSDeepResearchAgent() + self._graph = self._agent.get_compiled_graph() + self._pending_interrupts: dict[str, Interrupt] = {} + self._hitl_threads: set[str] = set() + self._recursion_limit = self._load_recursion_limit() + + def create_thread_id(self) -> str: + """新しいスレッドIDを生成する。 + + Returns: + str: UUIDv4形式で生成したスレッドID。 + """ + + return str(uuid.uuid4()) + + async def start_research( + self, + *, + thread_id: str, + query: str, + event_consumer: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, + ) -> RunOutcome: + """ワークフローを開始し、HITL割り込み発生位置まで実行する。 + + Args: + thread_id (str): 実行コンテキストを識別するスレッドID。 + query (str): ユーザーからのリサーチ要求テキスト。 + event_consumer (Callable[[Dict[str, Any]], Awaitable[None]] | None): + 実行途中のイベントを逐次処理するコールバック。未指定の場合はイベントを転送しない。 + + Returns: + RunOutcome: 実行結果の状態、イベント群、割り込み情報を含むオブジェクト。 + """ + + self._register_hitl_thread(thread_id) + initial_payload: Dict[str, Any] = {"user_input": query} + events, pending, finished, snapshot = await self._run_until_pause( + initial_payload, + thread_id=thread_id, + auto_resume=False, + interrupt_predicate=self._is_plan_edit_interrupt, + event_consumer=event_consumer, + ) + self._record_post_run(thread_id, pending, finished) + state = self._serialize_state(thread_id, snapshot) + status = "completed" if finished else "pending_human" + interrupt_dict = self._serialize_interrupt(pending) if pending else None + return RunOutcome( + status=status, state=state, events=events, interrupt=interrupt_dict + ) + + async def resume_research( + self, + *, + thread_id: str, + decision: str, + plan_update: Any | None, + event_consumer: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, + ) -> RunOutcome: + """保留中割り込みへの回答を用いてワークフローを再開する。 + + Args: + thread_id (str): 再開対象のスレッドID。 + decision (str): ヒューマンジャッジの判定結果。``"y"`` または ``"n"`` を想定する。 + plan_update (Any | None): 改訂後の調査計画。判定が ``"n"`` の場合は ``None``。 + event_consumer (Callable[[Dict[str, Any]], Awaitable[None]] | None): + 追加イベントを処理するコールバック。指定しない場合は転送しない。 + + Returns: + RunOutcome: 再開後の進捗、イベント、割り込み状態を表すオブジェクト。 + + Raises: + HitlNotEnabledError: スレッドがHITLモードで開始されていない場合。 + InterruptNotFoundError: 待機中の割り込みが存在しない場合。 + """ + + if not self._is_hitl_thread(thread_id): + raise HitlNotEnabledError("このスレッドはHITLモードで開始されていません。") + + pending = self._get_pending_interrupt(thread_id) + if pending is None: + raise InterruptNotFoundError("待機中の割り込みは見つかりません。") + + command_kwargs: Dict[str, Any] = {"resume": {pending.id: decision}} + if plan_update is not None: + command_kwargs["update"] = {"research_plan": plan_update} + + events, next_pending, finished, snapshot = await self._run_until_pause( + Command(**command_kwargs), + thread_id=thread_id, + auto_resume=False, + interrupt_predicate=self._is_plan_edit_interrupt, + event_consumer=event_consumer, + ) + self._record_post_run(thread_id, next_pending, finished) + + state = self._serialize_state(thread_id, snapshot) + status = "completed" if finished else "pending_human" + interrupt_dict = ( + self._serialize_interrupt(next_pending) if next_pending else None + ) + return RunOutcome( + status=status, state=state, events=events, interrupt=interrupt_dict + ) + + def get_state(self, thread_id: str) -> StateSnapshot: + """スレッドIDに紐づく最新状態を取得する。 + + Args: + thread_id (str): 状態を確認したいスレッドID。 + + Returns: + StateSnapshot: 現在の状態値、ステータス、割り込み情報をまとめたスナップショット。 + + Raises: + StateNotFoundError: 状態が永続層から取得できなかった場合。 + """ + + snapshot = self._graph.get_state(self._graph_config(thread_id)) + if snapshot is None: + raise StateNotFoundError("指定したスレッドの状態が見つかりません。") + + pending = self._get_pending_interrupt(thread_id) + status = "pending_human" + if not pending: + status = "completed" if self._is_run_finished(snapshot) else "running" + + state = self._serialize_state(thread_id, snapshot) + interrupt_dict = self._serialize_interrupt(pending) if pending else None + return StateSnapshot( + status=status, state=state, pending_interrupt=interrupt_dict + ) + + def diagnostics(self) -> Dict[str, Any]: + """サービス全体の診断情報を取得する。 + + Returns: + Dict[str, Any]: 稼働中スレッド数や再帰制限などのメトリクス。 + """ + + return { + "active_threads": len(self._hitl_threads), + "pending_interrupts": len(self._pending_interrupts), + "recursion_limit": self._recursion_limit, + } + + def list_active_threads(self) -> list[str]: + """現在アクティブなスレッドID一覧を返す。 + + Returns: + list[str]: 稼働中ワークフローのスレッドIDを昇順に並べたリスト。 + """ + + return sorted(self._hitl_threads) + + def list_pending_interrupts(self) -> list[str]: + """割り込み回答待ちのスレッドID一覧を返す。 + + Returns: + list[str]: 人手介入待ちのスレッドIDを昇順で並べたリスト。 + """ + + return sorted(self._pending_interrupts.keys()) + + async def stream_events( + self, + *, + thread_id: str, + query: str, + auto_resume: bool = True, + interrupt_predicate: Callable[[Interrupt], bool] | None = None, + ) -> AsyncGenerator[str, None]: + """ワークフロー実行イベントをストリームとして返す。 + + Args: + thread_id (str): イベントを取得したいスレッドID。 + query (str): 初回実行時に投入するリサーチクエリ。 + auto_resume (bool): 割り込み発生時に自動で ``"n"`` 応答を返すかどうか。 + interrupt_predicate (Callable[[Interrupt], bool] | None): + 人手介入が必要か判定するコールバック。未指定の場合は常に許可する。 + + Yields: + str: SSEフォーマットに変換されたイベント文字列。 + """ + + initial_payload: Dict[str, Any] = {"user_input": query} + async for event in self._astream( + initial_payload, + thread_id=thread_id, + auto_resume=auto_resume, + interrupt_predicate=interrupt_predicate, + ): + yield self._format_sse(event) + + final_state = self._serialize_state(thread_id) + yield self._format_sse( + { + "event": "state_snapshot", + "name": "final_state", + "data": {"thread_id": thread_id, "state": final_state}, + } + ) + + def _graph_config(self, thread_id: str) -> RunnableConfig: + """グラフ実行時の設定を生成する。 + + Args: + thread_id (str): 実行対象のスレッドID。 + + Returns: + RunnableConfig: LangGraphに与える設定オブジェクト。 + """ + + return { + "configurable": {"thread_id": thread_id}, + "recursion_limit": self._recursion_limit, + } + + def _load_recursion_limit(self) -> int: + """再帰回数の上限値を環境変数から読み込む。 + + Returns: + int: 1以上の再帰上限。環境変数が不正な場合は既定値を返す。 + """ + + raw_value = os.getenv("GRAPH_RECURSION_LIMIT") + try: + limit = int(raw_value) if raw_value is not None else DEFAULT_RECURSION_LIMIT + except (TypeError, ValueError): + limit = DEFAULT_RECURSION_LIMIT + return max(limit, 1) + + def _register_hitl_thread(self, thread_id: str) -> None: + """HITL対象スレッドとして登録し、不要な割り込みを初期化する。 + + Args: + thread_id (str): 登録したいスレッドID。 + """ + + self._hitl_threads.add(thread_id) + self._pending_interrupts.pop(thread_id, None) + + def _record_post_run( + self, + thread_id: str, + pending: Optional[Interrupt], + finished: bool, + ) -> None: + """実行後の割り込み状態を記録する。 + + Args: + thread_id (str): スレッドID。 + pending (Optional[Interrupt]): 次回の割り込み。存在しない場合は ``None``。 + finished (bool): 実行が完了したかどうか。 + """ + + if pending and not finished: + self._hitl_threads.add(thread_id) + self._pending_interrupts[thread_id] = pending + return + self._hitl_threads.discard(thread_id) + self._pending_interrupts.pop(thread_id, None) + + def _is_hitl_thread(self, thread_id: str) -> bool: + """スレッドがHITL対象か判定する。 + + Args: + thread_id (str): 判定対象のスレッドID。 + + Returns: + bool: HITL対象であれば ``True``。 + """ + + return thread_id in self._hitl_threads + + def _get_pending_interrupt(self, thread_id: str) -> Optional[Interrupt]: + """スレッドに紐づく保留中割り込みを取得する。 + + Args: + thread_id (str): 取得対象のスレッドID。 + + Returns: + Optional[Interrupt]: 保留中割り込み。存在しない場合は ``None``。 + """ + + if not self._is_hitl_thread(thread_id): + return None + return self._pending_interrupts.get(thread_id) + + def _serialize_interrupt( + self, interrupt: Interrupt | None + ) -> Dict[str, Any] | None: + """割り込みオブジェクトをシリアライズする。 + + Args: + interrupt (Interrupt | None): 変換対象の割り込み。 + + Returns: + Dict[str, Any] | None: ``id`` と ``value`` を持つ辞書。引数が ``None`` の場合は ``None``。 + """ + + if interrupt is None: + return None + return {"id": interrupt.id, "value": interrupt.value} + + def _extract_interrupt(self, event: Any) -> Optional[Interrupt]: + """イベントから割り込みを抽出する。 + + Args: + event (Any): LangGraphから受信したイベントデータ。 + + Returns: + Optional[Interrupt]: 抽出した割り込み。含まれない場合は ``None``。 + """ + + data = event.get("data") + if not isinstance(data, dict): + return None + + payload: Any | None = None + if event.get("event") == "on_chain_stream": + payload = data.get("chunk") + elif event.get("event") == "on_chain_end": + payload = data.get("output") + + if isinstance(payload, dict) and "__interrupt__" in payload: + interrupts = payload["__interrupt__"] + if isinstance(interrupts, (list, tuple)) and interrupts: + candidate = interrupts[-1] + if isinstance(candidate, Interrupt): + return candidate + return None + + def _sanitize_event(self, event: Any) -> Dict[str, Any]: + """イベントデータをJSONシリアライズ可能な形へ整形する。 + + Args: + event (Any): 整形対象のイベント。 + + Returns: + Dict[str, Any]: キーと値を安全に変換した辞書。 + """ + + if isinstance(event, dict): + return {key: self._convert_model(value) for key, value in event.items()} + return {"event": "message", "data": self._convert_model(event)} + + def _convert_model(self, obj: Any) -> Any: + """各種オブジェクトを辞書またはプリミティブに変換する。 + + Args: + obj (Any): 変換対象のオブジェクト。 + + Returns: + Any: JSON互換の値。 + """ + + if hasattr(obj, "model_dump"): + return obj.model_dump() + if isinstance(obj, Interrupt): + return {"id": obj.id, "value": obj.value} + if isinstance(obj, dict): + return {k: self._convert_model(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [self._convert_model(v) for v in obj] + if isinstance(obj, (str, int, float, bool)) or obj is None: + return obj + return str(obj) + + def _serialize_state( + self, thread_id: str, snapshot: Any | None = None + ) -> Dict[str, Any]: + """LangGraphの状態スナップショットを辞書化する。 + + Args: + thread_id (str): 状態を取得するスレッドID。 + snapshot (Any | None): 既に取得済みのスナップショット。未指定時は内部で取得する。 + + Returns: + Dict[str, Any]: JSON互換の状態辞書。 + + Raises: + StateNotFoundError: スレッド状態が取得できなかった場合。 + """ + + if snapshot is None: + snapshot = self._graph.get_state(self._graph_config(thread_id)) + if snapshot is None: + raise StateNotFoundError("指定したスレッドの状態が見つかりません。") + values = getattr(snapshot, "values", {}) + return {k: self._convert_model(v) for k, v in dict(values).items()} + + def _is_plan_edit_interrupt(self, interrupt: Interrupt) -> bool: + """割り込みが調査計画編集に関するものか判定する。 + + Args: + interrupt (Interrupt): 判定対象の割り込み。 + + Returns: + bool: 調査計画編集であれば ``True``。 + """ + + prompt = getattr(interrupt, "value", "") + if isinstance(prompt, str) and "編集しますか" in prompt: + return True + interrupt_id = getattr(interrupt, "id", "") + return ( + isinstance(interrupt_id, str) + and "_research_plan_human_judge" in interrupt_id + ) + + def _is_run_finished(self, snapshot: Any | None) -> bool: + """実行が完了しているかどうかを判定する。 + + Args: + snapshot (Any | None): 現在のスナップショット。 + + Returns: + bool: 続行不能であれば ``True``。 + """ + + if snapshot is None: + return False + return not getattr(snapshot, "next", None) + + async def _run_until_pause( + self, + payload: Any, + *, + thread_id: str, + auto_resume: bool, + interrupt_predicate: Callable[[Interrupt], bool] | None = None, + event_consumer: Callable[[Dict[str, Any]], Awaitable[None]] | None = None, + ) -> tuple[list[Dict[str, Any]], Optional[Interrupt], bool, Any | None]: + """割り込みが発生するか完了するまで実行する。 + + Args: + payload (Any): グラフへ渡す入力ペイロード。 + thread_id (str): 実行対象のスレッドID。 + auto_resume (bool): 割り込みを自動承認するかどうか。 + interrupt_predicate (Callable[[Interrupt], bool] | None): 人手介入を許可する判定関数。 + event_consumer (Callable[[Dict[str, Any]], Awaitable[None]] | None): イベントを処理するコールバック。 + + Returns: + tuple[list[Dict[str, Any]], Optional[Interrupt], bool, Any | None]: + 収集したイベント、保留割り込み、完了フラグ、最新スナップショット。 + """ + + config = self._graph_config(thread_id) + current_payload: Any = payload + collected_events: list[Dict[str, Any]] = [] + + while True: + pending: Interrupt | None = None + async for event in self._graph.astream_events( + current_payload, config=config, version=STREAM_VERSION + ): + sanitized = self._sanitize_event(event) + if self._is_error_event(sanitized): + error_message = self._extract_error_message(sanitized) + sanitized.setdefault("level", "error") + sanitized.setdefault("message", error_message) + logger.error( + "Workflow error event detected [thread_id=%s, event=%s]: %s", + thread_id, + sanitized.get("event"), + error_message, + ) + if event_consumer: + await event_consumer(sanitized) + collected_events.append(sanitized) + pending = self._extract_interrupt(event) + if pending: + break + + snapshot = self._graph.get_state(config) + finished = self._is_run_finished(snapshot) + + if pending: + allowed = interrupt_predicate(pending) if interrupt_predicate else True + if auto_resume or not allowed: + auto_event = { + "event": "auto_resume", + "name": "human_judge", + "data": {"decision": "n", "thread_id": thread_id}, + } + collected_events.append(self._sanitize_event(auto_event)) + current_payload = Command(resume={pending.id: "n"}) + continue + return collected_events, pending, finished, snapshot + + return collected_events, None, finished, snapshot + + async def _astream( + self, + payload: Any, + *, + thread_id: str, + auto_resume: bool, + interrupt_predicate: Callable[[Interrupt], bool] | None = None, + ) -> AsyncGenerator[Any, None]: + """イベントストリームを生成する内部ジェネレーター。 + + Args: + payload (Any): グラフへ入力する初期ペイロード。 + thread_id (str): 実行対象のスレッドID。 + auto_resume (bool): 割り込みを自動で解消するかどうか。 + interrupt_predicate (Callable[[Interrupt], bool] | None): 割り込み許可判定関数。 + + Yields: + Any: 整形済みのイベントオブジェクト。 + """ + + config = self._graph_config(thread_id) + current_payload: Any = payload + + while True: + pending: Interrupt | None = None + async for event in self._graph.astream_events( + current_payload, config=config, version=STREAM_VERSION + ): + sanitized = self._sanitize_event(event) + if self._is_error_event(sanitized): + error_message = self._extract_error_message(sanitized) + sanitized.setdefault("level", "error") + sanitized.setdefault("message", error_message) + logger.error( + "Workflow error event detected during stream [thread_id=%s, event=%s]: %s", + thread_id, + sanitized.get("event"), + error_message, + ) + yield sanitized + pending = self._extract_interrupt(event) + if pending: + break + if not pending: + break + + allowed = interrupt_predicate(pending) if interrupt_predicate else True + if auto_resume or not allowed: + # カバレッジ計測が async ジェネレーター内の continue 分岐を正しく捕捉しないため除外する。 + yield { + "event": "auto_resume", + "name": "human_judge", + "data": {"decision": "n", "thread_id": thread_id}, + } # pragma: no cover + current_payload = Command(resume={pending.id: "n"}) # pragma: no cover + continue # pragma: no cover + + yield { + "event": "interrupt", + "name": pending.id, + "data": self._serialize_interrupt(pending), + } + return + + def _format_sse(self, event: Dict[str, Any]) -> str: + """イベント辞書をSSE文字列に整形する。 + + Args: + event (Dict[str, Any]): 整形対象のイベント。 + + Returns: + str: SSEフォーマットの文字列。 + """ + + payload = json.dumps(event, default=self._convert_model, ensure_ascii=False) + event_type = event.get("event", "message") + return f"event: {event_type}\ndata: {payload}\n\n" + + def _is_error_event(self, event: Dict[str, Any]) -> bool: + """イベントにエラーが含まれているか判定する。 + + Args: + event (Dict[str, Any]): 判定対象のイベント。 + + Returns: + bool: エラー要素を含む場合は ``True``。 + """ + + event_name = str(event.get("event") or "").lower() + if "error" in event_name: + return True + data = event.get("data") + if isinstance(data, dict): + if "error" in data: + return True + return any("error" in str(key).lower() for key in data.keys()) + return False + + def _extract_error_message(self, event: Dict[str, Any]) -> str: + """イベントからエラーメッセージを抽出する。 + + Args: + event (Dict[str, Any]): 抽出対象のイベント。 + + Returns: + str: 抽出されたメッセージ。候補が無い場合は汎用メッセージ。 + """ + + data = event.get("data") + if isinstance(data, dict): + candidate = ( + data.get("error") + or data.get("message") + or data.get("text") + or data.get("details") + ) + if candidate is not None: + if isinstance(candidate, str): + return candidate + try: + return json.dumps(candidate, ensure_ascii=False) + except (TypeError, ValueError): + return str(candidate) + name = event.get("event") or event.get("name") + if name: + return f"{name} が発生しました。" + return "LLM処理中にエラーが発生しました。" diff --git a/src/backend/tools/get_current_date.py b/src/backend/tools/get_current_date.py deleted file mode 100644 index b5ce7cf..0000000 --- a/src/backend/tools/get_current_date.py +++ /dev/null @@ -1,17 +0,0 @@ -import datetime - -from langchain_core.tools import tool - - -@tool() -def get_current_date(): - """ - 本日の日付を返すツールです。 - 返される情報は、本日の日時です。 - - Returns - ------- - str: - 本日の日付。提供形式は [yyyy-MM-dd] です。 - """ - return datetime.date.today() diff --git a/src/frontend/public/file.svg b/src/frontend/public/file.svg deleted file mode 100644 index 004145c..0000000 --- a/src/frontend/public/file.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/src/frontend/public/globe.svg b/src/frontend/public/globe.svg deleted file mode 100644 index 567f17b..0000000 --- a/src/frontend/public/globe.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/src/frontend/public/next.svg b/src/frontend/public/next.svg deleted file mode 100644 index 5174b28..0000000 --- a/src/frontend/public/next.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/src/frontend/public/vercel.svg b/src/frontend/public/vercel.svg deleted file mode 100644 index 7705396..0000000 --- a/src/frontend/public/vercel.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/src/frontend/public/window.svg b/src/frontend/public/window.svg deleted file mode 100644 index b2b2a44..0000000 --- a/src/frontend/public/window.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/src/frontend/src/app/components/ChatTranscript.tsx b/src/frontend/src/app/components/ChatTranscript.tsx index 9a977a4..ec9c2f7 100644 --- a/src/frontend/src/app/components/ChatTranscript.tsx +++ b/src/frontend/src/app/components/ChatTranscript.tsx @@ -1,47 +1,71 @@ +import ReactMarkdown from "react-markdown"; +import type { Components } from "react-markdown"; import type { ChatMessage } from "../types"; import { formatTimestamp } from "../utils/conversation"; interface ChatTranscriptProps { messages: ChatMessage[]; hideEmptyState?: boolean; + markdownComponents: Components; } -export function ChatTranscript({ messages, hideEmptyState = false }: ChatTranscriptProps) { +export function ChatTranscript({ + messages, + hideEmptyState = false, + markdownComponents, +}: ChatTranscriptProps) { if (messages.length === 0) { if (hideEmptyState) { return null; } return ( -
- ここにリサーチの進行ログが表示されます。 -
+

+ メッセージはまだありません。 +

); } return ( <> - {messages.map((message) => ( -
- {message.title ? ( -

- {message.title} + {messages.map((message, index) => { + const animationDelay = `${Math.min(index, 8) * 60}ms`; + const baseClasses = + message.role === "user" + ? "ml-auto border-emerald-400/50 bg-gradient-to-br from-emerald-500/20 via-emerald-400/10 to-transparent shadow-[0_24px_45px_-28px_rgba(16,185,129,0.7)]" + : message.role === "assistant" + ? "mr-auto border-slate-700/60 bg-slate-900/70" + : "mr-auto border-amber-500/50 bg-amber-500/10"; + + return ( +

+ {message.title ? ( +

+ {message.title} +

+ ) : null} +
+ {message.content} +
+ {message.reasoning ? ( +
+

+ LLMの思考 +

+
+ {message.reasoning} +
+
+ ) : null} +

+ {formatTimestamp(message.createdAt)}

- ) : null} -

{message.content}

-

- {formatTimestamp(message.createdAt)} -

-
- ))} +
+ ); + })} ); } diff --git a/src/frontend/src/app/components/ConversationHeader.tsx b/src/frontend/src/app/components/ConversationHeader.tsx index e24de71..a569369 100644 --- a/src/frontend/src/app/components/ConversationHeader.tsx +++ b/src/frontend/src/app/components/ConversationHeader.tsx @@ -3,7 +3,9 @@ interface ConversationHeaderProps { subtitle?: string | null; statusBadgeLabel?: string | null; statusBadgeClassName?: string | null; + statusBadgeTitle?: string | null; errorMessage?: string | null; + progressSteps?: { label: string; done: boolean }[] | null; } export function ConversationHeader({ @@ -11,24 +13,73 @@ export function ConversationHeader({ subtitle, statusBadgeLabel, statusBadgeClassName, + statusBadgeTitle, errorMessage, + progressSteps, }: ConversationHeaderProps) { return ( -
-
-
-

{title}

- {subtitle ?

{subtitle}

: null} +
+
+
+
+

+ {title} +

+ {subtitle ?

{subtitle}

: null} +
+ {statusBadgeLabel && statusBadgeClassName ? ( + + {statusBadgeLabel} + + ) : null}
- {statusBadgeLabel && statusBadgeClassName ? ( - {statusBadgeLabel} + {progressSteps && progressSteps.length > 0 ? ( +
    + {progressSteps.map((step, index) => ( +
  1. + + {step.done ? ( + + + + ) : ( + index + 1 + )} + + {step.label} +
  2. + ))} +
) : null}
- {errorMessage && ( -

+ {errorMessage ? ( +

{errorMessage}

- )} + ) : null} +
); } diff --git a/src/frontend/src/app/components/ConversationSidebar.tsx b/src/frontend/src/app/components/ConversationSidebar.tsx index 9552147..00462eb 100644 --- a/src/frontend/src/app/components/ConversationSidebar.tsx +++ b/src/frontend/src/app/components/ConversationSidebar.tsx @@ -1,5 +1,9 @@ +"use client"; + +import { useMemo } from "react"; import type { ConversationMeta } from "../types"; -import { formatTimestamp, statusClassName, statusLabel } from "../utils/conversation"; +import { formatTimestamp } from "../utils/conversation"; +import { mergeClassNames } from "../utils/chat-helpers"; type HealthStatus = "loading" | "ok" | "error"; @@ -9,68 +13,136 @@ interface ConversationSidebarProps { healthStatus: HealthStatus; onSelectThread: (threadId: string) => void; onCreateThread: () => void; + isMinimized?: boolean; + onToggleMinimize?: () => void; } -const healthStatusText: Record = { - loading: "接続を確認中", - ok: "バックエンドと接続済み", +const healthStatusText: Record = { + loading: null, + ok: null, error: "バックエンドに接続できません", }; -const healthIndicatorClassName: Record = { - loading: "bg-slate-500 animate-pulse", - ok: "bg-emerald-400", - error: "bg-rose-400", -}; - export function ConversationSidebar({ threadList, selectedThreadId, healthStatus, onSelectThread, onCreateThread, + isMinimized = false, + onToggleMinimize, }: ConversationSidebarProps) { + const statusText = healthStatusText[healthStatus]; + const initialsByThread = useMemo(() => { + return threadList.reduce>((acc, thread) => { + const trimmed = (thread.title ?? "").trim(); + acc[thread.id] = trimmed ? trimmed.slice(0, 2) : "--"; + return acc; + }, {}); + }, [threadList]); + return ( -