diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index e2d41764..53d35be7 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -819,7 +819,7 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户 ### 8层架构分层 1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商 -2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py) +2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块) 3. **消息处理层**:MessageHandler (handlers.py)、SecurityService (security.py)、CommandDispatcher (services/command.py)、AICoordinator (ai_coordinator.py)、QueueManager (queue_manager.py)、Bilibili 自动提取 (bilibili/) 4. **AI 核心能力层**:AIClient (client.py)、PromptBuilder (prompts.py)、ModelRequester (llm.py)、ToolManager (tooling.py)、MultimodalAnalyzer (multimodal.py)、SummaryService (summaries.py)、TokenCounter (tokens.py) 5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb1da182..5db78b07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +## v3.3.2 架构重构、假@检测与认知侧写增强 + +围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 + +- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 +- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认)、`-r` 渲染为图片、`-t` 直接文本发送;超管可通过 `/p ` 和 `/p g <群号>` 跨目标查看。 +- 复读系统全面升级:触发阈值可配置(`repeat_threshold`)、Bot 发言不计入复读链、新增冷却机制(`repeat_cooldown_minutes`,?与 ? 等价)。 +- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型 token 耗尽导致工具调用截断。 +- 新增 arXiv 论文深度分析 Agent 与 `calculator` 安全计算器工具;新增消息历史限制可配置化(`[history].max_records`)。 +- Runtime API 拆分为 8 个路由子模块(`api/routes/`);配置系统拆为 `loader.py`、`models.py`、`hot_reload.py`。 +- 消息预处理并行化(`asyncio.gather`),认知史官改用 XML 格式并统一至 `utils/xml.py`。 +- WebUI:Cmd/Ctrl+K 命令面板、骨架屏、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆 CRUD、Modal 焦点陷阱。 +- WebUI 修复 top_k 溢出崩溃(四层防护)、表情包搜索防抖、仪表盘布局优化。 +- 修复 AI 模仿系统关键词自动回复前缀、LaTeX `\n` 替换破坏数学命令、复读冷却消息丢弃与内存泄漏、historian 热重载追踪缺失、GIF 多帧分析与临时文件泄漏、附件渲染异常等多项问题。 +- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色并提高截断上限至 5000 字符。 + +--- + ## v3.3.1 /version 命令的添加 添加了 /version 命令以查看版本号和更改内容。 diff --git a/CLAUDE.md b/CLAUDE.md index f2e67a0b..d6994630 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,7 +55,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) |------|------| | `ai/` | AI 核心:client.py(主入口)、llm.py(模型请求)、prompts.py(Prompt构建)、tooling.py(工具管理)、multimodal.py(多模态)、model_selector.py(模型选择) | | `services/` | 运行服务:ai_coordinator.py(协调器+队列投递)、queue_manager.py(车站-列车队列)、command.py(命令分发)、model_pool.py(多模型池)、security.py(安全防护) | -| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(6个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | +| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(7个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | | `cognitive/` | 认知记忆:service.py(入口)、vector_store.py(ChromaDB)、historian.py(后台史官异步改写+侧写合并)、job_queue.py、profile_storage.py | | `memes/` | 表情包库:service.py(两阶段AI管线)、worker.py(异步处理队列)、store.py(SQLite)、vector_store.py(ChromaDB)、models.py | | `config/` | 配置系统:loader.py(TOML解析+类型化)、models.py(数据模型)、hot_reload.py(热更新) | @@ -80,7 +80,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) 车站-列车模型(QueueManager):按模型隔离队列组,4 级优先级(超管 > 私聊 > @提及 > 普通群聊),普通队列自动修剪保留最新 2 条,非阻塞按节奏发车(默认 1Hz)。 ### 存储与数据 -- `data/history/` — 消息历史(group_*.json / private_*.json,10000 条限制) +- `data/history/` — 消息历史(group_*.json / private_*.json,默认 10000 条,可通过 `[history]` 配置节调整,0=无限制) - `data/cognitive/` — ChromaDB 向量库 + profiles/ 侧写 + queues/ 任务队列 - `data/memes/` — 表情包库(blobs原图、previews预览图、memes.sqlite3元数据、chromadb向量检索) - `data/memory.json` — 置顶备忘录(500 条上限) diff --git a/README.md b/README.md index 48f231c5..0cc94f04 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙: - ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。 +- 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。 - 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。 - 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。 - 😶 **[表情包系统 (Memes)](docs/memes.md)**:查看表情包两阶段判定管线、统一图片 `uid` 发送机制、检索模式及库存管理说明。 diff --git a/apps/undefined-console/package-lock.json b/apps/undefined-console/package-lock.json index 4c32b005..cb667065 100644 --- a/apps/undefined-console/package-lock.json +++ b/apps/undefined-console/package-lock.json @@ -1,12 +1,12 @@ { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "dependencies": { "@tauri-apps/api": "^2.3.0", "@tauri-apps/plugin-http": "^2.3.0" diff --git a/apps/undefined-console/package.json b/apps/undefined-console/package.json index 1a85c4c0..d7b6e213 100644 --- a/apps/undefined-console/package.json +++ b/apps/undefined-console/package.json @@ -1,7 +1,7 @@ { "name": "undefined-console", "private": true, - "version": "3.3.1", + "version": "3.3.2", "type": "module", "scripts": { "tauri": "tauri", diff --git a/apps/undefined-console/src-tauri/Cargo.lock b/apps/undefined-console/src-tauri/Cargo.lock index 190f9f8b..f9b61ac6 100644 --- a/apps/undefined-console/src-tauri/Cargo.lock +++ b/apps/undefined-console/src-tauri/Cargo.lock @@ -4063,7 +4063,7 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" dependencies = [ "serde", "serde_json", diff --git a/apps/undefined-console/src-tauri/Cargo.toml b/apps/undefined-console/src-tauri/Cargo.toml index c3a9356e..3acc5288 100644 --- a/apps/undefined-console/src-tauri/Cargo.toml +++ b/apps/undefined-console/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" description = "Undefined cross-platform management console" authors = ["Undefined contributors"] license = "MIT" diff --git a/apps/undefined-console/src-tauri/tauri.conf.json b/apps/undefined-console/src-tauri/tauri.conf.json index a8a61f15..7fe51078 100644 --- a/apps/undefined-console/src-tauri/tauri.conf.json +++ b/apps/undefined-console/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "Undefined Console", - "version": "3.3.1", + "version": "3.3.2", "identifier": "com.undefined.console", "build": { "beforeDevCommand": "npm run dev", diff --git a/config.toml.example b/config.toml.example index 84756ee2..4448a048 100644 --- a/config.toml.example +++ b/config.toml.example @@ -152,6 +152,9 @@ api_key = "" # zh: Vision 模型名称。 # en: Vision model name. model_name = "" +# zh: Vision 模型最大输出 tokens。启用 thinking 时建议设大(如 8192),确保思维链消耗后仍有余量输出工具调用。 +# en: Vision model max output tokens. When thinking is enabled, use a larger value (e.g. 8192) so there is still room for tool-call output after thinking. +max_tokens = 8192 # zh: 队列发车间隔(秒,0 表示立即发车)。 # en: Queue interval (seconds; 0 dispatches immediately). queue_interval_seconds = 1.0 @@ -424,6 +427,60 @@ prompt_cache_enabled = true # en: Extra request-body params (optional), e.g. temperature or vendor-specific fields. [models.historian.request_params] +# zh: 消息总结模型配置(/summary /sum 专用;未配置时回退到 agent 模型)。 +# en: Message-summary model config (used by /summary and /sum; falls back to the agent model when unset). +[models.summary] +# zh: OpenAI-compatible 基址 URL,例如 https://api.openai.com/v1(legacy "/chat/completions" 已弃用但仍兼容)。 +# en: OpenAI-compatible base URL, e.g. https://api.openai.com/v1. Note: legacy "/chat/completions" is deprecated but still supported. +api_url = "" +# zh: 消息总结模型 API Key。 +# en: Message-summary model API key. +api_key = "" +# zh: 消息总结模型名称。 +# en: Message-summary model name. +model_name = "" +# zh: 可选限制:最大生成 tokens。 +# en: Optional limit: max generation tokens. +max_tokens = 4096 +# zh: 队列发车间隔(秒,0 表示立即发车)。 +# en: Queue interval (seconds; 0 dispatches immediately). +queue_interval_seconds = 1.0 +# zh: API 模式:传统 chat.completions 或新版 responses。 +# en: API mode: classic chat.completions or the newer responses API. +api_mode = "chat_completions" +# zh: 是否启用 reasoning.effort。 +# en: Enable reasoning.effort. +reasoning_enabled = false +# zh: reasoning effort 档位。 +# en: reasoning effort level. +reasoning_effort = "medium" +# zh: 是否启用 thinking(思维链)。 +# en: Enable thinking (reasoning). +thinking_enabled = false +# zh: thinking 预算 tokens。 +# en: Thinking-budget tokens. +thinking_budget_tokens = 0 +# zh: 是否在请求中发送 budget_tokens(关闭后由提供商决定思维预算)。 +# en: Whether to include budget_tokens in the request (if disabled, the provider decides the thinking budget). +thinking_include_budget = true +# zh: reasoning effort 传参风格:openai(reasoning.effort)/ anthropic(output_config.effort)。 +# en: Reasoning effort wire format: openai (reasoning.effort) / anthropic (output_config.effort). +reasoning_effort_style = "openai" +# zh: 思维链工具调用兼容:启用后在多轮工具调用中回传 reasoning_content,避免部分模型返回 400。 +# en: Thinking tool-call compatibility: pass back reasoning_content in multi-turn tool calls to avoid 400 errors from some models. +thinking_tool_call_compat = true +# zh: Responses API 的 tool_choice 兼容模式:仅在关闭时请求仍返回 500、怀疑上游不兼容对象型 tool_choice 时再尝试开启;开启后上报为 "required" 并只保留目标工具。默认关闭。 +# en: Responses API tool_choice compatibility mode: only try enabling this when requests still return 500 with the default setting and you suspect the upstream does not support object-style tool_choice; it sends "required" and keeps only the selected tool. Disabled by default. +responses_tool_choice_compat = false +# zh: Responses API 续轮强制降级:启用后,多轮工具调用将始终跳过 previous_response_id,直接使用完整消息重放(stateless replay)。仅在上游不兼容 responses 状态续轮时使用。默认关闭。 +# en: Responses API force stateless replay: when enabled, multi-turn tool follow-ups always skip previous_response_id and replay the full message history instead. Use only when the upstream does not handle stateful responses follow-ups correctly. Disabled by default. +responses_force_stateless_replay = false +prompt_cache_enabled = true + +# zh: 额外请求体参数(可选),可用于 temperature 或供应商私有参数。 +# en: Extra request-body params (optional), e.g. temperature or vendor-specific fields. +[models.summary.request_params] + # zh: Grok 搜索模型配置(仅供 web_agent 内的 grok_search 使用;固定走 chat.completions,不支持 tool call 兼容字段)。 # en: Grok search model config (used only by grok_search inside web_agent; always uses chat completions and does not expose tool-call compatibility fields). [models.grok] @@ -656,16 +713,46 @@ pool_enabled = false # zh: 彩蛋提示发送模式。模式:"none"(关闭)/"agent"(主 AI 调用 Agent 时发送)/"tools"(主 AI 或 Agent 调用 Tool 时发送)/"clean"(过滤噪声;对自动预取的工具如 "get_current_time"、"send_message"、"end" 不予提示)/"all"(包括 Agent 内部调用其子工具即 "agent_tool" 的场景也发送)。默认:"none"。 # en: Easter-egg announcement mode. Modes: "none" (off) / "agent" (send when the main AI calls an Agent) / "tools" (send when the main AI or an Agent calls a Tool) / "clean" (filter noise; automatically prefetched tools such as "get_current_time", "send_message", and "end" are not announced) / "all" (also send when an Agent internally calls its sub-tools, i.e. "agent_tool"). Default: "none". agent_call_message_enabled = "none" -# zh: 是否启用群聊关键词(“心理委员”)自动回复。 +# zh: 是否启用群聊关键词("心理委员")自动回复。 # en: Enable keyword auto-replies("心理委员") in group chats. keyword_reply_enabled = false +# zh: 是否启用群聊复读功能(连续 N 条相同消息时复读,N 由 repeat_threshold 控制;若期间有 bot 自身发言则重置链)。 +# en: Enable repeat feature in group chats (repeat when N consecutive identical messages arrive, N set by repeat_threshold; resets if bot itself sent the same text in between). +repeat_enabled = false +# zh: 复读触发所需的连续相同消息条数(来自不同发送者),范围 2–20,默认 3。 +# en: Number of consecutive identical messages (from different senders) required to trigger repeat, range 2–20, default 3. +repeat_threshold = 3 +# zh: 复读冷却时间(分钟)。同一内容被复读后,在冷却时间内不再重复复读。0 = 无冷却。问号类消息(?/?)视为等价。 +# en: Repeat cooldown (minutes). After repeating the same content, won't repeat it again within this cooldown. 0 = no cooldown. Question marks (?/?) are treated as equivalent. +repeat_cooldown_minutes = 60 +# zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 +# en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). +inverted_question_enabled = false # zh: 历史记录配置。 # en: History settings. [history] -# zh: 每个会话最多保留的消息条数。 -# en: Max messages to keep per conversation. +# zh: 每个会话最多保留的消息条数(0 = 无限制,注意内存占用)。 +# en: Max messages to keep per conversation (0 = unlimited, mind memory usage). max_records = 10000 +# zh: 工具过滤查询返回的最大消息条数。 +# en: Max messages returned by tool filtered queries. +filtered_result_limit = 200 +# zh: 工具过滤搜索时扫描的最大消息条数。 +# en: Max messages to scan when tools perform filtered searches. +search_scan_limit = 10000 +# zh: 总结 agent 单次拉取的最大消息条数。 +# en: Max messages the summary agent can fetch per call. +summary_fetch_limit = 1000 +# zh: 总结 agent 按时间范围查询时的最大拉取条数。 +# en: Max messages the summary agent fetches for time-range queries. +summary_time_fetch_limit = 5000 +# zh: OneBot API 回退获取的最大消息条数。 +# en: Max messages to fetch via OneBot API fallback. +onebot_fetch_limit = 10000 +# zh: 群分析工具的消息/成员返回上限。 +# en: Max messages/members returned by group analysis tools. +group_analysis_limit = 500 # zh: Skills 热重载配置(可选)。 # en: Skills hot reload settings (optional). @@ -710,8 +797,8 @@ grok_search_enabled = false # zh: 代理设置(可选)。 # en: Proxy settings (optional). [proxy] -# zh: 是否在 "web_agent" 中启用代理。 -# en: Enable proxy in "web_agent". +# zh: 是否使用代理。 +# en: Whether to use proxy. use_proxy = true # zh: 例如 http://127.0.0.1:7890(也可使用环境变量 "HTTP_PROXY")。 # en: e.g. http://127.0.0.1:7890 (or use the "HTTP_PROXY" environment variable). @@ -1095,6 +1182,12 @@ max_total_bytes = 5368709120 # zh: 是否允许 GIF 入库。 # en: Whether GIF files are allowed. allow_gif = true +# zh: GIF 分析模式:grid(多帧拼接为网格图)或 multi(多帧分开发送给模型)。 +# en: GIF analysis mode: grid (composite frames into grid) or multi (send frames separately). +gif_analysis_mode = "grid" +# zh: GIF 分析帧数(包括首末帧,均匀采样)。 +# en: Number of frames to extract for GIF analysis (including first/last, evenly sampled). +gif_analysis_frames = 6 # zh: 是否自动处理群聊图片。 # en: Auto-ingest group chat images. auto_ingest_group = true diff --git a/docs/build.md b/docs/build.md index 297a6995..17e7f3d1 100644 --- a/docs/build.md +++ b/docs/build.md @@ -28,6 +28,57 @@ uv sync --group dev -p 3.12 uv run playwright install ``` +### 系统级 LaTeX 环境(必装,用于 `render.render_latex`) + +`render.render_latex` 使用系统外部 LaTeX(`usetex=True`)渲染公式,**必须提前安装**,否则渲染会失败并返回错误。 + +**Debian / Ubuntu** + +```bash +sudo apt-get update +sudo apt-get install -y texlive-full dvipng ghostscript +``` + +**Arch Linux** + +```bash +sudo pacman -S --needed \ + texlive-basic \ + texlive-bin \ + texlive-latex \ + texlive-latexrecommended \ + texlive-latexextra \ + texlive-fontsrecommended \ + texlive-binextra \ + texlive-mathscience \ + ghostscript +``` + +**macOS** + +```bash +# 推荐 MacTeX(完整,约 4 GB) +brew install --cask mactex-no-gui + +# 或体积更小的 BasicTeX,之后按需补包 +brew install --cask basictex +sudo tlmgr update --self +sudo tlmgr install dvipng type1cm type1ec cm-super collection-fontsrecommended +``` + +**Windows** + +安装 [MiKTeX](https://miktex.org/download)(推荐,缺包时自动下载)或 [TeX Live](https://tug.org/texlive/windows.html)。安装完成后在 MiKTeX Console 里手动安装 `dvipng` 包,并确保 `latex.exe` 在 PATH 中。 + +**验证** + +```bash +latex --version +dvipng --version +``` + +若日志出现 `type1ec.sty not found` 或 `latex was not able to process`,TeX 包仍不完整:Debian / Ubuntu 已装 `texlive-full` 则无需额外操作;Arch 补装 `texlive-latexextra` `texlive-fontsrecommended` `texlive-binextra`;macOS BasicTeX 用户运行 `sudo tlmgr install cm-super`。 + ### Node.js / Rust / Tauri 如果需要构建跨平台控制台,请额外准备: diff --git a/docs/configuration.md b/docs/configuration.md index fc739c2d..7764a2db 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -426,6 +426,10 @@ Prompt caching 补充: |---|---:|---|---| | `agent_call_message_enabled` | `"none"` | 调用提示模式 | `none` / `agent` / `tools` / `all` / `clean` | | `keyword_reply_enabled` | `false` | 群聊关键词自动回复 | 布尔 | +| `repeat_enabled` | `false` | 群聊复读(连续 N 条相同消息时复读) | 布尔 | +| `repeat_threshold` | `3` | 触发复读所需的连续相同消息条数(来自不同发送者) | 整数,2–20 | +| `repeat_cooldown_minutes` | `60` | 复读冷却时间(分钟)。同一内容被复读后,在冷却期内不再重复复读。?和 ? 视为等价。0 = 无冷却 | 整数,≥ 0 | +| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送 ¿) | 布尔 | 兼容:历史字段 `[core].keyword_reply_enabled` 仍可读取,建议迁移到 `[easter_egg]`。 @@ -722,6 +726,8 @@ Prompt caching 补充: | `max_items` | `10000` | 表情包条目上限 | `<=0` 回退 `10000` | | `max_total_bytes` | `5368709120` | 表情包总磁盘占用上限(字节) | `<=0` 回退 `5368709120` | | `allow_gif` | `true` | 是否允许 GIF 入库 | | +| `gif_analysis_mode` | `"grid"` | GIF 动图判定分析模式:`"grid"` (网格拼图)、`"multi"` (多图逐帧)、`"first_frame"` (仅第一帧) | 非法值回退 `"grid"` | +| `gif_analysis_frames` | `6` | GIF 动图抽帧供模型识别的数量 | `<=0` 回退 `6` | | `auto_ingest_group` | `true` | 是否自动处理群聊图片 | | | `auto_ingest_private` | `true` | 是否自动处理私聊图片 | | | `keyword_top_k` | `30` | 关键词候选召回数 | `<=0` 回退 `30` | @@ -732,7 +738,8 @@ Prompt caching 补充: - 表情包入库走两阶段 LLM 管线: 1. 判定是否为表情包 2. 对通过判定的图片生成纯文本描述与标签 -- 第一阶段失败时,按“不是表情包”处理,直接丢弃。 +- 第一阶段失败时,按“不是表情包”处理,直接丢弃(如果是网络和服务器限流等异常,系统会在后台自动重试)。 +- 对于 GIF 格式图片的分析,`"grid"` 模式会将多个抽帧横向并排或拼接在一张大图中降低计费单元,`"multi"` 模式则将各帧作为独立图像输入至多模态大模型。 - 第二阶段不做 OCR;向量存储和检索文本只使用纯文本 `description + tags + aliases`。 - 同一图片内容在单进程内会按 `SHA256` 串行入库,避免并发表情包重复写入。 - 若入库在写入来源记录或向量索引阶段失败,会回滚已写入的元数据与本地文件,避免残留孤儿记录。 diff --git a/docs/deployment.md b/docs/deployment.md index 9e24783b..bc5eb1be 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -1,6 +1,6 @@ # 安装与部署指南 -提供 pip/uv tool 安装与源码部署两种方式:前者适合直接使用;后者适合深度自定义与二次开发。 +提供源码部署与 pip/uv tool 安装两种方式:**源码部署是推荐的首选方式**,功能完整且经过充分测试;pip/uv tool 安装适合快速体验,但部分功能支持尚不完善。 > Python 版本要求:`3.11`~`3.13`(包含)。 > @@ -8,7 +8,137 @@ --- -## pip/uv tool 部署(快速,适合默认行为) +## 源码部署(推荐) + +### 1. 克隆项目 + +由于项目中使用了 `NagaAgent` 作为子模块,请使用以下命令克隆项目: + +```bash +git clone --recursive https://github.com/69gg/Undefined.git +cd Undefined +``` + +如果已经克隆了项目但没有初始化子模块: + +```bash +git submodule update --init --recursive +``` + +### 2. 安装依赖 + +推荐使用 `uv` 进行现代化的 Python 依赖管理(速度极快): + +```bash +# 安装 uv (如果尚未安装) +pip install uv + +# 可选:预装一个兼容解释器(推荐 3.12) +# uv python install 3.12 + +# 同步依赖 +# uv 会根据 pyproject.toml 自动处理 3.11~3.13 的解释器选择 +uv sync +``` + +同时需要安装 Playwright 浏览器内核(用于网页浏览功能): + +```bash +uv run playwright install +``` + +### 3. 安装系统级依赖(必装) + +Bot 内置的数学公式等功能直接强依赖系统级的渲染环境,你**必须**提前在宿主机配置以下依赖。若是缺失该依赖,渲染图像或公式时后台将直接报错。 + +**安装 LaTeX 与工具链**: + +- **Ubuntu / Debian** + 直接无脑安装完整的 TeX Live 环境最为稳妥: + ```bash + sudo apt-get update + sudo apt-get install -y texlive-full dvipng ghostscript + ``` + +- **Arch Linux** + 通过 pacman 安装基础包: + ```bash + sudo pacman -S --needed texlive-basic texlive-bin texlive-latex texlive-latexrecommended texlive-latexextra texlive-fontsrecommended texlive-binextra texlive-mathscience ghostscript + ``` + +- **macOS** + 推荐通过 Homebrew 安装 MacTeX 环境,提供完整(省心,体积较大)或者精简两个版本: + ```bash + # 方式 1:完整环境(推荐) + brew install --cask mactex-no-gui + + # 方式 2:精简版(体积小,需手动拉取补包) + brew install --cask basictex + sudo tlmgr update --self + sudo tlmgr install dvipng type1cm type1ec cm-super collection-fontsrecommended + ``` + +- **Windows** + 安装 [MiKTeX](https://miktex.org/download) (推荐,能自动下载缺失宏包)或者 [TeX Live](https://tug.org/texlive/windows.html)。 + 1. 打开 MiKTeX Console。 + 2. 搜索 `dvipng` 手动将其安装上。 + 3. 确认环境变量 `PATH` 中已经包含了 `latex.exe`。 + +> 验证安装:使用 `latex --version` 与 `dvipng --version` 命令检测是否识别。如日志报错 `type1ec.sty not found` 或 `dvipng: command not found`,一般是由于所处的系统少安装了包或可执行文件不在环境变量中。 + +### 4. 配置环境 + +复制示例配置文件 `config.toml.example` 为 `config.toml` 并填写你的配置信息。 + +```bash +cp config.toml.example config.toml +``` + +#### 源码部署的自定义指南 + +- **自定义提示词/预置文案**:直接修改仓库根目录的 `res/`(例如 `res/prompts/`)。 +- **自定义图片资源**:修改 `img/` 下的对应文件(例如 `img/xlwy.jpg`)。 +- **优先级**:若你希望“运行目录覆盖优先”:在启动目录放置 `./res/...`,会优先于默认资源生效(便于一套安装,多套运行配置)。 + +### 5. 启动运行 + +启动方式(二选一): + +```bash +# 1) 直接启动机器人(无 WebUI) +uv run Undefined + +# 2) 启动 WebUI(在浏览器里编辑配置,并在 WebUI 内启停机器人) +uv run Undefined-webui +``` + +> **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 +> +> WebUI 功能详见 [WebUI 使用指南](webui-guide.md)。 + +### 6. 跨平台与资源路径(重要) + +- **资源读取**:运行时会优先从运行目录加载同名 `res/...` / `img/...`(便于覆盖),若不存在再使用安装包自带资源;并提供仓库结构兜底查找,因此从任意目录启动也能正常加载提示词与资源文案。 +- **并发写入**:运行时会为 JSON/日志类文件使用”锁文件 + 原子替换”写入策略,Windows/Linux/macOS 行为一致(会生成 `*.lock` 文件)。 + +### Management-first 推荐流程 + +推荐把 `Undefined-webui` 当作默认入口: + +1. 运行 `uv run Undefined-webui` +2. 在浏览器中打开管理控制台 +3. 若 `config.toml` 缺失,WebUI 会自动生成模板 +4. 在控制台中补齐配置、保存并校验 +5. 直接点击启动 Bot +6. 若需要远程管理,再使用桌面端或 Android App 连接到这个 Management API + +这样可以避免"先手写配置、再反复命令行重启"的冷启动成本,尤其适合首次部署与远程运维。 + +--- + +## pip/uv tool 部署(快速体验) + +> **注意**:pip/uv tool 安装方式的功能支持尚不如源码部署完善,也未经过充分测试。如遇问题,建议优先切换到源码部署。 适合只想“安装后直接跑”的场景,`Undefined`/`Undefined-webui` 命令会作为可执行入口安装到你的环境中。 @@ -28,6 +158,8 @@ uv tool install Undefined-bot uv tool run --from Undefined-bot playwright install ``` +> **系统依赖提醒**:同源码部署要求一致,你必须在宿主机上预先安装所需的 LaTeX/dvipng 渲染环境。请参考上文 [3. 安装系统级依赖(必装)](#3-安装系统级依赖必装) 查阅你操作系统的对应安装命令,未配置前若触发公式与 Markdown 的图片渲染则会报错执行失败。 + 安装完成后,在任意目录准备 `config.toml` 并启动: ```bash @@ -46,20 +178,7 @@ Undefined-webui > - 选择 `Undefined-webui`:启动后访问 WebUI(默认 `http://127.0.0.1:8787`,密码默认 `changeme`;**首次启动必须修改默认密码,默认密码不可登录**;可在 `config.toml` 的 `[webui]` 中修改),在 WebUI 中在线编辑/校验配置,并通过 WebUI 启动/停止机器人进程。 > `Undefined-webui` 会在检测到当前目录缺少 `config.toml` 时,自动从 `config.toml.example` 生成一份,便于直接在 WebUI 中修改。 -> 提示:资源文件已随包发布,支持在非项目根目录启动;如需自定义内容,请参考下方说明。 - -## Management-first 推荐流程 - -推荐把 `Undefined-webui` 当作默认入口: - -1. 运行 `Undefined-webui` 或 `uv run Undefined-webui` -2. 在浏览器中打开管理控制台 -3. 若 `config.toml` 缺失,WebUI 会自动生成模板 -4. 在控制台中补齐配置、保存并校验 -5. 直接点击启动 Bot -6. 若需要远程管理,再使用桌面端或 Android App 连接到这个 Management API - -这样可以避免“先手写配置、再反复命令行重启”的冷启动成本,尤其适合首次部署与远程运维。 +> 提示:资源文件已随包发布,支持在非项目根目录启动;如需自定义内容,请参考上方源码部署的自定义指南。 ### 完整日志(排查用) @@ -90,7 +209,7 @@ mkdir -p res/prompts # 然后把你想改的提示词放到对应路径(文件名与目录层级保持一致) ``` -如果你希望直接修改“默认提示词/默认文案”(而不是每个运行目录做覆盖),推荐使用下面的“源码部署”,在仓库里修改 `res/` 后运行;不建议直接修改已安装环境的 `site-packages/res`(升级会被覆盖)。 +如果你希望直接修改“默认提示词/默认文案”(而不是每个运行目录做覆盖),推荐使用上面的“源码部署”,在仓库里修改 `res/` 后运行;不建议直接修改已安装环境的 `site-packages/res`(升级会被覆盖)。 如果你不知道安装包内默认提示词文件在哪,可以用下面方式打印路径(用于复制一份出来改): @@ -106,80 +225,6 @@ python -c "from Undefined.utils.resources import read_text_resource; print(len(r --- -## 源码部署(推荐开发/高定使用) - -### 1. 克隆项目 - -由于项目中使用了 `NagaAgent` 作为子模块,请使用以下命令克隆项目: - -```bash -git clone --recursive https://github.com/69gg/Undefined.git -cd Undefined -``` - -如果已经克隆了项目但没有初始化子模块: - -```bash -git submodule update --init --recursive -``` - -### 2. 安装依赖 - -推荐使用 `uv` 进行现代化的 Python 依赖管理(速度极快): - -```bash -# 安装 uv (如果尚未安装) -pip install uv - -# 可选:预装一个兼容解释器(推荐 3.12) -# uv python install 3.12 - -# 同步依赖 -# uv 会根据 pyproject.toml 自动处理 3.11~3.13 的解释器选择 -uv sync -``` - -同时需要安装 Playwright 浏览器内核(用于网页浏览功能): - -```bash -uv run playwright install -``` - -### 3. 配置环境 - -复制示例配置文件 `config.toml.example` 为 `config.toml` 并填写你的配置信息。 - -```bash -cp config.toml.example config.toml -``` - -#### 源码部署的自定义指南 - -- **自定义提示词/预置文案**:直接修改仓库根目录的 `res/`(例如 `res/prompts/`)。 -- **自定义图片资源**:修改 `img/` 下的对应文件(例如 `img/xlwy.jpg`)。 -- **优先级**:若你希望“运行目录覆盖优先”:在启动目录放置 `./res/...`,会优先于默认资源生效(便于一套安装,多套运行配置)。 - -### 4. 启动运行 - -启动方式(二选一): - -```bash -# 1) 直接启动机器人(无 WebUI) -uv run Undefined - -# 2) 启动 WebUI(在浏览器里编辑配置,并在 WebUI 内启停机器人) -uv run Undefined-webui -``` - -> **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 - -### 5. 跨平台与资源路径(重要) - -- **资源读取**:运行时会优先从运行目录加载同名 `res/...` / `img/...`(便于覆盖),若不存在再使用安装包自带资源;并提供仓库结构兜底查找,因此从任意目录启动也能正常加载提示词与资源文案。 -- **并发写入**:运行时会为 JSON/日志类文件使用”锁文件 + 原子替换”写入策略,Windows/Linux/macOS 行为一致(会生成 `*.lock` 文件)。 - ---- - ## NapCat / Lagrange.Core 部署要求 **NapCat(或 Lagrange.Core)必须与 Bot 进程共享同一文件系统,不能将 NapCat 单独放在无法访问 Bot 数据目录的 Docker 容器内。** diff --git a/docs/development.md b/docs/development.md index 6dab8180..522c1f65 100644 --- a/docs/development.md +++ b/docs/development.md @@ -22,8 +22,14 @@ src/Undefined/ │ ├── agents/ # 智能体 (独立自主的子 AI,负责处理诸如 Web 搜索、文件分析的具体长时任务) │ ├── commands/ # 中心化斜杠指令系统 (实现如 /help, /stats, /addadmin 等平台功能) │ └── anthropic_skills/# Anthropic 协议集成的外部 Skills (兼容 SKILL.md 格式) +├── config/ # 配置系统 (loader.py TOML 解析, models.py 数据模型, hot_reload.py 热更新) +├── api/ # Management API + Runtime API +│ ├── routes/ # 路由子模块 (chat, tools, naga, system, memes, memory, cognitive, health) +│ ├── app.py # aiohttp 服务主入口 (薄包装委派到 routes/) +│ └── _openapi.py # OpenAPI 文档生成 +├── memes/ # 表情包库 (两阶段 AI 管线, SQLite + ChromaDB) ├── services/ # 核心运行服务 (Queue 任务队列, Command 命令分发, Security 安全防护拦截) -├── utils/ # 通用支持工具组 (包含历史处理、JSON原子读写加锁 IO 操作等) +├── utils/ # 通用支持工具组 (io.py 异步原子读写, history.py, coerce.py 类型强转, fake_at.py 假@检测) ├── handlers.py # 最外层 OneBot 消息分流处理层 └── onebot.py # OneBot WebSocket 客户端核心连接 ``` diff --git a/docs/management-api.md b/docs/management-api.md index 4daec184..346ae86b 100644 --- a/docs/management-api.md +++ b/docs/management-api.md @@ -25,6 +25,8 @@ uv run Undefined-webui 4. 直接在控制台启动 Bot 5. 如需远程管理,再让桌面端或 Android App 连接同一个 Management API 地址 +> WebUI 各页面的功能和操作详见 [WebUI 使用指南](webui-guide.md)。 + ## 2. 鉴权模型 Management API 兼容两套鉴权: diff --git a/docs/memes.md b/docs/memes.md index 29e993d4..e0f93e86 100644 --- a/docs/memes.md +++ b/docs/memes.md @@ -11,12 +11,14 @@ Undefined 平台自 3.3.0 版本起内置了强大的**全局表情包库**功 2. **第一阶段 - 属性判定 (Judge)**: 提交给视觉模型(通过 `judge_meme_image.txt` 提示词)分析图片本质。如果图片只是普通的自拍、系统截图或者无法表现梗(Meme)的内容,流程将在此终止。 + 对于 GIF 动图,系统可根据配置(网格拼接或多张多帧)进行抽帧重组以提供更连贯的视觉上下文。若在此交互期间遇到暂时的网络错误或接口报错,处理管线会自动重试,确保判定过程的高可用性。 3. **第二阶段 - 语义解析 (Describe)**: 对于被判定为表情包的图片,模型会进一步(通过 `describe_meme_image.txt` 提示词)提取: - 图片的关键视觉元素与构图。 - 隐喻、情感与适合的回复语境。 - 高质量的搜索标签(Tags)。 + 同样,该阶段依然享有自动重试逻辑保护,从而保障长流程分析和描述内容的成功入库。 4. **向量化与持久化**: 提取出的结构化文本与标签被存入 SQLite (`MemeStore`),并通过嵌入模型向量化后存入 ChromaDB (`MemeVectorStore`)。原图及其生成的预览图(如 GIF 抽帧)持久化存放至数据目录。 @@ -41,6 +43,8 @@ Undefined 平台自 3.3.0 版本起内置了强大的**全局表情包库**功 [memes] enabled = true # 是否启用 query_default_mode = "hybrid" # 默认搜索策略:keyword / semantic / hybrid +gif_analysis_mode = "grid" # GIF 的多帧识别模式:grid(网格拼接)、multi(多张散图)、first_frame(仅首帧) +gif_analysis_frames = 6 # GIF 的抽帧数量 ``` 更多细节请查阅 [配置文档](configuration.md#425-memes-表情包库)。 diff --git a/docs/slash-commands.md b/docs/slash-commands.md index 9be0f8e0..f0baa17e 100644 --- a/docs/slash-commands.md +++ b/docs/slash-commands.md @@ -84,7 +84,63 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /changelog latest ``` -#### 2. 统计与分析服务 +#### 2. 消息总结与侧写查看 + +- **/profile [group] [-f|-r|-t] [目标ID]** + - **说明**:查看用户或群聊的认知侧写。侧写由系统根据聊天历史自动生成和更新。 + - **别名**:`/me`、`/p` + - **参数**: + + | 参数 | 是否必填 | 说明 | + |------|----------|------| + | `group` / `g` | 可选 | 查看群聊侧写(仅群聊可用) | + | `-f` / `--forward` | 可选 | 合并转发模式输出(默认) | + | `-r` / `--render` | 可选 | 渲染为图片发送 | + | `-t` / `--text` | 可选 | 直接文本消息发送 | + | `` | 可选 | 🔒 超管专用:查看指定用户的侧写 | + | `g <群号>` | 可选 | 🔒 超管专用:查看指定群聊的侧写 | + + - **行为**: + - **私聊**:查看自己的用户侧写,不支持 `group` 参数。 + - **群聊**:不带参数查看自己的用户侧写,带 `group` / `g` 查看当前群聊侧写。 + - **超管指定目标**:超级管理员可传入 QQ 号或群号查看任意用户/群的侧写,非超管使用时提示无权限。 + - **输出模式**:默认合并转发;`-r` 渲染为图片;`-t` 直接文本发送。 + - **限流**:普通用户 60 秒,管理员 10 秒,超管无限制。 + - **示例**: + ``` + /profile → 查看自己的侧写(合并转发) + /p -r → 查看自己的侧写(渲染图片) + /p -t → 查看自己的侧写(直接文本) + /me → 同上(别名) + /profile group → 查看当前群聊的侧写 + /p g → 同上 + /p 123456 → 🔒 超管:查看QQ号123456的侧写 + /p g 789012 → 🔒 超管:查看群号789012的侧写 + /p 123456 -r → 🔒 超管:查看指定用户侧写(渲染图片) + ``` + +- **/summary [条数|时间范围] [自定义描述]** + - **说明**:调用消息总结 Agent,拉取指定范围的聊天消息并进行智能总结。 + - **别名**:`/sum` + - **参数**: + + | 参数 | 是否必填 | 说明 | + |------|----------|------| + | `条数` | 可选 | 纯数字,表示总结最近 N 条消息(默认 50,最大 500) | + | `时间范围` | 可选 | 格式如 `1h`、`6h`、`1d`、`7d`,与条数互斥 | + | `自定义描述` | 可选 | 总结的重点方向,如"技术讨论"、"项目进展" | + + - **限流**:普通用户 120 秒,管理员 30 秒,超管无限制。 + - **示例**: + ``` + /summary → 总结最近 50 条消息 + /summary 100 → 总结最近 100 条消息 + /summary 1d → 总结过去 1 天的消息 + /summary 50 技术讨论 → 总结最近 50 条,重点关注技术讨论 + /sum 1d 项目进展 → 总结过去 1 天,重点关注项目进展 + ``` + +#### 3. 统计与分析服务 - **/stats [时间范围] [--ai]** - **说明**:生成过去一段时间内 Token 的使用统计数据、模型消耗排行、输入输出比例,并输出可视化图表。默认不启用 AI 分析,显式传 `--ai`(或 `-a`)才会触发。 - **参数**: @@ -116,7 +172,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /stats 30d --ai → 最近 30 天并启用 AI 分析 ``` -#### 3. 权限管理 (动态 Admin) +#### 4. 权限管理 (动态 Admin) 通过指令动态管理管理员列表,变更会自动持久化到 `config.local.json`,无需重启。超管(Superadmin)拥有最高权限,由配置文件的 `core.super_admins` 静态定义。 - **/lsadmin** @@ -151,7 +207,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 - 若目标本身不是管理员,返回"不是管理员"提示。 - **示例**:`/rmadmin 123456789` -#### 4. 本地群级 FAQ 系统 +#### 5. 本地群级 FAQ 系统 用于对常见问题(FAQ)进行检索和管理。FAQ 不必每次请求 AI 大模型,极大地节省 Token 并加快响应。 - **/lsfaq** @@ -192,7 +248,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 - **边界行为**:若 ID 不存在,返回"FAQ 不存在"提示。 - **示例**:`/delfaq 20241205-001` -#### 5. 排障与反馈 +#### 6. 排障与反馈 - **/bugfix \ [QQ号2...] \<开始时间\> \<结束时间\>** - **说明**:从群历史记录中抓取指定用户在指定时间段内的消息(包含文字、图片的 OCR 描述),交给 AI 进行分析并生成 Bug 修复报告,结果自动存入 FAQ 库。 - **参数**: @@ -214,7 +270,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /bugfix 111111 222222 2024/12/01/09:00 2024/12/01/18:00 ``` -#### 6. Naga 集成管理 +#### 7. Naga 集成管理 > **⚠️ 此功能面向与 NagaAgent 对接的高级场景,普通用户不建议开启。** 需要在 `config.toml` 中同时启用 `[api].enabled`、`[features].nagaagent_mode_enabled` 和 `[naga].enabled`。 @@ -347,6 +403,8 @@ async def execute(args: list[str], context: CommandContext) -> None: | `ctx.bot_qq` | `int` | 当前机器人的自身 QQ 号 | | `ctx.ai` | `AIClient` | 主 AI Client,可以用于进行分析、总结等大模型调用 | | `ctx.faq_storage` | `FAQStorage` | FAQ 的键值操作入口 | +| `ctx.cognitive_service` | `Any \| None` | 认知侧写服务,可调用 `get_profile(entity_type, entity_id)` | +| `ctx.history_manager` | `Any \| None` | 消息历史管理器,可调用 `get_recent(chat_id, msg_type, start, end)` | ### 3. 可用的 `permission` (权限级别) diff --git a/docs/usage.md b/docs/usage.md index 8cb2e96b..66dbc307 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1,64 +1,390 @@ # 使用与功能说明 -## 开始使用 +本文档对 Undefined 的功能模块进行系统性介绍。完成[部署配置](deployment.md)并成功与 QQ 端建立连接后,即可通过自然语言或结构化指令使用以下全部能力。 -1. 启动 OneBot 协议端(如 NapCat 或 Lagrange.Core)并登录 QQ。 -2. 配置好 `config.toml` 并[启动 Undefined](deployment.md)。 -3. 连接成功后,机器人即可在群聊或私聊中响应。 +--- + +## 目录 + +1. [基础交互方式](#1-基础交互方式) +2. [认知记忆系统](#2-认知记忆系统) +3. [内置智能体 (Agents)](#3-内置智能体-agents) +4. [工具集能力一览 (Toolsets & Tools)](#4-工具集能力一览-toolsets--tools) +5. [定时任务与调度](#5-定时任务与调度) +6. [FAQ 知识库管理](#6-faq-知识库管理) +7. [内置斜杠指令参考](#7-内置斜杠指令参考) +8. [多模型池(私聊模型切换)](#8-多模型池私聊模型切换) +9. [WebUI 与跨平台管理](#9-webui-与跨平台管理) + +--- + +## 1. 基础交互方式 + +### 私聊场景 +在私聊会话中,可以直接向 Bot 发送任意消息,无需附加任何前缀或格式要求。系统会自动维护当前对话的完整上下文。 + +### 群聊场景 +在群聊环境中,Bot 默认仅响应以下方式触发的消息: +- **@提及**:在消息中 `@Bot` 并附带指令内容。 +- **指令前缀**:使用 `config.toml` 中配置的前缀(如有)。 + +> **队列优先级说明**:系统底层采用四级消息队列调度模型,优先级从高到低为:超级管理员 > 私聊 > @提及 > 普通群聊。在群聊高并发场景下,管理请求和直接提及将优先得到响应。 + +--- + +## 2. 认知记忆系统 + +Undefined 搭载了基于 ChromaDB 向量数据库的后台认知系统,无需手动录入,即可实现跨会话的长期上下文追踪。 + +| 能力 | 说明 | +|---|---| +| **聊天侧写(Profile)** | 系统实时静默分析对话内容,自动提取并持久化用户的偏好、待办、身份与观点等信息,在后续对话中作为参考背景 | +| **历史事件检索** | 基于向量语义检索,支持按用户、群组、时间段查询历史记忆,并应用时间衰减加权排序 | +| **群聊宏观总结** | 可对历史消息进行语义召回与整合,快速梳理出大量消息中的重点内容 | + +**示例:** +> *"请回忆一下我们上周讨论过的项目规划内容。"* +> *"请总结一下本群过去三天内讨论的主要话题。"* + +--- + +## 3. 内置智能体 (Agents) + +智能体(Agent)是由独立大模型驱动的高自治任务处理器。主 AI 在理解到任务超出自身直接能力范围时,会自动将任务委托给相应的专业 Agent,由其递归调用子工具完成任务后汇报结果。 + +### `web_agent` — 网络信息检索助手 + +负责网页搜索和网页内容爬取,能够获取互联网上的实时最新信息。 + +**子工具**:`grok_search`(Grok 搜索)、`web_search`(通用搜索)、`crawl_webpage`(网页内容提取) + +**示例:** +> *"请搜索最近三天关于 DeepSeek 的最新动态并生成摘要。"* +> *"帮我爬取这个网页的主要内容并整理成结构化笔记。"* + +--- + +### `file_analysis_agent` — 文件分析助手 + +支持对代码、PDF、Word、Excel 等多种格式文件进行解析与分析。用户只需将文件发送至对话中即可。 + +**子工具**:`analyze_pdf`、`analyze_docx`、`analyze_xlsx`、`analyze_code`、`read_file` + +**示例:** +> *"请分析这份 PDF 文档,提取其中第三章的核心数据。"* +> *"请检查这份 Python 代码,找出其中潜在的性能瓶颈。"* + +--- + +### `info_agent` — 信息查询助手 + +整合了多种公开信息查询能力,覆盖天气、热搜、域名、哔哩哔哩以及学术论文等信息源。 + +**子工具**:`weather_query`、`*hot`(热搜榜)、`whois`(域名查询)、`bilibili_search`、`bilibili_user_info`、`arxiv_search` + +**示例:** +> *"北京明天的天气怎么样?"* +> *"查一下今天的微博热搜前十名。"* +> *"帮我查询 arxiv 上关于 Chain-of-Thought 的最新论文。"* +> *"查一下 B 站 UP 主 xxx 的近期投稿情况。"* + +--- + +### `entertainment_agent` — 娱乐助手 + +提供运势、小说、随机图片和随机视频等休闲娱乐类功能。 + +**子工具**:`horoscope`(星座运势)、`novel_search`(小说检索)、`ai_draw_one`(AI 绘图)、`video_random_recommend`(随机视频推荐) + +**示例:** +> *"查一下天蝎座今天的运势。"* +> *"随机推荐几个有趣的视频。"* +> *"帮我画一张赛博朋克风格的城市夜景。"* + +--- + +### `code_delivery_agent` — 代码分析与交付助手 + +支持沙盒级别的代码代写、本地执行验证与自动打包。测试通过后,代码成果会自动打包为 `.zip` 文件并通过 QQ 发送给用户。 + +**示例:** +> *"请使用 Python 编写一个 HTTP 测速脚本,监听 8080 端口,验证跑通后将整个项目打包发到这个群。"* + +--- + +### `naga_code_analysis_agent` — NagaAgent 代码分析助手 + +专门用于深度分析 NagaAgent 框架及本项目的源代码结构。 + +**子工具**:`read_file`、`search_code`、`analyze_structure` + +--- + +## 4. 工具集能力一览 (Toolsets & Tools) + +除了通过 Agent 按需调用外,以下工具在对话中均可以通过自然语言直接触发。 + +### 渲染 (`render.*`) + +| 工具 | 说明 | +|---|---| +| `render.render_markdown` | 将 Markdown 文本(含表格、代码块、标题等)渲染为图片发送 | +| `render.render_latex` | 将 LaTeX 数学公式渲染为图片(**依赖系统 TeX 环境**,需提前安装,详见[部署文档](deployment.md#3-安装系统级依赖必装)) | +| `render.render_html` | 将 HTML 内容渲染为图片 | + +支持 `embed`(嵌入回复)和 `send`(直接发送)两种图片交付方式。 + +**示例:** +> *"请把这段数学公式渲染成图片发给我:$E=mc^2$"* +> *"请把下面这份 Markdown 表格渲染成图片。"* --- -## Agent 能力展示 +### 表情包 (`memes.*`) -机器人通过自然语言理解用户意图,自动调度相应的专业 Agent,具有高度独立和自动化的能力: +| 工具 | 说明 | +|---|---| +| `memes.search_memes` | 支持 `keyword`(关键词精确匹配)、`semantic`(语义联想检索)、`hybrid`(混合模式)三种检索方式 | +| `memes.send_meme_by_uid` | 根据图片统一 uid 以独立消息发送原图表情包 | -* **网络搜索提取**:"搜索一下 DeepSeek 的最新动态" -* **多模态文件分析**:"总结一下群里长图说了什么","帮我提取这份 PDF 中的数据" -* **表情包检索与发送**:在轻松聊天场景中,AI 可使用 `memes.search_memes` 按关键词检索、按语义检索,或混合检索表情包,再按统一图片 `uid` 独立发送 -* **B站视频解析**:发送 B 站链接/BV 号自动下载发送 1080p 视频,或指令 AI "下载这个 B 站视频 BV1xx411c7mD" -* **代码分析与交付**:"用 Python 写一个 HTTP 服务器,监听 8080 端口,返回 Hello World,验证通过后打包发到这个群" (交由 Code Delivery Agent) -* **定时任务管理**:"每天早上 8 点提醒我看新闻" -* **向未来的自己发指令**:"明天早上 9 点提醒你自己先总结今天群里的待办,再把前三项发给我" +两者通常配合使用:先由 `search_memes` 检索到目标表情包的 uid,再由 `send_meme_by_uid` 独立发送原图。 -### 定时任务进阶:调用未来的自己 +**示例:** +> *"请根据现在的群聊气氛,发一个应景的表情包。"* -定时任务除了调用普通工具外,还支持 `self_instruction` 模式。你可以把一段自然语言指令留给未来触发时刻的 AI 自己执行。 +--- -示例意图: -- “每周一 09:00,先回顾上周群聊重点,再提醒本周计划” -- “今天晚上 23:30,帮我生成明天的复盘提纲” +### 消息操作 (`messages.*`) -实现上由 `scheduler.create_schedule_task` / `scheduler.update_schedule_task` 的 `self_instruction` 参数承载(与 `tool_name`/`tools` 三选一)。 +| 工具 | 说明 | +|---|---| +| `messages.send_message` | 向当前会话发送消息 | +| `messages.send_private_message` | 向指定用户发送私聊消息 | +| `messages.get_recent_messages` | 获取最近若干条历史消息 | +| `messages.get_messages_by_time` | 按时间范围检索历史消息 | +| `messages.react_message_emoji` | 对指定消息添加表情回应 | +| `messages.send_poke` | 发送戳一戳 | +| `messages.send_text_file` | 将文本内容生成文件后发送 | +| `messages.send_url_file` | 下载指定 URL 的文件后发送 | +| `messages.send_group_sign` | 执行群签到操作 | +| `messages.get_forward_msg` | 获取合并转发消息的内容 | --- -## 斜杠指令 +### 群组信息查询 (`group.*`) + +| 工具 | 说明 | +|---|---| +| `group.get_member_list` | 获取群成员列表 | +| `group.get_member_info` | 查询指定成员的详细信息 | +| `group.find_member` | 按昵称/备注搜索群成员 | +| `group.get_member_title` | 获取成员群头衔 | +| `group.get_honor_info` | 查询群荣誉(龙王、话唠等) | +| `group.get_member_activity` | 分析群成员活跃度(支持 member_list / history / hybrid 三种数据源模式) | +| `group.rank_members` | 对群成员进行多维度排名 | +| `group.filter_members` | 按条件过滤群成员 | +| `group.detect_inactive_risk` | 检测长期潜水有流失风险的成员 | +| `group.activity_trend` | 分析群活跃度趋势变化 | +| `group.level_distribution` | 统计群成员等级分布 | +| `group.get_files` | 获取群文件列表 | + +**示例:** +> *"帮我查一下这个群里近 30 天没说过话的成员有哪些。"* +> *"请列出本群最近发言最多的前 10 名成员。"* + +--- -> 💡 **进阶玩法**:想了解每个命令的具体使用参数,或者学习如何通过写几行代码**自定义属于你的独家斜杠指令**?请前往 [命令系统与斜杠指令配置指南](slash-commands.md)。 +### 群聊深度分析 (`group_analysis.*`) -在群聊或私聊中可使用以下指令。除明确说明外,管理类命令需要具备被设置的超级管理员或管理员权限: +| 工具 | 说明 | +|---|---| +| `group_analysis.analyze_member_messages` | 深度分析指定成员的消息数量、类型分布和活跃时段 | +| `group_analysis.analyze_join_statistics` | 统计群成员加入趋势与留存情况 | +| `group_analysis.analyze_new_member_activity` | 分析新成员加入后的活跃度变化 | + +--- + +### 认知记忆查询 (`cognitive.*`) + +| 工具 | 说明 | +|---|---| +| `cognitive.search_events` | 按关键词语义检索历史记忆事件,支持用户、群组、时间段过滤 | +| `cognitive.get_profile` | 获取指定用户的认知侧写画像 | +| `cognitive.search_profiles` | 跨用户语义搜索侧写信息 | + +--- + +### 置顶备忘录 (`memory.*`) + +用于管理 AI 的自我约束事项和高优先级待办。此备忘录会在每轮对话时被固定注入上下文(上限 500 条),优先级高于认知记忆。 + +| 工具 | 说明 | +|---|---| +| `memory.add` | 添加一条置顶备忘(如"用户要求以后用英文回复") | +| `memory.update` | 更新指定备忘内容 | +| `memory.delete` | 删除指定备忘 | +| `memory.list` | 列出当前所有置顶备忘 | +| `memory.query_archive` | 查询已归档的历史备忘 | +| `memory.search_summaries` | 语义搜索历史备忘 | + +> **注意**:用户偏好、身份等长期用户事实请通过对话让 AI 记入**认知记忆**(`cognitive.*`),而非此处。置顶备忘专用于 AI 自身的行为约束与短期高优待办。 + +--- + +### 知识库检索 (`knowledge_*`) + +如果管理员在 `config.toml` 中配置了知识库,AI 可通过以下工具检索其中的内容: + +| 工具 | 说明 | +|---|---| +| `knowledge_semantic_search` | 基于向量语义检索(支持重排序与相关度过滤) | +| `knowledge_text_search` | 基于关键词的精确文本检索 | +| `knowledge_list` | 列出当前可用的知识库 | + +--- + +### 通讯录查询 (`contacts.*`) + +| 工具 | 说明 | +|---|---| +| `contacts.query_friends` | 查询 Bot 的好友列表 | +| `contacts.query_groups` | 查询 Bot 所在的群列表 | + +--- + +### 独立原子工具 + +| 工具 | 说明 | +|---|---| +| `get_current_time` | 获取当前系统时间,支持公历、农历、黄历等多种格式输出 | +| `get_picture` | 获取指定类型的图片(二次元、壁纸、白丝、黑丝、JK、历史上的今天等 10 余种类别) | +| `qq_like` | 给指定 QQ 号的资料卡点赞(默认 10 次) | +| `python_interpreter` | 在隔离的 **Docker 容器**中执行 Python 代码,支持按需安装第三方库,可在执行后自动发送生成的文件(图片、CSV 等) | +| `bilibili_video` | 下载并发送哔哩哔哩视频(支持 BV 号、链接) | +| `arxiv_paper` | 下载并发送 arXiv 论文 PDF(支持 arXiv ID、链接) | +| `fetch_image_uid` | 将指定 URL 的图片下载并转换为系统内部 uid | +| `task_progress` | 向用户发送长任务的阶段性进度通知 | +| `changelog_query` | 查询系统内置版本更新日志 | + +**示例:** +> *"请下载 arXiv 论文 2501.01234 并发到这个群。"* +> *"请在 Docker 里安装 matplotlib 后绘制一张正弦函数图像并发给我。"* +> *"帮我给 QQ 号 123456 点 10 个赞。"* + +--- + +## 5. 定时任务与调度 + +调度器基于标准 crontab 语法,支持三种执行模式,适用于从简单报时到复杂 AI 自主任务的全部场景。 + +### 执行模式 + +| 模式 | 描述 | 配置字段 | +|---|---|---| +| **单工具模式** | 定时调用一个指定的工具,传入固定参数 | `tool_name` + `tool_args` | +| **多工具串/并行模式** | 定时依次(serial)或同时(parallel)调用多个工具 | `tools` + `execution_mode` | +| **AI 自我督办模式** | 在触发时刻,以一段自然语言指令唤醒 AI 自主完成任务 | `self_instruction` | + +### 自我督办模式示例 + +这是调度器最灵活的功能:您可以通过自然语言预约将任意复杂的指令投递给"未来的 AI 自己"来执行。 + +> *"每天上午 9:00,请回顾昨日遗留的待办事项,并把最重要的前三项通过私聊发给我。"* +> *"每周一 08:30,请总结上周群内的高频讨论话题,生成一份周报并发送至群聊。"* +> *"明天晚上 23:00,帮我生成今天的话痨统计图表发到本群。"*(仅执行一次:设置 `max_executions: 1`) + +### 任务管理工具 + +| 工具 | 说明 | +|---|---| +| `scheduler.create_schedule_task` | 创建定时任务,支持 `max_executions`(达到次数后自动删除) | +| `scheduler.update_schedule_task` | 修改任务的触发规则、执行内容或参数 | +| `scheduler.delete_schedule_task` | 删除指定定时任务 | +| `scheduler.list_schedule_tasks` | 列出当前所有定时任务及其运行状态 | + +--- + +## 6. FAQ 知识库管理 + +Bot 支持在运行时维护一个结构化的群专属 FAQ 知识库,可通过斜杠指令进行增删查操作。 + +| 指令 | 权限 | 说明 | +|---|---|---| +| `/lsfaq` | 公开 | 列出当前群的全部 FAQ 条目 | +| `/viewfaq ` | 公开 | 查看指定 FAQ 的详细内容 | +| `/searchfaq <关键词>` | 公开 | 按关键词搜索匹配的 FAQ | +| `/delfaq ` | 管理员 | 删除指定 ID 的 FAQ 条目 | + +--- + +## 7. 内置斜杠指令参考 + +所有斜杠指令均以 `/` 开头,在群聊或私聊中直接输入即可触发。下表基于代码实际配置整理: + +| 指令 | 别名 | 权限 | 私聊 | 说明 | +|---|---|---|---|---| +| `/help [命令名]` | — | 公开 | ✅ | 显示命令列表;附带命令名时展示该命令的详细帮助文档 | +| `/version` | `/v` | 公开 | ✅ | 查看当前版本号及最新版本变更标题 | +| `/changelog [子命令]` | `/cl` | 公开 | ✅ | 查看版本更新日志(详见下方说明) | +| `/copyright` | `/about` `/license` `/cprt` | 公开 | ✅ | 查看版权信息与 MIT 许可证声明 | +| `/stats [天数] [--ai]` | — | 公开 | ✅ | 查看 Token 使用统计图表;附加 `--ai` 启用 AI 智能分析报告 | +| `/lsfaq` | — | 公开 | ❌ | 列出当前群的全部 FAQ | +| `/viewfaq ` | — | 公开 | ❌ | 查看指定 FAQ 详情 | +| `/searchfaq <关键词>` | — | 公开 | ❌ | 按关键词搜索 FAQ | +| `/delfaq ` | — | 管理员 | ❌ | 删除指定 FAQ | +| `/bugfix [起止时间]` | — | 管理员 | ❌ | 基于目标用户近期发言生成娱乐性 Bug 修复报告 | +| `/lsadmin` | — | 管理员 | ✅ | 查看系统当前的超管与管理员列表 | +| `/naga ` | — | 公开 | ✅ | 绑定或解绑关联的 NagaAgent 实例 | +| `/addadmin ` | — | **超级管理员** | ✅ | 将指定用户提权为普通管理员 | +| `/rmadmin ` | — | **超级管理员** | ✅ | 撤销指定用户的管理员权限 | + +### `/changelog` 子命令详解 -```bash -/help # 查看帮助菜单 -/changelog # 查看最近版本历史(公开命令) -/changelog show v3.2.6 # 查看指定版本详情(公开命令) -/lsadmin # 查看当前所有的系统管理员列表 -/addadmin # 添加新的普通管理员(仅限超级管理员使用) -/rmadmin # 移除某位普通管理员 -/bugfix # 根据最近用户在群里的聊天上下文生成该用户的 Bug 修复报告 (幽默搞笑用) -/stats [时间范围] [--ai] # 核心统计功能:获取 Token 使用统计 + 成本计算;加 --ai 才启用智能分析 ``` +/changelog # 列出最近 8 个版本(版本号 + 标题) +/changelog list <数量> # 列出更多版本,最大 20 条 +/changelog latest # 展示最新一个版本的完整变更详情 +/changelog show <版本号> # 展示指定版本的完整详情(带或不带 v 均可) +/changelog <版本号> # 等同于 show +``` + +### `/stats` 说明 -### 关于 `/changelog` 的详细说明: +- 默认统计最近 **7 天**的数据,可传入天数参数(允许范围:1 ~ 365 天)。 +- 默认仅生成统计图表与数字摘要,**不触发** AI 智能分析。 +- 附加 `--ai`(或 `-a`)时,向 AI 发起分析请求;若分析超时,系统会先返回图表与摘要并附带超时提示。 +- 普通用户频率限制为每 3600 秒一次;管理员与超级管理员无限制。 -- `/changelog` 默认列最近 8 个版本,按新到旧展示 `版本号 + 标题`。 -- `/changelog list 12` 可查看更多版本,最大 20 条。 -- `/changelog show <版本号>` 会展示单个版本的标题、摘要和变更点,版本号支持带或不带 `v`。 -- `/changelog latest` 会直接展示 `CHANGELOG.md` 中最新一条版本记录。 -- 版本内容直接来自仓库内维护的 `CHANGELOG.md`,不是运行时临时扫描 git tag。 +### 扩展自定义指令 -### 关于 `/stats` 的详细说明: +系统支持热插拔机制,创建对应目录结构并保存文件即刻生效,无需重启服务。详细的开发步骤与参数说明请参阅 [《命令系统与斜杠指令》](slash-commands.md)。 + +--- + +## 8. 多模型池(私聊模型切换) + +在 `config.toml` 中全局开启 `[features] pool_enabled = true` 后,Bot 支持在多个配置的大模型之间进行灵活调度: + +- **自动轮换**:配置 `strategy = "round_robin"` 或 `"random"` 后,私聊请求会自动按策略在池中模型之间切换。 +- **手动指定**:在私聊中,可通过发送"选 1"、"选 2"等指令来手动锁定本次使用的模型。 + +> 群聊场景始终使用主模型,不参与多模型池调度。 + +完整配置方式及 Agent 模型池说明请参阅 [《多模型池功能》](multi-model.md)。 + +--- + +## 9. WebUI 与跨平台管理 + +Undefined 提供了一套完整的可视化管理控制台,无需修改配置文件或重启服务即可对系统进行动态管理: + +- 实时切换底层驱动的大模型(如 GPT-4o、Claude 3.5 Sonnet 等)。 +- 在线编辑系统 Prompt 与人格设定面板。 +- 监控并干预运行时任务队列与内存状态。 +- 查看完整的 Token 消耗统计与调用日志。 + +WebUI 通过浏览器访问(默认地址 `http://127.0.0.1:8787`,默认密码 `changeme`,**首次启动必须在 `config.toml` 的 `[webui]` 中修改默认密码**)。如需通过手机或其他设备进行远程管理,可使用配套的多端控制台 App,详见 [《跨平台控制台 App》](app.md)。 + +--- -- 默认统计最近 7 天的数据,时间参数范围会自动被系统钳制在 1 天 - 365 天之间。 -- 默认只发送图表与基本摘要,不会触发 AI 智能分析。 -- 仅在显式传入 `--ai`(或 `-a`)时才会请求 AI 分析;若分析超时,系统会先发图表与摘要并附超时提示。 +*如需查阅各模块的底层设计原理与 API 集成说明,请参阅本目录下的其余技术文档。* diff --git a/docs/webui-guide.md b/docs/webui-guide.md new file mode 100644 index 00000000..9a77ff1d --- /dev/null +++ b/docs/webui-guide.md @@ -0,0 +1,181 @@ +# WebUI 使用指南 + +WebUI 是 Undefined 的主要管理入口,提供配置编辑、日志查看、认知记忆管理、表情包库、AI 对话和系统监控等一站式功能。即使 `config.toml` 尚未创建,也可以通过 WebUI 补全配置并启动 Bot。 + +--- + +## 快速开始 + +### 启动 + +```bash +uv run Undefined-webui +``` + +默认监听 `http://127.0.0.1:8787`。相关配置位于 `config.toml` 的 `[webui]` 节: + +```toml +[webui] +url = "127.0.0.1" # 监听地址 +port = 8787 # 端口 +password = "changeme" # 密码(必须在首次登录时修改) +``` + +> 如需远程访问,将 `url` 改为 `0.0.0.0` 或实际 IP。 + +### 首次登录 + +1. 浏览器打开 `http://127.0.0.1:8787` +2. 输入初始密码 `changeme` +3. 系统强制要求修改默认密码后才能进入管理界面 + +修改后的密码会写入 `config.toml`。桌面端 / Android 客户端也使用同一密码连接。 + +--- + +## 功能概览 + +WebUI 共有 8 个主要页签(Tab),下面逐一介绍。 + +### 概览(Overview) + +仪表盘页面,展示系统运行状态: + +| 指标 | 说明 | +|------|------| +| Bot 状态 | 运行中 / 已停止 | +| 系统运行时间 | 自启动以来的持续时间 | +| CPU 使用率 | 实时百分比 | +| 内存占用 | 已用 / 总量 / 百分比 | +| 运行环境 | CPU 型号、操作系统、Python 版本 | +| 资源趋势图 | CPU / 内存随时间的变化曲线 | + +页面会自动刷新,也可手动触发。 + +### 配置管理(Config) + +提供两种编辑方式: + +- **表单模式**:按分组展示所有配置项,带字段说明和类型校验,适合日常修改。 +- **TOML 原始编辑器**:直接编辑 `config.toml` 文本,带语法高亮,适合批量变更。 + +其他能力: + +- **验证**:保存前自动进行 TOML 语法检查和严格配置校验,错误项会标红提示。 +- **配置历史**:每次保存自动生成带时间戳的备份(最多 50 版本),可随时回滚到任意历史版本。 +- **模板同步**:一键将 `config.toml.example` 中新增的配置项合并到当前配置,不覆盖已有值。 + +> 配置支持热更新——大多数配置项修改后即时生效,无需重启 Bot。需要重启的项(如 `onebot_ws_url`、`webui_port`)会在保存时提示。 + +### 日志查看(Logs) + +- 支持 **Bot 日志** 和 **WebUI 日志** 切换。 +- **实时流式推送**(SSE):日志实时滚动到最新。 +- 可暂停 / 恢复流式推送,调整显示行数(1–2000)。 +- 左侧文件列表展示所有日志文件(含归档),可选择查看历史日志。 +- 支持下载日志文件。 + +### 探针与诊断(Probes) + +三类探针帮助排查问题: + +- **内部探针**:版本号、Python 版本、平台、运行时间、OneBot 连接状态及 WebSocket 地址。 +- **外部探针**:Runtime API 的可用端点和能力列表。 +- **引导探针**:检查 `config.toml` 是否存在、TOML 语法是否合法、配置值是否有效,并给出修复建议。 + +### 认知记忆(Memory) + +分为三个子面板,对应 [认知记忆系统](cognitive-memory.md) 的不同层次: + +**认知事件搜索** + +输入关键词进行语义搜索,查看 AI 提取并存储的用户 / 群聊事实记录。支持调整返回条数和排序方式。 + +**认知侧写查看** + +- 按关键词搜索侧写。 +- 按 QQ 号或群号精确查看对应的完整侧写内容。 + +**长期记忆管理** + +AI 的置顶备忘录(自我约束、待办事项等),支持完整 CRUD: + +- 新建记忆条目 +- 编辑现有内容 +- 删除条目 +- 关键词搜索 + +### 表情包库(Memes) + +管理 [全局表情包库](memes.md) 的 Web 界面: + +- **浏览与搜索**:分页列表展示,支持关键词搜索和语义搜索(以及混合模式)。输入时自动防抖搜索,Enter 立即触发。 +- **筛选与排序**:按启用 / 禁用、静态 / 动态、置顶等条件筛选,按创建或更新时间排序。 +- **详情与操作**:查看元数据(UID、描述、标签)和预览图;支持编辑描述 / 标签、启用 / 禁用、置顶 / 取消、删除。 +- **重分析 / 重索引**:对单张表情包重新触发 AI 描述生成或搜索索引更新。 +- **统计概览**:总数、启用 / 禁用数、静态 / 动态数等。 + +### AI 对话(Chat) + +WebUI 内置的对话界面,直接与 Bot 的 AI 进行交互: + +- 支持文本和图片消息。 +- AI 回复支持 Markdown 渲染。 +- 消息历史分页浏览。 +- 发出的消息会经过与 QQ 侧相同的处理流程(安全检查、工具调用等)。 + +### 关于(About) + +显示当前版本号和 MIT 许可证文本。 + +--- + +## Bot 控制 + +WebUI 首页(Landing Page)提供 Bot 的启停控制: + +- **启动 Bot**:点击启动按钮,Bot 进程在后台运行。 +- **停止 Bot**:安全停止当前 Bot 进程。 +- **状态指示**:实时显示 Bot 运行状态。 + +首页还会检测是否有可用更新(基于 Git),并提供更新 + 重启功能。 + +--- + +## 键盘快捷键 + +| 快捷键 | 功能 | +|--------|------| +| `Cmd/Ctrl + K` | 打开命令面板,可快速跳转到任意页签或执行操作 | + +--- + +## 远程访问 + +WebUI 和桌面端 / Android 客户端共享同一 Management API: + +1. 将 `[webui].url` 设为 `0.0.0.0`(或你的 LAN/公网 IP)。 +2. 确保防火墙放行 `[webui].port`(默认 8787)。 +3. 桌面端 / Android 客户端输入 `http://:8787` 和密码即可连接。 + +如果启用了 Runtime API(`[api].enabled = true`),WebUI 会自动代理 Runtime API 的功能(探针、记忆查询、AI Chat 等),无需单独暴露 Runtime API 端口。 + +--- + +## 常见问题 + +**Q: 忘记密码怎么办?** + +直接编辑 `config.toml` 中的 `[webui].password` 字段,重启 WebUI 即可。 + +**Q: 配置保存后 Bot 没反应?** + +大多数配置项支持热更新。但少数关键配置(如 WebSocket 地址、WebUI 端口、API 端口)需要重启才能生效,保存时会有提示。 + +**Q: 日志不滚动了?** + +检查是否意外暂停了日志流。点击日志页面的播放按钮恢复实时推送。 + +**Q: 探针显示 Runtime API 不可达?** + +确认 `[api].enabled = true` 且 Bot 正在运行。Runtime API 由 Bot 主进程提供,Bot 未启动时自然不可达。 diff --git a/pyproject.toml b/pyproject.toml index 9128920b..6dcc6a32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Undefined-bot" -version = "3.3.1" +version = "3.3.2" description = "QQ bot platform with cognitive memory architecture and multi-agent Skills, via OneBot V11." readme = "README.md" authors = [ diff --git a/res/prompts/describe_meme_image.txt b/res/prompts/describe_meme_image.txt index af24d2c2..e223e567 100644 --- a/res/prompts/describe_meme_image.txt +++ b/res/prompts/describe_meme_image.txt @@ -56,6 +56,7 @@ - 不要逐字抄录图中文字。 - 不要输出 Markdown,不要输出额外解释。 - 你必须且只能调用 `submit_meme_description`。 +- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:描述应涵盖动图的动态变化过程和整体语义,不要只描述单帧;可以在 tags 里加上 `动图` 标签。 好的例子: - `description`: `猫猫无语翻白眼反应图` diff --git a/res/prompts/historian_rewrite.md b/res/prompts/historian_rewrite.md index 72a8a2a7..ac706d29 100644 --- a/res/prompts/historian_rewrite.md +++ b/res/prompts/historian_rewrite.md @@ -35,7 +35,7 @@ observations: {observations} 当前消息原文(触发本轮): {source_message} -最近消息参考(用于消歧,不要求逐字复述): +最近消息参考(XML 格式,与主对话一致,消息间以 `---` 分隔,用于消歧,不要求逐字复述): {recent_messages} 必须通过 `submit_rewrite` 工具提交结果,禁止输出普通文本内容。 diff --git a/res/prompts/judge_meme_image.txt b/res/prompts/judge_meme_image.txt index e22cf642..fdb480b1 100644 --- a/res/prompts/judge_meme_image.txt +++ b/res/prompts/judge_meme_image.txt @@ -22,3 +22,4 @@ - `is_meme` 仅表示“适不适合放进聊天表情包库”,不是“图里有没有梗” - `reason` 用一句简短中文说明依据 - 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true` +- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file diff --git a/res/prompts/undefined.xml b/res/prompts/undefined.xml index 2e9b52e5..b8e503a5 100644 --- a/res/prompts/undefined.xml +++ b/res/prompts/undefined.xml @@ -141,12 +141,16 @@ - 除非 `memes.search_memes` 没找到合适结果,或表情包会干扰信息传递,否则不要把本来适合发图的反应先写成一句话来代替发图 - 表情包相关规则只决定“怎么回复”,不单独构成“该不该回复”的参与许可;是否回复仍以前面的回复触发逻辑为准 - 默认不要把表情包和正文写进同一条消息;需要补一句解释时,优先分成两条消息发送 - - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` - - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 - - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` - - 表情包库返回的图片 UID 也可以直接用于 ``;当前会话临时图片和表情包库图片共用同一套 `uid` 语义 + - 推荐使用统一标签 `` 引用任何附件(图片或文件),系统根据 UID 前缀自动处理: + - `pic_*` UID → 内嵌为图片(等效于旧 `` 语法) + - `file_*` UID → 作为独立文件消息在文字之后发出 + - `` 语法仍然可用且仅限图片 UID(向后兼容) + - `` 是推荐的统一语法,适用于所有类型的附件 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 文件附件在文字消息发出后作为独立文件消息依次发送,不会混排在文字中 + - 表情包库返回的图片 UID 也可以直接用于 `` - 只能引用工具结果或上下文里明确给出的图片 UID,禁止臆造 UID - - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + - 不要把 `file_*` UID 放进 `` 标签(会报类型错误) diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index 4bbebbf8..205b99f5 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -141,12 +141,16 @@ - 除非 `memes.search_memes` 没找到合适结果,或表情包会干扰信息传递,否则不要把本来适合发图的反应先写成一句话来代替发图 - 表情包相关规则只决定“怎么回复”,不单独构成“该不该回复”的参与许可;是否回复仍以前面的回复触发逻辑为准 - 默认不要把表情包和正文写进同一条消息;需要补一句解释时,优先分成两条消息发送 - - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` - - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 - - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` - - 表情包库返回的图片 UID 也可以直接用于 ``;当前会话临时图片和表情包库图片共用同一套 `uid` 语义 + - 推荐使用统一标签 `` 引用任何附件(图片或文件),系统根据 UID 前缀自动处理: + - `pic_*` UID → 内嵌为图片(等效于旧 `` 语法) + - `file_*` UID → 作为独立文件消息在文字之后发出 + - `` 语法仍然可用且仅限图片 UID(向后兼容) + - `` 是推荐的统一语法,适用于所有类型的附件 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 文件附件在文字消息发出后作为独立文件消息依次发送,不会混排在文字中 + - 表情包库返回的图片 UID 也可以直接用于 `` - 只能引用工具结果或上下文里明确给出的图片 UID,禁止臆造 UID - - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + - 不要把 `file_*` UID 放进 `` 标签(会报类型错误) @@ -301,8 +305,8 @@ 明确的 NagaAgent 技术问题或讨论 - 直接调用 naga_code_analysis_agent,确认相关性后再回复 - 如果只是泛泛提到naga但不是技术讨论,不要回复 + **必须**先调用 naga_code_analysis_agent 获取信息,再基于返回结果回复 + 如果只是泛泛提到naga但不是技术讨论,不要回复;但只要涉及技术细节,一定要先调 agent @@ -433,9 +437,20 @@ - 对于任何涉及 NagaAgent 的技术问题,直接调用 naga_code_analysis_agent 处理。 + 对于任何涉及 NagaAgent 的技术问题,**必须先调用 naga_code_analysis_agent 获取准确信息后再回复**。 + 不要依赖自身记忆或猜测来回答 NagaAgent 相关问题——该项目代码频繁更新,只有通过 agent 实时查阅才能保证准确。 该 Agent 内部拥有自己的工具集(read_naga_intro、read_file、search_file_content 等), 这些内部工具你无法直接调用,你只需要调用 naga_code_analysis_agent 即可。 + + + 以下场景必须调用 naga_code_analysis_agent: + - 用户询问 NagaAgent 的功能、配置、部署、构建方式 + - 用户遇到 NagaAgent 相关的报错或问题 + - 用户想了解 NagaAgent 的架构、代码逻辑、技能系统等 + - 用户提到 NagaAgent 的任何技术细节(API、openclaw、干员、技能等) + - 讨论涉及 NagaAgent 与其他系统的集成或对比 + 只有纯闲聊式提及(如"naga好用吗"这类不需要技术细节的对话)才可以不调用。 + diff --git a/scripts/sync_config_template.py b/scripts/sync_config_template.py index e4320012..583b9eb5 100755 --- a/scripts/sync_config_template.py +++ b/scripts/sync_config_template.py @@ -96,6 +96,11 @@ def main() -> int: for path in result.added_paths: print(f" + {path}") + if result.updated_comment_paths: + print(f"[sync-config] 注释更新数量: {len(result.updated_comment_paths)}") + for path in result.updated_comment_paths: + print(f" ~ {path}") + if result.removed_paths: print(f"[sync-config] 多余路径数量: {len(result.removed_paths)}") for path in result.removed_paths: diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5722e020..c2c8768d 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -1,3 +1,3 @@ """Undefined - A high-performance, highly scalable QQ group and private chat robot based on a self-developed architecture.""" -__version__ = "3.3.1" +__version__ = "3.3.2" diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index 94288827..e4dda36f 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -16,6 +16,7 @@ import httpx from Undefined.ai.parsing import extract_choices_content +from Undefined.utils.coerce import safe_float from Undefined.ai.llm import ModelRequester from Undefined.config import VisionModelConfig from Undefined.ai.transports import API_MODE_CHAT_COMPLETIONS, get_api_mode @@ -352,19 +353,12 @@ def _parse_meme_analysis_response(content: str) -> dict[str, Any]: parsed = _extract_json_object(content) return { "is_meme": bool(parsed.get("is_meme", False)), - "confidence": _safe_float(parsed.get("confidence", 0.0), default=0.0), + "confidence": safe_float(parsed.get("confidence", 0.0), default=0.0), "description": str(parsed.get("description") or "").strip(), "tags": _normalize_meme_tags(parsed.get("tags")), } -def _safe_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except (TypeError, ValueError): - return default - - class MultimodalAnalyzer: """多模态媒体分析器。 @@ -607,13 +601,13 @@ async def _prune_url_cache_locks( self._url_cache_locks.pop(key, None) async def _build_content_items( - self, media_type: str, media_content: str, prompt: str + self, media_type: str, media_content: str | list[str], prompt: str ) -> list[dict[str, Any]]: """构建请求内容项。 Args: media_type: 媒体类型 - media_content: 媒体内容(URL 或 data URL) + media_content: 媒体内容(URL/data URL),或其列表 prompt: 提示词 Returns: @@ -623,9 +617,9 @@ async def _build_content_items( # 添加媒体内容项 media_item_key = f"{media_type}_url" - content_items.append( - {"type": media_item_key, media_item_key: {"url": media_content}} - ) + contents = media_content if isinstance(media_content, list) else [media_content] + for mc in contents: + content_items.append({"type": media_item_key, media_item_key: {"url": mc}}) return content_items @@ -702,7 +696,7 @@ async def analyze( result = await self._requester.request( model_config=self._vision_config, messages=[{"role": "user", "content": content_items}], - max_tokens=8192, + max_tokens=self._vision_config.max_tokens, call_type=f"vision_{detected_type}", ) content = extract_choices_content(result) @@ -822,13 +816,19 @@ async def _request_required_tool_args( self, *, prompt_path: str, - image_url: str, + image_url: str | list[str], tool_schema: dict[str, Any], tool_name: str, call_type: str, max_tokens: int, ) -> dict[str, Any]: - media_content = await self._load_media_content(image_url, "image") + if isinstance(image_url, list): + media_contents: list[str] = [] + for url in image_url: + media_contents.append(await self._load_media_content(url, "image")) + media_content: str | list[str] = media_contents + else: + media_content = await self._load_media_content(image_url, "image") prompt = await self._load_prompt_text(prompt_path) content_items = await self._build_content_items("image", media_content, prompt) response = await self._requester.request( @@ -847,11 +847,13 @@ async def _request_required_tool_args( expected_tool_name=tool_name, stage=call_type, logger=logger, - error_context=f"image={redact_string(image_url)[:120]}", + error_context=f"image={redact_string(str(image_url) if isinstance(image_url, list) else image_url)[:120]}", ) - async def judge_meme_image(self, image_url: str) -> dict[str, Any]: - safe_url = redact_string(image_url) + async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) try: args = await self._request_required_tool_args( prompt_path=_MEME_JUDGE_PROMPT_PATH, @@ -859,7 +861,7 @@ async def judge_meme_image(self, image_url: str) -> dict[str, Any]: tool_schema=_MEME_JUDGE_TOOL, tool_name="submit_meme_judgement", call_type="vision_meme_judge", - max_tokens=256, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包判定失败,按非表情包处理: %s", exc) @@ -872,7 +874,7 @@ async def judge_meme_image(self, image_url: str) -> dict[str, Any]: try: parsed = { "is_meme": bool(args.get("is_meme", False)), - "confidence": _safe_float(args.get("confidence", 0.0), default=0.0), + "confidence": safe_float(args.get("confidence", 0.0), default=0.0), "reason": str(args.get("reason") or "").strip(), } except Exception: @@ -881,13 +883,15 @@ async def judge_meme_image(self, image_url: str) -> dict[str, Any]: "[媒体分析] 表情包判定完成: url=%s is_meme=%s confidence=%.3f reason=%s", safe_url[:50], parsed.get("is_meme", False), - _safe_float(parsed.get("confidence", 0.0), default=0.0), + safe_float(parsed.get("confidence", 0.0), default=0.0), str(parsed.get("reason", ""))[:80], ) return parsed - async def describe_meme_image(self, image_url: str) -> dict[str, Any]: - safe_url = redact_string(image_url) + async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) try: args = await self._request_required_tool_args( prompt_path=_MEME_DESCRIBE_PROMPT_PATH, @@ -895,7 +899,7 @@ async def describe_meme_image(self, image_url: str) -> dict[str, Any]: tool_schema=_MEME_DESCRIBE_TOOL, tool_name="submit_meme_description", call_type="vision_meme_describe", - max_tokens=512, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包描述失败: %s", exc) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index 62ae9558..5a87c8dd 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -11,7 +11,7 @@ import aiofiles -from Undefined.attachments import attachment_refs_to_xml +from Undefined.utils.coerce import safe_int from Undefined.context import RequestContext from Undefined.end_summary_storage import ( EndSummaryStorage, @@ -22,7 +22,7 @@ from Undefined.skills.anthropic_skills import AnthropicSkillRegistry from Undefined.utils.logging import log_debug_json from Undefined.utils.resources import read_text_resource -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from Undefined.utils.xml import format_message_xml logger = logging.getLogger(__name__) @@ -204,6 +204,43 @@ def _build_model_config_info(self, runtime_config: Any) -> str: else: parts.append("- 思维链: 未启用") + # 彩蛋功能状态 + keyword_reply_enabled = bool( + getattr(runtime_config, "keyword_reply_enabled", False) + ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) + agent_call_mode = str( + getattr(runtime_config, "easter_egg_agent_call_message_mode", "none") + ) + easter_egg_parts: list[str] = [] + if keyword_reply_enabled: + easter_egg_parts.append( + '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' + ) + if repeat_enabled: + threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + desc = f"复读(群聊连续{threshold}条相同消息时自动复读)" + if inverted_question_enabled: + desc += ",倒问号(复读触发时若消息为问号则发送¿)" + easter_egg_parts.append(desc) + elif inverted_question_enabled: + easter_egg_parts.append("倒问号(复读未启用,此功能不生效)") + if agent_call_mode != "none": + mode_desc = { + "agent": "Agent调用提示", + "tools": "工具调用提示", + "clean": "降噪调用提示", + "all": "全量调用提示", + }.get(agent_call_mode, agent_call_mode) + easter_egg_parts.append(f"调用提示模式={mode_desc}") + if easter_egg_parts: + parts.append("- 彩蛋功能: " + ";".join(easter_egg_parts)) + else: + parts.append("- 彩蛋功能: 未启用") + parts.append("") parts.append( "重要:以上是你的模型配置信息。\n" @@ -304,31 +341,60 @@ async def build_messages( is_group_context = True keyword_reply_enabled = False + repeat_enabled = False + repeat_threshold = 3 + inverted_question_enabled = False if self._runtime_config_getter is not None: try: runtime_config = self._runtime_config_getter() keyword_reply_enabled = bool( getattr(runtime_config, "keyword_reply_enabled", False) ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + repeat_threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) except Exception as exc: - logger.debug("读取关键词自动回复配置失败: %s", exc) + logger.debug("读取彩蛋功能配置失败: %s", exc) if is_group_context and keyword_reply_enabled: messages.append( { "role": "system", "content": ( - "【系统行为说明】\n" + "【系统行为说明 — 关键词自动回复】\n" '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' - "命中时,系统可能直接发送固定回复,并在历史中写入" - '以"[系统关键词自动回复] "开头的消息。\n\n' - "这类消息属于系统预设机制,不代表你在该轮主动决策。" + "该功能由 handlers.py 中的独立代码路径处理," + "在消息到达你之前就已完成发送。\n\n" + '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' + "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," + "不经过你的工具调用,与你的决策无关。\n\n" "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" "除非用户主动询问,否则不要主动解释此机制。" ), } ) + if is_group_context and repeat_enabled: + repeat_desc = ( + "【系统行为说明】\n" + f"当前群聊已开启复读彩蛋:当群聊中连续出现{repeat_threshold}条内容相同且来自不同人的消息时," + "系统会自动复读一条相同的消息,并在历史中写入" + '以"[系统复读] "开头的消息。' + ) + if inverted_question_enabled: + repeat_desc += ( + "\n此外,若复读触发时消息内容仅由问号组成(如?或???)," + "系统会发送对应数量的倒问号(¿)代替。" + ) + repeat_desc += ( + "\n\n这类消息属于系统预设机制,不代表你在该轮主动决策。" + "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" + "除非用户主动询问,否则不要主动解释此机制。" + ) + messages.append({"role": "system", "content": repeat_desc}) + # 注入 Anthropic Skills 元数据(Level 1: 始终加载 name + description) if ( self._anthropic_skill_registry @@ -575,39 +641,24 @@ def _resolve_chat_scope( ) -> tuple[Literal["group", "private"], int] | None: ctx = RequestContext.current() - def _safe_int(value: Any) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return value - if isinstance(value, str): - text = value.strip() - if not text: - return None - try: - return int(text) - except ValueError: - return None - return None - if ctx and ctx.request_type == "group" and ctx.group_id is not None: - group_id = _safe_int(ctx.group_id) + group_id = safe_int(ctx.group_id) if group_id is not None: return ("group", group_id) return None if ctx and ctx.request_type == "private" and ctx.user_id is not None: - user_id = _safe_int(ctx.user_id) + user_id = safe_int(ctx.user_id) if user_id is not None: return ("private", user_id) return None if extra_context and extra_context.get("group_id") is not None: - group_id = _safe_int(extra_context.get("group_id")) + group_id = safe_int(extra_context.get("group_id")) if group_id is not None: return ("group", group_id) return None if extra_context and extra_context.get("user_id") is not None: - user_id = _safe_int(extra_context.get("user_id")) + user_id = safe_int(extra_context.get("user_id")) if user_id is not None: return ("private", user_id) return None @@ -673,56 +724,7 @@ async def _inject_recent_messages( recent_msgs = self._drop_current_message_if_duplicated( recent_msgs, question ) - context_lines: list[str] = [] - for msg in recent_msgs: - msg_type_val = msg.get("type", "group") - sender_name = msg.get("display_name", "未知用户") - sender_id = msg.get("user_id", "") - chat_id = msg.get("chat_id", "") - chat_name = msg.get("chat_name", "未知群聊") - timestamp = msg.get("timestamp", "") - text = msg.get("message", "") - attachments = msg.get("attachments", []) - role = msg.get("role", "member") - title = msg.get("title", "") - message_id = msg.get("message_id") - - safe_sender = escape_xml_attr(sender_name) - safe_sender_id = escape_xml_attr(sender_id) - safe_chat_id = escape_xml_attr(chat_id) - safe_chat_name = escape_xml_attr(chat_name) - safe_role = escape_xml_attr(role) - safe_title = escape_xml_attr(title) - safe_time = escape_xml_attr(timestamp) - safe_text = escape_xml_text(str(text)) - - msg_id_attr = "" - if message_id is not None: - msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' - attachment_xml = ( - f"\n{attachment_refs_to_xml(attachments)}" - if isinstance(attachments, list) and attachments - else "" - ) - - if msg_type_val == "group": - location = ( - chat_name if chat_name.endswith("群") else f"{chat_name}群" - ) - safe_location = escape_xml_attr(location) - xml_msg = ( - f'\n{safe_text}{attachment_xml}\n' - ) - else: - location = "私聊" - safe_location = escape_xml_attr(location) - xml_msg = ( - f'\n{safe_text}{attachment_xml}\n' - ) - context_lines.append(xml_msg) + context_lines: list[str] = [format_message_xml(msg) for msg in recent_msgs] formatted_context = "\n---\n".join(context_lines) diff --git a/src/Undefined/ai/retrieval.py b/src/Undefined/ai/retrieval.py index 82c42d42..5fe8337d 100644 --- a/src/Undefined/ai/retrieval.py +++ b/src/Undefined/ai/retrieval.py @@ -10,6 +10,7 @@ from openai import NOT_GIVEN, AsyncOpenAI from Undefined.ai.tokens import TokenCounter +from Undefined.utils.coerce import safe_int from Undefined.config import EmbeddingModelConfig, RerankModelConfig from Undefined.utils.request_params import split_reserved_request_params @@ -224,13 +225,13 @@ def _extract_usage(self, response_dict: dict[str, Any]) -> tuple[int, int, int]: usage = response_dict.get("usage", {}) or {} if not isinstance(usage, dict): usage = {} - prompt_tokens = self._safe_int( - usage.get("prompt_tokens", usage.get("input_tokens", 0)) + prompt_tokens = safe_int( + usage.get("prompt_tokens", usage.get("input_tokens", 0)), 0 ) - completion_tokens = self._safe_int( - usage.get("completion_tokens", usage.get("output_tokens", 0)) + completion_tokens = safe_int( + usage.get("completion_tokens", usage.get("output_tokens", 0)), 0 ) - total_tokens = self._safe_int(usage.get("total_tokens", 0)) + total_tokens = safe_int(usage.get("total_tokens", 0), 0) if total_tokens <= 0 and (prompt_tokens > 0 or completion_tokens > 0): total_tokens = prompt_tokens + completion_tokens return prompt_tokens, completion_tokens, total_tokens @@ -275,7 +276,7 @@ def _normalize_rerank_results( for idx, item in enumerate(raw_results): if not isinstance(item, dict): continue - doc_index = self._safe_int(item.get("index", idx)) + doc_index = safe_int(item.get("index", idx), 0) if doc_index < 0: continue @@ -322,9 +323,3 @@ def _normalize_rerank_results( } for i in range(limit) ] - - def _safe_int(self, value: Any) -> int: - try: - return int(value or 0) - except (TypeError, ValueError): - return 0 diff --git a/src/Undefined/api/__init__.py b/src/Undefined/api/__init__.py index f71c40d3..1ab331b9 100644 --- a/src/Undefined/api/__init__.py +++ b/src/Undefined/api/__init__.py @@ -1,5 +1,6 @@ """Runtime API server for Undefined main process.""" -from .app import RuntimeAPIContext, RuntimeAPIServer +from ._context import RuntimeAPIContext +from .app import RuntimeAPIServer __all__ = ["RuntimeAPIContext", "RuntimeAPIServer"] diff --git a/src/Undefined/api/_context.py b/src/Undefined/api/_context.py new file mode 100644 index 00000000..07d8c92c --- /dev/null +++ b/src/Undefined/api/_context.py @@ -0,0 +1,22 @@ +"""Shared context types for the Runtime API.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + + +@dataclass +class RuntimeAPIContext: + config_getter: Callable[[], Any] + onebot: Any + ai: Any + command_dispatcher: Any + queue_manager: Any + history_manager: Any + sender: Any = None + scheduler: Any = None + cognitive_service: Any = None + cognitive_job_queue: Any = None + meme_service: Any = None + naga_store: Any = None diff --git a/src/Undefined/api/_helpers.py b/src/Undefined/api/_helpers.py new file mode 100644 index 00000000..40dcfb00 --- /dev/null +++ b/src/Undefined/api/_helpers.py @@ -0,0 +1,333 @@ +"""Helper classes and utility functions for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Callable +from urllib.parse import urlsplit + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.config import load_webui_settings +from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin + +logger = logging.getLogger(__name__) + +_AUTH_HEADER = "X-Undefined-API-Key" +_VIRTUAL_USER_ID = 42 + + +class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): + """由 Runtime API 工具调用超时包装器抛出的超时异常。""" + + +@dataclass +class _NagaRequestResult: + payload_hash: str + status: int + payload: dict[str, Any] + finished_at: float + + +class _WebUIVirtualSender: + """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" + + def __init__( + self, + virtual_user_id: int, + send_private_callback: Callable[[int, str], Awaitable[None]], + onebot: Any = None, + ) -> None: + self._virtual_user_id = virtual_user_id + self._send_private_callback = send_private_callback + # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 + self.onebot = onebot + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + reply_to: int | None = None, + preferred_temp_group_id: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + user_id, + auto_history, + mark_sent, + reply_to, + preferred_temp_group_id, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_group_message( + self, + group_id: int, + message: str, + auto_history: bool = True, + history_prefix: str = "", + *, + mark_sent: bool = True, + reply_to: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + group_id, + auto_history, + history_prefix, + mark_sent, + reply_to, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_private_file( + self, + user_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" + _ = user_id, auto_history + import shutil + import uuid as _uuid + from pathlib import Path as _Path + + from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir + + src = _Path(file_path) + display_name = name or src.name + file_id = _uuid.uuid4().hex + dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) + dest = dest_dir / display_name + + def _copy_and_stat() -> int: + shutil.copy2(src, dest) + return dest.stat().st_size + + try: + file_size = await asyncio.to_thread(_copy_and_stat) + except OSError: + file_size = 0 + + message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" + await self._send_private_callback(self._virtual_user_id, message) + + async def send_group_file( + self, + group_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """群文件在虚拟会话中同样重定向为文本消息。""" + await self.send_private_file(group_id, file_path, name, auto_history) + + +def _json_error(message: str, status: int = 400) -> Response: + return web.json_response({"error": message}, status=status) + + +def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: + origin = normalize_origin(str(request.headers.get("Origin") or "")) + settings = load_webui_settings() + response.headers.setdefault("Vary", "Origin") + response.headers.setdefault( + "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" + ) + response.headers.setdefault( + "Access-Control-Allow-Headers", + "Authorization, Content-Type, X-Undefined-API-Key", + ) + response.headers.setdefault("Access-Control-Max-Age", "86400") + if origin and is_allowed_cors_origin( + origin, + configured_host=str(settings.url or ""), + configured_port=settings.port, + ): + response.headers.setdefault("Access-Control-Allow-Origin", origin) + response.headers.setdefault("Access-Control-Allow-Credentials", "true") + + +def _optional_query_param(request: web.Request, key: str) -> str | None: + raw = request.query.get(key) + if raw is None: + return None + text = str(raw).strip() + if not text: + return None + return text + + +def _parse_query_time(value: str | None) -> datetime | None: + text = str(value or "").strip() + if not text: + return None + candidates = [text, text.replace("Z", "+00:00")] + if "T" in text: + candidates.append(text.replace("T", " ")) + for candidate in candidates: + with suppress(ValueError): + return datetime.fromisoformat(candidate) + return None + + +def _to_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value != 0 + text = str(value or "").strip().lower() + return text in {"1", "true", "yes", "on"} + + +def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: + return { + "mode": mode, + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "messages": outputs, + "reply": "\n\n".join(outputs).strip(), + } + + +def _sse_event(event: str, payload: dict[str, Any]) -> bytes: + data = json.dumps(payload, ensure_ascii=False) + return f"event: {event}\ndata: {data}\n\n".encode("utf-8") + + +def _mask_url(url: str) -> str: + """保留 scheme + host,隐藏 path 细节。""" + text = str(url or "").strip().rstrip("/") + if not text: + return "" + parsed = urlsplit(text) + host = parsed.hostname or "" + port_part = f":{parsed.port}" if parsed.port else "" + scheme = parsed.scheme or "https" + return f"{scheme}://{host}{port_part}/..." + + +def _naga_runtime_enabled(cfg: Any) -> bool: + naga_cfg = getattr(cfg, "naga", None) + return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( + getattr(naga_cfg, "enabled", False) + ) + + +def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: + return _naga_runtime_enabled(cfg) and naga_store is not None + + +def _short_text_preview(text: str, limit: int = 80) -> str: + compact = " ".join(str(text or "").split()) + if len(compact) <= limit: + return compact + return compact[:limit] + "..." + + +def _naga_message_digest( + *, + bind_uuid: str, + naga_id: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, +) -> str: + raw = json.dumps( + { + "bind_uuid": bind_uuid, + "naga_id": naga_id, + "target_qq": target_qq, + "target_group": target_group, + "mode": mode, + "format": message_format, + "content": content, + }, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + ) + return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] + + +def _parse_response_payload(response: Response) -> dict[str, Any]: + text = response.text or "" + if not text: + return {} + payload = json.loads(text) + return payload if isinstance(payload, dict) else {"data": payload} + + +def _registry_summary(registry: Any) -> dict[str, Any]: + """从 BaseRegistry 提取轻量摘要。""" + if registry is None: + return {"count": 0, "loaded": 0, "items": []} + items: dict[str, Any] = getattr(registry, "_items", {}) + stats: dict[str, Any] = {} + get_stats = getattr(registry, "get_stats", None) + if callable(get_stats): + stats = get_stats() + summary_items: list[dict[str, Any]] = [] + for name, item in items.items(): + st = stats.get(name) + entry: dict[str, Any] = { + "name": name, + "loaded": getattr(item, "loaded", False), + } + if st is not None: + entry["calls"] = getattr(st, "count", 0) + entry["success"] = getattr(st, "success", 0) + entry["failure"] = getattr(st, "failure", 0) + summary_items.append(entry) + return { + "count": len(items), + "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), + "items": summary_items, + } + + +def _validate_callback_url(url: str) -> str | None: + """校验回调 URL,返回错误信息或 None 表示通过。 + + 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 + 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 + """ + import ipaddress + + parsed = urlsplit(url) + scheme = (parsed.scheme or "").lower() + + if scheme not in ("http", "https"): + return "callback.url must use http or https" + + hostname = parsed.hostname or "" + if not hostname: + return "callback.url must include a hostname" + + # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) + try: + addr = ipaddress.ip_address(hostname) + except ValueError: + pass # 域名形式,放行 + else: + if addr.is_private or addr.is_loopback or addr.is_link_local: + return "callback.url must not point to a private/loopback address" + + return None diff --git a/src/Undefined/api/_naga_state.py b/src/Undefined/api/_naga_state.py new file mode 100644 index 00000000..26bff9f0 --- /dev/null +++ b/src/Undefined/api/_naga_state.py @@ -0,0 +1,115 @@ +"""Naga request deduplication and inflight tracking state.""" + +from __future__ import annotations + +import asyncio +import time +from copy import deepcopy +from typing import Any + +from ._helpers import _NagaRequestResult + +_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 + + +class NagaState: + """Tracks inflight Naga sends and provides request-uuid deduplication.""" + + def __init__(self) -> None: + self.send_registry_lock = asyncio.Lock() + self.send_inflight: dict[str, int] = {} + self.request_uuid_lock = asyncio.Lock() + self.request_uuid_inflight: dict[ + str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] + ] = {} + self.request_uuid_results: dict[str, _NagaRequestResult] = {} + + async def track_send_start(self, message_key: str) -> int: + async with self.send_registry_lock: + next_count = self.send_inflight.get(message_key, 0) + 1 + self.send_inflight[message_key] = next_count + return next_count + + async def track_send_done(self, message_key: str) -> int: + async with self.send_registry_lock: + current = self.send_inflight.get(message_key, 0) + if current <= 1: + self.send_inflight.pop(message_key, None) + return 0 + next_count = current - 1 + self.send_inflight[message_key] = next_count + return next_count + + def _prune_request_uuid_state_locked(self) -> None: + now = time.time() + expired = [ + request_uuid + for request_uuid, result in self.request_uuid_results.items() + if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS + ] + for request_uuid in expired: + self.request_uuid_results.pop(request_uuid, None) + + async def register_request_uuid( + self, request_uuid: str, payload_hash: str + ) -> tuple[str, Any]: + async with self.request_uuid_lock: + self._prune_request_uuid_state_locked() + + cached = self.request_uuid_results.get(request_uuid) + if cached is not None: + if cached.payload_hash != payload_hash: + return "conflict", cached.payload_hash + return "cached", (cached.status, deepcopy(cached.payload)) + + inflight = self.request_uuid_inflight.get(request_uuid) + if inflight is not None: + existing_hash, inflight_future = inflight + if existing_hash != payload_hash: + return "conflict", existing_hash + return "await", inflight_future + + owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( + asyncio.get_running_loop().create_future() + ) + self.request_uuid_inflight[request_uuid] = ( + payload_hash, + owner_future, + ) + return "owner", owner_future + + async def finish_request_uuid( + self, + request_uuid: str, + payload_hash: str, + *, + status: int, + payload: dict[str, Any], + ) -> None: + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + result_payload = deepcopy(payload) + self.request_uuid_results[request_uuid] = _NagaRequestResult( + payload_hash=payload_hash, + status=status, + payload=result_payload, + finished_at=time.time(), + ) + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_result((status, deepcopy(result_payload))) + + async def fail_request_uuid( + self, + request_uuid: str, + payload_hash: str, + exc: BaseException, + ) -> None: + _ = payload_hash + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_exception(exc) diff --git a/src/Undefined/api/_openapi.py b/src/Undefined/api/_openapi.py new file mode 100644 index 00000000..1076256b --- /dev/null +++ b/src/Undefined/api/_openapi.py @@ -0,0 +1,183 @@ +"""OpenAPI / Swagger specification builder.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from Undefined import __version__ +from ._helpers import _AUTH_HEADER, _naga_routes_enabled + +if TYPE_CHECKING: + from ._context import RuntimeAPIContext + + +def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: + server_url = f"{request.scheme}://{request.host}" + cfg = ctx.config_getter() + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + paths: dict[str, Any] = { + "/health": { + "get": { + "summary": "Health check", + "security": [], + "responses": {"200": {"description": "OK"}}, + } + }, + "/openapi.json": { + "get": { + "summary": "OpenAPI schema", + "security": [], + "responses": {"200": {"description": "Schema JSON"}}, + } + }, + "/api/v1/probes/internal": { + "get": { + "summary": "Internal runtime probes", + "description": ( + "Returns system info (version, Python, platform, uptime), " + "OneBot connection status, request queue snapshot, " + "memory count, cognitive service status, API config, " + "skill statistics (tools/agents/anthropic_skills with call counts), " + "and model configuration (names, masked URLs, thinking flags)." + ), + } + }, + "/api/v1/probes/external": { + "get": { + "summary": "External dependency probes", + "description": ( + "Concurrently probes all configured model API endpoints " + "(chat, vision, security, naga, agent, embedding, rerank) " + "and OneBot WebSocket. Each result includes status, " + "model name, masked URL, HTTP status code, and latency." + ), + } + }, + "/api/v1/memory": { + "get": {"summary": "List/search manual memories"}, + "post": {"summary": "Create a manual memory"}, + }, + "/api/v1/memory/{uuid}": { + "patch": {"summary": "Update a manual memory by UUID"}, + "delete": {"summary": "Delete a manual memory by UUID"}, + }, + "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, + "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, + "/api/v1/memes/{uid}": { + "get": {"summary": "Get a meme by uid"}, + "patch": {"summary": "Update a meme by uid"}, + "delete": {"summary": "Delete a meme by uid"}, + }, + "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, + "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, + "/api/v1/memes/{uid}/reanalyze": { + "post": {"summary": "Queue a meme reanalyze job"} + }, + "/api/v1/memes/{uid}/reindex": { + "post": {"summary": "Queue a meme reindex job"} + }, + "/api/v1/cognitive/events": { + "get": {"summary": "Search cognitive event memories"} + }, + "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, + "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { + "get": {"summary": "Get a profile by entity type/id"} + }, + "/api/v1/chat": { + "post": { + "summary": "WebUI special private chat", + "description": ( + "POST JSON {message, stream?}. " + "When stream=true, response is SSE with keep-alive comments." + ), + } + }, + "/api/v1/chat/history": { + "get": {"summary": "Get virtual private chat history for WebUI"} + }, + "/api/v1/tools": { + "get": { + "summary": "List available tools", + "description": ( + "Returns currently available tools filtered by " + "tool_invoke_expose / allowlist / denylist config. " + "Each item follows the OpenAI function calling schema." + ), + } + }, + "/api/v1/tools/invoke": { + "post": { + "summary": "Invoke a tool", + "description": ( + "Execute a specific tool by name. Supports synchronous " + "response and optional async webhook callback." + ), + } + }, + } + + if naga_routes_enabled: + paths.update( + { + "/api/v1/naga/bind/callback": { + "post": { + "summary": "Finalize a Naga bind request", + "description": ( + "Internal callback used by Naga to approve or reject " + "a bind_uuid." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/messages/send": { + "post": { + "summary": "Send a Naga-authenticated message", + "description": ( + "Validates bind_uuid + delivery_signature, runs " + "moderation, then delivers the message. " + "Caller may provide uuid for idempotent retry deduplication." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/unbind": { + "post": { + "summary": "Revoke an active Naga binding", + "description": ( + "Allows Naga to proactively revoke a binding using " + "Authorization: Bearer ." + ), + "security": [{"BearerAuth": []}], + } + }, + } + ) + + return { + "openapi": "3.0.3", + "info": { + "title": "Undefined Runtime API", + "version": __version__, + "description": "API exposed by the main Undefined process.", + }, + "servers": [ + { + "url": server_url, + "description": "Runtime endpoint", + } + ], + "components": { + "securitySchemes": { + "ApiKeyAuth": { + "type": "apiKey", + "in": "header", + "name": _AUTH_HEADER, + }, + "BearerAuth": {"type": "http", "scheme": "bearer"}, + } + }, + "security": [{"ApiKeyAuth": []}], + "paths": paths, + } diff --git a/src/Undefined/api/_probes.py b/src/Undefined/api/_probes.py new file mode 100644 index 00000000..fc51bb3a --- /dev/null +++ b/src/Undefined/api/_probes.py @@ -0,0 +1,147 @@ +"""HTTP/WebSocket endpoint probe functions for health checks.""" + +from __future__ import annotations + +import asyncio +import logging +import socket +import time +from typing import Any +from urllib.parse import urlsplit + +from aiohttp import ClientSession, ClientTimeout + +from ._helpers import _mask_url + +logger = logging.getLogger(__name__) + + +async def _probe_http_endpoint( + *, + name: str, + base_url: str, + api_key: str, + model_name: str = "", + timeout_seconds: float = 5.0, +) -> dict[str, Any]: + normalized = str(base_url or "").strip().rstrip("/") + if not normalized: + return { + "name": name, + "status": "skipped", + "reason": "empty_url", + "model_name": model_name, + } + + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + candidates = [normalized, f"{normalized}/models"] + last_error = "" + for url in candidates: + start = time.perf_counter() + try: + timeout = ClientTimeout(total=timeout_seconds) + async with ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as resp: + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": name, + "status": "ok", + "url": _mask_url(url), + "http_status": resp.status, + "latency_ms": elapsed_ms, + "model_name": model_name, + } + except Exception as exc: + last_error = str(exc) + continue + + return { + "name": name, + "status": "error", + "url": _mask_url(normalized), + "error": last_error or "request_failed", + "model_name": model_name, + } + + +async def _skipped_probe( + *, name: str, reason: str, model_name: str = "" +) -> dict[str, Any]: + payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} + if model_name: + payload["model_name"] = model_name + return payload + + +def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: + payload = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + if hasattr(mcfg, "api_mode"): + payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") + if hasattr(mcfg, "thinking_enabled"): + payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) + if hasattr(mcfg, "thinking_tool_call_compat"): + payload["thinking_tool_call_compat"] = getattr( + mcfg, "thinking_tool_call_compat", True + ) + if hasattr(mcfg, "responses_tool_choice_compat"): + payload["responses_tool_choice_compat"] = getattr( + mcfg, "responses_tool_choice_compat", False + ) + if hasattr(mcfg, "responses_force_stateless_replay"): + payload["responses_force_stateless_replay"] = getattr( + mcfg, "responses_force_stateless_replay", False + ) + if hasattr(mcfg, "prompt_cache_enabled"): + payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) + if hasattr(mcfg, "reasoning_enabled"): + payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) + if hasattr(mcfg, "reasoning_effort"): + payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") + return payload + + +async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: + normalized = str(url or "").strip() + if not normalized: + return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} + + parsed = urlsplit(normalized) + host = parsed.hostname + if not host: + return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} + + if parsed.port is not None: + port = parsed.port + elif parsed.scheme == "wss": + port = 443 + else: + port = 80 + + start = time.perf_counter() + try: + conn = asyncio.open_connection(host, port) + reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) + writer.close() + await writer.wait_closed() + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": "onebot_ws", + "status": "ok", + "host": host, + "port": port, + "latency_ms": elapsed_ms, + } + except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: + return { + "name": "onebot_ws", + "status": "error", + "host": host, + "port": port, + "error": str(exc), + } diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 2ba4ae48..7ad5e56c 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1,669 +1,33 @@ +"""Runtime API server for Undefined. + +Route handler logic lives in ``routes/`` sub-modules. This file keeps only +the ``RuntimeAPIServer`` class (init / start / stop / middleware / routing) +and thin one-liner wrappers that delegate to the route functions so that +existing tests calling ``server._xxx_handler(request)`` keep working. +""" + from __future__ import annotations import asyncio -from copy import deepcopy -import hashlib -import json import logging -import os -import platform -from pathlib import Path -import socket -import sys -import time -import uuid as _uuid -from contextlib import suppress -from dataclasses import dataclass -from datetime import datetime from typing import Any, Awaitable, Callable -from typing import cast -from urllib.parse import urlsplit -from aiohttp import ClientSession, ClientTimeout, web + +from aiohttp import web from aiohttp.web_response import Response -from Undefined import __version__ -from Undefined.attachments import ( - attachment_refs_to_xml, - build_attachment_scope, - register_message_attachments, - render_message_with_pic_placeholders, +from ._context import RuntimeAPIContext +from ._helpers import ( + _apply_cors_headers, + _json_error, + _naga_routes_enabled, + _naga_runtime_enabled, + _AUTH_HEADER, ) -from Undefined.config import load_webui_settings -from Undefined.context import RequestContext -from Undefined.context_resource_registry import collect_context_resources -from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 -from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN -from Undefined.utils.common import message_to_segments -from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin -from Undefined.utils.recent_messages import get_recent_messages_prefer_local -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from ._naga_state import NagaState +from .routes import chat, cognitive, health, memes, memory, naga, system, tools logger = logging.getLogger(__name__) -_VIRTUAL_USER_ID = 42 -_VIRTUAL_USER_NAME = "system" -_AUTH_HEADER = "X-Undefined-API-Key" -_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 -_PROCESS_START_TIME = time.time() -_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 - - -class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): - """由 Runtime API 工具调用超时包装器抛出的超时异常。""" - - -@dataclass -class _NagaRequestResult: - payload_hash: str - status: int - payload: dict[str, Any] - finished_at: float - - -class _WebUIVirtualSender: - """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" - - def __init__( - self, - virtual_user_id: int, - send_private_callback: Callable[[int, str], Awaitable[None]], - onebot: Any = None, - ) -> None: - self._virtual_user_id = virtual_user_id - self._send_private_callback = send_private_callback - # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 - self.onebot = onebot - - async def send_private_message( - self, - user_id: int, - message: str, - auto_history: bool = True, - *, - mark_sent: bool = True, - reply_to: int | None = None, - preferred_temp_group_id: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - user_id, - auto_history, - mark_sent, - reply_to, - preferred_temp_group_id, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_group_message( - self, - group_id: int, - message: str, - auto_history: bool = True, - history_prefix: str = "", - *, - mark_sent: bool = True, - reply_to: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - group_id, - auto_history, - history_prefix, - mark_sent, - reply_to, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_private_file( - self, - user_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" - _ = user_id, auto_history - import shutil - import uuid as _uuid - from pathlib import Path as _Path - - from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir - - src = _Path(file_path) - display_name = name or src.name - file_id = _uuid.uuid4().hex - dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) - dest = dest_dir / display_name - - def _copy_and_stat() -> int: - shutil.copy2(src, dest) - return dest.stat().st_size - - try: - file_size = await asyncio.to_thread(_copy_and_stat) - except OSError: - file_size = 0 - - message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" - await self._send_private_callback(self._virtual_user_id, message) - - async def send_group_file( - self, - group_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """群文件在虚拟会话中同样重定向为文本消息。""" - await self.send_private_file(group_id, file_path, name, auto_history) - - -@dataclass -class RuntimeAPIContext: - config_getter: Callable[[], Any] - onebot: Any - ai: Any - command_dispatcher: Any - queue_manager: Any - history_manager: Any - sender: Any = None - scheduler: Any = None - cognitive_service: Any = None - cognitive_job_queue: Any = None - meme_service: Any = None - naga_store: Any = None - - -def _json_error(message: str, status: int = 400) -> Response: - return web.json_response({"error": message}, status=status) - - -def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: - origin = normalize_origin(str(request.headers.get("Origin") or "")) - settings = load_webui_settings() - response.headers.setdefault("Vary", "Origin") - response.headers.setdefault("Access-Control-Allow-Methods", "GET,POST,OPTIONS") - response.headers.setdefault( - "Access-Control-Allow-Headers", - "Authorization, Content-Type, X-Undefined-API-Key", - ) - response.headers.setdefault("Access-Control-Max-Age", "86400") - if origin and is_allowed_cors_origin( - origin, - configured_host=str(settings.url or ""), - configured_port=settings.port, - ): - response.headers.setdefault("Access-Control-Allow-Origin", origin) - response.headers.setdefault("Access-Control-Allow-Credentials", "true") - - -def _optional_query_param(request: web.Request, key: str) -> str | None: - raw = request.query.get(key) - if raw is None: - return None - text = str(raw).strip() - if not text: - return None - return text - - -def _parse_query_time(value: str | None) -> datetime | None: - text = str(value or "").strip() - if not text: - return None - candidates = [text, text.replace("Z", "+00:00")] - if "T" in text: - candidates.append(text.replace("T", " ")) - for candidate in candidates: - with suppress(ValueError): - return datetime.fromisoformat(candidate) - return None - - -def _to_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return value != 0 - text = str(value or "").strip().lower() - return text in {"1", "true", "yes", "on"} - - -def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: - return { - "mode": mode, - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "messages": outputs, - "reply": "\n\n".join(outputs).strip(), - } - - -def _sse_event(event: str, payload: dict[str, Any]) -> bytes: - data = json.dumps(payload, ensure_ascii=False) - return f"event: {event}\ndata: {data}\n\n".encode("utf-8") - - -def _mask_url(url: str) -> str: - """保留 scheme + host,隐藏 path 细节。""" - text = str(url or "").strip().rstrip("/") - if not text: - return "" - parsed = urlsplit(text) - host = parsed.hostname or "" - port_part = f":{parsed.port}" if parsed.port else "" - scheme = parsed.scheme or "https" - return f"{scheme}://{host}{port_part}/..." - - -def _naga_runtime_enabled(cfg: Any) -> bool: - naga_cfg = getattr(cfg, "naga", None) - return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( - getattr(naga_cfg, "enabled", False) - ) - - -def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: - return _naga_runtime_enabled(cfg) and naga_store is not None - - -def _short_text_preview(text: str, limit: int = 80) -> str: - compact = " ".join(str(text or "").split()) - if len(compact) <= limit: - return compact - return compact[:limit] + "..." - - -def _naga_message_digest( - *, - bind_uuid: str, - naga_id: str, - target_qq: int, - target_group: int, - mode: str, - message_format: str, - content: str, -) -> str: - raw = json.dumps( - { - "bind_uuid": bind_uuid, - "naga_id": naga_id, - "target_qq": target_qq, - "target_group": target_group, - "mode": mode, - "format": message_format, - "content": content, - }, - ensure_ascii=False, - sort_keys=True, - separators=(",", ":"), - ) - return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] - - -def _parse_response_payload(response: Response) -> dict[str, Any]: - text = response.text or "" - if not text: - return {} - payload = json.loads(text) - return payload if isinstance(payload, dict) else {"data": payload} - - -def _registry_summary(registry: Any) -> dict[str, Any]: - """从 BaseRegistry 提取轻量摘要。""" - if registry is None: - return {"count": 0, "loaded": 0, "items": []} - items: dict[str, Any] = getattr(registry, "_items", {}) - stats: dict[str, Any] = {} - get_stats = getattr(registry, "get_stats", None) - if callable(get_stats): - stats = get_stats() - summary_items: list[dict[str, Any]] = [] - for name, item in items.items(): - st = stats.get(name) - entry: dict[str, Any] = { - "name": name, - "loaded": getattr(item, "loaded", False), - } - if st is not None: - entry["calls"] = getattr(st, "count", 0) - entry["success"] = getattr(st, "success", 0) - entry["failure"] = getattr(st, "failure", 0) - summary_items.append(entry) - return { - "count": len(items), - "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), - "items": summary_items, - } - - -def _validate_callback_url(url: str) -> str | None: - """校验回调 URL,返回错误信息或 None 表示通过。 - - 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 - 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 - """ - import ipaddress - - parsed = urlsplit(url) - scheme = (parsed.scheme or "").lower() - - if scheme not in ("http", "https"): - return "callback.url must use http or https" - - hostname = parsed.hostname or "" - if not hostname: - return "callback.url must include a hostname" - - # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) - try: - addr = ipaddress.ip_address(hostname) - except ValueError: - pass # 域名形式,放行 - else: - if addr.is_private or addr.is_loopback or addr.is_link_local: - return "callback.url must not point to a private/loopback address" - - return None - - -async def _probe_http_endpoint( - *, - name: str, - base_url: str, - api_key: str, - model_name: str = "", - timeout_seconds: float = 5.0, -) -> dict[str, Any]: - normalized = str(base_url or "").strip().rstrip("/") - if not normalized: - return { - "name": name, - "status": "skipped", - "reason": "empty_url", - "model_name": model_name, - } - - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - candidates = [normalized, f"{normalized}/models"] - last_error = "" - for url in candidates: - start = time.perf_counter() - try: - timeout = ClientTimeout(total=timeout_seconds) - async with ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as resp: - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": name, - "status": "ok", - "url": _mask_url(url), - "http_status": resp.status, - "latency_ms": elapsed_ms, - "model_name": model_name, - } - except Exception as exc: - last_error = str(exc) - continue - - return { - "name": name, - "status": "error", - "url": _mask_url(normalized), - "error": last_error or "request_failed", - "model_name": model_name, - } - - -async def _skipped_probe( - *, name: str, reason: str, model_name: str = "" -) -> dict[str, Any]: - payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} - if model_name: - payload["model_name"] = model_name - return payload - - -def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: - payload = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - if hasattr(mcfg, "api_mode"): - payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") - if hasattr(mcfg, "thinking_enabled"): - payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) - if hasattr(mcfg, "thinking_tool_call_compat"): - payload["thinking_tool_call_compat"] = getattr( - mcfg, "thinking_tool_call_compat", True - ) - if hasattr(mcfg, "responses_tool_choice_compat"): - payload["responses_tool_choice_compat"] = getattr( - mcfg, "responses_tool_choice_compat", False - ) - if hasattr(mcfg, "responses_force_stateless_replay"): - payload["responses_force_stateless_replay"] = getattr( - mcfg, "responses_force_stateless_replay", False - ) - if hasattr(mcfg, "prompt_cache_enabled"): - payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) - if hasattr(mcfg, "reasoning_enabled"): - payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) - if hasattr(mcfg, "reasoning_effort"): - payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") - return payload - - -async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: - normalized = str(url or "").strip() - if not normalized: - return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} - - parsed = urlsplit(normalized) - host = parsed.hostname - if not host: - return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} - - if parsed.port is not None: - port = parsed.port - elif parsed.scheme == "wss": - port = 443 - else: - port = 80 - - start = time.perf_counter() - try: - conn = asyncio.open_connection(host, port) - reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) - writer.close() - await writer.wait_closed() - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": "onebot_ws", - "status": "ok", - "host": host, - "port": port, - "latency_ms": elapsed_ms, - } - except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: - return { - "name": "onebot_ws", - "status": "error", - "host": host, - "port": port, - "error": str(exc), - } - - -def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: - server_url = f"{request.scheme}://{request.host}" - cfg = ctx.config_getter() - naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) - paths: dict[str, Any] = { - "/health": { - "get": { - "summary": "Health check", - "security": [], - "responses": {"200": {"description": "OK"}}, - } - }, - "/openapi.json": { - "get": { - "summary": "OpenAPI schema", - "security": [], - "responses": {"200": {"description": "Schema JSON"}}, - } - }, - "/api/v1/probes/internal": { - "get": { - "summary": "Internal runtime probes", - "description": ( - "Returns system info (version, Python, platform, uptime), " - "OneBot connection status, request queue snapshot, " - "memory count, cognitive service status, API config, " - "skill statistics (tools/agents/anthropic_skills with call counts), " - "and model configuration (names, masked URLs, thinking flags)." - ), - } - }, - "/api/v1/probes/external": { - "get": { - "summary": "External dependency probes", - "description": ( - "Concurrently probes all configured model API endpoints " - "(chat, vision, security, naga, agent, embedding, rerank) " - "and OneBot WebSocket. Each result includes status, " - "model name, masked URL, HTTP status code, and latency." - ), - } - }, - "/api/v1/memory": {"get": {"summary": "List/search manual memories"}}, - "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, - "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, - "/api/v1/memes/{uid}": { - "get": {"summary": "Get a meme by uid"}, - "patch": {"summary": "Update a meme by uid"}, - "delete": {"summary": "Delete a meme by uid"}, - }, - "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, - "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, - "/api/v1/memes/{uid}/reanalyze": { - "post": {"summary": "Queue a meme reanalyze job"} - }, - "/api/v1/memes/{uid}/reindex": { - "post": {"summary": "Queue a meme reindex job"} - }, - "/api/v1/cognitive/events": { - "get": {"summary": "Search cognitive event memories"} - }, - "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, - "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { - "get": {"summary": "Get a profile by entity type/id"} - }, - "/api/v1/chat": { - "post": { - "summary": "WebUI special private chat", - "description": ( - "POST JSON {message, stream?}. " - "When stream=true, response is SSE with keep-alive comments." - ), - } - }, - "/api/v1/chat/history": { - "get": {"summary": "Get virtual private chat history for WebUI"} - }, - "/api/v1/tools": { - "get": { - "summary": "List available tools", - "description": ( - "Returns currently available tools filtered by " - "tool_invoke_expose / allowlist / denylist config. " - "Each item follows the OpenAI function calling schema." - ), - } - }, - "/api/v1/tools/invoke": { - "post": { - "summary": "Invoke a tool", - "description": ( - "Execute a specific tool by name. Supports synchronous " - "response and optional async webhook callback." - ), - } - }, - } - - if naga_routes_enabled: - paths.update( - { - "/api/v1/naga/bind/callback": { - "post": { - "summary": "Finalize a Naga bind request", - "description": ( - "Internal callback used by Naga to approve or reject " - "a bind_uuid." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/messages/send": { - "post": { - "summary": "Send a Naga-authenticated message", - "description": ( - "Validates bind_uuid + delivery_signature, runs " - "moderation, then delivers the message. " - "Caller may provide uuid for idempotent retry deduplication." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/unbind": { - "post": { - "summary": "Revoke an active Naga binding", - "description": ( - "Allows Naga to proactively revoke a binding using " - "Authorization: Bearer ." - ), - "security": [{"BearerAuth": []}], - } - }, - } - ) - - return { - "openapi": "3.0.3", - "info": { - "title": "Undefined Runtime API", - "version": __version__, - "description": "API exposed by the main Undefined process.", - }, - "servers": [ - { - "url": server_url, - "description": "Runtime endpoint", - } - ], - "components": { - "securitySchemes": { - "ApiKeyAuth": { - "type": "apiKey", - "in": "header", - "name": _AUTH_HEADER, - }, - "BearerAuth": {"type": "http", "scheme": "bearer"}, - } - }, - "security": [{"ApiKeyAuth": []}], - "paths": paths, - } - class RuntimeAPIServer: def __init__( @@ -678,13 +42,7 @@ def __init__( self._runner: web.AppRunner | None = None self._sites: list[web.TCPSite] = [] self._background_tasks: set[asyncio.Task[Any]] = set() - self._naga_send_registry_lock = asyncio.Lock() - self._naga_send_inflight: dict[str, int] = {} - self._naga_request_uuid_lock = asyncio.Lock() - self._naga_request_uuid_inflight: dict[ - str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] - ] = {} - self._naga_request_uuid_results: dict[str, _NagaRequestResult] = {} + self._naga_state = NagaState() async def start(self) -> None: from Undefined.config.models import resolve_bind_hosts @@ -700,7 +58,6 @@ async def start(self) -> None: logger.info("[RuntimeAPI] 已启动: %s", cfg.api.display_url) async def stop(self) -> None: - # 取消所有后台任务(如异步 tool invoke 回调) for task in self._background_tasks: task.cancel() if self._background_tasks: @@ -725,7 +82,6 @@ async def _auth_middleware( _apply_cors_headers(request, response) return response if request.path.startswith("/api/"): - # Naga 端点使用独立鉴权,仅在总开关+子开关均启用时跳过主 auth cfg = self._context.config_getter() is_naga_path = request.path.startswith("/api/v1/naga/") skip_auth = is_naga_path and _naga_runtime_enabled(cfg) @@ -749,6 +105,9 @@ async def _auth_middleware( web.get("/api/v1/probes/internal", self._internal_probe_handler), web.get("/api/v1/probes/external", self._external_probe_handler), web.get("/api/v1/memory", self._memory_handler), + web.post("/api/v1/memory", self._memory_create_handler), + web.patch("/api/v1/memory/{uuid}", self._memory_update_handler), + web.delete("/api/v1/memory/{uuid}", self._memory_delete_handler), web.get("/api/v1/memes", self._meme_list_handler), web.get("/api/v1/memes/stats", self._meme_stats_handler), web.get("/api/v1/memes/{uid}", self._meme_detail_handler), @@ -779,7 +138,6 @@ async def _auth_middleware( web.post("/api/v1/tools/invoke", self._tools_invoke_handler), ] ) - # Naga 端点仅在总开关+子开关均启用时注册 cfg = self._context.config_getter() if _naga_routes_enabled(cfg, self._context.naga_store): app.add_routes( @@ -811,1152 +169,107 @@ async def _auth_middleware( ) return app - async def _track_naga_send_start(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - next_count = self._naga_send_inflight.get(message_key, 0) + 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - async def _track_naga_send_done(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - current = self._naga_send_inflight.get(message_key, 0) - if current <= 1: - self._naga_send_inflight.pop(message_key, None) - return 0 - next_count = current - 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - def _prune_naga_request_uuid_state_locked(self) -> None: - now = time.time() - expired = [ - request_uuid - for request_uuid, result in self._naga_request_uuid_results.items() - if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS - ] - for request_uuid in expired: - self._naga_request_uuid_results.pop(request_uuid, None) - - async def _register_naga_request_uuid( - self, request_uuid: str, payload_hash: str - ) -> tuple[str, Any]: - async with self._naga_request_uuid_lock: - self._prune_naga_request_uuid_state_locked() - - cached = self._naga_request_uuid_results.get(request_uuid) - if cached is not None: - if cached.payload_hash != payload_hash: - return "conflict", cached.payload_hash - return "cached", (cached.status, deepcopy(cached.payload)) - - inflight = self._naga_request_uuid_inflight.get(request_uuid) - if inflight is not None: - existing_hash, inflight_future = inflight - if existing_hash != payload_hash: - return "conflict", existing_hash - return "await", inflight_future - - owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( - asyncio.get_running_loop().create_future() - ) - self._naga_request_uuid_inflight[request_uuid] = ( - payload_hash, - owner_future, - ) - return "owner", owner_future - - async def _finish_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - *, - status: int, - payload: dict[str, Any], - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - result_payload = deepcopy(payload) - self._naga_request_uuid_results[request_uuid] = _NagaRequestResult( - payload_hash=payload_hash, - status=status, - payload=result_payload, - finished_at=time.time(), - ) - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_result((status, deepcopy(result_payload))) - - async def _fail_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - exc: BaseException, - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_exception(exc) - @property def _ctx(self) -> RuntimeAPIContext: return self._context + # ------------------------------------------------------------------ + # Thin delegation wrappers — keep tests calling server._xxx(request) + # ------------------------------------------------------------------ + + # Health / System async def _health_handler(self, request: web.Request) -> Response: - _ = request - return web.json_response( - { - "ok": True, - "service": "undefined-runtime-api", - "version": __version__, - "timestamp": datetime.now().isoformat(), - } - ) + return await health.health_handler(self._ctx, request) async def _openapi_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not bool(getattr(cfg.api, "openapi_enabled", True)): - logger.info( - "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote - ) - return _json_error("OpenAPI disabled", status=404) - naga_routes_enabled = _naga_routes_enabled(cfg, self._ctx.naga_store) - logger.info( - "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", - request.remote, - naga_routes_enabled, - ) - return web.json_response(_build_openapi_spec(self._ctx, request)) + return await system.openapi_handler(self._ctx, request) async def _internal_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - queue_snapshot = ( - self._ctx.queue_manager.snapshot() if self._ctx.queue_manager else {} - ) - cognitive_queue_snapshot = ( - self._ctx.cognitive_job_queue.snapshot() - if self._ctx.cognitive_job_queue - else {} - ) - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - memory_count = memory_storage.count() if memory_storage is not None else 0 - - # Skills 统计 - ai = self._ctx.ai - skills_info: dict[str, Any] = {} - if ai is not None: - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - anthropic_reg = getattr(ai, "anthropic_skill_registry", None) - skills_info["tools"] = _registry_summary(tool_reg) - skills_info["agents"] = _registry_summary(agent_reg) - skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) - - # 模型配置(脱敏) - models_info: dict[str, Any] = {} - for label in ( - "chat_model", - "vision_model", - "agent_model", - "security_model", - "naga_model", - "grok_model", - ): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = _build_internal_model_probe_payload(mcfg) - for label in ("embedding_model", "rerank_model"): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - - uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) - payload = { - "timestamp": datetime.now().isoformat(), - "version": __version__, - "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", - "platform": platform.system(), - "uptime_seconds": uptime_seconds, - "onebot": self._ctx.onebot.connection_status() - if self._ctx.onebot is not None - else {}, - "queues": queue_snapshot, - "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, - "cognitive": { - "enabled": bool( - self._ctx.cognitive_service and self._ctx.cognitive_service.enabled - ), - "queue": cognitive_queue_snapshot, - }, - "api": { - "enabled": bool(cfg.api.enabled), - "host": cfg.api.host, - "port": cfg.api.port, - "openapi_enabled": bool(cfg.api.openapi_enabled), - }, - "skills": skills_info, - "models": models_info, - } - return web.json_response(payload) + return await system.internal_probe_handler(self._ctx, request) async def _external_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - naga_probe = ( - _probe_http_endpoint( - name="naga_model", - base_url=cfg.naga_model.api_url, - api_key=cfg.naga_model.api_key, - model_name=cfg.naga_model.model_name, - ) - if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) - else _skipped_probe( - name="naga_model", - reason="naga_integration_disabled", - model_name=cfg.naga_model.model_name, - ) - ) - checks = [ - _probe_http_endpoint( - name="chat_model", - base_url=cfg.chat_model.api_url, - api_key=cfg.chat_model.api_key, - model_name=cfg.chat_model.model_name, - ), - _probe_http_endpoint( - name="vision_model", - base_url=cfg.vision_model.api_url, - api_key=cfg.vision_model.api_key, - model_name=cfg.vision_model.model_name, - ), - _probe_http_endpoint( - name="security_model", - base_url=cfg.security_model.api_url, - api_key=cfg.security_model.api_key, - model_name=cfg.security_model.model_name, - ), - naga_probe, - _probe_http_endpoint( - name="agent_model", - base_url=cfg.agent_model.api_url, - api_key=cfg.agent_model.api_key, - model_name=cfg.agent_model.model_name, - ), - ] - grok_model = getattr(cfg, "grok_model", None) - if grok_model is not None: - checks.append( - _probe_http_endpoint( - name="grok_model", - base_url=getattr(grok_model, "api_url", ""), - api_key=getattr(grok_model, "api_key", ""), - model_name=getattr(grok_model, "model_name", ""), - ) - ) - checks.extend( - [ - _probe_http_endpoint( - name="embedding_model", - base_url=cfg.embedding_model.api_url, - api_key=cfg.embedding_model.api_key, - model_name=getattr(cfg.embedding_model, "model_name", ""), - ), - _probe_http_endpoint( - name="rerank_model", - base_url=cfg.rerank_model.api_url, - api_key=cfg.rerank_model.api_key, - model_name=getattr(cfg.rerank_model, "model_name", ""), - ), - _probe_ws_endpoint(cfg.onebot_ws_url), - ] - ) - results = await asyncio.gather(*checks) - ok = all(item.get("status") in {"ok", "skipped"} for item in results) - return web.json_response( - { - "ok": ok, - "timestamp": datetime.now().isoformat(), - "results": results, - } - ) + return await system.external_probe_handler(self._ctx, request) + # Memory CRUD async def _memory_handler(self, request: web.Request) -> Response: - query = str(request.query.get("q", "") or "").strip().lower() - top_k_raw = _optional_query_param(request, "top_k") - time_from_raw = _optional_query_param(request, "time_from") - time_to_raw = _optional_query_param(request, "time_to") - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - - limit: int | None = None - if top_k_raw is not None: - try: - limit = int(top_k_raw) - except ValueError: - return _json_error("top_k must be an integer", status=400) - if limit <= 0: - return _json_error("top_k must be > 0", status=400) - - time_from_dt = _parse_query_time(time_from_raw) - if time_from_raw is not None and time_from_dt is None: - return _json_error("time_from must be ISO datetime", status=400) - time_to_dt = _parse_query_time(time_to_raw) - if time_to_raw is not None and time_to_dt is None: - return _json_error("time_to must be ISO datetime", status=400) - if time_from_dt and time_to_dt and time_from_dt > time_to_dt: - time_from_dt, time_to_dt = time_to_dt, time_from_dt - - records = memory_storage.get_all() - items: list[dict[str, Any]] = [] - for item in records: - created_at = str(item.created_at or "").strip() - created_dt = _parse_query_time(created_at) - if time_from_dt and created_dt and created_dt < time_from_dt: - continue - if time_to_dt and created_dt and created_dt > time_to_dt: - continue - if (time_from_dt or time_to_dt) and created_dt is None: - continue - items.append( - { - "uuid": item.uuid, - "fact": item.fact, - "created_at": created_at, - } - ) - if query: - items = [ - item - for item in items - if query in str(item.get("fact", "")).lower() - or query in str(item.get("uuid", "")).lower() - ] + return await memory.memory_list_handler(self._ctx, request) - def _created_sort_key(item: dict[str, Any]) -> float: - created_dt = _parse_query_time(str(item.get("created_at") or "")) - if created_dt is None: - return float("-inf") - with suppress(OSError, OverflowError, ValueError): - return float(created_dt.timestamp()) - return float("-inf") + async def _memory_create_handler(self, request: web.Request) -> Response: + return await memory.memory_create_handler(self._ctx, request) - items.sort(key=_created_sort_key) - if limit is not None: - items = items[:limit] + async def _memory_update_handler(self, request: web.Request) -> Response: + return await memory.memory_update_handler(self._ctx, request) - return web.json_response( - { - "total": len(items), - "items": items, - "query": { - "q": query or "", - "top_k": limit, - "time_from": time_from_raw, - "time_to": time_to_raw, - }, - } - ) + async def _memory_delete_handler(self, request: web.Request) -> Response: + return await memory.memory_delete_handler(self._ctx, request) + # Memes async def _meme_list_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - - def _parse_optional_bool(name: str) -> bool | None: - raw = request.query.get(name) - if raw is None or str(raw).strip() == "": - return None - return _to_bool(raw) - - page_raw = _optional_query_param(request, "page") - page_size_raw = _optional_query_param(request, "page_size") - top_k_raw = _optional_query_param(request, "top_k") - query = str(request.query.get("q", "") or "").strip() - query_mode = str(request.query.get("query_mode", "") or "").strip().lower() - keyword_query = str(request.query.get("keyword_query", "") or "").strip() - semantic_query = str(request.query.get("semantic_query", "") or "").strip() - try: - page = int(page_raw) if page_raw is not None else 1 - page_size = int(page_size_raw) if page_size_raw is not None else 50 - top_k = int(top_k_raw) if top_k_raw is not None else page_size - except ValueError: - return _json_error("page/page_size/top_k must be integers", status=400) - page = max(1, page) - page_size = max(1, min(200, page_size)) - top_k = max(1, top_k) - sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() - - enabled_filter = _parse_optional_bool("enabled") - animated_filter = _parse_optional_bool("animated") - pinned_filter = _parse_optional_bool("pinned") - if not (query or keyword_query or semantic_query) and sort == "relevance": - sort = "updated_at" - - if query or keyword_query or semantic_query: - has_post_filter = any( - f is not None for f in (enabled_filter, animated_filter, pinned_filter) - ) - requested_window = max(page * page_size, top_k) - if has_post_filter or page > 1 or sort != "relevance": - fetch_k = min(500, max(requested_window * 4, top_k)) - else: - fetch_k = min(500, requested_window) - search_payload = await meme_service.search_memes( - query, - query_mode=query_mode or meme_service.default_query_mode, - keyword_query=keyword_query or None, - semantic_query=semantic_query or None, - top_k=fetch_k, - include_disabled=enabled_filter is not True, - sort=sort, - ) - filtered_items: list[dict[str, Any]] = [] - for item in list(search_payload.get("items") or []): - if ( - enabled_filter is not None - and bool(item.get("enabled")) != enabled_filter - ): - continue - if ( - animated_filter is not None - and bool(item.get("is_animated")) != animated_filter - ): - continue - if ( - pinned_filter is not None - and bool(item.get("pinned")) != pinned_filter - ): - continue - filtered_items.append(item) - offset = (page - 1) * page_size - paged_items = filtered_items[offset : offset + page_size] - window_total = len(filtered_items) - fetched_window_count = len(list(search_payload.get("items") or [])) - window_exhausted = fetched_window_count < fetch_k - has_more = bool(paged_items) and ( - offset + page_size < window_total - or (not window_exhausted and window_total >= offset + page_size) - ) - return web.json_response( - { - "ok": True, - "total": None, - "window_total": window_total, - "total_exact": False, - "page": page, - "page_size": page_size, - "has_more": has_more, - "query_mode": search_payload.get("query_mode"), - "keyword_query": search_payload.get("keyword_query"), - "semantic_query": search_payload.get("semantic_query"), - "sort": search_payload.get("sort", sort), - "items": paged_items, - } - ) - - payload = await meme_service.list_memes( - query=query, - enabled=enabled_filter, - animated=animated_filter, - pinned=pinned_filter, - sort=sort, - page=page, - page_size=page_size, - summary=True, - ) - return web.json_response(payload) + return await memes.meme_list_handler(self._ctx, request) async def _meme_stats_handler(self, request: web.Request) -> Response: - _ = request - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - return web.json_response(await meme_service.stats()) + return await memes.meme_stats_handler(self._ctx, request) async def _meme_detail_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - detail = await meme_service.get_meme(uid) - if detail is None: - return _json_error("Meme not found", status=404) - return web.json_response(detail) + return await memes.meme_detail_handler(self._ctx, request) async def _meme_blob_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=False) - if path is None: - return _json_error("Meme blob not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_blob_handler(self._ctx, request) async def _meme_preview_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=True) - if path is None: - return _json_error("Meme preview not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_preview_handler(self._ctx, request) async def _meme_update_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - try: - payload = await request.json() - except Exception: - return _json_error("Invalid JSON body", status=400) - if not isinstance(payload, dict): - return _json_error("JSON body must be an object", status=400) - updated = await meme_service.update_meme( - uid, - manual_description=payload.get("manual_description"), - tags=payload.get("tags"), - aliases=payload.get("aliases"), - enabled=payload.get("enabled") if "enabled" in payload else None, - pinned=payload.get("pinned") if "pinned" in payload else None, - ) - if updated is None: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "record": updated}) + return await memes.meme_update_handler(self._ctx, request) async def _meme_delete_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - deleted = await meme_service.delete_meme(uid) - if not deleted: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "uid": uid}) + return await memes.meme_delete_handler(self._ctx, request) async def _meme_reanalyze_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reanalyze(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reanalyze_handler(self._ctx, request) async def _meme_reindex_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reindex(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reindex_handler(self._ctx, request) + # Cognitive async def _cognitive_events_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - for key in ( - "target_user_id", - "target_group_id", - "sender_id", - "request_type", - "top_k", - "time_from", - "time_to", - ): - value = _optional_query_param(request, key) - if value is not None: - search_kwargs[key] = value - - results = await cognitive_service.search_events(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_events_handler(self._ctx, request) async def _cognitive_profiles_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - entity_type = _optional_query_param(request, "entity_type") - if entity_type is not None: - search_kwargs["entity_type"] = entity_type - top_k = _optional_query_param(request, "top_k") - if top_k is not None: - search_kwargs["top_k"] = top_k - - results = await cognitive_service.search_profiles(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_profiles_handler(self._ctx, request) async def _cognitive_profile_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - entity_type = str(request.match_info.get("entity_type", "")).strip() - entity_id = str(request.match_info.get("entity_id", "")).strip() - if not entity_type or not entity_id: - return _json_error("entity_type/entity_id are required", status=400) - - profile = await cognitive_service.get_profile(entity_type, entity_id) - return web.json_response( - { - "entity_type": entity_type, - "entity_id": entity_id, - "profile": profile or "", - "found": bool(profile), - } - ) + return await cognitive.cognitive_profile_handler(self._ctx, request) + # Chat async def _run_webui_chat( self, *, text: str, send_output: Callable[[int, str], Awaitable[None]], ) -> str: - cfg = self._ctx.config_getter() - permission_sender_id = int(cfg.superadmin_qq) - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - input_segments = message_to_segments(text) - registered_input = await register_message_attachments( - registry=self._ctx.ai.attachment_registry, - segments=input_segments, - scope_key=webui_scope_key, - resolve_image_url=self._ctx.onebot.get_image, - get_forward_messages=self._ctx.onebot.get_forward_msg, - ) - normalized_text = registered_input.normalized_text or text - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=normalized_text, - display_name=_VIRTUAL_USER_NAME, - user_name=_VIRTUAL_USER_NAME, - attachments=registered_input.attachments, - ) - - command = self._ctx.command_dispatcher.parse_command(normalized_text) - if command: - await self._ctx.command_dispatcher.dispatch_private( - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - command=command, - send_private_callback=send_output, - is_webui_session=True, - ) - return "command" - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - attachment_xml = ( - f"\n{attachment_refs_to_xml(registered_input.attachments)}" - if registered_input.attachments - else "" - ) - full_question = f""" - {escape_xml_text(normalized_text)}{attachment_xml} - - -【WebUI 会话】 -这是一条来自 WebUI 控制台的会话请求。 -会话身份:虚拟用户 system(42)。 -权限等级:superadmin(你可按最高管理权限处理)。 -请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" - virtual_sender = _WebUIVirtualSender( - _VIRTUAL_USER_ID, send_output, onebot=self._ctx.onebot - ) - - async def _get_recent_cb( - chat_id: str, msg_type: str, start: int, end: int - ) -> list[dict[str, Any]]: - return await get_recent_messages_prefer_local( - chat_id=chat_id, - msg_type=msg_type, - start=start, - end=end, - onebot_client=self._ctx.onebot, - history_manager=self._ctx.history_manager, - bot_qq=cfg.bot_qq, - attachment_registry=getattr(self._ctx.ai, "attachment_registry", None), - ) - - async with RequestContext( - request_type="private", - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - ) as ctx: - # 与 ai_coordinator 保持一致:通过 collect_context_resources 自动注入 - ai_client = self._ctx.ai - memory_storage = self._ctx.ai.memory_storage - runtime_config = self._ctx.ai.runtime_config - sender = virtual_sender - history_manager = self._ctx.history_manager - onebot_client = self._ctx.onebot - scheduler = self._ctx.scheduler - - def send_message_callback( - msg: str, reply_to: int | None = None - ) -> Awaitable[None]: - _ = reply_to - return send_output(_VIRTUAL_USER_ID, msg) - - get_recent_messages_callback = _get_recent_cb - get_image_url_callback = self._ctx.onebot.get_image - get_forward_msg_callback = self._ctx.onebot.get_forward_msg - resource_vars = dict(globals()) - resource_vars.update(locals()) - resources = collect_context_resources(resource_vars) - for key, value in resources.items(): - if value is not None: - ctx.set_resource(key, value) - ctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) - ctx.set_resource("webui_session", True) - ctx.set_resource("webui_permission", "superadmin") - - result = await self._ctx.ai.ask( - full_question, - send_message_callback=send_message_callback, - get_recent_messages_callback=get_recent_messages_callback, - get_image_url_callback=get_image_url_callback, - get_forward_msg_callback=get_forward_msg_callback, - sender=sender, - history_manager=history_manager, - onebot_client=onebot_client, - scheduler=scheduler, - extra_context={ - "is_private_chat": True, - "request_type": "private", - "user_id": _VIRTUAL_USER_ID, - "sender_name": _VIRTUAL_USER_NAME, - "webui_session": True, - "webui_permission": "superadmin", - }, - ) - - final_reply = str(result or "").strip() - if final_reply: - await send_output(_VIRTUAL_USER_ID, final_reply) - - return "chat" + return await chat.run_webui_chat(self._ctx, text=text, send_output=send_output) async def _chat_history_handler(self, request: web.Request) -> Response: - limit_raw = str(request.query.get("limit", "200") or "200").strip() - try: - limit = int(limit_raw) - except ValueError: - limit = 200 - limit = max(1, min(limit, 500)) - - getter = getattr(self._ctx.history_manager, "get_recent_private", None) - if not callable(getter): - return _json_error("History manager not ready", status=503) - - records = getter(_VIRTUAL_USER_ID, limit) - items: list[dict[str, Any]] = [] - for item in records: - if not isinstance(item, dict): - continue - content = str(item.get("message", "")).strip() - if not content: - continue - display_name = str(item.get("display_name", "")).strip().lower() - role = "bot" if display_name == "bot" else "user" - items.append( - { - "role": role, - "content": content, - "timestamp": str(item.get("timestamp", "") or "").strip(), - } - ) - - return web.json_response( - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "count": len(items), - "items": items, - } - ) + return await chat.chat_history_handler(self._ctx, request) async def _chat_handler(self, request: web.Request) -> web.StreamResponse: - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - text = str(body.get("message", "") or "").strip() - if not text: - return _json_error("message is required", status=400) - - stream = _to_bool(body.get("stream")) - outputs: list[str] = [] - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - - async def _capture_private_message(user_id: int, message: str) -> None: - _ = user_id - content = str(message or "").strip() - if not content: - return - rendered = await render_message_with_pic_placeholders( - content, - registry=self._ctx.ai.attachment_registry, - scope_key=webui_scope_key, - strict=False, - ) - if not rendered.delivery_text.strip(): - return - outputs.append(rendered.delivery_text) - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=rendered.history_text, - display_name="Bot", - user_name="Bot", - attachments=rendered.attachments, - ) - - if not stream: - try: - mode = await self._run_webui_chat( - text=text, send_output=_capture_private_message - ) - except Exception as exc: - logger.exception("[RuntimeAPI] chat failed: %s", exc) - return _json_error("Chat failed", status=502) - return web.json_response(_build_chat_response_payload(mode, outputs)) - - response = web.StreamResponse( - status=200, - reason="OK", - headers={ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - await response.prepare(request) - - message_queue: asyncio.Queue[str] = asyncio.Queue() - - async def _capture_private_message_stream(user_id: int, message: str) -> None: - output_count = len(outputs) - await _capture_private_message(user_id, message) - if len(outputs) <= output_count: - return - content = outputs[-1].strip() - if content: - await message_queue.put(content) - - task = asyncio.create_task( - self._run_webui_chat(text=text, send_output=_capture_private_message_stream) - ) - mode = "chat" - client_disconnected = False - try: - await response.write( - _sse_event( - "meta", - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - }, - ) - ) - - while True: - if request.transport is None or request.transport.is_closing(): - client_disconnected = True - break - if task.done() and message_queue.empty(): - break - try: - message = await asyncio.wait_for( - message_queue.get(), - timeout=_CHAT_SSE_KEEPALIVE_SECONDS, - ) - await response.write(_sse_event("message", {"content": message})) - except asyncio.TimeoutError: - await response.write(b": keep-alive\n\n") - - if client_disconnected: - task.cancel() - with suppress(asyncio.CancelledError): - await task - return response - - mode = await task - await response.write( - _sse_event("done", _build_chat_response_payload(mode, outputs)) - ) - except asyncio.CancelledError: - task.cancel() - with suppress(asyncio.CancelledError): - await task - raise - except (ConnectionResetError, RuntimeError): - task.cancel() - with suppress(asyncio.CancelledError): - await task - except Exception as exc: - logger.exception("[RuntimeAPI] chat stream failed: %s", exc) - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - await task - with suppress(Exception): - await response.write(_sse_event("error", {"error": str(exc)})) - finally: - with suppress(Exception): - await response.write_eof() - - return response - - # ------------------------------------------------------------------ - # Tool Invoke API - # ------------------------------------------------------------------ + return await chat.chat_handler(self._ctx, request) + # Tools def _get_filtered_tools(self) -> list[dict[str, Any]]: - """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" - cfg = self._ctx.config_getter() - api_cfg = cfg.api - ai = self._ctx.ai - if ai is None: - return [] - - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - - all_schemas: list[dict[str, Any]] = [] - if tool_reg is not None: - all_schemas.extend(tool_reg.get_tools_schema()) - - # 收集 agent schema 并缓存名称集合(避免重复调用) - agent_names: set[str] = set() - if agent_reg is not None: - agent_schemas = agent_reg.get_agents_schema() - all_schemas.extend(agent_schemas) - for schema in agent_schemas: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - - denylist: set[str] = set(api_cfg.tool_invoke_denylist) - allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) - expose = api_cfg.tool_invoke_expose - - def _get_name(schema: dict[str, Any]) -> str: - func = schema.get("function", {}) - return str(func.get("name", "")) - - # 1. 先排除黑名单 - if denylist: - all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] - - # 2. 白名单非空时仅保留匹配项 - if allowlist: - return [s for s in all_schemas if _get_name(s) in allowlist] - - # 3. 按 expose 过滤 - if expose == "all": - return all_schemas - - def _is_tool(name: str) -> bool: - return "." not in name and name not in agent_names - - def _is_toolset(name: str) -> bool: - return "." in name and not name.startswith("mcp.") - - filtered: list[dict[str, Any]] = [] - for schema in all_schemas: - name = _get_name(schema) - if not name: - continue - if expose == "tools" and _is_tool(name): - filtered.append(schema) - elif expose == "toolsets" and _is_toolset(name): - filtered.append(schema) - elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): - filtered.append(schema) - elif expose == "agents" and name in agent_names: - filtered.append(schema) - - return filtered + return tools.get_filtered_tools(self._ctx) def _get_agent_tool_names(self) -> set[str]: - ai = self._ctx.ai - if ai is None: - return set() - - agent_reg = getattr(ai, "agent_registry", None) - if agent_reg is None: - return set() - - agent_names: set[str] = set() - for schema in agent_reg.get_agents_schema(): - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - return agent_names - - def _resolve_tool_invoke_timeout( - self, tool_name: str, timeout: int - ) -> float | None: - if tool_name in self._get_agent_tool_names(): - return None - return float(timeout) - - async def _await_tool_invoke_result( - self, - awaitable: Awaitable[Any], - *, - timeout: float | None, - ) -> Any: - if timeout is None or timeout <= 0: - return await awaitable - try: - return await asyncio.wait_for(awaitable, timeout=timeout) - except asyncio.TimeoutError as exc: - raise _ToolInvokeExecutionTimeoutError from exc + return tools.get_agent_tool_names(self._ctx) async def _tools_list_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - tools = self._get_filtered_tools() - return web.json_response({"count": len(tools), "tools": tools}) + return await tools.tools_list_handler(self._ctx, request) async def _tools_invoke_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - if not isinstance(body, dict): - return _json_error("Request body must be a JSON object", status=400) - - tool_name = str(body.get("tool_name", "") or "").strip() - if not tool_name: - return _json_error("tool_name is required", status=400) - - args = body.get("args") - if not isinstance(args, dict): - return _json_error("args must be a JSON object", status=400) - - # 验证工具是否在允许列表中 - filtered_tools = self._get_filtered_tools() - available_names: set[str] = set() - for schema in filtered_tools: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - available_names.add(name) - - if tool_name not in available_names: - caller_ip = request.remote or "unknown" - logger.warning( - "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", - tool_name, - caller_ip, - ) - return _json_error(f"Tool '{tool_name}' is not available", status=404) - - # 解析回调配置 - callback_cfg = body.get("callback") - use_callback = False - callback_url = "" - callback_headers: dict[str, str] = {} - if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): - callback_url = str(callback_cfg.get("url", "") or "").strip() - if not callback_url: - return _json_error( - "callback.url is required when callback is enabled", - status=400, - ) - url_error = _validate_callback_url(callback_url) - if url_error: - return _json_error(url_error, status=400) - raw_headers = callback_cfg.get("headers") - if isinstance(raw_headers, dict): - callback_headers = {str(k): str(v) for k, v in raw_headers.items()} - use_callback = True - - request_id = _uuid.uuid4().hex - caller_ip = request.remote or "unknown" - logger.info( - "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", - request_id, - tool_name, - caller_ip, - ) - - if use_callback: - # 异步执行 + 回调 - task = asyncio.create_task( - self._execute_and_callback( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - callback_url=callback_url, - callback_headers=callback_headers, - timeout=cfg.api.tool_invoke_timeout, - callback_timeout=cfg.api.tool_invoke_callback_timeout, - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return web.json_response( - { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "status": "accepted", - } - ) - - # 同步执行 - result = await self._execute_tool_invoke( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - timeout=cfg.api.tool_invoke_timeout, + return await tools.tools_invoke_handler( + self._ctx, self._background_tasks, request ) - return web.json_response(result) async def _execute_tool_invoke( self, @@ -1967,146 +280,8 @@ async def _execute_tool_invoke( body_context: Any, timeout: int, ) -> dict[str, Any]: - """执行工具调用并返回结果字典。""" - ai = self._ctx.ai - if ai is None: - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": "AI client not ready", - "duration_ms": 0, - } - - # 解析请求上下文 - ctx_data: dict[str, Any] = {} - if isinstance(body_context, dict): - ctx_data = body_context - request_type = str(ctx_data.get("request_type", "api") or "api") - group_id = ctx_data.get("group_id") - user_id = ctx_data.get("user_id") - sender_id = ctx_data.get("sender_id") - - args_keys = list(args.keys()) - logger.info( - "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", - request_id, - tool_name, - args_keys, - ) - - start = time.perf_counter() - effective_timeout = self._resolve_tool_invoke_timeout(tool_name, timeout) - try: - async with RequestContext( - request_type=request_type, - group_id=int(group_id) if group_id is not None else None, - user_id=int(user_id) if user_id is not None else None, - sender_id=int(sender_id) if sender_id is not None else None, - ) as ctx: - # 注入核心服务资源 - if self._ctx.sender is not None: - ctx.set_resource("sender", self._ctx.sender) - if self._ctx.history_manager is not None: - ctx.set_resource("history_manager", self._ctx.history_manager) - runtime_config = getattr(ai, "runtime_config", None) - if runtime_config is not None: - ctx.set_resource("runtime_config", runtime_config) - memory_storage = getattr(ai, "memory_storage", None) - if memory_storage is not None: - ctx.set_resource("memory_storage", memory_storage) - if self._ctx.onebot is not None: - ctx.set_resource("onebot_client", self._ctx.onebot) - if self._ctx.scheduler is not None: - ctx.set_resource("scheduler", self._ctx.scheduler) - if self._ctx.cognitive_service is not None: - ctx.set_resource("cognitive_service", self._ctx.cognitive_service) - if self._ctx.meme_service is not None: - ctx.set_resource("meme_service", self._ctx.meme_service) - - tool_context: dict[str, Any] = { - "request_type": request_type, - "request_id": request_id, - } - if group_id is not None: - tool_context["group_id"] = int(group_id) - if user_id is not None: - tool_context["user_id"] = int(user_id) - if sender_id is not None: - tool_context["sender_id"] = int(sender_id) - - tool_manager = getattr(ai, "tool_manager", None) - if tool_manager is None: - raise RuntimeError("ToolManager not available") - - raw_result = await self._await_tool_invoke_result( - tool_manager.execute_tool(tool_name, args, tool_context), - timeout=effective_timeout, - ) - - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - result_str = str(raw_result or "") - logger.info( - "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " - "duration_ms=%s result_len=%d", - request_id, - tool_name, - elapsed_ms, - len(result_str), - ) - return { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "result": result_str, - "duration_ms": elapsed_ms, - } - - except _ToolInvokeExecutionTimeoutError: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.warning( - "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", - request_id, - tool_name, - timeout, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": f"Execution timed out after {timeout}s", - "duration_ms": elapsed_ms, - } - except Exception as exc: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.exception( - "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", - request_id, - tool_name, - exc, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": str(exc), - "duration_ms": elapsed_ms, - } - - async def _execute_and_callback( - self, - *, - request_id: str, - tool_name: str, - args: dict[str, Any], - body_context: Any, - callback_url: str, - callback_headers: dict[str, str], - timeout: int, - callback_timeout: int, - ) -> None: - """异步执行工具并发送回调。""" - result = await self._execute_tool_invoke( + return await tools.execute_tool_invoke( + self._ctx, request_id=request_id, tool_name=tool_name, args=args, @@ -2114,371 +289,17 @@ async def _execute_and_callback( timeout=timeout, ) - payload = { - "request_id": result["request_id"], - "tool_name": result["tool_name"], - "ok": result["ok"], - "result": result.get("result"), - "duration_ms": result.get("duration_ms", 0), - "error": result.get("error"), - } - - try: - cb_timeout = ClientTimeout(total=callback_timeout) - async with ClientSession(timeout=cb_timeout) as session: - # aiohttp json= 自动设置 Content-Type,无需手动指定 - async with session.post( - callback_url, - json=payload, - headers=callback_headers or None, - ) as resp: - logger.info( - "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", - request_id, - _mask_url(callback_url), - resp.status, - ) - except Exception as exc: - logger.warning( - "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", - request_id, - _mask_url(callback_url), - exc, - ) - - # ------------------------------------------------------------------ - # Naga Bind / Send / Unbind API - # ------------------------------------------------------------------ - + # Naga def _verify_naga_api_key(self, request: web.Request) -> str | None: - """校验 Naga 共享密钥,返回错误信息或 None 表示通过。""" - cfg = self._ctx.config_getter() - expected = cfg.naga.api_key - if not expected: - return "naga api_key not configured" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return "missing or invalid Authorization header" - provided = auth_header[7:] - import secrets as _secrets - - if not _secrets.compare_digest(provided, expected): - return "invalid api_key" - return None + return naga.verify_naga_api_key(self._ctx, request) async def _naga_bind_callback_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - status = str(body.get("status", "") or "").strip().lower() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - reason = str(body.get("reason", "") or "").strip() - if not bind_uuid or not naga_id: - return _json_error("bind_uuid and naga_id are required", status=400) - if status not in {"approved", "rejected"}: - return _json_error("status must be 'approved' or 'rejected'", status=400) - logger.info( - "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - status, - _short_text_preview(reason, limit=60), - delivery_signature[:12] + "..." if delivery_signature else "", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - sender = self._ctx.sender - if status == "approved": - if not delivery_signature: - return _json_error( - "delivery_signature is required when approved", status=400 - ) - binding, created, err = await naga_store.activate_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - delivery_signature=delivery_signature, - ) - if err: - logger.warning( - "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - created, - binding.qq_id if binding is not None else "", - ) - if created and binding is not None and sender is not None: - try: - await sender.send_private_message( - binding.qq_id, - f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "approved", - "idempotent": not created, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) - - pending, removed, err = await naga_store.reject_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - reason=reason, - ) - if err: - logger.warning( - "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - removed, - pending.qq_id if pending is not None else "", - ) - if removed and pending is not None and sender is not None: - try: - detail = f"\n原因: {reason}" if reason else "" - await sender.send_private_message( - pending.qq_id, - f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "rejected", - "idempotent": not removed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_bind_callback_handler(self._ctx, request) async def _naga_messages_send_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/messages/send — 验签后发送消息。""" - from Undefined.api.naga_store import mask_token - - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - request_uuid = str(body.get("uuid", "") or "").strip() - target = body.get("target") - message = body.get("message") - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - if not isinstance(target, dict): - return _json_error("target object is required", status=400) - if not isinstance(message, dict): - return _json_error("message object is required", status=400) - - raw_target_qq = target.get("qq_id") - raw_target_group = target.get("group_id") - if raw_target_qq is None or raw_target_group is None: - return _json_error( - "target.qq_id and target.group_id are required", status=400 - ) - try: - target_qq = int(raw_target_qq) - target_group = int(raw_target_group) - except Exception: - return _json_error( - "target.qq_id and target.group_id must be integers", status=400 - ) - mode = str(target.get("mode", "") or "").strip().lower() - if mode not in {"private", "group", "both"}: - return _json_error( - "target.mode must be 'private', 'group', or 'both'", status=400 - ) - - fmt = str(message.get("format", "text") or "text").strip().lower() - content = str(message.get("content", "") or "").strip() - if fmt not in {"text", "markdown", "html"}: - return _json_error( - "message.format must be 'text', 'markdown', or 'html'", status=400 - ) - if not content: - return _json_error("message.content is required", status=400) - - message_key = _naga_message_digest( - bind_uuid=bind_uuid, - naga_id=naga_id, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - ) - logger.info( - "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - request_uuid, - mode, - fmt, - target_qq, - target_group, - message_key, - len(content), - _short_text_preview(content), - mask_token(delivery_signature), + return await naga.naga_messages_send_handler( + self._ctx, self._naga_state, request ) - if mode == "both": - logger.warning( - "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - inflight_count = await self._track_naga_send_start(message_key) - if inflight_count > 1: - logger.warning( - "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - inflight_count, - ) - try: - if request_uuid: - dedupe_action, dedupe_value = await self._register_naga_request_uuid( - request_uuid, message_key - ) - if dedupe_action == "conflict": - logger.warning( - "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return _json_error("uuid reused with different payload", status=409) - if dedupe_action == "cached": - cached_status, cached_payload = dedupe_value - logger.warning( - "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - if dedupe_action == "await": - wait_future = dedupe_value - logger.warning( - "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - cached_status, cached_payload = await wait_future - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - - response = await self._naga_messages_send_impl( - naga_id=naga_id, - bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - trace_id=trace_id, - message_key=message_key, - ) - if request_uuid: - await self._finish_naga_request_uuid( - request_uuid, - message_key, - status=response.status, - payload=_parse_response_payload(response), - ) - return response - except Exception as exc: - if request_uuid: - await self._fail_naga_request_uuid(request_uuid, message_key, exc) - raise - finally: - remaining = await self._track_naga_send_done(message_key) - logger.info( - "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - remaining, - ) async def _naga_messages_send_impl( self, @@ -2494,492 +315,19 @@ async def _naga_messages_send_impl( trace_id: str, message_key: str, ) -> Response: - from Undefined.api.naga_store import mask_token - - naga_store = self._ctx.naga_store - if naga_store is None: - logger.warning( - "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("Naga integration not available", status=503) - - binding, err_msg = await naga_store.acquire_delivery( + return await naga.naga_messages_send_impl( + self._ctx, naga_id=naga_id, bind_uuid=bind_uuid, delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=message_format, + content=content, + trace_id=trace_id, + message_key=message_key, ) - if binding is None: - logger.warning( - "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", - trace_id, - naga_id, - bind_uuid, - err_msg.message if err_msg is not None else "unknown_error", - mask_token(delivery_signature), - ) - return _json_error( - err_msg.message if err_msg is not None else "delivery not available", - status=err_msg.http_status if err_msg is not None else 403, - ) - - logger.info( - "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - binding.qq_id, - binding.group_id, - ) - try: - if target_qq != binding.qq_id or target_group != binding.group_id: - logger.warning( - "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", - trace_id, - naga_id, - bind_uuid, - target_qq, - target_group, - binding.qq_id, - binding.group_id, - ) - return _json_error("target does not match bound qq/group", status=403) - - cfg = self._ctx.config_getter() - if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: - logger.warning( - "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", - trace_id, - naga_id, - bind_uuid, - binding.group_id, - ) - return _json_error( - "bound group is not in naga.allowed_groups", status=403 - ) - - sender = self._ctx.sender - if sender is None: - logger.warning( - "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("sender not available", status=503) - - moderation: dict[str, Any] - naga_cfg = getattr(cfg, "naga", None) - moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) - security = getattr(self._ctx.command_dispatcher, "security", None) - if not moderation_enabled: - moderation = { - "status": "skipped_disabled", - "blocked": False, - "categories": [], - "message": "Naga moderation disabled by config; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - ) - elif security is None or not hasattr(security, "moderate_naga_message"): - moderation = { - "status": "error_allowed", - "blocked": False, - "categories": [], - "message": "Naga moderation service unavailable; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - else: - logger.info( - "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - len(content), - ) - result = await security.moderate_naga_message( - message_format=message_format, - content=content, - ) - moderation = { - "status": result.status, - "blocked": result.blocked, - "categories": result.categories, - "message": result.message, - "model_name": result.model_name, - } - logger.info( - "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - result.blocked, - result.status, - result.model_name, - ",".join(result.categories) or "-", - ) - if moderation["blocked"]: - logger.warning( - "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - moderation["message"], - ) - return web.json_response( - { - "ok": False, - "error": "message blocked by moderation", - "moderation": moderation, - }, - status=403, - ) - - send_content: str | None = content if message_format == "text" else None - image_path: str | None = None - tmp_path: str | None = None - rendered = False - render_fallback = False - if message_format in {"markdown", "html"}: - import tempfile - - try: - html_str = content - if message_format == "markdown": - html_str = await render_markdown_to_html(content) - fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") - os.close(fd) - await render_html_to_image(html_str, tmp_path) - image_path = tmp_path - rendered = True - logger.info( - "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - Path(tmp_path).name if tmp_path is not None else "", - ) - except Exception as exc: - logger.warning( - "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - exc, - ) - send_content = content - render_fallback = True - - sent_private = False - sent_group = False - group_policy_blocked = False - - async def _ensure_delivery_active() -> tuple[Any, Response | None]: - current_binding, live_err = await naga_store.ensure_delivery_active( - naga_id=naga_id, - bind_uuid=bind_uuid, - ) - if current_binding is None: - logger.warning( - "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - live_err.message - if live_err is not None - else "delivery no longer active", - ) - return None, web.json_response( - { - "ok": False, - "error": ( - live_err.message - if live_err is not None - else "delivery no longer active" - ), - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=live_err.http_status if live_err is not None else 409, - ) - return current_binding, None - - try: - cq_image: str | None = None - if image_path is not None: - file_uri = Path(image_path).resolve().as_uri() - cq_image = f"[CQ:image,file={file_uri}]" - - if mode in {"private", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - logger.info( - "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - try: - if send_content is not None: - await sender.send_private_message( - current_binding.qq_id, send_content - ) - elif cq_image is not None: - await sender.send_private_message( - current_binding.qq_id, cq_image - ) - sent_private = True - logger.info( - "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.qq_id, - message_key, - exc, - ) - - if mode in {"group", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - current_cfg = self._ctx.config_getter() - if current_binding.group_id not in current_cfg.naga.allowed_groups: - group_policy_blocked = True - logger.warning( - "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - else: - logger.info( - "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - try: - if send_content is not None: - await sender.send_group_message( - current_binding.group_id, send_content - ) - elif cq_image is not None: - await sender.send_group_message( - current_binding.group_id, cq_image - ) - sent_group = True - logger.info( - "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.group_id, - message_key, - exc, - ) - finally: - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - - if mode == "private" and not sent_private: - return web.json_response( - { - "ok": False, - "error": "private delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "group" and not sent_group: - return web.json_response( - { - "ok": False, - "error": "group delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "both" and not (sent_private or sent_group): - if group_policy_blocked: - return web.json_response( - { - "ok": False, - "error": "bound group is not in naga.allowed_groups", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=403, - ) - return web.json_response( - { - "ok": False, - "error": "all deliveries failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - - await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) - partial_success = mode == "both" and (sent_private != sent_group) - logger.info( - "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - sent_private, - sent_group, - partial_success, - rendered, - render_fallback, - ) - return web.json_response( - { - "ok": True, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - "sent_private": sent_private, - "sent_group": sent_group, - "partial_success": partial_success, - "delivery_status": ( - "partial_success" if partial_success else "full_success" - ), - "rendered": rendered, - "render_fallback": render_fallback, - "moderation": moderation, - } - ) - finally: - await naga_store.release_delivery(bind_uuid=bind_uuid) async def _naga_unbind_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/unbind — 远端主动解绑。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - logger.info( - "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - delivery_signature[:12] + "...", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - binding, changed, err = await naga_store.revoke_binding( - naga_id, - expected_bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - ) - if binding is None: - logger.warning( - "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message if err is not None else "binding not found", - ) - return _json_error( - err.message if err is not None else "binding not found", - status=err.http_status if err is not None else 404, - ) - logger.info( - "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - changed, - binding.qq_id, - binding.group_id, - ) - return web.json_response( - { - "ok": True, - "idempotent": not changed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_unbind_handler(self._ctx, request) diff --git a/src/Undefined/api/routes/__init__.py b/src/Undefined/api/routes/__init__.py new file mode 100644 index 00000000..8a999d1f --- /dev/null +++ b/src/Undefined/api/routes/__init__.py @@ -0,0 +1 @@ +"""Runtime API route modules.""" diff --git a/src/Undefined/api/routes/chat.py b/src/Undefined/api/routes/chat.py new file mode 100644 index 00000000..536435ca --- /dev/null +++ b/src/Undefined/api/routes/chat.py @@ -0,0 +1,359 @@ +"""Chat route handlers extracted from the Runtime API application.""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import suppress +from datetime import datetime +from typing import Any, Awaitable, Callable + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _WebUIVirtualSender, + _build_chat_response_payload, + _json_error, + _sse_event, + _to_bool, +) +from Undefined.attachments import ( + attachment_refs_to_xml, + build_attachment_scope, + register_message_attachments, + render_message_with_pic_placeholders, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN +from Undefined.utils.common import message_to_segments +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +logger = logging.getLogger(__name__) + +_VIRTUAL_USER_NAME = "system" +_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 + + +async def run_webui_chat( + ctx: RuntimeAPIContext, + *, + text: str, + send_output: Callable[[int, str], Awaitable[None]], +) -> str: + """Execute a single WebUI chat turn (command dispatch or AI ask).""" + + cfg = ctx.config_getter() + permission_sender_id = int(cfg.superadmin_qq) + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + input_segments = message_to_segments(text) + registered_input = await register_message_attachments( + registry=ctx.ai.attachment_registry, + segments=input_segments, + scope_key=webui_scope_key, + resolve_image_url=ctx.onebot.get_image, + get_forward_messages=ctx.onebot.get_forward_msg, + ) + normalized_text = registered_input.normalized_text or text + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=normalized_text, + display_name=_VIRTUAL_USER_NAME, + user_name=_VIRTUAL_USER_NAME, + attachments=registered_input.attachments, + ) + + command = ctx.command_dispatcher.parse_command(normalized_text) + if command: + await ctx.command_dispatcher.dispatch_private( + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + command=command, + send_private_callback=send_output, + is_webui_session=True, + ) + return "command" + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + attachment_xml = ( + f"\n{attachment_refs_to_xml(registered_input.attachments)}" + if registered_input.attachments + else "" + ) + full_question = f""" + {escape_xml_text(normalized_text)}{attachment_xml} + + +【WebUI 会话】 +这是一条来自 WebUI 控制台的会话请求。 +会话身份:虚拟用户 system(42)。 +权限等级:superadmin(你可按最高管理权限处理)。 +请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" + virtual_sender = _WebUIVirtualSender( + _VIRTUAL_USER_ID, send_output, onebot=ctx.onebot + ) + + async def _get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=ctx.onebot, + history_manager=ctx.history_manager, + bot_qq=cfg.bot_qq, + attachment_registry=getattr(ctx.ai, "attachment_registry", None), + ) + + async with RequestContext( + request_type="private", + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + ) as rctx: + ai_client = ctx.ai # noqa: F841 + memory_storage = ctx.ai.memory_storage # noqa: F841 + runtime_config = ctx.ai.runtime_config # noqa: F841 + sender = virtual_sender # noqa: F841 + history_manager = ctx.history_manager # noqa: F841 + onebot_client = ctx.onebot # noqa: F841 + scheduler = ctx.scheduler # noqa: F841 + + def send_message_callback( + msg: str, reply_to: int | None = None + ) -> Awaitable[None]: + _ = reply_to + return send_output(_VIRTUAL_USER_ID, msg) + + get_recent_messages_callback = _get_recent_cb # noqa: F841 + get_image_url_callback = ctx.onebot.get_image # noqa: F841 + get_forward_msg_callback = ctx.onebot.get_forward_msg # noqa: F841 + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + rctx.set_resource(key, value) + rctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) + rctx.set_resource("webui_session", True) + rctx.set_resource("webui_permission", "superadmin") + + result = await ctx.ai.ask( + full_question, + send_message_callback=send_message_callback, + get_recent_messages_callback=get_recent_messages_callback, + get_image_url_callback=get_image_url_callback, + get_forward_msg_callback=get_forward_msg_callback, + sender=sender, + history_manager=history_manager, + onebot_client=onebot_client, + scheduler=scheduler, + extra_context={ + "is_private_chat": True, + "request_type": "private", + "user_id": _VIRTUAL_USER_ID, + "sender_name": _VIRTUAL_USER_NAME, + "webui_session": True, + "webui_permission": "superadmin", + }, + ) + + final_reply = str(result or "").strip() + if final_reply: + await send_output(_VIRTUAL_USER_ID, final_reply) + + return "chat" + + +async def chat_history_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """Return recent WebUI chat history.""" + + limit_raw = str(request.query.get("limit", "200") or "200").strip() + try: + limit = int(limit_raw) + except ValueError: + limit = 200 + limit = max(1, min(limit, 500)) + + getter = getattr(ctx.history_manager, "get_recent_private", None) + if not callable(getter): + return _json_error("History manager not ready", status=503) + + records = getter(_VIRTUAL_USER_ID, limit) + items: list[dict[str, Any]] = [] + for item in records: + if not isinstance(item, dict): + continue + content = str(item.get("message", "")).strip() + if not content: + continue + display_name = str(item.get("display_name", "")).strip().lower() + role = "bot" if display_name == "bot" else "user" + items.append( + { + "role": role, + "content": content, + "timestamp": str(item.get("timestamp", "") or "").strip(), + } + ) + + return web.json_response( + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "count": len(items), + "items": items, + } + ) + + +async def chat_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> web.StreamResponse: + """Handle a WebUI chat request (non-streaming or SSE streaming).""" + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + text = str(body.get("message", "") or "").strip() + if not text: + return _json_error("message is required", status=400) + + stream = _to_bool(body.get("stream")) + outputs: list[str] = [] + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + + async def _capture_private_message(user_id: int, message: str) -> None: + _ = user_id + content = str(message or "").strip() + if not content: + return + rendered = await render_message_with_pic_placeholders( + content, + registry=ctx.ai.attachment_registry, + scope_key=webui_scope_key, + strict=False, + ) + if not rendered.delivery_text.strip(): + return + outputs.append(rendered.delivery_text) + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=rendered.history_text, + display_name="Bot", + user_name="Bot", + attachments=rendered.attachments, + ) + + if not stream: + try: + mode = await run_webui_chat( + ctx, text=text, send_output=_capture_private_message + ) + except Exception as exc: + logger.exception("[RuntimeAPI] chat failed: %s", exc) + return _json_error("Chat failed", status=502) + return web.json_response(_build_chat_response_payload(mode, outputs)) + + response = web.StreamResponse( + status=200, + reason="OK", + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await response.prepare(request) + + message_queue: asyncio.Queue[str] = asyncio.Queue() + + async def _capture_private_message_stream(user_id: int, message: str) -> None: + output_count = len(outputs) + await _capture_private_message(user_id, message) + if len(outputs) <= output_count: + return + content = outputs[-1].strip() + if content: + await message_queue.put(content) + + task = asyncio.create_task( + run_webui_chat(ctx, text=text, send_output=_capture_private_message_stream) + ) + mode = "chat" + client_disconnected = False + try: + await response.write( + _sse_event( + "meta", + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + }, + ) + ) + + while True: + if request.transport is None or request.transport.is_closing(): + client_disconnected = True + break + if task.done() and message_queue.empty(): + break + try: + message = await asyncio.wait_for( + message_queue.get(), + timeout=_CHAT_SSE_KEEPALIVE_SECONDS, + ) + await response.write(_sse_event("message", {"content": message})) + except asyncio.TimeoutError: + await response.write(b": keep-alive\n\n") + + if client_disconnected: + task.cancel() + with suppress(asyncio.CancelledError): + await task + return response + + mode = await task + await response.write( + _sse_event("done", _build_chat_response_payload(mode, outputs)) + ) + except asyncio.CancelledError: + task.cancel() + with suppress(asyncio.CancelledError): + await task + raise + except (ConnectionResetError, RuntimeError): + task.cancel() + with suppress(asyncio.CancelledError): + await task + except Exception as exc: + logger.exception("[RuntimeAPI] chat stream failed: %s", exc) + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + with suppress(Exception): + await response.write(_sse_event("error", {"error": str(exc)})) + finally: + with suppress(Exception): + await response.write_eof() + + return response diff --git a/src/Undefined/api/routes/cognitive.py b/src/Undefined/api/routes/cognitive.py new file mode 100644 index 00000000..66096994 --- /dev/null +++ b/src/Undefined/api/routes/cognitive.py @@ -0,0 +1,86 @@ +"""Cognitive event & profile routes.""" + +from __future__ import annotations + +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param + + +async def cognitive_events_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + for key in ( + "target_user_id", + "target_group_id", + "sender_id", + "request_type", + "top_k", + "time_from", + "time_to", + ): + value = _optional_query_param(request, key) + if value is not None: + search_kwargs[key] = value + + results = await cognitive_service.search_events(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profiles_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + entity_type = _optional_query_param(request, "entity_type") + if entity_type is not None: + search_kwargs["entity_type"] = entity_type + top_k = _optional_query_param(request, "top_k") + if top_k is not None: + search_kwargs["top_k"] = top_k + + results = await cognitive_service.search_profiles(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profile_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + entity_type = str(request.match_info.get("entity_type", "")).strip() + entity_id = str(request.match_info.get("entity_id", "")).strip() + if not entity_type or not entity_id: + return _json_error("entity_type/entity_id are required", status=400) + + profile = await cognitive_service.get_profile(entity_type, entity_id) + return web.json_response( + { + "entity_type": entity_type, + "entity_id": entity_id, + "profile": profile or "", + "found": bool(profile), + } + ) diff --git a/src/Undefined/api/routes/health.py b/src/Undefined/api/routes/health.py new file mode 100644 index 00000000..bbdb0208 --- /dev/null +++ b/src/Undefined/api/routes/health.py @@ -0,0 +1,23 @@ +"""Health check route.""" + +from __future__ import annotations + +from datetime import datetime + +from aiohttp.web_response import Response +from aiohttp import web + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext + + +async def health_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = ctx, request + return web.json_response( + { + "ok": True, + "service": "undefined-runtime-api", + "version": __version__, + "timestamp": datetime.now().isoformat(), + } + ) diff --git a/src/Undefined/api/routes/memes.py b/src/Undefined/api/routes/memes.py new file mode 100644 index 00000000..dcef98ad --- /dev/null +++ b/src/Undefined/api/routes/memes.py @@ -0,0 +1,222 @@ +"""Meme management route handlers.""" + +from __future__ import annotations + +from typing import Any, cast + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _to_bool + + +async def meme_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + + def _parse_optional_bool(name: str) -> bool | None: + raw = request.query.get(name) + if raw is None or str(raw).strip() == "": + return None + return _to_bool(raw) + + page_raw = _optional_query_param(request, "page") + page_size_raw = _optional_query_param(request, "page_size") + top_k_raw = _optional_query_param(request, "top_k") + query = str(request.query.get("q", "") or "").strip() + query_mode = str(request.query.get("query_mode", "") or "").strip().lower() + keyword_query = str(request.query.get("keyword_query", "") or "").strip() + semantic_query = str(request.query.get("semantic_query", "") or "").strip() + try: + page = int(page_raw) if page_raw is not None else 1 + page_size = int(page_size_raw) if page_size_raw is not None else 50 + top_k = int(top_k_raw) if top_k_raw is not None else page_size + except ValueError: + return _json_error("page/page_size/top_k must be integers", status=400) + page = max(1, page) + page_size = max(1, min(200, page_size)) + top_k = max(1, top_k) + sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() + + enabled_filter = _parse_optional_bool("enabled") + animated_filter = _parse_optional_bool("animated") + pinned_filter = _parse_optional_bool("pinned") + if not (query or keyword_query or semantic_query) and sort == "relevance": + sort = "updated_at" + + if query or keyword_query or semantic_query: + has_post_filter = any( + f is not None for f in (enabled_filter, animated_filter, pinned_filter) + ) + requested_window = max(page * page_size, top_k) + if has_post_filter or page > 1 or sort != "relevance": + fetch_k = min(500, max(requested_window * 4, top_k)) + else: + fetch_k = min(500, requested_window) + search_payload = await meme_service.search_memes( + query, + query_mode=query_mode or meme_service.default_query_mode, + keyword_query=keyword_query or None, + semantic_query=semantic_query or None, + top_k=fetch_k, + include_disabled=enabled_filter is not True, + sort=sort, + ) + filtered_items: list[dict[str, Any]] = [] + for item in list(search_payload.get("items") or []): + if ( + enabled_filter is not None + and bool(item.get("enabled")) != enabled_filter + ): + continue + if ( + animated_filter is not None + and bool(item.get("is_animated")) != animated_filter + ): + continue + if pinned_filter is not None and bool(item.get("pinned")) != pinned_filter: + continue + filtered_items.append(item) + offset = (page - 1) * page_size + paged_items = filtered_items[offset : offset + page_size] + window_total = len(filtered_items) + fetched_window_count = len(list(search_payload.get("items") or [])) + window_exhausted = fetched_window_count < fetch_k + has_more = bool(paged_items) and ( + offset + page_size < window_total + or (not window_exhausted and window_total >= offset + page_size) + ) + return web.json_response( + { + "ok": True, + "total": None, + "window_total": window_total, + "total_exact": False, + "page": page, + "page_size": page_size, + "has_more": has_more, + "query_mode": search_payload.get("query_mode"), + "keyword_query": search_payload.get("keyword_query"), + "semantic_query": search_payload.get("semantic_query"), + "sort": search_payload.get("sort", sort), + "items": paged_items, + } + ) + + payload = await meme_service.list_memes( + query=query, + enabled=enabled_filter, + animated=animated_filter, + pinned=pinned_filter, + sort=sort, + page=page, + page_size=page_size, + summary=True, + ) + return web.json_response(payload) + + +async def meme_stats_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + return web.json_response(await meme_service.stats()) + + +async def meme_detail_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + detail = await meme_service.get_meme(uid) + if detail is None: + return _json_error("Meme not found", status=404) + return web.json_response(detail) + + +async def meme_blob_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=False) + if path is None: + return _json_error("Meme blob not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_preview_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=True) + if path is None: + return _json_error("Meme preview not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_update_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + try: + payload = await request.json() + except Exception: + return _json_error("Invalid JSON body", status=400) + if not isinstance(payload, dict): + return _json_error("JSON body must be an object", status=400) + updated = await meme_service.update_meme( + uid, + manual_description=payload.get("manual_description"), + tags=payload.get("tags"), + aliases=payload.get("aliases"), + enabled=payload.get("enabled") if "enabled" in payload else None, + pinned=payload.get("pinned") if "pinned" in payload else None, + ) + if updated is None: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "record": updated}) + + +async def meme_delete_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + deleted = await meme_service.delete_meme(uid) + if not deleted: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "uid": uid}) + + +async def meme_reanalyze_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reanalyze(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + + +async def meme_reindex_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reindex(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) diff --git a/src/Undefined/api/routes/memory.py b/src/Undefined/api/routes/memory.py new file mode 100644 index 00000000..30c114ed --- /dev/null +++ b/src/Undefined/api/routes/memory.py @@ -0,0 +1,156 @@ +"""Memory CRUD routes.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _parse_query_time + + +async def memory_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + query = str(request.query.get("q", "") or "").strip().lower() + top_k_raw = _optional_query_param(request, "top_k") + time_from_raw = _optional_query_param(request, "time_from") + time_to_raw = _optional_query_param(request, "time_to") + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + + limit: int | None = None + if top_k_raw is not None: + try: + limit = int(top_k_raw) + except ValueError: + return _json_error("top_k must be an integer", status=400) + if limit <= 0: + return _json_error("top_k must be > 0", status=400) + + time_from_dt = _parse_query_time(time_from_raw) + if time_from_raw is not None and time_from_dt is None: + return _json_error("time_from must be ISO datetime", status=400) + time_to_dt = _parse_query_time(time_to_raw) + if time_to_raw is not None and time_to_dt is None: + return _json_error("time_to must be ISO datetime", status=400) + if time_from_dt and time_to_dt and time_from_dt > time_to_dt: + time_from_dt, time_to_dt = time_to_dt, time_from_dt + + records = memory_storage.get_all() + items: list[dict[str, Any]] = [] + for item in records: + created_at = str(item.created_at or "").strip() + created_dt = _parse_query_time(created_at) + if time_from_dt and created_dt and created_dt < time_from_dt: + continue + if time_to_dt and created_dt and created_dt > time_to_dt: + continue + if (time_from_dt or time_to_dt) and created_dt is None: + continue + items.append( + { + "uuid": item.uuid, + "fact": item.fact, + "created_at": created_at, + } + ) + if query: + items = [ + item + for item in items + if query in str(item.get("fact", "")).lower() + or query in str(item.get("uuid", "")).lower() + ] + + def _created_sort_key(item: dict[str, Any]) -> float: + created_dt = _parse_query_time(str(item.get("created_at") or "")) + if created_dt is None: + return float("-inf") + with suppress(OSError, OverflowError, ValueError): + return float(created_dt.timestamp()) + return float("-inf") + + items.sort(key=_created_sort_key) + if limit is not None: + items = items[:limit] + + return web.json_response( + { + "total": len(items), + "items": items, + "query": { + "q": query or "", + "top_k": limit, + "time_from": time_from_raw, + "time_to": time_to_raw, + }, + } + ) + + +async def memory_create_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + new_uuid = await memory_storage.add(fact) + if new_uuid is None: + return _json_error("Failed to create memory", status=500) + existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] + item = existing[0] if existing else None + return web.json_response( + { + "uuid": new_uuid, + "fact": item.fact if item else fact, + "created_at": item.created_at if item else "", + }, + status=201, + ) + + +async def memory_update_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + ok = await memory_storage.update(target_uuid, fact) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + + +async def memory_delete_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + ok = await memory_storage.delete(target_uuid) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "deleted": True}) diff --git a/src/Undefined/api/routes/naga.py b/src/Undefined/api/routes/naga.py new file mode 100644 index 00000000..b76375de --- /dev/null +++ b/src/Undefined/api/routes/naga.py @@ -0,0 +1,897 @@ +"""Naga integration route handlers. + +Extracted from ``RuntimeAPI`` methods into free functions so they can be +registered declaratively in the route table. +""" + +from __future__ import annotations + +import logging +import os +import uuid as _uuid +from copy import deepcopy +from pathlib import Path +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, + _naga_message_digest, + _parse_response_payload, + _short_text_preview, +) +from Undefined.api._naga_state import NagaState +from Undefined.render import render_html_to_image, render_markdown_to_html + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Auth helper +# ------------------------------------------------------------------ + + +def verify_naga_api_key(ctx: RuntimeAPIContext, request: web.Request) -> str | None: + """校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过。""" + import secrets as _secrets + + cfg = ctx.config_getter() + expected = cfg.naga.api_key + if not expected: + return "naga api_key not configured" + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return "missing or invalid Authorization header" + provided = auth_header[7:] + if not _secrets.compare_digest(provided, expected): + return "invalid api_key" + return None + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/bind/callback +# ------------------------------------------------------------------ + + +async def naga_bind_callback_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + status = str(body.get("status", "") or "").strip().lower() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + reason = str(body.get("reason", "") or "").strip() + if not bind_uuid or not naga_id: + return _json_error("bind_uuid and naga_id are required", status=400) + if status not in {"approved", "rejected"}: + return _json_error("status must be 'approved' or 'rejected'", status=400) + logger.info( + "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + status, + _short_text_preview(reason, limit=60), + delivery_signature[:12] + "..." if delivery_signature else "", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + sender = ctx.sender + if status == "approved": + if not delivery_signature: + return _json_error( + "delivery_signature is required when approved", status=400 + ) + binding, created, err = await naga_store.activate_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + delivery_signature=delivery_signature, + ) + if err: + logger.warning( + "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + created, + binding.qq_id if binding is not None else "", + ) + if created and binding is not None and sender is not None: + try: + await sender.send_private_message( + binding.qq_id, + f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "approved", + "idempotent": not created, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + # --- rejected --- + pending, removed, err = await naga_store.reject_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + reason=reason, + ) + if err: + logger.warning( + "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + removed, + pending.qq_id if pending is not None else "", + ) + if removed and pending is not None and sender is not None: + try: + detail = f"\n原因: {reason}" if reason else "" + await sender.send_private_message( + pending.qq_id, + f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "rejected", + "idempotent": not removed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/messages/send +# ------------------------------------------------------------------ + + +async def naga_messages_send_handler( + ctx: RuntimeAPIContext, + naga_state: NagaState, + request: web.Request, +) -> Response: + """POST /api/v1/naga/messages/send — 验签后发送消息。""" + from Undefined.api.naga_store import mask_token + + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + request_uuid = str(body.get("uuid", "") or "").strip() + target = body.get("target") + message = body.get("message") + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + if not isinstance(target, dict): + return _json_error("target object is required", status=400) + if not isinstance(message, dict): + return _json_error("message object is required", status=400) + + raw_target_qq = target.get("qq_id") + raw_target_group = target.get("group_id") + if raw_target_qq is None or raw_target_group is None: + return _json_error("target.qq_id and target.group_id are required", status=400) + try: + target_qq = int(raw_target_qq) + target_group = int(raw_target_group) + except Exception: + return _json_error( + "target.qq_id and target.group_id must be integers", status=400 + ) + mode = str(target.get("mode", "") or "").strip().lower() + if mode not in {"private", "group", "both"}: + return _json_error( + "target.mode must be 'private', 'group', or 'both'", status=400 + ) + + fmt = str(message.get("format", "text") or "text").strip().lower() + content = str(message.get("content", "") or "").strip() + if fmt not in {"text", "markdown", "html"}: + return _json_error( + "message.format must be 'text', 'markdown', or 'html'", status=400 + ) + if not content: + return _json_error("message.content is required", status=400) + + message_key = _naga_message_digest( + bind_uuid=bind_uuid, + naga_id=naga_id, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + ) + logger.info( + "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + request_uuid, + mode, + fmt, + target_qq, + target_group, + message_key, + len(content), + _short_text_preview(content), + mask_token(delivery_signature), + ) + if mode == "both": + logger.warning( + "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + inflight_count = await naga_state.track_send_start(message_key) + if inflight_count > 1: + logger.warning( + "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + inflight_count, + ) + try: + if request_uuid: + dedupe_action, dedupe_value = await naga_state.register_request_uuid( + request_uuid, message_key + ) + if dedupe_action == "conflict": + logger.warning( + "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return _json_error("uuid reused with different payload", status=409) + if dedupe_action == "cached": + cached_status, cached_payload = dedupe_value + logger.warning( + "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + if dedupe_action == "await": + wait_future = dedupe_value + logger.warning( + "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + cached_status, cached_payload = await wait_future + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + + response = await naga_messages_send_impl( + ctx, + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + trace_id=trace_id, + message_key=message_key, + ) + if request_uuid: + await naga_state.finish_request_uuid( + request_uuid, + message_key, + status=response.status, + payload=_parse_response_payload(response), + ) + return response + except Exception as exc: + if request_uuid: + await naga_state.fail_request_uuid(request_uuid, message_key, exc) + raise + finally: + remaining = await naga_state.track_send_done(message_key) + logger.info( + "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + remaining, + ) + + +# ------------------------------------------------------------------ +# Core send implementation (no NagaState dependency) +# ------------------------------------------------------------------ + + +async def naga_messages_send_impl( + ctx: RuntimeAPIContext, + *, + naga_id: str, + bind_uuid: str, + delivery_signature: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, + trace_id: str, + message_key: str, +) -> Response: + from Undefined.api.naga_store import mask_token + + naga_store = ctx.naga_store + if naga_store is None: + logger.warning( + "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("Naga integration not available", status=503) + + binding, err_msg = await naga_store.acquire_delivery( + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", + trace_id, + naga_id, + bind_uuid, + err_msg.message if err_msg is not None else "unknown_error", + mask_token(delivery_signature), + ) + return _json_error( + err_msg.message if err_msg is not None else "delivery not available", + status=err_msg.http_status if err_msg is not None else 403, + ) + + logger.info( + "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + binding.qq_id, + binding.group_id, + ) + try: + if target_qq != binding.qq_id or target_group != binding.group_id: + logger.warning( + "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", + trace_id, + naga_id, + bind_uuid, + target_qq, + target_group, + binding.qq_id, + binding.group_id, + ) + return _json_error("target does not match bound qq/group", status=403) + + cfg = ctx.config_getter() + if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: + logger.warning( + "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", + trace_id, + naga_id, + bind_uuid, + binding.group_id, + ) + return _json_error("bound group is not in naga.allowed_groups", status=403) + + sender = ctx.sender + if sender is None: + logger.warning( + "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("sender not available", status=503) + + moderation: dict[str, Any] + naga_cfg = getattr(cfg, "naga", None) + moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) + security = getattr(ctx.command_dispatcher, "security", None) + if not moderation_enabled: + moderation = { + "status": "skipped_disabled", + "blocked": False, + "categories": [], + "message": "Naga moderation disabled by config; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + ) + elif security is None or not hasattr(security, "moderate_naga_message"): + moderation = { + "status": "error_allowed", + "blocked": False, + "categories": [], + "message": "Naga moderation service unavailable; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + else: + logger.info( + "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + len(content), + ) + result = await security.moderate_naga_message( + message_format=message_format, + content=content, + ) + moderation = { + "status": result.status, + "blocked": result.blocked, + "categories": result.categories, + "message": result.message, + "model_name": result.model_name, + } + logger.info( + "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + result.blocked, + result.status, + result.model_name, + ",".join(result.categories) or "-", + ) + if moderation["blocked"]: + logger.warning( + "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + moderation["message"], + ) + return web.json_response( + { + "ok": False, + "error": "message blocked by moderation", + "moderation": moderation, + }, + status=403, + ) + + send_content: str | None = content if message_format == "text" else None + image_path: str | None = None + tmp_path: str | None = None + rendered = False + render_fallback = False + if message_format in {"markdown", "html"}: + import tempfile + + try: + html_str = content + if message_format == "markdown": + html_str = await render_markdown_to_html(content) + fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") + os.close(fd) + await render_html_to_image(html_str, tmp_path) + image_path = tmp_path + rendered = True + logger.info( + "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + Path(tmp_path).name if tmp_path is not None else "", + ) + except Exception as exc: + logger.warning( + "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + exc, + ) + send_content = content + render_fallback = True + + sent_private = False + sent_group = False + group_policy_blocked = False + + async def _ensure_delivery_active() -> tuple[Any, Response | None]: + current_binding, live_err = await naga_store.ensure_delivery_active( + naga_id=naga_id, + bind_uuid=bind_uuid, + ) + if current_binding is None: + logger.warning( + "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + live_err.message + if live_err is not None + else "delivery no longer active", + ) + return None, web.json_response( + { + "ok": False, + "error": ( + live_err.message + if live_err is not None + else "delivery no longer active" + ), + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=live_err.http_status if live_err is not None else 409, + ) + return current_binding, None + + try: + cq_image: str | None = None + if image_path is not None: + file_uri = Path(image_path).resolve().as_uri() + cq_image = f"[CQ:image,file={file_uri}]" + + if mode in {"private", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + logger.info( + "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + try: + if send_content is not None: + await sender.send_private_message( + current_binding.qq_id, send_content + ) + elif cq_image is not None: + await sender.send_private_message( + current_binding.qq_id, cq_image + ) + sent_private = True + logger.info( + "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.qq_id, + message_key, + exc, + ) + + if mode in {"group", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + current_cfg = ctx.config_getter() + if current_binding.group_id not in current_cfg.naga.allowed_groups: + group_policy_blocked = True + logger.warning( + "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + else: + logger.info( + "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + try: + if send_content is not None: + await sender.send_group_message( + current_binding.group_id, send_content + ) + elif cq_image is not None: + await sender.send_group_message( + current_binding.group_id, cq_image + ) + sent_group = True + logger.info( + "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.group_id, + message_key, + exc, + ) + finally: + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + + if mode == "private" and not sent_private: + return web.json_response( + { + "ok": False, + "error": "private delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "group" and not sent_group: + return web.json_response( + { + "ok": False, + "error": "group delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "both" and not (sent_private or sent_group): + if group_policy_blocked: + return web.json_response( + { + "ok": False, + "error": "bound group is not in naga.allowed_groups", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=403, + ) + return web.json_response( + { + "ok": False, + "error": "all deliveries failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + + await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) + partial_success = mode == "both" and (sent_private != sent_group) + logger.info( + "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + sent_private, + sent_group, + partial_success, + rendered, + render_fallback, + ) + return web.json_response( + { + "ok": True, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + "sent_private": sent_private, + "sent_group": sent_group, + "partial_success": partial_success, + "delivery_status": ( + "partial_success" if partial_success else "full_success" + ), + "rendered": rendered, + "render_fallback": render_fallback, + "moderation": moderation, + } + ) + finally: + await naga_store.release_delivery(bind_uuid=bind_uuid) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/unbind +# ------------------------------------------------------------------ + + +async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + """POST /api/v1/naga/unbind — 远端主动解绑。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + logger.info( + "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + delivery_signature[:12] + "...", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + binding, changed, err = await naga_store.revoke_binding( + naga_id, + expected_bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message if err is not None else "binding not found", + ) + return _json_error( + err.message if err is not None else "binding not found", + status=err.http_status if err is not None else 404, + ) + logger.info( + "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + changed, + binding.qq_id, + binding.group_id, + ) + return web.json_response( + { + "ok": True, + "idempotent": not changed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) diff --git a/src/Undefined/api/routes/system.py b/src/Undefined/api/routes/system.py new file mode 100644 index 00000000..2cf3915f --- /dev/null +++ b/src/Undefined/api/routes/system.py @@ -0,0 +1,228 @@ +"""System / probe route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import platform +import sys +import time +from datetime import datetime +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _json_error, + _mask_url, + _naga_routes_enabled, + _registry_summary, +) +from Undefined.api._openapi import _build_openapi_spec +from Undefined.api._probes import ( + _build_internal_model_probe_payload, + _probe_http_endpoint, + _probe_ws_endpoint, + _skipped_probe, +) + +logger = logging.getLogger(__name__) + +_PROCESS_START_TIME = time.time() + + +async def openapi_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + cfg = ctx.config_getter() + if not bool(getattr(cfg.api, "openapi_enabled", True)): + logger.info( + "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote + ) + return _json_error("OpenAPI disabled", status=404) + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + logger.info( + "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", + request.remote, + naga_routes_enabled, + ) + return web.json_response(_build_openapi_spec(ctx, request)) + + +async def internal_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + queue_snapshot = ctx.queue_manager.snapshot() if ctx.queue_manager else {} + cognitive_queue_snapshot = ( + ctx.cognitive_job_queue.snapshot() if ctx.cognitive_job_queue else {} + ) + memory_storage = getattr(ctx.ai, "memory_storage", None) + memory_count = memory_storage.count() if memory_storage is not None else 0 + + # Skills 统计 + ai = ctx.ai + skills_info: dict[str, Any] = {} + if ai is not None: + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + anthropic_reg = getattr(ai, "anthropic_skill_registry", None) + skills_info["tools"] = _registry_summary(tool_reg) + skills_info["agents"] = _registry_summary(agent_reg) + skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) + + # 模型配置(脱敏) + models_info: dict[str, Any] = {} + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + for label in ( + "chat_model", + "vision_model", + "agent_model", + "security_model", + "naga_model", + "grok_model", + ): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = _build_internal_model_probe_payload(mcfg) + if summary_model is not None: + models_info["summary_model"] = _build_internal_model_probe_payload( + summary_model + ) + for label in ("embedding_model", "rerank_model"): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + + uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) + payload = { + "timestamp": datetime.now().isoformat(), + "version": __version__, + "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "platform": platform.system(), + "uptime_seconds": uptime_seconds, + "onebot": ctx.onebot.connection_status() if ctx.onebot is not None else {}, + "queues": queue_snapshot, + "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, + "cognitive": { + "enabled": bool(ctx.cognitive_service and ctx.cognitive_service.enabled), + "queue": cognitive_queue_snapshot, + }, + "api": { + "enabled": bool(cfg.api.enabled), + "host": cfg.api.host, + "port": cfg.api.port, + "openapi_enabled": bool(cfg.api.openapi_enabled), + }, + "skills": skills_info, + "models": models_info, + } + return web.json_response(payload) + + +async def external_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + naga_probe = ( + _probe_http_endpoint( + name="naga_model", + base_url=cfg.naga_model.api_url, + api_key=cfg.naga_model.api_key, + model_name=cfg.naga_model.model_name, + ) + if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) + else _skipped_probe( + name="naga_model", + reason="naga_integration_disabled", + model_name=cfg.naga_model.model_name, + ) + ) + checks = [ + _probe_http_endpoint( + name="chat_model", + base_url=cfg.chat_model.api_url, + api_key=cfg.chat_model.api_key, + model_name=cfg.chat_model.model_name, + ), + _probe_http_endpoint( + name="vision_model", + base_url=cfg.vision_model.api_url, + api_key=cfg.vision_model.api_key, + model_name=cfg.vision_model.model_name, + ), + _probe_http_endpoint( + name="security_model", + base_url=cfg.security_model.api_url, + api_key=cfg.security_model.api_key, + model_name=cfg.security_model.model_name, + ), + naga_probe, + _probe_http_endpoint( + name="agent_model", + base_url=cfg.agent_model.api_url, + api_key=cfg.agent_model.api_key, + model_name=cfg.agent_model.model_name, + ), + ] + if summary_model is not None: + checks.append( + _probe_http_endpoint( + name="summary_model", + base_url=summary_model.api_url, + api_key=summary_model.api_key, + model_name=summary_model.model_name, + ) + ) + grok_model = getattr(cfg, "grok_model", None) + if grok_model is not None: + checks.append( + _probe_http_endpoint( + name="grok_model", + base_url=getattr(grok_model, "api_url", ""), + api_key=getattr(grok_model, "api_key", ""), + model_name=getattr(grok_model, "model_name", ""), + ) + ) + checks.extend( + [ + _probe_http_endpoint( + name="embedding_model", + base_url=cfg.embedding_model.api_url, + api_key=cfg.embedding_model.api_key, + model_name=getattr(cfg.embedding_model, "model_name", ""), + ), + _probe_http_endpoint( + name="rerank_model", + base_url=cfg.rerank_model.api_url, + api_key=cfg.rerank_model.api_key, + model_name=getattr(cfg.rerank_model, "model_name", ""), + ), + _probe_ws_endpoint(cfg.onebot_ws_url), + ] + ) + results = await asyncio.gather(*checks) + ok = all(item.get("status") in {"ok", "skipped"} for item in results) + return web.json_response( + { + "ok": ok, + "timestamp": datetime.now().isoformat(), + "results": results, + } + ) diff --git a/src/Undefined/api/routes/tools.py b/src/Undefined/api/routes/tools.py new file mode 100644 index 00000000..f305179e --- /dev/null +++ b/src/Undefined/api/routes/tools.py @@ -0,0 +1,462 @@ +"""Tools route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid as _uuid +from typing import Any, Awaitable + +from aiohttp import ClientSession, ClientTimeout, web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _ToolInvokeExecutionTimeoutError, + _json_error, + _mask_url, + _to_bool, + _validate_callback_url, +) +from Undefined.context import RequestContext + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Pure helpers +# ------------------------------------------------------------------ + + +def get_filtered_tools(ctx: RuntimeAPIContext) -> list[dict[str, Any]]: + """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" + cfg = ctx.config_getter() + api_cfg = cfg.api + ai = ctx.ai + if ai is None: + return [] + + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + + all_schemas: list[dict[str, Any]] = [] + if tool_reg is not None: + all_schemas.extend(tool_reg.get_tools_schema()) + + # 收集 agent schema 并缓存名称集合(避免重复调用) + agent_names: set[str] = set() + if agent_reg is not None: + agent_schemas = agent_reg.get_agents_schema() + all_schemas.extend(agent_schemas) + for schema in agent_schemas: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + + denylist: set[str] = set(api_cfg.tool_invoke_denylist) + allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) + expose = api_cfg.tool_invoke_expose + + def _get_name(schema: dict[str, Any]) -> str: + func = schema.get("function", {}) + return str(func.get("name", "")) + + # 1. 先排除黑名单 + if denylist: + all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] + + # 2. 白名单非空时仅保留匹配项 + if allowlist: + return [s for s in all_schemas if _get_name(s) in allowlist] + + # 3. 按 expose 过滤 + if expose == "all": + return all_schemas + + def _is_tool(name: str) -> bool: + return "." not in name and name not in agent_names + + def _is_toolset(name: str) -> bool: + return "." in name and not name.startswith("mcp.") + + filtered: list[dict[str, Any]] = [] + for schema in all_schemas: + name = _get_name(schema) + if not name: + continue + if expose == "tools" and _is_tool(name): + filtered.append(schema) + elif expose == "toolsets" and _is_toolset(name): + filtered.append(schema) + elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): + filtered.append(schema) + elif expose == "agents" and name in agent_names: + filtered.append(schema) + + return filtered + + +def get_agent_tool_names(ctx: RuntimeAPIContext) -> set[str]: + ai = ctx.ai + if ai is None: + return set() + + agent_reg = getattr(ai, "agent_registry", None) + if agent_reg is None: + return set() + + agent_names: set[str] = set() + for schema in agent_reg.get_agents_schema(): + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + return agent_names + + +def resolve_tool_invoke_timeout( + ctx: RuntimeAPIContext, tool_name: str, timeout: int +) -> float | None: + if tool_name in get_agent_tool_names(ctx): + return None + return float(timeout) + + +# ------------------------------------------------------------------ +# Async helpers +# ------------------------------------------------------------------ + + +async def await_tool_invoke_result( + awaitable: Awaitable[Any], + *, + timeout: float | None, +) -> Any: + if timeout is None or timeout <= 0: + return await awaitable + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except asyncio.TimeoutError as exc: + raise _ToolInvokeExecutionTimeoutError from exc + + +# ------------------------------------------------------------------ +# Route handlers +# ------------------------------------------------------------------ + + +async def tools_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + tools = get_filtered_tools(ctx) + return web.json_response({"count": len(tools), "tools": tools}) + + +async def tools_invoke_handler( + ctx: RuntimeAPIContext, + background_tasks: set[asyncio.Task[Any]], + request: web.Request, +) -> Response: + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + if not isinstance(body, dict): + return _json_error("Request body must be a JSON object", status=400) + + tool_name = str(body.get("tool_name", "") or "").strip() + if not tool_name: + return _json_error("tool_name is required", status=400) + + args = body.get("args") + if not isinstance(args, dict): + return _json_error("args must be a JSON object", status=400) + + # 验证工具是否在允许列表中 + filtered_tools = get_filtered_tools(ctx) + available_names: set[str] = set() + for schema in filtered_tools: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + available_names.add(name) + + if tool_name not in available_names: + caller_ip = request.remote or "unknown" + logger.warning( + "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", + tool_name, + caller_ip, + ) + return _json_error(f"Tool '{tool_name}' is not available", status=404) + + # 解析回调配置 + callback_cfg = body.get("callback") + use_callback = False + callback_url = "" + callback_headers: dict[str, str] = {} + if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): + callback_url = str(callback_cfg.get("url", "") or "").strip() + if not callback_url: + return _json_error( + "callback.url is required when callback is enabled", + status=400, + ) + url_error = _validate_callback_url(callback_url) + if url_error: + return _json_error(url_error, status=400) + raw_headers = callback_cfg.get("headers") + if isinstance(raw_headers, dict): + callback_headers = {str(k): str(v) for k, v in raw_headers.items()} + use_callback = True + + request_id = _uuid.uuid4().hex + caller_ip = request.remote or "unknown" + logger.info( + "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", + request_id, + tool_name, + caller_ip, + ) + + if use_callback: + # 异步执行 + 回调 + task = asyncio.create_task( + execute_and_callback( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + callback_url=callback_url, + callback_headers=callback_headers, + timeout=cfg.api.tool_invoke_timeout, + callback_timeout=cfg.api.tool_invoke_callback_timeout, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return web.json_response( + { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "status": "accepted", + } + ) + + # 同步执行 + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + timeout=cfg.api.tool_invoke_timeout, + ) + return web.json_response(result) + + +# ------------------------------------------------------------------ +# Execution core +# ------------------------------------------------------------------ + + +async def execute_tool_invoke( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + timeout: int, +) -> dict[str, Any]: + """执行工具调用并返回结果字典。""" + ai = ctx.ai + if ai is None: + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": "AI client not ready", + "duration_ms": 0, + } + + # 解析请求上下文 + ctx_data: dict[str, Any] = {} + if isinstance(body_context, dict): + ctx_data = body_context + request_type = str(ctx_data.get("request_type", "api") or "api") + group_id = ctx_data.get("group_id") + user_id = ctx_data.get("user_id") + sender_id = ctx_data.get("sender_id") + + args_keys = list(args.keys()) + logger.info( + "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", + request_id, + tool_name, + args_keys, + ) + + start = time.perf_counter() + effective_timeout = resolve_tool_invoke_timeout(ctx, tool_name, timeout) + try: + async with RequestContext( + request_type=request_type, + group_id=int(group_id) if group_id is not None else None, + user_id=int(user_id) if user_id is not None else None, + sender_id=int(sender_id) if sender_id is not None else None, + ) as req_ctx: + # 注入核心服务资源 + if ctx.sender is not None: + req_ctx.set_resource("sender", ctx.sender) + if ctx.history_manager is not None: + req_ctx.set_resource("history_manager", ctx.history_manager) + runtime_config = getattr(ai, "runtime_config", None) + if runtime_config is not None: + req_ctx.set_resource("runtime_config", runtime_config) + memory_storage = getattr(ai, "memory_storage", None) + if memory_storage is not None: + req_ctx.set_resource("memory_storage", memory_storage) + if ctx.onebot is not None: + req_ctx.set_resource("onebot_client", ctx.onebot) + if ctx.scheduler is not None: + req_ctx.set_resource("scheduler", ctx.scheduler) + if ctx.cognitive_service is not None: + req_ctx.set_resource("cognitive_service", ctx.cognitive_service) + if ctx.meme_service is not None: + req_ctx.set_resource("meme_service", ctx.meme_service) + + tool_context: dict[str, Any] = { + "request_type": request_type, + "request_id": request_id, + } + if group_id is not None: + tool_context["group_id"] = int(group_id) + if user_id is not None: + tool_context["user_id"] = int(user_id) + if sender_id is not None: + tool_context["sender_id"] = int(sender_id) + + tool_manager = getattr(ai, "tool_manager", None) + if tool_manager is None: + raise RuntimeError("ToolManager not available") + + raw_result = await await_tool_invoke_result( + tool_manager.execute_tool(tool_name, args, tool_context), + timeout=effective_timeout, + ) + + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + result_str = str(raw_result or "") + logger.info( + "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " + "duration_ms=%s result_len=%d", + request_id, + tool_name, + elapsed_ms, + len(result_str), + ) + return { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "result": result_str, + "duration_ms": elapsed_ms, + } + + except _ToolInvokeExecutionTimeoutError: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.warning( + "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", + request_id, + tool_name, + timeout, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": f"Execution timed out after {timeout}s", + "duration_ms": elapsed_ms, + } + except Exception as exc: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.exception( + "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", + request_id, + tool_name, + exc, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": str(exc), + "duration_ms": elapsed_ms, + } + + +async def execute_and_callback( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + callback_url: str, + callback_headers: dict[str, str], + timeout: int, + callback_timeout: int, +) -> None: + """异步执行工具并发送回调。""" + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body_context, + timeout=timeout, + ) + + payload = { + "request_id": result["request_id"], + "tool_name": result["tool_name"], + "ok": result["ok"], + "result": result.get("result"), + "duration_ms": result.get("duration_ms", 0), + "error": result.get("error"), + } + + try: + cb_timeout = ClientTimeout(total=callback_timeout) + async with ClientSession(timeout=cb_timeout) as session: + async with session.post( + callback_url, + json=payload, + headers=callback_headers or None, + ) as resp: + logger.info( + "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", + request_id, + _mask_url(callback_url), + resp.status, + ) + except Exception as exc: + logger.warning( + "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", + request_id, + _mask_url(callback_url), + exc, + ) diff --git a/src/Undefined/arxiv/sender.py b/src/Undefined/arxiv/sender.py index 214ccebb..1cecf06b 100644 --- a/src/Undefined/arxiv/sender.py +++ b/src/Undefined/arxiv/sender.py @@ -4,6 +4,7 @@ import asyncio import logging +import time from typing import TYPE_CHECKING, Literal from Undefined.arxiv.client import get_paper_info @@ -19,6 +20,31 @@ _INFLIGHT_LOCK = asyncio.Lock() _INFLIGHT_SENDS: dict[tuple[str, int, str], asyncio.Future[str]] = {} +# Time-based dedup: maps (target_type, target_id, paper_id) → monotonic timestamp +_RECENT_SENDS: dict[tuple[str, int, str], float] = {} +_DEDUP_COOLDOWN_SECONDS: float = 3600.0 # 1 hour +_RECENT_SENDS_MAX_SIZE: int = 1000 + + +def _cleanup_expired_recent_sends() -> None: + """Remove expired entries from _RECENT_SENDS. Must be called under _INFLIGHT_LOCK.""" + now = time.monotonic() + expired = [ + k for k, v in _RECENT_SENDS.items() if now - v >= _DEDUP_COOLDOWN_SECONDS + ] + for k in expired: + del _RECENT_SENDS[k] + + +def _evict_oldest_recent_sends() -> None: + """Evict oldest entries if _RECENT_SENDS exceeds max size. Must be called under _INFLIGHT_LOCK.""" + if len(_RECENT_SENDS) <= _RECENT_SENDS_MAX_SIZE: + return + sorted_keys = sorted(_RECENT_SENDS, key=lambda k: _RECENT_SENDS[k]) + excess = len(_RECENT_SENDS) - _RECENT_SENDS_MAX_SIZE + for k in sorted_keys[:excess]: + del _RECENT_SENDS[k] + def _build_abs_url(paper_id: str) -> str: return f"https://arxiv.org/abs/{paper_id}" @@ -203,6 +229,24 @@ async def send_arxiv_paper( created = False async with _INFLIGHT_LOCK: + # Lazy cleanup of expired entries + _cleanup_expired_recent_sends() + + # Check time-based dedup first + recent_ts = _RECENT_SENDS.get(key) + if ( + recent_ts is not None + and (time.monotonic() - recent_ts) < _DEDUP_COOLDOWN_SECONDS + ): + logger.info( + "[arXiv] 论文近期已发送,跳过: paper=%s target=%s:%s", + normalized, + target_type, + target_id, + ) + return f"论文 {normalized} 近期已发送过,已跳过" + + # Check inflight dedup future = _INFLIGHT_SENDS.get(key) if future is None: future = asyncio.get_running_loop().create_future() @@ -242,3 +286,7 @@ async def send_arxiv_paper( current = _INFLIGHT_SENDS.get(key) if current is future: _INFLIGHT_SENDS.pop(key, None) + # Record successful send time for dedup cooldown + if future.done() and not future.cancelled() and future.exception() is None: + _RECENT_SENDS[key] = time.monotonic() + _evict_oldest_recent_sends() diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index b0c09af6..403154dc 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -33,12 +33,21 @@ r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", re.IGNORECASE, ) +_ATTACHMENT_TAG_PATTERN = re.compile( + r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_UNIFIED_TAG_PATTERN = re.compile( + r"<(?Ppic|attachment)\s+uid=(?P[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) _MEDIA_LABELS = { "image": "图片", "file": "文件", "audio": "音频", "video": "视频", "record": "语音", + "pic": "图片", } _WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") _DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 @@ -107,10 +116,11 @@ class RenderedRichMessage: delivery_text: str history_text: str attachments: list[dict[str, str]] + pending_file_sends: tuple[AttachmentRecord, ...] = () class AttachmentRenderError(RuntimeError): - """Raised when a `` tag cannot be rendered.""" + """Raised when an attachment tag cannot be rendered.""" def _now_iso() -> str: @@ -681,6 +691,25 @@ def _build_uid(self, prefix: str) -> str: if uid not in self._records: return uid + def _find_by_sha256( + self, scope_key: str, sha256: str, kind: str + ) -> AttachmentRecord | None: + """Find an existing record with matching scope, kind, and SHA-256. + + Only returns a record whose *local_path* still exists on disk. + Must be called while ``self._lock`` is held. + """ + for record in self._records.values(): + if ( + record.scope_key == scope_key + and record.sha256 == sha256 + and record.kind == kind + and record.local_path + and Path(record.local_path).is_file() + ): + return record + return None + async def register_bytes( self, scope_key: str, @@ -703,15 +732,18 @@ async def register_bytes( prefix = "pic" if normalized_media_type == "image" else "file" async with self._lock: + digest = await asyncio.to_thread(hashlib.sha256, content) + digest_hex = digest.hexdigest() + + existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) + if existing is not None: + return existing + uid = self._build_uid(prefix) file_name = f"{uid}{suffix}" cache_path = ensure_dir(self._cache_dir) / file_name + await asyncio.to_thread(cache_path.write_bytes, content) - def _write() -> str: - cache_path.write_bytes(content) - return hashlib.sha256(content).hexdigest() - - digest = await asyncio.to_thread(_write) record = AttachmentRecord( uid=uid, scope_key=scope_key, @@ -722,7 +754,7 @@ def _write() -> str: source_ref=source_ref, local_path=str(cache_path), mime_type=normalized_mime, - sha256=digest, + sha256=digest_hex, created_at=_now_iso(), segment_data={ str(k): str(v) @@ -1064,31 +1096,38 @@ async def _collect_from_segments( ) -async def render_message_with_pic_placeholders( +async def render_message_with_attachments( message: str, *, registry: AttachmentRegistry | None, scope_key: str | None, strict: bool, ) -> RenderedRichMessage: - if ( - not message - or registry is None - or not scope_key - or "`` and ```` tags into delivery/history text. + + * ```` — backward-compatible, image-only. + * ```` — unified tag for any media type. + Images (``pic_*``) are inlined as CQ images; files (``file_*``) + are collected into *pending_file_sends* for later dispatch. + """ + has_tags = message and ( + " tag: strictly image-only + if tag_name == "pic" and record.media_type != "image": replacement = f"[图片 uid={uid} 类型错误]" if strict: raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") @@ -1110,32 +1152,21 @@ async def render_message_with_pic_placeholders( history_parts.append(replacement) continue - image_source = record.source_ref - if record.local_path: - image_source = Path(record.local_path).resolve().as_uri() - elif not image_source: - replacement = f"[图片 uid={uid} 缺少文件]" - if strict: - raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - continue - - cq_args = [f"file={image_source}"] - for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): - cleaned_key = str(key or "").strip() - cleaned_value = str(value or "").strip() - if not cleaned_key or not cleaned_value or cleaned_key == "file": - continue - cq_args.append( - f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" - ) - delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") - if record.display_name: - history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + # Route by media type + if record.media_type == "image": + ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) else: - history_parts.append(f"[图片 uid={uid}]") - attachments.append(record.prompt_ref()) + ok = _render_file_tag( + record, + uid, + strict, + delivery_parts, + history_parts, + pending_files, + ) + + if ok: + attachments.append(record.prompt_ref()) delivery_parts.append(message[last_index:]) history_parts.append(message[last_index:]) @@ -1143,4 +1174,115 @@ async def render_message_with_pic_placeholders( delivery_text="".join(delivery_parts), history_text="".join(history_parts), attachments=attachments, + pending_file_sends=tuple(pending_files), ) + + +def _render_image_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], +) -> bool: + """Render an image attachment as an inline CQ:image. Returns True on success.""" + image_source = record.source_ref + if record.local_path: + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + cq_args = [f"file={image_source}"] + for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): + cleaned_key = str(key or "").strip() + cleaned_value = str(value or "").strip() + if not cleaned_key or not cleaned_value or cleaned_key == "file": + continue + cq_args.append( + f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" + ) + delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + return True + + +def _render_file_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], + pending_files: list[AttachmentRecord], +) -> bool: + """Render a non-image attachment as a pending file send. Returns True on success.""" + if not record.local_path or not Path(record.local_path).is_file(): + replacement = f"[文件 uid={uid} 缺少本地文件]" + if strict: + raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return False + + # Remove from delivery text (file sent separately) + # Keep a readable placeholder in history + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + return True + + +# Backward-compatible alias +render_message_with_pic_placeholders = render_message_with_attachments + + +async def dispatch_pending_file_sends( + rendered: RenderedRichMessage, + *, + sender: Any, + target_type: str, + target_id: int, +) -> None: + """Send pending file attachments collected by *render_message_with_attachments*. + + This is best-effort: each file send failure is logged but does not interrupt + the remaining sends or the caller. + """ + if not rendered.pending_file_sends or sender is None: + return + for record in rendered.pending_file_sends: + if not record.local_path or not Path(record.local_path).is_file(): + logger.warning( + "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", + record.uid, + record.local_path, + ) + continue + try: + if target_type == "group": + await sender.send_group_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + else: + await sender.send_private_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + except Exception: + logger.warning( + "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", + record.uid, + target_type, + target_id, + exc_info=True, + ) diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian.py index fde5d24a..fda333c3 100644 --- a/src/Undefined/cognitive/historian.py +++ b/src/Undefined/cognitive/historian.py @@ -461,7 +461,7 @@ async def _rewrite( recent_messages = [ str(item).strip() for item in recent_messages_raw if str(item).strip() ] - recent_messages_text = "\n".join(f"- {line}" for line in recent_messages) + recent_messages_text = "\n---\n".join(recent_messages) prompt = template.format( request_id=job.get("request_id", ""), end_seq=job.get("end_seq", 0), diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index 9b7405e9..13cf3f1c 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, cast from Undefined.context import RequestContext +from Undefined.utils.coerce import safe_float logger = logging.getLogger(__name__) @@ -39,17 +40,6 @@ def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: return {"$and": clauses} -def _safe_float(value: Any, default: float = 0.0) -> float: - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - try: - return float(value.strip()) - except Exception: - return default - return default - - def _event_base_score(item: dict[str, Any]) -> float: rerank_score = item.get("rerank_score") if isinstance(rerank_score, (int, float)): @@ -59,7 +49,7 @@ def _event_base_score(item: dict[str, Any]) -> float: return max(0.0, float(rerank_score.strip())) except Exception: pass - similarity = 1.0 - _safe_float(item.get("distance"), default=1.0) + similarity = 1.0 - safe_float(item.get("distance"), default=1.0) if similarity < 0.0: return 0.0 if similarity > 1.0: @@ -336,7 +326,7 @@ def _merge_weighted_events( ] = [] serial = 0 for scoped_events, scope_weight in scoped_results: - safe_scope_weight = max(0.0, _safe_float(scope_weight, default=1.0)) + safe_scope_weight = max(0.0, safe_float(scope_weight, default=1.0)) scope_size = max(1, len(scoped_events)) for rank_idx, event in enumerate(scoped_events): dedupe_key = _event_dedupe_key(event) @@ -399,12 +389,12 @@ async def _query_events_for_auto_context( if scope_candidate_multiplier <= 0: scope_candidate_multiplier = 2 scoped_top_k = max(safe_top_k, safe_top_k * scope_candidate_multiplier) - current_group_boost = _safe_float( + current_group_boost = safe_float( getattr(config, "auto_current_group_boost", 1.15), default=1.15 ) if current_group_boost <= 0: current_group_boost = 1.15 - current_private_boost = _safe_float( + current_private_boost = safe_float( getattr(config, "auto_current_private_boost", 1.25), default=1.25 ) if current_private_boost <= 0: @@ -735,6 +725,7 @@ async def build_context( top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) try: events = await self._query_events_for_auto_context( query=query, @@ -829,6 +820,7 @@ async def search_events(self, query: str, **kwargs: Any) -> list[dict[str, Any]] top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) logger.info( "[认知服务] 搜索事件: query_len=%s top_k=%s where=%s time_from=%s time_to=%s", len(query or ""), @@ -882,6 +874,7 @@ async def search_profiles(self, query: str, **kwargs: Any) -> list[dict[str, Any top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) where: dict[str, Any] | None = None entity_type_raw = kwargs.get("entity_type") diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index b6d1c7c1..48b80eb1 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -11,6 +11,8 @@ from typing import Any import chromadb + +from Undefined.utils.coerce import safe_float from chromadb.errors import InternalError as ChromaInternalError import numpy as np from numba import njit @@ -32,24 +34,15 @@ def _clamp(value: float, lower: float, upper: float) -> float: return value -def _safe_float(value: Any, default: float = 0.0) -> float: - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - try: - return float(value.strip()) - except Exception: - return default - return default - - -def _safe_positive_int(value: Any, default: int) -> int: +def _safe_positive_int(value: Any, default: int, maximum: int = 0) -> int: try: parsed = int(value) except Exception: return max(1, int(default)) if parsed <= 0: return max(1, int(default)) + if maximum > 0 and parsed > maximum: + return maximum return parsed @@ -120,7 +113,7 @@ def _sanitize_metadata(metadata: dict[str, Any]) -> dict[str, Any]: def _similarity_from_distance(distance: Any) -> float: - dist = _safe_float(distance, default=1.0) + dist = safe_float(distance, default=1.0) return _clamp(1.0 - dist, 0.0, 1.0) @@ -436,7 +429,7 @@ async def _query( query_embedding: list[float] | None = None, ) -> list[dict[str, Any]]: col_name = getattr(col, "name", "unknown") - safe_top_k = _safe_positive_int(top_k, default=1) + safe_top_k = _safe_positive_int(top_k, default=1, maximum=500) safe_multiplier = _safe_positive_int(candidate_multiplier, default=1) total_started = time.perf_counter() logger.debug( @@ -464,6 +457,7 @@ async def _query( use_reranker or apply_time_decay or apply_mmr ) fetch_k = safe_top_k * safe_multiplier if use_extra_candidates else safe_top_k + fetch_k = min(fetch_k, 10000) include: list[str] = ["documents", "metadatas", "distances"] if apply_mmr: include.append("embeddings") @@ -553,14 +547,14 @@ def _q() -> Any: else: reranked_results: list[dict[str, Any]] = [] for item in reranked[:rerank_top_n]: - index = int(_safe_float(item.get("index"), default=-1)) + index = int(safe_float(item.get("index"), default=-1)) if index < 0 or index >= len(results): continue entry: dict[str, Any] = { "document": item.get("document", results[index]["document"]), "metadata": results[index]["metadata"], "distance": results[index]["distance"], - "rerank_score": _safe_float( + "rerank_score": safe_float( item.get("relevance_score"), default=0.0 ), } @@ -637,11 +631,9 @@ def _apply_time_decay_ranking( collection_name: str, ) -> list[dict[str, Any]]: safe_top_k = max(1, int(top_k)) - safe_half_life_days = _safe_float(half_life_days, default=14.0) - safe_boost = max(0.0, _safe_float(boost, default=0.2)) - safe_min_similarity = _clamp( - _safe_float(min_similarity, default=0.35), 0.0, 1.0 - ) + safe_half_life_days = safe_float(half_life_days, default=14.0) + safe_boost = max(0.0, safe_float(boost, default=0.2)) + safe_min_similarity = _clamp(safe_float(min_similarity, default=0.35), 0.0, 1.0) if safe_half_life_days <= 0: logger.warning( "[认知向量库] 时间衰减参数非法,跳过时间加权: collection=%s half_life_days=%s", diff --git a/src/Undefined/config/admin.py b/src/Undefined/config/admin.py new file mode 100644 index 00000000..16dcfe65 --- /dev/null +++ b/src/Undefined/config/admin.py @@ -0,0 +1,44 @@ +"""Local admin management (config.local.json).""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +LOCAL_CONFIG_PATH = Path("config.local.json") + + +def load_local_admins() -> list[int]: + """从本地配置文件加载动态管理员列表""" + if not LOCAL_CONFIG_PATH.exists(): + return [] + try: + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + admin_qqs: list[int] = data.get("admin_qqs", []) + return admin_qqs + except Exception as exc: + logger.warning("读取本地配置失败: %s", exc) + return [] + + +def save_local_admins(admin_qqs: list[int]) -> None: + """保存动态管理员列表到本地配置文件""" + try: + data: dict[str, list[int]] = {} + if LOCAL_CONFIG_PATH.exists(): + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + + data["admin_qqs"] = admin_qqs + + with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) + except Exception as exc: + logger.error("保存本地配置失败: %s", exc) + raise diff --git a/src/Undefined/config/coercers.py b/src/Undefined/config/coercers.py new file mode 100644 index 00000000..4edff34a --- /dev/null +++ b/src/Undefined/config/coercers.py @@ -0,0 +1,150 @@ +"""Type coercion and normalization helpers for config loading.""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +from Undefined.utils.request_params import normalize_request_params + +logger = logging.getLogger(__name__) + +_ENV_WARNED_KEYS: set[str] = set() + + +def _warn_env_fallback(name: str) -> None: + if name in _ENV_WARNED_KEYS: + return + _ENV_WARNED_KEYS.add(name) + logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) + + +def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: + node: Any = data + for key in path: + if not isinstance(node, dict) or key not in node: + return None + node = node[key] + return node + + +def _normalize_str(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped if stripped else None + return str(value).strip() + + +def _coerce_int(value: Any, default: int) -> int: + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _coerce_float(value: Any, default: float) -> float: + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_queue_interval(value: float, default: float = 1.0) -> float: + """规范化队列发车间隔。 + + `0` 表示立即发车,负数回退到默认值。 + """ + + return default if value < 0 else value + + +def _coerce_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return default + + +def _coerce_str(value: Any, default: str) -> str: + normalized = _normalize_str(value) + return normalized if normalized is not None else default + + +def _normalize_base_url(value: str, default: str) -> str: + normalized = value.strip().rstrip("/") + if normalized: + return normalized + return default.rstrip("/") + + +def _coerce_int_list(value: Any) -> list[int]: + if value is None: + return [] + if isinstance(value, list): + items: list[int] = [] + for item in value: + try: + items.append(int(item)) + except (TypeError, ValueError): + continue + return items + if isinstance(value, str): + parts = [part.strip() for part in value.split(",") if part.strip()] + items = [] + for part in parts: + try: + items.append(int(part)) + except ValueError: + continue + return items + return [] + + +def _coerce_str_list(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + if isinstance(value, str): + return [part.strip() for part in value.split(",") if part.strip()] + return [] + + +def _coerce_request_params(value: Any) -> dict[str, Any]: + return normalize_request_params(value) + + +def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: + return _coerce_request_params( + _get_nested(data, ("models", model_name, "request_params")) + ) + + +def _get_value( + data: dict[str, Any], + path: tuple[str, ...], + env_key: Optional[str], +) -> Any: + value = _get_nested(data, path) + if value is not None: + return value + if env_key: + env_value = os.getenv(env_key) + if env_value is not None and str(env_value).strip() != "": + _warn_env_fallback(env_key) + return env_value + return None + + +_VALID_API_MODES = {"chat_completions", "responses"} +_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py new file mode 100644 index 00000000..07f807b1 --- /dev/null +++ b/src/Undefined/config/domain_parsers.py @@ -0,0 +1,286 @@ +"""Domain configuration parsers (cognitive, memes, API, naga).""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _coerce_int_list, + _coerce_str_list, +) +from .models import ( + APIConfig, + CognitiveConfig, + MemeConfig, + NagaConfig, +) + +DEFAULT_API_HOST = "127.0.0.1" +DEFAULT_API_PORT = 8788 +DEFAULT_API_AUTH_KEY = "changeme" + + +def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: + cog = data.get("cognitive", {}) + vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} + q = cog.get("query", {}) if isinstance(cog, dict) else {} + hist = cog.get("historian", {}) if isinstance(cog, dict) else {} + prof = cog.get("profile", {}) if isinstance(cog, dict) else {} + que = cog.get("queue", {}) if isinstance(cog, dict) else {} + return CognitiveConfig( + enabled=_coerce_bool( + cog.get("enabled") if isinstance(cog, dict) else None, True + ), + bot_name=_coerce_str( + cog.get("bot_name") if isinstance(cog, dict) else None, + "Undefined", + ), + vector_store_path=_coerce_str( + vs.get("path") if isinstance(vs, dict) else None, + "data/cognitive/chromadb", + ), + queue_path=_coerce_str( + que.get("path") if isinstance(que, dict) else None, + "data/cognitive/queues", + ), + profiles_path=_coerce_str( + prof.get("path") if isinstance(prof, dict) else None, + "data/cognitive/profiles", + ), + auto_top_k=_coerce_int(q.get("auto_top_k") if isinstance(q, dict) else None, 3), + auto_scope_candidate_multiplier=_coerce_int( + q.get("auto_scope_candidate_multiplier") if isinstance(q, dict) else None, + 2, + ), + auto_current_group_boost=_coerce_float( + q.get("auto_current_group_boost") if isinstance(q, dict) else None, + 1.15, + ), + auto_current_private_boost=_coerce_float( + q.get("auto_current_private_boost") if isinstance(q, dict) else None, + 1.25, + ), + enable_rerank=_coerce_bool( + q.get("enable_rerank") if isinstance(q, dict) else None, True + ), + recent_end_summaries_inject_k=_coerce_int( + q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, + 30, + ), + time_decay_enabled=_coerce_bool( + q.get("time_decay_enabled") if isinstance(q, dict) else None, True + ), + time_decay_half_life_days_auto=_coerce_float( + q.get("time_decay_half_life_days_auto") if isinstance(q, dict) else None, + 14.0, + ), + time_decay_half_life_days_tool=_coerce_float( + q.get("time_decay_half_life_days_tool") if isinstance(q, dict) else None, + 60.0, + ), + time_decay_boost=_coerce_float( + q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 + ), + time_decay_min_similarity=_coerce_float( + q.get("time_decay_min_similarity") if isinstance(q, dict) else None, + 0.35, + ), + tool_default_top_k=_coerce_int( + q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 + ), + profile_top_k=_coerce_int( + q.get("profile_top_k") if isinstance(q, dict) else None, 8 + ), + rerank_candidate_multiplier=_coerce_int( + q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 + ), + rewrite_max_retry=_coerce_int( + hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 + ), + historian_recent_messages_inject_k=_coerce_int( + hist.get("recent_messages_inject_k") if isinstance(hist, dict) else None, + 12, + ), + historian_recent_message_line_max_len=_coerce_int( + hist.get("recent_message_line_max_len") if isinstance(hist, dict) else None, + 240, + ), + historian_source_message_max_len=_coerce_int( + hist.get("source_message_max_len") if isinstance(hist, dict) else None, + 800, + ), + poll_interval_seconds=_coerce_float( + hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, + 1.0, + ), + stale_job_timeout_seconds=_coerce_float( + hist.get("stale_job_timeout_seconds") if isinstance(hist, dict) else None, + 300.0, + ), + profile_revision_keep=_coerce_int( + prof.get("revision_keep") if isinstance(prof, dict) else None, 5 + ), + failed_max_age_days=_coerce_int( + que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 + ), + failed_max_files=_coerce_int( + que.get("failed_max_files") if isinstance(que, dict) else None, 500 + ), + failed_cleanup_interval=_coerce_int( + que.get("failed_cleanup_interval") if isinstance(que, dict) else None, + 100, + ), + job_max_retries=_coerce_int( + que.get("job_max_retries") if isinstance(que, dict) else None, 3 + ), + ) + + +def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: + section_raw = data.get("memes", {}) + section = section_raw if isinstance(section_raw, dict) else {} + return MemeConfig( + enabled=_coerce_bool(section.get("enabled"), True), + query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), + max_source_image_bytes=max( + 1, + _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), + ), + blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), + preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), + db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), + vector_store_path=_coerce_str( + section.get("vector_store_path"), "data/memes/chromadb" + ), + queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), + max_items=max(1, _coerce_int(section.get("max_items"), 10000)), + max_total_bytes=max( + 1, + _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), + ), + allow_gif=_coerce_bool(section.get("allow_gif"), True), + auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), + auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), + keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), + semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), + rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), + worker_max_concurrency=max( + 1, _coerce_int(section.get("worker_max_concurrency"), 4) + ), + gif_analysis_mode=_coerce_str(section.get("gif_analysis_mode"), "grid"), + gif_analysis_frames=max(1, _coerce_int(section.get("gif_analysis_frames"), 6)), + ) + + +def _parse_api_config(data: dict[str, Any]) -> APIConfig: + section_raw = data.get("api", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), True) + host = _coerce_str(section.get("host"), DEFAULT_API_HOST) + port = _coerce_int(section.get("port"), DEFAULT_API_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_API_PORT + + auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) + if not auth_key: + auth_key = DEFAULT_API_AUTH_KEY + + openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) + + tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) + tool_invoke_expose = _coerce_str( + section.get("tool_invoke_expose"), "tools+toolsets" + ) + valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} + if tool_invoke_expose not in valid_expose: + tool_invoke_expose = "tools+toolsets" + tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) + tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) + tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) + if tool_invoke_timeout <= 0: + tool_invoke_timeout = 120 + tool_invoke_callback_timeout = _coerce_int( + section.get("tool_invoke_callback_timeout"), 10 + ) + if tool_invoke_callback_timeout <= 0: + tool_invoke_callback_timeout = 10 + + return APIConfig( + enabled=enabled, + host=host, + port=port, + auth_key=auth_key, + openapi_enabled=openapi_enabled, + tool_invoke_enabled=tool_invoke_enabled, + tool_invoke_expose=tool_invoke_expose, + tool_invoke_allowlist=tool_invoke_allowlist, + tool_invoke_denylist=tool_invoke_denylist, + tool_invoke_timeout=tool_invoke_timeout, + tool_invoke_callback_timeout=tool_invoke_callback_timeout, + ) + + +def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: + section_raw = data.get("naga", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), False) + api_url = _coerce_str(section.get("api_url"), "") + api_key = _coerce_str(section.get("api_key"), "") + moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) + allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) + + return NagaConfig( + enabled=enabled, + api_url=api_url, + api_key=api_key, + moderation_enabled=moderation_enabled, + allowed_groups=allowed_groups, + ) + + +def _parse_easter_egg_call_mode(value: Any) -> str: + """解析彩蛋提示模式。 + + 兼容旧版布尔值: + - True => agent + - False => none + """ + if isinstance(value, bool): + return "agent" if value else "none" + if isinstance(value, (int, float)): + return "agent" if bool(value) else "none" + if value is None: + return "none" + + text = str(value).strip().lower() + if text in {"true", "1", "yes", "on"}: + return "agent" + if text in {"false", "0", "no", "off"}: + return "none" + if text in {"none", "agent", "tools", "all", "clean"}: + return text + return "none" + + +def _update_dataclass( + old_value: Any, new_value: Any, prefix: str +) -> dict[str, tuple[Any, Any]]: + changes: dict[str, tuple[Any, Any]] = {} + if not isinstance(old_value, type(new_value)): + changes[prefix] = (old_value, new_value) + return changes + for field in fields(old_value): + name = field.name + old_attr = getattr(old_value, name) + new_attr = getattr(new_value, name) + if old_attr != new_attr: + setattr(old_value, name, new_attr) + changes[f"{prefix}.{name}"] = (old_attr, new_attr) + return changes diff --git a/src/Undefined/config/hot_reload.py b/src/Undefined/config/hot_reload.py index c82d8c19..f0d83c9d 100644 --- a/src/Undefined/config/hot_reload.py +++ b/src/Undefined/config/hot_reload.py @@ -42,6 +42,8 @@ "security_model.queue_interval_seconds", "naga_model.queue_interval_seconds", "agent_model.queue_interval_seconds", + "summary_model.queue_interval_seconds", + "historian_model.queue_interval_seconds", "grok_model.queue_interval_seconds", "chat_model.pool", "agent_model.pool", @@ -53,6 +55,8 @@ "security_model.model_name", "naga_model.model_name", "agent_model.model_name", + "summary_model.model_name", + "historian_model.model_name", "grok_model.model_name", } diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index ff8486af..992829a5 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import logging import os import re @@ -37,38 +36,90 @@ def load_dotenv( ImageGenConfig, ImageGenModelConfig, MemeConfig, - ModelPool, - ModelPoolEntry, NagaConfig, RerankModelConfig, SecurityModelConfig, VisionModelConfig, ) -from Undefined.utils.request_params import ( - merge_request_params, - normalize_request_params, +from .coercers import ( # noqa: F401 — re-exported for backward compat + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_int_list, + _coerce_str, + _coerce_str_list, + _get_model_request_params, + _get_value, + _normalize_base_url, + _normalize_queue_interval, + _normalize_str, + _warn_env_fallback, +) +from .resolvers import ( # noqa: F401 — re-exported for backward compat + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) +from .admin import ( # noqa: F401 — re-exported for backward compat + LOCAL_CONFIG_PATH, + load_local_admins, + save_local_admins, +) +from .webui_settings import ( # noqa: F401 — re-exported for backward compat + DEFAULT_WEBUI_PASSWORD, + DEFAULT_WEBUI_PORT, + DEFAULT_WEBUI_URL, + WebUISettings, + load_webui_settings, +) +from .model_parsers import ( + _log_debug_info, + _merge_admins, + _parse_agent_model_config, + _parse_chat_model_config, + _parse_embedding_model_config, + _parse_grok_model_config, + _parse_historian_model_config, + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, + _parse_naga_model_config, + _parse_rerank_model_config, + _parse_security_model_config, + _parse_summary_model_config, + _parse_vision_model_config, + _verify_required_fields, +) +from .domain_parsers import ( + _parse_api_config, + _parse_cognitive_config, + _parse_easter_egg_call_mode, + _parse_memes_config, + _parse_naga_config, + _update_dataclass, ) logger = logging.getLogger(__name__) CONFIG_PATH = Path("config.toml") -LOCAL_CONFIG_PATH = Path("config.local.json") - -DEFAULT_WEBUI_URL = "127.0.0.1" -DEFAULT_WEBUI_PORT = 8787 -DEFAULT_WEBUI_PASSWORD = "changeme" -DEFAULT_API_HOST = "127.0.0.1" -DEFAULT_API_PORT = 8788 -DEFAULT_API_AUTH_KEY = "changeme" - -_ENV_WARNED_KEYS: set[str] = set() - -def _warn_env_fallback(name: str) -> None: - if name in _ENV_WARNED_KEYS: - return - _ENV_WARNED_KEYS.add(name) - logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) +# Re-export symbols that external modules import from this module. +__all__ = [ + "CONFIG_PATH", + "Config", + "DEFAULT_WEBUI_PASSWORD", + "DEFAULT_WEBUI_PORT", + "DEFAULT_WEBUI_URL", + "LOCAL_CONFIG_PATH", + "WebUISettings", + "load_local_admins", + "load_toml_data", + "load_webui_settings", + "save_local_admins", +] def _load_env() -> None: @@ -143,308 +194,6 @@ def load_toml_data( return {} -def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: - node: Any = data - for key in path: - if not isinstance(node, dict) or key not in node: - return None - node = node[key] - return node - - -def _normalize_str(value: Any) -> Optional[str]: - if value is None: - return None - if isinstance(value, str): - stripped = value.strip() - return stripped if stripped else None - return str(value).strip() - - -def _coerce_int(value: Any, default: int) -> int: - if value is None: - return default - try: - return int(value) - except (TypeError, ValueError): - return default - - -def _coerce_float(value: Any, default: float) -> float: - if value is None: - return default - try: - return float(value) - except (TypeError, ValueError): - return default - - -def _normalize_queue_interval(value: float, default: float = 1.0) -> float: - """规范化队列发车间隔。 - - `0` 表示立即发车,负数回退到默认值。 - """ - - return default if value < 0 else value - - -def _coerce_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "on"} - return default - - -def _coerce_str(value: Any, default: str) -> str: - normalized = _normalize_str(value) - return normalized if normalized is not None else default - - -def _normalize_base_url(value: str, default: str) -> str: - normalized = value.strip().rstrip("/") - if normalized: - return normalized - return default.rstrip("/") - - -def _coerce_int_list(value: Any) -> list[int]: - if value is None: - return [] - if isinstance(value, list): - items: list[int] = [] - for item in value: - try: - items.append(int(item)) - except (TypeError, ValueError): - continue - return items - if isinstance(value, str): - parts = [part.strip() for part in value.split(",") if part.strip()] - items = [] - for part in parts: - try: - items.append(int(part)) - except ValueError: - continue - return items - return [] - - -def _coerce_str_list(value: Any) -> list[str]: - if value is None: - return [] - if isinstance(value, list): - return [str(item).strip() for item in value if str(item).strip()] - if isinstance(value, str): - return [part.strip() for part in value.split(",") if part.strip()] - return [] - - -def _coerce_request_params(value: Any) -> dict[str, Any]: - return normalize_request_params(value) - - -def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: - return _coerce_request_params( - _get_nested(data, ("models", model_name, "request_params")) - ) - - -def _get_value( - data: dict[str, Any], - path: tuple[str, ...], - env_key: Optional[str], -) -> Any: - value = _get_nested(data, path) - if value is not None: - return value - if env_key: - env_value = os.getenv(env_key) - if env_value is not None and str(env_value).strip() != "": - _warn_env_fallback(env_key) - return env_value - return None - - -_VALID_API_MODES = {"chat_completions", "responses"} -_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} - - -def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: - style = _coerce_str(value, default).strip().lower() - if style not in _VALID_REASONING_EFFORT_STYLES: - return default - return style - - -def _resolve_thinking_compat_flags( - data: dict[str, Any], - model_name: str, - include_budget_env_key: str, - tool_call_compat_env_key: str, - legacy_env_key: str, -) -> tuple[bool, bool]: - """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" - include_budget_value = _get_value( - data, - ("models", model_name, "thinking_include_budget"), - include_budget_env_key, - ) - tool_call_compat_value = _get_value( - data, - ("models", model_name, "thinking_tool_call_compat"), - tool_call_compat_env_key, - ) - legacy_value = _get_value( - data, - ("models", model_name, "deepseek_new_cot_support"), - legacy_env_key, - ) - - include_budget_default = True - tool_call_compat_default = True - if legacy_value is not None: - legacy_enabled = _coerce_bool(legacy_value, False) - include_budget_default = not legacy_enabled - tool_call_compat_default = legacy_enabled - - return ( - _coerce_bool(include_budget_value, include_budget_default), - _coerce_bool(tool_call_compat_value, tool_call_compat_default), - ) - - -def _resolve_api_mode( - data: dict[str, Any], - model_name: str, - env_key: str, - default: str = "chat_completions", -) -> str: - raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) - value = _coerce_str(raw_value, default).strip().lower() - if value not in _VALID_API_MODES: - return default - return value - - -def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: - return _coerce_str(value, default).strip().lower() - - -def _resolve_responses_tool_choice_compat( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_tool_choice_compat"), - env_key, - ), - default, - ) - - -def _resolve_responses_force_stateless_replay( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_force_stateless_replay"), - env_key, - ), - default, - ) - - -def load_local_admins() -> list[int]: - """从本地配置文件加载动态管理员列表""" - if not LOCAL_CONFIG_PATH.exists(): - return [] - try: - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - admin_qqs: list[int] = data.get("admin_qqs", []) - return admin_qqs - except Exception as exc: - logger.warning("读取本地配置失败: %s", exc) - return [] - - -def save_local_admins(admin_qqs: list[int]) -> None: - """保存动态管理员列表到本地配置文件""" - try: - data: dict[str, list[int]] = {} - if LOCAL_CONFIG_PATH.exists(): - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - - data["admin_qqs"] = admin_qqs - - with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) - except Exception as exc: - logger.error("保存本地配置失败: %s", exc) - raise - - -@dataclass -class WebUISettings: - url: str - port: int - password: str - using_default_password: bool - config_exists: bool - - @property - def display_url(self) -> str: - """用于日志和展示的格式化 URL。""" - from Undefined.config.models import format_netloc - - return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" - - -def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: - data = load_toml_data(config_path) - config_exists = bool(data) - url_value = _get_value(data, ("webui", "url"), None) - port_value = _get_value(data, ("webui", "port"), None) - password_value = _get_value(data, ("webui", "password"), None) - - url = _coerce_str(url_value, DEFAULT_WEBUI_URL) - port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_WEBUI_PORT - - password_normalized = _normalize_str(password_value) - if not password_normalized: - return WebUISettings( - url=url, - port=port, - password=DEFAULT_WEBUI_PASSWORD, - using_default_password=True, - config_exists=config_exists, - ) - return WebUISettings( - url=url, - port=port, - password=password_normalized, - using_default_password=False, - config_exists=config_exists, - ) - - @dataclass class Config: """应用配置""" @@ -468,6 +217,10 @@ class Config: process_private_message: bool process_poke_message: bool keyword_reply_enabled: bool + repeat_enabled: bool + repeat_threshold: int + repeat_cooldown_minutes: int + inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int nagaagent_mode_enabled: bool @@ -480,6 +233,8 @@ class Config: naga_model: SecurityModelConfig agent_model: AgentModelConfig historian_model: AgentModelConfig + summary_model: AgentModelConfig + summary_model_configured: bool grok_model: GrokModelConfig model_pool_enabled: bool log_level: str @@ -499,6 +254,12 @@ class Config: token_usage_max_total_mb: int token_usage_archive_prune_mode: str history_max_records: int + history_filtered_result_limit: int + history_search_scan_limit: int + history_summary_fetch_limit: int + history_summary_time_fetch_limit: int + history_onebot_fetch_limit: int + history_group_analysis_limit: int skills_hot_reload: bool skills_hot_reload_interval: float skills_hot_reload_debounce: float @@ -705,6 +466,44 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi None, ) keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) + repeat_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "repeat_enabled"), + "EASTER_EGG_REPEAT_ENABLED", + ), + False, + ) + inverted_question_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "inverted_question_enabled"), + "EASTER_EGG_INVERTED_QUESTION_ENABLED", + ), + False, + ) + repeat_threshold = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_threshold"), + "EASTER_EGG_REPEAT_THRESHOLD", + ), + 3, + ) + if repeat_threshold < 2: + repeat_threshold = 2 + if repeat_threshold > 20: + repeat_threshold = 20 + repeat_cooldown_minutes = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_cooldown_minutes"), + "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", + ), + 60, + ) + if repeat_cooldown_minutes < 0: + repeat_cooldown_minutes = 0 context_recent_messages_limit = _coerce_int( _get_value( data, @@ -744,8 +543,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" ) - embedding_model = cls._parse_embedding_model_config(data) - rerank_model = cls._parse_rerank_model_config(data) + embedding_model = _parse_embedding_model_config(data) + rerank_model = _parse_rerank_model_config(data) knowledge_enabled = _coerce_bool( _get_value(data, ("knowledge", "enabled"), None), False @@ -819,8 +618,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ) knowledge_rerank_top_k = fallback - chat_model = cls._parse_chat_model_config(data) - vision_model = cls._parse_vision_model_config(data) + chat_model = _parse_chat_model_config(data) + vision_model = _parse_vision_model_config(data) security_model_enabled = _coerce_bool( _get_value( data, @@ -829,17 +628,20 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ), True, ) - security_model = cls._parse_security_model_config(data, chat_model) - naga_model = cls._parse_naga_model_config(data, security_model) - agent_model = cls._parse_agent_model_config(data) - historian_model = cls._parse_historian_model_config(data, agent_model) - grok_model = cls._parse_grok_model_config(data) + security_model = _parse_security_model_config(data, chat_model) + naga_model = _parse_naga_model_config(data, security_model) + agent_model = _parse_agent_model_config(data) + historian_model = _parse_historian_model_config(data, agent_model) + summary_model, summary_model_configured = _parse_summary_model_config( + data, agent_model + ) + grok_model = _parse_grok_model_config(data) model_pool_enabled = _coerce_bool( _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False ) - superadmin_qq, admin_qqs = cls._merge_admins( + superadmin_qq, admin_qqs = _merge_admins( superadmin_qq=superadmin_qq, admin_qqs=admin_qqs ) @@ -971,7 +773,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi if easter_egg_mode_raw is not None: _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") - easter_egg_agent_call_message_mode = cls._parse_easter_egg_call_mode( + easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( easter_egg_mode_raw ) @@ -1000,8 +802,78 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi "delete", ) - history_max_records = _coerce_int( - _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), 10000 + history_max_records = max( + 0, + _coerce_int( + _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), + 10000, + ), + ) + history_filtered_result_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "filtered_result_limit"), + "HISTORY_FILTERED_RESULT_LIMIT", + ), + 200, + ), + ) + history_search_scan_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "search_scan_limit"), + "HISTORY_SEARCH_SCAN_LIMIT", + ), + 10000, + ), + ) + history_summary_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_fetch_limit"), + "HISTORY_SUMMARY_FETCH_LIMIT", + ), + 1000, + ), + ) + history_summary_time_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_time_fetch_limit"), + "HISTORY_SUMMARY_TIME_FETCH_LIMIT", + ), + 5000, + ), + ) + history_onebot_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "onebot_fetch_limit"), + "HISTORY_ONEBOT_FETCH_LIMIT", + ), + 10000, + ), + ) + history_group_analysis_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "group_analysis_limit"), + "HISTORY_GROUP_ANALYSIS_LIMIT", + ), + 500, + ), ) skills_hot_reload = _coerce_bool( @@ -1331,17 +1203,17 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi messages_send_url_file_max_size_mb = 100 webui_settings = load_webui_settings(config_path) - api_config = cls._parse_api_config(data) + api_config = _parse_api_config(data) - cognitive = cls._parse_cognitive_config(data) - memes = cls._parse_memes_config(data) - naga = cls._parse_naga_config(data) - models_image_gen = cls._parse_image_gen_model_config(data) - models_image_edit = cls._parse_image_edit_model_config(data) - image_gen = cls._parse_image_gen_config(data) + cognitive = _parse_cognitive_config(data) + memes = _parse_memes_config(data) + naga = _parse_naga_config(data) + models_image_gen = _parse_image_gen_model_config(data) + models_image_edit = _parse_image_edit_model_config(data) + image_gen = _parse_image_gen_config(data) if strict: - cls._verify_required_fields( + _verify_required_fields( bot_qq=bot_qq, superadmin_qq=superadmin_qq, onebot_ws_url=onebot_ws_url, @@ -1352,12 +1224,13 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi embedding_model=embedding_model, ) - cls._log_debug_info( + _log_debug_info( chat_model, vision_model, security_model, naga_model, agent_model, + summary_model, grok_model, ) @@ -1377,6 +1250,10 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi process_private_message=process_private_message, process_poke_message=process_poke_message, keyword_reply_enabled=keyword_reply_enabled, + repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, + inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, nagaagent_mode_enabled=nagaagent_mode_enabled, @@ -1389,6 +1266,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi naga_model=naga_model, agent_model=agent_model, historian_model=historian_model, + summary_model=summary_model, + summary_model_configured=summary_model_configured, grok_model=grok_model, model_pool_enabled=model_pool_enabled, log_level=log_level, @@ -1409,6 +1288,12 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi token_usage_archive_prune_mode=token_usage_archive_prune_mode, skills_hot_reload=skills_hot_reload, history_max_records=history_max_records, + history_filtered_result_limit=history_filtered_result_limit, + history_search_scan_limit=history_search_scan_limit, + history_summary_fetch_limit=history_summary_fetch_limit, + history_summary_time_fetch_limit=history_summary_time_fetch_limit, + history_onebot_fetch_limit=history_onebot_fetch_limit, + history_group_analysis_limit=history_group_analysis_limit, skills_hot_reload_interval=skills_hot_reload_interval, skills_hot_reload_debounce=skills_hot_reload_debounce, agent_intro_autogen_enabled=agent_intro_autogen_enabled, @@ -1669,1054 +1554,6 @@ def security_check_enabled(self) -> bool: return bool(self.security_model_enabled) - @staticmethod - def _parse_model_pool( - data: dict[str, Any], - model_section: str, - primary_config: ChatModelConfig | AgentModelConfig, - ) -> ModelPool | None: - """解析模型池配置,缺省字段继承 primary_config""" - pool_data = data.get("models", {}).get(model_section, {}).get("pool") - if not isinstance(pool_data, dict): - return None - - enabled = _coerce_bool(pool_data.get("enabled"), False) - strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() - if strategy not in ("default", "round_robin", "random"): - strategy = "default" - - raw_models = pool_data.get("models") - if not isinstance(raw_models, list): - return ModelPool(enabled=enabled, strategy=strategy) - - entries: list[ModelPoolEntry] = [] - for item in raw_models: - if not isinstance(item, dict): - continue - name = _coerce_str(item.get("model_name"), "").strip() - if not name: - continue - entries.append( - ModelPoolEntry( - api_url=_coerce_str(item.get("api_url"), primary_config.api_url), - api_key=_coerce_str(item.get("api_key"), primary_config.api_key), - model_name=name, - max_tokens=_coerce_int( - item.get("max_tokens"), primary_config.max_tokens - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - item.get("queue_interval_seconds"), - primary_config.queue_interval_seconds, - ), - primary_config.queue_interval_seconds, - ), - api_mode=( - _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - ) - if _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - in _VALID_API_MODES - else primary_config.api_mode, - thinking_enabled=_coerce_bool( - item.get("thinking_enabled"), primary_config.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - item.get("thinking_budget_tokens"), - primary_config.thinking_budget_tokens, - ), - thinking_include_budget=_coerce_bool( - item.get("thinking_include_budget"), - primary_config.thinking_include_budget, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - item.get("reasoning_effort_style"), - primary_config.reasoning_effort_style, - ), - thinking_tool_call_compat=_coerce_bool( - item.get("thinking_tool_call_compat"), - primary_config.thinking_tool_call_compat, - ), - responses_tool_choice_compat=_coerce_bool( - item.get("responses_tool_choice_compat"), - primary_config.responses_tool_choice_compat, - ), - responses_force_stateless_replay=_coerce_bool( - item.get("responses_force_stateless_replay"), - primary_config.responses_force_stateless_replay, - ), - prompt_cache_enabled=_coerce_bool( - item.get("prompt_cache_enabled"), - primary_config.prompt_cache_enabled, - ), - reasoning_enabled=_coerce_bool( - item.get("reasoning_enabled"), - primary_config.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - item.get("reasoning_effort"), - primary_config.reasoning_effort, - ), - request_params=merge_request_params( - primary_config.request_params, - item.get("request_params"), - ), - ) - ) - - return ModelPool(enabled=enabled, strategy=strategy, models=entries) - - @staticmethod - def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: - return EmbeddingModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - _get_value( - data, ("models", "embedding", "queue_interval_seconds"), None - ), - 0.0, - ), - 0.0, - ), - dimensions=_coerce_int( - _get_value(data, ("models", "embedding", "dimensions"), None), 0 - ) - or None, - query_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "query_instruction"), None), "" - ), - document_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "document_instruction"), None), - "", - ), - request_params=_get_model_request_params(data, "embedding"), - ) - - @staticmethod - def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), - 0.0, - ), - 0.0, - ) - return RerankModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - query_instruction=_coerce_str( - _get_value(data, ("models", "rerank", "query_instruction"), None), "" - ), - request_params=_get_model_request_params(data, "rerank"), - ) - - @staticmethod - def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "chat", "queue_interval_seconds"), - "CHAT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="chat", - include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "prompt_cache_enabled"), - "CHAT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "reasoning_enabled"), - "CHAT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "chat", "reasoning_effort"), - "CHAT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = ChatModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "chat", "thinking_enabled"), - "CHAT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "chat", "thinking_budget_tokens"), - "CHAT_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "chat", "reasoning_effort_style"), - "CHAT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "chat"), - ) - config.pool = Config._parse_model_pool(data, "chat", config) - return config - - @staticmethod - def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "vision", "queue_interval_seconds"), - "VISION_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="vision", - include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "prompt_cache_enabled"), - "VISION_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "reasoning_enabled"), - "VISION_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "vision", "reasoning_effort"), - "VISION_MODEL_REASONING_EFFORT", - ), - "medium", - ) - return VisionModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "vision", "model_name"), "VISION_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "vision", "thinking_enabled"), - "VISION_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "vision", "thinking_budget_tokens"), - "VISION_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "vision", "reasoning_effort_style"), - "VISION_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "vision"), - ) - - @staticmethod - def _parse_security_model_config( - data: dict[str, Any], chat_model: ChatModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value( - data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL" - ), - "", - ) - api_key = _coerce_str( - _get_value( - data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY" - ), - "", - ) - model_name = _coerce_str( - _get_value( - data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME" - ), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "security", "queue_interval_seconds"), - "SECURITY_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="security", - include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "prompt_cache_enabled"), - "SECURITY_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "reasoning_enabled"), - "SECURITY_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "security", "reasoning_effort"), - "SECURITY_MODEL_REASONING_EFFORT", - ), - "medium", - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "max_tokens"), - "SECURITY_MODEL_MAX_TOKENS", - ), - 100, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "security", "thinking_enabled"), - "SECURITY_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "thinking_budget_tokens"), - "SECURITY_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "security", "reasoning_effort_style"), - "SECURITY_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "security"), - ) - - logger.warning("未配置安全模型,将使用对话模型作为后备") - return SecurityModelConfig( - api_url=chat_model.api_url, - api_key=chat_model.api_key, - model_name=chat_model.model_name, - max_tokens=chat_model.max_tokens, - queue_interval_seconds=chat_model.queue_interval_seconds, - api_mode=chat_model.api_mode, - thinking_enabled=False, - thinking_budget_tokens=0, - thinking_include_budget=True, - reasoning_effort_style="openai", - thinking_tool_call_compat=chat_model.thinking_tool_call_compat, - responses_tool_choice_compat=chat_model.responses_tool_choice_compat, - responses_force_stateless_replay=chat_model.responses_force_stateless_replay, - prompt_cache_enabled=chat_model.prompt_cache_enabled, - reasoning_enabled=chat_model.reasoning_enabled, - reasoning_effort=chat_model.reasoning_effort, - request_params=merge_request_params(chat_model.request_params), - ) - - @staticmethod - def _parse_naga_model_config( - data: dict[str, Any], security_model: SecurityModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), - "", - ) - api_key = _coerce_str( - _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), - "", - ) - model_name = _coerce_str( - _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "naga", "queue_interval_seconds"), - "NAGA_MODEL_QUEUE_INTERVAL", - ), - security_model.queue_interval_seconds, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="naga", - include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "prompt_cache_enabled"), - "NAGA_MODEL_PROMPT_CACHE_ENABLED", - ), - getattr(security_model, "prompt_cache_enabled", True), - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "reasoning_enabled"), - "NAGA_MODEL_REASONING_ENABLED", - ), - getattr(security_model, "reasoning_enabled", False), - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "naga", "reasoning_effort"), - "NAGA_MODEL_REASONING_EFFORT", - ), - getattr(security_model, "reasoning_effort", "medium"), - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "max_tokens"), - "NAGA_MODEL_MAX_TOKENS", - ), - 160, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "naga", "thinking_enabled"), - "NAGA_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "thinking_budget_tokens"), - "NAGA_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "naga", "reasoning_effort_style"), - "NAGA_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "naga"), - ) - - logger.info( - "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" - ) - return SecurityModelConfig( - api_url=security_model.api_url, - api_key=security_model.api_key, - model_name=security_model.model_name, - max_tokens=security_model.max_tokens, - queue_interval_seconds=security_model.queue_interval_seconds, - api_mode=security_model.api_mode, - thinking_enabled=security_model.thinking_enabled, - thinking_budget_tokens=security_model.thinking_budget_tokens, - thinking_include_budget=security_model.thinking_include_budget, - reasoning_effort_style=security_model.reasoning_effort_style, - thinking_tool_call_compat=security_model.thinking_tool_call_compat, - responses_tool_choice_compat=security_model.responses_tool_choice_compat, - responses_force_stateless_replay=security_model.responses_force_stateless_replay, - prompt_cache_enabled=security_model.prompt_cache_enabled, - reasoning_enabled=security_model.reasoning_enabled, - reasoning_effort=security_model.reasoning_effort, - request_params=merge_request_params(security_model.request_params), - ) - - @staticmethod - def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "agent", "queue_interval_seconds"), - "AGENT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="agent", - include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "prompt_cache_enabled"), - "AGENT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "reasoning_enabled"), - "AGENT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "agent", "reasoning_effort"), - "AGENT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = AgentModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" - ), - 4096, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "agent", "thinking_enabled"), - "AGENT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "agent", "thinking_budget_tokens"), - "AGENT_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "agent", "reasoning_effort_style"), - "AGENT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "agent"), - ) - config.pool = Config._parse_model_pool(data, "agent", config) - return config - - @staticmethod - def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "grok", "queue_interval_seconds"), - "GROK_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - return GrokModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_enabled"), - "GROK_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "grok", "thinking_budget_tokens"), - "GROK_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_include_budget"), - "GROK_MODEL_THINKING_INCLUDE_BUDGET", - ), - True, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "grok", "reasoning_effort_style"), - "GROK_MODEL_REASONING_EFFORT_STYLE", - ), - ), - prompt_cache_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "prompt_cache_enabled"), - "GROK_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ), - reasoning_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "reasoning_enabled"), - "GROK_MODEL_REASONING_ENABLED", - ), - False, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - data, - ("models", "grok", "reasoning_effort"), - "GROK_MODEL_REASONING_EFFORT", - ), - "medium", - ), - request_params=_get_model_request_params(data, "grok"), - ) - - @staticmethod - def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_gen] 生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" - ), - "", - ), - request_params=_get_model_request_params(data, "image_gen"), - ) - - @staticmethod - def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_edit] 参考图生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_url"), - "IMAGE_EDIT_MODEL_API_URL", - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_key"), - "IMAGE_EDIT_MODEL_API_KEY", - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, - ("models", "image_edit", "model_name"), - "IMAGE_EDIT_MODEL_NAME", - ), - "", - ), - request_params=_get_model_request_params(data, "image_edit"), - ) - - @staticmethod - def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: - """解析 [image_gen] 生图工具配置""" - return ImageGenConfig( - provider=_coerce_str( - _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), - "xingzhige", - ), - xingzhige_size=_coerce_str( - _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" - ), - openai_size=_coerce_str( - _get_value(data, ("image_gen", "openai_size"), None), "" - ), - openai_quality=_coerce_str( - _get_value(data, ("image_gen", "openai_quality"), None), "" - ), - openai_style=_coerce_str( - _get_value(data, ("image_gen", "openai_style"), None), "" - ), - openai_timeout=_coerce_float( - _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 - ), - ) - - @staticmethod - def _merge_admins( - superadmin_qq: int, admin_qqs: list[int] - ) -> tuple[int, list[int]]: - local_admins = load_local_admins() - all_admins = list(set(admin_qqs + local_admins)) - if superadmin_qq and superadmin_qq not in all_admins: - all_admins.append(superadmin_qq) - return superadmin_qq, all_admins - - @staticmethod - def _verify_required_fields( - bot_qq: int, - superadmin_qq: int, - onebot_ws_url: str, - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - agent_model: AgentModelConfig, - knowledge_enabled: bool, - embedding_model: EmbeddingModelConfig, - ) -> None: - missing: list[str] = [] - if bot_qq <= 0: - missing.append("core.bot_qq") - if superadmin_qq <= 0: - missing.append("core.superadmin_qq") - if not onebot_ws_url: - missing.append("onebot.ws_url") - if not chat_model.api_url: - missing.append("models.chat.api_url") - if not chat_model.api_key: - missing.append("models.chat.api_key") - if not chat_model.model_name: - missing.append("models.chat.model_name") - if not vision_model.api_url: - missing.append("models.vision.api_url") - if not vision_model.api_key: - missing.append("models.vision.api_key") - if not vision_model.model_name: - missing.append("models.vision.model_name") - if not agent_model.api_url: - missing.append("models.agent.api_url") - if not agent_model.api_key: - missing.append("models.agent.api_key") - if not agent_model.model_name: - missing.append("models.agent.model_name") - if knowledge_enabled: - if not embedding_model.api_url: - missing.append("models.embedding.api_url") - if not embedding_model.model_name: - missing.append("models.embedding.model_name") - if missing: - raise ValueError(f"缺少必需配置: {', '.join(missing)}") - - @staticmethod - def _log_debug_info( - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - security_model: SecurityModelConfig, - naga_model: SecurityModelConfig, - agent_model: AgentModelConfig, - grok_model: GrokModelConfig, - ) -> None: - configs: list[ - tuple[ - str, - ChatModelConfig - | VisionModelConfig - | SecurityModelConfig - | AgentModelConfig - | GrokModelConfig, - ] - ] = [ - ("chat", chat_model), - ("vision", vision_model), - ("security", security_model), - ("naga", naga_model), - ("agent", agent_model), - ("grok", grok_model), - ] - for name, cfg in configs: - logger.debug( - "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", - name, - cfg.model_name, - cfg.api_url, - bool(cfg.api_key), - getattr(cfg, "api_mode", "chat_completions"), - cfg.thinking_enabled, - getattr(cfg, "reasoning_enabled", False), - getattr(cfg, "reasoning_effort", "medium"), - getattr(cfg, "thinking_tool_call_compat", False), - getattr(cfg, "responses_tool_choice_compat", False), - getattr(cfg, "responses_force_stateless_replay", False), - ) - def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes: dict[str, tuple[Any, Any]] = {} for field in fields(self): @@ -2740,358 +1577,6 @@ def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes[name] = (old_value, new_value) return changes - @staticmethod - def _parse_historian_model_config( - data: dict[str, Any], fallback: AgentModelConfig - ) -> AgentModelConfig: - h = data.get("models", {}).get("historian", {}) - if not isinstance(h, dict) or not h: - return fallback - queue_interval_seconds = _coerce_float( - h.get("queue_interval_seconds"), fallback.queue_interval_seconds - ) - queue_interval_seconds = _normalize_queue_interval( - queue_interval_seconds, fallback.queue_interval_seconds - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data={"models": {"historian": h}}, - model_name="historian", - include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_API_MODE", - fallback.api_mode, - ) - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", - fallback.responses_tool_choice_compat, - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", - fallback.responses_force_stateless_replay, - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "prompt_cache_enabled"), - "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", - ), - fallback.prompt_cache_enabled, - ) - return AgentModelConfig( - api_url=_coerce_str(h.get("api_url"), fallback.api_url), - api_key=_coerce_str(h.get("api_key"), fallback.api_key), - model_name=_coerce_str(h.get("model_name"), fallback.model_name), - max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - h.get("thinking_enabled"), fallback.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort_style"), - "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", - ), - fallback.reasoning_effort_style, - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=_coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_enabled"), - "HISTORIAN_MODEL_REASONING_ENABLED", - ), - fallback.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort"), - "HISTORIAN_MODEL_REASONING_EFFORT", - ), - fallback.reasoning_effort, - ), - request_params=merge_request_params( - fallback.request_params, - h.get("request_params"), - ), - ) - - @staticmethod - def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: - cog = data.get("cognitive", {}) - vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} - q = cog.get("query", {}) if isinstance(cog, dict) else {} - hist = cog.get("historian", {}) if isinstance(cog, dict) else {} - prof = cog.get("profile", {}) if isinstance(cog, dict) else {} - que = cog.get("queue", {}) if isinstance(cog, dict) else {} - return CognitiveConfig( - enabled=_coerce_bool( - cog.get("enabled") if isinstance(cog, dict) else None, True - ), - bot_name=_coerce_str( - cog.get("bot_name") if isinstance(cog, dict) else None, - "Undefined", - ), - vector_store_path=_coerce_str( - vs.get("path") if isinstance(vs, dict) else None, - "data/cognitive/chromadb", - ), - queue_path=_coerce_str( - que.get("path") if isinstance(que, dict) else None, - "data/cognitive/queues", - ), - profiles_path=_coerce_str( - prof.get("path") if isinstance(prof, dict) else None, - "data/cognitive/profiles", - ), - auto_top_k=_coerce_int( - q.get("auto_top_k") if isinstance(q, dict) else None, 3 - ), - auto_scope_candidate_multiplier=_coerce_int( - q.get("auto_scope_candidate_multiplier") - if isinstance(q, dict) - else None, - 2, - ), - auto_current_group_boost=_coerce_float( - q.get("auto_current_group_boost") if isinstance(q, dict) else None, - 1.15, - ), - auto_current_private_boost=_coerce_float( - q.get("auto_current_private_boost") if isinstance(q, dict) else None, - 1.25, - ), - enable_rerank=_coerce_bool( - q.get("enable_rerank") if isinstance(q, dict) else None, True - ), - recent_end_summaries_inject_k=_coerce_int( - q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, - 30, - ), - time_decay_enabled=_coerce_bool( - q.get("time_decay_enabled") if isinstance(q, dict) else None, True - ), - time_decay_half_life_days_auto=_coerce_float( - q.get("time_decay_half_life_days_auto") - if isinstance(q, dict) - else None, - 14.0, - ), - time_decay_half_life_days_tool=_coerce_float( - q.get("time_decay_half_life_days_tool") - if isinstance(q, dict) - else None, - 60.0, - ), - time_decay_boost=_coerce_float( - q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 - ), - time_decay_min_similarity=_coerce_float( - q.get("time_decay_min_similarity") if isinstance(q, dict) else None, - 0.35, - ), - tool_default_top_k=_coerce_int( - q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 - ), - profile_top_k=_coerce_int( - q.get("profile_top_k") if isinstance(q, dict) else None, 8 - ), - rerank_candidate_multiplier=_coerce_int( - q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 - ), - rewrite_max_retry=_coerce_int( - hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 - ), - historian_recent_messages_inject_k=_coerce_int( - hist.get("recent_messages_inject_k") - if isinstance(hist, dict) - else None, - 12, - ), - historian_recent_message_line_max_len=_coerce_int( - hist.get("recent_message_line_max_len") - if isinstance(hist, dict) - else None, - 240, - ), - historian_source_message_max_len=_coerce_int( - hist.get("source_message_max_len") if isinstance(hist, dict) else None, - 800, - ), - poll_interval_seconds=_coerce_float( - hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, - 1.0, - ), - stale_job_timeout_seconds=_coerce_float( - hist.get("stale_job_timeout_seconds") - if isinstance(hist, dict) - else None, - 300.0, - ), - profile_revision_keep=_coerce_int( - prof.get("revision_keep") if isinstance(prof, dict) else None, 5 - ), - failed_max_age_days=_coerce_int( - que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 - ), - failed_max_files=_coerce_int( - que.get("failed_max_files") if isinstance(que, dict) else None, 500 - ), - failed_cleanup_interval=_coerce_int( - que.get("failed_cleanup_interval") if isinstance(que, dict) else None, - 100, - ), - job_max_retries=_coerce_int( - que.get("job_max_retries") if isinstance(que, dict) else None, 3 - ), - ) - - @staticmethod - def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: - section_raw = data.get("memes", {}) - section = section_raw if isinstance(section_raw, dict) else {} - return MemeConfig( - enabled=_coerce_bool(section.get("enabled"), True), - query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), - max_source_image_bytes=max( - 1, - _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), - ), - blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), - preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), - db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), - vector_store_path=_coerce_str( - section.get("vector_store_path"), "data/memes/chromadb" - ), - queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), - max_items=max(1, _coerce_int(section.get("max_items"), 10000)), - max_total_bytes=max( - 1, - _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), - ), - allow_gif=_coerce_bool(section.get("allow_gif"), True), - auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), - auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), - keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), - semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), - rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), - worker_max_concurrency=max( - 1, _coerce_int(section.get("worker_max_concurrency"), 4) - ), - ) - - @staticmethod - def _parse_api_config(data: dict[str, Any]) -> APIConfig: - section_raw = data.get("api", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), True) - host = _coerce_str(section.get("host"), DEFAULT_API_HOST) - port = _coerce_int(section.get("port"), DEFAULT_API_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_API_PORT - - auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) - if not auth_key: - auth_key = DEFAULT_API_AUTH_KEY - - openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) - - tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) - tool_invoke_expose = _coerce_str( - section.get("tool_invoke_expose"), "tools+toolsets" - ) - valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} - if tool_invoke_expose not in valid_expose: - tool_invoke_expose = "tools+toolsets" - tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) - tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) - tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) - if tool_invoke_timeout <= 0: - tool_invoke_timeout = 120 - tool_invoke_callback_timeout = _coerce_int( - section.get("tool_invoke_callback_timeout"), 10 - ) - if tool_invoke_callback_timeout <= 0: - tool_invoke_callback_timeout = 10 - - return APIConfig( - enabled=enabled, - host=host, - port=port, - auth_key=auth_key, - openapi_enabled=openapi_enabled, - tool_invoke_enabled=tool_invoke_enabled, - tool_invoke_expose=tool_invoke_expose, - tool_invoke_allowlist=tool_invoke_allowlist, - tool_invoke_denylist=tool_invoke_denylist, - tool_invoke_timeout=tool_invoke_timeout, - tool_invoke_callback_timeout=tool_invoke_callback_timeout, - ) - - @staticmethod - def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: - section_raw = data.get("naga", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), False) - api_url = _coerce_str(section.get("api_url"), "") - api_key = _coerce_str(section.get("api_key"), "") - moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) - allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) - - return NagaConfig( - enabled=enabled, - api_url=api_url, - api_key=api_key, - moderation_enabled=moderation_enabled, - allowed_groups=allowed_groups, - ) - - @staticmethod - def _parse_easter_egg_call_mode(value: Any) -> str: - """解析彩蛋提示模式。 - - 兼容旧版布尔值: - - True => agent - - False => none - """ - if isinstance(value, bool): - return "agent" if value else "none" - if isinstance(value, (int, float)): - return "agent" if bool(value) else "none" - if value is None: - return "none" - - text = str(value).strip().lower() - if text in {"true", "1", "yes", "on"}: - return "agent" - if text in {"false", "0", "no", "off"}: - return "none" - if text in {"none", "agent", "tools", "all", "clean"}: - return text - return "none" - def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: new_config = Config.load(strict=strict) return self.update_from(new_config) @@ -3121,20 +1606,3 @@ def is_superadmin(self, qq: int) -> bool: def is_admin(self, qq: int) -> bool: return qq in self.admin_qqs - - -def _update_dataclass( - old_value: Any, new_value: Any, prefix: str -) -> dict[str, tuple[Any, Any]]: - changes: dict[str, tuple[Any, Any]] = {} - if not isinstance(old_value, type(new_value)): - changes[prefix] = (old_value, new_value) - return changes - for field in fields(old_value): - name = field.name - old_attr = getattr(old_value, name) - new_attr = getattr(new_value, name) - if old_attr != new_attr: - setattr(old_value, name, new_attr) - changes[f"{prefix}.{name}"] = (old_attr, new_attr) - return changes diff --git a/src/Undefined/config/model_parsers.py b/src/Undefined/config/model_parsers.py new file mode 100644 index 00000000..13e16c2f --- /dev/null +++ b/src/Undefined/config/model_parsers.py @@ -0,0 +1,1258 @@ +"""Model configuration parsers extracted from Config class.""" + +from __future__ import annotations + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from .admin import load_local_admins +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, + _VALID_API_MODES, +) +from .models import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + ImageGenConfig, + ImageGenModelConfig, + ModelPool, + ModelPoolEntry, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) +from .resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_model_pool( + data: dict[str, Any], + model_section: str, + primary_config: ChatModelConfig | AgentModelConfig, +) -> ModelPool | None: + """解析模型池配置,缺省字段继承 primary_config""" + pool_data = data.get("models", {}).get(model_section, {}).get("pool") + if not isinstance(pool_data, dict): + return None + + enabled = _coerce_bool(pool_data.get("enabled"), False) + strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() + if strategy not in ("default", "round_robin", "random"): + strategy = "default" + + raw_models = pool_data.get("models") + if not isinstance(raw_models, list): + return ModelPool(enabled=enabled, strategy=strategy) + + entries: list[ModelPoolEntry] = [] + for item in raw_models: + if not isinstance(item, dict): + continue + name = _coerce_str(item.get("model_name"), "").strip() + if not name: + continue + entries.append( + ModelPoolEntry( + api_url=_coerce_str(item.get("api_url"), primary_config.api_url), + api_key=_coerce_str(item.get("api_key"), primary_config.api_key), + model_name=name, + max_tokens=_coerce_int( + item.get("max_tokens"), primary_config.max_tokens + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + item.get("queue_interval_seconds"), + primary_config.queue_interval_seconds, + ), + primary_config.queue_interval_seconds, + ), + api_mode=( + _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + ) + if _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + in _VALID_API_MODES + else primary_config.api_mode, + thinking_enabled=_coerce_bool( + item.get("thinking_enabled"), primary_config.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + item.get("thinking_budget_tokens"), + primary_config.thinking_budget_tokens, + ), + thinking_include_budget=_coerce_bool( + item.get("thinking_include_budget"), + primary_config.thinking_include_budget, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + item.get("reasoning_effort_style"), + primary_config.reasoning_effort_style, + ), + thinking_tool_call_compat=_coerce_bool( + item.get("thinking_tool_call_compat"), + primary_config.thinking_tool_call_compat, + ), + responses_tool_choice_compat=_coerce_bool( + item.get("responses_tool_choice_compat"), + primary_config.responses_tool_choice_compat, + ), + responses_force_stateless_replay=_coerce_bool( + item.get("responses_force_stateless_replay"), + primary_config.responses_force_stateless_replay, + ), + prompt_cache_enabled=_coerce_bool( + item.get("prompt_cache_enabled"), + primary_config.prompt_cache_enabled, + ), + reasoning_enabled=_coerce_bool( + item.get("reasoning_enabled"), + primary_config.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + item.get("reasoning_effort"), + primary_config.reasoning_effort, + ), + request_params=merge_request_params( + primary_config.request_params, + item.get("request_params"), + ), + ) + ) + + return ModelPool(enabled=enabled, strategy=strategy, models=entries) + + +def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: + return EmbeddingModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" + ), + "", + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + _get_value( + data, ("models", "embedding", "queue_interval_seconds"), None + ), + 0.0, + ), + 0.0, + ), + dimensions=_coerce_int( + _get_value(data, ("models", "embedding", "dimensions"), None), 0 + ) + or None, + query_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "query_instruction"), None), "" + ), + document_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "document_instruction"), None), + "", + ), + request_params=_get_model_request_params(data, "embedding"), + ) + + +def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), + 0.0, + ), + 0.0, + ) + return RerankModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + query_instruction=_coerce_str( + _get_value(data, ("models", "rerank", "query_instruction"), None), "" + ), + request_params=_get_model_request_params(data, "rerank"), + ) + + +def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "chat", "queue_interval_seconds"), + "CHAT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="chat", + include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "prompt_cache_enabled"), + "CHAT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "reasoning_enabled"), + "CHAT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "chat", "reasoning_effort"), + "CHAT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = ChatModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "chat", "thinking_enabled"), + "CHAT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "chat", "thinking_budget_tokens"), + "CHAT_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "chat", "reasoning_effort_style"), + "CHAT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "chat"), + ) + config.pool = _parse_model_pool(data, "chat", config) + return config + + +def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "vision", "queue_interval_seconds"), + "VISION_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="vision", + include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "prompt_cache_enabled"), + "VISION_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "reasoning_enabled"), + "VISION_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "vision", "reasoning_effort"), + "VISION_MODEL_REASONING_EFFORT", + ), + "medium", + ) + return VisionModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "max_tokens"), + "VISION_MODEL_MAX_TOKENS", + ), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "vision", "thinking_enabled"), + "VISION_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "thinking_budget_tokens"), + "VISION_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "vision", "reasoning_effort_style"), + "VISION_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "vision"), + ) + + +def _parse_security_model_config( + data: dict[str, Any], chat_model: ChatModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "security", "queue_interval_seconds"), + "SECURITY_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="security", + include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "prompt_cache_enabled"), + "SECURITY_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "reasoning_enabled"), + "SECURITY_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "security", "reasoning_effort"), + "SECURITY_MODEL_REASONING_EFFORT", + ), + "medium", + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "max_tokens"), + "SECURITY_MODEL_MAX_TOKENS", + ), + 100, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "security", "thinking_enabled"), + "SECURITY_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "thinking_budget_tokens"), + "SECURITY_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "security", "reasoning_effort_style"), + "SECURITY_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "security"), + ) + + logger.warning("未配置安全模型,将使用对话模型作为后备") + return SecurityModelConfig( + api_url=chat_model.api_url, + api_key=chat_model.api_key, + model_name=chat_model.model_name, + max_tokens=chat_model.max_tokens, + queue_interval_seconds=chat_model.queue_interval_seconds, + api_mode=chat_model.api_mode, + thinking_enabled=False, + thinking_budget_tokens=0, + thinking_include_budget=True, + reasoning_effort_style="openai", + thinking_tool_call_compat=chat_model.thinking_tool_call_compat, + responses_tool_choice_compat=chat_model.responses_tool_choice_compat, + responses_force_stateless_replay=chat_model.responses_force_stateless_replay, + prompt_cache_enabled=chat_model.prompt_cache_enabled, + reasoning_enabled=chat_model.reasoning_enabled, + reasoning_effort=chat_model.reasoning_effort, + request_params=merge_request_params(chat_model.request_params), + ) + + +def _parse_naga_model_config( + data: dict[str, Any], security_model: SecurityModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "naga", "queue_interval_seconds"), + "NAGA_MODEL_QUEUE_INTERVAL", + ), + security_model.queue_interval_seconds, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="naga", + include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "prompt_cache_enabled"), + "NAGA_MODEL_PROMPT_CACHE_ENABLED", + ), + getattr(security_model, "prompt_cache_enabled", True), + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "reasoning_enabled"), + "NAGA_MODEL_REASONING_ENABLED", + ), + getattr(security_model, "reasoning_enabled", False), + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "naga", "reasoning_effort"), + "NAGA_MODEL_REASONING_EFFORT", + ), + getattr(security_model, "reasoning_effort", "medium"), + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "max_tokens"), + "NAGA_MODEL_MAX_TOKENS", + ), + 160, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "naga", "thinking_enabled"), + "NAGA_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "thinking_budget_tokens"), + "NAGA_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "naga", "reasoning_effort_style"), + "NAGA_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "naga"), + ) + + logger.info( + "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" + ) + return SecurityModelConfig( + api_url=security_model.api_url, + api_key=security_model.api_key, + model_name=security_model.model_name, + max_tokens=security_model.max_tokens, + queue_interval_seconds=security_model.queue_interval_seconds, + api_mode=security_model.api_mode, + thinking_enabled=security_model.thinking_enabled, + thinking_budget_tokens=security_model.thinking_budget_tokens, + thinking_include_budget=security_model.thinking_include_budget, + reasoning_effort_style=security_model.reasoning_effort_style, + thinking_tool_call_compat=security_model.thinking_tool_call_compat, + responses_tool_choice_compat=security_model.responses_tool_choice_compat, + responses_force_stateless_replay=security_model.responses_force_stateless_replay, + prompt_cache_enabled=security_model.prompt_cache_enabled, + reasoning_enabled=security_model.reasoning_enabled, + reasoning_effort=security_model.reasoning_effort, + request_params=merge_request_params(security_model.request_params), + ) + + +def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "agent", "queue_interval_seconds"), + "AGENT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="agent", + include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "prompt_cache_enabled"), + "AGENT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "reasoning_enabled"), + "AGENT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "agent", "reasoning_effort"), + "AGENT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = AgentModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" + ), + 4096, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "agent", "thinking_enabled"), + "AGENT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "agent", "thinking_budget_tokens"), + "AGENT_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "agent", "reasoning_effort_style"), + "AGENT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "agent"), + ) + config.pool = _parse_model_pool(data, "agent", config) + return config + + +def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "grok", "queue_interval_seconds"), + "GROK_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + return GrokModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_enabled"), + "GROK_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "grok", "thinking_budget_tokens"), + "GROK_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_include_budget"), + "GROK_MODEL_THINKING_INCLUDE_BUDGET", + ), + True, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "grok", "reasoning_effort_style"), + "GROK_MODEL_REASONING_EFFORT_STYLE", + ), + ), + prompt_cache_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "prompt_cache_enabled"), + "GROK_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ), + reasoning_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "reasoning_enabled"), + "GROK_MODEL_REASONING_ENABLED", + ), + False, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + data, + ("models", "grok", "reasoning_effort"), + "GROK_MODEL_REASONING_EFFORT", + ), + "medium", + ), + request_params=_get_model_request_params(data, "grok"), + ) + + +def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_gen] 生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" + ), + "", + ), + request_params=_get_model_request_params(data, "image_gen"), + ) + + +def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + + +def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: + """解析 [image_gen] 生图工具配置""" + return ImageGenConfig( + provider=_coerce_str( + _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), + "xingzhige", + ), + xingzhige_size=_coerce_str( + _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" + ), + openai_size=_coerce_str( + _get_value(data, ("image_gen", "openai_size"), None), "" + ), + openai_quality=_coerce_str( + _get_value(data, ("image_gen", "openai_quality"), None), "" + ), + openai_style=_coerce_str( + _get_value(data, ("image_gen", "openai_style"), None), "" + ), + openai_timeout=_coerce_float( + _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 + ), + ) + + +def _merge_admins(superadmin_qq: int, admin_qqs: list[int]) -> tuple[int, list[int]]: + local_admins = load_local_admins() + all_admins = list(set(admin_qqs + local_admins)) + if superadmin_qq and superadmin_qq not in all_admins: + all_admins.append(superadmin_qq) + return superadmin_qq, all_admins + + +def _verify_required_fields( + bot_qq: int, + superadmin_qq: int, + onebot_ws_url: str, + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + agent_model: AgentModelConfig, + knowledge_enabled: bool, + embedding_model: EmbeddingModelConfig, +) -> None: + missing: list[str] = [] + if bot_qq <= 0: + missing.append("core.bot_qq") + if superadmin_qq <= 0: + missing.append("core.superadmin_qq") + if not onebot_ws_url: + missing.append("onebot.ws_url") + if not chat_model.api_url: + missing.append("models.chat.api_url") + if not chat_model.api_key: + missing.append("models.chat.api_key") + if not chat_model.model_name: + missing.append("models.chat.model_name") + if not vision_model.api_url: + missing.append("models.vision.api_url") + if not vision_model.api_key: + missing.append("models.vision.api_key") + if not vision_model.model_name: + missing.append("models.vision.model_name") + if not agent_model.api_url: + missing.append("models.agent.api_url") + if not agent_model.api_key: + missing.append("models.agent.api_key") + if not agent_model.model_name: + missing.append("models.agent.model_name") + if knowledge_enabled: + if not embedding_model.api_url: + missing.append("models.embedding.api_url") + if not embedding_model.model_name: + missing.append("models.embedding.model_name") + if missing: + raise ValueError(f"缺少必需配置: {', '.join(missing)}") + + +def _log_debug_info( + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + security_model: SecurityModelConfig, + naga_model: SecurityModelConfig, + agent_model: AgentModelConfig, + summary_model: AgentModelConfig, + grok_model: GrokModelConfig, +) -> None: + configs: list[ + tuple[ + str, + ChatModelConfig + | VisionModelConfig + | SecurityModelConfig + | AgentModelConfig + | GrokModelConfig, + ] + ] = [ + ("chat", chat_model), + ("vision", vision_model), + ("security", security_model), + ("naga", naga_model), + ("agent", agent_model), + ("summary", summary_model), + ("grok", grok_model), + ] + for name, cfg in configs: + logger.debug( + "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", + name, + cfg.model_name, + cfg.api_url, + bool(cfg.api_key), + getattr(cfg, "api_mode", "chat_completions"), + cfg.thinking_enabled, + getattr(cfg, "reasoning_enabled", False), + getattr(cfg, "reasoning_effort", "medium"), + getattr(cfg, "thinking_tool_call_compat", False), + getattr(cfg, "responses_tool_choice_compat", False), + getattr(cfg, "responses_force_stateless_replay", False), + ) + + +def _parse_historian_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> AgentModelConfig: + h = data.get("models", {}).get("historian", {}) + if not isinstance(h, dict) or not h: + return fallback + queue_interval_seconds = _coerce_float( + h.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"historian": h}}, + model_name="historian", + include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "prompt_cache_enabled"), + "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return AgentModelConfig( + api_url=_coerce_str(h.get("api_url"), fallback.api_url), + api_key=_coerce_str(h.get("api_key"), fallback.api_key), + model_name=_coerce_str(h.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + h.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort_style"), + "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_enabled"), + "HISTORIAN_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort"), + "HISTORIAN_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + h.get("request_params"), + ), + ) + + +def _parse_summary_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> tuple[AgentModelConfig, bool]: + s = data.get("models", {}).get("summary", {}) + if not isinstance(s, dict) or not s: + return fallback, False + queue_interval_seconds = _coerce_float( + s.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"summary": s}}, + model_name="summary", + include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "prompt_cache_enabled"), + "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return ( + AgentModelConfig( + api_url=_coerce_str(s.get("api_url"), fallback.api_url), + api_key=_coerce_str(s.get("api_key"), fallback.api_key), + model_name=_coerce_str(s.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + s.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort_style"), + "SUMMARY_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_enabled"), + "SUMMARY_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort"), + "SUMMARY_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + s.get("request_params"), + ), + ), + True, + ) diff --git a/src/Undefined/config/models.py b/src/Undefined/config/models.py index 64aae69b..118fb9ad 100644 --- a/src/Undefined/config/models.py +++ b/src/Undefined/config/models.py @@ -92,6 +92,7 @@ class VisionModelConfig: api_url: str api_key: str model_name: str + max_tokens: int = 8192 # 最大输出 tokens queue_interval_seconds: float = 1.0 api_mode: str = "chat_completions" # 请求 API 模式 thinking_enabled: bool = False # 是否启用 thinking @@ -338,6 +339,8 @@ class MemeConfig: semantic_top_k: int = 30 rerank_top_k: int = 20 worker_max_concurrency: int = 4 + gif_analysis_mode: str = "grid" + gif_analysis_frames: int = 6 @dataclass diff --git a/src/Undefined/config/resolvers.py b/src/Undefined/config/resolvers.py new file mode 100644 index 00000000..43533c37 --- /dev/null +++ b/src/Undefined/config/resolvers.py @@ -0,0 +1,106 @@ +"""Configuration value resolution helpers.""" + +from __future__ import annotations + +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_str, + _get_value, + _VALID_API_MODES, + _VALID_REASONING_EFFORT_STYLES, +) + + +def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: + style = _coerce_str(value, default).strip().lower() + if style not in _VALID_REASONING_EFFORT_STYLES: + return default + return style + + +def _resolve_thinking_compat_flags( + data: dict[str, Any], + model_name: str, + include_budget_env_key: str, + tool_call_compat_env_key: str, + legacy_env_key: str, +) -> tuple[bool, bool]: + """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" + include_budget_value = _get_value( + data, + ("models", model_name, "thinking_include_budget"), + include_budget_env_key, + ) + tool_call_compat_value = _get_value( + data, + ("models", model_name, "thinking_tool_call_compat"), + tool_call_compat_env_key, + ) + legacy_value = _get_value( + data, + ("models", model_name, "deepseek_new_cot_support"), + legacy_env_key, + ) + + include_budget_default = True + tool_call_compat_default = True + if legacy_value is not None: + legacy_enabled = _coerce_bool(legacy_value, False) + include_budget_default = not legacy_enabled + tool_call_compat_default = legacy_enabled + + return ( + _coerce_bool(include_budget_value, include_budget_default), + _coerce_bool(tool_call_compat_value, tool_call_compat_default), + ) + + +def _resolve_api_mode( + data: dict[str, Any], + model_name: str, + env_key: str, + default: str = "chat_completions", +) -> str: + raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) + value = _coerce_str(raw_value, default).strip().lower() + if value not in _VALID_API_MODES: + return default + return value + + +def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: + return _coerce_str(value, default).strip().lower() + + +def _resolve_responses_tool_choice_compat( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_tool_choice_compat"), + env_key, + ), + default, + ) + + +def _resolve_responses_force_stateless_replay( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_force_stateless_replay"), + env_key, + ), + default, + ) diff --git a/src/Undefined/config/webui_settings.py b/src/Undefined/config/webui_settings.py new file mode 100644 index 00000000..45a5392b --- /dev/null +++ b/src/Undefined/config/webui_settings.py @@ -0,0 +1,61 @@ +"""WebUI settings management.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .coercers import _coerce_int, _coerce_str, _normalize_str, _get_value + +DEFAULT_WEBUI_URL = "127.0.0.1" +DEFAULT_WEBUI_PORT = 8787 +DEFAULT_WEBUI_PASSWORD = "changeme" + + +@dataclass +class WebUISettings: + url: str + port: int + password: str + using_default_password: bool + config_exists: bool + + @property + def display_url(self) -> str: + """用于日志和展示的格式化 URL。""" + from Undefined.config.models import format_netloc + + return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" + + +def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: + from .loader import load_toml_data # lazy to avoid circular + + data = load_toml_data(config_path) + config_exists = bool(data) + url_value = _get_value(data, ("webui", "url"), None) + port_value = _get_value(data, ("webui", "port"), None) + password_value = _get_value(data, ("webui", "password"), None) + + url = _coerce_str(url_value, DEFAULT_WEBUI_URL) + port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_WEBUI_PORT + + password_normalized = _normalize_str(password_value) + if not password_normalized: + return WebUISettings( + url=url, + port=port, + password=DEFAULT_WEBUI_PASSWORD, + using_default_password=True, + config_exists=config_exists, + ) + return WebUISettings( + url=url, + port=port, + password=password_normalized, + using_default_password=False, + config_exists=config_exists, + ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 849b8cfd..9d5bfcba 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -1,11 +1,14 @@ """消息处理和命令分发""" +from __future__ import annotations + import asyncio from dataclasses import dataclass import logging import os from pathlib import Path import random +import time from typing import Any, Coroutine from Undefined.attachments import ( @@ -28,6 +31,7 @@ parse_message_content_for_history, matches_xinliweiyuan, ) +from Undefined.utils.fake_at import BotNicknameCache, strip_fake_at from Undefined.utils.history import MessageHistoryManager from Undefined.utils.scheduler import TaskScheduler from Undefined.utils.sender import MessageSender @@ -39,18 +43,12 @@ from Undefined.scheduled_task_storage import ScheduledTaskStorage from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.coerce import safe_int logger = logging.getLogger(__name__) KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " - - -def _safe_int(value: Any) -> int | None: - try: - parsed = int(value) - except (TypeError, ValueError): - return None - return parsed if parsed > 0 else None +REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " def _format_poke_history_text(display_name: str, user_id: int) -> str: @@ -71,6 +69,7 @@ class GroupPokeRecord: group_name: str sender_role: str sender_title: str + sender_level: str class MessageHandler: @@ -112,6 +111,7 @@ def __init__( self.security, queue_manager=self.queue_manager, rate_limiter=self.rate_limiter, + history_manager=self.history_manager, ) self.ai_coordinator = AICoordinator( config, @@ -127,10 +127,129 @@ def __init__( self._background_tasks: set[asyncio.Task[None]] = set() self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} + self._bot_nickname_cache = BotNicknameCache(onebot, config.bot_qq) + + # 复读功能状态(按群跟踪最近消息文本与发送者) + self._repeat_counter: dict[int, list[tuple[str, int]]] = {} + self._repeat_locks: dict[int, asyncio.Lock] = {} + # 复读冷却:group_id → {normalized_text → monotonic_timestamp} + self._repeat_cooldown: dict[int, dict[str, float]] = {} # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) + def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: + """获取或创建指定群的复读竞态保护锁。""" + lock = self._repeat_locks.get(group_id) + if lock is None: + lock = asyncio.Lock() + self._repeat_locks[group_id] = lock + return lock + + @staticmethod + def _normalize_repeat_text(text: str) -> str: + """规范化复读文本用于冷却比较(?→?)。""" + return text.replace("?", "?") + + def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: + """检查指定群的文本是否在复读冷却期内。""" + cooldown_minutes = self.config.repeat_cooldown_minutes + if cooldown_minutes <= 0: + return False + group_cd = self._repeat_cooldown.get(group_id) + if not group_cd: + return False + key = self._normalize_repeat_text(text) + last_time = group_cd.get(key) + if last_time is None: + return False + return (time.monotonic() - last_time) < cooldown_minutes * 60 + + def _record_repeat_cooldown(self, group_id: int, text: str) -> None: + """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + if cooldown_seconds <= 0: + return + key = self._normalize_repeat_text(text) + group_cd = self._repeat_cooldown.setdefault(group_id, {}) + now = time.monotonic() + # 清理已过期条目 + expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] + for k in expired: + del group_cd[k] + group_cd[key] = now + + async def _annotate_meme_descriptions( + self, + attachments: list[dict[str, str]], + scope_key: str, + ) -> list[dict[str, str]]: + """为图片附件添加表情包描述(如果在表情库中找到)。 + + 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 + 最佳努力:任何失败时返回原始列表。 + """ + if not attachments: + return attachments + + ai_client = getattr(self, "ai", None) + if ai_client is None: + return attachments + + attachment_registry = getattr(ai_client, "attachment_registry", None) + if attachment_registry is None: + return attachments + + meme_service = getattr(ai_client, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return attachments + + meme_store = getattr(meme_service, "_store", None) + if meme_store is None: + return attachments + + try: + # 1. 从图片附件收集唯一的 SHA256 哈希值 + uid_to_hash: dict[str, str] = {} + for att in attachments: + uid = att.get("uid", "") + if not uid.startswith("pic_"): + continue + record = attachment_registry.resolve(uid, scope_key) + if record and record.sha256: + uid_to_hash[uid] = record.sha256 + + if not uid_to_hash: + return attachments + + # 2. 批量查询:去重哈希值 + unique_hashes = set(uid_to_hash.values()) + hash_to_desc: dict[str, str] = {} + for h in unique_hashes: + meme = await meme_store.find_by_sha256(h) + if meme and meme.description: + hash_to_desc[h] = meme.description + + if not hash_to_desc: + return attachments + + # 3. 构建带注释的新列表 + result: list[dict[str, str]] = [] + for att in attachments: + uid = att.get("uid", "") + sha = uid_to_hash.get(uid, "") + desc = hash_to_desc.get(sha, "") + if desc: + new_att = dict(att) + new_att["description"] = f"[表情包] {desc}" + result.append(new_att) + else: + result.append(att) + return result + except Exception: + logger.warning("表情包自动匹配失败,跳过", exc_info=True) + return attachments + async def _collect_message_attachments( self, message_content: list[dict[str, Any]], @@ -163,7 +282,10 @@ async def _collect_message_attachments( if onebot else None, ) - return result.attachments + attachments = result.attachments + # 为图片附件添加表情包描述 + attachments = await self._annotate_meme_descriptions(attachments, scope_key) + return attachments def _schedule_meme_ingest( self, @@ -371,6 +493,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_poke.group_name, sender_role=group_poke.sender_role, sender_title=group_poke.sender_title, + sender_level=group_poke.sender_level, ) return @@ -409,10 +532,19 @@ async def handle_message(self, event: dict[str, Any]) -> None: logger.warning("获取用户昵称失败: %s", exc) text = extract_text(private_message_content, self.config.bot_qq) - private_attachments = await self._collect_message_attachments( - private_message_content, - user_id=private_sender_id, - request_type="private", + # 并行执行附件收集和历史内容解析 + private_attachments, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ), + parse_message_content_for_history( + private_message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), ) safe_text = redact_string(text) logger.info( @@ -428,15 +560,10 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_name=resolved_private_name, ) - # 保存私聊消息到历史记录(保存处理后的内容) - # 使用新的工具函数解析内容 - parsed_content = await parse_message_content_for_history( - private_message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, + # 保存私聊消息到历史记录 + parsed_content = append_attachment_text( + parsed_content_raw, private_attachments ) - parsed_content = append_attachment_text(parsed_content, private_attachments) safe_parsed = redact_string(parsed_content) logger.debug( "[历史记录] 保存私聊: user=%s content=%s...", @@ -461,7 +588,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: chat_type="private", chat_id=private_sender_id, sender_id=private_sender_id, - message_id=_safe_int(trigger_message_id), + message_id=safe_int(trigger_message_id), scope_key=build_attachment_scope( user_id=private_sender_id, request_type="private", @@ -557,29 +684,41 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_nickname: str = group_sender.get("nickname", "") sender_role: str = group_sender.get("role", "member") sender_title: str = group_sender.get("title", "") + sender_level: str = str(group_sender.get("level", "")).strip() # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) - group_attachments = await self._collect_message_attachments( - message_content, - group_id=group_id, - request_type="group", - ) safe_text = redact_string(text) logger.info( f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " f"role={sender_role} | {safe_text[:100]}" ) - # 保存消息到历史记录 (使用处理后的内容) - # 获取群聊名 - group_name = "" - try: - group_info = await self.onebot.get_group_info(group_id) - if group_info: - group_name = group_info.get("group_name", "") - except Exception as e: - logger.warning(f"获取群聊名失败: {e}") + # 并行执行 3 个独立的异步操作:附件收集、群信息获取、历史内容解析 + async def _fetch_group_name() -> str: + try: + info = await self.onebot.get_group_info(group_id) + if info: + return str(info.get("group_name", "") or "") + except Exception as e: + logger.warning(f"获取群聊名失败: {e}") + return "" + + group_attachments, group_name, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ), + _fetch_group_name(), + parse_message_content_for_history( + message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() self._schedule_profile_display_name_refresh( task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", @@ -589,14 +728,8 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=str(group_name or "").strip(), ) - # 使用新的 utils - parsed_content = await parse_message_content_for_history( - message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ) - parsed_content = append_attachment_text(parsed_content, group_attachments) + # 保存消息到历史记录 + parsed_content = append_attachment_text(parsed_content_raw, group_attachments) safe_parsed = redact_string(parsed_content) logger.debug( f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." @@ -610,12 +743,23 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_name, role=sender_role, title=sender_title, + level=sender_level, message_id=trigger_message_id, attachments=group_attachments, ) # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 + # 同时把 bot 自身的发言写入复读计数器,使窗口中留有 bot 标记, + # 后续触发检查时会排除含 bot 的窗口,防止"bot 先发 → 用户跟发"或 + # "用户发到一半 bot 插入"等情况误触复读。 if sender_id == self.config.bot_qq: + if self.config.repeat_enabled and text: + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + n = self.config.repeat_threshold + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] return self._schedule_meme_ingest( @@ -623,13 +767,29 @@ async def handle_message(self, event: dict[str, Any]) -> None: chat_type="group", chat_id=group_id, sender_id=sender_id, - message_id=_safe_int(trigger_message_id), + message_id=safe_int(trigger_message_id), scope_key=build_attachment_scope(group_id=group_id, request_type="group"), ) # 检查是否 @ 了机器人(后续分流共用) is_at_bot = self.ai_coordinator._is_at_bot(message_content) + # 假@检测:识别 "@昵称" 纯文本形式 + # normalized_text 用于命令解析和 AI 路由,原始 text 已用于历史/日志 + is_fake_at = False + normalized_text = text + if not is_at_bot: + nicknames = await self._bot_nickname_cache.get_nicknames(group_id) + if nicknames: + is_fake_at, normalized_text = strip_fake_at(text, nicknames) + if is_fake_at: + is_at_bot = True + logger.info( + "[假@] 识别到假@: group=%s sender=%s", + group_id, + sender_id, + ) + # 关闭“每条消息处理”后,仅处理 @ 消息(私聊/拍一拍在其他分支中处理) if not self.config.should_process_group_message(is_at_bot=is_at_bot): logger.debug( @@ -684,6 +844,55 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) return + # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold + if self.config.repeat_enabled and text: + n = self.config.repeat_threshold + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + # 只保留最近 n 条 + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] + counter = self._repeat_counter[group_id] + + if len(counter) >= n: + last_n = counter[-n:] + texts = [t for t, _ in last_n] + senders = [s for _, s in last_n] + if ( + len(set(texts)) == 1 + and len(set(senders)) == n + and self.config.bot_qq not in senders + ): + reply_text = texts[0] + # 冷却检查:同一内容在冷却期内不再复读 + if self._is_repeat_on_cooldown(group_id, reply_text): + self._repeat_counter[group_id] = [] + logger.debug( + "[复读] 冷却中跳过: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + else: + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + # 清空计数器防止重复触发 + self._repeat_counter[group_id] = [] + self._record_repeat_cooldown(group_id, texts[0]) + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) + return + # Bilibili 视频自动提取 if self.config.bilibili_auto_extract_enabled: if self.config.is_bilibili_auto_extract_allowed_group(group_id): @@ -709,26 +918,28 @@ async def handle_message(self, event: dict[str, Any]) -> None: # 提取文本内容 # (已在上方提取用于日志记录) - # 只有被@时才处理斜杠命令 + # 只有被@时才处理斜杠命令(使用 normalized_text 以支持假@后的命令) if is_at_bot: - command = self.command_dispatcher.parse_command(text) + command = self.command_dispatcher.parse_command(normalized_text) if command: await self.command_dispatcher.dispatch(group_id, sender_id, command) return - # 自动回复处理 + # 自动回复处理(使用 normalized_text 以去除假@前缀) display_name = sender_card or sender_nickname or str(sender_id) await self.ai_coordinator.handle_auto_reply( group_id, sender_id, - text, + normalized_text, message_content, attachments=group_attachments, sender_name=display_name, group_name=group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, trigger_message_id=trigger_message_id, + is_fake_at=is_fake_at, ) async def _record_private_poke_history( @@ -790,11 +1001,13 @@ async def _record_group_poke_history( sender_nickname = "" sender_role = "member" sender_title = "" + sender_level = "" if isinstance(sender, dict): sender_card = str(sender.get("card", "")).strip() sender_nickname = str(sender.get("nickname", "")).strip() sender_role = str(sender.get("role", "member")).strip() or "member" sender_title = str(sender.get("title", "")).strip() + sender_level = str(sender.get("level", "")).strip() if not sender_card and not sender_nickname: try: @@ -808,6 +1021,7 @@ async def _record_group_poke_history( str(member_info.get("role", "member")).strip() or "member" ) sender_title = str(member_info.get("title", "")).strip() + sender_level = str(member_info.get("level", "")).strip() except Exception as exc: logger.warning( "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", @@ -851,6 +1065,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, role=sender_role, title=sender_title, + level=sender_level, ) except Exception as exc: logger.warning( @@ -865,6 +1080,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, ) async def _extract_bilibili_ids( diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index e853f1e5..96f80abe 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -6,6 +6,7 @@ from datetime import datetime import hashlib import logging +import math import mimetypes from pathlib import Path import re @@ -14,6 +15,7 @@ from typing import Any from uuid import uuid4 +from openai import APIConnectionError, APIStatusError, APITimeoutError from PIL import Image from Undefined.attachments import AttachmentRecord @@ -27,6 +29,7 @@ from Undefined.memes.store import MemeStore from Undefined.memes.vector_store import MemeVectorStore from Undefined.utils.message_targets import resolve_message_target +from Undefined.utils.coerce import safe_int from Undefined.utils.paths import ensure_dir logger = logging.getLogger(__name__) @@ -46,16 +49,6 @@ def _now_iso() -> str: return datetime.now().isoformat(timespec="seconds") -def _safe_int(value: Any) -> int | None: - if value is None: - return None - try: - parsed = int(value) - except (TypeError, ValueError): - return None - return parsed if parsed > 0 else None - - def _guess_suffix(path: Path, mime_type: str) -> str: suffix = path.suffix.lower() if suffix: @@ -78,6 +71,77 @@ def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: return normalize_string_list(raw_tags) +def _is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = _sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def _sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + # 去重并保持顺序 + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + @dataclass class _IngestDigestLockEntry: lock: asyncio.Lock @@ -374,6 +438,7 @@ async def delete_meme(self, uid: str) -> bool: self._delete_file_if_exists, Path(record.preview_path), ) + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return True def _delete_file_if_exists(self, path: Path) -> None: @@ -382,6 +447,17 @@ def _delete_file_if_exists(self, path: Path) -> None: except OSError: logger.debug("[memes] 删除文件失败: path=%s", path, exc_info=True) + def _cleanup_gif_frame_files(self, uid: str) -> None: + """清理 GIF 多帧分析产生的临时帧文件 ({uid}_f{i}.png)。""" + preview_dir = self._preview_dir() + for frame_file in preview_dir.glob(f"{uid}_f*.png"): + try: + frame_file.unlink(missing_ok=True) + except OSError: + logger.debug( + "[memes] 删除帧文件失败: path=%s", frame_file, exc_info=True + ) + async def _cleanup_meme_artifacts( self, *, @@ -408,6 +484,8 @@ async def _cleanup_meme_artifacts( await asyncio.to_thread(self._delete_file_if_exists, blob_path) if preview_path is not None and preview_path != blob_path: await asyncio.to_thread(self._delete_file_if_exists, preview_path) + if uid: + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) async def search_memes( self, @@ -675,7 +753,7 @@ async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: attachments=history_attachments, ) else: - preferred_temp_group_id = _safe_int(context.get("group_id")) + preferred_temp_group_id = safe_int(context.get("group_id")) or None sent_message_id = await sender.send_private_message( int(target_id), cq_message, @@ -794,25 +872,52 @@ async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: return if self._ai_client is None: raise RuntimeError("reanalyze requires ai_client") + analyze_path: str | list[str] = ( + record.preview_path if record.preview_path else record.blob_path + ) + # GIF 多帧模式:与 ingest 路径保持一致 + if record.is_animated: + cfg = self._cfg() + if str(getattr(cfg, "gif_analysis_mode", "grid")).lower() == "multi": + analyze_path = await self._prepare_gif_multi_frames( + Path(record.blob_path), uid + ) try: - judgement = await self._ai_client.judge_meme_image(record.blob_path) + judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise logger.exception( "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self.delete_meme(uid) return try: - described = await self._ai_client.describe_meme_image(record.blob_path) + described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise logger.exception( "[memes] describe stage failed during reanalyze: uid=%s err=%s", uid, exc, ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) auto_description = str(described.get("description") or "").strip() next_tags = _normalize_tags(described.get("tags")) if not auto_description and not next_tags: @@ -933,12 +1038,24 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: or mimetypes.guess_type(source_path.name)[0] or "application/octet-stream" ) - analyze_path = str( + analyze_path: str | list[str] = str( preview_path if preview_path is not None else blob_path ) + if ( + is_animated + and str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + == "multi" + ): + analyze_path = await self._prepare_gif_multi_frames( + source_path, uid + ) try: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise logger.exception( "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", uid, @@ -946,6 +1063,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: ) judgement = {"is_meme": False} if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self._cleanup_meme_artifacts( uid=None, blob_path=blob_path, @@ -956,10 +1075,17 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: try: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) + raise logger.exception( "[memes] describe stage failed, drop uid=%s err=%s", uid, exc ) described = {"description": "", "tags": []} + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) tags = _normalize_tags(described.get("tags")) auto_description = str(described.get("description") or "").strip() if not auto_description and not tags: @@ -1043,17 +1169,44 @@ def _copy() -> None: if not is_animated: return blob_path + cfg = self._cfg() + mode = str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) preview_path = self._preview_dir() / f"{target_uid}.png" def _render_preview() -> None: - with Image.open(source_path) as image: - image.seek(0) - frame = image.convert("RGBA") - frame.save(preview_path, format="PNG") + frames = _extract_gif_frames(source_path, n_frames) + if mode == "multi": + # multi 模式也需要生成一张预览用于存储/展示,取首帧 + frames[0].save(preview_path, format="PNG") + else: + _compose_grid(frames, preview_path) + for f in frames: + f.close() await asyncio.to_thread(_render_preview) return preview_path + async def _prepare_gif_multi_frames( + self, source_path: Path, target_uid: str + ) -> list[str]: + """multi 模式:将 GIF 各帧单独保存为 PNG,返回路径列表。""" + cfg = self._cfg() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_dir = self._preview_dir() + + def _render_frames() -> list[str]: + frames = _extract_gif_frames(source_path, n_frames) + paths: list[str] = [] + for i, frame in enumerate(frames): + p = preview_dir / f"{target_uid}_f{i}.png" + frame.save(p, format="PNG") + frame.close() + paths.append(str(p)) + return paths + + return await asyncio.to_thread(_render_frames) + def _hash_file(self, path: Path) -> str: hasher = hashlib.sha256() with path.open("rb") as handle: diff --git a/src/Undefined/render.py b/src/Undefined/render.py index dba3e136..3678ce1b 100644 --- a/src/Undefined/render.py +++ b/src/Undefined/render.py @@ -91,19 +91,28 @@ def _parse() -> str: return full_html -async def render_html_to_image(html_content: str, output_path: str) -> None: +async def render_html_to_image( + html_content: str, + output_path: str, + *, + viewport_width: int = 1280, +) -> None: """ 将 HTML 字符串转换为 PNG 图片 参数: html_content: 完整的 HTML 字符串 output_path: 输出图片路径 (例如 'result.png') + viewport_width: 视口宽度(像素),默认 1280 """ async with async_playwright() as p: # 启动无头浏览器 browser = await p.chromium.launch(headless=True) # 设置上下文,可以指定缩放比例(device_scale_factor),2代表2倍清晰度(Retina) - context = await browser.new_context(device_scale_factor=2) + context = await browser.new_context( + device_scale_factor=2, + viewport={"width": viewport_width, "height": 800}, + ) page = await context.new_page() # 设置页面内容 diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 5319a7d5..98ca1ea7 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -6,6 +6,7 @@ from Undefined.attachments import ( attachment_refs_to_xml, build_attachment_scope, + dispatch_pending_file_sends, render_message_with_pic_placeholders, ) from Undefined.config import Config @@ -72,7 +73,9 @@ async def handle_auto_reply( group_name: str = "未知群聊", sender_role: str = "member", sender_title: str = "", + sender_level: str = "", trigger_message_id: int | None = None, + is_fake_at: bool = False, ) -> None: """群聊自动回复入口:根据消息内容、命中情况和安全检测决定是否回复 @@ -86,13 +89,15 @@ async def handle_auto_reply( group_name: 群名称 sender_role: 发送者角色 (owner/admin/member) sender_title: 发送者群头衔 + is_fake_at: 是否为假@(纯文本 @昵称)触发 """ - is_at_bot = is_poke or self._is_at_bot(message_content) + is_at_bot = is_poke or is_fake_at or self._is_at_bot(message_content) logger.debug( - "[自动回复] group=%s sender=%s at_bot=%s text_len=%s", + "[自动回复] group=%s sender=%s at_bot=%s fake_at=%s text_len=%s", group_id, sender_id, is_at_bot, + is_fake_at, len(text), ) @@ -130,6 +135,7 @@ async def handle_auto_reply( text, attachments=attachments, message_id=trigger_message_id, + level=sender_level, ) logger.debug( "[自动回复] full_question_len=%s group=%s sender=%s", @@ -472,6 +478,12 @@ async def send_private_cb( rendered.delivery_text, history_message=rendered.history_text, ) + await dispatch_pending_file_sends( + rendered, + sender=self.sender, + target_type="private", + target_id=user_id, + ) except Exception: logger.exception("私聊回复执行出错") raise @@ -707,6 +719,7 @@ def _build_prompt( text: str, attachments: list[dict[str, str]] | None = None, message_id: int | None = None, + level: str = "", ) -> str: """构建最终发送给 AI 的结构化 XML 消息 Prompt @@ -724,10 +737,11 @@ def _build_prompt( message_id_attr = "" if message_id is not None: message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" attachment_xml = ( f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" ) - return f"""{prefix} + return f"""{prefix} {safe_text}{attachment_xml} diff --git a/src/Undefined/services/command.py b/src/Undefined/services/command.py index 8abfcf55..075fd954 100644 --- a/src/Undefined/services/command.py +++ b/src/Undefined/services/command.py @@ -94,6 +94,7 @@ def __init__( security: SecurityService, queue_manager: Any = None, rate_limiter: Any = None, + history_manager: Any = None, ) -> None: """初始化命令分发器 @@ -106,6 +107,7 @@ def __init__( security: 安全审计与限流服务 queue_manager: AI 请求队列管理器 rate_limiter: 速率限制器 + history_manager: 消息历史记录管理器 """ self.config = config self.sender = sender @@ -115,6 +117,7 @@ def __init__( self.security = security self.queue_manager = queue_manager self.rate_limiter = rate_limiter + self.history_manager = history_manager self.naga_store: Any = None self._token_usage_storage = TokenUsageStorage() # 存储 stats 分析结果,用于队列回调 @@ -1078,6 +1081,8 @@ async def _send_target_message(message: str) -> None: scope=scope, user_id=user_id, is_webui_session=is_webui_session, + cognitive_service=getattr(self.ai, "_cognitive_service", None), + history_manager=self.history_manager, ) try: diff --git a/src/Undefined/services/commands/context.py b/src/Undefined/services/commands/context.py index 507af5e8..56732b97 100644 --- a/src/Undefined/services/commands/context.py +++ b/src/Undefined/services/commands/context.py @@ -32,3 +32,5 @@ class CommandContext: scope: str = "group" user_id: int | None = None is_webui_session: bool = False + cognitive_service: Any = None + history_manager: Any = None diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json new file mode 100644 index 00000000..855776f7 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["info_agent", "web_agent", "naga_code_analysis_agent"] +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json new file mode 100644 index 00000000..a97b0a97 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "arxiv_analysis_agent", + "description": "arXiv 论文深度分析助手,下载并解析 arXiv 论文 PDF 全文,进行结构化学术深度分析。支持按 arXiv ID 或 URL 获取论文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID(如 '2301.07041')、arXiv URL(如 'https://arxiv.org/abs/2301.07041')或 'arXiv:2301.07041' 格式" + }, + "prompt": { + "type": "string", + "description": "可选的分析需求,例如:'重点分析方法论和实验设计'、'关注与 diffusion model 的对比'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py new file mode 100644 index 00000000..7b6d6052 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.arxiv.parser import normalize_arxiv_id +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 arxiv_analysis_agent。""" + + raw_paper_id = str(args.get("paper_id", "")).strip() + user_prompt = str(args.get("prompt", "")).strip() + + if not raw_paper_id: + return "请提供 arXiv 论文 ID 或 URL" + + paper_id = normalize_arxiv_id(raw_paper_id) + if paper_id is None: + return f"无法解析 arXiv 标识:{raw_paper_id}" + + context["arxiv_paper_id"] = paper_id + + context_messages = [ + { + "role": "system", + "content": f"当前任务:深度分析 arXiv 论文 {paper_id}", + } + ] + + user_content = user_prompt if user_prompt else f"请深度分析论文 arXiv:{paper_id}" + + return await run_agent_with_tools( + agent_name="arxiv_analysis_agent", + user_content=user_content, + context_messages=context_messages, + empty_user_content_message="请提供 arXiv 论文 ID 或 URL", + default_prompt="你是一个学术论文深度分析助手。", + context=context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=15, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md new file mode 100644 index 00000000..1b306e32 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md @@ -0,0 +1 @@ +arXiv 论文深度分析助手:下载 arXiv 论文 PDF 全文并进行结构化学术深度分析,涵盖方法论、实验、创新点、局限性等维度。 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md new file mode 100644 index 00000000..029ab1e8 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md @@ -0,0 +1,28 @@ +你是学术论文深度分析助手,专门对 arXiv 论文进行全面、结构化的学术分析。 + +工作流程: +1. 先调用 `fetch_paper` 获取论文元数据和摘要 +2. 调用 `read_paper_pages` 分批阅读论文全文(每次读取一定页数范围) +3. 读完后产出结构化深度分析 + +阅读策略: +- 先通过 `fetch_paper` 了解总页数和摘要 +- 用 `read_paper_pages` 按区间阅读(如 1-5、6-10、11-15 等),每次读 5 页 +- 对于较长的论文(>20 页),可以选择性跳过附录/参考文献,集中精力分析正文 +- 短论文可以一次性读完 + +分析输出结构: +- **概要**:一句话总结论文核心贡献 +- **研究背景与动机**:论文要解决什么问题、为什么重要 +- **方法论**:核心技术方案、算法、模型架构的详细解析 +- **实验与结果**:实验设置、基准对比、主要结论 +- **创新点与贡献**:论文的主要新颖之处 +- **局限性与未来方向**:作者提到的或你分析出的不足和可改进之处 +- **总评**:论文的整体质量和影响力评估 + +注意事项: +- 保持学术严谨,用客观语言描述 +- 如果用户指定了分析侧重(prompt),要重点展开相关部分 +- 对公式和算法用自然语言解释,不要直接复制 LaTeX +- PDF 提取可能有格式问题(乱码/缺失),遇到时说明情况继续分析 +- 如果论文是中文,用中文输出分析;否则用中文输出分析但保留关键术语英文原文 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json new file mode 100644 index 00000000..5e69a691 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "fetch_paper", + "description": "获取 arXiv 论文元数据并下载 PDF,返回论文基本信息(标题、作者、摘要、页数)。调用后可用 read_paper_pages 分页阅读正文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID,如 '2301.07041'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py new file mode 100644 index 00000000..d75cfcd1 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py @@ -0,0 +1,75 @@ +"""获取 arXiv 论文元数据并下载 PDF 到本地缓存。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +from Undefined.arxiv.client import get_paper_info +from Undefined.arxiv.downloader import download_paper_pdf +from Undefined.arxiv.parser import normalize_arxiv_id + +logger = logging.getLogger(__name__) + +_MAX_FILE_SIZE_MB = 50 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + raw_id = str(args.get("paper_id", "")).strip() + if not raw_id: + return "错误:请提供 arXiv 论文 ID" + + paper_id = normalize_arxiv_id(raw_id) or raw_id + + request_context = {"request_id": context.get("request_id", "-")} + + try: + paper = await get_paper_info(paper_id, context=request_context) + except Exception as exc: + logger.exception("[arxiv_analysis] 获取论文元数据失败: %s", exc) + return f"错误:获取论文 {paper_id} 元数据失败 — {exc}" + + if paper is None: + return f"错误:未找到论文 arXiv:{paper_id}" + + lines: list[str] = [ + f"论文: {paper.title}", + f"ID: {paper.paper_id}", + f"作者: {'、'.join(paper.authors[:10])}{'(等 ' + str(len(paper.authors)) + ' 位)' if len(paper.authors) > 10 else ''}", + f"分类: {paper.primary_category}", + f"发布: {paper.published[:10]}", + f"更新: {paper.updated[:10]}", + f"链接: {paper.abs_url}", + f"\n摘要:\n{paper.summary}", + ] + + try: + result, task_dir = await download_paper_pdf( + paper, max_file_size_mb=_MAX_FILE_SIZE_MB, context=request_context + ) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 下载失败: %s", exc) + lines.append(f"\nPDF 下载失败: {exc}(可基于摘要进行分析)") + return "\n".join(lines) + + if result.path is None: + lines.append(f"\nPDF 不可用(状态: {result.status}),请基于摘要进行分析") + return "\n".join(lines) + + try: + doc = fitz.open(str(result.path)) + try: + page_count = len(doc) + lines.append(f"\nPDF 已下载: {page_count} 页") + context["_arxiv_pdf_path"] = str(result.path) + context["_arxiv_pdf_pages"] = page_count + context["_arxiv_task_dir"] = str(task_dir) + finally: + doc.close() + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + lines.append(f"\nPDF 无法打开: {exc}(可基于摘要进行分析)") + + return "\n".join(lines) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json new file mode 100644 index 00000000..abc268e5 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "read_paper_pages", + "description": "读取已下载论文 PDF 的指定页范围文本。必须先调用 fetch_paper 下载论文。", + "parameters": { + "type": "object", + "properties": { + "start_page": { + "type": "integer", + "description": "起始页码(从 1 开始)" + }, + "end_page": { + "type": "integer", + "description": "结束页码(包含该页)" + } + }, + "required": ["start_page", "end_page"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py new file mode 100644 index 00000000..ed013d1d --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py @@ -0,0 +1,68 @@ +"""分页读取已下载的 arXiv 论文 PDF 文本。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +logger = logging.getLogger(__name__) + +_MAX_CHARS_PER_READ = 15000 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + pdf_path = context.get("_arxiv_pdf_path") + total_pages = context.get("_arxiv_pdf_pages") + + if not pdf_path: + return "错误:请先调用 fetch_paper 下载论文" + + try: + start_page = int(args.get("start_page", 1)) + end_page = int(args.get("end_page", start_page)) + except (TypeError, ValueError): + return "错误:页码必须为整数" + + if start_page < 1: + start_page = 1 + if total_pages and end_page > total_pages: + end_page = total_pages + if start_page > end_page: + return f"错误:起始页 {start_page} 大于结束页 {end_page}" + + try: + doc = fitz.open(pdf_path) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + return f"错误:PDF 文件无法打开 — {exc}" + + try: + actual_pages = len(doc) + if start_page > actual_pages: + return f"错误:论文共 {actual_pages} 页,请求的起始页 {start_page} 超出范围" + + end_page = min(end_page, actual_pages) + text_parts: list[str] = [] + total_chars = 0 + + for page_num in range(start_page - 1, end_page): + page = doc.load_page(page_num) + raw_text = page.get_text() + page_text = str(raw_text) if raw_text else "" + + if total_chars + len(page_text) > _MAX_CHARS_PER_READ and text_parts: + text_parts.append( + f"\n[第 {page_num + 1} 页起文本已截断,请用更小的页范围重新读取]" + ) + break + + text_parts.append(f"--- 第 {page_num + 1} 页 ---") + text_parts.append(page_text if page_text.strip() else "(此页无可提取文本)") + total_chars += len(page_text) + + header = f"论文内容(第 {start_page}-{end_page} 页,共 {actual_pages} 页)" + return f"{header}\n\n" + "\n".join(text_parts) + finally: + doc.close() diff --git a/src/Undefined/skills/agents/runner.py b/src/Undefined/skills/agents/runner.py index 69d9e15c..6e568d8b 100644 --- a/src/Undefined/skills/agents/runner.py +++ b/src/Undefined/skills/agents/runner.py @@ -7,6 +7,7 @@ import aiofiles +from Undefined.config.models import AgentModelConfig from Undefined.ai.transports.openai_transport import RESPONSES_OUTPUT_ITEMS_KEY from Undefined.skills.agents.agent_tool_registry import AgentToolRegistry from Undefined.skills.anthropic_skills import AnthropicSkillRegistry @@ -96,14 +97,21 @@ async def run_agent_with_tools( if not ai_client: return "AI client 未在上下文中提供" - agent_config = ai_client.agent_config - # 动态选择 agent 模型 - group_id = context.get("group_id", 0) or 0 - user_id = context.get("user_id", 0) or 0 - global_enabled = runtime_config.model_pool_enabled if runtime_config else False - agent_config = ai_client.model_selector.select_agent_config( - agent_config, group_id=group_id, user_id=user_id, global_enabled=global_enabled - ) + model_config_override = context.get("model_config_override") + if isinstance(model_config_override, AgentModelConfig): + agent_config = model_config_override + else: + agent_config = ai_client.agent_config + # 动态选择 agent 模型 + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + global_enabled = runtime_config.model_pool_enabled if runtime_config else False + agent_config = ai_client.model_selector.select_agent_config( + agent_config, + group_id=group_id, + user_id=user_id, + global_enabled=global_enabled, + ) system_prompt = await load_prompt_text(agent_dir, default_prompt) # 注入 agent 私有 Anthropic Skills 元数据到 system prompt diff --git a/src/Undefined/skills/agents/summary_agent/__init__.py b/src/Undefined/skills/agents/summary_agent/__init__.py new file mode 100644 index 00000000..13726fa1 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/__init__.py @@ -0,0 +1 @@ +# summary_agent diff --git a/src/Undefined/skills/agents/summary_agent/config.json b/src/Undefined/skills/agents/summary_agent/config.json new file mode 100644 index 00000000..2e30e126 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/config.json @@ -0,0 +1,28 @@ +{ + "type": "function", + "function": { + "name": "summary_agent", + "description": "消息总结助手,拉取指定范围的聊天消息并进行智能总结。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "用户的总结需求,例如:'总结最近50条消息'、'总结过去1小时的聊天'、'总结今天的技术讨论'" + }, + "count": { + "type": "integer", + "description": "要总结的最近消息条数。与 time_range 二选一。" + }, + "time_range": { + "type": "string", + "description": "要总结的时间范围,如 '1h'、'6h'、'1d'、'7d'。与 count 二选一。" + }, + "focus": { + "type": "string", + "description": "可选的重点关注主题。" + } + } + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/handler.py b/src/Undefined/skills/agents/summary_agent/handler.py new file mode 100644 index 00000000..ba64eb51 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/handler.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +def _normalize_positive_int(value: Any) -> int | None: + try: + parsed = int(value) + except (TypeError, ValueError): + return None + return parsed if parsed > 0 else None + + +def _build_user_content(args: dict[str, Any]) -> str: + prompt = str(args.get("prompt", "")).strip() + count = _normalize_positive_int(args.get("count")) + time_range = str(args.get("time_range", "") or "").strip() + focus = str(args.get("focus", "") or "").strip() + + if not prompt: + if time_range: + prompt = f"请总结过去 {time_range} 内的聊天消息" + elif count is not None: + prompt = f"请总结最近 {count} 条聊天消息" + + instructions: list[str] = [] + if time_range: + instructions.append(f"必须调用 fetch_messages,并使用 time_range={time_range}") + elif count is not None: + instructions.append(f"必须调用 fetch_messages,并使用 count={count}") + else: + instructions.append("必须调用 fetch_messages,并使用默认的 count=50") + + if focus: + instructions.append(f"总结时重点关注:{focus}") + + instructions.append("输出尽量精炼,控制在 2 到 3 个短段落内") + instructions.append("不要使用 emoji、markdown、项目符号或标题") + + if not prompt: + return "" + + return f"{prompt}\n\n执行要求:\n" + "\n".join( + f"{index}. {item}" for index, item in enumerate(instructions, start=1) + ) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 summary_agent。""" + user_prompt = _build_user_content(args) + runtime_config = context.get("runtime_config") + run_context = context + if ( + runtime_config is not None + and getattr(runtime_config, "summary_model_configured", False) + and getattr(runtime_config, "summary_model", None) is not None + ): + run_context = dict(context) + run_context["model_config_override"] = runtime_config.summary_model + return await run_agent_with_tools( + agent_name="summary_agent", + user_content=user_prompt, + empty_user_content_message="请提供您的总结需求", + default_prompt=( + "你是一个消息总结助手。" + "必须严格按照用户给定的 count/time_range/focus 约束调用 fetch_messages," + "不要擅自扩大范围。" + "输出要简短、朴素、信息密度高。" + ), + context=run_context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=10, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/summary_agent/intro.md b/src/Undefined/skills/agents/summary_agent/intro.md new file mode 100644 index 00000000..a7343dbc --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/intro.md @@ -0,0 +1,30 @@ +# summary_agent - 消息总结助手 + +## 定位 +专业的聊天消息总结智能体,从海量聊天记录中提取关键信息并生成结构化报告。 + +## 擅长 +- ✅ 按**条数**拉取消息 (默认50条,最大500条) +- ✅ 按**时间范围**拉取消息 (支持1h/6h/1d/7d等格式) +- ✅ **话题提取**: 识别主要讨论主题 +- ✅ **要点归纳**: 总结重要决策、结论、共识 +- ✅ **参与者分析**: 识别活跃用户及其贡献 +- ✅ **资源收集**: 提取链接、代码片段等 +- ✅ 输出**结构化、清晰**的总结报告 + +## 边界 +- ❌ **不做实时监控**: 仅分析历史消息,不监听新消息 +- ❌ **不做情感分析**: 专注事实总结,不评价情绪 +- ❌ **不做预测**: 不推测未来走向 + +## 输入偏好 +- 明确的**条数要求**: "总结最近100条消息" +- 明确的**时间范围**: "总结过去6小时的聊天"、"总结今天的讨论" +- 特定的**总结目标**: "总结技术讨论"、"提取所有链接" +- 如果用户仅说"总结一下",默认总结最近50条消息 + +## 适用场景 +- 快速了解错过的聊天内容 +- 回顾会议或讨论要点 +- 整理特定时间段的聊天记录 +- 提取重要链接和资源 diff --git a/src/Undefined/skills/agents/summary_agent/prompt.md b/src/Undefined/skills/agents/summary_agent/prompt.md new file mode 100644 index 00000000..8e0a1d65 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/prompt.md @@ -0,0 +1,49 @@ +# 消息总结助手 + +你是一个专业的聊天消息总结助手,擅长从大量聊天记录中提取关键信息并生成简洁的总结。 + +## 核心能力 + +- 使用 `fetch_messages` 工具拉取指定范围的聊天消息 +- 支持按消息条数(如最近50条)或时间范围(如过去1小时、今天)筛选 +- 提取主题、关键参与者、重要决策、链接资源等 +- 生成清晰、自然的总结 + +## 工作流程 + +1. 理解需求: 分析用户的总结需求,确定查询参数 + - 如果用户消息里明确给了 `count` / `time_range` / 重点关注内容,必须严格照着执行 + - 如果用户指定了条数(如"最近50条"),使用 `count` 参数 + - 如果用户指定了时间范围(如"过去1小时"、"今天"),使用 `time_range` 参数 + - 如果用户未明确指定,默认使用最近50条消息 + +2. 拉取消息: 调用 `fetch_messages` 工具获取聊天记录 + - `count`: 消息条数,默认50,最大500 + - `time_range`: 时间范围,支持 "1h"(1小时)、"6h"(6小时)、"1d"(1天)、"7d"(7天) + +3. 分析总结: 对获取的消息进行智能分析 + - 识别主要讨论话题 + - 提取关键参与者及其贡献 + - 总结重要决策、结论或共识 + - 收集提到的链接、资源、代码片段 + - 标注特别重要或需要关注的信息 + +4. 生成报告: 以自然、朴素的文字段落输出总结 + - 语言精炼准确 + - 只保留高信息密度内容,不要把聊天流水账全复述一遍 + +## 输出格式要求 + +务必遵守以下格式规则: +- 不要使用 emoji 表情符号 +- 不要使用 markdown 格式(不要用 #、**、- 列表等) +- 使用朴素的纯文字段落,自然地组织内容 +- 分段描述不同话题 +- 保持简洁但全面,用正常的叙述语气 +- 如果没有特别重要的参与者、链接或待办,就不要硬凑这些内容 + +## 注意事项 +- 保持客观中立,不加入主观评价 +- 如果消息量很大,优先突出重点,可省略琐碎细节 +- 如果讨论涉及敏感话题,谨慎措辞 +- 如果消息为空或无有效内容,明确说明 diff --git a/src/Undefined/skills/agents/summary_agent/tools/__init__.py b/src/Undefined/skills/agents/summary_agent/tools/__init__.py new file mode 100644 index 00000000..fb43c212 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/__init__.py @@ -0,0 +1 @@ +# summary_agent tools diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json new file mode 100644 index 00000000..30042bc1 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json @@ -0,0 +1,20 @@ +{ + "type": "function", + "function": { + "name": "fetch_messages", + "description": "从当前会话拉取聊天消息。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "description": "要获取的消息条数,默认50,上限由服务器配置决定。" + }, + "time_range": { + "type": "string", + "description": "时间范围,如 '1h'、'6h'、'1d'、'7d'。与 count 互斥,优先使用 time_range。" + } + } + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py new file mode 100644 index 00000000..56be9c33 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +import re +from datetime import datetime, timedelta +from typing import Any + +from Undefined.utils.xml import format_messages_xml + +logger = logging.getLogger(__name__) + +_TIME_RANGE_PATTERN = re.compile(r"^(\d+)([hHdDwW])$") +_TIME_UNIT_SECONDS = {"h": 3600, "d": 86400, "w": 604800} +_DEFAULT_COUNT = 50 + +# 以下值仅作为 runtime_config 缺失时的回退 +_FALLBACK_MAX_COUNT = 1000 +_FALLBACK_MAX_FETCH_FOR_TIME_FILTER = 5000 + + +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback + + +def _parse_time_range(value: str) -> int | None: + """Parse time range string like '1h', '6h', '1d', '7d' into seconds.""" + match = _TIME_RANGE_PATTERN.match(value.strip()) + if not match: + return None + amount = int(match.group(1)) + unit = match.group(2).lower() + return amount * _TIME_UNIT_SECONDS.get(unit, 3600) + + +def _filter_by_time( + messages: list[dict[str, Any]], seconds: int +) -> list[dict[str, Any]]: + """Filter messages to only include those within the given time range from now.""" + cutoff = datetime.now() - timedelta(seconds=seconds) + result = [] + for msg in messages: + ts_str = msg.get("timestamp", "") + if not ts_str: + continue + try: + ts = datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S") + except ValueError: + continue + if ts >= cutoff: + result.append(msg) + return result + + +def _normalize_messages_for_chat( + messages: list[dict[str, Any]], + *, + chat_type: str, + chat_id: str, +) -> list[dict[str, Any]]: + normalized: list[dict[str, Any]] = [] + for raw in messages: + msg = dict(raw) + if not str(msg.get("type", "") or "").strip(): + msg["type"] = chat_type + if not str(msg.get("chat_id", "") or "").strip(): + msg["chat_id"] = chat_id + if chat_type == "private" and not str(msg.get("chat_name", "") or "").strip(): + msg["chat_name"] = f"QQ用户{chat_id}" + normalized.append(msg) + return normalized + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """拉取当前会话的聊天消息。""" + history_manager = context.get("history_manager") + if not history_manager: + return "历史记录管理器未配置" + + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + + if int(group_id) > 0: + chat_type = "group" + chat_id = str(group_id) + else: + chat_type = "private" + chat_id = str(user_id) + + time_range_str = str(args.get("time_range", "")).strip() + raw_count = args.get("count", _DEFAULT_COUNT) + max_count = _get_history_limit( + context, "history_summary_fetch_limit", _FALLBACK_MAX_COUNT + ) + try: + count = min(max(int(raw_count), 1), max_count) + except (TypeError, ValueError): + count = _DEFAULT_COUNT + + if time_range_str: + seconds = _parse_time_range(time_range_str) + if seconds is None: + return f"无法解析时间范围: {time_range_str}(支持格式: 1h, 6h, 1d, 7d)" + max_time_fetch = _get_history_limit( + context, + "history_summary_time_fetch_limit", + _FALLBACK_MAX_FETCH_FOR_TIME_FILTER, + ) + fetch_count = max(count * 2, max_time_fetch) + messages = history_manager.get_recent(chat_id, chat_type, 0, fetch_count) + if messages: + messages = _filter_by_time(messages, seconds) + else: + messages = history_manager.get_recent(chat_id, chat_type, 0, count) + + if not messages: + return "当前会话暂无消息记录" + + messages = _normalize_messages_for_chat( + messages, chat_type=chat_type, chat_id=chat_id + ) + + formatted = format_messages_xml(messages) + total = len(messages) + header = f"共获取 {total} 条消息" + if time_range_str: + header += f"(时间范围: {time_range_str})" + return f"{header}\n\n{formatted}" diff --git a/src/Undefined/skills/agents/web_agent/callable.json b/src/Undefined/skills/agents/web_agent/callable.json index ab8b02aa..bc537f29 100644 --- a/src/Undefined/skills/agents/web_agent/callable.json +++ b/src/Undefined/skills/agents/web_agent/callable.json @@ -1,4 +1,4 @@ { "enabled": true, - "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent"] + "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent", "summary_agent", "arxiv_analysis_agent"] } diff --git a/src/Undefined/skills/commands/help/config.json b/src/Undefined/skills/commands/help/config.json index 700dcdca..8639fb7f 100644 --- a/src/Undefined/skills/commands/help/config.json +++ b/src/Undefined/skills/commands/help/config.json @@ -12,7 +12,7 @@ "show_in_help": true, "order": 10, "allow_in_private": true, - "aliases": [], + "aliases": ["h"], "help_footer": [ "查看详细帮助:/help ", "详细版权与免责声明:/cprt", diff --git a/src/Undefined/skills/commands/profile/config.json b/src/Undefined/skills/commands/profile/config.json new file mode 100644 index 00000000..e14c037a --- /dev/null +++ b/src/Undefined/skills/commands/profile/config.json @@ -0,0 +1,16 @@ +{ + "name": "profile", + "description": "查看侧写。默认显示你的用户侧写;加 g 查看当前群聊侧写(仅群聊可用)。群聊默认合并转发,可用 -t 直接发送、-r 渲染图片", + "usage": "/p [g] [-t|-f|-r]", + "example": "/p g -r", + "permission": "public", + "rate_limit": { + "user": 60, + "admin": 10, + "superadmin": 0 + }, + "show_in_help": true, + "order": 25, + "allow_in_private": true, + "aliases": ["me", "p"] +} diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py new file mode 100644 index 00000000..f3d94f67 --- /dev/null +++ b/src/Undefined/skills/commands/profile/handler.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import html +import logging +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from Undefined.services.commands.context import CommandContext +from Undefined.utils.paths import COGNITIVE_PROFILES_DIR, RENDER_CACHE_DIR, ensure_dir + +logger = logging.getLogger("profile") + +_MAX_PROFILE_LENGTH = 5000 + +_MODE_TEXT = "text" +_MODE_FORWARD = "forward" +_MODE_RENDER = "render" + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: + if len(text) <= limit: + return text + return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" + + +def _parse_args(args: list[str]) -> tuple[str, str, str]: + """解析参数,返回 (子命令, 输出模式, 目标ID)。 + + 目标 ID 为纯数字参数,仅超级管理员可使用。 + """ + sub = "" + mode = "" + target = "" + for arg in args: + lower = arg.lower().strip() + if lower in ("-t", "--text"): + mode = _MODE_TEXT + elif lower in ("-f", "--forward"): + mode = _MODE_FORWARD + elif lower in ("-r", "--render"): + mode = _MODE_RENDER + elif lower in ("g", "group"): + sub = lower + elif arg.strip().isdigit(): + target = arg.strip() + return sub, mode, target + + +def _profile_mtime(entity_type: str, entity_id: str) -> str | None: + """读取侧写文件最后修改时间,返回人类可读字符串。""" + p = COGNITIVE_PROFILES_DIR / f"{entity_type}s" / f"{entity_id}.md" + try: + mtime = p.stat().st_mtime + dt = datetime.fromtimestamp(mtime, tz=timezone.utc).astimezone() + return dt.strftime("%Y-%m-%d %H:%M") + except OSError: + return None + + +def _build_metadata( + entity_type: str, + entity_id: str, + profile_len: int, +) -> str: + """构建元数据摘要文本。""" + type_label = "用户" if entity_type == "user" else "群聊" + lines = [ + f"类型: {type_label}侧写", + f"ID: {entity_id}", + f"长度: {profile_len} 字", + ] + mtime = _profile_mtime(entity_type, entity_id) + if mtime: + lines.append(f"更新: {mtime}") + return "\n".join(lines) + + +# ── 发送方法 ────────────────────────────────────────────────── + + +async def _send_text(context: CommandContext, text: str) -> None: + """纯文本直接发送。""" + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def _send_forward( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """合并转发:节点1=元数据,节点2=完整侧写内容。""" + bot_qq = str(getattr(context.config, "bot_qq", 0)) + + def _node(content: str) -> dict[str, Any]: + return { + "type": "node", + "data": {"name": "Undefined", "uin": bot_qq, "content": content}, + } + + nodes = [_node(metadata), _node(profile_text)] + await context.onebot.send_forward_msg(context.group_id, nodes) + + +async def _send_render( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """渲染为图片发送——元数据区 + 侧写正文区。""" + from Undefined.render import render_html_to_image + + safe_meta = html.escape(metadata) + safe_body = html.escape(profile_text) + + meta_rows = "" + for line in safe_meta.split("\n"): + if ": " in line: + key, _, val = line.partition(": ") + meta_rows += ( + f'{key}{val}\n' + ) + + html_content = f""" + + +
+
{meta_rows}
+
{safe_body}
+
+""" + + output_dir = ensure_dir(RENDER_CACHE_DIR) + output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") + await render_html_to_image(html_content, output_path, viewport_width=480) + + abs_path = Path(output_path).resolve() + image_cq = f"[CQ:image,file=file://{abs_path}]" + + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, image_cq) + else: + await context.sender.send_group_message(context.group_id, image_cq) + + +# ── 入口 ───────────────────────────────────────────────────── + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /profile 命令。 + + 用法: /p [g] [-t|--text] [-f|--forward] [-r|--render] [目标ID] + g / group 查看群聊侧写(仅群聊可用) + -t / --text 纯文本直接发出 + -f / --forward 合并转发发出(群聊默认) + -r / --render 渲染为图片发出 + 目标ID 指定查询对象(仅超级管理员) + """ + cognitive_service = context.cognitive_service + if cognitive_service is None: + await _send_text(context, "❌ 侧写服务未启用") + return + + sub, mode, target = _parse_args(args) + + # 超管指定目标 + if target: + if not context.config.is_superadmin(context.sender_id): + await _send_text(context, "❌ 仅超级管理员可查看他人侧写") + return + + if sub in ("group", "g"): + if _is_private(context) and not target: + await _send_text(context, "❌ 私聊中不支持查看群聊侧写(可指定群号)") + return + entity_type = "group" + entity_id = target or str(context.group_id) + empty_hint = "暂无群聊侧写数据" + else: + entity_type = "user" + entity_id = target or str(context.sender_id) + empty_hint = "暂无侧写数据" + + profile = await cognitive_service.get_profile(entity_type, entity_id) + if not profile: + await _send_text(context, f"📭 {empty_hint}") + return + + profile = _truncate(profile) + metadata = _build_metadata(entity_type, entity_id, len(profile)) + + # 私聊始终纯文本 + if _is_private(context): + mode = _MODE_TEXT + + # 未指定模式:群聊默认合并转发 + if not mode: + mode = _MODE_FORWARD + + if mode == _MODE_TEXT: + await _send_text(context, profile) + elif mode == _MODE_RENDER: + try: + await _send_render(context, metadata, profile) + except Exception: + logger.exception("渲染侧写图片失败,回退到纯文本") + await _send_text(context, profile) + else: + try: + await _send_forward(context, metadata, profile) + except Exception: + logger.exception("发送合并转发失败,回退到纯文本") + await _send_text(context, profile) diff --git a/src/Undefined/skills/commands/summary/config.json b/src/Undefined/skills/commands/summary/config.json new file mode 100644 index 00000000..9135eda7 --- /dev/null +++ b/src/Undefined/skills/commands/summary/config.json @@ -0,0 +1,16 @@ +{ + "name": "summary", + "description": "总结聊天消息", + "usage": "/summary [条数|时间范围] [自定义描述]", + "example": "/summary 50", + "permission": "public", + "rate_limit": { + "user": 120, + "admin": 30, + "superadmin": 0 + }, + "show_in_help": true, + "order": 26, + "allow_in_private": true, + "aliases": ["sum"] +} diff --git a/src/Undefined/skills/commands/summary/handler.py b/src/Undefined/skills/commands/summary/handler.py new file mode 100644 index 00000000..39587bbf --- /dev/null +++ b/src/Undefined/skills/commands/summary/handler.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +import re +from typing import Any + +from Undefined.services.commands.context import CommandContext + +logger = logging.getLogger(__name__) + +_TIME_RANGE_RE = re.compile(r"^\d+[hHdDwW]$") +_DEFAULT_COUNT = 50 + + +def _parse_args(args: list[str]) -> tuple[int | None, str | None, str]: + """Parse command arguments into (count, time_range, custom_prompt). + + Returns: + Tuple of (count, time_range, custom_prompt). + count and time_range are mutually exclusive; at most one is non-None. + """ + if not args: + return _DEFAULT_COUNT, None, "" + + first = args[0] + rest = " ".join(args[1:]).strip() + + if first.isdigit(): + count = max(1, min(int(first), 500)) + return count, None, rest + + if _TIME_RANGE_RE.match(first): + return None, first, rest + + # First arg is not a number or time range — treat everything as prompt + return _DEFAULT_COUNT, None, " ".join(args).strip() + + +def _build_prompt(count: int | None, time_range: str | None, custom_prompt: str) -> str: + """Build the natural language prompt for summary_agent.""" + parts: list[str] = ["请总结"] + if time_range: + parts.append(f"过去 {time_range} 内的聊天消息") + elif count: + parts.append(f"最近 {count} 条聊天消息") + else: + parts.append(f"最近 {_DEFAULT_COUNT} 条聊天消息") + + if custom_prompt: + parts.append(f",重点关注:{custom_prompt}") + + return "".join(parts) + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +async def _send(context: CommandContext, text: str) -> None: + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /summary 命令。""" + if context.history_manager is None: + await _send(context, "❌ 历史记录管理器未配置") + return + + count, time_range, custom_prompt = _parse_args(args) + prompt = _build_prompt(count, time_range, custom_prompt) + + # Build agent context + agent_context: dict[str, Any] = { + "ai_client": context.ai, + "history_manager": context.history_manager, + "group_id": context.group_id, + "user_id": int(context.user_id or context.sender_id), + "sender_id": context.sender_id, + "request_type": "group" if int(context.group_id) > 0 else "private", + "runtime_config": getattr(context.ai, "runtime_config", None), + "queue_lane": None, + } + + try: + from Undefined.skills.agents.summary_agent.handler import execute as run_summary + + result = await run_summary( + { + "prompt": prompt, + "count": count, + "time_range": time_range, + "focus": custom_prompt, + }, + agent_context, + ) + except Exception: + logger.exception("[/summary] 执行总结失败") + await _send(context, "❌ 消息总结失败,请稍后重试") + return + + if not result or not result.strip(): + await _send(context, "📭 未能生成总结内容") + return + + await _send(context, result) diff --git a/src/Undefined/skills/tools/calculator/callable.json b/src/Undefined/skills/tools/calculator/callable.json new file mode 100644 index 00000000..0a69975c --- /dev/null +++ b/src/Undefined/skills/tools/calculator/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["*"] +} diff --git a/src/Undefined/skills/tools/calculator/config.json b/src/Undefined/skills/tools/calculator/config.json new file mode 100644 index 00000000..ad86fb49 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "calculator", + "description": "安全的多功能数学计算器。支持算术运算、科学函数、统计函数和常量。输入数学表达式即可计算。", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "数学表达式,例如:'2+3*4'、'sqrt(144)'、'sin(pi/6)'、'log(1000,10)'、'mean(1,2,3,4,5)'、'2**10'、'factorial(10)'、'gcd(48,18)'" + } + }, + "required": ["expression"] + } + } +} diff --git a/src/Undefined/skills/tools/calculator/handler.py b/src/Undefined/skills/tools/calculator/handler.py new file mode 100644 index 00000000..fd730249 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/handler.py @@ -0,0 +1,268 @@ +"""安全的多功能数学计算器。 + +通过 AST 解析数学表达式,仅允许数学运算,拒绝任何危险操作。 +支持:算术、幂运算、科学函数、统计函数、常量。 +""" + +from __future__ import annotations + +import ast +import math +import operator +import statistics +from typing import Any + +_UNARY_OPS: dict[type, Any] = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + +_BIN_OPS: dict[type, Any] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.BitXor: operator.pow, # ^ 也当幂运算(常见误用) +} + +_COMPARE_OPS: dict[type, Any] = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, +} + +_CONSTANTS: dict[str, float] = { + "pi": math.pi, + "PI": math.pi, + "e": math.e, + "E": math.e, + "tau": math.tau, + "inf": math.inf, + "nan": math.nan, +} + + +def _stat_fn(fn_name: str, args: list[float | int]) -> float | int: + """统计函数分发。""" + if len(args) < 1: + raise ValueError(f"{fn_name} 至少需要 1 个参数") + fn_map: dict[str, Any] = { + "mean": statistics.mean, + "median": statistics.median, + "stdev": statistics.stdev, + "variance": statistics.variance, + "pstdev": statistics.pstdev, + "pvariance": statistics.pvariance, + "harmonic_mean": statistics.harmonic_mean, + } + fn = fn_map.get(fn_name) + if fn is None: + raise ValueError(f"未知统计函数: {fn_name}") + if fn_name in ("stdev", "variance") and len(args) < 2: + raise ValueError(f"{fn_name} 至少需要 2 个数据点") + result: float | int = fn(args) + return result + + +_MATH_FUNCS: dict[str, Any] = { + # 基础 + "abs": abs, + "round": round, + "int": int, + "float": float, + # 幂与对数 + "sqrt": math.sqrt, + "cbrt": lambda x: x ** (1 / 3), + "pow": math.pow, + "exp": math.exp, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "ln": math.log, + # 三角 + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "atan2": math.atan2, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + # 角度转换 + "degrees": math.degrees, + "radians": math.radians, + # 取整 + "ceil": math.ceil, + "floor": math.floor, + "trunc": math.trunc, + # 组合数学 + "factorial": math.factorial, + "comb": math.comb, + "perm": math.perm, + "gcd": math.gcd, + "lcm": math.lcm, + # 其他 + "hypot": math.hypot, + "copysign": math.copysign, + "fmod": math.fmod, + "isqrt": math.isqrt, + # 统计(占位,实际调用 _stat_fn) + "mean": None, + "median": None, + "stdev": None, + "variance": None, + "pstdev": None, + "pvariance": None, + "harmonic_mean": None, + # 最值 + "max": max, + "min": min, + "sum": sum, +} + +_STAT_FUNCS = frozenset( + {"mean", "median", "stdev", "variance", "pstdev", "pvariance", "harmonic_mean"} +) + +_MAX_POWER = 10000 +_MAX_EXPRESSION_LENGTH = 500 +_MAX_COMBINATORIAL_ARG = 1000 + +_COMBINATORIAL_FUNCS = frozenset({"factorial", "perm", "comb"}) + + +class _SafeEvaluator(ast.NodeVisitor): + """安全 AST 求值器,仅允许数学运算。""" + + def visit(self, node: ast.AST) -> Any: + return super().visit(node) + + def generic_visit(self, node: ast.AST) -> Any: + raise ValueError(f"不支持的表达式语法: {type(node).__name__}") + + def visit_Expression(self, node: ast.Expression) -> Any: + return self.visit(node.body) + + def visit_Constant(self, node: ast.Constant) -> Any: + if isinstance(node.value, (int, float, complex)): + return node.value + raise ValueError(f"不支持的常量类型: {type(node.value).__name__}") + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in _CONSTANTS: + return _CONSTANTS[node.id] + raise ValueError(f"未知变量: {node.id}") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + op_fn = _UNARY_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的一元运算: {type(node.op).__name__}") + return op_fn(self.visit(node.operand)) + + def visit_BinOp(self, node: ast.BinOp) -> Any: + op_fn = _BIN_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的运算: {type(node.op).__name__}") + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, (ast.Pow, ast.BitXor)): + if isinstance(right, (int, float)) and abs(right) > _MAX_POWER: + raise ValueError(f"指数过大: {right}(上限 {_MAX_POWER})") + return op_fn(left, right) + + def visit_Compare(self, node: ast.Compare) -> Any: + left = self.visit(node.left) + for op, comparator in zip(node.ops, node.comparators): + op_fn = _COMPARE_OPS.get(type(op)) + if op_fn is None: + raise ValueError(f"不支持的比较: {type(op).__name__}") + right = self.visit(comparator) + if not op_fn(left, right): + return False + left = right + return True + + def visit_Call(self, node: ast.Call) -> Any: + if not isinstance(node.func, ast.Name): + raise ValueError("仅支持直接函数调用") + fn_name = node.func.id + if fn_name not in _MATH_FUNCS: + raise ValueError(f"未知函数: {fn_name}") + if node.keywords: + raise ValueError("不支持关键字参数") + + args = [self.visit(arg) for arg in node.args] + + if fn_name in _COMBINATORIAL_FUNCS: + for a in args: + if isinstance(a, (int, float)) and abs(a) > _MAX_COMBINATORIAL_ARG: + raise ValueError( + f"{fn_name}() 参数过大: {a}(上限 {_MAX_COMBINATORIAL_ARG})" + ) + + if fn_name in _STAT_FUNCS: + return _stat_fn(fn_name, args) + + fn = _MATH_FUNCS[fn_name] + return fn(*args) + + def visit_IfExp(self, node: ast.IfExp) -> Any: + condition = self.visit(node.test) + return self.visit(node.body) if condition else self.visit(node.orelse) + + def visit_Tuple(self, node: ast.Tuple) -> Any: + return tuple(self.visit(elt) for elt in node.elts) + + def visit_List(self, node: ast.List) -> Any: + return [self.visit(elt) for elt in node.elts] + + +def safe_eval(expression: str) -> str: + """安全计算数学表达式,返回字符串结果。""" + expr = expression.strip() + if not expr: + raise ValueError("表达式为空") + if len(expr) > _MAX_EXPRESSION_LENGTH: + raise ValueError( + f"表达式过长({len(expr)} 字符,上限 {_MAX_EXPRESSION_LENGTH})" + ) + + tree = ast.parse(expr, mode="eval") + result = _SafeEvaluator().visit(tree) + + if isinstance(result, float): + if result == int(result) and not (math.isinf(result) or math.isnan(result)): + return str(int(result)) + return f"{result:.10g}" + if isinstance(result, complex): + return str(result) + if isinstance(result, bool): + return str(result) + return str(result) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """计算数学表达式。""" + expression = str(args.get("expression", "")).strip() + if not expression: + return "请提供数学表达式" + + try: + result = safe_eval(expression) + return f"{expression} = {result}" + except ZeroDivisionError: + return f"计算错误:除以零 ({expression})" + except OverflowError: + return f"计算错误:结果溢出 ({expression})" + except (ValueError, TypeError) as exc: + return f"计算错误:{exc}" + except SyntaxError: + return f"表达式语法错误:{expression}" diff --git a/src/Undefined/skills/tools/end/handler.py b/src/Undefined/skills/tools/end/handler.py index d122ae25..42a98831 100644 --- a/src/Undefined/skills/tools/end/handler.py +++ b/src/Undefined/skills/tools/end/handler.py @@ -1,9 +1,14 @@ +from __future__ import annotations + from collections import deque from typing import Any, Dict import logging import re from Undefined.context import RequestContext +from Undefined.utils.coerce import safe_int +from Undefined.utils.xml import format_message_xml + from Undefined.end_summary_storage import ( EndSummaryLocation, EndSummaryRecord, @@ -80,13 +85,6 @@ def _clip_text(value: Any, max_len: int) -> str: return text[: max_len - 3].rstrip() + "..." -def _safe_int(value: Any, default: int) -> int: - try: - return int(value) - except (TypeError, ValueError): - return default - - def _clamp_int(value: int, min_value: int, max_value: int) -> int: if value < min_value: return min_value @@ -95,37 +93,38 @@ def _clamp_int(value: int, min_value: int, max_value: int) -> int: return value -def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int, int]: +def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int]: + """Return (max_source_len, recent_k) for historian context injection. + + ``recent_k`` is derived from ``context_recent_messages_limit`` (same as + main AI) so the historian sees the same message window. + """ max_source_len = _DEFAULT_HISTORIAN_TEXT_LEN recent_k = _DEFAULT_HISTORIAN_LINES - max_recent_line_len = _DEFAULT_HISTORIAN_LINE_LEN runtime_config = context.get("runtime_config") + + # 消息数量:优先使用 context_recent_messages_limit(与主 AI 一致) + if runtime_config is not None and hasattr( + runtime_config, "get_context_recent_messages_limit" + ): + try: + recent_k = int(runtime_config.get_context_recent_messages_limit()) + except Exception: + pass + cognitive = getattr(runtime_config, "cognitive", None) if runtime_config else None if cognitive is not None: - max_source_len = _safe_int( + max_source_len = safe_int( getattr(cognitive, "historian_source_message_max_len", max_source_len), max_source_len, ) - recent_k = _safe_int( - getattr(cognitive, "historian_recent_messages_inject_k", recent_k), - recent_k, - ) - max_recent_line_len = _safe_int( - getattr( - cognitive, "historian_recent_message_line_max_len", max_recent_line_len - ), - max_recent_line_len, - ) max_source_len = _clamp_int( max_source_len, _MIN_HISTORIAN_TEXT_LEN, _MAX_HISTORIAN_TEXT_LEN ) recent_k = _clamp_int(recent_k, _MIN_HISTORIAN_LINES, _MAX_HISTORIAN_LINES) - max_recent_line_len = _clamp_int( - max_recent_line_len, _MIN_HISTORIAN_LINE_LEN, _MAX_HISTORIAN_LINE_LEN - ) - return max_source_len, recent_k, max_recent_line_len + return max_source_len, recent_k def _extract_current_content_from_question(question: str, *, max_len: int) -> str: @@ -139,8 +138,13 @@ def _extract_current_content_from_question(question: str, *, max_len: int) -> st def _build_historian_recent_messages( - context: Dict[str, Any], *, recent_k: int, max_line_len: int + context: Dict[str, Any], *, recent_k: int ) -> list[str]: + """Build XML-formatted recent messages for historian context. + + Uses the same XML schema as the main AI prompt so the historian LLM + sees identical message structure for better disambiguation. + """ if recent_k <= 0: return [] @@ -176,24 +180,14 @@ def _build_historian_recent_messages( for msg in recent: if not isinstance(msg, dict): continue - timestamp = str(msg.get("timestamp", "")).strip() - display_name = str(msg.get("display_name", "")).strip() - user_id = str(msg.get("user_id", "")).strip() - message_text = _clip_text(msg.get("message", ""), max_line_len) - if not message_text: + if not str(msg.get("message", "") or "").strip(): continue - who = display_name or (f"UID:{user_id}" if user_id else "未知用户") - if user_id: - who = f"{who}({user_id})" - if timestamp: - lines.append(f"[{timestamp}] {who}: {message_text}") - else: - lines.append(f"{who}: {message_text}") + lines.append(format_message_xml(msg)) return lines[-recent_k:] def _inject_historian_reference_context(context: Dict[str, Any]) -> None: - max_source_len, recent_k, max_recent_line_len = _resolve_historian_limits(context) + max_source_len, recent_k = _resolve_historian_limits(context) current_question = str(context.get("current_question") or "").strip() source_message = _extract_current_content_from_question( current_question, max_len=max_source_len @@ -205,9 +199,7 @@ def _inject_historian_reference_context(context: Dict[str, Any]) -> None: current_question, max_source_len ) - recent_lines = _build_historian_recent_messages( - context, recent_k=recent_k, max_line_len=max_recent_line_len - ) + recent_lines = _build_historian_recent_messages(context, recent_k=recent_k) if recent_lines: context["historian_recent_messages"] = recent_lines diff --git a/src/Undefined/skills/toolsets/README.md b/src/Undefined/skills/toolsets/README.md index 750dc647..92209988 100644 --- a/src/Undefined/skills/toolsets/README.md +++ b/src/Undefined/skills/toolsets/README.md @@ -136,7 +136,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: ### Render(渲染) - `render.render_html`: 将 HTML 渲染为图片 -- `render.render_latex`: 将 LaTeX 渲染为图片 +- `render.render_latex`: 将 LaTeX 渲染为图片(依赖系统 TeX 环境,需安装 TeX Live / MiKTeX) - `render.render_markdown`: 将 Markdown 渲染为图片 ### Memes(表情包) diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json index 410c0a6e..c20eebea 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json @@ -30,7 +30,7 @@ }, "member_limit": { "type": "integer", - "description": "返回成员列表的数量限制(最大100),默认 20", + "description": "返回成员列表的数量限制,上限由服务器配置决定,默认 20", "default": 20 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py index 9ba62679..fe660ed9 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: member_limit = int(member_limit_raw) if member_limit_raw is not None else 20 if member_limit < 0: return "参数错误:member_limit 必须是非负整数" - member_limit = min(member_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + member_limit = min(member_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:member_limit 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json index 188e2356..6b3ce00f 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json @@ -29,12 +29,12 @@ }, "message_limit": { "type": "integer", - "description": "返回消息内容的数量限制(最大100),默认 20", + "description": "返回消息内容的数量限制,上限由服务器配置决定,默认 20", "default": 20 }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py index 31b8ef82..dd001179 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py @@ -42,7 +42,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: message_limit = int(message_limit_raw) if message_limit_raw is not None else 20 if message_limit < 0: return "参数错误:message_limit 必须是非负整数" - message_limit = min(message_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + message_limit = min(message_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:message_limit 必须是整数" @@ -53,7 +55,8 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json index 118b896c..2a31c984 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json @@ -20,7 +20,7 @@ }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 }, "top_count": { diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py index fbbec911..a8c5d7c8 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + cfg = context.get("runtime_config") + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py index 05f4ba6c..96aaa89b 100644 --- a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py @@ -1,7 +1,15 @@ from typing import Any, Dict from datetime import datetime -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _resolve_chat_id(chat_id: str, msg_type: str, history_manager: Any) -> str: @@ -151,8 +159,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: resolved_chat_id = _resolve_chat_id(chat_id, msg_type, history_manager) if get_recent_messages_callback: + search_scan_limit = _get_history_limit( + context, "history_search_scan_limit", 10000 + ) messages = await get_recent_messages_callback( - resolved_chat_id, msg_type, 0, 10000 + resolved_chat_id, msg_type, 0, search_scan_limit ) # 时间范围过滤 @@ -164,10 +175,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: filtered_messages, keyword, sender ) - # 限制最多返回 50 条 + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) total_matched = len(filtered_messages) - if total_matched > _FILTERED_RESULT_LIMIT: - filtered_messages = filtered_messages[:_FILTERED_RESULT_LIMIT] + if total_matched > filtered_result_limit: + filtered_messages = filtered_messages[:filtered_result_limit] formatted = [] for msg in filtered_messages: @@ -178,6 +191,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: timestamp_str = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") if msg_type_val == "group": # 确保群名以"群"结尾 @@ -189,8 +205,17 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + # 格式:XML 标准化 - formatted.append(f""" + formatted.append(f""" {text} """) diff --git a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py index fda39efe..fbfd906a 100644 --- a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py @@ -2,7 +2,15 @@ from typing import Any, Dict -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _find_group_id_by_name(chat_id: str, history_manager: Any) -> str | None: @@ -105,6 +113,9 @@ def _format_message_xml(msg: dict[str, Any]) -> str: timestamp = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") location = _format_message_location(msg_type_val, chat_name) @@ -112,7 +123,16 @@ def _format_message_xml(msg: dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' - return f""" + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + + return f""" {text} """ @@ -218,8 +238,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 获取消息:有过滤条件时扩大获取范围 messages = [] + search_scan_limit = _get_history_limit(context, "history_search_scan_limit", 10000) fetch_start = 0 if has_filter else start - fetch_end = 10000 if has_filter else end + fetch_end = search_scan_limit if has_filter else end if get_recent_messages_callback: messages = await get_recent_messages_callback( @@ -253,9 +274,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 对过滤结果进行分页切片 messages = messages[start:end] - # 限制最多返回 50 条 - if len(messages) > _FILTERED_RESULT_LIMIT: - messages = messages[:_FILTERED_RESULT_LIMIT] + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) + if len(messages) > filtered_result_limit: + messages = messages[:filtered_result_limit] # 格式化消息 formatted = [_format_message_xml(msg) for msg in messages] diff --git a/src/Undefined/skills/toolsets/messages/send_message/handler.py b/src/Undefined/skills/toolsets/messages/send_message/handler.py index 9d21aab9..f5e8b137 100644 --- a/src/Undefined/skills/toolsets/messages/send_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -168,6 +169,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type=target_type, + target_id=target_id, + ) return _format_send_success(sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py index 727cb2d3..f92cab87 100644 --- a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -115,6 +116,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type="private", + target_id=user_id, + ) return _format_send_success(user_id, sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/render/render_latex/config.json b/src/Undefined/skills/toolsets/render/render_latex/config.json index 1d3c48ef..16b1f995 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/config.json +++ b/src/Undefined/skills/toolsets/render/render_latex/config.json @@ -2,27 +2,18 @@ "type": "function", "function": { "name": "render_latex", - "description": "将 LaTeX 文本渲染为图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。支持完整的 LaTeX 语法(包含 \\begin 和 \\end)。", + "description": "将 LaTeX 数学公式渲染为图片或 PDF 文档,使用 MathJax(不依赖系统 TeX 安装)。支持 LaTeX 数学子集(amsmath、equation、align、matrix 等),但不支持自定义 TeX 包。返回可嵌入回复的附件 UID。", "parameters": { "type": "object", "properties": { "content": { "type": "string", - "description": "要渲染的 LaTeX 内容。必须是完整格式(包含 \\begin 和 \\end)。" + "description": "要渲染的 LaTeX 数学内容。支持 $...$(行内)、$$...$$(块级)、\\[...\\]、\\(...\\) 及标准数学环境(\\begin{align}、\\begin{equation}、\\begin{matrix} 等)。如果不包含分隔符,会自动用 \\[ ... \\] 包装。\\begin{document}...\\end{document} 外层包装会自动去掉。" }, - "delivery": { + "output_format": { "type": "string", - "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", - "enum": ["embed", "send"] - }, - "target_id": { - "type": "integer", - "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" - }, - "message_type": { - "type": "string", - "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", - "enum": ["group", "private"] + "description": "输出格式:\"png\"(默认,图片)或 \"pdf\"(PDF 文档)", + "enum": ["png", "pdf"] } }, "required": ["content"] diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 26774d64..5da83f9b 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -1,138 +1,241 @@ from __future__ import annotations -from typing import Any, Dict import logging -import uuid -import matplotlib.pyplot as plt -import matplotlib +import re +from typing import Any, Dict from Undefined.attachments import scope_from_context logger = logging.getLogger(__name__) +_DOCUMENT_PATTERN = re.compile( + r"^\s*\\begin\{document\}(?P.*?)\\end\{document\}\s*$", + re.DOTALL, +) + +# MathJax 数学分隔符模式 +_MATH_DELIMITER_PATTERN = re.compile( + r"(\$\$|\\\[|\\\(|\\begin\{)", + re.MULTILINE, +) + + +def _strip_document_wrappers(content: str) -> str: + """去掉 \\begin{document}...\\end{document} 外层包装。""" + text = content.strip() + match = _DOCUMENT_PATTERN.fullmatch(text) + if match is None: + return text + return match.group("body").strip() + + +def _has_math_delimiters(content: str) -> bool: + """检查内容是否已包含数学分隔符。""" + return bool(_MATH_DELIMITER_PATTERN.search(content)) + + +def _prepare_content(raw_content: str) -> str: + """ + 准备 LaTeX 内容: + 1. 去掉 document 包装 + 2. 处理字面量 \\n(LLM 输出常见问题) + 3. 如果没有数学分隔符,自动用 \\[ ... \\] 包装 + """ + content = _strip_document_wrappers(raw_content) + # 替换字面量 \\n 为真实换行符,但保留 LaTeX 命令如 \nu \nabla \neq 等 + content = re.sub(r"\\n(?![a-zA-Z])", "\n", content) + + if not _has_math_delimiters(content): + # 没有分隔符,自动包装为块级数学环境 + content = f"\\[\n{content}\n\\]" + + return content + + +def _build_html(latex_content: str) -> str: + """构建包含 MathJax 的 HTML 页面。""" + # HTML 转义(防止内容中的 < > & 破坏结构) + import html + + escaped_content = html.escape(latex_content) + + return f""" + + + + + + + + +
+{escaped_content} +
+ +""" + + +async def _render_latex_to_bytes( + content: str, output_format: str, proxy: str | None = None +) -> tuple[bytes, str]: + """ + 使用 MathJax + Playwright 渲染 LaTeX 内容。 + + 返回: (渲染后的字节流, MIME 类型) + """ + try: + from playwright.async_api import ( + async_playwright, + TimeoutError as PwTimeoutError, + ) + except ImportError: + raise ImportError( + "请运行 `uv run playwright install` 安装浏览器运行时" + ) from None + + html_content = _build_html(content) + + launch_kwargs: dict[str, object] = {"headless": True} + if proxy: + launch_kwargs["proxy"] = {"server": proxy} + logger.info("LaTeX 渲染使用代理: %s", proxy) + + async with async_playwright() as p: + browser = await p.chromium.launch(**launch_kwargs) # type: ignore[arg-type] + try: + page = await browser.new_page() + await page.set_content(html_content) + + # 等待 MathJax 完成排版(pageReady 回调设置 window._mjReady) + try: + await page.wait_for_function( + "() => window._mjReady === true", + timeout=30000, + ) + except PwTimeoutError: + logger.warning("MathJax 排版超时,内容可能过于复杂或网络不可达") + raise RuntimeError( + "LaTeX 内容可能过于复杂或网络不可达(MathJax 加载超时)" + ) from None + + if output_format == "pdf": + # 获取容器尺寸 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + bbox = await container.bounding_box() + if bbox is None: + raise RuntimeError("无法获取数学容器的边界框") + + # PDF 输出,设置合适的页面尺寸 + pdf_bytes = await page.pdf( + width=f"{bbox['width'] + 40}px", + height=f"{bbox['height'] + 40}px", + print_background=True, + ) + return pdf_bytes, "application/pdf" + else: + # PNG 输出 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + screenshot_bytes = await container.screenshot(type="png") + return screenshot_bytes, "image/png" + + finally: + await browser.close() -def _resolve_send_target( - target_id: Any, - message_type: Any, - context: Dict[str, Any], -) -> tuple[int | str | None, str | None, str | None]: - """从参数或 context 推断发送目标。""" - if target_id is not None and message_type is not None: - return target_id, message_type, None - request_type = str(context.get("request_type", "") or "").strip().lower() - if request_type == "group": - gid = context.get("group_id") - if gid is not None: - return gid, "group", None - if request_type == "private": - uid = context.get("user_id") - if uid is not None: - return uid, "private", None - return None, None, "渲染成功,但缺少发送目标参数" + +async def _resolve_proxy(context: Dict[str, Any]) -> str | None: + """从 context 的 runtime_config 中解析代理地址。""" + from Undefined.config import get_config + + runtime_config = context.get("runtime_config") or get_config(strict=False) + if runtime_config is None: + return None + use_proxy: bool = getattr(runtime_config, "use_proxy", False) + if not use_proxy: + return None + proxy: str = getattr(runtime_config, "http_proxy", "") or getattr( + runtime_config, "https_proxy", "" + ) + return proxy or None async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: - """渲染 LaTeX 数学公式为图片""" - content = args.get("content", "") - delivery = str(args.get("delivery", "embed") or "embed").strip().lower() - target_id = args.get("target_id") - message_type = args.get("message_type") + """渲染 LaTeX 数学公式为图片或 PDF""" + raw_content = str(args.get("content", "") or "") + output_format = str(args.get("output_format", "png") or "png").strip().lower() - if not content: - return "内容不能为空" - if delivery not in {"embed", "send"}: - return f"delivery 无效:{delivery}。仅支持 embed 或 send" + # 参数校验 + if not raw_content or not raw_content.strip(): + return "LaTeX 内容不能为空" - if delivery == "send" and message_type and message_type not in ("group", "private"): - return "消息类型必须是 group 或 private" + if output_format not in {"png", "pdf"}: + return f"output_format 无效:{output_format}。仅支持 png 或 pdf" try: - from Undefined.utils.cache import cleanup_cache_dir - from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir - - filename = f"render_{uuid.uuid4().hex[:16]}.png" - filepath = ensure_dir(RENDER_CACHE_DIR) / filename - - matplotlib.use("Agg") - - fig, ax = plt.subplots(figsize=(10, 6)) - ax.axis("off") - - ax.text( - 0.5, - 0.5, - content, - transform=ax.transAxes, - fontsize=12, - verticalalignment="center", - horizontalalignment="center", - usetex=True, - wrap=True, - ) + # 准备内容 + prepared_content = _prepare_content(raw_content) - plt.tight_layout() - plt.savefig(filepath, dpi=150, bbox_inches="tight", pad_inches=0.1) - plt.close(fig) + # 解析代理 + proxy = await _resolve_proxy(context) + + # 渲染 + rendered_bytes, mime_type = await _render_latex_to_bytes( + prepared_content, output_format, proxy=proxy + ) # 注册到附件系统 attachment_registry = context.get("attachment_registry") scope_key = scope_from_context(context) - record: Any = None - if attachment_registry is not None and scope_key: - try: - record = await attachment_registry.register_local_file( - scope_key, - filepath, - kind="image", - display_name=filename, - source_kind="rendered_image", - source_ref="render_latex", - ) - except Exception as exc: - logger.warning("注册渲染图片到附件系统失败: %s", exc) - - if delivery == "embed": - cleanup_cache_dir(RENDER_CACHE_DIR) - if record is None: - return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" - return f'' - - # delivery == "send" - resolved_target_id, resolved_message_type, target_error = _resolve_send_target( - target_id, message_type, context - ) - if target_error or resolved_target_id is None or resolved_message_type is None: - return target_error or "渲染成功,但缺少发送目标参数" - - sender = context.get("sender") - send_image_callback = context.get("send_image_callback") - - if sender: - from pathlib import Path - - cq_message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" - if resolved_message_type == "group": - await sender.send_group_message(int(resolved_target_id), cq_message) - elif resolved_message_type == "private": - await sender.send_private_message(int(resolved_target_id), cq_message) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" - ) - elif send_image_callback: - await send_image_callback( - resolved_target_id, resolved_message_type, str(filepath) - ) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + + if attachment_registry is None or not scope_key: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + + kind = "image" if output_format == "png" else "file" + extension = "png" if output_format == "png" else "pdf" + display_name = f"latex.{extension}" + + try: + record = await attachment_registry.register_bytes( + scope_key, + rendered_bytes, + kind=kind, + display_name=display_name, + mime_type=mime_type, + source_kind="rendered_latex", + source_ref="render_latex", ) - else: - return "发送图片回调未设置" + tag = "pic" if output_format == "png" else "attachment" + return f'<{tag} uid="{record.uid}"/>' + + except Exception as exc: + logger.exception("注册渲染结果到附件系统失败: %s", exc) + return f"渲染成功,但注册到附件系统失败: {exc}" except ImportError as e: - missing_pkg = str(e).split("'")[1] if "'" in str(e) else "未知包" - return f"渲染失败:缺少依赖包 {missing_pkg},请运行: uv add {missing_pkg}" + logger.error("Playwright 导入失败: %s", e) + return "请运行 `uv run playwright install` 安装浏览器运行时" + except RuntimeError as e: + logger.error("LaTeX 渲染运行时错误: %s", e) + return str(e) except Exception as e: - logger.exception(f"渲染并发送 LaTeX 图片失败: {e}") - return "渲染失败,请稍后重试" + logger.exception("渲染 LaTeX 失败: %s", e) + return f"渲染失败:{e}" diff --git a/src/Undefined/utils/coerce.py b/src/Undefined/utils/coerce.py new file mode 100644 index 00000000..20c81eff --- /dev/null +++ b/src/Undefined/utils/coerce.py @@ -0,0 +1,37 @@ +"""Type-safe coercion helpers shared across the codebase.""" + +from __future__ import annotations + +from typing import Any, overload + + +@overload +def safe_int(value: Any) -> int | None: ... + + +@overload +def safe_int(value: Any, default: int) -> int: ... + + +@overload +def safe_int(value: Any, default: None) -> int | None: ... + + +def safe_int(value: Any, default: int | None = None) -> int | None: + """Safely convert *value* to int, returning *default* on failure.""" + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def safe_float(value: Any, default: float = 0.0) -> float: + """Safely convert *value* to float, returning *default* on failure.""" + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default diff --git a/src/Undefined/utils/fake_at.py b/src/Undefined/utils/fake_at.py new file mode 100644 index 00000000..c3df1ae2 --- /dev/null +++ b/src/Undefined/utils/fake_at.py @@ -0,0 +1,180 @@ +"""假@检测:识别群聊中 ``@昵称`` 纯文本形式的"假 at"。 + +设计要点: +- ``BotNicknameCache`` 自动通过 OneBot API 获取 bot 在各群的群名片 / QQ 昵称, + 带 TTL 缓存 + per-group asyncio.Lock 防竞态。 +- ``strip_fake_at`` 是无状态纯函数,负责文本匹配与剥离。 +- 匹配规则:半角 ``@`` / 全角 ``@`` + 昵称 (casefold) + 边界(空白/标点/行尾), + 昵称按长度降序匹配以避免短昵称吃掉长昵称的前缀。 +""" + +from __future__ import annotations + +import asyncio +import logging +import re +import time +import unicodedata +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from Undefined.onebot import OneBotClient + +logger = logging.getLogger(__name__) + +# 昵称后必须跟的边界:空白、常见标点、行尾 +_BOUNDARY_RE = re.compile(r"[\s\u3000,.:;!?,。:;!?/\-\]))>》]+|$") + +# 缓存默认 TTL (秒) +_DEFAULT_TTL: float = 600.0 + + +def _normalize(text: str) -> str: + """将全角 @ 统一为半角 @,然后 casefold。""" + return unicodedata.normalize("NFKC", text).casefold() + + +def _sorted_nicknames(names: frozenset[str]) -> tuple[str, ...]: + """按长度降序排列,长昵称优先匹配。""" + return tuple(sorted(names, key=len, reverse=True)) + + +def strip_fake_at( + text: str, + nicknames: frozenset[str], +) -> tuple[bool, str]: + """检测并剥离文本开头的假 @ 前缀。 + + 参数: + text: 原始消息文本(已做 extract_text 处理)。 + nicknames: 当前群 bot 所有有效昵称 (已 casefold)。 + + 返回: + (is_fake_at, stripped_text) + - is_fake_at: 是否命中假 @。 + - stripped_text: 剥离假 @ 后的文本(未命中时返回原文)。 + """ + if not nicknames or not text: + return False, text + + normalized = _normalize(text) + # 以 @ 开头才可能是假 @ + if not normalized.startswith("@"): + return False, text + + # 去掉开头的 @ + after_at = normalized[1:] + raw_after_at = text[1:] # 保持原始大小写的切片 + + for nick in _sorted_nicknames(nicknames): + if not after_at.startswith(nick): + continue + # 检查昵称后是否为合法边界 + rest_pos = len(nick) + rest_normalized = after_at[rest_pos:] + if rest_normalized and not _BOUNDARY_RE.match(rest_normalized): + continue + # 命中——用原始文本切出剥离后的内容 + stripped = raw_after_at[rest_pos:].lstrip() + return True, stripped + + return False, text + + +class BotNicknameCache: + """按群缓存 bot 昵称,用于假 @ 检测。 + + 线程安全性: + - 只在单一 asyncio 事件循环中使用。 + - 每个 group_id 使用独立 ``asyncio.Lock``,保证同一群的并发 + 消息不会触发重复 API 请求。 + - ``_global_lock`` 仅保护 ``_locks`` 字典本身的创建。 + """ + + def __init__( + self, + onebot: OneBotClient, + bot_qq: int, + ttl: float = _DEFAULT_TTL, + ) -> None: + self._onebot = onebot + self._bot_qq = bot_qq + self._ttl = ttl + # group_id → (nicknames_frozenset, timestamp) + self._cache: dict[int, tuple[frozenset[str], float]] = {} + # group_id → asyncio.Lock + self._locks: dict[int, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + + def _get_lock(self, group_id: int) -> asyncio.Lock: + """获取指定群的锁(同步快路径 + 异步慢路径)。 + + 在 asyncio 单线程模型下,dict 读写本身是原子的, + 但为严谨起见仍使用 ``_global_lock`` 保护 ``_locks`` 的创建。 + 注意:此方法不是 coroutine,但会在 _ensure_lock 里 await。 + """ + lock = self._locks.get(group_id) + if lock is not None: + return lock + # 需要创建——由调用方在 _ensure_lock 中持有 _global_lock + lock = asyncio.Lock() + self._locks[group_id] = lock + return lock + + async def _ensure_lock(self, group_id: int) -> asyncio.Lock: + """确保 group lock 存在,如有必要在 global lock 下创建。""" + lock = self._locks.get(group_id) + if lock is not None: + return lock + async with self._global_lock: + return self._get_lock(group_id) + + async def get_nicknames(self, group_id: int) -> frozenset[str]: + """获取 bot 在指定群的所有有效昵称(含手动配置)。 + + 会自动缓存,过期后异步刷新。API 失败时返回上次缓存或仅手动昵称。 + """ + now = time.monotonic() + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + lock = await self._ensure_lock(group_id) + async with lock: + # Double-check:可能在等锁期间已被其他协程刷新 + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + fetched = await self._fetch(group_id) + self._cache[group_id] = (fetched, time.monotonic()) + return fetched + + async def _fetch(self, group_id: int) -> frozenset[str]: + """从 OneBot API 获取 bot 在指定群的群名片 + QQ 昵称。""" + names: set[str] = set() + try: + info = await self._onebot.get_group_member_info(group_id, self._bot_qq) + if isinstance(info, dict): + for key in ("card", "nickname"): + val = str(info.get(key, "") or "").strip() + if val: + names.add(_normalize(val)) + except Exception as exc: + logger.debug( + "[假@] 获取 bot 群成员信息失败: group=%s err=%s", + group_id, + exc, + ) + return frozenset(names) + + def invalidate(self, group_id: int | None = None) -> None: + """手动失效缓存。group_id=None 清空全部。""" + if group_id is None: + self._cache.clear() + else: + self._cache.pop(group_id, None) diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index 46dfafee..578711eb 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -81,17 +81,20 @@ def _get_private_history_path(self, user_id: int) -> str: async def _save_history_to_file( self, history: list[dict[str, Any]], path: str ) -> None: - """异步保存历史记录到文件(最多 10000 条)""" + """异步保存历史记录到文件""" from Undefined.utils import io try: - # 只保留最近的 self._max_records 条 - truncated_history = ( - history[-self._max_records :] - if len(history) > self._max_records - else history - ) - truncated = len(history) > self._max_records + if self._max_records > 0: + truncated_history = ( + history[-self._max_records :] + if len(history) > self._max_records + else history + ) + truncated = len(history) > self._max_records + else: + truncated_history = history + truncated = False logger.debug( f"[历史记录] 准备保存: path={path}, total={len(history)}, truncated={truncated}" @@ -210,12 +213,13 @@ async def _load_history_from_file(self, path: str) -> list[dict[str, Any]]: normalized_history.append(msg) - # 只保留最近的 self._max_records 条 - return ( - normalized_history[-self._max_records :] - if len(normalized_history) > self._max_records - else normalized_history - ) + if self._max_records > 0: + return ( + normalized_history[-self._max_records :] + if len(normalized_history) > self._max_records + else normalized_history + ) + return normalized_history except Exception as e: logger.error(f"加载历史记录失败 {path}: {e}") @@ -306,6 +310,7 @@ async def add_group_message( group_name: str = "", role: str = "member", title: str = "", + level: str = "", message_id: int | None = None, attachments: list[dict[str, str]] | None = None, ) -> None: @@ -334,6 +339,7 @@ async def add_group_message( "display_name": display_name, "role": role, "title": title, + "level": level, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "message": text_content, } @@ -344,7 +350,10 @@ async def add_group_message( self._message_history[group_id_str].append(record) - if len(self._message_history[group_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._message_history[group_id_str]) > self._max_records + ): self._message_history[group_id_str] = self._message_history[ group_id_str ][-self._max_records :] @@ -393,7 +402,10 @@ async def add_private_message( self._private_message_history[user_id_str].append(record) - if len(self._private_message_history[user_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._private_message_history[user_id_str]) > self._max_records + ): self._private_message_history[user_id_str] = ( self._private_message_history[user_id_str][-self._max_records :] ) diff --git a/src/Undefined/utils/queue_intervals.py b/src/Undefined/utils/queue_intervals.py index 094dceba..0596d7d5 100644 --- a/src/Undefined/utils/queue_intervals.py +++ b/src/Undefined/utils/queue_intervals.py @@ -6,6 +6,7 @@ def build_model_queue_intervals(config: Config) -> dict[str, float]: + summary_model = getattr(config, "summary_model", config.agent_model) pairs: Iterable[tuple[str, float]] = ( (config.chat_model.model_name, config.chat_model.queue_interval_seconds), (config.agent_model.model_name, config.agent_model.queue_interval_seconds), @@ -18,7 +19,15 @@ def build_model_queue_intervals(config: Config) -> dict[str, float]: config.naga_model.model_name, config.naga_model.queue_interval_seconds, ), + ( + summary_model.model_name, + summary_model.queue_interval_seconds, + ), (config.grok_model.model_name, config.grok_model.queue_interval_seconds), + ( + config.historian_model.model_name, + config.historian_model.queue_interval_seconds, + ), ) intervals: dict[str, float] = {} for model_name, interval in pairs: diff --git a/src/Undefined/utils/recent_messages.py b/src/Undefined/utils/recent_messages.py index 84953ae2..10adca85 100644 --- a/src/Undefined/utils/recent_messages.py +++ b/src/Undefined/utils/recent_messages.py @@ -247,9 +247,14 @@ async def get_recent_messages_prefer_local( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """优先从本地 history 获取最近消息,必要时回退到 OneBot。""" + if max_onebot_count is None: + from Undefined.config import get_config + + cfg = get_config(strict=False) + max_onebot_count = getattr(cfg, "history_onebot_fetch_limit", 10000) norm_start, norm_end = _normalize_range(start, end) if norm_end <= 0: return [] @@ -311,7 +316,7 @@ async def get_recent_messages_prefer_onebot( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """兼容旧名称,当前行为等同于本地优先。""" return await get_recent_messages_prefer_local( diff --git a/src/Undefined/utils/xml.py b/src/Undefined/utils/xml.py index 865ac249..16d24153 100644 --- a/src/Undefined/utils/xml.py +++ b/src/Undefined/utils/xml.py @@ -1,7 +1,9 @@ -"""Minimal XML escaping helpers.""" +"""Minimal XML escaping helpers and message formatting.""" from __future__ import annotations +from typing import Any, Callable, Sequence, Mapping + from xml.sax.saxutils import escape @@ -12,3 +14,80 @@ def escape_xml_text(value: str) -> str: def escape_xml_attr(value: object) -> str: text = "" if value is None else str(value) return escape(text, {'"': """, "'": "'"}) + + +def _message_location(msg_type: str, chat_name: str) -> str: + """Derive the human-readable location label from message type.""" + if msg_type == "group": + return chat_name if chat_name.endswith("群") else f"{chat_name}群" + return "私聊" + + +def format_message_xml( + msg: dict[str, Any], + *, + attachment_formatter: (Callable[[Sequence[Mapping[str, str]]], str] | None) = None, +) -> str: + """Format a single history record dict into main-AI-compatible XML. + + ``attachment_formatter`` is an optional callable that turns the attachments + list into an XML fragment. When *None* (the default) a lazy import of + :func:`Undefined.attachments.attachment_refs_to_xml` is used so that + lightweight callers do not pay the import cost. + """ + msg_type_val = str(msg.get("type", "group") or "group") + sender_name = str(msg.get("display_name", "未知用户") or "未知用户") + sender_id = str(msg.get("user_id", "") or "") + chat_id = str(msg.get("chat_id", "") or "") + chat_name = str(msg.get("chat_name", "未知群聊") or "未知群聊") + timestamp = str(msg.get("timestamp", "") or "") + text = str(msg.get("message", "") or "") + message_id = msg.get("message_id") + role = str(msg.get("role", "member") or "member") + title = str(msg.get("title", "") or "") + level = str(msg.get("level", "") or "") + attachments = msg.get("attachments", []) + + safe_sender = escape_xml_attr(sender_name) + safe_sender_id = escape_xml_attr(sender_id) + safe_chat_id = escape_xml_attr(chat_id) + safe_chat_name = escape_xml_attr(chat_name) + safe_role = escape_xml_attr(role) + safe_title = escape_xml_attr(title) + safe_time = escape_xml_attr(timestamp) + safe_text = escape_xml_text(text) + safe_location = escape_xml_attr(_message_location(msg_type_val, chat_name)) + + msg_id_attr = "" + if message_id is not None: + msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' + + attachment_xml = "" + if isinstance(attachments, list) and attachments: + if attachment_formatter is None: + from Undefined.attachments import attachment_refs_to_xml + + attachment_formatter = attachment_refs_to_xml + attachment_xml = f"\n{attachment_formatter(attachments)}" + + if msg_type_val == "group": + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" + return ( + f'\n' + f"{safe_text}{attachment_xml}\n" + f"" + ) + + return ( + f'\n' + f"{safe_text}{attachment_xml}\n" + f"" + ) + + +def format_messages_xml(messages: list[dict[str, Any]]) -> str: + """Format a list of history records into ``\\n---\\n``-separated XML.""" + return "\n---\n".join(format_message_xml(msg) for msg in messages) diff --git a/src/Undefined/webui/routes/_config.py b/src/Undefined/webui/routes/_config.py index 0afeaa01..e3f44e43 100644 --- a/src/Undefined/webui/routes/_config.py +++ b/src/Undefined/webui/routes/_config.py @@ -1,5 +1,7 @@ import asyncio +import shutil import tomllib +from datetime import datetime, timezone from pathlib import Path from tempfile import NamedTemporaryFile @@ -23,6 +25,27 @@ sync_config_file, ) +_BACKUP_DIR = Path("data") / "config_backups" +_MAX_BACKUPS = 50 + + +def _backup_config() -> str | None: + """Create a timestamped backup of the current config.toml. + + Returns the backup filename, or *None* if the source file does not exist. + """ + if not CONFIG_PATH.exists(): + return None + _BACKUP_DIR.mkdir(parents=True, exist_ok=True) + ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + backup_name = f"config_{ts}.toml" + shutil.copy2(CONFIG_PATH, _BACKUP_DIR / backup_name) + # Trim old backups beyond the cap + backups = sorted(_BACKUP_DIR.glob("config_*.toml")) + while len(backups) > _MAX_BACKUPS: + backups.pop(0).unlink(missing_ok=True) + return backup_name + @routes.get("/api/v1/management/config") @routes.get("/api/config") @@ -45,6 +68,7 @@ async def save_config_handler(request: web.Request) -> Response: if not valid: return web.json_response({"success": False, "error": msg}, status=400) try: + _backup_config() CONFIG_PATH.write_text(content, encoding="utf-8") get_config_manager().reload() logic_valid, logic_msg = validate_required_config() @@ -133,6 +157,7 @@ async def config_patch_handler(request: web.Request) -> Response: data = {} patched = apply_patch(data, patch) + _backup_config() CONFIG_PATH.write_text( render_toml(patched, comments=load_comment_map()), encoding="utf-8" ) @@ -181,3 +206,51 @@ async def sync_config_template_handler(request: web.Request) -> Response: return web.json_response({"success": False, "error": str(exc)}, status=400) except Exception as exc: return web.json_response({"success": False, "error": str(exc)}, status=500) + + +# --------------------------------------------------------------------------- +# Config version history +# --------------------------------------------------------------------------- + + +@routes.get("/api/v1/management/config/history") +@routes.get("/api/config/history") +async def config_history_handler(request: web.Request) -> Response: + """Return the list of config backups, newest first.""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + if not _BACKUP_DIR.exists(): + return web.json_response({"backups": []}) + backups: list[dict[str, object]] = [] + for f in sorted(_BACKUP_DIR.glob("config_*.toml"), reverse=True): + stat = f.stat() + backups.append({"name": f.name, "size": stat.st_size, "mtime": stat.st_mtime}) + return web.json_response({"backups": backups}) + + +@routes.post("/api/v1/management/config/history/restore") +@routes.post("/api/config/history/restore") +async def config_restore_handler(request: web.Request) -> Response: + """Restore a config backup by name (auto-backs-up current config first).""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: + data = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON"}, status=400) + name = str(data.get("name", "")) + if not name or ".." in name or "/" in name or "\\" in name: + return web.json_response({"error": "Invalid backup name"}, status=400) + backup_path = (_BACKUP_DIR / name).resolve() + if not str(backup_path).startswith(str(_BACKUP_DIR.resolve())): + return web.json_response({"error": "Invalid backup name"}, status=400) + if not backup_path.exists(): + return web.json_response({"error": "Backup not found"}, status=404) + content = backup_path.read_text(encoding="utf-8") + valid, msg = validate_toml(content) + if not valid: + return web.json_response({"error": f"Backup TOML invalid: {msg}"}, status=400) + _backup_config() + CONFIG_PATH.write_text(content, encoding="utf-8") + get_config_manager().reload() + return web.json_response({"success": True, "message": "Restored"}) diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index c6de3ebb..371c46b9 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -296,6 +296,51 @@ async def runtime_memory_handler(request: web.Request) -> Response: ) +@routes.post("/api/v1/management/runtime/memory") +@routes.post("/api/runtime/memory") +async def runtime_memory_create_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="POST", + path="/api/v1/memory", + payload=payload, + ) + + +@routes.patch("/api/v1/management/runtime/memory/{uuid}") +@routes.patch("/api/runtime/memory/{uuid}") +async def runtime_memory_update_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="PATCH", + path=f"/api/v1/memory/{target_uuid}", + payload=payload, + ) + + +@routes.delete("/api/v1/management/runtime/memory/{uuid}") +@routes.delete("/api/runtime/memory/{uuid}") +async def runtime_memory_delete_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + return await _proxy_runtime( + method="DELETE", + path=f"/api/v1/memory/{target_uuid}", + ) + + @routes.get("/api/v1/management/runtime/cognitive/events") @routes.get("/api/runtime/cognitive/events") async def runtime_cognitive_events_handler(request: web.Request) -> Response: diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 7118ef69..bd91f6d2 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -428,6 +428,61 @@ white-space: pre-wrap; word-break: break-word; } +/* Memory CRUD */ +.memory-create-form { + display: flex; + gap: 8px; + align-items: flex-start; + padding: 8px 0; +} +.memory-create-form textarea { + flex: 1; + resize: vertical; + min-height: 40px; +} +.memory-create-actions { + flex-shrink: 0; + align-self: flex-end; +} +.memory-item-actions { + display: flex; + gap: 4px; + flex-shrink: 0; + margin-left: 8px; +} +.memory-item-actions button { + background: none; + border: 1px solid var(--border-color); + border-radius: var(--radius-sm); + cursor: pointer; + padding: 2px 6px; + font-size: 12px; + color: var(--text-tertiary); + transition: color 0.15s, border-color 0.15s; +} +.memory-item-actions button:hover { + color: var(--accent-color); + border-color: var(--accent-color); +} +.memory-item-actions button.memory-btn-delete:hover { + color: var(--error); + border-color: var(--error); +} +.memory-edit-area { + width: 100%; + min-height: 60px; + font-size: 14px; + line-height: 1.6; + font-family: inherit; + resize: vertical; + margin-top: 4px; +} +.memory-edit-actions { + display: flex; + gap: 6px; + margin-top: 6px; + justify-content: flex-end; +} .runtime-doc { font-size: 13px; line-height: 1.6; @@ -724,3 +779,31 @@ white-space: nowrap; border: 0; } + +/* Skeleton loading */ +@keyframes shimmer { + 0% { background-position: -400px 0; } + 100% { background-position: 400px 0; } +} +.skeleton { + background: linear-gradient(90deg, var(--bg-card) 25%, var(--bg-app) 50%, var(--bg-card) 75%); + background-size: 800px 100%; + animation: shimmer 1.5s infinite linear; + border-radius: var(--radius-sm); +} +.skeleton-text { height: 14px; margin-bottom: 10px; } +.skeleton-text.short { width: 60%; } +.skeleton-text.medium { width: 80%; } +.skeleton-block { height: 48px; margin-bottom: 12px; } +.skeleton-bar { height: 8px; border-radius: 999px; } + +/* Command Palette */ +.cmd-palette-overlay { position: fixed; inset: 0; background: rgba(0,0,0,0.5); z-index: 9998; display: flex; align-items: flex-start; justify-content: center; padding-top: 20vh; } +.cmd-palette { width: 480px; max-width: 90vw; background: var(--bg-card); border: 1px solid var(--border-color); border-radius: var(--radius-lg); box-shadow: var(--shadow-lg); overflow: hidden; } +.cmd-palette-input { width: 100%; padding: 14px 18px; border: none; border-bottom: 1px solid var(--border-color); background: transparent; color: var(--text-primary); font-size: 15px; outline: none; box-sizing: border-box; } +.cmd-palette-input::placeholder { color: var(--text-tertiary); } +.cmd-palette-list { max-height: 320px; overflow-y: auto; padding: 6px; } +.cmd-palette-item { padding: 10px 14px; border-radius: var(--radius-sm); cursor: pointer; display: flex; align-items: center; gap: 10px; color: var(--text-secondary); font-size: 14px; } +.cmd-palette-item:hover, .cmd-palette-item.active { background: var(--accent-subtle); color: var(--text-primary); } +.cmd-palette-item .cmd-key { font-size: 11px; color: var(--text-tertiary); margin-left: auto; font-family: var(--font-mono); } +.cmd-palette-empty { padding: 20px; text-align: center; color: var(--text-tertiary); font-size: 13px; } diff --git a/src/Undefined/webui/static/js/api.js b/src/Undefined/webui/static/js/api.js index 8ab88a3e..ed528e10 100644 --- a/src/Undefined/webui/static/js/api.js +++ b/src/Undefined/webui/static/js/api.js @@ -1,3 +1,31 @@ +// Active request controllers for cancellation on tab switch +const _activeControllers = new Map(); + +function abortPendingRequests(kind) { + if (kind) { + const controller = _activeControllers.get(kind); + if (controller) { + controller.abort(); + _activeControllers.delete(kind); + } + } else { + for (const controller of _activeControllers.values()) { + controller.abort(); + } + _activeControllers.clear(); + } +} + +function getAbortSignal(kind) { + if (kind) { + abortPendingRequests(kind); + const controller = new AbortController(); + _activeControllers.set(kind, controller); + return controller.signal; + } + return undefined; +} + const AUTH_ENDPOINTS = { login: [ "/api/v1/management/auth/login", @@ -31,7 +59,13 @@ function shouldRetryCandidate(res) { async function requestOnce(path, options = {}) { const headers = { ...(options.headers || {}) }; - if (options.method === "POST" && options.body && !headers["Content-Type"]) { + const needsJson = + options.body && + !headers["Content-Type"] && + ["POST", "PATCH", "PUT", "DELETE"].includes( + String(options.method || "").toUpperCase(), + ); + if (needsJson) { headers["Content-Type"] = "application/json"; } if (state.authAccessToken && !headers.Authorization) { @@ -41,6 +75,7 @@ async function requestOnce(path, options = {}) { ...options, headers, credentials: options.credentials || "same-origin", + signal: options.signal || undefined, }); } diff --git a/src/Undefined/webui/static/js/bot.js b/src/Undefined/webui/static/js/bot.js index 4d079724..28f7808a 100644 --- a/src/Undefined/webui/static/js/bot.js +++ b/src/Undefined/webui/static/js/bot.js @@ -1,3 +1,124 @@ +// Metrics history for time series chart +const METRICS_HISTORY_SIZE = 120; +const _metricsHistory = { cpu: [], memory: [], timestamps: [] }; + +function pushMetrics(cpuPercent, memPercent) { + const now = new Date(); + _metricsHistory.cpu.push(cpuPercent); + _metricsHistory.memory.push(memPercent); + _metricsHistory.timestamps.push(now); + if (_metricsHistory.cpu.length > METRICS_HISTORY_SIZE) { + _metricsHistory.cpu.shift(); + _metricsHistory.memory.shift(); + _metricsHistory.timestamps.shift(); + } +} + +function drawMetricsChart() { + const canvas = get("metricsChart"); + if (!canvas) return; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const dpr = window.devicePixelRatio || 1; + const rect = canvas.getBoundingClientRect(); + canvas.width = rect.width * dpr; + canvas.height = rect.height * dpr; + ctx.scale(dpr, dpr); + + const w = rect.width; + const h = rect.height; + const pad = { top: 10, right: 12, bottom: 24, left: 36 }; + const plotW = w - pad.left - pad.right; + const plotH = h - pad.top - pad.bottom; + + ctx.clearRect(0, 0, w, h); + + const len = _metricsHistory.cpu.length; + if (len < 2) { + ctx.fillStyle = + getComputedStyle(document.documentElement) + .getPropertyValue("--text-tertiary") + .trim() || "#999"; + ctx.font = "12px sans-serif"; + ctx.textAlign = "center"; + ctx.fillText("Collecting data...", w / 2, h / 2); + return; + } + + const textColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--text-tertiary") + .trim() || "#999"; + const gridColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--border-color") + .trim() || "#333"; + const cpuColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--accent-color") + .trim() || "#d97757"; + const memColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--success") + .trim() || "#4a7c59"; + + // Y axis gridlines + ctx.strokeStyle = gridColor; + ctx.lineWidth = 0.5; + ctx.fillStyle = textColor; + ctx.font = "10px sans-serif"; + ctx.textAlign = "right"; + for (let pct = 0; pct <= 100; pct += 25) { + const y = pad.top + plotH - (pct / 100) * plotH; + ctx.beginPath(); + ctx.moveTo(pad.left, y); + ctx.lineTo(pad.left + plotW, y); + ctx.stroke(); + ctx.fillText(`${pct}%`, pad.left - 4, y + 3); + } + + // X axis time labels + ctx.textAlign = "center"; + const timestamps = _metricsHistory.timestamps; + const labelCount = Math.min(4, len); + for (let i = 0; i < labelCount; i++) { + const idx = Math.round((i / (labelCount - 1)) * (len - 1)); + const x = pad.left + (idx / (len - 1)) * plotW; + const t = timestamps[idx]; + const label = `${String(t.getMinutes()).padStart(2, "0")}:${String(t.getSeconds()).padStart(2, "0")}`; + ctx.fillText(label, x, h - 4); + } + + function drawLine(data, color) { + ctx.strokeStyle = color; + ctx.lineWidth = 1.5; + ctx.lineJoin = "round"; + ctx.beginPath(); + for (let i = 0; i < data.length; i++) { + const x = pad.left + (i / (len - 1)) * plotW; + const y = + pad.top + + plotH - + (Math.min(100, Math.max(0, data[i])) / 100) * plotH; + if (i === 0) ctx.moveTo(x, y); + else ctx.lineTo(x, y); + } + ctx.stroke(); + + ctx.globalAlpha = 0.08; + ctx.fillStyle = color; + ctx.lineTo(pad.left + plotW, pad.top + plotH); + ctx.lineTo(pad.left, pad.top + plotH); + ctx.closePath(); + ctx.fill(); + ctx.globalAlpha = 1; + } + + drawLine(_metricsHistory.cpu, cpuColor); + drawLine(_metricsHistory.memory, memColor); +} + async function fetchStatus() { if (!shouldFetch("status")) return; try { @@ -136,6 +257,8 @@ async function fetchSystemInfo() { get("systemMemoryBar").style.width = `${Math.min(100, Math.max(0, memUsage))}%`; recordFetchSuccess("system"); + pushMetrics(cpuUsage, memUsage); + drawMetricsChart(); } catch (e) { recordFetchError("system"); } diff --git a/src/Undefined/webui/static/js/config-form.js b/src/Undefined/webui/static/js/config-form.js index ffadfe0c..111d64ac 100644 --- a/src/Undefined/webui/static/js/config-form.js +++ b/src/Undefined/webui/static/js/config-form.js @@ -251,6 +251,7 @@ function isLongText(value) { const FIELD_SELECT_OPTIONS = { api_mode: ["chat_completions", "responses"], + gif_analysis_mode: ["grid", "multi"], reasoning_effort_style: ["openai", "anthropic"], // path -> options key mapping (underscore-separated segments) image_gen_provider: ["xingzhige", "models"], diff --git a/src/Undefined/webui/static/js/i18n.js b/src/Undefined/webui/static/js/i18n.js index c7b1ab53..1d60d0ba 100644 --- a/src/Undefined/webui/static/js/i18n.js +++ b/src/Undefined/webui/static/js/i18n.js @@ -36,6 +36,7 @@ const I18N = { "overview.refresh": "刷新", "overview.system": "系统信息", "overview.resources": "资源使用", + "overview.chart": "资源趋势", "overview.runtime": "运行环境", "overview.cpu_model": "CPU 型号", "overview.cpu_usage": "CPU 占用率", @@ -91,6 +92,8 @@ const I18N = { "config.clear_search": "清除搜索", "config.expand_all": "全部展开", "config.collapse_all": "全部折叠", + "config.view_toml": "查看 TOML", + "config.view_form": "表单视图", "config.expand_section": "展开", "config.collapse_section": "折叠", "config.loading": "正在加载配置...", @@ -105,6 +108,12 @@ const I18N = { "config.reload_error": "配置重载失败。", "config.bootstrap_created": "检测到缺少 config.toml,已从示例生成;请在此页完善配置并保存。", + "config.history": "版本历史", + "config.history_empty": "暂无备份", + "config.history_restore": "恢复", + "config.history_restore_confirm": + "确定恢复到此版本?当前配置将自动备份。", + "config.history_restored": "已恢复配置", "logs.title": "运行日志", "logs.subtitle": "实时查看日志尾部输出。", "logs.auto": "自动刷新", @@ -273,6 +282,17 @@ const I18N = { "update.not_eligible": "未满足更新条件(仅支持官方 origin/main)", "update.failed": "更新失败", "update.no_restart": "更新已完成但未重启(请检查 uv sync 输出)", + "cmd.placeholder": "输入命令...", + "cmd.empty": "没有匹配的命令", + "cmd.tab_overview": "跳转到 概览", + "cmd.tab_config": "跳转到 配置", + "cmd.tab_logs": "跳转到 日志", + "cmd.tab_probes": "跳转到 探针", + "cmd.tab_memory": "跳转到 备忘录", + "cmd.tab_memes": "跳转到 表情包", + "cmd.tab_cognitive": "跳转到 认知记忆", + "cmd.refresh": "刷新当前页面", + "cmd.logout": "退出登录", }, en: { "landing.title": "Undefined Console", @@ -312,6 +332,7 @@ const I18N = { "overview.refresh": "Refresh", "overview.system": "System", "overview.resources": "Resources", + "overview.chart": "Resource Trends", "overview.runtime": "Runtime", "overview.cpu_model": "CPU Model", "overview.cpu_usage": "CPU Usage", @@ -372,6 +393,8 @@ const I18N = { "config.clear_search": "Clear search", "config.expand_all": "Expand all", "config.collapse_all": "Collapse all", + "config.view_toml": "View TOML", + "config.view_form": "Form View", "config.expand_section": "Expand", "config.collapse_section": "Collapse", "config.loading": "Loading configuration...", @@ -386,6 +409,12 @@ const I18N = { "config.reload_error": "Failed to reload configuration.", "config.bootstrap_created": "config.toml was missing and has been generated from the example. Please review and save your configuration.", + "config.history": "Version History", + "config.history_empty": "No backups yet", + "config.history_restore": "Restore", + "config.history_restore_confirm": + "Restore to this version? Current config will be backed up automatically.", + "config.history_restored": "Configuration restored", "logs.title": "System Logs", "logs.subtitle": "Real-time view of recent log output.", "logs.auto": "Auto Refresh", @@ -558,5 +587,16 @@ const I18N = { "update.no_restart": "Updated but not restarted (check uv sync output)", "config.aot_add": "+ Add Entry", "config.aot_remove": "Remove", + "cmd.placeholder": "Type a command...", + "cmd.empty": "No matching commands", + "cmd.tab_overview": "Go to Overview", + "cmd.tab_config": "Go to Config", + "cmd.tab_logs": "Go to Logs", + "cmd.tab_probes": "Go to Probes", + "cmd.tab_memory": "Go to Memory", + "cmd.tab_memes": "Go to Memes", + "cmd.tab_cognitive": "Go to Cognitive", + "cmd.refresh": "Refresh current page", + "cmd.logout": "Logout", }, }; diff --git a/src/Undefined/webui/static/js/log-view.js b/src/Undefined/webui/static/js/log-view.js index 23d6ff8b..9959474b 100644 --- a/src/Undefined/webui/static/js/log-view.js +++ b/src/Undefined/webui/static/js/log-view.js @@ -49,6 +49,27 @@ function filterLogLines(raw) { line.toLowerCase().includes(query), ); + // Time range filtering + const timeFrom = state.logTimeFrom + ? new Date(state.logTimeFrom).getTime() + : 0; + const timeTo = state.logTimeTo ? new Date(state.logTimeTo).getTime() : 0; + if (timeFrom || timeTo) { + const tsRe = /^(\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2})/; + const result = []; + let include = true; + for (const line of filtered) { + const m = line.match(tsRe); + if (m) { + const ts = new Date(m[1].replace(" ", "T")).getTime(); + include = + (!timeFrom || ts >= timeFrom) && (!timeTo || ts <= timeTo); + } + if (include) result.push(line); + } + filtered = result; + } + const total = base.total ?? rawLines.length; const matched = filtered.filter((line) => line.length > 0).length; return { filtered, total, matched }; @@ -115,7 +136,9 @@ function updateLogMeta(total, matched) { if ( state.logLevel !== "all" || state.logSearch.trim() || - state.logLevelGte + state.logLevelGte || + state.logTimeFrom || + state.logTimeTo ) { parts.push( `${t("logs.filtered")}: ${total > 0 ? `${matched}/${total}` : "0/0"}`, diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index baf380c7..5745ef3d 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -54,6 +54,7 @@ function refreshUI() { } function switchTab(tab) { + abortPendingRequests(); // Cancel pending requests from previous tab state.tab = tab; state.mobileDrawerOpen = false; const mainContent = document.querySelector(".main-content"); @@ -175,7 +176,176 @@ function setMobileInlineActionsOpen(key, open) { syncMobileChrome(); } +// Command Palette +const _cmdCommands = [ + { + id: "overview", + label: () => t("cmd.tab_overview"), + action: () => switchTab("overview"), + keys: "1", + }, + { + id: "config", + label: () => t("cmd.tab_config"), + action: () => switchTab("config"), + keys: "2", + }, + { + id: "logs", + label: () => t("cmd.tab_logs"), + action: () => switchTab("logs"), + keys: "3", + }, + { + id: "probes", + label: () => t("cmd.tab_probes"), + action: () => switchTab("probes"), + keys: "4", + }, + { + id: "memory", + label: () => t("cmd.tab_memory"), + action: () => switchTab("memory"), + keys: "5", + }, + { + id: "memes", + label: () => t("cmd.tab_memes"), + action: () => switchTab("memes"), + keys: "6", + }, + { + id: "cognitive", + label: () => t("cmd.tab_cognitive"), + action: () => switchTab("cognitive"), + keys: "7", + }, + { + id: "refresh", + label: () => t("cmd.refresh"), + action: () => location.reload(), + keys: "R", + }, + { + id: "logout", + label: () => t("cmd.logout"), + action: () => { + get("btnLogout")?.click(); + }, + keys: "", + }, +]; + +let _cmdActiveIndex = 0; + +function openCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + const input = get("cmdPaletteInput"); + if (!overlay || !input) return; + input.placeholder = t("cmd.placeholder"); + overlay.style.display = "flex"; + input.value = ""; + _cmdActiveIndex = 0; + _renderCmdList(""); + trapFocus(overlay); +} + +function closeCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + if (!overlay) return; + releaseFocus(overlay); + overlay.style.display = "none"; +} + +function _renderCmdList(query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (filtered.length === 0) { + list.innerHTML = `
${escapeHtml(t("cmd.empty"))}
`; + return; + } + _cmdActiveIndex = Math.min(_cmdActiveIndex, filtered.length - 1); + list.innerHTML = filtered + .map( + (c, i) => + `
${escapeHtml(c.label())}${c.keys ? `${escapeHtml(c.keys)}` : ""}
`, + ) + .join(""); + list.querySelectorAll(".cmd-palette-item").forEach((el, i) => { + el.addEventListener("click", () => { + closeCmdPalette(); + filtered[i].action(); + }); + el.addEventListener("mouseenter", () => { + _cmdActiveIndex = i; + _renderCmdList(query); + }); + }); +} + +function _handleCmdKey(e, query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (e.key === "ArrowDown") { + e.preventDefault(); + _cmdActiveIndex = (_cmdActiveIndex + 1) % filtered.length; + _renderCmdList(query); + } else if (e.key === "ArrowUp") { + e.preventDefault(); + _cmdActiveIndex = + (_cmdActiveIndex - 1 + filtered.length) % filtered.length; + _renderCmdList(query); + } else if (e.key === "Enter" && filtered.length > 0) { + e.preventDefault(); + closeCmdPalette(); + filtered[_cmdActiveIndex].action(); + } +} + async function init() { + // Global error handlers + window.onerror = function (message, source, lineno, colno, error) { + console.error("[GlobalError]", { + message, + source, + lineno, + colno, + error, + }); + if (typeof showToast === "function") { + showToast(`⚠️ ${message}`, "error", 5000); + } + return false; + }; + + window.onunhandledrejection = function (event) { + const reason = event.reason; + const msg = reason instanceof Error ? reason.message : String(reason); + // Don't toast for routine auth errors or aborted requests + if ( + msg === "Unauthorized" || + msg === "The user aborted a request." || + reason?.name === "AbortError" + ) + return; + console.error("[UnhandledRejection]", reason); + if (typeof showToast === "function") { + showToast(`⚠️ ${msg}`, "error", 5000); + } + }; + if ( window.RuntimeController && typeof window.RuntimeController.init === "function" @@ -430,6 +600,21 @@ async function init() { }; } + const logTimeFrom = get("logTimeFrom"); + if (logTimeFrom) { + logTimeFrom.addEventListener("change", () => { + state.logTimeFrom = logTimeFrom.value; + renderLogs(); + }); + } + const logTimeTo = get("logTimeTo"); + if (logTimeTo) { + logTimeTo.addEventListener("change", () => { + state.logTimeTo = logTimeTo.value; + renderLogs(); + }); + } + const logSearchInput = get("logSearchInput"); if (logSearchInput) { logSearchInput.addEventListener("input", () => { @@ -508,6 +693,109 @@ async function init() { if (collapseAllBtn) collapseAllBtn.onclick = () => setAllSectionsCollapsed(true); + get("btnToggleToml").onclick = async function () { + const formGrid = get("formSections"); + const tomlViewer = get("tomlViewer"); + const historyPanel = get("configHistoryPanel"); + const btn = get("btnToggleToml"); + if (!formGrid || !tomlViewer || !btn) return; + + if (historyPanel) historyPanel.style.display = "none"; + + const isShowingToml = tomlViewer.style.display !== "none"; + if (isShowingToml) { + tomlViewer.style.display = "none"; + formGrid.style.display = ""; + btn.innerText = t("config.view_toml"); + } else { + try { + const res = await api("/api/config"); + const data = await res.json(); + const content = data.content || ""; + get("tomlContent").textContent = content; + formGrid.style.display = "none"; + tomlViewer.style.display = "block"; + btn.innerText = t("config.view_form"); + } catch (e) { + showToast(`${t("common.error")}: ${e.message}`, "error", 5000); + } + } + }; + + get("btnConfigHistory")?.addEventListener("click", async () => { + const panel = get("configHistoryPanel"); + const formGrid = get("formSections"); + const tomlViewer = get("tomlViewer"); + if (!panel) return; + + const isShowing = panel.style.display !== "none"; + if (isShowing) { + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + return; + } + + try { + const res = await api("/api/config/history"); + const data = await res.json(); + const backups = data.backups || []; + const list = get("configHistoryList"); + if (!list) return; + + if (backups.length === 0) { + list.innerHTML = `

${escapeHtml(t("config.history_empty"))}

`; + } else { + list.innerHTML = backups + .map((b) => { + const date = new Date(b.mtime * 1000); + const sizeKB = (b.size / 1024).toFixed(1); + return `
+
+
+ ${escapeHtml(b.name)} + ${sizeKB} KB · ${date.toLocaleString()} +
+ +
+
`; + }) + .join(""); + + list.querySelectorAll("[data-restore-name]").forEach((btn) => { + btn.addEventListener("click", async () => { + const name = btn.getAttribute("data-restore-name"); + if (!confirm(t("config.history_restore_confirm"))) + return; + try { + await api("/api/config/history/restore", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ name }), + }); + showToast(t("config.history_restored"), "success"); + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + loadConfig(); + } catch (e) { + showToast( + `${t("common.error")}: ${e.message}`, + "error", + ); + } + }); + }); + } + + if (formGrid) formGrid.style.display = "none"; + if (tomlViewer) tomlViewer.style.display = "none"; + panel.style.display = "block"; + } catch (e) { + showToast(`${t("common.error")}: ${e.message}`, "error"); + } + }); + const logout = async () => { try { await api(authEndpointCandidates("logout"), { method: "POST" }); @@ -583,6 +871,40 @@ async function init() { switchTab("config"); } + // Command Palette keybinding + document.addEventListener("keydown", (e) => { + if ((e.metaKey || e.ctrlKey) && e.key === "k") { + e.preventDefault(); + const overlay = get("cmdPaletteOverlay"); + if (overlay && overlay.style.display !== "none") { + closeCmdPalette(); + } else { + openCmdPalette(); + } + } + if (e.key === "Escape") { + closeCmdPalette(); + } + }); + + const cmdInput = get("cmdPaletteInput"); + if (cmdInput) { + cmdInput.addEventListener("input", () => { + _cmdActiveIndex = 0; + _renderCmdList(cmdInput.value); + }); + cmdInput.addEventListener("keydown", (e) => { + _handleCmdKey(e, cmdInput.value); + }); + } + + const cmdOverlay = get("cmdPaletteOverlay"); + if (cmdOverlay) { + cmdOverlay.addEventListener("click", (e) => { + if (e.target === cmdOverlay) closeCmdPalette(); + }); + } + document.addEventListener("visibilitychange", () => { if (document.hidden) { stopStatusTimer(); diff --git a/src/Undefined/webui/static/js/memes.js b/src/Undefined/webui/static/js/memes.js index 0afef677..465c3cd7 100644 --- a/src/Undefined/webui/static/js/memes.js +++ b/src/Undefined/webui/static/js/memes.js @@ -314,6 +314,27 @@ return payload; } + let _debounceTimer = null; + let _pendingRefresh = false; + + function debouncedFetchList(delayMs = 350) { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + } + _debounceTimer = setTimeout(() => { + _debounceTimer = null; + fetchList().catch(showError); + }, delayMs); + } + + function flushDebouncedFetchList() { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + _debounceTimer = null; + } + fetchList().catch(showError); + } + async function fetchList(options = {}) { const append = !!options.append; if (append) { @@ -321,6 +342,7 @@ return null; } } else if (state.loading) { + _pendingRefresh = true; return null; } @@ -415,6 +437,10 @@ } renderLoadMore(); } + if (!append && _pendingRefresh) { + _pendingRefresh = false; + fetchList().catch(showError); + } } } @@ -558,7 +584,7 @@ state.initialized = true; get("btnMemesRefresh")?.addEventListener("click", refreshAll); get("btnMemesSearch")?.addEventListener("click", () => { - fetchList().catch(showError); + flushDebouncedFetchList(); }); get("btnMemesLoadMore")?.addEventListener("click", () => { fetchList({ append: true }).catch(showError); @@ -579,14 +605,17 @@ get("btnMemesDelete")?.addEventListener("click", () => { deleteSelected().catch(showError); }); - bindEnter("memesSearchInput", () => { - fetchList().catch(showError); - }); - bindEnter("memesKeywordQuery", () => { - fetchList().catch(showError); - }); - bindEnter("memesSemanticQuery", () => { - fetchList().catch(showError); + bindEnter("memesSearchInput", flushDebouncedFetchList); + bindEnter("memesKeywordQuery", flushDebouncedFetchList); + bindEnter("memesSemanticQuery", flushDebouncedFetchList); + [ + "memesSearchInput", + "memesKeywordQuery", + "memesSemanticQuery", + ].forEach((id) => { + const el = get(id); + if (el) + el.addEventListener("input", () => debouncedFetchList()); }); [ "memesQueryMode", diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index 6968c0fa..c1e1282d 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -303,12 +303,12 @@ params.set(key, text); } - function appendPositiveIntParam(params, key, value) { + function appendPositiveIntParam(params, key, value, max = 500) { const text = String(value || "").trim(); if (!text) return; const num = Number.parseInt(text, 10); if (!Number.isFinite(num) || num <= 0) return; - params.set(key, String(num)); + params.set(key, String(Math.min(num, max))); } function formatNumeric(value, digits = 4) { @@ -636,6 +636,8 @@ if (buffer.trim()) emitBlock(buffer); } + let _memoryMutating = false; + function renderMemoryItems(payload) { const container = get("runtimeMemoryList"); const meta = get("runtimeMemoryMeta"); @@ -666,9 +668,155 @@ const uuid = escapeHtml(item.uuid || ""); const fact = escapeHtml(item.fact || ""); const created = escapeHtml(item.created_at || ""); - return `
${uuid}${created}
${fact}
`; + return `
${uuid}
${created}
${fact}
`; }) .join(""); + + container.querySelectorAll(".memory-btn-edit").forEach((btn) => { + btn.addEventListener("click", () => + startEditMemory(btn.dataset.uuid), + ); + }); + container.querySelectorAll(".memory-btn-delete").forEach((btn) => { + btn.addEventListener("click", () => deleteMemory(btn.dataset.uuid)); + }); + } + + function startEditMemory(uuid) { + const container = get("runtimeMemoryList"); + if (!container) return; + const itemEl = container.querySelector( + `.runtime-list-item[data-uuid="${CSS.escape(uuid)}"]`, + ); + if (!itemEl) return; + const factEl = itemEl.querySelector(".runtime-list-fact"); + if (!factEl || factEl.dataset.editing === "true") return; + + const currentText = factEl.textContent || ""; + factEl.dataset.editing = "true"; + factEl.innerHTML = ""; + + const textarea = document.createElement("textarea"); + textarea.className = "form-control memory-edit-area"; + textarea.value = currentText; + + const actions = document.createElement("div"); + actions.className = "memory-edit-actions"; + const saveBtn = document.createElement("button"); + saveBtn.className = "btn btn-sm"; + saveBtn.textContent = "保存"; + const cancelBtn = document.createElement("button"); + cancelBtn.className = "btn btn-sm"; + cancelBtn.textContent = "取消"; + actions.append(saveBtn, cancelBtn); + factEl.append(textarea, actions); + textarea.focus(); + + cancelBtn.addEventListener("click", () => { + delete factEl.dataset.editing; + factEl.innerHTML = ""; + factEl.textContent = currentText; + }); + + saveBtn.addEventListener("click", () => + updateMemory(uuid, textarea.value), + ); + + textarea.addEventListener("keydown", (e) => { + if (e.key === "Escape") { + e.preventDefault(); + cancelBtn.click(); + } + if (e.key === "Enter" && e.ctrlKey) { + e.preventDefault(); + saveBtn.click(); + } + }); + } + + async function createMemory() { + if (_memoryMutating) return; + const input = get("memoryCreateInput"); + if (!input) return; + const fact = String(input.value || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + _memoryMutating = true; + const btn = get("btnMemoryCreate"); + if (btn) btn.disabled = true; + try { + const res = await api("/api/runtime/memory", { + method: "POST", + body: JSON.stringify({ fact }), + }); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已添加", "success"); + input.value = ""; + await searchMemory(); + } catch (err) { + showToast(`添加失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + if (btn) btn.disabled = false; + } + } + + async function updateMemory(uuid, newFact) { + const fact = String(newFact || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + if (_memoryMutating) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "PATCH", + body: JSON.stringify({ fact }), + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已更新", "success"); + await searchMemory(); + } catch (err) { + showToast(`更新失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } + } + + async function deleteMemory(uuid) { + if (_memoryMutating) return; + if (!confirm(`确认删除记忆 ${uuid.slice(0, 8)}…?`)) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "DELETE", + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已删除", "success"); + await searchMemory(); + } catch (err) { + showToast(`删除失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } } function setListMessage(metaId, listId, message) { @@ -1204,6 +1352,10 @@ if (memoryRefresh) memoryRefresh.addEventListener("click", refreshMemory); + const memoryCreateBtn = get("btnMemoryCreate"); + if (memoryCreateBtn) + memoryCreateBtn.addEventListener("click", createMemory); + const runMemorySearch = () => runQueryAction("memory", "btnRuntimeMemorySearch", searchMemory); const runEventsSearch = () => diff --git a/src/Undefined/webui/static/js/state.js b/src/Undefined/webui/static/js/state.js index 2b1561b3..284f0da1 100644 --- a/src/Undefined/webui/static/js/state.js +++ b/src/Undefined/webui/static/js/state.js @@ -190,6 +190,8 @@ const state = { bot: { running: false, pid: null, uptime: 0 }, logsRaw: "", logSearch: "", + logTimeFrom: "", + logTimeTo: "", logLevel: "all", logLevelGte: false, logType: "bot", diff --git a/src/Undefined/webui/static/js/ui.js b/src/Undefined/webui/static/js/ui.js index c42a4dd5..f455ba2c 100644 --- a/src/Undefined/webui/static/js/ui.js +++ b/src/Undefined/webui/static/js/ui.js @@ -186,6 +186,46 @@ function showToast(message, type = "info", duration = 3000) { }, duration); } +// Focus trap for modals +const _focusTrapStack = []; + +function trapFocus(container) { + if (!container) return; + const focusable = container.querySelectorAll( + 'a[href], button:not([disabled]), input:not([disabled]), select:not([disabled]), textarea:not([disabled]), [tabindex]:not([tabindex="-1"])', + ); + if (focusable.length === 0) return; + const first = focusable[0]; + const last = focusable[focusable.length - 1]; + + function handler(e) { + if (e.key !== "Tab") return; + if (e.shiftKey) { + if (document.activeElement === first) { + e.preventDefault(); + last.focus(); + } + } else { + if (document.activeElement === last) { + e.preventDefault(); + first.focus(); + } + } + } + + container.addEventListener("keydown", handler); + _focusTrapStack.push({ container, handler }); + first.focus(); +} + +function releaseFocus(container) { + const idx = _focusTrapStack.findIndex((t) => t.container === container); + if (idx === -1) return; + const entry = _focusTrapStack[idx]; + entry.container.removeEventListener("keydown", entry.handler); + _focusTrapStack.splice(idx, 1); +} + function setConfigState(mode) { const stateEl = get("configState"); const grid = get("formSections"); diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 3f4c8a97..a3067209 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -314,6 +314,15 @@

运行概览

-- + +
+
资源趋势
+ +
+ CPU + Memory +
+
@@ -347,18 +356,30 @@

配置修改

更多操作
+ +
+ + @@ -403,6 +424,10 @@

系统日志

+ + @@ -446,7 +471,7 @@

探针

记忆检索

-

只读检索记忆、认知事件与侧写。

+

管理长期记忆,检索认知事件与侧写。

@@ -467,7 +492,7 @@

记忆检索

- @@ -535,7 +560,7 @@

记忆检索

-
长期记忆查询
+
长期记忆管理
记忆检索
- @@ -555,6 +580,14 @@

记忆检索

+ +
+ +
+ +
+
@@ -783,6 +816,14 @@

MIT License

+ + + diff --git a/src/Undefined/webui/utils/config_sync.py b/src/Undefined/webui/utils/config_sync.py index 46c16eef..c3df03a6 100644 --- a/src/Undefined/webui/utils/config_sync.py +++ b/src/Undefined/webui/utils/config_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import dataclasses import tomllib from dataclasses import dataclass from pathlib import Path @@ -21,6 +22,7 @@ class ConfigTemplateSyncResult: added_paths: list[str] removed_paths: list[str] comments: CommentMap + updated_comment_paths: list[str] = dataclasses.field(default_factory=list) def _parse_toml_text(content: str, *, label: str) -> TomlData: @@ -226,6 +228,17 @@ def _augment_pool_model_comments( return merged +def _collect_updated_comment_paths( + current: CommentMap, example: CommentMap +) -> list[str]: + """收集在 example 与 current 中都存在、但注释文本不同的路径。""" + return [ + key + for key, example_value in example.items() + if key in current and current[key] != example_value + ] + + def _merge_comment_maps(current: CommentMap, example: CommentMap) -> CommentMap: merged: CommentMap = dict(current) for key, value in example.items(): @@ -248,16 +261,21 @@ def sync_config_text( if prune and removed_paths: merged = _prune_to_template(merged, prepared_example_data) example_comments = parse_comment_map_text(example_text) - comments = _merge_comment_maps( - parse_comment_map_text(current_text), - _augment_pool_model_comments(example_comments, example_data, current_data), + current_comments = parse_comment_map_text(current_text) + augmented_example_comments = _augment_pool_model_comments( + example_comments, example_data, current_data + ) + updated_comment_paths = _collect_updated_comment_paths( + current_comments, augmented_example_comments ) + comments = _merge_comment_maps(current_comments, augmented_example_comments) content = render_toml(merged, comments=comments) return ConfigTemplateSyncResult( content=content, added_paths=added_paths, removed_paths=removed_paths, comments=comments, + updated_comment_paths=updated_comment_paths, ) diff --git a/tests/test_ai_parsing.py b/tests/test_ai_parsing.py new file mode 100644 index 00000000..ff7015b9 --- /dev/null +++ b/tests/test_ai_parsing.py @@ -0,0 +1,101 @@ +"""Tests for Undefined.ai.parsing module.""" + +from __future__ import annotations + +import pytest + +from Undefined.ai.parsing import extract_choices_content + + +class TestExtractChoicesContent: + """Tests for extract_choices_content().""" + + def test_standard_response(self) -> None: + result: dict[str, object] = { + "choices": [{"message": {"content": "Hello, world!"}}] + } + assert extract_choices_content(result) == "Hello, world!" + + def test_data_wrapped_response(self) -> None: + result: dict[str, object] = { + "data": {"choices": [{"message": {"content": "nested content"}}]} + } + assert extract_choices_content(result) == "nested content" + + def test_output_text_field(self) -> None: + result: dict[str, object] = { + "output_text": "direct output", + "choices": [{"message": {"content": "ignored"}}], + } + assert extract_choices_content(result) == "direct output" + + def test_output_text_preferred_over_choices(self) -> None: + result: dict[str, object] = {"output_text": "preferred"} + assert extract_choices_content(result) == "preferred" + + def test_output_text_non_string_falls_through(self) -> None: + result: dict[str, object] = { + "output_text": 42, + "choices": [{"message": {"content": "fallback"}}], + } + assert extract_choices_content(result) == "fallback" + + def test_empty_choices_raises(self) -> None: + result: dict[str, object] = {"choices": []} + with pytest.raises(KeyError): + extract_choices_content(result) + + def test_no_choices_key_raises(self) -> None: + result: dict[str, object] = {"id": "123", "object": "chat.completion"} + with pytest.raises(KeyError): + extract_choices_content(result) + + def test_no_content_in_message(self) -> None: + result: dict[str, object] = {"choices": [{"message": {}}]} + assert extract_choices_content(result) == "" + + def test_message_is_none(self) -> None: + """message=None triggers AttributeError in tool_calls check (known bug).""" + result: dict[str, object] = {"choices": [{"message": None}]} + with pytest.raises(AttributeError): + extract_choices_content(result) + + def test_choice_with_content_directly(self) -> None: + result: dict[str, object] = {"choices": [{"content": "direct"}]} + assert extract_choices_content(result) == "direct" + + def test_tool_calls_no_content(self) -> None: + result: dict[str, object] = { + "choices": [{"message": {"tool_calls": [{"function": {"name": "test"}}]}}] + } + assert extract_choices_content(result) == "" + + def test_refusal_field_content_still_extracted(self) -> None: + result: dict[str, object] = { + "choices": [ + { + "message": { + "content": "I can help with that.", + "refusal": None, + } + } + ] + } + assert extract_choices_content(result) == "I can help with that." + + def test_multiple_choices_returns_first(self) -> None: + result: dict[str, object] = { + "choices": [ + {"message": {"content": "first"}}, + {"message": {"content": "second"}}, + ] + } + assert extract_choices_content(result) == "first" + + def test_empty_dict_raises(self) -> None: + with pytest.raises(KeyError): + extract_choices_content({}) + + def test_message_is_string(self) -> None: + result: dict[str, object] = {"choices": [{"message": "plain string"}]} + assert extract_choices_content(result) == "plain string" diff --git a/tests/test_ai_queue_budget.py b/tests/test_ai_queue_budget.py new file mode 100644 index 00000000..001dec59 --- /dev/null +++ b/tests/test_ai_queue_budget.py @@ -0,0 +1,212 @@ +"""Tests for Undefined.ai.queue_budget module.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from Undefined.ai.queue_budget import ( + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS, + QUEUED_LLM_TIMEOUT_GRACE_SECONDS, + compute_queued_llm_timeout_seconds, + resolve_effective_retry_count, +) + + +class TestResolveEffectiveRetryCount: + """Tests for resolve_effective_retry_count().""" + + def test_from_runtime_config(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + assert resolve_effective_retry_count(cfg) == 3 + + def test_from_queue_manager(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=5) + qm = SimpleNamespace(get_max_retries=lambda: 2) + assert resolve_effective_retry_count(cfg, qm) == 2 + + def test_queue_manager_takes_precedence(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=10) + qm = SimpleNamespace(get_max_retries=lambda: 1) + assert resolve_effective_retry_count(cfg, qm) == 1 + + def test_negative_retries_clamped_to_zero(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=-5) + assert resolve_effective_retry_count(cfg) == 0 + + def test_none_retries_returns_zero(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=None) + assert resolve_effective_retry_count(cfg) == 0 + + def test_missing_attribute_returns_zero(self) -> None: + cfg = SimpleNamespace() + assert resolve_effective_retry_count(cfg) == 0 + + def test_queue_manager_invalid_return(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + qm = SimpleNamespace(get_max_retries=lambda: "invalid") + assert resolve_effective_retry_count(cfg, qm) == 0 + + def test_queue_manager_negative_clamped(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + qm = SimpleNamespace(get_max_retries=lambda: -1) + assert resolve_effective_retry_count(cfg, qm) == 0 + + def test_queue_manager_none_no_get_max_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=4) + qm = SimpleNamespace() + assert resolve_effective_retry_count(cfg, qm) == 4 + + def test_zero_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + assert resolve_effective_retry_count(cfg) == 0 + + +class TestComputeQueuedLlmTimeoutSeconds: + """Tests for compute_queued_llm_timeout_seconds().""" + + def _make_model_config( + self, + interval: float = 0.0, + pool_enabled: bool = False, + pool_models: list[SimpleNamespace] | None = None, + ) -> SimpleNamespace: + pool = SimpleNamespace( + enabled=pool_enabled, + models=pool_models or [], + ) + return SimpleNamespace(queue_interval_seconds=interval, pool=pool) + + def test_defaults_zero_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_with_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=2) + model_cfg = self._make_model_config(interval=1.0) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # attempts=3, dispatch_intervals=3 (2 retries + 1 first) + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 3 + + 1.0 * 3 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_explicit_retry_count(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=99) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, retry_count=1) + # explicit retry_count=1 overrides config + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 2 + + 0.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_initial_wait(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, initial_wait_seconds=10.0 + ) + expected = ( + 10.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_no_first_dispatch_interval(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=2) + model_cfg = self._make_model_config(interval=5.0) + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, include_first_dispatch_interval=False + ) + # dispatch_intervals = retries only = 2 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 3 + + 5.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_custom_attempt_timeout(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, attempt_timeout_seconds=60.0 + ) + expected = 0.0 + 60.0 * 1 + 0.0 * 1 + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + assert result == expected + + def test_custom_grace_seconds(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, grace_seconds=100.0) + expected = ( + 0.0 + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + 0.0 * 1 + 100.0 + ) + assert result == expected + + def test_pool_models_max_interval(self) -> None: + pool_models = [ + SimpleNamespace(queue_interval_seconds=2.0), + SimpleNamespace(queue_interval_seconds=5.0), + SimpleNamespace(queue_interval_seconds=1.0), + ] + cfg = SimpleNamespace(ai_request_max_retries=1) + model_cfg = self._make_model_config( + interval=3.0, pool_enabled=True, pool_models=pool_models + ) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # max interval = max(3.0, 2.0, 5.0, 1.0) = 5.0 + # attempts=2, dispatch=2 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 2 + + 5.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_pool_disabled_ignores_pool_models(self) -> None: + pool_models = [SimpleNamespace(queue_interval_seconds=100.0)] + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config( + interval=1.0, pool_enabled=False, pool_models=pool_models + ) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # pool disabled: only base interval=1.0 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 1.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_negative_retry_count_clamped(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, retry_count=-5) + # clamped to 0 → 1 attempt + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected diff --git a/tests/test_ai_tokens.py b/tests/test_ai_tokens.py new file mode 100644 index 00000000..bd701652 --- /dev/null +++ b/tests/test_ai_tokens.py @@ -0,0 +1,83 @@ +"""Tests for Undefined.ai.tokens module.""" + +from __future__ import annotations + +from unittest.mock import patch + +from Undefined.ai.tokens import TokenCounter + + +class TestTokenCounter: + """Tests for TokenCounter.""" + + def test_empty_string(self) -> None: + counter = TokenCounter() + result = counter.count("") + assert result == 0 or isinstance(result, int) + + def test_normal_text(self) -> None: + counter = TokenCounter() + result = counter.count("Hello, world!") + assert result > 0 + + def test_unicode_text(self) -> None: + counter = TokenCounter() + result = counter.count("你好世界!🌍") + assert result > 0 + + def test_long_text(self) -> None: + counter = TokenCounter() + short_count = counter.count("hello") + long_count = counter.count("hello " * 1000) + assert long_count > short_count + + def test_whitespace_only(self) -> None: + counter = TokenCounter() + result = counter.count(" \n\t ") + assert isinstance(result, int) + + def test_single_character(self) -> None: + counter = TokenCounter() + result = counter.count("a") + assert result >= 1 + + def test_fallback_when_tiktoken_unavailable(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + result = counter.count("hello world") + expected = len("hello world") // 3 + 1 + assert result == expected + + def test_fallback_empty_string(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + result = counter.count("") + assert result == 1 # len("") // 3 + 1 == 1 + + def test_fallback_short_text(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + assert counter.count("ab") == 1 # 2 // 3 + 1 + + def test_fallback_exact_multiple(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + assert counter.count("abc") == 2 # 3 // 3 + 1 + + def test_default_model_name(self) -> None: + counter = TokenCounter() + assert counter._model_name == "gpt-4" + + def test_custom_model_name(self) -> None: + counter = TokenCounter(model_name="gpt-3.5-turbo") + assert counter._model_name == "gpt-3.5-turbo" + + def test_tiktoken_load_failure_graceful(self) -> None: + with patch("builtins.__import__", side_effect=ImportError("no tiktoken")): + counter = TokenCounter.__new__(TokenCounter) + counter._model_name = "gpt-4" + counter._tokenizer = None + counter._try_load_tokenizer() + assert counter._tokenizer is None + result = counter.count("test text") + assert result == len("test text") // 3 + 1 diff --git a/tests/test_arxiv_analysis_agent.py b/tests/test_arxiv_analysis_agent.py new file mode 100644 index 00000000..82f8c53d --- /dev/null +++ b/tests/test_arxiv_analysis_agent.py @@ -0,0 +1,351 @@ +"""arxiv_analysis_agent 单元测试。""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# config.json / callable.json 结构检查 +# --------------------------------------------------------------------------- + +AGENT_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "arxiv_analysis_agent" +) + + +def _load_json(name: str) -> dict[str, Any]: + result: dict[str, Any] = json.loads((AGENT_DIR / name).read_text(encoding="utf-8")) + return result + + +class TestAgentConfig: + def test_config_json_schema(self) -> None: + cfg = _load_json("config.json") + assert cfg["type"] == "function" + func = cfg["function"] + assert func["name"] == "arxiv_analysis_agent" + assert "paper_id" in func["parameters"]["properties"] + assert "paper_id" in func["parameters"]["required"] + + def test_callable_json(self) -> None: + cfg = _load_json("callable.json") + assert cfg["enabled"] is True + assert isinstance(cfg["allowed_callers"], list) + assert len(cfg["allowed_callers"]) > 0 + + def test_tools_exist(self) -> None: + tools_dir = AGENT_DIR / "tools" + assert (tools_dir / "fetch_paper" / "config.json").exists() + assert (tools_dir / "fetch_paper" / "handler.py").exists() + assert (tools_dir / "read_paper_pages" / "config.json").exists() + assert (tools_dir / "read_paper_pages" / "handler.py").exists() + + def test_prompt_md_exists(self) -> None: + assert (AGENT_DIR / "prompt.md").exists() + content = (AGENT_DIR / "prompt.md").read_text(encoding="utf-8") + assert len(content) > 100 + + +# --------------------------------------------------------------------------- +# handler.py +# --------------------------------------------------------------------------- + + +class TestHandler: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_invalid_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": "not-a-valid-id"}, {}) + assert "无法解析" in result + + @pytest.mark.asyncio + async def test_valid_paper_id_calls_runner(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + mock_result = "分析结果" + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_run: + result = await execute( + {"paper_id": "2301.07041", "prompt": "分析方法论"}, + {"request_id": "test-123"}, + ) + assert result == mock_result + mock_run.assert_awaited_once() + call_kwargs = mock_run.call_args[1] + assert call_kwargs["agent_name"] == "arxiv_analysis_agent" + assert "分析方法论" in call_kwargs["user_content"] + + @pytest.mark.asyncio + async def test_url_input_normalized(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value="ok", + ): + ctx: dict[str, Any] = {} + await execute( + {"paper_id": "https://arxiv.org/abs/2301.07041"}, + ctx, + ) + assert ctx["arxiv_paper_id"] == "2301.07041" + + +# --------------------------------------------------------------------------- +# fetch_paper tool +# --------------------------------------------------------------------------- + + +def _make_paper_info( + paper_id: str = "2301.07041", +) -> Any: + from Undefined.arxiv.models import PaperInfo + + return PaperInfo( + paper_id=paper_id, + title="Test Paper Title", + authors=("Author A", "Author B"), + summary="This is a test abstract.", + published="2023-01-17T00:00:00Z", + updated="2023-01-18T00:00:00Z", + primary_category="cs.AI", + abs_url=f"https://arxiv.org/abs/{paper_id}", + pdf_url=f"https://arxiv.org/pdf/{paper_id}.pdf", + ) + + +class TestFetchPaper: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + result = await execute({"paper_id": ""}, {}) + assert "错误" in result + + @pytest.mark.asyncio + async def test_metadata_only_on_download_failure(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + side_effect=RuntimeError("network error"), + ), + ): + result = await execute({"paper_id": "2301.07041"}, {}) + assert "Test Paper Title" in result + assert "Author A" in result + assert "PDF 下载失败" in result + + @pytest.mark.asyncio + async def test_paper_not_found(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=None, + ): + result = await execute({"paper_id": "9999.99999"}, {}) + assert "未找到" in result + + @pytest.mark.asyncio + async def test_successful_fetch_with_pdf(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + pdf_path = tmp_path / "test.pdf" + + import fitz + + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Hello world") + doc.save(str(pdf_path)) + doc.close() + + from Undefined.arxiv.downloader import PaperDownloadResult + + download_result = PaperDownloadResult( + path=pdf_path, size_bytes=1024, status="downloaded" + ) + + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + return_value=(download_result, tmp_path), + ), + ): + ctx: dict[str, Any] = {} + result = await execute({"paper_id": "2301.07041"}, ctx) + assert "Test Paper Title" in result + assert "1 页" in result + assert ctx["_arxiv_pdf_path"] == str(pdf_path) + assert ctx["_arxiv_pdf_pages"] == 1 + + +# --------------------------------------------------------------------------- +# read_paper_pages tool +# --------------------------------------------------------------------------- + + +class TestReadPaperPages: + @pytest.mark.asyncio + async def test_no_pdf_downloaded(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + result = await execute({"start_page": 1, "end_page": 1}, {}) + assert "先调用 fetch_paper" in result + + @pytest.mark.asyncio + async def test_read_single_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Page 1 content") + page2 = doc.new_page() + page2.insert_text((72, 72), "Page 2 content") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 2, + } + result = await execute({"start_page": 1, "end_page": 1}, ctx) + assert "第 1 页" in result + assert "Page 1 content" in result + assert "Page 2 content" not in result + + @pytest.mark.asyncio + async def test_read_page_range(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + for i in range(3): + page = doc.new_page() + page.insert_text((72, 72), f"Content of page {i + 1}") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 3, + } + result = await execute({"start_page": 1, "end_page": 3}, ctx) + assert "第 1-3 页" in result or "第 1 页" in result + assert "Content of page 1" in result + assert "Content of page 3" in result + + @pytest.mark.asyncio + async def test_out_of_range_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + doc.new_page() + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 1, + } + result = await execute({"start_page": 5, "end_page": 10}, ctx) + assert "错误" in result + + @pytest.mark.asyncio + async def test_invalid_page_numbers(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": "/some/path.pdf", + "_arxiv_pdf_pages": 10, + } + result = await execute({"start_page": "abc", "end_page": "def"}, ctx) + assert "整数" in result + + +# --------------------------------------------------------------------------- +# web_agent callable.json 更新检查 +# --------------------------------------------------------------------------- + + +class TestWebAgentCallable: + def test_web_agent_allows_new_callers(self) -> None: + web_agent_callable = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "web_agent" + / "callable.json" + ) + cfg = json.loads(web_agent_callable.read_text(encoding="utf-8")) + callers = cfg["allowed_callers"] + assert "summary_agent" in callers + assert "arxiv_analysis_agent" in callers diff --git a/tests/test_arxiv_sender.py b/tests/test_arxiv_sender.py index 120ab0af..312075b5 100644 --- a/tests/test_arxiv_sender.py +++ b/tests/test_arxiv_sender.py @@ -17,6 +17,7 @@ @pytest.fixture(autouse=True) def _clear_inflight() -> None: arxiv_sender._INFLIGHT_SENDS.clear() + arxiv_sender._RECENT_SENDS.clear() def _paper_info() -> PaperInfo: @@ -165,3 +166,187 @@ async def _fake_once(**_: object) -> str: assert first_result == "ok" assert second_result == "ok" assert called == 1 + + +@pytest.mark.asyncio +async def test_recent_send_blocks_duplicate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After a successful send, duplicate should be blocked within cooldown.""" + sender = _sender() + call_count = 0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # First send should succeed + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Second send should be blocked by time-based dedup + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert "近期已发送过" in result2 + assert call_count == 1 # Still only 1 call + + +@pytest.mark.asyncio +async def test_recent_send_expires_after_cooldown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After cooldown expires, the same paper can be sent again.""" + sender = _sender() + call_count = 0 + mock_time = 0.0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + def _fake_monotonic() -> float: + return mock_time + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + # Patch time.monotonic in the arxiv_sender module + monkeypatch.setattr( + arxiv_sender, "time", SimpleNamespace(monotonic=_fake_monotonic) + ) + + # First send at time 0 + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Advance time past cooldown + mock_time = arxiv_sender._DEDUP_COOLDOWN_SECONDS + 1.0 + + # Second send should succeed now + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result2 == "ok" + assert call_count == 2 # Second call executed + + +@pytest.mark.asyncio +async def test_recent_send_capacity_limit() -> None: + """When _RECENT_SENDS exceeds max size, oldest entries are evicted.""" + # Fill with max_size + 100 entries + for i in range(arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + arxiv_sender._RECENT_SENDS[key] = float(i) + + # Trigger eviction + arxiv_sender._evict_oldest_recent_sends() + + # Should have evicted back to max size + assert len(arxiv_sender._RECENT_SENDS) == arxiv_sender._RECENT_SENDS_MAX_SIZE + + # Oldest 100 should be gone + for i in range(100): + key = ("group", i, f"paper_{i}") + assert key not in arxiv_sender._RECENT_SENDS + + # Newest max_size should remain + for i in range(100, arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + assert key in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_cleanup_expired_recent_sends() -> None: + """Test that expired entries are removed while non-expired remain.""" + import time as time_module + + # Get current time + now = time_module.monotonic() + + # Add expired entries (old timestamps, more than 1 hour ago) + expired_key1 = ("group", 1, "expired1") + expired_key2 = ("group", 2, "expired2") + arxiv_sender._RECENT_SENDS[expired_key1] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 100.0 + ) + arxiv_sender._RECENT_SENDS[expired_key2] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 50.0 + ) + + # Add non-expired entries (recent timestamps, within 1 hour) + recent_key1 = ("group", 3, "recent1") + recent_key2 = ("group", 4, "recent2") + arxiv_sender._RECENT_SENDS[recent_key1] = now - 100.0 # 100 seconds ago + arxiv_sender._RECENT_SENDS[recent_key2] = now - 10.0 # 10 seconds ago + + # Cleanup + arxiv_sender._cleanup_expired_recent_sends() + + # Expired should be gone + assert expired_key1 not in arxiv_sender._RECENT_SENDS + assert expired_key2 not in arxiv_sender._RECENT_SENDS + + # Recent should remain + assert recent_key1 in arxiv_sender._RECENT_SENDS + assert recent_key2 in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_failed_send_not_recorded_in_recent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed send should NOT record in _RECENT_SENDS.""" + sender = _sender() + + async def _fake_once(**_: object) -> str: + raise RuntimeError("Send failed") + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # Attempt send, should fail + with pytest.raises(RuntimeError, match="Send failed"): + await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + + # Should NOT be recorded in recent sends + assert len(arxiv_sender._RECENT_SENDS) == 0 diff --git a/tests/test_attachment_tags.py b/tests/test_attachment_tags.py new file mode 100644 index 00000000..762ffd40 --- /dev/null +++ b/tests/test_attachment_tags.py @@ -0,0 +1,336 @@ +"""Tests for unified / tag rendering.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.attachments import ( + AttachmentRegistry, + AttachmentRenderError, + RenderedRichMessage, + dispatch_pending_file_sends, + render_message_with_attachments, + render_message_with_pic_placeholders, +) + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +_PDF_BYTES = b"%PDF-1.4 fake content for testing" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +# ---------- backward compatibility ---------- + + +@pytest.mark.asyncio +async def test_pic_tag_still_works(tmp_path: Path) -> None: + """ backward compat: renders image as CQ.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Look: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert len(result.attachments) == 1 + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_alias_is_same_function() -> None: + """render_message_with_pic_placeholders is an alias.""" + assert render_message_with_pic_placeholders is render_message_with_attachments + + +# ---------- unified tag ---------- + + +@pytest.mark.asyncio +async def test_attachment_tag_image(tmp_path: Path) -> None: + """ renders as CQ image (same as ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Here: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_attachment_tag_file(tmp_path: Path) -> None: + """ collects into pending_file_sends.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'See doc: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + # File tag removed from delivery text + assert rec.uid not in result.delivery_text + assert "See doc: " in result.delivery_text + # Readable placeholder in history + assert f"[文件 uid={rec.uid}" in result.history_text + assert "doc.pdf" in result.history_text + # Collected in pending + assert len(result.pending_file_sends) == 1 + assert result.pending_file_sends[0].uid == rec.uid + + +@pytest.mark.asyncio +async def test_mixed_pic_and_attachment_tags(tmp_path: Path) -> None: + """Mix of and tags in the same message.""" + reg = _make_registry(tmp_path) + img = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="img.png", source_kind="test" + ) + doc = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' and ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert len(result.pending_file_sends) == 1 + assert len(result.attachments) == 2 + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image(tmp_path: Path) -> None: + """ tag with file UID shows type error.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "类型错误" in result.delivery_text + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image_strict(tmp_path: Path) -> None: + """ tag with file UID raises in strict mode.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + with pytest.raises(AttachmentRenderError, match="不是图片"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_attachment_tag_allows_any_type(tmp_path: Path) -> None: + """ tag does NOT reject file UIDs (unlike ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + assert "类型错误" not in result.delivery_text + assert len(result.pending_file_sends) == 1 + + +@pytest.mark.asyncio +async def test_invalid_uid_non_strict(tmp_path: Path) -> None: + """Unknown UID → placeholder in non-strict mode.""" + reg = _make_registry(tmp_path) + msg = '' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "不可用" in result.delivery_text + + +@pytest.mark.asyncio +async def test_invalid_uid_strict(tmp_path: Path) -> None: + """Unknown UID → exception in strict mode.""" + reg = _make_registry(tmp_path) + msg = '' + with pytest.raises(AttachmentRenderError, match="不可用"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_file_tag_missing_local_path(tmp_path: Path) -> None: + """File with deleted local_path → error placeholder.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + assert rec.local_path is not None + Path(rec.local_path).unlink() + + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "缺少本地文件" in result.delivery_text + assert len(result.pending_file_sends) == 0 + + +@pytest.mark.asyncio +async def test_no_tags_passthrough() -> None: + """Message without tags passes through unchanged.""" + result = await render_message_with_attachments( + "Hello world", registry=None, scope_key="group:1", strict=False + ) + assert result.delivery_text == "Hello world" + assert result.history_text == "Hello world" + assert result.attachments == [] + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_rendered_rich_message_default_pending() -> None: + """RenderedRichMessage.pending_file_sends defaults to empty tuple.""" + msg = RenderedRichMessage(delivery_text="hi", history_text="hi", attachments=[]) + assert msg.pending_file_sends == () + + +# ---------- dispatch_pending_file_sends ---------- + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_group(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_group_file for group targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file( + self, group_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("group", group_id, file_path, name)) + + async def send_private_file( + self, user_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("private", user_id, file_path, name)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="group", target_id=12345 + ) + assert len(calls) == 1 + assert calls[0][0] == "group" + assert calls[0][1] == 12345 + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_private(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_private_file for private targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", + _PDF_BYTES, + kind="file", + display_name="report.pdf", + source_kind="test", + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + calls.append(("group", *a)) + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + calls.append(("private", *a)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="private", target_id=99999 + ) + assert len(calls) == 1 + assert calls[0][0] == "private" + assert calls[0][1] == 99999 + + +@pytest.mark.asyncio +async def test_dispatch_best_effort_on_failure(tmp_path: Path) -> None: + """File send failure doesn't propagate — best-effort.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + class FailingSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + # Should not raise + await dispatch_pending_file_sends( + rendered, sender=FailingSender(), target_type="group", target_id=1 + ) + + +@pytest.mark.asyncio +async def test_dispatch_no_pending_is_noop() -> None: + """No pending files → no calls.""" + rendered = RenderedRichMessage( + delivery_text="text", history_text="text", attachments=[] + ) + await dispatch_pending_file_sends( + rendered, sender=None, target_type="group", target_id=1 + ) diff --git a/tests/test_attachments_dedup.py b/tests/test_attachments_dedup.py new file mode 100644 index 00000000..c1114bf5 --- /dev/null +++ b/tests/test_attachments_dedup.py @@ -0,0 +1,147 @@ +"""Tests for attachment SHA-256 hash deduplication.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from Undefined.attachments import AttachmentRegistry + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +# Different content → different hash +_PNG_BYTES_ALT = _PNG_BYTES + b"\x00" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +@pytest.mark.asyncio +async def test_same_hash_same_scope_same_kind_returns_same_uid( + tmp_path: Path, +) -> None: + """Identical bytes + scope + kind → dedup returns same record.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="b.png", source_kind="test" + ) + assert r1.uid == r2.uid + assert r1.sha256 == r2.sha256 + + +@pytest.mark.asyncio +async def test_different_content_gets_different_uid(tmp_path: Path) -> None: + """Different bytes → different records even in same scope/kind.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", + _PNG_BYTES_ALT, + kind="image", + display_name="b.png", + source_kind="test", + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_scope_no_dedup(tmp_path: Path) -> None: + """Same hash but different scope → separate records (scope isolation).""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:2", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_kind_no_dedup(tmp_path: Path) -> None: + """Same hash + scope but different kind → separate records.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="file", display_name="a.bin", source_kind="test" + ) + assert r1.uid != r2.uid + assert r1.uid.startswith("pic_") + assert r2.uid.startswith("file_") + + +@pytest.mark.asyncio +async def test_file_deleted_causes_new_registration(tmp_path: Path) -> None: + """If the cached file is deleted, a new record is created.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.local_path is not None + Path(r1.local_path).unlink() + + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r2.uid != r1.uid + assert r2.sha256 == r1.sha256 + + +@pytest.mark.asyncio +async def test_concurrent_identical_registrations(tmp_path: Path) -> None: + """Concurrent registrations with the same content should be safe.""" + reg = _make_registry(tmp_path) + + results = await asyncio.gather( + *( + reg.register_bytes( + "group:1", + _PNG_BYTES, + kind="image", + display_name="pic.png", + source_kind="test", + ) + for _ in range(5) + ) + ) + uids = {r.uid for r in results} + # All should resolve to the same UID (dedup) + assert len(uids) == 1 + + +@pytest.mark.asyncio +async def test_register_local_file_deduplicates(tmp_path: Path) -> None: + """register_local_file delegates to register_bytes, so dedup applies.""" + reg = _make_registry(tmp_path) + file_path = tmp_path / "input.png" + file_path.write_bytes(_PNG_BYTES) + + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_local_file( + "group:1", file_path, kind="image", display_name="input.png" + ) + assert r1.uid == r2.uid diff --git a/tests/test_cache_cleanup.py b/tests/test_cache_cleanup.py new file mode 100644 index 00000000..72a0f086 --- /dev/null +++ b/tests/test_cache_cleanup.py @@ -0,0 +1,105 @@ +"""Tests for Undefined.utils.cache.cleanup_cache_dir.""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +from Undefined.utils.cache import cleanup_cache_dir + + +class TestCleanupCacheDir: + def test_empty_dir(self, tmp_path: Path) -> None: + assert cleanup_cache_dir(tmp_path) == 0 + + def test_old_files_removed(self, tmp_path: Path) -> None: + old_file = tmp_path / "old.txt" + old_file.write_text("data") + # Set mtime to 30 days ago + old_time = time.time() - 30 * 24 * 3600 + os.utime(old_file, (old_time, old_time)) + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 1 + assert not old_file.exists() + + def test_new_files_kept(self, tmp_path: Path) -> None: + new_file = tmp_path / "new.txt" + new_file.write_text("data") + # mtime = now (default), so it's fresh + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 0 + assert new_file.exists() + + def test_max_files_cap(self, tmp_path: Path) -> None: + now = time.time() + for i in range(5): + f = tmp_path / f"file_{i}.txt" + f.write_text("data") + os.utime(f, (now - i, now - i)) # stagger mtime + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=3) + assert deleted == 2 + remaining = list(tmp_path.iterdir()) + assert len(remaining) == 3 + + def test_nonexistent_dir_created(self, tmp_path: Path) -> None: + new_dir = tmp_path / "subdir" / "cache" + assert not new_dir.exists() + deleted = cleanup_cache_dir(new_dir) + assert deleted == 0 + assert new_dir.is_dir() + + def test_mixed_ages(self, tmp_path: Path) -> None: + now = time.time() + # 1 old, 2 fresh + old_f = tmp_path / "old.txt" + old_f.write_text("old") + os.utime(old_f, (now - 999999, now - 999999)) + + for i in range(2): + f = tmp_path / f"fresh_{i}.txt" + f.write_text("fresh") + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 1 + assert not old_f.exists() + + def test_zero_max_age_skips_age_check(self, tmp_path: Path) -> None: + old_file = tmp_path / "old.txt" + old_file.write_text("data") + old_time = time.time() - 999999 + os.utime(old_file, (old_time, old_time)) + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=0) + assert deleted == 0 + assert old_file.exists() + + def test_zero_max_files_skips_cap(self, tmp_path: Path) -> None: + for i in range(10): + (tmp_path / f"f{i}.txt").write_text("x") + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=0) + assert deleted == 0 + assert len(list(tmp_path.iterdir())) == 10 + + def test_both_age_and_cap(self, tmp_path: Path) -> None: + now = time.time() + # Create 5 files: 2 old (removed by age), 3 fresh + for i in range(2): + f = tmp_path / f"old_{i}.txt" + f.write_text("old") + os.utime(f, (now - 999999, now - 999999)) + for i in range(3): + f = tmp_path / f"new_{i}.txt" + f.write_text("new") + os.utime(f, (now - i, now - i)) + + deleted = cleanup_cache_dir( + tmp_path, max_age_seconds=7 * 24 * 3600, max_files=2 + ) + # 2 removed by age + 1 removed by cap = 3 + assert deleted == 3 + assert len(list(tmp_path.iterdir())) == 2 diff --git a/tests/test_calculator.py b/tests/test_calculator.py new file mode 100644 index 00000000..d0e6d2f7 --- /dev/null +++ b/tests/test_calculator.py @@ -0,0 +1,278 @@ +"""calculator 工具单元测试。""" + +from __future__ import annotations + +import json +import math +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.skills.tools.calculator.handler import safe_eval + +TOOL_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "tools" + / "calculator" +) + + +# --------------------------------------------------------------------------- +# 配置文件检查 +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_config_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "config.json").read_text(encoding="utf-8") + ) + assert cfg["function"]["name"] == "calculator" + assert "expression" in cfg["function"]["parameters"]["properties"] + + def test_callable_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "callable.json").read_text(encoding="utf-8") + ) + assert cfg["enabled"] is True + assert "*" in cfg["allowed_callers"] + + +# --------------------------------------------------------------------------- +# safe_eval 单元测试 +# --------------------------------------------------------------------------- + + +class TestArithmetic: + def test_addition(self) -> None: + assert safe_eval("2+3") == "5" + + def test_subtraction(self) -> None: + assert safe_eval("10-3") == "7" + + def test_multiplication(self) -> None: + assert safe_eval("4*5") == "20" + + def test_division(self) -> None: + assert safe_eval("10/3") == "3.333333333" + + def test_floor_division(self) -> None: + assert safe_eval("10//3") == "3" + + def test_modulo(self) -> None: + assert safe_eval("10%3") == "1" + + def test_power(self) -> None: + assert safe_eval("2**10") == "1024" + + def test_caret_as_power(self) -> None: + assert safe_eval("2^10") == "1024" + + def test_negative_number(self) -> None: + assert safe_eval("-5+3") == "-2" + + def test_complex_expression(self) -> None: + assert safe_eval("(2+3)*4-1") == "19" + + def test_floating_point(self) -> None: + assert safe_eval("0.1+0.2") == "0.3" + + +class TestConstants: + def test_pi(self) -> None: + result = float(safe_eval("pi")) + assert abs(result - math.pi) < 1e-8 + + def test_e(self) -> None: + result = float(safe_eval("e")) + assert abs(result - math.e) < 1e-8 + + def test_tau(self) -> None: + result = float(safe_eval("tau")) + assert abs(result - math.tau) < 1e-8 + + +class TestScientificFunctions: + def test_sqrt(self) -> None: + assert safe_eval("sqrt(144)") == "12" + + def test_sin(self) -> None: + result = float(safe_eval("sin(pi/6)")) + assert abs(result - 0.5) < 1e-10 + + def test_cos(self) -> None: + result = float(safe_eval("cos(0)")) + assert abs(result - 1.0) < 1e-10 + + def test_log10(self) -> None: + assert safe_eval("log10(1000)") == "3" + + def test_log_with_base(self) -> None: + assert safe_eval("log(8, 2)") == "3" + + def test_ln(self) -> None: + result = float(safe_eval("ln(e)")) + assert abs(result - 1.0) < 1e-10 + + def test_factorial(self) -> None: + assert safe_eval("factorial(10)") == "3628800" + + def test_degrees(self) -> None: + result = float(safe_eval("degrees(pi)")) + assert abs(result - 180.0) < 1e-10 + + def test_radians(self) -> None: + result = float(safe_eval("radians(180)")) + assert abs(result - math.pi) < 1e-8 + + def test_ceil(self) -> None: + assert safe_eval("ceil(3.2)") == "4" + + def test_floor(self) -> None: + assert safe_eval("floor(3.8)") == "3" + + def test_abs(self) -> None: + assert safe_eval("abs(-42)") == "42" + + def test_gcd(self) -> None: + assert safe_eval("gcd(48, 18)") == "6" + + def test_lcm(self) -> None: + assert safe_eval("lcm(4, 6)") == "12" + + def test_comb(self) -> None: + assert safe_eval("comb(10, 3)") == "120" + + def test_perm(self) -> None: + assert safe_eval("perm(5, 3)") == "60" + + def test_factorial_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("factorial(9999)") + + def test_comb_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("comb(9999, 5000)") + + def test_perm_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("perm(9999, 5000)") + + def test_factorial_at_limit(self) -> None: + """factorial(1000) should succeed (within limit).""" + result = safe_eval("factorial(1000)") + assert int(result) > 0 + + def test_hypot(self) -> None: + assert safe_eval("hypot(3, 4)") == "5" + + +class TestStatistics: + def test_mean(self) -> None: + assert safe_eval("mean(1, 2, 3, 4, 5)") == "3" + + def test_median(self) -> None: + assert safe_eval("median(1, 3, 5, 7, 9)") == "5" + + def test_stdev(self) -> None: + result = float(safe_eval("stdev(2, 4, 4, 4, 5, 5, 7, 9)")) + assert result > 0 + + def test_min_max(self) -> None: + assert safe_eval("min(3, 1, 4, 1, 5)") == "1" + assert safe_eval("max(3, 1, 4, 1, 5)") == "5" + + def test_sum(self) -> None: + assert safe_eval("sum([1, 2, 3, 4, 5])") == "15" + + +class TestComparison: + def test_equal(self) -> None: + assert safe_eval("2+2 == 4") == "True" + + def test_not_equal(self) -> None: + assert safe_eval("2+2 != 5") == "True" + + def test_less_than(self) -> None: + assert safe_eval("3 < 5") == "True" + + def test_greater_than(self) -> None: + assert safe_eval("5 > 3") == "True" + + +class TestIfExpression: + def test_ternary(self) -> None: + assert safe_eval("42 if 2>1 else 0") == "42" + + +class TestSafety: + def test_rejects_import(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("__import__('os')") + + def test_rejects_attribute_access(self) -> None: + with pytest.raises(ValueError, match="不支持"): + safe_eval("math.pi") + + def test_rejects_unknown_function(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("eval('1+1')") + + def test_rejects_unknown_variable(self) -> None: + with pytest.raises(ValueError, match="未知变量"): + safe_eval("x + 1") + + def test_rejects_too_large_exponent(self) -> None: + with pytest.raises(ValueError, match="指数过大"): + safe_eval("2**100000") + + def test_rejects_too_long_expression(self) -> None: + with pytest.raises(ValueError, match="表达式过长"): + safe_eval("1+" * 300 + "1") + + def test_rejects_empty(self) -> None: + with pytest.raises(ValueError, match="表达式为空"): + safe_eval("") + + def test_division_by_zero(self) -> None: + with pytest.raises(ZeroDivisionError): + safe_eval("1/0") + + +# --------------------------------------------------------------------------- +# execute() 集成测试 +# --------------------------------------------------------------------------- + + +class TestExecute: + @pytest.mark.asyncio + async def test_basic_calculation(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+3*4"}, {}) + assert "= 14" in result + + @pytest.mark.asyncio + async def test_empty_expression(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_error_message(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "1/0"}, {}) + assert "除以零" in result + + @pytest.mark.asyncio + async def test_syntax_error(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+++"}, {}) + assert "语法错误" in result diff --git a/tests/test_config_cognitive_historian_limits.py b/tests/test_config_cognitive_historian_limits.py index c7f843f6..bb58f56b 100644 --- a/tests/test_config_cognitive_historian_limits.py +++ b/tests/test_config_cognitive_historian_limits.py @@ -1,10 +1,10 @@ from __future__ import annotations -from Undefined.config.loader import Config +from Undefined.config.domain_parsers import _parse_cognitive_config def test_parse_cognitive_historian_reference_limits() -> None: - cfg = Config._parse_cognitive_config( + cfg = _parse_cognitive_config( { "cognitive": { "query": { @@ -32,7 +32,7 @@ def test_parse_cognitive_historian_reference_limits() -> None: def test_parse_cognitive_historian_reference_limits_defaults() -> None: - cfg = Config._parse_cognitive_config({"cognitive": {}}) + cfg = _parse_cognitive_config({"cognitive": {}}) assert cfg.historian_recent_messages_inject_k == 12 assert cfg.historian_recent_message_line_max_len == 240 diff --git a/tests/test_config_easter_egg_repeat.py b/tests/test_config_easter_egg_repeat.py new file mode 100644 index 00000000..e27ae051 --- /dev/null +++ b/tests/test_config_easter_egg_repeat.py @@ -0,0 +1,80 @@ +"""Config 加载:[easter_egg] repeat_enabled / inverted_question_enabled""" + +from __future__ import annotations + +from pathlib import Path + +from Undefined.config.loader import Config + + +def _load(tmp_path: Path, text: str) -> Config: + p = tmp_path / "config.toml" + p.write_text(text, "utf-8") + return Config.load(p, strict=False) + + +_MINIMAL = """ +[onebot] +ws_url = "ws://127.0.0.1:3001" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +""" + + +def test_repeat_defaults_to_false(tmp_path: Path) -> None: + cfg = _load(tmp_path, _MINIMAL) + assert cfg.repeat_enabled is False + assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 + + +def test_repeat_enabled_explicit(tmp_path: Path) -> None: + cfg = _load(tmp_path, _MINIMAL + "\n[easter_egg]\nrepeat_enabled = true\n") + assert cfg.repeat_enabled is True + assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 + + +def test_inverted_question_enabled_explicit(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + + "\n[easter_egg]\nrepeat_enabled = true\ninverted_question_enabled = true\n", + ) + assert cfg.repeat_enabled is True + assert cfg.inverted_question_enabled is True + + +def test_inverted_question_without_repeat(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\ninverted_question_enabled = true\n", + ) + assert cfg.repeat_enabled is False + assert cfg.inverted_question_enabled is True + + +def test_keyword_reply_still_parsed_from_easter_egg(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nkeyword_reply_enabled = true\n", + ) + assert cfg.keyword_reply_enabled is True + + +def test_repeat_cooldown_custom_value(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = 30\n", + ) + assert cfg.repeat_cooldown_minutes == 30 + + +def test_repeat_cooldown_negative_clamped_to_zero(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = -5\n", + ) + assert cfg.repeat_cooldown_minutes == 0 diff --git a/tests/test_config_hot_reload.py b/tests/test_config_hot_reload.py index a2441627..4e0158ae 100644 --- a/tests/test_config_hot_reload.py +++ b/tests/test_config_hot_reload.py @@ -62,6 +62,10 @@ def test_apply_config_updates_propagates_to_security_service() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() @@ -120,6 +124,10 @@ def test_apply_config_updates_hot_reloads_ai_request_max_retries() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() diff --git a/tests/test_config_models.py b/tests/test_config_models.py new file mode 100644 index 00000000..b4bf05ff --- /dev/null +++ b/tests/test_config_models.py @@ -0,0 +1,53 @@ +"""Tests for Undefined.config.models — config model helpers.""" + +from __future__ import annotations + +from Undefined.config.models import format_netloc, resolve_bind_hosts + + +class TestFormatNetloc: + def test_ipv4(self) -> None: + assert format_netloc("127.0.0.1", 8080) == "127.0.0.1:8080" + + def test_hostname(self) -> None: + assert format_netloc("example.com", 443) == "example.com:443" + + def test_ipv6_wrapped(self) -> None: + assert format_netloc("::1", 8080) == "[::1]:8080" + + def test_ipv6_full(self) -> None: + result = format_netloc("2001:db8::1", 9090) + assert result == "[2001:db8::1]:9090" + + def test_ipv6_all_zeros(self) -> None: + assert format_netloc("::", 80) == "[::]:80" + + def test_ipv4_default_port(self) -> None: + assert format_netloc("0.0.0.0", 80) == "0.0.0.0:80" + + def test_localhost(self) -> None: + assert format_netloc("localhost", 3000) == "localhost:3000" + + def test_empty_host(self) -> None: + # No colon in empty string → treated as IPv4-style + assert format_netloc("", 8080) == ":8080" + + +class TestResolveBindHosts: + def test_empty_string(self) -> None: + assert resolve_bind_hosts("") == ["0.0.0.0", "::"] + + def test_double_colon(self) -> None: + assert resolve_bind_hosts("::") == ["0.0.0.0", "::"] + + def test_ipv4_any(self) -> None: + assert resolve_bind_hosts("0.0.0.0") == ["0.0.0.0"] + + def test_specific_ipv4(self) -> None: + assert resolve_bind_hosts("127.0.0.1") == ["127.0.0.1"] + + def test_specific_ipv6(self) -> None: + assert resolve_bind_hosts("::1") == ["::1"] + + def test_hostname(self) -> None: + assert resolve_bind_hosts("myhost.local") == ["myhost.local"] diff --git a/tests/test_config_request_params.py b/tests/test_config_request_params.py index eaa002d8..ace8b115 100644 --- a/tests/test_config_request_params.py +++ b/tests/test_config_request_params.py @@ -90,6 +90,15 @@ def test_model_request_params_load_inherit_and_new_transport_fields( temperature = 0.1 metadata = { source = "historian" } +[models.summary] +model_name = "gpt-summary" +api_mode = "chat_completions" +reasoning_effort = "xhigh" + +[models.summary.request_params] +temperature = 0.15 +metadata = { source = "summary" } + [models.grok] api_url = "https://grok.example/v1" api_key = "sk-grok" @@ -207,6 +216,18 @@ def test_model_request_params_load_inherit_and_new_transport_fields( "metadata": {"source": "historian"}, "response_format": {"type": "json_object"}, } + assert cfg.summary_model.api_mode == "chat_completions" + assert cfg.summary_model.reasoning_enabled is True + assert cfg.summary_model.reasoning_effort == "xhigh" + assert cfg.summary_model.thinking_tool_call_compat is True + assert cfg.summary_model.responses_tool_choice_compat is True + assert cfg.summary_model.responses_force_stateless_replay is True + assert cfg.summary_model.prompt_cache_enabled is False + assert cfg.summary_model.request_params == { + "temperature": 0.15, + "metadata": {"source": "summary"}, + "response_format": {"type": "json_object"}, + } assert cfg.grok_model.reasoning_enabled is True assert cfg.grok_model.reasoning_effort == "low" assert cfg.grok_model.prompt_cache_enabled is True diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 00000000..e68a0d86 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,195 @@ +"""Tests for Undefined.context module.""" + +from __future__ import annotations + +import logging + +import pytest + +from Undefined.context import ( + RequestContext, + RequestContextFilter, + get_group_id, + get_request_id, + get_request_type, + get_sender_id, + get_user_id, +) + + +class TestRequestContextManager: + """Tests for RequestContext async context manager.""" + + async def test_enter_sets_current(self) -> None: + async with RequestContext(request_type="group", group_id=123) as ctx: + assert RequestContext.current() is ctx + + async def test_exit_clears_current(self) -> None: + async with RequestContext(request_type="group"): + pass + assert RequestContext.current() is None + + async def test_request_id_generated(self) -> None: + async with RequestContext(request_type="private") as ctx: + assert ctx.request_id is not None + assert len(ctx.request_id) > 0 + + async def test_request_id_is_uuid(self) -> None: + import uuid + + async with RequestContext(request_type="private") as ctx: + uuid.UUID(ctx.request_id) + + async def test_nested_contexts(self) -> None: + async with RequestContext(request_type="group", group_id=1) as outer: + assert RequestContext.current() is outer + async with RequestContext(request_type="private", user_id=99) as inner: + assert RequestContext.current() is inner + assert inner.user_id == 99 + assert RequestContext.current() is outer + assert outer.group_id == 1 + + async def test_metadata(self) -> None: + async with RequestContext(request_type="api", extra_key="value") as ctx: + assert ctx.metadata["extra_key"] == "value" + + +class TestRequestContextResources: + """Tests for resource management.""" + + async def test_set_and_get_resource(self) -> None: + async with RequestContext(request_type="group") as ctx: + ctx.set_resource("sender", {"name": "test"}) + assert ctx.get_resource("sender") == {"name": "test"} + + async def test_get_missing_resource_default(self) -> None: + async with RequestContext(request_type="group") as ctx: + assert ctx.get_resource("missing") is None + assert ctx.get_resource("missing", "fallback") == "fallback" + + async def test_resources_cleared_on_exit(self) -> None: + ctx = RequestContext(request_type="group") + async with ctx: + ctx.set_resource("key", "value") + assert ctx.get_resource("key") is None + + async def test_get_resources_returns_copy(self) -> None: + async with RequestContext(request_type="group") as ctx: + ctx.set_resource("a", 1) + ctx.set_resource("b", 2) + resources = ctx.get_resources() + assert resources == {"a": 1, "b": 2} + resources["c"] = 3 + assert ctx.get_resource("c") is None + + +class TestRequireContext: + """Tests for RequestContext.require().""" + + async def test_require_inside_context(self) -> None: + async with RequestContext(request_type="group") as ctx: + assert RequestContext.require() is ctx + + async def test_require_outside_context_raises(self) -> None: + with pytest.raises(RuntimeError): + RequestContext.require() + + +class TestHelperFunctions: + """Tests for module-level helper functions.""" + + async def test_get_group_id_inside_context(self) -> None: + async with RequestContext(request_type="group", group_id=42): + assert get_group_id() == 42 + + async def test_get_group_id_outside_context(self) -> None: + assert get_group_id() is None + + async def test_get_user_id_inside_context(self) -> None: + async with RequestContext(request_type="private", user_id=7): + assert get_user_id() == 7 + + async def test_get_user_id_outside_context(self) -> None: + assert get_user_id() is None + + async def test_get_request_id_inside_context(self) -> None: + async with RequestContext(request_type="group"): + rid = get_request_id() + assert rid is not None + assert len(rid) > 0 + + async def test_get_request_id_outside_context(self) -> None: + assert get_request_id() is None + + async def test_get_sender_id_inside_context(self) -> None: + async with RequestContext(request_type="group", sender_id=100): + assert get_sender_id() == 100 + + async def test_get_sender_id_outside_context(self) -> None: + assert get_sender_id() is None + + async def test_get_request_type_inside_context(self) -> None: + async with RequestContext(request_type="private"): + assert get_request_type() == "private" + + async def test_get_request_type_outside_context(self) -> None: + assert get_request_type() is None + + +class TestRequestContextFilter: + """Tests for RequestContextFilter logging filter.""" + + async def test_filter_with_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test message", + args=(), + exc_info=None, + ) + async with RequestContext( + request_type="group", group_id=10, user_id=20, sender_id=30 + ) as ctx: + result = filt.filter(record) + assert result is True + assert getattr(record, "request_id") == ctx.request_id[:8] + assert getattr(record, "group_id") == 10 + assert getattr(record, "user_id") == 20 + assert getattr(record, "sender_id") == 30 + + def test_filter_without_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + result = filt.filter(record) + assert result is True + assert getattr(record, "request_id") == "-" + assert getattr(record, "group_id") == "-" + assert getattr(record, "user_id") == "-" + assert getattr(record, "sender_id") == "-" + + async def test_filter_partial_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + async with RequestContext(request_type="private", user_id=5): + filt.filter(record) + assert getattr(record, "group_id") == "-" + assert getattr(record, "user_id") == 5 diff --git a/tests/test_coordinator_level.py b/tests/test_coordinator_level.py new file mode 100644 index 00000000..38ecaa68 --- /dev/null +++ b/tests/test_coordinator_level.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + + +from Undefined.services.ai_coordinator import AICoordinator + + +def _make_coordinator() -> AICoordinator: + """创建用于测试的 AICoordinator 实例""" + config = MagicMock() + ai = MagicMock() + queue_manager = MagicMock() + history_manager = MagicMock() + sender = MagicMock() + onebot = MagicMock() + scheduler = MagicMock() + security = MagicMock() + + coordinator = AICoordinator( + config=config, + ai=ai, + queue_manager=queue_manager, + history_manager=history_manager, + sender=sender, + onebot=onebot, + scheduler=scheduler, + security=security, + ) + + return coordinator + + +def test_build_prompt_with_level_includes_level_attribute() -> None: + """测试 _build_prompt 带 level 参数时 XML 包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + attachments=None, + message_id=123456, + level="Lv.5", + ) + + assert 'level="Lv.5"' in result + assert " None: + """测试 _build_prompt level 为空字符串时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level="", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt 不传 level 参数时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt level 包含特殊字符时能正确转义""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level='Lv.5 ', + ) + + assert "level=" in result + assert '' not in result + assert "<" in result or "&" in result or """ in result + + +def test_build_prompt_with_all_attributes() -> None: + """测试 _build_prompt 包含所有属性时的输出""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="系统提示:\n", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="admin", + title="管理员", + time_str="2026-04-11 10:00:00", + text="测试消息内容", + attachments=[{"uid": "pic_001", "kind": "image"}], + message_id=123456, + level="Lv.10", + ) + + assert "系统提示:\n" in result + assert 'sender="测试用户"' in result + assert 'sender_id="10001"' in result + assert 'group_id="20001"' in result + assert 'group_name="测试群"' in result + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert "测试消息内容" in result + assert "" in result diff --git a/tests/test_cors_utils.py b/tests/test_cors_utils.py new file mode 100644 index 00000000..efea14e2 --- /dev/null +++ b/tests/test_cors_utils.py @@ -0,0 +1,137 @@ +"""Tests for Undefined.utils.cors — CORS origin helpers.""" + +from __future__ import annotations + +from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin + + +class TestNormalizeOrigin: + def test_simple_origin(self) -> None: + assert normalize_origin("http://example.com") == "http://example.com" + + def test_trailing_slash(self) -> None: + assert normalize_origin("http://example.com/") == "http://example.com" + + def test_multiple_trailing_slashes(self) -> None: + assert normalize_origin("http://example.com///") == "http://example.com" + + def test_case_insensitive(self) -> None: + assert normalize_origin("HTTP://EXAMPLE.COM") == "http://example.com" + + def test_whitespace_stripped(self) -> None: + assert normalize_origin(" http://example.com ") == "http://example.com" + + def test_empty_string(self) -> None: + assert normalize_origin("") == "" + + def test_none_like_empty(self) -> None: + # The function casts to str via `str(origin or "")`. + assert normalize_origin("") == "" + + def test_with_port(self) -> None: + assert normalize_origin("http://localhost:8080/") == "http://localhost:8080" + + +class TestIsAllowedCorsOrigin: + def test_empty_origin_rejected(self) -> None: + assert is_allowed_cors_origin("") is False + + def test_whitespace_only_rejected(self) -> None: + assert is_allowed_cors_origin(" ") is False + + def test_localhost_http_allowed(self) -> None: + assert is_allowed_cors_origin("http://localhost") is True + + def test_localhost_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://localhost:3000") is True + + def test_localhost_https_allowed(self) -> None: + assert is_allowed_cors_origin("https://localhost") is True + + def test_ipv4_loopback_allowed(self) -> None: + assert is_allowed_cors_origin("http://127.0.0.1") is True + + def test_ipv4_loopback_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://127.0.0.1:8080") is True + + def test_ipv6_loopback_allowed(self) -> None: + assert is_allowed_cors_origin("http://[::1]") is True + + def test_ipv6_loopback_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://[::1]:8080") is True + + def test_tauri_localhost_allowed(self) -> None: + assert is_allowed_cors_origin("tauri://localhost") is True + + def test_external_origin_rejected(self) -> None: + assert is_allowed_cors_origin("http://evil.com") is False + + def test_configured_host_allowed(self) -> None: + assert ( + is_allowed_cors_origin( + "http://myhost.local", + configured_host="myhost.local", + ) + is True + ) + + def test_configured_host_with_port(self) -> None: + assert ( + is_allowed_cors_origin( + "https://myhost.local:9090", + configured_host="myhost.local", + configured_port=9090, + ) + is True + ) + + def test_configured_host_wrong_port_rejected(self) -> None: + assert ( + is_allowed_cors_origin( + "http://myhost.local:1234", + configured_host="myhost.local", + configured_port=9090, + ) + is False + ) + + def test_extra_origins_allowed(self) -> None: + assert ( + is_allowed_cors_origin( + "https://cdn.example.com", + extra_origins={"https://cdn.example.com"}, + ) + is True + ) + + def test_extra_origins_case_insensitive(self) -> None: + assert ( + is_allowed_cors_origin( + "HTTPS://CDN.EXAMPLE.COM", + extra_origins={"https://cdn.example.com"}, + ) + is True + ) + + def test_extra_origins_not_matching_rejected(self) -> None: + assert ( + is_allowed_cors_origin( + "https://other.com", + extra_origins={"https://cdn.example.com"}, + ) + is False + ) + + def test_no_scheme_rejected(self) -> None: + # "example.com" without scheme is not a valid loopback HTTP origin + assert is_allowed_cors_origin("example.com") is False + + def test_ftp_scheme_rejected(self) -> None: + assert is_allowed_cors_origin("ftp://localhost") is False + + def test_configured_host_empty(self) -> None: + # Empty configured_host should not add anything + assert is_allowed_cors_origin("http://evil.com", configured_host="") is False + + def test_extra_origins_none(self) -> None: + assert is_allowed_cors_origin("http://localhost", extra_origins=None) is True diff --git a/tests/test_end_summary_storage.py b/tests/test_end_summary_storage.py new file mode 100644 index 00000000..aa3f4aa8 --- /dev/null +++ b/tests/test_end_summary_storage.py @@ -0,0 +1,124 @@ +"""EndSummaryStorage 单元测试""" + +from __future__ import annotations + +from typing import Any + + +from Undefined.end_summary_storage import ( + EndSummaryLocation, + EndSummaryStorage, +) + + +# --------------------------------------------------------------------------- +# make_record +# --------------------------------------------------------------------------- + + +class TestMakeRecord: + def test_basic(self) -> None: + record = EndSummaryStorage.make_record( + "summary text", "2025-01-01T00:00:00+08:00" + ) + assert record["summary"] == "summary text" + assert record["timestamp"] == "2025-01-01T00:00:00+08:00" + assert "location" not in record + + def test_strips_summary(self) -> None: + record = EndSummaryStorage.make_record(" spaces ", "ts") + assert record["summary"] == "spaces" + + def test_none_timestamp_auto_generates(self) -> None: + record = EndSummaryStorage.make_record("text", None) + assert record["timestamp"] # 非空 + assert "T" in record["timestamp"] # ISO 格式 + + def test_empty_timestamp_auto_generates(self) -> None: + record = EndSummaryStorage.make_record("text", " ") + assert record["timestamp"] + assert record["timestamp"].strip() != "" + + def test_with_location(self) -> None: + loc: EndSummaryLocation = {"type": "group", "name": "测试群"} + record = EndSummaryStorage.make_record("text", "ts", location=loc) + assert record.get("location") is not None + assert record["location"]["type"] == "group" + assert record["location"]["name"] == "测试群" + + def test_with_private_location(self) -> None: + loc: EndSummaryLocation = {"type": "private", "name": "好友"} + record = EndSummaryStorage.make_record("text", "ts", location=loc) + assert record["location"]["type"] == "private" + + def test_location_none_omitted(self) -> None: + record = EndSummaryStorage.make_record("text", "ts", location=None) + assert "location" not in record + + def test_invalid_location_type_ignored(self) -> None: + bad_loc: Any = {"type": "invalid", "name": "x"} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_missing_name_ignored(self) -> None: + bad_loc: Any = {"type": "group"} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_empty_name_ignored(self) -> None: + bad_loc: Any = {"type": "group", "name": " "} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_non_string_name_ignored(self) -> None: + bad_loc: Any = {"type": "group", "name": 123} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_not_dict_ignored(self) -> None: + bad: Any = "bad" + record = EndSummaryStorage.make_record("text", "ts", location=bad) + assert "location" not in record + + +# --------------------------------------------------------------------------- +# _normalize_records +# --------------------------------------------------------------------------- + + +class TestNormalizeRecords: + def _storage(self) -> EndSummaryStorage: + return EndSummaryStorage() + + def test_none_returns_empty(self) -> None: + assert self._storage()._normalize_records(None) == [] + + def test_non_list_returns_empty(self) -> None: + assert self._storage()._normalize_records("not a list") == [] + + def test_string_items_converted(self) -> None: + records = self._storage()._normalize_records(["hello", "world"]) + assert len(records) == 2 + assert records[0]["summary"] == "hello" + + def test_empty_string_items_skipped(self) -> None: + records = self._storage()._normalize_records(["", " ", "valid"]) + assert len(records) == 1 + assert records[0]["summary"] == "valid" + + def test_dict_items_normalized(self) -> None: + data: list[dict[str, Any]] = [ + {"summary": "text", "timestamp": "2025-01-01"}, + ] + records = self._storage()._normalize_records(data) + assert len(records) == 1 + assert records[0]["summary"] == "text" + + def test_dict_missing_summary_skipped(self) -> None: + records = self._storage()._normalize_records([{"timestamp": "t"}]) + assert len(records) == 0 + + def test_max_records_trimmed(self) -> None: + data = [f"summary-{i}" for i in range(250)] + records = self._storage()._normalize_records(data) + assert len(records) == 200 # MAX_END_SUMMARIES diff --git a/tests/test_end_tool.py b/tests/test_end_tool.py index 45e4106d..4b65584a 100644 --- a/tests/test_end_tool.py +++ b/tests/test_end_tool.py @@ -178,10 +178,9 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: cognitive_service = _FakeCognitiveService() runtime_config = SimpleNamespace( cognitive=SimpleNamespace( - historian_recent_messages_inject_k=2, - historian_recent_message_line_max_len=60, historian_source_message_max_len=40, - ) + ), + get_context_recent_messages_limit=lambda: 2, ) long_content = "A" * 300 context: dict[str, Any] = { @@ -207,6 +206,7 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: assert len(source) <= 40 assert isinstance(recent, list) assert len(recent) == 2 - assert all( - len(str(line).split(": ", 1)[1]) <= 60 for line in recent if ": " in str(line) - ) + # Recent messages now use XML format (same as main AI) + for line in recent: + assert "" in str(line) diff --git a/tests/test_fake_at.py b/tests/test_fake_at.py new file mode 100644 index 00000000..f56efafd --- /dev/null +++ b/tests/test_fake_at.py @@ -0,0 +1,179 @@ +"""Tests for Undefined.utils.fake_at.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from Undefined.utils.fake_at import ( + BotNicknameCache, + _normalize, + _sorted_nicknames, + strip_fake_at, +) + + +# --------------------------------------------------------------------------- +# _normalize +# --------------------------------------------------------------------------- + + +class TestNormalize: + def test_fullwidth_at_to_halfwidth(self) -> None: + assert "@" in _normalize("@") + + def test_casefold(self) -> None: + assert _normalize("ABC") == "abc" + + def test_nfkc_normalization(self) -> None: + # Fullwidth letters → ASCII + assert _normalize("A") == "a" + + def test_combined(self) -> None: + result = _normalize("@Hello") + assert result == "@hello" + + +# --------------------------------------------------------------------------- +# _sorted_nicknames +# --------------------------------------------------------------------------- + + +class TestSortedNicknames: + def test_sorted_by_length_desc(self) -> None: + names = frozenset({"ab", "abcd", "a"}) + result = _sorted_nicknames(names) + assert result == ("abcd", "ab", "a") + + def test_empty(self) -> None: + assert _sorted_nicknames(frozenset()) == () + + +# --------------------------------------------------------------------------- +# strip_fake_at +# --------------------------------------------------------------------------- + + +class TestStripFakeAt: + def test_empty_nicknames(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset()) + assert hit is False + assert text == "@bot hello" + + def test_empty_text(self) -> None: + hit, text = strip_fake_at("", frozenset({"bot"})) + assert hit is False + assert text == "" + + def test_no_at_prefix(self) -> None: + hit, text = strip_fake_at("hello bot", frozenset({"bot"})) + assert hit is False + assert text == "hello bot" + + def test_simple_match(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_match_with_fullwidth_at(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_case_insensitive(self) -> None: + hit, text = strip_fake_at("@BOT hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_longer_nickname_preferred(self) -> None: + nicks = frozenset({"bot", "bot助手"}) + hit, text = strip_fake_at("@bot助手 hello", nicks) + assert hit is True + assert text == "hello" + + def test_no_boundary_after_nickname(self) -> None: + hit, text = strip_fake_at("@botextrastuff", frozenset({"bot"})) + assert hit is False + assert text == "@botextrastuff" + + def test_boundary_punctuation(self) -> None: + hit, text = strip_fake_at("@bot,你好", frozenset({"bot"})) + assert hit is True + + def test_boundary_end_of_string(self) -> None: + hit, text = strip_fake_at("@bot", frozenset({"bot"})) + assert hit is True + assert text == "" + + def test_no_match_returns_original(self) -> None: + hit, text = strip_fake_at("@nobody hello", frozenset({"bot"})) + assert hit is False + assert text == "@nobody hello" + + def test_stripped_text_lstripped(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + +# --------------------------------------------------------------------------- +# BotNicknameCache +# --------------------------------------------------------------------------- + + +class TestBotNicknameCache: + @pytest.fixture() + def mock_onebot(self) -> MagicMock: + ob = MagicMock() + ob.get_group_member_info = AsyncMock( + return_value={"card": "BotCard", "nickname": "BotNick"} + ) + return ob + + async def test_get_nicknames_fetches_and_caches( + self, mock_onebot: MagicMock + ) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(12345) + assert "botcard" in names + assert "botnick" in names + mock_onebot.get_group_member_info.assert_awaited_once_with(12345, 10000) + + async def test_get_nicknames_uses_cache(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(12345) + await cache.get_nicknames(12345) + # Should only call API once thanks to caching + mock_onebot.get_group_member_info.assert_awaited_once() + + async def test_invalidate_specific_group(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(12345) + cache.invalidate(12345) + await cache.get_nicknames(12345) + assert mock_onebot.get_group_member_info.await_count == 2 + + async def test_invalidate_all(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(111) + await cache.get_nicknames(222) + cache.invalidate() + await cache.get_nicknames(111) + # 111 fetched twice, 222 fetched once = 3 + assert mock_onebot.get_group_member_info.await_count == 3 + + async def test_api_failure_returns_empty(self) -> None: + ob: Any = MagicMock() + ob.get_group_member_info = AsyncMock(side_effect=RuntimeError("API error")) + cache = BotNicknameCache(ob, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(99999) + assert names == frozenset() + + async def test_empty_card_and_nickname(self) -> None: + ob: Any = MagicMock() + ob.get_group_member_info = AsyncMock(return_value={"card": "", "nickname": ""}) + cache = BotNicknameCache(ob, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(123) + assert names == frozenset() diff --git a/tests/test_faq_unit.py b/tests/test_faq_unit.py new file mode 100644 index 00000000..ae6b2682 --- /dev/null +++ b/tests/test_faq_unit.py @@ -0,0 +1,305 @@ +"""FAQ 存储管理 单元测试""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from Undefined.faq import FAQ, FAQStorage, extract_faq_title + +_WRITE_JSON = "Undefined.utils.io.write_json" +_READ_JSON = "Undefined.utils.io.read_json" +_DELETE_FILE = "Undefined.utils.io.delete_file" + + +# --------------------------------------------------------------------------- +# FAQ dataclass +# --------------------------------------------------------------------------- + + +class TestFAQDataclass: + def _sample(self) -> FAQ: + return FAQ( + id="20250101-001", + group_id=12345, + target_qq=67890, + start_time="2025-01-01T00:00:00", + end_time="2025-01-02T00:00:00", + created_at="2025-01-01T00:00:00", + title="测试标题", + content="测试内容", + ) + + def test_to_dict(self) -> None: + faq = self._sample() + d = faq.to_dict() + assert d["id"] == "20250101-001" + assert d["group_id"] == 12345 + assert d["title"] == "测试标题" + + def test_from_dict(self) -> None: + faq = self._sample() + d = faq.to_dict() + restored = FAQ.from_dict(d) + assert restored == faq + + def test_roundtrip(self) -> None: + faq = self._sample() + assert FAQ.from_dict(faq.to_dict()) == faq + + +# --------------------------------------------------------------------------- +# extract_faq_title +# --------------------------------------------------------------------------- + + +class TestExtractFaqTitle: + def test_extract_from_question_colon(self) -> None: + content = "**问题**: 如何重启服务?\n回答是这样的" + assert extract_faq_title(content) == "如何重启服务?" + + def test_extract_from_question_chinese_colon(self) -> None: + content = "**问题**:如何重启服务?\n回答是这样的" + assert extract_faq_title(content) == "如何重启服务?" + + def test_extract_truncates_long_title(self) -> None: + long_question = "x" * 200 + content = f"**问题**: {long_question}" + result = extract_faq_title(content) + assert len(result) <= 100 + + def test_extract_from_bug_section(self) -> None: + content = "## Bug 问题描述\n登录页面崩溃\n更多细节" + assert extract_faq_title(content) == "登录页面崩溃" + + def test_extract_bug_section_truncates(self) -> None: + long_desc = "y" * 200 + content = f"## Bug 问题描述\n{long_desc}" + result = extract_faq_title(content) + assert len(result) <= 100 + + def test_extract_bug_section_skips_heading(self) -> None: + content = "## Bug 问题描述\n# 子标题\n实际描述" + assert extract_faq_title(content) == "实际描述" + + def test_extract_no_match_returns_default(self) -> None: + content = "一段普通文本" + assert extract_faq_title(content) == "未命名问题" + + def test_extract_empty_content(self) -> None: + assert extract_faq_title("") == "未命名问题" + + def test_question_priority_over_bug(self) -> None: + content = "**问题**: 优先问题\n## Bug 问题描述\nbug 内容" + assert extract_faq_title(content) == "优先问题" + + +# --------------------------------------------------------------------------- +# FAQStorage +# --------------------------------------------------------------------------- + + +class TestFAQStorage: + def _make_storage(self) -> FAQStorage: + with patch.object(Path, "mkdir"): + return FAQStorage(base_dir="data/faq") + + @pytest.mark.asyncio + async def test_create(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=[]), + patch(_WRITE_JSON, new_callable=AsyncMock), + ): + faq = await storage.create( + group_id=100, + target_qq=200, + start_time="2025-01-01", + end_time="2025-01-02", + title="标题", + content="内容", + ) + assert faq.group_id == 100 + assert faq.title == "标题" + assert faq.id # 有生成 ID + + @pytest.mark.asyncio + async def test_get_existing(self) -> None: + storage = self._make_storage() + sample = FAQ( + id="20250101-001", + group_id=100, + target_qq=200, + start_time="s", + end_time="e", + created_at="c", + title="t", + content="body", + ) + with ( + patch.object(Path, "mkdir"), + patch(_READ_JSON, new_callable=AsyncMock, return_value=sample.to_dict()), + ): + result = await storage.get(100, "20250101-001") + assert result is not None + assert result.title == "t" + + @pytest.mark.asyncio + async def test_get_nonexistent(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_READ_JSON, new_callable=AsyncMock, return_value=None), + ): + result = await storage.get(100, "nonexist") + assert result is None + + @pytest.mark.asyncio + async def test_list_all(self) -> None: + storage = self._make_storage() + faq1 = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="t1", + content="c1", + ) + faq2 = FAQ( + id="002", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="t2", + content="c2", + ) + mock_files = [Path("a.json"), Path("b.json")] + results_iter = iter([faq1.to_dict(), faq2.to_dict()]) + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch( + _READ_JSON, + new_callable=AsyncMock, + side_effect=lambda *a, **kw: next(results_iter), + ), + ): + faqs = await storage.list_all(1) + + assert len(faqs) == 2 + + @pytest.mark.asyncio + async def test_search_matches(self) -> None: + storage = self._make_storage() + faq_match = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="Python 教程", + content="内容", + ) + faq_no_match = FAQ( + id="002", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="其他", + content="其他内容", + ) + mock_files = [Path("a.json"), Path("b.json")] + results_iter = iter([faq_match.to_dict(), faq_no_match.to_dict()]) + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch( + _READ_JSON, + new_callable=AsyncMock, + side_effect=lambda *a, **kw: next(results_iter), + ), + ): + results = await storage.search(1, "python") + + assert len(results) == 1 + assert results[0].title == "Python 教程" + + @pytest.mark.asyncio + async def test_search_case_insensitive(self) -> None: + storage = self._make_storage() + faq = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="UPPER", + content="body", + ) + mock_files = [MagicMock(spec=Path)] + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch(_READ_JSON, new_callable=AsyncMock, return_value=faq.to_dict()), + ): + results = await storage.search(1, "upper") + + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_search_in_content(self) -> None: + storage = self._make_storage() + faq = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="无关标题", + content="详细的 Python 教程", + ) + mock_files = [MagicMock(spec=Path)] + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch(_READ_JSON, new_callable=AsyncMock, return_value=faq.to_dict()), + ): + results = await storage.search(1, "python") + + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_delete(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_DELETE_FILE, new_callable=AsyncMock, return_value=True), + ): + result = await storage.delete(100, "20250101-001") + assert result is True + + @pytest.mark.asyncio + async def test_delete_nonexistent(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_DELETE_FILE, new_callable=AsyncMock, return_value=False), + ): + result = await storage.delete(100, "nonexist") + assert result is False diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py new file mode 100644 index 00000000..65fff3d6 --- /dev/null +++ b/tests/test_fetch_messages_tool.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from Undefined.skills.agents.summary_agent.tools.fetch_messages.handler import ( + _filter_by_time, + _parse_time_range, + execute as fetch_messages_execute, +) +from Undefined.utils.xml import format_message_xml, format_messages_xml + + +# -- _parse_time_range unit tests -- + + +def test_parse_time_range_1h() -> None: + """'1h' → 3600.""" + assert _parse_time_range("1h") == 3600 + + +def test_parse_time_range_6h() -> None: + """'6h' → 21600.""" + assert _parse_time_range("6h") == 21600 + + +def test_parse_time_range_1d() -> None: + """'1d' → 86400.""" + assert _parse_time_range("1d") == 86400 + + +def test_parse_time_range_7d() -> None: + """'7d' → 604800.""" + assert _parse_time_range("7d") == 604800 + + +def test_parse_time_range_1w() -> None: + """'1w' → 604800.""" + assert _parse_time_range("1w") == 604800 + + +def test_parse_time_range_case_insensitive() -> None: + """'1H', '1D' → correct values.""" + assert _parse_time_range("1H") == 3600 + assert _parse_time_range("1D") == 86400 + assert _parse_time_range("1W") == 604800 + + +def test_parse_time_range_invalid() -> None: + """'invalid' → None.""" + assert _parse_time_range("invalid") is None + assert _parse_time_range("") is None + assert _parse_time_range("abc") is None + assert _parse_time_range("1x") is None + + +def test_parse_time_range_with_whitespace() -> None: + """' 1d ' → 86400 (strips whitespace).""" + assert _parse_time_range(" 1d ") == 86400 + + +def test_parse_time_range_multi_digit() -> None: + """'24h' → 86400.""" + assert _parse_time_range("24h") == 86400 + + +# -- _filter_by_time unit tests -- + + +def test_filter_by_time_keeps_recent() -> None: + """Messages within time range are kept.""" + now = datetime.now() + recent = (now - timedelta(seconds=1800)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(seconds=7200)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": recent, "message": "recent"}, + {"timestamp": old, "message": "old"}, + ] + + result = _filter_by_time(messages, 3600) # 1 hour + assert len(result) == 1 + assert result[0]["message"] == "recent" + + +def test_filter_by_time_removes_old() -> None: + """Messages outside time range are removed.""" + now = datetime.now() + old1 = (now - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S") + old2 = (now - timedelta(days=3)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": old1, "message": "old1"}, + {"timestamp": old2, "message": "old2"}, + ] + + result = _filter_by_time(messages, 86400) # 1 day + assert len(result) == 0 + + +def test_filter_by_time_missing_timestamp() -> None: + """Messages without timestamp are filtered out.""" + messages = [ + {"message": "no timestamp"}, + {"timestamp": "", "message": "empty timestamp"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +def test_filter_by_time_invalid_timestamp() -> None: + """Messages with invalid timestamp format are filtered out.""" + messages = [ + {"timestamp": "invalid-format", "message": "bad format"}, + {"timestamp": "2024-13-45 99:99:99", "message": "impossible date"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +# -- format_messages_xml unit tests -- + + +def testformat_message_xml_group_basic() -> None: + """Group message is formatted into main-AI-compatible XML.""" + messages = [ + { + "type": "group", + "chat_id": "123456", + "chat_name": "测试群", + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "user_id": "10001", + "message": "Hello", + "role": "member", + "title": "群主", + "level": "42", + "message_id": 123, + }, + ] + + result = format_message_xml(messages[0]) + assert 'message_id="123"' in result + assert 'sender="Alice"' in result + assert 'sender_id="10001"' in result + assert 'group_id="123456"' in result + assert 'group_name="测试群"' in result + assert 'location="测试群"' in result + assert 'role="member"' in result + assert 'title="群主"' in result + assert 'level="42"' in result + assert "Hello" in result + + +def testformat_message_xml_private_basic() -> None: + """Private message uses the private XML shape.""" + msg = { + "type": "private", + "timestamp": "2024-01-01 12:00:00", + "display_name": "Bob", + "user_id": "10002", + "message": "Hi", + "message_id": 456, + } + + result = format_message_xml(msg) + assert 'message_id="456"' in result + assert 'sender="Bob"' in result + assert 'sender_id="10002"' in result + assert 'location="私聊"' in result + assert "group_id=" not in result + assert "role=" not in result + assert "Hi" in result + + +def testformat_message_xml_includes_attachments() -> None: + """Attachment refs are rendered as XML below content.""" + msg = { + "type": "group", + "chat_id": "123456", + "chat_name": "测试群", + "timestamp": "2024-01-01 12:00:00", + "display_name": "Charlie", + "user_id": "10003", + "message": "看这个", + "attachments": [ + { + "uid": "pic_abcd1234", + "kind": "image", + "media_type": "image", + "display_name": "a.png", + "description": "截图", + } + ], + } + + result = format_message_xml(msg) + assert "" in result + assert 'uid="pic_abcd1234"' in result + assert 'type="image"' in result + assert 'description="截图"' in result + + +def testformat_messages_xml_multiple() -> None: + """Multiple messages are separated by main-AI-style delimiters.""" + messages = [ + { + "type": "group", + "chat_id": "123456", + "chat_name": "测试群", + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "user_id": "10001", + "message": "First", + }, + { + "type": "private", + "timestamp": "2024-01-01 12:00:00", + "display_name": "Bob", + "user_id": "10002", + "message": "Second", + }, + ] + + result = format_messages_xml(messages) + assert "\n---\n" in result + assert result.count(" None: + """Missing fields still produce valid XML.""" + messages = [ + { + "timestamp": "", + "message": "No timestamp", + }, + ] + + result = format_messages_xml(messages) + assert "未知用户" in result + assert "No timestamp" in result + assert " None: + """Count-based fetch in group context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "Message 1", + "role": "", + "title": "", + }, + { + "timestamp": "2024-01-01 12:01:00", + "display_name": "Bob", + "message": "Message 2", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute({"count": 50}, context) + + assert "共获取 2 条消息" in result + assert 'sender="Alice"' in result + assert "Message 1" in result + assert 'sender="Bob"' in result + assert "Message 2" in result + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_based_private() -> None: + """Count-based fetch in private context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "User", + "message": "Private message", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 0, + "user_id": 99999, + } + + result = await fetch_messages_execute({"count": 20}, context) + + assert "共获取 1 条消息" in result + assert 'location="私聊"' in result + assert "Private message" in result + history_manager.get_recent.assert_called_once_with("99999", "private", 0, 20) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range() -> None: + """Time-range fetch filters by time.""" + now = datetime.now() + recent = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S") + + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": recent, + "display_name": "Alice", + "message": "Recent", + "role": "", + "title": "", + }, + { + "timestamp": old, + "display_name": "Bob", + "message": "Old", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + assert "共获取 1 条消息" in result + assert "(时间范围: 1d)" in result + assert "Recent" in result + assert "Old" not in result + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_time_range() -> None: + """Invalid time range returns error message.""" + context: dict[str, Any] = { + "history_manager": MagicMock(), + "group_id": 123456, + } + + result = await fetch_messages_execute( + {"time_range": "invalid"}, + context, + ) + + assert "无法解析时间范围: invalid" in result + assert "支持格式: 1h, 6h, 1d, 7d" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_empty_history() -> None: + """Empty history returns '当前会话暂无消息记录'.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "当前会话暂无消息记录" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_no_history_manager() -> None: + """No history_manager returns error.""" + context: dict[str, Any] = { + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "历史记录管理器未配置" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_count_capped_at_config_limit() -> None: + """Count is capped at the configured summary fetch limit (fallback 1000).""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 1000) + + +@pytest.mark.asyncio +async def test_fetch_messages_default_count() -> None: + """Default count is 50 when not specified.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_count_defaults() -> None: + """Invalid count defaults to 50.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": "invalid"}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_fetch_larger_batch() -> None: + """Time range mode fetches larger batch (max(count*2, fallback_time_limit)).""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + # max(50*2, 5000) = 5000 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 5000) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_uses_runtime_config() -> None: + """When runtime_config provides history_summary_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 300) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_uses_runtime_config() -> None: + """When runtime_config provides history_summary_time_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 50, "time_range": "1d"}, context) + + # max(50*2, 800) = 800 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 800) diff --git a/tests/test_group_metrics.py b/tests/test_group_metrics.py new file mode 100644 index 00000000..04f832d3 --- /dev/null +++ b/tests/test_group_metrics.py @@ -0,0 +1,190 @@ +"""Tests for Undefined.utils.group_metrics — group member metric helpers.""" + +from __future__ import annotations + +from datetime import datetime + +from Undefined.utils.group_metrics import ( + clamp_int, + datetime_to_ts, + format_timestamp, + member_display_name, + parse_member_level, + parse_unix_timestamp, + role_to_cn, +) + + +class TestClampInt: + def test_within_range(self) -> None: + assert clamp_int(5, 0, 1, 10) == 5 + + def test_below_min(self) -> None: + assert clamp_int(-5, 0, 1, 10) == 1 + + def test_above_max(self) -> None: + assert clamp_int(20, 0, 1, 10) == 10 + + def test_at_min(self) -> None: + assert clamp_int(1, 0, 1, 10) == 1 + + def test_at_max(self) -> None: + assert clamp_int(10, 0, 1, 10) == 10 + + def test_non_numeric_returns_default(self) -> None: + assert clamp_int("abc", 7, 1, 10) == 7 + + def test_none_returns_default(self) -> None: + assert clamp_int(None, 5, 1, 10) == 5 + + def test_string_int(self) -> None: + assert clamp_int("3", 0, 1, 10) == 3 + + def test_float_truncated(self) -> None: + assert clamp_int(3.9, 0, 1, 10) == 3 + + def test_bool_as_int(self) -> None: + assert clamp_int(True, 0, 0, 10) == 1 + + +class TestParseUnixTimestamp: + def test_valid_positive(self) -> None: + assert parse_unix_timestamp(1700000000) == 1700000000 + + def test_zero(self) -> None: + assert parse_unix_timestamp(0) == 0 + + def test_negative(self) -> None: + assert parse_unix_timestamp(-100) == 0 + + def test_none(self) -> None: + assert parse_unix_timestamp(None) == 0 + + def test_non_numeric(self) -> None: + assert parse_unix_timestamp("abc") == 0 + + def test_string_number(self) -> None: + assert parse_unix_timestamp("1700000000") == 1700000000 + + def test_float(self) -> None: + assert parse_unix_timestamp(1700000000.5) == 1700000000 + + +class TestParseMemberLevel: + def test_integer(self) -> None: + assert parse_member_level(5) == 5 + + def test_zero(self) -> None: + assert parse_member_level(0) == 0 + + def test_negative_int(self) -> None: + assert parse_member_level(-1) is None + + def test_none(self) -> None: + assert parse_member_level(None) is None + + def test_bool_returns_none(self) -> None: + assert parse_member_level(True) is None + assert parse_member_level(False) is None + + def test_float(self) -> None: + assert parse_member_level(3.7) == 3 + + def test_digit_string(self) -> None: + assert parse_member_level("10") == 10 + + def test_string_with_digits(self) -> None: + assert parse_member_level("Lv.5") == 5 + + def test_string_no_digits(self) -> None: + assert parse_member_level("无") is None + + def test_empty_string(self) -> None: + assert parse_member_level("") is None + + def test_whitespace_string(self) -> None: + assert parse_member_level(" ") is None + + def test_complex_string(self) -> None: + assert parse_member_level("等级42勋章") == 42 + + +class TestMemberDisplayName: + def test_card_preferred(self) -> None: + member = {"card": "CardName", "nickname": "Nick", "user_id": 123} + assert member_display_name(member) == "CardName" + + def test_nickname_fallback(self) -> None: + member = {"card": "", "nickname": "Nick", "user_id": 123} + assert member_display_name(member) == "Nick" + + def test_user_id_fallback(self) -> None: + member = {"card": "", "nickname": "", "user_id": 123} + assert member_display_name(member) == "123" + + def test_none_card(self) -> None: + member = {"card": None, "nickname": "Nick"} + assert member_display_name(member) == "Nick" + + def test_all_missing(self) -> None: + member: dict[str, object] = {} + assert member_display_name(member) == "未知" + + def test_whitespace_card(self) -> None: + member = {"card": " ", "nickname": "Nick"} + assert member_display_name(member) == "Nick" + + +class TestRoleToCn: + def test_owner(self) -> None: + assert role_to_cn("owner") == "群主" + + def test_admin(self) -> None: + assert role_to_cn("admin") == "管理员" + + def test_member(self) -> None: + assert role_to_cn("member") == "成员" + + def test_none_defaults_to_member(self) -> None: + assert role_to_cn(None) == "成员" + + def test_unknown_role_passthrough(self) -> None: + assert role_to_cn("moderator") == "moderator" + + def test_empty_string_defaults_to_member(self) -> None: + # str("" or "member") -> "member" + assert role_to_cn("") == "成员" + + +class TestFormatTimestamp: + def test_valid_timestamp(self) -> None: + ts = int(datetime(2024, 1, 15, 12, 0, 0).timestamp()) + result = format_timestamp(ts) + assert "2024-01-15" in result + + def test_zero(self) -> None: + assert format_timestamp(0) == "无" + + def test_negative(self) -> None: + assert format_timestamp(-1) == "无" + + def test_overflow(self) -> None: + assert format_timestamp(999999999999999) == "无" + + +class TestDatetimeToTs: + def test_none(self) -> None: + assert datetime_to_ts(None) is None + + def test_valid_datetime(self) -> None: + dt = datetime(2024, 6, 15, 12, 0, 0) + result = datetime_to_ts(dt) + assert result is not None + assert isinstance(result, int) + # Round-trip check + assert datetime.fromtimestamp(result).replace(second=0) == dt.replace(second=0) + + def test_epoch(self) -> None: + dt = datetime(1970, 1, 1, 0, 0, 0) + result = datetime_to_ts(dt) + assert result is not None diff --git a/tests/test_handlers_meme_annotation.py b/tests/test_handlers_meme_annotation.py new file mode 100644 index 00000000..a8ff4d62 --- /dev/null +++ b/tests/test_handlers_meme_annotation.py @@ -0,0 +1,277 @@ +"""测试 handlers.py 中的表情包自动匹配功能""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from Undefined.attachments import AttachmentRecord, AttachmentRegistry + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_success(tmp_path: Path) -> None: + """测试成功匹配表情包时添加描述""" + # 创建 mock handler + handler = MagicMock() + + # 设置 attachment registry + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 添加一个附件记录 + test_sha256 = "abc123def456" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # 设置 meme service mock + mock_meme_record = MagicMock() + mock_meme_record.description = "可爱的猫猫" + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=mock_meme_record) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + # 导入实现 + from Undefined.handlers import MessageHandler + + # 测试输入 + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "file_002", "kind": "file"}, + ] + + # 调用函数 + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + # 第一个附件应该有表情包描述 + assert result[0]["uid"] == "pic_001" + assert result[0]["description"] == "[表情包] 可爱的猫猫" + # 第二个附件不应该被修改 + assert result[1]["uid"] == "file_002" + assert "description" not in result[1] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_no_match(tmp_path: Path) -> None: + """测试没有匹配表情包时返回原始列表""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + test_sha256 = "xyz789" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # meme store 返回 None(没找到) + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=None) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert len(result) == 1 + assert result[0]["uid"] == "pic_001" + assert "description" not in result[0] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_meme_disabled() -> None: + """测试 meme service 禁用时返回原始列表""" + handler = MagicMock() + + # meme service 禁用 + mock_meme_service = MagicMock() + mock_meme_service.enabled = False + + ai = MagicMock() + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_error_handling() -> None: + """测试异常处理:失败时返回原始列表""" + handler = MagicMock() + + # 设置会抛出异常的 attachment registry + mock_attachment_registry = MagicMock() + mock_attachment_registry.resolve = MagicMock( + side_effect=Exception("Registry error") + ) + + # 设置 meme service + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(side_effect=Exception("Database error")) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = mock_attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表(异常被捕获) + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_batch_query(tmp_path: Path) -> None: + """测试批量查询:多个附件共享同一个哈希值""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 两个附件,相同的 SHA256 + shared_sha256 = "shared123" + for uid in ["pic_001", "pic_002"]: + record = AttachmentRecord( + uid=uid, + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name=f"{uid}.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=shared_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records[uid] = record + + # 记录 find_by_sha256 被调用的次数 + call_count = 0 + + async def mock_find_by_sha256(sha: str) -> Any: + nonlocal call_count + call_count += 1 + if sha == shared_sha256: + meme = MagicMock() + meme.description = "共享表情包" + return meme + return None + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = mock_find_by_sha256 + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "pic_002", "kind": "image"}, + ] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + assert result[0]["description"] == "[表情包] 共享表情包" + assert result[1]["description"] == "[表情包] 共享表情包" + + # 验证 find_by_sha256 只被调用一次(批量查询去重) + assert call_count == 1 diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py new file mode 100644 index 00000000..c305b6bb --- /dev/null +++ b/tests/test_handlers_repeat.py @@ -0,0 +1,483 @@ +"""MessageHandler 复读功能测试""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from Undefined.handlers import ( + MessageHandler, + REPEAT_REPLY_HISTORY_PREFIX, +) + + +def _build_handler( + *, + repeat_enabled: bool = False, + repeat_threshold: int = 3, + repeat_cooldown_minutes: int = 60, + inverted_question_enabled: bool = False, + keyword_reply_enabled: bool = False, +) -> Any: + handler: Any = MessageHandler.__new__(MessageHandler) + handler.config = SimpleNamespace( + bot_qq=10000, + repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, + inverted_question_enabled=inverted_question_enabled, + keyword_reply_enabled=keyword_reply_enabled, + bilibili_auto_extract_enabled=False, + arxiv_auto_extract_enabled=False, + should_process_group_message=lambda is_at_bot=False: True, + should_process_private_message=lambda: True, + is_group_allowed=lambda _gid: True, + is_private_allowed=lambda _uid: True, + access_control_enabled=lambda: False, + process_every_message=True, + ) + handler.history_manager = SimpleNamespace( + add_group_message=AsyncMock(), + add_private_message=AsyncMock(), + ) + handler.sender = SimpleNamespace( + send_group_message=AsyncMock(), + send_private_message=AsyncMock(), + ) + handler.ai_coordinator = SimpleNamespace( + handle_auto_reply=AsyncMock(), + handle_private_reply=AsyncMock(), + _is_at_bot=lambda _mc: False, + ) + handler.ai = SimpleNamespace( + _cognitive_service=None, + memory_storage=None, + model_pool=SimpleNamespace( + handle_private_message=AsyncMock(return_value=False) + ), + ) + handler.onebot = SimpleNamespace( + get_group_info=AsyncMock(return_value={"group_name": "测试群"}), + get_stranger_info=AsyncMock(return_value={"nickname": "用户"}), + get_msg=AsyncMock(return_value=None), + get_forward_msg=AsyncMock(return_value=None), + ) + handler.command_dispatcher = SimpleNamespace( + parse_command=lambda _t: None, + ) + handler._background_tasks = set() + handler._repeat_counter = {} + handler._repeat_locks = {} + handler._repeat_cooldown = {} + handler._profile_name_refresh_cache = {} + handler._bot_nickname_cache = SimpleNamespace( + get_nicknames=AsyncMock(return_value=frozenset()), + ) + return handler + + +def _group_event( + group_id: int = 30001, + sender_id: int = 20001, + text: str = "hello", +) -> dict[str, Any]: + return { + "post_type": "message", + "message_type": "group", + "group_id": group_id, + "user_id": sender_id, + "message_id": 1, + "sender": { + "user_id": sender_id, + "card": f"用户{sender_id}", + "nickname": f"昵称{sender_id}", + "role": "member", + "title": "", + }, + "message": [{"type": "text", "data": {"text": text}}], + } + + +# ── 基础:复读未启用时不触发 ── + + +@pytest.mark.asyncio +async def test_repeat_disabled_does_not_repeat() -> None: + handler = _build_handler(repeat_enabled=False) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +# ── 复读触发:3条相同消息来自不同人 ── + + +@pytest.mark.asyncio +async def test_repeat_triggers_on_3_identical_from_different_senders() -> None: + handler = _build_handler(repeat_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[0] == 30001 + assert call.args[1] == "hello" + assert call.kwargs.get("history_prefix") == REPEAT_REPLY_HISTORY_PREFIX + + +# ── 不触发:3条相同消息来自同一人 ── + + +@pytest.mark.asyncio +async def test_repeat_does_not_trigger_from_same_sender() -> None: + handler = _build_handler(repeat_enabled=True) + for _ in range(3): + await handler.handle_message(_group_event(sender_id=20001, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +# ── 不触发:消息内容不同 ── + + +@pytest.mark.asyncio +async def test_repeat_does_not_trigger_for_different_texts() -> None: + handler = _build_handler(repeat_enabled=True) + for uid, text in [(20001, "hello"), (20002, "world"), (20003, "hello")]: + await handler.handle_message(_group_event(sender_id=uid, text=text)) + + handler.sender.send_group_message.assert_not_called() + + +# ── 防重复:触发后计数器清空 ── + + +@pytest.mark.asyncio +async def test_repeat_clears_counter_after_trigger() -> None: + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) + # 第一轮:3条相同触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + assert handler.sender.send_group_message.call_count == 1 + + # 第二轮:再来3条相同应再次触发(无冷却) + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + assert handler.sender.send_group_message.call_count == 2 + + +# ── 倒问号:问号消息触发倒问号 ── + + +@pytest.mark.asyncio +async def test_inverted_question_sends_inverted_mark() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿" + + +@pytest.mark.asyncio +async def test_inverted_question_multiple_marks() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿¿¿" + + +@pytest.mark.asyncio +async def test_inverted_question_chinese_question_mark() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿" + + +@pytest.mark.asyncio +async def test_inverted_question_disabled_sends_normal_text() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=False) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "?" + + +@pytest.mark.asyncio +async def test_inverted_question_mixed_text_not_triggered() -> None: + """非纯问号消息不受倒问号影响,正常复读。""" + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="what?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "what?" + + +# ── 不同群互不干扰 ── + + +@pytest.mark.asyncio +async def test_repeat_groups_are_independent() -> None: + handler = _build_handler(repeat_enabled=True) + # 群A: 2条相同 + await handler.handle_message( + _group_event(group_id=30001, sender_id=20001, text="hi") + ) + await handler.handle_message( + _group_event(group_id=30001, sender_id=20002, text="hi") + ) + # 群B: 3条相同 + for uid in [30001, 30002, 30003]: + await handler.handle_message( + _group_event(group_id=30002, sender_id=uid, text="hi") + ) + + # 群B触发,群A未触发 + assert handler.sender.send_group_message.call_count == 1 + call = handler.sender.send_group_message.call_args + assert call.args[0] == 30002 + + +# ── 计数器窗口:只看最近 N 条 ── + + +@pytest.mark.asyncio +async def test_repeat_counter_sliding_window() -> None: + handler = _build_handler(repeat_enabled=True) + # 发5条不同消息 + for i in range(5): + await handler.handle_message(_group_event(sender_id=20001 + i, text=f"msg{i}")) + # 再发3条相同 + for uid in [20010, 20011, 20012]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "hello" + + +# ── bot 自身发言后不触发复读 ── + +BOT_QQ = 10000 + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_before_users() -> None: + """bot 先发,后面用户再发相同消息,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # bot 先发 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 两个用户跟发 + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_in_middle() -> None: + """用户发到一半,bot 插入相同消息,之后用户再凑满阈值,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # 两个用户先发 + await handler.handle_message(_group_event(sender_id=20001, text="hello")) + await handler.handle_message(_group_event(sender_id=20002, text="hello")) + # bot 插入 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 第三个用户发:此时窗口 [user2, bot, user3],含 bot → 不触发 + await handler.handle_message(_group_event(sender_id=20003, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_triggers_after_bot_window_slides_out() -> None: + """bot 消息滑出窗口后,纯用户序列应正常触发复读(threshold=3)。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=3) + # bot 先发(进入窗口) + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 三个不同用户依次发:窗口变为 [user1, user2, user3](bot 已滑出) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hello" + + +# ── 可配置阈值 ── + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_2() -> None: + """threshold=2 时,2 条不同发送者相同消息即触发复读。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=2) + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hi")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hi" + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_4() -> None: + """threshold=4 时,3 条不同发送者相同消息不触发,第 4 条才触发。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=4) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hey")) + handler.sender.send_group_message.assert_not_called() + + await handler.handle_message(_group_event(sender_id=20004, text="hey")) + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hey" + + +# ── 冷却机制:复读后同一内容在冷却期内不再触发 ── + + +@pytest.mark.asyncio +async def test_repeat_cooldown_suppresses_same_text() -> None: + """复读触发后,同一内容在冷却期内再次满足条件也不触发,但消息继续处理。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 第二轮:同一内容,应被冷却抑制——复读不发送 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + # 冷却抑制后消息应继续处理(到达 AI auto-reply),不应被 silently dropped + assert handler.ai_coordinator.handle_auto_reply.call_count > 0 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_allows_different_text() -> None: + """复读 "草" 后,不同内容 "lol" 仍可正常复读。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + assert handler.sender.send_group_message.call_args.args[1] == "lol" + + +@pytest.mark.asyncio +async def test_repeat_cooldown_expired_allows_retrigger( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """冷却过期后,相同内容可以再次触发。""" + import time as _time + + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 模拟时间流逝 61 分钟 + original_monotonic = _time.monotonic + monkeypatch.setattr(_time, "monotonic", lambda: original_monotonic() + 3660) + + # 第二轮:冷却已过期,应触发 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_zero_disables() -> None: + """cooldown=0 时不启用冷却,连续复读同一内容仍可触发。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + # cooldown=0 不应写入任何冷却记录(防止内存泄漏) + assert len(handler._repeat_cooldown) == 0 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_question_mark_normalization() -> None: + """全角问号 ? 和半角问号 ? 视为等价,复读后互相抑制。""" + handler = _build_handler( + repeat_enabled=True, + repeat_cooldown_minutes=60, + inverted_question_enabled=True, + ) + # 全角问号触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 + assert handler.sender.send_group_message.call_args.args[1] == "¿¿¿" + + # 半角问号——应被冷却抑制(?和 ? 等价) + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_multiple_texts_tracked() -> None: + """多种不同内容各自独立冷却。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + # 复读 "lol" + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + # "草" 再次满足条件 → 冷却中,不触发 + for uid in [20007, 20008, 20009]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + # "lol" 再次满足条件 → 冷却中,不触发 + for uid in [20010, 20011, 20012]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_groups_independent() -> None: + """不同群的冷却互不影响。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 群A 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message( + _group_event(group_id=30001, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 1 + + # 群B 复读 "草" — 不同群,不受群A冷却影响 + for uid in [20004, 20005, 20006]: + await handler.handle_message( + _group_event(group_id=30002, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 2 diff --git a/tests/test_history_config.py b/tests/test_history_config.py new file mode 100644 index 00000000..c5b0d3e6 --- /dev/null +++ b/tests/test_history_config.py @@ -0,0 +1,76 @@ +"""Tests for configurable history limits in Config and tool handlers.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Helper: simulate _get_history_limit as used across multiple handlers +# --------------------------------------------------------------------------- +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """Mirror of the helper used in tool handlers.""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback + + +class TestGetHistoryLimit: + """Tests for the _get_history_limit helper pattern.""" + + def test_returns_fallback_when_no_config(self) -> None: + assert _get_history_limit({}, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_is_none(self) -> None: + ctx: dict[str, Any] = {"runtime_config": None} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_config_value(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 500 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 500 + + def test_returns_fallback_when_config_attr_missing(self) -> None: + cfg = MagicMock(spec=[]) # no attributes + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "nonexistent_field", 42) == 42 + + def test_returns_fallback_when_config_value_zero(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 0 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_value_negative(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = -1 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + +class TestConfigHistoryFieldDefaults: + """Verify that Config dataclass has the expected history fields.""" + + def test_config_has_history_fields(self) -> None: + from Undefined.config.loader import Config + + fields = Config.__dataclass_fields__ + expected_fields = [ + "history_max_records", + "history_filtered_result_limit", + "history_search_scan_limit", + "history_summary_fetch_limit", + "history_summary_time_fetch_limit", + "history_onebot_fetch_limit", + "history_group_analysis_limit", + ] + for field_name in expected_fields: + assert field_name in fields, f"Missing field: {field_name}" + assert fields[field_name].type == "int", ( + f"{field_name} should be int, got {fields[field_name].type}" + ) diff --git a/tests/test_history_level.py b/tests/test_history_level.py new file mode 100644 index 00000000..59a88d4e --- /dev/null +++ b/tests/test_history_level.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from Undefined.utils.history import MessageHistoryManager + + +@pytest.mark.asyncio +async def test_add_group_message_with_level_stores_level_field( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加带 level 的群消息会正确存储 level 字段""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + role="member", + title="", + level="Lv.5", + message_id=123456, + ) + + assert "20001" in manager._message_history + assert len(manager._message_history["20001"]) == 1 + + record = manager._message_history["20001"][0] + assert record["level"] == "Lv.5" + assert record["type"] == "group" + assert record["chat_id"] == "20001" + assert record["user_id"] == "10001" + assert record["message"] == "测试消息" + assert record["role"] == "member" + assert record["title"] == "" + assert record["message_id"] == 123456 + + +@pytest.mark.asyncio +async def test_add_group_message_without_level_stores_empty_level( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息不传 level 参数时默认存储空字符串""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + ) + + assert "20001" in manager._message_history + record = manager._message_history["20001"][0] + assert record["level"] == "" + + +@pytest.mark.asyncio +async def test_get_recent_returns_messages_with_level_intact( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试 get_recent 能正确返回带 level 的消息""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = { + "20001": [ + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10001", + "display_name": "测试用户", + "role": "admin", + "title": "管理员", + "level": "Lv.10", + "timestamp": "2026-04-11 10:00:00", + "message": "第一条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10002", + "display_name": "普通用户", + "role": "member", + "title": "", + "level": "Lv.2", + "timestamp": "2026-04-11 10:01:00", + "message": "第二条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10003", + "display_name": "新用户", + "role": "member", + "title": "", + "level": "", + "timestamp": "2026-04-11 10:02:00", + "message": "第三条消息", + }, + ] + } + manager._max_records = 10000 + _evt = asyncio.Event() + _evt.set() + manager._initialized = _evt + + messages = manager.get_recent("20001", "group", 0, 10) + + assert len(messages) == 3 + assert messages[0]["level"] == "Lv.10" + assert messages[1]["level"] == "Lv.2" + assert messages[2]["level"] == "" + + +@pytest.mark.asyncio +async def test_add_group_message_with_empty_level_string( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息显式传入空字符串 level""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + level="", + ) + + record = manager._message_history["20001"][0] + assert "level" in record + assert record["level"] == "" diff --git a/tests/test_member_utils.py b/tests/test_member_utils.py new file mode 100644 index 00000000..b509b79f --- /dev/null +++ b/tests/test_member_utils.py @@ -0,0 +1,196 @@ +"""Tests for Undefined.utils.member_utils — member analysis helpers.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from Undefined.utils.member_utils import ( + analyze_join_trend, + analyze_member_activity, + filter_by_join_time, +) + + +def _make_member( + user_id: int, + join_time: int | float | None = None, + card: str = "", + nickname: str = "", +) -> dict[str, Any]: + m: dict[str, Any] = {"user_id": user_id, "card": card, "nickname": nickname} + if join_time is not None: + m["join_time"] = join_time + return m + + +# A fixed reference timestamp: 2024-06-15 12:00:00 +_REF_TS = int(datetime(2024, 6, 15, 12, 0, 0).timestamp()) +_DAY = 86400 + + +class TestFilterByJoinTime: + def test_empty_list(self) -> None: + assert filter_by_join_time([], None, None) == [] + + def test_no_filters(self) -> None: + members = [_make_member(1, _REF_TS), _make_member(2, _REF_TS + _DAY)] + result = filter_by_join_time(members, None, None) + assert len(result) == 2 + + def test_start_filter(self) -> None: + members = [ + _make_member(1, _REF_TS - _DAY), + _make_member(2, _REF_TS + _DAY), + ] + start_dt = datetime(2024, 6, 15, 0, 0, 0) + result = filter_by_join_time(members, start_dt, None) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_end_filter(self) -> None: + members = [ + _make_member(1, _REF_TS - _DAY), + _make_member(2, _REF_TS + _DAY), + ] + end_dt = datetime(2024, 6, 15, 0, 0, 0) + result = filter_by_join_time(members, None, end_dt) + assert len(result) == 1 + assert result[0]["user_id"] == 1 + + def test_both_filters(self) -> None: + members = [ + _make_member(1, _REF_TS - 2 * _DAY), + _make_member(2, _REF_TS), + _make_member(3, _REF_TS + 2 * _DAY), + ] + start_dt = datetime.fromtimestamp(_REF_TS - _DAY) + end_dt = datetime.fromtimestamp(_REF_TS + _DAY) + result = filter_by_join_time(members, start_dt, end_dt) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_member_without_join_time_skipped(self) -> None: + members = [_make_member(1), _make_member(2, _REF_TS)] + result = filter_by_join_time(members, None, None) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_non_numeric_join_time_skipped(self) -> None: + members: list[dict[str, Any]] = [{"user_id": 1, "join_time": "not-a-number"}] + result = filter_by_join_time(members, None, None) + assert len(result) == 0 + + def test_float_join_time(self) -> None: + members = [_make_member(1, float(_REF_TS) + 0.5)] + result = filter_by_join_time(members, None, None) + assert len(result) == 1 + + +class TestAnalyzeJoinTrend: + def test_empty_list(self) -> None: + assert analyze_join_trend([]) == {} + + def test_single_member(self) -> None: + members = [_make_member(1, _REF_TS)] + result = analyze_join_trend(members) + assert result["peak_count"] == 1 + assert result["avg_per_day"] == 1.0 + assert result["first_time"] is not None + assert result["last_time"] is not None + assert result["first_time"] == result["last_time"] + + def test_multiple_members_same_day(self) -> None: + members = [ + _make_member(1, _REF_TS), + _make_member(2, _REF_TS + 3600), + ] + result = analyze_join_trend(members) + assert result["peak_count"] == 2 + assert result["avg_per_day"] == 2.0 + + def test_multiple_days(self) -> None: + members = [ + _make_member(1, _REF_TS), + _make_member(2, _REF_TS + _DAY), + _make_member(3, _REF_TS + _DAY), + ] + result = analyze_join_trend(members) + assert len(result["daily_stats"]) == 2 + assert result["peak_count"] == 2 + assert result["avg_per_day"] == 1.5 + + def test_members_without_join_time_ignored(self) -> None: + members = [_make_member(1), _make_member(2, _REF_TS)] + result = analyze_join_trend(members) + # Only one member has join_time, but total uses all members + assert result["avg_per_day"] == 2.0 # 2 members / 1 day + assert result["peak_count"] == 1 + + def test_daily_stats_populated(self) -> None: + members = [_make_member(1, _REF_TS)] + result = analyze_join_trend(members) + assert isinstance(result["daily_stats"], dict) + assert len(result["daily_stats"]) == 1 + + +class TestAnalyzeMemberActivity: + def test_empty_members(self) -> None: + result = analyze_member_activity([], {}, 5) + assert result["total_members"] == 0 + assert result["active_members"] == 0 + assert result["total_messages"] == 0 + assert result["top_members"] == [] + + def test_basic_activity(self) -> None: + members = [ + _make_member(1, _REF_TS, nickname="Alice"), + _make_member(2, _REF_TS, nickname="Bob"), + _make_member(3, _REF_TS, nickname="Charlie"), + ] + counts: dict[int, int] = {1: 100, 2: 50, 3: 0} + result = analyze_member_activity(members, counts, 5) + assert result["total_members"] == 3 + assert result["active_members"] == 2 + assert result["inactive_members"] == 1 + assert result["total_messages"] == 150 + assert result["avg_messages"] == 50.0 + assert len(result["top_members"]) == 2 + assert result["top_members"][0]["user_id"] == 1 + + def test_top_count_limit(self) -> None: + members = [_make_member(i, _REF_TS) for i in range(1, 11)] + counts: dict[int, int] = {i: i * 10 for i in range(1, 11)} + result = analyze_member_activity(members, counts, 3) + assert len(result["top_members"]) == 3 + assert result["top_members"][0]["user_id"] == 10 + + def test_active_rate_calculation(self) -> None: + members = [_make_member(1), _make_member(2)] + counts: dict[int, int] = {1: 10, 2: 0} + result = analyze_member_activity(members, counts, 5) + assert result["active_rate"] == 50.0 + + def test_zero_count_excluded_from_top(self) -> None: + members = [_make_member(1, _REF_TS, nickname="A")] + counts: dict[int, int] = {1: 0} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"] == [] + + def test_member_with_card_name(self) -> None: + members = [_make_member(1, _REF_TS, card="CardName", nickname="Nick")] + counts: dict[int, int] = {1: 10} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["nickname"] == "CardName" + + def test_join_time_formatted_in_top(self) -> None: + members = [_make_member(1, _REF_TS, nickname="A")] + counts: dict[int, int] = {1: 5} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["join_time"] != "" + + def test_no_join_time_empty_string(self) -> None: + members = [_make_member(1, nickname="A")] + counts: dict[int, int] = {1: 5} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["join_time"] == "" diff --git a/tests/test_meme_gif_frames.py b/tests/test_meme_gif_frames.py new file mode 100644 index 00000000..aea1f2b1 --- /dev/null +++ b/tests/test_meme_gif_frames.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import math +from pathlib import Path + +from PIL import Image + +from Undefined.memes.service import ( + _compose_grid, + _extract_gif_frames, + _sample_frame_indices, +) + + +def _make_gif(path: Path, n_frames: int, size: tuple[int, int] = (4, 4)) -> None: + """创建一个包含 *n_frames* 帧的 GIF 文件。""" + frames = [ + Image.new("RGBA", size, (i * 30 % 256, i * 60 % 256, i * 90 % 256, 255)) + for i in range(n_frames) + ] + frames[0].save( + path, + format="GIF", + save_all=True, + append_images=frames[1:], + loop=0, + duration=100, + ) + + +# ── _sample_frame_indices ── + + +def test_sample_indices_basic() -> None: + result = _sample_frame_indices(10, 4) + assert result[0] == 0 + assert result[-1] == 9 + assert len(result) == 4 + + +def test_sample_indices_more_than_total() -> None: + result = _sample_frame_indices(3, 10) + assert result == [0, 1, 2] + + +def test_sample_indices_two() -> None: + result = _sample_frame_indices(20, 2) + assert result == [0, 19] + + +def test_sample_indices_one() -> None: + result = _sample_frame_indices(5, 1) + assert result == [0] + + +def test_sample_indices_no_duplicates() -> None: + result = _sample_frame_indices(3, 3) + assert len(result) == len(set(result)) + + +# ── _extract_gif_frames ── + + +def test_extract_frames_count(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 12) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 6 + for f in frames: + assert f.mode == "RGBA" + f.close() + + +def test_extract_frames_fewer_than_requested(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 3) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 3 + for f in frames: + f.close() + + +def test_extract_frames_single_frame(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 1) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 1 + frames[0].close() + + +# ── _compose_grid ── + + +def test_compose_grid_output(tmp_path: Path) -> None: + frames = [ + Image.new("RGBA", (10, 10), (255, 0, 0, 255)), + Image.new("RGBA", (10, 10), (0, 255, 0, 255)), + Image.new("RGBA", (10, 10), (0, 0, 255, 255)), + Image.new("RGBA", (10, 10), (255, 255, 0, 255)), + ] + output = tmp_path / "grid.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + cols = math.ceil(math.sqrt(4)) + rows = math.ceil(4 / cols) + assert grid.size == (cols * 10, rows * 10) + for f in frames: + f.close() + + +def test_compose_grid_single_frame(tmp_path: Path) -> None: + frames = [Image.new("RGBA", (8, 8), (0, 0, 0, 255))] + output = tmp_path / "grid_single.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + assert grid.size == (8, 8) + frames[0].close() + + +def test_compose_grid_six_frames(tmp_path: Path) -> None: + frames = [Image.new("RGBA", (10, 10), (i * 40, 0, 0, 255)) for i in range(6)] + output = tmp_path / "grid6.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + cols = math.ceil(math.sqrt(6)) + rows = math.ceil(6 / cols) + assert grid.size == (cols * 10, rows * 10) + for f in frames: + f.close() diff --git a/tests/test_meme_retry.py b/tests/test_meme_retry.py new file mode 100644 index 00000000..b07ef0e7 --- /dev/null +++ b/tests/test_meme_retry.py @@ -0,0 +1,63 @@ +from __future__ import annotations + + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from unittest.mock import MagicMock + +from Undefined.memes.service import _is_retryable_llm_error + + +def _make_api_status_error(status_code: int) -> APIStatusError: + response = MagicMock() + response.status_code = status_code + response.headers = {} + response.text = "" + response.json.return_value = {} + return APIStatusError( + message=f"Error {status_code}", + response=response, + body=None, + ) + + +def test_connection_error_is_retryable() -> None: + exc = APIConnectionError(request=MagicMock()) + assert _is_retryable_llm_error(exc) is True + + +def test_timeout_error_is_retryable() -> None: + exc = APITimeoutError(request=MagicMock()) + assert _is_retryable_llm_error(exc) is True + + +def test_status_429_is_retryable() -> None: + exc = _make_api_status_error(429) + assert _is_retryable_llm_error(exc) is True + + +def test_status_500_is_retryable() -> None: + exc = _make_api_status_error(500) + assert _is_retryable_llm_error(exc) is True + + +def test_status_503_is_retryable() -> None: + exc = _make_api_status_error(503) + assert _is_retryable_llm_error(exc) is True + + +def test_status_401_not_retryable() -> None: + exc = _make_api_status_error(401) + assert _is_retryable_llm_error(exc) is False + + +def test_status_400_not_retryable() -> None: + exc = _make_api_status_error(400) + assert _is_retryable_llm_error(exc) is False + + +def test_generic_exception_not_retryable() -> None: + assert _is_retryable_llm_error(ValueError("parse fail")) is False + + +def test_runtime_error_not_retryable() -> None: + assert _is_retryable_llm_error(RuntimeError("oops")) is False diff --git a/tests/test_memory_unit.py b/tests/test_memory_unit.py new file mode 100644 index 00000000..75917c88 --- /dev/null +++ b/tests/test_memory_unit.py @@ -0,0 +1,229 @@ +"""MemoryStorage 单元测试""" + +from __future__ import annotations + +from dataclasses import asdict +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from Undefined.memory import Memory, MemoryStorage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_storage( + initial_data: list[dict[str, str]] | None = None, + max_memories: int = 500, +) -> MemoryStorage: + """构造 MemoryStorage 并跳过真实文件 I/O。""" + with patch("Undefined.memory.MEMORY_FILE_PATH") as mock_path: + if initial_data is not None: + import io as _io + import json + + mock_path.exists.return_value = True + mock_file = _io.StringIO(json.dumps(initial_data)) + mock_open = MagicMock(return_value=mock_file) + with patch("builtins.open", mock_open): + storage = MemoryStorage(max_memories=max_memories) + else: + mock_path.exists.return_value = False + storage = MemoryStorage(max_memories=max_memories) + return storage + + +_WRITE_JSON = "Undefined.utils.io.write_json" + + +# --------------------------------------------------------------------------- +# Memory dataclass +# --------------------------------------------------------------------------- + + +class TestMemoryDataclass: + def test_fields(self) -> None: + m = Memory(uuid="u1", fact="hello", created_at="2025-01-01") + assert m.uuid == "u1" + assert m.fact == "hello" + assert m.created_at == "2025-01-01" + + def test_asdict(self) -> None: + m = Memory(uuid="u1", fact="hello", created_at="2025-01-01") + d = asdict(m) + assert d == {"uuid": "u1", "fact": "hello", "created_at": "2025-01-01"} + + +# --------------------------------------------------------------------------- +# MemoryStorage +# --------------------------------------------------------------------------- + + +class TestMemoryStorageInit: + def test_empty_init(self) -> None: + storage = _make_storage() + assert storage.count() == 0 + assert storage.get_all() == [] + + def test_init_with_data(self) -> None: + data = [ + {"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}, + {"uuid": "u2", "fact": "fact2", "created_at": "2025-01-02"}, + ] + storage = _make_storage(initial_data=data) + assert storage.count() == 2 + + def test_init_with_legacy_data_without_uuid(self) -> None: + """旧格式记录不含 uuid,应自动生成。""" + data: list[dict[str, str]] = [ + {"fact": "old fact", "created_at": "2024-01-01"}, + ] + storage = _make_storage(initial_data=data) + assert storage.count() == 1 + memories = storage.get_all() + assert memories[0].fact == "old fact" + assert memories[0].uuid # 自动生成了 UUID + + +class TestMemoryStorageAdd: + @pytest.mark.asyncio + async def test_add_returns_uuid(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.add("new fact") + assert result is not None + assert storage.count() == 1 + + @pytest.mark.asyncio + async def test_add_strips_whitespace(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add(" spaced fact ") + assert storage.get_all()[0].fact == "spaced fact" + + @pytest.mark.asyncio + async def test_add_empty_returns_none(self) -> None: + storage = _make_storage() + result = await storage.add("") + assert result is None + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_add_whitespace_only_returns_none(self) -> None: + storage = _make_storage() + result = await storage.add(" ") + assert result is None + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_add_duplicate_returns_existing_uuid(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uuid1 = await storage.add("duplicate fact") + uuid2 = await storage.add("duplicate fact") + assert uuid1 == uuid2 + assert storage.count() == 1 + + @pytest.mark.asyncio + async def test_add_max_memories_evicts_oldest(self) -> None: + storage = _make_storage(max_memories=3) + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add("fact1") + await storage.add("fact2") + await storage.add("fact3") + assert storage.count() == 3 + await storage.add("fact4") + assert storage.count() == 3 + facts = [m.fact for m in storage.get_all()] + assert "fact1" not in facts + assert "fact4" in facts + + @pytest.mark.asyncio + async def test_add_calls_save(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock) as mock_write: + await storage.add("fact") + mock_write.assert_awaited_once() + + +class TestMemoryStorageUpdate: + @pytest.mark.asyncio + async def test_update_existing(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("old") + assert uid is not None + result = await storage.update(uid, "new") + assert result is True + assert storage.get_all()[0].fact == "new" + + @pytest.mark.asyncio + async def test_update_nonexistent_returns_false(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.update("nonexistent-uuid", "new") + assert result is False + + @pytest.mark.asyncio + async def test_update_strips_whitespace(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("old") + assert uid is not None + await storage.update(uid, " updated ") + assert storage.get_all()[0].fact == "updated" + + +class TestMemoryStorageDelete: + @pytest.mark.asyncio + async def test_delete_existing(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("to delete") + assert uid is not None + result = await storage.delete(uid) + assert result is True + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_false(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.delete("nonexistent-uuid") + assert result is False + + +class TestMemoryStorageGetAll: + def test_get_all_returns_copy(self) -> None: + data = [{"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}] + storage = _make_storage(initial_data=data) + list1 = storage.get_all() + list2 = storage.get_all() + assert list1 is not list2 + assert list1 == list2 + + +class TestMemoryStorageClear: + @pytest.mark.asyncio + async def test_clear(self) -> None: + data = [{"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}] + storage = _make_storage(initial_data=data) + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.clear() + assert storage.count() == 0 + assert storage.get_all() == [] + + +class TestMemoryStorageCount: + @pytest.mark.asyncio + async def test_count_tracks_additions(self) -> None: + storage = _make_storage() + assert storage.count() == 0 + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add("a") + assert storage.count() == 1 + await storage.add("b") + assert storage.count() == 2 diff --git a/tests/test_message_targets.py b/tests/test_message_targets.py new file mode 100644 index 00000000..69ee58f9 --- /dev/null +++ b/tests/test_message_targets.py @@ -0,0 +1,185 @@ +"""Tests for Undefined.utils.message_targets — target resolution helpers.""" + +from __future__ import annotations + +from typing import Any + +from Undefined.utils.message_targets import parse_positive_int, resolve_message_target + + +class TestParsePositiveInt: + def test_valid_int(self) -> None: + val, err = parse_positive_int(42, "field") + assert val == 42 + assert err is None + + def test_valid_string_int(self) -> None: + val, err = parse_positive_int("123", "field") + assert val == 123 + assert err is None + + def test_none_input(self) -> None: + val, err = parse_positive_int(None, "field") + assert val is None + assert err is None + + def test_zero_rejected(self) -> None: + val, err = parse_positive_int(0, "field") + assert val is None + assert err is not None + assert "正整数" in (err or "") + + def test_negative_rejected(self) -> None: + val, err = parse_positive_int(-5, "field") + assert val is None + assert err is not None + + def test_non_numeric_string(self) -> None: + val, err = parse_positive_int("abc", "field") + assert val is None + assert err is not None + assert "整数" in (err or "") + + def test_float_truncated(self) -> None: + val, err = parse_positive_int(3.9, "field") + assert val == 3 + assert err is None + + def test_float_string_rejected(self) -> None: + val, err = parse_positive_int("3.5", "field") + assert val is None + assert err is not None + + def test_bool_treated_as_int(self) -> None: + # bool is subclass of int; True -> 1 + val, err = parse_positive_int(True, "field") + assert val == 1 + assert err is None + + def test_field_name_in_error(self) -> None: + _, err = parse_positive_int("bad", "target_id") + assert err is not None + assert "target_id" in err + + +class TestResolveMessageTarget: + @staticmethod + def _call( + args: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, + ) -> tuple[tuple[str, int] | None, str | None]: + result: tuple[tuple[str, int] | None, str | None] = resolve_message_target( + args or {}, context or {} + ) + return result + + def test_explicit_group_target(self) -> None: + target, err = self._call( + args={"target_type": "group", "target_id": 12345}, + ) + assert target == ("group", 12345) + assert err is None + + def test_explicit_private_target(self) -> None: + target, err = self._call( + args={"target_type": "private", "target_id": 67890}, + ) + assert target == ("private", 67890) + assert err is None + + def test_target_type_case_insensitive(self) -> None: + target, err = self._call( + args={"target_type": "GROUP", "target_id": 1}, + ) + assert target == ("group", 1) + + def test_target_type_without_id_infers_from_context(self) -> None: + target, err = self._call( + args={"target_type": "group"}, + context={"request_type": "group", "group_id": 100}, + ) + assert target == ("group", 100) + assert err is None + + def test_target_type_without_id_mismatch_context(self) -> None: + target, err = self._call( + args={"target_type": "group"}, + context={"request_type": "private", "user_id": 100}, + ) + assert target is None + assert err is not None + assert "不一致" in (err or "") + + def test_target_id_without_type_error(self) -> None: + target, err = self._call(args={"target_id": 123}) + assert target is None + assert err is not None + assert "同时提供" in (err or "") + + def test_invalid_target_type(self) -> None: + target, err = self._call( + args={"target_type": "channel", "target_id": 1}, + ) + assert target is None + assert err is not None + + def test_target_type_non_string(self) -> None: + target, err = self._call( + args={"target_type": 123, "target_id": 1}, + ) + assert target is None + assert err is not None + assert "字符串" in (err or "") + + def test_legacy_group_id(self) -> None: + target, err = self._call(args={"group_id": 999}) + assert target == ("group", 999) + assert err is None + + def test_legacy_user_id(self) -> None: + target, err = self._call(args={"user_id": 888}) + assert target == ("private", 888) + assert err is None + + def test_legacy_invalid_group_id(self) -> None: + target, err = self._call(args={"group_id": -1}) + assert target is None + assert err is not None + + def test_fallback_to_context_group(self) -> None: + target, err = self._call( + context={"request_type": "group", "group_id": 555}, + ) + assert target == ("group", 555) + assert err is None + + def test_fallback_to_context_private(self) -> None: + target, err = self._call( + context={"request_type": "private", "user_id": 444}, + ) + assert target == ("private", 444) + assert err is None + + def test_fallback_context_group_id_only(self) -> None: + target, err = self._call(context={"group_id": 333}) + assert target == ("group", 333) + assert err is None + + def test_fallback_context_user_id_only(self) -> None: + target, err = self._call(context={"user_id": 222}) + assert target == ("private", 222) + assert err is None + + def test_no_target_info_at_all(self) -> None: + target, err = self._call() + assert target is None + assert err is not None + assert "无法确定" in (err or "") + + def test_target_type_private_infer_from_context(self) -> None: + target, err = self._call( + args={"target_type": "private"}, + context={"request_type": "private", "user_id": 77}, + ) + assert target == ("private", 77) + assert err is None diff --git a/tests/test_message_tools_level.py b/tests/test_message_tools_level.py new file mode 100644 index 00000000..81995024 --- /dev/null +++ b/tests/test_message_tools_level.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any + + +from Undefined.skills.toolsets.messages.get_recent_messages.handler import ( + _format_message_xml, +) + + +def test_format_message_xml_group_with_all_attributes() -> None: + """测试群消息包含 role/title/level 时 XML 全部显示""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "message_id": 123456, + "role": "admin", + "title": "管理员", + "level": "Lv.10", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert "测试消息" in result + + +def test_format_message_xml_group_with_empty_level() -> None: + """测试群消息 level 为空时 XML 不包含 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + "level": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + assert "title=" not in result + + +def test_format_message_xml_private_without_level() -> None: + """测试私聊消息不包含 role/title/level 属性""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result + assert "私聊消息" in result + + +def test_format_message_xml_group_with_only_level() -> None: + """测试群消息只设置 level 时仅显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "", + "title": "", + "level": "Lv.5", + } + + result = _format_message_xml(msg) + + assert 'level="Lv.5"' in result + assert "role=" not in result + assert "title=" not in result + + +def test_format_message_xml_group_without_level_key() -> None: + """测试群消息没有 level 键时不显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + + +def test_format_message_xml_group_with_role_and_title() -> None: + """测试群消息有 role 和 title 但无 level""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "admin", + "title": "群主", + "level": "", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="群主"' in result + assert "level=" not in result + + +def test_format_message_xml_private_with_level_ignored() -> None: + """测试私聊消息即使有 level 也不显示""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + "role": "member", + "title": "测试", + "level": "Lv.99", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result diff --git a/tests/test_message_utils.py b/tests/test_message_utils.py new file mode 100644 index 00000000..3341f26f --- /dev/null +++ b/tests/test_message_utils.py @@ -0,0 +1,232 @@ +"""Tests for Undefined.utils.message_utils.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict + +import pytest + +from Undefined.utils.message_utils import ( + analyze_activity_pattern, + count_message_types, + count_messages_by_user, + filter_user_messages, + format_messages, +) + + +@pytest.fixture(autouse=True) +def _mock_parse_message_time(monkeypatch: pytest.MonkeyPatch) -> None: + """Patch parse_message_time so tests don't depend on onebot imports.""" + + def _fake_parse(msg: Dict[str, Any]) -> datetime: + return datetime.fromtimestamp(msg.get("time", 0)) + + monkeypatch.setattr( + "Undefined.utils.message_utils.parse_message_time", + _fake_parse, + ) + + +def _msg( + user_id: int = 100, + ts: int = 1700000000, + message: Any = None, + nickname: str = "TestUser", +) -> Dict[str, Any]: + """Helper to build a minimal message dict.""" + return { + "sender": {"user_id": user_id, "nickname": nickname}, + "time": ts, + "message": message if message is not None else "hello", + } + + +# --------------------------------------------------------------------------- +# filter_user_messages +# --------------------------------------------------------------------------- + + +class TestFilterUserMessages: + def test_filters_by_user_id(self) -> None: + msgs = [_msg(user_id=1), _msg(user_id=2), _msg(user_id=1)] + result = filter_user_messages(msgs, user_id=1, start_dt=None, end_dt=None) + assert len(result) == 2 + + def test_filters_by_time_range(self) -> None: + msgs = [ + _msg(ts=1700000000), + _msg(ts=1700000100), + _msg(ts=1700000200), + ] + start = datetime.fromtimestamp(1700000050) + end = datetime.fromtimestamp(1700000150) + result = filter_user_messages(msgs, user_id=100, start_dt=start, end_dt=end) + assert len(result) == 1 + + def test_empty_messages(self) -> None: + result = filter_user_messages([], user_id=1, start_dt=None, end_dt=None) + assert result == [] + + def test_no_time_bounds(self) -> None: + msgs = [_msg(user_id=100, ts=1700000000)] + result = filter_user_messages(msgs, user_id=100, start_dt=None, end_dt=None) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# count_message_types +# --------------------------------------------------------------------------- + + +class TestCountMessageTypes: + def test_string_message_is_text(self) -> None: + msgs = [_msg(message="hi")] + result = count_message_types(msgs) + assert result == {"文本消息": 1} + + def test_image_segment(self) -> None: + msgs = [_msg(message=[{"type": "image", "data": {}}])] + result = count_message_types(msgs) + assert result == {"图片消息": 1} + + def test_reply_priority_over_text(self) -> None: + msgs = [ + _msg( + message=[ + {"type": "reply", "data": {}}, + {"type": "text", "data": {"text": "hi"}}, + ] + ) + ] + result = count_message_types(msgs) + assert result == {"回复消息": 1} + + def test_face_segment(self) -> None: + msgs = [_msg(message=[{"type": "face", "data": {}}])] + result = count_message_types(msgs) + assert result == {"表情消息": 1} + + def test_empty_segment_list(self) -> None: + msgs: list[Dict[str, Any]] = [_msg(message=[])] + result = count_message_types(msgs) + assert result == {"空消息": 1} + + def test_other_segment_type(self) -> None: + msgs = [_msg(message=[{"type": "forward", "data": {}}])] + result = count_message_types(msgs) + assert result == {"其他消息": 1} + + def test_text_only_segments(self) -> None: + msgs = [_msg(message=[{"type": "text", "data": {"text": "hello"}}])] + result = count_message_types(msgs) + assert result == {"文本消息": 1} + + def test_mixed_messages(self) -> None: + msgs = [ + _msg(message="hi"), + _msg(message=[{"type": "image", "data": {}}]), + _msg(message=[{"type": "face", "data": {}}]), + ] + result = count_message_types(msgs) + assert result == {"文本消息": 1, "图片消息": 1, "表情消息": 1} + + +# --------------------------------------------------------------------------- +# analyze_activity_pattern +# --------------------------------------------------------------------------- + + +class TestAnalyzeActivityPattern: + def test_empty_returns_empty_dict(self) -> None: + assert analyze_activity_pattern([]) == {} + + def test_single_message(self) -> None: + ts = 1700000000 + msgs = [_msg(ts=ts)] + result = analyze_activity_pattern(msgs) + assert result["avg_per_day"] == 1.0 + assert result["first_time"] is not None + assert result["last_time"] is not None + assert result["first_time"] == result["last_time"] + + def test_multiple_messages_avg_per_day(self) -> None: + # Two messages on the same day + msgs = [_msg(ts=1700000000), _msg(ts=1700000100)] + result = analyze_activity_pattern(msgs) + assert result["avg_per_day"] == 2.0 + + def test_most_active_hour_format(self) -> None: + msgs = [_msg(ts=1700000000)] + result = analyze_activity_pattern(msgs) + hour_str: str = result["most_active_hour"] + assert ":00-" in hour_str + assert ":59" in hour_str + + def test_weekday_is_chinese(self) -> None: + msgs = [_msg(ts=1700000000)] + result = analyze_activity_pattern(msgs) + weekday_str: str = result["most_active_weekday"] + assert weekday_str.startswith("周") + + +# --------------------------------------------------------------------------- +# count_messages_by_user +# --------------------------------------------------------------------------- + + +class TestCountMessagesByUser: + def test_counts_correctly(self) -> None: + msgs = [_msg(user_id=1), _msg(user_id=2), _msg(user_id=1)] + result = count_messages_by_user(msgs, {1, 2, 3}) + assert result == {1: 2, 2: 1, 3: 0} + + def test_unknown_user_ignored(self) -> None: + msgs = [_msg(user_id=99)] + result = count_messages_by_user(msgs, {1}) + assert result == {1: 0} + + def test_empty_messages(self) -> None: + result = count_messages_by_user([], {1, 2}) + assert result == {1: 0, 2: 0} + + +# --------------------------------------------------------------------------- +# format_messages +# --------------------------------------------------------------------------- + + +class TestFormatMessages: + def test_basic_format(self) -> None: + msgs = [_msg(user_id=42, ts=1700000000, nickname="Alice")] + result = format_messages(msgs) + assert len(result) == 1 + assert result[0]["sender"] == "Alice" + assert result[0]["sender_id"] == 42 + assert "2023" in result[0]["time"] + assert result[0]["content"] == "hello" + + def test_segment_format(self) -> None: + msg = _msg( + message=[ + {"type": "text", "data": {"text": "hi "}}, + {"type": "image", "data": {}}, + ] + ) + result = format_messages([msg]) + assert result[0]["content"] == "hi [图片]" + + def test_empty_content_placeholder(self) -> None: + msg = _msg(message=[]) + result = format_messages([msg]) + assert result[0]["content"] == "(空消息)" + + def test_card_preferred_over_nickname(self) -> None: + msg: Dict[str, Any] = { + "sender": {"user_id": 1, "card": "CardName", "nickname": "Nick"}, + "time": 1700000000, + "message": "hi", + } + result = format_messages([msg]) + assert result[0]["sender"] == "CardName" diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py new file mode 100644 index 00000000..b00cd7b0 --- /dev/null +++ b/tests/test_profile_command.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.profile.handler import execute as profile_execute + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + cognitive_service: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, + superadmin_qq: int = 0, +) -> CommandContext: + config_stub = cast(Any, SimpleNamespace()) + config_stub.is_superadmin = lambda qq: qq == superadmin_qq + config_stub.bot_qq = 0 + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=config_stub, + sender=cast(Any, sender), + ai=stub, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + cognitive_service=cognitive_service, + ) + + +# -- Private chat tests -- + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_found() -> None: + """Private chat, own profile found → sends profile via send_private_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="这是一个用户侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=99999, + user_id=99999, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert sender.private_messages[0][0] == 99999 + assert "这是一个用户侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "99999") + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_not_found() -> None: + """Private chat, own profile not found → sends '暂无侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert "📭 暂无侧写数据" in sender.private_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_private_group_subcommand_rejected() -> None: + """Private chat, `/profile group` rejected → sends error message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=77777, + user_id=77777, + ) + + await profile_execute(["group"], context) + + assert len(sender.private_messages) == 1 + assert "❌ 私聊中不支持查看群聊侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +# -- Group chat tests -- + + +@pytest.mark.asyncio +async def test_profile_group_own_profile() -> None: + """Group chat, own profile → sends profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群成员侧写数据") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=55555, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 123456 + assert "群成员侧写数据" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "55555") + + +@pytest.mark.asyncio +async def test_profile_group_profile_subcommand() -> None: + """Group chat, `/profile group` → sends group profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群聊整体侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=654321, + sender_id=44444, + ) + + await profile_execute(["GROUP"], context) # Test case-insensitive + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 654321 + assert "群聊整体侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "654321") + + +@pytest.mark.asyncio +async def test_profile_group_profile_g_shorthand() -> None: + """Group chat, `/p g` shorthand → shows group profile.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群聊简称侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=654321, + sender_id=44444, + ) + + await profile_execute(["g"], context) + + assert len(sender.group_messages) == 1 + assert "群聊简称侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "654321") + + +@pytest.mark.asyncio +async def test_profile_private_g_shorthand_rejected() -> None: + """Private chat, `/p g` shorthand also rejected.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=77777, + user_id=77777, + ) + + await profile_execute(["g"], context) + + assert len(sender.private_messages) == 1 + assert "❌ 私聊中不支持查看群聊侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +@pytest.mark.asyncio +async def test_profile_group_profile_not_found() -> None: + """Group chat, group profile not found → sends '暂无群聊侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value=None) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=111111, + sender_id=33333, + ) + + await profile_execute(["group"], context) + + assert len(sender.group_messages) == 1 + assert "📭 暂无群聊侧写数据" in sender.group_messages[0][1] + + +# -- Edge cases -- + + +@pytest.mark.asyncio +async def test_profile_no_cognitive_service() -> None: + """No cognitive_service → sends '侧写服务未启用'.""" + sender = _DummySender() + + context = _build_context( + sender=sender, + cognitive_service=None, + scope="group", + group_id=123456, + sender_id=22222, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 侧写服务未启用" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_truncation() -> None: + """Profile > 5000 chars gets truncated.""" + sender = _DummySender() + cognitive_service = AsyncMock() + long_profile = "A" * 5500 # Longer than 5000 chars + cognitive_service.get_profile = AsyncMock(return_value=long_profile) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=11111, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + message = sender.group_messages[0][1] + assert len(message) <= 5100 # 5000 + truncation notice + assert "[侧写过长,已截断]" in message + assert message.count("A") == 5000 # Exactly 5000 'A's before truncation + + +# -- Superadmin target tests -- + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_user() -> None: + """Superadmin can query another user's profile with /p .""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="目标用户侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["99999"], context) + + assert len(sender.group_messages) == 1 + assert "目标用户侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "99999") + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_group() -> None: + """Superadmin can query a group profile with /p g <群号>.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="目标群侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["g", "789000"], context) + + assert len(sender.group_messages) == 1 + assert "目标群侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "789000") + + +@pytest.mark.asyncio +async def test_profile_nonadmin_target_rejected() -> None: + """Non-superadmin cannot specify a target QQ → permission error.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=22222, + superadmin_qq=10001, + ) + + await profile_execute(["99999"], context) + + assert len(sender.group_messages) == 1 + assert "❌ 仅超级管理员可查看他人侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_with_mode() -> None: + """Superadmin with render mode + target: /p -r 12345.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="带模式的侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + # render mode will fail (no Playwright) → fallback to text + await profile_execute(["-t", "12345"], context) + + assert len(sender.group_messages) == 1 + assert "带模式的侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "12345") + + +@pytest.mark.asyncio +async def test_profile_superadmin_private_group_with_target() -> None: + """Superadmin in private chat can query a group with /p g <群号>.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="远程群侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=10001, + user_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["g", "654321"], context) + + # Private + group + target → still works for superadmin + assert len(sender.private_messages) == 1 + assert "远程群侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "654321") + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_not_found() -> None: + """Superadmin queries non-existent target → '暂无侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value=None) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["11111"], context) + + assert len(sender.group_messages) == 1 + assert "📭 暂无侧写数据" in sender.group_messages[0][1] diff --git a/tests/test_prompt_builder_easter_egg.py b/tests/test_prompt_builder_easter_egg.py new file mode 100644 index 00000000..332269c0 --- /dev/null +++ b/tests/test_prompt_builder_easter_egg.py @@ -0,0 +1,188 @@ +"""PromptBuilder 彩蛋功能注入测试""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord +from Undefined.memory import Memory + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [] + + +class _FakeCognitiveService: + enabled = False + + async def build_context(self, **kwargs: Any) -> str: + return "" + + +class _FakeMemoryStorage: + def get_all(self) -> list[Memory]: + return [] + + +def _make_builder( + *, + keyword_reply_enabled: bool = False, + repeat_enabled: bool = False, + inverted_question_enabled: bool = False, + easter_egg_agent_call_message_mode: str = "none", +) -> PromptBuilder: + runtime_config = SimpleNamespace( + keyword_reply_enabled=keyword_reply_enabled, + repeat_enabled=repeat_enabled, + inverted_question_enabled=inverted_question_enabled, + easter_egg_agent_call_message_mode=easter_egg_agent_call_message_mode, + knowledge_enabled=False, + grok_search_enabled=False, + chat_model=SimpleNamespace( + model_name="gpt-test", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=False, + ), + vision_model=None, + agent_model=None, + embedding_model=None, + security_model=None, + grok_model=None, + cognitive=None, + memes=None, + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=cast(Any, _FakeMemoryStorage()), + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, None), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +async def _build_messages( + builder: PromptBuilder, + *, + group_id: int | None = None, +) -> list[dict[str, Any]]: + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "" + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return [] + + # Patch internal loaders + builder._load_system_prompt = _fake_load_system_prompt # type: ignore[method-assign,unused-ignore] + builder._load_each_rules = _fake_load_each_rules # type: ignore[method-assign,unused-ignore] + + extra_context: dict[str, Any] = {} + if group_id is not None: + extra_context["group_id"] = group_id + + result = await builder.build_messages( + '\n你好\n', + get_recent_messages_callback=_fake_recent_messages, + extra_context=extra_context if extra_context else None, + ) + return list(result) + + +# ── _build_model_config_info 彩蛋状态 ── + + +def _get_config_info(builder: PromptBuilder) -> str: + getter = builder._runtime_config_getter + assert getter is not None + info = builder._build_model_config_info(getter()) + return str(info) + + +def test_model_config_info_shows_easter_egg_disabled() -> None: + builder = _make_builder() + info = _get_config_info(builder) + assert "彩蛋功能: 未启用" in info + + +def test_model_config_info_shows_keyword_reply_enabled() -> None: + builder = _make_builder(keyword_reply_enabled=True) + info = _get_config_info(builder) + assert "关键词自动回复" in info + assert "彩蛋功能: " in info + + +def test_model_config_info_shows_repeat_enabled() -> None: + builder = _make_builder(repeat_enabled=True) + info = _get_config_info(builder) + assert "复读" in info + assert "连续3条相同消息" in info + + +def test_model_config_info_shows_repeat_with_inverted_question() -> None: + builder = _make_builder(repeat_enabled=True, inverted_question_enabled=True) + info = _get_config_info(builder) + assert "倒问号" in info + assert "¿" in info + + +def test_model_config_info_shows_inverted_question_without_repeat() -> None: + builder = _make_builder(inverted_question_enabled=True) + info = _get_config_info(builder) + assert "倒问号" in info + assert "复读未启用" in info + + +def test_model_config_info_shows_agent_call_mode() -> None: + builder = _make_builder(easter_egg_agent_call_message_mode="clean") + info = _get_config_info(builder) + assert "降噪调用提示" in info + + +# ── 群聊上下文系统行为注入 ── + + +@pytest.mark.asyncio +async def test_repeat_injection_in_group_context() -> None: + builder = _make_builder(repeat_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + repeat_injected = any("[系统复读]" in c for c in system_contents) + assert repeat_injected, "复读彩蛋说明应注入群聊上下文" + + +@pytest.mark.asyncio +async def test_repeat_injection_not_in_private_context() -> None: + builder = _make_builder(repeat_enabled=True) + messages = await _build_messages(builder, group_id=None) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + repeat_injected = any("[系统复读]" in c for c in system_contents) + assert not repeat_injected, "复读彩蛋说明不应注入非群聊上下文" + + +@pytest.mark.asyncio +async def test_inverted_question_mentioned_in_repeat_injection() -> None: + builder = _make_builder(repeat_enabled=True, inverted_question_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + inverted_injected = any("倒问号" in c for c in system_contents) + assert inverted_injected, "倒问号说明应在复读注入中出现" + + +@pytest.mark.asyncio +async def test_keyword_reply_injection_still_works() -> None: + builder = _make_builder(keyword_reply_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + keyword_injected = any("[系统关键词自动回复]" in c for c in system_contents) + assert keyword_injected, "关键词自动回复说明仍应注入" diff --git a/tests/test_prompts_level.py b/tests/test_prompts_level.py new file mode 100644 index 00000000..239786af --- /dev/null +++ b/tests/test_prompts_level.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [] + + +class _FakeCognitiveService: + enabled = False + + +class _FakeAnthropicSkillRegistry: + def has_skills(self) -> bool: + return False + + +def _make_builder() -> PromptBuilder: + """创建用于测试的 PromptBuilder 实例""" + runtime_config = SimpleNamespace( + keyword_reply_enabled=False, + knowledge_enabled=False, + grok_search_enabled=False, + chat_model=SimpleNamespace( + model_name="gpt-4.1", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=False, + ), + vision_model=SimpleNamespace(model_name="gpt-4.1-mini"), + agent_model=SimpleNamespace(model_name="gpt-4.1-mini"), + embedding_model=SimpleNamespace(model_name="text-embedding-3-small"), + security_model=SimpleNamespace(model_name="gpt-4.1-mini"), + grok_model=SimpleNamespace(model_name="grok-4-search"), + cognitive=SimpleNamespace(enabled=False, recent_end_summaries_inject_k=0), + memes=SimpleNamespace( + enabled=False, + query_default_mode="hybrid", + allow_gif=False, + max_source_image_bytes=512000, + ), + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=None, + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, _FakeAnthropicSkillRegistry()), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +@pytest.mark.asyncio +async def test_group_message_with_level_includes_level_attribute( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试群消息有 level 时 XML 包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "Lv.5", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert 'level="Lv.5"' in history_message + assert " None: + """测试群消息 level 为空字符串时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + """测试群消息没有 level 键时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + """测试私聊消息无论是否有 level 都不会出现 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊测试消息", + "attachments": [], + "level": "Lv.10", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "sender_id": 10001, + "sender_name": "测试用户", + "request_type": "private", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + assert resolve_emoji_id_by_alias("微笑") == 14 + + def test_known_english_alias(self) -> None: + assert resolve_emoji_id_by_alias("smile") == 14 + + def test_known_unicode_emoji(self) -> None: + assert resolve_emoji_id_by_alias("👍") == 76 + + def test_case_insensitive(self) -> None: + assert resolve_emoji_id_by_alias("SMILE") == 14 + assert resolve_emoji_id_by_alias("Smile") == 14 + + def test_whitespace_stripped(self) -> None: + assert resolve_emoji_id_by_alias(" smile ") == 14 + + def test_unknown_alias(self) -> None: + assert resolve_emoji_id_by_alias("completely_unknown_emoji_xyz") is None + + def test_empty_string(self) -> None: + assert resolve_emoji_id_by_alias("") is None + + def test_whitespace_only(self) -> None: + assert resolve_emoji_id_by_alias(" ") is None + + +# --------------------------------------------------------------------------- +# search_emoji_aliases +# --------------------------------------------------------------------------- + + +class TestSearchEmojiAliases: + def test_search_finds_matching(self) -> None: + results = search_emoji_aliases("笑") + assert len(results) > 0 + for alias, _eid in results: + assert "笑" in alias + + def test_search_limit(self) -> None: + results = search_emoji_aliases("笑", limit=2) + assert len(results) <= 2 + + def test_search_no_match(self) -> None: + results = search_emoji_aliases("zzz_no_match_xyz") + assert results == [] + + def test_search_empty_keyword(self) -> None: + results = search_emoji_aliases("") + assert results == [] + + def test_search_returns_tuples(self) -> None: + results = search_emoji_aliases("赞") + assert len(results) > 0 + for item in results: + assert isinstance(item, tuple) + assert isinstance(item[0], str) + assert isinstance(item[1], int) + + def test_search_case_insensitive(self) -> None: + r1 = search_emoji_aliases("ok") + r2 = search_emoji_aliases("OK") + assert r1 == r2 + + def test_search_sorted_by_id_then_alias(self) -> None: + results = search_emoji_aliases("笑") + if len(results) >= 2: + for i in range(len(results) - 1): + assert (results[i][1], results[i][0]) <= ( + results[i + 1][1], + results[i + 1][0], + ) + + +# --------------------------------------------------------------------------- +# get_emoji_id_entries +# --------------------------------------------------------------------------- + + +class TestGetEmojiIdEntries: + def test_returns_list(self) -> None: + entries = get_emoji_id_entries() + assert isinstance(entries, list) + assert len(entries) > 0 + + def test_entries_structure(self) -> None: + entries = get_emoji_id_entries() + for emoji_id, aliases in entries: + assert isinstance(emoji_id, int) + assert isinstance(aliases, list) + assert all(isinstance(a, str) for a in aliases) + + def test_entries_sorted_by_id(self) -> None: + entries = get_emoji_id_entries() + ids = [eid for eid, _ in entries] + assert ids == sorted(ids) + + def test_aliases_sorted(self) -> None: + entries = get_emoji_id_entries() + for _, aliases in entries: + assert aliases == sorted(aliases) + + def test_known_emoji_in_entries(self) -> None: + entries = get_emoji_id_entries() + id_map = {eid: aliases for eid, aliases in entries} + assert 76 in id_map + assert "赞" in id_map[76] + + +# --------------------------------------------------------------------------- +# get_emoji_alias_map +# --------------------------------------------------------------------------- + + +class TestGetEmojiAliasMap: + def test_returns_dict(self) -> None: + m = get_emoji_alias_map() + assert isinstance(m, dict) + assert len(m) > 0 + + def test_contains_known_entries(self) -> None: + m = get_emoji_alias_map() + assert m.get("微笑") == 14 + assert m.get("👍") == 76 diff --git a/tests/test_queue_intervals.py b/tests/test_queue_intervals.py index 014ff48e..211440f7 100644 --- a/tests/test_queue_intervals.py +++ b/tests/test_queue_intervals.py @@ -63,6 +63,10 @@ def test_zero_queue_intervals_are_preserved_for_immediate_dispatch( model_name = "historian-model" queue_interval_seconds = 0 +[models.summary] +model_name = "summary-model" +queue_interval_seconds = 0 + [models.grok] api_url = "https://grok.example/v1" api_key = "sk-grok" @@ -92,6 +96,7 @@ def test_zero_queue_intervals_are_preserved_for_immediate_dispatch( assert cfg.naga_model.queue_interval_seconds == 0.0 assert cfg.agent_model.queue_interval_seconds == 0.0 assert cfg.historian_model.queue_interval_seconds == 0.0 + assert cfg.summary_model.queue_interval_seconds == 0.0 assert cfg.grok_model.queue_interval_seconds == 0.0 assert cfg.embedding_model.queue_interval_seconds == 0.0 assert cfg.rerank_model.queue_interval_seconds == 0.0 @@ -100,6 +105,7 @@ def test_zero_queue_intervals_are_preserved_for_immediate_dispatch( assert queue_manager.get_interval("chat-model") == 0.0 assert queue_manager.get_interval("chat-pool-model") == 0.0 assert queue_manager.get_interval("agent-model") == 0.0 + assert queue_manager.get_interval("summary-model") == 0.0 assert queue_manager.get_interval("grok-model") == 0.0 assert queue_manager.get_interval("naga-model") == 0.0 @@ -141,6 +147,10 @@ def test_negative_queue_intervals_still_fall_back_to_defaults(tmp_path: Path) -> model_name = "historian-model" queue_interval_seconds = -1 +[models.summary] +model_name = "summary-model" +queue_interval_seconds = -1 + [models.grok] api_url = "https://grok.example/v1" api_key = "sk-grok" @@ -168,6 +178,7 @@ def test_negative_queue_intervals_still_fall_back_to_defaults(tmp_path: Path) -> assert cfg.vision_model.queue_interval_seconds == 1.0 assert cfg.agent_model.queue_interval_seconds == 0.5 assert cfg.historian_model.queue_interval_seconds == 0.5 + assert cfg.summary_model.queue_interval_seconds == 0.5 assert cfg.grok_model.queue_interval_seconds == 1.0 assert cfg.embedding_model.queue_interval_seconds == 0.0 assert cfg.rerank_model.queue_interval_seconds == 0.0 diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 00000000..a38c0efa --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,267 @@ +"""RateLimiter 单元测试""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Any, cast + +from Undefined.rate_limit import RateLimiter + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +class _MockConfig: + """最小化的 Config mock。""" + + def __init__( + self, + superadmins: set[int] | None = None, + admins: set[int] | None = None, + ) -> None: + self._superadmins = superadmins or set() + self._admins = admins or set() + + def is_superadmin(self, user_id: int) -> bool: + return user_id in self._superadmins + + def is_admin(self, user_id: int) -> bool: + return user_id in self._admins + + +@dataclass +class _MockCommandRateLimit: + """模拟 CommandRateLimit。""" + + user: int = 10 + admin: int = 5 + superadmin: int = 0 + + +# --------------------------------------------------------------------------- +# 基本限流 (check / record) +# --------------------------------------------------------------------------- + + +class TestRateLimiterCheck: + def test_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, remaining = limiter.check(1001) + assert allowed is True + assert remaining == 0 + + def test_second_call_blocked(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + allowed, remaining = limiter.check(1001) + assert allowed is False + assert remaining > 0 + + def test_superadmin_always_allowed(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record(1001) + allowed, _ = limiter.check(1001) + assert allowed is True + + def test_admin_shorter_cooldown(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + # 模拟 admin 在较短冷却期后可以调用 + limiter._last_calls[2001] = time.time() - RateLimiter.ADMIN_COOLDOWN - 1 + allowed, _ = limiter.check(2001) + assert allowed is True + + def test_normal_user_cooldown(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_calls[3001] = time.time() - RateLimiter.USER_COOLDOWN + 2 + allowed, remaining = limiter.check(3001) + assert allowed is False + assert remaining >= 1 + + def test_cooldown_expires(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_calls[3001] = time.time() - RateLimiter.USER_COOLDOWN - 1 + allowed, _ = limiter.check(3001) + assert allowed is True + + +class TestRateLimiterRecord: + def test_record_stores_time(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + assert 1001 in limiter._last_calls + + def test_record_superadmin_skipped(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record(1001) + assert 1001 not in limiter._last_calls + + +# --------------------------------------------------------------------------- +# /ask 限流 +# --------------------------------------------------------------------------- + + +class TestRateLimiterAsk: + def test_ask_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_ask_blocked_within_cooldown(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_ask(1001) + allowed, remaining = limiter.check_ask(1001) + assert allowed is False + assert remaining > 0 + + def test_ask_superadmin_bypass(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_ask(1001) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_ask_cooldown_expires(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_ask_calls[1001] = time.time() - RateLimiter.ASK_COOLDOWN - 1 + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + +# --------------------------------------------------------------------------- +# /stats 限流 +# --------------------------------------------------------------------------- + + +class TestRateLimiterStats: + def test_stats_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_stats_blocked_for_normal_user(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_stats(1001) + allowed, remaining = limiter.check_stats(1001) + assert allowed is False + assert remaining > 0 + + def test_stats_admin_bypass(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(2001) + allowed, _ = limiter.check_stats(2001) + assert allowed is True + + def test_stats_superadmin_bypass(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(1001) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_stats_record_skipped_for_admin(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(2001) + assert 2001 not in limiter._last_stats_calls + + +# --------------------------------------------------------------------------- +# 动态命令限流 (check_command / record_command) +# --------------------------------------------------------------------------- + + +class TestRateLimiterCommand: + def test_command_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit() + allowed, _ = limiter.check_command(1001, "test_cmd", cast(Any, limits)) + assert allowed is True + + def test_command_blocked_after_record(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit(user=10) + limiter.record_command(1001, "cmd", cast(Any, limits)) + allowed, remaining = limiter.check_command(1001, "cmd", cast(Any, limits)) + assert allowed is False + assert remaining > 0 + + def test_command_superadmin_zero_cooldown(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limits = _MockCommandRateLimit(superadmin=0) + limiter.record_command(1001, "cmd", cast(Any, limits)) + allowed, _ = limiter.check_command(1001, "cmd", cast(Any, limits)) + assert allowed is True + + def test_command_admin_shorter_cooldown(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limits = _MockCommandRateLimit(admin=5, user=60) + limiter._command_calls.setdefault("cmd", {})[2001] = time.time() - 6 + allowed, _ = limiter.check_command(2001, "cmd", cast(Any, limits)) + assert allowed is True + + def test_command_different_commands_independent(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit(user=60) + limiter.record_command(1001, "cmd_a", cast(Any, limits)) + allowed, _ = limiter.check_command(1001, "cmd_b", cast(Any, limits)) + assert allowed is True + + +# --------------------------------------------------------------------------- +# clear 方法 +# --------------------------------------------------------------------------- + + +class TestRateLimiterClear: + def test_clear_removes_user(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + limiter.clear(1001) + allowed, _ = limiter.check(1001) + assert allowed is True + + def test_clear_ask(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_ask(1001) + limiter.clear_ask(1001) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_clear_stats(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_stats(1001) + limiter.clear_stats(1001) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_clear_all(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + limiter.record_ask(1001) + limiter.record_stats(1001) + limits = _MockCommandRateLimit() + limiter.record_command(1001, "cmd", cast(Any, limits)) + limiter.clear_all() + assert limiter._last_calls == {} + assert limiter._last_ask_calls == {} + assert limiter._last_stats_calls == {} + assert limiter._command_calls == {} + + def test_clear_nonexistent_user_no_error(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.clear(9999) # 不应抛出异常 + limiter.clear_ask(9999) + limiter.clear_stats(9999) diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py new file mode 100644 index 00000000..130991a3 --- /dev/null +++ b/tests/test_render_latex_tool.py @@ -0,0 +1,206 @@ +"""测试 LaTeX 渲染工具(MathJax + Playwright 实现)""" + +from __future__ import annotations + +import pytest +from typing import Any + +# 这个测试需要 Playwright 浏览器运行时,所以标记为可选 +pytest_plugins = ("pytest_asyncio",) + + +class MockAttachmentRegistry: + """模拟附件注册表""" + + def __init__(self) -> None: + self.registered_items: list[dict[str, Any]] = [] + + async def register_bytes( + self, + scope_key: str, + data: bytes, + kind: str, + display_name: str, + mime_type: str, + source_kind: str, + source_ref: str, + ) -> Any: + class MockRecord: + uid = "test-uid-12345" + + record = MockRecord() + self.registered_items.append( + { + "scope_key": scope_key, + "size": len(data), + "kind": kind, + "display_name": display_name, + "mime_type": mime_type, + "source_kind": source_kind, + "source_ref": source_ref, + "uid": record.uid, + } + ) + return record + + +@pytest.mark.asyncio +async def test_render_simple_equation() -> None: + """测试渲染简单方程(无分隔符,自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "group", + "group_id": 123456, + } + + args = {"content": "E = mc^2", "output_format": "png"} + + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "image" + assert mock_registry.registered_items[0]["mime_type"] == "image/png" + assert mock_registry.registered_items[0]["size"] > 0 + + +@pytest.mark.asyncio +async def test_render_with_delimiters() -> None: + """测试带分隔符的内容(不自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "private", + "user_id": 987654, + } + + args = {"content": r"\[ \int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2} \]"} + + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 + + +@pytest.mark.asyncio +async def test_render_pdf_output() -> None: + """测试 PDF 输出格式使用 attachment 标签""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "group", + "group_id": 123456, + } + + args = {"content": r"\frac{a}{b} + \sqrt{c}", "output_format": "pdf"} + + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "file" + assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" + assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" + + +@pytest.mark.asyncio +async def test_empty_content_error() -> None: + """测试空内容错误处理""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": " "} + + result = await execute(args, context) + assert "不能为空" in result + + +@pytest.mark.asyncio +async def test_invalid_output_format() -> None: + """测试无效输出格式""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": "x = 1", "output_format": "svg"} + + result = await execute(args, context) + assert "无效" in result or "仅支持" in result + + +def test_strip_document_wrappers() -> None: + """测试去除 document 包装""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _strip_document_wrappers, + ) + + content = r"\begin{document}E = mc^2\end{document}" + result = _strip_document_wrappers(content) + assert result == "E = mc^2" + + # 没有包装的内容应该原样返回 + content_no_wrapper = r"E = mc^2" + result_no_wrapper = _strip_document_wrappers(content_no_wrapper) + assert result_no_wrapper == "E = mc^2" + + +def test_has_math_delimiters() -> None: + """测试数学分隔符检测""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _has_math_delimiters, + ) + + assert _has_math_delimiters(r"\[ x = 1 \]") is True + assert _has_math_delimiters(r"\( x = 1 \)") is True + assert _has_math_delimiters(r"$$ x = 1 $$") is True + assert _has_math_delimiters(r"\begin{equation}") is True + assert _has_math_delimiters("x = 1") is False + + +def test_prepare_content() -> None: + """测试内容准备逻辑""" + from Undefined.skills.toolsets.render.render_latex.handler import _prepare_content + + # 无分隔符,自动包装 + result = _prepare_content("E = mc^2") + assert result.startswith(r"\[") + assert result.endswith(r"\]") + assert "E = mc^2" in result + + # 有分隔符,不包装 + result_with_delim = _prepare_content(r"\[ E = mc^2 \]") + assert result_with_delim == r"\[ E = mc^2 \]" + + # 字面量 \\n 处理(后面不跟字母时替换为换行) + result_newline = _prepare_content("x = 1\\n2 = y") + assert "\n" in result_newline + assert "x = 1" in result_newline + assert "2 = y" in result_newline + + # LaTeX 命令不被破坏:\nu \nabla \neq 保持不变 + result_latex = _prepare_content(r"\nu + \nabla \neq 0") + assert r"\nu" in result_latex + assert r"\nabla" in result_latex + assert r"\neq" in result_latex + + +def test_build_html_contains_mathjax_ready_flag() -> None: + """HTML 模板包含 MathJax pageReady 回调设置 _mjReady 标记""" + from Undefined.skills.toolsets.render.render_latex.handler import _build_html + + html = _build_html(r"\[ x = 1 \]") + assert "window._mjReady = true" in html + assert "pageReady" in html + assert "tex-svg.js" in html + assert "math-container" in html diff --git a/tests/test_request_params.py b/tests/test_request_params.py new file mode 100644 index 00000000..d5e09d3a --- /dev/null +++ b/tests/test_request_params.py @@ -0,0 +1,141 @@ +"""Tests for Undefined.utils.request_params — request param helpers.""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import Any + +from Undefined.utils.request_params import ( + merge_request_params, + normalize_request_params, + split_reserved_request_params, +) + + +class TestNormalizeRequestParams: + def test_dict_passthrough(self) -> None: + result = normalize_request_params({"a": 1, "b": "two"}) + assert result == {"a": 1, "b": "two"} + + def test_none_returns_empty(self) -> None: + assert normalize_request_params(None) == {} + + def test_non_dict_returns_empty(self) -> None: + assert normalize_request_params("string") == {} + assert normalize_request_params(42) == {} + assert normalize_request_params([1, 2]) == {} + + def test_empty_dict(self) -> None: + assert normalize_request_params({}) == {} + + def test_nested_dict_cloned(self) -> None: + original: dict[str, Any] = {"inner": {"x": 1}} + result = normalize_request_params(original) + assert result == {"inner": {"x": 1}} + # Must be a deep copy + assert result["inner"] is not original["inner"] + + def test_list_cloned(self) -> None: + original: dict[str, Any] = {"items": [1, 2, {"a": 3}]} + result = normalize_request_params(original) + assert result["items"] == [1, 2, {"a": 3}] + assert result["items"] is not original["items"] + + def test_tuple_converted_to_list(self) -> None: + result = normalize_request_params({"t": (1, 2, 3)}) + assert result["t"] == [1, 2, 3] + assert isinstance(result["t"], list) + + def test_non_json_value_stringified(self) -> None: + result = normalize_request_params({"obj": object()}) + assert isinstance(result["obj"], str) + + def test_keys_stringified(self) -> None: + result = normalize_request_params({1: "a", 2: "b"}) + assert "1" in result + assert "2" in result + + def test_ordered_dict_accepted(self) -> None: + od = OrderedDict([("z", 1), ("a", 2)]) + result = normalize_request_params(od) + assert result == {"z": 1, "a": 2} + + def test_bool_preserved(self) -> None: + result = normalize_request_params({"flag": True}) + assert result["flag"] is True + + def test_none_value_preserved(self) -> None: + result = normalize_request_params({"key": None}) + assert result["key"] is None + + +class TestMergeRequestParams: + def test_single_dict(self) -> None: + result = merge_request_params({"a": 1}) + assert result == {"a": 1} + + def test_two_dicts_merged(self) -> None: + result = merge_request_params({"a": 1}, {"b": 2}) + assert result == {"a": 1, "b": 2} + + def test_later_overrides_earlier(self) -> None: + result = merge_request_params({"a": 1}, {"a": 2}) + assert result["a"] == 2 + + def test_none_skipped(self) -> None: + result = merge_request_params(None, {"a": 1}) + assert result == {"a": 1} + + def test_non_dict_skipped(self) -> None: + result = merge_request_params("bad", {"a": 1}, 42) + assert result == {"a": 1} + + def test_empty_args(self) -> None: + result = merge_request_params() + assert result == {} + + def test_multiple_merges(self) -> None: + result = merge_request_params({"a": 1}, {"b": 2}, {"c": 3}) + assert result == {"a": 1, "b": 2, "c": 3} + + +class TestSplitReservedRequestParams: + def test_basic_split(self) -> None: + allowed, reserved = split_reserved_request_params( + {"a": 1, "b": 2, "c": 3}, {"b", "c"} + ) + assert allowed == {"a": 1} + assert reserved == {"b": 2, "c": 3} + + def test_no_reserved_keys(self) -> None: + allowed, reserved = split_reserved_request_params({"a": 1}, set()) + assert allowed == {"a": 1} + assert reserved == {} + + def test_all_reserved(self) -> None: + allowed, reserved = split_reserved_request_params({"a": 1, "b": 2}, {"a", "b"}) + assert allowed == {} + assert reserved == {"a": 1, "b": 2} + + def test_none_params(self) -> None: + allowed, reserved = split_reserved_request_params(None, {"a"}) + assert allowed == {} + assert reserved == {} + + def test_empty_params(self) -> None: + allowed, reserved = split_reserved_request_params({}, {"a"}) + assert allowed == {} + assert reserved == {} + + def test_frozenset_keys(self) -> None: + allowed, reserved = split_reserved_request_params( + {"x": 1, "y": 2}, frozenset({"x"}) + ) + assert allowed == {"y": 2} + assert reserved == {"x": 1} + + def test_nested_values_cloned(self) -> None: + original: dict[str, Any] = {"deep": {"nested": True}, "keep": "val"} + allowed, reserved = split_reserved_request_params(original, {"deep"}) + assert reserved["deep"] == {"nested": True} + assert reserved["deep"] is not original["deep"] diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py index c1f6e995..09738cbf 100644 --- a/tests/test_runtime_api_chat_stream.py +++ b/tests/test_runtime_api_chat_stream.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import chat as runtime_api_chat class _DummyTransport: @@ -90,18 +90,18 @@ async def _fake_render_message_with_pic_placeholders( ) server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) - async def _fake_run_webui_chat(*, text: str, send_output: Any) -> str: + async def _fake_run_webui_chat(_ctx: Any, *, text: str, send_output: Any) -> str: assert text == "hello" await send_output(42, "bot reply with ") return "chat" monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "render_message_with_pic_placeholders", _fake_render_message_with_pic_placeholders, ) monkeypatch.setattr(web, "StreamResponse", _DummyStreamResponse) - monkeypatch.setattr(server, "_run_webui_chat", _fake_run_webui_chat) + monkeypatch.setattr(runtime_api_chat, "run_webui_chat", _fake_run_webui_chat) request = cast( web.Request, @@ -173,11 +173,11 @@ async def _fake_ask(full_question: str, **kwargs: Any) -> str: server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "register_message_attachments", _fake_register_message_attachments, ) - monkeypatch.setattr(runtime_api_app, "collect_context_resources", lambda _vars: {}) + monkeypatch.setattr(runtime_api_chat, "collect_context_resources", lambda _vars: {}) sent_messages: list[tuple[int, str]] = [] diff --git a/tests/test_runtime_api_naga.py b/tests/test_runtime_api_naga.py index 55a66b46..739919d8 100644 --- a/tests/test_runtime_api_naga.py +++ b/tests/test_runtime_api_naga.py @@ -608,9 +608,11 @@ async def _render_html_to_image(_: str, __: str) -> None: raise RuntimeError("render failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html + ) + monkeypatch.setattr( + "Undefined.api.routes.naga.render_html_to_image", _render_html_to_image ) - monkeypatch.setattr("Undefined.api.app.render_html_to_image", _render_html_to_image) response = await server._naga_messages_send_handler( _make_request( @@ -665,7 +667,7 @@ async def _render_markdown_to_html(_: str) -> str: raise RuntimeError("markdown failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html ) response = await server._naga_messages_send_handler( diff --git a/tests/test_runtime_api_probes.py b/tests/test_runtime_api_probes.py index 3d6e0dd9..940f32f9 100644 --- a/tests/test_runtime_api_probes.py +++ b/tests/test_runtime_api_probes.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import system as runtime_api_system @pytest.mark.asyncio @@ -166,9 +166,11 @@ async def _fake_probe_ws_endpoint(_: str) -> dict[str, Any]: } monkeypatch.setattr( - runtime_api_app, "_probe_http_endpoint", _fake_probe_http_endpoint + runtime_api_system, "_probe_http_endpoint", _fake_probe_http_endpoint + ) + monkeypatch.setattr( + runtime_api_system, "_probe_ws_endpoint", _fake_probe_ws_endpoint ) - monkeypatch.setattr(runtime_api_app, "_probe_ws_endpoint", _fake_probe_ws_endpoint) context = RuntimeAPIContext( config_getter=lambda: SimpleNamespace( diff --git a/tests/test_runtime_api_sender_proxy.py b/tests/test_runtime_api_sender_proxy.py index be0c82c0..77fc8058 100644 --- a/tests/test_runtime_api_sender_proxy.py +++ b/tests/test_runtime_api_sender_proxy.py @@ -2,7 +2,7 @@ import pytest -from Undefined.api.app import _WebUIVirtualSender +from Undefined.api._helpers import _WebUIVirtualSender @pytest.mark.asyncio diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 9db13303..e3c66159 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -10,7 +10,7 @@ from aiohttp.web_response import Response from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api.app import _validate_callback_url +from Undefined.api._helpers import _validate_callback_url def _json(response: Response) -> Any: @@ -365,7 +365,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-tool", @@ -391,7 +391,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-agent", diff --git a/tests/test_scheduled_task_unit.py b/tests/test_scheduled_task_unit.py new file mode 100644 index 00000000..93866a4d --- /dev/null +++ b/tests/test_scheduled_task_unit.py @@ -0,0 +1,192 @@ +"""ScheduledTask / ToolCall 序列化 单元测试""" + +from __future__ import annotations + +from typing import Any + + +from Undefined.scheduled_task_storage import ScheduledTask, ToolCall + + +# --------------------------------------------------------------------------- +# ToolCall +# --------------------------------------------------------------------------- + + +class TestToolCall: + def test_fields(self) -> None: + tc = ToolCall(tool_name="search", tool_args={"q": "test"}) + assert tc.tool_name == "search" + assert tc.tool_args == {"q": "test"} + + +# --------------------------------------------------------------------------- +# ScheduledTask — to_dict / from_dict 往返 +# --------------------------------------------------------------------------- + + +def _sample_task_dict() -> dict[str, Any]: + return { + "task_id": "task-001", + "tool_name": "search", + "tool_args": {"q": "test"}, + "cron": "0 9 * * *", + "target_id": 12345, + "target_type": "group", + "task_name": "每日搜索", + "max_executions": 10, + "current_executions": 3, + "created_at": "2025-01-01T00:00:00", + "context_id": "ctx-1", + "tools": [ + {"tool_name": "search", "tool_args": {"q": "test"}}, + {"tool_name": "notify", "tool_args": {"msg": "done"}}, + ], + "execution_mode": "parallel", + } + + +class TestScheduledTaskRoundtrip: + def test_basic_roundtrip(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + restored = task.to_dict() + assert restored["task_id"] == "task-001" + assert restored["cron"] == "0 9 * * *" + assert restored["execution_mode"] == "parallel" + assert len(restored["tools"]) == 2 + + def test_tools_are_toolcall_instances(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + assert task.tools is not None + for tc in task.tools: + assert isinstance(tc, ToolCall) + + def test_to_dict_tools_are_dicts(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + restored = task.to_dict() + for tool in restored["tools"]: + assert isinstance(tool, dict) + assert "tool_name" in tool + + +# --------------------------------------------------------------------------- +# 向后兼容 — 旧格式无 tools +# --------------------------------------------------------------------------- + + +class TestScheduledTaskBackwardCompat: + def test_legacy_without_tools_field(self) -> None: + """旧格式只有 tool_name/tool_args,没有 tools 字段。""" + d: dict[str, Any] = { + "task_id": "legacy-1", + "tool_name": "old_tool", + "tool_args": {"key": "val"}, + "cron": "*/5 * * * *", + "target_id": None, + "target_type": "private", + "task_name": "旧任务", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.tools is not None + assert len(task.tools) == 1 + assert task.tools[0].tool_name == "old_tool" + assert task.tools[0].tool_args == {"key": "val"} + + def test_legacy_empty_tools_uses_tool_name(self) -> None: + """tools 为空列表时,回退到 tool_name。""" + d: dict[str, Any] = { + "task_id": "legacy-2", + "tool_name": "fallback", + "tool_args": {}, + "tools": [], + "cron": "0 0 * * *", + "target_id": 1, + "target_type": "group", + "task_name": "fallback task", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.tools is not None + assert len(task.tools) == 1 + assert task.tools[0].tool_name == "fallback" + + +# --------------------------------------------------------------------------- +# 可选字段缺失 +# --------------------------------------------------------------------------- + + +class TestScheduledTaskOptionalFields: + def test_missing_context_id(self) -> None: + d: dict[str, Any] = { + "task_id": "t1", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.context_id is None + + def test_missing_current_executions(self) -> None: + d: dict[str, Any] = { + "task_id": "t2", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": 1, + "target_type": "private", + "task_name": "n", + "max_executions": 5, + } + task = ScheduledTask.from_dict(d) + assert task.current_executions == 0 + + def test_missing_created_at(self) -> None: + d: dict[str, Any] = { + "task_id": "t3", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.created_at == "" + + def test_default_execution_mode(self) -> None: + d: dict[str, Any] = { + "task_id": "t4", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.execution_mode == "serial" + + def test_max_executions_none(self) -> None: + d: dict[str, Any] = { + "task_id": "t5", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.max_executions is None diff --git a/tests/test_skills_http_client.py b/tests/test_skills_http_client.py new file mode 100644 index 00000000..61479e9c --- /dev/null +++ b/tests/test_skills_http_client.py @@ -0,0 +1,76 @@ +"""Tests for Undefined.skills.http_client module (pure functions only).""" + +from __future__ import annotations + +from Undefined.skills.http_client import _retry_delay, _should_retry_http_status + + +class TestShouldRetryHttpStatus: + """Tests for _should_retry_http_status().""" + + def test_429_should_retry(self) -> None: + assert _should_retry_http_status(429) is True + + def test_500_should_retry(self) -> None: + assert _should_retry_http_status(500) is True + + def test_502_should_retry(self) -> None: + assert _should_retry_http_status(502) is True + + def test_503_should_retry(self) -> None: + assert _should_retry_http_status(503) is True + + def test_504_should_retry(self) -> None: + assert _should_retry_http_status(504) is True + + def test_599_should_retry(self) -> None: + assert _should_retry_http_status(599) is True + + def test_200_should_not_retry(self) -> None: + assert _should_retry_http_status(200) is False + + def test_201_should_not_retry(self) -> None: + assert _should_retry_http_status(201) is False + + def test_400_should_not_retry(self) -> None: + assert _should_retry_http_status(400) is False + + def test_401_should_not_retry(self) -> None: + assert _should_retry_http_status(401) is False + + def test_403_should_not_retry(self) -> None: + assert _should_retry_http_status(403) is False + + def test_404_should_not_retry(self) -> None: + assert _should_retry_http_status(404) is False + + def test_600_should_not_retry(self) -> None: + assert _should_retry_http_status(600) is False + + def test_428_should_not_retry(self) -> None: + assert _should_retry_http_status(428) is False + + +class TestRetryDelay: + """Tests for _retry_delay().""" + + def test_attempt_0(self) -> None: + assert _retry_delay(0) == 0.25 # min(2.0, 0.25 * 2^0) = 0.25 + + def test_attempt_1(self) -> None: + assert _retry_delay(1) == 0.5 # min(2.0, 0.25 * 2^1) = 0.5 + + def test_attempt_2(self) -> None: + assert _retry_delay(2) == 1.0 # min(2.0, 0.25 * 2^2) = 1.0 + + def test_attempt_3(self) -> None: + assert _retry_delay(3) == 2.0 # min(2.0, 0.25 * 2^3) = 2.0 + + def test_attempt_4_capped(self) -> None: + assert _retry_delay(4) == 2.0 # min(2.0, 0.25 * 2^4) = min(2.0, 4.0) = 2.0 + + def test_attempt_5_capped(self) -> None: + assert _retry_delay(5) == 2.0 # capped at 2.0 + + def test_returns_float(self) -> None: + assert isinstance(_retry_delay(0), float) diff --git a/tests/test_skills_http_config.py b/tests/test_skills_http_config.py new file mode 100644 index 00000000..7ed41ebe --- /dev/null +++ b/tests/test_skills_http_config.py @@ -0,0 +1,98 @@ +"""Tests for Undefined.skills.http_config module (pure functions only).""" + +from __future__ import annotations + +from Undefined.skills.http_config import _normalize_base_url, build_url + + +class TestBuildUrl: + """Tests for build_url().""" + + def test_simple_join(self) -> None: + assert ( + build_url("https://api.example.com", "/v1/data") + == "https://api.example.com/v1/data" + ) + + def test_trailing_slash_on_base(self) -> None: + assert ( + build_url("https://api.example.com/", "/v1/data") + == "https://api.example.com/v1/data" + ) + + def test_multiple_trailing_slashes(self) -> None: + assert ( + build_url("https://api.example.com///", "/v1") + == "https://api.example.com/v1" + ) + + def test_path_without_leading_slash(self) -> None: + assert ( + build_url("https://api.example.com", "v1/data") + == "https://api.example.com/v1/data" + ) + + def test_empty_path(self) -> None: + assert build_url("https://api.example.com", "") == "https://api.example.com/" + + def test_path_is_slash_only(self) -> None: + assert build_url("https://api.example.com", "/") == "https://api.example.com/" + + def test_base_with_subpath(self) -> None: + assert ( + build_url("https://api.example.com/v2", "/users") + == "https://api.example.com/v2/users" + ) + + def test_base_with_subpath_trailing_slash(self) -> None: + assert ( + build_url("https://api.example.com/v2/", "/users") + == "https://api.example.com/v2/users" + ) + + +class TestNormalizeBaseUrl: + """Tests for _normalize_base_url().""" + + def test_normal_url(self) -> None: + assert ( + _normalize_base_url("https://api.example.com", "https://fallback.com") + == "https://api.example.com" + ) + + def test_trailing_slash_removed(self) -> None: + assert ( + _normalize_base_url("https://api.example.com/", "https://fallback.com") + == "https://api.example.com" + ) + + def test_multiple_trailing_slashes(self) -> None: + assert ( + _normalize_base_url("https://api.example.com///", "https://fallback.com") + == "https://api.example.com" + ) + + def test_empty_value_uses_fallback(self) -> None: + assert _normalize_base_url("", "https://fallback.com") == "https://fallback.com" + + def test_whitespace_only_uses_fallback(self) -> None: + assert ( + _normalize_base_url(" ", "https://fallback.com") == "https://fallback.com" + ) + + def test_fallback_trailing_slash_stripped(self) -> None: + assert ( + _normalize_base_url("", "https://fallback.com/") == "https://fallback.com" + ) + + def test_leading_trailing_whitespace_stripped(self) -> None: + assert ( + _normalize_base_url(" https://api.example.com ", "https://fallback.com") + == "https://api.example.com" + ) + + def test_value_with_path(self) -> None: + assert ( + _normalize_base_url("https://api.example.com/v2/", "https://fallback.com") + == "https://api.example.com/v2" + ) diff --git a/tests/test_skills_registry_stats.py b/tests/test_skills_registry_stats.py new file mode 100644 index 00000000..d46fabdb --- /dev/null +++ b/tests/test_skills_registry_stats.py @@ -0,0 +1,100 @@ +"""Tests for Undefined.skills.registry.SkillStats dataclass.""" + +from __future__ import annotations + +from Undefined.skills.registry import SkillStats + + +class TestSkillStats: + """Tests for SkillStats dataclass.""" + + def test_initial_state(self) -> None: + stats = SkillStats() + assert stats.count == 0 + assert stats.success == 0 + assert stats.failure == 0 + assert stats.total_duration == 0.0 + assert stats.last_duration == 0.0 + assert stats.last_error is None + assert stats.last_called_at is None + + def test_record_success(self) -> None: + stats = SkillStats() + stats.record_success(1.5) + assert stats.count == 1 + assert stats.success == 1 + assert stats.failure == 0 + assert stats.total_duration == 1.5 + assert stats.last_duration == 1.5 + assert stats.last_error is None + assert stats.last_called_at is not None + + def test_record_failure(self) -> None: + stats = SkillStats() + stats.record_failure(2.0, "timeout") + assert stats.count == 1 + assert stats.success == 0 + assert stats.failure == 1 + assert stats.total_duration == 2.0 + assert stats.last_duration == 2.0 + assert stats.last_error == "timeout" + assert stats.last_called_at is not None + + def test_multiple_successes(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + stats.record_success(2.0) + stats.record_success(3.0) + assert stats.count == 3 + assert stats.success == 3 + assert stats.failure == 0 + assert stats.total_duration == 6.0 + assert stats.last_duration == 3.0 + + def test_mixed_success_and_failure(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + stats.record_failure(0.5, "error A") + stats.record_success(2.0) + assert stats.count == 3 + assert stats.success == 2 + assert stats.failure == 1 + assert stats.total_duration == 3.5 + assert stats.last_duration == 2.0 + assert stats.last_error is None # cleared by success + + def test_success_clears_last_error(self) -> None: + stats = SkillStats() + stats.record_failure(1.0, "something broke") + assert stats.last_error == "something broke" + stats.record_success(0.5) + assert stats.last_error is None + + def test_failure_overwrites_last_error(self) -> None: + stats = SkillStats() + stats.record_failure(1.0, "error 1") + stats.record_failure(2.0, "error 2") + assert stats.last_error == "error 2" + + def test_average_duration(self) -> None: + stats = SkillStats() + stats.record_success(2.0) + stats.record_success(4.0) + avg = stats.total_duration / stats.count + assert avg == 3.0 + + def test_last_called_at_updates(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + first_called = stats.last_called_at + assert first_called is not None + stats.record_failure(1.0, "err") + assert stats.last_called_at is not None + assert stats.last_called_at >= first_called + + def test_zero_duration(self) -> None: + stats = SkillStats() + stats.record_success(0.0) + assert stats.total_duration == 0.0 + assert stats.last_duration == 0.0 + assert stats.count == 1 diff --git a/tests/test_summary_agent.py b/tests/test_summary_agent.py new file mode 100644 index 00000000..00c86258 --- /dev/null +++ b/tests/test_summary_agent.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.config.models import AgentModelConfig +from Undefined.skills.agents.summary_agent.handler import ( + _build_user_content, + execute as summary_agent_execute, +) + + +@pytest.mark.asyncio +async def test_summary_agent_normal_execution() -> None: + """Normal execution → calls run_agent_with_tools with correct params.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 123456, + "user_id": 10002, + "sender_id": 10002, + "request_type": "group", + "runtime_config": None, + "queue_lane": None, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="总结结果:讨论了技术话题"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结最近 50 条消息"}, + context, + ) + + assert result == "总结结果:讨论了技术话题" + mock_run_agent.assert_called_once() + + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["agent_name"] == "summary_agent" + assert "请总结最近 50 条消息" in call_kwargs["user_content"] + assert "count=50" in call_kwargs["user_content"] + assert call_kwargs["empty_user_content_message"] == "请提供您的总结需求" + assert "消息总结助手" in call_kwargs["default_prompt"] + assert call_kwargs["context"] is context + assert isinstance(call_kwargs["agent_dir"], Path) + assert call_kwargs["max_iterations"] == 10 + assert call_kwargs["tool_error_prefix"] == "错误" + + +@pytest.mark.asyncio +async def test_summary_agent_empty_prompt() -> None: + """Empty prompt → returns '请提供您的总结需求'.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": ""}, context) + + assert result == "请提供您的总结需求" + mock_run_agent.assert_called_once() + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_whitespace_prompt() -> None: + """Whitespace-only prompt → treated as empty.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": " "}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_missing_prompt_arg() -> None: + """Missing 'prompt' arg → defaults to empty string.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_complex_prompt() -> None: + """Complex prompt with time range and custom instructions.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 654321, + "user_id": 99999, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="详细总结内容"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结过去 1d 内的聊天消息,重点关注:技术讨论"}, + context, + ) + + assert result == "详细总结内容" + call_kwargs = mock_run_agent.call_args.kwargs + assert ( + "请总结过去 1d 内的聊天消息,重点关注:技术讨论" in call_kwargs["user_content"] + ) + + +@pytest.mark.asyncio +async def test_summary_agent_propagates_exception() -> None: + """Exception from run_agent_with_tools propagates up.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(side_effect=RuntimeError("Agent failure")), + ): + with pytest.raises(RuntimeError, match="Agent failure"): + await summary_agent_execute({"prompt": "test"}, context) + + +def test_build_user_content_prefers_structured_args() -> None: + content = _build_user_content( + { + "prompt": "请总结最近 80 条聊天消息,重点关注:发布计划", + "count": 80, + "time_range": "", + "focus": "发布计划", + } + ) + + assert "请总结最近 80 条聊天消息,重点关注:发布计划" in content + assert "count=80" in content + assert "总结时重点关注:发布计划" in content + assert "2 到 3 个短段落" in content + + +@pytest.mark.asyncio +async def test_summary_agent_uses_summary_model_override_when_configured() -> None: + runtime_config = AsyncMock() + runtime_config.summary_model_configured = True + runtime_config.summary_model = AgentModelConfig( + api_url="https://summary.example/v1", + api_key="sk-summary", + model_name="summary-model", + ) + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "runtime_config": runtime_config, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="总结结果"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结最近 20 条消息"}, context + ) + + assert result == "总结结果" + call_kwargs = mock_run_agent.call_args.kwargs + assert ( + call_kwargs["context"]["model_config_override"] is runtime_config.summary_model + ) diff --git a/tests/test_summary_command.py b/tests/test_summary_command.py new file mode 100644 index 00000000..31d5c317 --- /dev/null +++ b/tests/test_summary_command.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.summary.handler import ( + _build_prompt, + _parse_args, + execute as summary_execute, +) + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + history_manager: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, + ai: Any = None, +) -> CommandContext: + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + if ai is None: + ai = stub + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=stub, + sender=cast(Any, sender), + ai=ai, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + history_manager=history_manager, + ) + + +# -- _parse_args unit tests -- + + +def test_parse_args_empty() -> None: + """Empty args → (50, None, '').""" + count, time_range, custom_prompt = _parse_args([]) + assert count == 50 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_count_only() -> None: + """['100'] → (100, None, '').""" + count, time_range, custom_prompt = _parse_args(["100"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_time_range_only() -> None: + """['1d'] → (None, '1d', '').""" + count, time_range, custom_prompt = _parse_args(["1d"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "" + + +def test_parse_args_count_with_custom_prompt() -> None: + """['100', '技术讨论'] → (100, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["100", "技术讨论"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_time_range_with_custom_prompt() -> None: + """['1d', '总结技术'] → (None, '1d', '总结技术').""" + count, time_range, custom_prompt = _parse_args(["1d", "总结技术"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "总结技术" + + +def test_parse_args_custom_prompt_only() -> None: + """['技术讨论'] → (50, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["技术讨论"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_count_capped_at_500() -> None: + """['999'] → (500, None, '') (capped).""" + count, time_range, custom_prompt = _parse_args(["999"]) + assert count == 500 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_multiple_words_prompt() -> None: + """['技术', '讨论', '总结'] → (50, None, '技术 讨论 总结').""" + count, time_range, custom_prompt = _parse_args(["技术", "讨论", "总结"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术 讨论 总结" + + +# -- _build_prompt unit tests -- + + +def test_build_prompt_with_count() -> None: + """With count → '请总结最近 X 条聊天消息'.""" + prompt = _build_prompt(100, None, "") + assert "请总结最近 100 条聊天消息" in prompt + + +def test_build_prompt_with_time_range() -> None: + """With time_range → '请总结过去 X 内的聊天消息'.""" + prompt = _build_prompt(None, "1d", "") + assert "请总结过去 1d 内的聊天消息" in prompt + + +def test_build_prompt_with_custom_prompt() -> None: + """With custom_prompt → adds '重点关注:...'.""" + prompt = _build_prompt(50, None, "技术讨论") + assert "请总结最近 50 条聊天消息" in prompt + assert "重点关注:技术讨论" in prompt + + +def test_build_prompt_default_count() -> None: + """Default count when both are None.""" + prompt = _build_prompt(None, None, "") + assert "请总结最近 50 条聊天消息" in prompt + + +def test_build_prompt_time_range_and_custom() -> None: + """Time range with custom prompt.""" + prompt = _build_prompt(None, "6h", "重要公告") + assert "请总结过去 6h 内的聊天消息" in prompt + assert "重点关注:重要公告" in prompt + + +# -- Command execution tests -- + + +@pytest.mark.asyncio +async def test_summary_no_history_manager() -> None: + """No history_manager → sends error message.""" + sender = _DummySender() + context = _build_context( + sender=sender, + history_manager=None, + scope="group", + group_id=123456, + sender_id=10002, + ) + + await summary_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 历史记录管理器未配置" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_agent_call_success() -> None: + """Agent call success → result forwarded to user.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = None + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结内容:最近讨论了技术话题。"), + ) as mock_agent: + await summary_execute(["50"], context) + + assert len(sender.group_messages) == 1 + assert "总结内容:最近讨论了技术话题。" in sender.group_messages[0][1] + mock_agent.assert_called_once() + call_args = mock_agent.call_args + assert call_args[0][0]["prompt"] == "请总结最近 50 条聊天消息" + assert call_args[0][0]["count"] == 50 + assert call_args[0][0]["time_range"] is None + assert call_args[0][0]["focus"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_call_failure() -> None: + """Agent call failure → sends error message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(side_effect=Exception("Agent error")), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 消息总结失败,请稍后重试" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_agent_returns_empty() -> None: + """Agent returns empty result → sends '未能生成总结内容'.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value=" "), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 1 + assert "📭 未能生成总结内容" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_private_chat() -> None: + """Private chat → uses send_private_message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="私聊总结结果"), + ): + await summary_execute(["1d", "重要消息"], context) + + assert len(sender.private_messages) == 1 + assert "私聊总结结果" in sender.private_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_passes_correct_context_to_agent() -> None: + """Agent receives correct context parameters.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = SimpleNamespace(some_config="value") + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=999888, + sender_id=777666, + user_id=None, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结"), + ) as mock_agent: + await summary_execute([], context) + + call_args = mock_agent.call_args + agent_context = call_args[0][1] + assert agent_context["ai_client"] is ai + assert agent_context["history_manager"] is history_manager + assert agent_context["group_id"] == 999888 + assert agent_context["sender_id"] == 777666 + assert agent_context["user_id"] == 777666 + assert agent_context["request_type"] == "group" + assert agent_context["runtime_config"] is ai.runtime_config + + +@pytest.mark.asyncio +async def test_summary_passes_time_range_and_focus_to_agent() -> None: + """Time range and focus are passed structurally to the agent.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结"), + ) as mock_agent: + await summary_execute(["1d", "技术讨论"], context) + + call_args = mock_agent.call_args + assert call_args[0][0]["prompt"] == "请总结过去 1d 内的聊天消息,重点关注:技术讨论" + assert call_args[0][0]["count"] is None + assert call_args[0][0]["time_range"] == "1d" + assert call_args[0][0]["focus"] == "技术讨论" diff --git a/tests/test_sync_config_template_script.py b/tests/test_sync_config_template_script.py index 3f14aec4..e422138f 100644 --- a/tests/test_sync_config_template_script.py +++ b/tests/test_sync_config_template_script.py @@ -43,6 +43,7 @@ def fake_sync_config_file( added_paths=[], removed_paths=["models.chat.extra"], comments={}, + updated_comment_paths=[], ) monkeypatch.setattr(module, "sync_config_file", fake_sync_config_file) diff --git a/tests/test_time_utils.py b/tests/test_time_utils.py new file mode 100644 index 00000000..d4c49cd1 --- /dev/null +++ b/tests/test_time_utils.py @@ -0,0 +1,84 @@ +"""Tests for Undefined.utils.time_utils — time parsing/formatting helpers.""" + +from __future__ import annotations + +from datetime import datetime + +from Undefined.utils.time_utils import format_datetime, parse_time_range + + +class TestParseTimeRange: + def test_both_valid(self) -> None: + start, end = parse_time_range("2024-01-15 08:30:00", "2024-06-20 17:45:00") + assert start == datetime(2024, 1, 15, 8, 30, 0) + assert end == datetime(2024, 6, 20, 17, 45, 0) + + def test_only_start(self) -> None: + start, end = parse_time_range("2024-01-01 00:00:00", None) + assert start == datetime(2024, 1, 1, 0, 0, 0) + assert end is None + + def test_only_end(self) -> None: + start, end = parse_time_range(None, "2024-12-31 23:59:59") + assert start is None + assert end == datetime(2024, 12, 31, 23, 59, 59) + + def test_both_none(self) -> None: + start, end = parse_time_range(None, None) + assert start is None + assert end is None + + def test_invalid_start_format(self) -> None: + start, end = parse_time_range("2024/01/01", None) + assert start is None + assert end is None + + def test_invalid_end_format(self) -> None: + start, end = parse_time_range(None, "not-a-date") + assert start is None + assert end is None + + def test_both_invalid(self) -> None: + start, end = parse_time_range("bad", "worse") + assert start is None + assert end is None + + def test_empty_strings(self) -> None: + start, end = parse_time_range("", "") + assert start is None + assert end is None + + def test_date_only_format_rejected(self) -> None: + start, end = parse_time_range("2024-01-01", None) + assert start is None + + def test_midnight(self) -> None: + start, end = parse_time_range("2024-01-01 00:00:00", None) + assert start == datetime(2024, 1, 1, 0, 0, 0) + + def test_end_of_day(self) -> None: + start, end = parse_time_range(None, "2024-12-31 23:59:59") + assert end == datetime(2024, 12, 31, 23, 59, 59) + + +class TestFormatDatetime: + def test_none_input(self) -> None: + assert format_datetime(None) == "未指定" + + def test_normal_datetime(self) -> None: + dt = datetime(2024, 3, 15, 14, 30, 45) + assert format_datetime(dt) == "2024-03-15 14:30:45" + + def test_midnight(self) -> None: + dt = datetime(2024, 1, 1, 0, 0, 0) + assert format_datetime(dt) == "2024-01-01 00:00:00" + + def test_end_of_day(self) -> None: + dt = datetime(2024, 12, 31, 23, 59, 59) + assert format_datetime(dt) == "2024-12-31 23:59:59" + + def test_roundtrip(self) -> None: + original = "2024-06-15 10:20:30" + start, _ = parse_time_range(original, None) + assert start is not None + assert format_datetime(start) == original diff --git a/tests/test_token_usage_unit.py b/tests/test_token_usage_unit.py new file mode 100644 index 00000000..821f6ea8 --- /dev/null +++ b/tests/test_token_usage_unit.py @@ -0,0 +1,169 @@ +"""TokenUsage 序列化/反序列化 单元测试""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from Undefined.token_usage_storage import TokenUsage + + +def _sample_dict() -> dict[str, Any]: + return { + "timestamp": "2025-01-01T00:00:00", + "model_name": "gpt-4", + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "duration_seconds": 1.5, + "call_type": "chat", + "success": True, + } + + +# --------------------------------------------------------------------------- +# to_dict / from_dict 往返 +# --------------------------------------------------------------------------- + + +class TestTokenUsageRoundtrip: + def test_basic_roundtrip(self) -> None: + d = _sample_dict() + usage = TokenUsage.from_dict(d) + assert usage.to_dict() == d + + def test_all_fields_preserved(self) -> None: + usage = TokenUsage( + timestamp="ts", + model_name="m", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + duration_seconds=0.5, + call_type="agent", + success=False, + ) + restored = TokenUsage.from_dict(usage.to_dict()) + assert restored == usage + + +# --------------------------------------------------------------------------- +# from_dict — 缺失字段回退 +# --------------------------------------------------------------------------- + + +class TestTokenUsageFromDictDefaults: + def test_empty_dict(self) -> None: + usage = TokenUsage.from_dict({}) + assert usage.timestamp == "" + assert usage.model_name == "" + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.duration_seconds == 0.0 + assert usage.call_type == "unknown" + assert usage.success is True # 默认为 True + + def test_timestamp_fallback_to_time(self) -> None: + usage = TokenUsage.from_dict({"time": "2025-01-01"}) + assert usage.timestamp == "2025-01-01" + + def test_timestamp_fallback_to_created_at(self) -> None: + usage = TokenUsage.from_dict({"created_at": "2025-02-02"}) + assert usage.timestamp == "2025-02-02" + + def test_model_name_fallback_to_model(self) -> None: + usage = TokenUsage.from_dict({"model": "claude"}) + assert usage.model_name == "claude" + + def test_prompt_tokens_fallback_to_input_tokens(self) -> None: + usage = TokenUsage.from_dict({"input_tokens": 42}) + assert usage.prompt_tokens == 42 + + def test_completion_tokens_fallback_to_output_tokens(self) -> None: + usage = TokenUsage.from_dict({"output_tokens": 24}) + assert usage.completion_tokens == 24 + + def test_total_tokens_auto_sum(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": 10, "completion_tokens": 20}) + assert usage.total_tokens == 30 + + def test_total_tokens_explicit(self) -> None: + usage = TokenUsage.from_dict( + {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 99} + ) + assert usage.total_tokens == 99 + + def test_call_type_fallback_to_type(self) -> None: + usage = TokenUsage.from_dict({"type": "vision"}) + assert usage.call_type == "vision" + + def test_duration_fallback_to_duration(self) -> None: + usage = TokenUsage.from_dict({"duration": 3.14}) + assert usage.duration_seconds == pytest.approx(3.14) + + +# --------------------------------------------------------------------------- +# from_dict — success 字段各种类型 +# --------------------------------------------------------------------------- + + +class TestTokenUsageSuccess: + def test_success_bool_true(self) -> None: + assert TokenUsage.from_dict({"success": True}).success is True + + def test_success_bool_false(self) -> None: + assert TokenUsage.from_dict({"success": False}).success is False + + def test_success_string_false(self) -> None: + assert TokenUsage.from_dict({"success": "false"}).success is False + + def test_success_string_0(self) -> None: + assert TokenUsage.from_dict({"success": "0"}).success is False + + def test_success_string_no(self) -> None: + assert TokenUsage.from_dict({"success": "no"}).success is False + + def test_success_string_yes(self) -> None: + assert TokenUsage.from_dict({"success": "yes"}).success is True + + def test_success_int_1(self) -> None: + assert TokenUsage.from_dict({"success": 1}).success is True + + def test_success_int_0(self) -> None: + assert TokenUsage.from_dict({"success": 0}).success is False + + +# --------------------------------------------------------------------------- +# from_dict — 类型转换容错 +# --------------------------------------------------------------------------- + + +class TestTokenUsageTypeCoercion: + def test_string_tokens(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": "42"}) + assert usage.prompt_tokens == 42 + + def test_invalid_tokens_default_zero(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": "abc"}) + assert usage.prompt_tokens == 0 + + def test_none_tokens_default_zero(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": None}) + assert usage.prompt_tokens == 0 + + def test_non_string_timestamp(self) -> None: + usage = TokenUsage.from_dict({"timestamp": 12345}) + assert usage.timestamp == "12345" + + def test_non_string_model_name(self) -> None: + usage = TokenUsage.from_dict({"model_name": 42}) + assert usage.model_name == "42" + + def test_extra_fields_ignored(self) -> None: + d = _sample_dict() + d["extra_field"] = "ignored" + d["another"] = 999 + usage = TokenUsage.from_dict(d) + assert usage.model_name == "gpt-4" diff --git a/tests/test_tool_calls.py b/tests/test_tool_calls.py new file mode 100644 index 00000000..4bac35f7 --- /dev/null +++ b/tests/test_tool_calls.py @@ -0,0 +1,273 @@ +"""Tests for Undefined.utils.tool_calls.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import pytest + +from Undefined.utils.tool_calls import ( + _clean_json_string, + _repair_json_like_string, + _strip_code_fences, + extract_required_tool_call_arguments, + normalize_tool_arguments_json, + parse_tool_arguments, +) + + +@pytest.fixture() +def test_logger() -> logging.Logger: + return logging.getLogger("test") + + +# --------------------------------------------------------------------------- +# _strip_code_fences +# --------------------------------------------------------------------------- + + +class TestStripCodeFences: + def test_strip_json_fence(self) -> None: + raw = '```json\n{"a": 1}\n```' + assert _strip_code_fences(raw) == '{"a": 1}' + + def test_strip_generic_fence(self) -> None: + raw = '```\n{"a": 1}\n```' + assert _strip_code_fences(raw) == '{"a": 1}' + + def test_no_fence(self) -> None: + raw = '{"a": 1}' + assert _strip_code_fences(raw) == '{"a": 1}' + + +# --------------------------------------------------------------------------- +# _clean_json_string +# --------------------------------------------------------------------------- + + +class TestCleanJsonString: + def test_removes_control_chars(self) -> None: + raw = '{"key":\r\n\t"val"}' + result = _clean_json_string(raw) + assert "\r" not in result + assert "\n" not in result + assert "\t" not in result + + +# --------------------------------------------------------------------------- +# _repair_json_like_string +# --------------------------------------------------------------------------- + + +class TestRepairJsonLikeString: + def test_missing_closing_brace(self) -> None: + raw = '{"a": 1' + repaired = _repair_json_like_string(raw) + assert json.loads(repaired) == {"a": 1} + + def test_trailing_comma(self) -> None: + raw = '{"a": 1, ' + repaired = _repair_json_like_string(raw) + assert json.loads(repaired) == {"a": 1} + + def test_empty_string(self) -> None: + assert _repair_json_like_string("") == "" + + +# --------------------------------------------------------------------------- +# parse_tool_arguments +# --------------------------------------------------------------------------- + + +class TestParseToolArguments: + def test_dict_passthrough(self) -> None: + d: dict[str, Any] = {"key": "val"} + assert parse_tool_arguments(d) is d + + def test_none_returns_empty(self) -> None: + assert parse_tool_arguments(None) == {} + + def test_empty_string_returns_empty(self) -> None: + assert parse_tool_arguments("") == {} + + def test_whitespace_returns_empty(self) -> None: + assert parse_tool_arguments(" ") == {} + + def test_valid_json_string(self) -> None: + result = parse_tool_arguments('{"x": 42}') + assert result == {"x": 42} + + def test_json_with_code_fences(self) -> None: + raw = '```json\n{"x": 42}\n```' + assert parse_tool_arguments(raw) == {"x": 42} + + def test_json_with_control_chars(self, test_logger: logging.Logger) -> None: + raw = '{"x":\r\n42}' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"x": 42} + + def test_truncated_json_repaired(self, test_logger: logging.Logger) -> None: + raw = '{"a": "hello"' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"a": "hello"} + + def test_json_with_trailing_content(self, test_logger: logging.Logger) -> None: + raw = '{"a": 1} some trailing text' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"a": 1} + + def test_non_dict_json_returns_empty(self, test_logger: logging.Logger) -> None: + raw = "[1, 2, 3]" + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {} + + def test_completely_invalid_returns_empty( + self, test_logger: logging.Logger + ) -> None: + raw = "this is not json at all" + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {} + + def test_unsupported_type_returns_empty(self, test_logger: logging.Logger) -> None: + result = parse_tool_arguments(42, logger=test_logger, tool_name="t") + assert result == {} + + +# --------------------------------------------------------------------------- +# normalize_tool_arguments_json +# --------------------------------------------------------------------------- + + +class TestNormalizeToolArgumentsJson: + def test_none(self) -> None: + assert normalize_tool_arguments_json(None) == "{}" + + def test_dict(self) -> None: + result = normalize_tool_arguments_json({"a": 1}) + parsed = json.loads(result) + assert parsed == {"a": 1} + + def test_empty_string(self) -> None: + assert normalize_tool_arguments_json("") == "{}" + + def test_valid_json_object_string(self) -> None: + result = normalize_tool_arguments_json('{"key": "val"}') + parsed = json.loads(result) + assert parsed == {"key": "val"} + + def test_non_object_json_wrapped(self) -> None: + result = normalize_tool_arguments_json("[1,2,3]") + parsed = json.loads(result) + assert parsed == {"_value": [1, 2, 3]} + + def test_invalid_json_wrapped_raw(self) -> None: + result = normalize_tool_arguments_json("not json") + parsed = json.loads(result) + assert parsed == {"_raw": "not json"} + + def test_non_string_non_dict_wrapped(self) -> None: + result = normalize_tool_arguments_json(42) + parsed = json.loads(result) + assert parsed == {"_value": 42} + + def test_number_json_string_wrapped(self) -> None: + result = normalize_tool_arguments_json("123") + parsed = json.loads(result) + assert parsed == {"_value": 123} + + +# --------------------------------------------------------------------------- +# extract_required_tool_call_arguments +# --------------------------------------------------------------------------- + + +class TestExtractRequiredToolCallArguments: + def _build_response( + self, + name: str = "my_tool", + arguments: Any = '{"x": 1}', + ) -> dict[str, Any]: + return { + "choices": [ + { + "message": { + "tool_calls": [ + { + "function": { + "name": name, + "arguments": arguments, + } + } + ] + } + } + ] + } + + def test_happy_path(self) -> None: + resp = self._build_response() + result = extract_required_tool_call_arguments( + resp, expected_tool_name="my_tool", stage="test" + ) + assert result == {"x": 1} + + def test_missing_choices_raises(self) -> None: + with pytest.raises(ValueError, match="choices"): + extract_required_tool_call_arguments({}, expected_tool_name="t", stage="s") + + def test_non_dict_choice_raises(self) -> None: + with pytest.raises(ValueError, match="choice"): + extract_required_tool_call_arguments( + {"choices": ["bad"]}, expected_tool_name="t", stage="s" + ) + + def test_missing_message_raises(self) -> None: + with pytest.raises(ValueError, match="message"): + extract_required_tool_call_arguments( + {"choices": [{"no_message": True}]}, + expected_tool_name="t", + stage="s", + ) + + def test_missing_tool_calls_raises(self) -> None: + with pytest.raises(ValueError, match="tool_calls"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"content": "hi"}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_non_dict_tool_call_raises(self) -> None: + with pytest.raises(ValueError, match="tool_call"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"tool_calls": ["bad"]}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_missing_function_raises(self) -> None: + with pytest.raises(ValueError, match="function"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"tool_calls": [{"id": "1"}]}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_name_mismatch_raises(self) -> None: + resp = self._build_response(name="wrong_name") + with pytest.raises(ValueError, match="不匹配"): + extract_required_tool_call_arguments( + resp, expected_tool_name="my_tool", stage="s" + ) + + def test_with_logger(self, test_logger: logging.Logger) -> None: + resp = self._build_response() + result = extract_required_tool_call_arguments( + resp, + expected_tool_name="my_tool", + stage="test", + logger=test_logger, + ) + assert result == {"x": 1} diff --git a/tests/test_utils_common.py b/tests/test_utils_common.py new file mode 100644 index 00000000..c1075b45 --- /dev/null +++ b/tests/test_utils_common.py @@ -0,0 +1,305 @@ +"""Tests for Undefined.utils.common.""" + +from __future__ import annotations + +from typing import Any + +from Undefined.utils.common import ( + FORWARD_EXPAND_MAX_CHARS, + _format_forward_node_time, + _normalize_message_content, + _parse_at_segment, + _parse_media_segment, + _parse_segment, + _truncate_forward_text, + extract_text, + matches_xinliweiyuan, + message_to_segments, + process_at_mentions, +) + + +# --------------------------------------------------------------------------- +# extract_text +# --------------------------------------------------------------------------- + + +class TestExtractText: + def test_text_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "text", "data": {"text": "hello"}}] + assert extract_text(segments) == "hello" + + def test_multiple_text_segments_joined(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "text", "data": {"text": "hello "}}, + {"type": "text", "data": {"text": "world"}}, + ] + assert extract_text(segments) == "hello world" + + def test_at_segment_without_name(self) -> None: + segments: list[dict[str, Any]] = [{"type": "at", "data": {"qq": "123456"}}] + assert extract_text(segments) == "[@123456]" + + def test_at_segment_with_name(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "at", "data": {"qq": "123456", "name": "Bob"}} + ] + assert extract_text(segments) == "[@123456(Bob)]" + + def test_face_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "face", "data": {}}] + assert extract_text(segments) == "[表情]" + + def test_image_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "image", "data": {"file": "a.png"}}] + assert extract_text(segments) == "[图片: a.png]" + + def test_file_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "file", "data": {"file": "doc.pdf"}}] + assert extract_text(segments) == "[文件: doc.pdf]" + + def test_video_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "video", "data": {"file": "v.mp4"}}] + assert extract_text(segments) == "[视频: v.mp4]" + + def test_record_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "record", "data": {"file": "r.amr"}}] + assert extract_text(segments) == "[语音: r.amr]" + + def test_audio_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "audio", "data": {"file": "a.mp3"}}] + assert extract_text(segments) == "[音频: a.mp3]" + + def test_forward_segment_with_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "forward", "data": {"id": "fw123"}}] + assert extract_text(segments) == "[合并转发: fw123]" + + def test_forward_segment_without_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "forward", "data": {}}] + assert extract_text(segments) == "[合并转发]" + + def test_reply_segment_with_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "reply", "data": {"id": "42"}}] + assert extract_text(segments) == "[引用: 42]" + + def test_reply_segment_without_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "reply", "data": {}}] + assert extract_text(segments) == "[引用]" + + def test_unknown_segment_skipped(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "unknown_custom", "data": {}}, + {"type": "text", "data": {"text": "ok"}}, + ] + assert extract_text(segments) == "ok" + + def test_empty_segments(self) -> None: + assert extract_text([]) == "" + + def test_mixed_segments(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "text", "data": {"text": "hi "}}, + {"type": "face", "data": {}}, + {"type": "text", "data": {"text": " bye"}}, + ] + assert extract_text(segments) == "hi [表情] bye" + + def test_data_not_dict_fallback(self) -> None: + """If data is not a dict, segment should still be handled safely.""" + segments: list[dict[str, Any]] = [ + {"type": "text", "data": "not_a_dict"}, + ] + # data becomes {}, so text = "" + assert extract_text(segments) == "" + + +# --------------------------------------------------------------------------- +# process_at_mentions +# --------------------------------------------------------------------------- + + +class TestProcessAtMentions: + def test_basic_at(self) -> None: + assert process_at_mentions("[@123456]") == "[CQ:at,qq=123456]" + + def test_at_with_braces(self) -> None: + assert process_at_mentions("[@{123456}]") == "[CQ:at,qq=123456]" + + def test_multiple_ats(self) -> None: + result = process_at_mentions("[@11111] hi [@22222]") + assert result == "[CQ:at,qq=11111] hi [CQ:at,qq=22222]" + + def test_escaped_brackets(self) -> None: + result = process_at_mentions("\\[@123456\\]") + assert result == "[@123456]" + + def test_no_match(self) -> None: + assert process_at_mentions("hello world") == "hello world" + + +# --------------------------------------------------------------------------- +# message_to_segments +# --------------------------------------------------------------------------- + + +class TestMessageToSegments: + def test_plain_text_only(self) -> None: + segs = message_to_segments("hello world") + assert segs == [{"type": "text", "data": {"text": "hello world"}}] + + def test_cq_at(self) -> None: + segs = message_to_segments("[CQ:at,qq=123]") + assert segs == [{"type": "at", "data": {"qq": "123"}}] + + def test_text_around_cq(self) -> None: + segs = message_to_segments("hi [CQ:face,id=178] bye") + assert len(segs) == 3 + assert segs[0] == {"type": "text", "data": {"text": "hi "}} + assert segs[1] == {"type": "face", "data": {"id": "178"}} + assert segs[2] == {"type": "text", "data": {"text": " bye"}} + + def test_empty_string(self) -> None: + assert message_to_segments("") == [] + + def test_cq_without_args(self) -> None: + segs = message_to_segments("[CQ:face]") + assert segs == [{"type": "face", "data": {}}] + + +# --------------------------------------------------------------------------- +# matches_xinliweiyuan +# --------------------------------------------------------------------------- + + +class TestMatchesXinliweiyuan: + def test_exact_keyword(self) -> None: + assert matches_xinliweiyuan("心理委员") is True + + def test_keyword_with_prefix(self) -> None: + assert matches_xinliweiyuan("找心理委员") is True + + def test_keyword_with_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员在吗") is True + + def test_keyword_both_sides_fails(self) -> None: + assert matches_xinliweiyuan("我找心理委员吧") is False + + def test_no_keyword(self) -> None: + assert matches_xinliweiyuan("你好世界") is False + + def test_too_many_extra_chars(self) -> None: + assert matches_xinliweiyuan("abcdef心理委员") is False + + def test_punctuation_not_counted(self) -> None: + # Punctuation is removed before counting + assert matches_xinliweiyuan("!!心理委员") is True + + def test_five_chars_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员abcde") is True + + def test_six_chars_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员abcdef") is False + + +# --------------------------------------------------------------------------- +# _normalize_message_content +# --------------------------------------------------------------------------- + + +class TestNormalizeMessageContent: + def test_list_of_dicts(self) -> None: + content: list[dict[str, Any]] = [{"type": "text", "data": {"text": "hi"}}] + result = _normalize_message_content(content) + assert result == content + + def test_single_dict(self) -> None: + seg: dict[str, Any] = {"type": "text", "data": {"text": "hi"}} + result = _normalize_message_content(seg) + assert result == [seg] + + def test_string(self) -> None: + result = _normalize_message_content("hello [CQ:face]") + assert len(result) == 2 + assert result[0]["type"] == "text" + assert result[1]["type"] == "face" + + def test_list_with_string_items(self) -> None: + result = _normalize_message_content(["hello"]) + assert result == [{"type": "text", "data": {"text": "hello"}}] + + def test_unsupported_type_returns_empty(self) -> None: + result = _normalize_message_content(12345) + assert result == [] + + +# --------------------------------------------------------------------------- +# _format_forward_node_time +# --------------------------------------------------------------------------- + + +class TestFormatForwardNodeTime: + def test_valid_timestamp(self) -> None: + result = _format_forward_node_time(1700000000) + assert "2023" in result + + def test_millisecond_timestamp(self) -> None: + result = _format_forward_node_time(1700000000000) + assert "2023" in result + + def test_zero_returns_empty(self) -> None: + assert _format_forward_node_time(0) == "" + + def test_none_returns_empty(self) -> None: + assert _format_forward_node_time(None) == "" + + def test_empty_string_returns_empty(self) -> None: + assert _format_forward_node_time("") == "" + + def test_invalid_string_returns_as_is(self) -> None: + assert _format_forward_node_time("not_a_time") == "not_a_time" + + +# --------------------------------------------------------------------------- +# _truncate_forward_text +# --------------------------------------------------------------------------- + + +class TestTruncateForwardText: + def test_short_text_not_truncated(self) -> None: + text = "hello" + assert _truncate_forward_text(text) == text + + def test_long_text_truncated(self) -> None: + text = "a" * (FORWARD_EXPAND_MAX_CHARS + 100) + result = _truncate_forward_text(text) + assert "[合并转发内容过长,已截断]" in result + assert len(result) <= FORWARD_EXPAND_MAX_CHARS + 50 # marker included + + +# --------------------------------------------------------------------------- +# _parse_segment / _parse_at_segment / _parse_media_segment +# --------------------------------------------------------------------------- + + +class TestParseHelpers: + def test_parse_at_segment_with_nickname(self) -> None: + result = _parse_at_segment({"qq": "999", "nickname": "Nick"}, bot_qq=0) + assert result == "[@999(Nick)]" + + def test_parse_at_segment_no_name(self) -> None: + result = _parse_at_segment({"qq": "999"}, bot_qq=0) + assert result == "[@999]" + + def test_parse_media_segment_image(self) -> None: + result = _parse_media_segment("image", {"file": "pic.jpg"}) + assert result == "[图片: pic.jpg]" + + def test_parse_media_segment_unknown(self) -> None: + result = _parse_media_segment("custom_type", {}) + assert result is None + + def test_parse_segment_missing_type(self) -> None: + seg: dict[str, Any] = {"data": {"text": "hello"}} + # type="" → falls through to _parse_media_segment → None + result = _parse_segment(seg) + assert result is None diff --git a/tests/test_webui_management_api.py b/tests/test_webui_management_api.py index 435b710d..93912790 100644 --- a/tests/test_webui_management_api.py +++ b/tests/test_webui_management_api.py @@ -6,7 +6,7 @@ from aiohttp import web -from Undefined.api import app as runtime_api_app +from Undefined.api import _helpers as runtime_api_helpers from Undefined.webui import app as webui_app from Undefined.webui.app import create_app from Undefined.webui.core import SessionStore @@ -350,7 +350,7 @@ def test_webui_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: monkeypatch.setattr( - runtime_api_app, + runtime_api_helpers, "load_webui_settings", lambda: SimpleNamespace(url="127.0.0.1", port=8787), ) @@ -359,7 +359,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "tauri://localhost"})), ) trusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(trusted_request, trusted_response) + runtime_api_helpers._apply_cors_headers(trusted_request, trusted_response) assert trusted_response.headers.get("Access-Control-Allow-Origin") == ( "tauri://localhost" ) @@ -370,7 +370,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "http://localhost:1420"})), ) loopback_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(loopback_request, loopback_response) + runtime_api_helpers._apply_cors_headers(loopback_request, loopback_response) assert loopback_response.headers.get("Access-Control-Allow-Origin") == ( "http://localhost:1420" ) @@ -380,7 +380,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "https://evil.example"})), ) untrusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(untrusted_request, untrusted_response) + runtime_api_helpers._apply_cors_headers(untrusted_request, untrusted_response) assert "Access-Control-Allow-Origin" not in untrusted_response.headers assert "Access-Control-Allow-Credentials" not in untrusted_response.headers diff --git a/tests/test_xml_utils.py b/tests/test_xml_utils.py new file mode 100644 index 00000000..fbb89a0d --- /dev/null +++ b/tests/test_xml_utils.py @@ -0,0 +1,100 @@ +"""Tests for Undefined.utils.xml — XML escaping helpers.""" + +from __future__ import annotations + +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + + +class TestEscapeXmlText: + def test_plain_text(self) -> None: + assert escape_xml_text("hello world") == "hello world" + + def test_ampersand(self) -> None: + assert escape_xml_text("a & b") == "a & b" + + def test_less_than(self) -> None: + assert escape_xml_text("a < b") == "a < b" + + def test_greater_than(self) -> None: + assert escape_xml_text("a > b") == "a > b" + + def test_double_quote(self) -> None: + assert escape_xml_text('say "hello"') == "say "hello"" + + def test_single_quote(self) -> None: + assert escape_xml_text("it's") == "it's" + + def test_all_special_chars(self) -> None: + result = escape_xml_text("""""") + assert "<" in result + assert ">" in result + assert "&" in result + assert """ in result + assert "'" in result + + def test_empty_string(self) -> None: + assert escape_xml_text("") == "" + + def test_unicode(self) -> None: + assert escape_xml_text("こんにちは") == "こんにちは" + + def test_unicode_with_special(self) -> None: + assert escape_xml_text("价格 < 100 & > 50") == "价格 < 100 & > 50" + + def test_nested_quotes(self) -> None: + result = escape_xml_text("""He said "it's fine" """) + assert """ in result + assert "'" in result + + def test_multiline(self) -> None: + text = "line1\nline2\n" + result = escape_xml_text(text) + assert "\n" in result + assert "<tag>" in result + + def test_already_escaped(self) -> None: + result = escape_xml_text("&") + assert result == "&amp;" + + +class TestEscapeXmlAttr: + def test_plain_string(self) -> None: + assert escape_xml_attr("hello") == "hello" + + def test_special_chars(self) -> None: + result = escape_xml_attr('') + assert "<" in result + assert "&" in result + assert """ in result + assert ">" in result + + def test_none_input(self) -> None: + assert escape_xml_attr(None) == "" + + def test_integer_input(self) -> None: + assert escape_xml_attr(42) == "42" + + def test_float_input(self) -> None: + assert escape_xml_attr(3.14) == "3.14" + + def test_bool_input(self) -> None: + assert escape_xml_attr(True) == "True" + assert escape_xml_attr(False) == "False" + + def test_empty_string(self) -> None: + assert escape_xml_attr("") == "" + + def test_object_with_str(self) -> None: + class Obj: + def __str__(self) -> str: + return '' + + result = escape_xml_attr(Obj()) + assert "<script>" in result + assert """ in result + + def test_unicode(self) -> None: + assert escape_xml_attr("日本語") == "日本語" + + def test_zero(self) -> None: + assert escape_xml_attr(0) == "0" diff --git a/uv.lock b/uv.lock index 810f4a08..d1d20530 100644 --- a/uv.lock +++ b/uv.lock @@ -4638,7 +4638,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.3.1" +version = "3.3.2" source = { editable = "." } dependencies = [ { name = "aiofiles" },