[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) _MEDIA_LABELS = { "image": "图片", "file": "文件", @@ -107,10 +115,11 @@ class RenderedRichMessage: delivery_text: str history_text: str attachments: list[dict[str, str]] + pending_file_sends: tuple[AttachmentRecord, ...] = () class AttachmentRenderError(RuntimeError): - """Raised when a ` ` tag cannot be rendered.""" + """Raised when an attachment tag cannot be rendered.""" def _now_iso() -> str: @@ -681,6 +690,25 @@ def _build_uid(self, prefix: str) -> str: if uid not in self._records: return uid + def _find_by_sha256( + self, scope_key: str, sha256: str, kind: str + ) -> AttachmentRecord | None: + """Find an existing record with matching scope, kind, and SHA-256. + + Only returns a record whose *local_path* still exists on disk. + Must be called while ``self._lock`` is held. + """ + for record in self._records.values(): + if ( + record.scope_key == scope_key + and record.sha256 == sha256 + and record.kind == kind + and record.local_path + and Path(record.local_path).is_file() + ): + return record + return None + async def register_bytes( self, scope_key: str, @@ -703,15 +731,18 @@ async def register_bytes( prefix = "pic" if normalized_media_type == "image" else "file" async with self._lock: + digest = await asyncio.to_thread(hashlib.sha256, content) + digest_hex = digest.hexdigest() + + existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) + if existing is not None: + return existing + uid = self._build_uid(prefix) file_name = f"{uid}{suffix}" cache_path = ensure_dir(self._cache_dir) / file_name + await asyncio.to_thread(cache_path.write_bytes, content) - def _write() -> str: - cache_path.write_bytes(content) - return hashlib.sha256(content).hexdigest() - - digest = await asyncio.to_thread(_write) record = AttachmentRecord( uid=uid, scope_key=scope_key, @@ -722,7 +753,7 @@ def _write() -> str: source_ref=source_ref, local_path=str(cache_path), mime_type=normalized_mime, - sha256=digest, + sha256=digest_hex, created_at=_now_iso(), segment_data={ str(k): str(v) @@ -1064,31 +1095,38 @@ async def _collect_from_segments( ) -async def render_message_with_pic_placeholders( +async def render_message_with_attachments( message: str, *, registry: AttachmentRegistry | None, scope_key: str | None, strict: bool, ) -> RenderedRichMessage: - if ( - not message - or registry is None - or not scope_key - or " `` and `` `` tags into delivery/history text. + + * `` `` — backward-compatible, image-only. + * `` `` — unified tag for any media type. + Images (``pic_*``) are inlined as CQ images; files (``file_*``) + are collected into *pending_file_sends* for later dispatch. + """ + has_tags = message and ( + " tag: strictly image-only + if tag_name == "pic" and record.media_type != "image": replacement = f"[图片 uid={uid} 类型错误]" if strict: raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") @@ -1110,31 +1151,19 @@ async def render_message_with_pic_placeholders( history_parts.append(replacement) continue - image_source = record.source_ref - if record.local_path: - image_source = Path(record.local_path).resolve().as_uri() - elif not image_source: - replacement = f"[图片 uid={uid} 缺少文件]" - if strict: - raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - continue - - cq_args = [f"file={image_source}"] - for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): - cleaned_key = str(key or "").strip() - cleaned_value = str(value or "").strip() - if not cleaned_key or not cleaned_value or cleaned_key == "file": - continue - cq_args.append( - f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" - ) - delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") - if record.display_name: - history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + # Route by media type + if record.media_type == "image": + _render_image_tag(record, uid, strict, delivery_parts, history_parts) else: - history_parts.append(f"[图片 uid={uid}]") + _render_file_tag( + record, + uid, + strict, + delivery_parts, + history_parts, + pending_files, + ) + attachments.append(record.prompt_ref()) delivery_parts.append(message[last_index:]) @@ -1143,4 +1172,113 @@ async def render_message_with_pic_placeholders( delivery_text="".join(delivery_parts), history_text="".join(history_parts), attachments=attachments, + pending_file_sends=tuple(pending_files), ) + + +def _render_image_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], +) -> None: + """Render an image attachment as an inline CQ:image.""" + image_source = record.source_ref + if record.local_path: + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return + + cq_args = [f"file={image_source}"] + for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): + cleaned_key = str(key or "").strip() + cleaned_value = str(value or "").strip() + if not cleaned_key or not cleaned_value or cleaned_key == "file": + continue + cq_args.append( + f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" + ) + delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + + +def _render_file_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], + pending_files: list[AttachmentRecord], +) -> None: + """Render a non-image attachment as a pending file send.""" + if not record.local_path or not Path(record.local_path).is_file(): + replacement = f"[文件 uid={uid} 缺少本地文件]" + if strict: + raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return + + # Remove from delivery text (file sent separately) + # Keep a readable placeholder in history + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + + +# Backward-compatible alias +render_message_with_pic_placeholders = render_message_with_attachments + + +async def dispatch_pending_file_sends( + rendered: RenderedRichMessage, + *, + sender: Any, + target_type: str, + target_id: int, +) -> None: + """Send pending file attachments collected by *render_message_with_attachments*. + + This is best-effort: each file send failure is logged but does not interrupt + the remaining sends or the caller. + """ + if not rendered.pending_file_sends or sender is None: + return + for record in rendered.pending_file_sends: + if not record.local_path or not Path(record.local_path).is_file(): + logger.warning( + "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", + record.uid, + record.local_path, + ) + continue + try: + if target_type == "group": + await sender.send_group_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + else: + await sender.send_private_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + except Exception: + logger.warning( + "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", + record.uid, + target_type, + target_id, + exc_info=True, + ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 651da57..8cfeae9 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -72,6 +72,7 @@ class GroupPokeRecord: group_name: str sender_role: str sender_title: str + sender_level: str class MessageHandler: @@ -113,6 +114,7 @@ def __init__( self.security, queue_manager=self.queue_manager, rate_limiter=self.rate_limiter, + history_manager=self.history_manager, ) self.ai_coordinator = AICoordinator( config, @@ -144,6 +146,77 @@ def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: self._repeat_locks[group_id] = lock return lock + async def _annotate_meme_descriptions( + self, + attachments: list[dict[str, str]], + scope_key: str, + ) -> list[dict[str, str]]: + """为图片附件添加表情包描述(如果在表情库中找到)。 + + 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 + 最佳努力:任何失败时返回原始列表。 + """ + if not attachments: + return attachments + + ai_client = getattr(self, "ai", None) + if ai_client is None: + return attachments + + attachment_registry = getattr(ai_client, "attachment_registry", None) + if attachment_registry is None: + return attachments + + meme_service = getattr(ai_client, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return attachments + + meme_store = getattr(meme_service, "_store", None) + if meme_store is None: + return attachments + + try: + # 1. 从图片附件收集唯一的 SHA256 哈希值 + uid_to_hash: dict[str, str] = {} + for att in attachments: + uid = att.get("uid", "") + if not uid.startswith("pic_"): + continue + record = attachment_registry.resolve(uid, scope_key) + if record and record.sha256: + uid_to_hash[uid] = record.sha256 + + if not uid_to_hash: + return attachments + + # 2. 批量查询:去重哈希值 + unique_hashes = set(uid_to_hash.values()) + hash_to_desc: dict[str, str] = {} + for h in unique_hashes: + meme = await meme_store.find_by_sha256(h) + if meme and meme.description: + hash_to_desc[h] = meme.description + + if not hash_to_desc: + return attachments + + # 3. 构建带注释的新列表 + result: list[dict[str, str]] = [] + for att in attachments: + uid = att.get("uid", "") + sha = uid_to_hash.get(uid, "") + desc = hash_to_desc.get(sha, "") + if desc: + new_att = dict(att) + new_att["description"] = f"[表情包] {desc}" + result.append(new_att) + else: + result.append(att) + return result + except Exception: + logger.warning("表情包自动匹配失败,跳过", exc_info=True) + return attachments + async def _collect_message_attachments( self, message_content: list[dict[str, Any]], @@ -176,7 +249,10 @@ async def _collect_message_attachments( if onebot else None, ) - return result.attachments + attachments = result.attachments + # 为图片附件添加表情包描述 + attachments = await self._annotate_meme_descriptions(attachments, scope_key) + return attachments def _schedule_meme_ingest( self, @@ -384,6 +460,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_poke.group_name, sender_role=group_poke.sender_role, sender_title=group_poke.sender_title, + sender_level=group_poke.sender_level, ) return @@ -570,6 +647,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_nickname: str = group_sender.get("nickname", "") sender_role: str = group_sender.get("role", "member") sender_title: str = group_sender.get("title", "") + sender_level: str = str(group_sender.get("level", "")).strip() # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) @@ -623,6 +701,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_name, role=sender_role, title=sender_title, + level=sender_level, message_id=trigger_message_id, attachments=group_attachments, ) @@ -775,6 +854,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, trigger_message_id=trigger_message_id, ) @@ -837,11 +917,13 @@ async def _record_group_poke_history( sender_nickname = "" sender_role = "member" sender_title = "" + sender_level = "" if isinstance(sender, dict): sender_card = str(sender.get("card", "")).strip() sender_nickname = str(sender.get("nickname", "")).strip() sender_role = str(sender.get("role", "member")).strip() or "member" sender_title = str(sender.get("title", "")).strip() + sender_level = str(sender.get("level", "")).strip() if not sender_card and not sender_nickname: try: @@ -855,6 +937,7 @@ async def _record_group_poke_history( str(member_info.get("role", "member")).strip() or "member" ) sender_title = str(member_info.get("title", "")).strip() + sender_level = str(member_info.get("level", "")).strip() except Exception as exc: logger.warning( "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", @@ -898,6 +981,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, role=sender_role, title=sender_title, + level=sender_level, ) except Exception as exc: logger.warning( @@ -912,6 +996,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, ) async def _extract_bilibili_ids( diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 5319a7d..321e555 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -6,6 +6,7 @@ from Undefined.attachments import ( attachment_refs_to_xml, build_attachment_scope, + dispatch_pending_file_sends, render_message_with_pic_placeholders, ) from Undefined.config import Config @@ -72,6 +73,7 @@ async def handle_auto_reply( group_name: str = "未知群聊", sender_role: str = "member", sender_title: str = "", + sender_level: str = "", trigger_message_id: int | None = None, ) -> None: """群聊自动回复入口:根据消息内容、命中情况和安全检测决定是否回复 @@ -130,6 +132,7 @@ async def handle_auto_reply( text, attachments=attachments, message_id=trigger_message_id, + level=sender_level, ) logger.debug( "[自动回复] full_question_len=%s group=%s sender=%s", @@ -472,6 +475,12 @@ async def send_private_cb( rendered.delivery_text, history_message=rendered.history_text, ) + await dispatch_pending_file_sends( + rendered, + sender=self.sender, + target_type="private", + target_id=user_id, + ) except Exception: logger.exception("私聊回复执行出错") raise @@ -707,6 +716,7 @@ def _build_prompt( text: str, attachments: list[dict[str, str]] | None = None, message_id: int | None = None, + level: str = "", ) -> str: """构建最终发送给 AI 的结构化 XML 消息 Prompt @@ -724,10 +734,11 @@ def _build_prompt( message_id_attr = "" if message_id is not None: message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" attachment_xml = ( f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" ) - return f"""{prefix} + return f"""{prefix} diff --git a/src/Undefined/services/command.py b/src/Undefined/services/command.py index 8abfcf5..075fd95 100644 --- a/src/Undefined/services/command.py +++ b/src/Undefined/services/command.py @@ -94,6 +94,7 @@ def __init__( security: SecurityService, queue_manager: Any = None, rate_limiter: Any = None, + history_manager: Any = None, ) -> None: """初始化命令分发器 @@ -106,6 +107,7 @@ def __init__( security: 安全审计与限流服务 queue_manager: AI 请求队列管理器 rate_limiter: 速率限制器 + history_manager: 消息历史记录管理器 """ self.config = config self.sender = sender @@ -115,6 +117,7 @@ def __init__( self.security = security self.queue_manager = queue_manager self.rate_limiter = rate_limiter + self.history_manager = history_manager self.naga_store: Any = None self._token_usage_storage = TokenUsageStorage() # 存储 stats 分析结果,用于队列回调 @@ -1078,6 +1081,8 @@ async def _send_target_message(message: str) -> None: scope=scope, user_id=user_id, is_webui_session=is_webui_session, + cognitive_service=getattr(self.ai, "_cognitive_service", None), + history_manager=self.history_manager, ) try: diff --git a/src/Undefined/services/commands/context.py b/src/Undefined/services/commands/context.py index 507af5e..56732b9 100644 --- a/src/Undefined/services/commands/context.py +++ b/src/Undefined/services/commands/context.py @@ -32,3 +32,5 @@ class CommandContext: scope: str = "group" user_id: int | None = None is_webui_session: bool = False + cognitive_service: Any = None + history_manager: Any = None diff --git a/src/Undefined/skills/agents/summary_agent/__init__.py b/src/Undefined/skills/agents/summary_agent/__init__.py new file mode 100644 index 0000000..13726fa --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/__init__.py @@ -0,0 +1 @@ +# summary_agent diff --git a/src/Undefined/skills/agents/summary_agent/config.json b/src/Undefined/skills/agents/summary_agent/config.json new file mode 100644 index 0000000..1536027 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "summary_agent", + "description": "消息总结助手,拉取指定范围的聊天消息并进行智能总结。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "用户的总结需求,例如:'总结最近50条消息'、'总结过去1小时的聊天'、'总结今天的技术讨论'" + } + }, + "required": ["prompt"] + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/handler.py b/src/Undefined/skills/agents/summary_agent/handler.py new file mode 100644 index 0000000..757f8a2 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/handler.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 summary_agent。""" + user_prompt = str(args.get("prompt", "")).strip() + return await run_agent_with_tools( + agent_name="summary_agent", + user_content=user_prompt, + empty_user_content_message="请提供您的总结需求", + default_prompt="你是一个消息总结助手。使用 fetch_messages 工具获取聊天记录,然后进行智能总结。", + context=context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=10, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/summary_agent/intro.md b/src/Undefined/skills/agents/summary_agent/intro.md new file mode 100644 index 0000000..a7343db --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/intro.md @@ -0,0 +1,30 @@ +# summary_agent - 消息总结助手 + +## 定位 +专业的聊天消息总结智能体,从海量聊天记录中提取关键信息并生成结构化报告。 + +## 擅长 +- ✅ 按**条数**拉取消息 (默认50条,最大500条) +- ✅ 按**时间范围**拉取消息 (支持1h/6h/1d/7d等格式) +- ✅ **话题提取**: 识别主要讨论主题 +- ✅ **要点归纳**: 总结重要决策、结论、共识 +- ✅ **参与者分析**: 识别活跃用户及其贡献 +- ✅ **资源收集**: 提取链接、代码片段等 +- ✅ 输出**结构化、清晰**的总结报告 + +## 边界 +- ❌ **不做实时监控**: 仅分析历史消息,不监听新消息 +- ❌ **不做情感分析**: 专注事实总结,不评价情绪 +- ❌ **不做预测**: 不推测未来走向 + +## 输入偏好 +- 明确的**条数要求**: "总结最近100条消息" +- 明确的**时间范围**: "总结过去6小时的聊天"、"总结今天的讨论" +- 特定的**总结目标**: "总结技术讨论"、"提取所有链接" +- 如果用户仅说"总结一下",默认总结最近50条消息 + +## 适用场景 +- 快速了解错过的聊天内容 +- 回顾会议或讨论要点 +- 整理特定时间段的聊天记录 +- 提取重要链接和资源 diff --git a/src/Undefined/skills/agents/summary_agent/prompt.md b/src/Undefined/skills/agents/summary_agent/prompt.md new file mode 100644 index 0000000..3b873f9 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/prompt.md @@ -0,0 +1,79 @@ +# 消息总结助手 + +你是一个专业的聊天消息总结助手,擅长从大量聊天记录中提取关键信息并生成结构化的总结报告。 + +## 核心能力 + +- 使用 `fetch_messages` 工具拉取指定范围的聊天消息 +- 支持按**消息条数**(如最近50条)或**时间范围**(如过去1小时、今天)筛选 +- 提取主题、关键参与者、重要决策、链接资源等 +- 生成清晰、结构化的总结报告 + +## 工作流程 + +1. **理解需求**: 分析用户的总结需求,确定查询参数 + - 如果用户指定了条数(如"最近50条"),使用 `count` 参数 + - 如果用户指定了时间范围(如"过去1小时"、"今天"),使用 `time_range` 参数 + - 如果用户未明确指定,**默认使用最近50条消息** + +2. **拉取消息**: 调用 `fetch_messages` 工具获取聊天记录 + - `count`: 消息条数,默认50,最大500 + - `time_range`: 时间范围,支持 "1h"(1小时)、"6h"(6小时)、"1d"(1天)、"7d"(7天) + +3. **分析总结**: 对获取的消息进行智能分析 + - 识别主要讨论话题 + - 提取关键参与者及其贡献 + - 总结重要决策、结论或共识 + - 收集提到的链接、资源、代码片段 + - 标注特别重要或需要关注的信息 + +4. **生成报告**: 以清晰的结构化格式输出总结 + - 使用**要点列表**而非长段落 + - 保持**简洁但全面** + - 突出**关键信息** + +## 输出格式建议 + +``` +📊 消息总结 (时间范围/条数) + +🔍 主要话题: +- 话题1: 简要描述 +- 话题2: 简要描述 + +👥 活跃参与者: +- 用户A: 主要贡献 +- 用户B: 主要贡献 + +💡 重要要点: +- 要点1 +- 要点2 + +🔗 相关链接/资源: +- 链接1 +- 链接2 + +⚠️ 需要关注: +- 重要事项或待办 +``` + +## 注意事项 + +- 保持**客观中立**,不加入主观评价 +- 如果消息量很大,**优先突出重点**,可省略琐碎细节 +- 如果讨论涉及敏感话题,**谨慎措辞** +- 如果消息为空或无有效内容,**明确说明** + +## 示例场景 + +**用户**: "总结最近100条消息" +→ 调用 `fetch_messages(count=100)` → 分析并总结 + +**用户**: "总结过去6小时的聊天" +→ 调用 `fetch_messages(time_range="6h")` → 分析并总结 + +**用户**: "总结今天大家讨论了什么技术问题" +→ 调用 `fetch_messages(time_range="1d")` → 重点提取技术相关话题 + +**用户**: "总结一下" +→ 调用 `fetch_messages(count=50)` (使用默认值) → 总结最近消息 diff --git a/src/Undefined/skills/agents/summary_agent/tools/__init__.py b/src/Undefined/skills/agents/summary_agent/tools/__init__.py new file mode 100644 index 0000000..fb43c21 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/__init__.py @@ -0,0 +1 @@ +# summary_agent tools diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json new file mode 100644 index 0000000..d64a015 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json @@ -0,0 +1,20 @@ +{ + "type": "function", + "function": { + "name": "fetch_messages", + "description": "从当前会话拉取聊天消息。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "description": "要获取的消息条数,默认50,最大500。" + }, + "time_range": { + "type": "string", + "description": "时间范围,如 '1h'、'6h'、'1d'、'7d'。与 count 互斥,优先使用 time_range。" + } + } + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py new file mode 100644 index 0000000..7ea5534 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import logging +import re +from datetime import datetime, timedelta +from typing import Any + +logger = logging.getLogger(__name__) + +_TIME_RANGE_PATTERN = re.compile(r"^(\d+)([hHdDwW])$") +_TIME_UNIT_SECONDS = {"h": 3600, "d": 86400, "w": 604800} +_MAX_COUNT = 500 +_DEFAULT_COUNT = 50 +_MAX_FETCH_FOR_TIME_FILTER = 2000 + + +def _parse_time_range(value: str) -> int | None: + """Parse time range string like '1h', '6h', '1d', '7d' into seconds.""" + match = _TIME_RANGE_PATTERN.match(value.strip()) + if not match: + return None + amount = int(match.group(1)) + unit = match.group(2).lower() + return amount * _TIME_UNIT_SECONDS.get(unit, 3600) + + +def _filter_by_time( + messages: list[dict[str, Any]], seconds: int +) -> list[dict[str, Any]]: + """Filter messages to only include those within the given time range from now.""" + cutoff = datetime.now() - timedelta(seconds=seconds) + result = [] + for msg in messages: + ts_str = msg.get("timestamp", "") + if not ts_str: + continue + try: + ts = datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S") + except ValueError: + continue + if ts >= cutoff: + result.append(msg) + return result + + +def _format_messages(messages: list[dict[str, Any]]) -> str: + """Format messages into readable text for the summary agent.""" + lines = [] + for msg in messages: + ts = msg.get("timestamp", "") + name = msg.get("display_name", "未知用户") + text = msg.get("message", "") + role = msg.get("role", "") + title = msg.get("title", "") + + prefix = f"[{ts}] " + if title: + prefix += f"[{title}] " + if role and role not in ("member", ""): + prefix += f"({role}) " + prefix += f"{name}: " + lines.append(f"{prefix}{text}") + return "\n".join(lines) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """拉取当前会话的聊天消息。""" + history_manager = context.get("history_manager") + if not history_manager: + return "历史记录管理器未配置" + + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + + if int(group_id) > 0: + chat_type = "group" + chat_id = str(group_id) + else: + chat_type = "private" + chat_id = str(user_id) + + time_range_str = str(args.get("time_range", "")).strip() + raw_count = args.get("count", _DEFAULT_COUNT) + try: + count = min(max(int(raw_count), 1), _MAX_COUNT) + except (TypeError, ValueError): + count = _DEFAULT_COUNT + + if time_range_str: + seconds = _parse_time_range(time_range_str) + if seconds is None: + return f"无法解析时间范围: {time_range_str}(支持格式: 1h, 6h, 1d, 7d)" + fetch_count = max(count * 2, _MAX_FETCH_FOR_TIME_FILTER) + messages = history_manager.get_recent(chat_id, chat_type, 0, fetch_count) + if messages: + messages = _filter_by_time(messages, seconds) + else: + messages = history_manager.get_recent(chat_id, chat_type, 0, count) + + if not messages: + return "当前会话暂无消息记录" + + formatted = _format_messages(messages) + total = len(messages) + header = f"共获取 {total} 条消息" + if time_range_str: + header += f"(时间范围: {time_range_str})" + return f"{header}\n\n{formatted}" diff --git a/src/Undefined/skills/commands/profile/config.json b/src/Undefined/skills/commands/profile/config.json new file mode 100644 index 0000000..6b47c63 --- /dev/null +++ b/src/Undefined/skills/commands/profile/config.json @@ -0,0 +1,16 @@ +{ + "name": "profile", + "description": "查看用户或群聊侧写", + "usage": "/profile [group]", + "example": "/profile", + "permission": "public", + "rate_limit": { + "user": 60, + "admin": 10, + "superadmin": 0 + }, + "show_in_help": true, + "order": 25, + "allow_in_private": true, + "aliases": ["me", "p"] +} diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py new file mode 100644 index 0000000..aba240a --- /dev/null +++ b/src/Undefined/skills/commands/profile/handler.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from Undefined.services.commands.context import CommandContext + +_MAX_PROFILE_LENGTH = 3000 + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: + if len(text) <= limit: + return text + return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" + + +async def _send(context: CommandContext, text: str) -> None: + """Send message to appropriate channel.""" + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /profile 命令。""" + cognitive_service = context.cognitive_service + if cognitive_service is None: + await _send(context, "❌ 侧写服务未启用") + return + + # Parse subcommand + sub = args[0].lower().strip() if args else "" + + if sub == "group": + if _is_private(context): + await _send(context, "❌ 私聊中不支持查看群聊侧写") + return + entity_type = "groups" + entity_id = str(context.group_id) + empty_hint = "暂无群聊侧写数据" + else: + entity_type = "users" + entity_id = str(context.sender_id) + empty_hint = "暂无侧写数据" + + profile = await cognitive_service.get_profile(entity_type, entity_id) + if not profile: + await _send(context, f"📭 {empty_hint}") + return + + await _send(context, _truncate(profile)) diff --git a/src/Undefined/skills/commands/summary/config.json b/src/Undefined/skills/commands/summary/config.json new file mode 100644 index 0000000..9135eda --- /dev/null +++ b/src/Undefined/skills/commands/summary/config.json @@ -0,0 +1,16 @@ +{ + "name": "summary", + "description": "总结聊天消息", + "usage": "/summary [条数|时间范围] [自定义描述]", + "example": "/summary 50", + "permission": "public", + "rate_limit": { + "user": 120, + "admin": 30, + "superadmin": 0 + }, + "show_in_help": true, + "order": 26, + "allow_in_private": true, + "aliases": ["sum"] +} diff --git a/src/Undefined/skills/commands/summary/handler.py b/src/Undefined/skills/commands/summary/handler.py new file mode 100644 index 0000000..a35dc4d --- /dev/null +++ b/src/Undefined/skills/commands/summary/handler.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +import re +from typing import Any + +from Undefined.services.commands.context import CommandContext + +logger = logging.getLogger(__name__) + +_TIME_RANGE_RE = re.compile(r"^\d+[hHdDwW]$") +_DEFAULT_COUNT = 50 + + +def _parse_args(args: list[str]) -> tuple[int | None, str | None, str]: + """Parse command arguments into (count, time_range, custom_prompt). + + Returns: + Tuple of (count, time_range, custom_prompt). + count and time_range are mutually exclusive; at most one is non-None. + """ + if not args: + return _DEFAULT_COUNT, None, "" + + first = args[0] + rest = " ".join(args[1:]).strip() + + if first.isdigit(): + count = max(1, min(int(first), 500)) + return count, None, rest + + if _TIME_RANGE_RE.match(first): + return None, first, rest + + # First arg is not a number or time range — treat everything as prompt + return _DEFAULT_COUNT, None, " ".join(args).strip() + + +def _build_prompt(count: int | None, time_range: str | None, custom_prompt: str) -> str: + """Build the natural language prompt for summary_agent.""" + parts: list[str] = ["请总结"] + if time_range: + parts.append(f"过去 {time_range} 内的聊天消息") + elif count: + parts.append(f"最近 {count} 条聊天消息") + else: + parts.append(f"最近 {_DEFAULT_COUNT} 条聊天消息") + + if custom_prompt: + parts.append(f",重点关注:{custom_prompt}") + + return "".join(parts) + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +async def _send(context: CommandContext, text: str) -> None: + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /summary 命令。""" + if context.history_manager is None: + await _send(context, "❌ 历史记录管理器未配置") + return + + count, time_range, custom_prompt = _parse_args(args) + prompt = _build_prompt(count, time_range, custom_prompt) + + # Build agent context + agent_context: dict[str, Any] = { + "ai_client": context.ai, + "history_manager": context.history_manager, + "group_id": context.group_id, + "user_id": int(context.user_id or context.sender_id), + "sender_id": context.sender_id, + "request_type": "group" if int(context.group_id) > 0 else "private", + "runtime_config": getattr(context.ai, "runtime_config", None), + "queue_lane": None, + } + + await _send(context, "📝 正在总结消息,请稍候...") + + try: + from Undefined.skills.agents.summary_agent.handler import execute as run_summary + + result = await run_summary({"prompt": prompt}, agent_context) + except Exception: + logger.exception("[/summary] 执行总结失败") + await _send(context, "❌ 消息总结失败,请稍后重试") + return + + if not result or not result.strip(): + await _send(context, "📭 未能生成总结内容") + return + + await _send(context, result) diff --git a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py index 05f4ba6..12ff836 100644 --- a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py @@ -178,6 +178,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: timestamp_str = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") if msg_type_val == "group": # 确保群名以"群"结尾 @@ -189,8 +192,17 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + # 格式:XML 标准化 - formatted.append(f""" {safe_text} {attachment_xml}+ formatted.append(f""" """) diff --git a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py index fda39ef..54874bb 100644 --- a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py @@ -105,6 +105,9 @@ def _format_message_xml(msg: dict[str, Any]) -> str: timestamp = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") location = _format_message_location(msg_type_val, chat_name) @@ -112,7 +115,16 @@ def _format_message_xml(msg: dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' - return f""" {text} + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + + return f""" """ diff --git a/src/Undefined/skills/toolsets/messages/send_message/handler.py b/src/Undefined/skills/toolsets/messages/send_message/handler.py index 9d21aab..f5e8b13 100644 --- a/src/Undefined/skills/toolsets/messages/send_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -168,6 +169,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type=target_type, + target_id=target_id, + ) return _format_send_success(sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py index 727cb2d..f92cab8 100644 --- a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -115,6 +116,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type="private", + target_id=user_id, + ) return _format_send_success(user_id, sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/render/render_latex/config.json b/src/Undefined/skills/toolsets/render/render_latex/config.json index a351083..16b1f99 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/config.json +++ b/src/Undefined/skills/toolsets/render/render_latex/config.json @@ -2,27 +2,18 @@ "type": "function", "function": { "name": "render_latex", - "description": "将 LaTeX 公式或文档渲染为图片,使用系统外部 LaTeX(需预先安装 TeX Live / MiKTeX)。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。", + "description": "将 LaTeX 数学公式渲染为图片或 PDF 文档,使用 MathJax(不依赖系统 TeX 安装)。支持 LaTeX 数学子集(amsmath、equation、align、matrix 等),但不支持自定义 TeX 包。返回可嵌入回复的附件 UID。", "parameters": { "type": "object", "properties": { "content": { "type": "string", - "description": "要渲染的 LaTeX 内容。支持 $...$、$$...$$、\\[...\\]、\\(...\\) 及完整环境(\\begin{align}...\\end{align} 等);\\begin{document}...\\end{document} 外层包装会自动去掉。" + "description": "要渲染的 LaTeX 数学内容。支持 $...$(行内)、$$...$$(块级)、\\[...\\]、\\(...\\) 及标准数学环境(\\begin{align}、\\begin{equation}、\\begin{matrix} 等)。如果不包含分隔符,会自动用 \\[ ... \\] 包装。\\begin{document}...\\end{document} 外层包装会自动去掉。" }, - "delivery": { + "output_format": { "type": "string", - "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", - "enum": ["embed", "send"] - }, - "target_id": { - "type": "integer", - "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" - }, - "message_type": { - "type": "string", - "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", - "enum": ["group", "private"] + "description": "输出格式:\"png\"(默认,图片)或 \"pdf\"(PDF 文档)", + "enum": ["png", "pdf"] } }, "required": ["content"] diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 8ac6e06..a449cd2 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -1,15 +1,8 @@ from __future__ import annotations -from pathlib import Path +import logging import re from typing import Any, Dict -import logging -import uuid - -import matplotlib - -matplotlib.use("Agg") -import matplotlib.pyplot as plt from Undefined.attachments import scope_from_context @@ -20,9 +13,15 @@ re.DOTALL, ) +# MathJax 数学分隔符模式 +_MATH_DELIMITER_PATTERN = re.compile( + r"(\$\$|\\\[|\\\(|\\begin\{)", + re.MULTILINE, +) + def _strip_document_wrappers(content: str) -> str: - """去掉 \\begin{document}...\\end{document} 外层包装;matplotlib 会自行构造文档。""" + """去掉 \\begin{document}...\\end{document} 外层包装。""" text = content.strip() match = _DOCUMENT_PATTERN.fullmatch(text) if match is None: @@ -30,136 +29,175 @@ def _strip_document_wrappers(content: str) -> str: return match.group("body").strip() -def _render_latex_image(filepath: Path, content: str) -> None: - text = _strip_document_wrappers(content) - fig = plt.figure(figsize=(6, 2.5)) +def _has_math_delimiters(content: str) -> bool: + """检查内容是否已包含数学分隔符。""" + return bool(_MATH_DELIMITER_PATTERN.search(content)) + + +def _prepare_content(raw_content: str) -> str: + """ + 准备 LaTeX 内容: + 1. 去掉 document 包装 + 2. 处理字面量 \\n(LLM 输出常见问题) + 3. 如果没有数学分隔符,自动用 \\[ ... \\] 包装 + """ + content = _strip_document_wrappers(raw_content) + # 替换字面量 \\n 为真实换行符 + content = content.replace("\\n", "\n") + + if not _has_math_delimiters(content): + # 没有分隔符,自动包装为块级数学环境 + content = f"\\[\n{content}\n\\]" + + return content + + +def _build_html(latex_content: str) -> str: + """构建包含 MathJax 的 HTML 页面。""" + # HTML 转义(防止内容中的 < > & 破坏结构) + import html + + escaped_content = html.escape(latex_content) + + return f""" + + + + + + + + {text} +{escaped_content} ++ +""" + + +async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[bytes, str]: + """ + 使用 MathJax + Playwright 渲染 LaTeX 内容。 + + 返回: (渲染后的字节流, MIME 类型) + """ try: - fig.patch.set_facecolor("white") - fig.text( - 0.5, - 0.5, - text, - fontsize=20, - verticalalignment="center", - horizontalalignment="center", - usetex=True, - wrap=True, + from playwright.async_api import ( + async_playwright, + TimeoutError as PwTimeoutError, ) - fig.savefig(filepath, dpi=200, bbox_inches="tight", pad_inches=0.25) - finally: - plt.close(fig) - - -def _resolve_send_target( - target_id: Any, - message_type: Any, - context: Dict[str, Any], -) -> tuple[int | str | None, str | None, str | None]: - """从参数或 context 推断发送目标。""" - if target_id is not None and message_type is not None: - return target_id, message_type, None - request_type = str(context.get("request_type", "") or "").strip().lower() - if request_type == "group": - gid = context.get("group_id") - if gid is not None: - return gid, "group", None - if request_type == "private": - uid = context.get("user_id") - if uid is not None: - return uid, "private", None - return None, None, "渲染成功,但缺少发送目标参数" + except ImportError: + raise ImportError( + "请运行 `uv run playwright install` 安装浏览器运行时" + ) from None + + html_content = _build_html(content) + + async with async_playwright() as p: + browser = await p.chromium.launch(headless=True) + try: + page = await browser.new_page() + await page.set_content(html_content) + + # 等待 MathJax 完成排版 + try: + await page.wait_for_function( + "() => window.MathJax?.startup?.promise?.then(() => true) ?? false", + timeout=15000, + ) + except PwTimeoutError: + logger.warning("MathJax 排版超时,内容可能过于复杂或网络不可达") + raise RuntimeError( + "LaTeX 内容可能过于复杂或网络不可达(MathJax 加载超时)" + ) from None + + if output_format == "pdf": + # 获取容器尺寸 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + bbox = await container.bounding_box() + if bbox is None: + raise RuntimeError("无法获取数学容器的边界框") + + # PDF 输出,设置合适的页面尺寸 + pdf_bytes = await page.pdf( + width=f"{bbox['width'] + 40}px", + height=f"{bbox['height'] + 40}px", + print_background=True, + ) + return pdf_bytes, "application/pdf" + else: + # PNG 输出 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + screenshot_bytes = await container.screenshot(type="png") + return screenshot_bytes, "image/png" + + finally: + await browser.close() async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: - """渲染 LaTeX 数学公式为图片""" - content = str(args.get("content", "") or "") - delivery = str(args.get("delivery", "embed") or "embed").strip().lower() - target_id = args.get("target_id") - message_type = args.get("message_type") + """渲染 LaTeX 数学公式为图片或 PDF""" + raw_content = str(args.get("content", "") or "") + output_format = str(args.get("output_format", "png") or "png").strip().lower() - if not content: - return "内容不能为空" - if delivery not in {"embed", "send"}: - return f"delivery 无效:{delivery}。仅支持 embed 或 send" + # 参数校验 + if not raw_content or not raw_content.strip(): + return "LaTeX 内容不能为空" - if delivery == "send" and message_type and message_type not in ("group", "private"): - return "消息类型必须是 group 或 private" + if output_format not in {"png", "pdf"}: + return f"output_format 无效:{output_format}。仅支持 png 或 pdf" try: - from Undefined.utils.cache import cleanup_cache_dir - from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir - - filename = f"render_{uuid.uuid4().hex[:16]}.png" - filepath = ensure_dir(RENDER_CACHE_DIR) / filename + # 准备内容 + prepared_content = _prepare_content(raw_content) - _render_latex_image(filepath, content) + # 渲染 + rendered_bytes, mime_type = await _render_latex_to_bytes( + prepared_content, output_format + ) # 注册到附件系统 attachment_registry = context.get("attachment_registry") scope_key = scope_from_context(context) - record: Any = None - if attachment_registry is not None and scope_key: - try: - record = await attachment_registry.register_local_file( - scope_key, - filepath, - kind="image", - display_name=filename, - source_kind="rendered_image", - source_ref="render_latex", - ) - except Exception as exc: - logger.warning("注册渲染图片到附件系统失败: %s", exc) - - if delivery == "embed": - cleanup_cache_dir(RENDER_CACHE_DIR) - if record is None: - return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" - return f'' - - # delivery == "send" - resolved_target_id, resolved_message_type, target_error = _resolve_send_target( - target_id, message_type, context - ) - if target_error or resolved_target_id is None or resolved_message_type is None: - return target_error or "渲染成功,但缺少发送目标参数" - - sender = context.get("sender") - send_image_callback = context.get("send_image_callback") - - if sender: - from pathlib import Path - - cq_message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" - if resolved_message_type == "group": - await sender.send_group_message(int(resolved_target_id), cq_message) - elif resolved_message_type == "private": - await sender.send_private_message(int(resolved_target_id), cq_message) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" - ) - elif send_image_callback: - await send_image_callback( - resolved_target_id, resolved_message_type, str(filepath) - ) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + + if attachment_registry is None or not scope_key: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + + kind = "image" if output_format == "png" else "file" + extension = "png" if output_format == "png" else "pdf" + display_name = f"latex.{extension}" + + try: + record = await attachment_registry.register_bytes( + scope_key, + rendered_bytes, + kind=kind, + display_name=display_name, + mime_type=mime_type, + source_kind="rendered_latex", + source_ref="render_latex", ) - else: - return "发送图片回调未设置" + tag = "pic" if output_format == "png" else "attachment" + return f'<{tag} uid="{record.uid}"/>' + + except Exception as exc: + logger.exception("注册渲染结果到附件系统失败: %s", exc) + return f"渲染成功,但注册到附件系统失败: {exc}" except ImportError as e: - missing_pkg = str(e).split("'")[1] if "'" in str(e) else "未知包" - return f"渲染失败:缺少依赖包 {missing_pkg},请运行: uv add {missing_pkg}" + logger.error("Playwright 导入失败: %s", e) + return "请运行 `uv run playwright install` 安装浏览器运行时" except RuntimeError as e: - err = str(e).lower() - if "latex" in err or "dvipng" in err or "dvi" in err: - logger.error("LaTeX 渲染失败(系统 TeX 环境不可用): %s", e) - return "渲染失败:系统 LaTeX 环境未安装或不完整,请按部署文档安装 TeX Live / MiKTeX 后重试。" - logger.exception("渲染并发送 LaTeX 图片失败: %s", e) - return "渲染失败,请稍后重试" + logger.error("LaTeX 渲染运行时错误: %s", e) + return str(e) except Exception as e: - logger.exception(f"渲染并发送 LaTeX 图片失败: {e}") - return "渲染失败,请稍后重试" + logger.exception("渲染 LaTeX 失败: %s", e) + return f"渲染失败:{e}" diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index 46dfafe..b6dfa4b 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -306,6 +306,7 @@ async def add_group_message( group_name: str = "", role: str = "member", title: str = "", + level: str = "", message_id: int | None = None, attachments: list[dict[str, str]] | None = None, ) -> None: @@ -334,6 +335,7 @@ async def add_group_message( "display_name": display_name, "role": role, "title": title, + "level": level, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "message": text_content, } diff --git a/tests/test_arxiv_sender.py b/tests/test_arxiv_sender.py index 120ab0a..312075b 100644 --- a/tests/test_arxiv_sender.py +++ b/tests/test_arxiv_sender.py @@ -17,6 +17,7 @@ @pytest.fixture(autouse=True) def _clear_inflight() -> None: arxiv_sender._INFLIGHT_SENDS.clear() + arxiv_sender._RECENT_SENDS.clear() def _paper_info() -> PaperInfo: @@ -165,3 +166,187 @@ async def _fake_once(**_: object) -> str: assert first_result == "ok" assert second_result == "ok" assert called == 1 + + +@pytest.mark.asyncio +async def test_recent_send_blocks_duplicate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After a successful send, duplicate should be blocked within cooldown.""" + sender = _sender() + call_count = 0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # First send should succeed + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Second send should be blocked by time-based dedup + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert "近期已发送过" in result2 + assert call_count == 1 # Still only 1 call + + +@pytest.mark.asyncio +async def test_recent_send_expires_after_cooldown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After cooldown expires, the same paper can be sent again.""" + sender = _sender() + call_count = 0 + mock_time = 0.0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + def _fake_monotonic() -> float: + return mock_time + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + # Patch time.monotonic in the arxiv_sender module + monkeypatch.setattr( + arxiv_sender, "time", SimpleNamespace(monotonic=_fake_monotonic) + ) + + # First send at time 0 + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Advance time past cooldown + mock_time = arxiv_sender._DEDUP_COOLDOWN_SECONDS + 1.0 + + # Second send should succeed now + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result2 == "ok" + assert call_count == 2 # Second call executed + + +@pytest.mark.asyncio +async def test_recent_send_capacity_limit() -> None: + """When _RECENT_SENDS exceeds max size, oldest entries are evicted.""" + # Fill with max_size + 100 entries + for i in range(arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + arxiv_sender._RECENT_SENDS[key] = float(i) + + # Trigger eviction + arxiv_sender._evict_oldest_recent_sends() + + # Should have evicted back to max size + assert len(arxiv_sender._RECENT_SENDS) == arxiv_sender._RECENT_SENDS_MAX_SIZE + + # Oldest 100 should be gone + for i in range(100): + key = ("group", i, f"paper_{i}") + assert key not in arxiv_sender._RECENT_SENDS + + # Newest max_size should remain + for i in range(100, arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + assert key in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_cleanup_expired_recent_sends() -> None: + """Test that expired entries are removed while non-expired remain.""" + import time as time_module + + # Get current time + now = time_module.monotonic() + + # Add expired entries (old timestamps, more than 1 hour ago) + expired_key1 = ("group", 1, "expired1") + expired_key2 = ("group", 2, "expired2") + arxiv_sender._RECENT_SENDS[expired_key1] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 100.0 + ) + arxiv_sender._RECENT_SENDS[expired_key2] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 50.0 + ) + + # Add non-expired entries (recent timestamps, within 1 hour) + recent_key1 = ("group", 3, "recent1") + recent_key2 = ("group", 4, "recent2") + arxiv_sender._RECENT_SENDS[recent_key1] = now - 100.0 # 100 seconds ago + arxiv_sender._RECENT_SENDS[recent_key2] = now - 10.0 # 10 seconds ago + + # Cleanup + arxiv_sender._cleanup_expired_recent_sends() + + # Expired should be gone + assert expired_key1 not in arxiv_sender._RECENT_SENDS + assert expired_key2 not in arxiv_sender._RECENT_SENDS + + # Recent should remain + assert recent_key1 in arxiv_sender._RECENT_SENDS + assert recent_key2 in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_failed_send_not_recorded_in_recent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed send should NOT record in _RECENT_SENDS.""" + sender = _sender() + + async def _fake_once(**_: object) -> str: + raise RuntimeError("Send failed") + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # Attempt send, should fail + with pytest.raises(RuntimeError, match="Send failed"): + await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + + # Should NOT be recorded in recent sends + assert len(arxiv_sender._RECENT_SENDS) == 0 diff --git a/tests/test_attachment_tags.py b/tests/test_attachment_tags.py new file mode 100644 index 0000000..762ffd4 --- /dev/null +++ b/tests/test_attachment_tags.py @@ -0,0 +1,336 @@ +"""Tests for unified / tag rendering.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.attachments import ( + AttachmentRegistry, + AttachmentRenderError, + RenderedRichMessage, + dispatch_pending_file_sends, + render_message_with_attachments, + render_message_with_pic_placeholders, +) + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +_PDF_BYTES = b"%PDF-1.4 fake content for testing" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +# ---------- backward compatibility ---------- + + +@pytest.mark.asyncio +async def test_pic_tag_still_works(tmp_path: Path) -> None: + """ backward compat: renders image as CQ.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Look: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert len(result.attachments) == 1 + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_alias_is_same_function() -> None: + """render_message_with_pic_placeholders is an alias.""" + assert render_message_with_pic_placeholders is render_message_with_attachments + + +# ---------- unified tag ---------- + + +@pytest.mark.asyncio +async def test_attachment_tag_image(tmp_path: Path) -> None: + """ renders as CQ image (same as ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Here: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_attachment_tag_file(tmp_path: Path) -> None: + """ collects into pending_file_sends.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'See doc: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + # File tag removed from delivery text + assert rec.uid not in result.delivery_text + assert "See doc: " in result.delivery_text + # Readable placeholder in history + assert f"[文件 uid={rec.uid}" in result.history_text + assert "doc.pdf" in result.history_text + # Collected in pending + assert len(result.pending_file_sends) == 1 + assert result.pending_file_sends[0].uid == rec.uid + + +@pytest.mark.asyncio +async def test_mixed_pic_and_attachment_tags(tmp_path: Path) -> None: + """Mix of and tags in the same message.""" + reg = _make_registry(tmp_path) + img = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="img.png", source_kind="test" + ) + doc = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' and ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert len(result.pending_file_sends) == 1 + assert len(result.attachments) == 2 + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image(tmp_path: Path) -> None: + """ tag with file UID shows type error.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "类型错误" in result.delivery_text + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image_strict(tmp_path: Path) -> None: + """ tag with file UID raises in strict mode.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' ' + with pytest.raises(AttachmentRenderError, match="不是图片"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_attachment_tag_allows_any_type(tmp_path: Path) -> None: + """ tag does NOT reject file UIDs (unlike ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + assert "类型错误" not in result.delivery_text + assert len(result.pending_file_sends) == 1 + + +@pytest.mark.asyncio +async def test_invalid_uid_non_strict(tmp_path: Path) -> None: + """Unknown UID → placeholder in non-strict mode.""" + reg = _make_registry(tmp_path) + msg = ' ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "不可用" in result.delivery_text + + +@pytest.mark.asyncio +async def test_invalid_uid_strict(tmp_path: Path) -> None: + """Unknown UID → exception in strict mode.""" + reg = _make_registry(tmp_path) + msg = ' ' + with pytest.raises(AttachmentRenderError, match="不可用"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_file_tag_missing_local_path(tmp_path: Path) -> None: + """File with deleted local_path → error placeholder.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + assert rec.local_path is not None + Path(rec.local_path).unlink() + + msg = f' ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "缺少本地文件" in result.delivery_text + assert len(result.pending_file_sends) == 0 + + +@pytest.mark.asyncio +async def test_no_tags_passthrough() -> None: + """Message without tags passes through unchanged.""" + result = await render_message_with_attachments( + "Hello world", registry=None, scope_key="group:1", strict=False + ) + assert result.delivery_text == "Hello world" + assert result.history_text == "Hello world" + assert result.attachments == [] + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_rendered_rich_message_default_pending() -> None: + """RenderedRichMessage.pending_file_sends defaults to empty tuple.""" + msg = RenderedRichMessage(delivery_text="hi", history_text="hi", attachments=[]) + assert msg.pending_file_sends == () + + +# ---------- dispatch_pending_file_sends ---------- + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_group(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_group_file for group targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file( + self, group_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("group", group_id, file_path, name)) + + async def send_private_file( + self, user_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("private", user_id, file_path, name)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="group", target_id=12345 + ) + assert len(calls) == 1 + assert calls[0][0] == "group" + assert calls[0][1] == 12345 + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_private(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_private_file for private targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", + _PDF_BYTES, + kind="file", + display_name="report.pdf", + source_kind="test", + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + calls.append(("group", *a)) + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + calls.append(("private", *a)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="private", target_id=99999 + ) + assert len(calls) == 1 + assert calls[0][0] == "private" + assert calls[0][1] == 99999 + + +@pytest.mark.asyncio +async def test_dispatch_best_effort_on_failure(tmp_path: Path) -> None: + """File send failure doesn't propagate — best-effort.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + class FailingSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + # Should not raise + await dispatch_pending_file_sends( + rendered, sender=FailingSender(), target_type="group", target_id=1 + ) + + +@pytest.mark.asyncio +async def test_dispatch_no_pending_is_noop() -> None: + """No pending files → no calls.""" + rendered = RenderedRichMessage( + delivery_text="text", history_text="text", attachments=[] + ) + await dispatch_pending_file_sends( + rendered, sender=None, target_type="group", target_id=1 + ) diff --git a/tests/test_attachments_dedup.py b/tests/test_attachments_dedup.py new file mode 100644 index 0000000..c1114bf --- /dev/null +++ b/tests/test_attachments_dedup.py @@ -0,0 +1,147 @@ +"""Tests for attachment SHA-256 hash deduplication.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from Undefined.attachments import AttachmentRegistry + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +# Different content → different hash +_PNG_BYTES_ALT = _PNG_BYTES + b"\x00" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +@pytest.mark.asyncio +async def test_same_hash_same_scope_same_kind_returns_same_uid( + tmp_path: Path, +) -> None: + """Identical bytes + scope + kind → dedup returns same record.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="b.png", source_kind="test" + ) + assert r1.uid == r2.uid + assert r1.sha256 == r2.sha256 + + +@pytest.mark.asyncio +async def test_different_content_gets_different_uid(tmp_path: Path) -> None: + """Different bytes → different records even in same scope/kind.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", + _PNG_BYTES_ALT, + kind="image", + display_name="b.png", + source_kind="test", + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_scope_no_dedup(tmp_path: Path) -> None: + """Same hash but different scope → separate records (scope isolation).""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:2", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_kind_no_dedup(tmp_path: Path) -> None: + """Same hash + scope but different kind → separate records.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="file", display_name="a.bin", source_kind="test" + ) + assert r1.uid != r2.uid + assert r1.uid.startswith("pic_") + assert r2.uid.startswith("file_") + + +@pytest.mark.asyncio +async def test_file_deleted_causes_new_registration(tmp_path: Path) -> None: + """If the cached file is deleted, a new record is created.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.local_path is not None + Path(r1.local_path).unlink() + + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r2.uid != r1.uid + assert r2.sha256 == r1.sha256 + + +@pytest.mark.asyncio +async def test_concurrent_identical_registrations(tmp_path: Path) -> None: + """Concurrent registrations with the same content should be safe.""" + reg = _make_registry(tmp_path) + + results = await asyncio.gather( + *( + reg.register_bytes( + "group:1", + _PNG_BYTES, + kind="image", + display_name="pic.png", + source_kind="test", + ) + for _ in range(5) + ) + ) + uids = {r.uid for r in results} + # All should resolve to the same UID (dedup) + assert len(uids) == 1 + + +@pytest.mark.asyncio +async def test_register_local_file_deduplicates(tmp_path: Path) -> None: + """register_local_file delegates to register_bytes, so dedup applies.""" + reg = _make_registry(tmp_path) + file_path = tmp_path / "input.png" + file_path.write_bytes(_PNG_BYTES) + + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_local_file( + "group:1", file_path, kind="image", display_name="input.png" + ) + assert r1.uid == r2.uid diff --git a/tests/test_coordinator_level.py b/tests/test_coordinator_level.py new file mode 100644 index 0000000..38ecaa6 --- /dev/null +++ b/tests/test_coordinator_level.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + + +from Undefined.services.ai_coordinator import AICoordinator + + +def _make_coordinator() -> AICoordinator: + """创建用于测试的 AICoordinator 实例""" + config = MagicMock() + ai = MagicMock() + queue_manager = MagicMock() + history_manager = MagicMock() + sender = MagicMock() + onebot = MagicMock() + scheduler = MagicMock() + security = MagicMock() + + coordinator = AICoordinator( + config=config, + ai=ai, + queue_manager=queue_manager, + history_manager=history_manager, + sender=sender, + onebot=onebot, + scheduler=scheduler, + security=security, + ) + + return coordinator + + +def test_build_prompt_with_level_includes_level_attribute() -> None: + """测试 _build_prompt 带 level 参数时 XML 包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + attachments=None, + message_id=123456, + level="Lv.5", + ) + + assert 'level="Lv.5"' in result + assert " None: + """测试 _build_prompt level 为空字符串时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level="", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt 不传 level 参数时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt level 包含特殊字符时能正确转义""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level='Lv.5 ', + ) + + assert "level=" in result + assert ' ' not in result + assert "<" in result or "&" in result or """ in result + + +def test_build_prompt_with_all_attributes() -> None: + """测试 _build_prompt 包含所有属性时的输出""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="系统提示:\n", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="admin", + title="管理员", + time_str="2026-04-11 10:00:00", + text="测试消息内容", + attachments=[{"uid": "pic_001", "kind": "image"}], + message_id=123456, + level="Lv.10", + ) + + assert "系统提示:\n" in result + assert 'sender="测试用户"' in result + assert 'sender_id="10001"' in result + assert 'group_id="20001"' in result + assert 'group_name="测试群"' in result + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert " 测试消息内容 " in result + assert "" in result diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py new file mode 100644 index 0000000..ddba794 --- /dev/null +++ b/tests/test_fetch_messages_tool.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from Undefined.skills.agents.summary_agent.tools.fetch_messages.handler import ( + _filter_by_time, + _format_messages, + _parse_time_range, + execute as fetch_messages_execute, +) + + +# -- _parse_time_range unit tests -- + + +def test_parse_time_range_1h() -> None: + """'1h' → 3600.""" + assert _parse_time_range("1h") == 3600 + + +def test_parse_time_range_6h() -> None: + """'6h' → 21600.""" + assert _parse_time_range("6h") == 21600 + + +def test_parse_time_range_1d() -> None: + """'1d' → 86400.""" + assert _parse_time_range("1d") == 86400 + + +def test_parse_time_range_7d() -> None: + """'7d' → 604800.""" + assert _parse_time_range("7d") == 604800 + + +def test_parse_time_range_1w() -> None: + """'1w' → 604800.""" + assert _parse_time_range("1w") == 604800 + + +def test_parse_time_range_case_insensitive() -> None: + """'1H', '1D' → correct values.""" + assert _parse_time_range("1H") == 3600 + assert _parse_time_range("1D") == 86400 + assert _parse_time_range("1W") == 604800 + + +def test_parse_time_range_invalid() -> None: + """'invalid' → None.""" + assert _parse_time_range("invalid") is None + assert _parse_time_range("") is None + assert _parse_time_range("abc") is None + assert _parse_time_range("1x") is None + + +def test_parse_time_range_with_whitespace() -> None: + """' 1d ' → 86400 (strips whitespace).""" + assert _parse_time_range(" 1d ") == 86400 + + +def test_parse_time_range_multi_digit() -> None: + """'24h' → 86400.""" + assert _parse_time_range("24h") == 86400 + + +# -- _filter_by_time unit tests -- + + +def test_filter_by_time_keeps_recent() -> None: + """Messages within time range are kept.""" + now = datetime.now() + recent = (now - timedelta(seconds=1800)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(seconds=7200)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": recent, "message": "recent"}, + {"timestamp": old, "message": "old"}, + ] + + result = _filter_by_time(messages, 3600) # 1 hour + assert len(result) == 1 + assert result[0]["message"] == "recent" + + +def test_filter_by_time_removes_old() -> None: + """Messages outside time range are removed.""" + now = datetime.now() + old1 = (now - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S") + old2 = (now - timedelta(days=3)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": old1, "message": "old1"}, + {"timestamp": old2, "message": "old2"}, + ] + + result = _filter_by_time(messages, 86400) # 1 day + assert len(result) == 0 + + +def test_filter_by_time_missing_timestamp() -> None: + """Messages without timestamp are filtered out.""" + messages = [ + {"message": "no timestamp"}, + {"timestamp": "", "message": "empty timestamp"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +def test_filter_by_time_invalid_timestamp() -> None: + """Messages with invalid timestamp format are filtered out.""" + messages = [ + {"timestamp": "invalid-format", "message": "bad format"}, + {"timestamp": "2024-13-45 99:99:99", "message": "impossible date"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +# -- _format_messages unit tests -- + + +def test_format_messages_basic() -> None: + """Basic formatting with timestamp, name, and message.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "Hello", + "role": "", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] Alice: Hello" + + +def test_format_messages_with_role() -> None: + """Role is included when not 'member' or empty.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Bob", + "message": "Hi", + "role": "admin", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] (admin) Bob: Hi" + + +def test_format_messages_with_title() -> None: + """Title is included when present.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Charlie", + "message": "Test", + "role": "", + "title": "群主", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] [群主] Charlie: Test" + + +def test_format_messages_with_title_and_role() -> None: + """Both title and role are included.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Dave", + "message": "Message", + "role": "owner", + "title": "管理员", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] [管理员] (owner) Dave: Message" + + +def test_format_messages_role_member_excluded() -> None: + """Role 'member' is not included.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Eve", + "message": "Text", + "role": "member", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] Eve: Text" + + +def test_format_messages_multiple() -> None: + """Multiple messages are separated by newlines.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "First", + "role": "", + "title": "", + }, + { + "timestamp": "2024-01-01 12:01:00", + "display_name": "Bob", + "message": "Second", + "role": "admin", + "title": "", + }, + ] + + result = _format_messages(messages) + expected = ( + "[2024-01-01 12:00:00] Alice: First\n[2024-01-01 12:01:00] (admin) Bob: Second" + ) + assert result == expected + + +def test_format_messages_missing_fields() -> None: + """Missing fields default to empty or '未知用户'.""" + messages = [ + { + "timestamp": "", + "message": "No timestamp", + }, + ] + + result = _format_messages(messages) + assert "未知用户" in result + assert "No timestamp" in result + + +# -- execute function tests -- + + +@pytest.mark.asyncio +async def test_fetch_messages_count_based_group() -> None: + """Count-based fetch in group context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "Message 1", + "role": "", + "title": "", + }, + { + "timestamp": "2024-01-01 12:01:00", + "display_name": "Bob", + "message": "Message 2", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute({"count": 50}, context) + + assert "共获取 2 条消息" in result + assert "Alice: Message 1" in result + assert "Bob: Message 2" in result + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_based_private() -> None: + """Count-based fetch in private context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "User", + "message": "Private message", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 0, + "user_id": 99999, + } + + result = await fetch_messages_execute({"count": 20}, context) + + assert "共获取 1 条消息" in result + assert "Private message" in result + history_manager.get_recent.assert_called_once_with("99999", "private", 0, 20) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range() -> None: + """Time-range fetch filters by time.""" + now = datetime.now() + recent = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S") + + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": recent, + "display_name": "Alice", + "message": "Recent", + "role": "", + "title": "", + }, + { + "timestamp": old, + "display_name": "Bob", + "message": "Old", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + assert "共获取 1 条消息" in result + assert "(时间范围: 1d)" in result + assert "Recent" in result + assert "Old" not in result + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_time_range() -> None: + """Invalid time range returns error message.""" + context: dict[str, Any] = { + "history_manager": MagicMock(), + "group_id": 123456, + } + + result = await fetch_messages_execute( + {"time_range": "invalid"}, + context, + ) + + assert "无法解析时间范围: invalid" in result + assert "支持格式: 1h, 6h, 1d, 7d" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_empty_history() -> None: + """Empty history returns '当前会话暂无消息记录'.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "当前会话暂无消息记录" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_no_history_manager() -> None: + """No history_manager returns error.""" + context: dict[str, Any] = { + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "历史记录管理器未配置" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_count_capped_at_500() -> None: + """Count is capped at 500.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 500) + + +@pytest.mark.asyncio +async def test_fetch_messages_default_count() -> None: + """Default count is 50 when not specified.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_count_defaults() -> None: + """Invalid count defaults to 50.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": "invalid"}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_fetch_larger_batch() -> None: + """Time range mode fetches larger batch (max(count*2, 2000)).""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + # max(50*2, 2000) = 2000 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 2000) diff --git a/tests/test_handlers_meme_annotation.py b/tests/test_handlers_meme_annotation.py new file mode 100644 index 0000000..a8ff4d6 --- /dev/null +++ b/tests/test_handlers_meme_annotation.py @@ -0,0 +1,277 @@ +"""测试 handlers.py 中的表情包自动匹配功能""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from Undefined.attachments import AttachmentRecord, AttachmentRegistry + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_success(tmp_path: Path) -> None: + """测试成功匹配表情包时添加描述""" + # 创建 mock handler + handler = MagicMock() + + # 设置 attachment registry + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 添加一个附件记录 + test_sha256 = "abc123def456" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # 设置 meme service mock + mock_meme_record = MagicMock() + mock_meme_record.description = "可爱的猫猫" + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=mock_meme_record) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + # 导入实现 + from Undefined.handlers import MessageHandler + + # 测试输入 + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "file_002", "kind": "file"}, + ] + + # 调用函数 + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + # 第一个附件应该有表情包描述 + assert result[0]["uid"] == "pic_001" + assert result[0]["description"] == "[表情包] 可爱的猫猫" + # 第二个附件不应该被修改 + assert result[1]["uid"] == "file_002" + assert "description" not in result[1] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_no_match(tmp_path: Path) -> None: + """测试没有匹配表情包时返回原始列表""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + test_sha256 = "xyz789" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # meme store 返回 None(没找到) + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=None) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert len(result) == 1 + assert result[0]["uid"] == "pic_001" + assert "description" not in result[0] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_meme_disabled() -> None: + """测试 meme service 禁用时返回原始列表""" + handler = MagicMock() + + # meme service 禁用 + mock_meme_service = MagicMock() + mock_meme_service.enabled = False + + ai = MagicMock() + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_error_handling() -> None: + """测试异常处理:失败时返回原始列表""" + handler = MagicMock() + + # 设置会抛出异常的 attachment registry + mock_attachment_registry = MagicMock() + mock_attachment_registry.resolve = MagicMock( + side_effect=Exception("Registry error") + ) + + # 设置 meme service + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(side_effect=Exception("Database error")) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = mock_attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表(异常被捕获) + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_batch_query(tmp_path: Path) -> None: + """测试批量查询:多个附件共享同一个哈希值""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 两个附件,相同的 SHA256 + shared_sha256 = "shared123" + for uid in ["pic_001", "pic_002"]: + record = AttachmentRecord( + uid=uid, + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name=f"{uid}.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=shared_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records[uid] = record + + # 记录 find_by_sha256 被调用的次数 + call_count = 0 + + async def mock_find_by_sha256(sha: str) -> Any: + nonlocal call_count + call_count += 1 + if sha == shared_sha256: + meme = MagicMock() + meme.description = "共享表情包" + return meme + return None + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = mock_find_by_sha256 + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "pic_002", "kind": "image"}, + ] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + assert result[0]["description"] == "[表情包] 共享表情包" + assert result[1]["description"] == "[表情包] 共享表情包" + + # 验证 find_by_sha256 只被调用一次(批量查询去重) + assert call_count == 1 diff --git a/tests/test_history_level.py b/tests/test_history_level.py new file mode 100644 index 0000000..59a88d4 --- /dev/null +++ b/tests/test_history_level.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from Undefined.utils.history import MessageHistoryManager + + +@pytest.mark.asyncio +async def test_add_group_message_with_level_stores_level_field( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加带 level 的群消息会正确存储 level 字段""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + role="member", + title="", + level="Lv.5", + message_id=123456, + ) + + assert "20001" in manager._message_history + assert len(manager._message_history["20001"]) == 1 + + record = manager._message_history["20001"][0] + assert record["level"] == "Lv.5" + assert record["type"] == "group" + assert record["chat_id"] == "20001" + assert record["user_id"] == "10001" + assert record["message"] == "测试消息" + assert record["role"] == "member" + assert record["title"] == "" + assert record["message_id"] == 123456 + + +@pytest.mark.asyncio +async def test_add_group_message_without_level_stores_empty_level( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息不传 level 参数时默认存储空字符串""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + ) + + assert "20001" in manager._message_history + record = manager._message_history["20001"][0] + assert record["level"] == "" + + +@pytest.mark.asyncio +async def test_get_recent_returns_messages_with_level_intact( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试 get_recent 能正确返回带 level 的消息""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = { + "20001": [ + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10001", + "display_name": "测试用户", + "role": "admin", + "title": "管理员", + "level": "Lv.10", + "timestamp": "2026-04-11 10:00:00", + "message": "第一条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10002", + "display_name": "普通用户", + "role": "member", + "title": "", + "level": "Lv.2", + "timestamp": "2026-04-11 10:01:00", + "message": "第二条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10003", + "display_name": "新用户", + "role": "member", + "title": "", + "level": "", + "timestamp": "2026-04-11 10:02:00", + "message": "第三条消息", + }, + ] + } + manager._max_records = 10000 + _evt = asyncio.Event() + _evt.set() + manager._initialized = _evt + + messages = manager.get_recent("20001", "group", 0, 10) + + assert len(messages) == 3 + assert messages[0]["level"] == "Lv.10" + assert messages[1]["level"] == "Lv.2" + assert messages[2]["level"] == "" + + +@pytest.mark.asyncio +async def test_add_group_message_with_empty_level_string( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息显式传入空字符串 level""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + level="", + ) + + record = manager._message_history["20001"][0] + assert "level" in record + assert record["level"] == "" diff --git a/tests/test_message_tools_level.py b/tests/test_message_tools_level.py new file mode 100644 index 0000000..8199502 --- /dev/null +++ b/tests/test_message_tools_level.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any + + +from Undefined.skills.toolsets.messages.get_recent_messages.handler import ( + _format_message_xml, +) + + +def test_format_message_xml_group_with_all_attributes() -> None: + """测试群消息包含 role/title/level 时 XML 全部显示""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "message_id": 123456, + "role": "admin", + "title": "管理员", + "level": "Lv.10", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert " 测试消息 " in result + + +def test_format_message_xml_group_with_empty_level() -> None: + """测试群消息 level 为空时 XML 不包含 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + "level": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + assert "title=" not in result + + +def test_format_message_xml_private_without_level() -> None: + """测试私聊消息不包含 role/title/level 属性""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result + assert "私聊消息 " in result + + +def test_format_message_xml_group_with_only_level() -> None: + """测试群消息只设置 level 时仅显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "", + "title": "", + "level": "Lv.5", + } + + result = _format_message_xml(msg) + + assert 'level="Lv.5"' in result + assert "role=" not in result + assert "title=" not in result + + +def test_format_message_xml_group_without_level_key() -> None: + """测试群消息没有 level 键时不显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + + +def test_format_message_xml_group_with_role_and_title() -> None: + """测试群消息有 role 和 title 但无 level""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "admin", + "title": "群主", + "level": "", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="群主"' in result + assert "level=" not in result + + +def test_format_message_xml_private_with_level_ignored() -> None: + """测试私聊消息即使有 level 也不显示""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + "role": "member", + "title": "测试", + "level": "Lv.99", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py new file mode 100644 index 0000000..c4fd024 --- /dev/null +++ b/tests/test_profile_command.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.profile.handler import execute as profile_execute + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + cognitive_service: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, +) -> CommandContext: + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=stub, + sender=cast(Any, sender), + ai=stub, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + cognitive_service=cognitive_service, + ) + + +# -- Private chat tests -- + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_found() -> None: + """Private chat, own profile found → sends profile via send_private_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="这是一个用户侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=99999, + user_id=99999, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert sender.private_messages[0][0] == 99999 + assert "这是一个用户侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("users", "99999") + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_not_found() -> None: + """Private chat, own profile not found → sends '暂无侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert "📭 暂无侧写数据" in sender.private_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_private_group_subcommand_rejected() -> None: + """Private chat, `/profile group` rejected → sends error message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=77777, + user_id=77777, + ) + + await profile_execute(["group"], context) + + assert len(sender.private_messages) == 1 + assert "❌ 私聊中不支持查看群聊侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +# -- Group chat tests -- + + +@pytest.mark.asyncio +async def test_profile_group_own_profile() -> None: + """Group chat, own profile → sends profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群成员侧写数据") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=55555, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 123456 + assert "群成员侧写数据" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("users", "55555") + + +@pytest.mark.asyncio +async def test_profile_group_profile_subcommand() -> None: + """Group chat, `/profile group` → sends group profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群聊整体侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=654321, + sender_id=44444, + ) + + await profile_execute(["GROUP"], context) # Test case-insensitive + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 654321 + assert "群聊整体侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("groups", "654321") + + +@pytest.mark.asyncio +async def test_profile_group_profile_not_found() -> None: + """Group chat, group profile not found → sends '暂无群聊侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value=None) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=111111, + sender_id=33333, + ) + + await profile_execute(["group"], context) + + assert len(sender.group_messages) == 1 + assert "📭 暂无群聊侧写数据" in sender.group_messages[0][1] + + +# -- Edge cases -- + + +@pytest.mark.asyncio +async def test_profile_no_cognitive_service() -> None: + """No cognitive_service → sends '侧写服务未启用'.""" + sender = _DummySender() + + context = _build_context( + sender=sender, + cognitive_service=None, + scope="group", + group_id=123456, + sender_id=22222, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 侧写服务未启用" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_truncation() -> None: + """Profile > 3000 chars gets truncated.""" + sender = _DummySender() + cognitive_service = AsyncMock() + long_profile = "A" * 3500 # Longer than 3000 chars + cognitive_service.get_profile = AsyncMock(return_value=long_profile) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=11111, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + message = sender.group_messages[0][1] + assert len(message) <= 3100 # 3000 + truncation notice + assert "[侧写过长,已截断]" in message + assert message.count("A") == 3000 # Exactly 3000 'A's before truncation diff --git a/tests/test_prompts_level.py b/tests/test_prompts_level.py new file mode 100644 index 0000000..239786a --- /dev/null +++ b/tests/test_prompts_level.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [] + + +class _FakeCognitiveService: + enabled = False + + +class _FakeAnthropicSkillRegistry: + def has_skills(self) -> bool: + return False + + +def _make_builder() -> PromptBuilder: + """创建用于测试的 PromptBuilder 实例""" + runtime_config = SimpleNamespace( + keyword_reply_enabled=False, + knowledge_enabled=False, + grok_search_enabled=False, + chat_model=SimpleNamespace( + model_name="gpt-4.1", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=False, + ), + vision_model=SimpleNamespace(model_name="gpt-4.1-mini"), + agent_model=SimpleNamespace(model_name="gpt-4.1-mini"), + embedding_model=SimpleNamespace(model_name="text-embedding-3-small"), + security_model=SimpleNamespace(model_name="gpt-4.1-mini"), + grok_model=SimpleNamespace(model_name="grok-4-search"), + cognitive=SimpleNamespace(enabled=False, recent_end_summaries_inject_k=0), + memes=SimpleNamespace( + enabled=False, + query_default_mode="hybrid", + allow_gif=False, + max_source_image_bytes=512000, + ), + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=None, + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, _FakeAnthropicSkillRegistry()), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +@pytest.mark.asyncio +async def test_group_message_with_level_includes_level_attribute( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试群消息有 level 时 XML 包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "Lv.5", + } + ] + + messages = await builder.build_messages( + '测试 ', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert 'level="Lv.5"' in history_message + assert "None: + """测试群消息 level 为空字符串时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "", + } + ] + + messages = await builder.build_messages( + ' 测试 ', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert "None: + """测试群消息没有 level 键时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + } + ] + + messages = await builder.build_messages( + ' 测试 ', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert "None: + """测试私聊消息无论是否有 level 都不会出现 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊测试消息", + "attachments": [], + "level": "Lv.10", + } + ] + + messages = await builder.build_messages( + ' 测试 ', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "sender_id": 10001, + "sender_name": "测试用户", + "request_type": "private", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert "None: + self.registered_items: list[dict[str, Any]] = [] + + async def register_bytes( + self, + scope_key: str, + data: bytes, + kind: str, + display_name: str, + mime_type: str, + source_kind: str, + source_ref: str, + ) -> Any: + class MockRecord: + uid = "test-uid-12345" + + record = MockRecord() + self.registered_items.append( + { + "scope_key": scope_key, + "size": len(data), + "kind": kind, + "display_name": display_name, + "mime_type": mime_type, + "source_kind": source_kind, + "source_ref": source_ref, + "uid": record.uid, + } + ) + return record -_PNG_HEADER = b"\x89PNG\r\n\x1a\n" +@pytest.mark.asyncio +async def test_render_simple_equation() -> None: + """测试渲染简单方程(无分隔符,自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute -def _build_context(registry: AttachmentRegistry) -> dict[str, Any]: - return { + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, "request_type": "group", - "group_id": 10001, - "sender_id": 20002, - "user_id": 20002, - "attachment_registry": registry, + "group_id": 123456, } + args = {"content": "E = mc^2", "output_format": "png"} -def test_strip_document_wrappers_removes_document_env() -> None: - content = ( - "\\begin{document}\n" - "\\[\n" - "\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n" - "\\]\n" - "\\end{document}" - ) - assert handler._strip_document_wrappers(content) == ( - "\\[\n\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n\\]" - ) + try: + result = await execute(args, context) + assert result == ' ' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "image" + assert mock_registry.registered_items[0]["mime_type"] == "image/png" + assert mock_registry.registered_items[0]["size"] > 0 + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise -def test_strip_document_wrappers_passthrough_for_plain_formula() -> None: - content = r"\[ E = mc^2 \]" - assert handler._strip_document_wrappers(content) == content +@pytest.mark.asyncio +async def test_render_with_delimiters() -> None: + """测试带分隔符的内容(不自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "private", + "user_id": 987654, + } + + args = {"content": r"\[ \int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2} \]"} + + try: + result = await execute(args, context) + assert result == ' ' + assert len(mock_registry.registered_items) == 1 + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise @pytest.mark.asyncio -async def test_render_latex_embed_success( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - registry = AttachmentRegistry( - registry_path=tmp_path / "attachment_registry.json", - cache_dir=tmp_path / "attachments", - ) - content = r"\[ \int_{-\infty}^{+\infty} e^{-x^2} dx = \sqrt{\pi} \]" - rendered_contents: list[str] = [] +async def test_render_pdf_output() -> None: + """测试 PDF 输出格式使用 attachment 标签""" + from Undefined.skills.toolsets.render.render_latex.handler import execute - monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "group", + "group_id": 123456, + } - def _fake_render(filepath: Path, render_content: str) -> None: - rendered_contents.append(render_content) - filepath.parent.mkdir(parents=True, exist_ok=True) - filepath.write_bytes(_PNG_HEADER) + args = {"content": r"\frac{a}{b} + \sqrt{c}", "output_format": "pdf"} - monkeypatch.setattr(handler, "_render_latex_image", _fake_render) + try: + result = await execute(args, context) + assert result == ' ' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "file" + assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" + assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise - result = await handler.execute( - {"content": content, "delivery": "embed"}, - _build_context(registry), - ) - assert result.startswith(' None: + """测试空内容错误处理""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": " "} + + result = await execute(args, context) + assert "不能为空" in result @pytest.mark.asyncio -async def test_render_latex_returns_helpful_message_when_tex_missing( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - registry = AttachmentRegistry( - registry_path=tmp_path / "attachment_registry.json", - cache_dir=tmp_path / "attachments", +async def test_invalid_output_format() -> None: + """测试无效输出格式""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": "x = 1", "output_format": "svg"} + + result = await execute(args, context) + assert "无效" in result or "仅支持" in result + + +def test_strip_document_wrappers() -> None: + """测试去除 document 包装""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _strip_document_wrappers, ) - content = r"\[ a = b \]" - monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + content = r"\begin{document}E = mc^2\end{document}" + result = _strip_document_wrappers(content) + assert result == "E = mc^2" - def _raise_runtime(_: Path, __: str) -> None: - raise RuntimeError("latex was not able to process the following string") + # 没有包装的内容应该原样返回 + content_no_wrapper = r"E = mc^2" + result_no_wrapper = _strip_document_wrappers(content_no_wrapper) + assert result_no_wrapper == "E = mc^2" - monkeypatch.setattr(handler, "_render_latex_image", _raise_runtime) - result = await handler.execute( - {"content": content, "delivery": "embed"}, - _build_context(registry), +def test_has_math_delimiters() -> None: + """测试数学分隔符检测""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _has_math_delimiters, ) - assert "TeX Live" in result or "MiKTeX" in result + assert _has_math_delimiters(r"\[ x = 1 \]") is True + assert _has_math_delimiters(r"\( x = 1 \)") is True + assert _has_math_delimiters(r"$$ x = 1 $$") is True + assert _has_math_delimiters(r"\begin{equation}") is True + assert _has_math_delimiters("x = 1") is False + + +def test_prepare_content() -> None: + """测试内容准备逻辑""" + from Undefined.skills.toolsets.render.render_latex.handler import _prepare_content + + # 无分隔符,自动包装 + result = _prepare_content("E = mc^2") + assert result.startswith(r"\[") + assert result.endswith(r"\]") + assert "E = mc^2" in result + + # 有分隔符,不包装 + result_with_delim = _prepare_content(r"\[ E = mc^2 \]") + assert result_with_delim == r"\[ E = mc^2 \]" + + # 字面量 \\n 处理 + result_newline = _prepare_content(r"x = 1\\ny = 2") + assert "\n" in result_newline + assert "\\n" not in result_newline.replace(r"\[", "").replace(r"\]", "") diff --git a/tests/test_summary_agent.py b/tests/test_summary_agent.py new file mode 100644 index 0000000..72fe659 --- /dev/null +++ b/tests/test_summary_agent.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.skills.agents.summary_agent.handler import ( + execute as summary_agent_execute, +) + + +@pytest.mark.asyncio +async def test_summary_agent_normal_execution() -> None: + """Normal execution → calls run_agent_with_tools with correct params.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 123456, + "user_id": 10002, + "sender_id": 10002, + "request_type": "group", + "runtime_config": None, + "queue_lane": None, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="总结结果:讨论了技术话题"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结最近 50 条消息"}, + context, + ) + + assert result == "总结结果:讨论了技术话题" + mock_run_agent.assert_called_once() + + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["agent_name"] == "summary_agent" + assert call_kwargs["user_content"] == "请总结最近 50 条消息" + assert call_kwargs["empty_user_content_message"] == "请提供您的总结需求" + assert "消息总结助手" in call_kwargs["default_prompt"] + assert call_kwargs["context"] is context + assert isinstance(call_kwargs["agent_dir"], Path) + assert call_kwargs["max_iterations"] == 10 + assert call_kwargs["tool_error_prefix"] == "错误" + + +@pytest.mark.asyncio +async def test_summary_agent_empty_prompt() -> None: + """Empty prompt → returns '请提供您的总结需求'.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": ""}, context) + + assert result == "请提供您的总结需求" + mock_run_agent.assert_called_once() + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_whitespace_prompt() -> None: + """Whitespace-only prompt → treated as empty.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": " "}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_missing_prompt_arg() -> None: + """Missing 'prompt' arg → defaults to empty string.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_complex_prompt() -> None: + """Complex prompt with time range and custom instructions.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 654321, + "user_id": 99999, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="详细总结内容"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结过去 1d 内的聊天消息,重点关注:技术讨论"}, + context, + ) + + assert result == "详细总结内容" + call_kwargs = mock_run_agent.call_args.kwargs + assert ( + call_kwargs["user_content"] == "请总结过去 1d 内的聊天消息,重点关注:技术讨论" + ) + + +@pytest.mark.asyncio +async def test_summary_agent_propagates_exception() -> None: + """Exception from run_agent_with_tools propagates up.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(side_effect=RuntimeError("Agent failure")), + ): + with pytest.raises(RuntimeError, match="Agent failure"): + await summary_agent_execute({"prompt": "test"}, context) diff --git a/tests/test_summary_command.py b/tests/test_summary_command.py new file mode 100644 index 0000000..7ce7550 --- /dev/null +++ b/tests/test_summary_command.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.summary.handler import ( + _build_prompt, + _parse_args, + execute as summary_execute, +) + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + history_manager: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, + ai: Any = None, +) -> CommandContext: + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + if ai is None: + ai = stub + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=stub, + sender=cast(Any, sender), + ai=ai, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + history_manager=history_manager, + ) + + +# -- _parse_args unit tests -- + + +def test_parse_args_empty() -> None: + """Empty args → (50, None, '').""" + count, time_range, custom_prompt = _parse_args([]) + assert count == 50 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_count_only() -> None: + """['100'] → (100, None, '').""" + count, time_range, custom_prompt = _parse_args(["100"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_time_range_only() -> None: + """['1d'] → (None, '1d', '').""" + count, time_range, custom_prompt = _parse_args(["1d"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "" + + +def test_parse_args_count_with_custom_prompt() -> None: + """['100', '技术讨论'] → (100, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["100", "技术讨论"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_time_range_with_custom_prompt() -> None: + """['1d', '总结技术'] → (None, '1d', '总结技术').""" + count, time_range, custom_prompt = _parse_args(["1d", "总结技术"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "总结技术" + + +def test_parse_args_custom_prompt_only() -> None: + """['技术讨论'] → (50, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["技术讨论"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_count_capped_at_500() -> None: + """['999'] → (500, None, '') (capped).""" + count, time_range, custom_prompt = _parse_args(["999"]) + assert count == 500 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_multiple_words_prompt() -> None: + """['技术', '讨论', '总结'] → (50, None, '技术 讨论 总结').""" + count, time_range, custom_prompt = _parse_args(["技术", "讨论", "总结"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术 讨论 总结" + + +# -- _build_prompt unit tests -- + + +def test_build_prompt_with_count() -> None: + """With count → '请总结最近 X 条聊天消息'.""" + prompt = _build_prompt(100, None, "") + assert "请总结最近 100 条聊天消息" in prompt + + +def test_build_prompt_with_time_range() -> None: + """With time_range → '请总结过去 X 内的聊天消息'.""" + prompt = _build_prompt(None, "1d", "") + assert "请总结过去 1d 内的聊天消息" in prompt + + +def test_build_prompt_with_custom_prompt() -> None: + """With custom_prompt → adds '重点关注:...'.""" + prompt = _build_prompt(50, None, "技术讨论") + assert "请总结最近 50 条聊天消息" in prompt + assert "重点关注:技术讨论" in prompt + + +def test_build_prompt_default_count() -> None: + """Default count when both are None.""" + prompt = _build_prompt(None, None, "") + assert "请总结最近 50 条聊天消息" in prompt + + +def test_build_prompt_time_range_and_custom() -> None: + """Time range with custom prompt.""" + prompt = _build_prompt(None, "6h", "重要公告") + assert "请总结过去 6h 内的聊天消息" in prompt + assert "重点关注:重要公告" in prompt + + +# -- Command execution tests -- + + +@pytest.mark.asyncio +async def test_summary_no_history_manager() -> None: + """No history_manager → sends error message.""" + sender = _DummySender() + context = _build_context( + sender=sender, + history_manager=None, + scope="group", + group_id=123456, + sender_id=10002, + ) + + await summary_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 历史记录管理器未配置" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_agent_call_success() -> None: + """Agent call success → result forwarded to user.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = None + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结内容:最近讨论了技术话题。"), + ) as mock_agent: + await summary_execute(["50"], context) + + assert len(sender.group_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.group_messages[0][1] + assert "总结内容:最近讨论了技术话题。" in sender.group_messages[1][1] + mock_agent.assert_called_once() + call_args = mock_agent.call_args + assert call_args[0][0]["prompt"] == "请总结最近 50 条聊天消息" + + +@pytest.mark.asyncio +async def test_summary_agent_call_failure() -> None: + """Agent call failure → sends error message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(side_effect=Exception("Agent error")), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.group_messages[0][1] + assert "❌ 消息总结失败,请稍后重试" in sender.group_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_agent_returns_empty() -> None: + """Agent returns empty result → sends '未能生成总结内容'.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value=" "), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 2 + assert "📭 未能生成总结内容" in sender.group_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_private_chat() -> None: + """Private chat → uses send_private_message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="私聊总结结果"), + ): + await summary_execute(["1d", "重要消息"], context) + + assert len(sender.private_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.private_messages[0][1] + assert "私聊总结结果" in sender.private_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_passes_correct_context_to_agent() -> None: + """Agent receives correct context parameters.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = SimpleNamespace(some_config="value") + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=999888, + sender_id=777666, + user_id=None, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结"), + ) as mock_agent: + await summary_execute([], context) + + call_args = mock_agent.call_args + agent_context = call_args[0][1] + assert agent_context["ai_client"] is ai + assert agent_context["history_manager"] is history_manager + assert agent_context["group_id"] == 999888 + assert agent_context["sender_id"] == 777666 + assert agent_context["user_id"] == 777666 + assert agent_context["request_type"] == "group" + assert agent_context["runtime_config"] is ai.runtime_config From 0e735cba1268a9c36ce8daa262b20337925637f8 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 14:51:56 +0800 Subject: [PATCH 06/57] fix(render_latex): pass proxy config to Playwright for CDN access Read use_proxy/http_proxy/https_proxy from runtime_config and forward to chromium.launch(proxy=...) so MathJax CDN loads correctly on servers requiring a proxy. Also update use_proxy comment to be generic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 4 +-- .../toolsets/render/render_latex/handler.py | 32 +++++++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/config.toml.example b/config.toml.example index 0e7fe2c..678dd09 100644 --- a/config.toml.example +++ b/config.toml.example @@ -716,8 +716,8 @@ grok_search_enabled = false # zh: 代理设置(可选)。 # en: Proxy settings (optional). [proxy] -# zh: 是否在 "web_agent" 中启用代理。 -# en: Enable proxy in "web_agent". +# zh: 是否使用代理。 +# en: Whether to use proxy. use_proxy = true # zh: 例如 http://127.0.0.1:7890(也可使用环境变量 "HTTP_PROXY")。 # en: e.g. http://127.0.0.1:7890 (or use the "HTTP_PROXY" environment variable). diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index a449cd2..01e8ba3 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -77,7 +77,9 @@ def _build_html(latex_content: str) -> str: """ -async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[bytes, str]: +async def _render_latex_to_bytes( + content: str, output_format: str, proxy: str | None = None +) -> tuple[bytes, str]: """ 使用 MathJax + Playwright 渲染 LaTeX 内容。 @@ -95,8 +97,13 @@ async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[byte html_content = _build_html(content) + launch_kwargs: dict[str, object] = {"headless": True} + if proxy: + launch_kwargs["proxy"] = {"server": proxy} + logger.info("LaTeX 渲染使用代理: %s", proxy) + async with async_playwright() as p: - browser = await p.chromium.launch(headless=True) + browser = await p.chromium.launch(**launch_kwargs) # type: ignore[arg-type] try: page = await browser.new_page() await page.set_content(html_content) @@ -143,6 +150,22 @@ async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[byte await browser.close() +async def _resolve_proxy(context: Dict[str, Any]) -> str | None: + """从 context 的 runtime_config 中解析代理地址。""" + from Undefined.config import get_config + + runtime_config = context.get("runtime_config") or get_config(strict=False) + if runtime_config is None: + return None + use_proxy: bool = getattr(runtime_config, "use_proxy", False) + if not use_proxy: + return None + proxy: str = getattr(runtime_config, "http_proxy", "") or getattr( + runtime_config, "https_proxy", "" + ) + return proxy or None + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: """渲染 LaTeX 数学公式为图片或 PDF""" raw_content = str(args.get("content", "") or "") @@ -159,9 +182,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 准备内容 prepared_content = _prepare_content(raw_content) + # 解析代理 + proxy = await _resolve_proxy(context) + # 渲染 rendered_bytes, mime_type = await _render_latex_to_bytes( - prepared_content, output_format + prepared_content, output_format, proxy=proxy ) # 注册到附件系统 From 1571d7e60428fa749f9883c98eceadfd5fb0c696 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:08:29 +0800 Subject: [PATCH 07/57] =?UTF-8?q?feat(config):=20sync=5Fconfig=5Ftemplate?= =?UTF-8?q?=20=E6=8A=A5=E5=91=8A=E6=B3=A8=E9=87=8A=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=EF=BC=9Bgif=5Fanalysis=5Fmode=20=E4=B8=8B?= =?UTF-8?q?=E6=8B=89=E9=80=89=E6=8B=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ConfigTemplateSyncResult 新增 updated_comment_paths 字段 - sync_config_text() 对比 current/example 注释,记录有改动的路径 - 脚本以 ~ 前缀展示注释更新项数量和路径列表 - 修复测试 mock 对象缺少新字段 - config-form.js:gif_analysis_mode 渲染为 grid/multi 下拉框 Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- scripts/sync_config_template.py | 5 ++++ src/Undefined/webui/static/js/config-form.js | 1 + src/Undefined/webui/utils/config_sync.py | 24 +++++++++++++++++--- tests/test_sync_config_template_script.py | 1 + 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/scripts/sync_config_template.py b/scripts/sync_config_template.py index e432001..583b9eb 100755 --- a/scripts/sync_config_template.py +++ b/scripts/sync_config_template.py @@ -96,6 +96,11 @@ def main() -> int: for path in result.added_paths: print(f" + {path}") + if result.updated_comment_paths: + print(f"[sync-config] 注释更新数量: {len(result.updated_comment_paths)}") + for path in result.updated_comment_paths: + print(f" ~ {path}") + if result.removed_paths: print(f"[sync-config] 多余路径数量: {len(result.removed_paths)}") for path in result.removed_paths: diff --git a/src/Undefined/webui/static/js/config-form.js b/src/Undefined/webui/static/js/config-form.js index ffadfe0..111d64a 100644 --- a/src/Undefined/webui/static/js/config-form.js +++ b/src/Undefined/webui/static/js/config-form.js @@ -251,6 +251,7 @@ function isLongText(value) { const FIELD_SELECT_OPTIONS = { api_mode: ["chat_completions", "responses"], + gif_analysis_mode: ["grid", "multi"], reasoning_effort_style: ["openai", "anthropic"], // path -> options key mapping (underscore-separated segments) image_gen_provider: ["xingzhige", "models"], diff --git a/src/Undefined/webui/utils/config_sync.py b/src/Undefined/webui/utils/config_sync.py index 46c16ee..c3df03a 100644 --- a/src/Undefined/webui/utils/config_sync.py +++ b/src/Undefined/webui/utils/config_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import dataclasses import tomllib from dataclasses import dataclass from pathlib import Path @@ -21,6 +22,7 @@ class ConfigTemplateSyncResult: added_paths: list[str] removed_paths: list[str] comments: CommentMap + updated_comment_paths: list[str] = dataclasses.field(default_factory=list) def _parse_toml_text(content: str, *, label: str) -> TomlData: @@ -226,6 +228,17 @@ def _augment_pool_model_comments( return merged +def _collect_updated_comment_paths( + current: CommentMap, example: CommentMap +) -> list[str]: + """收集在 example 与 current 中都存在、但注释文本不同的路径。""" + return [ + key + for key, example_value in example.items() + if key in current and current[key] != example_value + ] + + def _merge_comment_maps(current: CommentMap, example: CommentMap) -> CommentMap: merged: CommentMap = dict(current) for key, value in example.items(): @@ -248,16 +261,21 @@ def sync_config_text( if prune and removed_paths: merged = _prune_to_template(merged, prepared_example_data) example_comments = parse_comment_map_text(example_text) - comments = _merge_comment_maps( - parse_comment_map_text(current_text), - _augment_pool_model_comments(example_comments, example_data, current_data), + current_comments = parse_comment_map_text(current_text) + augmented_example_comments = _augment_pool_model_comments( + example_comments, example_data, current_data + ) + updated_comment_paths = _collect_updated_comment_paths( + current_comments, augmented_example_comments ) + comments = _merge_comment_maps(current_comments, augmented_example_comments) content = render_toml(merged, comments=comments) return ConfigTemplateSyncResult( content=content, added_paths=added_paths, removed_paths=removed_paths, comments=comments, + updated_comment_paths=updated_comment_paths, ) diff --git a/tests/test_sync_config_template_script.py b/tests/test_sync_config_template_script.py index 3f14aec..e422138 100644 --- a/tests/test_sync_config_template_script.py +++ b/tests/test_sync_config_template_script.py @@ -43,6 +43,7 @@ def fake_sync_config_file( added_paths=[], removed_paths=["models.chat.extra"], comments={}, + updated_comment_paths=[], ) monkeypatch.setattr(module, "sync_config_file", fake_sync_config_file) From 987ac132fb7feabf9c4a61d91ac35b4bfa36e3f5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:21:57 +0800 Subject: [PATCH 08/57] =?UTF-8?q?feat(repeat):=20bot=20=E5=8F=91=E8=A8=80?= =?UTF-8?q?=E4=B8=8D=E8=AE=A1=E5=85=A5=E5=A4=8D=E8=AF=BB=E9=93=BE=EF=BC=9B?= =?UTF-8?q?=E9=98=88=E5=80=BC=E5=8F=AF=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - bot 消息写入 counter(而非过滤)使窗口对其可见 - 触发条件增加 bot_qq not in senders 检查 覆盖三种情形:bot 先发、bot 中间插入、bot 滑出窗口后正常触发 - RuntimeConfig 新增 repeat_threshold(范围 2–20,默认 3) - 硬编码 3 / 5 替换为 n = repeat_threshold - config.toml.example 更新 repeat_enabled 注释并新增 repeat_threshold 字段 - 新增 5 个测试,共 17 个复读测试全部通过 Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 7 ++- src/Undefined/config/loader.py | 14 ++++++ src/Undefined/handlers.py | 33 ++++++++++---- tests/test_handlers_repeat.py | 78 +++++++++++++++++++++++++++++++++- 4 files changed, 120 insertions(+), 12 deletions(-) diff --git a/config.toml.example b/config.toml.example index 678dd09..ad25b66 100644 --- a/config.toml.example +++ b/config.toml.example @@ -659,9 +659,12 @@ agent_call_message_enabled = "none" # zh: 是否启用群聊关键词("心理委员")自动回复。 # en: Enable keyword auto-replies("心理委员") in group chats. keyword_reply_enabled = false -# zh: 是否启用群聊复读功能(连续3条相同消息时复读)。 -# en: Enable repeat feature in group chats (repeat when 3 consecutive identical messages). +# zh: 是否启用群聊复读功能(连续 N 条相同消息时复读,N 由 repeat_threshold 控制;若期间有 bot 自身发言则重置链)。 +# en: Enable repeat feature in group chats (repeat when N consecutive identical messages arrive, N set by repeat_threshold; resets if bot itself sent the same text in between). repeat_enabled = false +# zh: 复读触发所需的连续相同消息条数(来自不同发送者),范围 2–20,默认 3。 +# en: Number of consecutive identical messages (from different senders) required to trigger repeat, range 2–20, default 3. +repeat_threshold = 3 # zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 # en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). inverted_question_enabled = false diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index f7d7f52..00b8c3a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -469,6 +469,7 @@ class Config: process_poke_message: bool keyword_reply_enabled: bool repeat_enabled: bool + repeat_threshold: int inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int @@ -723,6 +724,18 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ), False, ) + repeat_threshold = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_threshold"), + "EASTER_EGG_REPEAT_THRESHOLD", + ), + 3, + ) + if repeat_threshold < 2: + repeat_threshold = 2 + if repeat_threshold > 20: + repeat_threshold = 20 context_recent_messages_limit = _coerce_int( _get_value( data, @@ -1396,6 +1409,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi process_poke_message=process_poke_message, keyword_reply_enabled=keyword_reply_enabled, repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 8cfeae9..db8fba2 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -707,7 +707,17 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 + # 同时把 bot 自身的发言写入复读计数器,使窗口中留有 bot 标记, + # 后续触发检查时会排除含 bot 的窗口,防止"bot 先发 → 用户跟发"或 + # "用户发到一半 bot 插入"等情况误触复读。 if sender_id == self.config.bot_qq: + if self.config.repeat_enabled and text: + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + n = self.config.repeat_threshold + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] return self._schedule_meme_ingest( @@ -776,21 +786,26 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) return - # 复读功能:连续3条相同消息(来自不同发送者)时复读 + # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold if self.config.repeat_enabled: + n = self.config.repeat_threshold async with self._get_repeat_lock(group_id): counter = self._repeat_counter.setdefault(group_id, []) counter.append((text, sender_id)) - # 只保留最近5条 - if len(counter) > 5: - self._repeat_counter[group_id] = counter[-5:] + # 只保留最近 n 条 + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] counter = self._repeat_counter[group_id] - if len(counter) >= 3: - last3 = counter[-3:] - texts = [t for t, _ in last3] - senders = [s for _, s in last3] - if len(set(texts)) == 1 and len(set(senders)) == 3: + if len(counter) >= n: + last_n = counter[-n:] + texts = [t for t, _ in last_n] + senders = [s for _, s in last_n] + if ( + len(set(texts)) == 1 + and len(set(senders)) == n + and self.config.bot_qq not in senders + ): reply_text = texts[0] if self.config.inverted_question_enabled: stripped = reply_text.strip() diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 3adf332..2bb6aef 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -17,6 +17,7 @@ def _build_handler( *, repeat_enabled: bool = False, + repeat_threshold: int = 3, inverted_question_enabled: bool = False, keyword_reply_enabled: bool = False, ) -> Any: @@ -24,6 +25,7 @@ def _build_handler( handler.config = SimpleNamespace( bot_qq=10000, repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, inverted_question_enabled=inverted_question_enabled, keyword_reply_enabled=keyword_reply_enabled, bilibili_auto_extract_enabled=False, @@ -248,7 +250,7 @@ async def test_repeat_groups_are_independent() -> None: assert call.args[0] == 30002 -# ── 计数器窗口:只看最近5条 ── +# ── 计数器窗口:只看最近 N 条 ── @pytest.mark.asyncio @@ -264,3 +266,77 @@ async def test_repeat_counter_sliding_window() -> None: handler.sender.send_group_message.assert_called_once() call = handler.sender.send_group_message.call_args assert call.args[1] == "hello" + + +# ── bot 自身发言后不触发复读 ── + +BOT_QQ = 10000 + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_before_users() -> None: + """bot 先发,后面用户再发相同消息,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # bot 先发 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 两个用户跟发 + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_in_middle() -> None: + """用户发到一半,bot 插入相同消息,之后用户再凑满阈值,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # 两个用户先发 + await handler.handle_message(_group_event(sender_id=20001, text="hello")) + await handler.handle_message(_group_event(sender_id=20002, text="hello")) + # bot 插入 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 第三个用户发:此时窗口 [user2, bot, user3],含 bot → 不触发 + await handler.handle_message(_group_event(sender_id=20003, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_triggers_after_bot_window_slides_out() -> None: + """bot 消息滑出窗口后,纯用户序列应正常触发复读(threshold=3)。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=3) + # bot 先发(进入窗口) + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 三个不同用户依次发:窗口变为 [user1, user2, user3](bot 已滑出) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hello" + + +# ── 可配置阈值 ── + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_2() -> None: + """threshold=2 时,2 条不同发送者相同消息即触发复读。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=2) + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hi")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hi" + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_4() -> None: + """threshold=4 时,3 条不同发送者相同消息不触发,第 4 条才触发。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=4) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hey")) + handler.sender.send_group_message.assert_not_called() + + await handler.handle_message(_group_event(sender_id=20004, text="hey")) + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hey" From f6712278a993a7fb7a8fc2433b4fd8373ff84e26 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:55:15 +0800 Subject: [PATCH 09/57] =?UTF-8?q?fix(profile,latex):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BE=A7=E5=86=99=E5=8F=8C=E5=A4=8D=E6=95=B0=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E5=92=8CMathJax=E7=AD=89=E5=BE=85=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit profile 命令: - 修复 entity_type 双复数 bug ("users"→"user", "groups"→"group") profile_storage._profile_path 会追加 "s", handler 传 "users" 导致 路径变为 "userss/", 永远找不到侧写文件 - 新增 "g" 快捷子命令 (/p g 等同于 /p group) - 更新 config.json 帮助文本和使用说明 LaTeX 渲染: - 修复 MathJax wait_for_function 逻辑: 旧代码返回 Promise 而非 boolean, Playwright 无法正确判断完成状态, 必然超时 - 改用 pageReady 回调设 window._mjReady 标记, wait 检查该标记 - 超时从 15s 增至 30s - 添加 MathJax 配置块支持行内数学 ($...$) 测试: 新增 g 快捷方式和私聊拒绝测试, HTML 模板测试 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/config.json | 6 +-- .../skills/commands/profile/handler.py | 8 +-- .../toolsets/render/render_latex/handler.py | 18 +++++-- tests/test_profile_command.py | 50 +++++++++++++++++-- tests/test_render_latex_tool.py | 11 ++++ 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/Undefined/skills/commands/profile/config.json b/src/Undefined/skills/commands/profile/config.json index 6b47c63..1dd5eb5 100644 --- a/src/Undefined/skills/commands/profile/config.json +++ b/src/Undefined/skills/commands/profile/config.json @@ -1,8 +1,8 @@ { "name": "profile", - "description": "查看用户或群聊侧写", - "usage": "/profile [group]", - "example": "/profile", + "description": "查看侧写。默认显示你的用户侧写;加 g 查看当前群聊侧写(仅群聊可用)", + "usage": "/p [g]", + "example": "/p g", "permission": "public", "rate_limit": { "user": 60, diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index aba240a..143f3b4 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -31,18 +31,18 @@ async def execute(args: list[str], context: CommandContext) -> None: await _send(context, "❌ 侧写服务未启用") return - # Parse subcommand + # Parse subcommand: "g" or "group" → group profile sub = args[0].lower().strip() if args else "" - if sub == "group": + if sub in ("group", "g"): if _is_private(context): await _send(context, "❌ 私聊中不支持查看群聊侧写") return - entity_type = "groups" + entity_type = "group" entity_id = str(context.group_id) empty_hint = "暂无群聊侧写数据" else: - entity_type = "users" + entity_type = "user" entity_id = str(context.sender_id) empty_hint = "暂无侧写数据" diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 01e8ba3..0637b9f 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -63,6 +63,18 @@ def _build_html(latex_content: str) -> str: + + {safe_title}
+{safe_body}+""" + + output_dir = ensure_dir(RENDER_CACHE_DIR) + output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") + + await render_html_to_image(html_content, output_path) + + abs_path = Path(output_path).resolve() + image_cq = f"[CQ:image,file=file://{abs_path}]" + + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, image_cq) + else: + await context.sender.send_group_message(context.group_id, image_cq) + + +# ── 入口 ───────────────────────────────────────────────────── + + async def execute(args: list[str], context: CommandContext) -> None: - """处理 /profile 命令。""" + """处理 /profile 命令。 + + 用法: /p [g] [-t|--text] [-f|--forward] [-r|--render] + g / group 查看群聊侧写(仅群聊可用) + -t / --text 纯文本直接发出 + -f / --forward 合并转发发出(群聊默认) + -r / --render 渲染为图片发出 + """ cognitive_service = context.cognitive_service if cognitive_service is None: - await _send(context, "❌ 侧写服务未启用") + await _send_text(context, "❌ 侧写服务未启用") return - # Parse subcommand: "g" or "group" → group profile - sub = args[0].lower().strip() if args else "" + sub, mode = _parse_args(args) if sub in ("group", "g"): if _is_private(context): - await _send(context, "❌ 私聊中不支持查看群聊侧写") + await _send_text(context, "❌ 私聊中不支持查看群聊侧写") return entity_type = "group" entity_id = str(context.group_id) + title = "📋 群聊侧写" empty_hint = "暂无群聊侧写数据" else: entity_type = "user" entity_id = str(context.sender_id) + title = "📋 用户侧写" empty_hint = "暂无侧写数据" profile = await cognitive_service.get_profile(entity_type, entity_id) if not profile: - await _send(context, f"📭 {empty_hint}") + await _send_text(context, f"📭 {empty_hint}") return - await _send(context, _truncate(profile)) + profile = _truncate(profile) + + # 私聊始终纯文本 + if _is_private(context): + mode = _MODE_TEXT + + # 未指定模式:群聊默认合并转发 + if not mode: + mode = _MODE_FORWARD + + if mode == _MODE_TEXT: + await _send_text(context, profile) + elif mode == _MODE_RENDER: + try: + await _send_render(context, title, profile) + except Exception: + logger.exception("渲染侧写图片失败,回退到纯文本") + await _send_text(context, profile) + else: + try: + await _send_forward(context, title, profile) + except Exception: + logger.exception("发送合并转发失败,回退到纯文本") + await _send_text(context, profile) diff --git a/src/Undefined/utils/fake_at.py b/src/Undefined/utils/fake_at.py new file mode 100644 index 0000000..0ac0200 --- /dev/null +++ b/src/Undefined/utils/fake_at.py @@ -0,0 +1,180 @@ +"""假@检测:识别群聊中 ``@昵称`` 纯文本形式的"假 at"。 + +设计要点: +- ``BotNicknameCache`` 自动通过 OneBot API 获取 bot 在各群的群名片 / QQ 昵称, + 带 TTL 缓存 + per-group asyncio.Lock 防竞态。 +- ``strip_fake_at`` 是无状态纯函数,负责文本匹配与剥离。 +- 匹配规则:半角 ``@`` / 全角 ``@`` + 昵称 (casefold) + 边界(空白/标点/行尾), + 昵称按长度降序匹配以避免短昵称吃掉长昵称的前缀。 +""" + +from __future__ import annotations + +import asyncio +import logging +import re +import time +import unicodedata +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from Undefined.onebot import OneBotClient + +logger = logging.getLogger(__name__) + +# 昵称后必须跟的边界:空白、常见标点、行尾 +_BOUNDARY_RE = re.compile(r"[\s\u3000,.:;!?,。:;!?/\-\]))>》]+|$") + +# 缓存默认 TTL (秒) +_DEFAULT_TTL: float = 600.0 + + +def _normalize(text: str) -> str: + """将全角 @ 统一为半角 @,然后 casefold。""" + return unicodedata.normalize("NFKC", text).casefold() + + +def _sorted_nicknames(names: frozenset[str]) -> tuple[str, ...]: + """按长度降序排列,长昵称优先匹配。""" + return tuple(sorted(names, key=len, reverse=True)) + + +def strip_fake_at( + text: str, + nicknames: frozenset[str], +) -> tuple[bool, str]: + """检测并剥离文本开头的假 @ 前缀。 + + 参数: + text: 原始消息文本(已做 extract_text 处理)。 + nicknames: 当前群 bot 所有有效昵称 (已 casefold)。 + + 返回: + (is_fake_at, stripped_text) + - is_fake_at: 是否命中假 @。 + - stripped_text: 剥离假 @ 后的文本(未命中时返回原文)。 + """ + if not nicknames or not text: + return False, text + + normalized = _normalize(text) + # 以 @ 开头才可能是假 @ + if not normalized.startswith("@"): + return False, text + + # 去掉开头的 @ + after_at = normalized[1:] + raw_after_at = text[1:] # 保持原始大小写的切片 + + for nick in _sorted_nicknames(nicknames): + if not after_at.startswith(nick): + continue + # 检查昵称后是否为合法边界 + rest_pos = len(nick) + rest_normalized = after_at[rest_pos:] + if rest_normalized and not _BOUNDARY_RE.match(rest_normalized): + continue + # 命中——用原始文本切出剥离后的内容 + stripped = raw_after_at[rest_pos:].lstrip() + return True, stripped + + return False, text + + +class BotNicknameCache: + """按群缓存 bot 昵称,用于假 @ 检测。 + + 线程安全性: + - 只在单一 asyncio 事件循环中使用。 + - 每个 group_id 使用独立 ``asyncio.Lock``,保证同一群的并发 + 消息不会触发重复 API 请求。 + - ``_global_lock`` 仅保护 ``_locks`` 字典本身的创建。 + """ + + def __init__( + self, + onebot: OneBotClient, + bot_qq: int, + ttl: float = _DEFAULT_TTL, + ) -> None: + self._onebot = onebot + self._bot_qq = bot_qq + self._ttl = ttl + # group_id → (nicknames_frozenset, timestamp) + self._cache: dict[int, tuple[frozenset[str], float]] = {} + # group_id → asyncio.Lock + self._locks: dict[int, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + + def _get_lock(self, group_id: int) -> asyncio.Lock: + """获取指定群的锁(同步快路径 + 异步慢路径)。 + + 在 asyncio 单线程模型下,dict 读写本身是原子的, + 但为严谨起见仍使用 ``_global_lock`` 保护 ``_locks`` 的创建。 + 注意:此方法不是 coroutine,但会在 _ensure_lock 里 await。 + """ + lock = self._locks.get(group_id) + if lock is not None: + return lock + # 需要创建——由调用方在 _ensure_lock 中持有 _global_lock + lock = asyncio.Lock() + self._locks[group_id] = lock + return lock + + async def _ensure_lock(self, group_id: int) -> asyncio.Lock: + """确保 group lock 存在,如有必要在 global lock 下创建。""" + lock = self._locks.get(group_id) + if lock is not None: + return lock + async with self._global_lock: + return self._get_lock(group_id) + + async def get_nicknames(self, group_id: int) -> frozenset[str]: + """获取 bot 在指定群的所有有效昵称(含手动配置)。 + + 会自动缓存,过期后异步刷新。API 失败时返回上次缓存或仅手动昵称。 + """ + now = time.monotonic() + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + lock = await self._ensure_lock(group_id) + async with lock: + # Double-check:可能在等锁期间已被其他协程刷新 + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + fetched = await self._fetch(group_id) + self._cache[group_id] = (fetched, time.monotonic()) + return fetched + + async def _fetch(self, group_id: int) -> frozenset[str]: + """从 OneBot API 获取 bot 在指定群的群名片 + QQ 昵称。""" + names: set[str] = set() + try: + info = await self._onebot.get_group_member_info(group_id, self._bot_qq) + if isinstance(info, dict): + for key in ("card", "nickname"): + val = str(info.get(key, "") or "").strip() + if val: + names.add(val.casefold()) + except Exception as exc: + logger.debug( + "[假@] 获取 bot 群成员信息失败: group=%s err=%s", + group_id, + exc, + ) + return frozenset(names) + + def invalidate(self, group_id: int | None = None) -> None: + """手动失效缓存。group_id=None 清空全部。""" + if group_id is None: + self._cache.clear() + else: + self._cache.pop(group_id, None) diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 2bb6aef..8d508de 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -70,6 +70,9 @@ def _build_handler( handler._repeat_counter = {} handler._repeat_locks = {} handler._profile_name_refresh_cache = {} + handler._bot_nickname_cache = SimpleNamespace( + get_nicknames=AsyncMock(return_value=frozenset()), + ) return handler From 1595d888b8aa55b30a54c89d410b8d6cc306c321 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:21:32 +0800 Subject: [PATCH 15/57] =?UTF-8?q?refactor(profile):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=90=88=E5=B9=B6=E8=BD=AC=E5=8F=91=E5=92=8C=E6=B8=B2=E6=9F=93?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 emoji 标题,合并转发改为 2 节点:元数据 + 完整侧写 - 不再截断分消息发送 - 渲染模式:卡片式布局,渐变色元数据头 + 正文区 - 元数据含类型/ID/字数/更新时间 Co-authored-by: Claude Opus 4.6Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/handler.py | 156 +++++++++++------- 1 file changed, 98 insertions(+), 58 deletions(-) diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 6eb0428..5eabaa2 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -3,18 +3,20 @@ import html import logging import uuid +from datetime import datetime, timezone from pathlib import Path from typing import Any from Undefined.services.commands.context import CommandContext -from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir +from Undefined.utils.paths import COGNITIVE_PROFILES_DIR, RENDER_CACHE_DIR, ensure_dir logger = logging.getLogger("profile") _MAX_PROFILE_LENGTH = 3000 -# 合并转发单条消息最大字符数(过长拆分) -_FORWARD_NODE_LIMIT = 2000 +_MODE_TEXT = "text" +_MODE_FORWARD = "forward" +_MODE_RENDER = "render" def _is_private(context: CommandContext) -> bool: @@ -27,24 +29,10 @@ def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" -# ── 输出模式枚举 ────────────────────────────────────────────── - -_MODE_TEXT = "text" -_MODE_FORWARD = "forward" -_MODE_RENDER = "render" - - -def _parse_args( - args: list[str], -) -> tuple[str, str]: - """解析参数,返回 (子命令, 输出模式)。 - - 子命令: "" | "g" | "group" - 输出模式: "text" | "forward" | "render" - """ +def _parse_args(args: list[str]) -> tuple[str, str]: + """解析参数,返回 (子命令, 输出模式)。""" sub = "" mode = "" - for arg in args: lower = arg.lower().strip() if lower in ("-t", "--text"): @@ -55,11 +43,38 @@ def _parse_args( mode = _MODE_RENDER elif lower in ("g", "group"): sub = lower - # 忽略无法识别的参数 - return sub, mode +def _profile_mtime(entity_type: str, entity_id: str) -> str | None: + """读取侧写文件最后修改时间,返回人类可读字符串。""" + p = COGNITIVE_PROFILES_DIR / f"{entity_type}s" / f"{entity_id}.md" + try: + mtime = p.stat().st_mtime + dt = datetime.fromtimestamp(mtime, tz=timezone.utc).astimezone() + return dt.strftime("%Y-%m-%d %H:%M") + except OSError: + return None + + +def _build_metadata( + entity_type: str, + entity_id: str, + profile_len: int, +) -> str: + """构建元数据摘要文本。""" + type_label = "用户" if entity_type == "user" else "群聊" + lines = [ + f"类型: {type_label}侧写", + f"ID: {entity_id}", + f"长度: {profile_len} 字", + ] + mtime = _profile_mtime(entity_type, entity_id) + if mtime: + lines.append(f"更新: {mtime}") + return "\n".join(lines) + + # ── 发送方法 ────────────────────────────────────────────────── @@ -72,57 +87,83 @@ async def _send_text(context: CommandContext, text: str) -> None: await context.sender.send_group_message(context.group_id, text) -async def _send_forward(context: CommandContext, title: str, profile_text: str) -> None: - """合并转发发送。""" +async def _send_forward( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """合并转发:节点1=元数据,节点2=完整侧写内容。""" bot_qq = str(getattr(context.config, "bot_qq", 0)) - nodes: list[dict[str, Any]] = [] - def _add_node(content: str) -> None: - nodes.append( - { - "type": "node", - "data": {"name": "Undefined", "uin": bot_qq, "content": content}, - } - ) - - _add_node(title) - - # 按长度拆分成多个节点 - remaining = profile_text - while remaining: - chunk = remaining[:_FORWARD_NODE_LIMIT] - remaining = remaining[_FORWARD_NODE_LIMIT:] - _add_node(chunk) + def _node(content: str) -> dict[str, Any]: + return { + "type": "node", + "data": {"name": "Undefined", "uin": bot_qq, "content": content}, + } + nodes = [_node(metadata), _node(profile_text)] await context.onebot.send_forward_msg(context.group_id, nodes) -async def _send_render(context: CommandContext, title: str, profile_text: str) -> None: - """渲染为图片发送。""" +async def _send_render( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """渲染为图片发送——元数据区 + 侧写正文区。""" from Undefined.render import render_html_to_image - safe_title = html.escape(title) + safe_meta = html.escape(metadata) safe_body = html.escape(profile_text) + + # 将元数据行拆分为 key: value 形式渲染 + meta_rows = "" + for line in safe_meta.split("\n"): + if ": " in line: + key, _, val = line.partition(": ") + meta_rows += ( + f' \n' + ) + html_content = f""" - {key} {val} {safe_title}
-{safe_body}+.meta .mv {{ font-size: 13px; padding: 3px 0; }} +.body {{ + padding: 24px; line-height: 1.8; font-size: 14px; + white-space: pre-wrap; word-wrap: break-word; +}} + + ++ +""" output_dir = ensure_dir(RENDER_CACHE_DIR) output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") - await render_html_to_image(html_content, output_path) abs_path = Path(output_path).resolve() @@ -160,12 +201,10 @@ async def execute(args: list[str], context: CommandContext) -> None: return entity_type = "group" entity_id = str(context.group_id) - title = "📋 群聊侧写" empty_hint = "暂无群聊侧写数据" else: entity_type = "user" entity_id = str(context.sender_id) - title = "📋 用户侧写" empty_hint = "暂无侧写数据" profile = await cognitive_service.get_profile(entity_type, entity_id) @@ -174,6 +213,7 @@ async def execute(args: list[str], context: CommandContext) -> None: return profile = _truncate(profile) + metadata = _build_metadata(entity_type, entity_id, len(profile)) # 私聊始终纯文本 if _is_private(context): @@ -187,13 +227,13 @@ async def execute(args: list[str], context: CommandContext) -> None: await _send_text(context, profile) elif mode == _MODE_RENDER: try: - await _send_render(context, title, profile) + await _send_render(context, metadata, profile) except Exception: logger.exception("渲染侧写图片失败,回退到纯文本") await _send_text(context, profile) else: try: - await _send_forward(context, title, profile) + await _send_forward(context, metadata, profile) except Exception: logger.exception("发送合并转发失败,回退到纯文本") await _send_text(context, profile) From a96e369773deacbd78580dca07394260d8514493 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:25:37 +0800 Subject: [PATCH 16/57] =?UTF-8?q?fix(profile):=20=E4=BD=BF=E7=94=A8=20WebU?= =?UTF-8?q?I=20=E9=85=8D=E8=89=B2=E3=80=81=E6=8F=90=E9=AB=98=E6=88=AA?= =?UTF-8?q?=E6=96=AD=E4=B8=8A=E9=99=90=E8=87=B3=205000?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 渲染样式改为 WebUI 暖色调 (#f9f5f1/#e6e0d8/#3d3935) - 截断上限从 3000 提高到 5000 - 更新对应测试用例 Co-authored-by: Claude Opus 4.6{safe_body}+Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/handler.py | 31 +++++++++---------- tests/test_profile_command.py | 8 ++--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 5eabaa2..6d6f473 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -12,7 +12,7 @@ logger = logging.getLogger("profile") -_MAX_PROFILE_LENGTH = 3000 +_MAX_PROFILE_LENGTH = 5000 _MODE_TEXT = "text" _MODE_FORWARD = "forward" @@ -116,7 +116,6 @@ async def _send_render( safe_meta = html.escape(metadata) safe_body = html.escape(profile_text) - # 将元数据行拆分为 key: value 形式渲染 meta_rows = "" for line in safe_meta.split("\n"): if ": " in line: @@ -129,29 +128,29 @@ async def _send_render( diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py index fc01cb8..fe7a105 100644 --- a/tests/test_profile_command.py +++ b/tests/test_profile_command.py @@ -273,10 +273,10 @@ async def test_profile_no_cognitive_service() -> None: @pytest.mark.asyncio async def test_profile_truncation() -> None: - """Profile > 3000 chars gets truncated.""" + """Profile > 5000 chars gets truncated.""" sender = _DummySender() cognitive_service = AsyncMock() - long_profile = "A" * 3500 # Longer than 3000 chars + long_profile = "A" * 5500 # Longer than 5000 chars cognitive_service.get_profile = AsyncMock(return_value=long_profile) context = _build_context( @@ -291,6 +291,6 @@ async def test_profile_truncation() -> None: assert len(sender.group_messages) == 1 message = sender.group_messages[0][1] - assert len(message) <= 3100 # 3000 + truncation notice + assert len(message) <= 5100 # 5000 + truncation notice assert "[侧写过长,已截断]" in message - assert message.count("A") == 3000 # Exactly 3000 'A's before truncation + assert message.count("A") == 5000 # Exactly 5000 'A's before truncation From f3a802ccbd885c241fbb3b45e945c2282bab7409 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:53:28 +0800 Subject: [PATCH 17/57] =?UTF-8?q?feat(agents):=20=E6=96=B0=E5=A2=9E=20arXi?= =?UTF-8?q?v=20=E8=AE=BA=E6=96=87=E6=B7=B1=E5=BA=A6=E5=88=86=E6=9E=90=20ag?= =?UTF-8?q?ent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 arxiv_analysis_agent: 下载 arXiv PDF 全文进行结构化学术分析 - 分页读取设计: fetch_paper 获取元数据, read_paper_pages 分批读取避免 token 溢出 - 复用 Undefined.arxiv 模块 (client/downloader/parser) - 更新 web_agent/callable.json 添加 summary_agent 和 arxiv_analysis_agent - 新增 18 个单元测试覆盖 handler + tools + config 结构 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agents/arxiv_analysis_agent/__init__.py | 0 .../agents/arxiv_analysis_agent/callable.json | 4 + .../agents/arxiv_analysis_agent/config.json | 21 ++ .../agents/arxiv_analysis_agent/handler.py | 48 +++ .../agents/arxiv_analysis_agent/intro.md | 1 + .../agents/arxiv_analysis_agent/prompt.md | 28 ++ .../arxiv_analysis_agent/tools/__init__.py | 0 .../tools/fetch_paper/config.json | 17 + .../tools/fetch_paper/handler.py | 75 ++++ .../tools/read_paper_pages/config.json | 21 ++ .../tools/read_paper_pages/handler.py | 68 ++++ .../skills/agents/web_agent/callable.json | 2 +- tests/test_arxiv_analysis_agent.py | 351 ++++++++++++++++++ 13 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/callable.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/handler.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/intro.md create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py create mode 100644 tests/test_arxiv_analysis_agent.py diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json new file mode 100644 index 0000000..855776f --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["info_agent", "web_agent", "naga_code_analysis_agent"] +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json new file mode 100644 index 0000000..a97b0a9 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "arxiv_analysis_agent", + "description": "arXiv 论文深度分析助手,下载并解析 arXiv 论文 PDF 全文,进行结构化学术深度分析。支持按 arXiv ID 或 URL 获取论文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID(如 '2301.07041')、arXiv URL(如 'https://arxiv.org/abs/2301.07041')或 'arXiv:2301.07041' 格式" + }, + "prompt": { + "type": "string", + "description": "可选的分析需求,例如:'重点分析方法论和实验设计'、'关注与 diffusion model 的对比'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py new file mode 100644 index 0000000..7b6d605 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.arxiv.parser import normalize_arxiv_id +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 arxiv_analysis_agent。""" + + raw_paper_id = str(args.get("paper_id", "")).strip() + user_prompt = str(args.get("prompt", "")).strip() + + if not raw_paper_id: + return "请提供 arXiv 论文 ID 或 URL" + + paper_id = normalize_arxiv_id(raw_paper_id) + if paper_id is None: + return f"无法解析 arXiv 标识:{raw_paper_id}" + + context["arxiv_paper_id"] = paper_id + + context_messages = [ + { + "role": "system", + "content": f"当前任务:深度分析 arXiv 论文 {paper_id}", + } + ] + + user_content = user_prompt if user_prompt else f"请深度分析论文 arXiv:{paper_id}" + + return await run_agent_with_tools( + agent_name="arxiv_analysis_agent", + user_content=user_content, + context_messages=context_messages, + empty_user_content_message="请提供 arXiv 论文 ID 或 URL", + default_prompt="你是一个学术论文深度分析助手。", + context=context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=15, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md new file mode 100644 index 0000000..1b306e3 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md @@ -0,0 +1 @@ +arXiv 论文深度分析助手:下载 arXiv 论文 PDF 全文并进行结构化学术深度分析,涵盖方法论、实验、创新点、局限性等维度。 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md new file mode 100644 index 0000000..029ab1e --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md @@ -0,0 +1,28 @@ +你是学术论文深度分析助手,专门对 arXiv 论文进行全面、结构化的学术分析。 + +工作流程: +1. 先调用 `fetch_paper` 获取论文元数据和摘要 +2. 调用 `read_paper_pages` 分批阅读论文全文(每次读取一定页数范围) +3. 读完后产出结构化深度分析 + +阅读策略: +- 先通过 `fetch_paper` 了解总页数和摘要 +- 用 `read_paper_pages` 按区间阅读(如 1-5、6-10、11-15 等),每次读 5 页 +- 对于较长的论文(>20 页),可以选择性跳过附录/参考文献,集中精力分析正文 +- 短论文可以一次性读完 + +分析输出结构: +- **概要**:一句话总结论文核心贡献 +- **研究背景与动机**:论文要解决什么问题、为什么重要 +- **方法论**:核心技术方案、算法、模型架构的详细解析 +- **实验与结果**:实验设置、基准对比、主要结论 +- **创新点与贡献**:论文的主要新颖之处 +- **局限性与未来方向**:作者提到的或你分析出的不足和可改进之处 +- **总评**:论文的整体质量和影响力评估 + +注意事项: +- 保持学术严谨,用客观语言描述 +- 如果用户指定了分析侧重(prompt),要重点展开相关部分 +- 对公式和算法用自然语言解释,不要直接复制 LaTeX +- PDF 提取可能有格式问题(乱码/缺失),遇到时说明情况继续分析 +- 如果论文是中文,用中文输出分析;否则用中文输出分析但保留关键术语英文原文 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json new file mode 100644 index 0000000..5e69a69 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "fetch_paper", + "description": "获取 arXiv 论文元数据并下载 PDF,返回论文基本信息(标题、作者、摘要、页数)。调用后可用 read_paper_pages 分页阅读正文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID,如 '2301.07041'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py new file mode 100644 index 0000000..d75cfcd --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py @@ -0,0 +1,75 @@ +"""获取 arXiv 论文元数据并下载 PDF 到本地缓存。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +from Undefined.arxiv.client import get_paper_info +from Undefined.arxiv.downloader import download_paper_pdf +from Undefined.arxiv.parser import normalize_arxiv_id + +logger = logging.getLogger(__name__) + +_MAX_FILE_SIZE_MB = 50 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + raw_id = str(args.get("paper_id", "")).strip() + if not raw_id: + return "错误:请提供 arXiv 论文 ID" + + paper_id = normalize_arxiv_id(raw_id) or raw_id + + request_context = {"request_id": context.get("request_id", "-")} + + try: + paper = await get_paper_info(paper_id, context=request_context) + except Exception as exc: + logger.exception("[arxiv_analysis] 获取论文元数据失败: %s", exc) + return f"错误:获取论文 {paper_id} 元数据失败 — {exc}" + + if paper is None: + return f"错误:未找到论文 arXiv:{paper_id}" + + lines: list[str] = [ + f"论文: {paper.title}", + f"ID: {paper.paper_id}", + f"作者: {'、'.join(paper.authors[:10])}{'(等 ' + str(len(paper.authors)) + ' 位)' if len(paper.authors) > 10 else ''}", + f"分类: {paper.primary_category}", + f"发布: {paper.published[:10]}", + f"更新: {paper.updated[:10]}", + f"链接: {paper.abs_url}", + f"\n摘要:\n{paper.summary}", + ] + + try: + result, task_dir = await download_paper_pdf( + paper, max_file_size_mb=_MAX_FILE_SIZE_MB, context=request_context + ) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 下载失败: %s", exc) + lines.append(f"\nPDF 下载失败: {exc}(可基于摘要进行分析)") + return "\n".join(lines) + + if result.path is None: + lines.append(f"\nPDF 不可用(状态: {result.status}),请基于摘要进行分析") + return "\n".join(lines) + + try: + doc = fitz.open(str(result.path)) + try: + page_count = len(doc) + lines.append(f"\nPDF 已下载: {page_count} 页") + context["_arxiv_pdf_path"] = str(result.path) + context["_arxiv_pdf_pages"] = page_count + context["_arxiv_task_dir"] = str(task_dir) + finally: + doc.close() + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + lines.append(f"\nPDF 无法打开: {exc}(可基于摘要进行分析)") + + return "\n".join(lines) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json new file mode 100644 index 0000000..abc268e --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "read_paper_pages", + "description": "读取已下载论文 PDF 的指定页范围文本。必须先调用 fetch_paper 下载论文。", + "parameters": { + "type": "object", + "properties": { + "start_page": { + "type": "integer", + "description": "起始页码(从 1 开始)" + }, + "end_page": { + "type": "integer", + "description": "结束页码(包含该页)" + } + }, + "required": ["start_page", "end_page"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py new file mode 100644 index 0000000..ed013d1 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py @@ -0,0 +1,68 @@ +"""分页读取已下载的 arXiv 论文 PDF 文本。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +logger = logging.getLogger(__name__) + +_MAX_CHARS_PER_READ = 15000 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + pdf_path = context.get("_arxiv_pdf_path") + total_pages = context.get("_arxiv_pdf_pages") + + if not pdf_path: + return "错误:请先调用 fetch_paper 下载论文" + + try: + start_page = int(args.get("start_page", 1)) + end_page = int(args.get("end_page", start_page)) + except (TypeError, ValueError): + return "错误:页码必须为整数" + + if start_page < 1: + start_page = 1 + if total_pages and end_page > total_pages: + end_page = total_pages + if start_page > end_page: + return f"错误:起始页 {start_page} 大于结束页 {end_page}" + + try: + doc = fitz.open(pdf_path) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + return f"错误:PDF 文件无法打开 — {exc}" + + try: + actual_pages = len(doc) + if start_page > actual_pages: + return f"错误:论文共 {actual_pages} 页,请求的起始页 {start_page} 超出范围" + + end_page = min(end_page, actual_pages) + text_parts: list[str] = [] + total_chars = 0 + + for page_num in range(start_page - 1, end_page): + page = doc.load_page(page_num) + raw_text = page.get_text() + page_text = str(raw_text) if raw_text else "" + + if total_chars + len(page_text) > _MAX_CHARS_PER_READ and text_parts: + text_parts.append( + f"\n[第 {page_num + 1} 页起文本已截断,请用更小的页范围重新读取]" + ) + break + + text_parts.append(f"--- 第 {page_num + 1} 页 ---") + text_parts.append(page_text if page_text.strip() else "(此页无可提取文本)") + total_chars += len(page_text) + + header = f"论文内容(第 {start_page}-{end_page} 页,共 {actual_pages} 页)" + return f"{header}\n\n" + "\n".join(text_parts) + finally: + doc.close() diff --git a/src/Undefined/skills/agents/web_agent/callable.json b/src/Undefined/skills/agents/web_agent/callable.json index ab8b02a..bc537f2 100644 --- a/src/Undefined/skills/agents/web_agent/callable.json +++ b/src/Undefined/skills/agents/web_agent/callable.json @@ -1,4 +1,4 @@ { "enabled": true, - "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent"] + "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent", "summary_agent", "arxiv_analysis_agent"] } diff --git a/tests/test_arxiv_analysis_agent.py b/tests/test_arxiv_analysis_agent.py new file mode 100644 index 0000000..82f8c53 --- /dev/null +++ b/tests/test_arxiv_analysis_agent.py @@ -0,0 +1,351 @@ +"""arxiv_analysis_agent 单元测试。""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# config.json / callable.json 结构检查 +# --------------------------------------------------------------------------- + +AGENT_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "arxiv_analysis_agent" +) + + +def _load_json(name: str) -> dict[str, Any]: + result: dict[str, Any] = json.loads((AGENT_DIR / name).read_text(encoding="utf-8")) + return result + + +class TestAgentConfig: + def test_config_json_schema(self) -> None: + cfg = _load_json("config.json") + assert cfg["type"] == "function" + func = cfg["function"] + assert func["name"] == "arxiv_analysis_agent" + assert "paper_id" in func["parameters"]["properties"] + assert "paper_id" in func["parameters"]["required"] + + def test_callable_json(self) -> None: + cfg = _load_json("callable.json") + assert cfg["enabled"] is True + assert isinstance(cfg["allowed_callers"], list) + assert len(cfg["allowed_callers"]) > 0 + + def test_tools_exist(self) -> None: + tools_dir = AGENT_DIR / "tools" + assert (tools_dir / "fetch_paper" / "config.json").exists() + assert (tools_dir / "fetch_paper" / "handler.py").exists() + assert (tools_dir / "read_paper_pages" / "config.json").exists() + assert (tools_dir / "read_paper_pages" / "handler.py").exists() + + def test_prompt_md_exists(self) -> None: + assert (AGENT_DIR / "prompt.md").exists() + content = (AGENT_DIR / "prompt.md").read_text(encoding="utf-8") + assert len(content) > 100 + + +# --------------------------------------------------------------------------- +# handler.py +# --------------------------------------------------------------------------- + + +class TestHandler: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_invalid_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": "not-a-valid-id"}, {}) + assert "无法解析" in result + + @pytest.mark.asyncio + async def test_valid_paper_id_calls_runner(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + mock_result = "分析结果" + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_run: + result = await execute( + {"paper_id": "2301.07041", "prompt": "分析方法论"}, + {"request_id": "test-123"}, + ) + assert result == mock_result + mock_run.assert_awaited_once() + call_kwargs = mock_run.call_args[1] + assert call_kwargs["agent_name"] == "arxiv_analysis_agent" + assert "分析方法论" in call_kwargs["user_content"] + + @pytest.mark.asyncio + async def test_url_input_normalized(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value="ok", + ): + ctx: dict[str, Any] = {} + await execute( + {"paper_id": "https://arxiv.org/abs/2301.07041"}, + ctx, + ) + assert ctx["arxiv_paper_id"] == "2301.07041" + + +# --------------------------------------------------------------------------- +# fetch_paper tool +# --------------------------------------------------------------------------- + + +def _make_paper_info( + paper_id: str = "2301.07041", +) -> Any: + from Undefined.arxiv.models import PaperInfo + + return PaperInfo( + paper_id=paper_id, + title="Test Paper Title", + authors=("Author A", "Author B"), + summary="This is a test abstract.", + published="2023-01-17T00:00:00Z", + updated="2023-01-18T00:00:00Z", + primary_category="cs.AI", + abs_url=f"https://arxiv.org/abs/{paper_id}", + pdf_url=f"https://arxiv.org/pdf/{paper_id}.pdf", + ) + + +class TestFetchPaper: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + result = await execute({"paper_id": ""}, {}) + assert "错误" in result + + @pytest.mark.asyncio + async def test_metadata_only_on_download_failure(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + side_effect=RuntimeError("network error"), + ), + ): + result = await execute({"paper_id": "2301.07041"}, {}) + assert "Test Paper Title" in result + assert "Author A" in result + assert "PDF 下载失败" in result + + @pytest.mark.asyncio + async def test_paper_not_found(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=None, + ): + result = await execute({"paper_id": "9999.99999"}, {}) + assert "未找到" in result + + @pytest.mark.asyncio + async def test_successful_fetch_with_pdf(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + pdf_path = tmp_path / "test.pdf" + + import fitz + + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Hello world") + doc.save(str(pdf_path)) + doc.close() + + from Undefined.arxiv.downloader import PaperDownloadResult + + download_result = PaperDownloadResult( + path=pdf_path, size_bytes=1024, status="downloaded" + ) + + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + return_value=(download_result, tmp_path), + ), + ): + ctx: dict[str, Any] = {} + result = await execute({"paper_id": "2301.07041"}, ctx) + assert "Test Paper Title" in result + assert "1 页" in result + assert ctx["_arxiv_pdf_path"] == str(pdf_path) + assert ctx["_arxiv_pdf_pages"] == 1 + + +# --------------------------------------------------------------------------- +# read_paper_pages tool +# --------------------------------------------------------------------------- + + +class TestReadPaperPages: + @pytest.mark.asyncio + async def test_no_pdf_downloaded(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + result = await execute({"start_page": 1, "end_page": 1}, {}) + assert "先调用 fetch_paper" in result + + @pytest.mark.asyncio + async def test_read_single_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Page 1 content") + page2 = doc.new_page() + page2.insert_text((72, 72), "Page 2 content") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 2, + } + result = await execute({"start_page": 1, "end_page": 1}, ctx) + assert "第 1 页" in result + assert "Page 1 content" in result + assert "Page 2 content" not in result + + @pytest.mark.asyncio + async def test_read_page_range(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + for i in range(3): + page = doc.new_page() + page.insert_text((72, 72), f"Content of page {i + 1}") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 3, + } + result = await execute({"start_page": 1, "end_page": 3}, ctx) + assert "第 1-3 页" in result or "第 1 页" in result + assert "Content of page 1" in result + assert "Content of page 3" in result + + @pytest.mark.asyncio + async def test_out_of_range_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + doc.new_page() + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 1, + } + result = await execute({"start_page": 5, "end_page": 10}, ctx) + assert "错误" in result + + @pytest.mark.asyncio + async def test_invalid_page_numbers(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": "/some/path.pdf", + "_arxiv_pdf_pages": 10, + } + result = await execute({"start_page": "abc", "end_page": "def"}, ctx) + assert "整数" in result + + +# --------------------------------------------------------------------------- +# web_agent callable.json 更新检查 +# --------------------------------------------------------------------------- + + +class TestWebAgentCallable: + def test_web_agent_allows_new_callers(self) -> None: + web_agent_callable = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "web_agent" + / "callable.json" + ) + cfg = json.loads(web_agent_callable.read_text(encoding="utf-8")) + callers = cfg["allowed_callers"] + assert "summary_agent" in callers + assert "arxiv_analysis_agent" in callers From 45e5a367b1894f1f4027881692ff583f3ea14d8f Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:54:06 +0800 Subject: [PATCH 18/57] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=20CLAUDE.md=20?= =?UTF-8?q?agents=20=E6=95=B0=E9=87=8F=E4=B8=BA=207?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CLAUDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index f2e67a0..3e3df12 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,7 +55,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) |------|------| | `ai/` | AI 核心:client.py(主入口)、llm.py(模型请求)、prompts.py(Prompt构建)、tooling.py(工具管理)、multimodal.py(多模态)、model_selector.py(模型选择) | | `services/` | 运行服务:ai_coordinator.py(协调器+队列投递)、queue_manager.py(车站-列车队列)、command.py(命令分发)、model_pool.py(多模型池)、security.py(安全防护) | -| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(6个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | +| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(7个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | | `cognitive/` | 认知记忆:service.py(入口)、vector_store.py(ChromaDB)、historian.py(后台史官异步改写+侧写合并)、job_queue.py、profile_storage.py | | `memes/` | 表情包库:service.py(两阶段AI管线)、worker.py(异步处理队列)、store.py(SQLite)、vector_store.py(ChromaDB)、models.py | | `config/` | 配置系统:loader.py(TOML解析+类型化)、models.py(数据模型)、hot_reload.py(热更新) | From 9439b0ebc5f8a2b2e4a181ef9aaeb638310b3953 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 08:37:50 +0800 Subject: [PATCH 19/57] =?UTF-8?q?feat(tools):=20=E6=96=B0=E5=A2=9E=20calcu?= =?UTF-8?q?lator=20=E5=A4=9A=E5=8A=9F=E8=83=BD=E5=AE=89=E5=85=A8=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 基于 AST 的安全表达式求值,拒绝任何非数学操作 - 支持:算术、幂运算、科学函数、三角函数、统计函数、组合数学 - 常量:pi, e, tau, inf; 函数:sqrt, log, sin, cos, factorial, gcd, mean 等 - allowed_callers: ["*"] 允许所有 agent 调用 - 安全限制:指数上限 10000、表达式长度上限 500 - 55 个单元测试覆盖算术/科学/统计/比较/安全拒绝 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/tools/calculator/callable.json | 4 + .../skills/tools/calculator/config.json | 17 ++ .../skills/tools/calculator/handler.py | 258 +++++++++++++++++ tests/test_calculator.py | 261 ++++++++++++++++++ 4 files changed, 540 insertions(+) create mode 100644 src/Undefined/skills/tools/calculator/callable.json create mode 100644 src/Undefined/skills/tools/calculator/config.json create mode 100644 src/Undefined/skills/tools/calculator/handler.py create mode 100644 tests/test_calculator.py diff --git a/src/Undefined/skills/tools/calculator/callable.json b/src/Undefined/skills/tools/calculator/callable.json new file mode 100644 index 0000000..0a69975 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["*"] +} diff --git a/src/Undefined/skills/tools/calculator/config.json b/src/Undefined/skills/tools/calculator/config.json new file mode 100644 index 0000000..ad86fb4 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "calculator", + "description": "安全的多功能数学计算器。支持算术运算、科学函数、统计函数和常量。输入数学表达式即可计算。", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "数学表达式,例如:'2+3*4'、'sqrt(144)'、'sin(pi/6)'、'log(1000,10)'、'mean(1,2,3,4,5)'、'2**10'、'factorial(10)'、'gcd(48,18)'" + } + }, + "required": ["expression"] + } + } +} diff --git a/src/Undefined/skills/tools/calculator/handler.py b/src/Undefined/skills/tools/calculator/handler.py new file mode 100644 index 0000000..cd26438 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/handler.py @@ -0,0 +1,258 @@ +"""安全的多功能数学计算器。 + +通过 AST 解析数学表达式,仅允许数学运算,拒绝任何危险操作。 +支持:算术、幂运算、科学函数、统计函数、常量。 +""" + +from __future__ import annotations + +import ast +import math +import operator +import statistics +from typing import Any + +_UNARY_OPS: dict[type, Any] = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + +_BIN_OPS: dict[type, Any] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.BitXor: operator.pow, # ^ 也当幂运算(常见误用) +} + +_COMPARE_OPS: dict[type, Any] = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, +} + +_CONSTANTS: dict[str, float] = { + "pi": math.pi, + "PI": math.pi, + "e": math.e, + "E": math.e, + "tau": math.tau, + "inf": math.inf, + "nan": math.nan, +} + + +def _stat_fn(fn_name: str, args: list[float | int]) -> float | int: + """统计函数分发。""" + if len(args) < 1: + raise ValueError(f"{fn_name} 至少需要 1 个参数") + fn_map: dict[str, Any] = { + "mean": statistics.mean, + "median": statistics.median, + "stdev": statistics.stdev, + "variance": statistics.variance, + "pstdev": statistics.pstdev, + "pvariance": statistics.pvariance, + "harmonic_mean": statistics.harmonic_mean, + } + fn = fn_map.get(fn_name) + if fn is None: + raise ValueError(f"未知统计函数: {fn_name}") + if fn_name in ("stdev", "variance") and len(args) < 2: + raise ValueError(f"{fn_name} 至少需要 2 个数据点") + result: float | int = fn(args) + return result + + +_MATH_FUNCS: dict[str, Any] = { + # 基础 + "abs": abs, + "round": round, + "int": int, + "float": float, + # 幂与对数 + "sqrt": math.sqrt, + "cbrt": lambda x: x ** (1 / 3), + "pow": math.pow, + "exp": math.exp, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "ln": math.log, + # 三角 + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "atan2": math.atan2, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + # 角度转换 + "degrees": math.degrees, + "radians": math.radians, + # 取整 + "ceil": math.ceil, + "floor": math.floor, + "trunc": math.trunc, + # 组合数学 + "factorial": math.factorial, + "comb": math.comb, + "perm": math.perm, + "gcd": math.gcd, + "lcm": math.lcm, + # 其他 + "hypot": math.hypot, + "copysign": math.copysign, + "fmod": math.fmod, + "isqrt": math.isqrt, + # 统计(占位,实际调用 _stat_fn) + "mean": None, + "median": None, + "stdev": None, + "variance": None, + "pstdev": None, + "pvariance": None, + "harmonic_mean": None, + # 最值 + "max": max, + "min": min, + "sum": sum, +} + +_STAT_FUNCS = frozenset( + {"mean", "median", "stdev", "variance", "pstdev", "pvariance", "harmonic_mean"} +) + +_MAX_POWER = 10000 +_MAX_EXPRESSION_LENGTH = 500 + + +class _SafeEvaluator(ast.NodeVisitor): + """安全 AST 求值器,仅允许数学运算。""" + + def visit(self, node: ast.AST) -> Any: + return super().visit(node) + + def generic_visit(self, node: ast.AST) -> Any: + raise ValueError(f"不支持的表达式语法: {type(node).__name__}") + + def visit_Expression(self, node: ast.Expression) -> Any: + return self.visit(node.body) + + def visit_Constant(self, node: ast.Constant) -> Any: + if isinstance(node.value, (int, float, complex)): + return node.value + raise ValueError(f"不支持的常量类型: {type(node.value).__name__}") + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in _CONSTANTS: + return _CONSTANTS[node.id] + raise ValueError(f"未知变量: {node.id}") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + op_fn = _UNARY_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的一元运算: {type(node.op).__name__}") + return op_fn(self.visit(node.operand)) + + def visit_BinOp(self, node: ast.BinOp) -> Any: + op_fn = _BIN_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的运算: {type(node.op).__name__}") + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, (ast.Pow, ast.BitXor)): + if isinstance(right, (int, float)) and abs(right) > _MAX_POWER: + raise ValueError(f"指数过大: {right}(上限 {_MAX_POWER})") + return op_fn(left, right) + + def visit_Compare(self, node: ast.Compare) -> Any: + left = self.visit(node.left) + for op, comparator in zip(node.ops, node.comparators): + op_fn = _COMPARE_OPS.get(type(op)) + if op_fn is None: + raise ValueError(f"不支持的比较: {type(op).__name__}") + right = self.visit(comparator) + if not op_fn(left, right): + return False + left = right + return True + + def visit_Call(self, node: ast.Call) -> Any: + if not isinstance(node.func, ast.Name): + raise ValueError("仅支持直接函数调用") + fn_name = node.func.id + if fn_name not in _MATH_FUNCS: + raise ValueError(f"未知函数: {fn_name}") + if node.keywords: + raise ValueError("不支持关键字参数") + + args = [self.visit(arg) for arg in node.args] + + if fn_name in _STAT_FUNCS: + return _stat_fn(fn_name, args) + + fn = _MATH_FUNCS[fn_name] + return fn(*args) + + def visit_IfExp(self, node: ast.IfExp) -> Any: + condition = self.visit(node.test) + return self.visit(node.body) if condition else self.visit(node.orelse) + + def visit_Tuple(self, node: ast.Tuple) -> Any: + return tuple(self.visit(elt) for elt in node.elts) + + def visit_List(self, node: ast.List) -> Any: + return [self.visit(elt) for elt in node.elts] + + +def safe_eval(expression: str) -> str: + """安全计算数学表达式,返回字符串结果。""" + expr = expression.strip() + if not expr: + raise ValueError("表达式为空") + if len(expr) > _MAX_EXPRESSION_LENGTH: + raise ValueError( + f"表达式过长({len(expr)} 字符,上限 {_MAX_EXPRESSION_LENGTH})" + ) + + tree = ast.parse(expr, mode="eval") + result = _SafeEvaluator().visit(tree) + + if isinstance(result, float): + if result == int(result) and not (math.isinf(result) or math.isnan(result)): + return str(int(result)) + return f"{result:.10g}" + if isinstance(result, complex): + return str(result) + if isinstance(result, bool): + return str(result) + return str(result) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """计算数学表达式。""" + expression = str(args.get("expression", "")).strip() + if not expression: + return "请提供数学表达式" + + try: + result = safe_eval(expression) + return f"{expression} = {result}" + except ZeroDivisionError: + return f"计算错误:除以零 ({expression})" + except OverflowError: + return f"计算错误:结果溢出 ({expression})" + except (ValueError, TypeError) as exc: + return f"计算错误:{exc}" + except SyntaxError: + return f"表达式语法错误:{expression}" diff --git a/tests/test_calculator.py b/tests/test_calculator.py new file mode 100644 index 0000000..118759a --- /dev/null +++ b/tests/test_calculator.py @@ -0,0 +1,261 @@ +"""calculator 工具单元测试。""" + +from __future__ import annotations + +import json +import math +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.skills.tools.calculator.handler import safe_eval + +TOOL_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "tools" + / "calculator" +) + + +# --------------------------------------------------------------------------- +# 配置文件检查 +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_config_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "config.json").read_text(encoding="utf-8") + ) + assert cfg["function"]["name"] == "calculator" + assert "expression" in cfg["function"]["parameters"]["properties"] + + def test_callable_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "callable.json").read_text(encoding="utf-8") + ) + assert cfg["enabled"] is True + assert "*" in cfg["allowed_callers"] + + +# --------------------------------------------------------------------------- +# safe_eval 单元测试 +# --------------------------------------------------------------------------- + + +class TestArithmetic: + def test_addition(self) -> None: + assert safe_eval("2+3") == "5" + + def test_subtraction(self) -> None: + assert safe_eval("10-3") == "7" + + def test_multiplication(self) -> None: + assert safe_eval("4*5") == "20" + + def test_division(self) -> None: + assert safe_eval("10/3") == "3.333333333" + + def test_floor_division(self) -> None: + assert safe_eval("10//3") == "3" + + def test_modulo(self) -> None: + assert safe_eval("10%3") == "1" + + def test_power(self) -> None: + assert safe_eval("2**10") == "1024" + + def test_caret_as_power(self) -> None: + assert safe_eval("2^10") == "1024" + + def test_negative_number(self) -> None: + assert safe_eval("-5+3") == "-2" + + def test_complex_expression(self) -> None: + assert safe_eval("(2+3)*4-1") == "19" + + def test_floating_point(self) -> None: + assert safe_eval("0.1+0.2") == "0.3" + + +class TestConstants: + def test_pi(self) -> None: + result = float(safe_eval("pi")) + assert abs(result - math.pi) < 1e-8 + + def test_e(self) -> None: + result = float(safe_eval("e")) + assert abs(result - math.e) < 1e-8 + + def test_tau(self) -> None: + result = float(safe_eval("tau")) + assert abs(result - math.tau) < 1e-8 + + +class TestScientificFunctions: + def test_sqrt(self) -> None: + assert safe_eval("sqrt(144)") == "12" + + def test_sin(self) -> None: + result = float(safe_eval("sin(pi/6)")) + assert abs(result - 0.5) < 1e-10 + + def test_cos(self) -> None: + result = float(safe_eval("cos(0)")) + assert abs(result - 1.0) < 1e-10 + + def test_log10(self) -> None: + assert safe_eval("log10(1000)") == "3" + + def test_log_with_base(self) -> None: + assert safe_eval("log(8, 2)") == "3" + + def test_ln(self) -> None: + result = float(safe_eval("ln(e)")) + assert abs(result - 1.0) < 1e-10 + + def test_factorial(self) -> None: + assert safe_eval("factorial(10)") == "3628800" + + def test_degrees(self) -> None: + result = float(safe_eval("degrees(pi)")) + assert abs(result - 180.0) < 1e-10 + + def test_radians(self) -> None: + result = float(safe_eval("radians(180)")) + assert abs(result - math.pi) < 1e-8 + + def test_ceil(self) -> None: + assert safe_eval("ceil(3.2)") == "4" + + def test_floor(self) -> None: + assert safe_eval("floor(3.8)") == "3" + + def test_abs(self) -> None: + assert safe_eval("abs(-42)") == "42" + + def test_gcd(self) -> None: + assert safe_eval("gcd(48, 18)") == "6" + + def test_lcm(self) -> None: + assert safe_eval("lcm(4, 6)") == "12" + + def test_comb(self) -> None: + assert safe_eval("comb(10, 3)") == "120" + + def test_perm(self) -> None: + assert safe_eval("perm(5, 3)") == "60" + + def test_hypot(self) -> None: + assert safe_eval("hypot(3, 4)") == "5" + + +class TestStatistics: + def test_mean(self) -> None: + assert safe_eval("mean(1, 2, 3, 4, 5)") == "3" + + def test_median(self) -> None: + assert safe_eval("median(1, 3, 5, 7, 9)") == "5" + + def test_stdev(self) -> None: + result = float(safe_eval("stdev(2, 4, 4, 4, 5, 5, 7, 9)")) + assert result > 0 + + def test_min_max(self) -> None: + assert safe_eval("min(3, 1, 4, 1, 5)") == "1" + assert safe_eval("max(3, 1, 4, 1, 5)") == "5" + + def test_sum(self) -> None: + assert safe_eval("sum([1, 2, 3, 4, 5])") == "15" + + +class TestComparison: + def test_equal(self) -> None: + assert safe_eval("2+2 == 4") == "True" + + def test_not_equal(self) -> None: + assert safe_eval("2+2 != 5") == "True" + + def test_less_than(self) -> None: + assert safe_eval("3 < 5") == "True" + + def test_greater_than(self) -> None: + assert safe_eval("5 > 3") == "True" + + +class TestIfExpression: + def test_ternary(self) -> None: + assert safe_eval("42 if 2>1 else 0") == "42" + + +class TestSafety: + def test_rejects_import(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("__import__('os')") + + def test_rejects_attribute_access(self) -> None: + with pytest.raises(ValueError, match="不支持"): + safe_eval("math.pi") + + def test_rejects_unknown_function(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("eval('1+1')") + + def test_rejects_unknown_variable(self) -> None: + with pytest.raises(ValueError, match="未知变量"): + safe_eval("x + 1") + + def test_rejects_too_large_exponent(self) -> None: + with pytest.raises(ValueError, match="指数过大"): + safe_eval("2**100000") + + def test_rejects_too_long_expression(self) -> None: + with pytest.raises(ValueError, match="表达式过长"): + safe_eval("1+" * 300 + "1") + + def test_rejects_empty(self) -> None: + with pytest.raises(ValueError, match="表达式为空"): + safe_eval("") + + def test_division_by_zero(self) -> None: + with pytest.raises(ZeroDivisionError): + safe_eval("1/0") + + +# --------------------------------------------------------------------------- +# execute() 集成测试 +# --------------------------------------------------------------------------- + + +class TestExecute: + @pytest.mark.asyncio + async def test_basic_calculation(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+3*4"}, {}) + assert "= 14" in result + + @pytest.mark.asyncio + async def test_empty_expression(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_error_message(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "1/0"}, {}) + assert "除以零" in result + + @pytest.mark.asyncio + async def test_syntax_error(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+++"}, {}) + assert "语法错误" in result From 56d277c9ab680757da8874e66ba257039faa2963 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 08:56:18 +0800 Subject: [PATCH 20/57] =?UTF-8?q?feat(config):=20=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=8E=86=E5=8F=B2=E9=99=90=E5=88=B6=E5=85=A8=E9=9D=A2=E5=8F=AF?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Config 新增 7 个 history_* 配置项,通过 [history] 配置节管理 - 所有消息获取/搜索/分析工具的硬编码限制改为从配置读取 - history_max_records 支持 0=无限制(默认 10000) - 提升默认值:filtered_result_limit 50→200, summary_fetch_limit 500→1000, summary_time_fetch_limit 2000→5000, onebot_fetch_limit 5000→10000, group_analysis_limit 100→500 - config.toml.example 新增双语注释的完整配置节 - 新增 test_history_config.py 验证配置字段和 helper 逻辑 - 更新 test_fetch_messages_tool.py 适配新默认值 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CLAUDE.md | 2 +- config.toml.example | 22 ++++- src/Undefined/config/loader.py | 86 ++++++++++++++++++- .../tools/fetch_messages/config.json | 2 +- .../tools/fetch_messages/handler.py | 28 +++++- .../analyze_join_statistics/config.json | 2 +- .../analyze_join_statistics/handler.py | 4 +- .../analyze_member_messages/config.json | 4 +- .../analyze_member_messages/handler.py | 7 +- .../analyze_new_member_activity/config.json | 2 +- .../analyze_new_member_activity/handler.py | 4 +- .../messages/get_messages_by_time/handler.py | 23 +++-- .../messages/get_recent_messages/handler.py | 21 +++-- src/Undefined/utils/history.py | 42 +++++---- src/Undefined/utils/recent_messages.py | 9 +- tests/test_fetch_messages_tool.py | 55 ++++++++++-- tests/test_history_config.py | 76 ++++++++++++++++ 17 files changed, 337 insertions(+), 52 deletions(-) create mode 100644 tests/test_history_config.py diff --git a/CLAUDE.md b/CLAUDE.md index 3e3df12..d699463 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,7 +80,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) 车站-列车模型(QueueManager):按模型隔离队列组,4 级优先级(超管 > 私聊 > @提及 > 普通群聊),普通队列自动修剪保留最新 2 条,非阻塞按节奏发车(默认 1Hz)。 ### 存储与数据 -- `data/history/` — 消息历史(group_*.json / private_*.json,10000 条限制) +- `data/history/` — 消息历史(group_*.json / private_*.json,默认 10000 条,可通过 `[history]` 配置节调整,0=无限制) - `data/cognitive/` — ChromaDB 向量库 + profiles/ 侧写 + queues/ 任务队列 - `data/memes/` — 表情包库(blobs原图、previews预览图、memes.sqlite3元数据、chromadb向量检索) - `data/memory.json` — 置顶备忘录(500 条上限) diff --git a/config.toml.example b/config.toml.example index 4893aff..4ab0d04 100644 --- a/config.toml.example +++ b/config.toml.example @@ -726,9 +726,27 @@ inverted_question_enabled = false # zh: 历史记录配置。 # en: History settings. [history] -# zh: 每个会话最多保留的消息条数。 -# en: Max messages to keep per conversation. +# zh: 每个会话最多保留的消息条数(0 = 无限制,注意内存占用)。 +# en: Max messages to keep per conversation (0 = unlimited, mind memory usage). max_records = 10000 +# zh: 工具过滤查询返回的最大消息条数。 +# en: Max messages returned by tool filtered queries. +filtered_result_limit = 200 +# zh: 工具过滤搜索时扫描的最大消息条数。 +# en: Max messages to scan when tools perform filtered searches. +search_scan_limit = 10000 +# zh: 总结 agent 单次拉取的最大消息条数。 +# en: Max messages the summary agent can fetch per call. +summary_fetch_limit = 1000 +# zh: 总结 agent 按时间范围查询时的最大拉取条数。 +# en: Max messages the summary agent fetches for time-range queries. +summary_time_fetch_limit = 5000 +# zh: OneBot API 回退获取的最大消息条数。 +# en: Max messages to fetch via OneBot API fallback. +onebot_fetch_limit = 10000 +# zh: 群分析工具的消息/成员返回上限。 +# en: Max messages/members returned by group analysis tools. +group_analysis_limit = 500 # zh: Skills 热重载配置(可选)。 # en: Skills hot reload settings (optional). diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index 998a3d7..ac65473 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -504,6 +504,12 @@ class Config: token_usage_max_total_mb: int token_usage_archive_prune_mode: str history_max_records: int + history_filtered_result_limit: int + history_search_scan_limit: int + history_summary_fetch_limit: int + history_summary_time_fetch_limit: int + history_onebot_fetch_limit: int + history_group_analysis_limit: int skills_hot_reload: bool skills_hot_reload_interval: float skills_hot_reload_debounce: float @@ -1036,8 +1042,78 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi "delete", ) - history_max_records = _coerce_int( - _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), 10000 + history_max_records = max( + 0, + _coerce_int( + _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), + 10000, + ), + ) + history_filtered_result_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "filtered_result_limit"), + "HISTORY_FILTERED_RESULT_LIMIT", + ), + 200, + ), + ) + history_search_scan_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "search_scan_limit"), + "HISTORY_SEARCH_SCAN_LIMIT", + ), + 10000, + ), + ) + history_summary_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_fetch_limit"), + "HISTORY_SUMMARY_FETCH_LIMIT", + ), + 1000, + ), + ) + history_summary_time_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_time_fetch_limit"), + "HISTORY_SUMMARY_TIME_FETCH_LIMIT", + ), + 5000, + ), + ) + history_onebot_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "onebot_fetch_limit"), + "HISTORY_ONEBOT_FETCH_LIMIT", + ), + 10000, + ), + ) + history_group_analysis_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "group_analysis_limit"), + "HISTORY_GROUP_ANALYSIS_LIMIT", + ), + 500, + ), ) skills_hot_reload = _coerce_bool( @@ -1451,6 +1527,12 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi token_usage_archive_prune_mode=token_usage_archive_prune_mode, skills_hot_reload=skills_hot_reload, history_max_records=history_max_records, + history_filtered_result_limit=history_filtered_result_limit, + history_search_scan_limit=history_search_scan_limit, + history_summary_fetch_limit=history_summary_fetch_limit, + history_summary_time_fetch_limit=history_summary_time_fetch_limit, + history_onebot_fetch_limit=history_onebot_fetch_limit, + history_group_analysis_limit=history_group_analysis_limit, skills_hot_reload_interval=skills_hot_reload_interval, skills_hot_reload_debounce=skills_hot_reload_debounce, agent_intro_autogen_enabled=agent_intro_autogen_enabled, diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json index d64a015..30042bc 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json @@ -8,7 +8,7 @@ "properties": { "count": { "type": "integer", - "description": "要获取的消息条数,默认50,最大500。" + "description": "要获取的消息条数,默认50,上限由服务器配置决定。" }, "time_range": { "type": "string", diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py index c79cb07..4426752 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -12,9 +12,21 @@ _TIME_RANGE_PATTERN = re.compile(r"^(\d+)([hHdDwW])$") _TIME_UNIT_SECONDS = {"h": 3600, "d": 86400, "w": 604800} -_MAX_COUNT = 500 _DEFAULT_COUNT = 50 -_MAX_FETCH_FOR_TIME_FILTER = 2000 + +# 以下值仅作为 runtime_config 缺失时的回退 +_FALLBACK_MAX_COUNT = 1000 +_FALLBACK_MAX_FETCH_FOR_TIME_FILTER = 5000 + + +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _parse_time_range(value: str) -> int | None: @@ -146,8 +158,11 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: time_range_str = str(args.get("time_range", "")).strip() raw_count = args.get("count", _DEFAULT_COUNT) + max_count = _get_history_limit( + context, "history_summary_fetch_limit", _FALLBACK_MAX_COUNT + ) try: - count = min(max(int(raw_count), 1), _MAX_COUNT) + count = min(max(int(raw_count), 1), max_count) except (TypeError, ValueError): count = _DEFAULT_COUNT @@ -155,7 +170,12 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: seconds = _parse_time_range(time_range_str) if seconds is None: return f"无法解析时间范围: {time_range_str}(支持格式: 1h, 6h, 1d, 7d)" - fetch_count = max(count * 2, _MAX_FETCH_FOR_TIME_FILTER) + max_time_fetch = _get_history_limit( + context, + "history_summary_time_fetch_limit", + _FALLBACK_MAX_FETCH_FOR_TIME_FILTER, + ) + fetch_count = max(count * 2, max_time_fetch) messages = history_manager.get_recent(chat_id, chat_type, 0, fetch_count) if messages: messages = _filter_by_time(messages, seconds) diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json index 410c0a6..c20eebe 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json @@ -30,7 +30,7 @@ }, "member_limit": { "type": "integer", - "description": "返回成员列表的数量限制(最大100),默认 20", + "description": "返回成员列表的数量限制,上限由服务器配置决定,默认 20", "default": 20 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py index 9ba6267..fe660ed 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: member_limit = int(member_limit_raw) if member_limit_raw is not None else 20 if member_limit < 0: return "参数错误:member_limit 必须是非负整数" - member_limit = min(member_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + member_limit = min(member_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:member_limit 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json index 188e235..6b3ce00 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json @@ -29,12 +29,12 @@ }, "message_limit": { "type": "integer", - "description": "返回消息内容的数量限制(最大100),默认 20", + "description": "返回消息内容的数量限制,上限由服务器配置决定,默认 20", "default": 20 }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py index 31b8ef8..dd00117 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py @@ -42,7 +42,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: message_limit = int(message_limit_raw) if message_limit_raw is not None else 20 if message_limit < 0: return "参数错误:message_limit 必须是非负整数" - message_limit = min(message_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + message_limit = min(message_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:message_limit 必须是整数" @@ -53,7 +55,8 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json index 118b896..2a31c98 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json @@ -20,7 +20,7 @@ }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 }, "top_count": { diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py index fbbec91..a8c5d7c 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + cfg = context.get("runtime_config") + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py index 12ff836..96aaa89 100644 --- a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py @@ -1,7 +1,15 @@ from typing import Any, Dict from datetime import datetime -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _resolve_chat_id(chat_id: str, msg_type: str, history_manager: Any) -> str: @@ -151,8 +159,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: resolved_chat_id = _resolve_chat_id(chat_id, msg_type, history_manager) if get_recent_messages_callback: + search_scan_limit = _get_history_limit( + context, "history_search_scan_limit", 10000 + ) messages = await get_recent_messages_callback( - resolved_chat_id, msg_type, 0, 10000 + resolved_chat_id, msg_type, 0, search_scan_limit ) # 时间范围过滤 @@ -164,10 +175,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: filtered_messages, keyword, sender ) - # 限制最多返回 50 条 + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) total_matched = len(filtered_messages) - if total_matched > _FILTERED_RESULT_LIMIT: - filtered_messages = filtered_messages[:_FILTERED_RESULT_LIMIT] + if total_matched > filtered_result_limit: + filtered_messages = filtered_messages[:filtered_result_limit] formatted = [] for msg in filtered_messages: diff --git a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py index 54874bb..fbfd906 100644 --- a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py @@ -2,7 +2,15 @@ from typing import Any, Dict -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _find_group_id_by_name(chat_id: str, history_manager: Any) -> str | None: @@ -230,8 +238,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 获取消息:有过滤条件时扩大获取范围 messages = [] + search_scan_limit = _get_history_limit(context, "history_search_scan_limit", 10000) fetch_start = 0 if has_filter else start - fetch_end = 10000 if has_filter else end + fetch_end = search_scan_limit if has_filter else end if get_recent_messages_callback: messages = await get_recent_messages_callback( @@ -265,9 +274,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 对过滤结果进行分页切片 messages = messages[start:end] - # 限制最多返回 50 条 - if len(messages) > _FILTERED_RESULT_LIMIT: - messages = messages[:_FILTERED_RESULT_LIMIT] + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) + if len(messages) > filtered_result_limit: + messages = messages[:filtered_result_limit] # 格式化消息 formatted = [_format_message_xml(msg) for msg in messages] diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index b6dfa4b..578711e 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -81,17 +81,20 @@ def _get_private_history_path(self, user_id: int) -> str: async def _save_history_to_file( self, history: list[dict[str, Any]], path: str ) -> None: - """异步保存历史记录到文件(最多 10000 条)""" + """异步保存历史记录到文件""" from Undefined.utils import io try: - # 只保留最近的 self._max_records 条 - truncated_history = ( - history[-self._max_records :] - if len(history) > self._max_records - else history - ) - truncated = len(history) > self._max_records + if self._max_records > 0: + truncated_history = ( + history[-self._max_records :] + if len(history) > self._max_records + else history + ) + truncated = len(history) > self._max_records + else: + truncated_history = history + truncated = False logger.debug( f"[历史记录] 准备保存: path={path}, total={len(history)}, truncated={truncated}" @@ -210,12 +213,13 @@ async def _load_history_from_file(self, path: str) -> list[dict[str, Any]]: normalized_history.append(msg) - # 只保留最近的 self._max_records 条 - return ( - normalized_history[-self._max_records :] - if len(normalized_history) > self._max_records - else normalized_history - ) + if self._max_records > 0: + return ( + normalized_history[-self._max_records :] + if len(normalized_history) > self._max_records + else normalized_history + ) + return normalized_history except Exception as e: logger.error(f"加载历史记录失败 {path}: {e}") @@ -346,7 +350,10 @@ async def add_group_message( self._message_history[group_id_str].append(record) - if len(self._message_history[group_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._message_history[group_id_str]) > self._max_records + ): self._message_history[group_id_str] = self._message_history[ group_id_str ][-self._max_records :] @@ -395,7 +402,10 @@ async def add_private_message( self._private_message_history[user_id_str].append(record) - if len(self._private_message_history[user_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._private_message_history[user_id_str]) > self._max_records + ): self._private_message_history[user_id_str] = ( self._private_message_history[user_id_str][-self._max_records :] ) diff --git a/src/Undefined/utils/recent_messages.py b/src/Undefined/utils/recent_messages.py index 84953ae..10adca8 100644 --- a/src/Undefined/utils/recent_messages.py +++ b/src/Undefined/utils/recent_messages.py @@ -247,9 +247,14 @@ async def get_recent_messages_prefer_local( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """优先从本地 history 获取最近消息,必要时回退到 OneBot。""" + if max_onebot_count is None: + from Undefined.config import get_config + + cfg = get_config(strict=False) + max_onebot_count = getattr(cfg, "history_onebot_fetch_limit", 10000) norm_start, norm_end = _normalize_range(start, end) if norm_end <= 0: return [] @@ -311,7 +316,7 @@ async def get_recent_messages_prefer_onebot( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """兼容旧名称,当前行为等同于本地优先。""" return await get_recent_messages_prefer_local( diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py index 337b006..1804c72 100644 --- a/tests/test_fetch_messages_tool.py +++ b/tests/test_fetch_messages_tool.py @@ -404,8 +404,8 @@ async def test_fetch_messages_no_history_manager() -> None: @pytest.mark.asyncio -async def test_fetch_messages_count_capped_at_500() -> None: - """Count is capped at 500.""" +async def test_fetch_messages_count_capped_at_config_limit() -> None: + """Count is capped at the configured summary fetch limit (fallback 1000).""" history_manager = MagicMock() history_manager.get_recent.return_value = [] @@ -416,7 +416,7 @@ async def test_fetch_messages_count_capped_at_500() -> None: await fetch_messages_execute({"count": 9999}, context) - history_manager.get_recent.assert_called_once_with("123456", "group", 0, 500) + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 1000) @pytest.mark.asyncio @@ -453,7 +453,7 @@ async def test_fetch_messages_invalid_count_defaults() -> None: @pytest.mark.asyncio async def test_fetch_messages_time_range_fetch_larger_batch() -> None: - """Time range mode fetches larger batch (max(count*2, 2000)).""" + """Time range mode fetches larger batch (max(count*2, fallback_time_limit)).""" history_manager = MagicMock() history_manager.get_recent.return_value = [] @@ -467,5 +467,48 @@ async def test_fetch_messages_time_range_fetch_larger_batch() -> None: context, ) - # max(50*2, 2000) = 2000 - history_manager.get_recent.assert_called_once_with("123456", "group", 0, 2000) + # max(50*2, 5000) = 5000 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 5000) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_uses_runtime_config() -> None: + """When runtime_config provides history_summary_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 300) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_uses_runtime_config() -> None: + """When runtime_config provides history_summary_time_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 50, "time_range": "1d"}, context) + + # max(50*2, 800) = 800 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 800) diff --git a/tests/test_history_config.py b/tests/test_history_config.py new file mode 100644 index 0000000..c5b0d3e --- /dev/null +++ b/tests/test_history_config.py @@ -0,0 +1,76 @@ +"""Tests for configurable history limits in Config and tool handlers.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Helper: simulate _get_history_limit as used across multiple handlers +# --------------------------------------------------------------------------- +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """Mirror of the helper used in tool handlers.""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback + + +class TestGetHistoryLimit: + """Tests for the _get_history_limit helper pattern.""" + + def test_returns_fallback_when_no_config(self) -> None: + assert _get_history_limit({}, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_is_none(self) -> None: + ctx: dict[str, Any] = {"runtime_config": None} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_config_value(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 500 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 500 + + def test_returns_fallback_when_config_attr_missing(self) -> None: + cfg = MagicMock(spec=[]) # no attributes + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "nonexistent_field", 42) == 42 + + def test_returns_fallback_when_config_value_zero(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 0 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_value_negative(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = -1 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + +class TestConfigHistoryFieldDefaults: + """Verify that Config dataclass has the expected history fields.""" + + def test_config_has_history_fields(self) -> None: + from Undefined.config.loader import Config + + fields = Config.__dataclass_fields__ + expected_fields = [ + "history_max_records", + "history_filtered_result_limit", + "history_search_scan_limit", + "history_summary_fetch_limit", + "history_summary_time_fetch_limit", + "history_onebot_fetch_limit", + "history_group_analysis_limit", + ] + for field_name in expected_fields: + assert field_name in fields, f"Missing field: {field_name}" + assert fields[field_name].type == "int", ( + f"{field_name} should be int, got {fields[field_name].type}" + ) From 6b274c0f28f513efe63727a4e6ccbd6b7803ad0b Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 09:29:45 +0800 Subject: [PATCH 21/57] =?UTF-8?q?feat(webui):=20=E9=95=BF=E6=9C=9F?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E5=AE=8C=E6=95=B4=20CRUD=20=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Runtime API: 新增 POST/PATCH/DELETE /api/v1/memory 端点 - WebUI proxy: 新增记忆增删改代理路由 - 前端: 内联创建表单、行内编辑(Ctrl+Enter保存/Esc取消)、确认删除 - 修复 api.js Content-Type 仅对 POST 自动设置的问题,扩展至 PATCH/PUT/DELETE - 修复 CORS Allow-Methods 缺少 PATCH/DELETE - 新增 CSS 样式支持编辑/删除按钮与内联编辑区域 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/api/app.py | 73 ++++++++- src/Undefined/webui/routes/_runtime.py | 45 +++++ src/Undefined/webui/static/css/components.css | 55 +++++++ src/Undefined/webui/static/js/api.js | 8 +- src/Undefined/webui/static/js/runtime.js | 154 +++++++++++++++++- src/Undefined/webui/templates/index.html | 12 +- 6 files changed, 341 insertions(+), 6 deletions(-) diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 7c58116..cbe5824 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -186,7 +186,9 @@ def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> N origin = normalize_origin(str(request.headers.get("Origin") or "")) settings = load_webui_settings() response.headers.setdefault("Vary", "Origin") - response.headers.setdefault("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + response.headers.setdefault( + "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" + ) response.headers.setdefault( "Access-Control-Allow-Headers", "Authorization, Content-Type, X-Undefined-API-Key", @@ -544,7 +546,14 @@ def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[st ), } }, - "/api/v1/memory": {"get": {"summary": "List/search manual memories"}}, + "/api/v1/memory": { + "get": {"summary": "List/search manual memories"}, + "post": {"summary": "Create a manual memory"}, + }, + "/api/v1/memory/{uuid}": { + "patch": {"summary": "Update a manual memory by UUID"}, + "delete": {"summary": "Delete a manual memory by UUID"}, + }, "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, "/api/v1/memes/{uid}": { @@ -749,6 +758,9 @@ async def _auth_middleware( web.get("/api/v1/probes/internal", self._internal_probe_handler), web.get("/api/v1/probes/external", self._external_probe_handler), web.get("/api/v1/memory", self._memory_handler), + web.post("/api/v1/memory", self._memory_create_handler), + web.patch("/api/v1/memory/{uuid}", self._memory_update_handler), + web.delete("/api/v1/memory/{uuid}", self._memory_delete_handler), web.get("/api/v1/memes", self._meme_list_handler), web.get("/api/v1/memes/stats", self._meme_stats_handler), web.get("/api/v1/memes/{uid}", self._meme_detail_handler), @@ -1187,6 +1199,63 @@ def _created_sort_key(item: dict[str, Any]) -> float: } ) + async def _memory_create_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + new_uuid = await memory_storage.add(fact) + if new_uuid is None: + return _json_error("Failed to create memory", status=500) + # add() returns existing UUID on duplicate + existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] + item = existing[0] if existing else None + return web.json_response( + { + "uuid": new_uuid, + "fact": item.fact if item else fact, + "created_at": item.created_at if item else "", + }, + status=201, + ) + + async def _memory_update_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + ok = await memory_storage.update(target_uuid, fact) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + + async def _memory_delete_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + ok = await memory_storage.delete(target_uuid) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "deleted": True}) + async def _meme_list_handler(self, request: web.Request) -> Response: meme_service = self._ctx.meme_service if meme_service is None or not meme_service.enabled: diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index c6de3eb..371c46b 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -296,6 +296,51 @@ async def runtime_memory_handler(request: web.Request) -> Response: ) +@routes.post("/api/v1/management/runtime/memory") +@routes.post("/api/runtime/memory") +async def runtime_memory_create_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="POST", + path="/api/v1/memory", + payload=payload, + ) + + +@routes.patch("/api/v1/management/runtime/memory/{uuid}") +@routes.patch("/api/runtime/memory/{uuid}") +async def runtime_memory_update_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="PATCH", + path=f"/api/v1/memory/{target_uuid}", + payload=payload, + ) + + +@routes.delete("/api/v1/management/runtime/memory/{uuid}") +@routes.delete("/api/runtime/memory/{uuid}") +async def runtime_memory_delete_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + return await _proxy_runtime( + method="DELETE", + path=f"/api/v1/memory/{target_uuid}", + ) + + @routes.get("/api/v1/management/runtime/cognitive/events") @routes.get("/api/runtime/cognitive/events") async def runtime_cognitive_events_handler(request: web.Request) -> Response: diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 7118ef6..4d4c7f4 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -428,6 +428,61 @@ white-space: pre-wrap; word-break: break-word; } +/* Memory CRUD */ +.memory-create-form { + display: flex; + gap: 8px; + align-items: flex-start; + padding: 8px 0; +} +.memory-create-form textarea { + flex: 1; + resize: vertical; + min-height: 40px; +} +.memory-create-actions { + flex-shrink: 0; + align-self: flex-end; +} +.memory-item-actions { + display: flex; + gap: 4px; + flex-shrink: 0; + margin-left: 8px; +} +.memory-item-actions button { + background: none; + border: 1px solid var(--border-color); + border-radius: var(--radius-sm); + cursor: pointer; + padding: 2px 6px; + font-size: 12px; + color: var(--text-tertiary); + transition: color 0.15s, border-color 0.15s; +} +.memory-item-actions button:hover { + color: var(--accent-color); + border-color: var(--accent-color); +} +.memory-item-actions button.memory-btn-delete:hover { + color: var(--error); + border-color: var(--error); +} +.memory-edit-area { + width: 100%; + min-height: 60px; + font-size: 14px; + line-height: 1.6; + font-family: inherit; + resize: vertical; + margin-top: 4px; +} +.memory-edit-actions { + display: flex; + gap: 6px; + margin-top: 6px; + justify-content: flex-end; +} .runtime-doc { font-size: 13px; line-height: 1.6; diff --git a/src/Undefined/webui/static/js/api.js b/src/Undefined/webui/static/js/api.js index 8ab88a3..82b427c 100644 --- a/src/Undefined/webui/static/js/api.js +++ b/src/Undefined/webui/static/js/api.js @@ -31,7 +31,13 @@ function shouldRetryCandidate(res) { async function requestOnce(path, options = {}) { const headers = { ...(options.headers || {}) }; - if (options.method === "POST" && options.body && !headers["Content-Type"]) { + const needsJson = + options.body && + !headers["Content-Type"] && + ["POST", "PATCH", "PUT", "DELETE"].includes( + String(options.method || "").toUpperCase(), + ); + if (needsJson) { headers["Content-Type"] = "application/json"; } if (state.authAccessToken && !headers.Authorization) { diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index 6968c0f..42655a5 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -636,6 +636,8 @@ if (buffer.trim()) emitBlock(buffer); } + let _memoryMutating = false; + function renderMemoryItems(payload) { const container = get("runtimeMemoryList"); const meta = get("runtimeMemoryMeta"); @@ -666,9 +668,155 @@ const uuid = escapeHtml(item.uuid || ""); const fact = escapeHtml(item.fact || ""); const created = escapeHtml(item.created_at || ""); - return ` `; + return `${uuid}${created}${fact}`; }) .join(""); + + container.querySelectorAll(".memory-btn-edit").forEach((btn) => { + btn.addEventListener("click", () => + startEditMemory(btn.dataset.uuid), + ); + }); + container.querySelectorAll(".memory-btn-delete").forEach((btn) => { + btn.addEventListener("click", () => deleteMemory(btn.dataset.uuid)); + }); + } + + function startEditMemory(uuid) { + const container = get("runtimeMemoryList"); + if (!container) return; + const itemEl = container.querySelector( + `.runtime-list-item[data-uuid="${CSS.escape(uuid)}"]`, + ); + if (!itemEl) return; + const factEl = itemEl.querySelector(".runtime-list-fact"); + if (!factEl || factEl.dataset.editing === "true") return; + + const currentText = factEl.textContent || ""; + factEl.dataset.editing = "true"; + factEl.innerHTML = ""; + + const textarea = document.createElement("textarea"); + textarea.className = "form-control memory-edit-area"; + textarea.value = currentText; + + const actions = document.createElement("div"); + actions.className = "memory-edit-actions"; + const saveBtn = document.createElement("button"); + saveBtn.className = "btn btn-sm"; + saveBtn.textContent = "保存"; + const cancelBtn = document.createElement("button"); + cancelBtn.className = "btn btn-sm"; + cancelBtn.textContent = "取消"; + actions.append(saveBtn, cancelBtn); + factEl.append(textarea, actions); + textarea.focus(); + + cancelBtn.addEventListener("click", () => { + delete factEl.dataset.editing; + factEl.innerHTML = ""; + factEl.textContent = currentText; + }); + + saveBtn.addEventListener("click", () => + updateMemory(uuid, textarea.value), + ); + + textarea.addEventListener("keydown", (e) => { + if (e.key === "Escape") { + e.preventDefault(); + cancelBtn.click(); + } + if (e.key === "Enter" && e.ctrlKey) { + e.preventDefault(); + saveBtn.click(); + } + }); + } + + async function createMemory() { + if (_memoryMutating) return; + const input = get("memoryCreateInput"); + if (!input) return; + const fact = String(input.value || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + _memoryMutating = true; + const btn = get("btnMemoryCreate"); + if (btn) btn.disabled = true; + try { + const res = await api("/api/runtime/memory", { + method: "POST", + body: JSON.stringify({ fact }), + }); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已添加", "success"); + input.value = ""; + await searchMemory(); + } catch (err) { + showToast(`添加失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + if (btn) btn.disabled = false; + } + } + + async function updateMemory(uuid, newFact) { + const fact = String(newFact || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + if (_memoryMutating) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "PATCH", + body: JSON.stringify({ fact }), + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已更新", "success"); + await searchMemory(); + } catch (err) { + showToast(`更新失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } + } + + async function deleteMemory(uuid) { + if (_memoryMutating) return; + if (!confirm(`确认删除记忆 ${uuid.slice(0, 8)}…?`)) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "DELETE", + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已删除", "success"); + await searchMemory(); + } catch (err) { + showToast(`删除失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } } function setListMessage(metaId, listId, message) { @@ -1204,6 +1352,10 @@ if (memoryRefresh) memoryRefresh.addEventListener("click", refreshMemory); + const memoryCreateBtn = get("btnMemoryCreate"); + if (memoryCreateBtn) + memoryCreateBtn.addEventListener("click", createMemory); + const runMemorySearch = () => runQueryAction("memory", "btnRuntimeMemorySearch", searchMemory); const runEventsSearch = () => diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 3f4c8a9..3c9886f 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -446,7 +446,7 @@${uuid}${created}${fact}探针
+记忆检索
-只读检索记忆、认知事件与侧写。
+管理长期记忆,检索认知事件与侧写。
++资源趋势+ ++ CPU + Memory ++@@ -372,6 +374,12 @@运行环境@@ -347,6 +356,7 @@配置修改
更多操作 + @@ -403,6 +416,10 @@系统日志
+ + From 0e034fc05a7f5485217db08d6cbec76cfcfa8483 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:44:49 +0800 Subject: [PATCH 27/57] feat(webui): add Cmd/Ctrl+K command palette - CSS: overlay, palette card, input, list items, keyboard hint styles - HTML: modal overlay with input and list container - i18n: zh/en strings for all palette commands - JS: command list (tab nav, refresh, logout), open/close, keyboard navigation (arrows + enter), fuzzy filtering, Ctrl/Cmd+K toggle, Escape to close, click-outside dismiss Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/webui/static/css/components.css | 11 + src/Undefined/webui/static/js/i18n.js | 34 +++ src/Undefined/webui/static/js/main.js | 247 ++++++++++++++++++ src/Undefined/webui/templates/index.html | 16 ++ 4 files changed, 308 insertions(+) diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 08044e1..bd91f6d 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -796,3 +796,14 @@ .skeleton-text.medium { width: 80%; } .skeleton-block { height: 48px; margin-bottom: 12px; } .skeleton-bar { height: 8px; border-radius: 999px; } + +/* Command Palette */ +.cmd-palette-overlay { position: fixed; inset: 0; background: rgba(0,0,0,0.5); z-index: 9998; display: flex; align-items: flex-start; justify-content: center; padding-top: 20vh; } +.cmd-palette { width: 480px; max-width: 90vw; background: var(--bg-card); border: 1px solid var(--border-color); border-radius: var(--radius-lg); box-shadow: var(--shadow-lg); overflow: hidden; } +.cmd-palette-input { width: 100%; padding: 14px 18px; border: none; border-bottom: 1px solid var(--border-color); background: transparent; color: var(--text-primary); font-size: 15px; outline: none; box-sizing: border-box; } +.cmd-palette-input::placeholder { color: var(--text-tertiary); } +.cmd-palette-list { max-height: 320px; overflow-y: auto; padding: 6px; } +.cmd-palette-item { padding: 10px 14px; border-radius: var(--radius-sm); cursor: pointer; display: flex; align-items: center; gap: 10px; color: var(--text-secondary); font-size: 14px; } +.cmd-palette-item:hover, .cmd-palette-item.active { background: var(--accent-subtle); color: var(--text-primary); } +.cmd-palette-item .cmd-key { font-size: 11px; color: var(--text-tertiary); margin-left: auto; font-family: var(--font-mono); } +.cmd-palette-empty { padding: 20px; text-align: center; color: var(--text-tertiary); font-size: 13px; } diff --git a/src/Undefined/webui/static/js/i18n.js b/src/Undefined/webui/static/js/i18n.js index 1de6b78..1d60d0b 100644 --- a/src/Undefined/webui/static/js/i18n.js +++ b/src/Undefined/webui/static/js/i18n.js @@ -108,6 +108,12 @@ const I18N = { "config.reload_error": "配置重载失败。", "config.bootstrap_created": "检测到缺少 config.toml,已从示例生成;请在此页完善配置并保存。", + "config.history": "版本历史", + "config.history_empty": "暂无备份", + "config.history_restore": "恢复", + "config.history_restore_confirm": + "确定恢复到此版本?当前配置将自动备份。", + "config.history_restored": "已恢复配置", "logs.title": "运行日志", "logs.subtitle": "实时查看日志尾部输出。", "logs.auto": "自动刷新", @@ -276,6 +282,17 @@ const I18N = { "update.not_eligible": "未满足更新条件(仅支持官方 origin/main)", "update.failed": "更新失败", "update.no_restart": "更新已完成但未重启(请检查 uv sync 输出)", + "cmd.placeholder": "输入命令...", + "cmd.empty": "没有匹配的命令", + "cmd.tab_overview": "跳转到 概览", + "cmd.tab_config": "跳转到 配置", + "cmd.tab_logs": "跳转到 日志", + "cmd.tab_probes": "跳转到 探针", + "cmd.tab_memory": "跳转到 备忘录", + "cmd.tab_memes": "跳转到 表情包", + "cmd.tab_cognitive": "跳转到 认知记忆", + "cmd.refresh": "刷新当前页面", + "cmd.logout": "退出登录", }, en: { "landing.title": "Undefined Console", @@ -392,6 +409,12 @@ const I18N = { "config.reload_error": "Failed to reload configuration.", "config.bootstrap_created": "config.toml was missing and has been generated from the example. Please review and save your configuration.", + "config.history": "Version History", + "config.history_empty": "No backups yet", + "config.history_restore": "Restore", + "config.history_restore_confirm": + "Restore to this version? Current config will be backed up automatically.", + "config.history_restored": "Configuration restored", "logs.title": "System Logs", "logs.subtitle": "Real-time view of recent log output.", "logs.auto": "Auto Refresh", @@ -564,5 +587,16 @@ const I18N = { "update.no_restart": "Updated but not restarted (check uv sync output)", "config.aot_add": "+ Add Entry", "config.aot_remove": "Remove", + "cmd.placeholder": "Type a command...", + "cmd.empty": "No matching commands", + "cmd.tab_overview": "Go to Overview", + "cmd.tab_config": "Go to Config", + "cmd.tab_logs": "Go to Logs", + "cmd.tab_probes": "Go to Probes", + "cmd.tab_memory": "Go to Memory", + "cmd.tab_memes": "Go to Memes", + "cmd.tab_cognitive": "Go to Cognitive", + "cmd.refresh": "Refresh current page", + "cmd.logout": "Logout", }, }; diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index 949ae83..2be7aee 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -176,6 +176,142 @@ function setMobileInlineActionsOpen(key, open) { syncMobileChrome(); } +// Command Palette +const _cmdCommands = [ + { + id: "overview", + label: () => t("cmd.tab_overview"), + action: () => switchTab("overview"), + keys: "1", + }, + { + id: "config", + label: () => t("cmd.tab_config"), + action: () => switchTab("config"), + keys: "2", + }, + { + id: "logs", + label: () => t("cmd.tab_logs"), + action: () => switchTab("logs"), + keys: "3", + }, + { + id: "probes", + label: () => t("cmd.tab_probes"), + action: () => switchTab("probes"), + keys: "4", + }, + { + id: "memory", + label: () => t("cmd.tab_memory"), + action: () => switchTab("memory"), + keys: "5", + }, + { + id: "memes", + label: () => t("cmd.tab_memes"), + action: () => switchTab("memes"), + keys: "6", + }, + { + id: "cognitive", + label: () => t("cmd.tab_cognitive"), + action: () => switchTab("cognitive"), + keys: "7", + }, + { + id: "refresh", + label: () => t("cmd.refresh"), + action: () => location.reload(), + keys: "R", + }, + { + id: "logout", + label: () => t("cmd.logout"), + action: () => { + get("btnLogout")?.click(); + }, + keys: "", + }, +]; + +let _cmdActiveIndex = 0; + +function openCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + const input = get("cmdPaletteInput"); + if (!overlay || !input) return; + input.placeholder = t("cmd.placeholder"); + overlay.style.display = "flex"; + input.value = ""; + _cmdActiveIndex = 0; + _renderCmdList(""); + input.focus(); +} + +function closeCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + if (overlay) overlay.style.display = "none"; +} + +function _renderCmdList(query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (filtered.length === 0) { + list.innerHTML = `${escapeHtml(t("cmd.empty"))}`; + return; + } + _cmdActiveIndex = Math.min(_cmdActiveIndex, filtered.length - 1); + list.innerHTML = filtered + .map( + (c, i) => + `${escapeHtml(c.label())}${c.keys ? `${escapeHtml(c.keys)}` : ""}`, + ) + .join(""); + list.querySelectorAll(".cmd-palette-item").forEach((el, i) => { + el.addEventListener("click", () => { + closeCmdPalette(); + filtered[i].action(); + }); + el.addEventListener("mouseenter", () => { + _cmdActiveIndex = i; + _renderCmdList(query); + }); + }); +} + +function _handleCmdKey(e, query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (e.key === "ArrowDown") { + e.preventDefault(); + _cmdActiveIndex = (_cmdActiveIndex + 1) % filtered.length; + _renderCmdList(query); + } else if (e.key === "ArrowUp") { + e.preventDefault(); + _cmdActiveIndex = + (_cmdActiveIndex - 1 + filtered.length) % filtered.length; + _renderCmdList(query); + } else if (e.key === "Enter" && filtered.length > 0) { + e.preventDefault(); + closeCmdPalette(); + filtered[_cmdActiveIndex].action(); + } +} + async function init() { // Global error handlers window.onerror = function (message, source, lineno, colno, error) { @@ -558,9 +694,12 @@ async function init() { get("btnToggleToml").onclick = async function () { const formGrid = get("formSections"); const tomlViewer = get("tomlViewer"); + const historyPanel = get("configHistoryPanel"); const btn = get("btnToggleToml"); if (!formGrid || !tomlViewer || !btn) return; + if (historyPanel) historyPanel.style.display = "none"; + const isShowingToml = tomlViewer.style.display !== "none"; if (isShowingToml) { tomlViewer.style.display = "none"; @@ -581,6 +720,80 @@ async function init() { } }; + get("btnConfigHistory")?.addEventListener("click", async () => { + const panel = get("configHistoryPanel"); + const formGrid = get("formSections"); + const tomlViewer = get("tomlViewer"); + if (!panel) return; + + const isShowing = panel.style.display !== "none"; + if (isShowing) { + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + return; + } + + try { + const res = await api("/api/config/history"); + const data = await res.json(); + const backups = data.backups || []; + const list = get("configHistoryList"); + if (!list) return; + + if (backups.length === 0) { + list.innerHTML = `${escapeHtml(t("config.history_empty"))}
`; + } else { + list.innerHTML = backups + .map((b) => { + const date = new Date(b.mtime * 1000); + const sizeKB = (b.size / 1024).toFixed(1); + return `+`; + }) + .join(""); + + list.querySelectorAll("[data-restore-name]").forEach((btn) => { + btn.addEventListener("click", async () => { + const name = btn.getAttribute("data-restore-name"); + if (!confirm(t("config.history_restore_confirm"))) + return; + try { + await api("/api/config/history/restore", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ name }), + }); + showToast(t("config.history_restored"), "success"); + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + loadConfig(); + } catch (e) { + showToast( + `${t("common.error")}: ${e.message}`, + "error", + ); + } + }); + }); + } + + if (formGrid) formGrid.style.display = "none"; + if (tomlViewer) tomlViewer.style.display = "none"; + panel.style.display = "block"; + } catch (e) { + showToast(`${t("common.error")}: ${e.message}`, "error"); + } + }); + const logout = async () => { try { await api(authEndpointCandidates("logout"), { method: "POST" }); @@ -656,6 +869,40 @@ async function init() { switchTab("config"); } + // Command Palette keybinding + document.addEventListener("keydown", (e) => { + if ((e.metaKey || e.ctrlKey) && e.key === "k") { + e.preventDefault(); + const overlay = get("cmdPaletteOverlay"); + if (overlay && overlay.style.display !== "none") { + closeCmdPalette(); + } else { + openCmdPalette(); + } + } + if (e.key === "Escape") { + closeCmdPalette(); + } + }); + + const cmdInput = get("cmdPaletteInput"); + if (cmdInput) { + cmdInput.addEventListener("input", () => { + _cmdActiveIndex = 0; + _renderCmdList(cmdInput.value); + }); + cmdInput.addEventListener("keydown", (e) => { + _handleCmdKey(e, cmdInput.value); + }); + } + + const cmdOverlay = get("cmdPaletteOverlay"); + if (cmdOverlay) { + cmdOverlay.addEventListener("click", (e) => { + if (e.target === cmdOverlay) closeCmdPalette(); + }); + } + document.addEventListener("visibilitychange", () => { if (document.hidden) { stopStatusTimer(); diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index a4953a5..214ee63 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -363,6 +363,8 @@+++ ${escapeHtml(b.name)} + ${sizeKB} KB · ${date.toLocaleString()} ++ +配置修改
data-i18n="config.collapse_all">全部折叠 +配置修改
+ @@ -808,6 +816,14 @@MIT License
+ + + From 759aeb65792ca517144dfee2af6b054dc26623ca Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:46:23 +0800 Subject: [PATCH 28/57] feat(webui): add config version history and rollback backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _backup_config() helper that creates timestamped backups in data/config_backups/ with a 50-backup cap - Auto-backup before every config save (POST /api/config) and patch (POST /api/patch) - GET /api/config/history — list all backups (newest first) - POST /api/config/history/restore — restore a backup by name with TOML validation and auto-backup of current config Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude--- src/Undefined/webui/routes/_config.py | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/Undefined/webui/routes/_config.py b/src/Undefined/webui/routes/_config.py index 0afeaa0..c6c65b3 100644 --- a/src/Undefined/webui/routes/_config.py +++ b/src/Undefined/webui/routes/_config.py @@ -1,5 +1,7 @@ import asyncio +import shutil import tomllib +from datetime import datetime, timezone from pathlib import Path from tempfile import NamedTemporaryFile @@ -23,6 +25,27 @@ sync_config_file, ) +_BACKUP_DIR = Path("data") / "config_backups" +_MAX_BACKUPS = 50 + + +def _backup_config() -> str | None: + """Create a timestamped backup of the current config.toml. + + Returns the backup filename, or *None* if the source file does not exist. + """ + if not CONFIG_PATH.exists(): + return None + _BACKUP_DIR.mkdir(parents=True, exist_ok=True) + ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + backup_name = f"config_{ts}.toml" + shutil.copy2(CONFIG_PATH, _BACKUP_DIR / backup_name) + # Trim old backups beyond the cap + backups = sorted(_BACKUP_DIR.glob("config_*.toml")) + while len(backups) > _MAX_BACKUPS: + backups.pop(0).unlink(missing_ok=True) + return backup_name + @routes.get("/api/v1/management/config") @routes.get("/api/config") @@ -45,6 +68,7 @@ async def save_config_handler(request: web.Request) -> Response: if not valid: return web.json_response({"success": False, "error": msg}, status=400) try: + _backup_config() CONFIG_PATH.write_text(content, encoding="utf-8") get_config_manager().reload() logic_valid, logic_msg = validate_required_config() @@ -133,6 +157,7 @@ async def config_patch_handler(request: web.Request) -> Response: data = {} patched = apply_patch(data, patch) + _backup_config() CONFIG_PATH.write_text( render_toml(patched, comments=load_comment_map()), encoding="utf-8" ) @@ -181,3 +206,49 @@ async def sync_config_template_handler(request: web.Request) -> Response: return web.json_response({"success": False, "error": str(exc)}, status=400) except Exception as exc: return web.json_response({"success": False, "error": str(exc)}, status=500) + + +# --------------------------------------------------------------------------- +# Config version history +# --------------------------------------------------------------------------- + + +@routes.get("/api/v1/management/config/history") +@routes.get("/api/config/history") +async def config_history_handler(request: web.Request) -> Response: + """Return the list of config backups, newest first.""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + if not _BACKUP_DIR.exists(): + return web.json_response({"backups": []}) + backups: list[dict[str, object]] = [] + for f in sorted(_BACKUP_DIR.glob("config_*.toml"), reverse=True): + stat = f.stat() + backups.append({"name": f.name, "size": stat.st_size, "mtime": stat.st_mtime}) + return web.json_response({"backups": backups}) + + +@routes.post("/api/v1/management/config/history/restore") +@routes.post("/api/config/history/restore") +async def config_restore_handler(request: web.Request) -> Response: + """Restore a config backup by name (auto-backs-up current config first).""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: + data = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON"}, status=400) + name = str(data.get("name", "")) + if not name or ".." in name or "/" in name: + return web.json_response({"error": "Invalid backup name"}, status=400) + backup_path = _BACKUP_DIR / name + if not backup_path.exists(): + return web.json_response({"error": "Backup not found"}, status=404) + content = backup_path.read_text(encoding="utf-8") + valid, msg = validate_toml(content) + if not valid: + return web.json_response({"error": f"Backup TOML invalid: {msg}"}, status=400) + _backup_config() + CONFIG_PATH.write_text(content, encoding="utf-8") + get_config_manager().reload() + return web.json_response({"success": True, "message": "Restored"}) From 8f478b93ee156fc0228a4ceb67507de0dd68d027 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:52:43 +0800 Subject: [PATCH 29/57] feat(webui): add modal focus trap and wire into command palette - Add stack-based trapFocus/releaseFocus to ui.js - Wire focus trap into openCmdPalette/closeCmdPalette - Tab/Shift+Tab cycles within modal boundaries Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/webui/static/js/main.js | 6 ++-- src/Undefined/webui/static/js/ui.js | 40 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index 2be7aee..5745ef3 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -247,12 +247,14 @@ function openCmdPalette() { input.value = ""; _cmdActiveIndex = 0; _renderCmdList(""); - input.focus(); + trapFocus(overlay); } function closeCmdPalette() { const overlay = get("cmdPaletteOverlay"); - if (overlay) overlay.style.display = "none"; + if (!overlay) return; + releaseFocus(overlay); + overlay.style.display = "none"; } function _renderCmdList(query) { diff --git a/src/Undefined/webui/static/js/ui.js b/src/Undefined/webui/static/js/ui.js index c42a4dd..f455ba2 100644 --- a/src/Undefined/webui/static/js/ui.js +++ b/src/Undefined/webui/static/js/ui.js @@ -186,6 +186,46 @@ function showToast(message, type = "info", duration = 3000) { }, duration); } +// Focus trap for modals +const _focusTrapStack = []; + +function trapFocus(container) { + if (!container) return; + const focusable = container.querySelectorAll( + 'a[href], button:not([disabled]), input:not([disabled]), select:not([disabled]), textarea:not([disabled]), [tabindex]:not([tabindex="-1"])', + ); + if (focusable.length === 0) return; + const first = focusable[0]; + const last = focusable[focusable.length - 1]; + + function handler(e) { + if (e.key !== "Tab") return; + if (e.shiftKey) { + if (document.activeElement === first) { + e.preventDefault(); + last.focus(); + } + } else { + if (document.activeElement === last) { + e.preventDefault(); + first.focus(); + } + } + } + + container.addEventListener("keydown", handler); + _focusTrapStack.push({ container, handler }); + first.focus(); +} + +function releaseFocus(container) { + const idx = _focusTrapStack.findIndex((t) => t.container === container); + if (idx === -1) return; + const entry = _focusTrapStack[idx]; + entry.container.removeEventListener("keydown", entry.handler); + _focusTrapStack.splice(idx, 1); +} + function setConfigState(mode) { const stateEl = get("configState"); const grid = get("formSections"); From a410707dcc6c49242e7b0d3c6a71d4c461ac1404 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:04:23 +0800 Subject: [PATCH 30/57] refactor(config): split loader.py into sub-modules Extract 6 sub-modules from the 3365-line monolith: - coercers.py: type coercion/normalization helpers (~150 lines) - resolvers.py: config value resolution (~106 lines) - admin.py: local admin management (~44 lines) - webui_settings.py: WebUI settings class (~61 lines) - model_parsers.py: all model config parsers (~1250 lines) - domain_parsers.py: domain config parsers + _update_dataclass (~284 lines) loader.py retains Config class, TOML loading, and all re-exports for backward compatibility. All 1429 tests pass, mypy strict clean. Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/config/admin.py | 44 + src/Undefined/config/coercers.py | 150 ++ src/Undefined/config/domain_parsers.py | 284 +++ src/Undefined/config/loader.py | 1957 +---------------- src/Undefined/config/model_parsers.py | 1250 +++++++++++ src/Undefined/config/resolvers.py | 106 + src/Undefined/config/webui_settings.py | 61 + .../test_config_cognitive_historian_limits.py | 6 +- 8 files changed, 1992 insertions(+), 1866 deletions(-) create mode 100644 src/Undefined/config/admin.py create mode 100644 src/Undefined/config/coercers.py create mode 100644 src/Undefined/config/domain_parsers.py create mode 100644 src/Undefined/config/model_parsers.py create mode 100644 src/Undefined/config/resolvers.py create mode 100644 src/Undefined/config/webui_settings.py diff --git a/src/Undefined/config/admin.py b/src/Undefined/config/admin.py new file mode 100644 index 0000000..16dcfe6 --- /dev/null +++ b/src/Undefined/config/admin.py @@ -0,0 +1,44 @@ +"""Local admin management (config.local.json).""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +LOCAL_CONFIG_PATH = Path("config.local.json") + + +def load_local_admins() -> list[int]: + """从本地配置文件加载动态管理员列表""" + if not LOCAL_CONFIG_PATH.exists(): + return [] + try: + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + admin_qqs: list[int] = data.get("admin_qqs", []) + return admin_qqs + except Exception as exc: + logger.warning("读取本地配置失败: %s", exc) + return [] + + +def save_local_admins(admin_qqs: list[int]) -> None: + """保存动态管理员列表到本地配置文件""" + try: + data: dict[str, list[int]] = {} + if LOCAL_CONFIG_PATH.exists(): + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + + data["admin_qqs"] = admin_qqs + + with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) + except Exception as exc: + logger.error("保存本地配置失败: %s", exc) + raise diff --git a/src/Undefined/config/coercers.py b/src/Undefined/config/coercers.py new file mode 100644 index 0000000..4edff34 --- /dev/null +++ b/src/Undefined/config/coercers.py @@ -0,0 +1,150 @@ +"""Type coercion and normalization helpers for config loading.""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +from Undefined.utils.request_params import normalize_request_params + +logger = logging.getLogger(__name__) + +_ENV_WARNED_KEYS: set[str] = set() + + +def _warn_env_fallback(name: str) -> None: + if name in _ENV_WARNED_KEYS: + return + _ENV_WARNED_KEYS.add(name) + logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) + + +def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: + node: Any = data + for key in path: + if not isinstance(node, dict) or key not in node: + return None + node = node[key] + return node + + +def _normalize_str(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped if stripped else None + return str(value).strip() + + +def _coerce_int(value: Any, default: int) -> int: + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _coerce_float(value: Any, default: float) -> float: + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_queue_interval(value: float, default: float = 1.0) -> float: + """规范化队列发车间隔。 + + `0` 表示立即发车,负数回退到默认值。 + """ + + return default if value < 0 else value + + +def _coerce_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return default + + +def _coerce_str(value: Any, default: str) -> str: + normalized = _normalize_str(value) + return normalized if normalized is not None else default + + +def _normalize_base_url(value: str, default: str) -> str: + normalized = value.strip().rstrip("/") + if normalized: + return normalized + return default.rstrip("/") + + +def _coerce_int_list(value: Any) -> list[int]: + if value is None: + return [] + if isinstance(value, list): + items: list[int] = [] + for item in value: + try: + items.append(int(item)) + except (TypeError, ValueError): + continue + return items + if isinstance(value, str): + parts = [part.strip() for part in value.split(",") if part.strip()] + items = [] + for part in parts: + try: + items.append(int(part)) + except ValueError: + continue + return items + return [] + + +def _coerce_str_list(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + if isinstance(value, str): + return [part.strip() for part in value.split(",") if part.strip()] + return [] + + +def _coerce_request_params(value: Any) -> dict[str, Any]: + return normalize_request_params(value) + + +def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: + return _coerce_request_params( + _get_nested(data, ("models", model_name, "request_params")) + ) + + +def _get_value( + data: dict[str, Any], + path: tuple[str, ...], + env_key: Optional[str], +) -> Any: + value = _get_nested(data, path) + if value is not None: + return value + if env_key: + env_value = os.getenv(env_key) + if env_value is not None and str(env_value).strip() != "": + _warn_env_fallback(env_key) + return env_value + return None + + +_VALID_API_MODES = {"chat_completions", "responses"} +_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py new file mode 100644 index 0000000..ebb2222 --- /dev/null +++ b/src/Undefined/config/domain_parsers.py @@ -0,0 +1,284 @@ +"""Domain configuration parsers (cognitive, memes, API, naga).""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _coerce_int_list, + _coerce_str_list, +) +from .models import ( + APIConfig, + CognitiveConfig, + MemeConfig, + NagaConfig, +) + +DEFAULT_API_HOST = "127.0.0.1" +DEFAULT_API_PORT = 8788 +DEFAULT_API_AUTH_KEY = "changeme" + + +def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: + cog = data.get("cognitive", {}) + vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} + q = cog.get("query", {}) if isinstance(cog, dict) else {} + hist = cog.get("historian", {}) if isinstance(cog, dict) else {} + prof = cog.get("profile", {}) if isinstance(cog, dict) else {} + que = cog.get("queue", {}) if isinstance(cog, dict) else {} + return CognitiveConfig( + enabled=_coerce_bool( + cog.get("enabled") if isinstance(cog, dict) else None, True + ), + bot_name=_coerce_str( + cog.get("bot_name") if isinstance(cog, dict) else None, + "Undefined", + ), + vector_store_path=_coerce_str( + vs.get("path") if isinstance(vs, dict) else None, + "data/cognitive/chromadb", + ), + queue_path=_coerce_str( + que.get("path") if isinstance(que, dict) else None, + "data/cognitive/queues", + ), + profiles_path=_coerce_str( + prof.get("path") if isinstance(prof, dict) else None, + "data/cognitive/profiles", + ), + auto_top_k=_coerce_int(q.get("auto_top_k") if isinstance(q, dict) else None, 3), + auto_scope_candidate_multiplier=_coerce_int( + q.get("auto_scope_candidate_multiplier") if isinstance(q, dict) else None, + 2, + ), + auto_current_group_boost=_coerce_float( + q.get("auto_current_group_boost") if isinstance(q, dict) else None, + 1.15, + ), + auto_current_private_boost=_coerce_float( + q.get("auto_current_private_boost") if isinstance(q, dict) else None, + 1.25, + ), + enable_rerank=_coerce_bool( + q.get("enable_rerank") if isinstance(q, dict) else None, True + ), + recent_end_summaries_inject_k=_coerce_int( + q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, + 30, + ), + time_decay_enabled=_coerce_bool( + q.get("time_decay_enabled") if isinstance(q, dict) else None, True + ), + time_decay_half_life_days_auto=_coerce_float( + q.get("time_decay_half_life_days_auto") if isinstance(q, dict) else None, + 14.0, + ), + time_decay_half_life_days_tool=_coerce_float( + q.get("time_decay_half_life_days_tool") if isinstance(q, dict) else None, + 60.0, + ), + time_decay_boost=_coerce_float( + q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 + ), + time_decay_min_similarity=_coerce_float( + q.get("time_decay_min_similarity") if isinstance(q, dict) else None, + 0.35, + ), + tool_default_top_k=_coerce_int( + q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 + ), + profile_top_k=_coerce_int( + q.get("profile_top_k") if isinstance(q, dict) else None, 8 + ), + rerank_candidate_multiplier=_coerce_int( + q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 + ), + rewrite_max_retry=_coerce_int( + hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 + ), + historian_recent_messages_inject_k=_coerce_int( + hist.get("recent_messages_inject_k") if isinstance(hist, dict) else None, + 12, + ), + historian_recent_message_line_max_len=_coerce_int( + hist.get("recent_message_line_max_len") if isinstance(hist, dict) else None, + 240, + ), + historian_source_message_max_len=_coerce_int( + hist.get("source_message_max_len") if isinstance(hist, dict) else None, + 800, + ), + poll_interval_seconds=_coerce_float( + hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, + 1.0, + ), + stale_job_timeout_seconds=_coerce_float( + hist.get("stale_job_timeout_seconds") if isinstance(hist, dict) else None, + 300.0, + ), + profile_revision_keep=_coerce_int( + prof.get("revision_keep") if isinstance(prof, dict) else None, 5 + ), + failed_max_age_days=_coerce_int( + que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 + ), + failed_max_files=_coerce_int( + que.get("failed_max_files") if isinstance(que, dict) else None, 500 + ), + failed_cleanup_interval=_coerce_int( + que.get("failed_cleanup_interval") if isinstance(que, dict) else None, + 100, + ), + job_max_retries=_coerce_int( + que.get("job_max_retries") if isinstance(que, dict) else None, 3 + ), + ) + + +def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: + section_raw = data.get("memes", {}) + section = section_raw if isinstance(section_raw, dict) else {} + return MemeConfig( + enabled=_coerce_bool(section.get("enabled"), True), + query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), + max_source_image_bytes=max( + 1, + _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), + ), + blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), + preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), + db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), + vector_store_path=_coerce_str( + section.get("vector_store_path"), "data/memes/chromadb" + ), + queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), + max_items=max(1, _coerce_int(section.get("max_items"), 10000)), + max_total_bytes=max( + 1, + _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), + ), + allow_gif=_coerce_bool(section.get("allow_gif"), True), + auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), + auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), + keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), + semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), + rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), + worker_max_concurrency=max( + 1, _coerce_int(section.get("worker_max_concurrency"), 4) + ), + ) + + +def _parse_api_config(data: dict[str, Any]) -> APIConfig: + section_raw = data.get("api", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), True) + host = _coerce_str(section.get("host"), DEFAULT_API_HOST) + port = _coerce_int(section.get("port"), DEFAULT_API_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_API_PORT + + auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) + if not auth_key: + auth_key = DEFAULT_API_AUTH_KEY + + openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) + + tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) + tool_invoke_expose = _coerce_str( + section.get("tool_invoke_expose"), "tools+toolsets" + ) + valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} + if tool_invoke_expose not in valid_expose: + tool_invoke_expose = "tools+toolsets" + tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) + tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) + tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) + if tool_invoke_timeout <= 0: + tool_invoke_timeout = 120 + tool_invoke_callback_timeout = _coerce_int( + section.get("tool_invoke_callback_timeout"), 10 + ) + if tool_invoke_callback_timeout <= 0: + tool_invoke_callback_timeout = 10 + + return APIConfig( + enabled=enabled, + host=host, + port=port, + auth_key=auth_key, + openapi_enabled=openapi_enabled, + tool_invoke_enabled=tool_invoke_enabled, + tool_invoke_expose=tool_invoke_expose, + tool_invoke_allowlist=tool_invoke_allowlist, + tool_invoke_denylist=tool_invoke_denylist, + tool_invoke_timeout=tool_invoke_timeout, + tool_invoke_callback_timeout=tool_invoke_callback_timeout, + ) + + +def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: + section_raw = data.get("naga", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), False) + api_url = _coerce_str(section.get("api_url"), "") + api_key = _coerce_str(section.get("api_key"), "") + moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) + allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) + + return NagaConfig( + enabled=enabled, + api_url=api_url, + api_key=api_key, + moderation_enabled=moderation_enabled, + allowed_groups=allowed_groups, + ) + + +def _parse_easter_egg_call_mode(value: Any) -> str: + """解析彩蛋提示模式。 + + 兼容旧版布尔值: + - True => agent + - False => none + """ + if isinstance(value, bool): + return "agent" if value else "none" + if isinstance(value, (int, float)): + return "agent" if bool(value) else "none" + if value is None: + return "none" + + text = str(value).strip().lower() + if text in {"true", "1", "yes", "on"}: + return "agent" + if text in {"false", "0", "no", "off"}: + return "none" + if text in {"none", "agent", "tools", "all", "clean"}: + return text + return "none" + + +def _update_dataclass( + old_value: Any, new_value: Any, prefix: str +) -> dict[str, tuple[Any, Any]]: + changes: dict[str, tuple[Any, Any]] = {} + if not isinstance(old_value, type(new_value)): + changes[prefix] = (old_value, new_value) + return changes + for field in fields(old_value): + name = field.name + old_attr = getattr(old_value, name) + new_attr = getattr(new_value, name) + if old_attr != new_attr: + setattr(old_value, name, new_attr) + changes[f"{prefix}.{name}"] = (old_attr, new_attr) + return changes diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index ac65473..26f669c 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import logging import os import re @@ -37,38 +36,90 @@ def load_dotenv( ImageGenConfig, ImageGenModelConfig, MemeConfig, - ModelPool, - ModelPoolEntry, NagaConfig, RerankModelConfig, SecurityModelConfig, VisionModelConfig, ) -from Undefined.utils.request_params import ( - merge_request_params, - normalize_request_params, +from .coercers import ( # noqa: F401 — re-exported for backward compat + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_int_list, + _coerce_str, + _coerce_str_list, + _get_model_request_params, + _get_value, + _normalize_base_url, + _normalize_queue_interval, + _normalize_str, + _warn_env_fallback, +) +from .resolvers import ( # noqa: F401 — re-exported for backward compat + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) +from .admin import ( # noqa: F401 — re-exported for backward compat + LOCAL_CONFIG_PATH, + load_local_admins, + save_local_admins, +) +from .webui_settings import ( # noqa: F401 — re-exported for backward compat + DEFAULT_WEBUI_PASSWORD, + DEFAULT_WEBUI_PORT, + DEFAULT_WEBUI_URL, + WebUISettings, + load_webui_settings, +) +from .model_parsers import ( + _log_debug_info, + _merge_admins, + _parse_agent_model_config, + _parse_chat_model_config, + _parse_embedding_model_config, + _parse_grok_model_config, + _parse_historian_model_config, + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, + _parse_naga_model_config, + _parse_rerank_model_config, + _parse_security_model_config, + _parse_summary_model_config, + _parse_vision_model_config, + _verify_required_fields, +) +from .domain_parsers import ( + _parse_api_config, + _parse_cognitive_config, + _parse_easter_egg_call_mode, + _parse_memes_config, + _parse_naga_config, + _update_dataclass, ) logger = logging.getLogger(__name__) CONFIG_PATH = Path("config.toml") -LOCAL_CONFIG_PATH = Path("config.local.json") - -DEFAULT_WEBUI_URL = "127.0.0.1" -DEFAULT_WEBUI_PORT = 8787 -DEFAULT_WEBUI_PASSWORD = "changeme" -DEFAULT_API_HOST = "127.0.0.1" -DEFAULT_API_PORT = 8788 -DEFAULT_API_AUTH_KEY = "changeme" - -_ENV_WARNED_KEYS: set[str] = set() - -def _warn_env_fallback(name: str) -> None: - if name in _ENV_WARNED_KEYS: - return - _ENV_WARNED_KEYS.add(name) - logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) +# Re-export symbols that external modules import from this module. +__all__ = [ + "CONFIG_PATH", + "Config", + "DEFAULT_WEBUI_PASSWORD", + "DEFAULT_WEBUI_PORT", + "DEFAULT_WEBUI_URL", + "LOCAL_CONFIG_PATH", + "WebUISettings", + "load_local_admins", + "load_toml_data", + "load_webui_settings", + "save_local_admins", +] def _load_env() -> None: @@ -143,308 +194,6 @@ def load_toml_data( return {} -def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: - node: Any = data - for key in path: - if not isinstance(node, dict) or key not in node: - return None - node = node[key] - return node - - -def _normalize_str(value: Any) -> Optional[str]: - if value is None: - return None - if isinstance(value, str): - stripped = value.strip() - return stripped if stripped else None - return str(value).strip() - - -def _coerce_int(value: Any, default: int) -> int: - if value is None: - return default - try: - return int(value) - except (TypeError, ValueError): - return default - - -def _coerce_float(value: Any, default: float) -> float: - if value is None: - return default - try: - return float(value) - except (TypeError, ValueError): - return default - - -def _normalize_queue_interval(value: float, default: float = 1.0) -> float: - """规范化队列发车间隔。 - - `0` 表示立即发车,负数回退到默认值。 - """ - - return default if value < 0 else value - - -def _coerce_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "on"} - return default - - -def _coerce_str(value: Any, default: str) -> str: - normalized = _normalize_str(value) - return normalized if normalized is not None else default - - -def _normalize_base_url(value: str, default: str) -> str: - normalized = value.strip().rstrip("/") - if normalized: - return normalized - return default.rstrip("/") - - -def _coerce_int_list(value: Any) -> list[int]: - if value is None: - return [] - if isinstance(value, list): - items: list[int] = [] - for item in value: - try: - items.append(int(item)) - except (TypeError, ValueError): - continue - return items - if isinstance(value, str): - parts = [part.strip() for part in value.split(",") if part.strip()] - items = [] - for part in parts: - try: - items.append(int(part)) - except ValueError: - continue - return items - return [] - - -def _coerce_str_list(value: Any) -> list[str]: - if value is None: - return [] - if isinstance(value, list): - return [str(item).strip() for item in value if str(item).strip()] - if isinstance(value, str): - return [part.strip() for part in value.split(",") if part.strip()] - return [] - - -def _coerce_request_params(value: Any) -> dict[str, Any]: - return normalize_request_params(value) - - -def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: - return _coerce_request_params( - _get_nested(data, ("models", model_name, "request_params")) - ) - - -def _get_value( - data: dict[str, Any], - path: tuple[str, ...], - env_key: Optional[str], -) -> Any: - value = _get_nested(data, path) - if value is not None: - return value - if env_key: - env_value = os.getenv(env_key) - if env_value is not None and str(env_value).strip() != "": - _warn_env_fallback(env_key) - return env_value - return None - - -_VALID_API_MODES = {"chat_completions", "responses"} -_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} - - -def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: - style = _coerce_str(value, default).strip().lower() - if style not in _VALID_REASONING_EFFORT_STYLES: - return default - return style - - -def _resolve_thinking_compat_flags( - data: dict[str, Any], - model_name: str, - include_budget_env_key: str, - tool_call_compat_env_key: str, - legacy_env_key: str, -) -> tuple[bool, bool]: - """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" - include_budget_value = _get_value( - data, - ("models", model_name, "thinking_include_budget"), - include_budget_env_key, - ) - tool_call_compat_value = _get_value( - data, - ("models", model_name, "thinking_tool_call_compat"), - tool_call_compat_env_key, - ) - legacy_value = _get_value( - data, - ("models", model_name, "deepseek_new_cot_support"), - legacy_env_key, - ) - - include_budget_default = True - tool_call_compat_default = True - if legacy_value is not None: - legacy_enabled = _coerce_bool(legacy_value, False) - include_budget_default = not legacy_enabled - tool_call_compat_default = legacy_enabled - - return ( - _coerce_bool(include_budget_value, include_budget_default), - _coerce_bool(tool_call_compat_value, tool_call_compat_default), - ) - - -def _resolve_api_mode( - data: dict[str, Any], - model_name: str, - env_key: str, - default: str = "chat_completions", -) -> str: - raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) - value = _coerce_str(raw_value, default).strip().lower() - if value not in _VALID_API_MODES: - return default - return value - - -def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: - return _coerce_str(value, default).strip().lower() - - -def _resolve_responses_tool_choice_compat( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_tool_choice_compat"), - env_key, - ), - default, - ) - - -def _resolve_responses_force_stateless_replay( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_force_stateless_replay"), - env_key, - ), - default, - ) - - -def load_local_admins() -> list[int]: - """从本地配置文件加载动态管理员列表""" - if not LOCAL_CONFIG_PATH.exists(): - return [] - try: - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - admin_qqs: list[int] = data.get("admin_qqs", []) - return admin_qqs - except Exception as exc: - logger.warning("读取本地配置失败: %s", exc) - return [] - - -def save_local_admins(admin_qqs: list[int]) -> None: - """保存动态管理员列表到本地配置文件""" - try: - data: dict[str, list[int]] = {} - if LOCAL_CONFIG_PATH.exists(): - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - - data["admin_qqs"] = admin_qqs - - with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) - except Exception as exc: - logger.error("保存本地配置失败: %s", exc) - raise - - -@dataclass -class WebUISettings: - url: str - port: int - password: str - using_default_password: bool - config_exists: bool - - @property - def display_url(self) -> str: - """用于日志和展示的格式化 URL。""" - from Undefined.config.models import format_netloc - - return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" - - -def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: - data = load_toml_data(config_path) - config_exists = bool(data) - url_value = _get_value(data, ("webui", "url"), None) - port_value = _get_value(data, ("webui", "port"), None) - password_value = _get_value(data, ("webui", "password"), None) - - url = _coerce_str(url_value, DEFAULT_WEBUI_URL) - port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_WEBUI_PORT - - password_normalized = _normalize_str(password_value) - if not password_normalized: - return WebUISettings( - url=url, - port=port, - password=DEFAULT_WEBUI_PASSWORD, - using_default_password=True, - config_exists=config_exists, - ) - return WebUISettings( - url=url, - port=port, - password=password_normalized, - using_default_password=False, - config_exists=config_exists, - ) - - @dataclass class Config: """应用配置""" @@ -783,8 +532,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" ) - embedding_model = cls._parse_embedding_model_config(data) - rerank_model = cls._parse_rerank_model_config(data) + embedding_model = _parse_embedding_model_config(data) + rerank_model = _parse_rerank_model_config(data) knowledge_enabled = _coerce_bool( _get_value(data, ("knowledge", "enabled"), None), False @@ -858,8 +607,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ) knowledge_rerank_top_k = fallback - chat_model = cls._parse_chat_model_config(data) - vision_model = cls._parse_vision_model_config(data) + chat_model = _parse_chat_model_config(data) + vision_model = _parse_vision_model_config(data) security_model_enabled = _coerce_bool( _get_value( data, @@ -868,20 +617,20 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ), True, ) - security_model = cls._parse_security_model_config(data, chat_model) - naga_model = cls._parse_naga_model_config(data, security_model) - agent_model = cls._parse_agent_model_config(data) - historian_model = cls._parse_historian_model_config(data, agent_model) - summary_model, summary_model_configured = cls._parse_summary_model_config( + security_model = _parse_security_model_config(data, chat_model) + naga_model = _parse_naga_model_config(data, security_model) + agent_model = _parse_agent_model_config(data) + historian_model = _parse_historian_model_config(data, agent_model) + summary_model, summary_model_configured = _parse_summary_model_config( data, agent_model ) - grok_model = cls._parse_grok_model_config(data) + grok_model = _parse_grok_model_config(data) model_pool_enabled = _coerce_bool( _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False ) - superadmin_qq, admin_qqs = cls._merge_admins( + superadmin_qq, admin_qqs = _merge_admins( superadmin_qq=superadmin_qq, admin_qqs=admin_qqs ) @@ -1013,7 +762,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi if easter_egg_mode_raw is not None: _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") - easter_egg_agent_call_message_mode = cls._parse_easter_egg_call_mode( + easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( easter_egg_mode_raw ) @@ -1443,17 +1192,17 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi messages_send_url_file_max_size_mb = 100 webui_settings = load_webui_settings(config_path) - api_config = cls._parse_api_config(data) + api_config = _parse_api_config(data) - cognitive = cls._parse_cognitive_config(data) - memes = cls._parse_memes_config(data) - naga = cls._parse_naga_config(data) - models_image_gen = cls._parse_image_gen_model_config(data) - models_image_edit = cls._parse_image_edit_model_config(data) - image_gen = cls._parse_image_gen_config(data) + cognitive = _parse_cognitive_config(data) + memes = _parse_memes_config(data) + naga = _parse_naga_config(data) + models_image_gen = _parse_image_gen_model_config(data) + models_image_edit = _parse_image_edit_model_config(data) + image_gen = _parse_image_gen_config(data) if strict: - cls._verify_required_fields( + _verify_required_fields( bot_qq=bot_qq, superadmin_qq=superadmin_qq, onebot_ws_url=onebot_ws_url, @@ -1464,7 +1213,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi embedding_model=embedding_model, ) - cls._log_debug_info( + _log_debug_info( chat_model, vision_model, security_model, @@ -1793,1056 +1542,6 @@ def security_check_enabled(self) -> bool: return bool(self.security_model_enabled) - @staticmethod - def _parse_model_pool( - data: dict[str, Any], - model_section: str, - primary_config: ChatModelConfig | AgentModelConfig, - ) -> ModelPool | None: - """解析模型池配置,缺省字段继承 primary_config""" - pool_data = data.get("models", {}).get(model_section, {}).get("pool") - if not isinstance(pool_data, dict): - return None - - enabled = _coerce_bool(pool_data.get("enabled"), False) - strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() - if strategy not in ("default", "round_robin", "random"): - strategy = "default" - - raw_models = pool_data.get("models") - if not isinstance(raw_models, list): - return ModelPool(enabled=enabled, strategy=strategy) - - entries: list[ModelPoolEntry] = [] - for item in raw_models: - if not isinstance(item, dict): - continue - name = _coerce_str(item.get("model_name"), "").strip() - if not name: - continue - entries.append( - ModelPoolEntry( - api_url=_coerce_str(item.get("api_url"), primary_config.api_url), - api_key=_coerce_str(item.get("api_key"), primary_config.api_key), - model_name=name, - max_tokens=_coerce_int( - item.get("max_tokens"), primary_config.max_tokens - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - item.get("queue_interval_seconds"), - primary_config.queue_interval_seconds, - ), - primary_config.queue_interval_seconds, - ), - api_mode=( - _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - ) - if _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - in _VALID_API_MODES - else primary_config.api_mode, - thinking_enabled=_coerce_bool( - item.get("thinking_enabled"), primary_config.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - item.get("thinking_budget_tokens"), - primary_config.thinking_budget_tokens, - ), - thinking_include_budget=_coerce_bool( - item.get("thinking_include_budget"), - primary_config.thinking_include_budget, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - item.get("reasoning_effort_style"), - primary_config.reasoning_effort_style, - ), - thinking_tool_call_compat=_coerce_bool( - item.get("thinking_tool_call_compat"), - primary_config.thinking_tool_call_compat, - ), - responses_tool_choice_compat=_coerce_bool( - item.get("responses_tool_choice_compat"), - primary_config.responses_tool_choice_compat, - ), - responses_force_stateless_replay=_coerce_bool( - item.get("responses_force_stateless_replay"), - primary_config.responses_force_stateless_replay, - ), - prompt_cache_enabled=_coerce_bool( - item.get("prompt_cache_enabled"), - primary_config.prompt_cache_enabled, - ), - reasoning_enabled=_coerce_bool( - item.get("reasoning_enabled"), - primary_config.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - item.get("reasoning_effort"), - primary_config.reasoning_effort, - ), - request_params=merge_request_params( - primary_config.request_params, - item.get("request_params"), - ), - ) - ) - - return ModelPool(enabled=enabled, strategy=strategy, models=entries) - - @staticmethod - def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: - return EmbeddingModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - _get_value( - data, ("models", "embedding", "queue_interval_seconds"), None - ), - 0.0, - ), - 0.0, - ), - dimensions=_coerce_int( - _get_value(data, ("models", "embedding", "dimensions"), None), 0 - ) - or None, - query_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "query_instruction"), None), "" - ), - document_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "document_instruction"), None), - "", - ), - request_params=_get_model_request_params(data, "embedding"), - ) - - @staticmethod - def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), - 0.0, - ), - 0.0, - ) - return RerankModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - query_instruction=_coerce_str( - _get_value(data, ("models", "rerank", "query_instruction"), None), "" - ), - request_params=_get_model_request_params(data, "rerank"), - ) - - @staticmethod - def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "chat", "queue_interval_seconds"), - "CHAT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="chat", - include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "prompt_cache_enabled"), - "CHAT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "reasoning_enabled"), - "CHAT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "chat", "reasoning_effort"), - "CHAT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = ChatModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "chat", "thinking_enabled"), - "CHAT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "chat", "thinking_budget_tokens"), - "CHAT_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "chat", "reasoning_effort_style"), - "CHAT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "chat"), - ) - config.pool = Config._parse_model_pool(data, "chat", config) - return config - - @staticmethod - def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "vision", "queue_interval_seconds"), - "VISION_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="vision", - include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "prompt_cache_enabled"), - "VISION_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "reasoning_enabled"), - "VISION_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "vision", "reasoning_effort"), - "VISION_MODEL_REASONING_EFFORT", - ), - "medium", - ) - return VisionModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "vision", "model_name"), "VISION_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "vision", "thinking_enabled"), - "VISION_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "vision", "thinking_budget_tokens"), - "VISION_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "vision", "reasoning_effort_style"), - "VISION_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "vision"), - ) - - @staticmethod - def _parse_security_model_config( - data: dict[str, Any], chat_model: ChatModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value( - data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL" - ), - "", - ) - api_key = _coerce_str( - _get_value( - data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY" - ), - "", - ) - model_name = _coerce_str( - _get_value( - data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME" - ), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "security", "queue_interval_seconds"), - "SECURITY_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="security", - include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "prompt_cache_enabled"), - "SECURITY_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "reasoning_enabled"), - "SECURITY_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "security", "reasoning_effort"), - "SECURITY_MODEL_REASONING_EFFORT", - ), - "medium", - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "max_tokens"), - "SECURITY_MODEL_MAX_TOKENS", - ), - 100, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "security", "thinking_enabled"), - "SECURITY_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "thinking_budget_tokens"), - "SECURITY_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "security", "reasoning_effort_style"), - "SECURITY_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "security"), - ) - - logger.warning("未配置安全模型,将使用对话模型作为后备") - return SecurityModelConfig( - api_url=chat_model.api_url, - api_key=chat_model.api_key, - model_name=chat_model.model_name, - max_tokens=chat_model.max_tokens, - queue_interval_seconds=chat_model.queue_interval_seconds, - api_mode=chat_model.api_mode, - thinking_enabled=False, - thinking_budget_tokens=0, - thinking_include_budget=True, - reasoning_effort_style="openai", - thinking_tool_call_compat=chat_model.thinking_tool_call_compat, - responses_tool_choice_compat=chat_model.responses_tool_choice_compat, - responses_force_stateless_replay=chat_model.responses_force_stateless_replay, - prompt_cache_enabled=chat_model.prompt_cache_enabled, - reasoning_enabled=chat_model.reasoning_enabled, - reasoning_effort=chat_model.reasoning_effort, - request_params=merge_request_params(chat_model.request_params), - ) - - @staticmethod - def _parse_naga_model_config( - data: dict[str, Any], security_model: SecurityModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), - "", - ) - api_key = _coerce_str( - _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), - "", - ) - model_name = _coerce_str( - _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "naga", "queue_interval_seconds"), - "NAGA_MODEL_QUEUE_INTERVAL", - ), - security_model.queue_interval_seconds, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="naga", - include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "prompt_cache_enabled"), - "NAGA_MODEL_PROMPT_CACHE_ENABLED", - ), - getattr(security_model, "prompt_cache_enabled", True), - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "reasoning_enabled"), - "NAGA_MODEL_REASONING_ENABLED", - ), - getattr(security_model, "reasoning_enabled", False), - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "naga", "reasoning_effort"), - "NAGA_MODEL_REASONING_EFFORT", - ), - getattr(security_model, "reasoning_effort", "medium"), - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "max_tokens"), - "NAGA_MODEL_MAX_TOKENS", - ), - 160, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "naga", "thinking_enabled"), - "NAGA_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "thinking_budget_tokens"), - "NAGA_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "naga", "reasoning_effort_style"), - "NAGA_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "naga"), - ) - - logger.info( - "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" - ) - return SecurityModelConfig( - api_url=security_model.api_url, - api_key=security_model.api_key, - model_name=security_model.model_name, - max_tokens=security_model.max_tokens, - queue_interval_seconds=security_model.queue_interval_seconds, - api_mode=security_model.api_mode, - thinking_enabled=security_model.thinking_enabled, - thinking_budget_tokens=security_model.thinking_budget_tokens, - thinking_include_budget=security_model.thinking_include_budget, - reasoning_effort_style=security_model.reasoning_effort_style, - thinking_tool_call_compat=security_model.thinking_tool_call_compat, - responses_tool_choice_compat=security_model.responses_tool_choice_compat, - responses_force_stateless_replay=security_model.responses_force_stateless_replay, - prompt_cache_enabled=security_model.prompt_cache_enabled, - reasoning_enabled=security_model.reasoning_enabled, - reasoning_effort=security_model.reasoning_effort, - request_params=merge_request_params(security_model.request_params), - ) - - @staticmethod - def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "agent", "queue_interval_seconds"), - "AGENT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="agent", - include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "prompt_cache_enabled"), - "AGENT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "reasoning_enabled"), - "AGENT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "agent", "reasoning_effort"), - "AGENT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = AgentModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" - ), - 4096, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "agent", "thinking_enabled"), - "AGENT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "agent", "thinking_budget_tokens"), - "AGENT_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "agent", "reasoning_effort_style"), - "AGENT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "agent"), - ) - config.pool = Config._parse_model_pool(data, "agent", config) - return config - - @staticmethod - def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "grok", "queue_interval_seconds"), - "GROK_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - return GrokModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_enabled"), - "GROK_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "grok", "thinking_budget_tokens"), - "GROK_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_include_budget"), - "GROK_MODEL_THINKING_INCLUDE_BUDGET", - ), - True, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "grok", "reasoning_effort_style"), - "GROK_MODEL_REASONING_EFFORT_STYLE", - ), - ), - prompt_cache_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "prompt_cache_enabled"), - "GROK_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ), - reasoning_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "reasoning_enabled"), - "GROK_MODEL_REASONING_ENABLED", - ), - False, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - data, - ("models", "grok", "reasoning_effort"), - "GROK_MODEL_REASONING_EFFORT", - ), - "medium", - ), - request_params=_get_model_request_params(data, "grok"), - ) - - @staticmethod - def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_gen] 生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" - ), - "", - ), - request_params=_get_model_request_params(data, "image_gen"), - ) - - @staticmethod - def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_edit] 参考图生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_url"), - "IMAGE_EDIT_MODEL_API_URL", - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_key"), - "IMAGE_EDIT_MODEL_API_KEY", - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, - ("models", "image_edit", "model_name"), - "IMAGE_EDIT_MODEL_NAME", - ), - "", - ), - request_params=_get_model_request_params(data, "image_edit"), - ) - - @staticmethod - def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: - """解析 [image_gen] 生图工具配置""" - return ImageGenConfig( - provider=_coerce_str( - _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), - "xingzhige", - ), - xingzhige_size=_coerce_str( - _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" - ), - openai_size=_coerce_str( - _get_value(data, ("image_gen", "openai_size"), None), "" - ), - openai_quality=_coerce_str( - _get_value(data, ("image_gen", "openai_quality"), None), "" - ), - openai_style=_coerce_str( - _get_value(data, ("image_gen", "openai_style"), None), "" - ), - openai_timeout=_coerce_float( - _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 - ), - ) - - @staticmethod - def _merge_admins( - superadmin_qq: int, admin_qqs: list[int] - ) -> tuple[int, list[int]]: - local_admins = load_local_admins() - all_admins = list(set(admin_qqs + local_admins)) - if superadmin_qq and superadmin_qq not in all_admins: - all_admins.append(superadmin_qq) - return superadmin_qq, all_admins - - @staticmethod - def _verify_required_fields( - bot_qq: int, - superadmin_qq: int, - onebot_ws_url: str, - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - agent_model: AgentModelConfig, - knowledge_enabled: bool, - embedding_model: EmbeddingModelConfig, - ) -> None: - missing: list[str] = [] - if bot_qq <= 0: - missing.append("core.bot_qq") - if superadmin_qq <= 0: - missing.append("core.superadmin_qq") - if not onebot_ws_url: - missing.append("onebot.ws_url") - if not chat_model.api_url: - missing.append("models.chat.api_url") - if not chat_model.api_key: - missing.append("models.chat.api_key") - if not chat_model.model_name: - missing.append("models.chat.model_name") - if not vision_model.api_url: - missing.append("models.vision.api_url") - if not vision_model.api_key: - missing.append("models.vision.api_key") - if not vision_model.model_name: - missing.append("models.vision.model_name") - if not agent_model.api_url: - missing.append("models.agent.api_url") - if not agent_model.api_key: - missing.append("models.agent.api_key") - if not agent_model.model_name: - missing.append("models.agent.model_name") - if knowledge_enabled: - if not embedding_model.api_url: - missing.append("models.embedding.api_url") - if not embedding_model.model_name: - missing.append("models.embedding.model_name") - if missing: - raise ValueError(f"缺少必需配置: {', '.join(missing)}") - - @staticmethod - def _log_debug_info( - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - security_model: SecurityModelConfig, - naga_model: SecurityModelConfig, - agent_model: AgentModelConfig, - summary_model: AgentModelConfig, - grok_model: GrokModelConfig, - ) -> None: - configs: list[ - tuple[ - str, - ChatModelConfig - | VisionModelConfig - | SecurityModelConfig - | AgentModelConfig - | GrokModelConfig, - ] - ] = [ - ("chat", chat_model), - ("vision", vision_model), - ("security", security_model), - ("naga", naga_model), - ("agent", agent_model), - ("summary", summary_model), - ("grok", grok_model), - ] - for name, cfg in configs: - logger.debug( - "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", - name, - cfg.model_name, - cfg.api_url, - bool(cfg.api_key), - getattr(cfg, "api_mode", "chat_completions"), - cfg.thinking_enabled, - getattr(cfg, "reasoning_enabled", False), - getattr(cfg, "reasoning_effort", "medium"), - getattr(cfg, "thinking_tool_call_compat", False), - getattr(cfg, "responses_tool_choice_compat", False), - getattr(cfg, "responses_force_stateless_replay", False), - ) - def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes: dict[str, tuple[Any, Any]] = {} for field in fields(self): @@ -2866,457 +1565,6 @@ def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes[name] = (old_value, new_value) return changes - @staticmethod - def _parse_historian_model_config( - data: dict[str, Any], fallback: AgentModelConfig - ) -> AgentModelConfig: - h = data.get("models", {}).get("historian", {}) - if not isinstance(h, dict) or not h: - return fallback - queue_interval_seconds = _coerce_float( - h.get("queue_interval_seconds"), fallback.queue_interval_seconds - ) - queue_interval_seconds = _normalize_queue_interval( - queue_interval_seconds, fallback.queue_interval_seconds - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data={"models": {"historian": h}}, - model_name="historian", - include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_API_MODE", - fallback.api_mode, - ) - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", - fallback.responses_tool_choice_compat, - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", - fallback.responses_force_stateless_replay, - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "prompt_cache_enabled"), - "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", - ), - fallback.prompt_cache_enabled, - ) - return AgentModelConfig( - api_url=_coerce_str(h.get("api_url"), fallback.api_url), - api_key=_coerce_str(h.get("api_key"), fallback.api_key), - model_name=_coerce_str(h.get("model_name"), fallback.model_name), - max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - h.get("thinking_enabled"), fallback.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort_style"), - "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", - ), - fallback.reasoning_effort_style, - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=_coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_enabled"), - "HISTORIAN_MODEL_REASONING_ENABLED", - ), - fallback.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort"), - "HISTORIAN_MODEL_REASONING_EFFORT", - ), - fallback.reasoning_effort, - ), - request_params=merge_request_params( - fallback.request_params, - h.get("request_params"), - ), - ) - - @staticmethod - def _parse_summary_model_config( - data: dict[str, Any], fallback: AgentModelConfig - ) -> tuple[AgentModelConfig, bool]: - s = data.get("models", {}).get("summary", {}) - if not isinstance(s, dict) or not s: - return fallback, False - queue_interval_seconds = _coerce_float( - s.get("queue_interval_seconds"), fallback.queue_interval_seconds - ) - queue_interval_seconds = _normalize_queue_interval( - queue_interval_seconds, fallback.queue_interval_seconds - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data={"models": {"summary": s}}, - model_name="summary", - include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_API_MODE", - fallback.api_mode, - ) - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", - fallback.responses_tool_choice_compat, - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", - fallback.responses_force_stateless_replay, - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "prompt_cache_enabled"), - "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", - ), - fallback.prompt_cache_enabled, - ) - return ( - AgentModelConfig( - api_url=_coerce_str(s.get("api_url"), fallback.api_url), - api_key=_coerce_str(s.get("api_key"), fallback.api_key), - model_name=_coerce_str(s.get("model_name"), fallback.model_name), - max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - s.get("thinking_enabled"), fallback.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_effort_style"), - "SUMMARY_MODEL_REASONING_EFFORT_STYLE", - ), - fallback.reasoning_effort_style, - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=_coerce_bool( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_enabled"), - "SUMMARY_MODEL_REASONING_ENABLED", - ), - fallback.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_effort"), - "SUMMARY_MODEL_REASONING_EFFORT", - ), - fallback.reasoning_effort, - ), - request_params=merge_request_params( - fallback.request_params, - s.get("request_params"), - ), - ), - True, - ) - - @staticmethod - def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: - cog = data.get("cognitive", {}) - vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} - q = cog.get("query", {}) if isinstance(cog, dict) else {} - hist = cog.get("historian", {}) if isinstance(cog, dict) else {} - prof = cog.get("profile", {}) if isinstance(cog, dict) else {} - que = cog.get("queue", {}) if isinstance(cog, dict) else {} - return CognitiveConfig( - enabled=_coerce_bool( - cog.get("enabled") if isinstance(cog, dict) else None, True - ), - bot_name=_coerce_str( - cog.get("bot_name") if isinstance(cog, dict) else None, - "Undefined", - ), - vector_store_path=_coerce_str( - vs.get("path") if isinstance(vs, dict) else None, - "data/cognitive/chromadb", - ), - queue_path=_coerce_str( - que.get("path") if isinstance(que, dict) else None, - "data/cognitive/queues", - ), - profiles_path=_coerce_str( - prof.get("path") if isinstance(prof, dict) else None, - "data/cognitive/profiles", - ), - auto_top_k=_coerce_int( - q.get("auto_top_k") if isinstance(q, dict) else None, 3 - ), - auto_scope_candidate_multiplier=_coerce_int( - q.get("auto_scope_candidate_multiplier") - if isinstance(q, dict) - else None, - 2, - ), - auto_current_group_boost=_coerce_float( - q.get("auto_current_group_boost") if isinstance(q, dict) else None, - 1.15, - ), - auto_current_private_boost=_coerce_float( - q.get("auto_current_private_boost") if isinstance(q, dict) else None, - 1.25, - ), - enable_rerank=_coerce_bool( - q.get("enable_rerank") if isinstance(q, dict) else None, True - ), - recent_end_summaries_inject_k=_coerce_int( - q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, - 30, - ), - time_decay_enabled=_coerce_bool( - q.get("time_decay_enabled") if isinstance(q, dict) else None, True - ), - time_decay_half_life_days_auto=_coerce_float( - q.get("time_decay_half_life_days_auto") - if isinstance(q, dict) - else None, - 14.0, - ), - time_decay_half_life_days_tool=_coerce_float( - q.get("time_decay_half_life_days_tool") - if isinstance(q, dict) - else None, - 60.0, - ), - time_decay_boost=_coerce_float( - q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 - ), - time_decay_min_similarity=_coerce_float( - q.get("time_decay_min_similarity") if isinstance(q, dict) else None, - 0.35, - ), - tool_default_top_k=_coerce_int( - q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 - ), - profile_top_k=_coerce_int( - q.get("profile_top_k") if isinstance(q, dict) else None, 8 - ), - rerank_candidate_multiplier=_coerce_int( - q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 - ), - rewrite_max_retry=_coerce_int( - hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 - ), - historian_recent_messages_inject_k=_coerce_int( - hist.get("recent_messages_inject_k") - if isinstance(hist, dict) - else None, - 12, - ), - historian_recent_message_line_max_len=_coerce_int( - hist.get("recent_message_line_max_len") - if isinstance(hist, dict) - else None, - 240, - ), - historian_source_message_max_len=_coerce_int( - hist.get("source_message_max_len") if isinstance(hist, dict) else None, - 800, - ), - poll_interval_seconds=_coerce_float( - hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, - 1.0, - ), - stale_job_timeout_seconds=_coerce_float( - hist.get("stale_job_timeout_seconds") - if isinstance(hist, dict) - else None, - 300.0, - ), - profile_revision_keep=_coerce_int( - prof.get("revision_keep") if isinstance(prof, dict) else None, 5 - ), - failed_max_age_days=_coerce_int( - que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 - ), - failed_max_files=_coerce_int( - que.get("failed_max_files") if isinstance(que, dict) else None, 500 - ), - failed_cleanup_interval=_coerce_int( - que.get("failed_cleanup_interval") if isinstance(que, dict) else None, - 100, - ), - job_max_retries=_coerce_int( - que.get("job_max_retries") if isinstance(que, dict) else None, 3 - ), - ) - - @staticmethod - def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: - section_raw = data.get("memes", {}) - section = section_raw if isinstance(section_raw, dict) else {} - return MemeConfig( - enabled=_coerce_bool(section.get("enabled"), True), - query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), - max_source_image_bytes=max( - 1, - _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), - ), - blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), - preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), - db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), - vector_store_path=_coerce_str( - section.get("vector_store_path"), "data/memes/chromadb" - ), - queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), - max_items=max(1, _coerce_int(section.get("max_items"), 10000)), - max_total_bytes=max( - 1, - _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), - ), - allow_gif=_coerce_bool(section.get("allow_gif"), True), - auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), - auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), - keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), - semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), - rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), - worker_max_concurrency=max( - 1, _coerce_int(section.get("worker_max_concurrency"), 4) - ), - ) - - @staticmethod - def _parse_api_config(data: dict[str, Any]) -> APIConfig: - section_raw = data.get("api", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), True) - host = _coerce_str(section.get("host"), DEFAULT_API_HOST) - port = _coerce_int(section.get("port"), DEFAULT_API_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_API_PORT - - auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) - if not auth_key: - auth_key = DEFAULT_API_AUTH_KEY - - openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) - - tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) - tool_invoke_expose = _coerce_str( - section.get("tool_invoke_expose"), "tools+toolsets" - ) - valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} - if tool_invoke_expose not in valid_expose: - tool_invoke_expose = "tools+toolsets" - tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) - tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) - tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) - if tool_invoke_timeout <= 0: - tool_invoke_timeout = 120 - tool_invoke_callback_timeout = _coerce_int( - section.get("tool_invoke_callback_timeout"), 10 - ) - if tool_invoke_callback_timeout <= 0: - tool_invoke_callback_timeout = 10 - - return APIConfig( - enabled=enabled, - host=host, - port=port, - auth_key=auth_key, - openapi_enabled=openapi_enabled, - tool_invoke_enabled=tool_invoke_enabled, - tool_invoke_expose=tool_invoke_expose, - tool_invoke_allowlist=tool_invoke_allowlist, - tool_invoke_denylist=tool_invoke_denylist, - tool_invoke_timeout=tool_invoke_timeout, - tool_invoke_callback_timeout=tool_invoke_callback_timeout, - ) - - @staticmethod - def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: - section_raw = data.get("naga", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), False) - api_url = _coerce_str(section.get("api_url"), "") - api_key = _coerce_str(section.get("api_key"), "") - moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) - allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) - - return NagaConfig( - enabled=enabled, - api_url=api_url, - api_key=api_key, - moderation_enabled=moderation_enabled, - allowed_groups=allowed_groups, - ) - - @staticmethod - def _parse_easter_egg_call_mode(value: Any) -> str: - """解析彩蛋提示模式。 - - 兼容旧版布尔值: - - True => agent - - False => none - """ - if isinstance(value, bool): - return "agent" if value else "none" - if isinstance(value, (int, float)): - return "agent" if bool(value) else "none" - if value is None: - return "none" - - text = str(value).strip().lower() - if text in {"true", "1", "yes", "on"}: - return "agent" - if text in {"false", "0", "no", "off"}: - return "none" - if text in {"none", "agent", "tools", "all", "clean"}: - return text - return "none" - def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: new_config = Config.load(strict=strict) return self.update_from(new_config) @@ -3346,20 +1594,3 @@ def is_superadmin(self, qq: int) -> bool: def is_admin(self, qq: int) -> bool: return qq in self.admin_qqs - - -def _update_dataclass( - old_value: Any, new_value: Any, prefix: str -) -> dict[str, tuple[Any, Any]]: - changes: dict[str, tuple[Any, Any]] = {} - if not isinstance(old_value, type(new_value)): - changes[prefix] = (old_value, new_value) - return changes - for field in fields(old_value): - name = field.name - old_attr = getattr(old_value, name) - new_attr = getattr(new_value, name) - if old_attr != new_attr: - setattr(old_value, name, new_attr) - changes[f"{prefix}.{name}"] = (old_attr, new_attr) - return changes diff --git a/src/Undefined/config/model_parsers.py b/src/Undefined/config/model_parsers.py new file mode 100644 index 0000000..e17fc85 --- /dev/null +++ b/src/Undefined/config/model_parsers.py @@ -0,0 +1,1250 @@ +"""Model configuration parsers extracted from Config class.""" + +from __future__ import annotations + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from .admin import load_local_admins +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, + _VALID_API_MODES, +) +from .models import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + ImageGenConfig, + ImageGenModelConfig, + ModelPool, + ModelPoolEntry, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) +from .resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_model_pool( + data: dict[str, Any], + model_section: str, + primary_config: ChatModelConfig | AgentModelConfig, +) -> ModelPool | None: + """解析模型池配置,缺省字段继承 primary_config""" + pool_data = data.get("models", {}).get(model_section, {}).get("pool") + if not isinstance(pool_data, dict): + return None + + enabled = _coerce_bool(pool_data.get("enabled"), False) + strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() + if strategy not in ("default", "round_robin", "random"): + strategy = "default" + + raw_models = pool_data.get("models") + if not isinstance(raw_models, list): + return ModelPool(enabled=enabled, strategy=strategy) + + entries: list[ModelPoolEntry] = [] + for item in raw_models: + if not isinstance(item, dict): + continue + name = _coerce_str(item.get("model_name"), "").strip() + if not name: + continue + entries.append( + ModelPoolEntry( + api_url=_coerce_str(item.get("api_url"), primary_config.api_url), + api_key=_coerce_str(item.get("api_key"), primary_config.api_key), + model_name=name, + max_tokens=_coerce_int( + item.get("max_tokens"), primary_config.max_tokens + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + item.get("queue_interval_seconds"), + primary_config.queue_interval_seconds, + ), + primary_config.queue_interval_seconds, + ), + api_mode=( + _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + ) + if _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + in _VALID_API_MODES + else primary_config.api_mode, + thinking_enabled=_coerce_bool( + item.get("thinking_enabled"), primary_config.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + item.get("thinking_budget_tokens"), + primary_config.thinking_budget_tokens, + ), + thinking_include_budget=_coerce_bool( + item.get("thinking_include_budget"), + primary_config.thinking_include_budget, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + item.get("reasoning_effort_style"), + primary_config.reasoning_effort_style, + ), + thinking_tool_call_compat=_coerce_bool( + item.get("thinking_tool_call_compat"), + primary_config.thinking_tool_call_compat, + ), + responses_tool_choice_compat=_coerce_bool( + item.get("responses_tool_choice_compat"), + primary_config.responses_tool_choice_compat, + ), + responses_force_stateless_replay=_coerce_bool( + item.get("responses_force_stateless_replay"), + primary_config.responses_force_stateless_replay, + ), + prompt_cache_enabled=_coerce_bool( + item.get("prompt_cache_enabled"), + primary_config.prompt_cache_enabled, + ), + reasoning_enabled=_coerce_bool( + item.get("reasoning_enabled"), + primary_config.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + item.get("reasoning_effort"), + primary_config.reasoning_effort, + ), + request_params=merge_request_params( + primary_config.request_params, + item.get("request_params"), + ), + ) + ) + + return ModelPool(enabled=enabled, strategy=strategy, models=entries) + + +def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: + return EmbeddingModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" + ), + "", + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + _get_value( + data, ("models", "embedding", "queue_interval_seconds"), None + ), + 0.0, + ), + 0.0, + ), + dimensions=_coerce_int( + _get_value(data, ("models", "embedding", "dimensions"), None), 0 + ) + or None, + query_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "query_instruction"), None), "" + ), + document_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "document_instruction"), None), + "", + ), + request_params=_get_model_request_params(data, "embedding"), + ) + + +def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), + 0.0, + ), + 0.0, + ) + return RerankModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + query_instruction=_coerce_str( + _get_value(data, ("models", "rerank", "query_instruction"), None), "" + ), + request_params=_get_model_request_params(data, "rerank"), + ) + + +def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "chat", "queue_interval_seconds"), + "CHAT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="chat", + include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "prompt_cache_enabled"), + "CHAT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "reasoning_enabled"), + "CHAT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "chat", "reasoning_effort"), + "CHAT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = ChatModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "chat", "thinking_enabled"), + "CHAT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "chat", "thinking_budget_tokens"), + "CHAT_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "chat", "reasoning_effort_style"), + "CHAT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "chat"), + ) + config.pool = _parse_model_pool(data, "chat", config) + return config + + +def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "vision", "queue_interval_seconds"), + "VISION_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="vision", + include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "prompt_cache_enabled"), + "VISION_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "reasoning_enabled"), + "VISION_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "vision", "reasoning_effort"), + "VISION_MODEL_REASONING_EFFORT", + ), + "medium", + ) + return VisionModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "vision", "thinking_enabled"), + "VISION_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "thinking_budget_tokens"), + "VISION_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "vision", "reasoning_effort_style"), + "VISION_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "vision"), + ) + + +def _parse_security_model_config( + data: dict[str, Any], chat_model: ChatModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "security", "queue_interval_seconds"), + "SECURITY_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="security", + include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "prompt_cache_enabled"), + "SECURITY_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "reasoning_enabled"), + "SECURITY_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "security", "reasoning_effort"), + "SECURITY_MODEL_REASONING_EFFORT", + ), + "medium", + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "max_tokens"), + "SECURITY_MODEL_MAX_TOKENS", + ), + 100, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "security", "thinking_enabled"), + "SECURITY_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "thinking_budget_tokens"), + "SECURITY_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "security", "reasoning_effort_style"), + "SECURITY_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "security"), + ) + + logger.warning("未配置安全模型,将使用对话模型作为后备") + return SecurityModelConfig( + api_url=chat_model.api_url, + api_key=chat_model.api_key, + model_name=chat_model.model_name, + max_tokens=chat_model.max_tokens, + queue_interval_seconds=chat_model.queue_interval_seconds, + api_mode=chat_model.api_mode, + thinking_enabled=False, + thinking_budget_tokens=0, + thinking_include_budget=True, + reasoning_effort_style="openai", + thinking_tool_call_compat=chat_model.thinking_tool_call_compat, + responses_tool_choice_compat=chat_model.responses_tool_choice_compat, + responses_force_stateless_replay=chat_model.responses_force_stateless_replay, + prompt_cache_enabled=chat_model.prompt_cache_enabled, + reasoning_enabled=chat_model.reasoning_enabled, + reasoning_effort=chat_model.reasoning_effort, + request_params=merge_request_params(chat_model.request_params), + ) + + +def _parse_naga_model_config( + data: dict[str, Any], security_model: SecurityModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "naga", "queue_interval_seconds"), + "NAGA_MODEL_QUEUE_INTERVAL", + ), + security_model.queue_interval_seconds, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="naga", + include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "prompt_cache_enabled"), + "NAGA_MODEL_PROMPT_CACHE_ENABLED", + ), + getattr(security_model, "prompt_cache_enabled", True), + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "reasoning_enabled"), + "NAGA_MODEL_REASONING_ENABLED", + ), + getattr(security_model, "reasoning_enabled", False), + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "naga", "reasoning_effort"), + "NAGA_MODEL_REASONING_EFFORT", + ), + getattr(security_model, "reasoning_effort", "medium"), + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "max_tokens"), + "NAGA_MODEL_MAX_TOKENS", + ), + 160, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "naga", "thinking_enabled"), + "NAGA_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "thinking_budget_tokens"), + "NAGA_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "naga", "reasoning_effort_style"), + "NAGA_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "naga"), + ) + + logger.info( + "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" + ) + return SecurityModelConfig( + api_url=security_model.api_url, + api_key=security_model.api_key, + model_name=security_model.model_name, + max_tokens=security_model.max_tokens, + queue_interval_seconds=security_model.queue_interval_seconds, + api_mode=security_model.api_mode, + thinking_enabled=security_model.thinking_enabled, + thinking_budget_tokens=security_model.thinking_budget_tokens, + thinking_include_budget=security_model.thinking_include_budget, + reasoning_effort_style=security_model.reasoning_effort_style, + thinking_tool_call_compat=security_model.thinking_tool_call_compat, + responses_tool_choice_compat=security_model.responses_tool_choice_compat, + responses_force_stateless_replay=security_model.responses_force_stateless_replay, + prompt_cache_enabled=security_model.prompt_cache_enabled, + reasoning_enabled=security_model.reasoning_enabled, + reasoning_effort=security_model.reasoning_effort, + request_params=merge_request_params(security_model.request_params), + ) + + +def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "agent", "queue_interval_seconds"), + "AGENT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="agent", + include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "prompt_cache_enabled"), + "AGENT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "reasoning_enabled"), + "AGENT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "agent", "reasoning_effort"), + "AGENT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = AgentModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" + ), + 4096, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "agent", "thinking_enabled"), + "AGENT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "agent", "thinking_budget_tokens"), + "AGENT_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "agent", "reasoning_effort_style"), + "AGENT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "agent"), + ) + config.pool = _parse_model_pool(data, "agent", config) + return config + + +def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "grok", "queue_interval_seconds"), + "GROK_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + return GrokModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_enabled"), + "GROK_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "grok", "thinking_budget_tokens"), + "GROK_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_include_budget"), + "GROK_MODEL_THINKING_INCLUDE_BUDGET", + ), + True, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "grok", "reasoning_effort_style"), + "GROK_MODEL_REASONING_EFFORT_STYLE", + ), + ), + prompt_cache_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "prompt_cache_enabled"), + "GROK_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ), + reasoning_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "reasoning_enabled"), + "GROK_MODEL_REASONING_ENABLED", + ), + False, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + data, + ("models", "grok", "reasoning_effort"), + "GROK_MODEL_REASONING_EFFORT", + ), + "medium", + ), + request_params=_get_model_request_params(data, "grok"), + ) + + +def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_gen] 生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" + ), + "", + ), + request_params=_get_model_request_params(data, "image_gen"), + ) + + +def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + + +def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: + """解析 [image_gen] 生图工具配置""" + return ImageGenConfig( + provider=_coerce_str( + _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), + "xingzhige", + ), + xingzhige_size=_coerce_str( + _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" + ), + openai_size=_coerce_str( + _get_value(data, ("image_gen", "openai_size"), None), "" + ), + openai_quality=_coerce_str( + _get_value(data, ("image_gen", "openai_quality"), None), "" + ), + openai_style=_coerce_str( + _get_value(data, ("image_gen", "openai_style"), None), "" + ), + openai_timeout=_coerce_float( + _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 + ), + ) + + +def _merge_admins(superadmin_qq: int, admin_qqs: list[int]) -> tuple[int, list[int]]: + local_admins = load_local_admins() + all_admins = list(set(admin_qqs + local_admins)) + if superadmin_qq and superadmin_qq not in all_admins: + all_admins.append(superadmin_qq) + return superadmin_qq, all_admins + + +def _verify_required_fields( + bot_qq: int, + superadmin_qq: int, + onebot_ws_url: str, + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + agent_model: AgentModelConfig, + knowledge_enabled: bool, + embedding_model: EmbeddingModelConfig, +) -> None: + missing: list[str] = [] + if bot_qq <= 0: + missing.append("core.bot_qq") + if superadmin_qq <= 0: + missing.append("core.superadmin_qq") + if not onebot_ws_url: + missing.append("onebot.ws_url") + if not chat_model.api_url: + missing.append("models.chat.api_url") + if not chat_model.api_key: + missing.append("models.chat.api_key") + if not chat_model.model_name: + missing.append("models.chat.model_name") + if not vision_model.api_url: + missing.append("models.vision.api_url") + if not vision_model.api_key: + missing.append("models.vision.api_key") + if not vision_model.model_name: + missing.append("models.vision.model_name") + if not agent_model.api_url: + missing.append("models.agent.api_url") + if not agent_model.api_key: + missing.append("models.agent.api_key") + if not agent_model.model_name: + missing.append("models.agent.model_name") + if knowledge_enabled: + if not embedding_model.api_url: + missing.append("models.embedding.api_url") + if not embedding_model.model_name: + missing.append("models.embedding.model_name") + if missing: + raise ValueError(f"缺少必需配置: {', '.join(missing)}") + + +def _log_debug_info( + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + security_model: SecurityModelConfig, + naga_model: SecurityModelConfig, + agent_model: AgentModelConfig, + summary_model: AgentModelConfig, + grok_model: GrokModelConfig, +) -> None: + configs: list[ + tuple[ + str, + ChatModelConfig + | VisionModelConfig + | SecurityModelConfig + | AgentModelConfig + | GrokModelConfig, + ] + ] = [ + ("chat", chat_model), + ("vision", vision_model), + ("security", security_model), + ("naga", naga_model), + ("agent", agent_model), + ("summary", summary_model), + ("grok", grok_model), + ] + for name, cfg in configs: + logger.debug( + "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", + name, + cfg.model_name, + cfg.api_url, + bool(cfg.api_key), + getattr(cfg, "api_mode", "chat_completions"), + cfg.thinking_enabled, + getattr(cfg, "reasoning_enabled", False), + getattr(cfg, "reasoning_effort", "medium"), + getattr(cfg, "thinking_tool_call_compat", False), + getattr(cfg, "responses_tool_choice_compat", False), + getattr(cfg, "responses_force_stateless_replay", False), + ) + + +def _parse_historian_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> AgentModelConfig: + h = data.get("models", {}).get("historian", {}) + if not isinstance(h, dict) or not h: + return fallback + queue_interval_seconds = _coerce_float( + h.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"historian": h}}, + model_name="historian", + include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "prompt_cache_enabled"), + "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return AgentModelConfig( + api_url=_coerce_str(h.get("api_url"), fallback.api_url), + api_key=_coerce_str(h.get("api_key"), fallback.api_key), + model_name=_coerce_str(h.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + h.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort_style"), + "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_enabled"), + "HISTORIAN_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort"), + "HISTORIAN_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + h.get("request_params"), + ), + ) + + +def _parse_summary_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> tuple[AgentModelConfig, bool]: + s = data.get("models", {}).get("summary", {}) + if not isinstance(s, dict) or not s: + return fallback, False + queue_interval_seconds = _coerce_float( + s.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"summary": s}}, + model_name="summary", + include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "prompt_cache_enabled"), + "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return ( + AgentModelConfig( + api_url=_coerce_str(s.get("api_url"), fallback.api_url), + api_key=_coerce_str(s.get("api_key"), fallback.api_key), + model_name=_coerce_str(s.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + s.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort_style"), + "SUMMARY_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_enabled"), + "SUMMARY_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort"), + "SUMMARY_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + s.get("request_params"), + ), + ), + True, + ) diff --git a/src/Undefined/config/resolvers.py b/src/Undefined/config/resolvers.py new file mode 100644 index 0000000..43533c3 --- /dev/null +++ b/src/Undefined/config/resolvers.py @@ -0,0 +1,106 @@ +"""Configuration value resolution helpers.""" + +from __future__ import annotations + +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_str, + _get_value, + _VALID_API_MODES, + _VALID_REASONING_EFFORT_STYLES, +) + + +def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: + style = _coerce_str(value, default).strip().lower() + if style not in _VALID_REASONING_EFFORT_STYLES: + return default + return style + + +def _resolve_thinking_compat_flags( + data: dict[str, Any], + model_name: str, + include_budget_env_key: str, + tool_call_compat_env_key: str, + legacy_env_key: str, +) -> tuple[bool, bool]: + """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" + include_budget_value = _get_value( + data, + ("models", model_name, "thinking_include_budget"), + include_budget_env_key, + ) + tool_call_compat_value = _get_value( + data, + ("models", model_name, "thinking_tool_call_compat"), + tool_call_compat_env_key, + ) + legacy_value = _get_value( + data, + ("models", model_name, "deepseek_new_cot_support"), + legacy_env_key, + ) + + include_budget_default = True + tool_call_compat_default = True + if legacy_value is not None: + legacy_enabled = _coerce_bool(legacy_value, False) + include_budget_default = not legacy_enabled + tool_call_compat_default = legacy_enabled + + return ( + _coerce_bool(include_budget_value, include_budget_default), + _coerce_bool(tool_call_compat_value, tool_call_compat_default), + ) + + +def _resolve_api_mode( + data: dict[str, Any], + model_name: str, + env_key: str, + default: str = "chat_completions", +) -> str: + raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) + value = _coerce_str(raw_value, default).strip().lower() + if value not in _VALID_API_MODES: + return default + return value + + +def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: + return _coerce_str(value, default).strip().lower() + + +def _resolve_responses_tool_choice_compat( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_tool_choice_compat"), + env_key, + ), + default, + ) + + +def _resolve_responses_force_stateless_replay( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_force_stateless_replay"), + env_key, + ), + default, + ) diff --git a/src/Undefined/config/webui_settings.py b/src/Undefined/config/webui_settings.py new file mode 100644 index 0000000..45a5392 --- /dev/null +++ b/src/Undefined/config/webui_settings.py @@ -0,0 +1,61 @@ +"""WebUI settings management.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .coercers import _coerce_int, _coerce_str, _normalize_str, _get_value + +DEFAULT_WEBUI_URL = "127.0.0.1" +DEFAULT_WEBUI_PORT = 8787 +DEFAULT_WEBUI_PASSWORD = "changeme" + + +@dataclass +class WebUISettings: + url: str + port: int + password: str + using_default_password: bool + config_exists: bool + + @property + def display_url(self) -> str: + """用于日志和展示的格式化 URL。""" + from Undefined.config.models import format_netloc + + return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" + + +def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: + from .loader import load_toml_data # lazy to avoid circular + + data = load_toml_data(config_path) + config_exists = bool(data) + url_value = _get_value(data, ("webui", "url"), None) + port_value = _get_value(data, ("webui", "port"), None) + password_value = _get_value(data, ("webui", "password"), None) + + url = _coerce_str(url_value, DEFAULT_WEBUI_URL) + port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_WEBUI_PORT + + password_normalized = _normalize_str(password_value) + if not password_normalized: + return WebUISettings( + url=url, + port=port, + password=DEFAULT_WEBUI_PASSWORD, + using_default_password=True, + config_exists=config_exists, + ) + return WebUISettings( + url=url, + port=port, + password=password_normalized, + using_default_password=False, + config_exists=config_exists, + ) diff --git a/tests/test_config_cognitive_historian_limits.py b/tests/test_config_cognitive_historian_limits.py index c7f843f..bb58f56 100644 --- a/tests/test_config_cognitive_historian_limits.py +++ b/tests/test_config_cognitive_historian_limits.py @@ -1,10 +1,10 @@ from __future__ import annotations -from Undefined.config.loader import Config +from Undefined.config.domain_parsers import _parse_cognitive_config def test_parse_cognitive_historian_reference_limits() -> None: - cfg = Config._parse_cognitive_config( + cfg = _parse_cognitive_config( { "cognitive": { "query": { @@ -32,7 +32,7 @@ def test_parse_cognitive_historian_reference_limits() -> None: def test_parse_cognitive_historian_reference_limits_defaults() -> None: - cfg = Config._parse_cognitive_config({"cognitive": {}}) + cfg = _parse_cognitive_config({"cognitive": {}}) assert cfg.historian_recent_messages_inject_k == 12 assert cfg.historian_recent_message_line_max_len == 240 From 30b9a26dd6846b9900405fdad50b9275ebe26f10 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:09:43 +0800 Subject: [PATCH 31/57] refactor(api): extract helpers, probes, and OpenAPI from app.py Split 3077-line monolith into focused modules: - _helpers.py: utility classes and functions (~345 lines) - _probes.py: HTTP/WS endpoint health probes (~145 lines) - _openapi.py: OpenAPI spec builder (~180 lines) app.py retains RuntimeAPIContext + RuntimeAPIServer (~2500 lines). TYPE_CHECKING guard prevents circular imports for _openapi. All 1429 tests pass, mypy strict clean. Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/api/_helpers.py | 333 +++++++++++++ src/Undefined/api/_openapi.py | 183 +++++++ src/Undefined/api/_probes.py | 147 ++++++ src/Undefined/api/app.py | 652 ++----------------------- tests/test_runtime_api_sender_proxy.py | 2 +- tests/test_runtime_api_tool_invoke.py | 2 +- tests/test_webui_management_api.py | 10 +- 7 files changed, 703 insertions(+), 626 deletions(-) create mode 100644 src/Undefined/api/_helpers.py create mode 100644 src/Undefined/api/_openapi.py create mode 100644 src/Undefined/api/_probes.py diff --git a/src/Undefined/api/_helpers.py b/src/Undefined/api/_helpers.py new file mode 100644 index 0000000..40dcfb0 --- /dev/null +++ b/src/Undefined/api/_helpers.py @@ -0,0 +1,333 @@ +"""Helper classes and utility functions for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Callable +from urllib.parse import urlsplit + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.config import load_webui_settings +from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin + +logger = logging.getLogger(__name__) + +_AUTH_HEADER = "X-Undefined-API-Key" +_VIRTUAL_USER_ID = 42 + + +class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): + """由 Runtime API 工具调用超时包装器抛出的超时异常。""" + + +@dataclass +class _NagaRequestResult: + payload_hash: str + status: int + payload: dict[str, Any] + finished_at: float + + +class _WebUIVirtualSender: + """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" + + def __init__( + self, + virtual_user_id: int, + send_private_callback: Callable[[int, str], Awaitable[None]], + onebot: Any = None, + ) -> None: + self._virtual_user_id = virtual_user_id + self._send_private_callback = send_private_callback + # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 + self.onebot = onebot + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + reply_to: int | None = None, + preferred_temp_group_id: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + user_id, + auto_history, + mark_sent, + reply_to, + preferred_temp_group_id, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_group_message( + self, + group_id: int, + message: str, + auto_history: bool = True, + history_prefix: str = "", + *, + mark_sent: bool = True, + reply_to: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + group_id, + auto_history, + history_prefix, + mark_sent, + reply_to, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_private_file( + self, + user_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" + _ = user_id, auto_history + import shutil + import uuid as _uuid + from pathlib import Path as _Path + + from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir + + src = _Path(file_path) + display_name = name or src.name + file_id = _uuid.uuid4().hex + dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) + dest = dest_dir / display_name + + def _copy_and_stat() -> int: + shutil.copy2(src, dest) + return dest.stat().st_size + + try: + file_size = await asyncio.to_thread(_copy_and_stat) + except OSError: + file_size = 0 + + message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" + await self._send_private_callback(self._virtual_user_id, message) + + async def send_group_file( + self, + group_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """群文件在虚拟会话中同样重定向为文本消息。""" + await self.send_private_file(group_id, file_path, name, auto_history) + + +def _json_error(message: str, status: int = 400) -> Response: + return web.json_response({"error": message}, status=status) + + +def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: + origin = normalize_origin(str(request.headers.get("Origin") or "")) + settings = load_webui_settings() + response.headers.setdefault("Vary", "Origin") + response.headers.setdefault( + "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" + ) + response.headers.setdefault( + "Access-Control-Allow-Headers", + "Authorization, Content-Type, X-Undefined-API-Key", + ) + response.headers.setdefault("Access-Control-Max-Age", "86400") + if origin and is_allowed_cors_origin( + origin, + configured_host=str(settings.url or ""), + configured_port=settings.port, + ): + response.headers.setdefault("Access-Control-Allow-Origin", origin) + response.headers.setdefault("Access-Control-Allow-Credentials", "true") + + +def _optional_query_param(request: web.Request, key: str) -> str | None: + raw = request.query.get(key) + if raw is None: + return None + text = str(raw).strip() + if not text: + return None + return text + + +def _parse_query_time(value: str | None) -> datetime | None: + text = str(value or "").strip() + if not text: + return None + candidates = [text, text.replace("Z", "+00:00")] + if "T" in text: + candidates.append(text.replace("T", " ")) + for candidate in candidates: + with suppress(ValueError): + return datetime.fromisoformat(candidate) + return None + + +def _to_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value != 0 + text = str(value or "").strip().lower() + return text in {"1", "true", "yes", "on"} + + +def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: + return { + "mode": mode, + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "messages": outputs, + "reply": "\n\n".join(outputs).strip(), + } + + +def _sse_event(event: str, payload: dict[str, Any]) -> bytes: + data = json.dumps(payload, ensure_ascii=False) + return f"event: {event}\ndata: {data}\n\n".encode("utf-8") + + +def _mask_url(url: str) -> str: + """保留 scheme + host,隐藏 path 细节。""" + text = str(url or "").strip().rstrip("/") + if not text: + return "" + parsed = urlsplit(text) + host = parsed.hostname or "" + port_part = f":{parsed.port}" if parsed.port else "" + scheme = parsed.scheme or "https" + return f"{scheme}://{host}{port_part}/..." + + +def _naga_runtime_enabled(cfg: Any) -> bool: + naga_cfg = getattr(cfg, "naga", None) + return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( + getattr(naga_cfg, "enabled", False) + ) + + +def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: + return _naga_runtime_enabled(cfg) and naga_store is not None + + +def _short_text_preview(text: str, limit: int = 80) -> str: + compact = " ".join(str(text or "").split()) + if len(compact) <= limit: + return compact + return compact[:limit] + "..." + + +def _naga_message_digest( + *, + bind_uuid: str, + naga_id: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, +) -> str: + raw = json.dumps( + { + "bind_uuid": bind_uuid, + "naga_id": naga_id, + "target_qq": target_qq, + "target_group": target_group, + "mode": mode, + "format": message_format, + "content": content, + }, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + ) + return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] + + +def _parse_response_payload(response: Response) -> dict[str, Any]: + text = response.text or "" + if not text: + return {} + payload = json.loads(text) + return payload if isinstance(payload, dict) else {"data": payload} + + +def _registry_summary(registry: Any) -> dict[str, Any]: + """从 BaseRegistry 提取轻量摘要。""" + if registry is None: + return {"count": 0, "loaded": 0, "items": []} + items: dict[str, Any] = getattr(registry, "_items", {}) + stats: dict[str, Any] = {} + get_stats = getattr(registry, "get_stats", None) + if callable(get_stats): + stats = get_stats() + summary_items: list[dict[str, Any]] = [] + for name, item in items.items(): + st = stats.get(name) + entry: dict[str, Any] = { + "name": name, + "loaded": getattr(item, "loaded", False), + } + if st is not None: + entry["calls"] = getattr(st, "count", 0) + entry["success"] = getattr(st, "success", 0) + entry["failure"] = getattr(st, "failure", 0) + summary_items.append(entry) + return { + "count": len(items), + "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), + "items": summary_items, + } + + +def _validate_callback_url(url: str) -> str | None: + """校验回调 URL,返回错误信息或 None 表示通过。 + + 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 + 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 + """ + import ipaddress + + parsed = urlsplit(url) + scheme = (parsed.scheme or "").lower() + + if scheme not in ("http", "https"): + return "callback.url must use http or https" + + hostname = parsed.hostname or "" + if not hostname: + return "callback.url must include a hostname" + + # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) + try: + addr = ipaddress.ip_address(hostname) + except ValueError: + pass # 域名形式,放行 + else: + if addr.is_private or addr.is_loopback or addr.is_link_local: + return "callback.url must not point to a private/loopback address" + + return None diff --git a/src/Undefined/api/_openapi.py b/src/Undefined/api/_openapi.py new file mode 100644 index 0000000..6c4566c --- /dev/null +++ b/src/Undefined/api/_openapi.py @@ -0,0 +1,183 @@ +"""OpenAPI / Swagger specification builder.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from Undefined import __version__ +from ._helpers import _AUTH_HEADER, _naga_routes_enabled + +if TYPE_CHECKING: + from .app import RuntimeAPIContext + + +def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: + server_url = f"{request.scheme}://{request.host}" + cfg = ctx.config_getter() + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + paths: dict[str, Any] = { + "/health": { + "get": { + "summary": "Health check", + "security": [], + "responses": {"200": {"description": "OK"}}, + } + }, + "/openapi.json": { + "get": { + "summary": "OpenAPI schema", + "security": [], + "responses": {"200": {"description": "Schema JSON"}}, + } + }, + "/api/v1/probes/internal": { + "get": { + "summary": "Internal runtime probes", + "description": ( + "Returns system info (version, Python, platform, uptime), " + "OneBot connection status, request queue snapshot, " + "memory count, cognitive service status, API config, " + "skill statistics (tools/agents/anthropic_skills with call counts), " + "and model configuration (names, masked URLs, thinking flags)." + ), + } + }, + "/api/v1/probes/external": { + "get": { + "summary": "External dependency probes", + "description": ( + "Concurrently probes all configured model API endpoints " + "(chat, vision, security, naga, agent, embedding, rerank) " + "and OneBot WebSocket. Each result includes status, " + "model name, masked URL, HTTP status code, and latency." + ), + } + }, + "/api/v1/memory": { + "get": {"summary": "List/search manual memories"}, + "post": {"summary": "Create a manual memory"}, + }, + "/api/v1/memory/{uuid}": { + "patch": {"summary": "Update a manual memory by UUID"}, + "delete": {"summary": "Delete a manual memory by UUID"}, + }, + "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, + "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, + "/api/v1/memes/{uid}": { + "get": {"summary": "Get a meme by uid"}, + "patch": {"summary": "Update a meme by uid"}, + "delete": {"summary": "Delete a meme by uid"}, + }, + "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, + "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, + "/api/v1/memes/{uid}/reanalyze": { + "post": {"summary": "Queue a meme reanalyze job"} + }, + "/api/v1/memes/{uid}/reindex": { + "post": {"summary": "Queue a meme reindex job"} + }, + "/api/v1/cognitive/events": { + "get": {"summary": "Search cognitive event memories"} + }, + "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, + "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { + "get": {"summary": "Get a profile by entity type/id"} + }, + "/api/v1/chat": { + "post": { + "summary": "WebUI special private chat", + "description": ( + "POST JSON {message, stream?}. " + "When stream=true, response is SSE with keep-alive comments." + ), + } + }, + "/api/v1/chat/history": { + "get": {"summary": "Get virtual private chat history for WebUI"} + }, + "/api/v1/tools": { + "get": { + "summary": "List available tools", + "description": ( + "Returns currently available tools filtered by " + "tool_invoke_expose / allowlist / denylist config. " + "Each item follows the OpenAI function calling schema." + ), + } + }, + "/api/v1/tools/invoke": { + "post": { + "summary": "Invoke a tool", + "description": ( + "Execute a specific tool by name. Supports synchronous " + "response and optional async webhook callback." + ), + } + }, + } + + if naga_routes_enabled: + paths.update( + { + "/api/v1/naga/bind/callback": { + "post": { + "summary": "Finalize a Naga bind request", + "description": ( + "Internal callback used by Naga to approve or reject " + "a bind_uuid." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/messages/send": { + "post": { + "summary": "Send a Naga-authenticated message", + "description": ( + "Validates bind_uuid + delivery_signature, runs " + "moderation, then delivers the message. " + "Caller may provide uuid for idempotent retry deduplication." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/unbind": { + "post": { + "summary": "Revoke an active Naga binding", + "description": ( + "Allows Naga to proactively revoke a binding using " + "Authorization: Bearer ." + ), + "security": [{"BearerAuth": []}], + } + }, + } + ) + + return { + "openapi": "3.0.3", + "info": { + "title": "Undefined Runtime API", + "version": __version__, + "description": "API exposed by the main Undefined process.", + }, + "servers": [ + { + "url": server_url, + "description": "Runtime endpoint", + } + ], + "components": { + "securitySchemes": { + "ApiKeyAuth": { + "type": "apiKey", + "in": "header", + "name": _AUTH_HEADER, + }, + "BearerAuth": {"type": "http", "scheme": "bearer"}, + } + }, + "security": [{"ApiKeyAuth": []}], + "paths": paths, + } diff --git a/src/Undefined/api/_probes.py b/src/Undefined/api/_probes.py new file mode 100644 index 0000000..fc51bb3 --- /dev/null +++ b/src/Undefined/api/_probes.py @@ -0,0 +1,147 @@ +"""HTTP/WebSocket endpoint probe functions for health checks.""" + +from __future__ import annotations + +import asyncio +import logging +import socket +import time +from typing import Any +from urllib.parse import urlsplit + +from aiohttp import ClientSession, ClientTimeout + +from ._helpers import _mask_url + +logger = logging.getLogger(__name__) + + +async def _probe_http_endpoint( + *, + name: str, + base_url: str, + api_key: str, + model_name: str = "", + timeout_seconds: float = 5.0, +) -> dict[str, Any]: + normalized = str(base_url or "").strip().rstrip("/") + if not normalized: + return { + "name": name, + "status": "skipped", + "reason": "empty_url", + "model_name": model_name, + } + + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + candidates = [normalized, f"{normalized}/models"] + last_error = "" + for url in candidates: + start = time.perf_counter() + try: + timeout = ClientTimeout(total=timeout_seconds) + async with ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as resp: + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": name, + "status": "ok", + "url": _mask_url(url), + "http_status": resp.status, + "latency_ms": elapsed_ms, + "model_name": model_name, + } + except Exception as exc: + last_error = str(exc) + continue + + return { + "name": name, + "status": "error", + "url": _mask_url(normalized), + "error": last_error or "request_failed", + "model_name": model_name, + } + + +async def _skipped_probe( + *, name: str, reason: str, model_name: str = "" +) -> dict[str, Any]: + payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} + if model_name: + payload["model_name"] = model_name + return payload + + +def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: + payload = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + if hasattr(mcfg, "api_mode"): + payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") + if hasattr(mcfg, "thinking_enabled"): + payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) + if hasattr(mcfg, "thinking_tool_call_compat"): + payload["thinking_tool_call_compat"] = getattr( + mcfg, "thinking_tool_call_compat", True + ) + if hasattr(mcfg, "responses_tool_choice_compat"): + payload["responses_tool_choice_compat"] = getattr( + mcfg, "responses_tool_choice_compat", False + ) + if hasattr(mcfg, "responses_force_stateless_replay"): + payload["responses_force_stateless_replay"] = getattr( + mcfg, "responses_force_stateless_replay", False + ) + if hasattr(mcfg, "prompt_cache_enabled"): + payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) + if hasattr(mcfg, "reasoning_enabled"): + payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) + if hasattr(mcfg, "reasoning_effort"): + payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") + return payload + + +async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: + normalized = str(url or "").strip() + if not normalized: + return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} + + parsed = urlsplit(normalized) + host = parsed.hostname + if not host: + return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} + + if parsed.port is not None: + port = parsed.port + elif parsed.scheme == "wss": + port = 443 + else: + port = 80 + + start = time.perf_counter() + try: + conn = asyncio.open_connection(host, port) + reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) + writer.close() + await writer.wait_closed() + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": "onebot_ws", + "status": "ok", + "host": host, + "port": port, + "latency_ms": elapsed_ms, + } + except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: + return { + "name": "onebot_ws", + "status": "error", + "host": host, + "port": port, + "error": str(exc), + } diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index cbe5824..91ceaf4 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1,14 +1,13 @@ +"""Runtime API server for Undefined.""" + from __future__ import annotations import asyncio from copy import deepcopy -import hashlib -import json import logging import os import platform from pathlib import Path -import socket import sys import time import uuid as _uuid @@ -17,7 +16,7 @@ from datetime import datetime from typing import Any, Awaitable, Callable from typing import cast -from urllib.parse import urlsplit + from aiohttp import ClientSession, ClientTimeout, web from aiohttp.web_response import Response @@ -28,140 +27,51 @@ register_message_attachments, render_message_with_pic_placeholders, ) -from Undefined.config import load_webui_settings from Undefined.context import RequestContext from Undefined.context_resource_registry import collect_context_resources from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN from Undefined.utils.common import message_to_segments -from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin from Undefined.utils.recent_messages import get_recent_messages_prefer_local from Undefined.utils.xml import escape_xml_attr, escape_xml_text -logger = logging.getLogger(__name__) +from ._helpers import ( + _apply_cors_headers, + _build_chat_response_payload, + _json_error, + _mask_url, + _naga_message_digest, + _naga_routes_enabled, + _naga_runtime_enabled, + _NagaRequestResult, + _optional_query_param, + _parse_query_time, + _parse_response_payload, + _registry_summary, + _short_text_preview, + _sse_event, + _to_bool, + _ToolInvokeExecutionTimeoutError, + _validate_callback_url, + _WebUIVirtualSender, + _AUTH_HEADER, + _VIRTUAL_USER_ID, +) +from ._probes import ( + _build_internal_model_probe_payload, + _probe_http_endpoint, + _probe_ws_endpoint, + _skipped_probe, +) +from ._openapi import _build_openapi_spec -_VIRTUAL_USER_ID = 42 +logger = logging.getLogger(__name__) _VIRTUAL_USER_NAME = "system" -_AUTH_HEADER = "X-Undefined-API-Key" _CHAT_SSE_KEEPALIVE_SECONDS = 10.0 _PROCESS_START_TIME = time.time() _NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 -class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): - """由 Runtime API 工具调用超时包装器抛出的超时异常。""" - - -@dataclass -class _NagaRequestResult: - payload_hash: str - status: int - payload: dict[str, Any] - finished_at: float - - -class _WebUIVirtualSender: - """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" - - def __init__( - self, - virtual_user_id: int, - send_private_callback: Callable[[int, str], Awaitable[None]], - onebot: Any = None, - ) -> None: - self._virtual_user_id = virtual_user_id - self._send_private_callback = send_private_callback - # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 - self.onebot = onebot - - async def send_private_message( - self, - user_id: int, - message: str, - auto_history: bool = True, - *, - mark_sent: bool = True, - reply_to: int | None = None, - preferred_temp_group_id: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - user_id, - auto_history, - mark_sent, - reply_to, - preferred_temp_group_id, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_group_message( - self, - group_id: int, - message: str, - auto_history: bool = True, - history_prefix: str = "", - *, - mark_sent: bool = True, - reply_to: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - group_id, - auto_history, - history_prefix, - mark_sent, - reply_to, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_private_file( - self, - user_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" - _ = user_id, auto_history - import shutil - import uuid as _uuid - from pathlib import Path as _Path - - from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir - - src = _Path(file_path) - display_name = name or src.name - file_id = _uuid.uuid4().hex - dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) - dest = dest_dir / display_name - - def _copy_and_stat() -> int: - shutil.copy2(src, dest) - return dest.stat().st_size - - try: - file_size = await asyncio.to_thread(_copy_and_stat) - except OSError: - file_size = 0 - - message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" - await self._send_private_callback(self._virtual_user_id, message) - - async def send_group_file( - self, - group_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """群文件在虚拟会话中同样重定向为文本消息。""" - await self.send_private_file(group_id, file_path, name, auto_history) - - @dataclass class RuntimeAPIContext: config_getter: Callable[[], Any] @@ -178,502 +88,6 @@ class RuntimeAPIContext: naga_store: Any = None -def _json_error(message: str, status: int = 400) -> Response: - return web.json_response({"error": message}, status=status) - - -def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: - origin = normalize_origin(str(request.headers.get("Origin") or "")) - settings = load_webui_settings() - response.headers.setdefault("Vary", "Origin") - response.headers.setdefault( - "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" - ) - response.headers.setdefault( - "Access-Control-Allow-Headers", - "Authorization, Content-Type, X-Undefined-API-Key", - ) - response.headers.setdefault("Access-Control-Max-Age", "86400") - if origin and is_allowed_cors_origin( - origin, - configured_host=str(settings.url or ""), - configured_port=settings.port, - ): - response.headers.setdefault("Access-Control-Allow-Origin", origin) - response.headers.setdefault("Access-Control-Allow-Credentials", "true") - - -def _optional_query_param(request: web.Request, key: str) -> str | None: - raw = request.query.get(key) - if raw is None: - return None - text = str(raw).strip() - if not text: - return None - return text - - -def _parse_query_time(value: str | None) -> datetime | None: - text = str(value or "").strip() - if not text: - return None - candidates = [text, text.replace("Z", "+00:00")] - if "T" in text: - candidates.append(text.replace("T", " ")) - for candidate in candidates: - with suppress(ValueError): - return datetime.fromisoformat(candidate) - return None - - -def _to_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return value != 0 - text = str(value or "").strip().lower() - return text in {"1", "true", "yes", "on"} - - -def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: - return { - "mode": mode, - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "messages": outputs, - "reply": "\n\n".join(outputs).strip(), - } - - -def _sse_event(event: str, payload: dict[str, Any]) -> bytes: - data = json.dumps(payload, ensure_ascii=False) - return f"event: {event}\ndata: {data}\n\n".encode("utf-8") - - -def _mask_url(url: str) -> str: - """保留 scheme + host,隐藏 path 细节。""" - text = str(url or "").strip().rstrip("/") - if not text: - return "" - parsed = urlsplit(text) - host = parsed.hostname or "" - port_part = f":{parsed.port}" if parsed.port else "" - scheme = parsed.scheme or "https" - return f"{scheme}://{host}{port_part}/..." - - -def _naga_runtime_enabled(cfg: Any) -> bool: - naga_cfg = getattr(cfg, "naga", None) - return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( - getattr(naga_cfg, "enabled", False) - ) - - -def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: - return _naga_runtime_enabled(cfg) and naga_store is not None - - -def _short_text_preview(text: str, limit: int = 80) -> str: - compact = " ".join(str(text or "").split()) - if len(compact) <= limit: - return compact - return compact[:limit] + "..." - - -def _naga_message_digest( - *, - bind_uuid: str, - naga_id: str, - target_qq: int, - target_group: int, - mode: str, - message_format: str, - content: str, -) -> str: - raw = json.dumps( - { - "bind_uuid": bind_uuid, - "naga_id": naga_id, - "target_qq": target_qq, - "target_group": target_group, - "mode": mode, - "format": message_format, - "content": content, - }, - ensure_ascii=False, - sort_keys=True, - separators=(",", ":"), - ) - return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] - - -def _parse_response_payload(response: Response) -> dict[str, Any]: - text = response.text or "" - if not text: - return {} - payload = json.loads(text) - return payload if isinstance(payload, dict) else {"data": payload} - - -def _registry_summary(registry: Any) -> dict[str, Any]: - """从 BaseRegistry 提取轻量摘要。""" - if registry is None: - return {"count": 0, "loaded": 0, "items": []} - items: dict[str, Any] = getattr(registry, "_items", {}) - stats: dict[str, Any] = {} - get_stats = getattr(registry, "get_stats", None) - if callable(get_stats): - stats = get_stats() - summary_items: list[dict[str, Any]] = [] - for name, item in items.items(): - st = stats.get(name) - entry: dict[str, Any] = { - "name": name, - "loaded": getattr(item, "loaded", False), - } - if st is not None: - entry["calls"] = getattr(st, "count", 0) - entry["success"] = getattr(st, "success", 0) - entry["failure"] = getattr(st, "failure", 0) - summary_items.append(entry) - return { - "count": len(items), - "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), - "items": summary_items, - } - - -def _validate_callback_url(url: str) -> str | None: - """校验回调 URL,返回错误信息或 None 表示通过。 - - 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 - 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 - """ - import ipaddress - - parsed = urlsplit(url) - scheme = (parsed.scheme or "").lower() - - if scheme not in ("http", "https"): - return "callback.url must use http or https" - - hostname = parsed.hostname or "" - if not hostname: - return "callback.url must include a hostname" - - # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) - try: - addr = ipaddress.ip_address(hostname) - except ValueError: - pass # 域名形式,放行 - else: - if addr.is_private or addr.is_loopback or addr.is_link_local: - return "callback.url must not point to a private/loopback address" - - return None - - -async def _probe_http_endpoint( - *, - name: str, - base_url: str, - api_key: str, - model_name: str = "", - timeout_seconds: float = 5.0, -) -> dict[str, Any]: - normalized = str(base_url or "").strip().rstrip("/") - if not normalized: - return { - "name": name, - "status": "skipped", - "reason": "empty_url", - "model_name": model_name, - } - - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - candidates = [normalized, f"{normalized}/models"] - last_error = "" - for url in candidates: - start = time.perf_counter() - try: - timeout = ClientTimeout(total=timeout_seconds) - async with ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as resp: - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": name, - "status": "ok", - "url": _mask_url(url), - "http_status": resp.status, - "latency_ms": elapsed_ms, - "model_name": model_name, - } - except Exception as exc: - last_error = str(exc) - continue - - return { - "name": name, - "status": "error", - "url": _mask_url(normalized), - "error": last_error or "request_failed", - "model_name": model_name, - } - - -async def _skipped_probe( - *, name: str, reason: str, model_name: str = "" -) -> dict[str, Any]: - payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} - if model_name: - payload["model_name"] = model_name - return payload - - -def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: - payload = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - if hasattr(mcfg, "api_mode"): - payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") - if hasattr(mcfg, "thinking_enabled"): - payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) - if hasattr(mcfg, "thinking_tool_call_compat"): - payload["thinking_tool_call_compat"] = getattr( - mcfg, "thinking_tool_call_compat", True - ) - if hasattr(mcfg, "responses_tool_choice_compat"): - payload["responses_tool_choice_compat"] = getattr( - mcfg, "responses_tool_choice_compat", False - ) - if hasattr(mcfg, "responses_force_stateless_replay"): - payload["responses_force_stateless_replay"] = getattr( - mcfg, "responses_force_stateless_replay", False - ) - if hasattr(mcfg, "prompt_cache_enabled"): - payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) - if hasattr(mcfg, "reasoning_enabled"): - payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) - if hasattr(mcfg, "reasoning_effort"): - payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") - return payload - - -async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: - normalized = str(url or "").strip() - if not normalized: - return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} - - parsed = urlsplit(normalized) - host = parsed.hostname - if not host: - return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} - - if parsed.port is not None: - port = parsed.port - elif parsed.scheme == "wss": - port = 443 - else: - port = 80 - - start = time.perf_counter() - try: - conn = asyncio.open_connection(host, port) - reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) - writer.close() - await writer.wait_closed() - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": "onebot_ws", - "status": "ok", - "host": host, - "port": port, - "latency_ms": elapsed_ms, - } - except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: - return { - "name": "onebot_ws", - "status": "error", - "host": host, - "port": port, - "error": str(exc), - } - - -def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: - server_url = f"{request.scheme}://{request.host}" - cfg = ctx.config_getter() - naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) - paths: dict[str, Any] = { - "/health": { - "get": { - "summary": "Health check", - "security": [], - "responses": {"200": {"description": "OK"}}, - } - }, - "/openapi.json": { - "get": { - "summary": "OpenAPI schema", - "security": [], - "responses": {"200": {"description": "Schema JSON"}}, - } - }, - "/api/v1/probes/internal": { - "get": { - "summary": "Internal runtime probes", - "description": ( - "Returns system info (version, Python, platform, uptime), " - "OneBot connection status, request queue snapshot, " - "memory count, cognitive service status, API config, " - "skill statistics (tools/agents/anthropic_skills with call counts), " - "and model configuration (names, masked URLs, thinking flags)." - ), - } - }, - "/api/v1/probes/external": { - "get": { - "summary": "External dependency probes", - "description": ( - "Concurrently probes all configured model API endpoints " - "(chat, vision, security, naga, agent, embedding, rerank) " - "and OneBot WebSocket. Each result includes status, " - "model name, masked URL, HTTP status code, and latency." - ), - } - }, - "/api/v1/memory": { - "get": {"summary": "List/search manual memories"}, - "post": {"summary": "Create a manual memory"}, - }, - "/api/v1/memory/{uuid}": { - "patch": {"summary": "Update a manual memory by UUID"}, - "delete": {"summary": "Delete a manual memory by UUID"}, - }, - "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, - "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, - "/api/v1/memes/{uid}": { - "get": {"summary": "Get a meme by uid"}, - "patch": {"summary": "Update a meme by uid"}, - "delete": {"summary": "Delete a meme by uid"}, - }, - "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, - "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, - "/api/v1/memes/{uid}/reanalyze": { - "post": {"summary": "Queue a meme reanalyze job"} - }, - "/api/v1/memes/{uid}/reindex": { - "post": {"summary": "Queue a meme reindex job"} - }, - "/api/v1/cognitive/events": { - "get": {"summary": "Search cognitive event memories"} - }, - "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, - "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { - "get": {"summary": "Get a profile by entity type/id"} - }, - "/api/v1/chat": { - "post": { - "summary": "WebUI special private chat", - "description": ( - "POST JSON {message, stream?}. " - "When stream=true, response is SSE with keep-alive comments." - ), - } - }, - "/api/v1/chat/history": { - "get": {"summary": "Get virtual private chat history for WebUI"} - }, - "/api/v1/tools": { - "get": { - "summary": "List available tools", - "description": ( - "Returns currently available tools filtered by " - "tool_invoke_expose / allowlist / denylist config. " - "Each item follows the OpenAI function calling schema." - ), - } - }, - "/api/v1/tools/invoke": { - "post": { - "summary": "Invoke a tool", - "description": ( - "Execute a specific tool by name. Supports synchronous " - "response and optional async webhook callback." - ), - } - }, - } - - if naga_routes_enabled: - paths.update( - { - "/api/v1/naga/bind/callback": { - "post": { - "summary": "Finalize a Naga bind request", - "description": ( - "Internal callback used by Naga to approve or reject " - "a bind_uuid." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/messages/send": { - "post": { - "summary": "Send a Naga-authenticated message", - "description": ( - "Validates bind_uuid + delivery_signature, runs " - "moderation, then delivers the message. " - "Caller may provide uuid for idempotent retry deduplication." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/unbind": { - "post": { - "summary": "Revoke an active Naga binding", - "description": ( - "Allows Naga to proactively revoke a binding using " - "Authorization: Bearer ." - ), - "security": [{"BearerAuth": []}], - } - }, - } - ) - - return { - "openapi": "3.0.3", - "info": { - "title": "Undefined Runtime API", - "version": __version__, - "description": "API exposed by the main Undefined process.", - }, - "servers": [ - { - "url": server_url, - "description": "Runtime endpoint", - } - ], - "components": { - "securitySchemes": { - "ApiKeyAuth": { - "type": "apiKey", - "in": "header", - "name": _AUTH_HEADER, - }, - "BearerAuth": {"type": "http", "scheme": "bearer"}, - } - }, - "security": [{"ApiKeyAuth": []}], - "paths": paths, - } - - class RuntimeAPIServer: def __init__( self, diff --git a/tests/test_runtime_api_sender_proxy.py b/tests/test_runtime_api_sender_proxy.py index be0c82c..77fc805 100644 --- a/tests/test_runtime_api_sender_proxy.py +++ b/tests/test_runtime_api_sender_proxy.py @@ -2,7 +2,7 @@ import pytest -from Undefined.api.app import _WebUIVirtualSender +from Undefined.api._helpers import _WebUIVirtualSender @pytest.mark.asyncio diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 9db1330..16dc141 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -10,7 +10,7 @@ from aiohttp.web_response import Response from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api.app import _validate_callback_url +from Undefined.api._helpers import _validate_callback_url def _json(response: Response) -> Any: diff --git a/tests/test_webui_management_api.py b/tests/test_webui_management_api.py index 435b710..9391279 100644 --- a/tests/test_webui_management_api.py +++ b/tests/test_webui_management_api.py @@ -6,7 +6,7 @@ from aiohttp import web -from Undefined.api import app as runtime_api_app +from Undefined.api import _helpers as runtime_api_helpers from Undefined.webui import app as webui_app from Undefined.webui.app import create_app from Undefined.webui.core import SessionStore @@ -350,7 +350,7 @@ def test_webui_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: monkeypatch.setattr( - runtime_api_app, + runtime_api_helpers, "load_webui_settings", lambda: SimpleNamespace(url="127.0.0.1", port=8787), ) @@ -359,7 +359,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "tauri://localhost"})), ) trusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(trusted_request, trusted_response) + runtime_api_helpers._apply_cors_headers(trusted_request, trusted_response) assert trusted_response.headers.get("Access-Control-Allow-Origin") == ( "tauri://localhost" ) @@ -370,7 +370,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "http://localhost:1420"})), ) loopback_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(loopback_request, loopback_response) + runtime_api_helpers._apply_cors_headers(loopback_request, loopback_response) assert loopback_response.headers.get("Access-Control-Allow-Origin") == ( "http://localhost:1420" ) @@ -380,7 +380,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "https://evil.example"})), ) untrusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(untrusted_request, untrusted_response) + runtime_api_helpers._apply_cors_headers(untrusted_request, untrusted_response) assert "Access-Control-Allow-Origin" not in untrusted_response.headers assert "Access-Control-Allow-Credentials" not in untrusted_response.headers From 72ac2d9cd97f15bbadfa62e7e05bd65d9839de78 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:44:21 +0800 Subject: [PATCH 32/57] fix(queue): register historian model in queue interval builder The build_model_queue_intervals() function was missing the historian model (glm-5), causing it to fall back to the default 1.0s dispatch interval instead of the configured 0s. This slowed background historian tasks unnecessarily. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/utils/queue_intervals.py | 4 ++++ tests/test_config_hot_reload.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/src/Undefined/utils/queue_intervals.py b/src/Undefined/utils/queue_intervals.py index 3b068b8..0596d7d 100644 --- a/src/Undefined/utils/queue_intervals.py +++ b/src/Undefined/utils/queue_intervals.py @@ -24,6 +24,10 @@ def build_model_queue_intervals(config: Config) -> dict[str, float]: summary_model.queue_interval_seconds, ), (config.grok_model.model_name, config.grok_model.queue_interval_seconds), + ( + config.historian_model.model_name, + config.historian_model.queue_interval_seconds, + ), ) intervals: dict[str, float] = {} for model_name, interval in pairs: diff --git a/tests/test_config_hot_reload.py b/tests/test_config_hot_reload.py index a244162..4e0158a 100644 --- a/tests/test_config_hot_reload.py +++ b/tests/test_config_hot_reload.py @@ -62,6 +62,10 @@ def test_apply_config_updates_propagates_to_security_service() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() @@ -120,6 +124,10 @@ def test_apply_config_updates_hot_reloads_ai_request_max_retries() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() From e972c6e07d0be7a1fbac2eeb9f166c9e5caab31c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:57:30 +0800 Subject: [PATCH 33/57] perf(handlers): parallelize message preprocessing with asyncio.gather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Group messages: run attachment collection, group info fetch, and history content parsing concurrently instead of serially. Reduces preprocessing latency from sum(A+B+C) to max(A,B,C). Private messages: run attachment collection and history content parsing concurrently. No behavioral changes — all data is still fully prepared before any feature checks (keyword, repeat, bilibili, arxiv, commands, AI). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 77 ++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 4bd21a8..726fa5c 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -496,10 +496,19 @@ async def handle_message(self, event: dict[str, Any]) -> None: logger.warning("获取用户昵称失败: %s", exc) text = extract_text(private_message_content, self.config.bot_qq) - private_attachments = await self._collect_message_attachments( - private_message_content, - user_id=private_sender_id, - request_type="private", + # 并行执行附件收集和历史内容解析 + private_attachments, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ), + parse_message_content_for_history( + private_message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), ) safe_text = redact_string(text) logger.info( @@ -515,15 +524,10 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_name=resolved_private_name, ) - # 保存私聊消息到历史记录(保存处理后的内容) - # 使用新的工具函数解析内容 - parsed_content = await parse_message_content_for_history( - private_message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, + # 保存私聊消息到历史记录 + parsed_content = append_attachment_text( + parsed_content_raw, private_attachments ) - parsed_content = append_attachment_text(parsed_content, private_attachments) safe_parsed = redact_string(parsed_content) logger.debug( "[历史记录] 保存私聊: user=%s content=%s...", @@ -648,26 +652,37 @@ async def handle_message(self, event: dict[str, Any]) -> None: # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) - group_attachments = await self._collect_message_attachments( - message_content, - group_id=group_id, - request_type="group", - ) safe_text = redact_string(text) logger.info( f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " f"role={sender_role} | {safe_text[:100]}" ) - # 保存消息到历史记录 (使用处理后的内容) - # 获取群聊名 - group_name = "" - try: - group_info = await self.onebot.get_group_info(group_id) - if group_info: - group_name = group_info.get("group_name", "") - except Exception as e: - logger.warning(f"获取群聊名失败: {e}") + # 并行执行 3 个独立的异步操作:附件收集、群信息获取、历史内容解析 + async def _fetch_group_name() -> str: + try: + info = await self.onebot.get_group_info(group_id) + if info: + return str(info.get("group_name", "") or "") + except Exception as e: + logger.warning(f"获取群聊名失败: {e}") + return "" + + group_attachments, group_name, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ), + _fetch_group_name(), + parse_message_content_for_history( + message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() self._schedule_profile_display_name_refresh( task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", @@ -677,14 +692,8 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=str(group_name or "").strip(), ) - # 使用新的 utils - parsed_content = await parse_message_content_for_history( - message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ) - parsed_content = append_attachment_text(parsed_content, group_attachments) + # 保存消息到历史记录 + parsed_content = append_attachment_text(parsed_content_raw, group_attachments) safe_parsed = redact_string(parsed_content) logger.debug( f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." From a3fc737a102f6b295bda84f082949dd31379d855 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 13:55:45 +0800 Subject: [PATCH 34/57] refactor(api): split app.py into route submodules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract 2491-line app.py into focused route modules under api/routes/: - health.py (23 lines) — health endpoint - system.py (228 lines) — OpenAPI + internal/external probes - memory.py (156 lines) — memory CRUD - memes.py (222 lines) — meme management (9 handlers) - cognitive.py (86 lines) — cognitive events & profiles - chat.py (359 lines) — WebUI chat with SSE streaming - tools.py (462 lines) — tool invoke with async callbacks - naga.py (897 lines) — Naga bind/send/unbind + moderation Infrastructure: - _context.py: RuntimeAPIContext dataclass (shared import root) - _naga_state.py: NagaState class (request dedup + inflight tracking) app.py reduced to 333 lines: server class + thin delegation wrappers that preserve test compatibility (no test API changes needed). Updated 4 test files to monkeypatch route-module-level symbols instead of app-module-level symbols. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/api/__init__.py | 3 +- src/Undefined/api/_context.py | 22 + src/Undefined/api/_naga_state.py | 115 ++ src/Undefined/api/_openapi.py | 2 +- src/Undefined/api/app.py | 2292 +------------------------ src/Undefined/api/routes/__init__.py | 1 + src/Undefined/api/routes/chat.py | 359 ++++ src/Undefined/api/routes/cognitive.py | 86 + src/Undefined/api/routes/health.py | 23 + src/Undefined/api/routes/memes.py | 222 +++ src/Undefined/api/routes/memory.py | 156 ++ src/Undefined/api/routes/naga.py | 897 ++++++++++ src/Undefined/api/routes/system.py | 228 +++ src/Undefined/api/routes/tools.py | 462 +++++ tests/test_runtime_api_chat_stream.py | 12 +- tests/test_runtime_api_naga.py | 8 +- tests/test_runtime_api_probes.py | 8 +- tests/test_runtime_api_tool_invoke.py | 4 +- 18 files changed, 2659 insertions(+), 2241 deletions(-) create mode 100644 src/Undefined/api/_context.py create mode 100644 src/Undefined/api/_naga_state.py create mode 100644 src/Undefined/api/routes/__init__.py create mode 100644 src/Undefined/api/routes/chat.py create mode 100644 src/Undefined/api/routes/cognitive.py create mode 100644 src/Undefined/api/routes/health.py create mode 100644 src/Undefined/api/routes/memes.py create mode 100644 src/Undefined/api/routes/memory.py create mode 100644 src/Undefined/api/routes/naga.py create mode 100644 src/Undefined/api/routes/system.py create mode 100644 src/Undefined/api/routes/tools.py diff --git a/src/Undefined/api/__init__.py b/src/Undefined/api/__init__.py index f71c40d..1ab331b 100644 --- a/src/Undefined/api/__init__.py +++ b/src/Undefined/api/__init__.py @@ -1,5 +1,6 @@ """Runtime API server for Undefined main process.""" -from .app import RuntimeAPIContext, RuntimeAPIServer +from ._context import RuntimeAPIContext +from .app import RuntimeAPIServer __all__ = ["RuntimeAPIContext", "RuntimeAPIServer"] diff --git a/src/Undefined/api/_context.py b/src/Undefined/api/_context.py new file mode 100644 index 0000000..07d8c92 --- /dev/null +++ b/src/Undefined/api/_context.py @@ -0,0 +1,22 @@ +"""Shared context types for the Runtime API.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + + +@dataclass +class RuntimeAPIContext: + config_getter: Callable[[], Any] + onebot: Any + ai: Any + command_dispatcher: Any + queue_manager: Any + history_manager: Any + sender: Any = None + scheduler: Any = None + cognitive_service: Any = None + cognitive_job_queue: Any = None + meme_service: Any = None + naga_store: Any = None diff --git a/src/Undefined/api/_naga_state.py b/src/Undefined/api/_naga_state.py new file mode 100644 index 0000000..26bff9f --- /dev/null +++ b/src/Undefined/api/_naga_state.py @@ -0,0 +1,115 @@ +"""Naga request deduplication and inflight tracking state.""" + +from __future__ import annotations + +import asyncio +import time +from copy import deepcopy +from typing import Any + +from ._helpers import _NagaRequestResult + +_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 + + +class NagaState: + """Tracks inflight Naga sends and provides request-uuid deduplication.""" + + def __init__(self) -> None: + self.send_registry_lock = asyncio.Lock() + self.send_inflight: dict[str, int] = {} + self.request_uuid_lock = asyncio.Lock() + self.request_uuid_inflight: dict[ + str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] + ] = {} + self.request_uuid_results: dict[str, _NagaRequestResult] = {} + + async def track_send_start(self, message_key: str) -> int: + async with self.send_registry_lock: + next_count = self.send_inflight.get(message_key, 0) + 1 + self.send_inflight[message_key] = next_count + return next_count + + async def track_send_done(self, message_key: str) -> int: + async with self.send_registry_lock: + current = self.send_inflight.get(message_key, 0) + if current <= 1: + self.send_inflight.pop(message_key, None) + return 0 + next_count = current - 1 + self.send_inflight[message_key] = next_count + return next_count + + def _prune_request_uuid_state_locked(self) -> None: + now = time.time() + expired = [ + request_uuid + for request_uuid, result in self.request_uuid_results.items() + if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS + ] + for request_uuid in expired: + self.request_uuid_results.pop(request_uuid, None) + + async def register_request_uuid( + self, request_uuid: str, payload_hash: str + ) -> tuple[str, Any]: + async with self.request_uuid_lock: + self._prune_request_uuid_state_locked() + + cached = self.request_uuid_results.get(request_uuid) + if cached is not None: + if cached.payload_hash != payload_hash: + return "conflict", cached.payload_hash + return "cached", (cached.status, deepcopy(cached.payload)) + + inflight = self.request_uuid_inflight.get(request_uuid) + if inflight is not None: + existing_hash, inflight_future = inflight + if existing_hash != payload_hash: + return "conflict", existing_hash + return "await", inflight_future + + owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( + asyncio.get_running_loop().create_future() + ) + self.request_uuid_inflight[request_uuid] = ( + payload_hash, + owner_future, + ) + return "owner", owner_future + + async def finish_request_uuid( + self, + request_uuid: str, + payload_hash: str, + *, + status: int, + payload: dict[str, Any], + ) -> None: + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + result_payload = deepcopy(payload) + self.request_uuid_results[request_uuid] = _NagaRequestResult( + payload_hash=payload_hash, + status=status, + payload=result_payload, + finished_at=time.time(), + ) + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_result((status, deepcopy(result_payload))) + + async def fail_request_uuid( + self, + request_uuid: str, + payload_hash: str, + exc: BaseException, + ) -> None: + _ = payload_hash + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_exception(exc) diff --git a/src/Undefined/api/_openapi.py b/src/Undefined/api/_openapi.py index 6c4566c..1076256 100644 --- a/src/Undefined/api/_openapi.py +++ b/src/Undefined/api/_openapi.py @@ -10,7 +10,7 @@ from ._helpers import _AUTH_HEADER, _naga_routes_enabled if TYPE_CHECKING: - from .app import RuntimeAPIContext + from ._context import RuntimeAPIContext def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 91ceaf4..7ad5e56 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1,91 +1,32 @@ -"""Runtime API server for Undefined.""" +"""Runtime API server for Undefined. + +Route handler logic lives in ``routes/`` sub-modules. This file keeps only +the ``RuntimeAPIServer`` class (init / start / stop / middleware / routing) +and thin one-liner wrappers that delegate to the route functions so that +existing tests calling ``server._xxx_handler(request)`` keep working. +""" from __future__ import annotations import asyncio -from copy import deepcopy import logging -import os -import platform -from pathlib import Path -import sys -import time -import uuid as _uuid -from contextlib import suppress -from dataclasses import dataclass -from datetime import datetime from typing import Any, Awaitable, Callable -from typing import cast -from aiohttp import ClientSession, ClientTimeout, web +from aiohttp import web from aiohttp.web_response import Response -from Undefined import __version__ -from Undefined.attachments import ( - attachment_refs_to_xml, - build_attachment_scope, - register_message_attachments, - render_message_with_pic_placeholders, -) -from Undefined.context import RequestContext -from Undefined.context_resource_registry import collect_context_resources -from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 -from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN -from Undefined.utils.common import message_to_segments -from Undefined.utils.recent_messages import get_recent_messages_prefer_local -from Undefined.utils.xml import escape_xml_attr, escape_xml_text - +from ._context import RuntimeAPIContext from ._helpers import ( _apply_cors_headers, - _build_chat_response_payload, _json_error, - _mask_url, - _naga_message_digest, _naga_routes_enabled, _naga_runtime_enabled, - _NagaRequestResult, - _optional_query_param, - _parse_query_time, - _parse_response_payload, - _registry_summary, - _short_text_preview, - _sse_event, - _to_bool, - _ToolInvokeExecutionTimeoutError, - _validate_callback_url, - _WebUIVirtualSender, _AUTH_HEADER, - _VIRTUAL_USER_ID, ) -from ._probes import ( - _build_internal_model_probe_payload, - _probe_http_endpoint, - _probe_ws_endpoint, - _skipped_probe, -) -from ._openapi import _build_openapi_spec +from ._naga_state import NagaState +from .routes import chat, cognitive, health, memes, memory, naga, system, tools logger = logging.getLogger(__name__) -_VIRTUAL_USER_NAME = "system" -_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 -_PROCESS_START_TIME = time.time() -_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 - - -@dataclass -class RuntimeAPIContext: - config_getter: Callable[[], Any] - onebot: Any - ai: Any - command_dispatcher: Any - queue_manager: Any - history_manager: Any - sender: Any = None - scheduler: Any = None - cognitive_service: Any = None - cognitive_job_queue: Any = None - meme_service: Any = None - naga_store: Any = None class RuntimeAPIServer: @@ -101,13 +42,7 @@ def __init__( self._runner: web.AppRunner | None = None self._sites: list[web.TCPSite] = [] self._background_tasks: set[asyncio.Task[Any]] = set() - self._naga_send_registry_lock = asyncio.Lock() - self._naga_send_inflight: dict[str, int] = {} - self._naga_request_uuid_lock = asyncio.Lock() - self._naga_request_uuid_inflight: dict[ - str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] - ] = {} - self._naga_request_uuid_results: dict[str, _NagaRequestResult] = {} + self._naga_state = NagaState() async def start(self) -> None: from Undefined.config.models import resolve_bind_hosts @@ -123,7 +58,6 @@ async def start(self) -> None: logger.info("[RuntimeAPI] 已启动: %s", cfg.api.display_url) async def stop(self) -> None: - # 取消所有后台任务(如异步 tool invoke 回调) for task in self._background_tasks: task.cancel() if self._background_tasks: @@ -148,7 +82,6 @@ async def _auth_middleware( _apply_cors_headers(request, response) return response if request.path.startswith("/api/"): - # Naga 端点使用独立鉴权,仅在总开关+子开关均启用时跳过主 auth cfg = self._context.config_getter() is_naga_path = request.path.startswith("/api/v1/naga/") skip_auth = is_naga_path and _naga_runtime_enabled(cfg) @@ -205,7 +138,6 @@ async def _auth_middleware( web.post("/api/v1/tools/invoke", self._tools_invoke_handler), ] ) - # Naga 端点仅在总开关+子开关均启用时注册 cfg = self._context.config_getter() if _naga_routes_enabled(cfg, self._context.naga_store): app.add_routes( @@ -237,1233 +169,108 @@ async def _auth_middleware( ) return app - async def _track_naga_send_start(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - next_count = self._naga_send_inflight.get(message_key, 0) + 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - async def _track_naga_send_done(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - current = self._naga_send_inflight.get(message_key, 0) - if current <= 1: - self._naga_send_inflight.pop(message_key, None) - return 0 - next_count = current - 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - def _prune_naga_request_uuid_state_locked(self) -> None: - now = time.time() - expired = [ - request_uuid - for request_uuid, result in self._naga_request_uuid_results.items() - if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS - ] - for request_uuid in expired: - self._naga_request_uuid_results.pop(request_uuid, None) - - async def _register_naga_request_uuid( - self, request_uuid: str, payload_hash: str - ) -> tuple[str, Any]: - async with self._naga_request_uuid_lock: - self._prune_naga_request_uuid_state_locked() - - cached = self._naga_request_uuid_results.get(request_uuid) - if cached is not None: - if cached.payload_hash != payload_hash: - return "conflict", cached.payload_hash - return "cached", (cached.status, deepcopy(cached.payload)) - - inflight = self._naga_request_uuid_inflight.get(request_uuid) - if inflight is not None: - existing_hash, inflight_future = inflight - if existing_hash != payload_hash: - return "conflict", existing_hash - return "await", inflight_future - - owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( - asyncio.get_running_loop().create_future() - ) - self._naga_request_uuid_inflight[request_uuid] = ( - payload_hash, - owner_future, - ) - return "owner", owner_future - - async def _finish_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - *, - status: int, - payload: dict[str, Any], - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - result_payload = deepcopy(payload) - self._naga_request_uuid_results[request_uuid] = _NagaRequestResult( - payload_hash=payload_hash, - status=status, - payload=result_payload, - finished_at=time.time(), - ) - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_result((status, deepcopy(result_payload))) - - async def _fail_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - exc: BaseException, - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_exception(exc) - @property def _ctx(self) -> RuntimeAPIContext: return self._context + # ------------------------------------------------------------------ + # Thin delegation wrappers — keep tests calling server._xxx(request) + # ------------------------------------------------------------------ + + # Health / System async def _health_handler(self, request: web.Request) -> Response: - _ = request - return web.json_response( - { - "ok": True, - "service": "undefined-runtime-api", - "version": __version__, - "timestamp": datetime.now().isoformat(), - } - ) + return await health.health_handler(self._ctx, request) async def _openapi_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not bool(getattr(cfg.api, "openapi_enabled", True)): - logger.info( - "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote - ) - return _json_error("OpenAPI disabled", status=404) - naga_routes_enabled = _naga_routes_enabled(cfg, self._ctx.naga_store) - logger.info( - "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", - request.remote, - naga_routes_enabled, - ) - return web.json_response(_build_openapi_spec(self._ctx, request)) + return await system.openapi_handler(self._ctx, request) async def _internal_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - queue_snapshot = ( - self._ctx.queue_manager.snapshot() if self._ctx.queue_manager else {} - ) - cognitive_queue_snapshot = ( - self._ctx.cognitive_job_queue.snapshot() - if self._ctx.cognitive_job_queue - else {} - ) - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - memory_count = memory_storage.count() if memory_storage is not None else 0 - - # Skills 统计 - ai = self._ctx.ai - skills_info: dict[str, Any] = {} - if ai is not None: - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - anthropic_reg = getattr(ai, "anthropic_skill_registry", None) - skills_info["tools"] = _registry_summary(tool_reg) - skills_info["agents"] = _registry_summary(agent_reg) - skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) - - # 模型配置(脱敏) - models_info: dict[str, Any] = {} - summary_model = getattr( - cfg, - "summary_model", - getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), - ) - for label in ( - "chat_model", - "vision_model", - "agent_model", - "security_model", - "naga_model", - "grok_model", - ): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = _build_internal_model_probe_payload(mcfg) - if summary_model is not None: - models_info["summary_model"] = _build_internal_model_probe_payload( - summary_model - ) - for label in ("embedding_model", "rerank_model"): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - - uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) - payload = { - "timestamp": datetime.now().isoformat(), - "version": __version__, - "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", - "platform": platform.system(), - "uptime_seconds": uptime_seconds, - "onebot": self._ctx.onebot.connection_status() - if self._ctx.onebot is not None - else {}, - "queues": queue_snapshot, - "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, - "cognitive": { - "enabled": bool( - self._ctx.cognitive_service and self._ctx.cognitive_service.enabled - ), - "queue": cognitive_queue_snapshot, - }, - "api": { - "enabled": bool(cfg.api.enabled), - "host": cfg.api.host, - "port": cfg.api.port, - "openapi_enabled": bool(cfg.api.openapi_enabled), - }, - "skills": skills_info, - "models": models_info, - } - return web.json_response(payload) + return await system.internal_probe_handler(self._ctx, request) async def _external_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - summary_model = getattr( - cfg, - "summary_model", - getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), - ) - naga_probe = ( - _probe_http_endpoint( - name="naga_model", - base_url=cfg.naga_model.api_url, - api_key=cfg.naga_model.api_key, - model_name=cfg.naga_model.model_name, - ) - if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) - else _skipped_probe( - name="naga_model", - reason="naga_integration_disabled", - model_name=cfg.naga_model.model_name, - ) - ) - checks = [ - _probe_http_endpoint( - name="chat_model", - base_url=cfg.chat_model.api_url, - api_key=cfg.chat_model.api_key, - model_name=cfg.chat_model.model_name, - ), - _probe_http_endpoint( - name="vision_model", - base_url=cfg.vision_model.api_url, - api_key=cfg.vision_model.api_key, - model_name=cfg.vision_model.model_name, - ), - _probe_http_endpoint( - name="security_model", - base_url=cfg.security_model.api_url, - api_key=cfg.security_model.api_key, - model_name=cfg.security_model.model_name, - ), - naga_probe, - _probe_http_endpoint( - name="agent_model", - base_url=cfg.agent_model.api_url, - api_key=cfg.agent_model.api_key, - model_name=cfg.agent_model.model_name, - ), - ] - if summary_model is not None: - checks.append( - _probe_http_endpoint( - name="summary_model", - base_url=summary_model.api_url, - api_key=summary_model.api_key, - model_name=summary_model.model_name, - ) - ) - grok_model = getattr(cfg, "grok_model", None) - if grok_model is not None: - checks.append( - _probe_http_endpoint( - name="grok_model", - base_url=getattr(grok_model, "api_url", ""), - api_key=getattr(grok_model, "api_key", ""), - model_name=getattr(grok_model, "model_name", ""), - ) - ) - checks.extend( - [ - _probe_http_endpoint( - name="embedding_model", - base_url=cfg.embedding_model.api_url, - api_key=cfg.embedding_model.api_key, - model_name=getattr(cfg.embedding_model, "model_name", ""), - ), - _probe_http_endpoint( - name="rerank_model", - base_url=cfg.rerank_model.api_url, - api_key=cfg.rerank_model.api_key, - model_name=getattr(cfg.rerank_model, "model_name", ""), - ), - _probe_ws_endpoint(cfg.onebot_ws_url), - ] - ) - results = await asyncio.gather(*checks) - ok = all(item.get("status") in {"ok", "skipped"} for item in results) - return web.json_response( - { - "ok": ok, - "timestamp": datetime.now().isoformat(), - "results": results, - } - ) + return await system.external_probe_handler(self._ctx, request) + # Memory CRUD async def _memory_handler(self, request: web.Request) -> Response: - query = str(request.query.get("q", "") or "").strip().lower() - top_k_raw = _optional_query_param(request, "top_k") - time_from_raw = _optional_query_param(request, "time_from") - time_to_raw = _optional_query_param(request, "time_to") - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - - limit: int | None = None - if top_k_raw is not None: - try: - limit = int(top_k_raw) - except ValueError: - return _json_error("top_k must be an integer", status=400) - if limit <= 0: - return _json_error("top_k must be > 0", status=400) - - time_from_dt = _parse_query_time(time_from_raw) - if time_from_raw is not None and time_from_dt is None: - return _json_error("time_from must be ISO datetime", status=400) - time_to_dt = _parse_query_time(time_to_raw) - if time_to_raw is not None and time_to_dt is None: - return _json_error("time_to must be ISO datetime", status=400) - if time_from_dt and time_to_dt and time_from_dt > time_to_dt: - time_from_dt, time_to_dt = time_to_dt, time_from_dt - - records = memory_storage.get_all() - items: list[dict[str, Any]] = [] - for item in records: - created_at = str(item.created_at or "").strip() - created_dt = _parse_query_time(created_at) - if time_from_dt and created_dt and created_dt < time_from_dt: - continue - if time_to_dt and created_dt and created_dt > time_to_dt: - continue - if (time_from_dt or time_to_dt) and created_dt is None: - continue - items.append( - { - "uuid": item.uuid, - "fact": item.fact, - "created_at": created_at, - } - ) - if query: - items = [ - item - for item in items - if query in str(item.get("fact", "")).lower() - or query in str(item.get("uuid", "")).lower() - ] - - def _created_sort_key(item: dict[str, Any]) -> float: - created_dt = _parse_query_time(str(item.get("created_at") or "")) - if created_dt is None: - return float("-inf") - with suppress(OSError, OverflowError, ValueError): - return float(created_dt.timestamp()) - return float("-inf") - - items.sort(key=_created_sort_key) - if limit is not None: - items = items[:limit] - - return web.json_response( - { - "total": len(items), - "items": items, - "query": { - "q": query or "", - "top_k": limit, - "time_from": time_from_raw, - "time_to": time_to_raw, - }, - } - ) + return await memory.memory_list_handler(self._ctx, request) async def _memory_create_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - fact = str(body.get("fact", "") or "").strip() - if not fact: - return _json_error("fact must not be empty", status=400) - new_uuid = await memory_storage.add(fact) - if new_uuid is None: - return _json_error("Failed to create memory", status=500) - # add() returns existing UUID on duplicate - existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] - item = existing[0] if existing else None - return web.json_response( - { - "uuid": new_uuid, - "fact": item.fact if item else fact, - "created_at": item.created_at if item else "", - }, - status=201, - ) + return await memory.memory_create_handler(self._ctx, request) async def _memory_update_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - target_uuid = str(request.match_info.get("uuid", "")).strip() - if not target_uuid: - return _json_error("uuid is required", status=400) - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - fact = str(body.get("fact", "") or "").strip() - if not fact: - return _json_error("fact must not be empty", status=400) - ok = await memory_storage.update(target_uuid, fact) - if not ok: - return _json_error(f"Memory {target_uuid} not found", status=404) - return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + return await memory.memory_update_handler(self._ctx, request) async def _memory_delete_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - target_uuid = str(request.match_info.get("uuid", "")).strip() - if not target_uuid: - return _json_error("uuid is required", status=400) - ok = await memory_storage.delete(target_uuid) - if not ok: - return _json_error(f"Memory {target_uuid} not found", status=404) - return web.json_response({"uuid": target_uuid, "deleted": True}) + return await memory.memory_delete_handler(self._ctx, request) + # Memes async def _meme_list_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - - def _parse_optional_bool(name: str) -> bool | None: - raw = request.query.get(name) - if raw is None or str(raw).strip() == "": - return None - return _to_bool(raw) - - page_raw = _optional_query_param(request, "page") - page_size_raw = _optional_query_param(request, "page_size") - top_k_raw = _optional_query_param(request, "top_k") - query = str(request.query.get("q", "") or "").strip() - query_mode = str(request.query.get("query_mode", "") or "").strip().lower() - keyword_query = str(request.query.get("keyword_query", "") or "").strip() - semantic_query = str(request.query.get("semantic_query", "") or "").strip() - try: - page = int(page_raw) if page_raw is not None else 1 - page_size = int(page_size_raw) if page_size_raw is not None else 50 - top_k = int(top_k_raw) if top_k_raw is not None else page_size - except ValueError: - return _json_error("page/page_size/top_k must be integers", status=400) - page = max(1, page) - page_size = max(1, min(200, page_size)) - top_k = max(1, top_k) - sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() - - enabled_filter = _parse_optional_bool("enabled") - animated_filter = _parse_optional_bool("animated") - pinned_filter = _parse_optional_bool("pinned") - if not (query or keyword_query or semantic_query) and sort == "relevance": - sort = "updated_at" - - if query or keyword_query or semantic_query: - has_post_filter = any( - f is not None for f in (enabled_filter, animated_filter, pinned_filter) - ) - requested_window = max(page * page_size, top_k) - if has_post_filter or page > 1 or sort != "relevance": - fetch_k = min(500, max(requested_window * 4, top_k)) - else: - fetch_k = min(500, requested_window) - search_payload = await meme_service.search_memes( - query, - query_mode=query_mode or meme_service.default_query_mode, - keyword_query=keyword_query or None, - semantic_query=semantic_query or None, - top_k=fetch_k, - include_disabled=enabled_filter is not True, - sort=sort, - ) - filtered_items: list[dict[str, Any]] = [] - for item in list(search_payload.get("items") or []): - if ( - enabled_filter is not None - and bool(item.get("enabled")) != enabled_filter - ): - continue - if ( - animated_filter is not None - and bool(item.get("is_animated")) != animated_filter - ): - continue - if ( - pinned_filter is not None - and bool(item.get("pinned")) != pinned_filter - ): - continue - filtered_items.append(item) - offset = (page - 1) * page_size - paged_items = filtered_items[offset : offset + page_size] - window_total = len(filtered_items) - fetched_window_count = len(list(search_payload.get("items") or [])) - window_exhausted = fetched_window_count < fetch_k - has_more = bool(paged_items) and ( - offset + page_size < window_total - or (not window_exhausted and window_total >= offset + page_size) - ) - return web.json_response( - { - "ok": True, - "total": None, - "window_total": window_total, - "total_exact": False, - "page": page, - "page_size": page_size, - "has_more": has_more, - "query_mode": search_payload.get("query_mode"), - "keyword_query": search_payload.get("keyword_query"), - "semantic_query": search_payload.get("semantic_query"), - "sort": search_payload.get("sort", sort), - "items": paged_items, - } - ) - - payload = await meme_service.list_memes( - query=query, - enabled=enabled_filter, - animated=animated_filter, - pinned=pinned_filter, - sort=sort, - page=page, - page_size=page_size, - summary=True, - ) - return web.json_response(payload) + return await memes.meme_list_handler(self._ctx, request) async def _meme_stats_handler(self, request: web.Request) -> Response: - _ = request - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - return web.json_response(await meme_service.stats()) + return await memes.meme_stats_handler(self._ctx, request) async def _meme_detail_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - detail = await meme_service.get_meme(uid) - if detail is None: - return _json_error("Meme not found", status=404) - return web.json_response(detail) + return await memes.meme_detail_handler(self._ctx, request) async def _meme_blob_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=False) - if path is None: - return _json_error("Meme blob not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_blob_handler(self._ctx, request) async def _meme_preview_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=True) - if path is None: - return _json_error("Meme preview not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_preview_handler(self._ctx, request) async def _meme_update_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - try: - payload = await request.json() - except Exception: - return _json_error("Invalid JSON body", status=400) - if not isinstance(payload, dict): - return _json_error("JSON body must be an object", status=400) - updated = await meme_service.update_meme( - uid, - manual_description=payload.get("manual_description"), - tags=payload.get("tags"), - aliases=payload.get("aliases"), - enabled=payload.get("enabled") if "enabled" in payload else None, - pinned=payload.get("pinned") if "pinned" in payload else None, - ) - if updated is None: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "record": updated}) + return await memes.meme_update_handler(self._ctx, request) async def _meme_delete_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - deleted = await meme_service.delete_meme(uid) - if not deleted: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "uid": uid}) + return await memes.meme_delete_handler(self._ctx, request) async def _meme_reanalyze_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reanalyze(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reanalyze_handler(self._ctx, request) async def _meme_reindex_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reindex(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reindex_handler(self._ctx, request) + # Cognitive async def _cognitive_events_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - for key in ( - "target_user_id", - "target_group_id", - "sender_id", - "request_type", - "top_k", - "time_from", - "time_to", - ): - value = _optional_query_param(request, key) - if value is not None: - search_kwargs[key] = value - - results = await cognitive_service.search_events(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_events_handler(self._ctx, request) async def _cognitive_profiles_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - entity_type = _optional_query_param(request, "entity_type") - if entity_type is not None: - search_kwargs["entity_type"] = entity_type - top_k = _optional_query_param(request, "top_k") - if top_k is not None: - search_kwargs["top_k"] = top_k - - results = await cognitive_service.search_profiles(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_profiles_handler(self._ctx, request) async def _cognitive_profile_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - entity_type = str(request.match_info.get("entity_type", "")).strip() - entity_id = str(request.match_info.get("entity_id", "")).strip() - if not entity_type or not entity_id: - return _json_error("entity_type/entity_id are required", status=400) - - profile = await cognitive_service.get_profile(entity_type, entity_id) - return web.json_response( - { - "entity_type": entity_type, - "entity_id": entity_id, - "profile": profile or "", - "found": bool(profile), - } - ) + return await cognitive.cognitive_profile_handler(self._ctx, request) + # Chat async def _run_webui_chat( self, *, text: str, send_output: Callable[[int, str], Awaitable[None]], ) -> str: - cfg = self._ctx.config_getter() - permission_sender_id = int(cfg.superadmin_qq) - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - input_segments = message_to_segments(text) - registered_input = await register_message_attachments( - registry=self._ctx.ai.attachment_registry, - segments=input_segments, - scope_key=webui_scope_key, - resolve_image_url=self._ctx.onebot.get_image, - get_forward_messages=self._ctx.onebot.get_forward_msg, - ) - normalized_text = registered_input.normalized_text or text - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=normalized_text, - display_name=_VIRTUAL_USER_NAME, - user_name=_VIRTUAL_USER_NAME, - attachments=registered_input.attachments, - ) - - command = self._ctx.command_dispatcher.parse_command(normalized_text) - if command: - await self._ctx.command_dispatcher.dispatch_private( - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - command=command, - send_private_callback=send_output, - is_webui_session=True, - ) - return "command" - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - attachment_xml = ( - f"\n{attachment_refs_to_xml(registered_input.attachments)}" - if registered_input.attachments - else "" - ) - full_question = f""" - - -【WebUI 会话】 -这是一条来自 WebUI 控制台的会话请求。 -会话身份:虚拟用户 system(42)。 -权限等级:superadmin(你可按最高管理权限处理)。 -请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" - virtual_sender = _WebUIVirtualSender( - _VIRTUAL_USER_ID, send_output, onebot=self._ctx.onebot - ) - - async def _get_recent_cb( - chat_id: str, msg_type: str, start: int, end: int - ) -> list[dict[str, Any]]: - return await get_recent_messages_prefer_local( - chat_id=chat_id, - msg_type=msg_type, - start=start, - end=end, - onebot_client=self._ctx.onebot, - history_manager=self._ctx.history_manager, - bot_qq=cfg.bot_qq, - attachment_registry=getattr(self._ctx.ai, "attachment_registry", None), - ) - - async with RequestContext( - request_type="private", - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - ) as ctx: - # 与 ai_coordinator 保持一致:通过 collect_context_resources 自动注入 - ai_client = self._ctx.ai - memory_storage = self._ctx.ai.memory_storage - runtime_config = self._ctx.ai.runtime_config - sender = virtual_sender - history_manager = self._ctx.history_manager - onebot_client = self._ctx.onebot - scheduler = self._ctx.scheduler - - def send_message_callback( - msg: str, reply_to: int | None = None - ) -> Awaitable[None]: - _ = reply_to - return send_output(_VIRTUAL_USER_ID, msg) - - get_recent_messages_callback = _get_recent_cb - get_image_url_callback = self._ctx.onebot.get_image - get_forward_msg_callback = self._ctx.onebot.get_forward_msg - resource_vars = dict(globals()) - resource_vars.update(locals()) - resources = collect_context_resources(resource_vars) - for key, value in resources.items(): - if value is not None: - ctx.set_resource(key, value) - ctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) - ctx.set_resource("webui_session", True) - ctx.set_resource("webui_permission", "superadmin") - - result = await self._ctx.ai.ask( - full_question, - send_message_callback=send_message_callback, - get_recent_messages_callback=get_recent_messages_callback, - get_image_url_callback=get_image_url_callback, - get_forward_msg_callback=get_forward_msg_callback, - sender=sender, - history_manager=history_manager, - onebot_client=onebot_client, - scheduler=scheduler, - extra_context={ - "is_private_chat": True, - "request_type": "private", - "user_id": _VIRTUAL_USER_ID, - "sender_name": _VIRTUAL_USER_NAME, - "webui_session": True, - "webui_permission": "superadmin", - }, - ) - - final_reply = str(result or "").strip() - if final_reply: - await send_output(_VIRTUAL_USER_ID, final_reply) - - return "chat" + return await chat.run_webui_chat(self._ctx, text=text, send_output=send_output) async def _chat_history_handler(self, request: web.Request) -> Response: - limit_raw = str(request.query.get("limit", "200") or "200").strip() - try: - limit = int(limit_raw) - except ValueError: - limit = 200 - limit = max(1, min(limit, 500)) - - getter = getattr(self._ctx.history_manager, "get_recent_private", None) - if not callable(getter): - return _json_error("History manager not ready", status=503) - - records = getter(_VIRTUAL_USER_ID, limit) - items: list[dict[str, Any]] = [] - for item in records: - if not isinstance(item, dict): - continue - content = str(item.get("message", "")).strip() - if not content: - continue - display_name = str(item.get("display_name", "")).strip().lower() - role = "bot" if display_name == "bot" else "user" - items.append( - { - "role": role, - "content": content, - "timestamp": str(item.get("timestamp", "") or "").strip(), - } - ) - - return web.json_response( - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "count": len(items), - "items": items, - } - ) + return await chat.chat_history_handler(self._ctx, request) async def _chat_handler(self, request: web.Request) -> web.StreamResponse: - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - text = str(body.get("message", "") or "").strip() - if not text: - return _json_error("message is required", status=400) - - stream = _to_bool(body.get("stream")) - outputs: list[str] = [] - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - - async def _capture_private_message(user_id: int, message: str) -> None: - _ = user_id - content = str(message or "").strip() - if not content: - return - rendered = await render_message_with_pic_placeholders( - content, - registry=self._ctx.ai.attachment_registry, - scope_key=webui_scope_key, - strict=False, - ) - if not rendered.delivery_text.strip(): - return - outputs.append(rendered.delivery_text) - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=rendered.history_text, - display_name="Bot", - user_name="Bot", - attachments=rendered.attachments, - ) - - if not stream: - try: - mode = await self._run_webui_chat( - text=text, send_output=_capture_private_message - ) - except Exception as exc: - logger.exception("[RuntimeAPI] chat failed: %s", exc) - return _json_error("Chat failed", status=502) - return web.json_response(_build_chat_response_payload(mode, outputs)) - - response = web.StreamResponse( - status=200, - reason="OK", - headers={ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - await response.prepare(request) - - message_queue: asyncio.Queue[str] = asyncio.Queue() - - async def _capture_private_message_stream(user_id: int, message: str) -> None: - output_count = len(outputs) - await _capture_private_message(user_id, message) - if len(outputs) <= output_count: - return - content = outputs[-1].strip() - if content: - await message_queue.put(content) - - task = asyncio.create_task( - self._run_webui_chat(text=text, send_output=_capture_private_message_stream) - ) - mode = "chat" - client_disconnected = False - try: - await response.write( - _sse_event( - "meta", - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - }, - ) - ) - - while True: - if request.transport is None or request.transport.is_closing(): - client_disconnected = True - break - if task.done() and message_queue.empty(): - break - try: - message = await asyncio.wait_for( - message_queue.get(), - timeout=_CHAT_SSE_KEEPALIVE_SECONDS, - ) - await response.write(_sse_event("message", {"content": message})) - except asyncio.TimeoutError: - await response.write(b": keep-alive\n\n") - - if client_disconnected: - task.cancel() - with suppress(asyncio.CancelledError): - await task - return response - - mode = await task - await response.write( - _sse_event("done", _build_chat_response_payload(mode, outputs)) - ) - except asyncio.CancelledError: - task.cancel() - with suppress(asyncio.CancelledError): - await task - raise - except (ConnectionResetError, RuntimeError): - task.cancel() - with suppress(asyncio.CancelledError): - await task - except Exception as exc: - logger.exception("[RuntimeAPI] chat stream failed: %s", exc) - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - await task - with suppress(Exception): - await response.write(_sse_event("error", {"error": str(exc)})) - finally: - with suppress(Exception): - await response.write_eof() - - return response - - # ------------------------------------------------------------------ - # Tool Invoke API - # ------------------------------------------------------------------ + return await chat.chat_handler(self._ctx, request) + # Tools def _get_filtered_tools(self) -> list[dict[str, Any]]: - """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" - cfg = self._ctx.config_getter() - api_cfg = cfg.api - ai = self._ctx.ai - if ai is None: - return [] - - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - - all_schemas: list[dict[str, Any]] = [] - if tool_reg is not None: - all_schemas.extend(tool_reg.get_tools_schema()) - - # 收集 agent schema 并缓存名称集合(避免重复调用) - agent_names: set[str] = set() - if agent_reg is not None: - agent_schemas = agent_reg.get_agents_schema() - all_schemas.extend(agent_schemas) - for schema in agent_schemas: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - - denylist: set[str] = set(api_cfg.tool_invoke_denylist) - allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) - expose = api_cfg.tool_invoke_expose - - def _get_name(schema: dict[str, Any]) -> str: - func = schema.get("function", {}) - return str(func.get("name", "")) - - # 1. 先排除黑名单 - if denylist: - all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] - - # 2. 白名单非空时仅保留匹配项 - if allowlist: - return [s for s in all_schemas if _get_name(s) in allowlist] - - # 3. 按 expose 过滤 - if expose == "all": - return all_schemas - - def _is_tool(name: str) -> bool: - return "." not in name and name not in agent_names - - def _is_toolset(name: str) -> bool: - return "." in name and not name.startswith("mcp.") - - filtered: list[dict[str, Any]] = [] - for schema in all_schemas: - name = _get_name(schema) - if not name: - continue - if expose == "tools" and _is_tool(name): - filtered.append(schema) - elif expose == "toolsets" and _is_toolset(name): - filtered.append(schema) - elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): - filtered.append(schema) - elif expose == "agents" and name in agent_names: - filtered.append(schema) - - return filtered + return tools.get_filtered_tools(self._ctx) def _get_agent_tool_names(self) -> set[str]: - ai = self._ctx.ai - if ai is None: - return set() - - agent_reg = getattr(ai, "agent_registry", None) - if agent_reg is None: - return set() - - agent_names: set[str] = set() - for schema in agent_reg.get_agents_schema(): - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - return agent_names - - def _resolve_tool_invoke_timeout( - self, tool_name: str, timeout: int - ) -> float | None: - if tool_name in self._get_agent_tool_names(): - return None - return float(timeout) - - async def _await_tool_invoke_result( - self, - awaitable: Awaitable[Any], - *, - timeout: float | None, - ) -> Any: - if timeout is None or timeout <= 0: - return await awaitable - try: - return await asyncio.wait_for(awaitable, timeout=timeout) - except asyncio.TimeoutError as exc: - raise _ToolInvokeExecutionTimeoutError from exc + return tools.get_agent_tool_names(self._ctx) async def _tools_list_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - tools = self._get_filtered_tools() - return web.json_response({"count": len(tools), "tools": tools}) + return await tools.tools_list_handler(self._ctx, request) async def _tools_invoke_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - if not isinstance(body, dict): - return _json_error("Request body must be a JSON object", status=400) - - tool_name = str(body.get("tool_name", "") or "").strip() - if not tool_name: - return _json_error("tool_name is required", status=400) - - args = body.get("args") - if not isinstance(args, dict): - return _json_error("args must be a JSON object", status=400) - - # 验证工具是否在允许列表中 - filtered_tools = self._get_filtered_tools() - available_names: set[str] = set() - for schema in filtered_tools: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - available_names.add(name) - - if tool_name not in available_names: - caller_ip = request.remote or "unknown" - logger.warning( - "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", - tool_name, - caller_ip, - ) - return _json_error(f"Tool '{tool_name}' is not available", status=404) - - # 解析回调配置 - callback_cfg = body.get("callback") - use_callback = False - callback_url = "" - callback_headers: dict[str, str] = {} - if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): - callback_url = str(callback_cfg.get("url", "") or "").strip() - if not callback_url: - return _json_error( - "callback.url is required when callback is enabled", - status=400, - ) - url_error = _validate_callback_url(callback_url) - if url_error: - return _json_error(url_error, status=400) - raw_headers = callback_cfg.get("headers") - if isinstance(raw_headers, dict): - callback_headers = {str(k): str(v) for k, v in raw_headers.items()} - use_callback = True - - request_id = _uuid.uuid4().hex - caller_ip = request.remote or "unknown" - logger.info( - "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", - request_id, - tool_name, - caller_ip, + return await tools.tools_invoke_handler( + self._ctx, self._background_tasks, request ) - if use_callback: - # 异步执行 + 回调 - task = asyncio.create_task( - self._execute_and_callback( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - callback_url=callback_url, - callback_headers=callback_headers, - timeout=cfg.api.tool_invoke_timeout, - callback_timeout=cfg.api.tool_invoke_callback_timeout, - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return web.json_response( - { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "status": "accepted", - } - ) - - # 同步执行 - result = await self._execute_tool_invoke( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - timeout=cfg.api.tool_invoke_timeout, - ) - return web.json_response(result) - async def _execute_tool_invoke( self, *, @@ -1473,146 +280,8 @@ async def _execute_tool_invoke( body_context: Any, timeout: int, ) -> dict[str, Any]: - """执行工具调用并返回结果字典。""" - ai = self._ctx.ai - if ai is None: - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": "AI client not ready", - "duration_ms": 0, - } - - # 解析请求上下文 - ctx_data: dict[str, Any] = {} - if isinstance(body_context, dict): - ctx_data = body_context - request_type = str(ctx_data.get("request_type", "api") or "api") - group_id = ctx_data.get("group_id") - user_id = ctx_data.get("user_id") - sender_id = ctx_data.get("sender_id") - - args_keys = list(args.keys()) - logger.info( - "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", - request_id, - tool_name, - args_keys, - ) - - start = time.perf_counter() - effective_timeout = self._resolve_tool_invoke_timeout(tool_name, timeout) - try: - async with RequestContext( - request_type=request_type, - group_id=int(group_id) if group_id is not None else None, - user_id=int(user_id) if user_id is not None else None, - sender_id=int(sender_id) if sender_id is not None else None, - ) as ctx: - # 注入核心服务资源 - if self._ctx.sender is not None: - ctx.set_resource("sender", self._ctx.sender) - if self._ctx.history_manager is not None: - ctx.set_resource("history_manager", self._ctx.history_manager) - runtime_config = getattr(ai, "runtime_config", None) - if runtime_config is not None: - ctx.set_resource("runtime_config", runtime_config) - memory_storage = getattr(ai, "memory_storage", None) - if memory_storage is not None: - ctx.set_resource("memory_storage", memory_storage) - if self._ctx.onebot is not None: - ctx.set_resource("onebot_client", self._ctx.onebot) - if self._ctx.scheduler is not None: - ctx.set_resource("scheduler", self._ctx.scheduler) - if self._ctx.cognitive_service is not None: - ctx.set_resource("cognitive_service", self._ctx.cognitive_service) - if self._ctx.meme_service is not None: - ctx.set_resource("meme_service", self._ctx.meme_service) - - tool_context: dict[str, Any] = { - "request_type": request_type, - "request_id": request_id, - } - if group_id is not None: - tool_context["group_id"] = int(group_id) - if user_id is not None: - tool_context["user_id"] = int(user_id) - if sender_id is not None: - tool_context["sender_id"] = int(sender_id) - - tool_manager = getattr(ai, "tool_manager", None) - if tool_manager is None: - raise RuntimeError("ToolManager not available") - - raw_result = await self._await_tool_invoke_result( - tool_manager.execute_tool(tool_name, args, tool_context), - timeout=effective_timeout, - ) - - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - result_str = str(raw_result or "") - logger.info( - "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " - "duration_ms=%s result_len=%d", - request_id, - tool_name, - elapsed_ms, - len(result_str), - ) - return { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "result": result_str, - "duration_ms": elapsed_ms, - } - - except _ToolInvokeExecutionTimeoutError: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.warning( - "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", - request_id, - tool_name, - timeout, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": f"Execution timed out after {timeout}s", - "duration_ms": elapsed_ms, - } - except Exception as exc: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.exception( - "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", - request_id, - tool_name, - exc, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": str(exc), - "duration_ms": elapsed_ms, - } - - async def _execute_and_callback( - self, - *, - request_id: str, - tool_name: str, - args: dict[str, Any], - body_context: Any, - callback_url: str, - callback_headers: dict[str, str], - timeout: int, - callback_timeout: int, - ) -> None: - """异步执行工具并发送回调。""" - result = await self._execute_tool_invoke( + return await tools.execute_tool_invoke( + self._ctx, request_id=request_id, tool_name=tool_name, args=args, @@ -1620,371 +289,17 @@ async def _execute_and_callback( timeout=timeout, ) - payload = { - "request_id": result["request_id"], - "tool_name": result["tool_name"], - "ok": result["ok"], - "result": result.get("result"), - "duration_ms": result.get("duration_ms", 0), - "error": result.get("error"), - } - - try: - cb_timeout = ClientTimeout(total=callback_timeout) - async with ClientSession(timeout=cb_timeout) as session: - # aiohttp json= 自动设置 Content-Type,无需手动指定 - async with session.post( - callback_url, - json=payload, - headers=callback_headers or None, - ) as resp: - logger.info( - "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", - request_id, - _mask_url(callback_url), - resp.status, - ) - except Exception as exc: - logger.warning( - "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", - request_id, - _mask_url(callback_url), - exc, - ) - - # ------------------------------------------------------------------ - # Naga Bind / Send / Unbind API - # ------------------------------------------------------------------ - + # Naga def _verify_naga_api_key(self, request: web.Request) -> str | None: - """校验 Naga 共享密钥,返回错误信息或 None 表示通过。""" - cfg = self._ctx.config_getter() - expected = cfg.naga.api_key - if not expected: - return "naga api_key not configured" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return "missing or invalid Authorization header" - provided = auth_header[7:] - import secrets as _secrets - - if not _secrets.compare_digest(provided, expected): - return "invalid api_key" - return None + return naga.verify_naga_api_key(self._ctx, request) async def _naga_bind_callback_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - status = str(body.get("status", "") or "").strip().lower() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - reason = str(body.get("reason", "") or "").strip() - if not bind_uuid or not naga_id: - return _json_error("bind_uuid and naga_id are required", status=400) - if status not in {"approved", "rejected"}: - return _json_error("status must be 'approved' or 'rejected'", status=400) - logger.info( - "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - status, - _short_text_preview(reason, limit=60), - delivery_signature[:12] + "..." if delivery_signature else "", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - sender = self._ctx.sender - if status == "approved": - if not delivery_signature: - return _json_error( - "delivery_signature is required when approved", status=400 - ) - binding, created, err = await naga_store.activate_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - delivery_signature=delivery_signature, - ) - if err: - logger.warning( - "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - created, - binding.qq_id if binding is not None else "", - ) - if created and binding is not None and sender is not None: - try: - await sender.send_private_message( - binding.qq_id, - f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "approved", - "idempotent": not created, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) - - pending, removed, err = await naga_store.reject_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - reason=reason, - ) - if err: - logger.warning( - "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - removed, - pending.qq_id if pending is not None else "", - ) - if removed and pending is not None and sender is not None: - try: - detail = f"\n原因: {reason}" if reason else "" - await sender.send_private_message( - pending.qq_id, - f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "rejected", - "idempotent": not removed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_bind_callback_handler(self._ctx, request) async def _naga_messages_send_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/messages/send — 验签后发送消息。""" - from Undefined.api.naga_store import mask_token - - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - request_uuid = str(body.get("uuid", "") or "").strip() - target = body.get("target") - message = body.get("message") - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - if not isinstance(target, dict): - return _json_error("target object is required", status=400) - if not isinstance(message, dict): - return _json_error("message object is required", status=400) - - raw_target_qq = target.get("qq_id") - raw_target_group = target.get("group_id") - if raw_target_qq is None or raw_target_group is None: - return _json_error( - "target.qq_id and target.group_id are required", status=400 - ) - try: - target_qq = int(raw_target_qq) - target_group = int(raw_target_group) - except Exception: - return _json_error( - "target.qq_id and target.group_id must be integers", status=400 - ) - mode = str(target.get("mode", "") or "").strip().lower() - if mode not in {"private", "group", "both"}: - return _json_error( - "target.mode must be 'private', 'group', or 'both'", status=400 - ) - - fmt = str(message.get("format", "text") or "text").strip().lower() - content = str(message.get("content", "") or "").strip() - if fmt not in {"text", "markdown", "html"}: - return _json_error( - "message.format must be 'text', 'markdown', or 'html'", status=400 - ) - if not content: - return _json_error("message.content is required", status=400) - - message_key = _naga_message_digest( - bind_uuid=bind_uuid, - naga_id=naga_id, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - ) - logger.info( - "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - request_uuid, - mode, - fmt, - target_qq, - target_group, - message_key, - len(content), - _short_text_preview(content), - mask_token(delivery_signature), + return await naga.naga_messages_send_handler( + self._ctx, self._naga_state, request ) - if mode == "both": - logger.warning( - "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - inflight_count = await self._track_naga_send_start(message_key) - if inflight_count > 1: - logger.warning( - "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - inflight_count, - ) - try: - if request_uuid: - dedupe_action, dedupe_value = await self._register_naga_request_uuid( - request_uuid, message_key - ) - if dedupe_action == "conflict": - logger.warning( - "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return _json_error("uuid reused with different payload", status=409) - if dedupe_action == "cached": - cached_status, cached_payload = dedupe_value - logger.warning( - "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - if dedupe_action == "await": - wait_future = dedupe_value - logger.warning( - "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - cached_status, cached_payload = await wait_future - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - - response = await self._naga_messages_send_impl( - naga_id=naga_id, - bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - trace_id=trace_id, - message_key=message_key, - ) - if request_uuid: - await self._finish_naga_request_uuid( - request_uuid, - message_key, - status=response.status, - payload=_parse_response_payload(response), - ) - return response - except Exception as exc: - if request_uuid: - await self._fail_naga_request_uuid(request_uuid, message_key, exc) - raise - finally: - remaining = await self._track_naga_send_done(message_key) - logger.info( - "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - remaining, - ) async def _naga_messages_send_impl( self, @@ -2000,492 +315,19 @@ async def _naga_messages_send_impl( trace_id: str, message_key: str, ) -> Response: - from Undefined.api.naga_store import mask_token - - naga_store = self._ctx.naga_store - if naga_store is None: - logger.warning( - "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("Naga integration not available", status=503) - - binding, err_msg = await naga_store.acquire_delivery( + return await naga.naga_messages_send_impl( + self._ctx, naga_id=naga_id, bind_uuid=bind_uuid, delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=message_format, + content=content, + trace_id=trace_id, + message_key=message_key, ) - if binding is None: - logger.warning( - "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", - trace_id, - naga_id, - bind_uuid, - err_msg.message if err_msg is not None else "unknown_error", - mask_token(delivery_signature), - ) - return _json_error( - err_msg.message if err_msg is not None else "delivery not available", - status=err_msg.http_status if err_msg is not None else 403, - ) - - logger.info( - "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - binding.qq_id, - binding.group_id, - ) - try: - if target_qq != binding.qq_id or target_group != binding.group_id: - logger.warning( - "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", - trace_id, - naga_id, - bind_uuid, - target_qq, - target_group, - binding.qq_id, - binding.group_id, - ) - return _json_error("target does not match bound qq/group", status=403) - - cfg = self._ctx.config_getter() - if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: - logger.warning( - "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", - trace_id, - naga_id, - bind_uuid, - binding.group_id, - ) - return _json_error( - "bound group is not in naga.allowed_groups", status=403 - ) - - sender = self._ctx.sender - if sender is None: - logger.warning( - "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("sender not available", status=503) - - moderation: dict[str, Any] - naga_cfg = getattr(cfg, "naga", None) - moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) - security = getattr(self._ctx.command_dispatcher, "security", None) - if not moderation_enabled: - moderation = { - "status": "skipped_disabled", - "blocked": False, - "categories": [], - "message": "Naga moderation disabled by config; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - ) - elif security is None or not hasattr(security, "moderate_naga_message"): - moderation = { - "status": "error_allowed", - "blocked": False, - "categories": [], - "message": "Naga moderation service unavailable; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - else: - logger.info( - "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - len(content), - ) - result = await security.moderate_naga_message( - message_format=message_format, - content=content, - ) - moderation = { - "status": result.status, - "blocked": result.blocked, - "categories": result.categories, - "message": result.message, - "model_name": result.model_name, - } - logger.info( - "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - result.blocked, - result.status, - result.model_name, - ",".join(result.categories) or "-", - ) - if moderation["blocked"]: - logger.warning( - "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - moderation["message"], - ) - return web.json_response( - { - "ok": False, - "error": "message blocked by moderation", - "moderation": moderation, - }, - status=403, - ) - - send_content: str | None = content if message_format == "text" else None - image_path: str | None = None - tmp_path: str | None = None - rendered = False - render_fallback = False - if message_format in {"markdown", "html"}: - import tempfile - - try: - html_str = content - if message_format == "markdown": - html_str = await render_markdown_to_html(content) - fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") - os.close(fd) - await render_html_to_image(html_str, tmp_path) - image_path = tmp_path - rendered = True - logger.info( - "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - Path(tmp_path).name if tmp_path is not None else "", - ) - except Exception as exc: - logger.warning( - "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - exc, - ) - send_content = content - render_fallback = True - - sent_private = False - sent_group = False - group_policy_blocked = False - - async def _ensure_delivery_active() -> tuple[Any, Response | None]: - current_binding, live_err = await naga_store.ensure_delivery_active( - naga_id=naga_id, - bind_uuid=bind_uuid, - ) - if current_binding is None: - logger.warning( - "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - live_err.message - if live_err is not None - else "delivery no longer active", - ) - return None, web.json_response( - { - "ok": False, - "error": ( - live_err.message - if live_err is not None - else "delivery no longer active" - ), - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=live_err.http_status if live_err is not None else 409, - ) - return current_binding, None - - try: - cq_image: str | None = None - if image_path is not None: - file_uri = Path(image_path).resolve().as_uri() - cq_image = f"[CQ:image,file={file_uri}]" - - if mode in {"private", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - logger.info( - "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - try: - if send_content is not None: - await sender.send_private_message( - current_binding.qq_id, send_content - ) - elif cq_image is not None: - await sender.send_private_message( - current_binding.qq_id, cq_image - ) - sent_private = True - logger.info( - "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.qq_id, - message_key, - exc, - ) - - if mode in {"group", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - current_cfg = self._ctx.config_getter() - if current_binding.group_id not in current_cfg.naga.allowed_groups: - group_policy_blocked = True - logger.warning( - "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - else: - logger.info( - "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - try: - if send_content is not None: - await sender.send_group_message( - current_binding.group_id, send_content - ) - elif cq_image is not None: - await sender.send_group_message( - current_binding.group_id, cq_image - ) - sent_group = True - logger.info( - "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.group_id, - message_key, - exc, - ) - finally: - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - - if mode == "private" and not sent_private: - return web.json_response( - { - "ok": False, - "error": "private delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "group" and not sent_group: - return web.json_response( - { - "ok": False, - "error": "group delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "both" and not (sent_private or sent_group): - if group_policy_blocked: - return web.json_response( - { - "ok": False, - "error": "bound group is not in naga.allowed_groups", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=403, - ) - return web.json_response( - { - "ok": False, - "error": "all deliveries failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - - await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) - partial_success = mode == "both" and (sent_private != sent_group) - logger.info( - "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - sent_private, - sent_group, - partial_success, - rendered, - render_fallback, - ) - return web.json_response( - { - "ok": True, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - "sent_private": sent_private, - "sent_group": sent_group, - "partial_success": partial_success, - "delivery_status": ( - "partial_success" if partial_success else "full_success" - ), - "rendered": rendered, - "render_fallback": render_fallback, - "moderation": moderation, - } - ) - finally: - await naga_store.release_delivery(bind_uuid=bind_uuid) async def _naga_unbind_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/unbind — 远端主动解绑。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - logger.info( - "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - delivery_signature[:12] + "...", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - binding, changed, err = await naga_store.revoke_binding( - naga_id, - expected_bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - ) - if binding is None: - logger.warning( - "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message if err is not None else "binding not found", - ) - return _json_error( - err.message if err is not None else "binding not found", - status=err.http_status if err is not None else 404, - ) - logger.info( - "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - changed, - binding.qq_id, - binding.group_id, - ) - return web.json_response( - { - "ok": True, - "idempotent": not changed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_unbind_handler(self._ctx, request) diff --git a/src/Undefined/api/routes/__init__.py b/src/Undefined/api/routes/__init__.py new file mode 100644 index 0000000..8a999d1 --- /dev/null +++ b/src/Undefined/api/routes/__init__.py @@ -0,0 +1 @@ +"""Runtime API route modules.""" diff --git a/src/Undefined/api/routes/chat.py b/src/Undefined/api/routes/chat.py new file mode 100644 index 0000000..536435c --- /dev/null +++ b/src/Undefined/api/routes/chat.py @@ -0,0 +1,359 @@ +"""Chat route handlers extracted from the Runtime API application.""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import suppress +from datetime import datetime +from typing import Any, Awaitable, Callable + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _WebUIVirtualSender, + _build_chat_response_payload, + _json_error, + _sse_event, + _to_bool, +) +from Undefined.attachments import ( + attachment_refs_to_xml, + build_attachment_scope, + register_message_attachments, + render_message_with_pic_placeholders, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN +from Undefined.utils.common import message_to_segments +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +logger = logging.getLogger(__name__) + +_VIRTUAL_USER_NAME = "system" +_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 + + +async def run_webui_chat( + ctx: RuntimeAPIContext, + *, + text: str, + send_output: Callable[[int, str], Awaitable[None]], +) -> str: + """Execute a single WebUI chat turn (command dispatch or AI ask).""" + + cfg = ctx.config_getter() + permission_sender_id = int(cfg.superadmin_qq) + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + input_segments = message_to_segments(text) + registered_input = await register_message_attachments( + registry=ctx.ai.attachment_registry, + segments=input_segments, + scope_key=webui_scope_key, + resolve_image_url=ctx.onebot.get_image, + get_forward_messages=ctx.onebot.get_forward_msg, + ) + normalized_text = registered_input.normalized_text or text + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=normalized_text, + display_name=_VIRTUAL_USER_NAME, + user_name=_VIRTUAL_USER_NAME, + attachments=registered_input.attachments, + ) + + command = ctx.command_dispatcher.parse_command(normalized_text) + if command: + await ctx.command_dispatcher.dispatch_private( + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + command=command, + send_private_callback=send_output, + is_webui_session=True, + ) + return "command" + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + attachment_xml = ( + f"\n{attachment_refs_to_xml(registered_input.attachments)}" + if registered_input.attachments + else "" + ) + full_question = f"""{escape_xml_text(normalized_text)} {attachment_xml} -+ + +【WebUI 会话】 +这是一条来自 WebUI 控制台的会话请求。 +会话身份:虚拟用户 system(42)。 +权限等级:superadmin(你可按最高管理权限处理)。 +请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" + virtual_sender = _WebUIVirtualSender( + _VIRTUAL_USER_ID, send_output, onebot=ctx.onebot + ) + + async def _get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=ctx.onebot, + history_manager=ctx.history_manager, + bot_qq=cfg.bot_qq, + attachment_registry=getattr(ctx.ai, "attachment_registry", None), + ) + + async with RequestContext( + request_type="private", + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + ) as rctx: + ai_client = ctx.ai # noqa: F841 + memory_storage = ctx.ai.memory_storage # noqa: F841 + runtime_config = ctx.ai.runtime_config # noqa: F841 + sender = virtual_sender # noqa: F841 + history_manager = ctx.history_manager # noqa: F841 + onebot_client = ctx.onebot # noqa: F841 + scheduler = ctx.scheduler # noqa: F841 + + def send_message_callback( + msg: str, reply_to: int | None = None + ) -> Awaitable[None]: + _ = reply_to + return send_output(_VIRTUAL_USER_ID, msg) + + get_recent_messages_callback = _get_recent_cb # noqa: F841 + get_image_url_callback = ctx.onebot.get_image # noqa: F841 + get_forward_msg_callback = ctx.onebot.get_forward_msg # noqa: F841 + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + rctx.set_resource(key, value) + rctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) + rctx.set_resource("webui_session", True) + rctx.set_resource("webui_permission", "superadmin") + + result = await ctx.ai.ask( + full_question, + send_message_callback=send_message_callback, + get_recent_messages_callback=get_recent_messages_callback, + get_image_url_callback=get_image_url_callback, + get_forward_msg_callback=get_forward_msg_callback, + sender=sender, + history_manager=history_manager, + onebot_client=onebot_client, + scheduler=scheduler, + extra_context={ + "is_private_chat": True, + "request_type": "private", + "user_id": _VIRTUAL_USER_ID, + "sender_name": _VIRTUAL_USER_NAME, + "webui_session": True, + "webui_permission": "superadmin", + }, + ) + + final_reply = str(result or "").strip() + if final_reply: + await send_output(_VIRTUAL_USER_ID, final_reply) + + return "chat" + + +async def chat_history_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """Return recent WebUI chat history.""" + + limit_raw = str(request.query.get("limit", "200") or "200").strip() + try: + limit = int(limit_raw) + except ValueError: + limit = 200 + limit = max(1, min(limit, 500)) + + getter = getattr(ctx.history_manager, "get_recent_private", None) + if not callable(getter): + return _json_error("History manager not ready", status=503) + + records = getter(_VIRTUAL_USER_ID, limit) + items: list[dict[str, Any]] = [] + for item in records: + if not isinstance(item, dict): + continue + content = str(item.get("message", "")).strip() + if not content: + continue + display_name = str(item.get("display_name", "")).strip().lower() + role = "bot" if display_name == "bot" else "user" + items.append( + { + "role": role, + "content": content, + "timestamp": str(item.get("timestamp", "") or "").strip(), + } + ) + + return web.json_response( + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "count": len(items), + "items": items, + } + ) + + +async def chat_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> web.StreamResponse: + """Handle a WebUI chat request (non-streaming or SSE streaming).""" + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + text = str(body.get("message", "") or "").strip() + if not text: + return _json_error("message is required", status=400) + + stream = _to_bool(body.get("stream")) + outputs: list[str] = [] + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + + async def _capture_private_message(user_id: int, message: str) -> None: + _ = user_id + content = str(message or "").strip() + if not content: + return + rendered = await render_message_with_pic_placeholders( + content, + registry=ctx.ai.attachment_registry, + scope_key=webui_scope_key, + strict=False, + ) + if not rendered.delivery_text.strip(): + return + outputs.append(rendered.delivery_text) + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=rendered.history_text, + display_name="Bot", + user_name="Bot", + attachments=rendered.attachments, + ) + + if not stream: + try: + mode = await run_webui_chat( + ctx, text=text, send_output=_capture_private_message + ) + except Exception as exc: + logger.exception("[RuntimeAPI] chat failed: %s", exc) + return _json_error("Chat failed", status=502) + return web.json_response(_build_chat_response_payload(mode, outputs)) + + response = web.StreamResponse( + status=200, + reason="OK", + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await response.prepare(request) + + message_queue: asyncio.Queue[str] = asyncio.Queue() + + async def _capture_private_message_stream(user_id: int, message: str) -> None: + output_count = len(outputs) + await _capture_private_message(user_id, message) + if len(outputs) <= output_count: + return + content = outputs[-1].strip() + if content: + await message_queue.put(content) + + task = asyncio.create_task( + run_webui_chat(ctx, text=text, send_output=_capture_private_message_stream) + ) + mode = "chat" + client_disconnected = False + try: + await response.write( + _sse_event( + "meta", + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + }, + ) + ) + + while True: + if request.transport is None or request.transport.is_closing(): + client_disconnected = True + break + if task.done() and message_queue.empty(): + break + try: + message = await asyncio.wait_for( + message_queue.get(), + timeout=_CHAT_SSE_KEEPALIVE_SECONDS, + ) + await response.write(_sse_event("message", {"content": message})) + except asyncio.TimeoutError: + await response.write(b": keep-alive\n\n") + + if client_disconnected: + task.cancel() + with suppress(asyncio.CancelledError): + await task + return response + + mode = await task + await response.write( + _sse_event("done", _build_chat_response_payload(mode, outputs)) + ) + except asyncio.CancelledError: + task.cancel() + with suppress(asyncio.CancelledError): + await task + raise + except (ConnectionResetError, RuntimeError): + task.cancel() + with suppress(asyncio.CancelledError): + await task + except Exception as exc: + logger.exception("[RuntimeAPI] chat stream failed: %s", exc) + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + with suppress(Exception): + await response.write(_sse_event("error", {"error": str(exc)})) + finally: + with suppress(Exception): + await response.write_eof() + + return response diff --git a/src/Undefined/api/routes/cognitive.py b/src/Undefined/api/routes/cognitive.py new file mode 100644 index 0000000..6609699 --- /dev/null +++ b/src/Undefined/api/routes/cognitive.py @@ -0,0 +1,86 @@ +"""Cognitive event & profile routes.""" + +from __future__ import annotations + +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param + + +async def cognitive_events_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + for key in ( + "target_user_id", + "target_group_id", + "sender_id", + "request_type", + "top_k", + "time_from", + "time_to", + ): + value = _optional_query_param(request, key) + if value is not None: + search_kwargs[key] = value + + results = await cognitive_service.search_events(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profiles_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + entity_type = _optional_query_param(request, "entity_type") + if entity_type is not None: + search_kwargs["entity_type"] = entity_type + top_k = _optional_query_param(request, "top_k") + if top_k is not None: + search_kwargs["top_k"] = top_k + + results = await cognitive_service.search_profiles(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profile_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + entity_type = str(request.match_info.get("entity_type", "")).strip() + entity_id = str(request.match_info.get("entity_id", "")).strip() + if not entity_type or not entity_id: + return _json_error("entity_type/entity_id are required", status=400) + + profile = await cognitive_service.get_profile(entity_type, entity_id) + return web.json_response( + { + "entity_type": entity_type, + "entity_id": entity_id, + "profile": profile or "", + "found": bool(profile), + } + ) diff --git a/src/Undefined/api/routes/health.py b/src/Undefined/api/routes/health.py new file mode 100644 index 0000000..bbdb020 --- /dev/null +++ b/src/Undefined/api/routes/health.py @@ -0,0 +1,23 @@ +"""Health check route.""" + +from __future__ import annotations + +from datetime import datetime + +from aiohttp.web_response import Response +from aiohttp import web + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext + + +async def health_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = ctx, request + return web.json_response( + { + "ok": True, + "service": "undefined-runtime-api", + "version": __version__, + "timestamp": datetime.now().isoformat(), + } + ) diff --git a/src/Undefined/api/routes/memes.py b/src/Undefined/api/routes/memes.py new file mode 100644 index 0000000..dcef98a --- /dev/null +++ b/src/Undefined/api/routes/memes.py @@ -0,0 +1,222 @@ +"""Meme management route handlers.""" + +from __future__ import annotations + +from typing import Any, cast + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _to_bool + + +async def meme_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + + def _parse_optional_bool(name: str) -> bool | None: + raw = request.query.get(name) + if raw is None or str(raw).strip() == "": + return None + return _to_bool(raw) + + page_raw = _optional_query_param(request, "page") + page_size_raw = _optional_query_param(request, "page_size") + top_k_raw = _optional_query_param(request, "top_k") + query = str(request.query.get("q", "") or "").strip() + query_mode = str(request.query.get("query_mode", "") or "").strip().lower() + keyword_query = str(request.query.get("keyword_query", "") or "").strip() + semantic_query = str(request.query.get("semantic_query", "") or "").strip() + try: + page = int(page_raw) if page_raw is not None else 1 + page_size = int(page_size_raw) if page_size_raw is not None else 50 + top_k = int(top_k_raw) if top_k_raw is not None else page_size + except ValueError: + return _json_error("page/page_size/top_k must be integers", status=400) + page = max(1, page) + page_size = max(1, min(200, page_size)) + top_k = max(1, top_k) + sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() + + enabled_filter = _parse_optional_bool("enabled") + animated_filter = _parse_optional_bool("animated") + pinned_filter = _parse_optional_bool("pinned") + if not (query or keyword_query or semantic_query) and sort == "relevance": + sort = "updated_at" + + if query or keyword_query or semantic_query: + has_post_filter = any( + f is not None for f in (enabled_filter, animated_filter, pinned_filter) + ) + requested_window = max(page * page_size, top_k) + if has_post_filter or page > 1 or sort != "relevance": + fetch_k = min(500, max(requested_window * 4, top_k)) + else: + fetch_k = min(500, requested_window) + search_payload = await meme_service.search_memes( + query, + query_mode=query_mode or meme_service.default_query_mode, + keyword_query=keyword_query or None, + semantic_query=semantic_query or None, + top_k=fetch_k, + include_disabled=enabled_filter is not True, + sort=sort, + ) + filtered_items: list[dict[str, Any]] = [] + for item in list(search_payload.get("items") or []): + if ( + enabled_filter is not None + and bool(item.get("enabled")) != enabled_filter + ): + continue + if ( + animated_filter is not None + and bool(item.get("is_animated")) != animated_filter + ): + continue + if pinned_filter is not None and bool(item.get("pinned")) != pinned_filter: + continue + filtered_items.append(item) + offset = (page - 1) * page_size + paged_items = filtered_items[offset : offset + page_size] + window_total = len(filtered_items) + fetched_window_count = len(list(search_payload.get("items") or [])) + window_exhausted = fetched_window_count < fetch_k + has_more = bool(paged_items) and ( + offset + page_size < window_total + or (not window_exhausted and window_total >= offset + page_size) + ) + return web.json_response( + { + "ok": True, + "total": None, + "window_total": window_total, + "total_exact": False, + "page": page, + "page_size": page_size, + "has_more": has_more, + "query_mode": search_payload.get("query_mode"), + "keyword_query": search_payload.get("keyword_query"), + "semantic_query": search_payload.get("semantic_query"), + "sort": search_payload.get("sort", sort), + "items": paged_items, + } + ) + + payload = await meme_service.list_memes( + query=query, + enabled=enabled_filter, + animated=animated_filter, + pinned=pinned_filter, + sort=sort, + page=page, + page_size=page_size, + summary=True, + ) + return web.json_response(payload) + + +async def meme_stats_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + return web.json_response(await meme_service.stats()) + + +async def meme_detail_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + detail = await meme_service.get_meme(uid) + if detail is None: + return _json_error("Meme not found", status=404) + return web.json_response(detail) + + +async def meme_blob_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=False) + if path is None: + return _json_error("Meme blob not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_preview_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=True) + if path is None: + return _json_error("Meme preview not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_update_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + try: + payload = await request.json() + except Exception: + return _json_error("Invalid JSON body", status=400) + if not isinstance(payload, dict): + return _json_error("JSON body must be an object", status=400) + updated = await meme_service.update_meme( + uid, + manual_description=payload.get("manual_description"), + tags=payload.get("tags"), + aliases=payload.get("aliases"), + enabled=payload.get("enabled") if "enabled" in payload else None, + pinned=payload.get("pinned") if "pinned" in payload else None, + ) + if updated is None: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "record": updated}) + + +async def meme_delete_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + deleted = await meme_service.delete_meme(uid) + if not deleted: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "uid": uid}) + + +async def meme_reanalyze_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reanalyze(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + + +async def meme_reindex_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reindex(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) diff --git a/src/Undefined/api/routes/memory.py b/src/Undefined/api/routes/memory.py new file mode 100644 index 0000000..30c114e --- /dev/null +++ b/src/Undefined/api/routes/memory.py @@ -0,0 +1,156 @@ +"""Memory CRUD routes.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _parse_query_time + + +async def memory_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + query = str(request.query.get("q", "") or "").strip().lower() + top_k_raw = _optional_query_param(request, "top_k") + time_from_raw = _optional_query_param(request, "time_from") + time_to_raw = _optional_query_param(request, "time_to") + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + + limit: int | None = None + if top_k_raw is not None: + try: + limit = int(top_k_raw) + except ValueError: + return _json_error("top_k must be an integer", status=400) + if limit <= 0: + return _json_error("top_k must be > 0", status=400) + + time_from_dt = _parse_query_time(time_from_raw) + if time_from_raw is not None and time_from_dt is None: + return _json_error("time_from must be ISO datetime", status=400) + time_to_dt = _parse_query_time(time_to_raw) + if time_to_raw is not None and time_to_dt is None: + return _json_error("time_to must be ISO datetime", status=400) + if time_from_dt and time_to_dt and time_from_dt > time_to_dt: + time_from_dt, time_to_dt = time_to_dt, time_from_dt + + records = memory_storage.get_all() + items: list[dict[str, Any]] = [] + for item in records: + created_at = str(item.created_at or "").strip() + created_dt = _parse_query_time(created_at) + if time_from_dt and created_dt and created_dt < time_from_dt: + continue + if time_to_dt and created_dt and created_dt > time_to_dt: + continue + if (time_from_dt or time_to_dt) and created_dt is None: + continue + items.append( + { + "uuid": item.uuid, + "fact": item.fact, + "created_at": created_at, + } + ) + if query: + items = [ + item + for item in items + if query in str(item.get("fact", "")).lower() + or query in str(item.get("uuid", "")).lower() + ] + + def _created_sort_key(item: dict[str, Any]) -> float: + created_dt = _parse_query_time(str(item.get("created_at") or "")) + if created_dt is None: + return float("-inf") + with suppress(OSError, OverflowError, ValueError): + return float(created_dt.timestamp()) + return float("-inf") + + items.sort(key=_created_sort_key) + if limit is not None: + items = items[:limit] + + return web.json_response( + { + "total": len(items), + "items": items, + "query": { + "q": query or "", + "top_k": limit, + "time_from": time_from_raw, + "time_to": time_to_raw, + }, + } + ) + + +async def memory_create_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + new_uuid = await memory_storage.add(fact) + if new_uuid is None: + return _json_error("Failed to create memory", status=500) + existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] + item = existing[0] if existing else None + return web.json_response( + { + "uuid": new_uuid, + "fact": item.fact if item else fact, + "created_at": item.created_at if item else "", + }, + status=201, + ) + + +async def memory_update_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + ok = await memory_storage.update(target_uuid, fact) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + + +async def memory_delete_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + ok = await memory_storage.delete(target_uuid) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "deleted": True}) diff --git a/src/Undefined/api/routes/naga.py b/src/Undefined/api/routes/naga.py new file mode 100644 index 0000000..b76375d --- /dev/null +++ b/src/Undefined/api/routes/naga.py @@ -0,0 +1,897 @@ +"""Naga integration route handlers. + +Extracted from ``RuntimeAPI`` methods into free functions so they can be +registered declaratively in the route table. +""" + +from __future__ import annotations + +import logging +import os +import uuid as _uuid +from copy import deepcopy +from pathlib import Path +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, + _naga_message_digest, + _parse_response_payload, + _short_text_preview, +) +from Undefined.api._naga_state import NagaState +from Undefined.render import render_html_to_image, render_markdown_to_html + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Auth helper +# ------------------------------------------------------------------ + + +def verify_naga_api_key(ctx: RuntimeAPIContext, request: web.Request) -> str | None: + """校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过。""" + import secrets as _secrets + + cfg = ctx.config_getter() + expected = cfg.naga.api_key + if not expected: + return "naga api_key not configured" + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return "missing or invalid Authorization header" + provided = auth_header[7:] + if not _secrets.compare_digest(provided, expected): + return "invalid api_key" + return None + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/bind/callback +# ------------------------------------------------------------------ + + +async def naga_bind_callback_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + status = str(body.get("status", "") or "").strip().lower() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + reason = str(body.get("reason", "") or "").strip() + if not bind_uuid or not naga_id: + return _json_error("bind_uuid and naga_id are required", status=400) + if status not in {"approved", "rejected"}: + return _json_error("status must be 'approved' or 'rejected'", status=400) + logger.info( + "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + status, + _short_text_preview(reason, limit=60), + delivery_signature[:12] + "..." if delivery_signature else "", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + sender = ctx.sender + if status == "approved": + if not delivery_signature: + return _json_error( + "delivery_signature is required when approved", status=400 + ) + binding, created, err = await naga_store.activate_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + delivery_signature=delivery_signature, + ) + if err: + logger.warning( + "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + created, + binding.qq_id if binding is not None else "", + ) + if created and binding is not None and sender is not None: + try: + await sender.send_private_message( + binding.qq_id, + f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "approved", + "idempotent": not created, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + # --- rejected --- + pending, removed, err = await naga_store.reject_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + reason=reason, + ) + if err: + logger.warning( + "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + removed, + pending.qq_id if pending is not None else "", + ) + if removed and pending is not None and sender is not None: + try: + detail = f"\n原因: {reason}" if reason else "" + await sender.send_private_message( + pending.qq_id, + f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "rejected", + "idempotent": not removed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/messages/send +# ------------------------------------------------------------------ + + +async def naga_messages_send_handler( + ctx: RuntimeAPIContext, + naga_state: NagaState, + request: web.Request, +) -> Response: + """POST /api/v1/naga/messages/send — 验签后发送消息。""" + from Undefined.api.naga_store import mask_token + + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + request_uuid = str(body.get("uuid", "") or "").strip() + target = body.get("target") + message = body.get("message") + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + if not isinstance(target, dict): + return _json_error("target object is required", status=400) + if not isinstance(message, dict): + return _json_error("message object is required", status=400) + + raw_target_qq = target.get("qq_id") + raw_target_group = target.get("group_id") + if raw_target_qq is None or raw_target_group is None: + return _json_error("target.qq_id and target.group_id are required", status=400) + try: + target_qq = int(raw_target_qq) + target_group = int(raw_target_group) + except Exception: + return _json_error( + "target.qq_id and target.group_id must be integers", status=400 + ) + mode = str(target.get("mode", "") or "").strip().lower() + if mode not in {"private", "group", "both"}: + return _json_error( + "target.mode must be 'private', 'group', or 'both'", status=400 + ) + + fmt = str(message.get("format", "text") or "text").strip().lower() + content = str(message.get("content", "") or "").strip() + if fmt not in {"text", "markdown", "html"}: + return _json_error( + "message.format must be 'text', 'markdown', or 'html'", status=400 + ) + if not content: + return _json_error("message.content is required", status=400) + + message_key = _naga_message_digest( + bind_uuid=bind_uuid, + naga_id=naga_id, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + ) + logger.info( + "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + request_uuid, + mode, + fmt, + target_qq, + target_group, + message_key, + len(content), + _short_text_preview(content), + mask_token(delivery_signature), + ) + if mode == "both": + logger.warning( + "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + inflight_count = await naga_state.track_send_start(message_key) + if inflight_count > 1: + logger.warning( + "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + inflight_count, + ) + try: + if request_uuid: + dedupe_action, dedupe_value = await naga_state.register_request_uuid( + request_uuid, message_key + ) + if dedupe_action == "conflict": + logger.warning( + "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return _json_error("uuid reused with different payload", status=409) + if dedupe_action == "cached": + cached_status, cached_payload = dedupe_value + logger.warning( + "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + if dedupe_action == "await": + wait_future = dedupe_value + logger.warning( + "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + cached_status, cached_payload = await wait_future + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + + response = await naga_messages_send_impl( + ctx, + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + trace_id=trace_id, + message_key=message_key, + ) + if request_uuid: + await naga_state.finish_request_uuid( + request_uuid, + message_key, + status=response.status, + payload=_parse_response_payload(response), + ) + return response + except Exception as exc: + if request_uuid: + await naga_state.fail_request_uuid(request_uuid, message_key, exc) + raise + finally: + remaining = await naga_state.track_send_done(message_key) + logger.info( + "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + remaining, + ) + + +# ------------------------------------------------------------------ +# Core send implementation (no NagaState dependency) +# ------------------------------------------------------------------ + + +async def naga_messages_send_impl( + ctx: RuntimeAPIContext, + *, + naga_id: str, + bind_uuid: str, + delivery_signature: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, + trace_id: str, + message_key: str, +) -> Response: + from Undefined.api.naga_store import mask_token + + naga_store = ctx.naga_store + if naga_store is None: + logger.warning( + "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("Naga integration not available", status=503) + + binding, err_msg = await naga_store.acquire_delivery( + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", + trace_id, + naga_id, + bind_uuid, + err_msg.message if err_msg is not None else "unknown_error", + mask_token(delivery_signature), + ) + return _json_error( + err_msg.message if err_msg is not None else "delivery not available", + status=err_msg.http_status if err_msg is not None else 403, + ) + + logger.info( + "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + binding.qq_id, + binding.group_id, + ) + try: + if target_qq != binding.qq_id or target_group != binding.group_id: + logger.warning( + "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", + trace_id, + naga_id, + bind_uuid, + target_qq, + target_group, + binding.qq_id, + binding.group_id, + ) + return _json_error("target does not match bound qq/group", status=403) + + cfg = ctx.config_getter() + if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: + logger.warning( + "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", + trace_id, + naga_id, + bind_uuid, + binding.group_id, + ) + return _json_error("bound group is not in naga.allowed_groups", status=403) + + sender = ctx.sender + if sender is None: + logger.warning( + "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("sender not available", status=503) + + moderation: dict[str, Any] + naga_cfg = getattr(cfg, "naga", None) + moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) + security = getattr(ctx.command_dispatcher, "security", None) + if not moderation_enabled: + moderation = { + "status": "skipped_disabled", + "blocked": False, + "categories": [], + "message": "Naga moderation disabled by config; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + ) + elif security is None or not hasattr(security, "moderate_naga_message"): + moderation = { + "status": "error_allowed", + "blocked": False, + "categories": [], + "message": "Naga moderation service unavailable; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + else: + logger.info( + "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + len(content), + ) + result = await security.moderate_naga_message( + message_format=message_format, + content=content, + ) + moderation = { + "status": result.status, + "blocked": result.blocked, + "categories": result.categories, + "message": result.message, + "model_name": result.model_name, + } + logger.info( + "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + result.blocked, + result.status, + result.model_name, + ",".join(result.categories) or "-", + ) + if moderation["blocked"]: + logger.warning( + "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + moderation["message"], + ) + return web.json_response( + { + "ok": False, + "error": "message blocked by moderation", + "moderation": moderation, + }, + status=403, + ) + + send_content: str | None = content if message_format == "text" else None + image_path: str | None = None + tmp_path: str | None = None + rendered = False + render_fallback = False + if message_format in {"markdown", "html"}: + import tempfile + + try: + html_str = content + if message_format == "markdown": + html_str = await render_markdown_to_html(content) + fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") + os.close(fd) + await render_html_to_image(html_str, tmp_path) + image_path = tmp_path + rendered = True + logger.info( + "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + Path(tmp_path).name if tmp_path is not None else "", + ) + except Exception as exc: + logger.warning( + "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + exc, + ) + send_content = content + render_fallback = True + + sent_private = False + sent_group = False + group_policy_blocked = False + + async def _ensure_delivery_active() -> tuple[Any, Response | None]: + current_binding, live_err = await naga_store.ensure_delivery_active( + naga_id=naga_id, + bind_uuid=bind_uuid, + ) + if current_binding is None: + logger.warning( + "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + live_err.message + if live_err is not None + else "delivery no longer active", + ) + return None, web.json_response( + { + "ok": False, + "error": ( + live_err.message + if live_err is not None + else "delivery no longer active" + ), + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=live_err.http_status if live_err is not None else 409, + ) + return current_binding, None + + try: + cq_image: str | None = None + if image_path is not None: + file_uri = Path(image_path).resolve().as_uri() + cq_image = f"[CQ:image,file={file_uri}]" + + if mode in {"private", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + logger.info( + "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + try: + if send_content is not None: + await sender.send_private_message( + current_binding.qq_id, send_content + ) + elif cq_image is not None: + await sender.send_private_message( + current_binding.qq_id, cq_image + ) + sent_private = True + logger.info( + "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.qq_id, + message_key, + exc, + ) + + if mode in {"group", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + current_cfg = ctx.config_getter() + if current_binding.group_id not in current_cfg.naga.allowed_groups: + group_policy_blocked = True + logger.warning( + "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + else: + logger.info( + "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + try: + if send_content is not None: + await sender.send_group_message( + current_binding.group_id, send_content + ) + elif cq_image is not None: + await sender.send_group_message( + current_binding.group_id, cq_image + ) + sent_group = True + logger.info( + "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.group_id, + message_key, + exc, + ) + finally: + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + + if mode == "private" and not sent_private: + return web.json_response( + { + "ok": False, + "error": "private delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "group" and not sent_group: + return web.json_response( + { + "ok": False, + "error": "group delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "both" and not (sent_private or sent_group): + if group_policy_blocked: + return web.json_response( + { + "ok": False, + "error": "bound group is not in naga.allowed_groups", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=403, + ) + return web.json_response( + { + "ok": False, + "error": "all deliveries failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + + await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) + partial_success = mode == "both" and (sent_private != sent_group) + logger.info( + "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + sent_private, + sent_group, + partial_success, + rendered, + render_fallback, + ) + return web.json_response( + { + "ok": True, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + "sent_private": sent_private, + "sent_group": sent_group, + "partial_success": partial_success, + "delivery_status": ( + "partial_success" if partial_success else "full_success" + ), + "rendered": rendered, + "render_fallback": render_fallback, + "moderation": moderation, + } + ) + finally: + await naga_store.release_delivery(bind_uuid=bind_uuid) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/unbind +# ------------------------------------------------------------------ + + +async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + """POST /api/v1/naga/unbind — 远端主动解绑。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + logger.info( + "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + delivery_signature[:12] + "...", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + binding, changed, err = await naga_store.revoke_binding( + naga_id, + expected_bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message if err is not None else "binding not found", + ) + return _json_error( + err.message if err is not None else "binding not found", + status=err.http_status if err is not None else 404, + ) + logger.info( + "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + changed, + binding.qq_id, + binding.group_id, + ) + return web.json_response( + { + "ok": True, + "idempotent": not changed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) diff --git a/src/Undefined/api/routes/system.py b/src/Undefined/api/routes/system.py new file mode 100644 index 0000000..2cf3915 --- /dev/null +++ b/src/Undefined/api/routes/system.py @@ -0,0 +1,228 @@ +"""System / probe route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import platform +import sys +import time +from datetime import datetime +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _json_error, + _mask_url, + _naga_routes_enabled, + _registry_summary, +) +from Undefined.api._openapi import _build_openapi_spec +from Undefined.api._probes import ( + _build_internal_model_probe_payload, + _probe_http_endpoint, + _probe_ws_endpoint, + _skipped_probe, +) + +logger = logging.getLogger(__name__) + +_PROCESS_START_TIME = time.time() + + +async def openapi_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + cfg = ctx.config_getter() + if not bool(getattr(cfg.api, "openapi_enabled", True)): + logger.info( + "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote + ) + return _json_error("OpenAPI disabled", status=404) + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + logger.info( + "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", + request.remote, + naga_routes_enabled, + ) + return web.json_response(_build_openapi_spec(ctx, request)) + + +async def internal_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + queue_snapshot = ctx.queue_manager.snapshot() if ctx.queue_manager else {} + cognitive_queue_snapshot = ( + ctx.cognitive_job_queue.snapshot() if ctx.cognitive_job_queue else {} + ) + memory_storage = getattr(ctx.ai, "memory_storage", None) + memory_count = memory_storage.count() if memory_storage is not None else 0 + + # Skills 统计 + ai = ctx.ai + skills_info: dict[str, Any] = {} + if ai is not None: + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + anthropic_reg = getattr(ai, "anthropic_skill_registry", None) + skills_info["tools"] = _registry_summary(tool_reg) + skills_info["agents"] = _registry_summary(agent_reg) + skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) + + # 模型配置(脱敏) + models_info: dict[str, Any] = {} + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + for label in ( + "chat_model", + "vision_model", + "agent_model", + "security_model", + "naga_model", + "grok_model", + ): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = _build_internal_model_probe_payload(mcfg) + if summary_model is not None: + models_info["summary_model"] = _build_internal_model_probe_payload( + summary_model + ) + for label in ("embedding_model", "rerank_model"): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + + uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) + payload = { + "timestamp": datetime.now().isoformat(), + "version": __version__, + "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "platform": platform.system(), + "uptime_seconds": uptime_seconds, + "onebot": ctx.onebot.connection_status() if ctx.onebot is not None else {}, + "queues": queue_snapshot, + "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, + "cognitive": { + "enabled": bool(ctx.cognitive_service and ctx.cognitive_service.enabled), + "queue": cognitive_queue_snapshot, + }, + "api": { + "enabled": bool(cfg.api.enabled), + "host": cfg.api.host, + "port": cfg.api.port, + "openapi_enabled": bool(cfg.api.openapi_enabled), + }, + "skills": skills_info, + "models": models_info, + } + return web.json_response(payload) + + +async def external_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + naga_probe = ( + _probe_http_endpoint( + name="naga_model", + base_url=cfg.naga_model.api_url, + api_key=cfg.naga_model.api_key, + model_name=cfg.naga_model.model_name, + ) + if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) + else _skipped_probe( + name="naga_model", + reason="naga_integration_disabled", + model_name=cfg.naga_model.model_name, + ) + ) + checks = [ + _probe_http_endpoint( + name="chat_model", + base_url=cfg.chat_model.api_url, + api_key=cfg.chat_model.api_key, + model_name=cfg.chat_model.model_name, + ), + _probe_http_endpoint( + name="vision_model", + base_url=cfg.vision_model.api_url, + api_key=cfg.vision_model.api_key, + model_name=cfg.vision_model.model_name, + ), + _probe_http_endpoint( + name="security_model", + base_url=cfg.security_model.api_url, + api_key=cfg.security_model.api_key, + model_name=cfg.security_model.model_name, + ), + naga_probe, + _probe_http_endpoint( + name="agent_model", + base_url=cfg.agent_model.api_url, + api_key=cfg.agent_model.api_key, + model_name=cfg.agent_model.model_name, + ), + ] + if summary_model is not None: + checks.append( + _probe_http_endpoint( + name="summary_model", + base_url=summary_model.api_url, + api_key=summary_model.api_key, + model_name=summary_model.model_name, + ) + ) + grok_model = getattr(cfg, "grok_model", None) + if grok_model is not None: + checks.append( + _probe_http_endpoint( + name="grok_model", + base_url=getattr(grok_model, "api_url", ""), + api_key=getattr(grok_model, "api_key", ""), + model_name=getattr(grok_model, "model_name", ""), + ) + ) + checks.extend( + [ + _probe_http_endpoint( + name="embedding_model", + base_url=cfg.embedding_model.api_url, + api_key=cfg.embedding_model.api_key, + model_name=getattr(cfg.embedding_model, "model_name", ""), + ), + _probe_http_endpoint( + name="rerank_model", + base_url=cfg.rerank_model.api_url, + api_key=cfg.rerank_model.api_key, + model_name=getattr(cfg.rerank_model, "model_name", ""), + ), + _probe_ws_endpoint(cfg.onebot_ws_url), + ] + ) + results = await asyncio.gather(*checks) + ok = all(item.get("status") in {"ok", "skipped"} for item in results) + return web.json_response( + { + "ok": ok, + "timestamp": datetime.now().isoformat(), + "results": results, + } + ) diff --git a/src/Undefined/api/routes/tools.py b/src/Undefined/api/routes/tools.py new file mode 100644 index 0000000..f305179 --- /dev/null +++ b/src/Undefined/api/routes/tools.py @@ -0,0 +1,462 @@ +"""Tools route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid as _uuid +from typing import Any, Awaitable + +from aiohttp import ClientSession, ClientTimeout, web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _ToolInvokeExecutionTimeoutError, + _json_error, + _mask_url, + _to_bool, + _validate_callback_url, +) +from Undefined.context import RequestContext + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Pure helpers +# ------------------------------------------------------------------ + + +def get_filtered_tools(ctx: RuntimeAPIContext) -> list[dict[str, Any]]: + """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" + cfg = ctx.config_getter() + api_cfg = cfg.api + ai = ctx.ai + if ai is None: + return [] + + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + + all_schemas: list[dict[str, Any]] = [] + if tool_reg is not None: + all_schemas.extend(tool_reg.get_tools_schema()) + + # 收集 agent schema 并缓存名称集合(避免重复调用) + agent_names: set[str] = set() + if agent_reg is not None: + agent_schemas = agent_reg.get_agents_schema() + all_schemas.extend(agent_schemas) + for schema in agent_schemas: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + + denylist: set[str] = set(api_cfg.tool_invoke_denylist) + allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) + expose = api_cfg.tool_invoke_expose + + def _get_name(schema: dict[str, Any]) -> str: + func = schema.get("function", {}) + return str(func.get("name", "")) + + # 1. 先排除黑名单 + if denylist: + all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] + + # 2. 白名单非空时仅保留匹配项 + if allowlist: + return [s for s in all_schemas if _get_name(s) in allowlist] + + # 3. 按 expose 过滤 + if expose == "all": + return all_schemas + + def _is_tool(name: str) -> bool: + return "." not in name and name not in agent_names + + def _is_toolset(name: str) -> bool: + return "." in name and not name.startswith("mcp.") + + filtered: list[dict[str, Any]] = [] + for schema in all_schemas: + name = _get_name(schema) + if not name: + continue + if expose == "tools" and _is_tool(name): + filtered.append(schema) + elif expose == "toolsets" and _is_toolset(name): + filtered.append(schema) + elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): + filtered.append(schema) + elif expose == "agents" and name in agent_names: + filtered.append(schema) + + return filtered + + +def get_agent_tool_names(ctx: RuntimeAPIContext) -> set[str]: + ai = ctx.ai + if ai is None: + return set() + + agent_reg = getattr(ai, "agent_registry", None) + if agent_reg is None: + return set() + + agent_names: set[str] = set() + for schema in agent_reg.get_agents_schema(): + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + return agent_names + + +def resolve_tool_invoke_timeout( + ctx: RuntimeAPIContext, tool_name: str, timeout: int +) -> float | None: + if tool_name in get_agent_tool_names(ctx): + return None + return float(timeout) + + +# ------------------------------------------------------------------ +# Async helpers +# ------------------------------------------------------------------ + + +async def await_tool_invoke_result( + awaitable: Awaitable[Any], + *, + timeout: float | None, +) -> Any: + if timeout is None or timeout <= 0: + return await awaitable + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except asyncio.TimeoutError as exc: + raise _ToolInvokeExecutionTimeoutError from exc + + +# ------------------------------------------------------------------ +# Route handlers +# ------------------------------------------------------------------ + + +async def tools_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + tools = get_filtered_tools(ctx) + return web.json_response({"count": len(tools), "tools": tools}) + + +async def tools_invoke_handler( + ctx: RuntimeAPIContext, + background_tasks: set[asyncio.Task[Any]], + request: web.Request, +) -> Response: + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + if not isinstance(body, dict): + return _json_error("Request body must be a JSON object", status=400) + + tool_name = str(body.get("tool_name", "") or "").strip() + if not tool_name: + return _json_error("tool_name is required", status=400) + + args = body.get("args") + if not isinstance(args, dict): + return _json_error("args must be a JSON object", status=400) + + # 验证工具是否在允许列表中 + filtered_tools = get_filtered_tools(ctx) + available_names: set[str] = set() + for schema in filtered_tools: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + available_names.add(name) + + if tool_name not in available_names: + caller_ip = request.remote or "unknown" + logger.warning( + "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", + tool_name, + caller_ip, + ) + return _json_error(f"Tool '{tool_name}' is not available", status=404) + + # 解析回调配置 + callback_cfg = body.get("callback") + use_callback = False + callback_url = "" + callback_headers: dict[str, str] = {} + if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): + callback_url = str(callback_cfg.get("url", "") or "").strip() + if not callback_url: + return _json_error( + "callback.url is required when callback is enabled", + status=400, + ) + url_error = _validate_callback_url(callback_url) + if url_error: + return _json_error(url_error, status=400) + raw_headers = callback_cfg.get("headers") + if isinstance(raw_headers, dict): + callback_headers = {str(k): str(v) for k, v in raw_headers.items()} + use_callback = True + + request_id = _uuid.uuid4().hex + caller_ip = request.remote or "unknown" + logger.info( + "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", + request_id, + tool_name, + caller_ip, + ) + + if use_callback: + # 异步执行 + 回调 + task = asyncio.create_task( + execute_and_callback( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + callback_url=callback_url, + callback_headers=callback_headers, + timeout=cfg.api.tool_invoke_timeout, + callback_timeout=cfg.api.tool_invoke_callback_timeout, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return web.json_response( + { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "status": "accepted", + } + ) + + # 同步执行 + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + timeout=cfg.api.tool_invoke_timeout, + ) + return web.json_response(result) + + +# ------------------------------------------------------------------ +# Execution core +# ------------------------------------------------------------------ + + +async def execute_tool_invoke( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + timeout: int, +) -> dict[str, Any]: + """执行工具调用并返回结果字典。""" + ai = ctx.ai + if ai is None: + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": "AI client not ready", + "duration_ms": 0, + } + + # 解析请求上下文 + ctx_data: dict[str, Any] = {} + if isinstance(body_context, dict): + ctx_data = body_context + request_type = str(ctx_data.get("request_type", "api") or "api") + group_id = ctx_data.get("group_id") + user_id = ctx_data.get("user_id") + sender_id = ctx_data.get("sender_id") + + args_keys = list(args.keys()) + logger.info( + "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", + request_id, + tool_name, + args_keys, + ) + + start = time.perf_counter() + effective_timeout = resolve_tool_invoke_timeout(ctx, tool_name, timeout) + try: + async with RequestContext( + request_type=request_type, + group_id=int(group_id) if group_id is not None else None, + user_id=int(user_id) if user_id is not None else None, + sender_id=int(sender_id) if sender_id is not None else None, + ) as req_ctx: + # 注入核心服务资源 + if ctx.sender is not None: + req_ctx.set_resource("sender", ctx.sender) + if ctx.history_manager is not None: + req_ctx.set_resource("history_manager", ctx.history_manager) + runtime_config = getattr(ai, "runtime_config", None) + if runtime_config is not None: + req_ctx.set_resource("runtime_config", runtime_config) + memory_storage = getattr(ai, "memory_storage", None) + if memory_storage is not None: + req_ctx.set_resource("memory_storage", memory_storage) + if ctx.onebot is not None: + req_ctx.set_resource("onebot_client", ctx.onebot) + if ctx.scheduler is not None: + req_ctx.set_resource("scheduler", ctx.scheduler) + if ctx.cognitive_service is not None: + req_ctx.set_resource("cognitive_service", ctx.cognitive_service) + if ctx.meme_service is not None: + req_ctx.set_resource("meme_service", ctx.meme_service) + + tool_context: dict[str, Any] = { + "request_type": request_type, + "request_id": request_id, + } + if group_id is not None: + tool_context["group_id"] = int(group_id) + if user_id is not None: + tool_context["user_id"] = int(user_id) + if sender_id is not None: + tool_context["sender_id"] = int(sender_id) + + tool_manager = getattr(ai, "tool_manager", None) + if tool_manager is None: + raise RuntimeError("ToolManager not available") + + raw_result = await await_tool_invoke_result( + tool_manager.execute_tool(tool_name, args, tool_context), + timeout=effective_timeout, + ) + + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + result_str = str(raw_result or "") + logger.info( + "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " + "duration_ms=%s result_len=%d", + request_id, + tool_name, + elapsed_ms, + len(result_str), + ) + return { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "result": result_str, + "duration_ms": elapsed_ms, + } + + except _ToolInvokeExecutionTimeoutError: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.warning( + "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", + request_id, + tool_name, + timeout, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": f"Execution timed out after {timeout}s", + "duration_ms": elapsed_ms, + } + except Exception as exc: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.exception( + "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", + request_id, + tool_name, + exc, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": str(exc), + "duration_ms": elapsed_ms, + } + + +async def execute_and_callback( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + callback_url: str, + callback_headers: dict[str, str], + timeout: int, + callback_timeout: int, +) -> None: + """异步执行工具并发送回调。""" + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body_context, + timeout=timeout, + ) + + payload = { + "request_id": result["request_id"], + "tool_name": result["tool_name"], + "ok": result["ok"], + "result": result.get("result"), + "duration_ms": result.get("duration_ms", 0), + "error": result.get("error"), + } + + try: + cb_timeout = ClientTimeout(total=callback_timeout) + async with ClientSession(timeout=cb_timeout) as session: + async with session.post( + callback_url, + json=payload, + headers=callback_headers or None, + ) as resp: + logger.info( + "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", + request_id, + _mask_url(callback_url), + resp.status, + ) + except Exception as exc: + logger.warning( + "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", + request_id, + _mask_url(callback_url), + exc, + ) diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py index c1f6e99..09738cb 100644 --- a/tests/test_runtime_api_chat_stream.py +++ b/tests/test_runtime_api_chat_stream.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import chat as runtime_api_chat class _DummyTransport: @@ -90,18 +90,18 @@ async def _fake_render_message_with_pic_placeholders( ) server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) - async def _fake_run_webui_chat(*, text: str, send_output: Any) -> str: + async def _fake_run_webui_chat(_ctx: Any, *, text: str, send_output: Any) -> str: assert text == "hello" await send_output(42, "bot reply with{escape_xml_text(normalized_text)} {attachment_xml} +") return "chat" monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "render_message_with_pic_placeholders", _fake_render_message_with_pic_placeholders, ) monkeypatch.setattr(web, "StreamResponse", _DummyStreamResponse) - monkeypatch.setattr(server, "_run_webui_chat", _fake_run_webui_chat) + monkeypatch.setattr(runtime_api_chat, "run_webui_chat", _fake_run_webui_chat) request = cast( web.Request, @@ -173,11 +173,11 @@ async def _fake_ask(full_question: str, **kwargs: Any) -> str: server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "register_message_attachments", _fake_register_message_attachments, ) - monkeypatch.setattr(runtime_api_app, "collect_context_resources", lambda _vars: {}) + monkeypatch.setattr(runtime_api_chat, "collect_context_resources", lambda _vars: {}) sent_messages: list[tuple[int, str]] = [] diff --git a/tests/test_runtime_api_naga.py b/tests/test_runtime_api_naga.py index 55a66b4..739919d 100644 --- a/tests/test_runtime_api_naga.py +++ b/tests/test_runtime_api_naga.py @@ -608,9 +608,11 @@ async def _render_html_to_image(_: str, __: str) -> None: raise RuntimeError("render failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html + ) + monkeypatch.setattr( + "Undefined.api.routes.naga.render_html_to_image", _render_html_to_image ) - monkeypatch.setattr("Undefined.api.app.render_html_to_image", _render_html_to_image) response = await server._naga_messages_send_handler( _make_request( @@ -665,7 +667,7 @@ async def _render_markdown_to_html(_: str) -> str: raise RuntimeError("markdown failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html ) response = await server._naga_messages_send_handler( diff --git a/tests/test_runtime_api_probes.py b/tests/test_runtime_api_probes.py index 3d6e0dd..940f32f 100644 --- a/tests/test_runtime_api_probes.py +++ b/tests/test_runtime_api_probes.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import system as runtime_api_system @pytest.mark.asyncio @@ -166,9 +166,11 @@ async def _fake_probe_ws_endpoint(_: str) -> dict[str, Any]: } monkeypatch.setattr( - runtime_api_app, "_probe_http_endpoint", _fake_probe_http_endpoint + runtime_api_system, "_probe_http_endpoint", _fake_probe_http_endpoint + ) + monkeypatch.setattr( + runtime_api_system, "_probe_ws_endpoint", _fake_probe_ws_endpoint ) - monkeypatch.setattr(runtime_api_app, "_probe_ws_endpoint", _fake_probe_ws_endpoint) context = RuntimeAPIContext( config_getter=lambda: SimpleNamespace( diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 16dc141..e3c6615 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -365,7 +365,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-tool", @@ -391,7 +391,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-agent", From 0480ac7c01e1d0f75b7be364e40f2ac03ba513e2 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:11:32 +0800 Subject: [PATCH 35/57] feat(repeat): add cooldown to prevent re-repeating same content MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After the bot repeats a message, the same content enters a configurable cooldown period (default 60 minutes) during which it won't be repeated again, even if the repeat chain conditions are met again. Features: - New config: easter_egg.repeat_cooldown_minutes (default 60, 0=disabled) - Question mark normalization: ? and ? treated as equivalent for cooldown - Per-group, per-text independent cooldown tracking - Cooldown uses monotonic clock (immune to wall-clock changes) Tests: 8 new repeat cooldown tests + 2 config parsing tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- config.toml.example | 3 + src/Undefined/config/loader.py | 12 +++ src/Undefined/handlers.py | 70 ++++++++++--- tests/test_config_easter_egg_repeat.py | 18 ++++ tests/test_handlers_repeat.py | 137 ++++++++++++++++++++++++- 5 files changed, 222 insertions(+), 18 deletions(-) diff --git a/config.toml.example b/config.toml.example index 4ab0d04..c2a7fed 100644 --- a/config.toml.example +++ b/config.toml.example @@ -719,6 +719,9 @@ repeat_enabled = false # zh: 复读触发所需的连续相同消息条数(来自不同发送者),范围 2–20,默认 3。 # en: Number of consecutive identical messages (from different senders) required to trigger repeat, range 2–20, default 3. repeat_threshold = 3 +# zh: 复读冷却时间(分钟)。同一内容被复读后,在冷却时间内不再重复复读。0 = 无冷却。问号类消息(?/?)视为等价。 +# en: Repeat cooldown (minutes). After repeating the same content, won't repeat it again within this cooldown. 0 = no cooldown. Question marks (?/?) are treated as equivalent. +repeat_cooldown_minutes = 60 # zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 # en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). inverted_question_enabled = false diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index 26f669c..992829a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -219,6 +219,7 @@ class Config: keyword_reply_enabled: bool repeat_enabled: bool repeat_threshold: int + repeat_cooldown_minutes: int inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int @@ -493,6 +494,16 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi repeat_threshold = 2 if repeat_threshold > 20: repeat_threshold = 20 + repeat_cooldown_minutes = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_cooldown_minutes"), + "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", + ), + 60, + ) + if repeat_cooldown_minutes < 0: + repeat_cooldown_minutes = 0 context_recent_messages_limit = _coerce_int( _get_value( data, @@ -1241,6 +1252,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi keyword_reply_enabled=keyword_reply_enabled, repeat_enabled=repeat_enabled, repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 726fa5c..357e0c6 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -8,6 +8,7 @@ import os from pathlib import Path import random +import time from typing import Any, Coroutine from Undefined.attachments import ( @@ -131,6 +132,8 @@ def __init__( # 复读功能状态(按群跟踪最近消息文本与发送者) self._repeat_counter: dict[int, list[tuple[str, int]]] = {} self._repeat_locks: dict[int, asyncio.Lock] = {} + # 复读冷却:group_id → {normalized_text → monotonic_timestamp} + self._repeat_cooldown: dict[int, dict[str, float]] = {} # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) @@ -143,6 +146,31 @@ def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: self._repeat_locks[group_id] = lock return lock + @staticmethod + def _normalize_repeat_text(text: str) -> str: + """规范化复读文本用于冷却比较(?→?)。""" + return text.replace("?", "?") + + def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: + """检查指定群的文本是否在复读冷却期内。""" + cooldown_minutes = self.config.repeat_cooldown_minutes + if cooldown_minutes <= 0: + return False + group_cd = self._repeat_cooldown.get(group_id) + if not group_cd: + return False + key = self._normalize_repeat_text(text) + last_time = group_cd.get(key) + if last_time is None: + return False + return (time.monotonic() - last_time) < cooldown_minutes * 60 + + def _record_repeat_cooldown(self, group_id: int, text: str) -> None: + """记录复读冷却时间戳。""" + key = self._normalize_repeat_text(text) + group_cd = self._repeat_cooldown.setdefault(group_id, {}) + group_cd[key] = time.monotonic() + async def _annotate_meme_descriptions( self, attachments: list[dict[str, str]], @@ -829,22 +857,32 @@ async def _fetch_group_name() -> str: and self.config.bot_qq not in senders ): reply_text = texts[0] - if self.config.inverted_question_enabled: - stripped = reply_text.strip() - if set(stripped) <= {"?", "?"}: - reply_text = "¿" * len(stripped) - # 清空计数器防止重复触发 - self._repeat_counter[group_id] = [] - logger.info( - "[复读] 触发复读: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - await self.sender.send_group_message( - group_id, - reply_text, - history_prefix=REPEAT_REPLY_HISTORY_PREFIX, - ) + # 冷却检查:同一内容在冷却期内不再复读 + if self._is_repeat_on_cooldown(group_id, reply_text): + self._repeat_counter[group_id] = [] + logger.debug( + "[复读] 冷却中跳过: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + else: + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + # 清空计数器防止重复触发 + self._repeat_counter[group_id] = [] + self._record_repeat_cooldown(group_id, texts[0]) + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) return # Bilibili 视频自动提取 diff --git a/tests/test_config_easter_egg_repeat.py b/tests/test_config_easter_egg_repeat.py index 89872e5..e27ae05 100644 --- a/tests/test_config_easter_egg_repeat.py +++ b/tests/test_config_easter_egg_repeat.py @@ -27,12 +27,14 @@ def test_repeat_defaults_to_false(tmp_path: Path) -> None: cfg = _load(tmp_path, _MINIMAL) assert cfg.repeat_enabled is False assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 def test_repeat_enabled_explicit(tmp_path: Path) -> None: cfg = _load(tmp_path, _MINIMAL + "\n[easter_egg]\nrepeat_enabled = true\n") assert cfg.repeat_enabled is True assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 def test_inverted_question_enabled_explicit(tmp_path: Path) -> None: @@ -60,3 +62,19 @@ def test_keyword_reply_still_parsed_from_easter_egg(tmp_path: Path) -> None: _MINIMAL + "\n[easter_egg]\nkeyword_reply_enabled = true\n", ) assert cfg.keyword_reply_enabled is True + + +def test_repeat_cooldown_custom_value(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = 30\n", + ) + assert cfg.repeat_cooldown_minutes == 30 + + +def test_repeat_cooldown_negative_clamped_to_zero(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = -5\n", + ) + assert cfg.repeat_cooldown_minutes == 0 diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 8d508de..4f9a830 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -18,6 +18,7 @@ def _build_handler( *, repeat_enabled: bool = False, repeat_threshold: int = 3, + repeat_cooldown_minutes: int = 60, inverted_question_enabled: bool = False, keyword_reply_enabled: bool = False, ) -> Any: @@ -26,6 +27,7 @@ def _build_handler( bot_qq=10000, repeat_enabled=repeat_enabled, repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, inverted_question_enabled=inverted_question_enabled, keyword_reply_enabled=keyword_reply_enabled, bilibili_auto_extract_enabled=False, @@ -69,6 +71,7 @@ def _build_handler( handler._background_tasks = set() handler._repeat_counter = {} handler._repeat_locks = {} + handler._repeat_cooldown = {} handler._profile_name_refresh_cache = {} handler._bot_nickname_cache = SimpleNamespace( get_nicknames=AsyncMock(return_value=frozenset()), @@ -155,14 +158,14 @@ async def test_repeat_does_not_trigger_for_different_texts() -> None: @pytest.mark.asyncio async def test_repeat_clears_counter_after_trigger() -> None: - handler = _build_handler(repeat_enabled=True) + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) # 第一轮:3条相同触发复读 for uid in [20001, 20002, 20003]: await handler.handle_message(_group_event(sender_id=uid, text="hello")) assert handler.sender.send_group_message.call_count == 1 - # 第二轮:再来3条相同应再次触发 + # 第二轮:再来3条相同应再次触发(无冷却) for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="hello")) @@ -343,3 +346,133 @@ async def test_repeat_custom_threshold_4() -> None: await handler.handle_message(_group_event(sender_id=20004, text="hey")) handler.sender.send_group_message.assert_called_once() assert handler.sender.send_group_message.call_args.args[1] == "hey" + + +# ── 冷却机制:复读后同一内容在冷却期内不再触发 ── + + +@pytest.mark.asyncio +async def test_repeat_cooldown_suppresses_same_text() -> None: + """复读触发后,同一内容在冷却期内再次满足条件也不触发。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 第二轮:同一内容,应被冷却抑制 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_allows_different_text() -> None: + """复读 "草" 后,不同内容 "lol" 仍可正常复读。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + assert handler.sender.send_group_message.call_args.args[1] == "lol" + + +@pytest.mark.asyncio +async def test_repeat_cooldown_expired_allows_retrigger( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """冷却过期后,相同内容可以再次触发。""" + import time as _time + + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 模拟时间流逝 61 分钟 + original_monotonic = _time.monotonic + monkeypatch.setattr(_time, "monotonic", lambda: original_monotonic() + 3660) + + # 第二轮:冷却已过期,应触发 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_zero_disables() -> None: + """cooldown=0 时不启用冷却,连续复读同一内容仍可触发。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_question_mark_normalization() -> None: + """全角问号 ? 和半角问号 ? 视为等价,复读后互相抑制。""" + handler = _build_handler( + repeat_enabled=True, + repeat_cooldown_minutes=60, + inverted_question_enabled=True, + ) + # 全角问号触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 + assert handler.sender.send_group_message.call_args.args[1] == "¿¿¿" + + # 半角问号——应被冷却抑制(?和 ? 等价) + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_multiple_texts_tracked() -> None: + """多种不同内容各自独立冷却。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + # 复读 "lol" + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + # "草" 再次满足条件 → 冷却中,不触发 + for uid in [20007, 20008, 20009]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + # "lol" 再次满足条件 → 冷却中,不触发 + for uid in [20010, 20011, 20012]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_groups_independent() -> None: + """不同群的冷却互不影响。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 群A 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message( + _group_event(group_id=30001, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 1 + + # 群B 复读 "草" — 不同群,不受群A冷却影响 + for uid in [20004, 20005, 20006]: + await handler.handle_message( + _group_event(group_id=30002, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 2 From b024427d4a8f9e785969b9163ceba106ef1c709f Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:19:00 +0800 Subject: [PATCH 36/57] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E8=A6=86=E7=9B=96=E5=A4=8D=E8=AF=BB=E5=86=B7=E5=8D=B4?= =?UTF-8?q?=E3=80=81=E5=81=87@=E6=A3=80=E6=B5=8B=E3=80=81profile=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E6=A8=A1=E5=BC=8F=E4=B8=8EAPI=E6=8B=86=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - configuration.md: 新增 repeat_threshold、repeat_cooldown_minutes 字段 - slash-commands.md: 补充 /profile 输出模式(-f/-r/-t)与超管定向查看 - ARCHITECTURE.md: 补充 Runtime API 路由子模块层级 - development.md: 补充 config/、api/routes/、utils/ 目录说明 - CHANGELOG.md: 新增 v3.3.2 版本条目 Co-authored-by: Claude Opus 4.6 --- ARCHITECTURE.md | 2 +- CHANGELOG.md | 15 +++++++++++++++ docs/configuration.md | 6 ++++-- docs/development.md | 8 +++++++- docs/slash-commands.md | 24 ++++++++++++++++++------ 5 files changed, 45 insertions(+), 10 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index e2d4176..53d35be 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -819,7 +819,7 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户 ### 8层架构分层 1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商 -2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py) +2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块) 3. **消息处理层**:MessageHandler (handlers.py)、SecurityService (security.py)、CommandDispatcher (services/command.py)、AICoordinator (ai_coordinator.py)、QueueManager (queue_manager.py)、Bilibili 自动提取 (bilibili/) 4. **AI 核心能力层**:AIClient (client.py)、PromptBuilder (prompts.py)、ModelRequester (llm.py)、ToolManager (tooling.py)、MultimodalAnalyzer (multimodal.py)、SummaryService (summaries.py)、TokenCounter (tokens.py) 5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb1da18..e7b4f5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## v3.3.2 架构拆分、假@检测与多模式侧写 + +Runtime API 大规模拆分、假@检测、/profile 多输出模式、超管跨目标侧写查看、复读冷却机制,以及配置子模块化与大量工具函数增强。 + +- **api/app.py 拆分**:将 2491 行的巨型路由文件拆分为 8 个独立路由子模块 (`api/routes/`),app.py 仅保留薄包装委派层(333 行)。 +- **假@检测**:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息(自动获取群内昵称,防竞态)。支持 `@昵称 /命令` 正常触发斜杠指令。 +- **/profile 多输出模式**:新增 `-f`(合并转发,默认)、`-r`(渲染图片)、`-t`(直接文本)三种侧写输出模式。合并转发分元数据与内容两条消息;渲染模式优化了元数据与正文的视觉排版。 +- **超管跨目标侧写**:超级管理员可通过 `/p ` 和 `/p g <群号>` 查看任意用户/群的侧写。 +- **复读冷却**:同一内容被复读后,在可配置的冷却期(默认 60 分钟)内不再重复复读。?和 ? 视为等价。 +- **配置子模块化**:config/ 拆分为 loader.py、models.py、hot_reload.py,新增 `repeat_threshold`(2–20)与 `repeat_cooldown_minutes` 配置项。 +- **utils 增强**:新增 `coerce.py`(安全类型强转)、`fake_at.py`(假@文本检测与解析)。 +- 大量测试覆盖新增(1438+ 测试),ruff + mypy 零错误。 + +--- + ## v3.3.1 /version 命令的添加 添加了 /version 命令以查看版本号和更改内容。 diff --git a/docs/configuration.md b/docs/configuration.md index eceb9ea..7764a2d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -426,8 +426,10 @@ Prompt caching 补充: |---|---:|---|---| | `agent_call_message_enabled` | `"none"` | 调用提示模式 | `none` / `agent` / `tools` / `all` / `clean` | | `keyword_reply_enabled` | `false` | 群聊关键词自动回复 | 布尔 | -| `repeat_enabled` | `false` | 群聊复读(连续3条相同消息时复读) | 布尔 | -| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送¿) | 布尔 | +| `repeat_enabled` | `false` | 群聊复读(连续 N 条相同消息时复读) | 布尔 | +| `repeat_threshold` | `3` | 触发复读所需的连续相同消息条数(来自不同发送者) | 整数,2–20 | +| `repeat_cooldown_minutes` | `60` | 复读冷却时间(分钟)。同一内容被复读后,在冷却期内不再重复复读。?和 ? 视为等价。0 = 无冷却 | 整数,≥ 0 | +| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送 ¿) | 布尔 | 兼容:历史字段 `[core].keyword_reply_enabled` 仍可读取,建议迁移到 `[easter_egg]`。 diff --git a/docs/development.md b/docs/development.md index 6dab818..522c1f6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -22,8 +22,14 @@ src/Undefined/ │ ├── agents/ # 智能体 (独立自主的子 AI,负责处理诸如 Web 搜索、文件分析的具体长时任务) │ ├── commands/ # 中心化斜杠指令系统 (实现如 /help, /stats, /addadmin 等平台功能) │ └── anthropic_skills/# Anthropic 协议集成的外部 Skills (兼容 SKILL.md 格式) +├── config/ # 配置系统 (loader.py TOML 解析, models.py 数据模型, hot_reload.py 热更新) +├── api/ # Management API + Runtime API +│ ├── routes/ # 路由子模块 (chat, tools, naga, system, memes, memory, cognitive, health) +│ ├── app.py # aiohttp 服务主入口 (薄包装委派到 routes/) +│ └── _openapi.py # OpenAPI 文档生成 +├── memes/ # 表情包库 (两阶段 AI 管线, SQLite + ChromaDB) ├── services/ # 核心运行服务 (Queue 任务队列, Command 命令分发, Security 安全防护拦截) -├── utils/ # 通用支持工具组 (包含历史处理、JSON原子读写加锁 IO 操作等) +├── utils/ # 通用支持工具组 (io.py 异步原子读写, history.py, coerce.py 类型强转, fake_at.py 假@检测) ├── handlers.py # 最外层 OneBot 消息分流处理层 └── onebot.py # OneBot WebSocket 客户端核心连接 ``` diff --git a/docs/slash-commands.md b/docs/slash-commands.md index 6acf72b..f0baa17 100644 --- a/docs/slash-commands.md +++ b/docs/slash-commands.md @@ -86,25 +86,37 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 #### 2. 消息总结与侧写查看 -- **/profile [group]** +- **/profile [group] [-f|-r|-t] [目标ID]** - **说明**:查看用户或群聊的认知侧写。侧写由系统根据聊天历史自动生成和更新。 - **别名**:`/me`、`/p` - **参数**: | 参数 | 是否必填 | 说明 | |------|----------|------| - | `group` | 可选 | 传入 `group` 时查看当前群聊的侧写(仅群聊可用) | + | `group` / `g` | 可选 | 查看群聊侧写(仅群聊可用) | + | `-f` / `--forward` | 可选 | 合并转发模式输出(默认) | + | `-r` / `--render` | 可选 | 渲染为图片发送 | + | `-t` / `--text` | 可选 | 直接文本消息发送 | + | ` ` | 可选 | 🔒 超管专用:查看指定用户的侧写 | + | `g <群号>` | 可选 | 🔒 超管专用:查看指定群聊的侧写 | - **行为**: - - **私聊**:只能查看自己的用户侧写,不支持 `group` 参数。 - - **群聊**:不带参数查看自己的用户侧写,带 `group` 查看当前群聊侧写。 - - 过长侧写会自动截断(3000 字符上限)。 + - **私聊**:查看自己的用户侧写,不支持 `group` 参数。 + - **群聊**:不带参数查看自己的用户侧写,带 `group` / `g` 查看当前群聊侧写。 + - **超管指定目标**:超级管理员可传入 QQ 号或群号查看任意用户/群的侧写,非超管使用时提示无权限。 + - **输出模式**:默认合并转发;`-r` 渲染为图片;`-t` 直接文本发送。 - **限流**:普通用户 60 秒,管理员 10 秒,超管无限制。 - **示例**: ``` - /profile → 查看自己的侧写 + /profile → 查看自己的侧写(合并转发) + /p -r → 查看自己的侧写(渲染图片) + /p -t → 查看自己的侧写(直接文本) /me → 同上(别名) /profile group → 查看当前群聊的侧写 + /p g → 同上 + /p 123456 → 🔒 超管:查看QQ号123456的侧写 + /p g 789012 → 🔒 超管:查看群号789012的侧写 + /p 123456 -r → 🔒 超管:查看指定用户侧写(渲染图片) ``` - **/summary [条数|时间范围] [自定义描述]** From 33d715859c59f07a6dda85187cd39ff7ec543aa9 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:21:59 +0800 Subject: [PATCH 37/57] =?UTF-8?q?docs(changelog):=20=E9=87=8D=E5=86=99=20v?= =?UTF-8?q?3.3.2=20=E6=9D=A1=E7=9B=AE=E8=A6=86=E7=9B=96=E6=89=80=E6=9C=89?= =?UTF-8?q?=20feature/u-guess=20=E5=8F=98=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Claude Opus 4.6 --- CHANGELOG.md | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7b4f5d..83f2f07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,22 @@ -## v3.3.2 架构拆分、假@检测与多模式侧写 - -Runtime API 大规模拆分、假@检测、/profile 多输出模式、超管跨目标侧写查看、复读冷却机制,以及配置子模块化与大量工具函数增强。 - -- **api/app.py 拆分**:将 2491 行的巨型路由文件拆分为 8 个独立路由子模块 (`api/routes/`),app.py 仅保留薄包装委派层(333 行)。 -- **假@检测**:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息(自动获取群内昵称,防竞态)。支持 `@昵称 /命令` 正常触发斜杠指令。 -- **/profile 多输出模式**:新增 `-f`(合并转发,默认)、`-r`(渲染图片)、`-t`(直接文本)三种侧写输出模式。合并转发分元数据与内容两条消息;渲染模式优化了元数据与正文的视觉排版。 -- **超管跨目标侧写**:超级管理员可通过 `/p ` 和 `/p g <群号>` 查看任意用户/群的侧写。 -- **复读冷却**:同一内容被复读后,在可配置的冷却期(默认 60 分钟)内不再重复复读。?和 ? 视为等价。 -- **配置子模块化**:config/ 拆分为 loader.py、models.py、hot_reload.py,新增 `repeat_threshold`(2–20)与 `repeat_cooldown_minutes` 配置项。 -- **utils 增强**:新增 `coerce.py`(安全类型强转)、`fake_at.py`(假@文本检测与解析)。 -- 大量测试覆盖新增(1438+ 测试),ruff + mypy 零错误。 +## v3.3.2 架构重构、假@检测与认知侧写增强 + +围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 + +- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 +- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 +- 超级管理员可通过 `/p ` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 +- 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 +- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 +- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 +- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 +- 新增 `calculator` 多功能安全计算器工具。 +- 新增消息历史限制全面可配置化(`[history].max_records`)。 +- 新增 `utils/coerce.py`(安全类型强转)与 `utils/fake_at.py`(假@文本检测与解析)公共模块。 +- WebUI 新增功能:Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 修复队列系统 historian 模型未注册的调度问题。 +- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 +- 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 --- From dcf02ab40c584554f26309eaaa65528c60dabf8d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:34:40 +0800 Subject: [PATCH 38/57] chore(version): bump version to 3.3.2 --- apps/undefined-console/package-lock.json | 4 ++-- apps/undefined-console/package.json | 2 +- apps/undefined-console/src-tauri/Cargo.lock | 2 +- apps/undefined-console/src-tauri/Cargo.toml | 2 +- apps/undefined-console/src-tauri/tauri.conf.json | 2 +- pyproject.toml | 2 +- src/Undefined/__init__.py | 2 +- uv.lock | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/apps/undefined-console/package-lock.json b/apps/undefined-console/package-lock.json index 4c32b00..cb66706 100644 --- a/apps/undefined-console/package-lock.json +++ b/apps/undefined-console/package-lock.json @@ -1,12 +1,12 @@ { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "dependencies": { "@tauri-apps/api": "^2.3.0", "@tauri-apps/plugin-http": "^2.3.0" diff --git a/apps/undefined-console/package.json b/apps/undefined-console/package.json index 1a85c4c..d7b6e21 100644 --- a/apps/undefined-console/package.json +++ b/apps/undefined-console/package.json @@ -1,7 +1,7 @@ { "name": "undefined-console", "private": true, - "version": "3.3.1", + "version": "3.3.2", "type": "module", "scripts": { "tauri": "tauri", diff --git a/apps/undefined-console/src-tauri/Cargo.lock b/apps/undefined-console/src-tauri/Cargo.lock index 190f9f8..f9b61ac 100644 --- a/apps/undefined-console/src-tauri/Cargo.lock +++ b/apps/undefined-console/src-tauri/Cargo.lock @@ -4063,7 +4063,7 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" dependencies = [ "serde", "serde_json", diff --git a/apps/undefined-console/src-tauri/Cargo.toml b/apps/undefined-console/src-tauri/Cargo.toml index c3a9356..3acc528 100644 --- a/apps/undefined-console/src-tauri/Cargo.toml +++ b/apps/undefined-console/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" description = "Undefined cross-platform management console" authors = ["Undefined contributors"] license = "MIT" diff --git a/apps/undefined-console/src-tauri/tauri.conf.json b/apps/undefined-console/src-tauri/tauri.conf.json index a8a61f1..7fe5107 100644 --- a/apps/undefined-console/src-tauri/tauri.conf.json +++ b/apps/undefined-console/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "Undefined Console", - "version": "3.3.1", + "version": "3.3.2", "identifier": "com.undefined.console", "build": { "beforeDevCommand": "npm run dev", diff --git a/pyproject.toml b/pyproject.toml index 9128920..6dcc6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Undefined-bot" -version = "3.3.1" +version = "3.3.2" description = "QQ bot platform with cognitive memory architecture and multi-agent Skills, via OneBot V11." readme = "README.md" authors = [ diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5722e02..c2c8768 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -1,3 +1,3 @@ """Undefined - A high-performance, highly scalable QQ group and private chat robot based on a self-developed architecture.""" -__version__ = "3.3.1" +__version__ = "3.3.2" diff --git a/uv.lock b/uv.lock index 810f4a0..d1d2053 100644 --- a/uv.lock +++ b/uv.lock @@ -4638,7 +4638,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.3.1" +version = "3.3.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From 64af9b3b803268018d826f93d44194fb092efea7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:39:02 +0800 Subject: [PATCH 39/57] fix(test): skip LaTeX render tests when Playwright browser binary missing The tests only caught ImportError but CI has Playwright installed without browser binaries, causing a runtime Error. Now checks the error message returned by execute() for 'Executable doesn't exist' and skips. Co-authored-by: Claude Opus 4.6 --- tests/test_render_latex_tool.py | 51 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py index 0595c99..d4d3e80 100644 --- a/tests/test_render_latex_tool.py +++ b/tests/test_render_latex_tool.py @@ -58,17 +58,14 @@ async def test_render_simple_equation() -> None: args = {"content": "E = mc^2", "output_format": "png"} - try: - result = await execute(args, context) - assert result == ' ' - assert len(mock_registry.registered_items) == 1 - assert mock_registry.registered_items[0]["kind"] == "image" - assert mock_registry.registered_items[0]["mime_type"] == "image/png" - assert mock_registry.registered_items[0]["size"] > 0 - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == ' ' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "image" + assert mock_registry.registered_items[0]["mime_type"] == "image/png" + assert mock_registry.registered_items[0]["size"] > 0 @pytest.mark.asyncio @@ -85,14 +82,11 @@ async def test_render_with_delimiters() -> None: args = {"content": r"\[ \int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2} \]"} - try: - result = await execute(args, context) - assert result == ' ' - assert len(mock_registry.registered_items) == 1 - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == ' ' + assert len(mock_registry.registered_items) == 1 @pytest.mark.asyncio @@ -109,17 +103,14 @@ async def test_render_pdf_output() -> None: args = {"content": r"\frac{a}{b} + \sqrt{c}", "output_format": "pdf"} - try: - result = await execute(args, context) - assert result == ' ' - assert len(mock_registry.registered_items) == 1 - assert mock_registry.registered_items[0]["kind"] == "file" - assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" - assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == ' ' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "file" + assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" + assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" @pytest.mark.asyncio From 73954722efa2f6028b7419c8fd3ed21441e92f8a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:47:23 +0800 Subject: [PATCH 40/57] fix: address PR review findings from Devin and Codex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - fix(prompt): split concatenated bullet points in judge_meme_image.txt - fix(prompts): use configurable repeat_threshold instead of hardcoded 3 - fix(memes): guard safe_int(0) → None for group_id sentinel value - fix(fake_at): normalize cached nicknames with NFKC (not just casefold) - fix(handlers): skip repeat counting for empty text messages Co-authored-by: Claude Opus 4.6 --- res/prompts/judge_meme_image.txt | 3 ++- src/Undefined/ai/prompts.py | 7 +++++-- src/Undefined/handlers.py | 2 +- src/Undefined/memes/service.py | 2 +- src/Undefined/utils/fake_at.py | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/res/prompts/judge_meme_image.txt b/res/prompts/judge_meme_image.txt index 5a3d02e..fdb480b 100644 --- a/res/prompts/judge_meme_image.txt +++ b/res/prompts/judge_meme_image.txt @@ -21,4 +21,5 @@ - 必须且只能调用 `submit_meme_judgement` - `is_meme` 仅表示“适不适合放进聊天表情包库”,不是“图里有没有梗” - `reason` 用一句简短中文说明依据 -- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true`- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file +- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true` +- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index dda6b04..b1dd63b 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -222,7 +222,8 @@ def _build_model_config_info(self, runtime_config: Any) -> str: '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' ) if repeat_enabled: - desc = "复读(群聊连续3条相同消息时自动复读)" + threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + desc = f"复读(群聊连续{threshold}条相同消息时自动复读)" if inverted_question_enabled: desc += ",倒问号(复读触发时若消息为问号则发送¿)" easter_egg_parts.append(desc) @@ -342,6 +343,7 @@ async def build_messages( keyword_reply_enabled = False repeat_enabled = False + repeat_threshold = 3 inverted_question_enabled = False if self._runtime_config_getter is not None: try: @@ -350,6 +352,7 @@ async def build_messages( getattr(runtime_config, "keyword_reply_enabled", False) ) repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + repeat_threshold = int(getattr(runtime_config, "repeat_threshold", 3)) inverted_question_enabled = bool( getattr(runtime_config, "inverted_question_enabled", False) ) @@ -375,7 +378,7 @@ async def build_messages( if is_group_context and repeat_enabled: repeat_desc = ( "【系统行为说明】\n" - "当前群聊已开启复读彩蛋:当群聊中连续出现3条内容相同且来自不同人的消息时," + f"当前群聊已开启复读彩蛋:当群聊中连续出现{repeat_threshold}条内容相同且来自不同人的消息时," "系统会自动复读一条相同的消息,并在历史中写入" '以"[系统复读] "开头的消息。' ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 357e0c6..37d7e8f 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -837,7 +837,7 @@ async def _fetch_group_name() -> str: return # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold - if self.config.repeat_enabled: + if self.config.repeat_enabled and text: n = self.config.repeat_threshold async with self._get_repeat_lock(group_id): counter = self._repeat_counter.setdefault(group_id, []) diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 7712232..1159d0c 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -739,7 +739,7 @@ async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: attachments=history_attachments, ) else: - preferred_temp_group_id = safe_int(context.get("group_id")) + preferred_temp_group_id = safe_int(context.get("group_id")) or None sent_message_id = await sender.send_private_message( int(target_id), cq_message, diff --git a/src/Undefined/utils/fake_at.py b/src/Undefined/utils/fake_at.py index 0ac0200..c3df1ae 100644 --- a/src/Undefined/utils/fake_at.py +++ b/src/Undefined/utils/fake_at.py @@ -163,7 +163,7 @@ async def _fetch(self, group_id: int) -> frozenset[str]: for key in ("card", "nickname"): val = str(info.get(key, "") or "").strip() if val: - names.add(val.casefold()) + names.add(_normalize(val)) except Exception as exc: logger.debug( "[假@] 获取 bot 群成员信息失败: group=%s err=%s", From 5f659a83eac39b9f8af46d3e513ded5972932f5a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:57:45 +0800 Subject: [PATCH 41/57] fix: address additional Devin review flags - Add 'pic' key to _MEDIA_LABELS for correct tag-based fallback label - Harden config backup path validation with backslash check and resolve()-based containment verification Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/attachments.py | 1 + src/Undefined/webui/routes/_config.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index e8fa8fc..951a75b 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -47,6 +47,7 @@ "audio": "音频", "video": "视频", "record": "语音", + "pic": "图片", } _WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") _DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 diff --git a/src/Undefined/webui/routes/_config.py b/src/Undefined/webui/routes/_config.py index c6c65b3..e3f44e4 100644 --- a/src/Undefined/webui/routes/_config.py +++ b/src/Undefined/webui/routes/_config.py @@ -239,9 +239,11 @@ async def config_restore_handler(request: web.Request) -> Response: except Exception: return web.json_response({"error": "Invalid JSON"}, status=400) name = str(data.get("name", "")) - if not name or ".." in name or "/" in name: + if not name or ".." in name or "/" in name or "\\" in name: + return web.json_response({"error": "Invalid backup name"}, status=400) + backup_path = (_BACKUP_DIR / name).resolve() + if not str(backup_path).startswith(str(_BACKUP_DIR.resolve())): return web.json_response({"error": "Invalid backup name"}, status=400) - backup_path = _BACKUP_DIR / name if not backup_path.exists(): return web.json_response({"error": "Backup not found"}, status=404) content = backup_path.read_text(encoding="utf-8") From 4c973b934ee679f5c1b987df303c2d6f6ce356e5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:06:40 +0800 Subject: [PATCH 42/57] fix(config): parse gif_analysis_mode/gif_analysis_frames from TOML MemeConfig had these fields with defaults but _parse_memes_config never read them from the [memes] section, silently ignoring user configuration. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/config/domain_parsers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py index ebb2222..07f807b 100644 --- a/src/Undefined/config/domain_parsers.py +++ b/src/Undefined/config/domain_parsers.py @@ -172,6 +172,8 @@ def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: worker_max_concurrency=max( 1, _coerce_int(section.get("worker_max_concurrency"), 4) ), + gif_analysis_mode=_coerce_str(section.get("gif_analysis_mode"), "grid"), + gif_analysis_frames=max(1, _coerce_int(section.get("gif_analysis_frames"), 6)), ) From 589f949a84ad5ec67ac4a72f9fc9d86a22f9ac0b Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:26:11 +0800 Subject: [PATCH 43/57] fix(memes): clean up GIF multi-frame analysis temp files _prepare_gif_multi_frames creates per-frame PNG files ({uid}_f{i}.png) for LLM analysis but they were never cleaned up. Add _cleanup_gif_frame_files helper and call it: - In _cleanup_meme_artifacts when uid is provided - In delete_meme - After judge/describe AI calls complete (frames no longer needed) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/memes/service.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 1159d0c..9b02d41 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -438,6 +438,7 @@ async def delete_meme(self, uid: str) -> bool: self._delete_file_if_exists, Path(record.preview_path), ) + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return True def _delete_file_if_exists(self, path: Path) -> None: @@ -446,6 +447,17 @@ def _delete_file_if_exists(self, path: Path) -> None: except OSError: logger.debug("[memes] 删除文件失败: path=%s", path, exc_info=True) + def _cleanup_gif_frame_files(self, uid: str) -> None: + """清理 GIF 多帧分析产生的临时帧文件 ({uid}_f{i}.png)。""" + preview_dir = self._preview_dir() + for frame_file in preview_dir.glob(f"{uid}_f*.png"): + try: + frame_file.unlink(missing_ok=True) + except OSError: + logger.debug( + "[memes] 删除帧文件失败: path=%s", frame_file, exc_info=True + ) + async def _cleanup_meme_artifacts( self, *, @@ -472,6 +484,8 @@ async def _cleanup_meme_artifacts( await asyncio.to_thread(self._delete_file_if_exists, blob_path) if preview_path is not None and preview_path != blob_path: await asyncio.to_thread(self._delete_file_if_exists, preview_path) + if uid: + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) async def search_memes( self, @@ -1025,6 +1039,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: ) judgement = {"is_meme": False} if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self._cleanup_meme_artifacts( uid=None, blob_path=blob_path, @@ -1041,6 +1057,9 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: "[memes] describe stage failed, drop uid=%s err=%s", uid, exc ) described = {"description": "", "tags": []} + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) tags = _normalize_tags(described.get("tags")) auto_description = str(described.get("description") or "").strip() if not auto_description and not tags: From b80277cc7fede94db543473a7995be75126e59f4 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:38:11 +0800 Subject: [PATCH 44/57] fix(attachments): skip prompt_ref append on render error _render_image_tag and _render_file_tag error paths returned early but attachments.append(record.prompt_ref()) still executed unconditionally. Change both helpers to return bool; only append on success. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/attachments.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index 951a75b..403154d 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -1154,9 +1154,9 @@ async def render_message_with_attachments( # Route by media type if record.media_type == "image": - _render_image_tag(record, uid, strict, delivery_parts, history_parts) + ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) else: - _render_file_tag( + ok = _render_file_tag( record, uid, strict, @@ -1165,7 +1165,8 @@ async def render_message_with_attachments( pending_files, ) - attachments.append(record.prompt_ref()) + if ok: + attachments.append(record.prompt_ref()) delivery_parts.append(message[last_index:]) history_parts.append(message[last_index:]) @@ -1183,8 +1184,8 @@ def _render_image_tag( strict: bool, delivery_parts: list[str], history_parts: list[str], -) -> None: - """Render an image attachment as an inline CQ:image.""" +) -> bool: + """Render an image attachment as an inline CQ:image. Returns True on success.""" image_source = record.source_ref if record.local_path: image_source = Path(record.local_path).resolve().as_uri() @@ -1194,7 +1195,7 @@ def _render_image_tag( raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") delivery_parts.append(replacement) history_parts.append(replacement) - return + return False cq_args = [f"file={image_source}"] for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): @@ -1210,6 +1211,7 @@ def _render_image_tag( history_parts.append(f"[图片 uid={uid} name={record.display_name}]") else: history_parts.append(f"[图片 uid={uid}]") + return True def _render_file_tag( @@ -1219,21 +1221,22 @@ def _render_file_tag( delivery_parts: list[str], history_parts: list[str], pending_files: list[AttachmentRecord], -) -> None: - """Render a non-image attachment as a pending file send.""" +) -> bool: + """Render a non-image attachment as a pending file send. Returns True on success.""" if not record.local_path or not Path(record.local_path).is_file(): replacement = f"[文件 uid={uid} 缺少本地文件]" if strict: raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") delivery_parts.append(replacement) history_parts.append(replacement) - return + return False # Remove from delivery text (file sent separately) # Keep a readable placeholder in history name_part = f" name={record.display_name}" if record.display_name else "" history_parts.append(f"[文件 uid={uid}{name_part}]") pending_files.append(record) + return True # Backward-compatible alias From 58df7907771774f42ee08b1b9d6bcb3f478eead4 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:38:42 +0800 Subject: [PATCH 45/57] docs(prompts): strengthen naga_code_analysis_agent call guidance Reinforce that NagaAgent technical questions must call the agent before replying. Add explicit when_to_call scenarios and emphasize not relying on memory for frequently-updated project details. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- res/prompts/undefined_nagaagent.xml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index acd3cc3..205b99f 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -305,8 +305,8 @@ 明确的 NagaAgent 技术问题或讨论 -直接调用 naga_code_analysis_agent,确认相关性后再回复 -如果只是泛泛提到naga但不是技术讨论,不要回复 +**必须**先调用 naga_code_analysis_agent 获取信息,再基于返回结果回复 +如果只是泛泛提到naga但不是技术讨论,不要回复;但只要涉及技术细节,一定要先调 agent @@ -437,9 +437,20 @@ - 对于任何涉及 NagaAgent 的技术问题,直接调用 naga_code_analysis_agent 处理。 + 对于任何涉及 NagaAgent 的技术问题,**必须先调用 naga_code_analysis_agent 获取准确信息后再回复**。 + 不要依赖自身记忆或猜测来回答 NagaAgent 相关问题——该项目代码频繁更新,只有通过 agent 实时查阅才能保证准确。 该 Agent 内部拥有自己的工具集(read_naga_intro、read_file、search_file_content 等), 这些内部工具你无法直接调用,你只需要调用 naga_code_analysis_agent 即可。 + + From e02e0665d7a3b17d6d499cb3b3a95ca8e7f6c6f5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:57:14 +0800 Subject: [PATCH 46/57] fix(prompt): clarify keyword auto-reply is code-path-only to prevent AI mimicry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AI model (kimi-k2.5) read the system prompt about [系统关键词自动回复] and mimicked the format via send_message tool, fabricating a reply that didn't exist in the codebase. Reworded the prompt to explicitly state these messages are generated by a separate code path (handlers.py), use fixed responses, and never go through the AI's tool calls. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6+ 以下场景必须调用 naga_code_analysis_agent: + - 用户询问 NagaAgent 的功能、配置、部署、构建方式 + - 用户遇到 NagaAgent 相关的报错或问题 + - 用户想了解 NagaAgent 的架构、代码逻辑、技能系统等 + - 用户提到 NagaAgent 的任何技术细节(API、openclaw、干员、技能等) + - 讨论涉及 NagaAgent 与其他系统的集成或对比 + 只有纯闲聊式提及(如"naga好用吗"这类不需要技术细节的对话)才可以不调用。 + --- src/Undefined/ai/prompts.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index b1dd63b..e938181 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -364,11 +364,13 @@ async def build_messages( { "role": "system", "content": ( - "【系统行为说明】\n" + "【系统行为说明 — 关键词自动回复】\n" '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' - "命中时,系统可能直接发送固定回复,并在历史中写入" - '以"[系统关键词自动回复] "开头的消息。\n\n' - "这类消息属于系统预设机制,不代表你在该轮主动决策。" + "该功能由 handlers.py 中的独立代码路径处理," + "在消息到达你之前就已完成发送。\n\n" + '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' + "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," + "不经过你的工具调用,与你的决策无关。\n\n" "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" "除非用户主动询问,否则不要主动解释此机制。" ), From d7a294f69686c257a4adea161e3e5dbd59ede514 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:59:12 +0800 Subject: [PATCH 47/57] fix(repeat): don't silently drop messages when cooldown suppresses repeat Move the early 'return' inside the else branch (actual repeat sent) so that when cooldown suppresses a repeat, the message continues through downstream handlers (bilibili/arxiv/command/AI auto-reply) instead of being silently dropped. Update test_repeat_cooldown_suppresses_same_text to verify that ai_coordinator.handle_auto_reply is still called after suppression. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 2 +- tests/test_handlers_repeat.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 37d7e8f..33d6a07 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -883,7 +883,7 @@ async def _fetch_group_name() -> str: reply_text, history_prefix=REPEAT_REPLY_HISTORY_PREFIX, ) - return + return # Bilibili 视频自动提取 if self.config.bilibili_auto_extract_enabled: diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 4f9a830..c29219d 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -353,18 +353,21 @@ async def test_repeat_custom_threshold_4() -> None: @pytest.mark.asyncio async def test_repeat_cooldown_suppresses_same_text() -> None: - """复读触发后,同一内容在冷却期内再次满足条件也不触发。""" + """复读触发后,同一内容在冷却期内再次满足条件也不触发,但消息继续处理。""" handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) # 第一轮:触发复读 for uid in [20001, 20002, 20003]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 1 - # 第二轮:同一内容,应被冷却抑制 + # 第二轮:同一内容,应被冷却抑制——复读不发送 for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 1 # 不增加 + # 冷却抑制后消息应继续处理(到达 AI auto-reply),不应被 silently dropped + assert handler.ai_coordinator.handle_auto_reply.call_count > 0 + @pytest.mark.asyncio async def test_repeat_cooldown_allows_different_text() -> None: From 1daa53fbac985e97085a5f0c472ab9b5b39fbf44 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:24:31 +0800 Subject: [PATCH 48/57] fix(memes,repeat): address 3 Devin review bugs 1. Meme reanalyze now uses GIF multi-frame analysis (Flag #1): _process_reanalyze_job checks record.is_animated + gif_analysis_mode and calls _prepare_gif_multi_frames, matching the ingest code path. Frame files are cleaned up in all exit paths. 2. Repeat cooldown dict no longer grows unboundedly (Flag #17): _record_repeat_cooldown now prunes expired entries on each insert, preventing slow memory leak from accumulated unique texts per group. 3. GIF frame files cleaned up on retryable LLM errors (Flag #25): Both judge and describe stages in the ingest path now clean up multi-frame temp files before re-raising retryable exceptions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 13 +++++++++++-- src/Undefined/memes/service.py | 28 +++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 33d6a07..34fa9f6 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -166,10 +166,19 @@ def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: return (time.monotonic() - last_time) < cooldown_minutes * 60 def _record_repeat_cooldown(self, group_id: int, text: str) -> None: - """记录复读冷却时间戳。""" + """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" key = self._normalize_repeat_text(text) group_cd = self._repeat_cooldown.setdefault(group_id, {}) - group_cd[key] = time.monotonic() + now = time.monotonic() + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + # 清理已过期条目 + if cooldown_seconds > 0: + expired = [ + k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds + ] + for k in expired: + del group_cd[k] + group_cd[key] = now async def _annotate_meme_descriptions( self, diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 9b02d41..96f80ab 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -872,30 +872,52 @@ async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: return if self._ai_client is None: raise RuntimeError("reanalyze requires ai_client") - analyze_path = record.preview_path if record.preview_path else record.blob_path + analyze_path: str | list[str] = ( + record.preview_path if record.preview_path else record.blob_path + ) + # GIF 多帧模式:与 ingest 路径保持一致 + if record.is_animated: + cfg = self._cfg() + if str(getattr(cfg, "gif_analysis_mode", "grid")).lower() == "multi": + analyze_path = await self._prepare_gif_multi_frames( + Path(record.blob_path), uid + ) try: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self.delete_meme(uid) return try: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] describe stage failed during reanalyze: uid=%s err=%s", uid, exc, ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) auto_description = str(described.get("description") or "").strip() next_tags = _normalize_tags(described.get("tags")) if not auto_description and not next_tags: @@ -1031,6 +1053,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", @@ -1052,6 +1076,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] describe stage failed, drop uid=%s err=%s", uid, exc From 60e723bad8d24c0780f68c1b1b3f1072a596c8e7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:37:43 +0800 Subject: [PATCH 49/57] fix(latex): preserve LaTeX commands like \nu \nabla \neq during \n replacement The aggressive content.replace('\\n', '\n') destroyed any LaTeX command starting with \n (\nu, \nabla, \neq, \neg, etc). Use regex with negative converted to a real newline. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- .../skills/toolsets/render/render_latex/handler.py | 4 ++-- tests/test_render_latex_tool.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 0637b9f..5da83f9 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -42,8 +42,8 @@ def _prepare_content(raw_content: str) -> str: 3. 如果没有数学分隔符,自动用 \\[ ... \\] 包装 """ content = _strip_document_wrappers(raw_content) - # 替换字面量 \\n 为真实换行符 - content = content.replace("\\n", "\n") + # 替换字面量 \\n 为真实换行符,但保留 LaTeX 命令如 \nu \nabla \neq 等 + content = re.sub(r"\\n(?![a-zA-Z])", "\n", content) if not _has_math_delimiters(content): # 没有分隔符,自动包装为块级数学环境 diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py index d4d3e80..130991a 100644 --- a/tests/test_render_latex_tool.py +++ b/tests/test_render_latex_tool.py @@ -182,10 +182,17 @@ def test_prepare_content() -> None: result_with_delim = _prepare_content(r"\[ E = mc^2 \]") assert result_with_delim == r"\[ E = mc^2 \]" - # 字面量 \\n 处理 - result_newline = _prepare_content(r"x = 1\\ny = 2") + # 字面量 \\n 处理(后面不跟字母时替换为换行) + result_newline = _prepare_content("x = 1\\n2 = y") assert "\n" in result_newline - assert "\\n" not in result_newline.replace(r"\[", "").replace(r"\]", "") + assert "x = 1" in result_newline + assert "2 = y" in result_newline + + # LaTeX 命令不被破坏:\nu \nabla \neq 保持不变 + result_latex = _prepare_content(r"\nu + \nabla \neq 0") + assert r"\nu" in result_latex + assert r"\nabla" in result_latex + assert r"\neq" in result_latex def test_build_html_contains_mathjax_ready_flag() -> None: From 9f1c756e474995e15cead887730b4533a51a179a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:57:31 +0800 Subject: [PATCH 50/57] feat(vision): add configurable max_tokens to VisionModelConfig Previously vision model max_tokens was hardcoded (256/512/8192) at call sites. With thinking-enabled models like kimi-k2.5, the small budgets were entirely consumed by the thinking chain, leaving no room for tool-call output. - Add max_tokens field to VisionModelConfig (default 8192) - Parse from config.toml [models.vision] and VISION_MODEL_MAX_TOKENS env - Replace all hardcoded max_tokens in multimodal.py with config value - Update config.toml.example with documentation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 3 +++ src/Undefined/ai/multimodal.py | 6 +++--- src/Undefined/config/model_parsers.py | 8 ++++++++ src/Undefined/config/models.py | 1 + 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/config.toml.example b/config.toml.example index c2a7fed..4448a04 100644 --- a/config.toml.example +++ b/config.toml.example @@ -152,6 +152,9 @@ api_key = "" # zh: Vision 模型名称。 # en: Vision model name. model_name = "" +# zh: Vision 模型最大输出 tokens。启用 thinking 时建议设大(如 8192),确保思维链消耗后仍有余量输出工具调用。 +# en: Vision model max output tokens. When thinking is enabled, use a larger value (e.g. 8192) so there is still room for tool-call output after thinking. +max_tokens = 8192 # zh: 队列发车间隔(秒,0 表示立即发车)。 # en: Queue interval (seconds; 0 dispatches immediately). queue_interval_seconds = 1.0 diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index f5f1ab2..e4dda36 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -696,7 +696,7 @@ async def analyze( result = await self._requester.request( model_config=self._vision_config, messages=[{"role": "user", "content": content_items}], - max_tokens=8192, + max_tokens=self._vision_config.max_tokens, call_type=f"vision_{detected_type}", ) content = extract_choices_content(result) @@ -861,7 +861,7 @@ async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: tool_schema=_MEME_JUDGE_TOOL, tool_name="submit_meme_judgement", call_type="vision_meme_judge", - max_tokens=256, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包判定失败,按非表情包处理: %s", exc) @@ -899,7 +899,7 @@ async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any tool_schema=_MEME_DESCRIBE_TOOL, tool_name="submit_meme_description", call_type="vision_meme_describe", - max_tokens=512, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包描述失败: %s", exc) diff --git a/src/Undefined/config/model_parsers.py b/src/Undefined/config/model_parsers.py index e17fc85..13e16c2 100644 --- a/src/Undefined/config/model_parsers.py +++ b/src/Undefined/config/model_parsers.py @@ -382,6 +382,14 @@ def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), "", ), + max_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "max_tokens"), + "VISION_MODEL_MAX_TOKENS", + ), + 8192, + ), queue_interval_seconds=queue_interval_seconds, api_mode=api_mode, thinking_enabled=_coerce_bool( diff --git a/src/Undefined/config/models.py b/src/Undefined/config/models.py index 3190b38..118fb9a 100644 --- a/src/Undefined/config/models.py +++ b/src/Undefined/config/models.py @@ -92,6 +92,7 @@ class VisionModelConfig: api_url: str api_key: str model_name: str + max_tokens: int = 8192 # 最大输出 tokens queue_interval_seconds: float = 1.0 api_mode: str = "chat_completions" # 请求 API 模式 thinking_enabled: bool = False # 是否启用 thinking From 4a6c57c464dd10f3acda0338c8b372a13d8204f7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 17:21:42 +0800 Subject: [PATCH 51/57] refactor(historian): use context_recent_messages_limit and XML format The historian now sees the same message count and XML format as the main AI, improving disambiguation quality: - Extract shared format_message_xml() / format_messages_xml() into utils/xml.py; deduplicate prompts.py and fetch_messages handler - Historian recent messages use context_recent_messages_limit (default 20) instead of historian_recent_messages_inject_k (was 12) - Messages formatted as XML (matching main AI) instead of plain text bullet list, including attachments and full metadata - Update historian_rewrite.md prompt to note XML format - Update tests for new format and import paths Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- res/prompts/historian_rewrite.md | 2 +- src/Undefined/ai/prompts.py | 56 +------------ src/Undefined/cognitive/historian.py | 2 +- .../tools/fetch_messages/handler.py | 68 +--------------- src/Undefined/skills/tools/end/handler.py | 61 +++++++------- src/Undefined/utils/xml.py | 81 ++++++++++++++++++- tests/test_end_tool.py | 12 +-- tests/test_fetch_messages_tool.py | 25 +++--- 8 files changed, 132 insertions(+), 175 deletions(-) diff --git a/res/prompts/historian_rewrite.md b/res/prompts/historian_rewrite.md index 72a8a2a..ac706d2 100644 --- a/res/prompts/historian_rewrite.md +++ b/res/prompts/historian_rewrite.md @@ -35,7 +35,7 @@ observations: {observations} 当前消息原文(触发本轮): {source_message} -最近消息参考(用于消歧,不要求逐字复述): +最近消息参考(XML 格式,与主对话一致,消息间以 `---` 分隔,用于消歧,不要求逐字复述): {recent_messages} 必须通过 `submit_rewrite` 工具提交结果,禁止输出普通文本内容。 diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index e938181..5a87c8d 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -11,7 +11,6 @@ import aiofiles -from Undefined.attachments import attachment_refs_to_xml from Undefined.utils.coerce import safe_int from Undefined.context import RequestContext from Undefined.end_summary_storage import ( @@ -23,7 +22,7 @@ from Undefined.skills.anthropic_skills import AnthropicSkillRegistry from Undefined.utils.logging import log_debug_json from Undefined.utils.resources import read_text_resource -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from Undefined.utils.xml import format_message_xml logger = logging.getLogger(__name__) @@ -725,58 +724,7 @@ async def _inject_recent_messages( recent_msgs = self._drop_current_message_if_duplicated( recent_msgs, question ) - context_lines: list[str] = [] - for msg in recent_msgs: - msg_type_val = msg.get("type", "group") - sender_name = msg.get("display_name", "未知用户") - sender_id = msg.get("user_id", "") - chat_id = msg.get("chat_id", "") - chat_name = msg.get("chat_name", "未知群聊") - timestamp = msg.get("timestamp", "") - text = msg.get("message", "") - attachments = msg.get("attachments", []) - role = msg.get("role", "member") - title = msg.get("title", "") - level = msg.get("level", "") - message_id = msg.get("message_id") - - safe_sender = escape_xml_attr(sender_name) - safe_sender_id = escape_xml_attr(sender_id) - safe_chat_id = escape_xml_attr(chat_id) - safe_chat_name = escape_xml_attr(chat_name) - safe_role = escape_xml_attr(role) - safe_title = escape_xml_attr(title) - safe_time = escape_xml_attr(timestamp) - safe_text = escape_xml_text(str(text)) - - msg_id_attr = "" - if message_id is not None: - msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' - attachment_xml = ( - f"\n{attachment_refs_to_xml(attachments)}" - if isinstance(attachments, list) and attachments - else "" - ) - - if msg_type_val == "group": - location = ( - chat_name if chat_name.endswith("群") else f"{chat_name}群" - ) - safe_location = escape_xml_attr(location) - level_attr = f' level="{escape_xml_attr(level)}"' if level else "" - xml_msg = ( - f' \n ' - ) - else: - location = "私聊" - safe_location = escape_xml_attr(location) - xml_msg = ( - f'{safe_text} {attachment_xml}\n\n ' - ) - context_lines.append(xml_msg) + context_lines: list[str] = [format_message_xml(msg) for msg in recent_msgs] formatted_context = "\n---\n".join(context_lines) diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian.py index fde5d24..fda333c 100644 --- a/src/Undefined/cognitive/historian.py +++ b/src/Undefined/cognitive/historian.py @@ -461,7 +461,7 @@ async def _rewrite( recent_messages = [ str(item).strip() for item in recent_messages_raw if str(item).strip() ] - recent_messages_text = "\n".join(f"- {line}" for line in recent_messages) + recent_messages_text = "\n---\n".join(recent_messages) prompt = template.format( request_id=job.get("request_id", ""), end_seq=job.get("end_seq", 0), diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py index 4426752..56be9c3 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -5,8 +5,7 @@ from datetime import datetime, timedelta from typing import Any -from Undefined.attachments import attachment_refs_to_xml -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from Undefined.utils.xml import format_messages_xml logger = logging.getLogger(__name__) @@ -58,12 +57,6 @@ def _filter_by_time( return result -def _format_message_location(msg_type_val: str, chat_name: str) -> str: - if msg_type_val == "group": - return chat_name if chat_name.endswith("群") else f"{chat_name}群" - return "私聊" - - def _normalize_messages_for_chat( messages: list[dict[str, Any]], *, @@ -83,63 +76,6 @@ def _normalize_messages_for_chat( return normalized -def _format_message_xml(msg: dict[str, Any]) -> str: - msg_type_val = str(msg.get("type", "group") or "group") - sender_name = str(msg.get("display_name", "未知用户") or "未知用户") - sender_id = str(msg.get("user_id", "") or "") - chat_id = str(msg.get("chat_id", "") or "") - chat_name = str(msg.get("chat_name", "未知群聊") or "未知群聊") - timestamp = str(msg.get("timestamp", "") or "") - text = str(msg.get("message", "") or "") - message_id = msg.get("message_id") - role = str(msg.get("role", "member") or "member") - title = str(msg.get("title", "") or "") - level = str(msg.get("level", "") or "") - attachments = msg.get("attachments", []) - - safe_sender = escape_xml_attr(sender_name) - safe_sender_id = escape_xml_attr(sender_id) - safe_chat_id = escape_xml_attr(chat_id) - safe_chat_name = escape_xml_attr(chat_name) - safe_role = escape_xml_attr(role) - safe_title = escape_xml_attr(title) - safe_time = escape_xml_attr(timestamp) - safe_text = escape_xml_text(text) - safe_location = escape_xml_attr(_format_message_location(msg_type_val, chat_name)) - - msg_id_attr = "" - if message_id is not None: - msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' - - attachment_xml = ( - f"\n{attachment_refs_to_xml(attachments)}" - if isinstance(attachments, list) and attachments - else "" - ) - - if msg_type_val == "group": - level_attr = f' level="{escape_xml_attr(level)}"' if level else "" - return ( - f'{safe_text} {attachment_xml}\n\n' - f" " - ) - - return ( - f'{safe_text} {attachment_xml}\n" - f"\n' - f" " - ) - - -def _format_messages(messages: list[dict[str, Any]]) -> str: - """Format messages into main-AI-compatible XML for the summary agent.""" - return "\n---\n".join(_format_message_xml(msg) for msg in messages) - - async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: """拉取当前会话的聊天消息。""" history_manager = context.get("history_manager") @@ -189,7 +125,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: messages, chat_type=chat_type, chat_id=chat_id ) - formatted = _format_messages(messages) + formatted = format_messages_xml(messages) total = len(messages) header = f"共获取 {total} 条消息" if time_range_str: diff --git a/src/Undefined/skills/tools/end/handler.py b/src/Undefined/skills/tools/end/handler.py index 8afb076..42a9883 100644 --- a/src/Undefined/skills/tools/end/handler.py +++ b/src/Undefined/skills/tools/end/handler.py @@ -7,6 +7,7 @@ from Undefined.context import RequestContext from Undefined.utils.coerce import safe_int +from Undefined.utils.xml import format_message_xml from Undefined.end_summary_storage import ( EndSummaryLocation, @@ -92,37 +93,38 @@ def _clamp_int(value: int, min_value: int, max_value: int) -> int: return value -def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int, int]: +def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int]: + """Return (max_source_len, recent_k) for historian context injection. + + ``recent_k`` is derived from ``context_recent_messages_limit`` (same as + main AI) so the historian sees the same message window. + """ max_source_len = _DEFAULT_HISTORIAN_TEXT_LEN recent_k = _DEFAULT_HISTORIAN_LINES - max_recent_line_len = _DEFAULT_HISTORIAN_LINE_LEN runtime_config = context.get("runtime_config") + + # 消息数量:优先使用 context_recent_messages_limit(与主 AI 一致) + if runtime_config is not None and hasattr( + runtime_config, "get_context_recent_messages_limit" + ): + try: + recent_k = int(runtime_config.get_context_recent_messages_limit()) + except Exception: + pass + cognitive = getattr(runtime_config, "cognitive", None) if runtime_config else None if cognitive is not None: max_source_len = safe_int( getattr(cognitive, "historian_source_message_max_len", max_source_len), max_source_len, ) - recent_k = safe_int( - getattr(cognitive, "historian_recent_messages_inject_k", recent_k), - recent_k, - ) - max_recent_line_len = safe_int( - getattr( - cognitive, "historian_recent_message_line_max_len", max_recent_line_len - ), - max_recent_line_len, - ) max_source_len = _clamp_int( max_source_len, _MIN_HISTORIAN_TEXT_LEN, _MAX_HISTORIAN_TEXT_LEN ) recent_k = _clamp_int(recent_k, _MIN_HISTORIAN_LINES, _MAX_HISTORIAN_LINES) - max_recent_line_len = _clamp_int( - max_recent_line_len, _MIN_HISTORIAN_LINE_LEN, _MAX_HISTORIAN_LINE_LEN - ) - return max_source_len, recent_k, max_recent_line_len + return max_source_len, recent_k def _extract_current_content_from_question(question: str, *, max_len: int) -> str: @@ -136,8 +138,13 @@ def _extract_current_content_from_question(question: str, *, max_len: int) -> st def _build_historian_recent_messages( - context: Dict[str, Any], *, recent_k: int, max_line_len: int + context: Dict[str, Any], *, recent_k: int ) -> list[str]: + """Build XML-formatted recent messages for historian context. + + Uses the same XML schema as the main AI prompt so the historian LLM + sees identical message structure for better disambiguation. + """ if recent_k <= 0: return [] @@ -173,24 +180,14 @@ def _build_historian_recent_messages( for msg in recent: if not isinstance(msg, dict): continue - timestamp = str(msg.get("timestamp", "")).strip() - display_name = str(msg.get("display_name", "")).strip() - user_id = str(msg.get("user_id", "")).strip() - message_text = _clip_text(msg.get("message", ""), max_line_len) - if not message_text: + if not str(msg.get("message", "") or "").strip(): continue - who = display_name or (f"UID:{user_id}" if user_id else "未知用户") - if user_id: - who = f"{who}({user_id})" - if timestamp: - lines.append(f"[{timestamp}] {who}: {message_text}") - else: - lines.append(f"{who}: {message_text}") + lines.append(format_message_xml(msg)) return lines[-recent_k:] def _inject_historian_reference_context(context: Dict[str, Any]) -> None: - max_source_len, recent_k, max_recent_line_len = _resolve_historian_limits(context) + max_source_len, recent_k = _resolve_historian_limits(context) current_question = str(context.get("current_question") or "").strip() source_message = _extract_current_content_from_question( current_question, max_len=max_source_len @@ -202,9 +199,7 @@ def _inject_historian_reference_context(context: Dict[str, Any]) -> None: current_question, max_source_len ) - recent_lines = _build_historian_recent_messages( - context, recent_k=recent_k, max_line_len=max_recent_line_len - ) + recent_lines = _build_historian_recent_messages(context, recent_k=recent_k) if recent_lines: context["historian_recent_messages"] = recent_lines diff --git a/src/Undefined/utils/xml.py b/src/Undefined/utils/xml.py index 865ac24..16d2415 100644 --- a/src/Undefined/utils/xml.py +++ b/src/Undefined/utils/xml.py @@ -1,7 +1,9 @@ -"""Minimal XML escaping helpers.""" +"""Minimal XML escaping helpers and message formatting.""" from __future__ import annotations +from typing import Any, Callable, Sequence, Mapping + from xml.sax.saxutils import escape @@ -12,3 +14,80 @@ def escape_xml_text(value: str) -> str: def escape_xml_attr(value: object) -> str: text = "" if value is None else str(value) return escape(text, {'"': """, "'": "'"}) + + +def _message_location(msg_type: str, chat_name: str) -> str: + """Derive the human-readable location label from message type.""" + if msg_type == "group": + return chat_name if chat_name.endswith("群") else f"{chat_name}群" + return "私聊" + + +def format_message_xml( + msg: dict[str, Any], + *, + attachment_formatter: (Callable[[Sequence[Mapping[str, str]]], str] | None) = None, +) -> str: + """Format a single history record dict into main-AI-compatible XML. + + ``attachment_formatter`` is an optional callable that turns the attachments + list into an XML fragment. When *None* (the default) a lazy import of + :func:`Undefined.attachments.attachment_refs_to_xml` is used so that + lightweight callers do not pay the import cost. + """ + msg_type_val = str(msg.get("type", "group") or "group") + sender_name = str(msg.get("display_name", "未知用户") or "未知用户") + sender_id = str(msg.get("user_id", "") or "") + chat_id = str(msg.get("chat_id", "") or "") + chat_name = str(msg.get("chat_name", "未知群聊") or "未知群聊") + timestamp = str(msg.get("timestamp", "") or "") + text = str(msg.get("message", "") or "") + message_id = msg.get("message_id") + role = str(msg.get("role", "member") or "member") + title = str(msg.get("title", "") or "") + level = str(msg.get("level", "") or "") + attachments = msg.get("attachments", []) + + safe_sender = escape_xml_attr(sender_name) + safe_sender_id = escape_xml_attr(sender_id) + safe_chat_id = escape_xml_attr(chat_id) + safe_chat_name = escape_xml_attr(chat_name) + safe_role = escape_xml_attr(role) + safe_title = escape_xml_attr(title) + safe_time = escape_xml_attr(timestamp) + safe_text = escape_xml_text(text) + safe_location = escape_xml_attr(_message_location(msg_type_val, chat_name)) + + msg_id_attr = "" + if message_id is not None: + msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' + + attachment_xml = "" + if isinstance(attachments, list) and attachments: + if attachment_formatter is None: + from Undefined.attachments import attachment_refs_to_xml + + attachment_formatter = attachment_refs_to_xml + attachment_xml = f"\n{attachment_formatter(attachments)}" + + if msg_type_val == "group": + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" + return ( + f'{safe_text} {attachment_xml}\n" - f"\n' + f" " + ) + + return ( + f'{safe_text} {attachment_xml}\n" + f"\n' + f" " + ) + + +def format_messages_xml(messages: list[dict[str, Any]]) -> str: + """Format a list of history records into ``\\n---\\n``-separated XML.""" + return "\n---\n".join(format_message_xml(msg) for msg in messages) diff --git a/tests/test_end_tool.py b/tests/test_end_tool.py index 45e4106..4b65584 100644 --- a/tests/test_end_tool.py +++ b/tests/test_end_tool.py @@ -178,10 +178,9 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: cognitive_service = _FakeCognitiveService() runtime_config = SimpleNamespace( cognitive=SimpleNamespace( - historian_recent_messages_inject_k=2, - historian_recent_message_line_max_len=60, historian_source_message_max_len=40, - ) + ), + get_context_recent_messages_limit=lambda: 2, ) long_content = "A" * 300 context: dict[str, Any] = { @@ -207,6 +206,7 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: assert len(source) <= 40 assert isinstance(recent, list) assert len(recent) == 2 - assert all( - len(str(line).split(": ", 1)[1]) <= 60 for line in recent if ": " in str(line) - ) + # Recent messages now use XML format (same as main AI) + for line in recent: + assert "{safe_text} {attachment_xml}\n" + f"" in str(line) diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py index 1804c72..65fff3d 100644 --- a/tests/test_fetch_messages_tool.py +++ b/tests/test_fetch_messages_tool.py @@ -8,11 +8,10 @@ from Undefined.skills.agents.summary_agent.tools.fetch_messages.handler import ( _filter_by_time, - _format_messages, - _format_message_xml, _parse_time_range, execute as fetch_messages_execute, ) +from Undefined.utils.xml import format_message_xml, format_messages_xml # -- _parse_time_range unit tests -- @@ -124,10 +123,10 @@ def test_filter_by_time_invalid_timestamp() -> None: assert len(result) == 0 -# -- _format_messages unit tests -- +# -- format_messages_xml unit tests -- -def test_format_message_xml_group_basic() -> None: +def testformat_message_xml_group_basic() -> None: """Group message is formatted into main-AI-compatible XML.""" messages = [ { @@ -145,7 +144,7 @@ def test_format_message_xml_group_basic() -> None: }, ] - result = _format_message_xml(messages[0]) + result = format_message_xml(messages[0]) assert 'message_id="123"' in result assert 'sender="Alice"' in result assert 'sender_id="10001"' in result @@ -158,7 +157,7 @@ def test_format_message_xml_group_basic() -> None: assert " Hello " in result -def test_format_message_xml_private_basic() -> None: +def testformat_message_xml_private_basic() -> None: """Private message uses the private XML shape.""" msg = { "type": "private", @@ -169,7 +168,7 @@ def test_format_message_xml_private_basic() -> None: "message_id": 456, } - result = _format_message_xml(msg) + result = format_message_xml(msg) assert 'message_id="456"' in result assert 'sender="Bob"' in result assert 'sender_id="10002"' in result @@ -179,7 +178,7 @@ def test_format_message_xml_private_basic() -> None: assert "Hi " in result -def test_format_message_xml_includes_attachments() -> None: +def testformat_message_xml_includes_attachments() -> None: """Attachment refs are rendered as XML below content.""" msg = { "type": "group", @@ -200,14 +199,14 @@ def test_format_message_xml_includes_attachments() -> None: ], } - result = _format_message_xml(msg) + result = format_message_xml(msg) assert "" in result assert 'uid="pic_abcd1234"' in result assert 'type="image"' in result assert 'description="截图"' in result -def test_format_messages_multiple() -> None: +def testformat_messages_xml_multiple() -> None: """Multiple messages are separated by main-AI-style delimiters.""" messages = [ { @@ -228,12 +227,12 @@ def test_format_messages_multiple() -> None: }, ] - result = _format_messages(messages) + result = format_messages_xml(messages) assert "\n---\n" in result assert result.count(" None: +def testformat_messages_xml_missing_fields() -> None: """Missing fields still produce valid XML.""" messages = [ { @@ -242,7 +241,7 @@ def test_format_messages_missing_fields() -> None: }, ] - result = _format_messages(messages) + result = format_messages_xml(messages) assert "未知用户" in result assert "No timestamp" in result assert " Date: Sun, 19 Apr 2026 17:44:49 +0800 Subject: [PATCH 52/57] fix(webui): cap top_k overflow, debounce meme search, reorder dashboard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Cap top_k to 500 in frontend (appendPositiveIntParam) and backend (cognitive service + vector_store _safe_positive_int) to prevent ChromaDB OverflowError when large integers are passed - Add max=500 to all top_k HTML inputs - Cap fetch_k to 10000 in vector_store._query() as safety net - Add debounced auto-search (350ms) for meme text inputs with pending-refresh pattern to avoid stale results - Enter key flushes debounce timer for instant search - Move 运行环境 card before 资源趋势 chart in dashboard layout Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/cognitive/service.py | 3 ++ src/Undefined/cognitive/vector_store.py | 7 +++- src/Undefined/webui/static/js/memes.js | 47 +++++++++++++++++++----- src/Undefined/webui/static/js/runtime.js | 4 +- src/Undefined/webui/templates/index.html | 24 ++++++------ 5 files changed, 60 insertions(+), 25 deletions(-) diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index 1febe85..13cf3f1 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -725,6 +725,7 @@ async def build_context( top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) try: events = await self._query_events_for_auto_context( query=query, @@ -819,6 +820,7 @@ async def search_events(self, query: str, **kwargs: Any) -> list[dict[str, Any]] top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) logger.info( "[认知服务] 搜索事件: query_len=%s top_k=%s where=%s time_from=%s time_to=%s", len(query or ""), @@ -872,6 +874,7 @@ async def search_profiles(self, query: str, **kwargs: Any) -> list[dict[str, Any top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) where: dict[str, Any] | None = None entity_type_raw = kwargs.get("entity_type") diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index 0c84f5a..48b80eb 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -34,13 +34,15 @@ def _clamp(value: float, lower: float, upper: float) -> float: return value -def _safe_positive_int(value: Any, default: int) -> int: +def _safe_positive_int(value: Any, default: int, maximum: int = 0) -> int: try: parsed = int(value) except Exception: return max(1, int(default)) if parsed <= 0: return max(1, int(default)) + if maximum > 0 and parsed > maximum: + return maximum return parsed @@ -427,7 +429,7 @@ async def _query( query_embedding: list[float] | None = None, ) -> list[dict[str, Any]]: col_name = getattr(col, "name", "unknown") - safe_top_k = _safe_positive_int(top_k, default=1) + safe_top_k = _safe_positive_int(top_k, default=1, maximum=500) safe_multiplier = _safe_positive_int(candidate_multiplier, default=1) total_started = time.perf_counter() logger.debug( @@ -455,6 +457,7 @@ async def _query( use_reranker or apply_time_decay or apply_mmr ) fetch_k = safe_top_k * safe_multiplier if use_extra_candidates else safe_top_k + fetch_k = min(fetch_k, 10000) include: list[str] = ["documents", "metadatas", "distances"] if apply_mmr: include.append("embeddings") diff --git a/src/Undefined/webui/static/js/memes.js b/src/Undefined/webui/static/js/memes.js index 0afef67..465c3cd 100644 --- a/src/Undefined/webui/static/js/memes.js +++ b/src/Undefined/webui/static/js/memes.js @@ -314,6 +314,27 @@ return payload; } + let _debounceTimer = null; + let _pendingRefresh = false; + + function debouncedFetchList(delayMs = 350) { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + } + _debounceTimer = setTimeout(() => { + _debounceTimer = null; + fetchList().catch(showError); + }, delayMs); + } + + function flushDebouncedFetchList() { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + _debounceTimer = null; + } + fetchList().catch(showError); + } + async function fetchList(options = {}) { const append = !!options.append; if (append) { @@ -321,6 +342,7 @@ return null; } } else if (state.loading) { + _pendingRefresh = true; return null; } @@ -415,6 +437,10 @@ } renderLoadMore(); } + if (!append && _pendingRefresh) { + _pendingRefresh = false; + fetchList().catch(showError); + } } } @@ -558,7 +584,7 @@ state.initialized = true; get("btnMemesRefresh")?.addEventListener("click", refreshAll); get("btnMemesSearch")?.addEventListener("click", () => { - fetchList().catch(showError); + flushDebouncedFetchList(); }); get("btnMemesLoadMore")?.addEventListener("click", () => { fetchList({ append: true }).catch(showError); @@ -579,14 +605,17 @@ get("btnMemesDelete")?.addEventListener("click", () => { deleteSelected().catch(showError); }); - bindEnter("memesSearchInput", () => { - fetchList().catch(showError); - }); - bindEnter("memesKeywordQuery", () => { - fetchList().catch(showError); - }); - bindEnter("memesSemanticQuery", () => { - fetchList().catch(showError); + bindEnter("memesSearchInput", flushDebouncedFetchList); + bindEnter("memesKeywordQuery", flushDebouncedFetchList); + bindEnter("memesSemanticQuery", flushDebouncedFetchList); + [ + "memesSearchInput", + "memesKeywordQuery", + "memesSemanticQuery", + ].forEach((id) => { + const el = get(id); + if (el) + el.addEventListener("input", () => debouncedFetchList()); }); [ "memesQueryMode", diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index 42655a5..c1e1282 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -303,12 +303,12 @@ params.set(key, text); } - function appendPositiveIntParam(params, key, value) { + function appendPositiveIntParam(params, key, value, max = 500) { const text = String(value || "").trim(); if (!text) return; const num = Number.parseInt(text, 10); if (!Number.isFinite(num) || num <= 0) return; - params.set(key, String(num)); + params.set(key, String(Math.min(num, max))); } function formatNumeric(value, digits = 4) { diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 214ee63..a306720 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -299,15 +299,6 @@ 运行概览
---资源趋势- -- CPU - Memory --+ +运行环境@@ -323,6 +314,15 @@运行概览
--+@@ -492,7 +492,7 @@资源趋势+ ++ CPU + Memory ++记忆检索
-- @@ -571,7 +571,7 @@记忆检索
- From c347f6d744e6a200bc4a16befc445dacff86bad0 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 17:56:05 +0800 Subject: [PATCH 53/57] fix(hot_reload,repeat): address 2 Devin review bugs - Add historian_model to hot_reload tracking sets (_QUEUE_INTERVAL_KEYS and _MODEL_NAME_KEYS) so config changes take effect without restart - Fix memory leak in _record_repeat_cooldown: skip recording entirely when cooldown_minutes=0 instead of accumulating never-evicted entries - Add test assertion verifying no cooldown entries when disabled Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/config/hot_reload.py | 2 ++ src/Undefined/handlers.py | 13 ++++++------- tests/test_handlers_repeat.py | 2 ++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/Undefined/config/hot_reload.py b/src/Undefined/config/hot_reload.py index f603524..f0d83c9 100644 --- a/src/Undefined/config/hot_reload.py +++ b/src/Undefined/config/hot_reload.py @@ -43,6 +43,7 @@ "naga_model.queue_interval_seconds", "agent_model.queue_interval_seconds", "summary_model.queue_interval_seconds", + "historian_model.queue_interval_seconds", "grok_model.queue_interval_seconds", "chat_model.pool", "agent_model.pool", @@ -55,6 +56,7 @@ "naga_model.model_name", "agent_model.model_name", "summary_model.model_name", + "historian_model.model_name", "grok_model.model_name", } diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 34fa9f6..9d5bfcb 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -167,17 +167,16 @@ def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: def _record_repeat_cooldown(self, group_id: int, text: str) -> None: """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + if cooldown_seconds <= 0: + return key = self._normalize_repeat_text(text) group_cd = self._repeat_cooldown.setdefault(group_id, {}) now = time.monotonic() - cooldown_seconds = self.config.repeat_cooldown_minutes * 60 # 清理已过期条目 - if cooldown_seconds > 0: - expired = [ - k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds - ] - for k in expired: - del group_cd[k] + expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] + for k in expired: + del group_cd[k] group_cd[key] = now async def _annotate_meme_descriptions( diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index c29219d..c305b6b 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -417,6 +417,8 @@ async def test_repeat_cooldown_zero_disables() -> None: for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 2 + # cooldown=0 不应写入任何冷却记录(防止内存泄漏) + assert len(handler._repeat_cooldown) == 0 @pytest.mark.asyncio From ce5db0f46a3d70cfa37539be57935746929f0404 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:03:46 +0800 Subject: [PATCH 54/57] docs(changelog): update v3.3.2 with all session fixes and improvements Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f2f07..ebe9f48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,19 +2,44 @@ 围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 +### 新功能 + - 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 - `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 - 超级管理员可通过 `/p` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 - 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 -- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 -- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 -- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型消耗全部 token 导致工具调用截断的问题。 - 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 - 新增 `calculator` 多功能安全计算器工具。 - 新增消息历史限制全面可配置化(`[history].max_records`)。 -- 新增 `utils/coerce.py`(安全类型强转)与 `utils/fake_at.py`(假@文本检测与解析)公共模块。 -- WebUI 新增功能:Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 新增 `utils/coerce.py`(安全类型强转)、`utils/fake_at.py`(假@检测与解析)与 `utils/xml.py`(统一 XML 消息格式化)公共模块。 + +### 架构重构 + +- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 +- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 +- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- 认知史官(historian)消息格式从纯文本改为 XML,与主 AI 上下文格式统一;消息数量改用 `context_recent_messages_limit`(默认 20)。 +- 统一 XML 消息格式化逻辑至 `utils/xml.py`,消除 `prompts.py`、`fetch_messages` 和 `end` tool 之间的重复代码。 + +### WebUI + +- Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 修复 top_k 参数溢出导致 ChromaDB Rust 整数越界崩溃的问题,前端/后端四层防护(HTML max 属性、JS 参数上限、Service 层 clamp、向量存储 fetch_k 硬上限)。 +- 表情包搜索页新增 350ms 防抖自动搜索,Enter 立即刷新,`_pendingRefresh` 模式防止并发请求返回过期结果。 +- 仪表盘布局优化:三张信息卡共占一行,资源趋势图全宽置底。 + +### Bug 修复 + +- 修复 AI 模仿系统关键词自动回复前缀生成伪系统消息的问题,在提示词中明确标注该前缀仅由代码路径使用。 +- 修复 LaTeX 渲染中 `\n` 替换破坏 `\nu`、`\nabla`、`\neq` 等命令的问题,改用负向前瞻正则 `\\n(?![a-zA-Z])`。 +- 修复复读冷却抑制后错误丢弃消息(不再触发自动回复)的问题。 +- 修复 `repeat_cooldown_minutes=0` 时 `_repeat_cooldown` 字典持续积累不被清理的内存泄漏。 +- 修复 `historian_model` 未加入热重载追踪集导致配置热更新后队列间隔和模型名失效的问题。 - 修复队列系统 historian 模型未注册的调度问题。 +- 修复表情包 GIF 多帧分析 `_process_reanalyze_job` 未检查 `gif_analysis_mode` 的问题。 +- 修复表情包 GIF 帧临时文件在可重试 LLM 错误时未被清理的资源泄漏。 +- 修复附件渲染错误时仍追加 `prompt_ref` 导致后续异常的问题。 - 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 - 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 From 0b009db57d79b3d146138f46213631945dda5578 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:07:01 +0800 Subject: [PATCH 55/57] docs(changelog): simplify v3.3.2 format to match prior versions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 51 +++++++++++---------------------------------------- 1 file changed, 11 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebe9f48..5db78b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,46 +2,17 @@ 围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 -### 新功能 - -- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 -- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 -- 超级管理员可通过 `/p ` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 -- 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 -- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型消耗全部 token 导致工具调用截断的问题。 -- 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 -- 新增 `calculator` 多功能安全计算器工具。 -- 新增消息历史限制全面可配置化(`[history].max_records`)。 -- 新增 `utils/coerce.py`(安全类型强转)、`utils/fake_at.py`(假@检测与解析)与 `utils/xml.py`(统一 XML 消息格式化)公共模块。 - -### 架构重构 - -- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 -- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 -- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 -- 认知史官(historian)消息格式从纯文本改为 XML,与主 AI 上下文格式统一;消息数量改用 `context_recent_messages_limit`(默认 20)。 -- 统一 XML 消息格式化逻辑至 `utils/xml.py`,消除 `prompts.py`、`fetch_messages` 和 `end` tool 之间的重复代码。 - -### WebUI - -- Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 -- 修复 top_k 参数溢出导致 ChromaDB Rust 整数越界崩溃的问题,前端/后端四层防护(HTML max 属性、JS 参数上限、Service 层 clamp、向量存储 fetch_k 硬上限)。 -- 表情包搜索页新增 350ms 防抖自动搜索,Enter 立即刷新,`_pendingRefresh` 模式防止并发请求返回过期结果。 -- 仪表盘布局优化:三张信息卡共占一行,资源趋势图全宽置底。 - -### Bug 修复 - -- 修复 AI 模仿系统关键词自动回复前缀生成伪系统消息的问题,在提示词中明确标注该前缀仅由代码路径使用。 -- 修复 LaTeX 渲染中 `\n` 替换破坏 `\nu`、`\nabla`、`\neq` 等命令的问题,改用负向前瞻正则 `\\n(?![a-zA-Z])`。 -- 修复复读冷却抑制后错误丢弃消息(不再触发自动回复)的问题。 -- 修复 `repeat_cooldown_minutes=0` 时 `_repeat_cooldown` 字典持续积累不被清理的内存泄漏。 -- 修复 `historian_model` 未加入热重载追踪集导致配置热更新后队列间隔和模型名失效的问题。 -- 修复队列系统 historian 模型未注册的调度问题。 -- 修复表情包 GIF 多帧分析 `_process_reanalyze_job` 未检查 `gif_analysis_mode` 的问题。 -- 修复表情包 GIF 帧临时文件在可重试 LLM 错误时未被清理的资源泄漏。 -- 修复附件渲染错误时仍追加 `prompt_ref` 导致后续异常的问题。 -- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 -- 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 +- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 +- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认)、`-r` 渲染为图片、`-t` 直接文本发送;超管可通过 `/p ` 和 `/p g <群号>` 跨目标查看。 +- 复读系统全面升级:触发阈值可配置(`repeat_threshold`)、Bot 发言不计入复读链、新增冷却机制(`repeat_cooldown_minutes`,?与 ? 等价)。 +- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型 token 耗尽导致工具调用截断。 +- 新增 arXiv 论文深度分析 Agent 与 `calculator` 安全计算器工具;新增消息历史限制可配置化(`[history].max_records`)。 +- Runtime API 拆分为 8 个路由子模块(`api/routes/`);配置系统拆为 `loader.py`、`models.py`、`hot_reload.py`。 +- 消息预处理并行化(`asyncio.gather`),认知史官改用 XML 格式并统一至 `utils/xml.py`。 +- WebUI:Cmd/Ctrl+K 命令面板、骨架屏、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆 CRUD、Modal 焦点陷阱。 +- WebUI 修复 top_k 溢出崩溃(四层防护)、表情包搜索防抖、仪表盘布局优化。 +- 修复 AI 模仿系统关键词自动回复前缀、LaTeX `\n` 替换破坏数学命令、复读冷却消息丢弃与内存泄漏、historian 热重载追踪缺失、GIF 多帧分析与临时文件泄漏、附件渲染异常等多项问题。 +- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色并提高截断上限至 5000 字符。 --- From a2b8b2669e0e9d4bbf92d59e4d0ab82b0e246a3d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:14:14 +0800 Subject: [PATCH 56/57] docs(webui): add WebUI usage guide and cross-reference links - Create docs/webui-guide.md covering all 8 tabs, config, shortcuts, FAQ - Add link in README.md documentation navigation section - Add reference in docs/deployment.md startup section - Add reference in docs/management-api.md recommended entry section Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 1 + docs/deployment.md | 2 + docs/management-api.md | 2 + docs/webui-guide.md | 181 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 docs/webui-guide.md diff --git a/README.md b/README.md index 48f231c..0cc94f0 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙: - ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。 +- 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。 - 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。 - 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。 - 😶 **[表情包系统 (Memes)](docs/memes.md)**:查看表情包两阶段判定管线、统一图片 `uid` 发送机制、检索模式及库存管理说明。 diff --git a/docs/deployment.md b/docs/deployment.md index 335b51d..bc5eb1b 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -113,6 +113,8 @@ uv run Undefined-webui ``` > **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 +> +> WebUI 功能详见 [WebUI 使用指南](webui-guide.md)。 ### 6. 跨平台与资源路径(重要) diff --git a/docs/management-api.md b/docs/management-api.md index 4daec18..346ae86 100644 --- a/docs/management-api.md +++ b/docs/management-api.md @@ -25,6 +25,8 @@ uv run Undefined-webui 4. 直接在控制台启动 Bot 5. 如需远程管理,再让桌面端或 Android App 连接同一个 Management API 地址 +> WebUI 各页面的功能和操作详见 [WebUI 使用指南](webui-guide.md)。 + ## 2. 鉴权模型 Management API 兼容两套鉴权: diff --git a/docs/webui-guide.md b/docs/webui-guide.md new file mode 100644 index 0000000..9a77ff1 --- /dev/null +++ b/docs/webui-guide.md @@ -0,0 +1,181 @@ +# WebUI 使用指南 + +WebUI 是 Undefined 的主要管理入口,提供配置编辑、日志查看、认知记忆管理、表情包库、AI 对话和系统监控等一站式功能。即使 `config.toml` 尚未创建,也可以通过 WebUI 补全配置并启动 Bot。 + +--- + +## 快速开始 + +### 启动 + +```bash +uv run Undefined-webui +``` + +默认监听 `http://127.0.0.1:8787`。相关配置位于 `config.toml` 的 `[webui]` 节: + +```toml +[webui] +url = "127.0.0.1" # 监听地址 +port = 8787 # 端口 +password = "changeme" # 密码(必须在首次登录时修改) +``` + +> 如需远程访问,将 `url` 改为 `0.0.0.0` 或实际 IP。 + +### 首次登录 + +1. 浏览器打开 `http://127.0.0.1:8787` +2. 输入初始密码 `changeme` +3. 系统强制要求修改默认密码后才能进入管理界面 + +修改后的密码会写入 `config.toml`。桌面端 / Android 客户端也使用同一密码连接。 + +--- + +## 功能概览 + +WebUI 共有 8 个主要页签(Tab),下面逐一介绍。 + +### 概览(Overview) + +仪表盘页面,展示系统运行状态: + +| 指标 | 说明 | +|------|------| +| Bot 状态 | 运行中 / 已停止 | +| 系统运行时间 | 自启动以来的持续时间 | +| CPU 使用率 | 实时百分比 | +| 内存占用 | 已用 / 总量 / 百分比 | +| 运行环境 | CPU 型号、操作系统、Python 版本 | +| 资源趋势图 | CPU / 内存随时间的变化曲线 | + +页面会自动刷新,也可手动触发。 + +### 配置管理(Config) + +提供两种编辑方式: + +- **表单模式**:按分组展示所有配置项,带字段说明和类型校验,适合日常修改。 +- **TOML 原始编辑器**:直接编辑 `config.toml` 文本,带语法高亮,适合批量变更。 + +其他能力: + +- **验证**:保存前自动进行 TOML 语法检查和严格配置校验,错误项会标红提示。 +- **配置历史**:每次保存自动生成带时间戳的备份(最多 50 版本),可随时回滚到任意历史版本。 +- **模板同步**:一键将 `config.toml.example` 中新增的配置项合并到当前配置,不覆盖已有值。 + +> 配置支持热更新——大多数配置项修改后即时生效,无需重启 Bot。需要重启的项(如 `onebot_ws_url`、`webui_port`)会在保存时提示。 + +### 日志查看(Logs) + +- 支持 **Bot 日志** 和 **WebUI 日志** 切换。 +- **实时流式推送**(SSE):日志实时滚动到最新。 +- 可暂停 / 恢复流式推送,调整显示行数(1–2000)。 +- 左侧文件列表展示所有日志文件(含归档),可选择查看历史日志。 +- 支持下载日志文件。 + +### 探针与诊断(Probes) + +三类探针帮助排查问题: + +- **内部探针**:版本号、Python 版本、平台、运行时间、OneBot 连接状态及 WebSocket 地址。 +- **外部探针**:Runtime API 的可用端点和能力列表。 +- **引导探针**:检查 `config.toml` 是否存在、TOML 语法是否合法、配置值是否有效,并给出修复建议。 + +### 认知记忆(Memory) + +分为三个子面板,对应 [认知记忆系统](cognitive-memory.md) 的不同层次: + +**认知事件搜索** + +输入关键词进行语义搜索,查看 AI 提取并存储的用户 / 群聊事实记录。支持调整返回条数和排序方式。 + +**认知侧写查看** + +- 按关键词搜索侧写。 +- 按 QQ 号或群号精确查看对应的完整侧写内容。 + +**长期记忆管理** + +AI 的置顶备忘录(自我约束、待办事项等),支持完整 CRUD: + +- 新建记忆条目 +- 编辑现有内容 +- 删除条目 +- 关键词搜索 + +### 表情包库(Memes) + +管理 [全局表情包库](memes.md) 的 Web 界面: + +- **浏览与搜索**:分页列表展示,支持关键词搜索和语义搜索(以及混合模式)。输入时自动防抖搜索,Enter 立即触发。 +- **筛选与排序**:按启用 / 禁用、静态 / 动态、置顶等条件筛选,按创建或更新时间排序。 +- **详情与操作**:查看元数据(UID、描述、标签)和预览图;支持编辑描述 / 标签、启用 / 禁用、置顶 / 取消、删除。 +- **重分析 / 重索引**:对单张表情包重新触发 AI 描述生成或搜索索引更新。 +- **统计概览**:总数、启用 / 禁用数、静态 / 动态数等。 + +### AI 对话(Chat) + +WebUI 内置的对话界面,直接与 Bot 的 AI 进行交互: + +- 支持文本和图片消息。 +- AI 回复支持 Markdown 渲染。 +- 消息历史分页浏览。 +- 发出的消息会经过与 QQ 侧相同的处理流程(安全检查、工具调用等)。 + +### 关于(About) + +显示当前版本号和 MIT 许可证文本。 + +--- + +## Bot 控制 + +WebUI 首页(Landing Page)提供 Bot 的启停控制: + +- **启动 Bot**:点击启动按钮,Bot 进程在后台运行。 +- **停止 Bot**:安全停止当前 Bot 进程。 +- **状态指示**:实时显示 Bot 运行状态。 + +首页还会检测是否有可用更新(基于 Git),并提供更新 + 重启功能。 + +--- + +## 键盘快捷键 + +| 快捷键 | 功能 | +|--------|------| +| `Cmd/Ctrl + K` | 打开命令面板,可快速跳转到任意页签或执行操作 | + +--- + +## 远程访问 + +WebUI 和桌面端 / Android 客户端共享同一 Management API: + +1. 将 `[webui].url` 设为 `0.0.0.0`(或你的 LAN/公网 IP)。 +2. 确保防火墙放行 `[webui].port`(默认 8787)。 +3. 桌面端 / Android 客户端输入 `http:// :8787` 和密码即可连接。 + +如果启用了 Runtime API(`[api].enabled = true`),WebUI 会自动代理 Runtime API 的功能(探针、记忆查询、AI Chat 等),无需单独暴露 Runtime API 端口。 + +--- + +## 常见问题 + +**Q: 忘记密码怎么办?** + +直接编辑 `config.toml` 中的 `[webui].password` 字段,重启 WebUI 即可。 + +**Q: 配置保存后 Bot 没反应?** + +大多数配置项支持热更新。但少数关键配置(如 WebSocket 地址、WebUI 端口、API 端口)需要重启才能生效,保存时会有提示。 + +**Q: 日志不滚动了?** + +检查是否意外暂停了日志流。点击日志页面的播放按钮恢复实时推送。 + +**Q: 探针显示 Runtime API 不可达?** + +确认 `[api].enabled = true` 且 Bot 正在运行。Runtime API 由 Bot 主进程提供,Bot 未启动时自然不可达。 From 06a831149448d3fb7fbb6c1163bec4cd8acef17c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 21:07:17 +0800 Subject: [PATCH 57/57] fix(calculator): cap combinatorial function args to prevent CPU exhaustion Add _MAX_COMBINATORIAL_ARG=1000 limit for factorial/perm/comb to prevent adversarial inputs like factorial(99999) from consuming excessive CPU. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/tools/calculator/handler.py | 10 ++++++++++ tests/test_calculator.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/Undefined/skills/tools/calculator/handler.py b/src/Undefined/skills/tools/calculator/handler.py index cd26438..fd73024 100644 --- a/src/Undefined/skills/tools/calculator/handler.py +++ b/src/Undefined/skills/tools/calculator/handler.py @@ -134,6 +134,9 @@ def _stat_fn(fn_name: str, args: list[float | int]) -> float | int: _MAX_POWER = 10000 _MAX_EXPRESSION_LENGTH = 500 +_MAX_COMBINATORIAL_ARG = 1000 + +_COMBINATORIAL_FUNCS = frozenset({"factorial", "perm", "comb"}) class _SafeEvaluator(ast.NodeVisitor): @@ -198,6 +201,13 @@ def visit_Call(self, node: ast.Call) -> Any: args = [self.visit(arg) for arg in node.args] + if fn_name in _COMBINATORIAL_FUNCS: + for a in args: + if isinstance(a, (int, float)) and abs(a) > _MAX_COMBINATORIAL_ARG: + raise ValueError( + f"{fn_name}() 参数过大: {a}(上限 {_MAX_COMBINATORIAL_ARG})" + ) + if fn_name in _STAT_FUNCS: return _stat_fn(fn_name, args) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 118759a..d0e6d2f 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -150,6 +150,23 @@ def test_comb(self) -> None: def test_perm(self) -> None: assert safe_eval("perm(5, 3)") == "60" + def test_factorial_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("factorial(9999)") + + def test_comb_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("comb(9999, 5000)") + + def test_perm_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("perm(9999, 5000)") + + def test_factorial_at_limit(self) -> None: + """factorial(1000) should succeed (within limit).""" + result = safe_eval("factorial(1000)") + assert int(result) > 0 + def test_hypot(self) -> None: assert safe_eval("hypot(3, 4)") == "5"