From 0a9eca80b4636cf779d72d4c1ef20db4d054f8df Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Fri, 17 Apr 2026 23:52:28 +0800 Subject: [PATCH 01/57] feat(easter_egg): add repeat and inverted question mark features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add repeat_enabled and inverted_question_enabled config fields under [easter_egg] - Implement group chat repeat: auto-repeat when 3 consecutive identical messages from different senders - Inverted question mark: send ¿ instead of ? when repeat triggers on question-mark-only messages - Race-condition protection via per-group asyncio.Lock - Inject easter egg status into AI prompt context (model config info + system behavior) - Update config.toml.example and docs/configuration.md - Add tests for config loading (5), handler logic (12), and prompt injection (10) --- config.toml.example | 8 +- docs/configuration.md | 2 + src/Undefined/ai/prompts.py | 63 +++++- src/Undefined/config/loader.py | 20 ++ src/Undefined/handlers.py | 47 +++++ tests/test_config_easter_egg_repeat.py | 62 ++++++ tests/test_handlers_repeat.py | 266 ++++++++++++++++++++++++ tests/test_prompt_builder_easter_egg.py | 188 +++++++++++++++++ 8 files changed, 654 insertions(+), 2 deletions(-) create mode 100644 tests/test_config_easter_egg_repeat.py create mode 100644 tests/test_handlers_repeat.py create mode 100644 tests/test_prompt_builder_easter_egg.py diff --git a/config.toml.example b/config.toml.example index 84756ee..a68a724 100644 --- a/config.toml.example +++ b/config.toml.example @@ -656,9 +656,15 @@ pool_enabled = false # zh: 彩蛋提示发送模式。模式:"none"(关闭)/"agent"(主 AI 调用 Agent 时发送)/"tools"(主 AI 或 Agent 调用 Tool 时发送)/"clean"(过滤噪声;对自动预取的工具如 "get_current_time"、"send_message"、"end" 不予提示)/"all"(包括 Agent 内部调用其子工具即 "agent_tool" 的场景也发送)。默认:"none"。 # en: Easter-egg announcement mode. Modes: "none" (off) / "agent" (send when the main AI calls an Agent) / "tools" (send when the main AI or an Agent calls a Tool) / "clean" (filter noise; automatically prefetched tools such as "get_current_time", "send_message", and "end" are not announced) / "all" (also send when an Agent internally calls its sub-tools, i.e. "agent_tool"). Default: "none". agent_call_message_enabled = "none" -# zh: 是否启用群聊关键词(“心理委员”)自动回复。 +# zh: 是否启用群聊关键词("心理委员")自动回复。 # en: Enable keyword auto-replies("心理委员") in group chats. keyword_reply_enabled = false +# zh: 是否启用群聊复读功能(连续3条相同消息时复读)。 +# en: Enable repeat feature in group chats (repeat when 3 consecutive identical messages). +repeat_enabled = false +# zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 +# en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). +inverted_question_enabled = false # zh: 历史记录配置。 # en: History settings. diff --git a/docs/configuration.md b/docs/configuration.md index fc739c2..49c9465 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -426,6 +426,8 @@ Prompt caching 补充: |---|---:|---|---| | `agent_call_message_enabled` | `"none"` | 调用提示模式 | `none` / `agent` / `tools` / `all` / `clean` | | `keyword_reply_enabled` | `false` | 群聊关键词自动回复 | 布尔 | +| `repeat_enabled` | `false` | 群聊复读(连续3条相同消息时复读) | 布尔 | +| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送¿) | 布尔 | 兼容:历史字段 `[core].keyword_reply_enabled` 仍可读取,建议迁移到 `[easter_egg]`。 diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index 62ae955..3cd9b70 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -204,6 +204,42 @@ def _build_model_config_info(self, runtime_config: Any) -> str: else: parts.append("- 思维链: 未启用") + # 彩蛋功能状态 + keyword_reply_enabled = bool( + getattr(runtime_config, "keyword_reply_enabled", False) + ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) + agent_call_mode = str( + getattr(runtime_config, "easter_egg_agent_call_message_mode", "none") + ) + easter_egg_parts: list[str] = [] + if keyword_reply_enabled: + easter_egg_parts.append( + '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' + ) + if repeat_enabled: + desc = "复读(群聊连续3条相同消息时自动复读)" + if inverted_question_enabled: + desc += ",倒问号(复读触发时若消息为问号则发送¿)" + easter_egg_parts.append(desc) + elif inverted_question_enabled: + easter_egg_parts.append("倒问号(复读未启用,此功能不生效)") + if agent_call_mode != "none": + mode_desc = { + "agent": "Agent调用提示", + "tools": "工具调用提示", + "clean": "降噪调用提示", + "all": "全量调用提示", + }.get(agent_call_mode, agent_call_mode) + easter_egg_parts.append(f"调用提示模式={mode_desc}") + if easter_egg_parts: + parts.append("- 彩蛋功能: " + ";".join(easter_egg_parts)) + else: + parts.append("- 彩蛋功能: 未启用") + parts.append("") parts.append( "重要:以上是你的模型配置信息。\n" @@ -304,14 +340,20 @@ async def build_messages( is_group_context = True keyword_reply_enabled = False + repeat_enabled = False + inverted_question_enabled = False if self._runtime_config_getter is not None: try: runtime_config = self._runtime_config_getter() keyword_reply_enabled = bool( getattr(runtime_config, "keyword_reply_enabled", False) ) + repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + inverted_question_enabled = bool( + getattr(runtime_config, "inverted_question_enabled", False) + ) except Exception as exc: - logger.debug("读取关键词自动回复配置失败: %s", exc) + logger.debug("读取彩蛋功能配置失败: %s", exc) if is_group_context and keyword_reply_enabled: messages.append( @@ -329,6 +371,25 @@ async def build_messages( } ) + if is_group_context and repeat_enabled: + repeat_desc = ( + "【系统行为说明】\n" + "当前群聊已开启复读彩蛋:当群聊中连续出现3条内容相同且来自不同人的消息时," + "系统会自动复读一条相同的消息,并在历史中写入" + '以"[系统复读] "开头的消息。' + ) + if inverted_question_enabled: + repeat_desc += ( + "\n此外,若复读触发时消息内容仅由问号组成(如?或???)," + "系统会发送对应数量的倒问号(¿)代替。" + ) + repeat_desc += ( + "\n\n这类消息属于系统预设机制,不代表你在该轮主动决策。" + "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" + "除非用户主动询问,否则不要主动解释此机制。" + ) + messages.append({"role": "system", "content": repeat_desc}) + # 注入 Anthropic Skills 元数据(Level 1: 始终加载 name + description) if ( self._anthropic_skill_registry diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index ff8486a..f7d7f52 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -468,6 +468,8 @@ class Config: process_private_message: bool process_poke_message: bool keyword_reply_enabled: bool + repeat_enabled: bool + inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int nagaagent_mode_enabled: bool @@ -705,6 +707,22 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi None, ) keyword_reply_enabled = _coerce_bool(keyword_reply_raw, False) + repeat_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "repeat_enabled"), + "EASTER_EGG_REPEAT_ENABLED", + ), + False, + ) + inverted_question_enabled = _coerce_bool( + _get_value( + data, + ("easter_egg", "inverted_question_enabled"), + "EASTER_EGG_INVERTED_QUESTION_ENABLED", + ), + False, + ) context_recent_messages_limit = _coerce_int( _get_value( data, @@ -1377,6 +1395,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi process_private_message=process_private_message, process_poke_message=process_poke_message, keyword_reply_enabled=keyword_reply_enabled, + repeat_enabled=repeat_enabled, + inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, nagaagent_mode_enabled=nagaagent_mode_enabled, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 849b8cf..651da57 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -43,6 +43,7 @@ logger = logging.getLogger(__name__) KEYWORD_REPLY_HISTORY_PREFIX = "[系统关键词自动回复] " +REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " def _safe_int(value: Any) -> int | None: @@ -128,9 +129,21 @@ def __init__( self._background_tasks: set[asyncio.Task[None]] = set() self._profile_name_refresh_cache: dict[tuple[str, int], str] = {} + # 复读功能状态(按群跟踪最近消息文本与发送者) + self._repeat_counter: dict[int, list[tuple[str, int]]] = {} + self._repeat_locks: dict[int, asyncio.Lock] = {} + # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) + def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: + """获取或创建指定群的复读竞态保护锁。""" + lock = self._repeat_locks.get(group_id) + if lock is None: + lock = asyncio.Lock() + self._repeat_locks[group_id] = lock + return lock + async def _collect_message_attachments( self, message_content: list[dict[str, Any]], @@ -684,6 +697,40 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) return + # 复读功能:连续3条相同消息(来自不同发送者)时复读 + if self.config.repeat_enabled: + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + # 只保留最近5条 + if len(counter) > 5: + self._repeat_counter[group_id] = counter[-5:] + counter = self._repeat_counter[group_id] + + if len(counter) >= 3: + last3 = counter[-3:] + texts = [t for t, _ in last3] + senders = [s for _, s in last3] + if len(set(texts)) == 1 and len(set(senders)) == 3: + reply_text = texts[0] + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + # 清空计数器防止重复触发 + self._repeat_counter[group_id] = [] + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) + return + # Bilibili 视频自动提取 if self.config.bilibili_auto_extract_enabled: if self.config.is_bilibili_auto_extract_allowed_group(group_id): diff --git a/tests/test_config_easter_egg_repeat.py b/tests/test_config_easter_egg_repeat.py new file mode 100644 index 0000000..89872e5 --- /dev/null +++ b/tests/test_config_easter_egg_repeat.py @@ -0,0 +1,62 @@ +"""Config 加载:[easter_egg] repeat_enabled / inverted_question_enabled""" + +from __future__ import annotations + +from pathlib import Path + +from Undefined.config.loader import Config + + +def _load(tmp_path: Path, text: str) -> Config: + p = tmp_path / "config.toml" + p.write_text(text, "utf-8") + return Config.load(p, strict=False) + + +_MINIMAL = """ +[onebot] +ws_url = "ws://127.0.0.1:3001" +[models.chat] +api_url = "https://api.example/v1" +api_key = "sk-test" +model_name = "gpt-test" +""" + + +def test_repeat_defaults_to_false(tmp_path: Path) -> None: + cfg = _load(tmp_path, _MINIMAL) + assert cfg.repeat_enabled is False + assert cfg.inverted_question_enabled is False + + +def test_repeat_enabled_explicit(tmp_path: Path) -> None: + cfg = _load(tmp_path, _MINIMAL + "\n[easter_egg]\nrepeat_enabled = true\n") + assert cfg.repeat_enabled is True + assert cfg.inverted_question_enabled is False + + +def test_inverted_question_enabled_explicit(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + + "\n[easter_egg]\nrepeat_enabled = true\ninverted_question_enabled = true\n", + ) + assert cfg.repeat_enabled is True + assert cfg.inverted_question_enabled is True + + +def test_inverted_question_without_repeat(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\ninverted_question_enabled = true\n", + ) + assert cfg.repeat_enabled is False + assert cfg.inverted_question_enabled is True + + +def test_keyword_reply_still_parsed_from_easter_egg(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nkeyword_reply_enabled = true\n", + ) + assert cfg.keyword_reply_enabled is True diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py new file mode 100644 index 0000000..3adf332 --- /dev/null +++ b/tests/test_handlers_repeat.py @@ -0,0 +1,266 @@ +"""MessageHandler 复读功能测试""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from Undefined.handlers import ( + MessageHandler, + REPEAT_REPLY_HISTORY_PREFIX, +) + + +def _build_handler( + *, + repeat_enabled: bool = False, + inverted_question_enabled: bool = False, + keyword_reply_enabled: bool = False, +) -> Any: + handler: Any = MessageHandler.__new__(MessageHandler) + handler.config = SimpleNamespace( + bot_qq=10000, + repeat_enabled=repeat_enabled, + inverted_question_enabled=inverted_question_enabled, + keyword_reply_enabled=keyword_reply_enabled, + bilibili_auto_extract_enabled=False, + arxiv_auto_extract_enabled=False, + should_process_group_message=lambda is_at_bot=False: True, + should_process_private_message=lambda: True, + is_group_allowed=lambda _gid: True, + is_private_allowed=lambda _uid: True, + access_control_enabled=lambda: False, + process_every_message=True, + ) + handler.history_manager = SimpleNamespace( + add_group_message=AsyncMock(), + add_private_message=AsyncMock(), + ) + handler.sender = SimpleNamespace( + send_group_message=AsyncMock(), + send_private_message=AsyncMock(), + ) + handler.ai_coordinator = SimpleNamespace( + handle_auto_reply=AsyncMock(), + handle_private_reply=AsyncMock(), + _is_at_bot=lambda _mc: False, + ) + handler.ai = SimpleNamespace( + _cognitive_service=None, + memory_storage=None, + model_pool=SimpleNamespace( + handle_private_message=AsyncMock(return_value=False) + ), + ) + handler.onebot = SimpleNamespace( + get_group_info=AsyncMock(return_value={"group_name": "测试群"}), + get_stranger_info=AsyncMock(return_value={"nickname": "用户"}), + get_msg=AsyncMock(return_value=None), + get_forward_msg=AsyncMock(return_value=None), + ) + handler.command_dispatcher = SimpleNamespace( + parse_command=lambda _t: None, + ) + handler._background_tasks = set() + handler._repeat_counter = {} + handler._repeat_locks = {} + handler._profile_name_refresh_cache = {} + return handler + + +def _group_event( + group_id: int = 30001, + sender_id: int = 20001, + text: str = "hello", +) -> dict[str, Any]: + return { + "post_type": "message", + "message_type": "group", + "group_id": group_id, + "user_id": sender_id, + "message_id": 1, + "sender": { + "user_id": sender_id, + "card": f"用户{sender_id}", + "nickname": f"昵称{sender_id}", + "role": "member", + "title": "", + }, + "message": [{"type": "text", "data": {"text": text}}], + } + + +# ── 基础:复读未启用时不触发 ── + + +@pytest.mark.asyncio +async def test_repeat_disabled_does_not_repeat() -> None: + handler = _build_handler(repeat_enabled=False) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +# ── 复读触发:3条相同消息来自不同人 ── + + +@pytest.mark.asyncio +async def test_repeat_triggers_on_3_identical_from_different_senders() -> None: + handler = _build_handler(repeat_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[0] == 30001 + assert call.args[1] == "hello" + assert call.kwargs.get("history_prefix") == REPEAT_REPLY_HISTORY_PREFIX + + +# ── 不触发:3条相同消息来自同一人 ── + + +@pytest.mark.asyncio +async def test_repeat_does_not_trigger_from_same_sender() -> None: + handler = _build_handler(repeat_enabled=True) + for _ in range(3): + await handler.handle_message(_group_event(sender_id=20001, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +# ── 不触发:消息内容不同 ── + + +@pytest.mark.asyncio +async def test_repeat_does_not_trigger_for_different_texts() -> None: + handler = _build_handler(repeat_enabled=True) + for uid, text in [(20001, "hello"), (20002, "world"), (20003, "hello")]: + await handler.handle_message(_group_event(sender_id=uid, text=text)) + + handler.sender.send_group_message.assert_not_called() + + +# ── 防重复:触发后计数器清空 ── + + +@pytest.mark.asyncio +async def test_repeat_clears_counter_after_trigger() -> None: + handler = _build_handler(repeat_enabled=True) + # 第一轮:3条相同触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + assert handler.sender.send_group_message.call_count == 1 + + # 第二轮:再来3条相同应再次触发 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + assert handler.sender.send_group_message.call_count == 2 + + +# ── 倒问号:问号消息触发倒问号 ── + + +@pytest.mark.asyncio +async def test_inverted_question_sends_inverted_mark() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿" + + +@pytest.mark.asyncio +async def test_inverted_question_multiple_marks() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿¿¿" + + +@pytest.mark.asyncio +async def test_inverted_question_chinese_question_mark() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "¿" + + +@pytest.mark.asyncio +async def test_inverted_question_disabled_sends_normal_text() -> None: + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=False) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "?" + + +@pytest.mark.asyncio +async def test_inverted_question_mixed_text_not_triggered() -> None: + """非纯问号消息不受倒问号影响,正常复读。""" + handler = _build_handler(repeat_enabled=True, inverted_question_enabled=True) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="what?")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "what?" + + +# ── 不同群互不干扰 ── + + +@pytest.mark.asyncio +async def test_repeat_groups_are_independent() -> None: + handler = _build_handler(repeat_enabled=True) + # 群A: 2条相同 + await handler.handle_message( + _group_event(group_id=30001, sender_id=20001, text="hi") + ) + await handler.handle_message( + _group_event(group_id=30001, sender_id=20002, text="hi") + ) + # 群B: 3条相同 + for uid in [30001, 30002, 30003]: + await handler.handle_message( + _group_event(group_id=30002, sender_id=uid, text="hi") + ) + + # 群B触发,群A未触发 + assert handler.sender.send_group_message.call_count == 1 + call = handler.sender.send_group_message.call_args + assert call.args[0] == 30002 + + +# ── 计数器窗口:只看最近5条 ── + + +@pytest.mark.asyncio +async def test_repeat_counter_sliding_window() -> None: + handler = _build_handler(repeat_enabled=True) + # 发5条不同消息 + for i in range(5): + await handler.handle_message(_group_event(sender_id=20001 + i, text=f"msg{i}")) + # 再发3条相同 + for uid in [20010, 20011, 20012]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + call = handler.sender.send_group_message.call_args + assert call.args[1] == "hello" diff --git a/tests/test_prompt_builder_easter_egg.py b/tests/test_prompt_builder_easter_egg.py new file mode 100644 index 0000000..332269c --- /dev/null +++ b/tests/test_prompt_builder_easter_egg.py @@ -0,0 +1,188 @@ +"""PromptBuilder 彩蛋功能注入测试""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord +from Undefined.memory import Memory + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [] + + +class _FakeCognitiveService: + enabled = False + + async def build_context(self, **kwargs: Any) -> str: + return "" + + +class _FakeMemoryStorage: + def get_all(self) -> list[Memory]: + return [] + + +def _make_builder( + *, + keyword_reply_enabled: bool = False, + repeat_enabled: bool = False, + inverted_question_enabled: bool = False, + easter_egg_agent_call_message_mode: str = "none", +) -> PromptBuilder: + runtime_config = SimpleNamespace( + keyword_reply_enabled=keyword_reply_enabled, + repeat_enabled=repeat_enabled, + inverted_question_enabled=inverted_question_enabled, + easter_egg_agent_call_message_mode=easter_egg_agent_call_message_mode, + knowledge_enabled=False, + grok_search_enabled=False, + chat_model=SimpleNamespace( + model_name="gpt-test", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=False, + ), + vision_model=None, + agent_model=None, + embedding_model=None, + security_model=None, + grok_model=None, + cognitive=None, + memes=None, + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=cast(Any, _FakeMemoryStorage()), + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, None), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +async def _build_messages( + builder: PromptBuilder, + *, + group_id: int | None = None, +) -> list[dict[str, Any]]: + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "" + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return [] + + # Patch internal loaders + builder._load_system_prompt = _fake_load_system_prompt # type: ignore[method-assign,unused-ignore] + builder._load_each_rules = _fake_load_each_rules # type: ignore[method-assign,unused-ignore] + + extra_context: dict[str, Any] = {} + if group_id is not None: + extra_context["group_id"] = group_id + + result = await builder.build_messages( + '\n你好\n', + get_recent_messages_callback=_fake_recent_messages, + extra_context=extra_context if extra_context else None, + ) + return list(result) + + +# ── _build_model_config_info 彩蛋状态 ── + + +def _get_config_info(builder: PromptBuilder) -> str: + getter = builder._runtime_config_getter + assert getter is not None + info = builder._build_model_config_info(getter()) + return str(info) + + +def test_model_config_info_shows_easter_egg_disabled() -> None: + builder = _make_builder() + info = _get_config_info(builder) + assert "彩蛋功能: 未启用" in info + + +def test_model_config_info_shows_keyword_reply_enabled() -> None: + builder = _make_builder(keyword_reply_enabled=True) + info = _get_config_info(builder) + assert "关键词自动回复" in info + assert "彩蛋功能: " in info + + +def test_model_config_info_shows_repeat_enabled() -> None: + builder = _make_builder(repeat_enabled=True) + info = _get_config_info(builder) + assert "复读" in info + assert "连续3条相同消息" in info + + +def test_model_config_info_shows_repeat_with_inverted_question() -> None: + builder = _make_builder(repeat_enabled=True, inverted_question_enabled=True) + info = _get_config_info(builder) + assert "倒问号" in info + assert "¿" in info + + +def test_model_config_info_shows_inverted_question_without_repeat() -> None: + builder = _make_builder(inverted_question_enabled=True) + info = _get_config_info(builder) + assert "倒问号" in info + assert "复读未启用" in info + + +def test_model_config_info_shows_agent_call_mode() -> None: + builder = _make_builder(easter_egg_agent_call_message_mode="clean") + info = _get_config_info(builder) + assert "降噪调用提示" in info + + +# ── 群聊上下文系统行为注入 ── + + +@pytest.mark.asyncio +async def test_repeat_injection_in_group_context() -> None: + builder = _make_builder(repeat_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + repeat_injected = any("[系统复读]" in c for c in system_contents) + assert repeat_injected, "复读彩蛋说明应注入群聊上下文" + + +@pytest.mark.asyncio +async def test_repeat_injection_not_in_private_context() -> None: + builder = _make_builder(repeat_enabled=True) + messages = await _build_messages(builder, group_id=None) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + repeat_injected = any("[系统复读]" in c for c in system_contents) + assert not repeat_injected, "复读彩蛋说明不应注入非群聊上下文" + + +@pytest.mark.asyncio +async def test_inverted_question_mentioned_in_repeat_injection() -> None: + builder = _make_builder(repeat_enabled=True, inverted_question_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + inverted_injected = any("倒问号" in c for c in system_contents) + assert inverted_injected, "倒问号说明应在复读注入中出现" + + +@pytest.mark.asyncio +async def test_keyword_reply_injection_still_works() -> None: + builder = _make_builder(keyword_reply_enabled=True) + messages = await _build_messages(builder, group_id=30001) + system_contents = [m["content"] for m in messages if m["role"] == "system"] + keyword_injected = any("[系统关键词自动回复]" in c for c in system_contents) + assert keyword_injected, "关键词自动回复说明仍应注入" From 4af0beca3141d8cf4a6762431d5650c4120efbb7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 00:04:28 +0800 Subject: [PATCH 02/57] docs(deployment): prioritize source deployment over pip/uv tool - Reorder sections: source deployment first, pip/uv tool second - Update intro to state source deployment is the recommended primary method - Add warning note on pip/uv tool section about incomplete support and testing - Move Management-first flow into source deployment section - Fix cross-references to point upward instead of downward --- docs/deployment.md | 184 +++++++++++++++++++++++---------------------- 1 file changed, 93 insertions(+), 91 deletions(-) diff --git a/docs/deployment.md b/docs/deployment.md index 9e24783..3dec661 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -1,6 +1,6 @@ # 安装与部署指南 -提供 pip/uv tool 安装与源码部署两种方式:前者适合直接使用;后者适合深度自定义与二次开发。 +提供源码部署与 pip/uv tool 安装两种方式:**源码部署是推荐的首选方式**,功能完整且经过充分测试;pip/uv tool 安装适合快速体验,但部分功能支持尚不完善。 > Python 版本要求:`3.11`~`3.13`(包含)。 > @@ -8,7 +8,96 @@ --- -## pip/uv tool 部署(快速,适合默认行为) +## 源码部署(推荐) + +### 1. 克隆项目 + +由于项目中使用了 `NagaAgent` 作为子模块,请使用以下命令克隆项目: + +```bash +git clone --recursive https://github.com/69gg/Undefined.git +cd Undefined +``` + +如果已经克隆了项目但没有初始化子模块: + +```bash +git submodule update --init --recursive +``` + +### 2. 安装依赖 + +推荐使用 `uv` 进行现代化的 Python 依赖管理(速度极快): + +```bash +# 安装 uv (如果尚未安装) +pip install uv + +# 可选:预装一个兼容解释器(推荐 3.12) +# uv python install 3.12 + +# 同步依赖 +# uv 会根据 pyproject.toml 自动处理 3.11~3.13 的解释器选择 +uv sync +``` + +同时需要安装 Playwright 浏览器内核(用于网页浏览功能): + +```bash +uv run playwright install +``` + +### 3. 配置环境 + +复制示例配置文件 `config.toml.example` 为 `config.toml` 并填写你的配置信息。 + +```bash +cp config.toml.example config.toml +``` + +#### 源码部署的自定义指南 + +- **自定义提示词/预置文案**:直接修改仓库根目录的 `res/`(例如 `res/prompts/`)。 +- **自定义图片资源**:修改 `img/` 下的对应文件(例如 `img/xlwy.jpg`)。 +- **优先级**:若你希望“运行目录覆盖优先”:在启动目录放置 `./res/...`,会优先于默认资源生效(便于一套安装,多套运行配置)。 + +### 4. 启动运行 + +启动方式(二选一): + +```bash +# 1) 直接启动机器人(无 WebUI) +uv run Undefined + +# 2) 启动 WebUI(在浏览器里编辑配置,并在 WebUI 内启停机器人) +uv run Undefined-webui +``` + +> **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 + +### 5. 跨平台与资源路径(重要) + +- **资源读取**:运行时会优先从运行目录加载同名 `res/...` / `img/...`(便于覆盖),若不存在再使用安装包自带资源;并提供仓库结构兜底查找,因此从任意目录启动也能正常加载提示词与资源文案。 +- **并发写入**:运行时会为 JSON/日志类文件使用”锁文件 + 原子替换”写入策略,Windows/Linux/macOS 行为一致(会生成 `*.lock` 文件)。 + +### Management-first 推荐流程 + +推荐把 `Undefined-webui` 当作默认入口: + +1. 运行 `uv run Undefined-webui` +2. 在浏览器中打开管理控制台 +3. 若 `config.toml` 缺失,WebUI 会自动生成模板 +4. 在控制台中补齐配置、保存并校验 +5. 直接点击启动 Bot +6. 若需要远程管理,再使用桌面端或 Android App 连接到这个 Management API + +这样可以避免"先手写配置、再反复命令行重启"的冷启动成本,尤其适合首次部署与远程运维。 + +--- + +## pip/uv tool 部署(快速体验) + +> **注意**:pip/uv tool 安装方式的功能支持尚不如源码部署完善,也未经过充分测试。如遇问题,建议优先切换到源码部署。 适合只想“安装后直接跑”的场景,`Undefined`/`Undefined-webui` 命令会作为可执行入口安装到你的环境中。 @@ -46,20 +135,7 @@ Undefined-webui > - 选择 `Undefined-webui`:启动后访问 WebUI(默认 `http://127.0.0.1:8787`,密码默认 `changeme`;**首次启动必须修改默认密码,默认密码不可登录**;可在 `config.toml` 的 `[webui]` 中修改),在 WebUI 中在线编辑/校验配置,并通过 WebUI 启动/停止机器人进程。 > `Undefined-webui` 会在检测到当前目录缺少 `config.toml` 时,自动从 `config.toml.example` 生成一份,便于直接在 WebUI 中修改。 -> 提示:资源文件已随包发布,支持在非项目根目录启动;如需自定义内容,请参考下方说明。 - -## Management-first 推荐流程 - -推荐把 `Undefined-webui` 当作默认入口: - -1. 运行 `Undefined-webui` 或 `uv run Undefined-webui` -2. 在浏览器中打开管理控制台 -3. 若 `config.toml` 缺失,WebUI 会自动生成模板 -4. 在控制台中补齐配置、保存并校验 -5. 直接点击启动 Bot -6. 若需要远程管理,再使用桌面端或 Android App 连接到这个 Management API - -这样可以避免“先手写配置、再反复命令行重启”的冷启动成本,尤其适合首次部署与远程运维。 +> 提示:资源文件已随包发布,支持在非项目根目录启动;如需自定义内容,请参考上方源码部署的自定义指南。 ### 完整日志(排查用) @@ -90,7 +166,7 @@ mkdir -p res/prompts # 然后把你想改的提示词放到对应路径(文件名与目录层级保持一致) ``` -如果你希望直接修改“默认提示词/默认文案”(而不是每个运行目录做覆盖),推荐使用下面的“源码部署”,在仓库里修改 `res/` 后运行;不建议直接修改已安装环境的 `site-packages/res`(升级会被覆盖)。 +如果你希望直接修改“默认提示词/默认文案”(而不是每个运行目录做覆盖),推荐使用上面的“源码部署”,在仓库里修改 `res/` 后运行;不建议直接修改已安装环境的 `site-packages/res`(升级会被覆盖)。 如果你不知道安装包内默认提示词文件在哪,可以用下面方式打印路径(用于复制一份出来改): @@ -106,80 +182,6 @@ python -c "from Undefined.utils.resources import read_text_resource; print(len(r --- -## 源码部署(推荐开发/高定使用) - -### 1. 克隆项目 - -由于项目中使用了 `NagaAgent` 作为子模块,请使用以下命令克隆项目: - -```bash -git clone --recursive https://github.com/69gg/Undefined.git -cd Undefined -``` - -如果已经克隆了项目但没有初始化子模块: - -```bash -git submodule update --init --recursive -``` - -### 2. 安装依赖 - -推荐使用 `uv` 进行现代化的 Python 依赖管理(速度极快): - -```bash -# 安装 uv (如果尚未安装) -pip install uv - -# 可选:预装一个兼容解释器(推荐 3.12) -# uv python install 3.12 - -# 同步依赖 -# uv 会根据 pyproject.toml 自动处理 3.11~3.13 的解释器选择 -uv sync -``` - -同时需要安装 Playwright 浏览器内核(用于网页浏览功能): - -```bash -uv run playwright install -``` - -### 3. 配置环境 - -复制示例配置文件 `config.toml.example` 为 `config.toml` 并填写你的配置信息。 - -```bash -cp config.toml.example config.toml -``` - -#### 源码部署的自定义指南 - -- **自定义提示词/预置文案**:直接修改仓库根目录的 `res/`(例如 `res/prompts/`)。 -- **自定义图片资源**:修改 `img/` 下的对应文件(例如 `img/xlwy.jpg`)。 -- **优先级**:若你希望“运行目录覆盖优先”:在启动目录放置 `./res/...`,会优先于默认资源生效(便于一套安装,多套运行配置)。 - -### 4. 启动运行 - -启动方式(二选一): - -```bash -# 1) 直接启动机器人(无 WebUI) -uv run Undefined - -# 2) 启动 WebUI(在浏览器里编辑配置,并在 WebUI 内启停机器人) -uv run Undefined-webui -``` - -> **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 - -### 5. 跨平台与资源路径(重要) - -- **资源读取**:运行时会优先从运行目录加载同名 `res/...` / `img/...`(便于覆盖),若不存在再使用安装包自带资源;并提供仓库结构兜底查找,因此从任意目录启动也能正常加载提示词与资源文案。 -- **并发写入**:运行时会为 JSON/日志类文件使用”锁文件 + 原子替换”写入策略,Windows/Linux/macOS 行为一致(会生成 `*.lock` 文件)。 - ---- - ## NapCat / Lagrange.Core 部署要求 **NapCat(或 Lagrange.Core)必须与 Bot 进程共享同一文件系统,不能将 NapCat 单独放在无法访问 Bot 数据目录的 Docker 容器内。** From 379ae4d8ced920dda7fd85c5fe1b04fbf3181afb Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 09:41:26 +0800 Subject: [PATCH 03/57] =?UTF-8?q?fix(meme):=20=E8=A1=A8=E6=83=85=E5=8C=85?= =?UTF-8?q?=E5=8A=A8=E5=9B=BE=E5=88=A4=E5=AE=9A=E6=94=B9=E8=BF=9B=E4=B8=8E?= =?UTF-8?q?=E6=9B=B4=E5=A5=BD=E7=9A=84=E9=87=8D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.toml.example | 6 ++ docs/configuration.md | 5 +- docs/memes.md | 4 + res/prompts/describe_meme_image.txt | 1 + res/prompts/judge_meme_image.txt | 2 +- src/Undefined/ai/multimodal.py | 34 ++++--- src/Undefined/config/models.py | 2 + src/Undefined/memes/service.py | 131 +++++++++++++++++++++++++-- tests/test_meme_gif_frames.py | 132 ++++++++++++++++++++++++++++ tests/test_meme_retry.py | 63 +++++++++++++ 10 files changed, 359 insertions(+), 21 deletions(-) create mode 100644 tests/test_meme_gif_frames.py create mode 100644 tests/test_meme_retry.py diff --git a/config.toml.example b/config.toml.example index a68a724..0e7fe2c 100644 --- a/config.toml.example +++ b/config.toml.example @@ -1101,6 +1101,12 @@ max_total_bytes = 5368709120 # zh: 是否允许 GIF 入库。 # en: Whether GIF files are allowed. allow_gif = true +# zh: GIF 分析模式:grid(多帧拼接为网格图)或 multi(多帧分开发送给模型)。 +# en: GIF analysis mode: grid (composite frames into grid) or multi (send frames separately). +gif_analysis_mode = "grid" +# zh: GIF 分析帧数(包括首末帧,均匀采样)。 +# en: Number of frames to extract for GIF analysis (including first/last, evenly sampled). +gif_analysis_frames = 6 # zh: 是否自动处理群聊图片。 # en: Auto-ingest group chat images. auto_ingest_group = true diff --git a/docs/configuration.md b/docs/configuration.md index 49c9465..eceb9ea 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -724,6 +724,8 @@ Prompt caching 补充: | `max_items` | `10000` | 表情包条目上限 | `<=0` 回退 `10000` | | `max_total_bytes` | `5368709120` | 表情包总磁盘占用上限(字节) | `<=0` 回退 `5368709120` | | `allow_gif` | `true` | 是否允许 GIF 入库 | | +| `gif_analysis_mode` | `"grid"` | GIF 动图判定分析模式:`"grid"` (网格拼图)、`"multi"` (多图逐帧)、`"first_frame"` (仅第一帧) | 非法值回退 `"grid"` | +| `gif_analysis_frames` | `6` | GIF 动图抽帧供模型识别的数量 | `<=0` 回退 `6` | | `auto_ingest_group` | `true` | 是否自动处理群聊图片 | | | `auto_ingest_private` | `true` | 是否自动处理私聊图片 | | | `keyword_top_k` | `30` | 关键词候选召回数 | `<=0` 回退 `30` | @@ -734,7 +736,8 @@ Prompt caching 补充: - 表情包入库走两阶段 LLM 管线: 1. 判定是否为表情包 2. 对通过判定的图片生成纯文本描述与标签 -- 第一阶段失败时,按“不是表情包”处理,直接丢弃。 +- 第一阶段失败时,按“不是表情包”处理,直接丢弃(如果是网络和服务器限流等异常,系统会在后台自动重试)。 +- 对于 GIF 格式图片的分析,`"grid"` 模式会将多个抽帧横向并排或拼接在一张大图中降低计费单元,`"multi"` 模式则将各帧作为独立图像输入至多模态大模型。 - 第二阶段不做 OCR;向量存储和检索文本只使用纯文本 `description + tags + aliases`。 - 同一图片内容在单进程内会按 `SHA256` 串行入库,避免并发表情包重复写入。 - 若入库在写入来源记录或向量索引阶段失败,会回滚已写入的元数据与本地文件,避免残留孤儿记录。 diff --git a/docs/memes.md b/docs/memes.md index 29e993d..e0f93e8 100644 --- a/docs/memes.md +++ b/docs/memes.md @@ -11,12 +11,14 @@ Undefined 平台自 3.3.0 版本起内置了强大的**全局表情包库**功 2. **第一阶段 - 属性判定 (Judge)**: 提交给视觉模型(通过 `judge_meme_image.txt` 提示词)分析图片本质。如果图片只是普通的自拍、系统截图或者无法表现梗(Meme)的内容,流程将在此终止。 + 对于 GIF 动图,系统可根据配置(网格拼接或多张多帧)进行抽帧重组以提供更连贯的视觉上下文。若在此交互期间遇到暂时的网络错误或接口报错,处理管线会自动重试,确保判定过程的高可用性。 3. **第二阶段 - 语义解析 (Describe)**: 对于被判定为表情包的图片,模型会进一步(通过 `describe_meme_image.txt` 提示词)提取: - 图片的关键视觉元素与构图。 - 隐喻、情感与适合的回复语境。 - 高质量的搜索标签(Tags)。 + 同样,该阶段依然享有自动重试逻辑保护,从而保障长流程分析和描述内容的成功入库。 4. **向量化与持久化**: 提取出的结构化文本与标签被存入 SQLite (`MemeStore`),并通过嵌入模型向量化后存入 ChromaDB (`MemeVectorStore`)。原图及其生成的预览图(如 GIF 抽帧)持久化存放至数据目录。 @@ -41,6 +43,8 @@ Undefined 平台自 3.3.0 版本起内置了强大的**全局表情包库**功 [memes] enabled = true # 是否启用 query_default_mode = "hybrid" # 默认搜索策略:keyword / semantic / hybrid +gif_analysis_mode = "grid" # GIF 的多帧识别模式:grid(网格拼接)、multi(多张散图)、first_frame(仅首帧) +gif_analysis_frames = 6 # GIF 的抽帧数量 ``` 更多细节请查阅 [配置文档](configuration.md#425-memes-表情包库)。 diff --git a/res/prompts/describe_meme_image.txt b/res/prompts/describe_meme_image.txt index af24d2c..e223e56 100644 --- a/res/prompts/describe_meme_image.txt +++ b/res/prompts/describe_meme_image.txt @@ -56,6 +56,7 @@ - 不要逐字抄录图中文字。 - 不要输出 Markdown,不要输出额外解释。 - 你必须且只能调用 `submit_meme_description`。 +- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:描述应涵盖动图的动态变化过程和整体语义,不要只描述单帧;可以在 tags 里加上 `动图` 标签。 好的例子: - `description`: `猫猫无语翻白眼反应图` diff --git a/res/prompts/judge_meme_image.txt b/res/prompts/judge_meme_image.txt index e22cf64..5a3d02e 100644 --- a/res/prompts/judge_meme_image.txt +++ b/res/prompts/judge_meme_image.txt @@ -21,4 +21,4 @@ - 必须且只能调用 `submit_meme_judgement` - `is_meme` 仅表示“适不适合放进聊天表情包库”,不是“图里有没有梗” - `reason` 用一句简短中文说明依据 -- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true` +- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true`- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index 9428882..3391a6e 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -607,13 +607,13 @@ async def _prune_url_cache_locks( self._url_cache_locks.pop(key, None) async def _build_content_items( - self, media_type: str, media_content: str, prompt: str + self, media_type: str, media_content: str | list[str], prompt: str ) -> list[dict[str, Any]]: """构建请求内容项。 Args: media_type: 媒体类型 - media_content: 媒体内容(URL 或 data URL) + media_content: 媒体内容(URL/data URL),或其列表 prompt: 提示词 Returns: @@ -623,9 +623,9 @@ async def _build_content_items( # 添加媒体内容项 media_item_key = f"{media_type}_url" - content_items.append( - {"type": media_item_key, media_item_key: {"url": media_content}} - ) + contents = media_content if isinstance(media_content, list) else [media_content] + for mc in contents: + content_items.append({"type": media_item_key, media_item_key: {"url": mc}}) return content_items @@ -822,13 +822,19 @@ async def _request_required_tool_args( self, *, prompt_path: str, - image_url: str, + image_url: str | list[str], tool_schema: dict[str, Any], tool_name: str, call_type: str, max_tokens: int, ) -> dict[str, Any]: - media_content = await self._load_media_content(image_url, "image") + if isinstance(image_url, list): + media_contents: list[str] = [] + for url in image_url: + media_contents.append(await self._load_media_content(url, "image")) + media_content: str | list[str] = media_contents + else: + media_content = await self._load_media_content(image_url, "image") prompt = await self._load_prompt_text(prompt_path) content_items = await self._build_content_items("image", media_content, prompt) response = await self._requester.request( @@ -847,11 +853,13 @@ async def _request_required_tool_args( expected_tool_name=tool_name, stage=call_type, logger=logger, - error_context=f"image={redact_string(image_url)[:120]}", + error_context=f"image={redact_string(str(image_url) if isinstance(image_url, list) else image_url)[:120]}", ) - async def judge_meme_image(self, image_url: str) -> dict[str, Any]: - safe_url = redact_string(image_url) + async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) try: args = await self._request_required_tool_args( prompt_path=_MEME_JUDGE_PROMPT_PATH, @@ -886,8 +894,10 @@ async def judge_meme_image(self, image_url: str) -> dict[str, Any]: ) return parsed - async def describe_meme_image(self, image_url: str) -> dict[str, Any]: - safe_url = redact_string(image_url) + async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: + safe_url = redact_string( + str(image_url) if isinstance(image_url, list) else image_url + ) try: args = await self._request_required_tool_args( prompt_path=_MEME_DESCRIBE_PROMPT_PATH, diff --git a/src/Undefined/config/models.py b/src/Undefined/config/models.py index 64aae69..3190b38 100644 --- a/src/Undefined/config/models.py +++ b/src/Undefined/config/models.py @@ -338,6 +338,8 @@ class MemeConfig: semantic_top_k: int = 30 rerank_top_k: int = 20 worker_max_concurrency: int = 4 + gif_analysis_mode: str = "grid" + gif_analysis_frames: int = 6 @dataclass diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index e853f1e..851b6a6 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -6,6 +6,7 @@ from datetime import datetime import hashlib import logging +import math import mimetypes from pathlib import Path import re @@ -14,6 +15,7 @@ from typing import Any from uuid import uuid4 +from openai import APIConnectionError, APIStatusError, APITimeoutError from PIL import Image from Undefined.attachments import AttachmentRecord @@ -78,6 +80,77 @@ def _normalize_tags(raw_tags: list[str] | str | None) -> list[str]: return normalize_string_list(raw_tags) +def _is_retryable_llm_error(exc: Exception) -> bool: + """判断 LLM 调用异常是否应触发 worker 级重试。""" + if isinstance(exc, (APIConnectionError, APITimeoutError)): + return True + if isinstance(exc, APIStatusError): + return exc.status_code == 429 or exc.status_code >= 500 + return False + + +def _extract_gif_frames(source_path: Path, n_frames: int) -> list[Image.Image]: + """从 GIF 中均匀采样 *n_frames* 帧(含首末帧),返回 RGBA Image 列表。""" + with Image.open(source_path) as image: + total = getattr(image, "n_frames", 1) + if total <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + n = min(n_frames, total) + if n <= 1: + image.seek(0) + return [image.convert("RGBA").copy()] + indices = _sample_frame_indices(total, n) + frames: list[Image.Image] = [] + for idx in indices: + image.seek(idx) + frames.append(image.convert("RGBA").copy()) + return frames + + +def _sample_frame_indices(total: int, n: int) -> list[int]: + """生成均匀采样的帧索引列表(始终包含首帧和末帧)。""" + if n >= total: + return list(range(total)) + if n == 1: + return [0] + if n == 2: + return [0, total - 1] + indices = [round(i * (total - 1) / (n - 1)) for i in range(n)] + # 去重并保持顺序 + seen: set[int] = set() + result: list[int] = [] + for idx in indices: + if idx not in seen: + seen.add(idx) + result.append(idx) + return result + + +def _compose_grid(frames: list[Image.Image], output_path: Path) -> None: + """将多帧拼接为网格图并保存为 PNG。""" + n = len(frames) + if n == 0: + return + if n == 1: + frames[0].save(output_path, format="PNG") + return + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + fw, fh = frames[0].size + grid = Image.new("RGBA", (cols * fw, rows * fh), (0, 0, 0, 0)) + for i, frame in enumerate(frames): + resized = ( + frame.resize((fw, fh), Image.Resampling.LANCZOS) + if frame.size != (fw, fh) + else frame + ) + x = (i % cols) * fw + y = (i // cols) * fh + grid.paste(resized, (x, y)) + grid.save(output_path, format="PNG") + + @dataclass class _IngestDigestLockEntry: lock: asyncio.Lock @@ -794,9 +867,12 @@ async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: return if self._ai_client is None: raise RuntimeError("reanalyze requires ai_client") + analyze_path = record.preview_path if record.preview_path else record.blob_path try: - judgement = await self._ai_client.judge_meme_image(record.blob_path) + judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + raise logger.exception( "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc ) @@ -805,8 +881,10 @@ async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: await self.delete_meme(uid) return try: - described = await self._ai_client.describe_meme_image(record.blob_path) + described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + raise logger.exception( "[memes] describe stage failed during reanalyze: uid=%s err=%s", uid, @@ -933,12 +1011,22 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: or mimetypes.guess_type(source_path.name)[0] or "application/octet-stream" ) - analyze_path = str( + analyze_path: str | list[str] = str( preview_path if preview_path is not None else blob_path ) + if ( + is_animated + and str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + == "multi" + ): + analyze_path = await self._prepare_gif_multi_frames( + source_path, uid + ) try: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + raise logger.exception( "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", uid, @@ -956,6 +1044,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: try: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: + if _is_retryable_llm_error(exc): + raise logger.exception( "[memes] describe stage failed, drop uid=%s err=%s", uid, exc ) @@ -1043,17 +1133,44 @@ def _copy() -> None: if not is_animated: return blob_path + cfg = self._cfg() + mode = str(getattr(cfg, "gif_analysis_mode", "grid")).lower() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) preview_path = self._preview_dir() / f"{target_uid}.png" def _render_preview() -> None: - with Image.open(source_path) as image: - image.seek(0) - frame = image.convert("RGBA") - frame.save(preview_path, format="PNG") + frames = _extract_gif_frames(source_path, n_frames) + if mode == "multi": + # multi 模式也需要生成一张预览用于存储/展示,取首帧 + frames[0].save(preview_path, format="PNG") + else: + _compose_grid(frames, preview_path) + for f in frames: + f.close() await asyncio.to_thread(_render_preview) return preview_path + async def _prepare_gif_multi_frames( + self, source_path: Path, target_uid: str + ) -> list[str]: + """multi 模式:将 GIF 各帧单独保存为 PNG,返回路径列表。""" + cfg = self._cfg() + n_frames = max(2, int(getattr(cfg, "gif_analysis_frames", 6))) + preview_dir = self._preview_dir() + + def _render_frames() -> list[str]: + frames = _extract_gif_frames(source_path, n_frames) + paths: list[str] = [] + for i, frame in enumerate(frames): + p = preview_dir / f"{target_uid}_f{i}.png" + frame.save(p, format="PNG") + frame.close() + paths.append(str(p)) + return paths + + return await asyncio.to_thread(_render_frames) + def _hash_file(self, path: Path) -> str: hasher = hashlib.sha256() with path.open("rb") as handle: diff --git a/tests/test_meme_gif_frames.py b/tests/test_meme_gif_frames.py new file mode 100644 index 0000000..aea1f2b --- /dev/null +++ b/tests/test_meme_gif_frames.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import math +from pathlib import Path + +from PIL import Image + +from Undefined.memes.service import ( + _compose_grid, + _extract_gif_frames, + _sample_frame_indices, +) + + +def _make_gif(path: Path, n_frames: int, size: tuple[int, int] = (4, 4)) -> None: + """创建一个包含 *n_frames* 帧的 GIF 文件。""" + frames = [ + Image.new("RGBA", size, (i * 30 % 256, i * 60 % 256, i * 90 % 256, 255)) + for i in range(n_frames) + ] + frames[0].save( + path, + format="GIF", + save_all=True, + append_images=frames[1:], + loop=0, + duration=100, + ) + + +# ── _sample_frame_indices ── + + +def test_sample_indices_basic() -> None: + result = _sample_frame_indices(10, 4) + assert result[0] == 0 + assert result[-1] == 9 + assert len(result) == 4 + + +def test_sample_indices_more_than_total() -> None: + result = _sample_frame_indices(3, 10) + assert result == [0, 1, 2] + + +def test_sample_indices_two() -> None: + result = _sample_frame_indices(20, 2) + assert result == [0, 19] + + +def test_sample_indices_one() -> None: + result = _sample_frame_indices(5, 1) + assert result == [0] + + +def test_sample_indices_no_duplicates() -> None: + result = _sample_frame_indices(3, 3) + assert len(result) == len(set(result)) + + +# ── _extract_gif_frames ── + + +def test_extract_frames_count(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 12) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 6 + for f in frames: + assert f.mode == "RGBA" + f.close() + + +def test_extract_frames_fewer_than_requested(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 3) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 3 + for f in frames: + f.close() + + +def test_extract_frames_single_frame(tmp_path: Path) -> None: + gif_path = tmp_path / "test.gif" + _make_gif(gif_path, 1) + frames = _extract_gif_frames(gif_path, 6) + assert len(frames) == 1 + frames[0].close() + + +# ── _compose_grid ── + + +def test_compose_grid_output(tmp_path: Path) -> None: + frames = [ + Image.new("RGBA", (10, 10), (255, 0, 0, 255)), + Image.new("RGBA", (10, 10), (0, 255, 0, 255)), + Image.new("RGBA", (10, 10), (0, 0, 255, 255)), + Image.new("RGBA", (10, 10), (255, 255, 0, 255)), + ] + output = tmp_path / "grid.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + cols = math.ceil(math.sqrt(4)) + rows = math.ceil(4 / cols) + assert grid.size == (cols * 10, rows * 10) + for f in frames: + f.close() + + +def test_compose_grid_single_frame(tmp_path: Path) -> None: + frames = [Image.new("RGBA", (8, 8), (0, 0, 0, 255))] + output = tmp_path / "grid_single.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + assert grid.size == (8, 8) + frames[0].close() + + +def test_compose_grid_six_frames(tmp_path: Path) -> None: + frames = [Image.new("RGBA", (10, 10), (i * 40, 0, 0, 255)) for i in range(6)] + output = tmp_path / "grid6.png" + _compose_grid(frames, output) + assert output.is_file() + with Image.open(output) as grid: + cols = math.ceil(math.sqrt(6)) + rows = math.ceil(6 / cols) + assert grid.size == (cols * 10, rows * 10) + for f in frames: + f.close() diff --git a/tests/test_meme_retry.py b/tests/test_meme_retry.py new file mode 100644 index 0000000..b07ef0e --- /dev/null +++ b/tests/test_meme_retry.py @@ -0,0 +1,63 @@ +from __future__ import annotations + + +from openai import APIConnectionError, APIStatusError, APITimeoutError +from unittest.mock import MagicMock + +from Undefined.memes.service import _is_retryable_llm_error + + +def _make_api_status_error(status_code: int) -> APIStatusError: + response = MagicMock() + response.status_code = status_code + response.headers = {} + response.text = "" + response.json.return_value = {} + return APIStatusError( + message=f"Error {status_code}", + response=response, + body=None, + ) + + +def test_connection_error_is_retryable() -> None: + exc = APIConnectionError(request=MagicMock()) + assert _is_retryable_llm_error(exc) is True + + +def test_timeout_error_is_retryable() -> None: + exc = APITimeoutError(request=MagicMock()) + assert _is_retryable_llm_error(exc) is True + + +def test_status_429_is_retryable() -> None: + exc = _make_api_status_error(429) + assert _is_retryable_llm_error(exc) is True + + +def test_status_500_is_retryable() -> None: + exc = _make_api_status_error(500) + assert _is_retryable_llm_error(exc) is True + + +def test_status_503_is_retryable() -> None: + exc = _make_api_status_error(503) + assert _is_retryable_llm_error(exc) is True + + +def test_status_401_not_retryable() -> None: + exc = _make_api_status_error(401) + assert _is_retryable_llm_error(exc) is False + + +def test_status_400_not_retryable() -> None: + exc = _make_api_status_error(400) + assert _is_retryable_llm_error(exc) is False + + +def test_generic_exception_not_retryable() -> None: + assert _is_retryable_llm_error(ValueError("parse fail")) is False + + +def test_runtime_error_not_retryable() -> None: + assert _is_retryable_llm_error(RuntimeError("oops")) is False From 255bf8ca2fecb5465ab385b99c67174fa4529b24 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 11:06:32 +0800 Subject: [PATCH 04/57] docs: expand usage guide and enforce mandatory LaTeX dependency - render_latex: remove mathtext fallback, enforce usetex=True strictly; add _strip_document_wrappers to handle \begin{document} input; catch RuntimeError and return helpful install prompt on missing TeX - docs/build.md: add mandatory system LaTeX installation section with per-platform commands (Debian/Arch/macOS/Windows) and verification steps - docs/deployment.md: integrate LaTeX install steps into source deployment workflow as step 3; add reminder callout in pip/uv tool section - docs/usage.md: full rewrite with complete feature reference covering all Agents, Toolsets, Tools, scheduler modes, FAQ commands, slash command permission table, and multi-model pool - tests/test_render_latex_tool.py: add 4 tests covering wrapper stripping, successful embed delivery, and missing TeX error handling --- docs/build.md | 51 +++ docs/deployment.md | 47 +- docs/usage.md | 408 ++++++++++++++++-- src/Undefined/skills/toolsets/README.md | 2 +- .../toolsets/render/render_latex/config.json | 4 +- .../toolsets/render/render_latex/handler.py | 71 ++- tests/test_render_latex_tool.py | 98 +++++ 7 files changed, 612 insertions(+), 69 deletions(-) create mode 100644 tests/test_render_latex_tool.py diff --git a/docs/build.md b/docs/build.md index 297a699..17e7f3d 100644 --- a/docs/build.md +++ b/docs/build.md @@ -28,6 +28,57 @@ uv sync --group dev -p 3.12 uv run playwright install ``` +### 系统级 LaTeX 环境(必装,用于 `render.render_latex`) + +`render.render_latex` 使用系统外部 LaTeX(`usetex=True`)渲染公式,**必须提前安装**,否则渲染会失败并返回错误。 + +**Debian / Ubuntu** + +```bash +sudo apt-get update +sudo apt-get install -y texlive-full dvipng ghostscript +``` + +**Arch Linux** + +```bash +sudo pacman -S --needed \ + texlive-basic \ + texlive-bin \ + texlive-latex \ + texlive-latexrecommended \ + texlive-latexextra \ + texlive-fontsrecommended \ + texlive-binextra \ + texlive-mathscience \ + ghostscript +``` + +**macOS** + +```bash +# 推荐 MacTeX(完整,约 4 GB) +brew install --cask mactex-no-gui + +# 或体积更小的 BasicTeX,之后按需补包 +brew install --cask basictex +sudo tlmgr update --self +sudo tlmgr install dvipng type1cm type1ec cm-super collection-fontsrecommended +``` + +**Windows** + +安装 [MiKTeX](https://miktex.org/download)(推荐,缺包时自动下载)或 [TeX Live](https://tug.org/texlive/windows.html)。安装完成后在 MiKTeX Console 里手动安装 `dvipng` 包,并确保 `latex.exe` 在 PATH 中。 + +**验证** + +```bash +latex --version +dvipng --version +``` + +若日志出现 `type1ec.sty not found` 或 `latex was not able to process`,TeX 包仍不完整:Debian / Ubuntu 已装 `texlive-full` 则无需额外操作;Arch 补装 `texlive-latexextra` `texlive-fontsrecommended` `texlive-binextra`;macOS BasicTeX 用户运行 `sudo tlmgr install cm-super`。 + ### Node.js / Rust / Tauri 如果需要构建跨平台控制台,请额外准备: diff --git a/docs/deployment.md b/docs/deployment.md index 3dec661..335b51d 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -47,7 +47,46 @@ uv sync uv run playwright install ``` -### 3. 配置环境 +### 3. 安装系统级依赖(必装) + +Bot 内置的数学公式等功能直接强依赖系统级的渲染环境,你**必须**提前在宿主机配置以下依赖。若是缺失该依赖,渲染图像或公式时后台将直接报错。 + +**安装 LaTeX 与工具链**: + +- **Ubuntu / Debian** + 直接无脑安装完整的 TeX Live 环境最为稳妥: + ```bash + sudo apt-get update + sudo apt-get install -y texlive-full dvipng ghostscript + ``` + +- **Arch Linux** + 通过 pacman 安装基础包: + ```bash + sudo pacman -S --needed texlive-basic texlive-bin texlive-latex texlive-latexrecommended texlive-latexextra texlive-fontsrecommended texlive-binextra texlive-mathscience ghostscript + ``` + +- **macOS** + 推荐通过 Homebrew 安装 MacTeX 环境,提供完整(省心,体积较大)或者精简两个版本: + ```bash + # 方式 1:完整环境(推荐) + brew install --cask mactex-no-gui + + # 方式 2:精简版(体积小,需手动拉取补包) + brew install --cask basictex + sudo tlmgr update --self + sudo tlmgr install dvipng type1cm type1ec cm-super collection-fontsrecommended + ``` + +- **Windows** + 安装 [MiKTeX](https://miktex.org/download) (推荐,能自动下载缺失宏包)或者 [TeX Live](https://tug.org/texlive/windows.html)。 + 1. 打开 MiKTeX Console。 + 2. 搜索 `dvipng` 手动将其安装上。 + 3. 确认环境变量 `PATH` 中已经包含了 `latex.exe`。 + +> 验证安装:使用 `latex --version` 与 `dvipng --version` 命令检测是否识别。如日志报错 `type1ec.sty not found` 或 `dvipng: command not found`,一般是由于所处的系统少安装了包或可执行文件不在环境变量中。 + +### 4. 配置环境 复制示例配置文件 `config.toml.example` 为 `config.toml` 并填写你的配置信息。 @@ -61,7 +100,7 @@ cp config.toml.example config.toml - **自定义图片资源**:修改 `img/` 下的对应文件(例如 `img/xlwy.jpg`)。 - **优先级**:若你希望“运行目录覆盖优先”:在启动目录放置 `./res/...`,会优先于默认资源生效(便于一套安装,多套运行配置)。 -### 4. 启动运行 +### 5. 启动运行 启动方式(二选一): @@ -75,7 +114,7 @@ uv run Undefined-webui > **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 -### 5. 跨平台与资源路径(重要) +### 6. 跨平台与资源路径(重要) - **资源读取**:运行时会优先从运行目录加载同名 `res/...` / `img/...`(便于覆盖),若不存在再使用安装包自带资源;并提供仓库结构兜底查找,因此从任意目录启动也能正常加载提示词与资源文案。 - **并发写入**:运行时会为 JSON/日志类文件使用”锁文件 + 原子替换”写入策略,Windows/Linux/macOS 行为一致(会生成 `*.lock` 文件)。 @@ -117,6 +156,8 @@ uv tool install Undefined-bot uv tool run --from Undefined-bot playwright install ``` +> **系统依赖提醒**:同源码部署要求一致,你必须在宿主机上预先安装所需的 LaTeX/dvipng 渲染环境。请参考上文 [3. 安装系统级依赖(必装)](#3-安装系统级依赖必装) 查阅你操作系统的对应安装命令,未配置前若触发公式与 Markdown 的图片渲染则会报错执行失败。 + 安装完成后,在任意目录准备 `config.toml` 并启动: ```bash diff --git a/docs/usage.md b/docs/usage.md index 8cb2e96..66dbc30 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -1,64 +1,390 @@ # 使用与功能说明 -## 开始使用 +本文档对 Undefined 的功能模块进行系统性介绍。完成[部署配置](deployment.md)并成功与 QQ 端建立连接后,即可通过自然语言或结构化指令使用以下全部能力。 -1. 启动 OneBot 协议端(如 NapCat 或 Lagrange.Core)并登录 QQ。 -2. 配置好 `config.toml` 并[启动 Undefined](deployment.md)。 -3. 连接成功后,机器人即可在群聊或私聊中响应。 +--- + +## 目录 + +1. [基础交互方式](#1-基础交互方式) +2. [认知记忆系统](#2-认知记忆系统) +3. [内置智能体 (Agents)](#3-内置智能体-agents) +4. [工具集能力一览 (Toolsets & Tools)](#4-工具集能力一览-toolsets--tools) +5. [定时任务与调度](#5-定时任务与调度) +6. [FAQ 知识库管理](#6-faq-知识库管理) +7. [内置斜杠指令参考](#7-内置斜杠指令参考) +8. [多模型池(私聊模型切换)](#8-多模型池私聊模型切换) +9. [WebUI 与跨平台管理](#9-webui-与跨平台管理) + +--- + +## 1. 基础交互方式 + +### 私聊场景 +在私聊会话中,可以直接向 Bot 发送任意消息,无需附加任何前缀或格式要求。系统会自动维护当前对话的完整上下文。 + +### 群聊场景 +在群聊环境中,Bot 默认仅响应以下方式触发的消息: +- **@提及**:在消息中 `@Bot` 并附带指令内容。 +- **指令前缀**:使用 `config.toml` 中配置的前缀(如有)。 + +> **队列优先级说明**:系统底层采用四级消息队列调度模型,优先级从高到低为:超级管理员 > 私聊 > @提及 > 普通群聊。在群聊高并发场景下,管理请求和直接提及将优先得到响应。 + +--- + +## 2. 认知记忆系统 + +Undefined 搭载了基于 ChromaDB 向量数据库的后台认知系统,无需手动录入,即可实现跨会话的长期上下文追踪。 + +| 能力 | 说明 | +|---|---| +| **聊天侧写(Profile)** | 系统实时静默分析对话内容,自动提取并持久化用户的偏好、待办、身份与观点等信息,在后续对话中作为参考背景 | +| **历史事件检索** | 基于向量语义检索,支持按用户、群组、时间段查询历史记忆,并应用时间衰减加权排序 | +| **群聊宏观总结** | 可对历史消息进行语义召回与整合,快速梳理出大量消息中的重点内容 | + +**示例:** +> *"请回忆一下我们上周讨论过的项目规划内容。"* +> *"请总结一下本群过去三天内讨论的主要话题。"* + +--- + +## 3. 内置智能体 (Agents) + +智能体(Agent)是由独立大模型驱动的高自治任务处理器。主 AI 在理解到任务超出自身直接能力范围时,会自动将任务委托给相应的专业 Agent,由其递归调用子工具完成任务后汇报结果。 + +### `web_agent` — 网络信息检索助手 + +负责网页搜索和网页内容爬取,能够获取互联网上的实时最新信息。 + +**子工具**:`grok_search`(Grok 搜索)、`web_search`(通用搜索)、`crawl_webpage`(网页内容提取) + +**示例:** +> *"请搜索最近三天关于 DeepSeek 的最新动态并生成摘要。"* +> *"帮我爬取这个网页的主要内容并整理成结构化笔记。"* + +--- + +### `file_analysis_agent` — 文件分析助手 + +支持对代码、PDF、Word、Excel 等多种格式文件进行解析与分析。用户只需将文件发送至对话中即可。 + +**子工具**:`analyze_pdf`、`analyze_docx`、`analyze_xlsx`、`analyze_code`、`read_file` + +**示例:** +> *"请分析这份 PDF 文档,提取其中第三章的核心数据。"* +> *"请检查这份 Python 代码,找出其中潜在的性能瓶颈。"* + +--- + +### `info_agent` — 信息查询助手 + +整合了多种公开信息查询能力,覆盖天气、热搜、域名、哔哩哔哩以及学术论文等信息源。 + +**子工具**:`weather_query`、`*hot`(热搜榜)、`whois`(域名查询)、`bilibili_search`、`bilibili_user_info`、`arxiv_search` + +**示例:** +> *"北京明天的天气怎么样?"* +> *"查一下今天的微博热搜前十名。"* +> *"帮我查询 arxiv 上关于 Chain-of-Thought 的最新论文。"* +> *"查一下 B 站 UP 主 xxx 的近期投稿情况。"* + +--- + +### `entertainment_agent` — 娱乐助手 + +提供运势、小说、随机图片和随机视频等休闲娱乐类功能。 + +**子工具**:`horoscope`(星座运势)、`novel_search`(小说检索)、`ai_draw_one`(AI 绘图)、`video_random_recommend`(随机视频推荐) + +**示例:** +> *"查一下天蝎座今天的运势。"* +> *"随机推荐几个有趣的视频。"* +> *"帮我画一张赛博朋克风格的城市夜景。"* + +--- + +### `code_delivery_agent` — 代码分析与交付助手 + +支持沙盒级别的代码代写、本地执行验证与自动打包。测试通过后,代码成果会自动打包为 `.zip` 文件并通过 QQ 发送给用户。 + +**示例:** +> *"请使用 Python 编写一个 HTTP 测速脚本,监听 8080 端口,验证跑通后将整个项目打包发到这个群。"* + +--- + +### `naga_code_analysis_agent` — NagaAgent 代码分析助手 + +专门用于深度分析 NagaAgent 框架及本项目的源代码结构。 + +**子工具**:`read_file`、`search_code`、`analyze_structure` + +--- + +## 4. 工具集能力一览 (Toolsets & Tools) + +除了通过 Agent 按需调用外,以下工具在对话中均可以通过自然语言直接触发。 + +### 渲染 (`render.*`) + +| 工具 | 说明 | +|---|---| +| `render.render_markdown` | 将 Markdown 文本(含表格、代码块、标题等)渲染为图片发送 | +| `render.render_latex` | 将 LaTeX 数学公式渲染为图片(**依赖系统 TeX 环境**,需提前安装,详见[部署文档](deployment.md#3-安装系统级依赖必装)) | +| `render.render_html` | 将 HTML 内容渲染为图片 | + +支持 `embed`(嵌入回复)和 `send`(直接发送)两种图片交付方式。 + +**示例:** +> *"请把这段数学公式渲染成图片发给我:$E=mc^2$"* +> *"请把下面这份 Markdown 表格渲染成图片。"* --- -## Agent 能力展示 +### 表情包 (`memes.*`) -机器人通过自然语言理解用户意图,自动调度相应的专业 Agent,具有高度独立和自动化的能力: +| 工具 | 说明 | +|---|---| +| `memes.search_memes` | 支持 `keyword`(关键词精确匹配)、`semantic`(语义联想检索)、`hybrid`(混合模式)三种检索方式 | +| `memes.send_meme_by_uid` | 根据图片统一 uid 以独立消息发送原图表情包 | -* **网络搜索提取**:"搜索一下 DeepSeek 的最新动态" -* **多模态文件分析**:"总结一下群里长图说了什么","帮我提取这份 PDF 中的数据" -* **表情包检索与发送**:在轻松聊天场景中,AI 可使用 `memes.search_memes` 按关键词检索、按语义检索,或混合检索表情包,再按统一图片 `uid` 独立发送 -* **B站视频解析**:发送 B 站链接/BV 号自动下载发送 1080p 视频,或指令 AI "下载这个 B 站视频 BV1xx411c7mD" -* **代码分析与交付**:"用 Python 写一个 HTTP 服务器,监听 8080 端口,返回 Hello World,验证通过后打包发到这个群" (交由 Code Delivery Agent) -* **定时任务管理**:"每天早上 8 点提醒我看新闻" -* **向未来的自己发指令**:"明天早上 9 点提醒你自己先总结今天群里的待办,再把前三项发给我" +两者通常配合使用:先由 `search_memes` 检索到目标表情包的 uid,再由 `send_meme_by_uid` 独立发送原图。 -### 定时任务进阶:调用未来的自己 +**示例:** +> *"请根据现在的群聊气氛,发一个应景的表情包。"* -定时任务除了调用普通工具外,还支持 `self_instruction` 模式。你可以把一段自然语言指令留给未来触发时刻的 AI 自己执行。 +--- -示例意图: -- “每周一 09:00,先回顾上周群聊重点,再提醒本周计划” -- “今天晚上 23:30,帮我生成明天的复盘提纲” +### 消息操作 (`messages.*`) -实现上由 `scheduler.create_schedule_task` / `scheduler.update_schedule_task` 的 `self_instruction` 参数承载(与 `tool_name`/`tools` 三选一)。 +| 工具 | 说明 | +|---|---| +| `messages.send_message` | 向当前会话发送消息 | +| `messages.send_private_message` | 向指定用户发送私聊消息 | +| `messages.get_recent_messages` | 获取最近若干条历史消息 | +| `messages.get_messages_by_time` | 按时间范围检索历史消息 | +| `messages.react_message_emoji` | 对指定消息添加表情回应 | +| `messages.send_poke` | 发送戳一戳 | +| `messages.send_text_file` | 将文本内容生成文件后发送 | +| `messages.send_url_file` | 下载指定 URL 的文件后发送 | +| `messages.send_group_sign` | 执行群签到操作 | +| `messages.get_forward_msg` | 获取合并转发消息的内容 | --- -## 斜杠指令 +### 群组信息查询 (`group.*`) + +| 工具 | 说明 | +|---|---| +| `group.get_member_list` | 获取群成员列表 | +| `group.get_member_info` | 查询指定成员的详细信息 | +| `group.find_member` | 按昵称/备注搜索群成员 | +| `group.get_member_title` | 获取成员群头衔 | +| `group.get_honor_info` | 查询群荣誉(龙王、话唠等) | +| `group.get_member_activity` | 分析群成员活跃度(支持 member_list / history / hybrid 三种数据源模式) | +| `group.rank_members` | 对群成员进行多维度排名 | +| `group.filter_members` | 按条件过滤群成员 | +| `group.detect_inactive_risk` | 检测长期潜水有流失风险的成员 | +| `group.activity_trend` | 分析群活跃度趋势变化 | +| `group.level_distribution` | 统计群成员等级分布 | +| `group.get_files` | 获取群文件列表 | + +**示例:** +> *"帮我查一下这个群里近 30 天没说过话的成员有哪些。"* +> *"请列出本群最近发言最多的前 10 名成员。"* + +--- -> 💡 **进阶玩法**:想了解每个命令的具体使用参数,或者学习如何通过写几行代码**自定义属于你的独家斜杠指令**?请前往 [命令系统与斜杠指令配置指南](slash-commands.md)。 +### 群聊深度分析 (`group_analysis.*`) -在群聊或私聊中可使用以下指令。除明确说明外,管理类命令需要具备被设置的超级管理员或管理员权限: +| 工具 | 说明 | +|---|---| +| `group_analysis.analyze_member_messages` | 深度分析指定成员的消息数量、类型分布和活跃时段 | +| `group_analysis.analyze_join_statistics` | 统计群成员加入趋势与留存情况 | +| `group_analysis.analyze_new_member_activity` | 分析新成员加入后的活跃度变化 | + +--- + +### 认知记忆查询 (`cognitive.*`) + +| 工具 | 说明 | +|---|---| +| `cognitive.search_events` | 按关键词语义检索历史记忆事件,支持用户、群组、时间段过滤 | +| `cognitive.get_profile` | 获取指定用户的认知侧写画像 | +| `cognitive.search_profiles` | 跨用户语义搜索侧写信息 | + +--- + +### 置顶备忘录 (`memory.*`) + +用于管理 AI 的自我约束事项和高优先级待办。此备忘录会在每轮对话时被固定注入上下文(上限 500 条),优先级高于认知记忆。 + +| 工具 | 说明 | +|---|---| +| `memory.add` | 添加一条置顶备忘(如"用户要求以后用英文回复") | +| `memory.update` | 更新指定备忘内容 | +| `memory.delete` | 删除指定备忘 | +| `memory.list` | 列出当前所有置顶备忘 | +| `memory.query_archive` | 查询已归档的历史备忘 | +| `memory.search_summaries` | 语义搜索历史备忘 | + +> **注意**:用户偏好、身份等长期用户事实请通过对话让 AI 记入**认知记忆**(`cognitive.*`),而非此处。置顶备忘专用于 AI 自身的行为约束与短期高优待办。 + +--- + +### 知识库检索 (`knowledge_*`) + +如果管理员在 `config.toml` 中配置了知识库,AI 可通过以下工具检索其中的内容: + +| 工具 | 说明 | +|---|---| +| `knowledge_semantic_search` | 基于向量语义检索(支持重排序与相关度过滤) | +| `knowledge_text_search` | 基于关键词的精确文本检索 | +| `knowledge_list` | 列出当前可用的知识库 | + +--- + +### 通讯录查询 (`contacts.*`) + +| 工具 | 说明 | +|---|---| +| `contacts.query_friends` | 查询 Bot 的好友列表 | +| `contacts.query_groups` | 查询 Bot 所在的群列表 | + +--- + +### 独立原子工具 + +| 工具 | 说明 | +|---|---| +| `get_current_time` | 获取当前系统时间,支持公历、农历、黄历等多种格式输出 | +| `get_picture` | 获取指定类型的图片(二次元、壁纸、白丝、黑丝、JK、历史上的今天等 10 余种类别) | +| `qq_like` | 给指定 QQ 号的资料卡点赞(默认 10 次) | +| `python_interpreter` | 在隔离的 **Docker 容器**中执行 Python 代码,支持按需安装第三方库,可在执行后自动发送生成的文件(图片、CSV 等) | +| `bilibili_video` | 下载并发送哔哩哔哩视频(支持 BV 号、链接) | +| `arxiv_paper` | 下载并发送 arXiv 论文 PDF(支持 arXiv ID、链接) | +| `fetch_image_uid` | 将指定 URL 的图片下载并转换为系统内部 uid | +| `task_progress` | 向用户发送长任务的阶段性进度通知 | +| `changelog_query` | 查询系统内置版本更新日志 | + +**示例:** +> *"请下载 arXiv 论文 2501.01234 并发到这个群。"* +> *"请在 Docker 里安装 matplotlib 后绘制一张正弦函数图像并发给我。"* +> *"帮我给 QQ 号 123456 点 10 个赞。"* + +--- + +## 5. 定时任务与调度 + +调度器基于标准 crontab 语法,支持三种执行模式,适用于从简单报时到复杂 AI 自主任务的全部场景。 + +### 执行模式 + +| 模式 | 描述 | 配置字段 | +|---|---|---| +| **单工具模式** | 定时调用一个指定的工具,传入固定参数 | `tool_name` + `tool_args` | +| **多工具串/并行模式** | 定时依次(serial)或同时(parallel)调用多个工具 | `tools` + `execution_mode` | +| **AI 自我督办模式** | 在触发时刻,以一段自然语言指令唤醒 AI 自主完成任务 | `self_instruction` | + +### 自我督办模式示例 + +这是调度器最灵活的功能:您可以通过自然语言预约将任意复杂的指令投递给"未来的 AI 自己"来执行。 + +> *"每天上午 9:00,请回顾昨日遗留的待办事项,并把最重要的前三项通过私聊发给我。"* +> *"每周一 08:30,请总结上周群内的高频讨论话题,生成一份周报并发送至群聊。"* +> *"明天晚上 23:00,帮我生成今天的话痨统计图表发到本群。"*(仅执行一次:设置 `max_executions: 1`) + +### 任务管理工具 + +| 工具 | 说明 | +|---|---| +| `scheduler.create_schedule_task` | 创建定时任务,支持 `max_executions`(达到次数后自动删除) | +| `scheduler.update_schedule_task` | 修改任务的触发规则、执行内容或参数 | +| `scheduler.delete_schedule_task` | 删除指定定时任务 | +| `scheduler.list_schedule_tasks` | 列出当前所有定时任务及其运行状态 | + +--- + +## 6. FAQ 知识库管理 + +Bot 支持在运行时维护一个结构化的群专属 FAQ 知识库,可通过斜杠指令进行增删查操作。 + +| 指令 | 权限 | 说明 | +|---|---|---| +| `/lsfaq` | 公开 | 列出当前群的全部 FAQ 条目 | +| `/viewfaq ` | 公开 | 查看指定 FAQ 的详细内容 | +| `/searchfaq <关键词>` | 公开 | 按关键词搜索匹配的 FAQ | +| `/delfaq ` | 管理员 | 删除指定 ID 的 FAQ 条目 | + +--- + +## 7. 内置斜杠指令参考 + +所有斜杠指令均以 `/` 开头,在群聊或私聊中直接输入即可触发。下表基于代码实际配置整理: + +| 指令 | 别名 | 权限 | 私聊 | 说明 | +|---|---|---|---|---| +| `/help [命令名]` | — | 公开 | ✅ | 显示命令列表;附带命令名时展示该命令的详细帮助文档 | +| `/version` | `/v` | 公开 | ✅ | 查看当前版本号及最新版本变更标题 | +| `/changelog [子命令]` | `/cl` | 公开 | ✅ | 查看版本更新日志(详见下方说明) | +| `/copyright` | `/about` `/license` `/cprt` | 公开 | ✅ | 查看版权信息与 MIT 许可证声明 | +| `/stats [天数] [--ai]` | — | 公开 | ✅ | 查看 Token 使用统计图表;附加 `--ai` 启用 AI 智能分析报告 | +| `/lsfaq` | — | 公开 | ❌ | 列出当前群的全部 FAQ | +| `/viewfaq ` | — | 公开 | ❌ | 查看指定 FAQ 详情 | +| `/searchfaq <关键词>` | — | 公开 | ❌ | 按关键词搜索 FAQ | +| `/delfaq ` | — | 管理员 | ❌ | 删除指定 FAQ | +| `/bugfix [起止时间]` | — | 管理员 | ❌ | 基于目标用户近期发言生成娱乐性 Bug 修复报告 | +| `/lsadmin` | — | 管理员 | ✅ | 查看系统当前的超管与管理员列表 | +| `/naga ` | — | 公开 | ✅ | 绑定或解绑关联的 NagaAgent 实例 | +| `/addadmin ` | — | **超级管理员** | ✅ | 将指定用户提权为普通管理员 | +| `/rmadmin ` | — | **超级管理员** | ✅ | 撤销指定用户的管理员权限 | + +### `/changelog` 子命令详解 -```bash -/help # 查看帮助菜单 -/changelog # 查看最近版本历史(公开命令) -/changelog show v3.2.6 # 查看指定版本详情(公开命令) -/lsadmin # 查看当前所有的系统管理员列表 -/addadmin # 添加新的普通管理员(仅限超级管理员使用) -/rmadmin # 移除某位普通管理员 -/bugfix # 根据最近用户在群里的聊天上下文生成该用户的 Bug 修复报告 (幽默搞笑用) -/stats [时间范围] [--ai] # 核心统计功能:获取 Token 使用统计 + 成本计算;加 --ai 才启用智能分析 ``` +/changelog # 列出最近 8 个版本(版本号 + 标题) +/changelog list <数量> # 列出更多版本,最大 20 条 +/changelog latest # 展示最新一个版本的完整变更详情 +/changelog show <版本号> # 展示指定版本的完整详情(带或不带 v 均可) +/changelog <版本号> # 等同于 show +``` + +### `/stats` 说明 -### 关于 `/changelog` 的详细说明: +- 默认统计最近 **7 天**的数据,可传入天数参数(允许范围:1 ~ 365 天)。 +- 默认仅生成统计图表与数字摘要,**不触发** AI 智能分析。 +- 附加 `--ai`(或 `-a`)时,向 AI 发起分析请求;若分析超时,系统会先返回图表与摘要并附带超时提示。 +- 普通用户频率限制为每 3600 秒一次;管理员与超级管理员无限制。 -- `/changelog` 默认列最近 8 个版本,按新到旧展示 `版本号 + 标题`。 -- `/changelog list 12` 可查看更多版本,最大 20 条。 -- `/changelog show <版本号>` 会展示单个版本的标题、摘要和变更点,版本号支持带或不带 `v`。 -- `/changelog latest` 会直接展示 `CHANGELOG.md` 中最新一条版本记录。 -- 版本内容直接来自仓库内维护的 `CHANGELOG.md`,不是运行时临时扫描 git tag。 +### 扩展自定义指令 -### 关于 `/stats` 的详细说明: +系统支持热插拔机制,创建对应目录结构并保存文件即刻生效,无需重启服务。详细的开发步骤与参数说明请参阅 [《命令系统与斜杠指令》](slash-commands.md)。 + +--- + +## 8. 多模型池(私聊模型切换) + +在 `config.toml` 中全局开启 `[features] pool_enabled = true` 后,Bot 支持在多个配置的大模型之间进行灵活调度: + +- **自动轮换**:配置 `strategy = "round_robin"` 或 `"random"` 后,私聊请求会自动按策略在池中模型之间切换。 +- **手动指定**:在私聊中,可通过发送"选 1"、"选 2"等指令来手动锁定本次使用的模型。 + +> 群聊场景始终使用主模型,不参与多模型池调度。 + +完整配置方式及 Agent 模型池说明请参阅 [《多模型池功能》](multi-model.md)。 + +--- + +## 9. WebUI 与跨平台管理 + +Undefined 提供了一套完整的可视化管理控制台,无需修改配置文件或重启服务即可对系统进行动态管理: + +- 实时切换底层驱动的大模型(如 GPT-4o、Claude 3.5 Sonnet 等)。 +- 在线编辑系统 Prompt 与人格设定面板。 +- 监控并干预运行时任务队列与内存状态。 +- 查看完整的 Token 消耗统计与调用日志。 + +WebUI 通过浏览器访问(默认地址 `http://127.0.0.1:8787`,默认密码 `changeme`,**首次启动必须在 `config.toml` 的 `[webui]` 中修改默认密码**)。如需通过手机或其他设备进行远程管理,可使用配套的多端控制台 App,详见 [《跨平台控制台 App》](app.md)。 + +--- -- 默认统计最近 7 天的数据,时间参数范围会自动被系统钳制在 1 天 - 365 天之间。 -- 默认只发送图表与基本摘要,不会触发 AI 智能分析。 -- 仅在显式传入 `--ai`(或 `-a`)时才会请求 AI 分析;若分析超时,系统会先发图表与摘要并附超时提示。 +*如需查阅各模块的底层设计原理与 API 集成说明,请参阅本目录下的其余技术文档。* diff --git a/src/Undefined/skills/toolsets/README.md b/src/Undefined/skills/toolsets/README.md index 750dc64..9220998 100644 --- a/src/Undefined/skills/toolsets/README.md +++ b/src/Undefined/skills/toolsets/README.md @@ -136,7 +136,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: ### Render(渲染) - `render.render_html`: 将 HTML 渲染为图片 -- `render.render_latex`: 将 LaTeX 渲染为图片 +- `render.render_latex`: 将 LaTeX 渲染为图片(依赖系统 TeX 环境,需安装 TeX Live / MiKTeX) - `render.render_markdown`: 将 Markdown 渲染为图片 ### Memes(表情包) diff --git a/src/Undefined/skills/toolsets/render/render_latex/config.json b/src/Undefined/skills/toolsets/render/render_latex/config.json index 1d3c48e..a351083 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/config.json +++ b/src/Undefined/skills/toolsets/render/render_latex/config.json @@ -2,13 +2,13 @@ "type": "function", "function": { "name": "render_latex", - "description": "将 LaTeX 文本渲染为图片。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。支持完整的 LaTeX 语法(包含 \\begin 和 \\end)。", + "description": "将 LaTeX 公式或文档渲染为图片,使用系统外部 LaTeX(需预先安装 TeX Live / MiKTeX)。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。", "parameters": { "type": "object", "properties": { "content": { "type": "string", - "description": "要渲染的 LaTeX 内容。必须是完整格式(包含 \\begin 和 \\end)。" + "description": "要渲染的 LaTeX 内容。支持 $...$、$$...$$、\\[...\\]、\\(...\\) 及完整环境(\\begin{align}...\\end{align} 等);\\begin{document}...\\end{document} 外层包装会自动去掉。" }, "delivery": { "type": "string", diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 26774d6..8ac6e06 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -1,15 +1,54 @@ from __future__ import annotations +from pathlib import Path +import re from typing import Any, Dict import logging import uuid -import matplotlib.pyplot as plt + import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + from Undefined.attachments import scope_from_context logger = logging.getLogger(__name__) +_DOCUMENT_PATTERN = re.compile( + r"^\s*\\begin\{document\}(?P.*?)\\end\{document\}\s*$", + re.DOTALL, +) + + +def _strip_document_wrappers(content: str) -> str: + """去掉 \\begin{document}...\\end{document} 外层包装;matplotlib 会自行构造文档。""" + text = content.strip() + match = _DOCUMENT_PATTERN.fullmatch(text) + if match is None: + return text + return match.group("body").strip() + + +def _render_latex_image(filepath: Path, content: str) -> None: + text = _strip_document_wrappers(content) + fig = plt.figure(figsize=(6, 2.5)) + try: + fig.patch.set_facecolor("white") + fig.text( + 0.5, + 0.5, + text, + fontsize=20, + verticalalignment="center", + horizontalalignment="center", + usetex=True, + wrap=True, + ) + fig.savefig(filepath, dpi=200, bbox_inches="tight", pad_inches=0.25) + finally: + plt.close(fig) + def _resolve_send_target( target_id: Any, @@ -33,7 +72,7 @@ def _resolve_send_target( async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: """渲染 LaTeX 数学公式为图片""" - content = args.get("content", "") + content = str(args.get("content", "") or "") delivery = str(args.get("delivery", "embed") or "embed").strip().lower() target_id = args.get("target_id") message_type = args.get("message_type") @@ -53,26 +92,7 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: filename = f"render_{uuid.uuid4().hex[:16]}.png" filepath = ensure_dir(RENDER_CACHE_DIR) / filename - matplotlib.use("Agg") - - fig, ax = plt.subplots(figsize=(10, 6)) - ax.axis("off") - - ax.text( - 0.5, - 0.5, - content, - transform=ax.transAxes, - fontsize=12, - verticalalignment="center", - horizontalalignment="center", - usetex=True, - wrap=True, - ) - - plt.tight_layout() - plt.savefig(filepath, dpi=150, bbox_inches="tight", pad_inches=0.1) - plt.close(fig) + _render_latex_image(filepath, content) # 注册到附件系统 attachment_registry = context.get("attachment_registry") @@ -133,6 +153,13 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: except ImportError as e: missing_pkg = str(e).split("'")[1] if "'" in str(e) else "未知包" return f"渲染失败:缺少依赖包 {missing_pkg},请运行: uv add {missing_pkg}" + except RuntimeError as e: + err = str(e).lower() + if "latex" in err or "dvipng" in err or "dvi" in err: + logger.error("LaTeX 渲染失败(系统 TeX 环境不可用): %s", e) + return "渲染失败:系统 LaTeX 环境未安装或不完整,请按部署文档安装 TeX Live / MiKTeX 后重试。" + logger.exception("渲染并发送 LaTeX 图片失败: %s", e) + return "渲染失败,请稍后重试" except Exception as e: logger.exception(f"渲染并发送 LaTeX 图片失败: {e}") return "渲染失败,请稍后重试" diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py new file mode 100644 index 0000000..3231d6d --- /dev/null +++ b/tests/test_render_latex_tool.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.attachments import AttachmentRegistry +from Undefined.skills.toolsets.render.render_latex import handler +from Undefined.utils import paths + +_PNG_HEADER = b"\x89PNG\r\n\x1a\n" + + +def _build_context(registry: AttachmentRegistry) -> dict[str, Any]: + return { + "request_type": "group", + "group_id": 10001, + "sender_id": 20002, + "user_id": 20002, + "attachment_registry": registry, + } + + +def test_strip_document_wrappers_removes_document_env() -> None: + content = ( + "\\begin{document}\n" + "\\[\n" + "\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n" + "\\]\n" + "\\end{document}" + ) + assert handler._strip_document_wrappers(content) == ( + "\\[\n\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n\\]" + ) + + +def test_strip_document_wrappers_passthrough_for_plain_formula() -> None: + content = r"\[ E = mc^2 \]" + assert handler._strip_document_wrappers(content) == content + + +@pytest.mark.asyncio +async def test_render_latex_embed_success( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + content = r"\[ \int_{-\infty}^{+\infty} e^{-x^2} dx = \sqrt{\pi} \]" + rendered_contents: list[str] = [] + + monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + + def _fake_render(filepath: Path, render_content: str) -> None: + rendered_contents.append(render_content) + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_bytes(_PNG_HEADER) + + monkeypatch.setattr(handler, "_render_latex_image", _fake_render) + + result = await handler.execute( + {"content": content, "delivery": "embed"}, + _build_context(registry), + ) + + assert result.startswith(' None: + registry = AttachmentRegistry( + registry_path=tmp_path / "attachment_registry.json", + cache_dir=tmp_path / "attachments", + ) + content = r"\[ a = b \]" + + monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + + def _raise_runtime(_: Path, __: str) -> None: + raise RuntimeError("latex was not able to process the following string") + + monkeypatch.setattr(handler, "_render_latex_image", _raise_runtime) + + result = await handler.execute( + {"content": content, "delivery": "embed"}, + _build_context(registry), + ) + + assert "TeX Live" in result or "MiKTeX" in result From a42da3d0b1d06cfa4911f55e13000fd084a892b7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 14:22:27 +0800 Subject: [PATCH 05/57] feat: attachment hash dedup, unified tags, LaTeX MathJax refactor, meme auto-match - Attachment hash dedup: same scope+kind+SHA256 returns existing record - Unified tag: routes image/file by UID prefix, backward-compat - Centralized dispatch_pending_file_sends() for non-image file delivery - LaTeX rendering: migrate from matplotlib to MathJax + Playwright (no system TeX) - LaTeX: support PNG and PDF output via output_format parameter - Meme auto-match: annotate incoming images with meme descriptions by SHA256 - Update both prompt XML files with unified attachment tag documentation - 37 new tests (713 total, all passing) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- docs/slash-commands.md | 56 ++- res/prompts/undefined.xml | 14 +- res/prompts/undefined_nagaagent.xml | 14 +- src/Undefined/ai/prompts.py | 4 +- src/Undefined/arxiv/sender.py | 48 ++ src/Undefined/attachments.py | 226 +++++++-- src/Undefined/handlers.py | 87 +++- src/Undefined/services/ai_coordinator.py | 13 +- src/Undefined/services/command.py | 5 + src/Undefined/services/commands/context.py | 2 + .../skills/agents/summary_agent/__init__.py | 1 + .../skills/agents/summary_agent/config.json | 17 + .../skills/agents/summary_agent/handler.py | 25 + .../skills/agents/summary_agent/intro.md | 30 ++ .../skills/agents/summary_agent/prompt.md | 79 +++ .../agents/summary_agent/tools/__init__.py | 1 + .../tools/fetch_messages/config.json | 20 + .../tools/fetch_messages/handler.py | 108 ++++ .../skills/commands/profile/config.json | 16 + .../skills/commands/profile/handler.py | 54 ++ .../skills/commands/summary/config.json | 16 + .../skills/commands/summary/handler.py | 103 ++++ .../messages/get_messages_by_time/handler.py | 14 +- .../messages/get_recent_messages/handler.py | 14 +- .../toolsets/messages/send_message/handler.py | 7 + .../messages/send_private_message/handler.py | 7 + .../toolsets/render/render_latex/config.json | 19 +- .../toolsets/render/render_latex/handler.py | 284 ++++++----- src/Undefined/utils/history.py | 2 + tests/test_arxiv_sender.py | 185 +++++++ tests/test_attachment_tags.py | 336 +++++++++++++ tests/test_attachments_dedup.py | 147 ++++++ tests/test_coordinator_level.py | 155 ++++++ tests/test_fetch_messages_tool.py | 466 ++++++++++++++++++ tests/test_handlers_meme_annotation.py | 277 +++++++++++ tests/test_history_level.py | 174 +++++++ tests/test_message_tools_level.py | 161 ++++++ tests/test_profile_command.py | 252 ++++++++++ tests/test_prompts_level.py | 285 +++++++++++ tests/test_render_latex_tool.py | 233 ++++++--- tests/test_summary_agent.py | 148 ++++++ tests/test_summary_command.py | 341 +++++++++++++ 42 files changed, 4178 insertions(+), 268 deletions(-) create mode 100644 src/Undefined/skills/agents/summary_agent/__init__.py create mode 100644 src/Undefined/skills/agents/summary_agent/config.json create mode 100644 src/Undefined/skills/agents/summary_agent/handler.py create mode 100644 src/Undefined/skills/agents/summary_agent/intro.md create mode 100644 src/Undefined/skills/agents/summary_agent/prompt.md create mode 100644 src/Undefined/skills/agents/summary_agent/tools/__init__.py create mode 100644 src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json create mode 100644 src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py create mode 100644 src/Undefined/skills/commands/profile/config.json create mode 100644 src/Undefined/skills/commands/profile/handler.py create mode 100644 src/Undefined/skills/commands/summary/config.json create mode 100644 src/Undefined/skills/commands/summary/handler.py create mode 100644 tests/test_attachment_tags.py create mode 100644 tests/test_attachments_dedup.py create mode 100644 tests/test_coordinator_level.py create mode 100644 tests/test_fetch_messages_tool.py create mode 100644 tests/test_handlers_meme_annotation.py create mode 100644 tests/test_history_level.py create mode 100644 tests/test_message_tools_level.py create mode 100644 tests/test_profile_command.py create mode 100644 tests/test_prompts_level.py create mode 100644 tests/test_summary_agent.py create mode 100644 tests/test_summary_command.py diff --git a/docs/slash-commands.md b/docs/slash-commands.md index 9be0f8e..6acf72b 100644 --- a/docs/slash-commands.md +++ b/docs/slash-commands.md @@ -84,7 +84,51 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /changelog latest ``` -#### 2. 统计与分析服务 +#### 2. 消息总结与侧写查看 + +- **/profile [group]** + - **说明**:查看用户或群聊的认知侧写。侧写由系统根据聊天历史自动生成和更新。 + - **别名**:`/me`、`/p` + - **参数**: + + | 参数 | 是否必填 | 说明 | + |------|----------|------| + | `group` | 可选 | 传入 `group` 时查看当前群聊的侧写(仅群聊可用) | + + - **行为**: + - **私聊**:只能查看自己的用户侧写,不支持 `group` 参数。 + - **群聊**:不带参数查看自己的用户侧写,带 `group` 查看当前群聊侧写。 + - 过长侧写会自动截断(3000 字符上限)。 + - **限流**:普通用户 60 秒,管理员 10 秒,超管无限制。 + - **示例**: + ``` + /profile → 查看自己的侧写 + /me → 同上(别名) + /profile group → 查看当前群聊的侧写 + ``` + +- **/summary [条数|时间范围] [自定义描述]** + - **说明**:调用消息总结 Agent,拉取指定范围的聊天消息并进行智能总结。 + - **别名**:`/sum` + - **参数**: + + | 参数 | 是否必填 | 说明 | + |------|----------|------| + | `条数` | 可选 | 纯数字,表示总结最近 N 条消息(默认 50,最大 500) | + | `时间范围` | 可选 | 格式如 `1h`、`6h`、`1d`、`7d`,与条数互斥 | + | `自定义描述` | 可选 | 总结的重点方向,如"技术讨论"、"项目进展" | + + - **限流**:普通用户 120 秒,管理员 30 秒,超管无限制。 + - **示例**: + ``` + /summary → 总结最近 50 条消息 + /summary 100 → 总结最近 100 条消息 + /summary 1d → 总结过去 1 天的消息 + /summary 50 技术讨论 → 总结最近 50 条,重点关注技术讨论 + /sum 1d 项目进展 → 总结过去 1 天,重点关注项目进展 + ``` + +#### 3. 统计与分析服务 - **/stats [时间范围] [--ai]** - **说明**:生成过去一段时间内 Token 的使用统计数据、模型消耗排行、输入输出比例,并输出可视化图表。默认不启用 AI 分析,显式传 `--ai`(或 `-a`)才会触发。 - **参数**: @@ -116,7 +160,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /stats 30d --ai → 最近 30 天并启用 AI 分析 ``` -#### 3. 权限管理 (动态 Admin) +#### 4. 权限管理 (动态 Admin) 通过指令动态管理管理员列表,变更会自动持久化到 `config.local.json`,无需重启。超管(Superadmin)拥有最高权限,由配置文件的 `core.super_admins` 静态定义。 - **/lsadmin** @@ -151,7 +195,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 - 若目标本身不是管理员,返回"不是管理员"提示。 - **示例**:`/rmadmin 123456789` -#### 4. 本地群级 FAQ 系统 +#### 5. 本地群级 FAQ 系统 用于对常见问题(FAQ)进行检索和管理。FAQ 不必每次请求 AI 大模型,极大地节省 Token 并加快响应。 - **/lsfaq** @@ -192,7 +236,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 - **边界行为**:若 ID 不存在,返回"FAQ 不存在"提示。 - **示例**:`/delfaq 20241205-001` -#### 5. 排障与反馈 +#### 6. 排障与反馈 - **/bugfix \ [QQ号2...] \<开始时间\> \<结束时间\>** - **说明**:从群历史记录中抓取指定用户在指定时间段内的消息(包含文字、图片的 OCR 描述),交给 AI 进行分析并生成 Bug 修复报告,结果自动存入 FAQ 库。 - **参数**: @@ -214,7 +258,7 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 /bugfix 111111 222222 2024/12/01/09:00 2024/12/01/18:00 ``` -#### 6. Naga 集成管理 +#### 7. Naga 集成管理 > **⚠️ 此功能面向与 NagaAgent 对接的高级场景,普通用户不建议开启。** 需要在 `config.toml` 中同时启用 `[api].enabled`、`[features].nagaagent_mode_enabled` 和 `[naga].enabled`。 @@ -347,6 +391,8 @@ async def execute(args: list[str], context: CommandContext) -> None: | `ctx.bot_qq` | `int` | 当前机器人的自身 QQ 号 | | `ctx.ai` | `AIClient` | 主 AI Client,可以用于进行分析、总结等大模型调用 | | `ctx.faq_storage` | `FAQStorage` | FAQ 的键值操作入口 | +| `ctx.cognitive_service` | `Any \| None` | 认知侧写服务,可调用 `get_profile(entity_type, entity_id)` | +| `ctx.history_manager` | `Any \| None` | 消息历史管理器,可调用 `get_recent(chat_id, msg_type, start, end)` | ### 3. 可用的 `permission` (权限级别) diff --git a/res/prompts/undefined.xml b/res/prompts/undefined.xml index 2e9b52e..b8e503a 100644 --- a/res/prompts/undefined.xml +++ b/res/prompts/undefined.xml @@ -141,12 +141,16 @@ - 除非 `memes.search_memes` 没找到合适结果,或表情包会干扰信息传递,否则不要把本来适合发图的反应先写成一句话来代替发图 - 表情包相关规则只决定“怎么回复”,不单独构成“该不该回复”的参与许可;是否回复仍以前面的回复触发逻辑为准 - 默认不要把表情包和正文写进同一条消息;需要补一句解释时,优先分成两条消息发送 - - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` - - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 - - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` - - 表情包库返回的图片 UID 也可以直接用于 ``;当前会话临时图片和表情包库图片共用同一套 `uid` 语义 + - 推荐使用统一标签 `` 引用任何附件(图片或文件),系统根据 UID 前缀自动处理: + - `pic_*` UID → 内嵌为图片(等效于旧 `` 语法) + - `file_*` UID → 作为独立文件消息在文字之后发出 + - `` 语法仍然可用且仅限图片 UID(向后兼容) + - `` 是推荐的统一语法,适用于所有类型的附件 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 文件附件在文字消息发出后作为独立文件消息依次发送,不会混排在文字中 + - 表情包库返回的图片 UID 也可以直接用于 `` - 只能引用工具结果或上下文里明确给出的图片 UID,禁止臆造 UID - - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + - 不要把 `file_*` UID 放进 `` 标签(会报类型错误) diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index 4bbebbf..acd3cc3 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -141,12 +141,16 @@ - 除非 `memes.search_memes` 没找到合适结果,或表情包会干扰信息传递,否则不要把本来适合发图的反应先写成一句话来代替发图 - 表情包相关规则只决定“怎么回复”,不单独构成“该不该回复”的参与许可;是否回复仍以前面的回复触发逻辑为准 - 默认不要把表情包和正文写进同一条消息;需要补一句解释时,优先分成两条消息发送 - - 如果上下文或工具结果给了图片 UID(例如 `pic_ab12cd34`),你可以在 `send_message.message` 里直接插入 `` - - `` 是唯一允许的内嵌图片语法;不要改成 Markdown 图片、HTML ``、代码块或自然语言描述 - - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` - - 表情包库返回的图片 UID 也可以直接用于 ``;当前会话临时图片和表情包库图片共用同一套 `uid` 语义 + - 推荐使用统一标签 `` 引用任何附件(图片或文件),系统根据 UID 前缀自动处理: + - `pic_*` UID → 内嵌为图片(等效于旧 `` 语法) + - `file_*` UID → 作为独立文件消息在文字之后发出 + - `` 语法仍然可用且仅限图片 UID(向后兼容) + - `` 是推荐的统一语法,适用于所有类型的附件 + - 可以图文混排,例如:`我给你介绍一下`\n``\n`如图所示` + - 文件附件在文字消息发出后作为独立文件消息依次发送,不会混排在文字中 + - 表情包库返回的图片 UID 也可以直接用于 `` - 只能引用工具结果或上下文里明确给出的图片 UID,禁止臆造 UID - - 只有 `pic_*` 这类图片 UID 能放进 ``;普通文件 UID 不能放进去 + - 不要把 `file_*` UID 放进 `` 标签(会报类型错误) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index 3cd9b70..4fc9d81 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -746,6 +746,7 @@ async def _inject_recent_messages( attachments = msg.get("attachments", []) role = msg.get("role", "member") title = msg.get("title", "") + level = msg.get("level", "") message_id = msg.get("message_id") safe_sender = escape_xml_attr(sender_name) @@ -771,9 +772,10 @@ async def _inject_recent_messages( chat_name if chat_name.endswith("群") else f"{chat_name}群" ) safe_location = escape_xml_attr(location) + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" xml_msg = ( f'\n{safe_text}{attachment_xml}\n' ) else: diff --git a/src/Undefined/arxiv/sender.py b/src/Undefined/arxiv/sender.py index 214cceb..1cecf06 100644 --- a/src/Undefined/arxiv/sender.py +++ b/src/Undefined/arxiv/sender.py @@ -4,6 +4,7 @@ import asyncio import logging +import time from typing import TYPE_CHECKING, Literal from Undefined.arxiv.client import get_paper_info @@ -19,6 +20,31 @@ _INFLIGHT_LOCK = asyncio.Lock() _INFLIGHT_SENDS: dict[tuple[str, int, str], asyncio.Future[str]] = {} +# Time-based dedup: maps (target_type, target_id, paper_id) → monotonic timestamp +_RECENT_SENDS: dict[tuple[str, int, str], float] = {} +_DEDUP_COOLDOWN_SECONDS: float = 3600.0 # 1 hour +_RECENT_SENDS_MAX_SIZE: int = 1000 + + +def _cleanup_expired_recent_sends() -> None: + """Remove expired entries from _RECENT_SENDS. Must be called under _INFLIGHT_LOCK.""" + now = time.monotonic() + expired = [ + k for k, v in _RECENT_SENDS.items() if now - v >= _DEDUP_COOLDOWN_SECONDS + ] + for k in expired: + del _RECENT_SENDS[k] + + +def _evict_oldest_recent_sends() -> None: + """Evict oldest entries if _RECENT_SENDS exceeds max size. Must be called under _INFLIGHT_LOCK.""" + if len(_RECENT_SENDS) <= _RECENT_SENDS_MAX_SIZE: + return + sorted_keys = sorted(_RECENT_SENDS, key=lambda k: _RECENT_SENDS[k]) + excess = len(_RECENT_SENDS) - _RECENT_SENDS_MAX_SIZE + for k in sorted_keys[:excess]: + del _RECENT_SENDS[k] + def _build_abs_url(paper_id: str) -> str: return f"https://arxiv.org/abs/{paper_id}" @@ -203,6 +229,24 @@ async def send_arxiv_paper( created = False async with _INFLIGHT_LOCK: + # Lazy cleanup of expired entries + _cleanup_expired_recent_sends() + + # Check time-based dedup first + recent_ts = _RECENT_SENDS.get(key) + if ( + recent_ts is not None + and (time.monotonic() - recent_ts) < _DEDUP_COOLDOWN_SECONDS + ): + logger.info( + "[arXiv] 论文近期已发送,跳过: paper=%s target=%s:%s", + normalized, + target_type, + target_id, + ) + return f"论文 {normalized} 近期已发送过,已跳过" + + # Check inflight dedup future = _INFLIGHT_SENDS.get(key) if future is None: future = asyncio.get_running_loop().create_future() @@ -242,3 +286,7 @@ async def send_arxiv_paper( current = _INFLIGHT_SENDS.get(key) if current is future: _INFLIGHT_SENDS.pop(key, None) + # Record successful send time for dedup cooldown + if future.done() and not future.cancelled() and future.exception() is None: + _RECENT_SENDS[key] = time.monotonic() + _evict_oldest_recent_sends() diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index b0c09af..e8fa8fc 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -33,6 +33,14 @@ r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", re.IGNORECASE, ) +_ATTACHMENT_TAG_PATTERN = re.compile( + r"[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) +_UNIFIED_TAG_PATTERN = re.compile( + r"<(?Ppic|attachment)\s+uid=(?P[\"'])(?P[^\"']+)(?P=quote)\s*/?>", + re.IGNORECASE, +) _MEDIA_LABELS = { "image": "图片", "file": "文件", @@ -107,10 +115,11 @@ class RenderedRichMessage: delivery_text: str history_text: str attachments: list[dict[str, str]] + pending_file_sends: tuple[AttachmentRecord, ...] = () class AttachmentRenderError(RuntimeError): - """Raised when a `` tag cannot be rendered.""" + """Raised when an attachment tag cannot be rendered.""" def _now_iso() -> str: @@ -681,6 +690,25 @@ def _build_uid(self, prefix: str) -> str: if uid not in self._records: return uid + def _find_by_sha256( + self, scope_key: str, sha256: str, kind: str + ) -> AttachmentRecord | None: + """Find an existing record with matching scope, kind, and SHA-256. + + Only returns a record whose *local_path* still exists on disk. + Must be called while ``self._lock`` is held. + """ + for record in self._records.values(): + if ( + record.scope_key == scope_key + and record.sha256 == sha256 + and record.kind == kind + and record.local_path + and Path(record.local_path).is_file() + ): + return record + return None + async def register_bytes( self, scope_key: str, @@ -703,15 +731,18 @@ async def register_bytes( prefix = "pic" if normalized_media_type == "image" else "file" async with self._lock: + digest = await asyncio.to_thread(hashlib.sha256, content) + digest_hex = digest.hexdigest() + + existing = self._find_by_sha256(scope_key, digest_hex, normalized_kind) + if existing is not None: + return existing + uid = self._build_uid(prefix) file_name = f"{uid}{suffix}" cache_path = ensure_dir(self._cache_dir) / file_name + await asyncio.to_thread(cache_path.write_bytes, content) - def _write() -> str: - cache_path.write_bytes(content) - return hashlib.sha256(content).hexdigest() - - digest = await asyncio.to_thread(_write) record = AttachmentRecord( uid=uid, scope_key=scope_key, @@ -722,7 +753,7 @@ def _write() -> str: source_ref=source_ref, local_path=str(cache_path), mime_type=normalized_mime, - sha256=digest, + sha256=digest_hex, created_at=_now_iso(), segment_data={ str(k): str(v) @@ -1064,31 +1095,38 @@ async def _collect_from_segments( ) -async def render_message_with_pic_placeholders( +async def render_message_with_attachments( message: str, *, registry: AttachmentRegistry | None, scope_key: str | None, strict: bool, ) -> RenderedRichMessage: - if ( - not message - or registry is None - or not scope_key - or "`` and ```` tags into delivery/history text. + + * ```` — backward-compatible, image-only. + * ```` — unified tag for any media type. + Images (``pic_*``) are inlined as CQ images; files (``file_*``) + are collected into *pending_file_sends* for later dispatch. + """ + has_tags = message and ( + " tag: strictly image-only + if tag_name == "pic" and record.media_type != "image": replacement = f"[图片 uid={uid} 类型错误]" if strict: raise AttachmentRenderError(f"UID 不是图片,不能用于 :{uid}") @@ -1110,31 +1151,19 @@ async def render_message_with_pic_placeholders( history_parts.append(replacement) continue - image_source = record.source_ref - if record.local_path: - image_source = Path(record.local_path).resolve().as_uri() - elif not image_source: - replacement = f"[图片 uid={uid} 缺少文件]" - if strict: - raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") - delivery_parts.append(replacement) - history_parts.append(replacement) - continue - - cq_args = [f"file={image_source}"] - for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): - cleaned_key = str(key or "").strip() - cleaned_value = str(value or "").strip() - if not cleaned_key or not cleaned_value or cleaned_key == "file": - continue - cq_args.append( - f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" - ) - delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") - if record.display_name: - history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + # Route by media type + if record.media_type == "image": + _render_image_tag(record, uid, strict, delivery_parts, history_parts) else: - history_parts.append(f"[图片 uid={uid}]") + _render_file_tag( + record, + uid, + strict, + delivery_parts, + history_parts, + pending_files, + ) + attachments.append(record.prompt_ref()) delivery_parts.append(message[last_index:]) @@ -1143,4 +1172,113 @@ async def render_message_with_pic_placeholders( delivery_text="".join(delivery_parts), history_text="".join(history_parts), attachments=attachments, + pending_file_sends=tuple(pending_files), ) + + +def _render_image_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], +) -> None: + """Render an image attachment as an inline CQ:image.""" + image_source = record.source_ref + if record.local_path: + image_source = Path(record.local_path).resolve().as_uri() + elif not image_source: + replacement = f"[图片 uid={uid} 缺少文件]" + if strict: + raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return + + cq_args = [f"file={image_source}"] + for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): + cleaned_key = str(key or "").strip() + cleaned_value = str(value or "").strip() + if not cleaned_key or not cleaned_value or cleaned_key == "file": + continue + cq_args.append( + f"{_escape_cq_component(cleaned_key)}={_escape_cq_component(cleaned_value)}" + ) + delivery_parts.append(f"[CQ:image,{','.join(cq_args)}]") + if record.display_name: + history_parts.append(f"[图片 uid={uid} name={record.display_name}]") + else: + history_parts.append(f"[图片 uid={uid}]") + + +def _render_file_tag( + record: AttachmentRecord, + uid: str, + strict: bool, + delivery_parts: list[str], + history_parts: list[str], + pending_files: list[AttachmentRecord], +) -> None: + """Render a non-image attachment as a pending file send.""" + if not record.local_path or not Path(record.local_path).is_file(): + replacement = f"[文件 uid={uid} 缺少本地文件]" + if strict: + raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") + delivery_parts.append(replacement) + history_parts.append(replacement) + return + + # Remove from delivery text (file sent separately) + # Keep a readable placeholder in history + name_part = f" name={record.display_name}" if record.display_name else "" + history_parts.append(f"[文件 uid={uid}{name_part}]") + pending_files.append(record) + + +# Backward-compatible alias +render_message_with_pic_placeholders = render_message_with_attachments + + +async def dispatch_pending_file_sends( + rendered: RenderedRichMessage, + *, + sender: Any, + target_type: str, + target_id: int, +) -> None: + """Send pending file attachments collected by *render_message_with_attachments*. + + This is best-effort: each file send failure is logged but does not interrupt + the remaining sends or the caller. + """ + if not rendered.pending_file_sends or sender is None: + return + for record in rendered.pending_file_sends: + if not record.local_path or not Path(record.local_path).is_file(): + logger.warning( + "[文件发送] 跳过:本地文件缺失 uid=%s path=%s", + record.uid, + record.local_path, + ) + continue + try: + if target_type == "group": + await sender.send_group_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + else: + await sender.send_private_file( + target_id, + record.local_path, + name=record.display_name or None, + ) + except Exception: + logger.warning( + "[文件发送] 发送失败(最佳努力) uid=%s target=%s:%s", + record.uid, + target_type, + target_id, + exc_info=True, + ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 651da57..8cfeae9 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -72,6 +72,7 @@ class GroupPokeRecord: group_name: str sender_role: str sender_title: str + sender_level: str class MessageHandler: @@ -113,6 +114,7 @@ def __init__( self.security, queue_manager=self.queue_manager, rate_limiter=self.rate_limiter, + history_manager=self.history_manager, ) self.ai_coordinator = AICoordinator( config, @@ -144,6 +146,77 @@ def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: self._repeat_locks[group_id] = lock return lock + async def _annotate_meme_descriptions( + self, + attachments: list[dict[str, str]], + scope_key: str, + ) -> list[dict[str, str]]: + """为图片附件添加表情包描述(如果在表情库中找到)。 + + 采用批量查询:收集所有 SHA256 哈希值,一次性查询,然后映射结果。 + 最佳努力:任何失败时返回原始列表。 + """ + if not attachments: + return attachments + + ai_client = getattr(self, "ai", None) + if ai_client is None: + return attachments + + attachment_registry = getattr(ai_client, "attachment_registry", None) + if attachment_registry is None: + return attachments + + meme_service = getattr(ai_client, "_meme_service", None) + if meme_service is None or not getattr(meme_service, "enabled", False): + return attachments + + meme_store = getattr(meme_service, "_store", None) + if meme_store is None: + return attachments + + try: + # 1. 从图片附件收集唯一的 SHA256 哈希值 + uid_to_hash: dict[str, str] = {} + for att in attachments: + uid = att.get("uid", "") + if not uid.startswith("pic_"): + continue + record = attachment_registry.resolve(uid, scope_key) + if record and record.sha256: + uid_to_hash[uid] = record.sha256 + + if not uid_to_hash: + return attachments + + # 2. 批量查询:去重哈希值 + unique_hashes = set(uid_to_hash.values()) + hash_to_desc: dict[str, str] = {} + for h in unique_hashes: + meme = await meme_store.find_by_sha256(h) + if meme and meme.description: + hash_to_desc[h] = meme.description + + if not hash_to_desc: + return attachments + + # 3. 构建带注释的新列表 + result: list[dict[str, str]] = [] + for att in attachments: + uid = att.get("uid", "") + sha = uid_to_hash.get(uid, "") + desc = hash_to_desc.get(sha, "") + if desc: + new_att = dict(att) + new_att["description"] = f"[表情包] {desc}" + result.append(new_att) + else: + result.append(att) + return result + except Exception: + logger.warning("表情包自动匹配失败,跳过", exc_info=True) + return attachments + async def _collect_message_attachments( self, message_content: list[dict[str, Any]], @@ -176,7 +249,10 @@ async def _collect_message_attachments( if onebot else None, ) - return result.attachments + attachments = result.attachments + # 为图片附件添加表情包描述 + attachments = await self._annotate_meme_descriptions(attachments, scope_key) + return attachments def _schedule_meme_ingest( self, @@ -384,6 +460,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_poke.group_name, sender_role=group_poke.sender_role, sender_title=group_poke.sender_title, + sender_level=group_poke.sender_level, ) return @@ -570,6 +647,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_nickname: str = group_sender.get("nickname", "") sender_role: str = group_sender.get("role", "member") sender_title: str = group_sender.get("title", "") + sender_level: str = str(group_sender.get("level", "")).strip() # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) @@ -623,6 +701,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_name, role=sender_role, title=sender_title, + level=sender_level, message_id=trigger_message_id, attachments=group_attachments, ) @@ -775,6 +854,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, trigger_message_id=trigger_message_id, ) @@ -837,11 +917,13 @@ async def _record_group_poke_history( sender_nickname = "" sender_role = "member" sender_title = "" + sender_level = "" if isinstance(sender, dict): sender_card = str(sender.get("card", "")).strip() sender_nickname = str(sender.get("nickname", "")).strip() sender_role = str(sender.get("role", "member")).strip() or "member" sender_title = str(sender.get("title", "")).strip() + sender_level = str(sender.get("level", "")).strip() if not sender_card and not sender_nickname: try: @@ -855,6 +937,7 @@ async def _record_group_poke_history( str(member_info.get("role", "member")).strip() or "member" ) sender_title = str(member_info.get("title", "")).strip() + sender_level = str(member_info.get("level", "")).strip() except Exception as exc: logger.warning( "[通知] 获取拍一拍群成员信息失败: group=%s user=%s err=%s", @@ -898,6 +981,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, role=sender_role, title=sender_title, + level=sender_level, ) except Exception as exc: logger.warning( @@ -912,6 +996,7 @@ async def _record_group_poke_history( group_name=normalized_group_name, sender_role=sender_role, sender_title=sender_title, + sender_level=sender_level, ) async def _extract_bilibili_ids( diff --git a/src/Undefined/services/ai_coordinator.py b/src/Undefined/services/ai_coordinator.py index 5319a7d..321e555 100644 --- a/src/Undefined/services/ai_coordinator.py +++ b/src/Undefined/services/ai_coordinator.py @@ -6,6 +6,7 @@ from Undefined.attachments import ( attachment_refs_to_xml, build_attachment_scope, + dispatch_pending_file_sends, render_message_with_pic_placeholders, ) from Undefined.config import Config @@ -72,6 +73,7 @@ async def handle_auto_reply( group_name: str = "未知群聊", sender_role: str = "member", sender_title: str = "", + sender_level: str = "", trigger_message_id: int | None = None, ) -> None: """群聊自动回复入口:根据消息内容、命中情况和安全检测决定是否回复 @@ -130,6 +132,7 @@ async def handle_auto_reply( text, attachments=attachments, message_id=trigger_message_id, + level=sender_level, ) logger.debug( "[自动回复] full_question_len=%s group=%s sender=%s", @@ -472,6 +475,12 @@ async def send_private_cb( rendered.delivery_text, history_message=rendered.history_text, ) + await dispatch_pending_file_sends( + rendered, + sender=self.sender, + target_type="private", + target_id=user_id, + ) except Exception: logger.exception("私聊回复执行出错") raise @@ -707,6 +716,7 @@ def _build_prompt( text: str, attachments: list[dict[str, str]] | None = None, message_id: int | None = None, + level: str = "", ) -> str: """构建最终发送给 AI 的结构化 XML 消息 Prompt @@ -724,10 +734,11 @@ def _build_prompt( message_id_attr = "" if message_id is not None: message_id_attr = f' message_id="{escape_xml_attr(message_id)}"' + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" attachment_xml = ( f"\n{attachment_refs_to_xml(attachments)}" if attachments else "" ) - return f"""{prefix} + return f"""{prefix} {safe_text}{attachment_xml} diff --git a/src/Undefined/services/command.py b/src/Undefined/services/command.py index 8abfcf5..075fd95 100644 --- a/src/Undefined/services/command.py +++ b/src/Undefined/services/command.py @@ -94,6 +94,7 @@ def __init__( security: SecurityService, queue_manager: Any = None, rate_limiter: Any = None, + history_manager: Any = None, ) -> None: """初始化命令分发器 @@ -106,6 +107,7 @@ def __init__( security: 安全审计与限流服务 queue_manager: AI 请求队列管理器 rate_limiter: 速率限制器 + history_manager: 消息历史记录管理器 """ self.config = config self.sender = sender @@ -115,6 +117,7 @@ def __init__( self.security = security self.queue_manager = queue_manager self.rate_limiter = rate_limiter + self.history_manager = history_manager self.naga_store: Any = None self._token_usage_storage = TokenUsageStorage() # 存储 stats 分析结果,用于队列回调 @@ -1078,6 +1081,8 @@ async def _send_target_message(message: str) -> None: scope=scope, user_id=user_id, is_webui_session=is_webui_session, + cognitive_service=getattr(self.ai, "_cognitive_service", None), + history_manager=self.history_manager, ) try: diff --git a/src/Undefined/services/commands/context.py b/src/Undefined/services/commands/context.py index 507af5e..56732b9 100644 --- a/src/Undefined/services/commands/context.py +++ b/src/Undefined/services/commands/context.py @@ -32,3 +32,5 @@ class CommandContext: scope: str = "group" user_id: int | None = None is_webui_session: bool = False + cognitive_service: Any = None + history_manager: Any = None diff --git a/src/Undefined/skills/agents/summary_agent/__init__.py b/src/Undefined/skills/agents/summary_agent/__init__.py new file mode 100644 index 0000000..13726fa --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/__init__.py @@ -0,0 +1 @@ +# summary_agent diff --git a/src/Undefined/skills/agents/summary_agent/config.json b/src/Undefined/skills/agents/summary_agent/config.json new file mode 100644 index 0000000..1536027 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "summary_agent", + "description": "消息总结助手,拉取指定范围的聊天消息并进行智能总结。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "用户的总结需求,例如:'总结最近50条消息'、'总结过去1小时的聊天'、'总结今天的技术讨论'" + } + }, + "required": ["prompt"] + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/handler.py b/src/Undefined/skills/agents/summary_agent/handler.py new file mode 100644 index 0000000..757f8a2 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/handler.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 summary_agent。""" + user_prompt = str(args.get("prompt", "")).strip() + return await run_agent_with_tools( + agent_name="summary_agent", + user_content=user_prompt, + empty_user_content_message="请提供您的总结需求", + default_prompt="你是一个消息总结助手。使用 fetch_messages 工具获取聊天记录,然后进行智能总结。", + context=context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=10, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/summary_agent/intro.md b/src/Undefined/skills/agents/summary_agent/intro.md new file mode 100644 index 0000000..a7343db --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/intro.md @@ -0,0 +1,30 @@ +# summary_agent - 消息总结助手 + +## 定位 +专业的聊天消息总结智能体,从海量聊天记录中提取关键信息并生成结构化报告。 + +## 擅长 +- ✅ 按**条数**拉取消息 (默认50条,最大500条) +- ✅ 按**时间范围**拉取消息 (支持1h/6h/1d/7d等格式) +- ✅ **话题提取**: 识别主要讨论主题 +- ✅ **要点归纳**: 总结重要决策、结论、共识 +- ✅ **参与者分析**: 识别活跃用户及其贡献 +- ✅ **资源收集**: 提取链接、代码片段等 +- ✅ 输出**结构化、清晰**的总结报告 + +## 边界 +- ❌ **不做实时监控**: 仅分析历史消息,不监听新消息 +- ❌ **不做情感分析**: 专注事实总结,不评价情绪 +- ❌ **不做预测**: 不推测未来走向 + +## 输入偏好 +- 明确的**条数要求**: "总结最近100条消息" +- 明确的**时间范围**: "总结过去6小时的聊天"、"总结今天的讨论" +- 特定的**总结目标**: "总结技术讨论"、"提取所有链接" +- 如果用户仅说"总结一下",默认总结最近50条消息 + +## 适用场景 +- 快速了解错过的聊天内容 +- 回顾会议或讨论要点 +- 整理特定时间段的聊天记录 +- 提取重要链接和资源 diff --git a/src/Undefined/skills/agents/summary_agent/prompt.md b/src/Undefined/skills/agents/summary_agent/prompt.md new file mode 100644 index 0000000..3b873f9 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/prompt.md @@ -0,0 +1,79 @@ +# 消息总结助手 + +你是一个专业的聊天消息总结助手,擅长从大量聊天记录中提取关键信息并生成结构化的总结报告。 + +## 核心能力 + +- 使用 `fetch_messages` 工具拉取指定范围的聊天消息 +- 支持按**消息条数**(如最近50条)或**时间范围**(如过去1小时、今天)筛选 +- 提取主题、关键参与者、重要决策、链接资源等 +- 生成清晰、结构化的总结报告 + +## 工作流程 + +1. **理解需求**: 分析用户的总结需求,确定查询参数 + - 如果用户指定了条数(如"最近50条"),使用 `count` 参数 + - 如果用户指定了时间范围(如"过去1小时"、"今天"),使用 `time_range` 参数 + - 如果用户未明确指定,**默认使用最近50条消息** + +2. **拉取消息**: 调用 `fetch_messages` 工具获取聊天记录 + - `count`: 消息条数,默认50,最大500 + - `time_range`: 时间范围,支持 "1h"(1小时)、"6h"(6小时)、"1d"(1天)、"7d"(7天) + +3. **分析总结**: 对获取的消息进行智能分析 + - 识别主要讨论话题 + - 提取关键参与者及其贡献 + - 总结重要决策、结论或共识 + - 收集提到的链接、资源、代码片段 + - 标注特别重要或需要关注的信息 + +4. **生成报告**: 以清晰的结构化格式输出总结 + - 使用**要点列表**而非长段落 + - 保持**简洁但全面** + - 突出**关键信息** + +## 输出格式建议 + +``` +📊 消息总结 (时间范围/条数) + +🔍 主要话题: +- 话题1: 简要描述 +- 话题2: 简要描述 + +👥 活跃参与者: +- 用户A: 主要贡献 +- 用户B: 主要贡献 + +💡 重要要点: +- 要点1 +- 要点2 + +🔗 相关链接/资源: +- 链接1 +- 链接2 + +⚠️ 需要关注: +- 重要事项或待办 +``` + +## 注意事项 + +- 保持**客观中立**,不加入主观评价 +- 如果消息量很大,**优先突出重点**,可省略琐碎细节 +- 如果讨论涉及敏感话题,**谨慎措辞** +- 如果消息为空或无有效内容,**明确说明** + +## 示例场景 + +**用户**: "总结最近100条消息" +→ 调用 `fetch_messages(count=100)` → 分析并总结 + +**用户**: "总结过去6小时的聊天" +→ 调用 `fetch_messages(time_range="6h")` → 分析并总结 + +**用户**: "总结今天大家讨论了什么技术问题" +→ 调用 `fetch_messages(time_range="1d")` → 重点提取技术相关话题 + +**用户**: "总结一下" +→ 调用 `fetch_messages(count=50)` (使用默认值) → 总结最近消息 diff --git a/src/Undefined/skills/agents/summary_agent/tools/__init__.py b/src/Undefined/skills/agents/summary_agent/tools/__init__.py new file mode 100644 index 0000000..fb43c21 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/__init__.py @@ -0,0 +1 @@ +# summary_agent tools diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json new file mode 100644 index 0000000..d64a015 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json @@ -0,0 +1,20 @@ +{ + "type": "function", + "function": { + "name": "fetch_messages", + "description": "从当前会话拉取聊天消息。支持按条数或时间范围筛选。", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "description": "要获取的消息条数,默认50,最大500。" + }, + "time_range": { + "type": "string", + "description": "时间范围,如 '1h'、'6h'、'1d'、'7d'。与 count 互斥,优先使用 time_range。" + } + } + } + } +} diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py new file mode 100644 index 0000000..7ea5534 --- /dev/null +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import logging +import re +from datetime import datetime, timedelta +from typing import Any + +logger = logging.getLogger(__name__) + +_TIME_RANGE_PATTERN = re.compile(r"^(\d+)([hHdDwW])$") +_TIME_UNIT_SECONDS = {"h": 3600, "d": 86400, "w": 604800} +_MAX_COUNT = 500 +_DEFAULT_COUNT = 50 +_MAX_FETCH_FOR_TIME_FILTER = 2000 + + +def _parse_time_range(value: str) -> int | None: + """Parse time range string like '1h', '6h', '1d', '7d' into seconds.""" + match = _TIME_RANGE_PATTERN.match(value.strip()) + if not match: + return None + amount = int(match.group(1)) + unit = match.group(2).lower() + return amount * _TIME_UNIT_SECONDS.get(unit, 3600) + + +def _filter_by_time( + messages: list[dict[str, Any]], seconds: int +) -> list[dict[str, Any]]: + """Filter messages to only include those within the given time range from now.""" + cutoff = datetime.now() - timedelta(seconds=seconds) + result = [] + for msg in messages: + ts_str = msg.get("timestamp", "") + if not ts_str: + continue + try: + ts = datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S") + except ValueError: + continue + if ts >= cutoff: + result.append(msg) + return result + + +def _format_messages(messages: list[dict[str, Any]]) -> str: + """Format messages into readable text for the summary agent.""" + lines = [] + for msg in messages: + ts = msg.get("timestamp", "") + name = msg.get("display_name", "未知用户") + text = msg.get("message", "") + role = msg.get("role", "") + title = msg.get("title", "") + + prefix = f"[{ts}] " + if title: + prefix += f"[{title}] " + if role and role not in ("member", ""): + prefix += f"({role}) " + prefix += f"{name}: " + lines.append(f"{prefix}{text}") + return "\n".join(lines) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """拉取当前会话的聊天消息。""" + history_manager = context.get("history_manager") + if not history_manager: + return "历史记录管理器未配置" + + group_id = context.get("group_id", 0) or 0 + user_id = context.get("user_id", 0) or 0 + + if int(group_id) > 0: + chat_type = "group" + chat_id = str(group_id) + else: + chat_type = "private" + chat_id = str(user_id) + + time_range_str = str(args.get("time_range", "")).strip() + raw_count = args.get("count", _DEFAULT_COUNT) + try: + count = min(max(int(raw_count), 1), _MAX_COUNT) + except (TypeError, ValueError): + count = _DEFAULT_COUNT + + if time_range_str: + seconds = _parse_time_range(time_range_str) + if seconds is None: + return f"无法解析时间范围: {time_range_str}(支持格式: 1h, 6h, 1d, 7d)" + fetch_count = max(count * 2, _MAX_FETCH_FOR_TIME_FILTER) + messages = history_manager.get_recent(chat_id, chat_type, 0, fetch_count) + if messages: + messages = _filter_by_time(messages, seconds) + else: + messages = history_manager.get_recent(chat_id, chat_type, 0, count) + + if not messages: + return "当前会话暂无消息记录" + + formatted = _format_messages(messages) + total = len(messages) + header = f"共获取 {total} 条消息" + if time_range_str: + header += f"(时间范围: {time_range_str})" + return f"{header}\n\n{formatted}" diff --git a/src/Undefined/skills/commands/profile/config.json b/src/Undefined/skills/commands/profile/config.json new file mode 100644 index 0000000..6b47c63 --- /dev/null +++ b/src/Undefined/skills/commands/profile/config.json @@ -0,0 +1,16 @@ +{ + "name": "profile", + "description": "查看用户或群聊侧写", + "usage": "/profile [group]", + "example": "/profile", + "permission": "public", + "rate_limit": { + "user": 60, + "admin": 10, + "superadmin": 0 + }, + "show_in_help": true, + "order": 25, + "allow_in_private": true, + "aliases": ["me", "p"] +} diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py new file mode 100644 index 0000000..aba240a --- /dev/null +++ b/src/Undefined/skills/commands/profile/handler.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from Undefined.services.commands.context import CommandContext + +_MAX_PROFILE_LENGTH = 3000 + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: + if len(text) <= limit: + return text + return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" + + +async def _send(context: CommandContext, text: str) -> None: + """Send message to appropriate channel.""" + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /profile 命令。""" + cognitive_service = context.cognitive_service + if cognitive_service is None: + await _send(context, "❌ 侧写服务未启用") + return + + # Parse subcommand + sub = args[0].lower().strip() if args else "" + + if sub == "group": + if _is_private(context): + await _send(context, "❌ 私聊中不支持查看群聊侧写") + return + entity_type = "groups" + entity_id = str(context.group_id) + empty_hint = "暂无群聊侧写数据" + else: + entity_type = "users" + entity_id = str(context.sender_id) + empty_hint = "暂无侧写数据" + + profile = await cognitive_service.get_profile(entity_type, entity_id) + if not profile: + await _send(context, f"📭 {empty_hint}") + return + + await _send(context, _truncate(profile)) diff --git a/src/Undefined/skills/commands/summary/config.json b/src/Undefined/skills/commands/summary/config.json new file mode 100644 index 0000000..9135eda --- /dev/null +++ b/src/Undefined/skills/commands/summary/config.json @@ -0,0 +1,16 @@ +{ + "name": "summary", + "description": "总结聊天消息", + "usage": "/summary [条数|时间范围] [自定义描述]", + "example": "/summary 50", + "permission": "public", + "rate_limit": { + "user": 120, + "admin": 30, + "superadmin": 0 + }, + "show_in_help": true, + "order": 26, + "allow_in_private": true, + "aliases": ["sum"] +} diff --git a/src/Undefined/skills/commands/summary/handler.py b/src/Undefined/skills/commands/summary/handler.py new file mode 100644 index 0000000..a35dc4d --- /dev/null +++ b/src/Undefined/skills/commands/summary/handler.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import logging +import re +from typing import Any + +from Undefined.services.commands.context import CommandContext + +logger = logging.getLogger(__name__) + +_TIME_RANGE_RE = re.compile(r"^\d+[hHdDwW]$") +_DEFAULT_COUNT = 50 + + +def _parse_args(args: list[str]) -> tuple[int | None, str | None, str]: + """Parse command arguments into (count, time_range, custom_prompt). + + Returns: + Tuple of (count, time_range, custom_prompt). + count and time_range are mutually exclusive; at most one is non-None. + """ + if not args: + return _DEFAULT_COUNT, None, "" + + first = args[0] + rest = " ".join(args[1:]).strip() + + if first.isdigit(): + count = max(1, min(int(first), 500)) + return count, None, rest + + if _TIME_RANGE_RE.match(first): + return None, first, rest + + # First arg is not a number or time range — treat everything as prompt + return _DEFAULT_COUNT, None, " ".join(args).strip() + + +def _build_prompt(count: int | None, time_range: str | None, custom_prompt: str) -> str: + """Build the natural language prompt for summary_agent.""" + parts: list[str] = ["请总结"] + if time_range: + parts.append(f"过去 {time_range} 内的聊天消息") + elif count: + parts.append(f"最近 {count} 条聊天消息") + else: + parts.append(f"最近 {_DEFAULT_COUNT} 条聊天消息") + + if custom_prompt: + parts.append(f",重点关注:{custom_prompt}") + + return "".join(parts) + + +def _is_private(context: CommandContext) -> bool: + return context.scope == "private" + + +async def _send(context: CommandContext, text: str) -> None: + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, text) + else: + await context.sender.send_group_message(context.group_id, text) + + +async def execute(args: list[str], context: CommandContext) -> None: + """处理 /summary 命令。""" + if context.history_manager is None: + await _send(context, "❌ 历史记录管理器未配置") + return + + count, time_range, custom_prompt = _parse_args(args) + prompt = _build_prompt(count, time_range, custom_prompt) + + # Build agent context + agent_context: dict[str, Any] = { + "ai_client": context.ai, + "history_manager": context.history_manager, + "group_id": context.group_id, + "user_id": int(context.user_id or context.sender_id), + "sender_id": context.sender_id, + "request_type": "group" if int(context.group_id) > 0 else "private", + "runtime_config": getattr(context.ai, "runtime_config", None), + "queue_lane": None, + } + + await _send(context, "📝 正在总结消息,请稍候...") + + try: + from Undefined.skills.agents.summary_agent.handler import execute as run_summary + + result = await run_summary({"prompt": prompt}, agent_context) + except Exception: + logger.exception("[/summary] 执行总结失败") + await _send(context, "❌ 消息总结失败,请稍后重试") + return + + if not result or not result.strip(): + await _send(context, "📭 未能生成总结内容") + return + + await _send(context, result) diff --git a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py index 05f4ba6..12ff836 100644 --- a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py @@ -178,6 +178,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: timestamp_str = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") if msg_type_val == "group": # 确保群名以"群"结尾 @@ -189,8 +192,17 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + # 格式:XML 标准化 - formatted.append(f""" + formatted.append(f""" {text} """) diff --git a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py index fda39ef..54874bb 100644 --- a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py @@ -105,6 +105,9 @@ def _format_message_xml(msg: dict[str, Any]) -> str: timestamp = msg.get("timestamp", "") text = msg.get("message", "") message_id = msg.get("message_id") + role = msg.get("role", "") + title = msg.get("title", "") + level = msg.get("level", "") location = _format_message_location(msg_type_val, chat_name) @@ -112,7 +115,16 @@ def _format_message_xml(msg: dict[str, Any]) -> str: if message_id is not None: msg_id_attr = f' message_id="{message_id}"' - return f""" + extra_attrs = "" + if msg_type_val == "group": + if role: + extra_attrs += f' role="{role}"' + if title: + extra_attrs += f' title="{title}"' + if level: + extra_attrs += f' level="{level}"' + + return f""" {text} """ diff --git a/src/Undefined/skills/toolsets/messages/send_message/handler.py b/src/Undefined/skills/toolsets/messages/send_message/handler.py index 9d21aab..f5e8b13 100644 --- a/src/Undefined/skills/toolsets/messages/send_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -168,6 +169,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type=target_type, + target_id=target_id, + ) return _format_send_success(sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py index 727cb2d..f92cab8 100644 --- a/src/Undefined/skills/toolsets/messages/send_private_message/handler.py +++ b/src/Undefined/skills/toolsets/messages/send_private_message/handler.py @@ -2,6 +2,7 @@ import logging from Undefined.attachments import ( + dispatch_pending_file_sends, render_message_with_pic_placeholders, scope_from_context, ) @@ -115,6 +116,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: **send_kwargs, ) context["message_sent_this_turn"] = True + await dispatch_pending_file_sends( + rendered, + sender=sender, + target_type="private", + target_id=user_id, + ) return _format_send_success(user_id, sent_message_id) except Exception as e: logger.exception( diff --git a/src/Undefined/skills/toolsets/render/render_latex/config.json b/src/Undefined/skills/toolsets/render/render_latex/config.json index a351083..16b1f99 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/config.json +++ b/src/Undefined/skills/toolsets/render/render_latex/config.json @@ -2,27 +2,18 @@ "type": "function", "function": { "name": "render_latex", - "description": "将 LaTeX 公式或文档渲染为图片,使用系统外部 LaTeX(需预先安装 TeX Live / MiKTeX)。默认返回可嵌入回复的图片 UID(embed),也可直接发送到指定目标(send)。", + "description": "将 LaTeX 数学公式渲染为图片或 PDF 文档,使用 MathJax(不依赖系统 TeX 安装)。支持 LaTeX 数学子集(amsmath、equation、align、matrix 等),但不支持自定义 TeX 包。返回可嵌入回复的附件 UID。", "parameters": { "type": "object", "properties": { "content": { "type": "string", - "description": "要渲染的 LaTeX 内容。支持 $...$、$$...$$、\\[...\\]、\\(...\\) 及完整环境(\\begin{align}...\\end{align} 等);\\begin{document}...\\end{document} 外层包装会自动去掉。" + "description": "要渲染的 LaTeX 数学内容。支持 $...$(行内)、$$...$$(块级)、\\[...\\]、\\(...\\) 及标准数学环境(\\begin{align}、\\begin{equation}、\\begin{matrix} 等)。如果不包含分隔符,会自动用 \\[ ... \\] 包装。\\begin{document}...\\end{document} 外层包装会自动去掉。" }, - "delivery": { + "output_format": { "type": "string", - "description": "图片交付方式:embed 返回可插入回复的图片 UID;send 立即发送到目标", - "enum": ["embed", "send"] - }, - "target_id": { - "type": "integer", - "description": "目标 ID(群号或用户 QQ 号,仅 delivery=send 时需要,不提供则从当前会话推断)" - }, - "message_type": { - "type": "string", - "description": "消息类型(仅 delivery=send 时需要,不提供则从当前会话推断)", - "enum": ["group", "private"] + "description": "输出格式:\"png\"(默认,图片)或 \"pdf\"(PDF 文档)", + "enum": ["png", "pdf"] } }, "required": ["content"] diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 8ac6e06..a449cd2 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -1,15 +1,8 @@ from __future__ import annotations -from pathlib import Path +import logging import re from typing import Any, Dict -import logging -import uuid - -import matplotlib - -matplotlib.use("Agg") -import matplotlib.pyplot as plt from Undefined.attachments import scope_from_context @@ -20,9 +13,15 @@ re.DOTALL, ) +# MathJax 数学分隔符模式 +_MATH_DELIMITER_PATTERN = re.compile( + r"(\$\$|\\\[|\\\(|\\begin\{)", + re.MULTILINE, +) + def _strip_document_wrappers(content: str) -> str: - """去掉 \\begin{document}...\\end{document} 外层包装;matplotlib 会自行构造文档。""" + """去掉 \\begin{document}...\\end{document} 外层包装。""" text = content.strip() match = _DOCUMENT_PATTERN.fullmatch(text) if match is None: @@ -30,136 +29,175 @@ def _strip_document_wrappers(content: str) -> str: return match.group("body").strip() -def _render_latex_image(filepath: Path, content: str) -> None: - text = _strip_document_wrappers(content) - fig = plt.figure(figsize=(6, 2.5)) +def _has_math_delimiters(content: str) -> bool: + """检查内容是否已包含数学分隔符。""" + return bool(_MATH_DELIMITER_PATTERN.search(content)) + + +def _prepare_content(raw_content: str) -> str: + """ + 准备 LaTeX 内容: + 1. 去掉 document 包装 + 2. 处理字面量 \\n(LLM 输出常见问题) + 3. 如果没有数学分隔符,自动用 \\[ ... \\] 包装 + """ + content = _strip_document_wrappers(raw_content) + # 替换字面量 \\n 为真实换行符 + content = content.replace("\\n", "\n") + + if not _has_math_delimiters(content): + # 没有分隔符,自动包装为块级数学环境 + content = f"\\[\n{content}\n\\]" + + return content + + +def _build_html(latex_content: str) -> str: + """构建包含 MathJax 的 HTML 页面。""" + # HTML 转义(防止内容中的 < > & 破坏结构) + import html + + escaped_content = html.escape(latex_content) + + return f""" + + + + + + + +
+{escaped_content} +
+ +""" + + +async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[bytes, str]: + """ + 使用 MathJax + Playwright 渲染 LaTeX 内容。 + + 返回: (渲染后的字节流, MIME 类型) + """ try: - fig.patch.set_facecolor("white") - fig.text( - 0.5, - 0.5, - text, - fontsize=20, - verticalalignment="center", - horizontalalignment="center", - usetex=True, - wrap=True, + from playwright.async_api import ( + async_playwright, + TimeoutError as PwTimeoutError, ) - fig.savefig(filepath, dpi=200, bbox_inches="tight", pad_inches=0.25) - finally: - plt.close(fig) - - -def _resolve_send_target( - target_id: Any, - message_type: Any, - context: Dict[str, Any], -) -> tuple[int | str | None, str | None, str | None]: - """从参数或 context 推断发送目标。""" - if target_id is not None and message_type is not None: - return target_id, message_type, None - request_type = str(context.get("request_type", "") or "").strip().lower() - if request_type == "group": - gid = context.get("group_id") - if gid is not None: - return gid, "group", None - if request_type == "private": - uid = context.get("user_id") - if uid is not None: - return uid, "private", None - return None, None, "渲染成功,但缺少发送目标参数" + except ImportError: + raise ImportError( + "请运行 `uv run playwright install` 安装浏览器运行时" + ) from None + + html_content = _build_html(content) + + async with async_playwright() as p: + browser = await p.chromium.launch(headless=True) + try: + page = await browser.new_page() + await page.set_content(html_content) + + # 等待 MathJax 完成排版 + try: + await page.wait_for_function( + "() => window.MathJax?.startup?.promise?.then(() => true) ?? false", + timeout=15000, + ) + except PwTimeoutError: + logger.warning("MathJax 排版超时,内容可能过于复杂或网络不可达") + raise RuntimeError( + "LaTeX 内容可能过于复杂或网络不可达(MathJax 加载超时)" + ) from None + + if output_format == "pdf": + # 获取容器尺寸 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + bbox = await container.bounding_box() + if bbox is None: + raise RuntimeError("无法获取数学容器的边界框") + + # PDF 输出,设置合适的页面尺寸 + pdf_bytes = await page.pdf( + width=f"{bbox['width'] + 40}px", + height=f"{bbox['height'] + 40}px", + print_background=True, + ) + return pdf_bytes, "application/pdf" + else: + # PNG 输出 + container = await page.query_selector("#math-container") + if container is None: + raise RuntimeError("无法定位数学容器元素") + + screenshot_bytes = await container.screenshot(type="png") + return screenshot_bytes, "image/png" + + finally: + await browser.close() async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: - """渲染 LaTeX 数学公式为图片""" - content = str(args.get("content", "") or "") - delivery = str(args.get("delivery", "embed") or "embed").strip().lower() - target_id = args.get("target_id") - message_type = args.get("message_type") + """渲染 LaTeX 数学公式为图片或 PDF""" + raw_content = str(args.get("content", "") or "") + output_format = str(args.get("output_format", "png") or "png").strip().lower() - if not content: - return "内容不能为空" - if delivery not in {"embed", "send"}: - return f"delivery 无效:{delivery}。仅支持 embed 或 send" + # 参数校验 + if not raw_content or not raw_content.strip(): + return "LaTeX 内容不能为空" - if delivery == "send" and message_type and message_type not in ("group", "private"): - return "消息类型必须是 group 或 private" + if output_format not in {"png", "pdf"}: + return f"output_format 无效:{output_format}。仅支持 png 或 pdf" try: - from Undefined.utils.cache import cleanup_cache_dir - from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir - - filename = f"render_{uuid.uuid4().hex[:16]}.png" - filepath = ensure_dir(RENDER_CACHE_DIR) / filename + # 准备内容 + prepared_content = _prepare_content(raw_content) - _render_latex_image(filepath, content) + # 渲染 + rendered_bytes, mime_type = await _render_latex_to_bytes( + prepared_content, output_format + ) # 注册到附件系统 attachment_registry = context.get("attachment_registry") scope_key = scope_from_context(context) - record: Any = None - if attachment_registry is not None and scope_key: - try: - record = await attachment_registry.register_local_file( - scope_key, - filepath, - kind="image", - display_name=filename, - source_kind="rendered_image", - source_ref="render_latex", - ) - except Exception as exc: - logger.warning("注册渲染图片到附件系统失败: %s", exc) - - if delivery == "embed": - cleanup_cache_dir(RENDER_CACHE_DIR) - if record is None: - return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" - return f'' - - # delivery == "send" - resolved_target_id, resolved_message_type, target_error = _resolve_send_target( - target_id, message_type, context - ) - if target_error or resolved_target_id is None or resolved_message_type is None: - return target_error or "渲染成功,但缺少发送目标参数" - - sender = context.get("sender") - send_image_callback = context.get("send_image_callback") - - if sender: - from pathlib import Path - - cq_message = f"[CQ:image,file={Path(filepath).resolve().as_uri()}]" - if resolved_message_type == "group": - await sender.send_group_message(int(resolved_target_id), cq_message) - elif resolved_message_type == "private": - await sender.send_private_message(int(resolved_target_id), cq_message) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" - ) - elif send_image_callback: - await send_image_callback( - resolved_target_id, resolved_message_type, str(filepath) - ) - cleanup_cache_dir(RENDER_CACHE_DIR) - return ( - f"LaTeX 图片已渲染并发送到 {resolved_message_type} {resolved_target_id}" + + if attachment_registry is None or not scope_key: + return "渲染成功,但无法注册到附件系统(缺少 attachment_registry 或 scope_key)" + + kind = "image" if output_format == "png" else "file" + extension = "png" if output_format == "png" else "pdf" + display_name = f"latex.{extension}" + + try: + record = await attachment_registry.register_bytes( + scope_key, + rendered_bytes, + kind=kind, + display_name=display_name, + mime_type=mime_type, + source_kind="rendered_latex", + source_ref="render_latex", ) - else: - return "发送图片回调未设置" + tag = "pic" if output_format == "png" else "attachment" + return f'<{tag} uid="{record.uid}"/>' + + except Exception as exc: + logger.exception("注册渲染结果到附件系统失败: %s", exc) + return f"渲染成功,但注册到附件系统失败: {exc}" except ImportError as e: - missing_pkg = str(e).split("'")[1] if "'" in str(e) else "未知包" - return f"渲染失败:缺少依赖包 {missing_pkg},请运行: uv add {missing_pkg}" + logger.error("Playwright 导入失败: %s", e) + return "请运行 `uv run playwright install` 安装浏览器运行时" except RuntimeError as e: - err = str(e).lower() - if "latex" in err or "dvipng" in err or "dvi" in err: - logger.error("LaTeX 渲染失败(系统 TeX 环境不可用): %s", e) - return "渲染失败:系统 LaTeX 环境未安装或不完整,请按部署文档安装 TeX Live / MiKTeX 后重试。" - logger.exception("渲染并发送 LaTeX 图片失败: %s", e) - return "渲染失败,请稍后重试" + logger.error("LaTeX 渲染运行时错误: %s", e) + return str(e) except Exception as e: - logger.exception(f"渲染并发送 LaTeX 图片失败: {e}") - return "渲染失败,请稍后重试" + logger.exception("渲染 LaTeX 失败: %s", e) + return f"渲染失败:{e}" diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index 46dfafe..b6dfa4b 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -306,6 +306,7 @@ async def add_group_message( group_name: str = "", role: str = "member", title: str = "", + level: str = "", message_id: int | None = None, attachments: list[dict[str, str]] | None = None, ) -> None: @@ -334,6 +335,7 @@ async def add_group_message( "display_name": display_name, "role": role, "title": title, + "level": level, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "message": text_content, } diff --git a/tests/test_arxiv_sender.py b/tests/test_arxiv_sender.py index 120ab0a..312075b 100644 --- a/tests/test_arxiv_sender.py +++ b/tests/test_arxiv_sender.py @@ -17,6 +17,7 @@ @pytest.fixture(autouse=True) def _clear_inflight() -> None: arxiv_sender._INFLIGHT_SENDS.clear() + arxiv_sender._RECENT_SENDS.clear() def _paper_info() -> PaperInfo: @@ -165,3 +166,187 @@ async def _fake_once(**_: object) -> str: assert first_result == "ok" assert second_result == "ok" assert called == 1 + + +@pytest.mark.asyncio +async def test_recent_send_blocks_duplicate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After a successful send, duplicate should be blocked within cooldown.""" + sender = _sender() + call_count = 0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # First send should succeed + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Second send should be blocked by time-based dedup + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert "近期已发送过" in result2 + assert call_count == 1 # Still only 1 call + + +@pytest.mark.asyncio +async def test_recent_send_expires_after_cooldown( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """After cooldown expires, the same paper can be sent again.""" + sender = _sender() + call_count = 0 + mock_time = 0.0 + + async def _fake_once(**_: object) -> str: + nonlocal call_count + call_count += 1 + return "ok" + + def _fake_monotonic() -> float: + return mock_time + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + # Patch time.monotonic in the arxiv_sender module + monkeypatch.setattr( + arxiv_sender, "time", SimpleNamespace(monotonic=_fake_monotonic) + ) + + # First send at time 0 + result1 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result1 == "ok" + assert call_count == 1 + + # Advance time past cooldown + mock_time = arxiv_sender._DEDUP_COOLDOWN_SECONDS + 1.0 + + # Second send should succeed now + result2 = await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + assert result2 == "ok" + assert call_count == 2 # Second call executed + + +@pytest.mark.asyncio +async def test_recent_send_capacity_limit() -> None: + """When _RECENT_SENDS exceeds max size, oldest entries are evicted.""" + # Fill with max_size + 100 entries + for i in range(arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + arxiv_sender._RECENT_SENDS[key] = float(i) + + # Trigger eviction + arxiv_sender._evict_oldest_recent_sends() + + # Should have evicted back to max size + assert len(arxiv_sender._RECENT_SENDS) == arxiv_sender._RECENT_SENDS_MAX_SIZE + + # Oldest 100 should be gone + for i in range(100): + key = ("group", i, f"paper_{i}") + assert key not in arxiv_sender._RECENT_SENDS + + # Newest max_size should remain + for i in range(100, arxiv_sender._RECENT_SENDS_MAX_SIZE + 100): + key = ("group", i, f"paper_{i}") + assert key in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_cleanup_expired_recent_sends() -> None: + """Test that expired entries are removed while non-expired remain.""" + import time as time_module + + # Get current time + now = time_module.monotonic() + + # Add expired entries (old timestamps, more than 1 hour ago) + expired_key1 = ("group", 1, "expired1") + expired_key2 = ("group", 2, "expired2") + arxiv_sender._RECENT_SENDS[expired_key1] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 100.0 + ) + arxiv_sender._RECENT_SENDS[expired_key2] = ( + now - arxiv_sender._DEDUP_COOLDOWN_SECONDS - 50.0 + ) + + # Add non-expired entries (recent timestamps, within 1 hour) + recent_key1 = ("group", 3, "recent1") + recent_key2 = ("group", 4, "recent2") + arxiv_sender._RECENT_SENDS[recent_key1] = now - 100.0 # 100 seconds ago + arxiv_sender._RECENT_SENDS[recent_key2] = now - 10.0 # 10 seconds ago + + # Cleanup + arxiv_sender._cleanup_expired_recent_sends() + + # Expired should be gone + assert expired_key1 not in arxiv_sender._RECENT_SENDS + assert expired_key2 not in arxiv_sender._RECENT_SENDS + + # Recent should remain + assert recent_key1 in arxiv_sender._RECENT_SENDS + assert recent_key2 in arxiv_sender._RECENT_SENDS + + +@pytest.mark.asyncio +async def test_failed_send_not_recorded_in_recent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed send should NOT record in _RECENT_SENDS.""" + sender = _sender() + + async def _fake_once(**_: object) -> str: + raise RuntimeError("Send failed") + + monkeypatch.setattr(arxiv_sender, "_send_arxiv_paper_once", _fake_once) + + # Attempt send, should fail + with pytest.raises(RuntimeError, match="Send failed"): + await send_arxiv_paper( + paper_id="2501.01234", + sender=sender, + target_type="group", + target_id=123456, + max_file_size=100, + author_preview_limit=20, + summary_preview_chars=1000, + ) + + # Should NOT be recorded in recent sends + assert len(arxiv_sender._RECENT_SENDS) == 0 diff --git a/tests/test_attachment_tags.py b/tests/test_attachment_tags.py new file mode 100644 index 0000000..762ffd4 --- /dev/null +++ b/tests/test_attachment_tags.py @@ -0,0 +1,336 @@ +"""Tests for unified / tag rendering.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.attachments import ( + AttachmentRegistry, + AttachmentRenderError, + RenderedRichMessage, + dispatch_pending_file_sends, + render_message_with_attachments, + render_message_with_pic_placeholders, +) + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +_PDF_BYTES = b"%PDF-1.4 fake content for testing" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +# ---------- backward compatibility ---------- + + +@pytest.mark.asyncio +async def test_pic_tag_still_works(tmp_path: Path) -> None: + """ backward compat: renders image as CQ.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Look: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert len(result.attachments) == 1 + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_alias_is_same_function() -> None: + """render_message_with_pic_placeholders is an alias.""" + assert render_message_with_pic_placeholders is render_message_with_attachments + + +# ---------- unified tag ---------- + + +@pytest.mark.asyncio +async def test_attachment_tag_image(tmp_path: Path) -> None: + """ renders as CQ image (same as ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="cat.png", source_kind="test" + ) + msg = f'Here: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert rec.uid in result.history_text + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_attachment_tag_file(tmp_path: Path) -> None: + """ collects into pending_file_sends.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'See doc: ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + # File tag removed from delivery text + assert rec.uid not in result.delivery_text + assert "See doc: " in result.delivery_text + # Readable placeholder in history + assert f"[文件 uid={rec.uid}" in result.history_text + assert "doc.pdf" in result.history_text + # Collected in pending + assert len(result.pending_file_sends) == 1 + assert result.pending_file_sends[0].uid == rec.uid + + +@pytest.mark.asyncio +async def test_mixed_pic_and_attachment_tags(tmp_path: Path) -> None: + """Mix of and tags in the same message.""" + reg = _make_registry(tmp_path) + img = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="img.png", source_kind="test" + ) + doc = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f' and ' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "[CQ:image" in result.delivery_text + assert len(result.pending_file_sends) == 1 + assert len(result.attachments) == 2 + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image(tmp_path: Path) -> None: + """ tag with file UID shows type error.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "类型错误" in result.delivery_text + + +@pytest.mark.asyncio +async def test_pic_tag_rejects_non_image_strict(tmp_path: Path) -> None: + """ tag with file UID raises in strict mode.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + with pytest.raises(AttachmentRenderError, match="不是图片"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_attachment_tag_allows_any_type(tmp_path: Path) -> None: + """ tag does NOT reject file UIDs (unlike ).""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + assert "类型错误" not in result.delivery_text + assert len(result.pending_file_sends) == 1 + + +@pytest.mark.asyncio +async def test_invalid_uid_non_strict(tmp_path: Path) -> None: + """Unknown UID → placeholder in non-strict mode.""" + reg = _make_registry(tmp_path) + msg = '' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "不可用" in result.delivery_text + + +@pytest.mark.asyncio +async def test_invalid_uid_strict(tmp_path: Path) -> None: + """Unknown UID → exception in strict mode.""" + reg = _make_registry(tmp_path) + msg = '' + with pytest.raises(AttachmentRenderError, match="不可用"): + await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=True + ) + + +@pytest.mark.asyncio +async def test_file_tag_missing_local_path(tmp_path: Path) -> None: + """File with deleted local_path → error placeholder.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + assert rec.local_path is not None + Path(rec.local_path).unlink() + + msg = f'' + result = await render_message_with_attachments( + msg, registry=reg, scope_key="group:1", strict=False + ) + assert "缺少本地文件" in result.delivery_text + assert len(result.pending_file_sends) == 0 + + +@pytest.mark.asyncio +async def test_no_tags_passthrough() -> None: + """Message without tags passes through unchanged.""" + result = await render_message_with_attachments( + "Hello world", registry=None, scope_key="group:1", strict=False + ) + assert result.delivery_text == "Hello world" + assert result.history_text == "Hello world" + assert result.attachments == [] + assert result.pending_file_sends == () + + +@pytest.mark.asyncio +async def test_rendered_rich_message_default_pending() -> None: + """RenderedRichMessage.pending_file_sends defaults to empty tuple.""" + msg = RenderedRichMessage(delivery_text="hi", history_text="hi", attachments=[]) + assert msg.pending_file_sends == () + + +# ---------- dispatch_pending_file_sends ---------- + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_group(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_group_file for group targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file( + self, group_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("group", group_id, file_path, name)) + + async def send_private_file( + self, user_id: int, file_path: str, name: str | None = None + ) -> None: + calls.append(("private", user_id, file_path, name)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="group", target_id=12345 + ) + assert len(calls) == 1 + assert calls[0][0] == "group" + assert calls[0][1] == 12345 + + +@pytest.mark.asyncio +async def test_dispatch_pending_file_sends_private(tmp_path: Path) -> None: + """dispatch_pending_file_sends calls sender.send_private_file for private targets.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", + _PDF_BYTES, + kind="file", + display_name="report.pdf", + source_kind="test", + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + calls: list[tuple[Any, ...]] = [] + + class FakeSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + calls.append(("group", *a)) + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + calls.append(("private", *a)) + + await dispatch_pending_file_sends( + rendered, sender=FakeSender(), target_type="private", target_id=99999 + ) + assert len(calls) == 1 + assert calls[0][0] == "private" + assert calls[0][1] == 99999 + + +@pytest.mark.asyncio +async def test_dispatch_best_effort_on_failure(tmp_path: Path) -> None: + """File send failure doesn't propagate — best-effort.""" + reg = _make_registry(tmp_path) + rec = await reg.register_bytes( + "group:1", _PDF_BYTES, kind="file", display_name="doc.pdf", source_kind="test" + ) + rendered = RenderedRichMessage( + delivery_text="text", + history_text="text", + attachments=[], + pending_file_sends=(rec,), + ) + + class FailingSender: + async def send_group_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + async def send_private_file(self, *a: Any, **kw: Any) -> None: + raise RuntimeError("network error") + + # Should not raise + await dispatch_pending_file_sends( + rendered, sender=FailingSender(), target_type="group", target_id=1 + ) + + +@pytest.mark.asyncio +async def test_dispatch_no_pending_is_noop() -> None: + """No pending files → no calls.""" + rendered = RenderedRichMessage( + delivery_text="text", history_text="text", attachments=[] + ) + await dispatch_pending_file_sends( + rendered, sender=None, target_type="group", target_id=1 + ) diff --git a/tests/test_attachments_dedup.py b/tests/test_attachments_dedup.py new file mode 100644 index 0000000..c1114bf --- /dev/null +++ b/tests/test_attachments_dedup.py @@ -0,0 +1,147 @@ +"""Tests for attachment SHA-256 hash deduplication.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from Undefined.attachments import AttachmentRegistry + + +_PNG_BYTES = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR" + b"\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00" + b"\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" + b"\x0b\xe7\x02\x9d" + b"\x00\x00\x00\x00IEND\xaeB`\x82" +) + +# Different content → different hash +_PNG_BYTES_ALT = _PNG_BYTES + b"\x00" + + +def _make_registry(tmp_path: Path) -> AttachmentRegistry: + return AttachmentRegistry( + registry_path=tmp_path / "reg.json", + cache_dir=tmp_path / "cache", + ) + + +@pytest.mark.asyncio +async def test_same_hash_same_scope_same_kind_returns_same_uid( + tmp_path: Path, +) -> None: + """Identical bytes + scope + kind → dedup returns same record.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="b.png", source_kind="test" + ) + assert r1.uid == r2.uid + assert r1.sha256 == r2.sha256 + + +@pytest.mark.asyncio +async def test_different_content_gets_different_uid(tmp_path: Path) -> None: + """Different bytes → different records even in same scope/kind.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", + _PNG_BYTES_ALT, + kind="image", + display_name="b.png", + source_kind="test", + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_scope_no_dedup(tmp_path: Path) -> None: + """Same hash but different scope → separate records (scope isolation).""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:2", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.uid != r2.uid + + +@pytest.mark.asyncio +async def test_cross_kind_no_dedup(tmp_path: Path) -> None: + """Same hash + scope but different kind → separate records.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="file", display_name="a.bin", source_kind="test" + ) + assert r1.uid != r2.uid + assert r1.uid.startswith("pic_") + assert r2.uid.startswith("file_") + + +@pytest.mark.asyncio +async def test_file_deleted_causes_new_registration(tmp_path: Path) -> None: + """If the cached file is deleted, a new record is created.""" + reg = _make_registry(tmp_path) + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r1.local_path is not None + Path(r1.local_path).unlink() + + r2 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + assert r2.uid != r1.uid + assert r2.sha256 == r1.sha256 + + +@pytest.mark.asyncio +async def test_concurrent_identical_registrations(tmp_path: Path) -> None: + """Concurrent registrations with the same content should be safe.""" + reg = _make_registry(tmp_path) + + results = await asyncio.gather( + *( + reg.register_bytes( + "group:1", + _PNG_BYTES, + kind="image", + display_name="pic.png", + source_kind="test", + ) + for _ in range(5) + ) + ) + uids = {r.uid for r in results} + # All should resolve to the same UID (dedup) + assert len(uids) == 1 + + +@pytest.mark.asyncio +async def test_register_local_file_deduplicates(tmp_path: Path) -> None: + """register_local_file delegates to register_bytes, so dedup applies.""" + reg = _make_registry(tmp_path) + file_path = tmp_path / "input.png" + file_path.write_bytes(_PNG_BYTES) + + r1 = await reg.register_bytes( + "group:1", _PNG_BYTES, kind="image", display_name="a.png", source_kind="test" + ) + r2 = await reg.register_local_file( + "group:1", file_path, kind="image", display_name="input.png" + ) + assert r1.uid == r2.uid diff --git a/tests/test_coordinator_level.py b/tests/test_coordinator_level.py new file mode 100644 index 0000000..38ecaa6 --- /dev/null +++ b/tests/test_coordinator_level.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + + +from Undefined.services.ai_coordinator import AICoordinator + + +def _make_coordinator() -> AICoordinator: + """创建用于测试的 AICoordinator 实例""" + config = MagicMock() + ai = MagicMock() + queue_manager = MagicMock() + history_manager = MagicMock() + sender = MagicMock() + onebot = MagicMock() + scheduler = MagicMock() + security = MagicMock() + + coordinator = AICoordinator( + config=config, + ai=ai, + queue_manager=queue_manager, + history_manager=history_manager, + sender=sender, + onebot=onebot, + scheduler=scheduler, + security=security, + ) + + return coordinator + + +def test_build_prompt_with_level_includes_level_attribute() -> None: + """测试 _build_prompt 带 level 参数时 XML 包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + attachments=None, + message_id=123456, + level="Lv.5", + ) + + assert 'level="Lv.5"' in result + assert " None: + """测试 _build_prompt level 为空字符串时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level="", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt 不传 level 参数时 XML 不包含 level 属性""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + ) + + assert "level=" not in result + assert " None: + """测试 _build_prompt level 包含特殊字符时能正确转义""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="member", + title="", + time_str="2026-04-11 10:00:00", + text="测试消息", + level='Lv.5 ', + ) + + assert "level=" in result + assert '' not in result + assert "<" in result or "&" in result or """ in result + + +def test_build_prompt_with_all_attributes() -> None: + """测试 _build_prompt 包含所有属性时的输出""" + coordinator = _make_coordinator() + + result = coordinator._build_prompt( + prefix="系统提示:\n", + name="测试用户", + uid=10001, + gid=20001, + gname="测试群", + loc="测试群", + role="admin", + title="管理员", + time_str="2026-04-11 10:00:00", + text="测试消息内容", + attachments=[{"uid": "pic_001", "kind": "image"}], + message_id=123456, + level="Lv.10", + ) + + assert "系统提示:\n" in result + assert 'sender="测试用户"' in result + assert 'sender_id="10001"' in result + assert 'group_id="20001"' in result + assert 'group_name="测试群"' in result + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert "测试消息内容" in result + assert "" in result diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py new file mode 100644 index 0000000..ddba794 --- /dev/null +++ b/tests/test_fetch_messages_tool.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from Undefined.skills.agents.summary_agent.tools.fetch_messages.handler import ( + _filter_by_time, + _format_messages, + _parse_time_range, + execute as fetch_messages_execute, +) + + +# -- _parse_time_range unit tests -- + + +def test_parse_time_range_1h() -> None: + """'1h' → 3600.""" + assert _parse_time_range("1h") == 3600 + + +def test_parse_time_range_6h() -> None: + """'6h' → 21600.""" + assert _parse_time_range("6h") == 21600 + + +def test_parse_time_range_1d() -> None: + """'1d' → 86400.""" + assert _parse_time_range("1d") == 86400 + + +def test_parse_time_range_7d() -> None: + """'7d' → 604800.""" + assert _parse_time_range("7d") == 604800 + + +def test_parse_time_range_1w() -> None: + """'1w' → 604800.""" + assert _parse_time_range("1w") == 604800 + + +def test_parse_time_range_case_insensitive() -> None: + """'1H', '1D' → correct values.""" + assert _parse_time_range("1H") == 3600 + assert _parse_time_range("1D") == 86400 + assert _parse_time_range("1W") == 604800 + + +def test_parse_time_range_invalid() -> None: + """'invalid' → None.""" + assert _parse_time_range("invalid") is None + assert _parse_time_range("") is None + assert _parse_time_range("abc") is None + assert _parse_time_range("1x") is None + + +def test_parse_time_range_with_whitespace() -> None: + """' 1d ' → 86400 (strips whitespace).""" + assert _parse_time_range(" 1d ") == 86400 + + +def test_parse_time_range_multi_digit() -> None: + """'24h' → 86400.""" + assert _parse_time_range("24h") == 86400 + + +# -- _filter_by_time unit tests -- + + +def test_filter_by_time_keeps_recent() -> None: + """Messages within time range are kept.""" + now = datetime.now() + recent = (now - timedelta(seconds=1800)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(seconds=7200)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": recent, "message": "recent"}, + {"timestamp": old, "message": "old"}, + ] + + result = _filter_by_time(messages, 3600) # 1 hour + assert len(result) == 1 + assert result[0]["message"] == "recent" + + +def test_filter_by_time_removes_old() -> None: + """Messages outside time range are removed.""" + now = datetime.now() + old1 = (now - timedelta(days=2)).strftime("%Y-%m-%d %H:%M:%S") + old2 = (now - timedelta(days=3)).strftime("%Y-%m-%d %H:%M:%S") + + messages = [ + {"timestamp": old1, "message": "old1"}, + {"timestamp": old2, "message": "old2"}, + ] + + result = _filter_by_time(messages, 86400) # 1 day + assert len(result) == 0 + + +def test_filter_by_time_missing_timestamp() -> None: + """Messages without timestamp are filtered out.""" + messages = [ + {"message": "no timestamp"}, + {"timestamp": "", "message": "empty timestamp"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +def test_filter_by_time_invalid_timestamp() -> None: + """Messages with invalid timestamp format are filtered out.""" + messages = [ + {"timestamp": "invalid-format", "message": "bad format"}, + {"timestamp": "2024-13-45 99:99:99", "message": "impossible date"}, + ] + + result = _filter_by_time(messages, 3600) + assert len(result) == 0 + + +# -- _format_messages unit tests -- + + +def test_format_messages_basic() -> None: + """Basic formatting with timestamp, name, and message.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "Hello", + "role": "", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] Alice: Hello" + + +def test_format_messages_with_role() -> None: + """Role is included when not 'member' or empty.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Bob", + "message": "Hi", + "role": "admin", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] (admin) Bob: Hi" + + +def test_format_messages_with_title() -> None: + """Title is included when present.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Charlie", + "message": "Test", + "role": "", + "title": "群主", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] [群主] Charlie: Test" + + +def test_format_messages_with_title_and_role() -> None: + """Both title and role are included.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Dave", + "message": "Message", + "role": "owner", + "title": "管理员", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] [管理员] (owner) Dave: Message" + + +def test_format_messages_role_member_excluded() -> None: + """Role 'member' is not included.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Eve", + "message": "Text", + "role": "member", + "title": "", + }, + ] + + result = _format_messages(messages) + assert result == "[2024-01-01 12:00:00] Eve: Text" + + +def test_format_messages_multiple() -> None: + """Multiple messages are separated by newlines.""" + messages = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "First", + "role": "", + "title": "", + }, + { + "timestamp": "2024-01-01 12:01:00", + "display_name": "Bob", + "message": "Second", + "role": "admin", + "title": "", + }, + ] + + result = _format_messages(messages) + expected = ( + "[2024-01-01 12:00:00] Alice: First\n[2024-01-01 12:01:00] (admin) Bob: Second" + ) + assert result == expected + + +def test_format_messages_missing_fields() -> None: + """Missing fields default to empty or '未知用户'.""" + messages = [ + { + "timestamp": "", + "message": "No timestamp", + }, + ] + + result = _format_messages(messages) + assert "未知用户" in result + assert "No timestamp" in result + + +# -- execute function tests -- + + +@pytest.mark.asyncio +async def test_fetch_messages_count_based_group() -> None: + """Count-based fetch in group context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "Alice", + "message": "Message 1", + "role": "", + "title": "", + }, + { + "timestamp": "2024-01-01 12:01:00", + "display_name": "Bob", + "message": "Message 2", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute({"count": 50}, context) + + assert "共获取 2 条消息" in result + assert "Alice: Message 1" in result + assert "Bob: Message 2" in result + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_based_private() -> None: + """Count-based fetch in private context.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": "2024-01-01 12:00:00", + "display_name": "User", + "message": "Private message", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 0, + "user_id": 99999, + } + + result = await fetch_messages_execute({"count": 20}, context) + + assert "共获取 1 条消息" in result + assert "Private message" in result + history_manager.get_recent.assert_called_once_with("99999", "private", 0, 20) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range() -> None: + """Time-range fetch filters by time.""" + now = datetime.now() + recent = (now - timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S") + old = (now - timedelta(hours=25)).strftime("%Y-%m-%d %H:%M:%S") + + history_manager = MagicMock() + history_manager.get_recent.return_value = [ + { + "timestamp": recent, + "display_name": "Alice", + "message": "Recent", + "role": "", + "title": "", + }, + { + "timestamp": old, + "display_name": "Bob", + "message": "Old", + "role": "", + "title": "", + }, + ] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "user_id": 0, + } + + result = await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + assert "共获取 1 条消息" in result + assert "(时间范围: 1d)" in result + assert "Recent" in result + assert "Old" not in result + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_time_range() -> None: + """Invalid time range returns error message.""" + context: dict[str, Any] = { + "history_manager": MagicMock(), + "group_id": 123456, + } + + result = await fetch_messages_execute( + {"time_range": "invalid"}, + context, + ) + + assert "无法解析时间范围: invalid" in result + assert "支持格式: 1h, 6h, 1d, 7d" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_empty_history() -> None: + """Empty history returns '当前会话暂无消息记录'.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "当前会话暂无消息记录" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_no_history_manager() -> None: + """No history_manager returns error.""" + context: dict[str, Any] = { + "group_id": 123456, + } + + result = await fetch_messages_execute({}, context) + + assert "历史记录管理器未配置" in result + + +@pytest.mark.asyncio +async def test_fetch_messages_count_capped_at_500() -> None: + """Count is capped at 500.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 500) + + +@pytest.mark.asyncio +async def test_fetch_messages_default_count() -> None: + """Default count is 50 when not specified.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_invalid_count_defaults() -> None: + """Invalid count defaults to 50.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute({"count": "invalid"}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 50) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_fetch_larger_batch() -> None: + """Time range mode fetches larger batch (max(count*2, 2000)).""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + } + + await fetch_messages_execute( + {"count": 50, "time_range": "1d"}, + context, + ) + + # max(50*2, 2000) = 2000 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 2000) diff --git a/tests/test_handlers_meme_annotation.py b/tests/test_handlers_meme_annotation.py new file mode 100644 index 0000000..a8ff4d6 --- /dev/null +++ b/tests/test_handlers_meme_annotation.py @@ -0,0 +1,277 @@ +"""测试 handlers.py 中的表情包自动匹配功能""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from Undefined.attachments import AttachmentRecord, AttachmentRegistry + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_success(tmp_path: Path) -> None: + """测试成功匹配表情包时添加描述""" + # 创建 mock handler + handler = MagicMock() + + # 设置 attachment registry + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 添加一个附件记录 + test_sha256 = "abc123def456" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # 设置 meme service mock + mock_meme_record = MagicMock() + mock_meme_record.description = "可爱的猫猫" + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=mock_meme_record) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + # 导入实现 + from Undefined.handlers import MessageHandler + + # 测试输入 + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "file_002", "kind": "file"}, + ] + + # 调用函数 + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + # 第一个附件应该有表情包描述 + assert result[0]["uid"] == "pic_001" + assert result[0]["description"] == "[表情包] 可爱的猫猫" + # 第二个附件不应该被修改 + assert result[1]["uid"] == "file_002" + assert "description" not in result[1] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_no_match(tmp_path: Path) -> None: + """测试没有匹配表情包时返回原始列表""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + test_sha256 = "xyz789" + attachment_record = AttachmentRecord( + uid="pic_001", + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name="test.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=test_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records["pic_001"] = attachment_record + + # meme store 返回 None(没找到) + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(return_value=None) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert len(result) == 1 + assert result[0]["uid"] == "pic_001" + assert "description" not in result[0] + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_meme_disabled() -> None: + """测试 meme service 禁用时返回原始列表""" + handler = MagicMock() + + # meme service 禁用 + mock_meme_service = MagicMock() + mock_meme_service.enabled = False + + ai = MagicMock() + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表 + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_error_handling() -> None: + """测试异常处理:失败时返回原始列表""" + handler = MagicMock() + + # 设置会抛出异常的 attachment registry + mock_attachment_registry = MagicMock() + mock_attachment_registry.resolve = MagicMock( + side_effect=Exception("Registry error") + ) + + # 设置 meme service + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = AsyncMock(side_effect=Exception("Database error")) + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = mock_attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [{"uid": "pic_001", "kind": "image"}] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 应该返回原始列表(异常被捕获) + assert result == input_attachments + + +@pytest.mark.asyncio +async def test_annotate_meme_descriptions_batch_query(tmp_path: Path) -> None: + """测试批量查询:多个附件共享同一个哈希值""" + handler = MagicMock() + + registry_path = tmp_path / "registry.json" + cache_dir = tmp_path / "cache" + attachment_registry = AttachmentRegistry( + registry_path=registry_path, cache_dir=cache_dir + ) + + # 两个附件,相同的 SHA256 + shared_sha256 = "shared123" + for uid in ["pic_001", "pic_002"]: + record = AttachmentRecord( + uid=uid, + scope_key="group:10001", + kind="image", + media_type="image/png", + display_name=f"{uid}.png", + source_kind="test", + source_ref="", + local_path=None, + mime_type="image/png", + sha256=shared_sha256, + created_at="2024-01-01T00:00:00Z", + segment_data={}, + ) + attachment_registry._records[uid] = record + + # 记录 find_by_sha256 被调用的次数 + call_count = 0 + + async def mock_find_by_sha256(sha: str) -> Any: + nonlocal call_count + call_count += 1 + if sha == shared_sha256: + meme = MagicMock() + meme.description = "共享表情包" + return meme + return None + + mock_meme_store = MagicMock() + mock_meme_store.find_by_sha256 = mock_find_by_sha256 + + mock_meme_service = MagicMock() + mock_meme_service.enabled = True + mock_meme_service._store = mock_meme_store + + ai = MagicMock() + ai.attachment_registry = attachment_registry + ai._meme_service = mock_meme_service + + handler.ai = ai + + from Undefined.handlers import MessageHandler + + input_attachments = [ + {"uid": "pic_001", "kind": "image"}, + {"uid": "pic_002", "kind": "image"}, + ] + + result = await MessageHandler._annotate_meme_descriptions( + handler, input_attachments, "group:10001" + ) + + # 验证结果 + assert len(result) == 2 + assert result[0]["description"] == "[表情包] 共享表情包" + assert result[1]["description"] == "[表情包] 共享表情包" + + # 验证 find_by_sha256 只被调用一次(批量查询去重) + assert call_count == 1 diff --git a/tests/test_history_level.py b/tests/test_history_level.py new file mode 100644 index 0000000..59a88d4 --- /dev/null +++ b/tests/test_history_level.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from Undefined.utils.history import MessageHistoryManager + + +@pytest.mark.asyncio +async def test_add_group_message_with_level_stores_level_field( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加带 level 的群消息会正确存储 level 字段""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + role="member", + title="", + level="Lv.5", + message_id=123456, + ) + + assert "20001" in manager._message_history + assert len(manager._message_history["20001"]) == 1 + + record = manager._message_history["20001"][0] + assert record["level"] == "Lv.5" + assert record["type"] == "group" + assert record["chat_id"] == "20001" + assert record["user_id"] == "10001" + assert record["message"] == "测试消息" + assert record["role"] == "member" + assert record["title"] == "" + assert record["message_id"] == 123456 + + +@pytest.mark.asyncio +async def test_add_group_message_without_level_stores_empty_level( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息不传 level 参数时默认存储空字符串""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + sender_card="测试用户", + group_name="测试群", + ) + + assert "20001" in manager._message_history + record = manager._message_history["20001"][0] + assert record["level"] == "" + + +@pytest.mark.asyncio +async def test_get_recent_returns_messages_with_level_intact( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试 get_recent 能正确返回带 level 的消息""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = { + "20001": [ + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10001", + "display_name": "测试用户", + "role": "admin", + "title": "管理员", + "level": "Lv.10", + "timestamp": "2026-04-11 10:00:00", + "message": "第一条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10002", + "display_name": "普通用户", + "role": "member", + "title": "", + "level": "Lv.2", + "timestamp": "2026-04-11 10:01:00", + "message": "第二条消息", + }, + { + "type": "group", + "chat_id": "20001", + "chat_name": "测试群", + "user_id": "10003", + "display_name": "新用户", + "role": "member", + "title": "", + "level": "", + "timestamp": "2026-04-11 10:02:00", + "message": "第三条消息", + }, + ] + } + manager._max_records = 10000 + _evt = asyncio.Event() + _evt.set() + manager._initialized = _evt + + messages = manager.get_recent("20001", "group", 0, 10) + + assert len(messages) == 3 + assert messages[0]["level"] == "Lv.10" + assert messages[1]["level"] == "Lv.2" + assert messages[2]["level"] == "" + + +@pytest.mark.asyncio +async def test_add_group_message_with_empty_level_string( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试添加群消息显式传入空字符串 level""" + manager = MessageHistoryManager.__new__(MessageHistoryManager) + manager._message_history = {} + manager._max_records = 10000 + manager._initialized = asyncio.Event() + manager._initialized.set() + manager._group_locks = {} + + saved_data: dict[str, list[dict[str, object]]] = {} + + async def fake_save(data: list[dict[str, object]], path: str) -> None: + saved_data[path] = data + + monkeypatch.setattr(manager, "_save_history_to_file", fake_save) + + await manager.add_group_message( + group_id=20001, + sender_id=10001, + text_content="测试消息", + level="", + ) + + record = manager._message_history["20001"][0] + assert "level" in record + assert record["level"] == "" diff --git a/tests/test_message_tools_level.py b/tests/test_message_tools_level.py new file mode 100644 index 0000000..8199502 --- /dev/null +++ b/tests/test_message_tools_level.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any + + +from Undefined.skills.toolsets.messages.get_recent_messages.handler import ( + _format_message_xml, +) + + +def test_format_message_xml_group_with_all_attributes() -> None: + """测试群消息包含 role/title/level 时 XML 全部显示""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "message_id": 123456, + "role": "admin", + "title": "管理员", + "level": "Lv.10", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="管理员"' in result + assert 'level="Lv.10"' in result + assert 'message_id="123456"' in result + assert "测试消息" in result + + +def test_format_message_xml_group_with_empty_level() -> None: + """测试群消息 level 为空时 XML 不包含 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + "level": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + assert "title=" not in result + + +def test_format_message_xml_private_without_level() -> None: + """测试私聊消息不包含 role/title/level 属性""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result + assert "私聊消息" in result + + +def test_format_message_xml_group_with_only_level() -> None: + """测试群消息只设置 level 时仅显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "", + "title": "", + "level": "Lv.5", + } + + result = _format_message_xml(msg) + + assert 'level="Lv.5"' in result + assert "role=" not in result + assert "title=" not in result + + +def test_format_message_xml_group_without_level_key() -> None: + """测试群消息没有 level 键时不显示 level 属性""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "member", + "title": "", + } + + result = _format_message_xml(msg) + + assert "level=" not in result + assert 'role="member"' in result + + +def test_format_message_xml_group_with_role_and_title() -> None: + """测试群消息有 role 和 title 但无 level""" + msg: dict[str, Any] = { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "role": "admin", + "title": "群主", + "level": "", + } + + result = _format_message_xml(msg) + + assert 'role="admin"' in result + assert 'title="群主"' in result + assert "level=" not in result + + +def test_format_message_xml_private_with_level_ignored() -> None: + """测试私聊消息即使有 level 也不显示""" + msg: dict[str, Any] = { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊消息", + "role": "member", + "title": "测试", + "level": "Lv.99", + } + + result = _format_message_xml(msg) + + assert "role=" not in result + assert "title=" not in result + assert "level=" not in result diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py new file mode 100644 index 0000000..c4fd024 --- /dev/null +++ b/tests/test_profile_command.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.profile.handler import execute as profile_execute + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + cognitive_service: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, +) -> CommandContext: + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=stub, + sender=cast(Any, sender), + ai=stub, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + cognitive_service=cognitive_service, + ) + + +# -- Private chat tests -- + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_found() -> None: + """Private chat, own profile found → sends profile via send_private_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="这是一个用户侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=99999, + user_id=99999, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert sender.private_messages[0][0] == 99999 + assert "这是一个用户侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("users", "99999") + + +@pytest.mark.asyncio +async def test_profile_private_own_profile_not_found() -> None: + """Private chat, own profile not found → sends '暂无侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + await profile_execute([], context) + + assert len(sender.private_messages) == 1 + assert "📭 暂无侧写数据" in sender.private_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_private_group_subcommand_rejected() -> None: + """Private chat, `/profile group` rejected → sends error message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=77777, + user_id=77777, + ) + + await profile_execute(["group"], context) + + assert len(sender.private_messages) == 1 + assert "❌ 私聊中不支持查看群聊侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +# -- Group chat tests -- + + +@pytest.mark.asyncio +async def test_profile_group_own_profile() -> None: + """Group chat, own profile → sends profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群成员侧写数据") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=55555, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 123456 + assert "群成员侧写数据" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("users", "55555") + + +@pytest.mark.asyncio +async def test_profile_group_profile_subcommand() -> None: + """Group chat, `/profile group` → sends group profile via send_group_message.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="群聊整体侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=654321, + sender_id=44444, + ) + + await profile_execute(["GROUP"], context) # Test case-insensitive + + assert len(sender.group_messages) == 1 + assert sender.group_messages[0][0] == 654321 + assert "群聊整体侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("groups", "654321") + + +@pytest.mark.asyncio +async def test_profile_group_profile_not_found() -> None: + """Group chat, group profile not found → sends '暂无群聊侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value=None) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=111111, + sender_id=33333, + ) + + await profile_execute(["group"], context) + + assert len(sender.group_messages) == 1 + assert "📭 暂无群聊侧写数据" in sender.group_messages[0][1] + + +# -- Edge cases -- + + +@pytest.mark.asyncio +async def test_profile_no_cognitive_service() -> None: + """No cognitive_service → sends '侧写服务未启用'.""" + sender = _DummySender() + + context = _build_context( + sender=sender, + cognitive_service=None, + scope="group", + group_id=123456, + sender_id=22222, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 侧写服务未启用" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_profile_truncation() -> None: + """Profile > 3000 chars gets truncated.""" + sender = _DummySender() + cognitive_service = AsyncMock() + long_profile = "A" * 3500 # Longer than 3000 chars + cognitive_service.get_profile = AsyncMock(return_value=long_profile) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=11111, + ) + + await profile_execute([], context) + + assert len(sender.group_messages) == 1 + message = sender.group_messages[0][1] + assert len(message) <= 3100 # 3000 + truncation notice + assert "[侧写过长,已截断]" in message + assert message.count("A") == 3000 # Exactly 3000 'A's before truncation diff --git a/tests/test_prompts_level.py b/tests/test_prompts_level.py new file mode 100644 index 0000000..239786a --- /dev/null +++ b/tests/test_prompts_level.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from Undefined.ai.prompts import PromptBuilder +from Undefined.end_summary_storage import EndSummaryRecord + + +class _FakeEndSummaryStorage: + async def load(self) -> list[EndSummaryRecord]: + return [] + + +class _FakeCognitiveService: + enabled = False + + +class _FakeAnthropicSkillRegistry: + def has_skills(self) -> bool: + return False + + +def _make_builder() -> PromptBuilder: + """创建用于测试的 PromptBuilder 实例""" + runtime_config = SimpleNamespace( + keyword_reply_enabled=False, + knowledge_enabled=False, + grok_search_enabled=False, + chat_model=SimpleNamespace( + model_name="gpt-4.1", + pool=SimpleNamespace(enabled=False), + thinking_enabled=False, + reasoning_enabled=False, + ), + vision_model=SimpleNamespace(model_name="gpt-4.1-mini"), + agent_model=SimpleNamespace(model_name="gpt-4.1-mini"), + embedding_model=SimpleNamespace(model_name="text-embedding-3-small"), + security_model=SimpleNamespace(model_name="gpt-4.1-mini"), + grok_model=SimpleNamespace(model_name="grok-4-search"), + cognitive=SimpleNamespace(enabled=False, recent_end_summaries_inject_k=0), + memes=SimpleNamespace( + enabled=False, + query_default_mode="hybrid", + allow_gif=False, + max_source_image_bytes=512000, + ), + ) + return PromptBuilder( + bot_qq=123456, + memory_storage=None, + end_summary_storage=cast(Any, _FakeEndSummaryStorage()), + runtime_config_getter=lambda: runtime_config, + anthropic_skill_registry=cast(Any, _FakeAnthropicSkillRegistry()), + cognitive_service=cast(Any, _FakeCognitiveService()), + ) + + +@pytest.mark.asyncio +async def test_group_message_with_level_includes_level_attribute( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """测试群消息有 level 时 XML 包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "Lv.5", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert 'level="Lv.5"' in history_message + assert " None: + """测试群消息 level 为空字符串时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + "level": "", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + """测试群消息没有 level 键时 XML 不包含 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "group", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "20001", + "chat_name": "测试群", + "timestamp": "2026-04-11 10:00:00", + "message": "测试消息", + "attachments": [], + "role": "member", + "title": "", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "group_id": 20001, + "sender_id": 10001, + "sender_name": "测试用户", + "group_name": "测试群", + "request_type": "group", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + """测试私聊消息无论是否有 level 都不会出现 level 属性""" + builder = _make_builder() + + async def _fake_load_system_prompt() -> str: + return "系统提示词" + + async def _fake_load_each_rules() -> str: + return "规则" + + monkeypatch.setattr(builder, "_load_system_prompt", _fake_load_system_prompt) + monkeypatch.setattr(builder, "_load_each_rules", _fake_load_each_rules) + + async def _fake_recent_messages( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + _ = chat_id, msg_type, start, end + return [ + { + "type": "private", + "display_name": "测试用户", + "user_id": "10001", + "chat_id": "10001", + "chat_name": "QQ用户10001", + "timestamp": "2026-04-11 10:00:00", + "message": "私聊测试消息", + "attachments": [], + "level": "Lv.10", + } + ] + + messages = await builder.build_messages( + '测试', + get_recent_messages_callback=_fake_recent_messages, + extra_context={ + "sender_id": 10001, + "sender_name": "测试用户", + "request_type": "private", + }, + ) + + history_message = next( + str(msg.get("content", "")) + for msg in messages + if "【历史消息存档】" in str(msg.get("content", "")) + ) + + assert "level=" not in history_message + assert " None: + self.registered_items: list[dict[str, Any]] = [] + + async def register_bytes( + self, + scope_key: str, + data: bytes, + kind: str, + display_name: str, + mime_type: str, + source_kind: str, + source_ref: str, + ) -> Any: + class MockRecord: + uid = "test-uid-12345" + + record = MockRecord() + self.registered_items.append( + { + "scope_key": scope_key, + "size": len(data), + "kind": kind, + "display_name": display_name, + "mime_type": mime_type, + "source_kind": source_kind, + "source_ref": source_ref, + "uid": record.uid, + } + ) + return record -_PNG_HEADER = b"\x89PNG\r\n\x1a\n" +@pytest.mark.asyncio +async def test_render_simple_equation() -> None: + """测试渲染简单方程(无分隔符,自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute -def _build_context(registry: AttachmentRegistry) -> dict[str, Any]: - return { + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, "request_type": "group", - "group_id": 10001, - "sender_id": 20002, - "user_id": 20002, - "attachment_registry": registry, + "group_id": 123456, } + args = {"content": "E = mc^2", "output_format": "png"} -def test_strip_document_wrappers_removes_document_env() -> None: - content = ( - "\\begin{document}\n" - "\\[\n" - "\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n" - "\\]\n" - "\\end{document}" - ) - assert handler._strip_document_wrappers(content) == ( - "\\[\n\\int_{-\\infty}^{+\\infty} e^{-x^2} dx = \\sqrt{\\pi}\n\\]" - ) + try: + result = await execute(args, context) + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "image" + assert mock_registry.registered_items[0]["mime_type"] == "image/png" + assert mock_registry.registered_items[0]["size"] > 0 + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise -def test_strip_document_wrappers_passthrough_for_plain_formula() -> None: - content = r"\[ E = mc^2 \]" - assert handler._strip_document_wrappers(content) == content +@pytest.mark.asyncio +async def test_render_with_delimiters() -> None: + """测试带分隔符的内容(不自动包装)""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "private", + "user_id": 987654, + } + + args = {"content": r"\[ \int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2} \]"} + + try: + result = await execute(args, context) + assert result == '' + assert len(mock_registry.registered_items) == 1 + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise @pytest.mark.asyncio -async def test_render_latex_embed_success( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - registry = AttachmentRegistry( - registry_path=tmp_path / "attachment_registry.json", - cache_dir=tmp_path / "attachments", - ) - content = r"\[ \int_{-\infty}^{+\infty} e^{-x^2} dx = \sqrt{\pi} \]" - rendered_contents: list[str] = [] +async def test_render_pdf_output() -> None: + """测试 PDF 输出格式使用 attachment 标签""" + from Undefined.skills.toolsets.render.render_latex.handler import execute - monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + mock_registry = MockAttachmentRegistry() + context = { + "attachment_registry": mock_registry, + "request_type": "group", + "group_id": 123456, + } - def _fake_render(filepath: Path, render_content: str) -> None: - rendered_contents.append(render_content) - filepath.parent.mkdir(parents=True, exist_ok=True) - filepath.write_bytes(_PNG_HEADER) + args = {"content": r"\frac{a}{b} + \sqrt{c}", "output_format": "pdf"} - monkeypatch.setattr(handler, "_render_latex_image", _fake_render) + try: + result = await execute(args, context) + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "file" + assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" + assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" + except ImportError as e: + if "playwright" in str(e).lower(): + pytest.skip("Playwright 未安装,跳过测试") + raise - result = await handler.execute( - {"content": content, "delivery": "embed"}, - _build_context(registry), - ) - assert result.startswith(' None: + """测试空内容错误处理""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": " "} + + result = await execute(args, context) + assert "不能为空" in result @pytest.mark.asyncio -async def test_render_latex_returns_helpful_message_when_tex_missing( - tmp_path: Path, - monkeypatch: pytest.MonkeyPatch, -) -> None: - registry = AttachmentRegistry( - registry_path=tmp_path / "attachment_registry.json", - cache_dir=tmp_path / "attachments", +async def test_invalid_output_format() -> None: + """测试无效输出格式""" + from Undefined.skills.toolsets.render.render_latex.handler import execute + + context = {"attachment_registry": MockAttachmentRegistry()} + + args = {"content": "x = 1", "output_format": "svg"} + + result = await execute(args, context) + assert "无效" in result or "仅支持" in result + + +def test_strip_document_wrappers() -> None: + """测试去除 document 包装""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _strip_document_wrappers, ) - content = r"\[ a = b \]" - monkeypatch.setattr(paths, "RENDER_CACHE_DIR", tmp_path / "render") + content = r"\begin{document}E = mc^2\end{document}" + result = _strip_document_wrappers(content) + assert result == "E = mc^2" - def _raise_runtime(_: Path, __: str) -> None: - raise RuntimeError("latex was not able to process the following string") + # 没有包装的内容应该原样返回 + content_no_wrapper = r"E = mc^2" + result_no_wrapper = _strip_document_wrappers(content_no_wrapper) + assert result_no_wrapper == "E = mc^2" - monkeypatch.setattr(handler, "_render_latex_image", _raise_runtime) - result = await handler.execute( - {"content": content, "delivery": "embed"}, - _build_context(registry), +def test_has_math_delimiters() -> None: + """测试数学分隔符检测""" + from Undefined.skills.toolsets.render.render_latex.handler import ( + _has_math_delimiters, ) - assert "TeX Live" in result or "MiKTeX" in result + assert _has_math_delimiters(r"\[ x = 1 \]") is True + assert _has_math_delimiters(r"\( x = 1 \)") is True + assert _has_math_delimiters(r"$$ x = 1 $$") is True + assert _has_math_delimiters(r"\begin{equation}") is True + assert _has_math_delimiters("x = 1") is False + + +def test_prepare_content() -> None: + """测试内容准备逻辑""" + from Undefined.skills.toolsets.render.render_latex.handler import _prepare_content + + # 无分隔符,自动包装 + result = _prepare_content("E = mc^2") + assert result.startswith(r"\[") + assert result.endswith(r"\]") + assert "E = mc^2" in result + + # 有分隔符,不包装 + result_with_delim = _prepare_content(r"\[ E = mc^2 \]") + assert result_with_delim == r"\[ E = mc^2 \]" + + # 字面量 \\n 处理 + result_newline = _prepare_content(r"x = 1\\ny = 2") + assert "\n" in result_newline + assert "\\n" not in result_newline.replace(r"\[", "").replace(r"\]", "") diff --git a/tests/test_summary_agent.py b/tests/test_summary_agent.py new file mode 100644 index 0000000..72fe659 --- /dev/null +++ b/tests/test_summary_agent.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.skills.agents.summary_agent.handler import ( + execute as summary_agent_execute, +) + + +@pytest.mark.asyncio +async def test_summary_agent_normal_execution() -> None: + """Normal execution → calls run_agent_with_tools with correct params.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 123456, + "user_id": 10002, + "sender_id": 10002, + "request_type": "group", + "runtime_config": None, + "queue_lane": None, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="总结结果:讨论了技术话题"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结最近 50 条消息"}, + context, + ) + + assert result == "总结结果:讨论了技术话题" + mock_run_agent.assert_called_once() + + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["agent_name"] == "summary_agent" + assert call_kwargs["user_content"] == "请总结最近 50 条消息" + assert call_kwargs["empty_user_content_message"] == "请提供您的总结需求" + assert "消息总结助手" in call_kwargs["default_prompt"] + assert call_kwargs["context"] is context + assert isinstance(call_kwargs["agent_dir"], Path) + assert call_kwargs["max_iterations"] == 10 + assert call_kwargs["tool_error_prefix"] == "错误" + + +@pytest.mark.asyncio +async def test_summary_agent_empty_prompt() -> None: + """Empty prompt → returns '请提供您的总结需求'.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": ""}, context) + + assert result == "请提供您的总结需求" + mock_run_agent.assert_called_once() + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_whitespace_prompt() -> None: + """Whitespace-only prompt → treated as empty.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({"prompt": " "}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_missing_prompt_arg() -> None: + """Missing 'prompt' arg → defaults to empty string.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="请提供您的总结需求"), + ) as mock_run_agent: + result = await summary_agent_execute({}, context) + + assert result == "请提供您的总结需求" + call_kwargs = mock_run_agent.call_args.kwargs + assert call_kwargs["user_content"] == "" + + +@pytest.mark.asyncio +async def test_summary_agent_complex_prompt() -> None: + """Complex prompt with time range and custom instructions.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + "group_id": 654321, + "user_id": 99999, + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(return_value="详细总结内容"), + ) as mock_run_agent: + result = await summary_agent_execute( + {"prompt": "请总结过去 1d 内的聊天消息,重点关注:技术讨论"}, + context, + ) + + assert result == "详细总结内容" + call_kwargs = mock_run_agent.call_args.kwargs + assert ( + call_kwargs["user_content"] == "请总结过去 1d 内的聊天消息,重点关注:技术讨论" + ) + + +@pytest.mark.asyncio +async def test_summary_agent_propagates_exception() -> None: + """Exception from run_agent_with_tools propagates up.""" + context: dict[str, Any] = { + "ai_client": AsyncMock(), + "history_manager": AsyncMock(), + } + + with patch( + "Undefined.skills.agents.summary_agent.handler.run_agent_with_tools", + new=AsyncMock(side_effect=RuntimeError("Agent failure")), + ): + with pytest.raises(RuntimeError, match="Agent failure"): + await summary_agent_execute({"prompt": "test"}, context) diff --git a/tests/test_summary_command.py b/tests/test_summary_command.py new file mode 100644 index 0000000..7ce7550 --- /dev/null +++ b/tests/test_summary_command.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest + +from Undefined.services.commands.context import CommandContext +from Undefined.skills.commands.summary.handler import ( + _build_prompt, + _parse_args, + execute as summary_execute, +) + + +class _DummySender: + def __init__(self) -> None: + self.group_messages: list[tuple[int, str]] = [] + self.private_messages: list[tuple[int, str]] = [] + + async def send_group_message( + self, group_id: int, message: str, mark_sent: bool = False + ) -> None: + _ = mark_sent + self.group_messages.append((group_id, message)) + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + ) -> None: + _ = (auto_history, mark_sent) + self.private_messages.append((user_id, message)) + + +def _build_context( + *, + sender: _DummySender | None = None, + history_manager: Any = None, + scope: str = "group", + group_id: int = 123456, + sender_id: int = 10002, + user_id: int | None = None, + ai: Any = None, +) -> CommandContext: + stub = cast(Any, SimpleNamespace()) + if sender is None: + sender = _DummySender() + if ai is None: + ai = stub + return CommandContext( + group_id=group_id, + sender_id=sender_id, + config=stub, + sender=cast(Any, sender), + ai=ai, + faq_storage=stub, + onebot=stub, + security=stub, + queue_manager=None, + rate_limiter=None, + dispatcher=stub, + registry=stub, + scope=scope, + user_id=user_id, + history_manager=history_manager, + ) + + +# -- _parse_args unit tests -- + + +def test_parse_args_empty() -> None: + """Empty args → (50, None, '').""" + count, time_range, custom_prompt = _parse_args([]) + assert count == 50 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_count_only() -> None: + """['100'] → (100, None, '').""" + count, time_range, custom_prompt = _parse_args(["100"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_time_range_only() -> None: + """['1d'] → (None, '1d', '').""" + count, time_range, custom_prompt = _parse_args(["1d"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "" + + +def test_parse_args_count_with_custom_prompt() -> None: + """['100', '技术讨论'] → (100, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["100", "技术讨论"]) + assert count == 100 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_time_range_with_custom_prompt() -> None: + """['1d', '总结技术'] → (None, '1d', '总结技术').""" + count, time_range, custom_prompt = _parse_args(["1d", "总结技术"]) + assert count is None + assert time_range == "1d" + assert custom_prompt == "总结技术" + + +def test_parse_args_custom_prompt_only() -> None: + """['技术讨论'] → (50, None, '技术讨论').""" + count, time_range, custom_prompt = _parse_args(["技术讨论"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术讨论" + + +def test_parse_args_count_capped_at_500() -> None: + """['999'] → (500, None, '') (capped).""" + count, time_range, custom_prompt = _parse_args(["999"]) + assert count == 500 + assert time_range is None + assert custom_prompt == "" + + +def test_parse_args_multiple_words_prompt() -> None: + """['技术', '讨论', '总结'] → (50, None, '技术 讨论 总结').""" + count, time_range, custom_prompt = _parse_args(["技术", "讨论", "总结"]) + assert count == 50 + assert time_range is None + assert custom_prompt == "技术 讨论 总结" + + +# -- _build_prompt unit tests -- + + +def test_build_prompt_with_count() -> None: + """With count → '请总结最近 X 条聊天消息'.""" + prompt = _build_prompt(100, None, "") + assert "请总结最近 100 条聊天消息" in prompt + + +def test_build_prompt_with_time_range() -> None: + """With time_range → '请总结过去 X 内的聊天消息'.""" + prompt = _build_prompt(None, "1d", "") + assert "请总结过去 1d 内的聊天消息" in prompt + + +def test_build_prompt_with_custom_prompt() -> None: + """With custom_prompt → adds '重点关注:...'.""" + prompt = _build_prompt(50, None, "技术讨论") + assert "请总结最近 50 条聊天消息" in prompt + assert "重点关注:技术讨论" in prompt + + +def test_build_prompt_default_count() -> None: + """Default count when both are None.""" + prompt = _build_prompt(None, None, "") + assert "请总结最近 50 条聊天消息" in prompt + + +def test_build_prompt_time_range_and_custom() -> None: + """Time range with custom prompt.""" + prompt = _build_prompt(None, "6h", "重要公告") + assert "请总结过去 6h 内的聊天消息" in prompt + assert "重点关注:重要公告" in prompt + + +# -- Command execution tests -- + + +@pytest.mark.asyncio +async def test_summary_no_history_manager() -> None: + """No history_manager → sends error message.""" + sender = _DummySender() + context = _build_context( + sender=sender, + history_manager=None, + scope="group", + group_id=123456, + sender_id=10002, + ) + + await summary_execute([], context) + + assert len(sender.group_messages) == 1 + assert "❌ 历史记录管理器未配置" in sender.group_messages[0][1] + + +@pytest.mark.asyncio +async def test_summary_agent_call_success() -> None: + """Agent call success → result forwarded to user.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = None + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结内容:最近讨论了技术话题。"), + ) as mock_agent: + await summary_execute(["50"], context) + + assert len(sender.group_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.group_messages[0][1] + assert "总结内容:最近讨论了技术话题。" in sender.group_messages[1][1] + mock_agent.assert_called_once() + call_args = mock_agent.call_args + assert call_args[0][0]["prompt"] == "请总结最近 50 条聊天消息" + + +@pytest.mark.asyncio +async def test_summary_agent_call_failure() -> None: + """Agent call failure → sends error message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(side_effect=Exception("Agent error")), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.group_messages[0][1] + assert "❌ 消息总结失败,请稍后重试" in sender.group_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_agent_returns_empty() -> None: + """Agent returns empty result → sends '未能生成总结内容'.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=123456, + sender_id=10002, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value=" "), + ): + await summary_execute([], context) + + assert len(sender.group_messages) == 2 + assert "📭 未能生成总结内容" in sender.group_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_private_chat() -> None: + """Private chat → uses send_private_message.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="private", + group_id=0, + sender_id=88888, + user_id=88888, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="私聊总结结果"), + ): + await summary_execute(["1d", "重要消息"], context) + + assert len(sender.private_messages) == 2 + assert "📝 正在总结消息,请稍候..." in sender.private_messages[0][1] + assert "私聊总结结果" in sender.private_messages[1][1] + + +@pytest.mark.asyncio +async def test_summary_passes_correct_context_to_agent() -> None: + """Agent receives correct context parameters.""" + sender = _DummySender() + history_manager = AsyncMock() + ai = AsyncMock() + ai.runtime_config = SimpleNamespace(some_config="value") + + context = _build_context( + sender=sender, + history_manager=history_manager, + ai=ai, + scope="group", + group_id=999888, + sender_id=777666, + user_id=None, + ) + + with patch( + "Undefined.skills.agents.summary_agent.handler.execute", + new=AsyncMock(return_value="总结"), + ) as mock_agent: + await summary_execute([], context) + + call_args = mock_agent.call_args + agent_context = call_args[0][1] + assert agent_context["ai_client"] is ai + assert agent_context["history_manager"] is history_manager + assert agent_context["group_id"] == 999888 + assert agent_context["sender_id"] == 777666 + assert agent_context["user_id"] == 777666 + assert agent_context["request_type"] == "group" + assert agent_context["runtime_config"] is ai.runtime_config From 0e735cba1268a9c36ce8daa262b20337925637f8 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 14:51:56 +0800 Subject: [PATCH 06/57] fix(render_latex): pass proxy config to Playwright for CDN access Read use_proxy/http_proxy/https_proxy from runtime_config and forward to chromium.launch(proxy=...) so MathJax CDN loads correctly on servers requiring a proxy. Also update use_proxy comment to be generic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 4 +-- .../toolsets/render/render_latex/handler.py | 32 +++++++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/config.toml.example b/config.toml.example index 0e7fe2c..678dd09 100644 --- a/config.toml.example +++ b/config.toml.example @@ -716,8 +716,8 @@ grok_search_enabled = false # zh: 代理设置(可选)。 # en: Proxy settings (optional). [proxy] -# zh: 是否在 "web_agent" 中启用代理。 -# en: Enable proxy in "web_agent". +# zh: 是否使用代理。 +# en: Whether to use proxy. use_proxy = true # zh: 例如 http://127.0.0.1:7890(也可使用环境变量 "HTTP_PROXY")。 # en: e.g. http://127.0.0.1:7890 (or use the "HTTP_PROXY" environment variable). diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index a449cd2..01e8ba3 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -77,7 +77,9 @@ def _build_html(latex_content: str) -> str: """ -async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[bytes, str]: +async def _render_latex_to_bytes( + content: str, output_format: str, proxy: str | None = None +) -> tuple[bytes, str]: """ 使用 MathJax + Playwright 渲染 LaTeX 内容。 @@ -95,8 +97,13 @@ async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[byte html_content = _build_html(content) + launch_kwargs: dict[str, object] = {"headless": True} + if proxy: + launch_kwargs["proxy"] = {"server": proxy} + logger.info("LaTeX 渲染使用代理: %s", proxy) + async with async_playwright() as p: - browser = await p.chromium.launch(headless=True) + browser = await p.chromium.launch(**launch_kwargs) # type: ignore[arg-type] try: page = await browser.new_page() await page.set_content(html_content) @@ -143,6 +150,22 @@ async def _render_latex_to_bytes(content: str, output_format: str) -> tuple[byte await browser.close() +async def _resolve_proxy(context: Dict[str, Any]) -> str | None: + """从 context 的 runtime_config 中解析代理地址。""" + from Undefined.config import get_config + + runtime_config = context.get("runtime_config") or get_config(strict=False) + if runtime_config is None: + return None + use_proxy: bool = getattr(runtime_config, "use_proxy", False) + if not use_proxy: + return None + proxy: str = getattr(runtime_config, "http_proxy", "") or getattr( + runtime_config, "https_proxy", "" + ) + return proxy or None + + async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: """渲染 LaTeX 数学公式为图片或 PDF""" raw_content = str(args.get("content", "") or "") @@ -159,9 +182,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 准备内容 prepared_content = _prepare_content(raw_content) + # 解析代理 + proxy = await _resolve_proxy(context) + # 渲染 rendered_bytes, mime_type = await _render_latex_to_bytes( - prepared_content, output_format + prepared_content, output_format, proxy=proxy ) # 注册到附件系统 From 1571d7e60428fa749f9883c98eceadfd5fb0c696 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:08:29 +0800 Subject: [PATCH 07/57] =?UTF-8?q?feat(config):=20sync=5Fconfig=5Ftemplate?= =?UTF-8?q?=20=E6=8A=A5=E5=91=8A=E6=B3=A8=E9=87=8A=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=EF=BC=9Bgif=5Fanalysis=5Fmode=20=E4=B8=8B?= =?UTF-8?q?=E6=8B=89=E9=80=89=E6=8B=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ConfigTemplateSyncResult 新增 updated_comment_paths 字段 - sync_config_text() 对比 current/example 注释,记录有改动的路径 - 脚本以 ~ 前缀展示注释更新项数量和路径列表 - 修复测试 mock 对象缺少新字段 - config-form.js:gif_analysis_mode 渲染为 grid/multi 下拉框 Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- scripts/sync_config_template.py | 5 ++++ src/Undefined/webui/static/js/config-form.js | 1 + src/Undefined/webui/utils/config_sync.py | 24 +++++++++++++++++--- tests/test_sync_config_template_script.py | 1 + 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/scripts/sync_config_template.py b/scripts/sync_config_template.py index e432001..583b9eb 100755 --- a/scripts/sync_config_template.py +++ b/scripts/sync_config_template.py @@ -96,6 +96,11 @@ def main() -> int: for path in result.added_paths: print(f" + {path}") + if result.updated_comment_paths: + print(f"[sync-config] 注释更新数量: {len(result.updated_comment_paths)}") + for path in result.updated_comment_paths: + print(f" ~ {path}") + if result.removed_paths: print(f"[sync-config] 多余路径数量: {len(result.removed_paths)}") for path in result.removed_paths: diff --git a/src/Undefined/webui/static/js/config-form.js b/src/Undefined/webui/static/js/config-form.js index ffadfe0..111d64a 100644 --- a/src/Undefined/webui/static/js/config-form.js +++ b/src/Undefined/webui/static/js/config-form.js @@ -251,6 +251,7 @@ function isLongText(value) { const FIELD_SELECT_OPTIONS = { api_mode: ["chat_completions", "responses"], + gif_analysis_mode: ["grid", "multi"], reasoning_effort_style: ["openai", "anthropic"], // path -> options key mapping (underscore-separated segments) image_gen_provider: ["xingzhige", "models"], diff --git a/src/Undefined/webui/utils/config_sync.py b/src/Undefined/webui/utils/config_sync.py index 46c16ee..c3df03a 100644 --- a/src/Undefined/webui/utils/config_sync.py +++ b/src/Undefined/webui/utils/config_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import dataclasses import tomllib from dataclasses import dataclass from pathlib import Path @@ -21,6 +22,7 @@ class ConfigTemplateSyncResult: added_paths: list[str] removed_paths: list[str] comments: CommentMap + updated_comment_paths: list[str] = dataclasses.field(default_factory=list) def _parse_toml_text(content: str, *, label: str) -> TomlData: @@ -226,6 +228,17 @@ def _augment_pool_model_comments( return merged +def _collect_updated_comment_paths( + current: CommentMap, example: CommentMap +) -> list[str]: + """收集在 example 与 current 中都存在、但注释文本不同的路径。""" + return [ + key + for key, example_value in example.items() + if key in current and current[key] != example_value + ] + + def _merge_comment_maps(current: CommentMap, example: CommentMap) -> CommentMap: merged: CommentMap = dict(current) for key, value in example.items(): @@ -248,16 +261,21 @@ def sync_config_text( if prune and removed_paths: merged = _prune_to_template(merged, prepared_example_data) example_comments = parse_comment_map_text(example_text) - comments = _merge_comment_maps( - parse_comment_map_text(current_text), - _augment_pool_model_comments(example_comments, example_data, current_data), + current_comments = parse_comment_map_text(current_text) + augmented_example_comments = _augment_pool_model_comments( + example_comments, example_data, current_data + ) + updated_comment_paths = _collect_updated_comment_paths( + current_comments, augmented_example_comments ) + comments = _merge_comment_maps(current_comments, augmented_example_comments) content = render_toml(merged, comments=comments) return ConfigTemplateSyncResult( content=content, added_paths=added_paths, removed_paths=removed_paths, comments=comments, + updated_comment_paths=updated_comment_paths, ) diff --git a/tests/test_sync_config_template_script.py b/tests/test_sync_config_template_script.py index 3f14aec..e422138 100644 --- a/tests/test_sync_config_template_script.py +++ b/tests/test_sync_config_template_script.py @@ -43,6 +43,7 @@ def fake_sync_config_file( added_paths=[], removed_paths=["models.chat.extra"], comments={}, + updated_comment_paths=[], ) monkeypatch.setattr(module, "sync_config_file", fake_sync_config_file) From 987ac132fb7feabf9c4a61d91ac35b4bfa36e3f5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:21:57 +0800 Subject: [PATCH 08/57] =?UTF-8?q?feat(repeat):=20bot=20=E5=8F=91=E8=A8=80?= =?UTF-8?q?=E4=B8=8D=E8=AE=A1=E5=85=A5=E5=A4=8D=E8=AF=BB=E9=93=BE=EF=BC=9B?= =?UTF-8?q?=E9=98=88=E5=80=BC=E5=8F=AF=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - bot 消息写入 counter(而非过滤)使窗口对其可见 - 触发条件增加 bot_qq not in senders 检查 覆盖三种情形:bot 先发、bot 中间插入、bot 滑出窗口后正常触发 - RuntimeConfig 新增 repeat_threshold(范围 2–20,默认 3) - 硬编码 3 / 5 替换为 n = repeat_threshold - config.toml.example 更新 repeat_enabled 注释并新增 repeat_threshold 字段 - 新增 5 个测试,共 17 个复读测试全部通过 Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 7 ++- src/Undefined/config/loader.py | 14 ++++++ src/Undefined/handlers.py | 33 ++++++++++---- tests/test_handlers_repeat.py | 78 +++++++++++++++++++++++++++++++++- 4 files changed, 120 insertions(+), 12 deletions(-) diff --git a/config.toml.example b/config.toml.example index 678dd09..ad25b66 100644 --- a/config.toml.example +++ b/config.toml.example @@ -659,9 +659,12 @@ agent_call_message_enabled = "none" # zh: 是否启用群聊关键词("心理委员")自动回复。 # en: Enable keyword auto-replies("心理委员") in group chats. keyword_reply_enabled = false -# zh: 是否启用群聊复读功能(连续3条相同消息时复读)。 -# en: Enable repeat feature in group chats (repeat when 3 consecutive identical messages). +# zh: 是否启用群聊复读功能(连续 N 条相同消息时复读,N 由 repeat_threshold 控制;若期间有 bot 自身发言则重置链)。 +# en: Enable repeat feature in group chats (repeat when N consecutive identical messages arrive, N set by repeat_threshold; resets if bot itself sent the same text in between). repeat_enabled = false +# zh: 复读触发所需的连续相同消息条数(来自不同发送者),范围 2–20,默认 3。 +# en: Number of consecutive identical messages (from different senders) required to trigger repeat, range 2–20, default 3. +repeat_threshold = 3 # zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 # en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). inverted_question_enabled = false diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index f7d7f52..00b8c3a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -469,6 +469,7 @@ class Config: process_poke_message: bool keyword_reply_enabled: bool repeat_enabled: bool + repeat_threshold: int inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int @@ -723,6 +724,18 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ), False, ) + repeat_threshold = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_threshold"), + "EASTER_EGG_REPEAT_THRESHOLD", + ), + 3, + ) + if repeat_threshold < 2: + repeat_threshold = 2 + if repeat_threshold > 20: + repeat_threshold = 20 context_recent_messages_limit = _coerce_int( _get_value( data, @@ -1396,6 +1409,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi process_poke_message=process_poke_message, keyword_reply_enabled=keyword_reply_enabled, repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 8cfeae9..db8fba2 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -707,7 +707,17 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) # 如果是 bot 自己的消息,只保存不触发回复,避免无限循环 + # 同时把 bot 自身的发言写入复读计数器,使窗口中留有 bot 标记, + # 后续触发检查时会排除含 bot 的窗口,防止"bot 先发 → 用户跟发"或 + # "用户发到一半 bot 插入"等情况误触复读。 if sender_id == self.config.bot_qq: + if self.config.repeat_enabled and text: + async with self._get_repeat_lock(group_id): + counter = self._repeat_counter.setdefault(group_id, []) + counter.append((text, sender_id)) + n = self.config.repeat_threshold + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] return self._schedule_meme_ingest( @@ -776,21 +786,26 @@ async def handle_message(self, event: dict[str, Any]) -> None: ) return - # 复读功能:连续3条相同消息(来自不同发送者)时复读 + # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold if self.config.repeat_enabled: + n = self.config.repeat_threshold async with self._get_repeat_lock(group_id): counter = self._repeat_counter.setdefault(group_id, []) counter.append((text, sender_id)) - # 只保留最近5条 - if len(counter) > 5: - self._repeat_counter[group_id] = counter[-5:] + # 只保留最近 n 条 + if len(counter) > n: + self._repeat_counter[group_id] = counter[-n:] counter = self._repeat_counter[group_id] - if len(counter) >= 3: - last3 = counter[-3:] - texts = [t for t, _ in last3] - senders = [s for _, s in last3] - if len(set(texts)) == 1 and len(set(senders)) == 3: + if len(counter) >= n: + last_n = counter[-n:] + texts = [t for t, _ in last_n] + senders = [s for _, s in last_n] + if ( + len(set(texts)) == 1 + and len(set(senders)) == n + and self.config.bot_qq not in senders + ): reply_text = texts[0] if self.config.inverted_question_enabled: stripped = reply_text.strip() diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 3adf332..2bb6aef 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -17,6 +17,7 @@ def _build_handler( *, repeat_enabled: bool = False, + repeat_threshold: int = 3, inverted_question_enabled: bool = False, keyword_reply_enabled: bool = False, ) -> Any: @@ -24,6 +25,7 @@ def _build_handler( handler.config = SimpleNamespace( bot_qq=10000, repeat_enabled=repeat_enabled, + repeat_threshold=repeat_threshold, inverted_question_enabled=inverted_question_enabled, keyword_reply_enabled=keyword_reply_enabled, bilibili_auto_extract_enabled=False, @@ -248,7 +250,7 @@ async def test_repeat_groups_are_independent() -> None: assert call.args[0] == 30002 -# ── 计数器窗口:只看最近5条 ── +# ── 计数器窗口:只看最近 N 条 ── @pytest.mark.asyncio @@ -264,3 +266,77 @@ async def test_repeat_counter_sliding_window() -> None: handler.sender.send_group_message.assert_called_once() call = handler.sender.send_group_message.call_args assert call.args[1] == "hello" + + +# ── bot 自身发言后不触发复读 ── + +BOT_QQ = 10000 + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_before_users() -> None: + """bot 先发,后面用户再发相同消息,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # bot 先发 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 两个用户跟发 + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_no_trigger_when_bot_sends_in_middle() -> None: + """用户发到一半,bot 插入相同消息,之后用户再凑满阈值,不应触发复读。""" + handler = _build_handler(repeat_enabled=True) + # 两个用户先发 + await handler.handle_message(_group_event(sender_id=20001, text="hello")) + await handler.handle_message(_group_event(sender_id=20002, text="hello")) + # bot 插入 + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 第三个用户发:此时窗口 [user2, bot, user3],含 bot → 不触发 + await handler.handle_message(_group_event(sender_id=20003, text="hello")) + + handler.sender.send_group_message.assert_not_called() + + +@pytest.mark.asyncio +async def test_repeat_triggers_after_bot_window_slides_out() -> None: + """bot 消息滑出窗口后,纯用户序列应正常触发复读(threshold=3)。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=3) + # bot 先发(进入窗口) + await handler.handle_message(_group_event(sender_id=BOT_QQ, text="hello")) + # 三个不同用户依次发:窗口变为 [user1, user2, user3](bot 已滑出) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hello")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hello" + + +# ── 可配置阈值 ── + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_2() -> None: + """threshold=2 时,2 条不同发送者相同消息即触发复读。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=2) + for uid in [20001, 20002]: + await handler.handle_message(_group_event(sender_id=uid, text="hi")) + + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hi" + + +@pytest.mark.asyncio +async def test_repeat_custom_threshold_4() -> None: + """threshold=4 时,3 条不同发送者相同消息不触发,第 4 条才触发。""" + handler = _build_handler(repeat_enabled=True, repeat_threshold=4) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="hey")) + handler.sender.send_group_message.assert_not_called() + + await handler.handle_message(_group_event(sender_id=20004, text="hey")) + handler.sender.send_group_message.assert_called_once() + assert handler.sender.send_group_message.call_args.args[1] == "hey" From f6712278a993a7fb7a8fc2433b4fd8373ff84e26 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 15:55:15 +0800 Subject: [PATCH 09/57] =?UTF-8?q?fix(profile,latex):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E4=BE=A7=E5=86=99=E5=8F=8C=E5=A4=8D=E6=95=B0=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E5=92=8CMathJax=E7=AD=89=E5=BE=85=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit profile 命令: - 修复 entity_type 双复数 bug ("users"→"user", "groups"→"group") profile_storage._profile_path 会追加 "s", handler 传 "users" 导致 路径变为 "userss/", 永远找不到侧写文件 - 新增 "g" 快捷子命令 (/p g 等同于 /p group) - 更新 config.json 帮助文本和使用说明 LaTeX 渲染: - 修复 MathJax wait_for_function 逻辑: 旧代码返回 Promise 而非 boolean, Playwright 无法正确判断完成状态, 必然超时 - 改用 pageReady 回调设 window._mjReady 标记, wait 检查该标记 - 超时从 15s 增至 30s - 添加 MathJax 配置块支持行内数学 ($...$) 测试: 新增 g 快捷方式和私聊拒绝测试, HTML 模板测试 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/config.json | 6 +-- .../skills/commands/profile/handler.py | 8 +-- .../toolsets/render/render_latex/handler.py | 18 +++++-- tests/test_profile_command.py | 50 +++++++++++++++++-- tests/test_render_latex_tool.py | 11 ++++ 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/Undefined/skills/commands/profile/config.json b/src/Undefined/skills/commands/profile/config.json index 6b47c63..1dd5eb5 100644 --- a/src/Undefined/skills/commands/profile/config.json +++ b/src/Undefined/skills/commands/profile/config.json @@ -1,8 +1,8 @@ { "name": "profile", - "description": "查看用户或群聊侧写", - "usage": "/profile [group]", - "example": "/profile", + "description": "查看侧写。默认显示你的用户侧写;加 g 查看当前群聊侧写(仅群聊可用)", + "usage": "/p [g]", + "example": "/p g", "permission": "public", "rate_limit": { "user": 60, diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index aba240a..143f3b4 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -31,18 +31,18 @@ async def execute(args: list[str], context: CommandContext) -> None: await _send(context, "❌ 侧写服务未启用") return - # Parse subcommand + # Parse subcommand: "g" or "group" → group profile sub = args[0].lower().strip() if args else "" - if sub == "group": + if sub in ("group", "g"): if _is_private(context): await _send(context, "❌ 私聊中不支持查看群聊侧写") return - entity_type = "groups" + entity_type = "group" entity_id = str(context.group_id) empty_hint = "暂无群聊侧写数据" else: - entity_type = "users" + entity_type = "user" entity_id = str(context.sender_id) empty_hint = "暂无侧写数据" diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 01e8ba3..0637b9f 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -63,6 +63,18 @@ def _build_html(latex_content: str) -> str: + +

{safe_title}

+
{safe_body}
+""" + + output_dir = ensure_dir(RENDER_CACHE_DIR) + output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") + + await render_html_to_image(html_content, output_path) + + abs_path = Path(output_path).resolve() + image_cq = f"[CQ:image,file=file://{abs_path}]" + + if _is_private(context): + user_id = int(context.user_id or context.sender_id) + await context.sender.send_private_message(user_id, image_cq) + else: + await context.sender.send_group_message(context.group_id, image_cq) + + +# ── 入口 ───────────────────────────────────────────────────── + + async def execute(args: list[str], context: CommandContext) -> None: - """处理 /profile 命令。""" + """处理 /profile 命令。 + + 用法: /p [g] [-t|--text] [-f|--forward] [-r|--render] + g / group 查看群聊侧写(仅群聊可用) + -t / --text 纯文本直接发出 + -f / --forward 合并转发发出(群聊默认) + -r / --render 渲染为图片发出 + """ cognitive_service = context.cognitive_service if cognitive_service is None: - await _send(context, "❌ 侧写服务未启用") + await _send_text(context, "❌ 侧写服务未启用") return - # Parse subcommand: "g" or "group" → group profile - sub = args[0].lower().strip() if args else "" + sub, mode = _parse_args(args) if sub in ("group", "g"): if _is_private(context): - await _send(context, "❌ 私聊中不支持查看群聊侧写") + await _send_text(context, "❌ 私聊中不支持查看群聊侧写") return entity_type = "group" entity_id = str(context.group_id) + title = "📋 群聊侧写" empty_hint = "暂无群聊侧写数据" else: entity_type = "user" entity_id = str(context.sender_id) + title = "📋 用户侧写" empty_hint = "暂无侧写数据" profile = await cognitive_service.get_profile(entity_type, entity_id) if not profile: - await _send(context, f"📭 {empty_hint}") + await _send_text(context, f"📭 {empty_hint}") return - await _send(context, _truncate(profile)) + profile = _truncate(profile) + + # 私聊始终纯文本 + if _is_private(context): + mode = _MODE_TEXT + + # 未指定模式:群聊默认合并转发 + if not mode: + mode = _MODE_FORWARD + + if mode == _MODE_TEXT: + await _send_text(context, profile) + elif mode == _MODE_RENDER: + try: + await _send_render(context, title, profile) + except Exception: + logger.exception("渲染侧写图片失败,回退到纯文本") + await _send_text(context, profile) + else: + try: + await _send_forward(context, title, profile) + except Exception: + logger.exception("发送合并转发失败,回退到纯文本") + await _send_text(context, profile) diff --git a/src/Undefined/utils/fake_at.py b/src/Undefined/utils/fake_at.py new file mode 100644 index 0000000..0ac0200 --- /dev/null +++ b/src/Undefined/utils/fake_at.py @@ -0,0 +1,180 @@ +"""假@检测:识别群聊中 ``@昵称`` 纯文本形式的"假 at"。 + +设计要点: +- ``BotNicknameCache`` 自动通过 OneBot API 获取 bot 在各群的群名片 / QQ 昵称, + 带 TTL 缓存 + per-group asyncio.Lock 防竞态。 +- ``strip_fake_at`` 是无状态纯函数,负责文本匹配与剥离。 +- 匹配规则:半角 ``@`` / 全角 ``@`` + 昵称 (casefold) + 边界(空白/标点/行尾), + 昵称按长度降序匹配以避免短昵称吃掉长昵称的前缀。 +""" + +from __future__ import annotations + +import asyncio +import logging +import re +import time +import unicodedata +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from Undefined.onebot import OneBotClient + +logger = logging.getLogger(__name__) + +# 昵称后必须跟的边界:空白、常见标点、行尾 +_BOUNDARY_RE = re.compile(r"[\s\u3000,.:;!?,。:;!?/\-\]))>》]+|$") + +# 缓存默认 TTL (秒) +_DEFAULT_TTL: float = 600.0 + + +def _normalize(text: str) -> str: + """将全角 @ 统一为半角 @,然后 casefold。""" + return unicodedata.normalize("NFKC", text).casefold() + + +def _sorted_nicknames(names: frozenset[str]) -> tuple[str, ...]: + """按长度降序排列,长昵称优先匹配。""" + return tuple(sorted(names, key=len, reverse=True)) + + +def strip_fake_at( + text: str, + nicknames: frozenset[str], +) -> tuple[bool, str]: + """检测并剥离文本开头的假 @ 前缀。 + + 参数: + text: 原始消息文本(已做 extract_text 处理)。 + nicknames: 当前群 bot 所有有效昵称 (已 casefold)。 + + 返回: + (is_fake_at, stripped_text) + - is_fake_at: 是否命中假 @。 + - stripped_text: 剥离假 @ 后的文本(未命中时返回原文)。 + """ + if not nicknames or not text: + return False, text + + normalized = _normalize(text) + # 以 @ 开头才可能是假 @ + if not normalized.startswith("@"): + return False, text + + # 去掉开头的 @ + after_at = normalized[1:] + raw_after_at = text[1:] # 保持原始大小写的切片 + + for nick in _sorted_nicknames(nicknames): + if not after_at.startswith(nick): + continue + # 检查昵称后是否为合法边界 + rest_pos = len(nick) + rest_normalized = after_at[rest_pos:] + if rest_normalized and not _BOUNDARY_RE.match(rest_normalized): + continue + # 命中——用原始文本切出剥离后的内容 + stripped = raw_after_at[rest_pos:].lstrip() + return True, stripped + + return False, text + + +class BotNicknameCache: + """按群缓存 bot 昵称,用于假 @ 检测。 + + 线程安全性: + - 只在单一 asyncio 事件循环中使用。 + - 每个 group_id 使用独立 ``asyncio.Lock``,保证同一群的并发 + 消息不会触发重复 API 请求。 + - ``_global_lock`` 仅保护 ``_locks`` 字典本身的创建。 + """ + + def __init__( + self, + onebot: OneBotClient, + bot_qq: int, + ttl: float = _DEFAULT_TTL, + ) -> None: + self._onebot = onebot + self._bot_qq = bot_qq + self._ttl = ttl + # group_id → (nicknames_frozenset, timestamp) + self._cache: dict[int, tuple[frozenset[str], float]] = {} + # group_id → asyncio.Lock + self._locks: dict[int, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + + def _get_lock(self, group_id: int) -> asyncio.Lock: + """获取指定群的锁(同步快路径 + 异步慢路径)。 + + 在 asyncio 单线程模型下,dict 读写本身是原子的, + 但为严谨起见仍使用 ``_global_lock`` 保护 ``_locks`` 的创建。 + 注意:此方法不是 coroutine,但会在 _ensure_lock 里 await。 + """ + lock = self._locks.get(group_id) + if lock is not None: + return lock + # 需要创建——由调用方在 _ensure_lock 中持有 _global_lock + lock = asyncio.Lock() + self._locks[group_id] = lock + return lock + + async def _ensure_lock(self, group_id: int) -> asyncio.Lock: + """确保 group lock 存在,如有必要在 global lock 下创建。""" + lock = self._locks.get(group_id) + if lock is not None: + return lock + async with self._global_lock: + return self._get_lock(group_id) + + async def get_nicknames(self, group_id: int) -> frozenset[str]: + """获取 bot 在指定群的所有有效昵称(含手动配置)。 + + 会自动缓存,过期后异步刷新。API 失败时返回上次缓存或仅手动昵称。 + """ + now = time.monotonic() + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + lock = await self._ensure_lock(group_id) + async with lock: + # Double-check:可能在等锁期间已被其他协程刷新 + cached = self._cache.get(group_id) + if cached is not None: + names, ts = cached + if now - ts < self._ttl: + return names + + fetched = await self._fetch(group_id) + self._cache[group_id] = (fetched, time.monotonic()) + return fetched + + async def _fetch(self, group_id: int) -> frozenset[str]: + """从 OneBot API 获取 bot 在指定群的群名片 + QQ 昵称。""" + names: set[str] = set() + try: + info = await self._onebot.get_group_member_info(group_id, self._bot_qq) + if isinstance(info, dict): + for key in ("card", "nickname"): + val = str(info.get(key, "") or "").strip() + if val: + names.add(val.casefold()) + except Exception as exc: + logger.debug( + "[假@] 获取 bot 群成员信息失败: group=%s err=%s", + group_id, + exc, + ) + return frozenset(names) + + def invalidate(self, group_id: int | None = None) -> None: + """手动失效缓存。group_id=None 清空全部。""" + if group_id is None: + self._cache.clear() + else: + self._cache.pop(group_id, None) diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 2bb6aef..8d508de 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -70,6 +70,9 @@ def _build_handler( handler._repeat_counter = {} handler._repeat_locks = {} handler._profile_name_refresh_cache = {} + handler._bot_nickname_cache = SimpleNamespace( + get_nicknames=AsyncMock(return_value=frozenset()), + ) return handler From 1595d888b8aa55b30a54c89d410b8d6cc306c321 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:21:32 +0800 Subject: [PATCH 15/57] =?UTF-8?q?refactor(profile):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=90=88=E5=B9=B6=E8=BD=AC=E5=8F=91=E5=92=8C=E6=B8=B2=E6=9F=93?= =?UTF-8?q?=E8=BE=93=E5=87=BA=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 emoji 标题,合并转发改为 2 节点:元数据 + 完整侧写 - 不再截断分消息发送 - 渲染模式:卡片式布局,渐变色元数据头 + 正文区 - 元数据含类型/ID/字数/更新时间 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/handler.py | 156 +++++++++++------- 1 file changed, 98 insertions(+), 58 deletions(-) diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 6eb0428..5eabaa2 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -3,18 +3,20 @@ import html import logging import uuid +from datetime import datetime, timezone from pathlib import Path from typing import Any from Undefined.services.commands.context import CommandContext -from Undefined.utils.paths import RENDER_CACHE_DIR, ensure_dir +from Undefined.utils.paths import COGNITIVE_PROFILES_DIR, RENDER_CACHE_DIR, ensure_dir logger = logging.getLogger("profile") _MAX_PROFILE_LENGTH = 3000 -# 合并转发单条消息最大字符数(过长拆分) -_FORWARD_NODE_LIMIT = 2000 +_MODE_TEXT = "text" +_MODE_FORWARD = "forward" +_MODE_RENDER = "render" def _is_private(context: CommandContext) -> bool: @@ -27,24 +29,10 @@ def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" -# ── 输出模式枚举 ────────────────────────────────────────────── - -_MODE_TEXT = "text" -_MODE_FORWARD = "forward" -_MODE_RENDER = "render" - - -def _parse_args( - args: list[str], -) -> tuple[str, str]: - """解析参数,返回 (子命令, 输出模式)。 - - 子命令: "" | "g" | "group" - 输出模式: "text" | "forward" | "render" - """ +def _parse_args(args: list[str]) -> tuple[str, str]: + """解析参数,返回 (子命令, 输出模式)。""" sub = "" mode = "" - for arg in args: lower = arg.lower().strip() if lower in ("-t", "--text"): @@ -55,11 +43,38 @@ def _parse_args( mode = _MODE_RENDER elif lower in ("g", "group"): sub = lower - # 忽略无法识别的参数 - return sub, mode +def _profile_mtime(entity_type: str, entity_id: str) -> str | None: + """读取侧写文件最后修改时间,返回人类可读字符串。""" + p = COGNITIVE_PROFILES_DIR / f"{entity_type}s" / f"{entity_id}.md" + try: + mtime = p.stat().st_mtime + dt = datetime.fromtimestamp(mtime, tz=timezone.utc).astimezone() + return dt.strftime("%Y-%m-%d %H:%M") + except OSError: + return None + + +def _build_metadata( + entity_type: str, + entity_id: str, + profile_len: int, +) -> str: + """构建元数据摘要文本。""" + type_label = "用户" if entity_type == "user" else "群聊" + lines = [ + f"类型: {type_label}侧写", + f"ID: {entity_id}", + f"长度: {profile_len} 字", + ] + mtime = _profile_mtime(entity_type, entity_id) + if mtime: + lines.append(f"更新: {mtime}") + return "\n".join(lines) + + # ── 发送方法 ────────────────────────────────────────────────── @@ -72,57 +87,83 @@ async def _send_text(context: CommandContext, text: str) -> None: await context.sender.send_group_message(context.group_id, text) -async def _send_forward(context: CommandContext, title: str, profile_text: str) -> None: - """合并转发发送。""" +async def _send_forward( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """合并转发:节点1=元数据,节点2=完整侧写内容。""" bot_qq = str(getattr(context.config, "bot_qq", 0)) - nodes: list[dict[str, Any]] = [] - def _add_node(content: str) -> None: - nodes.append( - { - "type": "node", - "data": {"name": "Undefined", "uin": bot_qq, "content": content}, - } - ) - - _add_node(title) - - # 按长度拆分成多个节点 - remaining = profile_text - while remaining: - chunk = remaining[:_FORWARD_NODE_LIMIT] - remaining = remaining[_FORWARD_NODE_LIMIT:] - _add_node(chunk) + def _node(content: str) -> dict[str, Any]: + return { + "type": "node", + "data": {"name": "Undefined", "uin": bot_qq, "content": content}, + } + nodes = [_node(metadata), _node(profile_text)] await context.onebot.send_forward_msg(context.group_id, nodes) -async def _send_render(context: CommandContext, title: str, profile_text: str) -> None: - """渲染为图片发送。""" +async def _send_render( + context: CommandContext, + metadata: str, + profile_text: str, +) -> None: + """渲染为图片发送——元数据区 + 侧写正文区。""" from Undefined.render import render_html_to_image - safe_title = html.escape(title) + safe_meta = html.escape(metadata) safe_body = html.escape(profile_text) + + # 将元数据行拆分为 key: value 形式渲染 + meta_rows = "" + for line in safe_meta.split("\n"): + if ": " in line: + key, _, val = line.partition(": ") + meta_rows += ( + f'{key}{val}\n' + ) + html_content = f""" -

{safe_title}

-
{safe_body}
+.meta .mv {{ font-size: 13px; padding: 3px 0; }} +.body {{ + padding: 24px; line-height: 1.8; font-size: 14px; + white-space: pre-wrap; word-wrap: break-word; +}} + + +
+
{meta_rows}
+
{safe_body}
+
""" output_dir = ensure_dir(RENDER_CACHE_DIR) output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") - await render_html_to_image(html_content, output_path) abs_path = Path(output_path).resolve() @@ -160,12 +201,10 @@ async def execute(args: list[str], context: CommandContext) -> None: return entity_type = "group" entity_id = str(context.group_id) - title = "📋 群聊侧写" empty_hint = "暂无群聊侧写数据" else: entity_type = "user" entity_id = str(context.sender_id) - title = "📋 用户侧写" empty_hint = "暂无侧写数据" profile = await cognitive_service.get_profile(entity_type, entity_id) @@ -174,6 +213,7 @@ async def execute(args: list[str], context: CommandContext) -> None: return profile = _truncate(profile) + metadata = _build_metadata(entity_type, entity_id, len(profile)) # 私聊始终纯文本 if _is_private(context): @@ -187,13 +227,13 @@ async def execute(args: list[str], context: CommandContext) -> None: await _send_text(context, profile) elif mode == _MODE_RENDER: try: - await _send_render(context, title, profile) + await _send_render(context, metadata, profile) except Exception: logger.exception("渲染侧写图片失败,回退到纯文本") await _send_text(context, profile) else: try: - await _send_forward(context, title, profile) + await _send_forward(context, metadata, profile) except Exception: logger.exception("发送合并转发失败,回退到纯文本") await _send_text(context, profile) From a96e369773deacbd78580dca07394260d8514493 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:25:37 +0800 Subject: [PATCH 16/57] =?UTF-8?q?fix(profile):=20=E4=BD=BF=E7=94=A8=20WebU?= =?UTF-8?q?I=20=E9=85=8D=E8=89=B2=E3=80=81=E6=8F=90=E9=AB=98=E6=88=AA?= =?UTF-8?q?=E6=96=AD=E4=B8=8A=E9=99=90=E8=87=B3=205000?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 渲染样式改为 WebUI 暖色调 (#f9f5f1/#e6e0d8/#3d3935) - 截断上限从 3000 提高到 5000 - 更新对应测试用例 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/commands/profile/handler.py | 31 +++++++++---------- tests/test_profile_command.py | 8 ++--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 5eabaa2..6d6f473 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -12,7 +12,7 @@ logger = logging.getLogger("profile") -_MAX_PROFILE_LENGTH = 3000 +_MAX_PROFILE_LENGTH = 5000 _MODE_TEXT = "text" _MODE_FORWARD = "forward" @@ -116,7 +116,6 @@ async def _send_render( safe_meta = html.escape(metadata) safe_body = html.escape(profile_text) - # 将元数据行拆分为 key: value 形式渲染 meta_rows = "" for line in safe_meta.split("\n"): if ": " in line: @@ -129,29 +128,29 @@ async def _send_render( diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py index fc01cb8..fe7a105 100644 --- a/tests/test_profile_command.py +++ b/tests/test_profile_command.py @@ -273,10 +273,10 @@ async def test_profile_no_cognitive_service() -> None: @pytest.mark.asyncio async def test_profile_truncation() -> None: - """Profile > 3000 chars gets truncated.""" + """Profile > 5000 chars gets truncated.""" sender = _DummySender() cognitive_service = AsyncMock() - long_profile = "A" * 3500 # Longer than 3000 chars + long_profile = "A" * 5500 # Longer than 5000 chars cognitive_service.get_profile = AsyncMock(return_value=long_profile) context = _build_context( @@ -291,6 +291,6 @@ async def test_profile_truncation() -> None: assert len(sender.group_messages) == 1 message = sender.group_messages[0][1] - assert len(message) <= 3100 # 3000 + truncation notice + assert len(message) <= 5100 # 5000 + truncation notice assert "[侧写过长,已截断]" in message - assert message.count("A") == 3000 # Exactly 3000 'A's before truncation + assert message.count("A") == 5000 # Exactly 5000 'A's before truncation From f3a802ccbd885c241fbb3b45e945c2282bab7409 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:53:28 +0800 Subject: [PATCH 17/57] =?UTF-8?q?feat(agents):=20=E6=96=B0=E5=A2=9E=20arXi?= =?UTF-8?q?v=20=E8=AE=BA=E6=96=87=E6=B7=B1=E5=BA=A6=E5=88=86=E6=9E=90=20ag?= =?UTF-8?q?ent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 arxiv_analysis_agent: 下载 arXiv PDF 全文进行结构化学术分析 - 分页读取设计: fetch_paper 获取元数据, read_paper_pages 分批读取避免 token 溢出 - 复用 Undefined.arxiv 模块 (client/downloader/parser) - 更新 web_agent/callable.json 添加 summary_agent 和 arxiv_analysis_agent - 新增 18 个单元测试覆盖 handler + tools + config 结构 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agents/arxiv_analysis_agent/__init__.py | 0 .../agents/arxiv_analysis_agent/callable.json | 4 + .../agents/arxiv_analysis_agent/config.json | 21 ++ .../agents/arxiv_analysis_agent/handler.py | 48 +++ .../agents/arxiv_analysis_agent/intro.md | 1 + .../agents/arxiv_analysis_agent/prompt.md | 28 ++ .../arxiv_analysis_agent/tools/__init__.py | 0 .../tools/fetch_paper/config.json | 17 + .../tools/fetch_paper/handler.py | 75 ++++ .../tools/read_paper_pages/config.json | 21 ++ .../tools/read_paper_pages/handler.py | 68 ++++ .../skills/agents/web_agent/callable.json | 2 +- tests/test_arxiv_analysis_agent.py | 351 ++++++++++++++++++ 13 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/callable.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/handler.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/intro.md create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json create mode 100644 src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py create mode 100644 tests/test_arxiv_analysis_agent.py diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json new file mode 100644 index 0000000..855776f --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["info_agent", "web_agent", "naga_code_analysis_agent"] +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json new file mode 100644 index 0000000..a97b0a9 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "arxiv_analysis_agent", + "description": "arXiv 论文深度分析助手,下载并解析 arXiv 论文 PDF 全文,进行结构化学术深度分析。支持按 arXiv ID 或 URL 获取论文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID(如 '2301.07041')、arXiv URL(如 'https://arxiv.org/abs/2301.07041')或 'arXiv:2301.07041' 格式" + }, + "prompt": { + "type": "string", + "description": "可选的分析需求,例如:'重点分析方法论和实验设计'、'关注与 diffusion model 的对比'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py new file mode 100644 index 0000000..7b6d605 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/handler.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from Undefined.arxiv.parser import normalize_arxiv_id +from Undefined.skills.agents.runner import run_agent_with_tools + +logger = logging.getLogger(__name__) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """执行 arxiv_analysis_agent。""" + + raw_paper_id = str(args.get("paper_id", "")).strip() + user_prompt = str(args.get("prompt", "")).strip() + + if not raw_paper_id: + return "请提供 arXiv 论文 ID 或 URL" + + paper_id = normalize_arxiv_id(raw_paper_id) + if paper_id is None: + return f"无法解析 arXiv 标识:{raw_paper_id}" + + context["arxiv_paper_id"] = paper_id + + context_messages = [ + { + "role": "system", + "content": f"当前任务:深度分析 arXiv 论文 {paper_id}", + } + ] + + user_content = user_prompt if user_prompt else f"请深度分析论文 arXiv:{paper_id}" + + return await run_agent_with_tools( + agent_name="arxiv_analysis_agent", + user_content=user_content, + context_messages=context_messages, + empty_user_content_message="请提供 arXiv 论文 ID 或 URL", + default_prompt="你是一个学术论文深度分析助手。", + context=context, + agent_dir=Path(__file__).parent, + logger=logger, + max_iterations=15, + tool_error_prefix="错误", + ) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md new file mode 100644 index 0000000..1b306e3 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/intro.md @@ -0,0 +1 @@ +arXiv 论文深度分析助手:下载 arXiv 论文 PDF 全文并进行结构化学术深度分析,涵盖方法论、实验、创新点、局限性等维度。 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md new file mode 100644 index 0000000..029ab1e --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/prompt.md @@ -0,0 +1,28 @@ +你是学术论文深度分析助手,专门对 arXiv 论文进行全面、结构化的学术分析。 + +工作流程: +1. 先调用 `fetch_paper` 获取论文元数据和摘要 +2. 调用 `read_paper_pages` 分批阅读论文全文(每次读取一定页数范围) +3. 读完后产出结构化深度分析 + +阅读策略: +- 先通过 `fetch_paper` 了解总页数和摘要 +- 用 `read_paper_pages` 按区间阅读(如 1-5、6-10、11-15 等),每次读 5 页 +- 对于较长的论文(>20 页),可以选择性跳过附录/参考文献,集中精力分析正文 +- 短论文可以一次性读完 + +分析输出结构: +- **概要**:一句话总结论文核心贡献 +- **研究背景与动机**:论文要解决什么问题、为什么重要 +- **方法论**:核心技术方案、算法、模型架构的详细解析 +- **实验与结果**:实验设置、基准对比、主要结论 +- **创新点与贡献**:论文的主要新颖之处 +- **局限性与未来方向**:作者提到的或你分析出的不足和可改进之处 +- **总评**:论文的整体质量和影响力评估 + +注意事项: +- 保持学术严谨,用客观语言描述 +- 如果用户指定了分析侧重(prompt),要重点展开相关部分 +- 对公式和算法用自然语言解释,不要直接复制 LaTeX +- PDF 提取可能有格式问题(乱码/缺失),遇到时说明情况继续分析 +- 如果论文是中文,用中文输出分析;否则用中文输出分析但保留关键术语英文原文 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json new file mode 100644 index 0000000..5e69a69 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "fetch_paper", + "description": "获取 arXiv 论文元数据并下载 PDF,返回论文基本信息(标题、作者、摘要、页数)。调用后可用 read_paper_pages 分页阅读正文。", + "parameters": { + "type": "object", + "properties": { + "paper_id": { + "type": "string", + "description": "arXiv 论文 ID,如 '2301.07041'" + } + }, + "required": ["paper_id"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py new file mode 100644 index 0000000..d75cfcd --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/fetch_paper/handler.py @@ -0,0 +1,75 @@ +"""获取 arXiv 论文元数据并下载 PDF 到本地缓存。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +from Undefined.arxiv.client import get_paper_info +from Undefined.arxiv.downloader import download_paper_pdf +from Undefined.arxiv.parser import normalize_arxiv_id + +logger = logging.getLogger(__name__) + +_MAX_FILE_SIZE_MB = 50 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + raw_id = str(args.get("paper_id", "")).strip() + if not raw_id: + return "错误:请提供 arXiv 论文 ID" + + paper_id = normalize_arxiv_id(raw_id) or raw_id + + request_context = {"request_id": context.get("request_id", "-")} + + try: + paper = await get_paper_info(paper_id, context=request_context) + except Exception as exc: + logger.exception("[arxiv_analysis] 获取论文元数据失败: %s", exc) + return f"错误:获取论文 {paper_id} 元数据失败 — {exc}" + + if paper is None: + return f"错误:未找到论文 arXiv:{paper_id}" + + lines: list[str] = [ + f"论文: {paper.title}", + f"ID: {paper.paper_id}", + f"作者: {'、'.join(paper.authors[:10])}{'(等 ' + str(len(paper.authors)) + ' 位)' if len(paper.authors) > 10 else ''}", + f"分类: {paper.primary_category}", + f"发布: {paper.published[:10]}", + f"更新: {paper.updated[:10]}", + f"链接: {paper.abs_url}", + f"\n摘要:\n{paper.summary}", + ] + + try: + result, task_dir = await download_paper_pdf( + paper, max_file_size_mb=_MAX_FILE_SIZE_MB, context=request_context + ) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 下载失败: %s", exc) + lines.append(f"\nPDF 下载失败: {exc}(可基于摘要进行分析)") + return "\n".join(lines) + + if result.path is None: + lines.append(f"\nPDF 不可用(状态: {result.status}),请基于摘要进行分析") + return "\n".join(lines) + + try: + doc = fitz.open(str(result.path)) + try: + page_count = len(doc) + lines.append(f"\nPDF 已下载: {page_count} 页") + context["_arxiv_pdf_path"] = str(result.path) + context["_arxiv_pdf_pages"] = page_count + context["_arxiv_task_dir"] = str(task_dir) + finally: + doc.close() + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + lines.append(f"\nPDF 无法打开: {exc}(可基于摘要进行分析)") + + return "\n".join(lines) diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json new file mode 100644 index 0000000..abc268e --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/config.json @@ -0,0 +1,21 @@ +{ + "type": "function", + "function": { + "name": "read_paper_pages", + "description": "读取已下载论文 PDF 的指定页范围文本。必须先调用 fetch_paper 下载论文。", + "parameters": { + "type": "object", + "properties": { + "start_page": { + "type": "integer", + "description": "起始页码(从 1 开始)" + }, + "end_page": { + "type": "integer", + "description": "结束页码(包含该页)" + } + }, + "required": ["start_page", "end_page"] + } + } +} diff --git a/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py new file mode 100644 index 0000000..ed013d1 --- /dev/null +++ b/src/Undefined/skills/agents/arxiv_analysis_agent/tools/read_paper_pages/handler.py @@ -0,0 +1,68 @@ +"""分页读取已下载的 arXiv 论文 PDF 文本。""" + +from __future__ import annotations + +import logging +from typing import Any + +import fitz + +logger = logging.getLogger(__name__) + +_MAX_CHARS_PER_READ = 15000 + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + pdf_path = context.get("_arxiv_pdf_path") + total_pages = context.get("_arxiv_pdf_pages") + + if not pdf_path: + return "错误:请先调用 fetch_paper 下载论文" + + try: + start_page = int(args.get("start_page", 1)) + end_page = int(args.get("end_page", start_page)) + except (TypeError, ValueError): + return "错误:页码必须为整数" + + if start_page < 1: + start_page = 1 + if total_pages and end_page > total_pages: + end_page = total_pages + if start_page > end_page: + return f"错误:起始页 {start_page} 大于结束页 {end_page}" + + try: + doc = fitz.open(pdf_path) + except Exception as exc: + logger.exception("[arxiv_analysis] PDF 打开失败: %s", exc) + return f"错误:PDF 文件无法打开 — {exc}" + + try: + actual_pages = len(doc) + if start_page > actual_pages: + return f"错误:论文共 {actual_pages} 页,请求的起始页 {start_page} 超出范围" + + end_page = min(end_page, actual_pages) + text_parts: list[str] = [] + total_chars = 0 + + for page_num in range(start_page - 1, end_page): + page = doc.load_page(page_num) + raw_text = page.get_text() + page_text = str(raw_text) if raw_text else "" + + if total_chars + len(page_text) > _MAX_CHARS_PER_READ and text_parts: + text_parts.append( + f"\n[第 {page_num + 1} 页起文本已截断,请用更小的页范围重新读取]" + ) + break + + text_parts.append(f"--- 第 {page_num + 1} 页 ---") + text_parts.append(page_text if page_text.strip() else "(此页无可提取文本)") + total_chars += len(page_text) + + header = f"论文内容(第 {start_page}-{end_page} 页,共 {actual_pages} 页)" + return f"{header}\n\n" + "\n".join(text_parts) + finally: + doc.close() diff --git a/src/Undefined/skills/agents/web_agent/callable.json b/src/Undefined/skills/agents/web_agent/callable.json index ab8b02a..bc537f2 100644 --- a/src/Undefined/skills/agents/web_agent/callable.json +++ b/src/Undefined/skills/agents/web_agent/callable.json @@ -1,4 +1,4 @@ { "enabled": true, - "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent"] + "allowed_callers": ["naga_code_analysis_agent", "code_delivery_agent", "info_agent", "summary_agent", "arxiv_analysis_agent"] } diff --git a/tests/test_arxiv_analysis_agent.py b/tests/test_arxiv_analysis_agent.py new file mode 100644 index 0000000..82f8c53 --- /dev/null +++ b/tests/test_arxiv_analysis_agent.py @@ -0,0 +1,351 @@ +"""arxiv_analysis_agent 单元测试。""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# config.json / callable.json 结构检查 +# --------------------------------------------------------------------------- + +AGENT_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "arxiv_analysis_agent" +) + + +def _load_json(name: str) -> dict[str, Any]: + result: dict[str, Any] = json.loads((AGENT_DIR / name).read_text(encoding="utf-8")) + return result + + +class TestAgentConfig: + def test_config_json_schema(self) -> None: + cfg = _load_json("config.json") + assert cfg["type"] == "function" + func = cfg["function"] + assert func["name"] == "arxiv_analysis_agent" + assert "paper_id" in func["parameters"]["properties"] + assert "paper_id" in func["parameters"]["required"] + + def test_callable_json(self) -> None: + cfg = _load_json("callable.json") + assert cfg["enabled"] is True + assert isinstance(cfg["allowed_callers"], list) + assert len(cfg["allowed_callers"]) > 0 + + def test_tools_exist(self) -> None: + tools_dir = AGENT_DIR / "tools" + assert (tools_dir / "fetch_paper" / "config.json").exists() + assert (tools_dir / "fetch_paper" / "handler.py").exists() + assert (tools_dir / "read_paper_pages" / "config.json").exists() + assert (tools_dir / "read_paper_pages" / "handler.py").exists() + + def test_prompt_md_exists(self) -> None: + assert (AGENT_DIR / "prompt.md").exists() + content = (AGENT_DIR / "prompt.md").read_text(encoding="utf-8") + assert len(content) > 100 + + +# --------------------------------------------------------------------------- +# handler.py +# --------------------------------------------------------------------------- + + +class TestHandler: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_invalid_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + result = await execute({"paper_id": "not-a-valid-id"}, {}) + assert "无法解析" in result + + @pytest.mark.asyncio + async def test_valid_paper_id_calls_runner(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + mock_result = "分析结果" + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value=mock_result, + ) as mock_run: + result = await execute( + {"paper_id": "2301.07041", "prompt": "分析方法论"}, + {"request_id": "test-123"}, + ) + assert result == mock_result + mock_run.assert_awaited_once() + call_kwargs = mock_run.call_args[1] + assert call_kwargs["agent_name"] == "arxiv_analysis_agent" + assert "分析方法论" in call_kwargs["user_content"] + + @pytest.mark.asyncio + async def test_url_input_normalized(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.handler import execute + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.handler.run_agent_with_tools", + new_callable=AsyncMock, + return_value="ok", + ): + ctx: dict[str, Any] = {} + await execute( + {"paper_id": "https://arxiv.org/abs/2301.07041"}, + ctx, + ) + assert ctx["arxiv_paper_id"] == "2301.07041" + + +# --------------------------------------------------------------------------- +# fetch_paper tool +# --------------------------------------------------------------------------- + + +def _make_paper_info( + paper_id: str = "2301.07041", +) -> Any: + from Undefined.arxiv.models import PaperInfo + + return PaperInfo( + paper_id=paper_id, + title="Test Paper Title", + authors=("Author A", "Author B"), + summary="This is a test abstract.", + published="2023-01-17T00:00:00Z", + updated="2023-01-18T00:00:00Z", + primary_category="cs.AI", + abs_url=f"https://arxiv.org/abs/{paper_id}", + pdf_url=f"https://arxiv.org/pdf/{paper_id}.pdf", + ) + + +class TestFetchPaper: + @pytest.mark.asyncio + async def test_empty_paper_id(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + result = await execute({"paper_id": ""}, {}) + assert "错误" in result + + @pytest.mark.asyncio + async def test_metadata_only_on_download_failure(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + side_effect=RuntimeError("network error"), + ), + ): + result = await execute({"paper_id": "2301.07041"}, {}) + assert "Test Paper Title" in result + assert "Author A" in result + assert "PDF 下载失败" in result + + @pytest.mark.asyncio + async def test_paper_not_found(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + with patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=None, + ): + result = await execute({"paper_id": "9999.99999"}, {}) + assert "未找到" in result + + @pytest.mark.asyncio + async def test_successful_fetch_with_pdf(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler import ( + execute, + ) + + paper = _make_paper_info() + pdf_path = tmp_path / "test.pdf" + + import fitz + + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Hello world") + doc.save(str(pdf_path)) + doc.close() + + from Undefined.arxiv.downloader import PaperDownloadResult + + download_result = PaperDownloadResult( + path=pdf_path, size_bytes=1024, status="downloaded" + ) + + with ( + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.get_paper_info", + new_callable=AsyncMock, + return_value=paper, + ), + patch( + "Undefined.skills.agents.arxiv_analysis_agent.tools.fetch_paper.handler.download_paper_pdf", + new_callable=AsyncMock, + return_value=(download_result, tmp_path), + ), + ): + ctx: dict[str, Any] = {} + result = await execute({"paper_id": "2301.07041"}, ctx) + assert "Test Paper Title" in result + assert "1 页" in result + assert ctx["_arxiv_pdf_path"] == str(pdf_path) + assert ctx["_arxiv_pdf_pages"] == 1 + + +# --------------------------------------------------------------------------- +# read_paper_pages tool +# --------------------------------------------------------------------------- + + +class TestReadPaperPages: + @pytest.mark.asyncio + async def test_no_pdf_downloaded(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + result = await execute({"start_page": 1, "end_page": 1}, {}) + assert "先调用 fetch_paper" in result + + @pytest.mark.asyncio + async def test_read_single_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Page 1 content") + page2 = doc.new_page() + page2.insert_text((72, 72), "Page 2 content") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 2, + } + result = await execute({"start_page": 1, "end_page": 1}, ctx) + assert "第 1 页" in result + assert "Page 1 content" in result + assert "Page 2 content" not in result + + @pytest.mark.asyncio + async def test_read_page_range(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + for i in range(3): + page = doc.new_page() + page.insert_text((72, 72), f"Content of page {i + 1}") + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 3, + } + result = await execute({"start_page": 1, "end_page": 3}, ctx) + assert "第 1-3 页" in result or "第 1 页" in result + assert "Content of page 1" in result + assert "Content of page 3" in result + + @pytest.mark.asyncio + async def test_out_of_range_page(self, tmp_path: Path) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + import fitz + + pdf_path = tmp_path / "paper.pdf" + doc = fitz.open() + doc.new_page() + doc.save(str(pdf_path)) + doc.close() + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": str(pdf_path), + "_arxiv_pdf_pages": 1, + } + result = await execute({"start_page": 5, "end_page": 10}, ctx) + assert "错误" in result + + @pytest.mark.asyncio + async def test_invalid_page_numbers(self) -> None: + from Undefined.skills.agents.arxiv_analysis_agent.tools.read_paper_pages.handler import ( + execute, + ) + + ctx: dict[str, Any] = { + "_arxiv_pdf_path": "/some/path.pdf", + "_arxiv_pdf_pages": 10, + } + result = await execute({"start_page": "abc", "end_page": "def"}, ctx) + assert "整数" in result + + +# --------------------------------------------------------------------------- +# web_agent callable.json 更新检查 +# --------------------------------------------------------------------------- + + +class TestWebAgentCallable: + def test_web_agent_allows_new_callers(self) -> None: + web_agent_callable = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "agents" + / "web_agent" + / "callable.json" + ) + cfg = json.loads(web_agent_callable.read_text(encoding="utf-8")) + callers = cfg["allowed_callers"] + assert "summary_agent" in callers + assert "arxiv_analysis_agent" in callers From 45e5a367b1894f1f4027881692ff583f3ea14d8f Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sat, 18 Apr 2026 22:54:06 +0800 Subject: [PATCH 18/57] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=20CLAUDE.md=20?= =?UTF-8?q?agents=20=E6=95=B0=E9=87=8F=E4=B8=BA=207?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CLAUDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index f2e67a0..3e3df12 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,7 +55,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) |------|------| | `ai/` | AI 核心:client.py(主入口)、llm.py(模型请求)、prompts.py(Prompt构建)、tooling.py(工具管理)、multimodal.py(多模态)、model_selector.py(模型选择) | | `services/` | 运行服务:ai_coordinator.py(协调器+队列投递)、queue_manager.py(车站-列车队列)、command.py(命令分发)、model_pool.py(多模型池)、security.py(安全防护) | -| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(6个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | +| `skills/` | 热重载技能系统:tools/(原子工具)、toolsets/(11类工具集)、agents/(7个智能体)、commands/(斜杠指令)、anthropic_skills/(SKILL.md知识注入) | | `cognitive/` | 认知记忆:service.py(入口)、vector_store.py(ChromaDB)、historian.py(后台史官异步改写+侧写合并)、job_queue.py、profile_storage.py | | `memes/` | 表情包库:service.py(两阶段AI管线)、worker.py(异步处理队列)、store.py(SQLite)、vector_store.py(ChromaDB)、models.py | | `config/` | 配置系统:loader.py(TOML解析+类型化)、models.py(数据模型)、hot_reload.py(热更新) | From 9439b0ebc5f8a2b2e4a181ef9aaeb638310b3953 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 08:37:50 +0800 Subject: [PATCH 19/57] =?UTF-8?q?feat(tools):=20=E6=96=B0=E5=A2=9E=20calcu?= =?UTF-8?q?lator=20=E5=A4=9A=E5=8A=9F=E8=83=BD=E5=AE=89=E5=85=A8=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 基于 AST 的安全表达式求值,拒绝任何非数学操作 - 支持:算术、幂运算、科学函数、三角函数、统计函数、组合数学 - 常量:pi, e, tau, inf; 函数:sqrt, log, sin, cos, factorial, gcd, mean 等 - allowed_callers: ["*"] 允许所有 agent 调用 - 安全限制:指数上限 10000、表达式长度上限 500 - 55 个单元测试覆盖算术/科学/统计/比较/安全拒绝 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/tools/calculator/callable.json | 4 + .../skills/tools/calculator/config.json | 17 ++ .../skills/tools/calculator/handler.py | 258 +++++++++++++++++ tests/test_calculator.py | 261 ++++++++++++++++++ 4 files changed, 540 insertions(+) create mode 100644 src/Undefined/skills/tools/calculator/callable.json create mode 100644 src/Undefined/skills/tools/calculator/config.json create mode 100644 src/Undefined/skills/tools/calculator/handler.py create mode 100644 tests/test_calculator.py diff --git a/src/Undefined/skills/tools/calculator/callable.json b/src/Undefined/skills/tools/calculator/callable.json new file mode 100644 index 0000000..0a69975 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/callable.json @@ -0,0 +1,4 @@ +{ + "enabled": true, + "allowed_callers": ["*"] +} diff --git a/src/Undefined/skills/tools/calculator/config.json b/src/Undefined/skills/tools/calculator/config.json new file mode 100644 index 0000000..ad86fb4 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/config.json @@ -0,0 +1,17 @@ +{ + "type": "function", + "function": { + "name": "calculator", + "description": "安全的多功能数学计算器。支持算术运算、科学函数、统计函数和常量。输入数学表达式即可计算。", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "数学表达式,例如:'2+3*4'、'sqrt(144)'、'sin(pi/6)'、'log(1000,10)'、'mean(1,2,3,4,5)'、'2**10'、'factorial(10)'、'gcd(48,18)'" + } + }, + "required": ["expression"] + } + } +} diff --git a/src/Undefined/skills/tools/calculator/handler.py b/src/Undefined/skills/tools/calculator/handler.py new file mode 100644 index 0000000..cd26438 --- /dev/null +++ b/src/Undefined/skills/tools/calculator/handler.py @@ -0,0 +1,258 @@ +"""安全的多功能数学计算器。 + +通过 AST 解析数学表达式,仅允许数学运算,拒绝任何危险操作。 +支持:算术、幂运算、科学函数、统计函数、常量。 +""" + +from __future__ import annotations + +import ast +import math +import operator +import statistics +from typing import Any + +_UNARY_OPS: dict[type, Any] = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} + +_BIN_OPS: dict[type, Any] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.BitXor: operator.pow, # ^ 也当幂运算(常见误用) +} + +_COMPARE_OPS: dict[type, Any] = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, +} + +_CONSTANTS: dict[str, float] = { + "pi": math.pi, + "PI": math.pi, + "e": math.e, + "E": math.e, + "tau": math.tau, + "inf": math.inf, + "nan": math.nan, +} + + +def _stat_fn(fn_name: str, args: list[float | int]) -> float | int: + """统计函数分发。""" + if len(args) < 1: + raise ValueError(f"{fn_name} 至少需要 1 个参数") + fn_map: dict[str, Any] = { + "mean": statistics.mean, + "median": statistics.median, + "stdev": statistics.stdev, + "variance": statistics.variance, + "pstdev": statistics.pstdev, + "pvariance": statistics.pvariance, + "harmonic_mean": statistics.harmonic_mean, + } + fn = fn_map.get(fn_name) + if fn is None: + raise ValueError(f"未知统计函数: {fn_name}") + if fn_name in ("stdev", "variance") and len(args) < 2: + raise ValueError(f"{fn_name} 至少需要 2 个数据点") + result: float | int = fn(args) + return result + + +_MATH_FUNCS: dict[str, Any] = { + # 基础 + "abs": abs, + "round": round, + "int": int, + "float": float, + # 幂与对数 + "sqrt": math.sqrt, + "cbrt": lambda x: x ** (1 / 3), + "pow": math.pow, + "exp": math.exp, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "ln": math.log, + # 三角 + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "atan2": math.atan2, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + # 角度转换 + "degrees": math.degrees, + "radians": math.radians, + # 取整 + "ceil": math.ceil, + "floor": math.floor, + "trunc": math.trunc, + # 组合数学 + "factorial": math.factorial, + "comb": math.comb, + "perm": math.perm, + "gcd": math.gcd, + "lcm": math.lcm, + # 其他 + "hypot": math.hypot, + "copysign": math.copysign, + "fmod": math.fmod, + "isqrt": math.isqrt, + # 统计(占位,实际调用 _stat_fn) + "mean": None, + "median": None, + "stdev": None, + "variance": None, + "pstdev": None, + "pvariance": None, + "harmonic_mean": None, + # 最值 + "max": max, + "min": min, + "sum": sum, +} + +_STAT_FUNCS = frozenset( + {"mean", "median", "stdev", "variance", "pstdev", "pvariance", "harmonic_mean"} +) + +_MAX_POWER = 10000 +_MAX_EXPRESSION_LENGTH = 500 + + +class _SafeEvaluator(ast.NodeVisitor): + """安全 AST 求值器,仅允许数学运算。""" + + def visit(self, node: ast.AST) -> Any: + return super().visit(node) + + def generic_visit(self, node: ast.AST) -> Any: + raise ValueError(f"不支持的表达式语法: {type(node).__name__}") + + def visit_Expression(self, node: ast.Expression) -> Any: + return self.visit(node.body) + + def visit_Constant(self, node: ast.Constant) -> Any: + if isinstance(node.value, (int, float, complex)): + return node.value + raise ValueError(f"不支持的常量类型: {type(node.value).__name__}") + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in _CONSTANTS: + return _CONSTANTS[node.id] + raise ValueError(f"未知变量: {node.id}") + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + op_fn = _UNARY_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的一元运算: {type(node.op).__name__}") + return op_fn(self.visit(node.operand)) + + def visit_BinOp(self, node: ast.BinOp) -> Any: + op_fn = _BIN_OPS.get(type(node.op)) + if op_fn is None: + raise ValueError(f"不支持的运算: {type(node.op).__name__}") + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, (ast.Pow, ast.BitXor)): + if isinstance(right, (int, float)) and abs(right) > _MAX_POWER: + raise ValueError(f"指数过大: {right}(上限 {_MAX_POWER})") + return op_fn(left, right) + + def visit_Compare(self, node: ast.Compare) -> Any: + left = self.visit(node.left) + for op, comparator in zip(node.ops, node.comparators): + op_fn = _COMPARE_OPS.get(type(op)) + if op_fn is None: + raise ValueError(f"不支持的比较: {type(op).__name__}") + right = self.visit(comparator) + if not op_fn(left, right): + return False + left = right + return True + + def visit_Call(self, node: ast.Call) -> Any: + if not isinstance(node.func, ast.Name): + raise ValueError("仅支持直接函数调用") + fn_name = node.func.id + if fn_name not in _MATH_FUNCS: + raise ValueError(f"未知函数: {fn_name}") + if node.keywords: + raise ValueError("不支持关键字参数") + + args = [self.visit(arg) for arg in node.args] + + if fn_name in _STAT_FUNCS: + return _stat_fn(fn_name, args) + + fn = _MATH_FUNCS[fn_name] + return fn(*args) + + def visit_IfExp(self, node: ast.IfExp) -> Any: + condition = self.visit(node.test) + return self.visit(node.body) if condition else self.visit(node.orelse) + + def visit_Tuple(self, node: ast.Tuple) -> Any: + return tuple(self.visit(elt) for elt in node.elts) + + def visit_List(self, node: ast.List) -> Any: + return [self.visit(elt) for elt in node.elts] + + +def safe_eval(expression: str) -> str: + """安全计算数学表达式,返回字符串结果。""" + expr = expression.strip() + if not expr: + raise ValueError("表达式为空") + if len(expr) > _MAX_EXPRESSION_LENGTH: + raise ValueError( + f"表达式过长({len(expr)} 字符,上限 {_MAX_EXPRESSION_LENGTH})" + ) + + tree = ast.parse(expr, mode="eval") + result = _SafeEvaluator().visit(tree) + + if isinstance(result, float): + if result == int(result) and not (math.isinf(result) or math.isnan(result)): + return str(int(result)) + return f"{result:.10g}" + if isinstance(result, complex): + return str(result) + if isinstance(result, bool): + return str(result) + return str(result) + + +async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: + """计算数学表达式。""" + expression = str(args.get("expression", "")).strip() + if not expression: + return "请提供数学表达式" + + try: + result = safe_eval(expression) + return f"{expression} = {result}" + except ZeroDivisionError: + return f"计算错误:除以零 ({expression})" + except OverflowError: + return f"计算错误:结果溢出 ({expression})" + except (ValueError, TypeError) as exc: + return f"计算错误:{exc}" + except SyntaxError: + return f"表达式语法错误:{expression}" diff --git a/tests/test_calculator.py b/tests/test_calculator.py new file mode 100644 index 0000000..118759a --- /dev/null +++ b/tests/test_calculator.py @@ -0,0 +1,261 @@ +"""calculator 工具单元测试。""" + +from __future__ import annotations + +import json +import math +from pathlib import Path +from typing import Any + +import pytest + +from Undefined.skills.tools.calculator.handler import safe_eval + +TOOL_DIR = ( + Path(__file__).resolve().parent.parent + / "src" + / "Undefined" + / "skills" + / "tools" + / "calculator" +) + + +# --------------------------------------------------------------------------- +# 配置文件检查 +# --------------------------------------------------------------------------- + + +class TestConfig: + def test_config_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "config.json").read_text(encoding="utf-8") + ) + assert cfg["function"]["name"] == "calculator" + assert "expression" in cfg["function"]["parameters"]["properties"] + + def test_callable_json(self) -> None: + cfg: dict[str, Any] = json.loads( + (TOOL_DIR / "callable.json").read_text(encoding="utf-8") + ) + assert cfg["enabled"] is True + assert "*" in cfg["allowed_callers"] + + +# --------------------------------------------------------------------------- +# safe_eval 单元测试 +# --------------------------------------------------------------------------- + + +class TestArithmetic: + def test_addition(self) -> None: + assert safe_eval("2+3") == "5" + + def test_subtraction(self) -> None: + assert safe_eval("10-3") == "7" + + def test_multiplication(self) -> None: + assert safe_eval("4*5") == "20" + + def test_division(self) -> None: + assert safe_eval("10/3") == "3.333333333" + + def test_floor_division(self) -> None: + assert safe_eval("10//3") == "3" + + def test_modulo(self) -> None: + assert safe_eval("10%3") == "1" + + def test_power(self) -> None: + assert safe_eval("2**10") == "1024" + + def test_caret_as_power(self) -> None: + assert safe_eval("2^10") == "1024" + + def test_negative_number(self) -> None: + assert safe_eval("-5+3") == "-2" + + def test_complex_expression(self) -> None: + assert safe_eval("(2+3)*4-1") == "19" + + def test_floating_point(self) -> None: + assert safe_eval("0.1+0.2") == "0.3" + + +class TestConstants: + def test_pi(self) -> None: + result = float(safe_eval("pi")) + assert abs(result - math.pi) < 1e-8 + + def test_e(self) -> None: + result = float(safe_eval("e")) + assert abs(result - math.e) < 1e-8 + + def test_tau(self) -> None: + result = float(safe_eval("tau")) + assert abs(result - math.tau) < 1e-8 + + +class TestScientificFunctions: + def test_sqrt(self) -> None: + assert safe_eval("sqrt(144)") == "12" + + def test_sin(self) -> None: + result = float(safe_eval("sin(pi/6)")) + assert abs(result - 0.5) < 1e-10 + + def test_cos(self) -> None: + result = float(safe_eval("cos(0)")) + assert abs(result - 1.0) < 1e-10 + + def test_log10(self) -> None: + assert safe_eval("log10(1000)") == "3" + + def test_log_with_base(self) -> None: + assert safe_eval("log(8, 2)") == "3" + + def test_ln(self) -> None: + result = float(safe_eval("ln(e)")) + assert abs(result - 1.0) < 1e-10 + + def test_factorial(self) -> None: + assert safe_eval("factorial(10)") == "3628800" + + def test_degrees(self) -> None: + result = float(safe_eval("degrees(pi)")) + assert abs(result - 180.0) < 1e-10 + + def test_radians(self) -> None: + result = float(safe_eval("radians(180)")) + assert abs(result - math.pi) < 1e-8 + + def test_ceil(self) -> None: + assert safe_eval("ceil(3.2)") == "4" + + def test_floor(self) -> None: + assert safe_eval("floor(3.8)") == "3" + + def test_abs(self) -> None: + assert safe_eval("abs(-42)") == "42" + + def test_gcd(self) -> None: + assert safe_eval("gcd(48, 18)") == "6" + + def test_lcm(self) -> None: + assert safe_eval("lcm(4, 6)") == "12" + + def test_comb(self) -> None: + assert safe_eval("comb(10, 3)") == "120" + + def test_perm(self) -> None: + assert safe_eval("perm(5, 3)") == "60" + + def test_hypot(self) -> None: + assert safe_eval("hypot(3, 4)") == "5" + + +class TestStatistics: + def test_mean(self) -> None: + assert safe_eval("mean(1, 2, 3, 4, 5)") == "3" + + def test_median(self) -> None: + assert safe_eval("median(1, 3, 5, 7, 9)") == "5" + + def test_stdev(self) -> None: + result = float(safe_eval("stdev(2, 4, 4, 4, 5, 5, 7, 9)")) + assert result > 0 + + def test_min_max(self) -> None: + assert safe_eval("min(3, 1, 4, 1, 5)") == "1" + assert safe_eval("max(3, 1, 4, 1, 5)") == "5" + + def test_sum(self) -> None: + assert safe_eval("sum([1, 2, 3, 4, 5])") == "15" + + +class TestComparison: + def test_equal(self) -> None: + assert safe_eval("2+2 == 4") == "True" + + def test_not_equal(self) -> None: + assert safe_eval("2+2 != 5") == "True" + + def test_less_than(self) -> None: + assert safe_eval("3 < 5") == "True" + + def test_greater_than(self) -> None: + assert safe_eval("5 > 3") == "True" + + +class TestIfExpression: + def test_ternary(self) -> None: + assert safe_eval("42 if 2>1 else 0") == "42" + + +class TestSafety: + def test_rejects_import(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("__import__('os')") + + def test_rejects_attribute_access(self) -> None: + with pytest.raises(ValueError, match="不支持"): + safe_eval("math.pi") + + def test_rejects_unknown_function(self) -> None: + with pytest.raises(ValueError, match="未知函数"): + safe_eval("eval('1+1')") + + def test_rejects_unknown_variable(self) -> None: + with pytest.raises(ValueError, match="未知变量"): + safe_eval("x + 1") + + def test_rejects_too_large_exponent(self) -> None: + with pytest.raises(ValueError, match="指数过大"): + safe_eval("2**100000") + + def test_rejects_too_long_expression(self) -> None: + with pytest.raises(ValueError, match="表达式过长"): + safe_eval("1+" * 300 + "1") + + def test_rejects_empty(self) -> None: + with pytest.raises(ValueError, match="表达式为空"): + safe_eval("") + + def test_division_by_zero(self) -> None: + with pytest.raises(ZeroDivisionError): + safe_eval("1/0") + + +# --------------------------------------------------------------------------- +# execute() 集成测试 +# --------------------------------------------------------------------------- + + +class TestExecute: + @pytest.mark.asyncio + async def test_basic_calculation(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+3*4"}, {}) + assert "= 14" in result + + @pytest.mark.asyncio + async def test_empty_expression(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": ""}, {}) + assert "请提供" in result + + @pytest.mark.asyncio + async def test_error_message(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "1/0"}, {}) + assert "除以零" in result + + @pytest.mark.asyncio + async def test_syntax_error(self) -> None: + from Undefined.skills.tools.calculator.handler import execute + + result = await execute({"expression": "2+++"}, {}) + assert "语法错误" in result From 56d277c9ab680757da8874e66ba257039faa2963 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 08:56:18 +0800 Subject: [PATCH 20/57] =?UTF-8?q?feat(config):=20=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=8E=86=E5=8F=B2=E9=99=90=E5=88=B6=E5=85=A8=E9=9D=A2=E5=8F=AF?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Config 新增 7 个 history_* 配置项,通过 [history] 配置节管理 - 所有消息获取/搜索/分析工具的硬编码限制改为从配置读取 - history_max_records 支持 0=无限制(默认 10000) - 提升默认值:filtered_result_limit 50→200, summary_fetch_limit 500→1000, summary_time_fetch_limit 2000→5000, onebot_fetch_limit 5000→10000, group_analysis_limit 100→500 - config.toml.example 新增双语注释的完整配置节 - 新增 test_history_config.py 验证配置字段和 helper 逻辑 - 更新 test_fetch_messages_tool.py 适配新默认值 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CLAUDE.md | 2 +- config.toml.example | 22 ++++- src/Undefined/config/loader.py | 86 ++++++++++++++++++- .../tools/fetch_messages/config.json | 2 +- .../tools/fetch_messages/handler.py | 28 +++++- .../analyze_join_statistics/config.json | 2 +- .../analyze_join_statistics/handler.py | 4 +- .../analyze_member_messages/config.json | 4 +- .../analyze_member_messages/handler.py | 7 +- .../analyze_new_member_activity/config.json | 2 +- .../analyze_new_member_activity/handler.py | 4 +- .../messages/get_messages_by_time/handler.py | 23 +++-- .../messages/get_recent_messages/handler.py | 21 +++-- src/Undefined/utils/history.py | 42 +++++---- src/Undefined/utils/recent_messages.py | 9 +- tests/test_fetch_messages_tool.py | 55 ++++++++++-- tests/test_history_config.py | 76 ++++++++++++++++ 17 files changed, 337 insertions(+), 52 deletions(-) create mode 100644 tests/test_history_config.py diff --git a/CLAUDE.md b/CLAUDE.md index 3e3df12..d699463 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,7 +80,7 @@ OneBot WebSocket → onebot.py → handlers.py → SecurityService(注入检测) 车站-列车模型(QueueManager):按模型隔离队列组,4 级优先级(超管 > 私聊 > @提及 > 普通群聊),普通队列自动修剪保留最新 2 条,非阻塞按节奏发车(默认 1Hz)。 ### 存储与数据 -- `data/history/` — 消息历史(group_*.json / private_*.json,10000 条限制) +- `data/history/` — 消息历史(group_*.json / private_*.json,默认 10000 条,可通过 `[history]` 配置节调整,0=无限制) - `data/cognitive/` — ChromaDB 向量库 + profiles/ 侧写 + queues/ 任务队列 - `data/memes/` — 表情包库(blobs原图、previews预览图、memes.sqlite3元数据、chromadb向量检索) - `data/memory.json` — 置顶备忘录(500 条上限) diff --git a/config.toml.example b/config.toml.example index 4893aff..4ab0d04 100644 --- a/config.toml.example +++ b/config.toml.example @@ -726,9 +726,27 @@ inverted_question_enabled = false # zh: 历史记录配置。 # en: History settings. [history] -# zh: 每个会话最多保留的消息条数。 -# en: Max messages to keep per conversation. +# zh: 每个会话最多保留的消息条数(0 = 无限制,注意内存占用)。 +# en: Max messages to keep per conversation (0 = unlimited, mind memory usage). max_records = 10000 +# zh: 工具过滤查询返回的最大消息条数。 +# en: Max messages returned by tool filtered queries. +filtered_result_limit = 200 +# zh: 工具过滤搜索时扫描的最大消息条数。 +# en: Max messages to scan when tools perform filtered searches. +search_scan_limit = 10000 +# zh: 总结 agent 单次拉取的最大消息条数。 +# en: Max messages the summary agent can fetch per call. +summary_fetch_limit = 1000 +# zh: 总结 agent 按时间范围查询时的最大拉取条数。 +# en: Max messages the summary agent fetches for time-range queries. +summary_time_fetch_limit = 5000 +# zh: OneBot API 回退获取的最大消息条数。 +# en: Max messages to fetch via OneBot API fallback. +onebot_fetch_limit = 10000 +# zh: 群分析工具的消息/成员返回上限。 +# en: Max messages/members returned by group analysis tools. +group_analysis_limit = 500 # zh: Skills 热重载配置(可选)。 # en: Skills hot reload settings (optional). diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index 998a3d7..ac65473 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -504,6 +504,12 @@ class Config: token_usage_max_total_mb: int token_usage_archive_prune_mode: str history_max_records: int + history_filtered_result_limit: int + history_search_scan_limit: int + history_summary_fetch_limit: int + history_summary_time_fetch_limit: int + history_onebot_fetch_limit: int + history_group_analysis_limit: int skills_hot_reload: bool skills_hot_reload_interval: float skills_hot_reload_debounce: float @@ -1036,8 +1042,78 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi "delete", ) - history_max_records = _coerce_int( - _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), 10000 + history_max_records = max( + 0, + _coerce_int( + _get_value(data, ("history", "max_records"), "HISTORY_MAX_RECORDS"), + 10000, + ), + ) + history_filtered_result_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "filtered_result_limit"), + "HISTORY_FILTERED_RESULT_LIMIT", + ), + 200, + ), + ) + history_search_scan_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "search_scan_limit"), + "HISTORY_SEARCH_SCAN_LIMIT", + ), + 10000, + ), + ) + history_summary_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_fetch_limit"), + "HISTORY_SUMMARY_FETCH_LIMIT", + ), + 1000, + ), + ) + history_summary_time_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "summary_time_fetch_limit"), + "HISTORY_SUMMARY_TIME_FETCH_LIMIT", + ), + 5000, + ), + ) + history_onebot_fetch_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "onebot_fetch_limit"), + "HISTORY_ONEBOT_FETCH_LIMIT", + ), + 10000, + ), + ) + history_group_analysis_limit = max( + 1, + _coerce_int( + _get_value( + data, + ("history", "group_analysis_limit"), + "HISTORY_GROUP_ANALYSIS_LIMIT", + ), + 500, + ), ) skills_hot_reload = _coerce_bool( @@ -1451,6 +1527,12 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi token_usage_archive_prune_mode=token_usage_archive_prune_mode, skills_hot_reload=skills_hot_reload, history_max_records=history_max_records, + history_filtered_result_limit=history_filtered_result_limit, + history_search_scan_limit=history_search_scan_limit, + history_summary_fetch_limit=history_summary_fetch_limit, + history_summary_time_fetch_limit=history_summary_time_fetch_limit, + history_onebot_fetch_limit=history_onebot_fetch_limit, + history_group_analysis_limit=history_group_analysis_limit, skills_hot_reload_interval=skills_hot_reload_interval, skills_hot_reload_debounce=skills_hot_reload_debounce, agent_intro_autogen_enabled=agent_intro_autogen_enabled, diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json index d64a015..30042bc 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/config.json @@ -8,7 +8,7 @@ "properties": { "count": { "type": "integer", - "description": "要获取的消息条数,默认50,最大500。" + "description": "要获取的消息条数,默认50,上限由服务器配置决定。" }, "time_range": { "type": "string", diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py index c79cb07..4426752 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -12,9 +12,21 @@ _TIME_RANGE_PATTERN = re.compile(r"^(\d+)([hHdDwW])$") _TIME_UNIT_SECONDS = {"h": 3600, "d": 86400, "w": 604800} -_MAX_COUNT = 500 _DEFAULT_COUNT = 50 -_MAX_FETCH_FOR_TIME_FILTER = 2000 + +# 以下值仅作为 runtime_config 缺失时的回退 +_FALLBACK_MAX_COUNT = 1000 +_FALLBACK_MAX_FETCH_FOR_TIME_FILTER = 5000 + + +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _parse_time_range(value: str) -> int | None: @@ -146,8 +158,11 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: time_range_str = str(args.get("time_range", "")).strip() raw_count = args.get("count", _DEFAULT_COUNT) + max_count = _get_history_limit( + context, "history_summary_fetch_limit", _FALLBACK_MAX_COUNT + ) try: - count = min(max(int(raw_count), 1), _MAX_COUNT) + count = min(max(int(raw_count), 1), max_count) except (TypeError, ValueError): count = _DEFAULT_COUNT @@ -155,7 +170,12 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: seconds = _parse_time_range(time_range_str) if seconds is None: return f"无法解析时间范围: {time_range_str}(支持格式: 1h, 6h, 1d, 7d)" - fetch_count = max(count * 2, _MAX_FETCH_FOR_TIME_FILTER) + max_time_fetch = _get_history_limit( + context, + "history_summary_time_fetch_limit", + _FALLBACK_MAX_FETCH_FOR_TIME_FILTER, + ) + fetch_count = max(count * 2, max_time_fetch) messages = history_manager.get_recent(chat_id, chat_type, 0, fetch_count) if messages: messages = _filter_by_time(messages, seconds) diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json index 410c0a6..c20eebe 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/config.json @@ -30,7 +30,7 @@ }, "member_limit": { "type": "integer", - "description": "返回成员列表的数量限制(最大100),默认 20", + "description": "返回成员列表的数量限制,上限由服务器配置决定,默认 20", "default": 20 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py index 9ba6267..fe660ed 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_join_statistics/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: member_limit = int(member_limit_raw) if member_limit_raw is not None else 20 if member_limit < 0: return "参数错误:member_limit 必须是非负整数" - member_limit = min(member_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + member_limit = min(member_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:member_limit 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json index 188e235..6b3ce00 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/config.json @@ -29,12 +29,12 @@ }, "message_limit": { "type": "integer", - "description": "返回消息内容的数量限制(最大100),默认 20", + "description": "返回消息内容的数量限制,上限由服务器配置决定,默认 20", "default": 20 }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 } }, diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py index 31b8ef8..dd00117 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_member_messages/handler.py @@ -42,7 +42,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: message_limit = int(message_limit_raw) if message_limit_raw is not None else 20 if message_limit < 0: return "参数错误:message_limit 必须是非负整数" - message_limit = min(message_limit, 100) + cfg = context.get("runtime_config") + analysis_cap = getattr(cfg, "history_group_analysis_limit", 500) if cfg else 500 + message_limit = min(message_limit, analysis_cap) except (ValueError, TypeError): return "参数类型错误:message_limit 必须是整数" @@ -53,7 +55,8 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json index 118b896..2a31c98 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/config.json @@ -20,7 +20,7 @@ }, "max_history_count": { "type": "integer", - "description": "最多获取的历史消息数量(最大5000),默认 2000", + "description": "最多获取的历史消息数量,上限由服务器配置决定,默认 2000", "default": 2000 }, "top_count": { diff --git a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py index fbbec91..a8c5d7c 100644 --- a/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py +++ b/src/Undefined/skills/toolsets/group_analysis/analyze_new_member_activity/handler.py @@ -34,7 +34,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: ) if max_history_count < 0: return "参数错误:max_history_count 必须是非负整数" - max_history_count = min(max_history_count, 5000) + cfg = context.get("runtime_config") + fetch_cap = getattr(cfg, "history_search_scan_limit", 10000) if cfg else 10000 + max_history_count = min(max_history_count, fetch_cap) except (ValueError, TypeError): return "参数类型错误:max_history_count 必须是整数" diff --git a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py index 12ff836..96aaa89 100644 --- a/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_messages_by_time/handler.py @@ -1,7 +1,15 @@ from typing import Any, Dict from datetime import datetime -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _resolve_chat_id(chat_id: str, msg_type: str, history_manager: Any) -> str: @@ -151,8 +159,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: resolved_chat_id = _resolve_chat_id(chat_id, msg_type, history_manager) if get_recent_messages_callback: + search_scan_limit = _get_history_limit( + context, "history_search_scan_limit", 10000 + ) messages = await get_recent_messages_callback( - resolved_chat_id, msg_type, 0, 10000 + resolved_chat_id, msg_type, 0, search_scan_limit ) # 时间范围过滤 @@ -164,10 +175,12 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: filtered_messages, keyword, sender ) - # 限制最多返回 50 条 + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) total_matched = len(filtered_messages) - if total_matched > _FILTERED_RESULT_LIMIT: - filtered_messages = filtered_messages[:_FILTERED_RESULT_LIMIT] + if total_matched > filtered_result_limit: + filtered_messages = filtered_messages[:filtered_result_limit] formatted = [] for msg in filtered_messages: diff --git a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py index 54874bb..fbfd906 100644 --- a/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py +++ b/src/Undefined/skills/toolsets/messages/get_recent_messages/handler.py @@ -2,7 +2,15 @@ from typing import Any, Dict -_FILTERED_RESULT_LIMIT = 50 + +def _get_history_limit(context: Dict[str, Any], key: str, fallback: int) -> int: + """从 runtime_config 读取历史限制配置。""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback def _find_group_id_by_name(chat_id: str, history_manager: Any) -> str | None: @@ -230,8 +238,9 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 获取消息:有过滤条件时扩大获取范围 messages = [] + search_scan_limit = _get_history_limit(context, "history_search_scan_limit", 10000) fetch_start = 0 if has_filter else start - fetch_end = 10000 if has_filter else end + fetch_end = search_scan_limit if has_filter else end if get_recent_messages_callback: messages = await get_recent_messages_callback( @@ -265,9 +274,11 @@ async def execute(args: Dict[str, Any], context: Dict[str, Any]) -> str: # 对过滤结果进行分页切片 messages = messages[start:end] - # 限制最多返回 50 条 - if len(messages) > _FILTERED_RESULT_LIMIT: - messages = messages[:_FILTERED_RESULT_LIMIT] + filtered_result_limit = _get_history_limit( + context, "history_filtered_result_limit", 200 + ) + if len(messages) > filtered_result_limit: + messages = messages[:filtered_result_limit] # 格式化消息 formatted = [_format_message_xml(msg) for msg in messages] diff --git a/src/Undefined/utils/history.py b/src/Undefined/utils/history.py index b6dfa4b..578711e 100644 --- a/src/Undefined/utils/history.py +++ b/src/Undefined/utils/history.py @@ -81,17 +81,20 @@ def _get_private_history_path(self, user_id: int) -> str: async def _save_history_to_file( self, history: list[dict[str, Any]], path: str ) -> None: - """异步保存历史记录到文件(最多 10000 条)""" + """异步保存历史记录到文件""" from Undefined.utils import io try: - # 只保留最近的 self._max_records 条 - truncated_history = ( - history[-self._max_records :] - if len(history) > self._max_records - else history - ) - truncated = len(history) > self._max_records + if self._max_records > 0: + truncated_history = ( + history[-self._max_records :] + if len(history) > self._max_records + else history + ) + truncated = len(history) > self._max_records + else: + truncated_history = history + truncated = False logger.debug( f"[历史记录] 准备保存: path={path}, total={len(history)}, truncated={truncated}" @@ -210,12 +213,13 @@ async def _load_history_from_file(self, path: str) -> list[dict[str, Any]]: normalized_history.append(msg) - # 只保留最近的 self._max_records 条 - return ( - normalized_history[-self._max_records :] - if len(normalized_history) > self._max_records - else normalized_history - ) + if self._max_records > 0: + return ( + normalized_history[-self._max_records :] + if len(normalized_history) > self._max_records + else normalized_history + ) + return normalized_history except Exception as e: logger.error(f"加载历史记录失败 {path}: {e}") @@ -346,7 +350,10 @@ async def add_group_message( self._message_history[group_id_str].append(record) - if len(self._message_history[group_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._message_history[group_id_str]) > self._max_records + ): self._message_history[group_id_str] = self._message_history[ group_id_str ][-self._max_records :] @@ -395,7 +402,10 @@ async def add_private_message( self._private_message_history[user_id_str].append(record) - if len(self._private_message_history[user_id_str]) > self._max_records: + if ( + self._max_records > 0 + and len(self._private_message_history[user_id_str]) > self._max_records + ): self._private_message_history[user_id_str] = ( self._private_message_history[user_id_str][-self._max_records :] ) diff --git a/src/Undefined/utils/recent_messages.py b/src/Undefined/utils/recent_messages.py index 84953ae..10adca8 100644 --- a/src/Undefined/utils/recent_messages.py +++ b/src/Undefined/utils/recent_messages.py @@ -247,9 +247,14 @@ async def get_recent_messages_prefer_local( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """优先从本地 history 获取最近消息,必要时回退到 OneBot。""" + if max_onebot_count is None: + from Undefined.config import get_config + + cfg = get_config(strict=False) + max_onebot_count = getattr(cfg, "history_onebot_fetch_limit", 10000) norm_start, norm_end = _normalize_range(start, end) if norm_end <= 0: return [] @@ -311,7 +316,7 @@ async def get_recent_messages_prefer_onebot( bot_qq: int, attachment_registry: Any | None = None, group_name_hint: str | None = None, - max_onebot_count: int = 5000, + max_onebot_count: int | None = None, ) -> list[dict[str, Any]]: """兼容旧名称,当前行为等同于本地优先。""" return await get_recent_messages_prefer_local( diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py index 337b006..1804c72 100644 --- a/tests/test_fetch_messages_tool.py +++ b/tests/test_fetch_messages_tool.py @@ -404,8 +404,8 @@ async def test_fetch_messages_no_history_manager() -> None: @pytest.mark.asyncio -async def test_fetch_messages_count_capped_at_500() -> None: - """Count is capped at 500.""" +async def test_fetch_messages_count_capped_at_config_limit() -> None: + """Count is capped at the configured summary fetch limit (fallback 1000).""" history_manager = MagicMock() history_manager.get_recent.return_value = [] @@ -416,7 +416,7 @@ async def test_fetch_messages_count_capped_at_500() -> None: await fetch_messages_execute({"count": 9999}, context) - history_manager.get_recent.assert_called_once_with("123456", "group", 0, 500) + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 1000) @pytest.mark.asyncio @@ -453,7 +453,7 @@ async def test_fetch_messages_invalid_count_defaults() -> None: @pytest.mark.asyncio async def test_fetch_messages_time_range_fetch_larger_batch() -> None: - """Time range mode fetches larger batch (max(count*2, 2000)).""" + """Time range mode fetches larger batch (max(count*2, fallback_time_limit)).""" history_manager = MagicMock() history_manager.get_recent.return_value = [] @@ -467,5 +467,48 @@ async def test_fetch_messages_time_range_fetch_larger_batch() -> None: context, ) - # max(50*2, 2000) = 2000 - history_manager.get_recent.assert_called_once_with("123456", "group", 0, 2000) + # max(50*2, 5000) = 5000 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 5000) + + +@pytest.mark.asyncio +async def test_fetch_messages_count_uses_runtime_config() -> None: + """When runtime_config provides history_summary_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 9999}, context) + + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 300) + + +@pytest.mark.asyncio +async def test_fetch_messages_time_range_uses_runtime_config() -> None: + """When runtime_config provides history_summary_time_fetch_limit, use it.""" + history_manager = MagicMock() + history_manager.get_recent.return_value = [] + + cfg = MagicMock() + cfg.history_summary_fetch_limit = 300 + cfg.history_summary_time_fetch_limit = 800 + + context: dict[str, Any] = { + "history_manager": history_manager, + "group_id": 123456, + "runtime_config": cfg, + } + + await fetch_messages_execute({"count": 50, "time_range": "1d"}, context) + + # max(50*2, 800) = 800 + history_manager.get_recent.assert_called_once_with("123456", "group", 0, 800) diff --git a/tests/test_history_config.py b/tests/test_history_config.py new file mode 100644 index 0000000..c5b0d3e --- /dev/null +++ b/tests/test_history_config.py @@ -0,0 +1,76 @@ +"""Tests for configurable history limits in Config and tool handlers.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + + +# --------------------------------------------------------------------------- +# Helper: simulate _get_history_limit as used across multiple handlers +# --------------------------------------------------------------------------- +def _get_history_limit(context: dict[str, Any], key: str, fallback: int) -> int: + """Mirror of the helper used in tool handlers.""" + cfg = context.get("runtime_config") + if cfg is not None: + val = getattr(cfg, key, None) + if isinstance(val, int) and val > 0: + return val + return fallback + + +class TestGetHistoryLimit: + """Tests for the _get_history_limit helper pattern.""" + + def test_returns_fallback_when_no_config(self) -> None: + assert _get_history_limit({}, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_is_none(self) -> None: + ctx: dict[str, Any] = {"runtime_config": None} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_config_value(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 500 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 500 + + def test_returns_fallback_when_config_attr_missing(self) -> None: + cfg = MagicMock(spec=[]) # no attributes + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "nonexistent_field", 42) == 42 + + def test_returns_fallback_when_config_value_zero(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = 0 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + def test_returns_fallback_when_config_value_negative(self) -> None: + cfg = MagicMock() + cfg.history_filtered_result_limit = -1 + ctx: dict[str, Any] = {"runtime_config": cfg} + assert _get_history_limit(ctx, "history_filtered_result_limit", 200) == 200 + + +class TestConfigHistoryFieldDefaults: + """Verify that Config dataclass has the expected history fields.""" + + def test_config_has_history_fields(self) -> None: + from Undefined.config.loader import Config + + fields = Config.__dataclass_fields__ + expected_fields = [ + "history_max_records", + "history_filtered_result_limit", + "history_search_scan_limit", + "history_summary_fetch_limit", + "history_summary_time_fetch_limit", + "history_onebot_fetch_limit", + "history_group_analysis_limit", + ] + for field_name in expected_fields: + assert field_name in fields, f"Missing field: {field_name}" + assert fields[field_name].type == "int", ( + f"{field_name} should be int, got {fields[field_name].type}" + ) From 6b274c0f28f513efe63727a4e6ccbd6b7803ad0b Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 09:29:45 +0800 Subject: [PATCH 21/57] =?UTF-8?q?feat(webui):=20=E9=95=BF=E6=9C=9F?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E5=AE=8C=E6=95=B4=20CRUD=20=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Runtime API: 新增 POST/PATCH/DELETE /api/v1/memory 端点 - WebUI proxy: 新增记忆增删改代理路由 - 前端: 内联创建表单、行内编辑(Ctrl+Enter保存/Esc取消)、确认删除 - 修复 api.js Content-Type 仅对 POST 自动设置的问题,扩展至 PATCH/PUT/DELETE - 修复 CORS Allow-Methods 缺少 PATCH/DELETE - 新增 CSS 样式支持编辑/删除按钮与内联编辑区域 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/api/app.py | 73 ++++++++- src/Undefined/webui/routes/_runtime.py | 45 +++++ src/Undefined/webui/static/css/components.css | 55 +++++++ src/Undefined/webui/static/js/api.js | 8 +- src/Undefined/webui/static/js/runtime.js | 154 +++++++++++++++++- src/Undefined/webui/templates/index.html | 12 +- 6 files changed, 341 insertions(+), 6 deletions(-) diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 7c58116..cbe5824 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -186,7 +186,9 @@ def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> N origin = normalize_origin(str(request.headers.get("Origin") or "")) settings = load_webui_settings() response.headers.setdefault("Vary", "Origin") - response.headers.setdefault("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + response.headers.setdefault( + "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" + ) response.headers.setdefault( "Access-Control-Allow-Headers", "Authorization, Content-Type, X-Undefined-API-Key", @@ -544,7 +546,14 @@ def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[st ), } }, - "/api/v1/memory": {"get": {"summary": "List/search manual memories"}}, + "/api/v1/memory": { + "get": {"summary": "List/search manual memories"}, + "post": {"summary": "Create a manual memory"}, + }, + "/api/v1/memory/{uuid}": { + "patch": {"summary": "Update a manual memory by UUID"}, + "delete": {"summary": "Delete a manual memory by UUID"}, + }, "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, "/api/v1/memes/{uid}": { @@ -749,6 +758,9 @@ async def _auth_middleware( web.get("/api/v1/probes/internal", self._internal_probe_handler), web.get("/api/v1/probes/external", self._external_probe_handler), web.get("/api/v1/memory", self._memory_handler), + web.post("/api/v1/memory", self._memory_create_handler), + web.patch("/api/v1/memory/{uuid}", self._memory_update_handler), + web.delete("/api/v1/memory/{uuid}", self._memory_delete_handler), web.get("/api/v1/memes", self._meme_list_handler), web.get("/api/v1/memes/stats", self._meme_stats_handler), web.get("/api/v1/memes/{uid}", self._meme_detail_handler), @@ -1187,6 +1199,63 @@ def _created_sort_key(item: dict[str, Any]) -> float: } ) + async def _memory_create_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + new_uuid = await memory_storage.add(fact) + if new_uuid is None: + return _json_error("Failed to create memory", status=500) + # add() returns existing UUID on duplicate + existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] + item = existing[0] if existing else None + return web.json_response( + { + "uuid": new_uuid, + "fact": item.fact if item else fact, + "created_at": item.created_at if item else "", + }, + status=201, + ) + + async def _memory_update_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + ok = await memory_storage.update(target_uuid, fact) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + + async def _memory_delete_handler(self, request: web.Request) -> Response: + memory_storage = getattr(self._ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + ok = await memory_storage.delete(target_uuid) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "deleted": True}) + async def _meme_list_handler(self, request: web.Request) -> Response: meme_service = self._ctx.meme_service if meme_service is None or not meme_service.enabled: diff --git a/src/Undefined/webui/routes/_runtime.py b/src/Undefined/webui/routes/_runtime.py index c6de3eb..371c46b 100644 --- a/src/Undefined/webui/routes/_runtime.py +++ b/src/Undefined/webui/routes/_runtime.py @@ -296,6 +296,51 @@ async def runtime_memory_handler(request: web.Request) -> Response: ) +@routes.post("/api/v1/management/runtime/memory") +@routes.post("/api/runtime/memory") +async def runtime_memory_create_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="POST", + path="/api/v1/memory", + payload=payload, + ) + + +@routes.patch("/api/v1/management/runtime/memory/{uuid}") +@routes.patch("/api/runtime/memory/{uuid}") +async def runtime_memory_update_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + try: + payload = await request.json() + except (json.JSONDecodeError, UnicodeDecodeError, ValueError): + return web.json_response({"error": "Invalid JSON payload"}, status=400) + return await _proxy_runtime( + method="PATCH", + path=f"/api/v1/memory/{target_uuid}", + payload=payload, + ) + + +@routes.delete("/api/v1/management/runtime/memory/{uuid}") +@routes.delete("/api/runtime/memory/{uuid}") +async def runtime_memory_delete_handler(request: web.Request) -> Response: + if not check_auth(request): + return _unauthorized() + target_uuid = _url_quote(str(request.match_info.get("uuid", "")).strip(), safe="") + return await _proxy_runtime( + method="DELETE", + path=f"/api/v1/memory/{target_uuid}", + ) + + @routes.get("/api/v1/management/runtime/cognitive/events") @routes.get("/api/runtime/cognitive/events") async def runtime_cognitive_events_handler(request: web.Request) -> Response: diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 7118ef6..4d4c7f4 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -428,6 +428,61 @@ white-space: pre-wrap; word-break: break-word; } +/* Memory CRUD */ +.memory-create-form { + display: flex; + gap: 8px; + align-items: flex-start; + padding: 8px 0; +} +.memory-create-form textarea { + flex: 1; + resize: vertical; + min-height: 40px; +} +.memory-create-actions { + flex-shrink: 0; + align-self: flex-end; +} +.memory-item-actions { + display: flex; + gap: 4px; + flex-shrink: 0; + margin-left: 8px; +} +.memory-item-actions button { + background: none; + border: 1px solid var(--border-color); + border-radius: var(--radius-sm); + cursor: pointer; + padding: 2px 6px; + font-size: 12px; + color: var(--text-tertiary); + transition: color 0.15s, border-color 0.15s; +} +.memory-item-actions button:hover { + color: var(--accent-color); + border-color: var(--accent-color); +} +.memory-item-actions button.memory-btn-delete:hover { + color: var(--error); + border-color: var(--error); +} +.memory-edit-area { + width: 100%; + min-height: 60px; + font-size: 14px; + line-height: 1.6; + font-family: inherit; + resize: vertical; + margin-top: 4px; +} +.memory-edit-actions { + display: flex; + gap: 6px; + margin-top: 6px; + justify-content: flex-end; +} .runtime-doc { font-size: 13px; line-height: 1.6; diff --git a/src/Undefined/webui/static/js/api.js b/src/Undefined/webui/static/js/api.js index 8ab88a3..82b427c 100644 --- a/src/Undefined/webui/static/js/api.js +++ b/src/Undefined/webui/static/js/api.js @@ -31,7 +31,13 @@ function shouldRetryCandidate(res) { async function requestOnce(path, options = {}) { const headers = { ...(options.headers || {}) }; - if (options.method === "POST" && options.body && !headers["Content-Type"]) { + const needsJson = + options.body && + !headers["Content-Type"] && + ["POST", "PATCH", "PUT", "DELETE"].includes( + String(options.method || "").toUpperCase(), + ); + if (needsJson) { headers["Content-Type"] = "application/json"; } if (state.authAccessToken && !headers.Authorization) { diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index 6968c0f..42655a5 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -636,6 +636,8 @@ if (buffer.trim()) emitBlock(buffer); } + let _memoryMutating = false; + function renderMemoryItems(payload) { const container = get("runtimeMemoryList"); const meta = get("runtimeMemoryMeta"); @@ -666,9 +668,155 @@ const uuid = escapeHtml(item.uuid || ""); const fact = escapeHtml(item.fact || ""); const created = escapeHtml(item.created_at || ""); - return `
${uuid}${created}
${fact}
`; + return `
${uuid}
${created}
${fact}
`; }) .join(""); + + container.querySelectorAll(".memory-btn-edit").forEach((btn) => { + btn.addEventListener("click", () => + startEditMemory(btn.dataset.uuid), + ); + }); + container.querySelectorAll(".memory-btn-delete").forEach((btn) => { + btn.addEventListener("click", () => deleteMemory(btn.dataset.uuid)); + }); + } + + function startEditMemory(uuid) { + const container = get("runtimeMemoryList"); + if (!container) return; + const itemEl = container.querySelector( + `.runtime-list-item[data-uuid="${CSS.escape(uuid)}"]`, + ); + if (!itemEl) return; + const factEl = itemEl.querySelector(".runtime-list-fact"); + if (!factEl || factEl.dataset.editing === "true") return; + + const currentText = factEl.textContent || ""; + factEl.dataset.editing = "true"; + factEl.innerHTML = ""; + + const textarea = document.createElement("textarea"); + textarea.className = "form-control memory-edit-area"; + textarea.value = currentText; + + const actions = document.createElement("div"); + actions.className = "memory-edit-actions"; + const saveBtn = document.createElement("button"); + saveBtn.className = "btn btn-sm"; + saveBtn.textContent = "保存"; + const cancelBtn = document.createElement("button"); + cancelBtn.className = "btn btn-sm"; + cancelBtn.textContent = "取消"; + actions.append(saveBtn, cancelBtn); + factEl.append(textarea, actions); + textarea.focus(); + + cancelBtn.addEventListener("click", () => { + delete factEl.dataset.editing; + factEl.innerHTML = ""; + factEl.textContent = currentText; + }); + + saveBtn.addEventListener("click", () => + updateMemory(uuid, textarea.value), + ); + + textarea.addEventListener("keydown", (e) => { + if (e.key === "Escape") { + e.preventDefault(); + cancelBtn.click(); + } + if (e.key === "Enter" && e.ctrlKey) { + e.preventDefault(); + saveBtn.click(); + } + }); + } + + async function createMemory() { + if (_memoryMutating) return; + const input = get("memoryCreateInput"); + if (!input) return; + const fact = String(input.value || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + _memoryMutating = true; + const btn = get("btnMemoryCreate"); + if (btn) btn.disabled = true; + try { + const res = await api("/api/runtime/memory", { + method: "POST", + body: JSON.stringify({ fact }), + }); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已添加", "success"); + input.value = ""; + await searchMemory(); + } catch (err) { + showToast(`添加失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + if (btn) btn.disabled = false; + } + } + + async function updateMemory(uuid, newFact) { + const fact = String(newFact || "").trim(); + if (!fact) { + showToast("记忆内容不能为空", "warning"); + return; + } + if (_memoryMutating) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "PATCH", + body: JSON.stringify({ fact }), + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已更新", "success"); + await searchMemory(); + } catch (err) { + showToast(`更新失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } + } + + async function deleteMemory(uuid) { + if (_memoryMutating) return; + if (!confirm(`确认删除记忆 ${uuid.slice(0, 8)}…?`)) return; + _memoryMutating = true; + try { + const res = await api( + `/api/runtime/memory/${encodeURIComponent(uuid)}`, + { + method: "DELETE", + }, + ); + const data = await parseJsonSafe(res); + if (!res.ok || (data && data.error)) { + throw new Error(buildRequestError(res, data)); + } + showToast("记忆已删除", "success"); + await searchMemory(); + } catch (err) { + showToast(`删除失败: ${err.message || err}`, "error"); + } finally { + _memoryMutating = false; + } } function setListMessage(metaId, listId, message) { @@ -1204,6 +1352,10 @@ if (memoryRefresh) memoryRefresh.addEventListener("click", refreshMemory); + const memoryCreateBtn = get("btnMemoryCreate"); + if (memoryCreateBtn) + memoryCreateBtn.addEventListener("click", createMemory); + const runMemorySearch = () => runQueryAction("memory", "btnRuntimeMemorySearch", searchMemory); const runEventsSearch = () => diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 3f4c8a9..3c9886f 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -446,7 +446,7 @@

探针

记忆检索

-

只读检索记忆、认知事件与侧写。

+

管理长期记忆,检索认知事件与侧写。

@@ -535,7 +535,7 @@

记忆检索

-
长期记忆查询
+
长期记忆管理
记忆检索
+ +
+ +
+ +
+
From 9d7415605c01e6a83b090467498aa61a1a3cf250 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 09:32:47 +0800 Subject: [PATCH 22/57] =?UTF-8?q?fix(profile):=20=E4=BF=AE=E5=A4=8D=20-r?= =?UTF-8?q?=20=E6=B8=B2=E6=9F=93=E7=95=99=E7=99=BD=E5=92=8C=E5=AD=97?= =?UTF-8?q?=E5=B0=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - render_html_to_image 新增 viewport_width 参数,默认 1280 不影响其他调用 - profile 渲染使用 480px 窄视口,生成适合手机查看的长图 - 移除 max-width/margin:auto 居中,改为 width:100% 填满视口 - 字号从 12px/14px 提升到 14px/15px,行高 1.8 - 减少 padding 节省空间 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/render.py | 13 +++++++++++-- src/Undefined/skills/commands/profile/handler.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/Undefined/render.py b/src/Undefined/render.py index dba3e13..3678ce1 100644 --- a/src/Undefined/render.py +++ b/src/Undefined/render.py @@ -91,19 +91,28 @@ def _parse() -> str: return full_html -async def render_html_to_image(html_content: str, output_path: str) -> None: +async def render_html_to_image( + html_content: str, + output_path: str, + *, + viewport_width: int = 1280, +) -> None: """ 将 HTML 字符串转换为 PNG 图片 参数: html_content: 完整的 HTML 字符串 output_path: 输出图片路径 (例如 'result.png') + viewport_width: 视口宽度(像素),默认 1280 """ async with async_playwright() as p: # 启动无头浏览器 browser = await p.chromium.launch(headless=True) # 设置上下文,可以指定缩放比例(device_scale_factor),2代表2倍清晰度(Retina) - context = await browser.new_context(device_scale_factor=2) + context = await browser.new_context( + device_scale_factor=2, + viewport={"width": viewport_width, "height": 800}, + ) page = await context.new_page() # 设置页面内容 diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 6d6f473..4040ab2 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -129,28 +129,28 @@ async def _send_render( * {{ margin: 0; padding: 0; box-sizing: border-box; }} body {{ font-family: 'Microsoft YaHei', 'PingFang SC', 'Noto Sans CJK SC', sans-serif; - background: #f9f5f1; color: #3d3935; padding: 24px; + background: #f9f5f1; color: #3d3935; padding: 16px; }} .card {{ - max-width: 680px; margin: 0 auto; + width: 100%; background: #fff; border-radius: 10px; border: 1px solid #e6e0d8; overflow: hidden; }} .meta {{ background: #f9f5f1; border-bottom: 1px solid #e6e0d8; - padding: 16px 20px; + padding: 14px 18px; }} .meta table {{ border-collapse: collapse; }} .mk {{ - font-size: 12px; color: #6e675f; padding: 2px 10px 2px 0; - white-space: nowrap; vertical-align: top; + font-size: 14px; color: #6e675f; padding: 3px 12px 3px 0; + white-space: nowrap; vertical-align: top; font-weight: 600; }} .mv {{ - font-size: 12px; color: #3d3935; padding: 2px 0; + font-size: 14px; color: #3d3935; padding: 3px 0; }} .body {{ - padding: 20px; line-height: 1.75; font-size: 14px; + padding: 18px; line-height: 1.8; font-size: 15px; white-space: pre-wrap; word-wrap: break-word; }} @@ -163,7 +163,7 @@ async def _send_render( output_dir = ensure_dir(RENDER_CACHE_DIR) output_path = str(output_dir / f"profile_{uuid.uuid4().hex[:8]}.png") - await render_html_to_image(html_content, output_path) + await render_html_to_image(html_content, output_path, viewport_width=480) abs_path = Path(output_path).resolve() image_cq = f"[CQ:image,file=file://{abs_path}]" From fab3d10520096d291825d3b952e44790e1502a19 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 09:48:49 +0800 Subject: [PATCH 23/57] =?UTF-8?q?test:=20=E5=85=A8=E9=9D=A2=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=E8=A6=86=E7=9B=96?= =?UTF-8?q?=20(804=20=E2=86=92=201423)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 27 个测试文件,619 个测试用例,覆盖以下未测试模块: 纯工具函数: - utils/xml, cors, time_utils, message_targets, group_metrics - utils/request_params, member_utils, common, tool_calls - utils/message_utils, fake_at, cache AI/Skills: - ai/parsing, ai/tokens, ai/queue_budget - skills/http_config, http_client, registry (SkillStats) - context (RequestContext + helpers) 存储/数据: - memory (MemoryStorage CRUD + 去重 + 上限) - faq (FAQ dataclass + FAQStorage CRUD) - rate_limit (RateLimiter 分角色限流) - end_summary_storage, token_usage_storage, scheduled_task_storage - config/models (format_netloc, resolve_bind_hosts) - utils/qq_emoji 所有 1423 测试通过 ruff + mypy strict(新增文件零错误) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- tests/test_ai_parsing.py | 101 +++++++++ tests/test_ai_queue_budget.py | 212 +++++++++++++++++++ tests/test_ai_tokens.py | 83 ++++++++ tests/test_cache_cleanup.py | 105 ++++++++++ tests/test_config_models.py | 53 +++++ tests/test_context.py | 195 ++++++++++++++++++ tests/test_cors_utils.py | 137 +++++++++++++ tests/test_end_summary_storage.py | 124 +++++++++++ tests/test_fake_at.py | 179 ++++++++++++++++ tests/test_faq_unit.py | 305 ++++++++++++++++++++++++++++ tests/test_group_metrics.py | 190 +++++++++++++++++ tests/test_member_utils.py | 196 ++++++++++++++++++ tests/test_memory_unit.py | 229 +++++++++++++++++++++ tests/test_message_targets.py | 185 +++++++++++++++++ tests/test_message_utils.py | 232 +++++++++++++++++++++ tests/test_qq_emoji.py | 142 +++++++++++++ tests/test_rate_limit.py | 267 ++++++++++++++++++++++++ tests/test_request_params.py | 141 +++++++++++++ tests/test_scheduled_task_unit.py | 192 +++++++++++++++++ tests/test_skills_http_client.py | 76 +++++++ tests/test_skills_http_config.py | 98 +++++++++ tests/test_skills_registry_stats.py | 100 +++++++++ tests/test_time_utils.py | 84 ++++++++ tests/test_token_usage_unit.py | 169 +++++++++++++++ tests/test_tool_calls.py | 273 +++++++++++++++++++++++++ tests/test_utils_common.py | 305 ++++++++++++++++++++++++++++ tests/test_xml_utils.py | 100 +++++++++ 27 files changed, 4473 insertions(+) create mode 100644 tests/test_ai_parsing.py create mode 100644 tests/test_ai_queue_budget.py create mode 100644 tests/test_ai_tokens.py create mode 100644 tests/test_cache_cleanup.py create mode 100644 tests/test_config_models.py create mode 100644 tests/test_context.py create mode 100644 tests/test_cors_utils.py create mode 100644 tests/test_end_summary_storage.py create mode 100644 tests/test_fake_at.py create mode 100644 tests/test_faq_unit.py create mode 100644 tests/test_group_metrics.py create mode 100644 tests/test_member_utils.py create mode 100644 tests/test_memory_unit.py create mode 100644 tests/test_message_targets.py create mode 100644 tests/test_message_utils.py create mode 100644 tests/test_qq_emoji.py create mode 100644 tests/test_rate_limit.py create mode 100644 tests/test_request_params.py create mode 100644 tests/test_scheduled_task_unit.py create mode 100644 tests/test_skills_http_client.py create mode 100644 tests/test_skills_http_config.py create mode 100644 tests/test_skills_registry_stats.py create mode 100644 tests/test_time_utils.py create mode 100644 tests/test_token_usage_unit.py create mode 100644 tests/test_tool_calls.py create mode 100644 tests/test_utils_common.py create mode 100644 tests/test_xml_utils.py diff --git a/tests/test_ai_parsing.py b/tests/test_ai_parsing.py new file mode 100644 index 0000000..ff7015b --- /dev/null +++ b/tests/test_ai_parsing.py @@ -0,0 +1,101 @@ +"""Tests for Undefined.ai.parsing module.""" + +from __future__ import annotations + +import pytest + +from Undefined.ai.parsing import extract_choices_content + + +class TestExtractChoicesContent: + """Tests for extract_choices_content().""" + + def test_standard_response(self) -> None: + result: dict[str, object] = { + "choices": [{"message": {"content": "Hello, world!"}}] + } + assert extract_choices_content(result) == "Hello, world!" + + def test_data_wrapped_response(self) -> None: + result: dict[str, object] = { + "data": {"choices": [{"message": {"content": "nested content"}}]} + } + assert extract_choices_content(result) == "nested content" + + def test_output_text_field(self) -> None: + result: dict[str, object] = { + "output_text": "direct output", + "choices": [{"message": {"content": "ignored"}}], + } + assert extract_choices_content(result) == "direct output" + + def test_output_text_preferred_over_choices(self) -> None: + result: dict[str, object] = {"output_text": "preferred"} + assert extract_choices_content(result) == "preferred" + + def test_output_text_non_string_falls_through(self) -> None: + result: dict[str, object] = { + "output_text": 42, + "choices": [{"message": {"content": "fallback"}}], + } + assert extract_choices_content(result) == "fallback" + + def test_empty_choices_raises(self) -> None: + result: dict[str, object] = {"choices": []} + with pytest.raises(KeyError): + extract_choices_content(result) + + def test_no_choices_key_raises(self) -> None: + result: dict[str, object] = {"id": "123", "object": "chat.completion"} + with pytest.raises(KeyError): + extract_choices_content(result) + + def test_no_content_in_message(self) -> None: + result: dict[str, object] = {"choices": [{"message": {}}]} + assert extract_choices_content(result) == "" + + def test_message_is_none(self) -> None: + """message=None triggers AttributeError in tool_calls check (known bug).""" + result: dict[str, object] = {"choices": [{"message": None}]} + with pytest.raises(AttributeError): + extract_choices_content(result) + + def test_choice_with_content_directly(self) -> None: + result: dict[str, object] = {"choices": [{"content": "direct"}]} + assert extract_choices_content(result) == "direct" + + def test_tool_calls_no_content(self) -> None: + result: dict[str, object] = { + "choices": [{"message": {"tool_calls": [{"function": {"name": "test"}}]}}] + } + assert extract_choices_content(result) == "" + + def test_refusal_field_content_still_extracted(self) -> None: + result: dict[str, object] = { + "choices": [ + { + "message": { + "content": "I can help with that.", + "refusal": None, + } + } + ] + } + assert extract_choices_content(result) == "I can help with that." + + def test_multiple_choices_returns_first(self) -> None: + result: dict[str, object] = { + "choices": [ + {"message": {"content": "first"}}, + {"message": {"content": "second"}}, + ] + } + assert extract_choices_content(result) == "first" + + def test_empty_dict_raises(self) -> None: + with pytest.raises(KeyError): + extract_choices_content({}) + + def test_message_is_string(self) -> None: + result: dict[str, object] = {"choices": [{"message": "plain string"}]} + assert extract_choices_content(result) == "plain string" diff --git a/tests/test_ai_queue_budget.py b/tests/test_ai_queue_budget.py new file mode 100644 index 0000000..001dec5 --- /dev/null +++ b/tests/test_ai_queue_budget.py @@ -0,0 +1,212 @@ +"""Tests for Undefined.ai.queue_budget module.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from Undefined.ai.queue_budget import ( + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS, + QUEUED_LLM_TIMEOUT_GRACE_SECONDS, + compute_queued_llm_timeout_seconds, + resolve_effective_retry_count, +) + + +class TestResolveEffectiveRetryCount: + """Tests for resolve_effective_retry_count().""" + + def test_from_runtime_config(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + assert resolve_effective_retry_count(cfg) == 3 + + def test_from_queue_manager(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=5) + qm = SimpleNamespace(get_max_retries=lambda: 2) + assert resolve_effective_retry_count(cfg, qm) == 2 + + def test_queue_manager_takes_precedence(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=10) + qm = SimpleNamespace(get_max_retries=lambda: 1) + assert resolve_effective_retry_count(cfg, qm) == 1 + + def test_negative_retries_clamped_to_zero(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=-5) + assert resolve_effective_retry_count(cfg) == 0 + + def test_none_retries_returns_zero(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=None) + assert resolve_effective_retry_count(cfg) == 0 + + def test_missing_attribute_returns_zero(self) -> None: + cfg = SimpleNamespace() + assert resolve_effective_retry_count(cfg) == 0 + + def test_queue_manager_invalid_return(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + qm = SimpleNamespace(get_max_retries=lambda: "invalid") + assert resolve_effective_retry_count(cfg, qm) == 0 + + def test_queue_manager_negative_clamped(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=3) + qm = SimpleNamespace(get_max_retries=lambda: -1) + assert resolve_effective_retry_count(cfg, qm) == 0 + + def test_queue_manager_none_no_get_max_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=4) + qm = SimpleNamespace() + assert resolve_effective_retry_count(cfg, qm) == 4 + + def test_zero_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + assert resolve_effective_retry_count(cfg) == 0 + + +class TestComputeQueuedLlmTimeoutSeconds: + """Tests for compute_queued_llm_timeout_seconds().""" + + def _make_model_config( + self, + interval: float = 0.0, + pool_enabled: bool = False, + pool_models: list[SimpleNamespace] | None = None, + ) -> SimpleNamespace: + pool = SimpleNamespace( + enabled=pool_enabled, + models=pool_models or [], + ) + return SimpleNamespace(queue_interval_seconds=interval, pool=pool) + + def test_defaults_zero_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_with_retries(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=2) + model_cfg = self._make_model_config(interval=1.0) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # attempts=3, dispatch_intervals=3 (2 retries + 1 first) + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 3 + + 1.0 * 3 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_explicit_retry_count(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=99) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, retry_count=1) + # explicit retry_count=1 overrides config + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 2 + + 0.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_initial_wait(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, initial_wait_seconds=10.0 + ) + expected = ( + 10.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_no_first_dispatch_interval(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=2) + model_cfg = self._make_model_config(interval=5.0) + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, include_first_dispatch_interval=False + ) + # dispatch_intervals = retries only = 2 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 3 + + 5.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_custom_attempt_timeout(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds( + cfg, model_cfg, attempt_timeout_seconds=60.0 + ) + expected = 0.0 + 60.0 * 1 + 0.0 * 1 + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + assert result == expected + + def test_custom_grace_seconds(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, grace_seconds=100.0) + expected = ( + 0.0 + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + 0.0 * 1 + 100.0 + ) + assert result == expected + + def test_pool_models_max_interval(self) -> None: + pool_models = [ + SimpleNamespace(queue_interval_seconds=2.0), + SimpleNamespace(queue_interval_seconds=5.0), + SimpleNamespace(queue_interval_seconds=1.0), + ] + cfg = SimpleNamespace(ai_request_max_retries=1) + model_cfg = self._make_model_config( + interval=3.0, pool_enabled=True, pool_models=pool_models + ) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # max interval = max(3.0, 2.0, 5.0, 1.0) = 5.0 + # attempts=2, dispatch=2 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 2 + + 5.0 * 2 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_pool_disabled_ignores_pool_models(self) -> None: + pool_models = [SimpleNamespace(queue_interval_seconds=100.0)] + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config( + interval=1.0, pool_enabled=False, pool_models=pool_models + ) + result = compute_queued_llm_timeout_seconds(cfg, model_cfg) + # pool disabled: only base interval=1.0 + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 1.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected + + def test_negative_retry_count_clamped(self) -> None: + cfg = SimpleNamespace(ai_request_max_retries=0) + model_cfg = self._make_model_config() + result = compute_queued_llm_timeout_seconds(cfg, model_cfg, retry_count=-5) + # clamped to 0 → 1 attempt + expected = ( + 0.0 + + DEFAULT_QUEUED_LLM_ATTEMPT_TIMEOUT_SECONDS * 1 + + 0.0 * 1 + + QUEUED_LLM_TIMEOUT_GRACE_SECONDS + ) + assert result == expected diff --git a/tests/test_ai_tokens.py b/tests/test_ai_tokens.py new file mode 100644 index 0000000..bd70165 --- /dev/null +++ b/tests/test_ai_tokens.py @@ -0,0 +1,83 @@ +"""Tests for Undefined.ai.tokens module.""" + +from __future__ import annotations + +from unittest.mock import patch + +from Undefined.ai.tokens import TokenCounter + + +class TestTokenCounter: + """Tests for TokenCounter.""" + + def test_empty_string(self) -> None: + counter = TokenCounter() + result = counter.count("") + assert result == 0 or isinstance(result, int) + + def test_normal_text(self) -> None: + counter = TokenCounter() + result = counter.count("Hello, world!") + assert result > 0 + + def test_unicode_text(self) -> None: + counter = TokenCounter() + result = counter.count("你好世界!🌍") + assert result > 0 + + def test_long_text(self) -> None: + counter = TokenCounter() + short_count = counter.count("hello") + long_count = counter.count("hello " * 1000) + assert long_count > short_count + + def test_whitespace_only(self) -> None: + counter = TokenCounter() + result = counter.count(" \n\t ") + assert isinstance(result, int) + + def test_single_character(self) -> None: + counter = TokenCounter() + result = counter.count("a") + assert result >= 1 + + def test_fallback_when_tiktoken_unavailable(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + result = counter.count("hello world") + expected = len("hello world") // 3 + 1 + assert result == expected + + def test_fallback_empty_string(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + result = counter.count("") + assert result == 1 # len("") // 3 + 1 == 1 + + def test_fallback_short_text(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + assert counter.count("ab") == 1 # 2 // 3 + 1 + + def test_fallback_exact_multiple(self) -> None: + counter = TokenCounter() + counter._tokenizer = None + assert counter.count("abc") == 2 # 3 // 3 + 1 + + def test_default_model_name(self) -> None: + counter = TokenCounter() + assert counter._model_name == "gpt-4" + + def test_custom_model_name(self) -> None: + counter = TokenCounter(model_name="gpt-3.5-turbo") + assert counter._model_name == "gpt-3.5-turbo" + + def test_tiktoken_load_failure_graceful(self) -> None: + with patch("builtins.__import__", side_effect=ImportError("no tiktoken")): + counter = TokenCounter.__new__(TokenCounter) + counter._model_name = "gpt-4" + counter._tokenizer = None + counter._try_load_tokenizer() + assert counter._tokenizer is None + result = counter.count("test text") + assert result == len("test text") // 3 + 1 diff --git a/tests/test_cache_cleanup.py b/tests/test_cache_cleanup.py new file mode 100644 index 0000000..72a0f08 --- /dev/null +++ b/tests/test_cache_cleanup.py @@ -0,0 +1,105 @@ +"""Tests for Undefined.utils.cache.cleanup_cache_dir.""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +from Undefined.utils.cache import cleanup_cache_dir + + +class TestCleanupCacheDir: + def test_empty_dir(self, tmp_path: Path) -> None: + assert cleanup_cache_dir(tmp_path) == 0 + + def test_old_files_removed(self, tmp_path: Path) -> None: + old_file = tmp_path / "old.txt" + old_file.write_text("data") + # Set mtime to 30 days ago + old_time = time.time() - 30 * 24 * 3600 + os.utime(old_file, (old_time, old_time)) + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 1 + assert not old_file.exists() + + def test_new_files_kept(self, tmp_path: Path) -> None: + new_file = tmp_path / "new.txt" + new_file.write_text("data") + # mtime = now (default), so it's fresh + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 0 + assert new_file.exists() + + def test_max_files_cap(self, tmp_path: Path) -> None: + now = time.time() + for i in range(5): + f = tmp_path / f"file_{i}.txt" + f.write_text("data") + os.utime(f, (now - i, now - i)) # stagger mtime + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=3) + assert deleted == 2 + remaining = list(tmp_path.iterdir()) + assert len(remaining) == 3 + + def test_nonexistent_dir_created(self, tmp_path: Path) -> None: + new_dir = tmp_path / "subdir" / "cache" + assert not new_dir.exists() + deleted = cleanup_cache_dir(new_dir) + assert deleted == 0 + assert new_dir.is_dir() + + def test_mixed_ages(self, tmp_path: Path) -> None: + now = time.time() + # 1 old, 2 fresh + old_f = tmp_path / "old.txt" + old_f.write_text("old") + os.utime(old_f, (now - 999999, now - 999999)) + + for i in range(2): + f = tmp_path / f"fresh_{i}.txt" + f.write_text("fresh") + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=7 * 24 * 3600) + assert deleted == 1 + assert not old_f.exists() + + def test_zero_max_age_skips_age_check(self, tmp_path: Path) -> None: + old_file = tmp_path / "old.txt" + old_file.write_text("data") + old_time = time.time() - 999999 + os.utime(old_file, (old_time, old_time)) + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=0) + assert deleted == 0 + assert old_file.exists() + + def test_zero_max_files_skips_cap(self, tmp_path: Path) -> None: + for i in range(10): + (tmp_path / f"f{i}.txt").write_text("x") + + deleted = cleanup_cache_dir(tmp_path, max_age_seconds=0, max_files=0) + assert deleted == 0 + assert len(list(tmp_path.iterdir())) == 10 + + def test_both_age_and_cap(self, tmp_path: Path) -> None: + now = time.time() + # Create 5 files: 2 old (removed by age), 3 fresh + for i in range(2): + f = tmp_path / f"old_{i}.txt" + f.write_text("old") + os.utime(f, (now - 999999, now - 999999)) + for i in range(3): + f = tmp_path / f"new_{i}.txt" + f.write_text("new") + os.utime(f, (now - i, now - i)) + + deleted = cleanup_cache_dir( + tmp_path, max_age_seconds=7 * 24 * 3600, max_files=2 + ) + # 2 removed by age + 1 removed by cap = 3 + assert deleted == 3 + assert len(list(tmp_path.iterdir())) == 2 diff --git a/tests/test_config_models.py b/tests/test_config_models.py new file mode 100644 index 0000000..b4bf05f --- /dev/null +++ b/tests/test_config_models.py @@ -0,0 +1,53 @@ +"""Tests for Undefined.config.models — config model helpers.""" + +from __future__ import annotations + +from Undefined.config.models import format_netloc, resolve_bind_hosts + + +class TestFormatNetloc: + def test_ipv4(self) -> None: + assert format_netloc("127.0.0.1", 8080) == "127.0.0.1:8080" + + def test_hostname(self) -> None: + assert format_netloc("example.com", 443) == "example.com:443" + + def test_ipv6_wrapped(self) -> None: + assert format_netloc("::1", 8080) == "[::1]:8080" + + def test_ipv6_full(self) -> None: + result = format_netloc("2001:db8::1", 9090) + assert result == "[2001:db8::1]:9090" + + def test_ipv6_all_zeros(self) -> None: + assert format_netloc("::", 80) == "[::]:80" + + def test_ipv4_default_port(self) -> None: + assert format_netloc("0.0.0.0", 80) == "0.0.0.0:80" + + def test_localhost(self) -> None: + assert format_netloc("localhost", 3000) == "localhost:3000" + + def test_empty_host(self) -> None: + # No colon in empty string → treated as IPv4-style + assert format_netloc("", 8080) == ":8080" + + +class TestResolveBindHosts: + def test_empty_string(self) -> None: + assert resolve_bind_hosts("") == ["0.0.0.0", "::"] + + def test_double_colon(self) -> None: + assert resolve_bind_hosts("::") == ["0.0.0.0", "::"] + + def test_ipv4_any(self) -> None: + assert resolve_bind_hosts("0.0.0.0") == ["0.0.0.0"] + + def test_specific_ipv4(self) -> None: + assert resolve_bind_hosts("127.0.0.1") == ["127.0.0.1"] + + def test_specific_ipv6(self) -> None: + assert resolve_bind_hosts("::1") == ["::1"] + + def test_hostname(self) -> None: + assert resolve_bind_hosts("myhost.local") == ["myhost.local"] diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..e68a0d8 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,195 @@ +"""Tests for Undefined.context module.""" + +from __future__ import annotations + +import logging + +import pytest + +from Undefined.context import ( + RequestContext, + RequestContextFilter, + get_group_id, + get_request_id, + get_request_type, + get_sender_id, + get_user_id, +) + + +class TestRequestContextManager: + """Tests for RequestContext async context manager.""" + + async def test_enter_sets_current(self) -> None: + async with RequestContext(request_type="group", group_id=123) as ctx: + assert RequestContext.current() is ctx + + async def test_exit_clears_current(self) -> None: + async with RequestContext(request_type="group"): + pass + assert RequestContext.current() is None + + async def test_request_id_generated(self) -> None: + async with RequestContext(request_type="private") as ctx: + assert ctx.request_id is not None + assert len(ctx.request_id) > 0 + + async def test_request_id_is_uuid(self) -> None: + import uuid + + async with RequestContext(request_type="private") as ctx: + uuid.UUID(ctx.request_id) + + async def test_nested_contexts(self) -> None: + async with RequestContext(request_type="group", group_id=1) as outer: + assert RequestContext.current() is outer + async with RequestContext(request_type="private", user_id=99) as inner: + assert RequestContext.current() is inner + assert inner.user_id == 99 + assert RequestContext.current() is outer + assert outer.group_id == 1 + + async def test_metadata(self) -> None: + async with RequestContext(request_type="api", extra_key="value") as ctx: + assert ctx.metadata["extra_key"] == "value" + + +class TestRequestContextResources: + """Tests for resource management.""" + + async def test_set_and_get_resource(self) -> None: + async with RequestContext(request_type="group") as ctx: + ctx.set_resource("sender", {"name": "test"}) + assert ctx.get_resource("sender") == {"name": "test"} + + async def test_get_missing_resource_default(self) -> None: + async with RequestContext(request_type="group") as ctx: + assert ctx.get_resource("missing") is None + assert ctx.get_resource("missing", "fallback") == "fallback" + + async def test_resources_cleared_on_exit(self) -> None: + ctx = RequestContext(request_type="group") + async with ctx: + ctx.set_resource("key", "value") + assert ctx.get_resource("key") is None + + async def test_get_resources_returns_copy(self) -> None: + async with RequestContext(request_type="group") as ctx: + ctx.set_resource("a", 1) + ctx.set_resource("b", 2) + resources = ctx.get_resources() + assert resources == {"a": 1, "b": 2} + resources["c"] = 3 + assert ctx.get_resource("c") is None + + +class TestRequireContext: + """Tests for RequestContext.require().""" + + async def test_require_inside_context(self) -> None: + async with RequestContext(request_type="group") as ctx: + assert RequestContext.require() is ctx + + async def test_require_outside_context_raises(self) -> None: + with pytest.raises(RuntimeError): + RequestContext.require() + + +class TestHelperFunctions: + """Tests for module-level helper functions.""" + + async def test_get_group_id_inside_context(self) -> None: + async with RequestContext(request_type="group", group_id=42): + assert get_group_id() == 42 + + async def test_get_group_id_outside_context(self) -> None: + assert get_group_id() is None + + async def test_get_user_id_inside_context(self) -> None: + async with RequestContext(request_type="private", user_id=7): + assert get_user_id() == 7 + + async def test_get_user_id_outside_context(self) -> None: + assert get_user_id() is None + + async def test_get_request_id_inside_context(self) -> None: + async with RequestContext(request_type="group"): + rid = get_request_id() + assert rid is not None + assert len(rid) > 0 + + async def test_get_request_id_outside_context(self) -> None: + assert get_request_id() is None + + async def test_get_sender_id_inside_context(self) -> None: + async with RequestContext(request_type="group", sender_id=100): + assert get_sender_id() == 100 + + async def test_get_sender_id_outside_context(self) -> None: + assert get_sender_id() is None + + async def test_get_request_type_inside_context(self) -> None: + async with RequestContext(request_type="private"): + assert get_request_type() == "private" + + async def test_get_request_type_outside_context(self) -> None: + assert get_request_type() is None + + +class TestRequestContextFilter: + """Tests for RequestContextFilter logging filter.""" + + async def test_filter_with_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test message", + args=(), + exc_info=None, + ) + async with RequestContext( + request_type="group", group_id=10, user_id=20, sender_id=30 + ) as ctx: + result = filt.filter(record) + assert result is True + assert getattr(record, "request_id") == ctx.request_id[:8] + assert getattr(record, "group_id") == 10 + assert getattr(record, "user_id") == 20 + assert getattr(record, "sender_id") == 30 + + def test_filter_without_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + result = filt.filter(record) + assert result is True + assert getattr(record, "request_id") == "-" + assert getattr(record, "group_id") == "-" + assert getattr(record, "user_id") == "-" + assert getattr(record, "sender_id") == "-" + + async def test_filter_partial_context(self) -> None: + filt = RequestContextFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="test", + args=(), + exc_info=None, + ) + async with RequestContext(request_type="private", user_id=5): + filt.filter(record) + assert getattr(record, "group_id") == "-" + assert getattr(record, "user_id") == 5 diff --git a/tests/test_cors_utils.py b/tests/test_cors_utils.py new file mode 100644 index 0000000..efea14e --- /dev/null +++ b/tests/test_cors_utils.py @@ -0,0 +1,137 @@ +"""Tests for Undefined.utils.cors — CORS origin helpers.""" + +from __future__ import annotations + +from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin + + +class TestNormalizeOrigin: + def test_simple_origin(self) -> None: + assert normalize_origin("http://example.com") == "http://example.com" + + def test_trailing_slash(self) -> None: + assert normalize_origin("http://example.com/") == "http://example.com" + + def test_multiple_trailing_slashes(self) -> None: + assert normalize_origin("http://example.com///") == "http://example.com" + + def test_case_insensitive(self) -> None: + assert normalize_origin("HTTP://EXAMPLE.COM") == "http://example.com" + + def test_whitespace_stripped(self) -> None: + assert normalize_origin(" http://example.com ") == "http://example.com" + + def test_empty_string(self) -> None: + assert normalize_origin("") == "" + + def test_none_like_empty(self) -> None: + # The function casts to str via `str(origin or "")`. + assert normalize_origin("") == "" + + def test_with_port(self) -> None: + assert normalize_origin("http://localhost:8080/") == "http://localhost:8080" + + +class TestIsAllowedCorsOrigin: + def test_empty_origin_rejected(self) -> None: + assert is_allowed_cors_origin("") is False + + def test_whitespace_only_rejected(self) -> None: + assert is_allowed_cors_origin(" ") is False + + def test_localhost_http_allowed(self) -> None: + assert is_allowed_cors_origin("http://localhost") is True + + def test_localhost_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://localhost:3000") is True + + def test_localhost_https_allowed(self) -> None: + assert is_allowed_cors_origin("https://localhost") is True + + def test_ipv4_loopback_allowed(self) -> None: + assert is_allowed_cors_origin("http://127.0.0.1") is True + + def test_ipv4_loopback_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://127.0.0.1:8080") is True + + def test_ipv6_loopback_allowed(self) -> None: + assert is_allowed_cors_origin("http://[::1]") is True + + def test_ipv6_loopback_with_port_allowed(self) -> None: + assert is_allowed_cors_origin("http://[::1]:8080") is True + + def test_tauri_localhost_allowed(self) -> None: + assert is_allowed_cors_origin("tauri://localhost") is True + + def test_external_origin_rejected(self) -> None: + assert is_allowed_cors_origin("http://evil.com") is False + + def test_configured_host_allowed(self) -> None: + assert ( + is_allowed_cors_origin( + "http://myhost.local", + configured_host="myhost.local", + ) + is True + ) + + def test_configured_host_with_port(self) -> None: + assert ( + is_allowed_cors_origin( + "https://myhost.local:9090", + configured_host="myhost.local", + configured_port=9090, + ) + is True + ) + + def test_configured_host_wrong_port_rejected(self) -> None: + assert ( + is_allowed_cors_origin( + "http://myhost.local:1234", + configured_host="myhost.local", + configured_port=9090, + ) + is False + ) + + def test_extra_origins_allowed(self) -> None: + assert ( + is_allowed_cors_origin( + "https://cdn.example.com", + extra_origins={"https://cdn.example.com"}, + ) + is True + ) + + def test_extra_origins_case_insensitive(self) -> None: + assert ( + is_allowed_cors_origin( + "HTTPS://CDN.EXAMPLE.COM", + extra_origins={"https://cdn.example.com"}, + ) + is True + ) + + def test_extra_origins_not_matching_rejected(self) -> None: + assert ( + is_allowed_cors_origin( + "https://other.com", + extra_origins={"https://cdn.example.com"}, + ) + is False + ) + + def test_no_scheme_rejected(self) -> None: + # "example.com" without scheme is not a valid loopback HTTP origin + assert is_allowed_cors_origin("example.com") is False + + def test_ftp_scheme_rejected(self) -> None: + assert is_allowed_cors_origin("ftp://localhost") is False + + def test_configured_host_empty(self) -> None: + # Empty configured_host should not add anything + assert is_allowed_cors_origin("http://evil.com", configured_host="") is False + + def test_extra_origins_none(self) -> None: + assert is_allowed_cors_origin("http://localhost", extra_origins=None) is True diff --git a/tests/test_end_summary_storage.py b/tests/test_end_summary_storage.py new file mode 100644 index 0000000..aa3f4aa --- /dev/null +++ b/tests/test_end_summary_storage.py @@ -0,0 +1,124 @@ +"""EndSummaryStorage 单元测试""" + +from __future__ import annotations + +from typing import Any + + +from Undefined.end_summary_storage import ( + EndSummaryLocation, + EndSummaryStorage, +) + + +# --------------------------------------------------------------------------- +# make_record +# --------------------------------------------------------------------------- + + +class TestMakeRecord: + def test_basic(self) -> None: + record = EndSummaryStorage.make_record( + "summary text", "2025-01-01T00:00:00+08:00" + ) + assert record["summary"] == "summary text" + assert record["timestamp"] == "2025-01-01T00:00:00+08:00" + assert "location" not in record + + def test_strips_summary(self) -> None: + record = EndSummaryStorage.make_record(" spaces ", "ts") + assert record["summary"] == "spaces" + + def test_none_timestamp_auto_generates(self) -> None: + record = EndSummaryStorage.make_record("text", None) + assert record["timestamp"] # 非空 + assert "T" in record["timestamp"] # ISO 格式 + + def test_empty_timestamp_auto_generates(self) -> None: + record = EndSummaryStorage.make_record("text", " ") + assert record["timestamp"] + assert record["timestamp"].strip() != "" + + def test_with_location(self) -> None: + loc: EndSummaryLocation = {"type": "group", "name": "测试群"} + record = EndSummaryStorage.make_record("text", "ts", location=loc) + assert record.get("location") is not None + assert record["location"]["type"] == "group" + assert record["location"]["name"] == "测试群" + + def test_with_private_location(self) -> None: + loc: EndSummaryLocation = {"type": "private", "name": "好友"} + record = EndSummaryStorage.make_record("text", "ts", location=loc) + assert record["location"]["type"] == "private" + + def test_location_none_omitted(self) -> None: + record = EndSummaryStorage.make_record("text", "ts", location=None) + assert "location" not in record + + def test_invalid_location_type_ignored(self) -> None: + bad_loc: Any = {"type": "invalid", "name": "x"} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_missing_name_ignored(self) -> None: + bad_loc: Any = {"type": "group"} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_empty_name_ignored(self) -> None: + bad_loc: Any = {"type": "group", "name": " "} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_non_string_name_ignored(self) -> None: + bad_loc: Any = {"type": "group", "name": 123} + record = EndSummaryStorage.make_record("text", "ts", location=bad_loc) + assert "location" not in record + + def test_location_not_dict_ignored(self) -> None: + bad: Any = "bad" + record = EndSummaryStorage.make_record("text", "ts", location=bad) + assert "location" not in record + + +# --------------------------------------------------------------------------- +# _normalize_records +# --------------------------------------------------------------------------- + + +class TestNormalizeRecords: + def _storage(self) -> EndSummaryStorage: + return EndSummaryStorage() + + def test_none_returns_empty(self) -> None: + assert self._storage()._normalize_records(None) == [] + + def test_non_list_returns_empty(self) -> None: + assert self._storage()._normalize_records("not a list") == [] + + def test_string_items_converted(self) -> None: + records = self._storage()._normalize_records(["hello", "world"]) + assert len(records) == 2 + assert records[0]["summary"] == "hello" + + def test_empty_string_items_skipped(self) -> None: + records = self._storage()._normalize_records(["", " ", "valid"]) + assert len(records) == 1 + assert records[0]["summary"] == "valid" + + def test_dict_items_normalized(self) -> None: + data: list[dict[str, Any]] = [ + {"summary": "text", "timestamp": "2025-01-01"}, + ] + records = self._storage()._normalize_records(data) + assert len(records) == 1 + assert records[0]["summary"] == "text" + + def test_dict_missing_summary_skipped(self) -> None: + records = self._storage()._normalize_records([{"timestamp": "t"}]) + assert len(records) == 0 + + def test_max_records_trimmed(self) -> None: + data = [f"summary-{i}" for i in range(250)] + records = self._storage()._normalize_records(data) + assert len(records) == 200 # MAX_END_SUMMARIES diff --git a/tests/test_fake_at.py b/tests/test_fake_at.py new file mode 100644 index 0000000..f56efaf --- /dev/null +++ b/tests/test_fake_at.py @@ -0,0 +1,179 @@ +"""Tests for Undefined.utils.fake_at.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from Undefined.utils.fake_at import ( + BotNicknameCache, + _normalize, + _sorted_nicknames, + strip_fake_at, +) + + +# --------------------------------------------------------------------------- +# _normalize +# --------------------------------------------------------------------------- + + +class TestNormalize: + def test_fullwidth_at_to_halfwidth(self) -> None: + assert "@" in _normalize("@") + + def test_casefold(self) -> None: + assert _normalize("ABC") == "abc" + + def test_nfkc_normalization(self) -> None: + # Fullwidth letters → ASCII + assert _normalize("A") == "a" + + def test_combined(self) -> None: + result = _normalize("@Hello") + assert result == "@hello" + + +# --------------------------------------------------------------------------- +# _sorted_nicknames +# --------------------------------------------------------------------------- + + +class TestSortedNicknames: + def test_sorted_by_length_desc(self) -> None: + names = frozenset({"ab", "abcd", "a"}) + result = _sorted_nicknames(names) + assert result == ("abcd", "ab", "a") + + def test_empty(self) -> None: + assert _sorted_nicknames(frozenset()) == () + + +# --------------------------------------------------------------------------- +# strip_fake_at +# --------------------------------------------------------------------------- + + +class TestStripFakeAt: + def test_empty_nicknames(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset()) + assert hit is False + assert text == "@bot hello" + + def test_empty_text(self) -> None: + hit, text = strip_fake_at("", frozenset({"bot"})) + assert hit is False + assert text == "" + + def test_no_at_prefix(self) -> None: + hit, text = strip_fake_at("hello bot", frozenset({"bot"})) + assert hit is False + assert text == "hello bot" + + def test_simple_match(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_match_with_fullwidth_at(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_case_insensitive(self) -> None: + hit, text = strip_fake_at("@BOT hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + def test_longer_nickname_preferred(self) -> None: + nicks = frozenset({"bot", "bot助手"}) + hit, text = strip_fake_at("@bot助手 hello", nicks) + assert hit is True + assert text == "hello" + + def test_no_boundary_after_nickname(self) -> None: + hit, text = strip_fake_at("@botextrastuff", frozenset({"bot"})) + assert hit is False + assert text == "@botextrastuff" + + def test_boundary_punctuation(self) -> None: + hit, text = strip_fake_at("@bot,你好", frozenset({"bot"})) + assert hit is True + + def test_boundary_end_of_string(self) -> None: + hit, text = strip_fake_at("@bot", frozenset({"bot"})) + assert hit is True + assert text == "" + + def test_no_match_returns_original(self) -> None: + hit, text = strip_fake_at("@nobody hello", frozenset({"bot"})) + assert hit is False + assert text == "@nobody hello" + + def test_stripped_text_lstripped(self) -> None: + hit, text = strip_fake_at("@bot hello", frozenset({"bot"})) + assert hit is True + assert text == "hello" + + +# --------------------------------------------------------------------------- +# BotNicknameCache +# --------------------------------------------------------------------------- + + +class TestBotNicknameCache: + @pytest.fixture() + def mock_onebot(self) -> MagicMock: + ob = MagicMock() + ob.get_group_member_info = AsyncMock( + return_value={"card": "BotCard", "nickname": "BotNick"} + ) + return ob + + async def test_get_nicknames_fetches_and_caches( + self, mock_onebot: MagicMock + ) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(12345) + assert "botcard" in names + assert "botnick" in names + mock_onebot.get_group_member_info.assert_awaited_once_with(12345, 10000) + + async def test_get_nicknames_uses_cache(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(12345) + await cache.get_nicknames(12345) + # Should only call API once thanks to caching + mock_onebot.get_group_member_info.assert_awaited_once() + + async def test_invalidate_specific_group(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(12345) + cache.invalidate(12345) + await cache.get_nicknames(12345) + assert mock_onebot.get_group_member_info.await_count == 2 + + async def test_invalidate_all(self, mock_onebot: MagicMock) -> None: + cache = BotNicknameCache(mock_onebot, bot_qq=10000, ttl=600.0) + await cache.get_nicknames(111) + await cache.get_nicknames(222) + cache.invalidate() + await cache.get_nicknames(111) + # 111 fetched twice, 222 fetched once = 3 + assert mock_onebot.get_group_member_info.await_count == 3 + + async def test_api_failure_returns_empty(self) -> None: + ob: Any = MagicMock() + ob.get_group_member_info = AsyncMock(side_effect=RuntimeError("API error")) + cache = BotNicknameCache(ob, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(99999) + assert names == frozenset() + + async def test_empty_card_and_nickname(self) -> None: + ob: Any = MagicMock() + ob.get_group_member_info = AsyncMock(return_value={"card": "", "nickname": ""}) + cache = BotNicknameCache(ob, bot_qq=10000, ttl=60.0) + names = await cache.get_nicknames(123) + assert names == frozenset() diff --git a/tests/test_faq_unit.py b/tests/test_faq_unit.py new file mode 100644 index 0000000..ae6b268 --- /dev/null +++ b/tests/test_faq_unit.py @@ -0,0 +1,305 @@ +"""FAQ 存储管理 单元测试""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from Undefined.faq import FAQ, FAQStorage, extract_faq_title + +_WRITE_JSON = "Undefined.utils.io.write_json" +_READ_JSON = "Undefined.utils.io.read_json" +_DELETE_FILE = "Undefined.utils.io.delete_file" + + +# --------------------------------------------------------------------------- +# FAQ dataclass +# --------------------------------------------------------------------------- + + +class TestFAQDataclass: + def _sample(self) -> FAQ: + return FAQ( + id="20250101-001", + group_id=12345, + target_qq=67890, + start_time="2025-01-01T00:00:00", + end_time="2025-01-02T00:00:00", + created_at="2025-01-01T00:00:00", + title="测试标题", + content="测试内容", + ) + + def test_to_dict(self) -> None: + faq = self._sample() + d = faq.to_dict() + assert d["id"] == "20250101-001" + assert d["group_id"] == 12345 + assert d["title"] == "测试标题" + + def test_from_dict(self) -> None: + faq = self._sample() + d = faq.to_dict() + restored = FAQ.from_dict(d) + assert restored == faq + + def test_roundtrip(self) -> None: + faq = self._sample() + assert FAQ.from_dict(faq.to_dict()) == faq + + +# --------------------------------------------------------------------------- +# extract_faq_title +# --------------------------------------------------------------------------- + + +class TestExtractFaqTitle: + def test_extract_from_question_colon(self) -> None: + content = "**问题**: 如何重启服务?\n回答是这样的" + assert extract_faq_title(content) == "如何重启服务?" + + def test_extract_from_question_chinese_colon(self) -> None: + content = "**问题**:如何重启服务?\n回答是这样的" + assert extract_faq_title(content) == "如何重启服务?" + + def test_extract_truncates_long_title(self) -> None: + long_question = "x" * 200 + content = f"**问题**: {long_question}" + result = extract_faq_title(content) + assert len(result) <= 100 + + def test_extract_from_bug_section(self) -> None: + content = "## Bug 问题描述\n登录页面崩溃\n更多细节" + assert extract_faq_title(content) == "登录页面崩溃" + + def test_extract_bug_section_truncates(self) -> None: + long_desc = "y" * 200 + content = f"## Bug 问题描述\n{long_desc}" + result = extract_faq_title(content) + assert len(result) <= 100 + + def test_extract_bug_section_skips_heading(self) -> None: + content = "## Bug 问题描述\n# 子标题\n实际描述" + assert extract_faq_title(content) == "实际描述" + + def test_extract_no_match_returns_default(self) -> None: + content = "一段普通文本" + assert extract_faq_title(content) == "未命名问题" + + def test_extract_empty_content(self) -> None: + assert extract_faq_title("") == "未命名问题" + + def test_question_priority_over_bug(self) -> None: + content = "**问题**: 优先问题\n## Bug 问题描述\nbug 内容" + assert extract_faq_title(content) == "优先问题" + + +# --------------------------------------------------------------------------- +# FAQStorage +# --------------------------------------------------------------------------- + + +class TestFAQStorage: + def _make_storage(self) -> FAQStorage: + with patch.object(Path, "mkdir"): + return FAQStorage(base_dir="data/faq") + + @pytest.mark.asyncio + async def test_create(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=[]), + patch(_WRITE_JSON, new_callable=AsyncMock), + ): + faq = await storage.create( + group_id=100, + target_qq=200, + start_time="2025-01-01", + end_time="2025-01-02", + title="标题", + content="内容", + ) + assert faq.group_id == 100 + assert faq.title == "标题" + assert faq.id # 有生成 ID + + @pytest.mark.asyncio + async def test_get_existing(self) -> None: + storage = self._make_storage() + sample = FAQ( + id="20250101-001", + group_id=100, + target_qq=200, + start_time="s", + end_time="e", + created_at="c", + title="t", + content="body", + ) + with ( + patch.object(Path, "mkdir"), + patch(_READ_JSON, new_callable=AsyncMock, return_value=sample.to_dict()), + ): + result = await storage.get(100, "20250101-001") + assert result is not None + assert result.title == "t" + + @pytest.mark.asyncio + async def test_get_nonexistent(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_READ_JSON, new_callable=AsyncMock, return_value=None), + ): + result = await storage.get(100, "nonexist") + assert result is None + + @pytest.mark.asyncio + async def test_list_all(self) -> None: + storage = self._make_storage() + faq1 = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="t1", + content="c1", + ) + faq2 = FAQ( + id="002", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="t2", + content="c2", + ) + mock_files = [Path("a.json"), Path("b.json")] + results_iter = iter([faq1.to_dict(), faq2.to_dict()]) + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch( + _READ_JSON, + new_callable=AsyncMock, + side_effect=lambda *a, **kw: next(results_iter), + ), + ): + faqs = await storage.list_all(1) + + assert len(faqs) == 2 + + @pytest.mark.asyncio + async def test_search_matches(self) -> None: + storage = self._make_storage() + faq_match = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="Python 教程", + content="内容", + ) + faq_no_match = FAQ( + id="002", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="其他", + content="其他内容", + ) + mock_files = [Path("a.json"), Path("b.json")] + results_iter = iter([faq_match.to_dict(), faq_no_match.to_dict()]) + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch( + _READ_JSON, + new_callable=AsyncMock, + side_effect=lambda *a, **kw: next(results_iter), + ), + ): + results = await storage.search(1, "python") + + assert len(results) == 1 + assert results[0].title == "Python 教程" + + @pytest.mark.asyncio + async def test_search_case_insensitive(self) -> None: + storage = self._make_storage() + faq = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="UPPER", + content="body", + ) + mock_files = [MagicMock(spec=Path)] + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch(_READ_JSON, new_callable=AsyncMock, return_value=faq.to_dict()), + ): + results = await storage.search(1, "upper") + + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_search_in_content(self) -> None: + storage = self._make_storage() + faq = FAQ( + id="001", + group_id=1, + target_qq=2, + start_time="s", + end_time="e", + created_at="c", + title="无关标题", + content="详细的 Python 教程", + ) + mock_files = [MagicMock(spec=Path)] + + with ( + patch.object(Path, "mkdir"), + patch.object(Path, "glob", return_value=mock_files), + patch(_READ_JSON, new_callable=AsyncMock, return_value=faq.to_dict()), + ): + results = await storage.search(1, "python") + + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_delete(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_DELETE_FILE, new_callable=AsyncMock, return_value=True), + ): + result = await storage.delete(100, "20250101-001") + assert result is True + + @pytest.mark.asyncio + async def test_delete_nonexistent(self) -> None: + storage = self._make_storage() + with ( + patch.object(Path, "mkdir"), + patch(_DELETE_FILE, new_callable=AsyncMock, return_value=False), + ): + result = await storage.delete(100, "nonexist") + assert result is False diff --git a/tests/test_group_metrics.py b/tests/test_group_metrics.py new file mode 100644 index 0000000..04f832d --- /dev/null +++ b/tests/test_group_metrics.py @@ -0,0 +1,190 @@ +"""Tests for Undefined.utils.group_metrics — group member metric helpers.""" + +from __future__ import annotations + +from datetime import datetime + +from Undefined.utils.group_metrics import ( + clamp_int, + datetime_to_ts, + format_timestamp, + member_display_name, + parse_member_level, + parse_unix_timestamp, + role_to_cn, +) + + +class TestClampInt: + def test_within_range(self) -> None: + assert clamp_int(5, 0, 1, 10) == 5 + + def test_below_min(self) -> None: + assert clamp_int(-5, 0, 1, 10) == 1 + + def test_above_max(self) -> None: + assert clamp_int(20, 0, 1, 10) == 10 + + def test_at_min(self) -> None: + assert clamp_int(1, 0, 1, 10) == 1 + + def test_at_max(self) -> None: + assert clamp_int(10, 0, 1, 10) == 10 + + def test_non_numeric_returns_default(self) -> None: + assert clamp_int("abc", 7, 1, 10) == 7 + + def test_none_returns_default(self) -> None: + assert clamp_int(None, 5, 1, 10) == 5 + + def test_string_int(self) -> None: + assert clamp_int("3", 0, 1, 10) == 3 + + def test_float_truncated(self) -> None: + assert clamp_int(3.9, 0, 1, 10) == 3 + + def test_bool_as_int(self) -> None: + assert clamp_int(True, 0, 0, 10) == 1 + + +class TestParseUnixTimestamp: + def test_valid_positive(self) -> None: + assert parse_unix_timestamp(1700000000) == 1700000000 + + def test_zero(self) -> None: + assert parse_unix_timestamp(0) == 0 + + def test_negative(self) -> None: + assert parse_unix_timestamp(-100) == 0 + + def test_none(self) -> None: + assert parse_unix_timestamp(None) == 0 + + def test_non_numeric(self) -> None: + assert parse_unix_timestamp("abc") == 0 + + def test_string_number(self) -> None: + assert parse_unix_timestamp("1700000000") == 1700000000 + + def test_float(self) -> None: + assert parse_unix_timestamp(1700000000.5) == 1700000000 + + +class TestParseMemberLevel: + def test_integer(self) -> None: + assert parse_member_level(5) == 5 + + def test_zero(self) -> None: + assert parse_member_level(0) == 0 + + def test_negative_int(self) -> None: + assert parse_member_level(-1) is None + + def test_none(self) -> None: + assert parse_member_level(None) is None + + def test_bool_returns_none(self) -> None: + assert parse_member_level(True) is None + assert parse_member_level(False) is None + + def test_float(self) -> None: + assert parse_member_level(3.7) == 3 + + def test_digit_string(self) -> None: + assert parse_member_level("10") == 10 + + def test_string_with_digits(self) -> None: + assert parse_member_level("Lv.5") == 5 + + def test_string_no_digits(self) -> None: + assert parse_member_level("无") is None + + def test_empty_string(self) -> None: + assert parse_member_level("") is None + + def test_whitespace_string(self) -> None: + assert parse_member_level(" ") is None + + def test_complex_string(self) -> None: + assert parse_member_level("等级42勋章") == 42 + + +class TestMemberDisplayName: + def test_card_preferred(self) -> None: + member = {"card": "CardName", "nickname": "Nick", "user_id": 123} + assert member_display_name(member) == "CardName" + + def test_nickname_fallback(self) -> None: + member = {"card": "", "nickname": "Nick", "user_id": 123} + assert member_display_name(member) == "Nick" + + def test_user_id_fallback(self) -> None: + member = {"card": "", "nickname": "", "user_id": 123} + assert member_display_name(member) == "123" + + def test_none_card(self) -> None: + member = {"card": None, "nickname": "Nick"} + assert member_display_name(member) == "Nick" + + def test_all_missing(self) -> None: + member: dict[str, object] = {} + assert member_display_name(member) == "未知" + + def test_whitespace_card(self) -> None: + member = {"card": " ", "nickname": "Nick"} + assert member_display_name(member) == "Nick" + + +class TestRoleToCn: + def test_owner(self) -> None: + assert role_to_cn("owner") == "群主" + + def test_admin(self) -> None: + assert role_to_cn("admin") == "管理员" + + def test_member(self) -> None: + assert role_to_cn("member") == "成员" + + def test_none_defaults_to_member(self) -> None: + assert role_to_cn(None) == "成员" + + def test_unknown_role_passthrough(self) -> None: + assert role_to_cn("moderator") == "moderator" + + def test_empty_string_defaults_to_member(self) -> None: + # str("" or "member") -> "member" + assert role_to_cn("") == "成员" + + +class TestFormatTimestamp: + def test_valid_timestamp(self) -> None: + ts = int(datetime(2024, 1, 15, 12, 0, 0).timestamp()) + result = format_timestamp(ts) + assert "2024-01-15" in result + + def test_zero(self) -> None: + assert format_timestamp(0) == "无" + + def test_negative(self) -> None: + assert format_timestamp(-1) == "无" + + def test_overflow(self) -> None: + assert format_timestamp(999999999999999) == "无" + + +class TestDatetimeToTs: + def test_none(self) -> None: + assert datetime_to_ts(None) is None + + def test_valid_datetime(self) -> None: + dt = datetime(2024, 6, 15, 12, 0, 0) + result = datetime_to_ts(dt) + assert result is not None + assert isinstance(result, int) + # Round-trip check + assert datetime.fromtimestamp(result).replace(second=0) == dt.replace(second=0) + + def test_epoch(self) -> None: + dt = datetime(1970, 1, 1, 0, 0, 0) + result = datetime_to_ts(dt) + assert result is not None diff --git a/tests/test_member_utils.py b/tests/test_member_utils.py new file mode 100644 index 0000000..b509b79 --- /dev/null +++ b/tests/test_member_utils.py @@ -0,0 +1,196 @@ +"""Tests for Undefined.utils.member_utils — member analysis helpers.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from Undefined.utils.member_utils import ( + analyze_join_trend, + analyze_member_activity, + filter_by_join_time, +) + + +def _make_member( + user_id: int, + join_time: int | float | None = None, + card: str = "", + nickname: str = "", +) -> dict[str, Any]: + m: dict[str, Any] = {"user_id": user_id, "card": card, "nickname": nickname} + if join_time is not None: + m["join_time"] = join_time + return m + + +# A fixed reference timestamp: 2024-06-15 12:00:00 +_REF_TS = int(datetime(2024, 6, 15, 12, 0, 0).timestamp()) +_DAY = 86400 + + +class TestFilterByJoinTime: + def test_empty_list(self) -> None: + assert filter_by_join_time([], None, None) == [] + + def test_no_filters(self) -> None: + members = [_make_member(1, _REF_TS), _make_member(2, _REF_TS + _DAY)] + result = filter_by_join_time(members, None, None) + assert len(result) == 2 + + def test_start_filter(self) -> None: + members = [ + _make_member(1, _REF_TS - _DAY), + _make_member(2, _REF_TS + _DAY), + ] + start_dt = datetime(2024, 6, 15, 0, 0, 0) + result = filter_by_join_time(members, start_dt, None) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_end_filter(self) -> None: + members = [ + _make_member(1, _REF_TS - _DAY), + _make_member(2, _REF_TS + _DAY), + ] + end_dt = datetime(2024, 6, 15, 0, 0, 0) + result = filter_by_join_time(members, None, end_dt) + assert len(result) == 1 + assert result[0]["user_id"] == 1 + + def test_both_filters(self) -> None: + members = [ + _make_member(1, _REF_TS - 2 * _DAY), + _make_member(2, _REF_TS), + _make_member(3, _REF_TS + 2 * _DAY), + ] + start_dt = datetime.fromtimestamp(_REF_TS - _DAY) + end_dt = datetime.fromtimestamp(_REF_TS + _DAY) + result = filter_by_join_time(members, start_dt, end_dt) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_member_without_join_time_skipped(self) -> None: + members = [_make_member(1), _make_member(2, _REF_TS)] + result = filter_by_join_time(members, None, None) + assert len(result) == 1 + assert result[0]["user_id"] == 2 + + def test_non_numeric_join_time_skipped(self) -> None: + members: list[dict[str, Any]] = [{"user_id": 1, "join_time": "not-a-number"}] + result = filter_by_join_time(members, None, None) + assert len(result) == 0 + + def test_float_join_time(self) -> None: + members = [_make_member(1, float(_REF_TS) + 0.5)] + result = filter_by_join_time(members, None, None) + assert len(result) == 1 + + +class TestAnalyzeJoinTrend: + def test_empty_list(self) -> None: + assert analyze_join_trend([]) == {} + + def test_single_member(self) -> None: + members = [_make_member(1, _REF_TS)] + result = analyze_join_trend(members) + assert result["peak_count"] == 1 + assert result["avg_per_day"] == 1.0 + assert result["first_time"] is not None + assert result["last_time"] is not None + assert result["first_time"] == result["last_time"] + + def test_multiple_members_same_day(self) -> None: + members = [ + _make_member(1, _REF_TS), + _make_member(2, _REF_TS + 3600), + ] + result = analyze_join_trend(members) + assert result["peak_count"] == 2 + assert result["avg_per_day"] == 2.0 + + def test_multiple_days(self) -> None: + members = [ + _make_member(1, _REF_TS), + _make_member(2, _REF_TS + _DAY), + _make_member(3, _REF_TS + _DAY), + ] + result = analyze_join_trend(members) + assert len(result["daily_stats"]) == 2 + assert result["peak_count"] == 2 + assert result["avg_per_day"] == 1.5 + + def test_members_without_join_time_ignored(self) -> None: + members = [_make_member(1), _make_member(2, _REF_TS)] + result = analyze_join_trend(members) + # Only one member has join_time, but total uses all members + assert result["avg_per_day"] == 2.0 # 2 members / 1 day + assert result["peak_count"] == 1 + + def test_daily_stats_populated(self) -> None: + members = [_make_member(1, _REF_TS)] + result = analyze_join_trend(members) + assert isinstance(result["daily_stats"], dict) + assert len(result["daily_stats"]) == 1 + + +class TestAnalyzeMemberActivity: + def test_empty_members(self) -> None: + result = analyze_member_activity([], {}, 5) + assert result["total_members"] == 0 + assert result["active_members"] == 0 + assert result["total_messages"] == 0 + assert result["top_members"] == [] + + def test_basic_activity(self) -> None: + members = [ + _make_member(1, _REF_TS, nickname="Alice"), + _make_member(2, _REF_TS, nickname="Bob"), + _make_member(3, _REF_TS, nickname="Charlie"), + ] + counts: dict[int, int] = {1: 100, 2: 50, 3: 0} + result = analyze_member_activity(members, counts, 5) + assert result["total_members"] == 3 + assert result["active_members"] == 2 + assert result["inactive_members"] == 1 + assert result["total_messages"] == 150 + assert result["avg_messages"] == 50.0 + assert len(result["top_members"]) == 2 + assert result["top_members"][0]["user_id"] == 1 + + def test_top_count_limit(self) -> None: + members = [_make_member(i, _REF_TS) for i in range(1, 11)] + counts: dict[int, int] = {i: i * 10 for i in range(1, 11)} + result = analyze_member_activity(members, counts, 3) + assert len(result["top_members"]) == 3 + assert result["top_members"][0]["user_id"] == 10 + + def test_active_rate_calculation(self) -> None: + members = [_make_member(1), _make_member(2)] + counts: dict[int, int] = {1: 10, 2: 0} + result = analyze_member_activity(members, counts, 5) + assert result["active_rate"] == 50.0 + + def test_zero_count_excluded_from_top(self) -> None: + members = [_make_member(1, _REF_TS, nickname="A")] + counts: dict[int, int] = {1: 0} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"] == [] + + def test_member_with_card_name(self) -> None: + members = [_make_member(1, _REF_TS, card="CardName", nickname="Nick")] + counts: dict[int, int] = {1: 10} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["nickname"] == "CardName" + + def test_join_time_formatted_in_top(self) -> None: + members = [_make_member(1, _REF_TS, nickname="A")] + counts: dict[int, int] = {1: 5} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["join_time"] != "" + + def test_no_join_time_empty_string(self) -> None: + members = [_make_member(1, nickname="A")] + counts: dict[int, int] = {1: 5} + result = analyze_member_activity(members, counts, 5) + assert result["top_members"][0]["join_time"] == "" diff --git a/tests/test_memory_unit.py b/tests/test_memory_unit.py new file mode 100644 index 0000000..75917c8 --- /dev/null +++ b/tests/test_memory_unit.py @@ -0,0 +1,229 @@ +"""MemoryStorage 单元测试""" + +from __future__ import annotations + +from dataclasses import asdict +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from Undefined.memory import Memory, MemoryStorage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_storage( + initial_data: list[dict[str, str]] | None = None, + max_memories: int = 500, +) -> MemoryStorage: + """构造 MemoryStorage 并跳过真实文件 I/O。""" + with patch("Undefined.memory.MEMORY_FILE_PATH") as mock_path: + if initial_data is not None: + import io as _io + import json + + mock_path.exists.return_value = True + mock_file = _io.StringIO(json.dumps(initial_data)) + mock_open = MagicMock(return_value=mock_file) + with patch("builtins.open", mock_open): + storage = MemoryStorage(max_memories=max_memories) + else: + mock_path.exists.return_value = False + storage = MemoryStorage(max_memories=max_memories) + return storage + + +_WRITE_JSON = "Undefined.utils.io.write_json" + + +# --------------------------------------------------------------------------- +# Memory dataclass +# --------------------------------------------------------------------------- + + +class TestMemoryDataclass: + def test_fields(self) -> None: + m = Memory(uuid="u1", fact="hello", created_at="2025-01-01") + assert m.uuid == "u1" + assert m.fact == "hello" + assert m.created_at == "2025-01-01" + + def test_asdict(self) -> None: + m = Memory(uuid="u1", fact="hello", created_at="2025-01-01") + d = asdict(m) + assert d == {"uuid": "u1", "fact": "hello", "created_at": "2025-01-01"} + + +# --------------------------------------------------------------------------- +# MemoryStorage +# --------------------------------------------------------------------------- + + +class TestMemoryStorageInit: + def test_empty_init(self) -> None: + storage = _make_storage() + assert storage.count() == 0 + assert storage.get_all() == [] + + def test_init_with_data(self) -> None: + data = [ + {"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}, + {"uuid": "u2", "fact": "fact2", "created_at": "2025-01-02"}, + ] + storage = _make_storage(initial_data=data) + assert storage.count() == 2 + + def test_init_with_legacy_data_without_uuid(self) -> None: + """旧格式记录不含 uuid,应自动生成。""" + data: list[dict[str, str]] = [ + {"fact": "old fact", "created_at": "2024-01-01"}, + ] + storage = _make_storage(initial_data=data) + assert storage.count() == 1 + memories = storage.get_all() + assert memories[0].fact == "old fact" + assert memories[0].uuid # 自动生成了 UUID + + +class TestMemoryStorageAdd: + @pytest.mark.asyncio + async def test_add_returns_uuid(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.add("new fact") + assert result is not None + assert storage.count() == 1 + + @pytest.mark.asyncio + async def test_add_strips_whitespace(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add(" spaced fact ") + assert storage.get_all()[0].fact == "spaced fact" + + @pytest.mark.asyncio + async def test_add_empty_returns_none(self) -> None: + storage = _make_storage() + result = await storage.add("") + assert result is None + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_add_whitespace_only_returns_none(self) -> None: + storage = _make_storage() + result = await storage.add(" ") + assert result is None + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_add_duplicate_returns_existing_uuid(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uuid1 = await storage.add("duplicate fact") + uuid2 = await storage.add("duplicate fact") + assert uuid1 == uuid2 + assert storage.count() == 1 + + @pytest.mark.asyncio + async def test_add_max_memories_evicts_oldest(self) -> None: + storage = _make_storage(max_memories=3) + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add("fact1") + await storage.add("fact2") + await storage.add("fact3") + assert storage.count() == 3 + await storage.add("fact4") + assert storage.count() == 3 + facts = [m.fact for m in storage.get_all()] + assert "fact1" not in facts + assert "fact4" in facts + + @pytest.mark.asyncio + async def test_add_calls_save(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock) as mock_write: + await storage.add("fact") + mock_write.assert_awaited_once() + + +class TestMemoryStorageUpdate: + @pytest.mark.asyncio + async def test_update_existing(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("old") + assert uid is not None + result = await storage.update(uid, "new") + assert result is True + assert storage.get_all()[0].fact == "new" + + @pytest.mark.asyncio + async def test_update_nonexistent_returns_false(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.update("nonexistent-uuid", "new") + assert result is False + + @pytest.mark.asyncio + async def test_update_strips_whitespace(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("old") + assert uid is not None + await storage.update(uid, " updated ") + assert storage.get_all()[0].fact == "updated" + + +class TestMemoryStorageDelete: + @pytest.mark.asyncio + async def test_delete_existing(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + uid = await storage.add("to delete") + assert uid is not None + result = await storage.delete(uid) + assert result is True + assert storage.count() == 0 + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_false(self) -> None: + storage = _make_storage() + with patch(_WRITE_JSON, new_callable=AsyncMock): + result = await storage.delete("nonexistent-uuid") + assert result is False + + +class TestMemoryStorageGetAll: + def test_get_all_returns_copy(self) -> None: + data = [{"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}] + storage = _make_storage(initial_data=data) + list1 = storage.get_all() + list2 = storage.get_all() + assert list1 is not list2 + assert list1 == list2 + + +class TestMemoryStorageClear: + @pytest.mark.asyncio + async def test_clear(self) -> None: + data = [{"uuid": "u1", "fact": "fact1", "created_at": "2025-01-01"}] + storage = _make_storage(initial_data=data) + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.clear() + assert storage.count() == 0 + assert storage.get_all() == [] + + +class TestMemoryStorageCount: + @pytest.mark.asyncio + async def test_count_tracks_additions(self) -> None: + storage = _make_storage() + assert storage.count() == 0 + with patch(_WRITE_JSON, new_callable=AsyncMock): + await storage.add("a") + assert storage.count() == 1 + await storage.add("b") + assert storage.count() == 2 diff --git a/tests/test_message_targets.py b/tests/test_message_targets.py new file mode 100644 index 0000000..69ee58f --- /dev/null +++ b/tests/test_message_targets.py @@ -0,0 +1,185 @@ +"""Tests for Undefined.utils.message_targets — target resolution helpers.""" + +from __future__ import annotations + +from typing import Any + +from Undefined.utils.message_targets import parse_positive_int, resolve_message_target + + +class TestParsePositiveInt: + def test_valid_int(self) -> None: + val, err = parse_positive_int(42, "field") + assert val == 42 + assert err is None + + def test_valid_string_int(self) -> None: + val, err = parse_positive_int("123", "field") + assert val == 123 + assert err is None + + def test_none_input(self) -> None: + val, err = parse_positive_int(None, "field") + assert val is None + assert err is None + + def test_zero_rejected(self) -> None: + val, err = parse_positive_int(0, "field") + assert val is None + assert err is not None + assert "正整数" in (err or "") + + def test_negative_rejected(self) -> None: + val, err = parse_positive_int(-5, "field") + assert val is None + assert err is not None + + def test_non_numeric_string(self) -> None: + val, err = parse_positive_int("abc", "field") + assert val is None + assert err is not None + assert "整数" in (err or "") + + def test_float_truncated(self) -> None: + val, err = parse_positive_int(3.9, "field") + assert val == 3 + assert err is None + + def test_float_string_rejected(self) -> None: + val, err = parse_positive_int("3.5", "field") + assert val is None + assert err is not None + + def test_bool_treated_as_int(self) -> None: + # bool is subclass of int; True -> 1 + val, err = parse_positive_int(True, "field") + assert val == 1 + assert err is None + + def test_field_name_in_error(self) -> None: + _, err = parse_positive_int("bad", "target_id") + assert err is not None + assert "target_id" in err + + +class TestResolveMessageTarget: + @staticmethod + def _call( + args: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, + ) -> tuple[tuple[str, int] | None, str | None]: + result: tuple[tuple[str, int] | None, str | None] = resolve_message_target( + args or {}, context or {} + ) + return result + + def test_explicit_group_target(self) -> None: + target, err = self._call( + args={"target_type": "group", "target_id": 12345}, + ) + assert target == ("group", 12345) + assert err is None + + def test_explicit_private_target(self) -> None: + target, err = self._call( + args={"target_type": "private", "target_id": 67890}, + ) + assert target == ("private", 67890) + assert err is None + + def test_target_type_case_insensitive(self) -> None: + target, err = self._call( + args={"target_type": "GROUP", "target_id": 1}, + ) + assert target == ("group", 1) + + def test_target_type_without_id_infers_from_context(self) -> None: + target, err = self._call( + args={"target_type": "group"}, + context={"request_type": "group", "group_id": 100}, + ) + assert target == ("group", 100) + assert err is None + + def test_target_type_without_id_mismatch_context(self) -> None: + target, err = self._call( + args={"target_type": "group"}, + context={"request_type": "private", "user_id": 100}, + ) + assert target is None + assert err is not None + assert "不一致" in (err or "") + + def test_target_id_without_type_error(self) -> None: + target, err = self._call(args={"target_id": 123}) + assert target is None + assert err is not None + assert "同时提供" in (err or "") + + def test_invalid_target_type(self) -> None: + target, err = self._call( + args={"target_type": "channel", "target_id": 1}, + ) + assert target is None + assert err is not None + + def test_target_type_non_string(self) -> None: + target, err = self._call( + args={"target_type": 123, "target_id": 1}, + ) + assert target is None + assert err is not None + assert "字符串" in (err or "") + + def test_legacy_group_id(self) -> None: + target, err = self._call(args={"group_id": 999}) + assert target == ("group", 999) + assert err is None + + def test_legacy_user_id(self) -> None: + target, err = self._call(args={"user_id": 888}) + assert target == ("private", 888) + assert err is None + + def test_legacy_invalid_group_id(self) -> None: + target, err = self._call(args={"group_id": -1}) + assert target is None + assert err is not None + + def test_fallback_to_context_group(self) -> None: + target, err = self._call( + context={"request_type": "group", "group_id": 555}, + ) + assert target == ("group", 555) + assert err is None + + def test_fallback_to_context_private(self) -> None: + target, err = self._call( + context={"request_type": "private", "user_id": 444}, + ) + assert target == ("private", 444) + assert err is None + + def test_fallback_context_group_id_only(self) -> None: + target, err = self._call(context={"group_id": 333}) + assert target == ("group", 333) + assert err is None + + def test_fallback_context_user_id_only(self) -> None: + target, err = self._call(context={"user_id": 222}) + assert target == ("private", 222) + assert err is None + + def test_no_target_info_at_all(self) -> None: + target, err = self._call() + assert target is None + assert err is not None + assert "无法确定" in (err or "") + + def test_target_type_private_infer_from_context(self) -> None: + target, err = self._call( + args={"target_type": "private"}, + context={"request_type": "private", "user_id": 77}, + ) + assert target == ("private", 77) + assert err is None diff --git a/tests/test_message_utils.py b/tests/test_message_utils.py new file mode 100644 index 0000000..3341f26 --- /dev/null +++ b/tests/test_message_utils.py @@ -0,0 +1,232 @@ +"""Tests for Undefined.utils.message_utils.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict + +import pytest + +from Undefined.utils.message_utils import ( + analyze_activity_pattern, + count_message_types, + count_messages_by_user, + filter_user_messages, + format_messages, +) + + +@pytest.fixture(autouse=True) +def _mock_parse_message_time(monkeypatch: pytest.MonkeyPatch) -> None: + """Patch parse_message_time so tests don't depend on onebot imports.""" + + def _fake_parse(msg: Dict[str, Any]) -> datetime: + return datetime.fromtimestamp(msg.get("time", 0)) + + monkeypatch.setattr( + "Undefined.utils.message_utils.parse_message_time", + _fake_parse, + ) + + +def _msg( + user_id: int = 100, + ts: int = 1700000000, + message: Any = None, + nickname: str = "TestUser", +) -> Dict[str, Any]: + """Helper to build a minimal message dict.""" + return { + "sender": {"user_id": user_id, "nickname": nickname}, + "time": ts, + "message": message if message is not None else "hello", + } + + +# --------------------------------------------------------------------------- +# filter_user_messages +# --------------------------------------------------------------------------- + + +class TestFilterUserMessages: + def test_filters_by_user_id(self) -> None: + msgs = [_msg(user_id=1), _msg(user_id=2), _msg(user_id=1)] + result = filter_user_messages(msgs, user_id=1, start_dt=None, end_dt=None) + assert len(result) == 2 + + def test_filters_by_time_range(self) -> None: + msgs = [ + _msg(ts=1700000000), + _msg(ts=1700000100), + _msg(ts=1700000200), + ] + start = datetime.fromtimestamp(1700000050) + end = datetime.fromtimestamp(1700000150) + result = filter_user_messages(msgs, user_id=100, start_dt=start, end_dt=end) + assert len(result) == 1 + + def test_empty_messages(self) -> None: + result = filter_user_messages([], user_id=1, start_dt=None, end_dt=None) + assert result == [] + + def test_no_time_bounds(self) -> None: + msgs = [_msg(user_id=100, ts=1700000000)] + result = filter_user_messages(msgs, user_id=100, start_dt=None, end_dt=None) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# count_message_types +# --------------------------------------------------------------------------- + + +class TestCountMessageTypes: + def test_string_message_is_text(self) -> None: + msgs = [_msg(message="hi")] + result = count_message_types(msgs) + assert result == {"文本消息": 1} + + def test_image_segment(self) -> None: + msgs = [_msg(message=[{"type": "image", "data": {}}])] + result = count_message_types(msgs) + assert result == {"图片消息": 1} + + def test_reply_priority_over_text(self) -> None: + msgs = [ + _msg( + message=[ + {"type": "reply", "data": {}}, + {"type": "text", "data": {"text": "hi"}}, + ] + ) + ] + result = count_message_types(msgs) + assert result == {"回复消息": 1} + + def test_face_segment(self) -> None: + msgs = [_msg(message=[{"type": "face", "data": {}}])] + result = count_message_types(msgs) + assert result == {"表情消息": 1} + + def test_empty_segment_list(self) -> None: + msgs: list[Dict[str, Any]] = [_msg(message=[])] + result = count_message_types(msgs) + assert result == {"空消息": 1} + + def test_other_segment_type(self) -> None: + msgs = [_msg(message=[{"type": "forward", "data": {}}])] + result = count_message_types(msgs) + assert result == {"其他消息": 1} + + def test_text_only_segments(self) -> None: + msgs = [_msg(message=[{"type": "text", "data": {"text": "hello"}}])] + result = count_message_types(msgs) + assert result == {"文本消息": 1} + + def test_mixed_messages(self) -> None: + msgs = [ + _msg(message="hi"), + _msg(message=[{"type": "image", "data": {}}]), + _msg(message=[{"type": "face", "data": {}}]), + ] + result = count_message_types(msgs) + assert result == {"文本消息": 1, "图片消息": 1, "表情消息": 1} + + +# --------------------------------------------------------------------------- +# analyze_activity_pattern +# --------------------------------------------------------------------------- + + +class TestAnalyzeActivityPattern: + def test_empty_returns_empty_dict(self) -> None: + assert analyze_activity_pattern([]) == {} + + def test_single_message(self) -> None: + ts = 1700000000 + msgs = [_msg(ts=ts)] + result = analyze_activity_pattern(msgs) + assert result["avg_per_day"] == 1.0 + assert result["first_time"] is not None + assert result["last_time"] is not None + assert result["first_time"] == result["last_time"] + + def test_multiple_messages_avg_per_day(self) -> None: + # Two messages on the same day + msgs = [_msg(ts=1700000000), _msg(ts=1700000100)] + result = analyze_activity_pattern(msgs) + assert result["avg_per_day"] == 2.0 + + def test_most_active_hour_format(self) -> None: + msgs = [_msg(ts=1700000000)] + result = analyze_activity_pattern(msgs) + hour_str: str = result["most_active_hour"] + assert ":00-" in hour_str + assert ":59" in hour_str + + def test_weekday_is_chinese(self) -> None: + msgs = [_msg(ts=1700000000)] + result = analyze_activity_pattern(msgs) + weekday_str: str = result["most_active_weekday"] + assert weekday_str.startswith("周") + + +# --------------------------------------------------------------------------- +# count_messages_by_user +# --------------------------------------------------------------------------- + + +class TestCountMessagesByUser: + def test_counts_correctly(self) -> None: + msgs = [_msg(user_id=1), _msg(user_id=2), _msg(user_id=1)] + result = count_messages_by_user(msgs, {1, 2, 3}) + assert result == {1: 2, 2: 1, 3: 0} + + def test_unknown_user_ignored(self) -> None: + msgs = [_msg(user_id=99)] + result = count_messages_by_user(msgs, {1}) + assert result == {1: 0} + + def test_empty_messages(self) -> None: + result = count_messages_by_user([], {1, 2}) + assert result == {1: 0, 2: 0} + + +# --------------------------------------------------------------------------- +# format_messages +# --------------------------------------------------------------------------- + + +class TestFormatMessages: + def test_basic_format(self) -> None: + msgs = [_msg(user_id=42, ts=1700000000, nickname="Alice")] + result = format_messages(msgs) + assert len(result) == 1 + assert result[0]["sender"] == "Alice" + assert result[0]["sender_id"] == 42 + assert "2023" in result[0]["time"] + assert result[0]["content"] == "hello" + + def test_segment_format(self) -> None: + msg = _msg( + message=[ + {"type": "text", "data": {"text": "hi "}}, + {"type": "image", "data": {}}, + ] + ) + result = format_messages([msg]) + assert result[0]["content"] == "hi [图片]" + + def test_empty_content_placeholder(self) -> None: + msg = _msg(message=[]) + result = format_messages([msg]) + assert result[0]["content"] == "(空消息)" + + def test_card_preferred_over_nickname(self) -> None: + msg: Dict[str, Any] = { + "sender": {"user_id": 1, "card": "CardName", "nickname": "Nick"}, + "time": 1700000000, + "message": "hi", + } + result = format_messages([msg]) + assert result[0]["sender"] == "CardName" diff --git a/tests/test_qq_emoji.py b/tests/test_qq_emoji.py new file mode 100644 index 0000000..7f94651 --- /dev/null +++ b/tests/test_qq_emoji.py @@ -0,0 +1,142 @@ +"""QQ emoji 工具 单元测试""" + +from __future__ import annotations + + +from Undefined.utils.qq_emoji import ( + get_emoji_alias_map, + get_emoji_id_entries, + resolve_emoji_id_by_alias, + search_emoji_aliases, +) + + +# --------------------------------------------------------------------------- +# resolve_emoji_id_by_alias +# --------------------------------------------------------------------------- + + +class TestResolveEmojiIdByAlias: + def test_known_chinese_alias(self) -> None: + assert resolve_emoji_id_by_alias("微笑") == 14 + + def test_known_english_alias(self) -> None: + assert resolve_emoji_id_by_alias("smile") == 14 + + def test_known_unicode_emoji(self) -> None: + assert resolve_emoji_id_by_alias("👍") == 76 + + def test_case_insensitive(self) -> None: + assert resolve_emoji_id_by_alias("SMILE") == 14 + assert resolve_emoji_id_by_alias("Smile") == 14 + + def test_whitespace_stripped(self) -> None: + assert resolve_emoji_id_by_alias(" smile ") == 14 + + def test_unknown_alias(self) -> None: + assert resolve_emoji_id_by_alias("completely_unknown_emoji_xyz") is None + + def test_empty_string(self) -> None: + assert resolve_emoji_id_by_alias("") is None + + def test_whitespace_only(self) -> None: + assert resolve_emoji_id_by_alias(" ") is None + + +# --------------------------------------------------------------------------- +# search_emoji_aliases +# --------------------------------------------------------------------------- + + +class TestSearchEmojiAliases: + def test_search_finds_matching(self) -> None: + results = search_emoji_aliases("笑") + assert len(results) > 0 + for alias, _eid in results: + assert "笑" in alias + + def test_search_limit(self) -> None: + results = search_emoji_aliases("笑", limit=2) + assert len(results) <= 2 + + def test_search_no_match(self) -> None: + results = search_emoji_aliases("zzz_no_match_xyz") + assert results == [] + + def test_search_empty_keyword(self) -> None: + results = search_emoji_aliases("") + assert results == [] + + def test_search_returns_tuples(self) -> None: + results = search_emoji_aliases("赞") + assert len(results) > 0 + for item in results: + assert isinstance(item, tuple) + assert isinstance(item[0], str) + assert isinstance(item[1], int) + + def test_search_case_insensitive(self) -> None: + r1 = search_emoji_aliases("ok") + r2 = search_emoji_aliases("OK") + assert r1 == r2 + + def test_search_sorted_by_id_then_alias(self) -> None: + results = search_emoji_aliases("笑") + if len(results) >= 2: + for i in range(len(results) - 1): + assert (results[i][1], results[i][0]) <= ( + results[i + 1][1], + results[i + 1][0], + ) + + +# --------------------------------------------------------------------------- +# get_emoji_id_entries +# --------------------------------------------------------------------------- + + +class TestGetEmojiIdEntries: + def test_returns_list(self) -> None: + entries = get_emoji_id_entries() + assert isinstance(entries, list) + assert len(entries) > 0 + + def test_entries_structure(self) -> None: + entries = get_emoji_id_entries() + for emoji_id, aliases in entries: + assert isinstance(emoji_id, int) + assert isinstance(aliases, list) + assert all(isinstance(a, str) for a in aliases) + + def test_entries_sorted_by_id(self) -> None: + entries = get_emoji_id_entries() + ids = [eid for eid, _ in entries] + assert ids == sorted(ids) + + def test_aliases_sorted(self) -> None: + entries = get_emoji_id_entries() + for _, aliases in entries: + assert aliases == sorted(aliases) + + def test_known_emoji_in_entries(self) -> None: + entries = get_emoji_id_entries() + id_map = {eid: aliases for eid, aliases in entries} + assert 76 in id_map + assert "赞" in id_map[76] + + +# --------------------------------------------------------------------------- +# get_emoji_alias_map +# --------------------------------------------------------------------------- + + +class TestGetEmojiAliasMap: + def test_returns_dict(self) -> None: + m = get_emoji_alias_map() + assert isinstance(m, dict) + assert len(m) > 0 + + def test_contains_known_entries(self) -> None: + m = get_emoji_alias_map() + assert m.get("微笑") == 14 + assert m.get("👍") == 76 diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 0000000..a38c0ef --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,267 @@ +"""RateLimiter 单元测试""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Any, cast + +from Undefined.rate_limit import RateLimiter + + +# --------------------------------------------------------------------------- +# Mock helpers +# --------------------------------------------------------------------------- + + +class _MockConfig: + """最小化的 Config mock。""" + + def __init__( + self, + superadmins: set[int] | None = None, + admins: set[int] | None = None, + ) -> None: + self._superadmins = superadmins or set() + self._admins = admins or set() + + def is_superadmin(self, user_id: int) -> bool: + return user_id in self._superadmins + + def is_admin(self, user_id: int) -> bool: + return user_id in self._admins + + +@dataclass +class _MockCommandRateLimit: + """模拟 CommandRateLimit。""" + + user: int = 10 + admin: int = 5 + superadmin: int = 0 + + +# --------------------------------------------------------------------------- +# 基本限流 (check / record) +# --------------------------------------------------------------------------- + + +class TestRateLimiterCheck: + def test_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, remaining = limiter.check(1001) + assert allowed is True + assert remaining == 0 + + def test_second_call_blocked(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + allowed, remaining = limiter.check(1001) + assert allowed is False + assert remaining > 0 + + def test_superadmin_always_allowed(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record(1001) + allowed, _ = limiter.check(1001) + assert allowed is True + + def test_admin_shorter_cooldown(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + # 模拟 admin 在较短冷却期后可以调用 + limiter._last_calls[2001] = time.time() - RateLimiter.ADMIN_COOLDOWN - 1 + allowed, _ = limiter.check(2001) + assert allowed is True + + def test_normal_user_cooldown(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_calls[3001] = time.time() - RateLimiter.USER_COOLDOWN + 2 + allowed, remaining = limiter.check(3001) + assert allowed is False + assert remaining >= 1 + + def test_cooldown_expires(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_calls[3001] = time.time() - RateLimiter.USER_COOLDOWN - 1 + allowed, _ = limiter.check(3001) + assert allowed is True + + +class TestRateLimiterRecord: + def test_record_stores_time(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + assert 1001 in limiter._last_calls + + def test_record_superadmin_skipped(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record(1001) + assert 1001 not in limiter._last_calls + + +# --------------------------------------------------------------------------- +# /ask 限流 +# --------------------------------------------------------------------------- + + +class TestRateLimiterAsk: + def test_ask_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_ask_blocked_within_cooldown(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_ask(1001) + allowed, remaining = limiter.check_ask(1001) + assert allowed is False + assert remaining > 0 + + def test_ask_superadmin_bypass(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_ask(1001) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_ask_cooldown_expires(self) -> None: + cfg = _MockConfig() + limiter = RateLimiter(cast(Any, cfg)) + limiter._last_ask_calls[1001] = time.time() - RateLimiter.ASK_COOLDOWN - 1 + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + +# --------------------------------------------------------------------------- +# /stats 限流 +# --------------------------------------------------------------------------- + + +class TestRateLimiterStats: + def test_stats_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_stats_blocked_for_normal_user(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_stats(1001) + allowed, remaining = limiter.check_stats(1001) + assert allowed is False + assert remaining > 0 + + def test_stats_admin_bypass(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(2001) + allowed, _ = limiter.check_stats(2001) + assert allowed is True + + def test_stats_superadmin_bypass(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(1001) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_stats_record_skipped_for_admin(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limiter.record_stats(2001) + assert 2001 not in limiter._last_stats_calls + + +# --------------------------------------------------------------------------- +# 动态命令限流 (check_command / record_command) +# --------------------------------------------------------------------------- + + +class TestRateLimiterCommand: + def test_command_first_call_allowed(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit() + allowed, _ = limiter.check_command(1001, "test_cmd", cast(Any, limits)) + assert allowed is True + + def test_command_blocked_after_record(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit(user=10) + limiter.record_command(1001, "cmd", cast(Any, limits)) + allowed, remaining = limiter.check_command(1001, "cmd", cast(Any, limits)) + assert allowed is False + assert remaining > 0 + + def test_command_superadmin_zero_cooldown(self) -> None: + cfg = _MockConfig(superadmins={1001}) + limiter = RateLimiter(cast(Any, cfg)) + limits = _MockCommandRateLimit(superadmin=0) + limiter.record_command(1001, "cmd", cast(Any, limits)) + allowed, _ = limiter.check_command(1001, "cmd", cast(Any, limits)) + assert allowed is True + + def test_command_admin_shorter_cooldown(self) -> None: + cfg = _MockConfig(admins={2001}) + limiter = RateLimiter(cast(Any, cfg)) + limits = _MockCommandRateLimit(admin=5, user=60) + limiter._command_calls.setdefault("cmd", {})[2001] = time.time() - 6 + allowed, _ = limiter.check_command(2001, "cmd", cast(Any, limits)) + assert allowed is True + + def test_command_different_commands_independent(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limits = _MockCommandRateLimit(user=60) + limiter.record_command(1001, "cmd_a", cast(Any, limits)) + allowed, _ = limiter.check_command(1001, "cmd_b", cast(Any, limits)) + assert allowed is True + + +# --------------------------------------------------------------------------- +# clear 方法 +# --------------------------------------------------------------------------- + + +class TestRateLimiterClear: + def test_clear_removes_user(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + limiter.clear(1001) + allowed, _ = limiter.check(1001) + assert allowed is True + + def test_clear_ask(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_ask(1001) + limiter.clear_ask(1001) + allowed, _ = limiter.check_ask(1001) + assert allowed is True + + def test_clear_stats(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record_stats(1001) + limiter.clear_stats(1001) + allowed, _ = limiter.check_stats(1001) + assert allowed is True + + def test_clear_all(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.record(1001) + limiter.record_ask(1001) + limiter.record_stats(1001) + limits = _MockCommandRateLimit() + limiter.record_command(1001, "cmd", cast(Any, limits)) + limiter.clear_all() + assert limiter._last_calls == {} + assert limiter._last_ask_calls == {} + assert limiter._last_stats_calls == {} + assert limiter._command_calls == {} + + def test_clear_nonexistent_user_no_error(self) -> None: + limiter = RateLimiter(cast(Any, _MockConfig)()) + limiter.clear(9999) # 不应抛出异常 + limiter.clear_ask(9999) + limiter.clear_stats(9999) diff --git a/tests/test_request_params.py b/tests/test_request_params.py new file mode 100644 index 0000000..d5e09d3 --- /dev/null +++ b/tests/test_request_params.py @@ -0,0 +1,141 @@ +"""Tests for Undefined.utils.request_params — request param helpers.""" + +from __future__ import annotations + +from collections import OrderedDict +from typing import Any + +from Undefined.utils.request_params import ( + merge_request_params, + normalize_request_params, + split_reserved_request_params, +) + + +class TestNormalizeRequestParams: + def test_dict_passthrough(self) -> None: + result = normalize_request_params({"a": 1, "b": "two"}) + assert result == {"a": 1, "b": "two"} + + def test_none_returns_empty(self) -> None: + assert normalize_request_params(None) == {} + + def test_non_dict_returns_empty(self) -> None: + assert normalize_request_params("string") == {} + assert normalize_request_params(42) == {} + assert normalize_request_params([1, 2]) == {} + + def test_empty_dict(self) -> None: + assert normalize_request_params({}) == {} + + def test_nested_dict_cloned(self) -> None: + original: dict[str, Any] = {"inner": {"x": 1}} + result = normalize_request_params(original) + assert result == {"inner": {"x": 1}} + # Must be a deep copy + assert result["inner"] is not original["inner"] + + def test_list_cloned(self) -> None: + original: dict[str, Any] = {"items": [1, 2, {"a": 3}]} + result = normalize_request_params(original) + assert result["items"] == [1, 2, {"a": 3}] + assert result["items"] is not original["items"] + + def test_tuple_converted_to_list(self) -> None: + result = normalize_request_params({"t": (1, 2, 3)}) + assert result["t"] == [1, 2, 3] + assert isinstance(result["t"], list) + + def test_non_json_value_stringified(self) -> None: + result = normalize_request_params({"obj": object()}) + assert isinstance(result["obj"], str) + + def test_keys_stringified(self) -> None: + result = normalize_request_params({1: "a", 2: "b"}) + assert "1" in result + assert "2" in result + + def test_ordered_dict_accepted(self) -> None: + od = OrderedDict([("z", 1), ("a", 2)]) + result = normalize_request_params(od) + assert result == {"z": 1, "a": 2} + + def test_bool_preserved(self) -> None: + result = normalize_request_params({"flag": True}) + assert result["flag"] is True + + def test_none_value_preserved(self) -> None: + result = normalize_request_params({"key": None}) + assert result["key"] is None + + +class TestMergeRequestParams: + def test_single_dict(self) -> None: + result = merge_request_params({"a": 1}) + assert result == {"a": 1} + + def test_two_dicts_merged(self) -> None: + result = merge_request_params({"a": 1}, {"b": 2}) + assert result == {"a": 1, "b": 2} + + def test_later_overrides_earlier(self) -> None: + result = merge_request_params({"a": 1}, {"a": 2}) + assert result["a"] == 2 + + def test_none_skipped(self) -> None: + result = merge_request_params(None, {"a": 1}) + assert result == {"a": 1} + + def test_non_dict_skipped(self) -> None: + result = merge_request_params("bad", {"a": 1}, 42) + assert result == {"a": 1} + + def test_empty_args(self) -> None: + result = merge_request_params() + assert result == {} + + def test_multiple_merges(self) -> None: + result = merge_request_params({"a": 1}, {"b": 2}, {"c": 3}) + assert result == {"a": 1, "b": 2, "c": 3} + + +class TestSplitReservedRequestParams: + def test_basic_split(self) -> None: + allowed, reserved = split_reserved_request_params( + {"a": 1, "b": 2, "c": 3}, {"b", "c"} + ) + assert allowed == {"a": 1} + assert reserved == {"b": 2, "c": 3} + + def test_no_reserved_keys(self) -> None: + allowed, reserved = split_reserved_request_params({"a": 1}, set()) + assert allowed == {"a": 1} + assert reserved == {} + + def test_all_reserved(self) -> None: + allowed, reserved = split_reserved_request_params({"a": 1, "b": 2}, {"a", "b"}) + assert allowed == {} + assert reserved == {"a": 1, "b": 2} + + def test_none_params(self) -> None: + allowed, reserved = split_reserved_request_params(None, {"a"}) + assert allowed == {} + assert reserved == {} + + def test_empty_params(self) -> None: + allowed, reserved = split_reserved_request_params({}, {"a"}) + assert allowed == {} + assert reserved == {} + + def test_frozenset_keys(self) -> None: + allowed, reserved = split_reserved_request_params( + {"x": 1, "y": 2}, frozenset({"x"}) + ) + assert allowed == {"y": 2} + assert reserved == {"x": 1} + + def test_nested_values_cloned(self) -> None: + original: dict[str, Any] = {"deep": {"nested": True}, "keep": "val"} + allowed, reserved = split_reserved_request_params(original, {"deep"}) + assert reserved["deep"] == {"nested": True} + assert reserved["deep"] is not original["deep"] diff --git a/tests/test_scheduled_task_unit.py b/tests/test_scheduled_task_unit.py new file mode 100644 index 0000000..93866a4 --- /dev/null +++ b/tests/test_scheduled_task_unit.py @@ -0,0 +1,192 @@ +"""ScheduledTask / ToolCall 序列化 单元测试""" + +from __future__ import annotations + +from typing import Any + + +from Undefined.scheduled_task_storage import ScheduledTask, ToolCall + + +# --------------------------------------------------------------------------- +# ToolCall +# --------------------------------------------------------------------------- + + +class TestToolCall: + def test_fields(self) -> None: + tc = ToolCall(tool_name="search", tool_args={"q": "test"}) + assert tc.tool_name == "search" + assert tc.tool_args == {"q": "test"} + + +# --------------------------------------------------------------------------- +# ScheduledTask — to_dict / from_dict 往返 +# --------------------------------------------------------------------------- + + +def _sample_task_dict() -> dict[str, Any]: + return { + "task_id": "task-001", + "tool_name": "search", + "tool_args": {"q": "test"}, + "cron": "0 9 * * *", + "target_id": 12345, + "target_type": "group", + "task_name": "每日搜索", + "max_executions": 10, + "current_executions": 3, + "created_at": "2025-01-01T00:00:00", + "context_id": "ctx-1", + "tools": [ + {"tool_name": "search", "tool_args": {"q": "test"}}, + {"tool_name": "notify", "tool_args": {"msg": "done"}}, + ], + "execution_mode": "parallel", + } + + +class TestScheduledTaskRoundtrip: + def test_basic_roundtrip(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + restored = task.to_dict() + assert restored["task_id"] == "task-001" + assert restored["cron"] == "0 9 * * *" + assert restored["execution_mode"] == "parallel" + assert len(restored["tools"]) == 2 + + def test_tools_are_toolcall_instances(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + assert task.tools is not None + for tc in task.tools: + assert isinstance(tc, ToolCall) + + def test_to_dict_tools_are_dicts(self) -> None: + d = _sample_task_dict() + task = ScheduledTask.from_dict(d) + restored = task.to_dict() + for tool in restored["tools"]: + assert isinstance(tool, dict) + assert "tool_name" in tool + + +# --------------------------------------------------------------------------- +# 向后兼容 — 旧格式无 tools +# --------------------------------------------------------------------------- + + +class TestScheduledTaskBackwardCompat: + def test_legacy_without_tools_field(self) -> None: + """旧格式只有 tool_name/tool_args,没有 tools 字段。""" + d: dict[str, Any] = { + "task_id": "legacy-1", + "tool_name": "old_tool", + "tool_args": {"key": "val"}, + "cron": "*/5 * * * *", + "target_id": None, + "target_type": "private", + "task_name": "旧任务", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.tools is not None + assert len(task.tools) == 1 + assert task.tools[0].tool_name == "old_tool" + assert task.tools[0].tool_args == {"key": "val"} + + def test_legacy_empty_tools_uses_tool_name(self) -> None: + """tools 为空列表时,回退到 tool_name。""" + d: dict[str, Any] = { + "task_id": "legacy-2", + "tool_name": "fallback", + "tool_args": {}, + "tools": [], + "cron": "0 0 * * *", + "target_id": 1, + "target_type": "group", + "task_name": "fallback task", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.tools is not None + assert len(task.tools) == 1 + assert task.tools[0].tool_name == "fallback" + + +# --------------------------------------------------------------------------- +# 可选字段缺失 +# --------------------------------------------------------------------------- + + +class TestScheduledTaskOptionalFields: + def test_missing_context_id(self) -> None: + d: dict[str, Any] = { + "task_id": "t1", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.context_id is None + + def test_missing_current_executions(self) -> None: + d: dict[str, Any] = { + "task_id": "t2", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": 1, + "target_type": "private", + "task_name": "n", + "max_executions": 5, + } + task = ScheduledTask.from_dict(d) + assert task.current_executions == 0 + + def test_missing_created_at(self) -> None: + d: dict[str, Any] = { + "task_id": "t3", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.created_at == "" + + def test_default_execution_mode(self) -> None: + d: dict[str, Any] = { + "task_id": "t4", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.execution_mode == "serial" + + def test_max_executions_none(self) -> None: + d: dict[str, Any] = { + "task_id": "t5", + "tool_name": "x", + "tool_args": {}, + "cron": "0 0 * * *", + "target_id": None, + "target_type": "group", + "task_name": "n", + "max_executions": None, + } + task = ScheduledTask.from_dict(d) + assert task.max_executions is None diff --git a/tests/test_skills_http_client.py b/tests/test_skills_http_client.py new file mode 100644 index 0000000..61479e9 --- /dev/null +++ b/tests/test_skills_http_client.py @@ -0,0 +1,76 @@ +"""Tests for Undefined.skills.http_client module (pure functions only).""" + +from __future__ import annotations + +from Undefined.skills.http_client import _retry_delay, _should_retry_http_status + + +class TestShouldRetryHttpStatus: + """Tests for _should_retry_http_status().""" + + def test_429_should_retry(self) -> None: + assert _should_retry_http_status(429) is True + + def test_500_should_retry(self) -> None: + assert _should_retry_http_status(500) is True + + def test_502_should_retry(self) -> None: + assert _should_retry_http_status(502) is True + + def test_503_should_retry(self) -> None: + assert _should_retry_http_status(503) is True + + def test_504_should_retry(self) -> None: + assert _should_retry_http_status(504) is True + + def test_599_should_retry(self) -> None: + assert _should_retry_http_status(599) is True + + def test_200_should_not_retry(self) -> None: + assert _should_retry_http_status(200) is False + + def test_201_should_not_retry(self) -> None: + assert _should_retry_http_status(201) is False + + def test_400_should_not_retry(self) -> None: + assert _should_retry_http_status(400) is False + + def test_401_should_not_retry(self) -> None: + assert _should_retry_http_status(401) is False + + def test_403_should_not_retry(self) -> None: + assert _should_retry_http_status(403) is False + + def test_404_should_not_retry(self) -> None: + assert _should_retry_http_status(404) is False + + def test_600_should_not_retry(self) -> None: + assert _should_retry_http_status(600) is False + + def test_428_should_not_retry(self) -> None: + assert _should_retry_http_status(428) is False + + +class TestRetryDelay: + """Tests for _retry_delay().""" + + def test_attempt_0(self) -> None: + assert _retry_delay(0) == 0.25 # min(2.0, 0.25 * 2^0) = 0.25 + + def test_attempt_1(self) -> None: + assert _retry_delay(1) == 0.5 # min(2.0, 0.25 * 2^1) = 0.5 + + def test_attempt_2(self) -> None: + assert _retry_delay(2) == 1.0 # min(2.0, 0.25 * 2^2) = 1.0 + + def test_attempt_3(self) -> None: + assert _retry_delay(3) == 2.0 # min(2.0, 0.25 * 2^3) = 2.0 + + def test_attempt_4_capped(self) -> None: + assert _retry_delay(4) == 2.0 # min(2.0, 0.25 * 2^4) = min(2.0, 4.0) = 2.0 + + def test_attempt_5_capped(self) -> None: + assert _retry_delay(5) == 2.0 # capped at 2.0 + + def test_returns_float(self) -> None: + assert isinstance(_retry_delay(0), float) diff --git a/tests/test_skills_http_config.py b/tests/test_skills_http_config.py new file mode 100644 index 0000000..7ed41eb --- /dev/null +++ b/tests/test_skills_http_config.py @@ -0,0 +1,98 @@ +"""Tests for Undefined.skills.http_config module (pure functions only).""" + +from __future__ import annotations + +from Undefined.skills.http_config import _normalize_base_url, build_url + + +class TestBuildUrl: + """Tests for build_url().""" + + def test_simple_join(self) -> None: + assert ( + build_url("https://api.example.com", "/v1/data") + == "https://api.example.com/v1/data" + ) + + def test_trailing_slash_on_base(self) -> None: + assert ( + build_url("https://api.example.com/", "/v1/data") + == "https://api.example.com/v1/data" + ) + + def test_multiple_trailing_slashes(self) -> None: + assert ( + build_url("https://api.example.com///", "/v1") + == "https://api.example.com/v1" + ) + + def test_path_without_leading_slash(self) -> None: + assert ( + build_url("https://api.example.com", "v1/data") + == "https://api.example.com/v1/data" + ) + + def test_empty_path(self) -> None: + assert build_url("https://api.example.com", "") == "https://api.example.com/" + + def test_path_is_slash_only(self) -> None: + assert build_url("https://api.example.com", "/") == "https://api.example.com/" + + def test_base_with_subpath(self) -> None: + assert ( + build_url("https://api.example.com/v2", "/users") + == "https://api.example.com/v2/users" + ) + + def test_base_with_subpath_trailing_slash(self) -> None: + assert ( + build_url("https://api.example.com/v2/", "/users") + == "https://api.example.com/v2/users" + ) + + +class TestNormalizeBaseUrl: + """Tests for _normalize_base_url().""" + + def test_normal_url(self) -> None: + assert ( + _normalize_base_url("https://api.example.com", "https://fallback.com") + == "https://api.example.com" + ) + + def test_trailing_slash_removed(self) -> None: + assert ( + _normalize_base_url("https://api.example.com/", "https://fallback.com") + == "https://api.example.com" + ) + + def test_multiple_trailing_slashes(self) -> None: + assert ( + _normalize_base_url("https://api.example.com///", "https://fallback.com") + == "https://api.example.com" + ) + + def test_empty_value_uses_fallback(self) -> None: + assert _normalize_base_url("", "https://fallback.com") == "https://fallback.com" + + def test_whitespace_only_uses_fallback(self) -> None: + assert ( + _normalize_base_url(" ", "https://fallback.com") == "https://fallback.com" + ) + + def test_fallback_trailing_slash_stripped(self) -> None: + assert ( + _normalize_base_url("", "https://fallback.com/") == "https://fallback.com" + ) + + def test_leading_trailing_whitespace_stripped(self) -> None: + assert ( + _normalize_base_url(" https://api.example.com ", "https://fallback.com") + == "https://api.example.com" + ) + + def test_value_with_path(self) -> None: + assert ( + _normalize_base_url("https://api.example.com/v2/", "https://fallback.com") + == "https://api.example.com/v2" + ) diff --git a/tests/test_skills_registry_stats.py b/tests/test_skills_registry_stats.py new file mode 100644 index 0000000..d46fabd --- /dev/null +++ b/tests/test_skills_registry_stats.py @@ -0,0 +1,100 @@ +"""Tests for Undefined.skills.registry.SkillStats dataclass.""" + +from __future__ import annotations + +from Undefined.skills.registry import SkillStats + + +class TestSkillStats: + """Tests for SkillStats dataclass.""" + + def test_initial_state(self) -> None: + stats = SkillStats() + assert stats.count == 0 + assert stats.success == 0 + assert stats.failure == 0 + assert stats.total_duration == 0.0 + assert stats.last_duration == 0.0 + assert stats.last_error is None + assert stats.last_called_at is None + + def test_record_success(self) -> None: + stats = SkillStats() + stats.record_success(1.5) + assert stats.count == 1 + assert stats.success == 1 + assert stats.failure == 0 + assert stats.total_duration == 1.5 + assert stats.last_duration == 1.5 + assert stats.last_error is None + assert stats.last_called_at is not None + + def test_record_failure(self) -> None: + stats = SkillStats() + stats.record_failure(2.0, "timeout") + assert stats.count == 1 + assert stats.success == 0 + assert stats.failure == 1 + assert stats.total_duration == 2.0 + assert stats.last_duration == 2.0 + assert stats.last_error == "timeout" + assert stats.last_called_at is not None + + def test_multiple_successes(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + stats.record_success(2.0) + stats.record_success(3.0) + assert stats.count == 3 + assert stats.success == 3 + assert stats.failure == 0 + assert stats.total_duration == 6.0 + assert stats.last_duration == 3.0 + + def test_mixed_success_and_failure(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + stats.record_failure(0.5, "error A") + stats.record_success(2.0) + assert stats.count == 3 + assert stats.success == 2 + assert stats.failure == 1 + assert stats.total_duration == 3.5 + assert stats.last_duration == 2.0 + assert stats.last_error is None # cleared by success + + def test_success_clears_last_error(self) -> None: + stats = SkillStats() + stats.record_failure(1.0, "something broke") + assert stats.last_error == "something broke" + stats.record_success(0.5) + assert stats.last_error is None + + def test_failure_overwrites_last_error(self) -> None: + stats = SkillStats() + stats.record_failure(1.0, "error 1") + stats.record_failure(2.0, "error 2") + assert stats.last_error == "error 2" + + def test_average_duration(self) -> None: + stats = SkillStats() + stats.record_success(2.0) + stats.record_success(4.0) + avg = stats.total_duration / stats.count + assert avg == 3.0 + + def test_last_called_at_updates(self) -> None: + stats = SkillStats() + stats.record_success(1.0) + first_called = stats.last_called_at + assert first_called is not None + stats.record_failure(1.0, "err") + assert stats.last_called_at is not None + assert stats.last_called_at >= first_called + + def test_zero_duration(self) -> None: + stats = SkillStats() + stats.record_success(0.0) + assert stats.total_duration == 0.0 + assert stats.last_duration == 0.0 + assert stats.count == 1 diff --git a/tests/test_time_utils.py b/tests/test_time_utils.py new file mode 100644 index 0000000..d4c49cd --- /dev/null +++ b/tests/test_time_utils.py @@ -0,0 +1,84 @@ +"""Tests for Undefined.utils.time_utils — time parsing/formatting helpers.""" + +from __future__ import annotations + +from datetime import datetime + +from Undefined.utils.time_utils import format_datetime, parse_time_range + + +class TestParseTimeRange: + def test_both_valid(self) -> None: + start, end = parse_time_range("2024-01-15 08:30:00", "2024-06-20 17:45:00") + assert start == datetime(2024, 1, 15, 8, 30, 0) + assert end == datetime(2024, 6, 20, 17, 45, 0) + + def test_only_start(self) -> None: + start, end = parse_time_range("2024-01-01 00:00:00", None) + assert start == datetime(2024, 1, 1, 0, 0, 0) + assert end is None + + def test_only_end(self) -> None: + start, end = parse_time_range(None, "2024-12-31 23:59:59") + assert start is None + assert end == datetime(2024, 12, 31, 23, 59, 59) + + def test_both_none(self) -> None: + start, end = parse_time_range(None, None) + assert start is None + assert end is None + + def test_invalid_start_format(self) -> None: + start, end = parse_time_range("2024/01/01", None) + assert start is None + assert end is None + + def test_invalid_end_format(self) -> None: + start, end = parse_time_range(None, "not-a-date") + assert start is None + assert end is None + + def test_both_invalid(self) -> None: + start, end = parse_time_range("bad", "worse") + assert start is None + assert end is None + + def test_empty_strings(self) -> None: + start, end = parse_time_range("", "") + assert start is None + assert end is None + + def test_date_only_format_rejected(self) -> None: + start, end = parse_time_range("2024-01-01", None) + assert start is None + + def test_midnight(self) -> None: + start, end = parse_time_range("2024-01-01 00:00:00", None) + assert start == datetime(2024, 1, 1, 0, 0, 0) + + def test_end_of_day(self) -> None: + start, end = parse_time_range(None, "2024-12-31 23:59:59") + assert end == datetime(2024, 12, 31, 23, 59, 59) + + +class TestFormatDatetime: + def test_none_input(self) -> None: + assert format_datetime(None) == "未指定" + + def test_normal_datetime(self) -> None: + dt = datetime(2024, 3, 15, 14, 30, 45) + assert format_datetime(dt) == "2024-03-15 14:30:45" + + def test_midnight(self) -> None: + dt = datetime(2024, 1, 1, 0, 0, 0) + assert format_datetime(dt) == "2024-01-01 00:00:00" + + def test_end_of_day(self) -> None: + dt = datetime(2024, 12, 31, 23, 59, 59) + assert format_datetime(dt) == "2024-12-31 23:59:59" + + def test_roundtrip(self) -> None: + original = "2024-06-15 10:20:30" + start, _ = parse_time_range(original, None) + assert start is not None + assert format_datetime(start) == original diff --git a/tests/test_token_usage_unit.py b/tests/test_token_usage_unit.py new file mode 100644 index 0000000..821f6ea --- /dev/null +++ b/tests/test_token_usage_unit.py @@ -0,0 +1,169 @@ +"""TokenUsage 序列化/反序列化 单元测试""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from Undefined.token_usage_storage import TokenUsage + + +def _sample_dict() -> dict[str, Any]: + return { + "timestamp": "2025-01-01T00:00:00", + "model_name": "gpt-4", + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "duration_seconds": 1.5, + "call_type": "chat", + "success": True, + } + + +# --------------------------------------------------------------------------- +# to_dict / from_dict 往返 +# --------------------------------------------------------------------------- + + +class TestTokenUsageRoundtrip: + def test_basic_roundtrip(self) -> None: + d = _sample_dict() + usage = TokenUsage.from_dict(d) + assert usage.to_dict() == d + + def test_all_fields_preserved(self) -> None: + usage = TokenUsage( + timestamp="ts", + model_name="m", + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + duration_seconds=0.5, + call_type="agent", + success=False, + ) + restored = TokenUsage.from_dict(usage.to_dict()) + assert restored == usage + + +# --------------------------------------------------------------------------- +# from_dict — 缺失字段回退 +# --------------------------------------------------------------------------- + + +class TestTokenUsageFromDictDefaults: + def test_empty_dict(self) -> None: + usage = TokenUsage.from_dict({}) + assert usage.timestamp == "" + assert usage.model_name == "" + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.duration_seconds == 0.0 + assert usage.call_type == "unknown" + assert usage.success is True # 默认为 True + + def test_timestamp_fallback_to_time(self) -> None: + usage = TokenUsage.from_dict({"time": "2025-01-01"}) + assert usage.timestamp == "2025-01-01" + + def test_timestamp_fallback_to_created_at(self) -> None: + usage = TokenUsage.from_dict({"created_at": "2025-02-02"}) + assert usage.timestamp == "2025-02-02" + + def test_model_name_fallback_to_model(self) -> None: + usage = TokenUsage.from_dict({"model": "claude"}) + assert usage.model_name == "claude" + + def test_prompt_tokens_fallback_to_input_tokens(self) -> None: + usage = TokenUsage.from_dict({"input_tokens": 42}) + assert usage.prompt_tokens == 42 + + def test_completion_tokens_fallback_to_output_tokens(self) -> None: + usage = TokenUsage.from_dict({"output_tokens": 24}) + assert usage.completion_tokens == 24 + + def test_total_tokens_auto_sum(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": 10, "completion_tokens": 20}) + assert usage.total_tokens == 30 + + def test_total_tokens_explicit(self) -> None: + usage = TokenUsage.from_dict( + {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 99} + ) + assert usage.total_tokens == 99 + + def test_call_type_fallback_to_type(self) -> None: + usage = TokenUsage.from_dict({"type": "vision"}) + assert usage.call_type == "vision" + + def test_duration_fallback_to_duration(self) -> None: + usage = TokenUsage.from_dict({"duration": 3.14}) + assert usage.duration_seconds == pytest.approx(3.14) + + +# --------------------------------------------------------------------------- +# from_dict — success 字段各种类型 +# --------------------------------------------------------------------------- + + +class TestTokenUsageSuccess: + def test_success_bool_true(self) -> None: + assert TokenUsage.from_dict({"success": True}).success is True + + def test_success_bool_false(self) -> None: + assert TokenUsage.from_dict({"success": False}).success is False + + def test_success_string_false(self) -> None: + assert TokenUsage.from_dict({"success": "false"}).success is False + + def test_success_string_0(self) -> None: + assert TokenUsage.from_dict({"success": "0"}).success is False + + def test_success_string_no(self) -> None: + assert TokenUsage.from_dict({"success": "no"}).success is False + + def test_success_string_yes(self) -> None: + assert TokenUsage.from_dict({"success": "yes"}).success is True + + def test_success_int_1(self) -> None: + assert TokenUsage.from_dict({"success": 1}).success is True + + def test_success_int_0(self) -> None: + assert TokenUsage.from_dict({"success": 0}).success is False + + +# --------------------------------------------------------------------------- +# from_dict — 类型转换容错 +# --------------------------------------------------------------------------- + + +class TestTokenUsageTypeCoercion: + def test_string_tokens(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": "42"}) + assert usage.prompt_tokens == 42 + + def test_invalid_tokens_default_zero(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": "abc"}) + assert usage.prompt_tokens == 0 + + def test_none_tokens_default_zero(self) -> None: + usage = TokenUsage.from_dict({"prompt_tokens": None}) + assert usage.prompt_tokens == 0 + + def test_non_string_timestamp(self) -> None: + usage = TokenUsage.from_dict({"timestamp": 12345}) + assert usage.timestamp == "12345" + + def test_non_string_model_name(self) -> None: + usage = TokenUsage.from_dict({"model_name": 42}) + assert usage.model_name == "42" + + def test_extra_fields_ignored(self) -> None: + d = _sample_dict() + d["extra_field"] = "ignored" + d["another"] = 999 + usage = TokenUsage.from_dict(d) + assert usage.model_name == "gpt-4" diff --git a/tests/test_tool_calls.py b/tests/test_tool_calls.py new file mode 100644 index 0000000..4bac35f --- /dev/null +++ b/tests/test_tool_calls.py @@ -0,0 +1,273 @@ +"""Tests for Undefined.utils.tool_calls.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import pytest + +from Undefined.utils.tool_calls import ( + _clean_json_string, + _repair_json_like_string, + _strip_code_fences, + extract_required_tool_call_arguments, + normalize_tool_arguments_json, + parse_tool_arguments, +) + + +@pytest.fixture() +def test_logger() -> logging.Logger: + return logging.getLogger("test") + + +# --------------------------------------------------------------------------- +# _strip_code_fences +# --------------------------------------------------------------------------- + + +class TestStripCodeFences: + def test_strip_json_fence(self) -> None: + raw = '```json\n{"a": 1}\n```' + assert _strip_code_fences(raw) == '{"a": 1}' + + def test_strip_generic_fence(self) -> None: + raw = '```\n{"a": 1}\n```' + assert _strip_code_fences(raw) == '{"a": 1}' + + def test_no_fence(self) -> None: + raw = '{"a": 1}' + assert _strip_code_fences(raw) == '{"a": 1}' + + +# --------------------------------------------------------------------------- +# _clean_json_string +# --------------------------------------------------------------------------- + + +class TestCleanJsonString: + def test_removes_control_chars(self) -> None: + raw = '{"key":\r\n\t"val"}' + result = _clean_json_string(raw) + assert "\r" not in result + assert "\n" not in result + assert "\t" not in result + + +# --------------------------------------------------------------------------- +# _repair_json_like_string +# --------------------------------------------------------------------------- + + +class TestRepairJsonLikeString: + def test_missing_closing_brace(self) -> None: + raw = '{"a": 1' + repaired = _repair_json_like_string(raw) + assert json.loads(repaired) == {"a": 1} + + def test_trailing_comma(self) -> None: + raw = '{"a": 1, ' + repaired = _repair_json_like_string(raw) + assert json.loads(repaired) == {"a": 1} + + def test_empty_string(self) -> None: + assert _repair_json_like_string("") == "" + + +# --------------------------------------------------------------------------- +# parse_tool_arguments +# --------------------------------------------------------------------------- + + +class TestParseToolArguments: + def test_dict_passthrough(self) -> None: + d: dict[str, Any] = {"key": "val"} + assert parse_tool_arguments(d) is d + + def test_none_returns_empty(self) -> None: + assert parse_tool_arguments(None) == {} + + def test_empty_string_returns_empty(self) -> None: + assert parse_tool_arguments("") == {} + + def test_whitespace_returns_empty(self) -> None: + assert parse_tool_arguments(" ") == {} + + def test_valid_json_string(self) -> None: + result = parse_tool_arguments('{"x": 42}') + assert result == {"x": 42} + + def test_json_with_code_fences(self) -> None: + raw = '```json\n{"x": 42}\n```' + assert parse_tool_arguments(raw) == {"x": 42} + + def test_json_with_control_chars(self, test_logger: logging.Logger) -> None: + raw = '{"x":\r\n42}' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"x": 42} + + def test_truncated_json_repaired(self, test_logger: logging.Logger) -> None: + raw = '{"a": "hello"' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"a": "hello"} + + def test_json_with_trailing_content(self, test_logger: logging.Logger) -> None: + raw = '{"a": 1} some trailing text' + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {"a": 1} + + def test_non_dict_json_returns_empty(self, test_logger: logging.Logger) -> None: + raw = "[1, 2, 3]" + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {} + + def test_completely_invalid_returns_empty( + self, test_logger: logging.Logger + ) -> None: + raw = "this is not json at all" + result = parse_tool_arguments(raw, logger=test_logger, tool_name="t") + assert result == {} + + def test_unsupported_type_returns_empty(self, test_logger: logging.Logger) -> None: + result = parse_tool_arguments(42, logger=test_logger, tool_name="t") + assert result == {} + + +# --------------------------------------------------------------------------- +# normalize_tool_arguments_json +# --------------------------------------------------------------------------- + + +class TestNormalizeToolArgumentsJson: + def test_none(self) -> None: + assert normalize_tool_arguments_json(None) == "{}" + + def test_dict(self) -> None: + result = normalize_tool_arguments_json({"a": 1}) + parsed = json.loads(result) + assert parsed == {"a": 1} + + def test_empty_string(self) -> None: + assert normalize_tool_arguments_json("") == "{}" + + def test_valid_json_object_string(self) -> None: + result = normalize_tool_arguments_json('{"key": "val"}') + parsed = json.loads(result) + assert parsed == {"key": "val"} + + def test_non_object_json_wrapped(self) -> None: + result = normalize_tool_arguments_json("[1,2,3]") + parsed = json.loads(result) + assert parsed == {"_value": [1, 2, 3]} + + def test_invalid_json_wrapped_raw(self) -> None: + result = normalize_tool_arguments_json("not json") + parsed = json.loads(result) + assert parsed == {"_raw": "not json"} + + def test_non_string_non_dict_wrapped(self) -> None: + result = normalize_tool_arguments_json(42) + parsed = json.loads(result) + assert parsed == {"_value": 42} + + def test_number_json_string_wrapped(self) -> None: + result = normalize_tool_arguments_json("123") + parsed = json.loads(result) + assert parsed == {"_value": 123} + + +# --------------------------------------------------------------------------- +# extract_required_tool_call_arguments +# --------------------------------------------------------------------------- + + +class TestExtractRequiredToolCallArguments: + def _build_response( + self, + name: str = "my_tool", + arguments: Any = '{"x": 1}', + ) -> dict[str, Any]: + return { + "choices": [ + { + "message": { + "tool_calls": [ + { + "function": { + "name": name, + "arguments": arguments, + } + } + ] + } + } + ] + } + + def test_happy_path(self) -> None: + resp = self._build_response() + result = extract_required_tool_call_arguments( + resp, expected_tool_name="my_tool", stage="test" + ) + assert result == {"x": 1} + + def test_missing_choices_raises(self) -> None: + with pytest.raises(ValueError, match="choices"): + extract_required_tool_call_arguments({}, expected_tool_name="t", stage="s") + + def test_non_dict_choice_raises(self) -> None: + with pytest.raises(ValueError, match="choice"): + extract_required_tool_call_arguments( + {"choices": ["bad"]}, expected_tool_name="t", stage="s" + ) + + def test_missing_message_raises(self) -> None: + with pytest.raises(ValueError, match="message"): + extract_required_tool_call_arguments( + {"choices": [{"no_message": True}]}, + expected_tool_name="t", + stage="s", + ) + + def test_missing_tool_calls_raises(self) -> None: + with pytest.raises(ValueError, match="tool_calls"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"content": "hi"}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_non_dict_tool_call_raises(self) -> None: + with pytest.raises(ValueError, match="tool_call"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"tool_calls": ["bad"]}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_missing_function_raises(self) -> None: + with pytest.raises(ValueError, match="function"): + extract_required_tool_call_arguments( + {"choices": [{"message": {"tool_calls": [{"id": "1"}]}}]}, + expected_tool_name="t", + stage="s", + ) + + def test_name_mismatch_raises(self) -> None: + resp = self._build_response(name="wrong_name") + with pytest.raises(ValueError, match="不匹配"): + extract_required_tool_call_arguments( + resp, expected_tool_name="my_tool", stage="s" + ) + + def test_with_logger(self, test_logger: logging.Logger) -> None: + resp = self._build_response() + result = extract_required_tool_call_arguments( + resp, + expected_tool_name="my_tool", + stage="test", + logger=test_logger, + ) + assert result == {"x": 1} diff --git a/tests/test_utils_common.py b/tests/test_utils_common.py new file mode 100644 index 0000000..c1075b4 --- /dev/null +++ b/tests/test_utils_common.py @@ -0,0 +1,305 @@ +"""Tests for Undefined.utils.common.""" + +from __future__ import annotations + +from typing import Any + +from Undefined.utils.common import ( + FORWARD_EXPAND_MAX_CHARS, + _format_forward_node_time, + _normalize_message_content, + _parse_at_segment, + _parse_media_segment, + _parse_segment, + _truncate_forward_text, + extract_text, + matches_xinliweiyuan, + message_to_segments, + process_at_mentions, +) + + +# --------------------------------------------------------------------------- +# extract_text +# --------------------------------------------------------------------------- + + +class TestExtractText: + def test_text_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "text", "data": {"text": "hello"}}] + assert extract_text(segments) == "hello" + + def test_multiple_text_segments_joined(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "text", "data": {"text": "hello "}}, + {"type": "text", "data": {"text": "world"}}, + ] + assert extract_text(segments) == "hello world" + + def test_at_segment_without_name(self) -> None: + segments: list[dict[str, Any]] = [{"type": "at", "data": {"qq": "123456"}}] + assert extract_text(segments) == "[@123456]" + + def test_at_segment_with_name(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "at", "data": {"qq": "123456", "name": "Bob"}} + ] + assert extract_text(segments) == "[@123456(Bob)]" + + def test_face_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "face", "data": {}}] + assert extract_text(segments) == "[表情]" + + def test_image_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "image", "data": {"file": "a.png"}}] + assert extract_text(segments) == "[图片: a.png]" + + def test_file_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "file", "data": {"file": "doc.pdf"}}] + assert extract_text(segments) == "[文件: doc.pdf]" + + def test_video_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "video", "data": {"file": "v.mp4"}}] + assert extract_text(segments) == "[视频: v.mp4]" + + def test_record_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "record", "data": {"file": "r.amr"}}] + assert extract_text(segments) == "[语音: r.amr]" + + def test_audio_segment(self) -> None: + segments: list[dict[str, Any]] = [{"type": "audio", "data": {"file": "a.mp3"}}] + assert extract_text(segments) == "[音频: a.mp3]" + + def test_forward_segment_with_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "forward", "data": {"id": "fw123"}}] + assert extract_text(segments) == "[合并转发: fw123]" + + def test_forward_segment_without_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "forward", "data": {}}] + assert extract_text(segments) == "[合并转发]" + + def test_reply_segment_with_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "reply", "data": {"id": "42"}}] + assert extract_text(segments) == "[引用: 42]" + + def test_reply_segment_without_id(self) -> None: + segments: list[dict[str, Any]] = [{"type": "reply", "data": {}}] + assert extract_text(segments) == "[引用]" + + def test_unknown_segment_skipped(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "unknown_custom", "data": {}}, + {"type": "text", "data": {"text": "ok"}}, + ] + assert extract_text(segments) == "ok" + + def test_empty_segments(self) -> None: + assert extract_text([]) == "" + + def test_mixed_segments(self) -> None: + segments: list[dict[str, Any]] = [ + {"type": "text", "data": {"text": "hi "}}, + {"type": "face", "data": {}}, + {"type": "text", "data": {"text": " bye"}}, + ] + assert extract_text(segments) == "hi [表情] bye" + + def test_data_not_dict_fallback(self) -> None: + """If data is not a dict, segment should still be handled safely.""" + segments: list[dict[str, Any]] = [ + {"type": "text", "data": "not_a_dict"}, + ] + # data becomes {}, so text = "" + assert extract_text(segments) == "" + + +# --------------------------------------------------------------------------- +# process_at_mentions +# --------------------------------------------------------------------------- + + +class TestProcessAtMentions: + def test_basic_at(self) -> None: + assert process_at_mentions("[@123456]") == "[CQ:at,qq=123456]" + + def test_at_with_braces(self) -> None: + assert process_at_mentions("[@{123456}]") == "[CQ:at,qq=123456]" + + def test_multiple_ats(self) -> None: + result = process_at_mentions("[@11111] hi [@22222]") + assert result == "[CQ:at,qq=11111] hi [CQ:at,qq=22222]" + + def test_escaped_brackets(self) -> None: + result = process_at_mentions("\\[@123456\\]") + assert result == "[@123456]" + + def test_no_match(self) -> None: + assert process_at_mentions("hello world") == "hello world" + + +# --------------------------------------------------------------------------- +# message_to_segments +# --------------------------------------------------------------------------- + + +class TestMessageToSegments: + def test_plain_text_only(self) -> None: + segs = message_to_segments("hello world") + assert segs == [{"type": "text", "data": {"text": "hello world"}}] + + def test_cq_at(self) -> None: + segs = message_to_segments("[CQ:at,qq=123]") + assert segs == [{"type": "at", "data": {"qq": "123"}}] + + def test_text_around_cq(self) -> None: + segs = message_to_segments("hi [CQ:face,id=178] bye") + assert len(segs) == 3 + assert segs[0] == {"type": "text", "data": {"text": "hi "}} + assert segs[1] == {"type": "face", "data": {"id": "178"}} + assert segs[2] == {"type": "text", "data": {"text": " bye"}} + + def test_empty_string(self) -> None: + assert message_to_segments("") == [] + + def test_cq_without_args(self) -> None: + segs = message_to_segments("[CQ:face]") + assert segs == [{"type": "face", "data": {}}] + + +# --------------------------------------------------------------------------- +# matches_xinliweiyuan +# --------------------------------------------------------------------------- + + +class TestMatchesXinliweiyuan: + def test_exact_keyword(self) -> None: + assert matches_xinliweiyuan("心理委员") is True + + def test_keyword_with_prefix(self) -> None: + assert matches_xinliweiyuan("找心理委员") is True + + def test_keyword_with_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员在吗") is True + + def test_keyword_both_sides_fails(self) -> None: + assert matches_xinliweiyuan("我找心理委员吧") is False + + def test_no_keyword(self) -> None: + assert matches_xinliweiyuan("你好世界") is False + + def test_too_many_extra_chars(self) -> None: + assert matches_xinliweiyuan("abcdef心理委员") is False + + def test_punctuation_not_counted(self) -> None: + # Punctuation is removed before counting + assert matches_xinliweiyuan("!!心理委员") is True + + def test_five_chars_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员abcde") is True + + def test_six_chars_suffix(self) -> None: + assert matches_xinliweiyuan("心理委员abcdef") is False + + +# --------------------------------------------------------------------------- +# _normalize_message_content +# --------------------------------------------------------------------------- + + +class TestNormalizeMessageContent: + def test_list_of_dicts(self) -> None: + content: list[dict[str, Any]] = [{"type": "text", "data": {"text": "hi"}}] + result = _normalize_message_content(content) + assert result == content + + def test_single_dict(self) -> None: + seg: dict[str, Any] = {"type": "text", "data": {"text": "hi"}} + result = _normalize_message_content(seg) + assert result == [seg] + + def test_string(self) -> None: + result = _normalize_message_content("hello [CQ:face]") + assert len(result) == 2 + assert result[0]["type"] == "text" + assert result[1]["type"] == "face" + + def test_list_with_string_items(self) -> None: + result = _normalize_message_content(["hello"]) + assert result == [{"type": "text", "data": {"text": "hello"}}] + + def test_unsupported_type_returns_empty(self) -> None: + result = _normalize_message_content(12345) + assert result == [] + + +# --------------------------------------------------------------------------- +# _format_forward_node_time +# --------------------------------------------------------------------------- + + +class TestFormatForwardNodeTime: + def test_valid_timestamp(self) -> None: + result = _format_forward_node_time(1700000000) + assert "2023" in result + + def test_millisecond_timestamp(self) -> None: + result = _format_forward_node_time(1700000000000) + assert "2023" in result + + def test_zero_returns_empty(self) -> None: + assert _format_forward_node_time(0) == "" + + def test_none_returns_empty(self) -> None: + assert _format_forward_node_time(None) == "" + + def test_empty_string_returns_empty(self) -> None: + assert _format_forward_node_time("") == "" + + def test_invalid_string_returns_as_is(self) -> None: + assert _format_forward_node_time("not_a_time") == "not_a_time" + + +# --------------------------------------------------------------------------- +# _truncate_forward_text +# --------------------------------------------------------------------------- + + +class TestTruncateForwardText: + def test_short_text_not_truncated(self) -> None: + text = "hello" + assert _truncate_forward_text(text) == text + + def test_long_text_truncated(self) -> None: + text = "a" * (FORWARD_EXPAND_MAX_CHARS + 100) + result = _truncate_forward_text(text) + assert "[合并转发内容过长,已截断]" in result + assert len(result) <= FORWARD_EXPAND_MAX_CHARS + 50 # marker included + + +# --------------------------------------------------------------------------- +# _parse_segment / _parse_at_segment / _parse_media_segment +# --------------------------------------------------------------------------- + + +class TestParseHelpers: + def test_parse_at_segment_with_nickname(self) -> None: + result = _parse_at_segment({"qq": "999", "nickname": "Nick"}, bot_qq=0) + assert result == "[@999(Nick)]" + + def test_parse_at_segment_no_name(self) -> None: + result = _parse_at_segment({"qq": "999"}, bot_qq=0) + assert result == "[@999]" + + def test_parse_media_segment_image(self) -> None: + result = _parse_media_segment("image", {"file": "pic.jpg"}) + assert result == "[图片: pic.jpg]" + + def test_parse_media_segment_unknown(self) -> None: + result = _parse_media_segment("custom_type", {}) + assert result is None + + def test_parse_segment_missing_type(self) -> None: + seg: dict[str, Any] = {"data": {"text": "hello"}} + # type="" → falls through to _parse_media_segment → None + result = _parse_segment(seg) + assert result is None diff --git a/tests/test_xml_utils.py b/tests/test_xml_utils.py new file mode 100644 index 0000000..fbb89a0 --- /dev/null +++ b/tests/test_xml_utils.py @@ -0,0 +1,100 @@ +"""Tests for Undefined.utils.xml — XML escaping helpers.""" + +from __future__ import annotations + +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + + +class TestEscapeXmlText: + def test_plain_text(self) -> None: + assert escape_xml_text("hello world") == "hello world" + + def test_ampersand(self) -> None: + assert escape_xml_text("a & b") == "a & b" + + def test_less_than(self) -> None: + assert escape_xml_text("a < b") == "a < b" + + def test_greater_than(self) -> None: + assert escape_xml_text("a > b") == "a > b" + + def test_double_quote(self) -> None: + assert escape_xml_text('say "hello"') == "say "hello"" + + def test_single_quote(self) -> None: + assert escape_xml_text("it's") == "it's" + + def test_all_special_chars(self) -> None: + result = escape_xml_text("""""") + assert "<" in result + assert ">" in result + assert "&" in result + assert """ in result + assert "'" in result + + def test_empty_string(self) -> None: + assert escape_xml_text("") == "" + + def test_unicode(self) -> None: + assert escape_xml_text("こんにちは") == "こんにちは" + + def test_unicode_with_special(self) -> None: + assert escape_xml_text("价格 < 100 & > 50") == "价格 < 100 & > 50" + + def test_nested_quotes(self) -> None: + result = escape_xml_text("""He said "it's fine" """) + assert """ in result + assert "'" in result + + def test_multiline(self) -> None: + text = "line1\nline2\n" + result = escape_xml_text(text) + assert "\n" in result + assert "<tag>" in result + + def test_already_escaped(self) -> None: + result = escape_xml_text("&") + assert result == "&amp;" + + +class TestEscapeXmlAttr: + def test_plain_string(self) -> None: + assert escape_xml_attr("hello") == "hello" + + def test_special_chars(self) -> None: + result = escape_xml_attr('') + assert "<" in result + assert "&" in result + assert """ in result + assert ">" in result + + def test_none_input(self) -> None: + assert escape_xml_attr(None) == "" + + def test_integer_input(self) -> None: + assert escape_xml_attr(42) == "42" + + def test_float_input(self) -> None: + assert escape_xml_attr(3.14) == "3.14" + + def test_bool_input(self) -> None: + assert escape_xml_attr(True) == "True" + assert escape_xml_attr(False) == "False" + + def test_empty_string(self) -> None: + assert escape_xml_attr("") == "" + + def test_object_with_str(self) -> None: + class Obj: + def __str__(self) -> str: + return '' + + result = escape_xml_attr(Obj()) + assert "<script>" in result + assert """ in result + + def test_unicode(self) -> None: + assert escape_xml_attr("日本語") == "日本語" + + def test_zero(self) -> None: + assert escape_xml_attr(0) == "0" From 3ccd2ab9923648294a3fd36cf9406221984aa6f9 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 10:48:54 +0800 Subject: [PATCH 24/57] =?UTF-8?q?feat(help):=20=E4=B8=BA=20/help=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20/h=20=E5=88=AB=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/skills/commands/help/config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Undefined/skills/commands/help/config.json b/src/Undefined/skills/commands/help/config.json index 700dcdc..8639fb7 100644 --- a/src/Undefined/skills/commands/help/config.json +++ b/src/Undefined/skills/commands/help/config.json @@ -12,7 +12,7 @@ "show_in_help": true, "order": 10, "allow_in_private": true, - "aliases": [], + "aliases": ["h"], "help_footer": [ "查看详细帮助:/help ", "详细版权与免责声明:/cprt", From d6076be8891ba31b7f80d539727c9dcbd1c3923e Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:27:20 +0800 Subject: [PATCH 25/57] =?UTF-8?q?feat:=20profile=E8=B6=85=E7=AE=A1?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2=20+=20utils/coerce=E5=85=AC=E5=85=B1?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=20+=20WebUI=E5=AE=89=E5=85=A8=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - feat(profile): 超级管理员可指定目标 /p 或 /p g <群号> - refactor(utils): 统一 safe_int/safe_float 到 utils/coerce.py,替换8处重复 - feat(webui): 全局JS错误处理 window.onerror + toast - feat(webui): AbortController 请求取消,Tab切换时终止旧请求 - test: 新增6个profile超管指定目标测试 Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/ai/multimodal.py | 14 +- src/Undefined/ai/prompts.py | 24 +-- src/Undefined/ai/retrieval.py | 19 +-- src/Undefined/cognitive/service.py | 20 +-- src/Undefined/cognitive/vector_store.py | 27 +--- src/Undefined/handlers.py | 15 +- src/Undefined/memes/service.py | 13 +- .../skills/commands/profile/handler.py | 31 ++-- src/Undefined/skills/tools/end/handler.py | 17 +- src/Undefined/utils/coerce.py | 37 +++++ src/Undefined/webui/static/js/api.js | 29 ++++ src/Undefined/webui/static/js/main.js | 32 ++++ tests/test_profile_command.py | 148 +++++++++++++++++- 13 files changed, 310 insertions(+), 116 deletions(-) create mode 100644 src/Undefined/utils/coerce.py diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index 3391a6e..f5f1ab2 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -16,6 +16,7 @@ import httpx from Undefined.ai.parsing import extract_choices_content +from Undefined.utils.coerce import safe_float from Undefined.ai.llm import ModelRequester from Undefined.config import VisionModelConfig from Undefined.ai.transports import API_MODE_CHAT_COMPLETIONS, get_api_mode @@ -352,19 +353,12 @@ def _parse_meme_analysis_response(content: str) -> dict[str, Any]: parsed = _extract_json_object(content) return { "is_meme": bool(parsed.get("is_meme", False)), - "confidence": _safe_float(parsed.get("confidence", 0.0), default=0.0), + "confidence": safe_float(parsed.get("confidence", 0.0), default=0.0), "description": str(parsed.get("description") or "").strip(), "tags": _normalize_meme_tags(parsed.get("tags")), } -def _safe_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except (TypeError, ValueError): - return default - - class MultimodalAnalyzer: """多模态媒体分析器。 @@ -880,7 +874,7 @@ async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: try: parsed = { "is_meme": bool(args.get("is_meme", False)), - "confidence": _safe_float(args.get("confidence", 0.0), default=0.0), + "confidence": safe_float(args.get("confidence", 0.0), default=0.0), "reason": str(args.get("reason") or "").strip(), } except Exception: @@ -889,7 +883,7 @@ async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: "[媒体分析] 表情包判定完成: url=%s is_meme=%s confidence=%.3f reason=%s", safe_url[:50], parsed.get("is_meme", False), - _safe_float(parsed.get("confidence", 0.0), default=0.0), + safe_float(parsed.get("confidence", 0.0), default=0.0), str(parsed.get("reason", ""))[:80], ) return parsed diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index 4fc9d81..dda6b04 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -12,6 +12,7 @@ import aiofiles from Undefined.attachments import attachment_refs_to_xml +from Undefined.utils.coerce import safe_int from Undefined.context import RequestContext from Undefined.end_summary_storage import ( EndSummaryStorage, @@ -636,39 +637,24 @@ def _resolve_chat_scope( ) -> tuple[Literal["group", "private"], int] | None: ctx = RequestContext.current() - def _safe_int(value: Any) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return value - if isinstance(value, str): - text = value.strip() - if not text: - return None - try: - return int(text) - except ValueError: - return None - return None - if ctx and ctx.request_type == "group" and ctx.group_id is not None: - group_id = _safe_int(ctx.group_id) + group_id = safe_int(ctx.group_id) if group_id is not None: return ("group", group_id) return None if ctx and ctx.request_type == "private" and ctx.user_id is not None: - user_id = _safe_int(ctx.user_id) + user_id = safe_int(ctx.user_id) if user_id is not None: return ("private", user_id) return None if extra_context and extra_context.get("group_id") is not None: - group_id = _safe_int(extra_context.get("group_id")) + group_id = safe_int(extra_context.get("group_id")) if group_id is not None: return ("group", group_id) return None if extra_context and extra_context.get("user_id") is not None: - user_id = _safe_int(extra_context.get("user_id")) + user_id = safe_int(extra_context.get("user_id")) if user_id is not None: return ("private", user_id) return None diff --git a/src/Undefined/ai/retrieval.py b/src/Undefined/ai/retrieval.py index 82c42d4..5fe8337 100644 --- a/src/Undefined/ai/retrieval.py +++ b/src/Undefined/ai/retrieval.py @@ -10,6 +10,7 @@ from openai import NOT_GIVEN, AsyncOpenAI from Undefined.ai.tokens import TokenCounter +from Undefined.utils.coerce import safe_int from Undefined.config import EmbeddingModelConfig, RerankModelConfig from Undefined.utils.request_params import split_reserved_request_params @@ -224,13 +225,13 @@ def _extract_usage(self, response_dict: dict[str, Any]) -> tuple[int, int, int]: usage = response_dict.get("usage", {}) or {} if not isinstance(usage, dict): usage = {} - prompt_tokens = self._safe_int( - usage.get("prompt_tokens", usage.get("input_tokens", 0)) + prompt_tokens = safe_int( + usage.get("prompt_tokens", usage.get("input_tokens", 0)), 0 ) - completion_tokens = self._safe_int( - usage.get("completion_tokens", usage.get("output_tokens", 0)) + completion_tokens = safe_int( + usage.get("completion_tokens", usage.get("output_tokens", 0)), 0 ) - total_tokens = self._safe_int(usage.get("total_tokens", 0)) + total_tokens = safe_int(usage.get("total_tokens", 0), 0) if total_tokens <= 0 and (prompt_tokens > 0 or completion_tokens > 0): total_tokens = prompt_tokens + completion_tokens return prompt_tokens, completion_tokens, total_tokens @@ -275,7 +276,7 @@ def _normalize_rerank_results( for idx, item in enumerate(raw_results): if not isinstance(item, dict): continue - doc_index = self._safe_int(item.get("index", idx)) + doc_index = safe_int(item.get("index", idx), 0) if doc_index < 0: continue @@ -322,9 +323,3 @@ def _normalize_rerank_results( } for i in range(limit) ] - - def _safe_int(self, value: Any) -> int: - try: - return int(value or 0) - except (TypeError, ValueError): - return 0 diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index 9b7405e..1febe85 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, cast from Undefined.context import RequestContext +from Undefined.utils.coerce import safe_float logger = logging.getLogger(__name__) @@ -39,17 +40,6 @@ def _compose_where(clauses: list[dict[str, Any]]) -> dict[str, Any] | None: return {"$and": clauses} -def _safe_float(value: Any, default: float = 0.0) -> float: - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - try: - return float(value.strip()) - except Exception: - return default - return default - - def _event_base_score(item: dict[str, Any]) -> float: rerank_score = item.get("rerank_score") if isinstance(rerank_score, (int, float)): @@ -59,7 +49,7 @@ def _event_base_score(item: dict[str, Any]) -> float: return max(0.0, float(rerank_score.strip())) except Exception: pass - similarity = 1.0 - _safe_float(item.get("distance"), default=1.0) + similarity = 1.0 - safe_float(item.get("distance"), default=1.0) if similarity < 0.0: return 0.0 if similarity > 1.0: @@ -336,7 +326,7 @@ def _merge_weighted_events( ] = [] serial = 0 for scoped_events, scope_weight in scoped_results: - safe_scope_weight = max(0.0, _safe_float(scope_weight, default=1.0)) + safe_scope_weight = max(0.0, safe_float(scope_weight, default=1.0)) scope_size = max(1, len(scoped_events)) for rank_idx, event in enumerate(scoped_events): dedupe_key = _event_dedupe_key(event) @@ -399,12 +389,12 @@ async def _query_events_for_auto_context( if scope_candidate_multiplier <= 0: scope_candidate_multiplier = 2 scoped_top_k = max(safe_top_k, safe_top_k * scope_candidate_multiplier) - current_group_boost = _safe_float( + current_group_boost = safe_float( getattr(config, "auto_current_group_boost", 1.15), default=1.15 ) if current_group_boost <= 0: current_group_boost = 1.15 - current_private_boost = _safe_float( + current_private_boost = safe_float( getattr(config, "auto_current_private_boost", 1.25), default=1.25 ) if current_private_boost <= 0: diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index b6d1c7c..0c84f5a 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -11,6 +11,8 @@ from typing import Any import chromadb + +from Undefined.utils.coerce import safe_float from chromadb.errors import InternalError as ChromaInternalError import numpy as np from numba import njit @@ -32,17 +34,6 @@ def _clamp(value: float, lower: float, upper: float) -> float: return value -def _safe_float(value: Any, default: float = 0.0) -> float: - if isinstance(value, (int, float)): - return float(value) - if isinstance(value, str): - try: - return float(value.strip()) - except Exception: - return default - return default - - def _safe_positive_int(value: Any, default: int) -> int: try: parsed = int(value) @@ -120,7 +111,7 @@ def _sanitize_metadata(metadata: dict[str, Any]) -> dict[str, Any]: def _similarity_from_distance(distance: Any) -> float: - dist = _safe_float(distance, default=1.0) + dist = safe_float(distance, default=1.0) return _clamp(1.0 - dist, 0.0, 1.0) @@ -553,14 +544,14 @@ def _q() -> Any: else: reranked_results: list[dict[str, Any]] = [] for item in reranked[:rerank_top_n]: - index = int(_safe_float(item.get("index"), default=-1)) + index = int(safe_float(item.get("index"), default=-1)) if index < 0 or index >= len(results): continue entry: dict[str, Any] = { "document": item.get("document", results[index]["document"]), "metadata": results[index]["metadata"], "distance": results[index]["distance"], - "rerank_score": _safe_float( + "rerank_score": safe_float( item.get("relevance_score"), default=0.0 ), } @@ -637,11 +628,9 @@ def _apply_time_decay_ranking( collection_name: str, ) -> list[dict[str, Any]]: safe_top_k = max(1, int(top_k)) - safe_half_life_days = _safe_float(half_life_days, default=14.0) - safe_boost = max(0.0, _safe_float(boost, default=0.2)) - safe_min_similarity = _clamp( - _safe_float(min_similarity, default=0.35), 0.0, 1.0 - ) + safe_half_life_days = safe_float(half_life_days, default=14.0) + safe_boost = max(0.0, safe_float(boost, default=0.2)) + safe_min_similarity = _clamp(safe_float(min_similarity, default=0.35), 0.0, 1.0) if safe_half_life_days <= 0: logger.warning( "[认知向量库] 时间衰减参数非法,跳过时间加权: collection=%s half_life_days=%s", diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index a42899f..4bd21a8 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -1,5 +1,7 @@ """消息处理和命令分发""" +from __future__ import annotations + import asyncio from dataclasses import dataclass import logging @@ -40,6 +42,7 @@ from Undefined.scheduled_task_storage import ScheduledTaskStorage from Undefined.utils.logging import log_debug_json, redact_string +from Undefined.utils.coerce import safe_int logger = logging.getLogger(__name__) @@ -47,14 +50,6 @@ REPEAT_REPLY_HISTORY_PREFIX = "[系统复读] " -def _safe_int(value: Any) -> int | None: - try: - parsed = int(value) - except (TypeError, ValueError): - return None - return parsed if parsed > 0 else None - - def _format_poke_history_text(display_name: str, user_id: int) -> str: """格式化拍一拍历史文本。""" return f"{display_name}(暱称)[{user_id}(QQ号)] 拍了拍你。" @@ -553,7 +548,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: chat_type="private", chat_id=private_sender_id, sender_id=private_sender_id, - message_id=_safe_int(trigger_message_id), + message_id=safe_int(trigger_message_id), scope_key=build_attachment_scope( user_id=private_sender_id, request_type="private", @@ -727,7 +722,7 @@ async def handle_message(self, event: dict[str, Any]) -> None: chat_type="group", chat_id=group_id, sender_id=sender_id, - message_id=_safe_int(trigger_message_id), + message_id=safe_int(trigger_message_id), scope_key=build_attachment_scope(group_id=group_id, request_type="group"), ) diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 851b6a6..7712232 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -29,6 +29,7 @@ from Undefined.memes.store import MemeStore from Undefined.memes.vector_store import MemeVectorStore from Undefined.utils.message_targets import resolve_message_target +from Undefined.utils.coerce import safe_int from Undefined.utils.paths import ensure_dir logger = logging.getLogger(__name__) @@ -48,16 +49,6 @@ def _now_iso() -> str: return datetime.now().isoformat(timespec="seconds") -def _safe_int(value: Any) -> int | None: - if value is None: - return None - try: - parsed = int(value) - except (TypeError, ValueError): - return None - return parsed if parsed > 0 else None - - def _guess_suffix(path: Path, mime_type: str) -> str: suffix = path.suffix.lower() if suffix: @@ -748,7 +739,7 @@ async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: attachments=history_attachments, ) else: - preferred_temp_group_id = _safe_int(context.get("group_id")) + preferred_temp_group_id = safe_int(context.get("group_id")) sent_message_id = await sender.send_private_message( int(target_id), cq_message, diff --git a/src/Undefined/skills/commands/profile/handler.py b/src/Undefined/skills/commands/profile/handler.py index 4040ab2..f3d94f6 100644 --- a/src/Undefined/skills/commands/profile/handler.py +++ b/src/Undefined/skills/commands/profile/handler.py @@ -29,10 +29,14 @@ def _truncate(text: str, limit: int = _MAX_PROFILE_LENGTH) -> str: return text[:limit].rstrip() + "\n\n[侧写过长,已截断]" -def _parse_args(args: list[str]) -> tuple[str, str]: - """解析参数,返回 (子命令, 输出模式)。""" +def _parse_args(args: list[str]) -> tuple[str, str, str]: + """解析参数,返回 (子命令, 输出模式, 目标ID)。 + + 目标 ID 为纯数字参数,仅超级管理员可使用。 + """ sub = "" mode = "" + target = "" for arg in args: lower = arg.lower().strip() if lower in ("-t", "--text"): @@ -43,7 +47,9 @@ def _parse_args(args: list[str]) -> tuple[str, str]: mode = _MODE_RENDER elif lower in ("g", "group"): sub = lower - return sub, mode + elif arg.strip().isdigit(): + target = arg.strip() + return sub, mode, target def _profile_mtime(entity_type: str, entity_id: str) -> str | None: @@ -181,29 +187,36 @@ async def _send_render( async def execute(args: list[str], context: CommandContext) -> None: """处理 /profile 命令。 - 用法: /p [g] [-t|--text] [-f|--forward] [-r|--render] + 用法: /p [g] [-t|--text] [-f|--forward] [-r|--render] [目标ID] g / group 查看群聊侧写(仅群聊可用) -t / --text 纯文本直接发出 -f / --forward 合并转发发出(群聊默认) -r / --render 渲染为图片发出 + 目标ID 指定查询对象(仅超级管理员) """ cognitive_service = context.cognitive_service if cognitive_service is None: await _send_text(context, "❌ 侧写服务未启用") return - sub, mode = _parse_args(args) + sub, mode, target = _parse_args(args) + + # 超管指定目标 + if target: + if not context.config.is_superadmin(context.sender_id): + await _send_text(context, "❌ 仅超级管理员可查看他人侧写") + return if sub in ("group", "g"): - if _is_private(context): - await _send_text(context, "❌ 私聊中不支持查看群聊侧写") + if _is_private(context) and not target: + await _send_text(context, "❌ 私聊中不支持查看群聊侧写(可指定群号)") return entity_type = "group" - entity_id = str(context.group_id) + entity_id = target or str(context.group_id) empty_hint = "暂无群聊侧写数据" else: entity_type = "user" - entity_id = str(context.sender_id) + entity_id = target or str(context.sender_id) empty_hint = "暂无侧写数据" profile = await cognitive_service.get_profile(entity_type, entity_id) diff --git a/src/Undefined/skills/tools/end/handler.py b/src/Undefined/skills/tools/end/handler.py index d122ae2..8afb076 100644 --- a/src/Undefined/skills/tools/end/handler.py +++ b/src/Undefined/skills/tools/end/handler.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from collections import deque from typing import Any, Dict import logging import re from Undefined.context import RequestContext +from Undefined.utils.coerce import safe_int + from Undefined.end_summary_storage import ( EndSummaryLocation, EndSummaryRecord, @@ -80,13 +84,6 @@ def _clip_text(value: Any, max_len: int) -> str: return text[: max_len - 3].rstrip() + "..." -def _safe_int(value: Any, default: int) -> int: - try: - return int(value) - except (TypeError, ValueError): - return default - - def _clamp_int(value: int, min_value: int, max_value: int) -> int: if value < min_value: return min_value @@ -103,15 +100,15 @@ def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int, int]: runtime_config = context.get("runtime_config") cognitive = getattr(runtime_config, "cognitive", None) if runtime_config else None if cognitive is not None: - max_source_len = _safe_int( + max_source_len = safe_int( getattr(cognitive, "historian_source_message_max_len", max_source_len), max_source_len, ) - recent_k = _safe_int( + recent_k = safe_int( getattr(cognitive, "historian_recent_messages_inject_k", recent_k), recent_k, ) - max_recent_line_len = _safe_int( + max_recent_line_len = safe_int( getattr( cognitive, "historian_recent_message_line_max_len", max_recent_line_len ), diff --git a/src/Undefined/utils/coerce.py b/src/Undefined/utils/coerce.py new file mode 100644 index 0000000..20c81ef --- /dev/null +++ b/src/Undefined/utils/coerce.py @@ -0,0 +1,37 @@ +"""Type-safe coercion helpers shared across the codebase.""" + +from __future__ import annotations + +from typing import Any, overload + + +@overload +def safe_int(value: Any) -> int | None: ... + + +@overload +def safe_int(value: Any, default: int) -> int: ... + + +@overload +def safe_int(value: Any, default: None) -> int | None: ... + + +def safe_int(value: Any, default: int | None = None) -> int | None: + """Safely convert *value* to int, returning *default* on failure.""" + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def safe_float(value: Any, default: float = 0.0) -> float: + """Safely convert *value* to float, returning *default* on failure.""" + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default diff --git a/src/Undefined/webui/static/js/api.js b/src/Undefined/webui/static/js/api.js index 82b427c..ed528e1 100644 --- a/src/Undefined/webui/static/js/api.js +++ b/src/Undefined/webui/static/js/api.js @@ -1,3 +1,31 @@ +// Active request controllers for cancellation on tab switch +const _activeControllers = new Map(); + +function abortPendingRequests(kind) { + if (kind) { + const controller = _activeControllers.get(kind); + if (controller) { + controller.abort(); + _activeControllers.delete(kind); + } + } else { + for (const controller of _activeControllers.values()) { + controller.abort(); + } + _activeControllers.clear(); + } +} + +function getAbortSignal(kind) { + if (kind) { + abortPendingRequests(kind); + const controller = new AbortController(); + _activeControllers.set(kind, controller); + return controller.signal; + } + return undefined; +} + const AUTH_ENDPOINTS = { login: [ "/api/v1/management/auth/login", @@ -47,6 +75,7 @@ async function requestOnce(path, options = {}) { ...options, headers, credentials: options.credentials || "same-origin", + signal: options.signal || undefined, }); } diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index baf380c..fd724c9 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -54,6 +54,7 @@ function refreshUI() { } function switchTab(tab) { + abortPendingRequests(); // Cancel pending requests from previous tab state.tab = tab; state.mobileDrawerOpen = false; const mainContent = document.querySelector(".main-content"); @@ -176,6 +177,37 @@ function setMobileInlineActionsOpen(key, open) { } async function init() { + // Global error handlers + window.onerror = function (message, source, lineno, colno, error) { + console.error("[GlobalError]", { + message, + source, + lineno, + colno, + error, + }); + if (typeof showToast === "function") { + showToast(`⚠️ ${message}`, "error", 5000); + } + return false; + }; + + window.onunhandledrejection = function (event) { + const reason = event.reason; + const msg = reason instanceof Error ? reason.message : String(reason); + // Don't toast for routine auth errors or aborted requests + if ( + msg === "Unauthorized" || + msg === "The user aborted a request." || + reason?.name === "AbortError" + ) + return; + console.error("[UnhandledRejection]", reason); + if (typeof showToast === "function") { + showToast(`⚠️ ${msg}`, "error", 5000); + } + }; + if ( window.RuntimeController && typeof window.RuntimeController.init === "function" diff --git a/tests/test_profile_command.py b/tests/test_profile_command.py index fe7a105..b00cd7b 100644 --- a/tests/test_profile_command.py +++ b/tests/test_profile_command.py @@ -41,14 +41,18 @@ def _build_context( group_id: int = 123456, sender_id: int = 10002, user_id: int | None = None, + superadmin_qq: int = 0, ) -> CommandContext: + config_stub = cast(Any, SimpleNamespace()) + config_stub.is_superadmin = lambda qq: qq == superadmin_qq + config_stub.bot_qq = 0 stub = cast(Any, SimpleNamespace()) if sender is None: sender = _DummySender() return CommandContext( group_id=group_id, sender_id=sender_id, - config=stub, + config=config_stub, sender=cast(Any, sender), ai=stub, faq_storage=stub, @@ -294,3 +298,145 @@ async def test_profile_truncation() -> None: assert len(message) <= 5100 # 5000 + truncation notice assert "[侧写过长,已截断]" in message assert message.count("A") == 5000 # Exactly 5000 'A's before truncation + + +# -- Superadmin target tests -- + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_user() -> None: + """Superadmin can query another user's profile with /p .""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="目标用户侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["99999"], context) + + assert len(sender.group_messages) == 1 + assert "目标用户侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "99999") + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_group() -> None: + """Superadmin can query a group profile with /p g <群号>.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="目标群侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["g", "789000"], context) + + assert len(sender.group_messages) == 1 + assert "目标群侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "789000") + + +@pytest.mark.asyncio +async def test_profile_nonadmin_target_rejected() -> None: + """Non-superadmin cannot specify a target QQ → permission error.""" + sender = _DummySender() + cognitive_service = AsyncMock() + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=22222, + superadmin_qq=10001, + ) + + await profile_execute(["99999"], context) + + assert len(sender.group_messages) == 1 + assert "❌ 仅超级管理员可查看他人侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_not_called() + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_with_mode() -> None: + """Superadmin with render mode + target: /p -r 12345.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="带模式的侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + # render mode will fail (no Playwright) → fallback to text + await profile_execute(["-t", "12345"], context) + + assert len(sender.group_messages) == 1 + assert "带模式的侧写" in sender.group_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("user", "12345") + + +@pytest.mark.asyncio +async def test_profile_superadmin_private_group_with_target() -> None: + """Superadmin in private chat can query a group with /p g <群号>.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value="远程群侧写") + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="private", + group_id=0, + sender_id=10001, + user_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["g", "654321"], context) + + # Private + group + target → still works for superadmin + assert len(sender.private_messages) == 1 + assert "远程群侧写" in sender.private_messages[0][1] + cognitive_service.get_profile.assert_called_once_with("group", "654321") + + +@pytest.mark.asyncio +async def test_profile_superadmin_target_not_found() -> None: + """Superadmin queries non-existent target → '暂无侧写数据'.""" + sender = _DummySender() + cognitive_service = AsyncMock() + cognitive_service.get_profile = AsyncMock(return_value=None) + + context = _build_context( + sender=sender, + cognitive_service=cognitive_service, + scope="group", + group_id=123456, + sender_id=10001, + superadmin_qq=10001, + ) + + await profile_execute(["11111"], context) + + assert len(sender.group_messages) == 1 + assert "📭 暂无侧写数据" in sender.group_messages[0][1] From 1ec1919428efff6d8a4fc0f96e671e9b6869b2af Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:31:26 +0800 Subject: [PATCH 26/57] =?UTF-8?q?feat(webui):=20=E9=AA=A8=E6=9E=B6?= =?UTF-8?q?=E5=B1=8FCSS=20+=20=E6=97=A5=E5=BF=97=E6=97=B6=E9=97=B4?= =?UTF-8?q?=E8=BF=87=E6=BB=A4=20+=20=E8=B5=84=E6=BA=90=E8=B6=8B=E5=8A=BF?= =?UTF-8?q?=E5=9B=BE=20+=20TOML=E5=8E=9F=E5=A7=8B=E8=A7=86=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 .skeleton shimmer 动画 CSS 类 (components.css) - 日志页新增 datetime-local 时间范围过滤 (log-view.js, state.js) - 概览页新增 Canvas CPU/内存实时趋势图 (bot.js, 120点历史) - 配置页新增「查看 TOML」原始文本切换 (main.js, config toggle) - 新增 i18n 翻译项 (overview.chart, config.view_toml/form) Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/webui/static/css/components.css | 17 +++ src/Undefined/webui/static/js/bot.js | 123 ++++++++++++++++++ src/Undefined/webui/static/js/i18n.js | 6 + src/Undefined/webui/static/js/log-view.js | 25 +++- src/Undefined/webui/static/js/main.js | 41 ++++++ src/Undefined/webui/static/js/state.js | 2 + src/Undefined/webui/templates/index.html | 17 +++ 7 files changed, 230 insertions(+), 1 deletion(-) diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 4d4c7f4..08044e1 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -779,3 +779,20 @@ white-space: nowrap; border: 0; } + +/* Skeleton loading */ +@keyframes shimmer { + 0% { background-position: -400px 0; } + 100% { background-position: 400px 0; } +} +.skeleton { + background: linear-gradient(90deg, var(--bg-card) 25%, var(--bg-app) 50%, var(--bg-card) 75%); + background-size: 800px 100%; + animation: shimmer 1.5s infinite linear; + border-radius: var(--radius-sm); +} +.skeleton-text { height: 14px; margin-bottom: 10px; } +.skeleton-text.short { width: 60%; } +.skeleton-text.medium { width: 80%; } +.skeleton-block { height: 48px; margin-bottom: 12px; } +.skeleton-bar { height: 8px; border-radius: 999px; } diff --git a/src/Undefined/webui/static/js/bot.js b/src/Undefined/webui/static/js/bot.js index 4d07972..28f7808 100644 --- a/src/Undefined/webui/static/js/bot.js +++ b/src/Undefined/webui/static/js/bot.js @@ -1,3 +1,124 @@ +// Metrics history for time series chart +const METRICS_HISTORY_SIZE = 120; +const _metricsHistory = { cpu: [], memory: [], timestamps: [] }; + +function pushMetrics(cpuPercent, memPercent) { + const now = new Date(); + _metricsHistory.cpu.push(cpuPercent); + _metricsHistory.memory.push(memPercent); + _metricsHistory.timestamps.push(now); + if (_metricsHistory.cpu.length > METRICS_HISTORY_SIZE) { + _metricsHistory.cpu.shift(); + _metricsHistory.memory.shift(); + _metricsHistory.timestamps.shift(); + } +} + +function drawMetricsChart() { + const canvas = get("metricsChart"); + if (!canvas) return; + const ctx = canvas.getContext("2d"); + if (!ctx) return; + + const dpr = window.devicePixelRatio || 1; + const rect = canvas.getBoundingClientRect(); + canvas.width = rect.width * dpr; + canvas.height = rect.height * dpr; + ctx.scale(dpr, dpr); + + const w = rect.width; + const h = rect.height; + const pad = { top: 10, right: 12, bottom: 24, left: 36 }; + const plotW = w - pad.left - pad.right; + const plotH = h - pad.top - pad.bottom; + + ctx.clearRect(0, 0, w, h); + + const len = _metricsHistory.cpu.length; + if (len < 2) { + ctx.fillStyle = + getComputedStyle(document.documentElement) + .getPropertyValue("--text-tertiary") + .trim() || "#999"; + ctx.font = "12px sans-serif"; + ctx.textAlign = "center"; + ctx.fillText("Collecting data...", w / 2, h / 2); + return; + } + + const textColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--text-tertiary") + .trim() || "#999"; + const gridColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--border-color") + .trim() || "#333"; + const cpuColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--accent-color") + .trim() || "#d97757"; + const memColor = + getComputedStyle(document.documentElement) + .getPropertyValue("--success") + .trim() || "#4a7c59"; + + // Y axis gridlines + ctx.strokeStyle = gridColor; + ctx.lineWidth = 0.5; + ctx.fillStyle = textColor; + ctx.font = "10px sans-serif"; + ctx.textAlign = "right"; + for (let pct = 0; pct <= 100; pct += 25) { + const y = pad.top + plotH - (pct / 100) * plotH; + ctx.beginPath(); + ctx.moveTo(pad.left, y); + ctx.lineTo(pad.left + plotW, y); + ctx.stroke(); + ctx.fillText(`${pct}%`, pad.left - 4, y + 3); + } + + // X axis time labels + ctx.textAlign = "center"; + const timestamps = _metricsHistory.timestamps; + const labelCount = Math.min(4, len); + for (let i = 0; i < labelCount; i++) { + const idx = Math.round((i / (labelCount - 1)) * (len - 1)); + const x = pad.left + (idx / (len - 1)) * plotW; + const t = timestamps[idx]; + const label = `${String(t.getMinutes()).padStart(2, "0")}:${String(t.getSeconds()).padStart(2, "0")}`; + ctx.fillText(label, x, h - 4); + } + + function drawLine(data, color) { + ctx.strokeStyle = color; + ctx.lineWidth = 1.5; + ctx.lineJoin = "round"; + ctx.beginPath(); + for (let i = 0; i < data.length; i++) { + const x = pad.left + (i / (len - 1)) * plotW; + const y = + pad.top + + plotH - + (Math.min(100, Math.max(0, data[i])) / 100) * plotH; + if (i === 0) ctx.moveTo(x, y); + else ctx.lineTo(x, y); + } + ctx.stroke(); + + ctx.globalAlpha = 0.08; + ctx.fillStyle = color; + ctx.lineTo(pad.left + plotW, pad.top + plotH); + ctx.lineTo(pad.left, pad.top + plotH); + ctx.closePath(); + ctx.fill(); + ctx.globalAlpha = 1; + } + + drawLine(_metricsHistory.cpu, cpuColor); + drawLine(_metricsHistory.memory, memColor); +} + async function fetchStatus() { if (!shouldFetch("status")) return; try { @@ -136,6 +257,8 @@ async function fetchSystemInfo() { get("systemMemoryBar").style.width = `${Math.min(100, Math.max(0, memUsage))}%`; recordFetchSuccess("system"); + pushMetrics(cpuUsage, memUsage); + drawMetricsChart(); } catch (e) { recordFetchError("system"); } diff --git a/src/Undefined/webui/static/js/i18n.js b/src/Undefined/webui/static/js/i18n.js index c7b1ab5..1de6b78 100644 --- a/src/Undefined/webui/static/js/i18n.js +++ b/src/Undefined/webui/static/js/i18n.js @@ -36,6 +36,7 @@ const I18N = { "overview.refresh": "刷新", "overview.system": "系统信息", "overview.resources": "资源使用", + "overview.chart": "资源趋势", "overview.runtime": "运行环境", "overview.cpu_model": "CPU 型号", "overview.cpu_usage": "CPU 占用率", @@ -91,6 +92,8 @@ const I18N = { "config.clear_search": "清除搜索", "config.expand_all": "全部展开", "config.collapse_all": "全部折叠", + "config.view_toml": "查看 TOML", + "config.view_form": "表单视图", "config.expand_section": "展开", "config.collapse_section": "折叠", "config.loading": "正在加载配置...", @@ -312,6 +315,7 @@ const I18N = { "overview.refresh": "Refresh", "overview.system": "System", "overview.resources": "Resources", + "overview.chart": "Resource Trends", "overview.runtime": "Runtime", "overview.cpu_model": "CPU Model", "overview.cpu_usage": "CPU Usage", @@ -372,6 +376,8 @@ const I18N = { "config.clear_search": "Clear search", "config.expand_all": "Expand all", "config.collapse_all": "Collapse all", + "config.view_toml": "View TOML", + "config.view_form": "Form View", "config.expand_section": "Expand", "config.collapse_section": "Collapse", "config.loading": "Loading configuration...", diff --git a/src/Undefined/webui/static/js/log-view.js b/src/Undefined/webui/static/js/log-view.js index 23d6ff8..9959474 100644 --- a/src/Undefined/webui/static/js/log-view.js +++ b/src/Undefined/webui/static/js/log-view.js @@ -49,6 +49,27 @@ function filterLogLines(raw) { line.toLowerCase().includes(query), ); + // Time range filtering + const timeFrom = state.logTimeFrom + ? new Date(state.logTimeFrom).getTime() + : 0; + const timeTo = state.logTimeTo ? new Date(state.logTimeTo).getTime() : 0; + if (timeFrom || timeTo) { + const tsRe = /^(\d{4}-\d{2}-\d{2}[ T]\d{2}:\d{2}:\d{2})/; + const result = []; + let include = true; + for (const line of filtered) { + const m = line.match(tsRe); + if (m) { + const ts = new Date(m[1].replace(" ", "T")).getTime(); + include = + (!timeFrom || ts >= timeFrom) && (!timeTo || ts <= timeTo); + } + if (include) result.push(line); + } + filtered = result; + } + const total = base.total ?? rawLines.length; const matched = filtered.filter((line) => line.length > 0).length; return { filtered, total, matched }; @@ -115,7 +136,9 @@ function updateLogMeta(total, matched) { if ( state.logLevel !== "all" || state.logSearch.trim() || - state.logLevelGte + state.logLevelGte || + state.logTimeFrom || + state.logTimeTo ) { parts.push( `${t("logs.filtered")}: ${total > 0 ? `${matched}/${total}` : "0/0"}`, diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index fd724c9..949ae83 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -462,6 +462,21 @@ async function init() { }; } + const logTimeFrom = get("logTimeFrom"); + if (logTimeFrom) { + logTimeFrom.addEventListener("change", () => { + state.logTimeFrom = logTimeFrom.value; + renderLogs(); + }); + } + const logTimeTo = get("logTimeTo"); + if (logTimeTo) { + logTimeTo.addEventListener("change", () => { + state.logTimeTo = logTimeTo.value; + renderLogs(); + }); + } + const logSearchInput = get("logSearchInput"); if (logSearchInput) { logSearchInput.addEventListener("input", () => { @@ -540,6 +555,32 @@ async function init() { if (collapseAllBtn) collapseAllBtn.onclick = () => setAllSectionsCollapsed(true); + get("btnToggleToml").onclick = async function () { + const formGrid = get("formSections"); + const tomlViewer = get("tomlViewer"); + const btn = get("btnToggleToml"); + if (!formGrid || !tomlViewer || !btn) return; + + const isShowingToml = tomlViewer.style.display !== "none"; + if (isShowingToml) { + tomlViewer.style.display = "none"; + formGrid.style.display = ""; + btn.innerText = t("config.view_toml"); + } else { + try { + const res = await api("/api/config"); + const data = await res.json(); + const content = data.content || ""; + get("tomlContent").textContent = content; + formGrid.style.display = "none"; + tomlViewer.style.display = "block"; + btn.innerText = t("config.view_form"); + } catch (e) { + showToast(`${t("common.error")}: ${e.message}`, "error", 5000); + } + } + }; + const logout = async () => { try { await api(authEndpointCandidates("logout"), { method: "POST" }); diff --git a/src/Undefined/webui/static/js/state.js b/src/Undefined/webui/static/js/state.js index 2b1561b..284f0da 100644 --- a/src/Undefined/webui/static/js/state.js +++ b/src/Undefined/webui/static/js/state.js @@ -190,6 +190,8 @@ const state = { bot: { running: false, pid: null, uptime: 0 }, logsRaw: "", logSearch: "", + logTimeFrom: "", + logTimeTo: "", logLevel: "all", logLevelGte: false, logType: "bot", diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 3c9886f..a4953a5 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -299,6 +299,15 @@

运行概览

+
+
资源趋势
+ +
+ CPU + Memory +
+
+
运行环境
@@ -347,6 +356,7 @@

配置修改

更多操作
+
+ @@ -403,6 +416,10 @@

系统日志

+ + From 0e034fc05a7f5485217db08d6cbec76cfcfa8483 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:44:49 +0800 Subject: [PATCH 27/57] feat(webui): add Cmd/Ctrl+K command palette - CSS: overlay, palette card, input, list items, keyboard hint styles - HTML: modal overlay with input and list container - i18n: zh/en strings for all palette commands - JS: command list (tab nav, refresh, logout), open/close, keyboard navigation (arrows + enter), fuzzy filtering, Ctrl/Cmd+K toggle, Escape to close, click-outside dismiss Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/webui/static/css/components.css | 11 + src/Undefined/webui/static/js/i18n.js | 34 +++ src/Undefined/webui/static/js/main.js | 247 ++++++++++++++++++ src/Undefined/webui/templates/index.html | 16 ++ 4 files changed, 308 insertions(+) diff --git a/src/Undefined/webui/static/css/components.css b/src/Undefined/webui/static/css/components.css index 08044e1..bd91f6d 100644 --- a/src/Undefined/webui/static/css/components.css +++ b/src/Undefined/webui/static/css/components.css @@ -796,3 +796,14 @@ .skeleton-text.medium { width: 80%; } .skeleton-block { height: 48px; margin-bottom: 12px; } .skeleton-bar { height: 8px; border-radius: 999px; } + +/* Command Palette */ +.cmd-palette-overlay { position: fixed; inset: 0; background: rgba(0,0,0,0.5); z-index: 9998; display: flex; align-items: flex-start; justify-content: center; padding-top: 20vh; } +.cmd-palette { width: 480px; max-width: 90vw; background: var(--bg-card); border: 1px solid var(--border-color); border-radius: var(--radius-lg); box-shadow: var(--shadow-lg); overflow: hidden; } +.cmd-palette-input { width: 100%; padding: 14px 18px; border: none; border-bottom: 1px solid var(--border-color); background: transparent; color: var(--text-primary); font-size: 15px; outline: none; box-sizing: border-box; } +.cmd-palette-input::placeholder { color: var(--text-tertiary); } +.cmd-palette-list { max-height: 320px; overflow-y: auto; padding: 6px; } +.cmd-palette-item { padding: 10px 14px; border-radius: var(--radius-sm); cursor: pointer; display: flex; align-items: center; gap: 10px; color: var(--text-secondary); font-size: 14px; } +.cmd-palette-item:hover, .cmd-palette-item.active { background: var(--accent-subtle); color: var(--text-primary); } +.cmd-palette-item .cmd-key { font-size: 11px; color: var(--text-tertiary); margin-left: auto; font-family: var(--font-mono); } +.cmd-palette-empty { padding: 20px; text-align: center; color: var(--text-tertiary); font-size: 13px; } diff --git a/src/Undefined/webui/static/js/i18n.js b/src/Undefined/webui/static/js/i18n.js index 1de6b78..1d60d0b 100644 --- a/src/Undefined/webui/static/js/i18n.js +++ b/src/Undefined/webui/static/js/i18n.js @@ -108,6 +108,12 @@ const I18N = { "config.reload_error": "配置重载失败。", "config.bootstrap_created": "检测到缺少 config.toml,已从示例生成;请在此页完善配置并保存。", + "config.history": "版本历史", + "config.history_empty": "暂无备份", + "config.history_restore": "恢复", + "config.history_restore_confirm": + "确定恢复到此版本?当前配置将自动备份。", + "config.history_restored": "已恢复配置", "logs.title": "运行日志", "logs.subtitle": "实时查看日志尾部输出。", "logs.auto": "自动刷新", @@ -276,6 +282,17 @@ const I18N = { "update.not_eligible": "未满足更新条件(仅支持官方 origin/main)", "update.failed": "更新失败", "update.no_restart": "更新已完成但未重启(请检查 uv sync 输出)", + "cmd.placeholder": "输入命令...", + "cmd.empty": "没有匹配的命令", + "cmd.tab_overview": "跳转到 概览", + "cmd.tab_config": "跳转到 配置", + "cmd.tab_logs": "跳转到 日志", + "cmd.tab_probes": "跳转到 探针", + "cmd.tab_memory": "跳转到 备忘录", + "cmd.tab_memes": "跳转到 表情包", + "cmd.tab_cognitive": "跳转到 认知记忆", + "cmd.refresh": "刷新当前页面", + "cmd.logout": "退出登录", }, en: { "landing.title": "Undefined Console", @@ -392,6 +409,12 @@ const I18N = { "config.reload_error": "Failed to reload configuration.", "config.bootstrap_created": "config.toml was missing and has been generated from the example. Please review and save your configuration.", + "config.history": "Version History", + "config.history_empty": "No backups yet", + "config.history_restore": "Restore", + "config.history_restore_confirm": + "Restore to this version? Current config will be backed up automatically.", + "config.history_restored": "Configuration restored", "logs.title": "System Logs", "logs.subtitle": "Real-time view of recent log output.", "logs.auto": "Auto Refresh", @@ -564,5 +587,16 @@ const I18N = { "update.no_restart": "Updated but not restarted (check uv sync output)", "config.aot_add": "+ Add Entry", "config.aot_remove": "Remove", + "cmd.placeholder": "Type a command...", + "cmd.empty": "No matching commands", + "cmd.tab_overview": "Go to Overview", + "cmd.tab_config": "Go to Config", + "cmd.tab_logs": "Go to Logs", + "cmd.tab_probes": "Go to Probes", + "cmd.tab_memory": "Go to Memory", + "cmd.tab_memes": "Go to Memes", + "cmd.tab_cognitive": "Go to Cognitive", + "cmd.refresh": "Refresh current page", + "cmd.logout": "Logout", }, }; diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index 949ae83..2be7aee 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -176,6 +176,142 @@ function setMobileInlineActionsOpen(key, open) { syncMobileChrome(); } +// Command Palette +const _cmdCommands = [ + { + id: "overview", + label: () => t("cmd.tab_overview"), + action: () => switchTab("overview"), + keys: "1", + }, + { + id: "config", + label: () => t("cmd.tab_config"), + action: () => switchTab("config"), + keys: "2", + }, + { + id: "logs", + label: () => t("cmd.tab_logs"), + action: () => switchTab("logs"), + keys: "3", + }, + { + id: "probes", + label: () => t("cmd.tab_probes"), + action: () => switchTab("probes"), + keys: "4", + }, + { + id: "memory", + label: () => t("cmd.tab_memory"), + action: () => switchTab("memory"), + keys: "5", + }, + { + id: "memes", + label: () => t("cmd.tab_memes"), + action: () => switchTab("memes"), + keys: "6", + }, + { + id: "cognitive", + label: () => t("cmd.tab_cognitive"), + action: () => switchTab("cognitive"), + keys: "7", + }, + { + id: "refresh", + label: () => t("cmd.refresh"), + action: () => location.reload(), + keys: "R", + }, + { + id: "logout", + label: () => t("cmd.logout"), + action: () => { + get("btnLogout")?.click(); + }, + keys: "", + }, +]; + +let _cmdActiveIndex = 0; + +function openCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + const input = get("cmdPaletteInput"); + if (!overlay || !input) return; + input.placeholder = t("cmd.placeholder"); + overlay.style.display = "flex"; + input.value = ""; + _cmdActiveIndex = 0; + _renderCmdList(""); + input.focus(); +} + +function closeCmdPalette() { + const overlay = get("cmdPaletteOverlay"); + if (overlay) overlay.style.display = "none"; +} + +function _renderCmdList(query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (filtered.length === 0) { + list.innerHTML = `
${escapeHtml(t("cmd.empty"))}
`; + return; + } + _cmdActiveIndex = Math.min(_cmdActiveIndex, filtered.length - 1); + list.innerHTML = filtered + .map( + (c, i) => + `
${escapeHtml(c.label())}${c.keys ? `${escapeHtml(c.keys)}` : ""}
`, + ) + .join(""); + list.querySelectorAll(".cmd-palette-item").forEach((el, i) => { + el.addEventListener("click", () => { + closeCmdPalette(); + filtered[i].action(); + }); + el.addEventListener("mouseenter", () => { + _cmdActiveIndex = i; + _renderCmdList(query); + }); + }); +} + +function _handleCmdKey(e, query) { + const list = get("cmdPaletteList"); + if (!list) return; + const q = query.trim().toLowerCase(); + const filtered = q + ? _cmdCommands.filter( + (c) => c.label().toLowerCase().includes(q) || c.id.includes(q), + ) + : _cmdCommands; + if (e.key === "ArrowDown") { + e.preventDefault(); + _cmdActiveIndex = (_cmdActiveIndex + 1) % filtered.length; + _renderCmdList(query); + } else if (e.key === "ArrowUp") { + e.preventDefault(); + _cmdActiveIndex = + (_cmdActiveIndex - 1 + filtered.length) % filtered.length; + _renderCmdList(query); + } else if (e.key === "Enter" && filtered.length > 0) { + e.preventDefault(); + closeCmdPalette(); + filtered[_cmdActiveIndex].action(); + } +} + async function init() { // Global error handlers window.onerror = function (message, source, lineno, colno, error) { @@ -558,9 +694,12 @@ async function init() { get("btnToggleToml").onclick = async function () { const formGrid = get("formSections"); const tomlViewer = get("tomlViewer"); + const historyPanel = get("configHistoryPanel"); const btn = get("btnToggleToml"); if (!formGrid || !tomlViewer || !btn) return; + if (historyPanel) historyPanel.style.display = "none"; + const isShowingToml = tomlViewer.style.display !== "none"; if (isShowingToml) { tomlViewer.style.display = "none"; @@ -581,6 +720,80 @@ async function init() { } }; + get("btnConfigHistory")?.addEventListener("click", async () => { + const panel = get("configHistoryPanel"); + const formGrid = get("formSections"); + const tomlViewer = get("tomlViewer"); + if (!panel) return; + + const isShowing = panel.style.display !== "none"; + if (isShowing) { + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + return; + } + + try { + const res = await api("/api/config/history"); + const data = await res.json(); + const backups = data.backups || []; + const list = get("configHistoryList"); + if (!list) return; + + if (backups.length === 0) { + list.innerHTML = `

${escapeHtml(t("config.history_empty"))}

`; + } else { + list.innerHTML = backups + .map((b) => { + const date = new Date(b.mtime * 1000); + const sizeKB = (b.size / 1024).toFixed(1); + return `
+
+
+ ${escapeHtml(b.name)} + ${sizeKB} KB · ${date.toLocaleString()} +
+ +
+
`; + }) + .join(""); + + list.querySelectorAll("[data-restore-name]").forEach((btn) => { + btn.addEventListener("click", async () => { + const name = btn.getAttribute("data-restore-name"); + if (!confirm(t("config.history_restore_confirm"))) + return; + try { + await api("/api/config/history/restore", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ name }), + }); + showToast(t("config.history_restored"), "success"); + panel.style.display = "none"; + if (formGrid) formGrid.style.display = ""; + loadConfig(); + } catch (e) { + showToast( + `${t("common.error")}: ${e.message}`, + "error", + ); + } + }); + }); + } + + if (formGrid) formGrid.style.display = "none"; + if (tomlViewer) tomlViewer.style.display = "none"; + panel.style.display = "block"; + } catch (e) { + showToast(`${t("common.error")}: ${e.message}`, "error"); + } + }); + const logout = async () => { try { await api(authEndpointCandidates("logout"), { method: "POST" }); @@ -656,6 +869,40 @@ async function init() { switchTab("config"); } + // Command Palette keybinding + document.addEventListener("keydown", (e) => { + if ((e.metaKey || e.ctrlKey) && e.key === "k") { + e.preventDefault(); + const overlay = get("cmdPaletteOverlay"); + if (overlay && overlay.style.display !== "none") { + closeCmdPalette(); + } else { + openCmdPalette(); + } + } + if (e.key === "Escape") { + closeCmdPalette(); + } + }); + + const cmdInput = get("cmdPaletteInput"); + if (cmdInput) { + cmdInput.addEventListener("input", () => { + _cmdActiveIndex = 0; + _renderCmdList(cmdInput.value); + }); + cmdInput.addEventListener("keydown", (e) => { + _handleCmdKey(e, cmdInput.value); + }); + } + + const cmdOverlay = get("cmdPaletteOverlay"); + if (cmdOverlay) { + cmdOverlay.addEventListener("click", (e) => { + if (e.target === cmdOverlay) closeCmdPalette(); + }); + } + document.addEventListener("visibilitychange", () => { if (document.hidden) { stopStatusTimer(); diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index a4953a5..214ee63 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -363,6 +363,8 @@

配置修改

data-i18n="config.collapse_all">全部折叠 +
@@ -372,6 +374,12 @@

配置修改

+ @@ -808,6 +816,14 @@

MIT License

+ + + From 759aeb65792ca517144dfee2af6b054dc26623ca Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:46:23 +0800 Subject: [PATCH 28/57] feat(webui): add config version history and rollback backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _backup_config() helper that creates timestamped backups in data/config_backups/ with a 50-backup cap - Auto-backup before every config save (POST /api/config) and patch (POST /api/patch) - GET /api/config/history — list all backups (newest first) - POST /api/config/history/restore — restore a backup by name with TOML validation and auto-backup of current config Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude --- src/Undefined/webui/routes/_config.py | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/Undefined/webui/routes/_config.py b/src/Undefined/webui/routes/_config.py index 0afeaa0..c6c65b3 100644 --- a/src/Undefined/webui/routes/_config.py +++ b/src/Undefined/webui/routes/_config.py @@ -1,5 +1,7 @@ import asyncio +import shutil import tomllib +from datetime import datetime, timezone from pathlib import Path from tempfile import NamedTemporaryFile @@ -23,6 +25,27 @@ sync_config_file, ) +_BACKUP_DIR = Path("data") / "config_backups" +_MAX_BACKUPS = 50 + + +def _backup_config() -> str | None: + """Create a timestamped backup of the current config.toml. + + Returns the backup filename, or *None* if the source file does not exist. + """ + if not CONFIG_PATH.exists(): + return None + _BACKUP_DIR.mkdir(parents=True, exist_ok=True) + ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + backup_name = f"config_{ts}.toml" + shutil.copy2(CONFIG_PATH, _BACKUP_DIR / backup_name) + # Trim old backups beyond the cap + backups = sorted(_BACKUP_DIR.glob("config_*.toml")) + while len(backups) > _MAX_BACKUPS: + backups.pop(0).unlink(missing_ok=True) + return backup_name + @routes.get("/api/v1/management/config") @routes.get("/api/config") @@ -45,6 +68,7 @@ async def save_config_handler(request: web.Request) -> Response: if not valid: return web.json_response({"success": False, "error": msg}, status=400) try: + _backup_config() CONFIG_PATH.write_text(content, encoding="utf-8") get_config_manager().reload() logic_valid, logic_msg = validate_required_config() @@ -133,6 +157,7 @@ async def config_patch_handler(request: web.Request) -> Response: data = {} patched = apply_patch(data, patch) + _backup_config() CONFIG_PATH.write_text( render_toml(patched, comments=load_comment_map()), encoding="utf-8" ) @@ -181,3 +206,49 @@ async def sync_config_template_handler(request: web.Request) -> Response: return web.json_response({"success": False, "error": str(exc)}, status=400) except Exception as exc: return web.json_response({"success": False, "error": str(exc)}, status=500) + + +# --------------------------------------------------------------------------- +# Config version history +# --------------------------------------------------------------------------- + + +@routes.get("/api/v1/management/config/history") +@routes.get("/api/config/history") +async def config_history_handler(request: web.Request) -> Response: + """Return the list of config backups, newest first.""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + if not _BACKUP_DIR.exists(): + return web.json_response({"backups": []}) + backups: list[dict[str, object]] = [] + for f in sorted(_BACKUP_DIR.glob("config_*.toml"), reverse=True): + stat = f.stat() + backups.append({"name": f.name, "size": stat.st_size, "mtime": stat.st_mtime}) + return web.json_response({"backups": backups}) + + +@routes.post("/api/v1/management/config/history/restore") +@routes.post("/api/config/history/restore") +async def config_restore_handler(request: web.Request) -> Response: + """Restore a config backup by name (auto-backs-up current config first).""" + if not check_auth(request): + return web.json_response({"error": "Unauthorized"}, status=401) + try: + data = await request.json() + except Exception: + return web.json_response({"error": "Invalid JSON"}, status=400) + name = str(data.get("name", "")) + if not name or ".." in name or "/" in name: + return web.json_response({"error": "Invalid backup name"}, status=400) + backup_path = _BACKUP_DIR / name + if not backup_path.exists(): + return web.json_response({"error": "Backup not found"}, status=404) + content = backup_path.read_text(encoding="utf-8") + valid, msg = validate_toml(content) + if not valid: + return web.json_response({"error": f"Backup TOML invalid: {msg}"}, status=400) + _backup_config() + CONFIG_PATH.write_text(content, encoding="utf-8") + get_config_manager().reload() + return web.json_response({"success": True, "message": "Restored"}) From 8f478b93ee156fc0228a4ceb67507de0dd68d027 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 11:52:43 +0800 Subject: [PATCH 29/57] feat(webui): add modal focus trap and wire into command palette - Add stack-based trapFocus/releaseFocus to ui.js - Wire focus trap into openCmdPalette/closeCmdPalette - Tab/Shift+Tab cycles within modal boundaries Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/webui/static/js/main.js | 6 ++-- src/Undefined/webui/static/js/ui.js | 40 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/Undefined/webui/static/js/main.js b/src/Undefined/webui/static/js/main.js index 2be7aee..5745ef3 100644 --- a/src/Undefined/webui/static/js/main.js +++ b/src/Undefined/webui/static/js/main.js @@ -247,12 +247,14 @@ function openCmdPalette() { input.value = ""; _cmdActiveIndex = 0; _renderCmdList(""); - input.focus(); + trapFocus(overlay); } function closeCmdPalette() { const overlay = get("cmdPaletteOverlay"); - if (overlay) overlay.style.display = "none"; + if (!overlay) return; + releaseFocus(overlay); + overlay.style.display = "none"; } function _renderCmdList(query) { diff --git a/src/Undefined/webui/static/js/ui.js b/src/Undefined/webui/static/js/ui.js index c42a4dd..f455ba2 100644 --- a/src/Undefined/webui/static/js/ui.js +++ b/src/Undefined/webui/static/js/ui.js @@ -186,6 +186,46 @@ function showToast(message, type = "info", duration = 3000) { }, duration); } +// Focus trap for modals +const _focusTrapStack = []; + +function trapFocus(container) { + if (!container) return; + const focusable = container.querySelectorAll( + 'a[href], button:not([disabled]), input:not([disabled]), select:not([disabled]), textarea:not([disabled]), [tabindex]:not([tabindex="-1"])', + ); + if (focusable.length === 0) return; + const first = focusable[0]; + const last = focusable[focusable.length - 1]; + + function handler(e) { + if (e.key !== "Tab") return; + if (e.shiftKey) { + if (document.activeElement === first) { + e.preventDefault(); + last.focus(); + } + } else { + if (document.activeElement === last) { + e.preventDefault(); + first.focus(); + } + } + } + + container.addEventListener("keydown", handler); + _focusTrapStack.push({ container, handler }); + first.focus(); +} + +function releaseFocus(container) { + const idx = _focusTrapStack.findIndex((t) => t.container === container); + if (idx === -1) return; + const entry = _focusTrapStack[idx]; + entry.container.removeEventListener("keydown", entry.handler); + _focusTrapStack.splice(idx, 1); +} + function setConfigState(mode) { const stateEl = get("configState"); const grid = get("formSections"); From a410707dcc6c49242e7b0d3c6a71d4c461ac1404 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:04:23 +0800 Subject: [PATCH 30/57] refactor(config): split loader.py into sub-modules Extract 6 sub-modules from the 3365-line monolith: - coercers.py: type coercion/normalization helpers (~150 lines) - resolvers.py: config value resolution (~106 lines) - admin.py: local admin management (~44 lines) - webui_settings.py: WebUI settings class (~61 lines) - model_parsers.py: all model config parsers (~1250 lines) - domain_parsers.py: domain config parsers + _update_dataclass (~284 lines) loader.py retains Config class, TOML loading, and all re-exports for backward compatibility. All 1429 tests pass, mypy strict clean. Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/config/admin.py | 44 + src/Undefined/config/coercers.py | 150 ++ src/Undefined/config/domain_parsers.py | 284 +++ src/Undefined/config/loader.py | 1957 +---------------- src/Undefined/config/model_parsers.py | 1250 +++++++++++ src/Undefined/config/resolvers.py | 106 + src/Undefined/config/webui_settings.py | 61 + .../test_config_cognitive_historian_limits.py | 6 +- 8 files changed, 1992 insertions(+), 1866 deletions(-) create mode 100644 src/Undefined/config/admin.py create mode 100644 src/Undefined/config/coercers.py create mode 100644 src/Undefined/config/domain_parsers.py create mode 100644 src/Undefined/config/model_parsers.py create mode 100644 src/Undefined/config/resolvers.py create mode 100644 src/Undefined/config/webui_settings.py diff --git a/src/Undefined/config/admin.py b/src/Undefined/config/admin.py new file mode 100644 index 0000000..16dcfe6 --- /dev/null +++ b/src/Undefined/config/admin.py @@ -0,0 +1,44 @@ +"""Local admin management (config.local.json).""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +LOCAL_CONFIG_PATH = Path("config.local.json") + + +def load_local_admins() -> list[int]: + """从本地配置文件加载动态管理员列表""" + if not LOCAL_CONFIG_PATH.exists(): + return [] + try: + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + admin_qqs: list[int] = data.get("admin_qqs", []) + return admin_qqs + except Exception as exc: + logger.warning("读取本地配置失败: %s", exc) + return [] + + +def save_local_admins(admin_qqs: list[int]) -> None: + """保存动态管理员列表到本地配置文件""" + try: + data: dict[str, list[int]] = {} + if LOCAL_CONFIG_PATH.exists(): + with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: + data = json.load(f) + + data["admin_qqs"] = admin_qqs + + with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) + except Exception as exc: + logger.error("保存本地配置失败: %s", exc) + raise diff --git a/src/Undefined/config/coercers.py b/src/Undefined/config/coercers.py new file mode 100644 index 0000000..4edff34 --- /dev/null +++ b/src/Undefined/config/coercers.py @@ -0,0 +1,150 @@ +"""Type coercion and normalization helpers for config loading.""" + +from __future__ import annotations + +import logging +import os +from typing import Any, Optional + +from Undefined.utils.request_params import normalize_request_params + +logger = logging.getLogger(__name__) + +_ENV_WARNED_KEYS: set[str] = set() + + +def _warn_env_fallback(name: str) -> None: + if name in _ENV_WARNED_KEYS: + return + _ENV_WARNED_KEYS.add(name) + logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) + + +def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: + node: Any = data + for key in path: + if not isinstance(node, dict) or key not in node: + return None + node = node[key] + return node + + +def _normalize_str(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, str): + stripped = value.strip() + return stripped if stripped else None + return str(value).strip() + + +def _coerce_int(value: Any, default: int) -> int: + if value is None: + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _coerce_float(value: Any, default: float) -> float: + if value is None: + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_queue_interval(value: float, default: float = 1.0) -> float: + """规范化队列发车间隔。 + + `0` 表示立即发车,负数回退到默认值。 + """ + + return default if value < 0 else value + + +def _coerce_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return default + + +def _coerce_str(value: Any, default: str) -> str: + normalized = _normalize_str(value) + return normalized if normalized is not None else default + + +def _normalize_base_url(value: str, default: str) -> str: + normalized = value.strip().rstrip("/") + if normalized: + return normalized + return default.rstrip("/") + + +def _coerce_int_list(value: Any) -> list[int]: + if value is None: + return [] + if isinstance(value, list): + items: list[int] = [] + for item in value: + try: + items.append(int(item)) + except (TypeError, ValueError): + continue + return items + if isinstance(value, str): + parts = [part.strip() for part in value.split(",") if part.strip()] + items = [] + for part in parts: + try: + items.append(int(part)) + except ValueError: + continue + return items + return [] + + +def _coerce_str_list(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + if isinstance(value, str): + return [part.strip() for part in value.split(",") if part.strip()] + return [] + + +def _coerce_request_params(value: Any) -> dict[str, Any]: + return normalize_request_params(value) + + +def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: + return _coerce_request_params( + _get_nested(data, ("models", model_name, "request_params")) + ) + + +def _get_value( + data: dict[str, Any], + path: tuple[str, ...], + env_key: Optional[str], +) -> Any: + value = _get_nested(data, path) + if value is not None: + return value + if env_key: + env_value = os.getenv(env_key) + if env_value is not None and str(env_value).strip() != "": + _warn_env_fallback(env_key) + return env_value + return None + + +_VALID_API_MODES = {"chat_completions", "responses"} +_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py new file mode 100644 index 0000000..ebb2222 --- /dev/null +++ b/src/Undefined/config/domain_parsers.py @@ -0,0 +1,284 @@ +"""Domain configuration parsers (cognitive, memes, API, naga).""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _coerce_int_list, + _coerce_str_list, +) +from .models import ( + APIConfig, + CognitiveConfig, + MemeConfig, + NagaConfig, +) + +DEFAULT_API_HOST = "127.0.0.1" +DEFAULT_API_PORT = 8788 +DEFAULT_API_AUTH_KEY = "changeme" + + +def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: + cog = data.get("cognitive", {}) + vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} + q = cog.get("query", {}) if isinstance(cog, dict) else {} + hist = cog.get("historian", {}) if isinstance(cog, dict) else {} + prof = cog.get("profile", {}) if isinstance(cog, dict) else {} + que = cog.get("queue", {}) if isinstance(cog, dict) else {} + return CognitiveConfig( + enabled=_coerce_bool( + cog.get("enabled") if isinstance(cog, dict) else None, True + ), + bot_name=_coerce_str( + cog.get("bot_name") if isinstance(cog, dict) else None, + "Undefined", + ), + vector_store_path=_coerce_str( + vs.get("path") if isinstance(vs, dict) else None, + "data/cognitive/chromadb", + ), + queue_path=_coerce_str( + que.get("path") if isinstance(que, dict) else None, + "data/cognitive/queues", + ), + profiles_path=_coerce_str( + prof.get("path") if isinstance(prof, dict) else None, + "data/cognitive/profiles", + ), + auto_top_k=_coerce_int(q.get("auto_top_k") if isinstance(q, dict) else None, 3), + auto_scope_candidate_multiplier=_coerce_int( + q.get("auto_scope_candidate_multiplier") if isinstance(q, dict) else None, + 2, + ), + auto_current_group_boost=_coerce_float( + q.get("auto_current_group_boost") if isinstance(q, dict) else None, + 1.15, + ), + auto_current_private_boost=_coerce_float( + q.get("auto_current_private_boost") if isinstance(q, dict) else None, + 1.25, + ), + enable_rerank=_coerce_bool( + q.get("enable_rerank") if isinstance(q, dict) else None, True + ), + recent_end_summaries_inject_k=_coerce_int( + q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, + 30, + ), + time_decay_enabled=_coerce_bool( + q.get("time_decay_enabled") if isinstance(q, dict) else None, True + ), + time_decay_half_life_days_auto=_coerce_float( + q.get("time_decay_half_life_days_auto") if isinstance(q, dict) else None, + 14.0, + ), + time_decay_half_life_days_tool=_coerce_float( + q.get("time_decay_half_life_days_tool") if isinstance(q, dict) else None, + 60.0, + ), + time_decay_boost=_coerce_float( + q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 + ), + time_decay_min_similarity=_coerce_float( + q.get("time_decay_min_similarity") if isinstance(q, dict) else None, + 0.35, + ), + tool_default_top_k=_coerce_int( + q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 + ), + profile_top_k=_coerce_int( + q.get("profile_top_k") if isinstance(q, dict) else None, 8 + ), + rerank_candidate_multiplier=_coerce_int( + q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 + ), + rewrite_max_retry=_coerce_int( + hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 + ), + historian_recent_messages_inject_k=_coerce_int( + hist.get("recent_messages_inject_k") if isinstance(hist, dict) else None, + 12, + ), + historian_recent_message_line_max_len=_coerce_int( + hist.get("recent_message_line_max_len") if isinstance(hist, dict) else None, + 240, + ), + historian_source_message_max_len=_coerce_int( + hist.get("source_message_max_len") if isinstance(hist, dict) else None, + 800, + ), + poll_interval_seconds=_coerce_float( + hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, + 1.0, + ), + stale_job_timeout_seconds=_coerce_float( + hist.get("stale_job_timeout_seconds") if isinstance(hist, dict) else None, + 300.0, + ), + profile_revision_keep=_coerce_int( + prof.get("revision_keep") if isinstance(prof, dict) else None, 5 + ), + failed_max_age_days=_coerce_int( + que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 + ), + failed_max_files=_coerce_int( + que.get("failed_max_files") if isinstance(que, dict) else None, 500 + ), + failed_cleanup_interval=_coerce_int( + que.get("failed_cleanup_interval") if isinstance(que, dict) else None, + 100, + ), + job_max_retries=_coerce_int( + que.get("job_max_retries") if isinstance(que, dict) else None, 3 + ), + ) + + +def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: + section_raw = data.get("memes", {}) + section = section_raw if isinstance(section_raw, dict) else {} + return MemeConfig( + enabled=_coerce_bool(section.get("enabled"), True), + query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), + max_source_image_bytes=max( + 1, + _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), + ), + blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), + preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), + db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), + vector_store_path=_coerce_str( + section.get("vector_store_path"), "data/memes/chromadb" + ), + queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), + max_items=max(1, _coerce_int(section.get("max_items"), 10000)), + max_total_bytes=max( + 1, + _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), + ), + allow_gif=_coerce_bool(section.get("allow_gif"), True), + auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), + auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), + keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), + semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), + rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), + worker_max_concurrency=max( + 1, _coerce_int(section.get("worker_max_concurrency"), 4) + ), + ) + + +def _parse_api_config(data: dict[str, Any]) -> APIConfig: + section_raw = data.get("api", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), True) + host = _coerce_str(section.get("host"), DEFAULT_API_HOST) + port = _coerce_int(section.get("port"), DEFAULT_API_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_API_PORT + + auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) + if not auth_key: + auth_key = DEFAULT_API_AUTH_KEY + + openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) + + tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) + tool_invoke_expose = _coerce_str( + section.get("tool_invoke_expose"), "tools+toolsets" + ) + valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} + if tool_invoke_expose not in valid_expose: + tool_invoke_expose = "tools+toolsets" + tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) + tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) + tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) + if tool_invoke_timeout <= 0: + tool_invoke_timeout = 120 + tool_invoke_callback_timeout = _coerce_int( + section.get("tool_invoke_callback_timeout"), 10 + ) + if tool_invoke_callback_timeout <= 0: + tool_invoke_callback_timeout = 10 + + return APIConfig( + enabled=enabled, + host=host, + port=port, + auth_key=auth_key, + openapi_enabled=openapi_enabled, + tool_invoke_enabled=tool_invoke_enabled, + tool_invoke_expose=tool_invoke_expose, + tool_invoke_allowlist=tool_invoke_allowlist, + tool_invoke_denylist=tool_invoke_denylist, + tool_invoke_timeout=tool_invoke_timeout, + tool_invoke_callback_timeout=tool_invoke_callback_timeout, + ) + + +def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: + section_raw = data.get("naga", {}) + section = section_raw if isinstance(section_raw, dict) else {} + + enabled = _coerce_bool(section.get("enabled"), False) + api_url = _coerce_str(section.get("api_url"), "") + api_key = _coerce_str(section.get("api_key"), "") + moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) + allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) + + return NagaConfig( + enabled=enabled, + api_url=api_url, + api_key=api_key, + moderation_enabled=moderation_enabled, + allowed_groups=allowed_groups, + ) + + +def _parse_easter_egg_call_mode(value: Any) -> str: + """解析彩蛋提示模式。 + + 兼容旧版布尔值: + - True => agent + - False => none + """ + if isinstance(value, bool): + return "agent" if value else "none" + if isinstance(value, (int, float)): + return "agent" if bool(value) else "none" + if value is None: + return "none" + + text = str(value).strip().lower() + if text in {"true", "1", "yes", "on"}: + return "agent" + if text in {"false", "0", "no", "off"}: + return "none" + if text in {"none", "agent", "tools", "all", "clean"}: + return text + return "none" + + +def _update_dataclass( + old_value: Any, new_value: Any, prefix: str +) -> dict[str, tuple[Any, Any]]: + changes: dict[str, tuple[Any, Any]] = {} + if not isinstance(old_value, type(new_value)): + changes[prefix] = (old_value, new_value) + return changes + for field in fields(old_value): + name = field.name + old_attr = getattr(old_value, name) + new_attr = getattr(new_value, name) + if old_attr != new_attr: + setattr(old_value, name, new_attr) + changes[f"{prefix}.{name}"] = (old_attr, new_attr) + return changes diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index ac65473..26f669c 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import logging import os import re @@ -37,38 +36,90 @@ def load_dotenv( ImageGenConfig, ImageGenModelConfig, MemeConfig, - ModelPool, - ModelPoolEntry, NagaConfig, RerankModelConfig, SecurityModelConfig, VisionModelConfig, ) -from Undefined.utils.request_params import ( - merge_request_params, - normalize_request_params, +from .coercers import ( # noqa: F401 — re-exported for backward compat + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_int_list, + _coerce_str, + _coerce_str_list, + _get_model_request_params, + _get_value, + _normalize_base_url, + _normalize_queue_interval, + _normalize_str, + _warn_env_fallback, +) +from .resolvers import ( # noqa: F401 — re-exported for backward compat + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) +from .admin import ( # noqa: F401 — re-exported for backward compat + LOCAL_CONFIG_PATH, + load_local_admins, + save_local_admins, +) +from .webui_settings import ( # noqa: F401 — re-exported for backward compat + DEFAULT_WEBUI_PASSWORD, + DEFAULT_WEBUI_PORT, + DEFAULT_WEBUI_URL, + WebUISettings, + load_webui_settings, +) +from .model_parsers import ( + _log_debug_info, + _merge_admins, + _parse_agent_model_config, + _parse_chat_model_config, + _parse_embedding_model_config, + _parse_grok_model_config, + _parse_historian_model_config, + _parse_image_edit_model_config, + _parse_image_gen_config, + _parse_image_gen_model_config, + _parse_naga_model_config, + _parse_rerank_model_config, + _parse_security_model_config, + _parse_summary_model_config, + _parse_vision_model_config, + _verify_required_fields, +) +from .domain_parsers import ( + _parse_api_config, + _parse_cognitive_config, + _parse_easter_egg_call_mode, + _parse_memes_config, + _parse_naga_config, + _update_dataclass, ) logger = logging.getLogger(__name__) CONFIG_PATH = Path("config.toml") -LOCAL_CONFIG_PATH = Path("config.local.json") - -DEFAULT_WEBUI_URL = "127.0.0.1" -DEFAULT_WEBUI_PORT = 8787 -DEFAULT_WEBUI_PASSWORD = "changeme" -DEFAULT_API_HOST = "127.0.0.1" -DEFAULT_API_PORT = 8788 -DEFAULT_API_AUTH_KEY = "changeme" - -_ENV_WARNED_KEYS: set[str] = set() - -def _warn_env_fallback(name: str) -> None: - if name in _ENV_WARNED_KEYS: - return - _ENV_WARNED_KEYS.add(name) - logger.warning("检测到环境变量 %s,建议迁移到 config.toml", name) +# Re-export symbols that external modules import from this module. +__all__ = [ + "CONFIG_PATH", + "Config", + "DEFAULT_WEBUI_PASSWORD", + "DEFAULT_WEBUI_PORT", + "DEFAULT_WEBUI_URL", + "LOCAL_CONFIG_PATH", + "WebUISettings", + "load_local_admins", + "load_toml_data", + "load_webui_settings", + "save_local_admins", +] def _load_env() -> None: @@ -143,308 +194,6 @@ def load_toml_data( return {} -def _get_nested(data: dict[str, Any], path: tuple[str, ...]) -> Any: - node: Any = data - for key in path: - if not isinstance(node, dict) or key not in node: - return None - node = node[key] - return node - - -def _normalize_str(value: Any) -> Optional[str]: - if value is None: - return None - if isinstance(value, str): - stripped = value.strip() - return stripped if stripped else None - return str(value).strip() - - -def _coerce_int(value: Any, default: int) -> int: - if value is None: - return default - try: - return int(value) - except (TypeError, ValueError): - return default - - -def _coerce_float(value: Any, default: float) -> float: - if value is None: - return default - try: - return float(value) - except (TypeError, ValueError): - return default - - -def _normalize_queue_interval(value: float, default: float = 1.0) -> float: - """规范化队列发车间隔。 - - `0` 表示立即发车,负数回退到默认值。 - """ - - return default if value < 0 else value - - -def _coerce_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "on"} - return default - - -def _coerce_str(value: Any, default: str) -> str: - normalized = _normalize_str(value) - return normalized if normalized is not None else default - - -def _normalize_base_url(value: str, default: str) -> str: - normalized = value.strip().rstrip("/") - if normalized: - return normalized - return default.rstrip("/") - - -def _coerce_int_list(value: Any) -> list[int]: - if value is None: - return [] - if isinstance(value, list): - items: list[int] = [] - for item in value: - try: - items.append(int(item)) - except (TypeError, ValueError): - continue - return items - if isinstance(value, str): - parts = [part.strip() for part in value.split(",") if part.strip()] - items = [] - for part in parts: - try: - items.append(int(part)) - except ValueError: - continue - return items - return [] - - -def _coerce_str_list(value: Any) -> list[str]: - if value is None: - return [] - if isinstance(value, list): - return [str(item).strip() for item in value if str(item).strip()] - if isinstance(value, str): - return [part.strip() for part in value.split(",") if part.strip()] - return [] - - -def _coerce_request_params(value: Any) -> dict[str, Any]: - return normalize_request_params(value) - - -def _get_model_request_params(data: dict[str, Any], model_name: str) -> dict[str, Any]: - return _coerce_request_params( - _get_nested(data, ("models", model_name, "request_params")) - ) - - -def _get_value( - data: dict[str, Any], - path: tuple[str, ...], - env_key: Optional[str], -) -> Any: - value = _get_nested(data, path) - if value is not None: - return value - if env_key: - env_value = os.getenv(env_key) - if env_value is not None and str(env_value).strip() != "": - _warn_env_fallback(env_key) - return env_value - return None - - -_VALID_API_MODES = {"chat_completions", "responses"} -_VALID_REASONING_EFFORT_STYLES = {"openai", "anthropic"} - - -def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: - style = _coerce_str(value, default).strip().lower() - if style not in _VALID_REASONING_EFFORT_STYLES: - return default - return style - - -def _resolve_thinking_compat_flags( - data: dict[str, Any], - model_name: str, - include_budget_env_key: str, - tool_call_compat_env_key: str, - legacy_env_key: str, -) -> tuple[bool, bool]: - """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" - include_budget_value = _get_value( - data, - ("models", model_name, "thinking_include_budget"), - include_budget_env_key, - ) - tool_call_compat_value = _get_value( - data, - ("models", model_name, "thinking_tool_call_compat"), - tool_call_compat_env_key, - ) - legacy_value = _get_value( - data, - ("models", model_name, "deepseek_new_cot_support"), - legacy_env_key, - ) - - include_budget_default = True - tool_call_compat_default = True - if legacy_value is not None: - legacy_enabled = _coerce_bool(legacy_value, False) - include_budget_default = not legacy_enabled - tool_call_compat_default = legacy_enabled - - return ( - _coerce_bool(include_budget_value, include_budget_default), - _coerce_bool(tool_call_compat_value, tool_call_compat_default), - ) - - -def _resolve_api_mode( - data: dict[str, Any], - model_name: str, - env_key: str, - default: str = "chat_completions", -) -> str: - raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) - value = _coerce_str(raw_value, default).strip().lower() - if value not in _VALID_API_MODES: - return default - return value - - -def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: - return _coerce_str(value, default).strip().lower() - - -def _resolve_responses_tool_choice_compat( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_tool_choice_compat"), - env_key, - ), - default, - ) - - -def _resolve_responses_force_stateless_replay( - data: dict[str, Any], - model_name: str, - env_key: str, - default: bool = False, -) -> bool: - return _coerce_bool( - _get_value( - data, - ("models", model_name, "responses_force_stateless_replay"), - env_key, - ), - default, - ) - - -def load_local_admins() -> list[int]: - """从本地配置文件加载动态管理员列表""" - if not LOCAL_CONFIG_PATH.exists(): - return [] - try: - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - admin_qqs: list[int] = data.get("admin_qqs", []) - return admin_qqs - except Exception as exc: - logger.warning("读取本地配置失败: %s", exc) - return [] - - -def save_local_admins(admin_qqs: list[int]) -> None: - """保存动态管理员列表到本地配置文件""" - try: - data: dict[str, list[int]] = {} - if LOCAL_CONFIG_PATH.exists(): - with open(LOCAL_CONFIG_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - - data["admin_qqs"] = admin_qqs - - with open(LOCAL_CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - - logger.info("已保存管理员列表到 %s", LOCAL_CONFIG_PATH) - except Exception as exc: - logger.error("保存本地配置失败: %s", exc) - raise - - -@dataclass -class WebUISettings: - url: str - port: int - password: str - using_default_password: bool - config_exists: bool - - @property - def display_url(self) -> str: - """用于日志和展示的格式化 URL。""" - from Undefined.config.models import format_netloc - - return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" - - -def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: - data = load_toml_data(config_path) - config_exists = bool(data) - url_value = _get_value(data, ("webui", "url"), None) - port_value = _get_value(data, ("webui", "port"), None) - password_value = _get_value(data, ("webui", "password"), None) - - url = _coerce_str(url_value, DEFAULT_WEBUI_URL) - port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_WEBUI_PORT - - password_normalized = _normalize_str(password_value) - if not password_normalized: - return WebUISettings( - url=url, - port=port, - password=DEFAULT_WEBUI_PASSWORD, - using_default_password=True, - config_exists=config_exists, - ) - return WebUISettings( - url=url, - port=port, - password=password_normalized, - using_default_password=False, - config_exists=config_exists, - ) - - @dataclass class Config: """应用配置""" @@ -783,8 +532,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi _get_value(data, ("onebot", "token"), "ONEBOT_TOKEN"), "" ) - embedding_model = cls._parse_embedding_model_config(data) - rerank_model = cls._parse_rerank_model_config(data) + embedding_model = _parse_embedding_model_config(data) + rerank_model = _parse_rerank_model_config(data) knowledge_enabled = _coerce_bool( _get_value(data, ("knowledge", "enabled"), None), False @@ -858,8 +607,8 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ) knowledge_rerank_top_k = fallback - chat_model = cls._parse_chat_model_config(data) - vision_model = cls._parse_vision_model_config(data) + chat_model = _parse_chat_model_config(data) + vision_model = _parse_vision_model_config(data) security_model_enabled = _coerce_bool( _get_value( data, @@ -868,20 +617,20 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi ), True, ) - security_model = cls._parse_security_model_config(data, chat_model) - naga_model = cls._parse_naga_model_config(data, security_model) - agent_model = cls._parse_agent_model_config(data) - historian_model = cls._parse_historian_model_config(data, agent_model) - summary_model, summary_model_configured = cls._parse_summary_model_config( + security_model = _parse_security_model_config(data, chat_model) + naga_model = _parse_naga_model_config(data, security_model) + agent_model = _parse_agent_model_config(data) + historian_model = _parse_historian_model_config(data, agent_model) + summary_model, summary_model_configured = _parse_summary_model_config( data, agent_model ) - grok_model = cls._parse_grok_model_config(data) + grok_model = _parse_grok_model_config(data) model_pool_enabled = _coerce_bool( _get_value(data, ("features", "pool_enabled"), "MODEL_POOL_ENABLED"), False ) - superadmin_qq, admin_qqs = cls._merge_admins( + superadmin_qq, admin_qqs = _merge_admins( superadmin_qq=superadmin_qq, admin_qqs=admin_qqs ) @@ -1013,7 +762,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi if easter_egg_mode_raw is not None: _warn_env_fallback("EASTER_EGG_CALL_MESSAGE_MODE") - easter_egg_agent_call_message_mode = cls._parse_easter_egg_call_mode( + easter_egg_agent_call_message_mode = _parse_easter_egg_call_mode( easter_egg_mode_raw ) @@ -1443,17 +1192,17 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi messages_send_url_file_max_size_mb = 100 webui_settings = load_webui_settings(config_path) - api_config = cls._parse_api_config(data) + api_config = _parse_api_config(data) - cognitive = cls._parse_cognitive_config(data) - memes = cls._parse_memes_config(data) - naga = cls._parse_naga_config(data) - models_image_gen = cls._parse_image_gen_model_config(data) - models_image_edit = cls._parse_image_edit_model_config(data) - image_gen = cls._parse_image_gen_config(data) + cognitive = _parse_cognitive_config(data) + memes = _parse_memes_config(data) + naga = _parse_naga_config(data) + models_image_gen = _parse_image_gen_model_config(data) + models_image_edit = _parse_image_edit_model_config(data) + image_gen = _parse_image_gen_config(data) if strict: - cls._verify_required_fields( + _verify_required_fields( bot_qq=bot_qq, superadmin_qq=superadmin_qq, onebot_ws_url=onebot_ws_url, @@ -1464,7 +1213,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi embedding_model=embedding_model, ) - cls._log_debug_info( + _log_debug_info( chat_model, vision_model, security_model, @@ -1793,1056 +1542,6 @@ def security_check_enabled(self) -> bool: return bool(self.security_model_enabled) - @staticmethod - def _parse_model_pool( - data: dict[str, Any], - model_section: str, - primary_config: ChatModelConfig | AgentModelConfig, - ) -> ModelPool | None: - """解析模型池配置,缺省字段继承 primary_config""" - pool_data = data.get("models", {}).get(model_section, {}).get("pool") - if not isinstance(pool_data, dict): - return None - - enabled = _coerce_bool(pool_data.get("enabled"), False) - strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() - if strategy not in ("default", "round_robin", "random"): - strategy = "default" - - raw_models = pool_data.get("models") - if not isinstance(raw_models, list): - return ModelPool(enabled=enabled, strategy=strategy) - - entries: list[ModelPoolEntry] = [] - for item in raw_models: - if not isinstance(item, dict): - continue - name = _coerce_str(item.get("model_name"), "").strip() - if not name: - continue - entries.append( - ModelPoolEntry( - api_url=_coerce_str(item.get("api_url"), primary_config.api_url), - api_key=_coerce_str(item.get("api_key"), primary_config.api_key), - model_name=name, - max_tokens=_coerce_int( - item.get("max_tokens"), primary_config.max_tokens - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - item.get("queue_interval_seconds"), - primary_config.queue_interval_seconds, - ), - primary_config.queue_interval_seconds, - ), - api_mode=( - _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - ) - if _coerce_str(item.get("api_mode"), primary_config.api_mode) - .strip() - .lower() - in _VALID_API_MODES - else primary_config.api_mode, - thinking_enabled=_coerce_bool( - item.get("thinking_enabled"), primary_config.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - item.get("thinking_budget_tokens"), - primary_config.thinking_budget_tokens, - ), - thinking_include_budget=_coerce_bool( - item.get("thinking_include_budget"), - primary_config.thinking_include_budget, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - item.get("reasoning_effort_style"), - primary_config.reasoning_effort_style, - ), - thinking_tool_call_compat=_coerce_bool( - item.get("thinking_tool_call_compat"), - primary_config.thinking_tool_call_compat, - ), - responses_tool_choice_compat=_coerce_bool( - item.get("responses_tool_choice_compat"), - primary_config.responses_tool_choice_compat, - ), - responses_force_stateless_replay=_coerce_bool( - item.get("responses_force_stateless_replay"), - primary_config.responses_force_stateless_replay, - ), - prompt_cache_enabled=_coerce_bool( - item.get("prompt_cache_enabled"), - primary_config.prompt_cache_enabled, - ), - reasoning_enabled=_coerce_bool( - item.get("reasoning_enabled"), - primary_config.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - item.get("reasoning_effort"), - primary_config.reasoning_effort, - ), - request_params=merge_request_params( - primary_config.request_params, - item.get("request_params"), - ), - ) - ) - - return ModelPool(enabled=enabled, strategy=strategy, models=entries) - - @staticmethod - def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: - return EmbeddingModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=_normalize_queue_interval( - _coerce_float( - _get_value( - data, ("models", "embedding", "queue_interval_seconds"), None - ), - 0.0, - ), - 0.0, - ), - dimensions=_coerce_int( - _get_value(data, ("models", "embedding", "dimensions"), None), 0 - ) - or None, - query_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "query_instruction"), None), "" - ), - document_instruction=_coerce_str( - _get_value(data, ("models", "embedding", "document_instruction"), None), - "", - ), - request_params=_get_model_request_params(data, "embedding"), - ) - - @staticmethod - def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), - 0.0, - ), - 0.0, - ) - return RerankModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - query_instruction=_coerce_str( - _get_value(data, ("models", "rerank", "query_instruction"), None), "" - ), - request_params=_get_model_request_params(data, "rerank"), - ) - - @staticmethod - def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "chat", "queue_interval_seconds"), - "CHAT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="chat", - include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "prompt_cache_enabled"), - "CHAT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "chat", "reasoning_enabled"), - "CHAT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "chat", "reasoning_effort"), - "CHAT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = ChatModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "chat", "thinking_enabled"), - "CHAT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "chat", "thinking_budget_tokens"), - "CHAT_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "chat", "reasoning_effort_style"), - "CHAT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "chat"), - ) - config.pool = Config._parse_model_pool(data, "chat", config) - return config - - @staticmethod - def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "vision", "queue_interval_seconds"), - "VISION_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="vision", - include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "prompt_cache_enabled"), - "VISION_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "vision", "reasoning_enabled"), - "VISION_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "vision", "reasoning_effort"), - "VISION_MODEL_REASONING_EFFORT", - ), - "medium", - ) - return VisionModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "vision", "model_name"), "VISION_MODEL_NAME" - ), - "", - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "vision", "thinking_enabled"), - "VISION_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "vision", "thinking_budget_tokens"), - "VISION_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "vision", "reasoning_effort_style"), - "VISION_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "vision"), - ) - - @staticmethod - def _parse_security_model_config( - data: dict[str, Any], chat_model: ChatModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value( - data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL" - ), - "", - ) - api_key = _coerce_str( - _get_value( - data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY" - ), - "", - ) - model_name = _coerce_str( - _get_value( - data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME" - ), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "security", "queue_interval_seconds"), - "SECURITY_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="security", - include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "prompt_cache_enabled"), - "SECURITY_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "security", "reasoning_enabled"), - "SECURITY_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "security", "reasoning_effort"), - "SECURITY_MODEL_REASONING_EFFORT", - ), - "medium", - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "max_tokens"), - "SECURITY_MODEL_MAX_TOKENS", - ), - 100, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "security", "thinking_enabled"), - "SECURITY_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "security", "thinking_budget_tokens"), - "SECURITY_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "security", "reasoning_effort_style"), - "SECURITY_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "security"), - ) - - logger.warning("未配置安全模型,将使用对话模型作为后备") - return SecurityModelConfig( - api_url=chat_model.api_url, - api_key=chat_model.api_key, - model_name=chat_model.model_name, - max_tokens=chat_model.max_tokens, - queue_interval_seconds=chat_model.queue_interval_seconds, - api_mode=chat_model.api_mode, - thinking_enabled=False, - thinking_budget_tokens=0, - thinking_include_budget=True, - reasoning_effort_style="openai", - thinking_tool_call_compat=chat_model.thinking_tool_call_compat, - responses_tool_choice_compat=chat_model.responses_tool_choice_compat, - responses_force_stateless_replay=chat_model.responses_force_stateless_replay, - prompt_cache_enabled=chat_model.prompt_cache_enabled, - reasoning_enabled=chat_model.reasoning_enabled, - reasoning_effort=chat_model.reasoning_effort, - request_params=merge_request_params(chat_model.request_params), - ) - - @staticmethod - def _parse_naga_model_config( - data: dict[str, Any], security_model: SecurityModelConfig - ) -> SecurityModelConfig: - api_url = _coerce_str( - _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), - "", - ) - api_key = _coerce_str( - _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), - "", - ) - model_name = _coerce_str( - _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), - "", - ) - queue_interval_seconds = _coerce_float( - _get_value( - data, - ("models", "naga", "queue_interval_seconds"), - "NAGA_MODEL_QUEUE_INTERVAL", - ), - security_model.queue_interval_seconds, - ) - queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) - - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="naga", - include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "prompt_cache_enabled"), - "NAGA_MODEL_PROMPT_CACHE_ENABLED", - ), - getattr(security_model, "prompt_cache_enabled", True), - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "naga", "reasoning_enabled"), - "NAGA_MODEL_REASONING_ENABLED", - ), - getattr(security_model, "reasoning_enabled", False), - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "naga", "reasoning_effort"), - "NAGA_MODEL_REASONING_EFFORT", - ), - getattr(security_model, "reasoning_effort", "medium"), - ) - - if api_url and api_key and model_name: - return SecurityModelConfig( - api_url=api_url, - api_key=api_key, - model_name=model_name, - max_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "max_tokens"), - "NAGA_MODEL_MAX_TOKENS", - ), - 160, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "naga", "thinking_enabled"), - "NAGA_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "naga", "thinking_budget_tokens"), - "NAGA_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "naga", "reasoning_effort_style"), - "NAGA_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "naga"), - ) - - logger.info( - "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" - ) - return SecurityModelConfig( - api_url=security_model.api_url, - api_key=security_model.api_key, - model_name=security_model.model_name, - max_tokens=security_model.max_tokens, - queue_interval_seconds=security_model.queue_interval_seconds, - api_mode=security_model.api_mode, - thinking_enabled=security_model.thinking_enabled, - thinking_budget_tokens=security_model.thinking_budget_tokens, - thinking_include_budget=security_model.thinking_include_budget, - reasoning_effort_style=security_model.reasoning_effort_style, - thinking_tool_call_compat=security_model.thinking_tool_call_compat, - responses_tool_choice_compat=security_model.responses_tool_choice_compat, - responses_force_stateless_replay=security_model.responses_force_stateless_replay, - prompt_cache_enabled=security_model.prompt_cache_enabled, - reasoning_enabled=security_model.reasoning_enabled, - reasoning_effort=security_model.reasoning_effort, - request_params=merge_request_params(security_model.request_params), - ) - - @staticmethod - def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "agent", "queue_interval_seconds"), - "AGENT_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data=data, - model_name="agent", - include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "prompt_cache_enabled"), - "AGENT_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ) - reasoning_enabled = _coerce_bool( - _get_value( - data, - ("models", "agent", "reasoning_enabled"), - "AGENT_MODEL_REASONING_ENABLED", - ), - False, - ) - reasoning_effort = _resolve_reasoning_effort( - _get_value( - data, - ("models", "agent", "reasoning_effort"), - "AGENT_MODEL_REASONING_EFFORT", - ), - "medium", - ) - config = AgentModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" - ), - 4096, - ), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "agent", "thinking_enabled"), - "AGENT_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "agent", "thinking_budget_tokens"), - "AGENT_MODEL_THINKING_BUDGET_TOKENS", - ), - 0, - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "agent", "reasoning_effort_style"), - "AGENT_MODEL_REASONING_EFFORT_STYLE", - ), - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=reasoning_enabled, - reasoning_effort=reasoning_effort, - request_params=_get_model_request_params(data, "agent"), - ) - config.pool = Config._parse_model_pool(data, "agent", config) - return config - - @staticmethod - def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: - queue_interval_seconds = _normalize_queue_interval( - _coerce_float( - _get_value( - data, - ("models", "grok", "queue_interval_seconds"), - "GROK_MODEL_QUEUE_INTERVAL", - ), - 1.0, - ) - ) - return GrokModelConfig( - api_url=_coerce_str( - _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), - "", - ), - api_key=_coerce_str( - _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), - "", - ), - model_name=_coerce_str( - _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), - "", - ), - max_tokens=_coerce_int( - _get_value( - data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS" - ), - 8192, - ), - queue_interval_seconds=queue_interval_seconds, - thinking_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_enabled"), - "GROK_MODEL_THINKING_ENABLED", - ), - False, - ), - thinking_budget_tokens=_coerce_int( - _get_value( - data, - ("models", "grok", "thinking_budget_tokens"), - "GROK_MODEL_THINKING_BUDGET_TOKENS", - ), - 20000, - ), - thinking_include_budget=_coerce_bool( - _get_value( - data, - ("models", "grok", "thinking_include_budget"), - "GROK_MODEL_THINKING_INCLUDE_BUDGET", - ), - True, - ), - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - data, - ("models", "grok", "reasoning_effort_style"), - "GROK_MODEL_REASONING_EFFORT_STYLE", - ), - ), - prompt_cache_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "prompt_cache_enabled"), - "GROK_MODEL_PROMPT_CACHE_ENABLED", - ), - True, - ), - reasoning_enabled=_coerce_bool( - _get_value( - data, - ("models", "grok", "reasoning_enabled"), - "GROK_MODEL_REASONING_ENABLED", - ), - False, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - data, - ("models", "grok", "reasoning_effort"), - "GROK_MODEL_REASONING_EFFORT", - ), - "medium", - ), - request_params=_get_model_request_params(data, "grok"), - ) - - @staticmethod - def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_gen] 生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" - ), - "", - ), - request_params=_get_model_request_params(data, "image_gen"), - ) - - @staticmethod - def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: - """解析 [models.image_edit] 参考图生图模型配置""" - return ImageGenModelConfig( - api_url=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_url"), - "IMAGE_EDIT_MODEL_API_URL", - ), - "", - ), - api_key=_coerce_str( - _get_value( - data, - ("models", "image_edit", "api_key"), - "IMAGE_EDIT_MODEL_API_KEY", - ), - "", - ), - model_name=_coerce_str( - _get_value( - data, - ("models", "image_edit", "model_name"), - "IMAGE_EDIT_MODEL_NAME", - ), - "", - ), - request_params=_get_model_request_params(data, "image_edit"), - ) - - @staticmethod - def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: - """解析 [image_gen] 生图工具配置""" - return ImageGenConfig( - provider=_coerce_str( - _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), - "xingzhige", - ), - xingzhige_size=_coerce_str( - _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" - ), - openai_size=_coerce_str( - _get_value(data, ("image_gen", "openai_size"), None), "" - ), - openai_quality=_coerce_str( - _get_value(data, ("image_gen", "openai_quality"), None), "" - ), - openai_style=_coerce_str( - _get_value(data, ("image_gen", "openai_style"), None), "" - ), - openai_timeout=_coerce_float( - _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 - ), - ) - - @staticmethod - def _merge_admins( - superadmin_qq: int, admin_qqs: list[int] - ) -> tuple[int, list[int]]: - local_admins = load_local_admins() - all_admins = list(set(admin_qqs + local_admins)) - if superadmin_qq and superadmin_qq not in all_admins: - all_admins.append(superadmin_qq) - return superadmin_qq, all_admins - - @staticmethod - def _verify_required_fields( - bot_qq: int, - superadmin_qq: int, - onebot_ws_url: str, - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - agent_model: AgentModelConfig, - knowledge_enabled: bool, - embedding_model: EmbeddingModelConfig, - ) -> None: - missing: list[str] = [] - if bot_qq <= 0: - missing.append("core.bot_qq") - if superadmin_qq <= 0: - missing.append("core.superadmin_qq") - if not onebot_ws_url: - missing.append("onebot.ws_url") - if not chat_model.api_url: - missing.append("models.chat.api_url") - if not chat_model.api_key: - missing.append("models.chat.api_key") - if not chat_model.model_name: - missing.append("models.chat.model_name") - if not vision_model.api_url: - missing.append("models.vision.api_url") - if not vision_model.api_key: - missing.append("models.vision.api_key") - if not vision_model.model_name: - missing.append("models.vision.model_name") - if not agent_model.api_url: - missing.append("models.agent.api_url") - if not agent_model.api_key: - missing.append("models.agent.api_key") - if not agent_model.model_name: - missing.append("models.agent.model_name") - if knowledge_enabled: - if not embedding_model.api_url: - missing.append("models.embedding.api_url") - if not embedding_model.model_name: - missing.append("models.embedding.model_name") - if missing: - raise ValueError(f"缺少必需配置: {', '.join(missing)}") - - @staticmethod - def _log_debug_info( - chat_model: ChatModelConfig, - vision_model: VisionModelConfig, - security_model: SecurityModelConfig, - naga_model: SecurityModelConfig, - agent_model: AgentModelConfig, - summary_model: AgentModelConfig, - grok_model: GrokModelConfig, - ) -> None: - configs: list[ - tuple[ - str, - ChatModelConfig - | VisionModelConfig - | SecurityModelConfig - | AgentModelConfig - | GrokModelConfig, - ] - ] = [ - ("chat", chat_model), - ("vision", vision_model), - ("security", security_model), - ("naga", naga_model), - ("agent", agent_model), - ("summary", summary_model), - ("grok", grok_model), - ] - for name, cfg in configs: - logger.debug( - "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", - name, - cfg.model_name, - cfg.api_url, - bool(cfg.api_key), - getattr(cfg, "api_mode", "chat_completions"), - cfg.thinking_enabled, - getattr(cfg, "reasoning_enabled", False), - getattr(cfg, "reasoning_effort", "medium"), - getattr(cfg, "thinking_tool_call_compat", False), - getattr(cfg, "responses_tool_choice_compat", False), - getattr(cfg, "responses_force_stateless_replay", False), - ) - def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes: dict[str, tuple[Any, Any]] = {} for field in fields(self): @@ -2866,457 +1565,6 @@ def update_from(self, new_config: "Config") -> dict[str, tuple[Any, Any]]: changes[name] = (old_value, new_value) return changes - @staticmethod - def _parse_historian_model_config( - data: dict[str, Any], fallback: AgentModelConfig - ) -> AgentModelConfig: - h = data.get("models", {}).get("historian", {}) - if not isinstance(h, dict) or not h: - return fallback - queue_interval_seconds = _coerce_float( - h.get("queue_interval_seconds"), fallback.queue_interval_seconds - ) - queue_interval_seconds = _normalize_queue_interval( - queue_interval_seconds, fallback.queue_interval_seconds - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data={"models": {"historian": h}}, - model_name="historian", - include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_API_MODE", - fallback.api_mode, - ) - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", - fallback.responses_tool_choice_compat, - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - {"models": {"historian": h}}, - "historian", - "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", - fallback.responses_force_stateless_replay, - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "prompt_cache_enabled"), - "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", - ), - fallback.prompt_cache_enabled, - ) - return AgentModelConfig( - api_url=_coerce_str(h.get("api_url"), fallback.api_url), - api_key=_coerce_str(h.get("api_key"), fallback.api_key), - model_name=_coerce_str(h.get("model_name"), fallback.model_name), - max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - h.get("thinking_enabled"), fallback.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort_style"), - "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", - ), - fallback.reasoning_effort_style, - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=_coerce_bool( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_enabled"), - "HISTORIAN_MODEL_REASONING_ENABLED", - ), - fallback.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - {"models": {"historian": h}}, - ("models", "historian", "reasoning_effort"), - "HISTORIAN_MODEL_REASONING_EFFORT", - ), - fallback.reasoning_effort, - ), - request_params=merge_request_params( - fallback.request_params, - h.get("request_params"), - ), - ) - - @staticmethod - def _parse_summary_model_config( - data: dict[str, Any], fallback: AgentModelConfig - ) -> tuple[AgentModelConfig, bool]: - s = data.get("models", {}).get("summary", {}) - if not isinstance(s, dict) or not s: - return fallback, False - queue_interval_seconds = _coerce_float( - s.get("queue_interval_seconds"), fallback.queue_interval_seconds - ) - queue_interval_seconds = _normalize_queue_interval( - queue_interval_seconds, fallback.queue_interval_seconds - ) - thinking_include_budget, thinking_tool_call_compat = ( - _resolve_thinking_compat_flags( - data={"models": {"summary": s}}, - model_name="summary", - include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", - tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", - legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", - ) - ) - api_mode = _resolve_api_mode( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_API_MODE", - fallback.api_mode, - ) - responses_tool_choice_compat = _resolve_responses_tool_choice_compat( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", - fallback.responses_tool_choice_compat, - ) - responses_force_stateless_replay = _resolve_responses_force_stateless_replay( - {"models": {"summary": s}}, - "summary", - "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", - fallback.responses_force_stateless_replay, - ) - prompt_cache_enabled = _coerce_bool( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "prompt_cache_enabled"), - "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", - ), - fallback.prompt_cache_enabled, - ) - return ( - AgentModelConfig( - api_url=_coerce_str(s.get("api_url"), fallback.api_url), - api_key=_coerce_str(s.get("api_key"), fallback.api_key), - model_name=_coerce_str(s.get("model_name"), fallback.model_name), - max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), - queue_interval_seconds=queue_interval_seconds, - api_mode=api_mode, - thinking_enabled=_coerce_bool( - s.get("thinking_enabled"), fallback.thinking_enabled - ), - thinking_budget_tokens=_coerce_int( - s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens - ), - thinking_include_budget=thinking_include_budget, - reasoning_effort_style=_resolve_reasoning_effort_style( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_effort_style"), - "SUMMARY_MODEL_REASONING_EFFORT_STYLE", - ), - fallback.reasoning_effort_style, - ), - thinking_tool_call_compat=thinking_tool_call_compat, - responses_tool_choice_compat=responses_tool_choice_compat, - responses_force_stateless_replay=responses_force_stateless_replay, - prompt_cache_enabled=prompt_cache_enabled, - reasoning_enabled=_coerce_bool( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_enabled"), - "SUMMARY_MODEL_REASONING_ENABLED", - ), - fallback.reasoning_enabled, - ), - reasoning_effort=_resolve_reasoning_effort( - _get_value( - {"models": {"summary": s}}, - ("models", "summary", "reasoning_effort"), - "SUMMARY_MODEL_REASONING_EFFORT", - ), - fallback.reasoning_effort, - ), - request_params=merge_request_params( - fallback.request_params, - s.get("request_params"), - ), - ), - True, - ) - - @staticmethod - def _parse_cognitive_config(data: dict[str, Any]) -> CognitiveConfig: - cog = data.get("cognitive", {}) - vs = cog.get("vector_store", {}) if isinstance(cog, dict) else {} - q = cog.get("query", {}) if isinstance(cog, dict) else {} - hist = cog.get("historian", {}) if isinstance(cog, dict) else {} - prof = cog.get("profile", {}) if isinstance(cog, dict) else {} - que = cog.get("queue", {}) if isinstance(cog, dict) else {} - return CognitiveConfig( - enabled=_coerce_bool( - cog.get("enabled") if isinstance(cog, dict) else None, True - ), - bot_name=_coerce_str( - cog.get("bot_name") if isinstance(cog, dict) else None, - "Undefined", - ), - vector_store_path=_coerce_str( - vs.get("path") if isinstance(vs, dict) else None, - "data/cognitive/chromadb", - ), - queue_path=_coerce_str( - que.get("path") if isinstance(que, dict) else None, - "data/cognitive/queues", - ), - profiles_path=_coerce_str( - prof.get("path") if isinstance(prof, dict) else None, - "data/cognitive/profiles", - ), - auto_top_k=_coerce_int( - q.get("auto_top_k") if isinstance(q, dict) else None, 3 - ), - auto_scope_candidate_multiplier=_coerce_int( - q.get("auto_scope_candidate_multiplier") - if isinstance(q, dict) - else None, - 2, - ), - auto_current_group_boost=_coerce_float( - q.get("auto_current_group_boost") if isinstance(q, dict) else None, - 1.15, - ), - auto_current_private_boost=_coerce_float( - q.get("auto_current_private_boost") if isinstance(q, dict) else None, - 1.25, - ), - enable_rerank=_coerce_bool( - q.get("enable_rerank") if isinstance(q, dict) else None, True - ), - recent_end_summaries_inject_k=_coerce_int( - q.get("recent_end_summaries_inject_k") if isinstance(q, dict) else None, - 30, - ), - time_decay_enabled=_coerce_bool( - q.get("time_decay_enabled") if isinstance(q, dict) else None, True - ), - time_decay_half_life_days_auto=_coerce_float( - q.get("time_decay_half_life_days_auto") - if isinstance(q, dict) - else None, - 14.0, - ), - time_decay_half_life_days_tool=_coerce_float( - q.get("time_decay_half_life_days_tool") - if isinstance(q, dict) - else None, - 60.0, - ), - time_decay_boost=_coerce_float( - q.get("time_decay_boost") if isinstance(q, dict) else None, 0.2 - ), - time_decay_min_similarity=_coerce_float( - q.get("time_decay_min_similarity") if isinstance(q, dict) else None, - 0.35, - ), - tool_default_top_k=_coerce_int( - q.get("tool_default_top_k") if isinstance(q, dict) else None, 12 - ), - profile_top_k=_coerce_int( - q.get("profile_top_k") if isinstance(q, dict) else None, 8 - ), - rerank_candidate_multiplier=_coerce_int( - q.get("rerank_candidate_multiplier") if isinstance(q, dict) else None, 3 - ), - rewrite_max_retry=_coerce_int( - hist.get("rewrite_max_retry") if isinstance(hist, dict) else None, 2 - ), - historian_recent_messages_inject_k=_coerce_int( - hist.get("recent_messages_inject_k") - if isinstance(hist, dict) - else None, - 12, - ), - historian_recent_message_line_max_len=_coerce_int( - hist.get("recent_message_line_max_len") - if isinstance(hist, dict) - else None, - 240, - ), - historian_source_message_max_len=_coerce_int( - hist.get("source_message_max_len") if isinstance(hist, dict) else None, - 800, - ), - poll_interval_seconds=_coerce_float( - hist.get("poll_interval_seconds") if isinstance(hist, dict) else None, - 1.0, - ), - stale_job_timeout_seconds=_coerce_float( - hist.get("stale_job_timeout_seconds") - if isinstance(hist, dict) - else None, - 300.0, - ), - profile_revision_keep=_coerce_int( - prof.get("revision_keep") if isinstance(prof, dict) else None, 5 - ), - failed_max_age_days=_coerce_int( - que.get("failed_max_age_days") if isinstance(que, dict) else None, 30 - ), - failed_max_files=_coerce_int( - que.get("failed_max_files") if isinstance(que, dict) else None, 500 - ), - failed_cleanup_interval=_coerce_int( - que.get("failed_cleanup_interval") if isinstance(que, dict) else None, - 100, - ), - job_max_retries=_coerce_int( - que.get("job_max_retries") if isinstance(que, dict) else None, 3 - ), - ) - - @staticmethod - def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: - section_raw = data.get("memes", {}) - section = section_raw if isinstance(section_raw, dict) else {} - return MemeConfig( - enabled=_coerce_bool(section.get("enabled"), True), - query_default_mode=_coerce_str(section.get("query_default_mode"), "hybrid"), - max_source_image_bytes=max( - 1, - _coerce_int(section.get("max_source_image_bytes"), 500 * 1024), - ), - blob_dir=_coerce_str(section.get("blob_dir"), "data/memes/blobs"), - preview_dir=_coerce_str(section.get("preview_dir"), "data/memes/previews"), - db_path=_coerce_str(section.get("db_path"), "data/memes/memes.sqlite3"), - vector_store_path=_coerce_str( - section.get("vector_store_path"), "data/memes/chromadb" - ), - queue_path=_coerce_str(section.get("queue_path"), "data/memes/queues"), - max_items=max(1, _coerce_int(section.get("max_items"), 10000)), - max_total_bytes=max( - 1, - _coerce_int(section.get("max_total_bytes"), 5 * 1024 * 1024 * 1024), - ), - allow_gif=_coerce_bool(section.get("allow_gif"), True), - auto_ingest_group=_coerce_bool(section.get("auto_ingest_group"), True), - auto_ingest_private=_coerce_bool(section.get("auto_ingest_private"), True), - keyword_top_k=max(1, _coerce_int(section.get("keyword_top_k"), 30)), - semantic_top_k=max(1, _coerce_int(section.get("semantic_top_k"), 30)), - rerank_top_k=max(1, _coerce_int(section.get("rerank_top_k"), 20)), - worker_max_concurrency=max( - 1, _coerce_int(section.get("worker_max_concurrency"), 4) - ), - ) - - @staticmethod - def _parse_api_config(data: dict[str, Any]) -> APIConfig: - section_raw = data.get("api", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), True) - host = _coerce_str(section.get("host"), DEFAULT_API_HOST) - port = _coerce_int(section.get("port"), DEFAULT_API_PORT) - if port <= 0 or port > 65535: - port = DEFAULT_API_PORT - - auth_key = _coerce_str(section.get("auth_key"), DEFAULT_API_AUTH_KEY) - if not auth_key: - auth_key = DEFAULT_API_AUTH_KEY - - openapi_enabled = _coerce_bool(section.get("openapi_enabled"), True) - - tool_invoke_enabled = _coerce_bool(section.get("tool_invoke_enabled"), False) - tool_invoke_expose = _coerce_str( - section.get("tool_invoke_expose"), "tools+toolsets" - ) - valid_expose = {"tools", "toolsets", "tools+toolsets", "agents", "all"} - if tool_invoke_expose not in valid_expose: - tool_invoke_expose = "tools+toolsets" - tool_invoke_allowlist = _coerce_str_list(section.get("tool_invoke_allowlist")) - tool_invoke_denylist = _coerce_str_list(section.get("tool_invoke_denylist")) - tool_invoke_timeout = _coerce_int(section.get("tool_invoke_timeout"), 120) - if tool_invoke_timeout <= 0: - tool_invoke_timeout = 120 - tool_invoke_callback_timeout = _coerce_int( - section.get("tool_invoke_callback_timeout"), 10 - ) - if tool_invoke_callback_timeout <= 0: - tool_invoke_callback_timeout = 10 - - return APIConfig( - enabled=enabled, - host=host, - port=port, - auth_key=auth_key, - openapi_enabled=openapi_enabled, - tool_invoke_enabled=tool_invoke_enabled, - tool_invoke_expose=tool_invoke_expose, - tool_invoke_allowlist=tool_invoke_allowlist, - tool_invoke_denylist=tool_invoke_denylist, - tool_invoke_timeout=tool_invoke_timeout, - tool_invoke_callback_timeout=tool_invoke_callback_timeout, - ) - - @staticmethod - def _parse_naga_config(data: dict[str, Any]) -> NagaConfig: - section_raw = data.get("naga", {}) - section = section_raw if isinstance(section_raw, dict) else {} - - enabled = _coerce_bool(section.get("enabled"), False) - api_url = _coerce_str(section.get("api_url"), "") - api_key = _coerce_str(section.get("api_key"), "") - moderation_enabled = _coerce_bool(section.get("moderation_enabled"), True) - allowed_groups = frozenset(_coerce_int_list(section.get("allowed_groups"))) - - return NagaConfig( - enabled=enabled, - api_url=api_url, - api_key=api_key, - moderation_enabled=moderation_enabled, - allowed_groups=allowed_groups, - ) - - @staticmethod - def _parse_easter_egg_call_mode(value: Any) -> str: - """解析彩蛋提示模式。 - - 兼容旧版布尔值: - - True => agent - - False => none - """ - if isinstance(value, bool): - return "agent" if value else "none" - if isinstance(value, (int, float)): - return "agent" if bool(value) else "none" - if value is None: - return "none" - - text = str(value).strip().lower() - if text in {"true", "1", "yes", "on"}: - return "agent" - if text in {"false", "0", "no", "off"}: - return "none" - if text in {"none", "agent", "tools", "all", "clean"}: - return text - return "none" - def reload(self, strict: bool = False) -> dict[str, tuple[Any, Any]]: new_config = Config.load(strict=strict) return self.update_from(new_config) @@ -3346,20 +1594,3 @@ def is_superadmin(self, qq: int) -> bool: def is_admin(self, qq: int) -> bool: return qq in self.admin_qqs - - -def _update_dataclass( - old_value: Any, new_value: Any, prefix: str -) -> dict[str, tuple[Any, Any]]: - changes: dict[str, tuple[Any, Any]] = {} - if not isinstance(old_value, type(new_value)): - changes[prefix] = (old_value, new_value) - return changes - for field in fields(old_value): - name = field.name - old_attr = getattr(old_value, name) - new_attr = getattr(new_value, name) - if old_attr != new_attr: - setattr(old_value, name, new_attr) - changes[f"{prefix}.{name}"] = (old_attr, new_attr) - return changes diff --git a/src/Undefined/config/model_parsers.py b/src/Undefined/config/model_parsers.py new file mode 100644 index 0000000..e17fc85 --- /dev/null +++ b/src/Undefined/config/model_parsers.py @@ -0,0 +1,1250 @@ +"""Model configuration parsers extracted from Config class.""" + +from __future__ import annotations + +import logging +from typing import Any + +from Undefined.utils.request_params import merge_request_params + +from .admin import load_local_admins +from .coercers import ( + _coerce_bool, + _coerce_float, + _coerce_int, + _coerce_str, + _get_model_request_params, + _get_value, + _normalize_queue_interval, + _VALID_API_MODES, +) +from .models import ( + AgentModelConfig, + ChatModelConfig, + EmbeddingModelConfig, + GrokModelConfig, + ImageGenConfig, + ImageGenModelConfig, + ModelPool, + ModelPoolEntry, + RerankModelConfig, + SecurityModelConfig, + VisionModelConfig, +) +from .resolvers import ( + _resolve_api_mode, + _resolve_reasoning_effort, + _resolve_reasoning_effort_style, + _resolve_responses_force_stateless_replay, + _resolve_responses_tool_choice_compat, + _resolve_thinking_compat_flags, +) + +logger = logging.getLogger(__name__) + + +def _parse_model_pool( + data: dict[str, Any], + model_section: str, + primary_config: ChatModelConfig | AgentModelConfig, +) -> ModelPool | None: + """解析模型池配置,缺省字段继承 primary_config""" + pool_data = data.get("models", {}).get(model_section, {}).get("pool") + if not isinstance(pool_data, dict): + return None + + enabled = _coerce_bool(pool_data.get("enabled"), False) + strategy = _coerce_str(pool_data.get("strategy"), "default").strip().lower() + if strategy not in ("default", "round_robin", "random"): + strategy = "default" + + raw_models = pool_data.get("models") + if not isinstance(raw_models, list): + return ModelPool(enabled=enabled, strategy=strategy) + + entries: list[ModelPoolEntry] = [] + for item in raw_models: + if not isinstance(item, dict): + continue + name = _coerce_str(item.get("model_name"), "").strip() + if not name: + continue + entries.append( + ModelPoolEntry( + api_url=_coerce_str(item.get("api_url"), primary_config.api_url), + api_key=_coerce_str(item.get("api_key"), primary_config.api_key), + model_name=name, + max_tokens=_coerce_int( + item.get("max_tokens"), primary_config.max_tokens + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + item.get("queue_interval_seconds"), + primary_config.queue_interval_seconds, + ), + primary_config.queue_interval_seconds, + ), + api_mode=( + _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + ) + if _coerce_str(item.get("api_mode"), primary_config.api_mode) + .strip() + .lower() + in _VALID_API_MODES + else primary_config.api_mode, + thinking_enabled=_coerce_bool( + item.get("thinking_enabled"), primary_config.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + item.get("thinking_budget_tokens"), + primary_config.thinking_budget_tokens, + ), + thinking_include_budget=_coerce_bool( + item.get("thinking_include_budget"), + primary_config.thinking_include_budget, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + item.get("reasoning_effort_style"), + primary_config.reasoning_effort_style, + ), + thinking_tool_call_compat=_coerce_bool( + item.get("thinking_tool_call_compat"), + primary_config.thinking_tool_call_compat, + ), + responses_tool_choice_compat=_coerce_bool( + item.get("responses_tool_choice_compat"), + primary_config.responses_tool_choice_compat, + ), + responses_force_stateless_replay=_coerce_bool( + item.get("responses_force_stateless_replay"), + primary_config.responses_force_stateless_replay, + ), + prompt_cache_enabled=_coerce_bool( + item.get("prompt_cache_enabled"), + primary_config.prompt_cache_enabled, + ), + reasoning_enabled=_coerce_bool( + item.get("reasoning_enabled"), + primary_config.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + item.get("reasoning_effort"), + primary_config.reasoning_effort, + ), + request_params=merge_request_params( + primary_config.request_params, + item.get("request_params"), + ), + ) + ) + + return ModelPool(enabled=enabled, strategy=strategy, models=entries) + + +def _parse_embedding_model_config(data: dict[str, Any]) -> EmbeddingModelConfig: + return EmbeddingModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "embedding", "api_url"), "EMBEDDING_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "embedding", "api_key"), "EMBEDDING_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "embedding", "model_name"), "EMBEDDING_MODEL_NAME" + ), + "", + ), + queue_interval_seconds=_normalize_queue_interval( + _coerce_float( + _get_value( + data, ("models", "embedding", "queue_interval_seconds"), None + ), + 0.0, + ), + 0.0, + ), + dimensions=_coerce_int( + _get_value(data, ("models", "embedding", "dimensions"), None), 0 + ) + or None, + query_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "query_instruction"), None), "" + ), + document_instruction=_coerce_str( + _get_value(data, ("models", "embedding", "document_instruction"), None), + "", + ), + request_params=_get_model_request_params(data, "embedding"), + ) + + +def _parse_rerank_model_config(data: dict[str, Any]) -> RerankModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value(data, ("models", "rerank", "queue_interval_seconds"), None), + 0.0, + ), + 0.0, + ) + return RerankModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "rerank", "api_url"), "RERANK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "rerank", "api_key"), "RERANK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "rerank", "model_name"), "RERANK_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + query_instruction=_coerce_str( + _get_value(data, ("models", "rerank", "query_instruction"), None), "" + ), + request_params=_get_model_request_params(data, "rerank"), + ) + + +def _parse_chat_model_config(data: dict[str, Any]) -> ChatModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "chat", "queue_interval_seconds"), + "CHAT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="chat", + include_budget_env_key="CHAT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="CHAT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="CHAT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "chat", "CHAT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "chat", "CHAT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "chat", "CHAT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "prompt_cache_enabled"), + "CHAT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "chat", "reasoning_enabled"), + "CHAT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "chat", "reasoning_effort"), + "CHAT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = ChatModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "chat", "api_url"), "CHAT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "chat", "api_key"), "CHAT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "chat", "model_name"), "CHAT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "chat", "max_tokens"), "CHAT_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "chat", "thinking_enabled"), + "CHAT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "chat", "thinking_budget_tokens"), + "CHAT_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "chat", "reasoning_effort_style"), + "CHAT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "chat"), + ) + config.pool = _parse_model_pool(data, "chat", config) + return config + + +def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "vision", "queue_interval_seconds"), + "VISION_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="vision", + include_budget_env_key="VISION_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="VISION_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="VISION_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "vision", "VISION_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "vision", "VISION_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "vision", "VISION_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "prompt_cache_enabled"), + "VISION_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "vision", "reasoning_enabled"), + "VISION_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "vision", "reasoning_effort"), + "VISION_MODEL_REASONING_EFFORT", + ), + "medium", + ) + return VisionModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "vision", "api_url"), "VISION_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "vision", "api_key"), "VISION_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), + "", + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "vision", "thinking_enabled"), + "VISION_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "thinking_budget_tokens"), + "VISION_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "vision", "reasoning_effort_style"), + "VISION_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "vision"), + ) + + +def _parse_security_model_config( + data: dict[str, Any], chat_model: ChatModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "security", "api_url"), "SECURITY_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "security", "api_key"), "SECURITY_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "security", "model_name"), "SECURITY_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "security", "queue_interval_seconds"), + "SECURITY_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="security", + include_budget_env_key="SECURITY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SECURITY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SECURITY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "security", "SECURITY_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "security", "SECURITY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "security", "SECURITY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "prompt_cache_enabled"), + "SECURITY_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "security", "reasoning_enabled"), + "SECURITY_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "security", "reasoning_effort"), + "SECURITY_MODEL_REASONING_EFFORT", + ), + "medium", + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "max_tokens"), + "SECURITY_MODEL_MAX_TOKENS", + ), + 100, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "security", "thinking_enabled"), + "SECURITY_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "security", "thinking_budget_tokens"), + "SECURITY_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "security", "reasoning_effort_style"), + "SECURITY_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "security"), + ) + + logger.warning("未配置安全模型,将使用对话模型作为后备") + return SecurityModelConfig( + api_url=chat_model.api_url, + api_key=chat_model.api_key, + model_name=chat_model.model_name, + max_tokens=chat_model.max_tokens, + queue_interval_seconds=chat_model.queue_interval_seconds, + api_mode=chat_model.api_mode, + thinking_enabled=False, + thinking_budget_tokens=0, + thinking_include_budget=True, + reasoning_effort_style="openai", + thinking_tool_call_compat=chat_model.thinking_tool_call_compat, + responses_tool_choice_compat=chat_model.responses_tool_choice_compat, + responses_force_stateless_replay=chat_model.responses_force_stateless_replay, + prompt_cache_enabled=chat_model.prompt_cache_enabled, + reasoning_enabled=chat_model.reasoning_enabled, + reasoning_effort=chat_model.reasoning_effort, + request_params=merge_request_params(chat_model.request_params), + ) + + +def _parse_naga_model_config( + data: dict[str, Any], security_model: SecurityModelConfig +) -> SecurityModelConfig: + api_url = _coerce_str( + _get_value(data, ("models", "naga", "api_url"), "NAGA_MODEL_API_URL"), + "", + ) + api_key = _coerce_str( + _get_value(data, ("models", "naga", "api_key"), "NAGA_MODEL_API_KEY"), + "", + ) + model_name = _coerce_str( + _get_value(data, ("models", "naga", "model_name"), "NAGA_MODEL_NAME"), + "", + ) + queue_interval_seconds = _coerce_float( + _get_value( + data, + ("models", "naga", "queue_interval_seconds"), + "NAGA_MODEL_QUEUE_INTERVAL", + ), + security_model.queue_interval_seconds, + ) + queue_interval_seconds = _normalize_queue_interval(queue_interval_seconds) + + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="naga", + include_budget_env_key="NAGA_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="NAGA_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="NAGA_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "naga", "NAGA_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "naga", "NAGA_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "naga", "NAGA_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "prompt_cache_enabled"), + "NAGA_MODEL_PROMPT_CACHE_ENABLED", + ), + getattr(security_model, "prompt_cache_enabled", True), + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "naga", "reasoning_enabled"), + "NAGA_MODEL_REASONING_ENABLED", + ), + getattr(security_model, "reasoning_enabled", False), + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "naga", "reasoning_effort"), + "NAGA_MODEL_REASONING_EFFORT", + ), + getattr(security_model, "reasoning_effort", "medium"), + ) + + if api_url and api_key and model_name: + return SecurityModelConfig( + api_url=api_url, + api_key=api_key, + model_name=model_name, + max_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "max_tokens"), + "NAGA_MODEL_MAX_TOKENS", + ), + 160, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "naga", "thinking_enabled"), + "NAGA_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "naga", "thinking_budget_tokens"), + "NAGA_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "naga", "reasoning_effort_style"), + "NAGA_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "naga"), + ) + + logger.info( + "未配置 Naga 审核模型,将使用已解析的安全模型配置作为后备(安全模型本身可能已回退)" + ) + return SecurityModelConfig( + api_url=security_model.api_url, + api_key=security_model.api_key, + model_name=security_model.model_name, + max_tokens=security_model.max_tokens, + queue_interval_seconds=security_model.queue_interval_seconds, + api_mode=security_model.api_mode, + thinking_enabled=security_model.thinking_enabled, + thinking_budget_tokens=security_model.thinking_budget_tokens, + thinking_include_budget=security_model.thinking_include_budget, + reasoning_effort_style=security_model.reasoning_effort_style, + thinking_tool_call_compat=security_model.thinking_tool_call_compat, + responses_tool_choice_compat=security_model.responses_tool_choice_compat, + responses_force_stateless_replay=security_model.responses_force_stateless_replay, + prompt_cache_enabled=security_model.prompt_cache_enabled, + reasoning_enabled=security_model.reasoning_enabled, + reasoning_effort=security_model.reasoning_effort, + request_params=merge_request_params(security_model.request_params), + ) + + +def _parse_agent_model_config(data: dict[str, Any]) -> AgentModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "agent", "queue_interval_seconds"), + "AGENT_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data=data, + model_name="agent", + include_budget_env_key="AGENT_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="AGENT_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="AGENT_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode(data, "agent", "AGENT_MODEL_API_MODE") + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + data, "agent", "AGENT_MODEL_RESPONSES_TOOL_CHOICE_COMPAT" + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + data, "agent", "AGENT_MODEL_RESPONSES_FORCE_STATELESS_REPLAY" + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "prompt_cache_enabled"), + "AGENT_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ) + reasoning_enabled = _coerce_bool( + _get_value( + data, + ("models", "agent", "reasoning_enabled"), + "AGENT_MODEL_REASONING_ENABLED", + ), + False, + ) + reasoning_effort = _resolve_reasoning_effort( + _get_value( + data, + ("models", "agent", "reasoning_effort"), + "AGENT_MODEL_REASONING_EFFORT", + ), + "medium", + ) + config = AgentModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "agent", "api_url"), "AGENT_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "agent", "api_key"), "AGENT_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "agent", "model_name"), "AGENT_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value( + data, ("models", "agent", "max_tokens"), "AGENT_MODEL_MAX_TOKENS" + ), + 4096, + ), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "agent", "thinking_enabled"), + "AGENT_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "agent", "thinking_budget_tokens"), + "AGENT_MODEL_THINKING_BUDGET_TOKENS", + ), + 0, + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "agent", "reasoning_effort_style"), + "AGENT_MODEL_REASONING_EFFORT_STYLE", + ), + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=reasoning_enabled, + reasoning_effort=reasoning_effort, + request_params=_get_model_request_params(data, "agent"), + ) + config.pool = _parse_model_pool(data, "agent", config) + return config + + +def _parse_grok_model_config(data: dict[str, Any]) -> GrokModelConfig: + queue_interval_seconds = _normalize_queue_interval( + _coerce_float( + _get_value( + data, + ("models", "grok", "queue_interval_seconds"), + "GROK_MODEL_QUEUE_INTERVAL", + ), + 1.0, + ) + ) + return GrokModelConfig( + api_url=_coerce_str( + _get_value(data, ("models", "grok", "api_url"), "GROK_MODEL_API_URL"), + "", + ), + api_key=_coerce_str( + _get_value(data, ("models", "grok", "api_key"), "GROK_MODEL_API_KEY"), + "", + ), + model_name=_coerce_str( + _get_value(data, ("models", "grok", "model_name"), "GROK_MODEL_NAME"), + "", + ), + max_tokens=_coerce_int( + _get_value(data, ("models", "grok", "max_tokens"), "GROK_MODEL_MAX_TOKENS"), + 8192, + ), + queue_interval_seconds=queue_interval_seconds, + thinking_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_enabled"), + "GROK_MODEL_THINKING_ENABLED", + ), + False, + ), + thinking_budget_tokens=_coerce_int( + _get_value( + data, + ("models", "grok", "thinking_budget_tokens"), + "GROK_MODEL_THINKING_BUDGET_TOKENS", + ), + 20000, + ), + thinking_include_budget=_coerce_bool( + _get_value( + data, + ("models", "grok", "thinking_include_budget"), + "GROK_MODEL_THINKING_INCLUDE_BUDGET", + ), + True, + ), + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + data, + ("models", "grok", "reasoning_effort_style"), + "GROK_MODEL_REASONING_EFFORT_STYLE", + ), + ), + prompt_cache_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "prompt_cache_enabled"), + "GROK_MODEL_PROMPT_CACHE_ENABLED", + ), + True, + ), + reasoning_enabled=_coerce_bool( + _get_value( + data, + ("models", "grok", "reasoning_enabled"), + "GROK_MODEL_REASONING_ENABLED", + ), + False, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + data, + ("models", "grok", "reasoning_effort"), + "GROK_MODEL_REASONING_EFFORT", + ), + "medium", + ), + request_params=_get_model_request_params(data, "grok"), + ) + + +def _parse_image_gen_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_gen] 生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_url"), "IMAGE_GEN_MODEL_API_URL" + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, ("models", "image_gen", "api_key"), "IMAGE_GEN_MODEL_API_KEY" + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, ("models", "image_gen", "model_name"), "IMAGE_GEN_MODEL_NAME" + ), + "", + ), + request_params=_get_model_request_params(data, "image_gen"), + ) + + +def _parse_image_edit_model_config(data: dict[str, Any]) -> ImageGenModelConfig: + """解析 [models.image_edit] 参考图生图模型配置""" + return ImageGenModelConfig( + api_url=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_url"), + "IMAGE_EDIT_MODEL_API_URL", + ), + "", + ), + api_key=_coerce_str( + _get_value( + data, + ("models", "image_edit", "api_key"), + "IMAGE_EDIT_MODEL_API_KEY", + ), + "", + ), + model_name=_coerce_str( + _get_value( + data, + ("models", "image_edit", "model_name"), + "IMAGE_EDIT_MODEL_NAME", + ), + "", + ), + request_params=_get_model_request_params(data, "image_edit"), + ) + + +def _parse_image_gen_config(data: dict[str, Any]) -> ImageGenConfig: + """解析 [image_gen] 生图工具配置""" + return ImageGenConfig( + provider=_coerce_str( + _get_value(data, ("image_gen", "provider"), "IMAGE_GEN_PROVIDER"), + "xingzhige", + ), + xingzhige_size=_coerce_str( + _get_value(data, ("image_gen", "xingzhige_size"), None), "1:1" + ), + openai_size=_coerce_str( + _get_value(data, ("image_gen", "openai_size"), None), "" + ), + openai_quality=_coerce_str( + _get_value(data, ("image_gen", "openai_quality"), None), "" + ), + openai_style=_coerce_str( + _get_value(data, ("image_gen", "openai_style"), None), "" + ), + openai_timeout=_coerce_float( + _get_value(data, ("image_gen", "openai_timeout"), None), 120.0 + ), + ) + + +def _merge_admins(superadmin_qq: int, admin_qqs: list[int]) -> tuple[int, list[int]]: + local_admins = load_local_admins() + all_admins = list(set(admin_qqs + local_admins)) + if superadmin_qq and superadmin_qq not in all_admins: + all_admins.append(superadmin_qq) + return superadmin_qq, all_admins + + +def _verify_required_fields( + bot_qq: int, + superadmin_qq: int, + onebot_ws_url: str, + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + agent_model: AgentModelConfig, + knowledge_enabled: bool, + embedding_model: EmbeddingModelConfig, +) -> None: + missing: list[str] = [] + if bot_qq <= 0: + missing.append("core.bot_qq") + if superadmin_qq <= 0: + missing.append("core.superadmin_qq") + if not onebot_ws_url: + missing.append("onebot.ws_url") + if not chat_model.api_url: + missing.append("models.chat.api_url") + if not chat_model.api_key: + missing.append("models.chat.api_key") + if not chat_model.model_name: + missing.append("models.chat.model_name") + if not vision_model.api_url: + missing.append("models.vision.api_url") + if not vision_model.api_key: + missing.append("models.vision.api_key") + if not vision_model.model_name: + missing.append("models.vision.model_name") + if not agent_model.api_url: + missing.append("models.agent.api_url") + if not agent_model.api_key: + missing.append("models.agent.api_key") + if not agent_model.model_name: + missing.append("models.agent.model_name") + if knowledge_enabled: + if not embedding_model.api_url: + missing.append("models.embedding.api_url") + if not embedding_model.model_name: + missing.append("models.embedding.model_name") + if missing: + raise ValueError(f"缺少必需配置: {', '.join(missing)}") + + +def _log_debug_info( + chat_model: ChatModelConfig, + vision_model: VisionModelConfig, + security_model: SecurityModelConfig, + naga_model: SecurityModelConfig, + agent_model: AgentModelConfig, + summary_model: AgentModelConfig, + grok_model: GrokModelConfig, +) -> None: + configs: list[ + tuple[ + str, + ChatModelConfig + | VisionModelConfig + | SecurityModelConfig + | AgentModelConfig + | GrokModelConfig, + ] + ] = [ + ("chat", chat_model), + ("vision", vision_model), + ("security", security_model), + ("naga", naga_model), + ("agent", agent_model), + ("summary", summary_model), + ("grok", grok_model), + ] + for name, cfg in configs: + logger.debug( + "[配置] %s_model=%s api_url=%s api_key_set=%s api_mode=%s thinking=%s reasoning=%s/%s cot_compat=%s responses_tool_choice_compat=%s responses_force_stateless_replay=%s", + name, + cfg.model_name, + cfg.api_url, + bool(cfg.api_key), + getattr(cfg, "api_mode", "chat_completions"), + cfg.thinking_enabled, + getattr(cfg, "reasoning_enabled", False), + getattr(cfg, "reasoning_effort", "medium"), + getattr(cfg, "thinking_tool_call_compat", False), + getattr(cfg, "responses_tool_choice_compat", False), + getattr(cfg, "responses_force_stateless_replay", False), + ) + + +def _parse_historian_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> AgentModelConfig: + h = data.get("models", {}).get("historian", {}) + if not isinstance(h, dict) or not h: + return fallback + queue_interval_seconds = _coerce_float( + h.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"historian": h}}, + model_name="historian", + include_budget_env_key="HISTORIAN_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="HISTORIAN_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="HISTORIAN_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"historian": h}}, + "historian", + "HISTORIAN_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "prompt_cache_enabled"), + "HISTORIAN_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return AgentModelConfig( + api_url=_coerce_str(h.get("api_url"), fallback.api_url), + api_key=_coerce_str(h.get("api_key"), fallback.api_key), + model_name=_coerce_str(h.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(h.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + h.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + h.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort_style"), + "HISTORIAN_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_enabled"), + "HISTORIAN_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"historian": h}}, + ("models", "historian", "reasoning_effort"), + "HISTORIAN_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + h.get("request_params"), + ), + ) + + +def _parse_summary_model_config( + data: dict[str, Any], fallback: AgentModelConfig +) -> tuple[AgentModelConfig, bool]: + s = data.get("models", {}).get("summary", {}) + if not isinstance(s, dict) or not s: + return fallback, False + queue_interval_seconds = _coerce_float( + s.get("queue_interval_seconds"), fallback.queue_interval_seconds + ) + queue_interval_seconds = _normalize_queue_interval( + queue_interval_seconds, fallback.queue_interval_seconds + ) + thinking_include_budget, thinking_tool_call_compat = _resolve_thinking_compat_flags( + data={"models": {"summary": s}}, + model_name="summary", + include_budget_env_key="SUMMARY_MODEL_THINKING_INCLUDE_BUDGET", + tool_call_compat_env_key="SUMMARY_MODEL_THINKING_TOOL_CALL_COMPAT", + legacy_env_key="SUMMARY_MODEL_DEEPSEEK_NEW_COT_SUPPORT", + ) + api_mode = _resolve_api_mode( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_API_MODE", + fallback.api_mode, + ) + responses_tool_choice_compat = _resolve_responses_tool_choice_compat( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_TOOL_CHOICE_COMPAT", + fallback.responses_tool_choice_compat, + ) + responses_force_stateless_replay = _resolve_responses_force_stateless_replay( + {"models": {"summary": s}}, + "summary", + "SUMMARY_MODEL_RESPONSES_FORCE_STATELESS_REPLAY", + fallback.responses_force_stateless_replay, + ) + prompt_cache_enabled = _coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "prompt_cache_enabled"), + "SUMMARY_MODEL_PROMPT_CACHE_ENABLED", + ), + fallback.prompt_cache_enabled, + ) + return ( + AgentModelConfig( + api_url=_coerce_str(s.get("api_url"), fallback.api_url), + api_key=_coerce_str(s.get("api_key"), fallback.api_key), + model_name=_coerce_str(s.get("model_name"), fallback.model_name), + max_tokens=_coerce_int(s.get("max_tokens"), fallback.max_tokens), + queue_interval_seconds=queue_interval_seconds, + api_mode=api_mode, + thinking_enabled=_coerce_bool( + s.get("thinking_enabled"), fallback.thinking_enabled + ), + thinking_budget_tokens=_coerce_int( + s.get("thinking_budget_tokens"), fallback.thinking_budget_tokens + ), + thinking_include_budget=thinking_include_budget, + reasoning_effort_style=_resolve_reasoning_effort_style( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort_style"), + "SUMMARY_MODEL_REASONING_EFFORT_STYLE", + ), + fallback.reasoning_effort_style, + ), + thinking_tool_call_compat=thinking_tool_call_compat, + responses_tool_choice_compat=responses_tool_choice_compat, + responses_force_stateless_replay=responses_force_stateless_replay, + prompt_cache_enabled=prompt_cache_enabled, + reasoning_enabled=_coerce_bool( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_enabled"), + "SUMMARY_MODEL_REASONING_ENABLED", + ), + fallback.reasoning_enabled, + ), + reasoning_effort=_resolve_reasoning_effort( + _get_value( + {"models": {"summary": s}}, + ("models", "summary", "reasoning_effort"), + "SUMMARY_MODEL_REASONING_EFFORT", + ), + fallback.reasoning_effort, + ), + request_params=merge_request_params( + fallback.request_params, + s.get("request_params"), + ), + ), + True, + ) diff --git a/src/Undefined/config/resolvers.py b/src/Undefined/config/resolvers.py new file mode 100644 index 0000000..43533c3 --- /dev/null +++ b/src/Undefined/config/resolvers.py @@ -0,0 +1,106 @@ +"""Configuration value resolution helpers.""" + +from __future__ import annotations + +from typing import Any + +from .coercers import ( + _coerce_bool, + _coerce_str, + _get_value, + _VALID_API_MODES, + _VALID_REASONING_EFFORT_STYLES, +) + + +def _resolve_reasoning_effort_style(value: Any, default: str = "openai") -> str: + style = _coerce_str(value, default).strip().lower() + if style not in _VALID_REASONING_EFFORT_STYLES: + return default + return style + + +def _resolve_thinking_compat_flags( + data: dict[str, Any], + model_name: str, + include_budget_env_key: str, + tool_call_compat_env_key: str, + legacy_env_key: str, +) -> tuple[bool, bool]: + """解析思维链兼容配置,并兼容旧字段 deepseek_new_cot_support。""" + include_budget_value = _get_value( + data, + ("models", model_name, "thinking_include_budget"), + include_budget_env_key, + ) + tool_call_compat_value = _get_value( + data, + ("models", model_name, "thinking_tool_call_compat"), + tool_call_compat_env_key, + ) + legacy_value = _get_value( + data, + ("models", model_name, "deepseek_new_cot_support"), + legacy_env_key, + ) + + include_budget_default = True + tool_call_compat_default = True + if legacy_value is not None: + legacy_enabled = _coerce_bool(legacy_value, False) + include_budget_default = not legacy_enabled + tool_call_compat_default = legacy_enabled + + return ( + _coerce_bool(include_budget_value, include_budget_default), + _coerce_bool(tool_call_compat_value, tool_call_compat_default), + ) + + +def _resolve_api_mode( + data: dict[str, Any], + model_name: str, + env_key: str, + default: str = "chat_completions", +) -> str: + raw_value = _get_value(data, ("models", model_name, "api_mode"), env_key) + value = _coerce_str(raw_value, default).strip().lower() + if value not in _VALID_API_MODES: + return default + return value + + +def _resolve_reasoning_effort(value: Any, default: str = "medium") -> str: + return _coerce_str(value, default).strip().lower() + + +def _resolve_responses_tool_choice_compat( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_tool_choice_compat"), + env_key, + ), + default, + ) + + +def _resolve_responses_force_stateless_replay( + data: dict[str, Any], + model_name: str, + env_key: str, + default: bool = False, +) -> bool: + return _coerce_bool( + _get_value( + data, + ("models", model_name, "responses_force_stateless_replay"), + env_key, + ), + default, + ) diff --git a/src/Undefined/config/webui_settings.py b/src/Undefined/config/webui_settings.py new file mode 100644 index 0000000..45a5392 --- /dev/null +++ b/src/Undefined/config/webui_settings.py @@ -0,0 +1,61 @@ +"""WebUI settings management.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +from .coercers import _coerce_int, _coerce_str, _normalize_str, _get_value + +DEFAULT_WEBUI_URL = "127.0.0.1" +DEFAULT_WEBUI_PORT = 8787 +DEFAULT_WEBUI_PASSWORD = "changeme" + + +@dataclass +class WebUISettings: + url: str + port: int + password: str + using_default_password: bool + config_exists: bool + + @property + def display_url(self) -> str: + """用于日志和展示的格式化 URL。""" + from Undefined.config.models import format_netloc + + return f"http://{format_netloc(self.url or '0.0.0.0', self.port)}" + + +def load_webui_settings(config_path: Optional[Path] = None) -> WebUISettings: + from .loader import load_toml_data # lazy to avoid circular + + data = load_toml_data(config_path) + config_exists = bool(data) + url_value = _get_value(data, ("webui", "url"), None) + port_value = _get_value(data, ("webui", "port"), None) + password_value = _get_value(data, ("webui", "password"), None) + + url = _coerce_str(url_value, DEFAULT_WEBUI_URL) + port = _coerce_int(port_value, DEFAULT_WEBUI_PORT) + if port <= 0 or port > 65535: + port = DEFAULT_WEBUI_PORT + + password_normalized = _normalize_str(password_value) + if not password_normalized: + return WebUISettings( + url=url, + port=port, + password=DEFAULT_WEBUI_PASSWORD, + using_default_password=True, + config_exists=config_exists, + ) + return WebUISettings( + url=url, + port=port, + password=password_normalized, + using_default_password=False, + config_exists=config_exists, + ) diff --git a/tests/test_config_cognitive_historian_limits.py b/tests/test_config_cognitive_historian_limits.py index c7f843f..bb58f56 100644 --- a/tests/test_config_cognitive_historian_limits.py +++ b/tests/test_config_cognitive_historian_limits.py @@ -1,10 +1,10 @@ from __future__ import annotations -from Undefined.config.loader import Config +from Undefined.config.domain_parsers import _parse_cognitive_config def test_parse_cognitive_historian_reference_limits() -> None: - cfg = Config._parse_cognitive_config( + cfg = _parse_cognitive_config( { "cognitive": { "query": { @@ -32,7 +32,7 @@ def test_parse_cognitive_historian_reference_limits() -> None: def test_parse_cognitive_historian_reference_limits_defaults() -> None: - cfg = Config._parse_cognitive_config({"cognitive": {}}) + cfg = _parse_cognitive_config({"cognitive": {}}) assert cfg.historian_recent_messages_inject_k == 12 assert cfg.historian_recent_message_line_max_len == 240 From 30b9a26dd6846b9900405fdad50b9275ebe26f10 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:09:43 +0800 Subject: [PATCH 31/57] refactor(api): extract helpers, probes, and OpenAPI from app.py Split 3077-line monolith into focused modules: - _helpers.py: utility classes and functions (~345 lines) - _probes.py: HTTP/WS endpoint health probes (~145 lines) - _openapi.py: OpenAPI spec builder (~180 lines) app.py retains RuntimeAPIContext + RuntimeAPIServer (~2500 lines). TYPE_CHECKING guard prevents circular imports for _openapi. All 1429 tests pass, mypy strict clean. Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/api/_helpers.py | 333 +++++++++++++ src/Undefined/api/_openapi.py | 183 +++++++ src/Undefined/api/_probes.py | 147 ++++++ src/Undefined/api/app.py | 652 ++----------------------- tests/test_runtime_api_sender_proxy.py | 2 +- tests/test_runtime_api_tool_invoke.py | 2 +- tests/test_webui_management_api.py | 10 +- 7 files changed, 703 insertions(+), 626 deletions(-) create mode 100644 src/Undefined/api/_helpers.py create mode 100644 src/Undefined/api/_openapi.py create mode 100644 src/Undefined/api/_probes.py diff --git a/src/Undefined/api/_helpers.py b/src/Undefined/api/_helpers.py new file mode 100644 index 0000000..40dcfb0 --- /dev/null +++ b/src/Undefined/api/_helpers.py @@ -0,0 +1,333 @@ +"""Helper classes and utility functions for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Callable +from urllib.parse import urlsplit + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.config import load_webui_settings +from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin + +logger = logging.getLogger(__name__) + +_AUTH_HEADER = "X-Undefined-API-Key" +_VIRTUAL_USER_ID = 42 + + +class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): + """由 Runtime API 工具调用超时包装器抛出的超时异常。""" + + +@dataclass +class _NagaRequestResult: + payload_hash: str + status: int + payload: dict[str, Any] + finished_at: float + + +class _WebUIVirtualSender: + """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" + + def __init__( + self, + virtual_user_id: int, + send_private_callback: Callable[[int, str], Awaitable[None]], + onebot: Any = None, + ) -> None: + self._virtual_user_id = virtual_user_id + self._send_private_callback = send_private_callback + # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 + self.onebot = onebot + + async def send_private_message( + self, + user_id: int, + message: str, + auto_history: bool = True, + *, + mark_sent: bool = True, + reply_to: int | None = None, + preferred_temp_group_id: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + user_id, + auto_history, + mark_sent, + reply_to, + preferred_temp_group_id, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_group_message( + self, + group_id: int, + message: str, + auto_history: bool = True, + history_prefix: str = "", + *, + mark_sent: bool = True, + reply_to: int | None = None, + history_message: str | None = None, + ) -> int | None: + _ = ( + group_id, + auto_history, + history_prefix, + mark_sent, + reply_to, + history_message, + ) + await self._send_private_callback(self._virtual_user_id, message) + return None + + async def send_private_file( + self, + user_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" + _ = user_id, auto_history + import shutil + import uuid as _uuid + from pathlib import Path as _Path + + from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir + + src = _Path(file_path) + display_name = name or src.name + file_id = _uuid.uuid4().hex + dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) + dest = dest_dir / display_name + + def _copy_and_stat() -> int: + shutil.copy2(src, dest) + return dest.stat().st_size + + try: + file_size = await asyncio.to_thread(_copy_and_stat) + except OSError: + file_size = 0 + + message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" + await self._send_private_callback(self._virtual_user_id, message) + + async def send_group_file( + self, + group_id: int, + file_path: str, + name: str | None = None, + auto_history: bool = True, + ) -> None: + """群文件在虚拟会话中同样重定向为文本消息。""" + await self.send_private_file(group_id, file_path, name, auto_history) + + +def _json_error(message: str, status: int = 400) -> Response: + return web.json_response({"error": message}, status=status) + + +def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: + origin = normalize_origin(str(request.headers.get("Origin") or "")) + settings = load_webui_settings() + response.headers.setdefault("Vary", "Origin") + response.headers.setdefault( + "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" + ) + response.headers.setdefault( + "Access-Control-Allow-Headers", + "Authorization, Content-Type, X-Undefined-API-Key", + ) + response.headers.setdefault("Access-Control-Max-Age", "86400") + if origin and is_allowed_cors_origin( + origin, + configured_host=str(settings.url or ""), + configured_port=settings.port, + ): + response.headers.setdefault("Access-Control-Allow-Origin", origin) + response.headers.setdefault("Access-Control-Allow-Credentials", "true") + + +def _optional_query_param(request: web.Request, key: str) -> str | None: + raw = request.query.get(key) + if raw is None: + return None + text = str(raw).strip() + if not text: + return None + return text + + +def _parse_query_time(value: str | None) -> datetime | None: + text = str(value or "").strip() + if not text: + return None + candidates = [text, text.replace("Z", "+00:00")] + if "T" in text: + candidates.append(text.replace("T", " ")) + for candidate in candidates: + with suppress(ValueError): + return datetime.fromisoformat(candidate) + return None + + +def _to_bool(value: Any) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return value != 0 + text = str(value or "").strip().lower() + return text in {"1", "true", "yes", "on"} + + +def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: + return { + "mode": mode, + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "messages": outputs, + "reply": "\n\n".join(outputs).strip(), + } + + +def _sse_event(event: str, payload: dict[str, Any]) -> bytes: + data = json.dumps(payload, ensure_ascii=False) + return f"event: {event}\ndata: {data}\n\n".encode("utf-8") + + +def _mask_url(url: str) -> str: + """保留 scheme + host,隐藏 path 细节。""" + text = str(url or "").strip().rstrip("/") + if not text: + return "" + parsed = urlsplit(text) + host = parsed.hostname or "" + port_part = f":{parsed.port}" if parsed.port else "" + scheme = parsed.scheme or "https" + return f"{scheme}://{host}{port_part}/..." + + +def _naga_runtime_enabled(cfg: Any) -> bool: + naga_cfg = getattr(cfg, "naga", None) + return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( + getattr(naga_cfg, "enabled", False) + ) + + +def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: + return _naga_runtime_enabled(cfg) and naga_store is not None + + +def _short_text_preview(text: str, limit: int = 80) -> str: + compact = " ".join(str(text or "").split()) + if len(compact) <= limit: + return compact + return compact[:limit] + "..." + + +def _naga_message_digest( + *, + bind_uuid: str, + naga_id: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, +) -> str: + raw = json.dumps( + { + "bind_uuid": bind_uuid, + "naga_id": naga_id, + "target_qq": target_qq, + "target_group": target_group, + "mode": mode, + "format": message_format, + "content": content, + }, + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + ) + return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] + + +def _parse_response_payload(response: Response) -> dict[str, Any]: + text = response.text or "" + if not text: + return {} + payload = json.loads(text) + return payload if isinstance(payload, dict) else {"data": payload} + + +def _registry_summary(registry: Any) -> dict[str, Any]: + """从 BaseRegistry 提取轻量摘要。""" + if registry is None: + return {"count": 0, "loaded": 0, "items": []} + items: dict[str, Any] = getattr(registry, "_items", {}) + stats: dict[str, Any] = {} + get_stats = getattr(registry, "get_stats", None) + if callable(get_stats): + stats = get_stats() + summary_items: list[dict[str, Any]] = [] + for name, item in items.items(): + st = stats.get(name) + entry: dict[str, Any] = { + "name": name, + "loaded": getattr(item, "loaded", False), + } + if st is not None: + entry["calls"] = getattr(st, "count", 0) + entry["success"] = getattr(st, "success", 0) + entry["failure"] = getattr(st, "failure", 0) + summary_items.append(entry) + return { + "count": len(items), + "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), + "items": summary_items, + } + + +def _validate_callback_url(url: str) -> str | None: + """校验回调 URL,返回错误信息或 None 表示通过。 + + 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 + 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 + """ + import ipaddress + + parsed = urlsplit(url) + scheme = (parsed.scheme or "").lower() + + if scheme not in ("http", "https"): + return "callback.url must use http or https" + + hostname = parsed.hostname or "" + if not hostname: + return "callback.url must include a hostname" + + # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) + try: + addr = ipaddress.ip_address(hostname) + except ValueError: + pass # 域名形式,放行 + else: + if addr.is_private or addr.is_loopback or addr.is_link_local: + return "callback.url must not point to a private/loopback address" + + return None diff --git a/src/Undefined/api/_openapi.py b/src/Undefined/api/_openapi.py new file mode 100644 index 0000000..6c4566c --- /dev/null +++ b/src/Undefined/api/_openapi.py @@ -0,0 +1,183 @@ +"""OpenAPI / Swagger specification builder.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from Undefined import __version__ +from ._helpers import _AUTH_HEADER, _naga_routes_enabled + +if TYPE_CHECKING: + from .app import RuntimeAPIContext + + +def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: + server_url = f"{request.scheme}://{request.host}" + cfg = ctx.config_getter() + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + paths: dict[str, Any] = { + "/health": { + "get": { + "summary": "Health check", + "security": [], + "responses": {"200": {"description": "OK"}}, + } + }, + "/openapi.json": { + "get": { + "summary": "OpenAPI schema", + "security": [], + "responses": {"200": {"description": "Schema JSON"}}, + } + }, + "/api/v1/probes/internal": { + "get": { + "summary": "Internal runtime probes", + "description": ( + "Returns system info (version, Python, platform, uptime), " + "OneBot connection status, request queue snapshot, " + "memory count, cognitive service status, API config, " + "skill statistics (tools/agents/anthropic_skills with call counts), " + "and model configuration (names, masked URLs, thinking flags)." + ), + } + }, + "/api/v1/probes/external": { + "get": { + "summary": "External dependency probes", + "description": ( + "Concurrently probes all configured model API endpoints " + "(chat, vision, security, naga, agent, embedding, rerank) " + "and OneBot WebSocket. Each result includes status, " + "model name, masked URL, HTTP status code, and latency." + ), + } + }, + "/api/v1/memory": { + "get": {"summary": "List/search manual memories"}, + "post": {"summary": "Create a manual memory"}, + }, + "/api/v1/memory/{uuid}": { + "patch": {"summary": "Update a manual memory by UUID"}, + "delete": {"summary": "Delete a manual memory by UUID"}, + }, + "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, + "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, + "/api/v1/memes/{uid}": { + "get": {"summary": "Get a meme by uid"}, + "patch": {"summary": "Update a meme by uid"}, + "delete": {"summary": "Delete a meme by uid"}, + }, + "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, + "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, + "/api/v1/memes/{uid}/reanalyze": { + "post": {"summary": "Queue a meme reanalyze job"} + }, + "/api/v1/memes/{uid}/reindex": { + "post": {"summary": "Queue a meme reindex job"} + }, + "/api/v1/cognitive/events": { + "get": {"summary": "Search cognitive event memories"} + }, + "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, + "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { + "get": {"summary": "Get a profile by entity type/id"} + }, + "/api/v1/chat": { + "post": { + "summary": "WebUI special private chat", + "description": ( + "POST JSON {message, stream?}. " + "When stream=true, response is SSE with keep-alive comments." + ), + } + }, + "/api/v1/chat/history": { + "get": {"summary": "Get virtual private chat history for WebUI"} + }, + "/api/v1/tools": { + "get": { + "summary": "List available tools", + "description": ( + "Returns currently available tools filtered by " + "tool_invoke_expose / allowlist / denylist config. " + "Each item follows the OpenAI function calling schema." + ), + } + }, + "/api/v1/tools/invoke": { + "post": { + "summary": "Invoke a tool", + "description": ( + "Execute a specific tool by name. Supports synchronous " + "response and optional async webhook callback." + ), + } + }, + } + + if naga_routes_enabled: + paths.update( + { + "/api/v1/naga/bind/callback": { + "post": { + "summary": "Finalize a Naga bind request", + "description": ( + "Internal callback used by Naga to approve or reject " + "a bind_uuid." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/messages/send": { + "post": { + "summary": "Send a Naga-authenticated message", + "description": ( + "Validates bind_uuid + delivery_signature, runs " + "moderation, then delivers the message. " + "Caller may provide uuid for idempotent retry deduplication." + ), + "security": [{"BearerAuth": []}], + } + }, + "/api/v1/naga/unbind": { + "post": { + "summary": "Revoke an active Naga binding", + "description": ( + "Allows Naga to proactively revoke a binding using " + "Authorization: Bearer ." + ), + "security": [{"BearerAuth": []}], + } + }, + } + ) + + return { + "openapi": "3.0.3", + "info": { + "title": "Undefined Runtime API", + "version": __version__, + "description": "API exposed by the main Undefined process.", + }, + "servers": [ + { + "url": server_url, + "description": "Runtime endpoint", + } + ], + "components": { + "securitySchemes": { + "ApiKeyAuth": { + "type": "apiKey", + "in": "header", + "name": _AUTH_HEADER, + }, + "BearerAuth": {"type": "http", "scheme": "bearer"}, + } + }, + "security": [{"ApiKeyAuth": []}], + "paths": paths, + } diff --git a/src/Undefined/api/_probes.py b/src/Undefined/api/_probes.py new file mode 100644 index 0000000..fc51bb3 --- /dev/null +++ b/src/Undefined/api/_probes.py @@ -0,0 +1,147 @@ +"""HTTP/WebSocket endpoint probe functions for health checks.""" + +from __future__ import annotations + +import asyncio +import logging +import socket +import time +from typing import Any +from urllib.parse import urlsplit + +from aiohttp import ClientSession, ClientTimeout + +from ._helpers import _mask_url + +logger = logging.getLogger(__name__) + + +async def _probe_http_endpoint( + *, + name: str, + base_url: str, + api_key: str, + model_name: str = "", + timeout_seconds: float = 5.0, +) -> dict[str, Any]: + normalized = str(base_url or "").strip().rstrip("/") + if not normalized: + return { + "name": name, + "status": "skipped", + "reason": "empty_url", + "model_name": model_name, + } + + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + candidates = [normalized, f"{normalized}/models"] + last_error = "" + for url in candidates: + start = time.perf_counter() + try: + timeout = ClientTimeout(total=timeout_seconds) + async with ClientSession(timeout=timeout) as session: + async with session.get(url, headers=headers) as resp: + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": name, + "status": "ok", + "url": _mask_url(url), + "http_status": resp.status, + "latency_ms": elapsed_ms, + "model_name": model_name, + } + except Exception as exc: + last_error = str(exc) + continue + + return { + "name": name, + "status": "error", + "url": _mask_url(normalized), + "error": last_error or "request_failed", + "model_name": model_name, + } + + +async def _skipped_probe( + *, name: str, reason: str, model_name: str = "" +) -> dict[str, Any]: + payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} + if model_name: + payload["model_name"] = model_name + return payload + + +def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: + payload = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + if hasattr(mcfg, "api_mode"): + payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") + if hasattr(mcfg, "thinking_enabled"): + payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) + if hasattr(mcfg, "thinking_tool_call_compat"): + payload["thinking_tool_call_compat"] = getattr( + mcfg, "thinking_tool_call_compat", True + ) + if hasattr(mcfg, "responses_tool_choice_compat"): + payload["responses_tool_choice_compat"] = getattr( + mcfg, "responses_tool_choice_compat", False + ) + if hasattr(mcfg, "responses_force_stateless_replay"): + payload["responses_force_stateless_replay"] = getattr( + mcfg, "responses_force_stateless_replay", False + ) + if hasattr(mcfg, "prompt_cache_enabled"): + payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) + if hasattr(mcfg, "reasoning_enabled"): + payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) + if hasattr(mcfg, "reasoning_effort"): + payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") + return payload + + +async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: + normalized = str(url or "").strip() + if not normalized: + return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} + + parsed = urlsplit(normalized) + host = parsed.hostname + if not host: + return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} + + if parsed.port is not None: + port = parsed.port + elif parsed.scheme == "wss": + port = 443 + else: + port = 80 + + start = time.perf_counter() + try: + conn = asyncio.open_connection(host, port) + reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) + writer.close() + await writer.wait_closed() + elapsed_ms = round((time.perf_counter() - start) * 1000, 2) + return { + "name": "onebot_ws", + "status": "ok", + "host": host, + "port": port, + "latency_ms": elapsed_ms, + } + except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: + return { + "name": "onebot_ws", + "status": "error", + "host": host, + "port": port, + "error": str(exc), + } diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index cbe5824..91ceaf4 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1,14 +1,13 @@ +"""Runtime API server for Undefined.""" + from __future__ import annotations import asyncio from copy import deepcopy -import hashlib -import json import logging import os import platform from pathlib import Path -import socket import sys import time import uuid as _uuid @@ -17,7 +16,7 @@ from datetime import datetime from typing import Any, Awaitable, Callable from typing import cast -from urllib.parse import urlsplit + from aiohttp import ClientSession, ClientTimeout, web from aiohttp.web_response import Response @@ -28,140 +27,51 @@ register_message_attachments, render_message_with_pic_placeholders, ) -from Undefined.config import load_webui_settings from Undefined.context import RequestContext from Undefined.context_resource_registry import collect_context_resources from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN from Undefined.utils.common import message_to_segments -from Undefined.utils.cors import is_allowed_cors_origin, normalize_origin from Undefined.utils.recent_messages import get_recent_messages_prefer_local from Undefined.utils.xml import escape_xml_attr, escape_xml_text -logger = logging.getLogger(__name__) +from ._helpers import ( + _apply_cors_headers, + _build_chat_response_payload, + _json_error, + _mask_url, + _naga_message_digest, + _naga_routes_enabled, + _naga_runtime_enabled, + _NagaRequestResult, + _optional_query_param, + _parse_query_time, + _parse_response_payload, + _registry_summary, + _short_text_preview, + _sse_event, + _to_bool, + _ToolInvokeExecutionTimeoutError, + _validate_callback_url, + _WebUIVirtualSender, + _AUTH_HEADER, + _VIRTUAL_USER_ID, +) +from ._probes import ( + _build_internal_model_probe_payload, + _probe_http_endpoint, + _probe_ws_endpoint, + _skipped_probe, +) +from ._openapi import _build_openapi_spec -_VIRTUAL_USER_ID = 42 +logger = logging.getLogger(__name__) _VIRTUAL_USER_NAME = "system" -_AUTH_HEADER = "X-Undefined-API-Key" _CHAT_SSE_KEEPALIVE_SECONDS = 10.0 _PROCESS_START_TIME = time.time() _NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 -class _ToolInvokeExecutionTimeoutError(asyncio.TimeoutError): - """由 Runtime API 工具调用超时包装器抛出的超时异常。""" - - -@dataclass -class _NagaRequestResult: - payload_hash: str - status: int - payload: dict[str, Any] - finished_at: float - - -class _WebUIVirtualSender: - """将工具发送行为重定向到 WebUI 会话,不触发 OneBot 实际发送。""" - - def __init__( - self, - virtual_user_id: int, - send_private_callback: Callable[[int, str], Awaitable[None]], - onebot: Any = None, - ) -> None: - self._virtual_user_id = virtual_user_id - self._send_private_callback = send_private_callback - # 保留 onebot 属性,兼容依赖 sender.onebot 的工具读取能力。 - self.onebot = onebot - - async def send_private_message( - self, - user_id: int, - message: str, - auto_history: bool = True, - *, - mark_sent: bool = True, - reply_to: int | None = None, - preferred_temp_group_id: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - user_id, - auto_history, - mark_sent, - reply_to, - preferred_temp_group_id, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_group_message( - self, - group_id: int, - message: str, - auto_history: bool = True, - history_prefix: str = "", - *, - mark_sent: bool = True, - reply_to: int | None = None, - history_message: str | None = None, - ) -> int | None: - _ = ( - group_id, - auto_history, - history_prefix, - mark_sent, - reply_to, - history_message, - ) - await self._send_private_callback(self._virtual_user_id, message) - return None - - async def send_private_file( - self, - user_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """将文件拷贝到 WebUI 缓存并发送文件卡片消息。""" - _ = user_id, auto_history - import shutil - import uuid as _uuid - from pathlib import Path as _Path - - from Undefined.utils.paths import WEBUI_FILE_CACHE_DIR, ensure_dir - - src = _Path(file_path) - display_name = name or src.name - file_id = _uuid.uuid4().hex - dest_dir = ensure_dir(WEBUI_FILE_CACHE_DIR / file_id) - dest = dest_dir / display_name - - def _copy_and_stat() -> int: - shutil.copy2(src, dest) - return dest.stat().st_size - - try: - file_size = await asyncio.to_thread(_copy_and_stat) - except OSError: - file_size = 0 - - message = f"[CQ:file,id={file_id},name={display_name},size={file_size}]" - await self._send_private_callback(self._virtual_user_id, message) - - async def send_group_file( - self, - group_id: int, - file_path: str, - name: str | None = None, - auto_history: bool = True, - ) -> None: - """群文件在虚拟会话中同样重定向为文本消息。""" - await self.send_private_file(group_id, file_path, name, auto_history) - - @dataclass class RuntimeAPIContext: config_getter: Callable[[], Any] @@ -178,502 +88,6 @@ class RuntimeAPIContext: naga_store: Any = None -def _json_error(message: str, status: int = 400) -> Response: - return web.json_response({"error": message}, status=status) - - -def _apply_cors_headers(request: web.Request, response: web.StreamResponse) -> None: - origin = normalize_origin(str(request.headers.get("Origin") or "")) - settings = load_webui_settings() - response.headers.setdefault("Vary", "Origin") - response.headers.setdefault( - "Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS" - ) - response.headers.setdefault( - "Access-Control-Allow-Headers", - "Authorization, Content-Type, X-Undefined-API-Key", - ) - response.headers.setdefault("Access-Control-Max-Age", "86400") - if origin and is_allowed_cors_origin( - origin, - configured_host=str(settings.url or ""), - configured_port=settings.port, - ): - response.headers.setdefault("Access-Control-Allow-Origin", origin) - response.headers.setdefault("Access-Control-Allow-Credentials", "true") - - -def _optional_query_param(request: web.Request, key: str) -> str | None: - raw = request.query.get(key) - if raw is None: - return None - text = str(raw).strip() - if not text: - return None - return text - - -def _parse_query_time(value: str | None) -> datetime | None: - text = str(value or "").strip() - if not text: - return None - candidates = [text, text.replace("Z", "+00:00")] - if "T" in text: - candidates.append(text.replace("T", " ")) - for candidate in candidates: - with suppress(ValueError): - return datetime.fromisoformat(candidate) - return None - - -def _to_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return value != 0 - text = str(value or "").strip().lower() - return text in {"1", "true", "yes", "on"} - - -def _build_chat_response_payload(mode: str, outputs: list[str]) -> dict[str, Any]: - return { - "mode": mode, - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "messages": outputs, - "reply": "\n\n".join(outputs).strip(), - } - - -def _sse_event(event: str, payload: dict[str, Any]) -> bytes: - data = json.dumps(payload, ensure_ascii=False) - return f"event: {event}\ndata: {data}\n\n".encode("utf-8") - - -def _mask_url(url: str) -> str: - """保留 scheme + host,隐藏 path 细节。""" - text = str(url or "").strip().rstrip("/") - if not text: - return "" - parsed = urlsplit(text) - host = parsed.hostname or "" - port_part = f":{parsed.port}" if parsed.port else "" - scheme = parsed.scheme or "https" - return f"{scheme}://{host}{port_part}/..." - - -def _naga_runtime_enabled(cfg: Any) -> bool: - naga_cfg = getattr(cfg, "naga", None) - return bool(getattr(cfg, "nagaagent_mode_enabled", False)) and bool( - getattr(naga_cfg, "enabled", False) - ) - - -def _naga_routes_enabled(cfg: Any, naga_store: Any) -> bool: - return _naga_runtime_enabled(cfg) and naga_store is not None - - -def _short_text_preview(text: str, limit: int = 80) -> str: - compact = " ".join(str(text or "").split()) - if len(compact) <= limit: - return compact - return compact[:limit] + "..." - - -def _naga_message_digest( - *, - bind_uuid: str, - naga_id: str, - target_qq: int, - target_group: int, - mode: str, - message_format: str, - content: str, -) -> str: - raw = json.dumps( - { - "bind_uuid": bind_uuid, - "naga_id": naga_id, - "target_qq": target_qq, - "target_group": target_group, - "mode": mode, - "format": message_format, - "content": content, - }, - ensure_ascii=False, - sort_keys=True, - separators=(",", ":"), - ) - return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16] - - -def _parse_response_payload(response: Response) -> dict[str, Any]: - text = response.text or "" - if not text: - return {} - payload = json.loads(text) - return payload if isinstance(payload, dict) else {"data": payload} - - -def _registry_summary(registry: Any) -> dict[str, Any]: - """从 BaseRegistry 提取轻量摘要。""" - if registry is None: - return {"count": 0, "loaded": 0, "items": []} - items: dict[str, Any] = getattr(registry, "_items", {}) - stats: dict[str, Any] = {} - get_stats = getattr(registry, "get_stats", None) - if callable(get_stats): - stats = get_stats() - summary_items: list[dict[str, Any]] = [] - for name, item in items.items(): - st = stats.get(name) - entry: dict[str, Any] = { - "name": name, - "loaded": getattr(item, "loaded", False), - } - if st is not None: - entry["calls"] = getattr(st, "count", 0) - entry["success"] = getattr(st, "success", 0) - entry["failure"] = getattr(st, "failure", 0) - summary_items.append(entry) - return { - "count": len(items), - "loaded": sum(1 for i in items.values() if getattr(i, "loaded", False)), - "items": summary_items, - } - - -def _validate_callback_url(url: str) -> str | None: - """校验回调 URL,返回错误信息或 None 表示通过。 - - 拒绝非 HTTP(S) scheme,以及直接使用私有/回环 IP 字面量的 URL 以防止 SSRF。 - 域名形式的 URL 放行(DNS 解析阶段不适合在校验函数中做阻塞调用)。 - """ - import ipaddress - - parsed = urlsplit(url) - scheme = (parsed.scheme or "").lower() - - if scheme not in ("http", "https"): - return "callback.url must use http or https" - - hostname = parsed.hostname or "" - if not hostname: - return "callback.url must include a hostname" - - # 仅检查 IP 字面量(如 http://127.0.0.1/、http://[::1]/、http://10.0.0.1/) - try: - addr = ipaddress.ip_address(hostname) - except ValueError: - pass # 域名形式,放行 - else: - if addr.is_private or addr.is_loopback or addr.is_link_local: - return "callback.url must not point to a private/loopback address" - - return None - - -async def _probe_http_endpoint( - *, - name: str, - base_url: str, - api_key: str, - model_name: str = "", - timeout_seconds: float = 5.0, -) -> dict[str, Any]: - normalized = str(base_url or "").strip().rstrip("/") - if not normalized: - return { - "name": name, - "status": "skipped", - "reason": "empty_url", - "model_name": model_name, - } - - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - candidates = [normalized, f"{normalized}/models"] - last_error = "" - for url in candidates: - start = time.perf_counter() - try: - timeout = ClientTimeout(total=timeout_seconds) - async with ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers) as resp: - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": name, - "status": "ok", - "url": _mask_url(url), - "http_status": resp.status, - "latency_ms": elapsed_ms, - "model_name": model_name, - } - except Exception as exc: - last_error = str(exc) - continue - - return { - "name": name, - "status": "error", - "url": _mask_url(normalized), - "error": last_error or "request_failed", - "model_name": model_name, - } - - -async def _skipped_probe( - *, name: str, reason: str, model_name: str = "" -) -> dict[str, Any]: - payload: dict[str, Any] = {"name": name, "status": "skipped", "reason": reason} - if model_name: - payload["model_name"] = model_name - return payload - - -def _build_internal_model_probe_payload(mcfg: Any) -> dict[str, Any]: - payload = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - if hasattr(mcfg, "api_mode"): - payload["api_mode"] = getattr(mcfg, "api_mode", "chat_completions") - if hasattr(mcfg, "thinking_enabled"): - payload["thinking_enabled"] = getattr(mcfg, "thinking_enabled", False) - if hasattr(mcfg, "thinking_tool_call_compat"): - payload["thinking_tool_call_compat"] = getattr( - mcfg, "thinking_tool_call_compat", True - ) - if hasattr(mcfg, "responses_tool_choice_compat"): - payload["responses_tool_choice_compat"] = getattr( - mcfg, "responses_tool_choice_compat", False - ) - if hasattr(mcfg, "responses_force_stateless_replay"): - payload["responses_force_stateless_replay"] = getattr( - mcfg, "responses_force_stateless_replay", False - ) - if hasattr(mcfg, "prompt_cache_enabled"): - payload["prompt_cache_enabled"] = getattr(mcfg, "prompt_cache_enabled", True) - if hasattr(mcfg, "reasoning_enabled"): - payload["reasoning_enabled"] = getattr(mcfg, "reasoning_enabled", False) - if hasattr(mcfg, "reasoning_effort"): - payload["reasoning_effort"] = getattr(mcfg, "reasoning_effort", "medium") - return payload - - -async def _probe_ws_endpoint(url: str, timeout_seconds: float = 5.0) -> dict[str, Any]: - normalized = str(url or "").strip() - if not normalized: - return {"name": "onebot_ws", "status": "skipped", "reason": "empty_url"} - - parsed = urlsplit(normalized) - host = parsed.hostname - if not host: - return {"name": "onebot_ws", "status": "error", "error": "invalid_url"} - - if parsed.port is not None: - port = parsed.port - elif parsed.scheme == "wss": - port = 443 - else: - port = 80 - - start = time.perf_counter() - try: - conn = asyncio.open_connection(host, port) - reader, writer = await asyncio.wait_for(conn, timeout=timeout_seconds) - writer.close() - await writer.wait_closed() - elapsed_ms = round((time.perf_counter() - start) * 1000, 2) - return { - "name": "onebot_ws", - "status": "ok", - "host": host, - "port": port, - "latency_ms": elapsed_ms, - } - except (OSError, TimeoutError, socket.gaierror, asyncio.TimeoutError) as exc: - return { - "name": "onebot_ws", - "status": "error", - "host": host, - "port": port, - "error": str(exc), - } - - -def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: - server_url = f"{request.scheme}://{request.host}" - cfg = ctx.config_getter() - naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) - paths: dict[str, Any] = { - "/health": { - "get": { - "summary": "Health check", - "security": [], - "responses": {"200": {"description": "OK"}}, - } - }, - "/openapi.json": { - "get": { - "summary": "OpenAPI schema", - "security": [], - "responses": {"200": {"description": "Schema JSON"}}, - } - }, - "/api/v1/probes/internal": { - "get": { - "summary": "Internal runtime probes", - "description": ( - "Returns system info (version, Python, platform, uptime), " - "OneBot connection status, request queue snapshot, " - "memory count, cognitive service status, API config, " - "skill statistics (tools/agents/anthropic_skills with call counts), " - "and model configuration (names, masked URLs, thinking flags)." - ), - } - }, - "/api/v1/probes/external": { - "get": { - "summary": "External dependency probes", - "description": ( - "Concurrently probes all configured model API endpoints " - "(chat, vision, security, naga, agent, embedding, rerank) " - "and OneBot WebSocket. Each result includes status, " - "model name, masked URL, HTTP status code, and latency." - ), - } - }, - "/api/v1/memory": { - "get": {"summary": "List/search manual memories"}, - "post": {"summary": "Create a manual memory"}, - }, - "/api/v1/memory/{uuid}": { - "patch": {"summary": "Update a manual memory by UUID"}, - "delete": {"summary": "Delete a manual memory by UUID"}, - }, - "/api/v1/memes": {"get": {"summary": "List/search meme library items"}}, - "/api/v1/memes/stats": {"get": {"summary": "Get meme library stats"}}, - "/api/v1/memes/{uid}": { - "get": {"summary": "Get a meme by uid"}, - "patch": {"summary": "Update a meme by uid"}, - "delete": {"summary": "Delete a meme by uid"}, - }, - "/api/v1/memes/{uid}/blob": {"get": {"summary": "Get meme blob file"}}, - "/api/v1/memes/{uid}/preview": {"get": {"summary": "Get meme preview file"}}, - "/api/v1/memes/{uid}/reanalyze": { - "post": {"summary": "Queue a meme reanalyze job"} - }, - "/api/v1/memes/{uid}/reindex": { - "post": {"summary": "Queue a meme reindex job"} - }, - "/api/v1/cognitive/events": { - "get": {"summary": "Search cognitive event memories"} - }, - "/api/v1/cognitive/profiles": {"get": {"summary": "Search cognitive profiles"}}, - "/api/v1/cognitive/profile/{entity_type}/{entity_id}": { - "get": {"summary": "Get a profile by entity type/id"} - }, - "/api/v1/chat": { - "post": { - "summary": "WebUI special private chat", - "description": ( - "POST JSON {message, stream?}. " - "When stream=true, response is SSE with keep-alive comments." - ), - } - }, - "/api/v1/chat/history": { - "get": {"summary": "Get virtual private chat history for WebUI"} - }, - "/api/v1/tools": { - "get": { - "summary": "List available tools", - "description": ( - "Returns currently available tools filtered by " - "tool_invoke_expose / allowlist / denylist config. " - "Each item follows the OpenAI function calling schema." - ), - } - }, - "/api/v1/tools/invoke": { - "post": { - "summary": "Invoke a tool", - "description": ( - "Execute a specific tool by name. Supports synchronous " - "response and optional async webhook callback." - ), - } - }, - } - - if naga_routes_enabled: - paths.update( - { - "/api/v1/naga/bind/callback": { - "post": { - "summary": "Finalize a Naga bind request", - "description": ( - "Internal callback used by Naga to approve or reject " - "a bind_uuid." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/messages/send": { - "post": { - "summary": "Send a Naga-authenticated message", - "description": ( - "Validates bind_uuid + delivery_signature, runs " - "moderation, then delivers the message. " - "Caller may provide uuid for idempotent retry deduplication." - ), - "security": [{"BearerAuth": []}], - } - }, - "/api/v1/naga/unbind": { - "post": { - "summary": "Revoke an active Naga binding", - "description": ( - "Allows Naga to proactively revoke a binding using " - "Authorization: Bearer ." - ), - "security": [{"BearerAuth": []}], - } - }, - } - ) - - return { - "openapi": "3.0.3", - "info": { - "title": "Undefined Runtime API", - "version": __version__, - "description": "API exposed by the main Undefined process.", - }, - "servers": [ - { - "url": server_url, - "description": "Runtime endpoint", - } - ], - "components": { - "securitySchemes": { - "ApiKeyAuth": { - "type": "apiKey", - "in": "header", - "name": _AUTH_HEADER, - }, - "BearerAuth": {"type": "http", "scheme": "bearer"}, - } - }, - "security": [{"ApiKeyAuth": []}], - "paths": paths, - } - - class RuntimeAPIServer: def __init__( self, diff --git a/tests/test_runtime_api_sender_proxy.py b/tests/test_runtime_api_sender_proxy.py index be0c82c..77fc805 100644 --- a/tests/test_runtime_api_sender_proxy.py +++ b/tests/test_runtime_api_sender_proxy.py @@ -2,7 +2,7 @@ import pytest -from Undefined.api.app import _WebUIVirtualSender +from Undefined.api._helpers import _WebUIVirtualSender @pytest.mark.asyncio diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 9db1330..16dc141 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -10,7 +10,7 @@ from aiohttp.web_response import Response from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api.app import _validate_callback_url +from Undefined.api._helpers import _validate_callback_url def _json(response: Response) -> Any: diff --git a/tests/test_webui_management_api.py b/tests/test_webui_management_api.py index 435b710..9391279 100644 --- a/tests/test_webui_management_api.py +++ b/tests/test_webui_management_api.py @@ -6,7 +6,7 @@ from aiohttp import web -from Undefined.api import app as runtime_api_app +from Undefined.api import _helpers as runtime_api_helpers from Undefined.webui import app as webui_app from Undefined.webui.app import create_app from Undefined.webui.core import SessionStore @@ -350,7 +350,7 @@ def test_webui_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: monkeypatch.setattr( - runtime_api_app, + runtime_api_helpers, "load_webui_settings", lambda: SimpleNamespace(url="127.0.0.1", port=8787), ) @@ -359,7 +359,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "tauri://localhost"})), ) trusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(trusted_request, trusted_response) + runtime_api_helpers._apply_cors_headers(trusted_request, trusted_response) assert trusted_response.headers.get("Access-Control-Allow-Origin") == ( "tauri://localhost" ) @@ -370,7 +370,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "http://localhost:1420"})), ) loopback_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(loopback_request, loopback_response) + runtime_api_helpers._apply_cors_headers(loopback_request, loopback_response) assert loopback_response.headers.get("Access-Control-Allow-Origin") == ( "http://localhost:1420" ) @@ -380,7 +380,7 @@ def test_runtime_api_cors_only_allows_trusted_origins(monkeypatch: Any) -> None: cast(Any, DummyRequest(headers={"Origin": "https://evil.example"})), ) untrusted_response = web.Response(status=200) - runtime_api_app._apply_cors_headers(untrusted_request, untrusted_response) + runtime_api_helpers._apply_cors_headers(untrusted_request, untrusted_response) assert "Access-Control-Allow-Origin" not in untrusted_response.headers assert "Access-Control-Allow-Credentials" not in untrusted_response.headers From 72ac2d9cd97f15bbadfa62e7e05bd65d9839de78 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:44:21 +0800 Subject: [PATCH 32/57] fix(queue): register historian model in queue interval builder The build_model_queue_intervals() function was missing the historian model (glm-5), causing it to fall back to the default 1.0s dispatch interval instead of the configured 0s. This slowed background historian tasks unnecessarily. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/utils/queue_intervals.py | 4 ++++ tests/test_config_hot_reload.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/src/Undefined/utils/queue_intervals.py b/src/Undefined/utils/queue_intervals.py index 3b068b8..0596d7d 100644 --- a/src/Undefined/utils/queue_intervals.py +++ b/src/Undefined/utils/queue_intervals.py @@ -24,6 +24,10 @@ def build_model_queue_intervals(config: Config) -> dict[str, float]: summary_model.queue_interval_seconds, ), (config.grok_model.model_name, config.grok_model.queue_interval_seconds), + ( + config.historian_model.model_name, + config.historian_model.queue_interval_seconds, + ), ) intervals: dict[str, float] = {} for model_name, interval in pairs: diff --git a/tests/test_config_hot_reload.py b/tests/test_config_hot_reload.py index a244162..4e0158a 100644 --- a/tests/test_config_hot_reload.py +++ b/tests/test_config_hot_reload.py @@ -62,6 +62,10 @@ def test_apply_config_updates_propagates_to_security_service() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() @@ -120,6 +124,10 @@ def test_apply_config_updates_hot_reloads_ai_request_max_retries() -> None: model_name="grok", queue_interval_seconds=1.0, ), + historian_model=SimpleNamespace( + model_name="historian", + queue_interval_seconds=1.0, + ), ), ) security_service = _FakeSecurityService() From e972c6e07d0be7a1fbac2eeb9f166c9e5caab31c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 12:57:30 +0800 Subject: [PATCH 33/57] perf(handlers): parallelize message preprocessing with asyncio.gather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Group messages: run attachment collection, group info fetch, and history content parsing concurrently instead of serially. Reduces preprocessing latency from sum(A+B+C) to max(A,B,C). Private messages: run attachment collection and history content parsing concurrently. No behavioral changes — all data is still fully prepared before any feature checks (keyword, repeat, bilibili, arxiv, commands, AI). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 77 ++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 4bd21a8..726fa5c 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -496,10 +496,19 @@ async def handle_message(self, event: dict[str, Any]) -> None: logger.warning("获取用户昵称失败: %s", exc) text = extract_text(private_message_content, self.config.bot_qq) - private_attachments = await self._collect_message_attachments( - private_message_content, - user_id=private_sender_id, - request_type="private", + # 并行执行附件收集和历史内容解析 + private_attachments, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + private_message_content, + user_id=private_sender_id, + request_type="private", + ), + parse_message_content_for_history( + private_message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), ) safe_text = redact_string(text) logger.info( @@ -515,15 +524,10 @@ async def handle_message(self, event: dict[str, Any]) -> None: sender_name=resolved_private_name, ) - # 保存私聊消息到历史记录(保存处理后的内容) - # 使用新的工具函数解析内容 - parsed_content = await parse_message_content_for_history( - private_message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, + # 保存私聊消息到历史记录 + parsed_content = append_attachment_text( + parsed_content_raw, private_attachments ) - parsed_content = append_attachment_text(parsed_content, private_attachments) safe_parsed = redact_string(parsed_content) logger.debug( "[历史记录] 保存私聊: user=%s content=%s...", @@ -648,26 +652,37 @@ async def handle_message(self, event: dict[str, Any]) -> None: # 提取文本内容 text = extract_text(message_content, self.config.bot_qq) - group_attachments = await self._collect_message_attachments( - message_content, - group_id=group_id, - request_type="group", - ) safe_text = redact_string(text) logger.info( f"[群消息] group={group_id} sender={sender_id} name={sender_card or sender_nickname} " f"role={sender_role} | {safe_text[:100]}" ) - # 保存消息到历史记录 (使用处理后的内容) - # 获取群聊名 - group_name = "" - try: - group_info = await self.onebot.get_group_info(group_id) - if group_info: - group_name = group_info.get("group_name", "") - except Exception as e: - logger.warning(f"获取群聊名失败: {e}") + # 并行执行 3 个独立的异步操作:附件收集、群信息获取、历史内容解析 + async def _fetch_group_name() -> str: + try: + info = await self.onebot.get_group_info(group_id) + if info: + return str(info.get("group_name", "") or "") + except Exception as e: + logger.warning(f"获取群聊名失败: {e}") + return "" + + group_attachments, group_name, parsed_content_raw = await asyncio.gather( + self._collect_message_attachments( + message_content, + group_id=group_id, + request_type="group", + ), + _fetch_group_name(), + parse_message_content_for_history( + message_content, + self.config.bot_qq, + self.onebot.get_msg, + self.onebot.get_forward_msg, + ), + ) + resolved_group_sender_name = (sender_card or sender_nickname or "").strip() self._schedule_profile_display_name_refresh( task_name=f"profile_name_refresh_group:{group_id}:{sender_id}", @@ -677,14 +692,8 @@ async def handle_message(self, event: dict[str, Any]) -> None: group_name=str(group_name or "").strip(), ) - # 使用新的 utils - parsed_content = await parse_message_content_for_history( - message_content, - self.config.bot_qq, - self.onebot.get_msg, - self.onebot.get_forward_msg, - ) - parsed_content = append_attachment_text(parsed_content, group_attachments) + # 保存消息到历史记录 + parsed_content = append_attachment_text(parsed_content_raw, group_attachments) safe_parsed = redact_string(parsed_content) logger.debug( f"[历史记录] 保存群聊: group={group_id}, sender={sender_id}, content={safe_parsed[:50]}..." From a3fc737a102f6b295bda84f082949dd31379d855 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 13:55:45 +0800 Subject: [PATCH 34/57] refactor(api): split app.py into route submodules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract 2491-line app.py into focused route modules under api/routes/: - health.py (23 lines) — health endpoint - system.py (228 lines) — OpenAPI + internal/external probes - memory.py (156 lines) — memory CRUD - memes.py (222 lines) — meme management (9 handlers) - cognitive.py (86 lines) — cognitive events & profiles - chat.py (359 lines) — WebUI chat with SSE streaming - tools.py (462 lines) — tool invoke with async callbacks - naga.py (897 lines) — Naga bind/send/unbind + moderation Infrastructure: - _context.py: RuntimeAPIContext dataclass (shared import root) - _naga_state.py: NagaState class (request dedup + inflight tracking) app.py reduced to 333 lines: server class + thin delegation wrappers that preserve test compatibility (no test API changes needed). Updated 4 test files to monkeypatch route-module-level symbols instead of app-module-level symbols. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/api/__init__.py | 3 +- src/Undefined/api/_context.py | 22 + src/Undefined/api/_naga_state.py | 115 ++ src/Undefined/api/_openapi.py | 2 +- src/Undefined/api/app.py | 2292 +------------------------ src/Undefined/api/routes/__init__.py | 1 + src/Undefined/api/routes/chat.py | 359 ++++ src/Undefined/api/routes/cognitive.py | 86 + src/Undefined/api/routes/health.py | 23 + src/Undefined/api/routes/memes.py | 222 +++ src/Undefined/api/routes/memory.py | 156 ++ src/Undefined/api/routes/naga.py | 897 ++++++++++ src/Undefined/api/routes/system.py | 228 +++ src/Undefined/api/routes/tools.py | 462 +++++ tests/test_runtime_api_chat_stream.py | 12 +- tests/test_runtime_api_naga.py | 8 +- tests/test_runtime_api_probes.py | 8 +- tests/test_runtime_api_tool_invoke.py | 4 +- 18 files changed, 2659 insertions(+), 2241 deletions(-) create mode 100644 src/Undefined/api/_context.py create mode 100644 src/Undefined/api/_naga_state.py create mode 100644 src/Undefined/api/routes/__init__.py create mode 100644 src/Undefined/api/routes/chat.py create mode 100644 src/Undefined/api/routes/cognitive.py create mode 100644 src/Undefined/api/routes/health.py create mode 100644 src/Undefined/api/routes/memes.py create mode 100644 src/Undefined/api/routes/memory.py create mode 100644 src/Undefined/api/routes/naga.py create mode 100644 src/Undefined/api/routes/system.py create mode 100644 src/Undefined/api/routes/tools.py diff --git a/src/Undefined/api/__init__.py b/src/Undefined/api/__init__.py index f71c40d..1ab331b 100644 --- a/src/Undefined/api/__init__.py +++ b/src/Undefined/api/__init__.py @@ -1,5 +1,6 @@ """Runtime API server for Undefined main process.""" -from .app import RuntimeAPIContext, RuntimeAPIServer +from ._context import RuntimeAPIContext +from .app import RuntimeAPIServer __all__ = ["RuntimeAPIContext", "RuntimeAPIServer"] diff --git a/src/Undefined/api/_context.py b/src/Undefined/api/_context.py new file mode 100644 index 0000000..07d8c92 --- /dev/null +++ b/src/Undefined/api/_context.py @@ -0,0 +1,22 @@ +"""Shared context types for the Runtime API.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + + +@dataclass +class RuntimeAPIContext: + config_getter: Callable[[], Any] + onebot: Any + ai: Any + command_dispatcher: Any + queue_manager: Any + history_manager: Any + sender: Any = None + scheduler: Any = None + cognitive_service: Any = None + cognitive_job_queue: Any = None + meme_service: Any = None + naga_store: Any = None diff --git a/src/Undefined/api/_naga_state.py b/src/Undefined/api/_naga_state.py new file mode 100644 index 0000000..26bff9f --- /dev/null +++ b/src/Undefined/api/_naga_state.py @@ -0,0 +1,115 @@ +"""Naga request deduplication and inflight tracking state.""" + +from __future__ import annotations + +import asyncio +import time +from copy import deepcopy +from typing import Any + +from ._helpers import _NagaRequestResult + +_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 + + +class NagaState: + """Tracks inflight Naga sends and provides request-uuid deduplication.""" + + def __init__(self) -> None: + self.send_registry_lock = asyncio.Lock() + self.send_inflight: dict[str, int] = {} + self.request_uuid_lock = asyncio.Lock() + self.request_uuid_inflight: dict[ + str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] + ] = {} + self.request_uuid_results: dict[str, _NagaRequestResult] = {} + + async def track_send_start(self, message_key: str) -> int: + async with self.send_registry_lock: + next_count = self.send_inflight.get(message_key, 0) + 1 + self.send_inflight[message_key] = next_count + return next_count + + async def track_send_done(self, message_key: str) -> int: + async with self.send_registry_lock: + current = self.send_inflight.get(message_key, 0) + if current <= 1: + self.send_inflight.pop(message_key, None) + return 0 + next_count = current - 1 + self.send_inflight[message_key] = next_count + return next_count + + def _prune_request_uuid_state_locked(self) -> None: + now = time.time() + expired = [ + request_uuid + for request_uuid, result in self.request_uuid_results.items() + if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS + ] + for request_uuid in expired: + self.request_uuid_results.pop(request_uuid, None) + + async def register_request_uuid( + self, request_uuid: str, payload_hash: str + ) -> tuple[str, Any]: + async with self.request_uuid_lock: + self._prune_request_uuid_state_locked() + + cached = self.request_uuid_results.get(request_uuid) + if cached is not None: + if cached.payload_hash != payload_hash: + return "conflict", cached.payload_hash + return "cached", (cached.status, deepcopy(cached.payload)) + + inflight = self.request_uuid_inflight.get(request_uuid) + if inflight is not None: + existing_hash, inflight_future = inflight + if existing_hash != payload_hash: + return "conflict", existing_hash + return "await", inflight_future + + owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( + asyncio.get_running_loop().create_future() + ) + self.request_uuid_inflight[request_uuid] = ( + payload_hash, + owner_future, + ) + return "owner", owner_future + + async def finish_request_uuid( + self, + request_uuid: str, + payload_hash: str, + *, + status: int, + payload: dict[str, Any], + ) -> None: + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + result_payload = deepcopy(payload) + self.request_uuid_results[request_uuid] = _NagaRequestResult( + payload_hash=payload_hash, + status=status, + payload=result_payload, + finished_at=time.time(), + ) + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_result((status, deepcopy(result_payload))) + + async def fail_request_uuid( + self, + request_uuid: str, + payload_hash: str, + exc: BaseException, + ) -> None: + _ = payload_hash + async with self.request_uuid_lock: + inflight = self.request_uuid_inflight.pop(request_uuid, None) + future = inflight[1] if inflight is not None else None + self._prune_request_uuid_state_locked() + if future is not None and not future.done(): + future.set_exception(exc) diff --git a/src/Undefined/api/_openapi.py b/src/Undefined/api/_openapi.py index 6c4566c..1076256 100644 --- a/src/Undefined/api/_openapi.py +++ b/src/Undefined/api/_openapi.py @@ -10,7 +10,7 @@ from ._helpers import _AUTH_HEADER, _naga_routes_enabled if TYPE_CHECKING: - from .app import RuntimeAPIContext + from ._context import RuntimeAPIContext def _build_openapi_spec(ctx: RuntimeAPIContext, request: web.Request) -> dict[str, Any]: diff --git a/src/Undefined/api/app.py b/src/Undefined/api/app.py index 91ceaf4..7ad5e56 100644 --- a/src/Undefined/api/app.py +++ b/src/Undefined/api/app.py @@ -1,91 +1,32 @@ -"""Runtime API server for Undefined.""" +"""Runtime API server for Undefined. + +Route handler logic lives in ``routes/`` sub-modules. This file keeps only +the ``RuntimeAPIServer`` class (init / start / stop / middleware / routing) +and thin one-liner wrappers that delegate to the route functions so that +existing tests calling ``server._xxx_handler(request)`` keep working. +""" from __future__ import annotations import asyncio -from copy import deepcopy import logging -import os -import platform -from pathlib import Path -import sys -import time -import uuid as _uuid -from contextlib import suppress -from dataclasses import dataclass -from datetime import datetime from typing import Any, Awaitable, Callable -from typing import cast -from aiohttp import ClientSession, ClientTimeout, web +from aiohttp import web from aiohttp.web_response import Response -from Undefined import __version__ -from Undefined.attachments import ( - attachment_refs_to_xml, - build_attachment_scope, - register_message_attachments, - render_message_with_pic_placeholders, -) -from Undefined.context import RequestContext -from Undefined.context_resource_registry import collect_context_resources -from Undefined.render import render_html_to_image, render_markdown_to_html # noqa: F401 -from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN -from Undefined.utils.common import message_to_segments -from Undefined.utils.recent_messages import get_recent_messages_prefer_local -from Undefined.utils.xml import escape_xml_attr, escape_xml_text - +from ._context import RuntimeAPIContext from ._helpers import ( _apply_cors_headers, - _build_chat_response_payload, _json_error, - _mask_url, - _naga_message_digest, _naga_routes_enabled, _naga_runtime_enabled, - _NagaRequestResult, - _optional_query_param, - _parse_query_time, - _parse_response_payload, - _registry_summary, - _short_text_preview, - _sse_event, - _to_bool, - _ToolInvokeExecutionTimeoutError, - _validate_callback_url, - _WebUIVirtualSender, _AUTH_HEADER, - _VIRTUAL_USER_ID, ) -from ._probes import ( - _build_internal_model_probe_payload, - _probe_http_endpoint, - _probe_ws_endpoint, - _skipped_probe, -) -from ._openapi import _build_openapi_spec +from ._naga_state import NagaState +from .routes import chat, cognitive, health, memes, memory, naga, system, tools logger = logging.getLogger(__name__) -_VIRTUAL_USER_NAME = "system" -_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 -_PROCESS_START_TIME = time.time() -_NAGA_REQUEST_UUID_TTL_SECONDS = 6 * 60 * 60 - - -@dataclass -class RuntimeAPIContext: - config_getter: Callable[[], Any] - onebot: Any - ai: Any - command_dispatcher: Any - queue_manager: Any - history_manager: Any - sender: Any = None - scheduler: Any = None - cognitive_service: Any = None - cognitive_job_queue: Any = None - meme_service: Any = None - naga_store: Any = None class RuntimeAPIServer: @@ -101,13 +42,7 @@ def __init__( self._runner: web.AppRunner | None = None self._sites: list[web.TCPSite] = [] self._background_tasks: set[asyncio.Task[Any]] = set() - self._naga_send_registry_lock = asyncio.Lock() - self._naga_send_inflight: dict[str, int] = {} - self._naga_request_uuid_lock = asyncio.Lock() - self._naga_request_uuid_inflight: dict[ - str, tuple[str, asyncio.Future[tuple[int, dict[str, Any]]]] - ] = {} - self._naga_request_uuid_results: dict[str, _NagaRequestResult] = {} + self._naga_state = NagaState() async def start(self) -> None: from Undefined.config.models import resolve_bind_hosts @@ -123,7 +58,6 @@ async def start(self) -> None: logger.info("[RuntimeAPI] 已启动: %s", cfg.api.display_url) async def stop(self) -> None: - # 取消所有后台任务(如异步 tool invoke 回调) for task in self._background_tasks: task.cancel() if self._background_tasks: @@ -148,7 +82,6 @@ async def _auth_middleware( _apply_cors_headers(request, response) return response if request.path.startswith("/api/"): - # Naga 端点使用独立鉴权,仅在总开关+子开关均启用时跳过主 auth cfg = self._context.config_getter() is_naga_path = request.path.startswith("/api/v1/naga/") skip_auth = is_naga_path and _naga_runtime_enabled(cfg) @@ -205,7 +138,6 @@ async def _auth_middleware( web.post("/api/v1/tools/invoke", self._tools_invoke_handler), ] ) - # Naga 端点仅在总开关+子开关均启用时注册 cfg = self._context.config_getter() if _naga_routes_enabled(cfg, self._context.naga_store): app.add_routes( @@ -237,1233 +169,108 @@ async def _auth_middleware( ) return app - async def _track_naga_send_start(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - next_count = self._naga_send_inflight.get(message_key, 0) + 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - async def _track_naga_send_done(self, message_key: str) -> int: - async with self._naga_send_registry_lock: - current = self._naga_send_inflight.get(message_key, 0) - if current <= 1: - self._naga_send_inflight.pop(message_key, None) - return 0 - next_count = current - 1 - self._naga_send_inflight[message_key] = next_count - return next_count - - def _prune_naga_request_uuid_state_locked(self) -> None: - now = time.time() - expired = [ - request_uuid - for request_uuid, result in self._naga_request_uuid_results.items() - if now - result.finished_at > _NAGA_REQUEST_UUID_TTL_SECONDS - ] - for request_uuid in expired: - self._naga_request_uuid_results.pop(request_uuid, None) - - async def _register_naga_request_uuid( - self, request_uuid: str, payload_hash: str - ) -> tuple[str, Any]: - async with self._naga_request_uuid_lock: - self._prune_naga_request_uuid_state_locked() - - cached = self._naga_request_uuid_results.get(request_uuid) - if cached is not None: - if cached.payload_hash != payload_hash: - return "conflict", cached.payload_hash - return "cached", (cached.status, deepcopy(cached.payload)) - - inflight = self._naga_request_uuid_inflight.get(request_uuid) - if inflight is not None: - existing_hash, inflight_future = inflight - if existing_hash != payload_hash: - return "conflict", existing_hash - return "await", inflight_future - - owner_future: asyncio.Future[tuple[int, dict[str, Any]]] = ( - asyncio.get_running_loop().create_future() - ) - self._naga_request_uuid_inflight[request_uuid] = ( - payload_hash, - owner_future, - ) - return "owner", owner_future - - async def _finish_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - *, - status: int, - payload: dict[str, Any], - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - result_payload = deepcopy(payload) - self._naga_request_uuid_results[request_uuid] = _NagaRequestResult( - payload_hash=payload_hash, - status=status, - payload=result_payload, - finished_at=time.time(), - ) - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_result((status, deepcopy(result_payload))) - - async def _fail_naga_request_uuid( - self, - request_uuid: str, - payload_hash: str, - exc: BaseException, - ) -> None: - async with self._naga_request_uuid_lock: - inflight = self._naga_request_uuid_inflight.pop(request_uuid, None) - future = inflight[1] if inflight is not None else None - self._prune_naga_request_uuid_state_locked() - if future is not None and not future.done(): - future.set_exception(exc) - @property def _ctx(self) -> RuntimeAPIContext: return self._context + # ------------------------------------------------------------------ + # Thin delegation wrappers — keep tests calling server._xxx(request) + # ------------------------------------------------------------------ + + # Health / System async def _health_handler(self, request: web.Request) -> Response: - _ = request - return web.json_response( - { - "ok": True, - "service": "undefined-runtime-api", - "version": __version__, - "timestamp": datetime.now().isoformat(), - } - ) + return await health.health_handler(self._ctx, request) async def _openapi_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not bool(getattr(cfg.api, "openapi_enabled", True)): - logger.info( - "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote - ) - return _json_error("OpenAPI disabled", status=404) - naga_routes_enabled = _naga_routes_enabled(cfg, self._ctx.naga_store) - logger.info( - "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", - request.remote, - naga_routes_enabled, - ) - return web.json_response(_build_openapi_spec(self._ctx, request)) + return await system.openapi_handler(self._ctx, request) async def _internal_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - queue_snapshot = ( - self._ctx.queue_manager.snapshot() if self._ctx.queue_manager else {} - ) - cognitive_queue_snapshot = ( - self._ctx.cognitive_job_queue.snapshot() - if self._ctx.cognitive_job_queue - else {} - ) - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - memory_count = memory_storage.count() if memory_storage is not None else 0 - - # Skills 统计 - ai = self._ctx.ai - skills_info: dict[str, Any] = {} - if ai is not None: - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - anthropic_reg = getattr(ai, "anthropic_skill_registry", None) - skills_info["tools"] = _registry_summary(tool_reg) - skills_info["agents"] = _registry_summary(agent_reg) - skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) - - # 模型配置(脱敏) - models_info: dict[str, Any] = {} - summary_model = getattr( - cfg, - "summary_model", - getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), - ) - for label in ( - "chat_model", - "vision_model", - "agent_model", - "security_model", - "naga_model", - "grok_model", - ): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = _build_internal_model_probe_payload(mcfg) - if summary_model is not None: - models_info["summary_model"] = _build_internal_model_probe_payload( - summary_model - ) - for label in ("embedding_model", "rerank_model"): - mcfg = getattr(cfg, label, None) - if mcfg is not None: - models_info[label] = { - "model_name": getattr(mcfg, "model_name", ""), - "api_url": _mask_url(getattr(mcfg, "api_url", "")), - } - - uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) - payload = { - "timestamp": datetime.now().isoformat(), - "version": __version__, - "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", - "platform": platform.system(), - "uptime_seconds": uptime_seconds, - "onebot": self._ctx.onebot.connection_status() - if self._ctx.onebot is not None - else {}, - "queues": queue_snapshot, - "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, - "cognitive": { - "enabled": bool( - self._ctx.cognitive_service and self._ctx.cognitive_service.enabled - ), - "queue": cognitive_queue_snapshot, - }, - "api": { - "enabled": bool(cfg.api.enabled), - "host": cfg.api.host, - "port": cfg.api.port, - "openapi_enabled": bool(cfg.api.openapi_enabled), - }, - "skills": skills_info, - "models": models_info, - } - return web.json_response(payload) + return await system.internal_probe_handler(self._ctx, request) async def _external_probe_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - summary_model = getattr( - cfg, - "summary_model", - getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), - ) - naga_probe = ( - _probe_http_endpoint( - name="naga_model", - base_url=cfg.naga_model.api_url, - api_key=cfg.naga_model.api_key, - model_name=cfg.naga_model.model_name, - ) - if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) - else _skipped_probe( - name="naga_model", - reason="naga_integration_disabled", - model_name=cfg.naga_model.model_name, - ) - ) - checks = [ - _probe_http_endpoint( - name="chat_model", - base_url=cfg.chat_model.api_url, - api_key=cfg.chat_model.api_key, - model_name=cfg.chat_model.model_name, - ), - _probe_http_endpoint( - name="vision_model", - base_url=cfg.vision_model.api_url, - api_key=cfg.vision_model.api_key, - model_name=cfg.vision_model.model_name, - ), - _probe_http_endpoint( - name="security_model", - base_url=cfg.security_model.api_url, - api_key=cfg.security_model.api_key, - model_name=cfg.security_model.model_name, - ), - naga_probe, - _probe_http_endpoint( - name="agent_model", - base_url=cfg.agent_model.api_url, - api_key=cfg.agent_model.api_key, - model_name=cfg.agent_model.model_name, - ), - ] - if summary_model is not None: - checks.append( - _probe_http_endpoint( - name="summary_model", - base_url=summary_model.api_url, - api_key=summary_model.api_key, - model_name=summary_model.model_name, - ) - ) - grok_model = getattr(cfg, "grok_model", None) - if grok_model is not None: - checks.append( - _probe_http_endpoint( - name="grok_model", - base_url=getattr(grok_model, "api_url", ""), - api_key=getattr(grok_model, "api_key", ""), - model_name=getattr(grok_model, "model_name", ""), - ) - ) - checks.extend( - [ - _probe_http_endpoint( - name="embedding_model", - base_url=cfg.embedding_model.api_url, - api_key=cfg.embedding_model.api_key, - model_name=getattr(cfg.embedding_model, "model_name", ""), - ), - _probe_http_endpoint( - name="rerank_model", - base_url=cfg.rerank_model.api_url, - api_key=cfg.rerank_model.api_key, - model_name=getattr(cfg.rerank_model, "model_name", ""), - ), - _probe_ws_endpoint(cfg.onebot_ws_url), - ] - ) - results = await asyncio.gather(*checks) - ok = all(item.get("status") in {"ok", "skipped"} for item in results) - return web.json_response( - { - "ok": ok, - "timestamp": datetime.now().isoformat(), - "results": results, - } - ) + return await system.external_probe_handler(self._ctx, request) + # Memory CRUD async def _memory_handler(self, request: web.Request) -> Response: - query = str(request.query.get("q", "") or "").strip().lower() - top_k_raw = _optional_query_param(request, "top_k") - time_from_raw = _optional_query_param(request, "time_from") - time_to_raw = _optional_query_param(request, "time_to") - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - - limit: int | None = None - if top_k_raw is not None: - try: - limit = int(top_k_raw) - except ValueError: - return _json_error("top_k must be an integer", status=400) - if limit <= 0: - return _json_error("top_k must be > 0", status=400) - - time_from_dt = _parse_query_time(time_from_raw) - if time_from_raw is not None and time_from_dt is None: - return _json_error("time_from must be ISO datetime", status=400) - time_to_dt = _parse_query_time(time_to_raw) - if time_to_raw is not None and time_to_dt is None: - return _json_error("time_to must be ISO datetime", status=400) - if time_from_dt and time_to_dt and time_from_dt > time_to_dt: - time_from_dt, time_to_dt = time_to_dt, time_from_dt - - records = memory_storage.get_all() - items: list[dict[str, Any]] = [] - for item in records: - created_at = str(item.created_at or "").strip() - created_dt = _parse_query_time(created_at) - if time_from_dt and created_dt and created_dt < time_from_dt: - continue - if time_to_dt and created_dt and created_dt > time_to_dt: - continue - if (time_from_dt or time_to_dt) and created_dt is None: - continue - items.append( - { - "uuid": item.uuid, - "fact": item.fact, - "created_at": created_at, - } - ) - if query: - items = [ - item - for item in items - if query in str(item.get("fact", "")).lower() - or query in str(item.get("uuid", "")).lower() - ] - - def _created_sort_key(item: dict[str, Any]) -> float: - created_dt = _parse_query_time(str(item.get("created_at") or "")) - if created_dt is None: - return float("-inf") - with suppress(OSError, OverflowError, ValueError): - return float(created_dt.timestamp()) - return float("-inf") - - items.sort(key=_created_sort_key) - if limit is not None: - items = items[:limit] - - return web.json_response( - { - "total": len(items), - "items": items, - "query": { - "q": query or "", - "top_k": limit, - "time_from": time_from_raw, - "time_to": time_to_raw, - }, - } - ) + return await memory.memory_list_handler(self._ctx, request) async def _memory_create_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - fact = str(body.get("fact", "") or "").strip() - if not fact: - return _json_error("fact must not be empty", status=400) - new_uuid = await memory_storage.add(fact) - if new_uuid is None: - return _json_error("Failed to create memory", status=500) - # add() returns existing UUID on duplicate - existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] - item = existing[0] if existing else None - return web.json_response( - { - "uuid": new_uuid, - "fact": item.fact if item else fact, - "created_at": item.created_at if item else "", - }, - status=201, - ) + return await memory.memory_create_handler(self._ctx, request) async def _memory_update_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - target_uuid = str(request.match_info.get("uuid", "")).strip() - if not target_uuid: - return _json_error("uuid is required", status=400) - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - fact = str(body.get("fact", "") or "").strip() - if not fact: - return _json_error("fact must not be empty", status=400) - ok = await memory_storage.update(target_uuid, fact) - if not ok: - return _json_error(f"Memory {target_uuid} not found", status=404) - return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + return await memory.memory_update_handler(self._ctx, request) async def _memory_delete_handler(self, request: web.Request) -> Response: - memory_storage = getattr(self._ctx.ai, "memory_storage", None) - if memory_storage is None: - return _json_error("Memory storage not ready", status=503) - target_uuid = str(request.match_info.get("uuid", "")).strip() - if not target_uuid: - return _json_error("uuid is required", status=400) - ok = await memory_storage.delete(target_uuid) - if not ok: - return _json_error(f"Memory {target_uuid} not found", status=404) - return web.json_response({"uuid": target_uuid, "deleted": True}) + return await memory.memory_delete_handler(self._ctx, request) + # Memes async def _meme_list_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - - def _parse_optional_bool(name: str) -> bool | None: - raw = request.query.get(name) - if raw is None or str(raw).strip() == "": - return None - return _to_bool(raw) - - page_raw = _optional_query_param(request, "page") - page_size_raw = _optional_query_param(request, "page_size") - top_k_raw = _optional_query_param(request, "top_k") - query = str(request.query.get("q", "") or "").strip() - query_mode = str(request.query.get("query_mode", "") or "").strip().lower() - keyword_query = str(request.query.get("keyword_query", "") or "").strip() - semantic_query = str(request.query.get("semantic_query", "") or "").strip() - try: - page = int(page_raw) if page_raw is not None else 1 - page_size = int(page_size_raw) if page_size_raw is not None else 50 - top_k = int(top_k_raw) if top_k_raw is not None else page_size - except ValueError: - return _json_error("page/page_size/top_k must be integers", status=400) - page = max(1, page) - page_size = max(1, min(200, page_size)) - top_k = max(1, top_k) - sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() - - enabled_filter = _parse_optional_bool("enabled") - animated_filter = _parse_optional_bool("animated") - pinned_filter = _parse_optional_bool("pinned") - if not (query or keyword_query or semantic_query) and sort == "relevance": - sort = "updated_at" - - if query or keyword_query or semantic_query: - has_post_filter = any( - f is not None for f in (enabled_filter, animated_filter, pinned_filter) - ) - requested_window = max(page * page_size, top_k) - if has_post_filter or page > 1 or sort != "relevance": - fetch_k = min(500, max(requested_window * 4, top_k)) - else: - fetch_k = min(500, requested_window) - search_payload = await meme_service.search_memes( - query, - query_mode=query_mode or meme_service.default_query_mode, - keyword_query=keyword_query or None, - semantic_query=semantic_query or None, - top_k=fetch_k, - include_disabled=enabled_filter is not True, - sort=sort, - ) - filtered_items: list[dict[str, Any]] = [] - for item in list(search_payload.get("items") or []): - if ( - enabled_filter is not None - and bool(item.get("enabled")) != enabled_filter - ): - continue - if ( - animated_filter is not None - and bool(item.get("is_animated")) != animated_filter - ): - continue - if ( - pinned_filter is not None - and bool(item.get("pinned")) != pinned_filter - ): - continue - filtered_items.append(item) - offset = (page - 1) * page_size - paged_items = filtered_items[offset : offset + page_size] - window_total = len(filtered_items) - fetched_window_count = len(list(search_payload.get("items") or [])) - window_exhausted = fetched_window_count < fetch_k - has_more = bool(paged_items) and ( - offset + page_size < window_total - or (not window_exhausted and window_total >= offset + page_size) - ) - return web.json_response( - { - "ok": True, - "total": None, - "window_total": window_total, - "total_exact": False, - "page": page, - "page_size": page_size, - "has_more": has_more, - "query_mode": search_payload.get("query_mode"), - "keyword_query": search_payload.get("keyword_query"), - "semantic_query": search_payload.get("semantic_query"), - "sort": search_payload.get("sort", sort), - "items": paged_items, - } - ) - - payload = await meme_service.list_memes( - query=query, - enabled=enabled_filter, - animated=animated_filter, - pinned=pinned_filter, - sort=sort, - page=page, - page_size=page_size, - summary=True, - ) - return web.json_response(payload) + return await memes.meme_list_handler(self._ctx, request) async def _meme_stats_handler(self, request: web.Request) -> Response: - _ = request - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - return web.json_response(await meme_service.stats()) + return await memes.meme_stats_handler(self._ctx, request) async def _meme_detail_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - detail = await meme_service.get_meme(uid) - if detail is None: - return _json_error("Meme not found", status=404) - return web.json_response(detail) + return await memes.meme_detail_handler(self._ctx, request) async def _meme_blob_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=False) - if path is None: - return _json_error("Meme blob not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_blob_handler(self._ctx, request) async def _meme_preview_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - path = await meme_service.blob_path_for_uid(uid, preview=True) - if path is None: - return _json_error("Meme preview not found", status=404) - return cast(Response, web.FileResponse(path=path)) + return await memes.meme_preview_handler(self._ctx, request) async def _meme_update_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - try: - payload = await request.json() - except Exception: - return _json_error("Invalid JSON body", status=400) - if not isinstance(payload, dict): - return _json_error("JSON body must be an object", status=400) - updated = await meme_service.update_meme( - uid, - manual_description=payload.get("manual_description"), - tags=payload.get("tags"), - aliases=payload.get("aliases"), - enabled=payload.get("enabled") if "enabled" in payload else None, - pinned=payload.get("pinned") if "pinned" in payload else None, - ) - if updated is None: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "record": updated}) + return await memes.meme_update_handler(self._ctx, request) async def _meme_delete_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - deleted = await meme_service.delete_meme(uid) - if not deleted: - return _json_error("Meme not found", status=404) - return web.json_response({"ok": True, "uid": uid}) + return await memes.meme_delete_handler(self._ctx, request) async def _meme_reanalyze_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reanalyze(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reanalyze_handler(self._ctx, request) async def _meme_reindex_handler(self, request: web.Request) -> Response: - meme_service = self._ctx.meme_service - if meme_service is None or not meme_service.enabled: - return _json_error("Meme service disabled", status=400) - uid = str(request.match_info.get("uid", "")).strip() - job_id = await meme_service.enqueue_reindex(uid) - if not job_id: - return _json_error("Meme queue unavailable", status=503) - return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + return await memes.meme_reindex_handler(self._ctx, request) + # Cognitive async def _cognitive_events_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - for key in ( - "target_user_id", - "target_group_id", - "sender_id", - "request_type", - "top_k", - "time_from", - "time_to", - ): - value = _optional_query_param(request, key) - if value is not None: - search_kwargs[key] = value - - results = await cognitive_service.search_events(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_events_handler(self._ctx, request) async def _cognitive_profiles_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - query = str(request.query.get("q", "") or "").strip() - if not query: - return _json_error("q is required", status=400) - - search_kwargs: dict[str, Any] = {"query": query} - entity_type = _optional_query_param(request, "entity_type") - if entity_type is not None: - search_kwargs["entity_type"] = entity_type - top_k = _optional_query_param(request, "top_k") - if top_k is not None: - search_kwargs["top_k"] = top_k - - results = await cognitive_service.search_profiles(**search_kwargs) - return web.json_response({"count": len(results), "items": results}) + return await cognitive.cognitive_profiles_handler(self._ctx, request) async def _cognitive_profile_handler(self, request: web.Request) -> Response: - cognitive_service = self._ctx.cognitive_service - if not cognitive_service or not cognitive_service.enabled: - return _json_error("Cognitive service disabled", status=400) - - entity_type = str(request.match_info.get("entity_type", "")).strip() - entity_id = str(request.match_info.get("entity_id", "")).strip() - if not entity_type or not entity_id: - return _json_error("entity_type/entity_id are required", status=400) - - profile = await cognitive_service.get_profile(entity_type, entity_id) - return web.json_response( - { - "entity_type": entity_type, - "entity_id": entity_id, - "profile": profile or "", - "found": bool(profile), - } - ) + return await cognitive.cognitive_profile_handler(self._ctx, request) + # Chat async def _run_webui_chat( self, *, text: str, send_output: Callable[[int, str], Awaitable[None]], ) -> str: - cfg = self._ctx.config_getter() - permission_sender_id = int(cfg.superadmin_qq) - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - input_segments = message_to_segments(text) - registered_input = await register_message_attachments( - registry=self._ctx.ai.attachment_registry, - segments=input_segments, - scope_key=webui_scope_key, - resolve_image_url=self._ctx.onebot.get_image, - get_forward_messages=self._ctx.onebot.get_forward_msg, - ) - normalized_text = registered_input.normalized_text or text - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=normalized_text, - display_name=_VIRTUAL_USER_NAME, - user_name=_VIRTUAL_USER_NAME, - attachments=registered_input.attachments, - ) - - command = self._ctx.command_dispatcher.parse_command(normalized_text) - if command: - await self._ctx.command_dispatcher.dispatch_private( - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - command=command, - send_private_callback=send_output, - is_webui_session=True, - ) - return "command" - - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - attachment_xml = ( - f"\n{attachment_refs_to_xml(registered_input.attachments)}" - if registered_input.attachments - else "" - ) - full_question = f""" - {escape_xml_text(normalized_text)}{attachment_xml} - - -【WebUI 会话】 -这是一条来自 WebUI 控制台的会话请求。 -会话身份:虚拟用户 system(42)。 -权限等级:superadmin(你可按最高管理权限处理)。 -请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" - virtual_sender = _WebUIVirtualSender( - _VIRTUAL_USER_ID, send_output, onebot=self._ctx.onebot - ) - - async def _get_recent_cb( - chat_id: str, msg_type: str, start: int, end: int - ) -> list[dict[str, Any]]: - return await get_recent_messages_prefer_local( - chat_id=chat_id, - msg_type=msg_type, - start=start, - end=end, - onebot_client=self._ctx.onebot, - history_manager=self._ctx.history_manager, - bot_qq=cfg.bot_qq, - attachment_registry=getattr(self._ctx.ai, "attachment_registry", None), - ) - - async with RequestContext( - request_type="private", - user_id=_VIRTUAL_USER_ID, - sender_id=permission_sender_id, - ) as ctx: - # 与 ai_coordinator 保持一致:通过 collect_context_resources 自动注入 - ai_client = self._ctx.ai - memory_storage = self._ctx.ai.memory_storage - runtime_config = self._ctx.ai.runtime_config - sender = virtual_sender - history_manager = self._ctx.history_manager - onebot_client = self._ctx.onebot - scheduler = self._ctx.scheduler - - def send_message_callback( - msg: str, reply_to: int | None = None - ) -> Awaitable[None]: - _ = reply_to - return send_output(_VIRTUAL_USER_ID, msg) - - get_recent_messages_callback = _get_recent_cb - get_image_url_callback = self._ctx.onebot.get_image - get_forward_msg_callback = self._ctx.onebot.get_forward_msg - resource_vars = dict(globals()) - resource_vars.update(locals()) - resources = collect_context_resources(resource_vars) - for key, value in resources.items(): - if value is not None: - ctx.set_resource(key, value) - ctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) - ctx.set_resource("webui_session", True) - ctx.set_resource("webui_permission", "superadmin") - - result = await self._ctx.ai.ask( - full_question, - send_message_callback=send_message_callback, - get_recent_messages_callback=get_recent_messages_callback, - get_image_url_callback=get_image_url_callback, - get_forward_msg_callback=get_forward_msg_callback, - sender=sender, - history_manager=history_manager, - onebot_client=onebot_client, - scheduler=scheduler, - extra_context={ - "is_private_chat": True, - "request_type": "private", - "user_id": _VIRTUAL_USER_ID, - "sender_name": _VIRTUAL_USER_NAME, - "webui_session": True, - "webui_permission": "superadmin", - }, - ) - - final_reply = str(result or "").strip() - if final_reply: - await send_output(_VIRTUAL_USER_ID, final_reply) - - return "chat" + return await chat.run_webui_chat(self._ctx, text=text, send_output=send_output) async def _chat_history_handler(self, request: web.Request) -> Response: - limit_raw = str(request.query.get("limit", "200") or "200").strip() - try: - limit = int(limit_raw) - except ValueError: - limit = 200 - limit = max(1, min(limit, 500)) - - getter = getattr(self._ctx.history_manager, "get_recent_private", None) - if not callable(getter): - return _json_error("History manager not ready", status=503) - - records = getter(_VIRTUAL_USER_ID, limit) - items: list[dict[str, Any]] = [] - for item in records: - if not isinstance(item, dict): - continue - content = str(item.get("message", "")).strip() - if not content: - continue - display_name = str(item.get("display_name", "")).strip().lower() - role = "bot" if display_name == "bot" else "user" - items.append( - { - "role": role, - "content": content, - "timestamp": str(item.get("timestamp", "") or "").strip(), - } - ) - - return web.json_response( - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - "count": len(items), - "items": items, - } - ) + return await chat.chat_history_handler(self._ctx, request) async def _chat_handler(self, request: web.Request) -> web.StreamResponse: - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - text = str(body.get("message", "") or "").strip() - if not text: - return _json_error("message is required", status=400) - - stream = _to_bool(body.get("stream")) - outputs: list[str] = [] - webui_scope_key = build_attachment_scope( - user_id=_VIRTUAL_USER_ID, - request_type="private", - webui_session=True, - ) - - async def _capture_private_message(user_id: int, message: str) -> None: - _ = user_id - content = str(message or "").strip() - if not content: - return - rendered = await render_message_with_pic_placeholders( - content, - registry=self._ctx.ai.attachment_registry, - scope_key=webui_scope_key, - strict=False, - ) - if not rendered.delivery_text.strip(): - return - outputs.append(rendered.delivery_text) - await self._ctx.history_manager.add_private_message( - user_id=_VIRTUAL_USER_ID, - text_content=rendered.history_text, - display_name="Bot", - user_name="Bot", - attachments=rendered.attachments, - ) - - if not stream: - try: - mode = await self._run_webui_chat( - text=text, send_output=_capture_private_message - ) - except Exception as exc: - logger.exception("[RuntimeAPI] chat failed: %s", exc) - return _json_error("Chat failed", status=502) - return web.json_response(_build_chat_response_payload(mode, outputs)) - - response = web.StreamResponse( - status=200, - reason="OK", - headers={ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - await response.prepare(request) - - message_queue: asyncio.Queue[str] = asyncio.Queue() - - async def _capture_private_message_stream(user_id: int, message: str) -> None: - output_count = len(outputs) - await _capture_private_message(user_id, message) - if len(outputs) <= output_count: - return - content = outputs[-1].strip() - if content: - await message_queue.put(content) - - task = asyncio.create_task( - self._run_webui_chat(text=text, send_output=_capture_private_message_stream) - ) - mode = "chat" - client_disconnected = False - try: - await response.write( - _sse_event( - "meta", - { - "virtual_user_id": _VIRTUAL_USER_ID, - "permission": "superadmin", - }, - ) - ) - - while True: - if request.transport is None or request.transport.is_closing(): - client_disconnected = True - break - if task.done() and message_queue.empty(): - break - try: - message = await asyncio.wait_for( - message_queue.get(), - timeout=_CHAT_SSE_KEEPALIVE_SECONDS, - ) - await response.write(_sse_event("message", {"content": message})) - except asyncio.TimeoutError: - await response.write(b": keep-alive\n\n") - - if client_disconnected: - task.cancel() - with suppress(asyncio.CancelledError): - await task - return response - - mode = await task - await response.write( - _sse_event("done", _build_chat_response_payload(mode, outputs)) - ) - except asyncio.CancelledError: - task.cancel() - with suppress(asyncio.CancelledError): - await task - raise - except (ConnectionResetError, RuntimeError): - task.cancel() - with suppress(asyncio.CancelledError): - await task - except Exception as exc: - logger.exception("[RuntimeAPI] chat stream failed: %s", exc) - if not task.done(): - task.cancel() - with suppress(asyncio.CancelledError): - await task - with suppress(Exception): - await response.write(_sse_event("error", {"error": str(exc)})) - finally: - with suppress(Exception): - await response.write_eof() - - return response - - # ------------------------------------------------------------------ - # Tool Invoke API - # ------------------------------------------------------------------ + return await chat.chat_handler(self._ctx, request) + # Tools def _get_filtered_tools(self) -> list[dict[str, Any]]: - """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" - cfg = self._ctx.config_getter() - api_cfg = cfg.api - ai = self._ctx.ai - if ai is None: - return [] - - tool_reg = getattr(ai, "tool_registry", None) - agent_reg = getattr(ai, "agent_registry", None) - - all_schemas: list[dict[str, Any]] = [] - if tool_reg is not None: - all_schemas.extend(tool_reg.get_tools_schema()) - - # 收集 agent schema 并缓存名称集合(避免重复调用) - agent_names: set[str] = set() - if agent_reg is not None: - agent_schemas = agent_reg.get_agents_schema() - all_schemas.extend(agent_schemas) - for schema in agent_schemas: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - - denylist: set[str] = set(api_cfg.tool_invoke_denylist) - allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) - expose = api_cfg.tool_invoke_expose - - def _get_name(schema: dict[str, Any]) -> str: - func = schema.get("function", {}) - return str(func.get("name", "")) - - # 1. 先排除黑名单 - if denylist: - all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] - - # 2. 白名单非空时仅保留匹配项 - if allowlist: - return [s for s in all_schemas if _get_name(s) in allowlist] - - # 3. 按 expose 过滤 - if expose == "all": - return all_schemas - - def _is_tool(name: str) -> bool: - return "." not in name and name not in agent_names - - def _is_toolset(name: str) -> bool: - return "." in name and not name.startswith("mcp.") - - filtered: list[dict[str, Any]] = [] - for schema in all_schemas: - name = _get_name(schema) - if not name: - continue - if expose == "tools" and _is_tool(name): - filtered.append(schema) - elif expose == "toolsets" and _is_toolset(name): - filtered.append(schema) - elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): - filtered.append(schema) - elif expose == "agents" and name in agent_names: - filtered.append(schema) - - return filtered + return tools.get_filtered_tools(self._ctx) def _get_agent_tool_names(self) -> set[str]: - ai = self._ctx.ai - if ai is None: - return set() - - agent_reg = getattr(ai, "agent_registry", None) - if agent_reg is None: - return set() - - agent_names: set[str] = set() - for schema in agent_reg.get_agents_schema(): - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - agent_names.add(name) - return agent_names - - def _resolve_tool_invoke_timeout( - self, tool_name: str, timeout: int - ) -> float | None: - if tool_name in self._get_agent_tool_names(): - return None - return float(timeout) - - async def _await_tool_invoke_result( - self, - awaitable: Awaitable[Any], - *, - timeout: float | None, - ) -> Any: - if timeout is None or timeout <= 0: - return await awaitable - try: - return await asyncio.wait_for(awaitable, timeout=timeout) - except asyncio.TimeoutError as exc: - raise _ToolInvokeExecutionTimeoutError from exc + return tools.get_agent_tool_names(self._ctx) async def _tools_list_handler(self, request: web.Request) -> Response: - _ = request - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - tools = self._get_filtered_tools() - return web.json_response({"count": len(tools), "tools": tools}) + return await tools.tools_list_handler(self._ctx, request) async def _tools_invoke_handler(self, request: web.Request) -> Response: - cfg = self._ctx.config_getter() - if not cfg.api.tool_invoke_enabled: - return _json_error("Tool invoke API is disabled", status=403) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - if not isinstance(body, dict): - return _json_error("Request body must be a JSON object", status=400) - - tool_name = str(body.get("tool_name", "") or "").strip() - if not tool_name: - return _json_error("tool_name is required", status=400) - - args = body.get("args") - if not isinstance(args, dict): - return _json_error("args must be a JSON object", status=400) - - # 验证工具是否在允许列表中 - filtered_tools = self._get_filtered_tools() - available_names: set[str] = set() - for schema in filtered_tools: - func = schema.get("function", {}) - name = str(func.get("name", "")) - if name: - available_names.add(name) - - if tool_name not in available_names: - caller_ip = request.remote or "unknown" - logger.warning( - "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", - tool_name, - caller_ip, - ) - return _json_error(f"Tool '{tool_name}' is not available", status=404) - - # 解析回调配置 - callback_cfg = body.get("callback") - use_callback = False - callback_url = "" - callback_headers: dict[str, str] = {} - if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): - callback_url = str(callback_cfg.get("url", "") or "").strip() - if not callback_url: - return _json_error( - "callback.url is required when callback is enabled", - status=400, - ) - url_error = _validate_callback_url(callback_url) - if url_error: - return _json_error(url_error, status=400) - raw_headers = callback_cfg.get("headers") - if isinstance(raw_headers, dict): - callback_headers = {str(k): str(v) for k, v in raw_headers.items()} - use_callback = True - - request_id = _uuid.uuid4().hex - caller_ip = request.remote or "unknown" - logger.info( - "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", - request_id, - tool_name, - caller_ip, + return await tools.tools_invoke_handler( + self._ctx, self._background_tasks, request ) - if use_callback: - # 异步执行 + 回调 - task = asyncio.create_task( - self._execute_and_callback( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - callback_url=callback_url, - callback_headers=callback_headers, - timeout=cfg.api.tool_invoke_timeout, - callback_timeout=cfg.api.tool_invoke_callback_timeout, - ) - ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return web.json_response( - { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "status": "accepted", - } - ) - - # 同步执行 - result = await self._execute_tool_invoke( - request_id=request_id, - tool_name=tool_name, - args=args, - body_context=body.get("context"), - timeout=cfg.api.tool_invoke_timeout, - ) - return web.json_response(result) - async def _execute_tool_invoke( self, *, @@ -1473,146 +280,8 @@ async def _execute_tool_invoke( body_context: Any, timeout: int, ) -> dict[str, Any]: - """执行工具调用并返回结果字典。""" - ai = self._ctx.ai - if ai is None: - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": "AI client not ready", - "duration_ms": 0, - } - - # 解析请求上下文 - ctx_data: dict[str, Any] = {} - if isinstance(body_context, dict): - ctx_data = body_context - request_type = str(ctx_data.get("request_type", "api") or "api") - group_id = ctx_data.get("group_id") - user_id = ctx_data.get("user_id") - sender_id = ctx_data.get("sender_id") - - args_keys = list(args.keys()) - logger.info( - "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", - request_id, - tool_name, - args_keys, - ) - - start = time.perf_counter() - effective_timeout = self._resolve_tool_invoke_timeout(tool_name, timeout) - try: - async with RequestContext( - request_type=request_type, - group_id=int(group_id) if group_id is not None else None, - user_id=int(user_id) if user_id is not None else None, - sender_id=int(sender_id) if sender_id is not None else None, - ) as ctx: - # 注入核心服务资源 - if self._ctx.sender is not None: - ctx.set_resource("sender", self._ctx.sender) - if self._ctx.history_manager is not None: - ctx.set_resource("history_manager", self._ctx.history_manager) - runtime_config = getattr(ai, "runtime_config", None) - if runtime_config is not None: - ctx.set_resource("runtime_config", runtime_config) - memory_storage = getattr(ai, "memory_storage", None) - if memory_storage is not None: - ctx.set_resource("memory_storage", memory_storage) - if self._ctx.onebot is not None: - ctx.set_resource("onebot_client", self._ctx.onebot) - if self._ctx.scheduler is not None: - ctx.set_resource("scheduler", self._ctx.scheduler) - if self._ctx.cognitive_service is not None: - ctx.set_resource("cognitive_service", self._ctx.cognitive_service) - if self._ctx.meme_service is not None: - ctx.set_resource("meme_service", self._ctx.meme_service) - - tool_context: dict[str, Any] = { - "request_type": request_type, - "request_id": request_id, - } - if group_id is not None: - tool_context["group_id"] = int(group_id) - if user_id is not None: - tool_context["user_id"] = int(user_id) - if sender_id is not None: - tool_context["sender_id"] = int(sender_id) - - tool_manager = getattr(ai, "tool_manager", None) - if tool_manager is None: - raise RuntimeError("ToolManager not available") - - raw_result = await self._await_tool_invoke_result( - tool_manager.execute_tool(tool_name, args, tool_context), - timeout=effective_timeout, - ) - - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - result_str = str(raw_result or "") - logger.info( - "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " - "duration_ms=%s result_len=%d", - request_id, - tool_name, - elapsed_ms, - len(result_str), - ) - return { - "ok": True, - "request_id": request_id, - "tool_name": tool_name, - "result": result_str, - "duration_ms": elapsed_ms, - } - - except _ToolInvokeExecutionTimeoutError: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.warning( - "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", - request_id, - tool_name, - timeout, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": f"Execution timed out after {timeout}s", - "duration_ms": elapsed_ms, - } - except Exception as exc: - elapsed_ms = round((time.perf_counter() - start) * 1000, 1) - logger.exception( - "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", - request_id, - tool_name, - exc, - ) - return { - "ok": False, - "request_id": request_id, - "tool_name": tool_name, - "error": str(exc), - "duration_ms": elapsed_ms, - } - - async def _execute_and_callback( - self, - *, - request_id: str, - tool_name: str, - args: dict[str, Any], - body_context: Any, - callback_url: str, - callback_headers: dict[str, str], - timeout: int, - callback_timeout: int, - ) -> None: - """异步执行工具并发送回调。""" - result = await self._execute_tool_invoke( + return await tools.execute_tool_invoke( + self._ctx, request_id=request_id, tool_name=tool_name, args=args, @@ -1620,371 +289,17 @@ async def _execute_and_callback( timeout=timeout, ) - payload = { - "request_id": result["request_id"], - "tool_name": result["tool_name"], - "ok": result["ok"], - "result": result.get("result"), - "duration_ms": result.get("duration_ms", 0), - "error": result.get("error"), - } - - try: - cb_timeout = ClientTimeout(total=callback_timeout) - async with ClientSession(timeout=cb_timeout) as session: - # aiohttp json= 自动设置 Content-Type,无需手动指定 - async with session.post( - callback_url, - json=payload, - headers=callback_headers or None, - ) as resp: - logger.info( - "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", - request_id, - _mask_url(callback_url), - resp.status, - ) - except Exception as exc: - logger.warning( - "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", - request_id, - _mask_url(callback_url), - exc, - ) - - # ------------------------------------------------------------------ - # Naga Bind / Send / Unbind API - # ------------------------------------------------------------------ - + # Naga def _verify_naga_api_key(self, request: web.Request) -> str | None: - """校验 Naga 共享密钥,返回错误信息或 None 表示通过。""" - cfg = self._ctx.config_getter() - expected = cfg.naga.api_key - if not expected: - return "naga api_key not configured" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return "missing or invalid Authorization header" - provided = auth_header[7:] - import secrets as _secrets - - if not _secrets.compare_digest(provided, expected): - return "invalid api_key" - return None + return naga.verify_naga_api_key(self._ctx, request) async def _naga_bind_callback_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - status = str(body.get("status", "") or "").strip().lower() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - reason = str(body.get("reason", "") or "").strip() - if not bind_uuid or not naga_id: - return _json_error("bind_uuid and naga_id are required", status=400) - if status not in {"approved", "rejected"}: - return _json_error("status must be 'approved' or 'rejected'", status=400) - logger.info( - "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - status, - _short_text_preview(reason, limit=60), - delivery_signature[:12] + "..." if delivery_signature else "", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - sender = self._ctx.sender - if status == "approved": - if not delivery_signature: - return _json_error( - "delivery_signature is required when approved", status=400 - ) - binding, created, err = await naga_store.activate_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - delivery_signature=delivery_signature, - ) - if err: - logger.warning( - "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - created, - binding.qq_id if binding is not None else "", - ) - if created and binding is not None and sender is not None: - try: - await sender.send_private_message( - binding.qq_id, - f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "approved", - "idempotent": not created, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) - - pending, removed, err = await naga_store.reject_binding( - bind_uuid=bind_uuid, - naga_id=naga_id, - reason=reason, - ) - if err: - logger.warning( - "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message, - ) - return _json_error(err.message, status=err.http_status) - logger.info( - "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - removed, - pending.qq_id if pending is not None else "", - ) - if removed and pending is not None and sender is not None: - try: - detail = f"\n原因: {reason}" if reason else "" - await sender.send_private_message( - pending.qq_id, - f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", - ) - except Exception as exc: - logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) - return web.json_response( - { - "ok": True, - "status": "rejected", - "idempotent": not removed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_bind_callback_handler(self._ctx, request) async def _naga_messages_send_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/messages/send — 验签后发送消息。""" - from Undefined.api.naga_store import mask_token - - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - request_uuid = str(body.get("uuid", "") or "").strip() - target = body.get("target") - message = body.get("message") - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - if not isinstance(target, dict): - return _json_error("target object is required", status=400) - if not isinstance(message, dict): - return _json_error("message object is required", status=400) - - raw_target_qq = target.get("qq_id") - raw_target_group = target.get("group_id") - if raw_target_qq is None or raw_target_group is None: - return _json_error( - "target.qq_id and target.group_id are required", status=400 - ) - try: - target_qq = int(raw_target_qq) - target_group = int(raw_target_group) - except Exception: - return _json_error( - "target.qq_id and target.group_id must be integers", status=400 - ) - mode = str(target.get("mode", "") or "").strip().lower() - if mode not in {"private", "group", "both"}: - return _json_error( - "target.mode must be 'private', 'group', or 'both'", status=400 - ) - - fmt = str(message.get("format", "text") or "text").strip().lower() - content = str(message.get("content", "") or "").strip() - if fmt not in {"text", "markdown", "html"}: - return _json_error( - "message.format must be 'text', 'markdown', or 'html'", status=400 - ) - if not content: - return _json_error("message.content is required", status=400) - - message_key = _naga_message_digest( - bind_uuid=bind_uuid, - naga_id=naga_id, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - ) - logger.info( - "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - request_uuid, - mode, - fmt, - target_qq, - target_group, - message_key, - len(content), - _short_text_preview(content), - mask_token(delivery_signature), + return await naga.naga_messages_send_handler( + self._ctx, self._naga_state, request ) - if mode == "both": - logger.warning( - "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - inflight_count = await self._track_naga_send_start(message_key) - if inflight_count > 1: - logger.warning( - "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - inflight_count, - ) - try: - if request_uuid: - dedupe_action, dedupe_value = await self._register_naga_request_uuid( - request_uuid, message_key - ) - if dedupe_action == "conflict": - logger.warning( - "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return _json_error("uuid reused with different payload", status=409) - if dedupe_action == "cached": - cached_status, cached_payload = dedupe_value - logger.warning( - "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - if dedupe_action == "await": - wait_future = dedupe_value - logger.warning( - "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - ) - cached_status, cached_payload = await wait_future - return web.json_response( - deepcopy(cached_payload), - status=int(cached_status), - ) - - response = await self._naga_messages_send_impl( - naga_id=naga_id, - bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - target_qq=target_qq, - target_group=target_group, - mode=mode, - message_format=fmt, - content=content, - trace_id=trace_id, - message_key=message_key, - ) - if request_uuid: - await self._finish_naga_request_uuid( - request_uuid, - message_key, - status=response.status, - payload=_parse_response_payload(response), - ) - return response - except Exception as exc: - if request_uuid: - await self._fail_naga_request_uuid(request_uuid, message_key, exc) - raise - finally: - remaining = await self._track_naga_send_done(message_key) - logger.info( - "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", - trace_id, - naga_id, - bind_uuid, - request_uuid, - message_key, - remaining, - ) async def _naga_messages_send_impl( self, @@ -2000,492 +315,19 @@ async def _naga_messages_send_impl( trace_id: str, message_key: str, ) -> Response: - from Undefined.api.naga_store import mask_token - - naga_store = self._ctx.naga_store - if naga_store is None: - logger.warning( - "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("Naga integration not available", status=503) - - binding, err_msg = await naga_store.acquire_delivery( + return await naga.naga_messages_send_impl( + self._ctx, naga_id=naga_id, bind_uuid=bind_uuid, delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=message_format, + content=content, + trace_id=trace_id, + message_key=message_key, ) - if binding is None: - logger.warning( - "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", - trace_id, - naga_id, - bind_uuid, - err_msg.message if err_msg is not None else "unknown_error", - mask_token(delivery_signature), - ) - return _json_error( - err_msg.message if err_msg is not None else "delivery not available", - status=err_msg.http_status if err_msg is not None else 403, - ) - - logger.info( - "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - binding.qq_id, - binding.group_id, - ) - try: - if target_qq != binding.qq_id or target_group != binding.group_id: - logger.warning( - "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", - trace_id, - naga_id, - bind_uuid, - target_qq, - target_group, - binding.qq_id, - binding.group_id, - ) - return _json_error("target does not match bound qq/group", status=403) - - cfg = self._ctx.config_getter() - if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: - logger.warning( - "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", - trace_id, - naga_id, - bind_uuid, - binding.group_id, - ) - return _json_error( - "bound group is not in naga.allowed_groups", status=403 - ) - - sender = self._ctx.sender - if sender is None: - logger.warning( - "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - return _json_error("sender not available", status=503) - - moderation: dict[str, Any] - naga_cfg = getattr(cfg, "naga", None) - moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) - security = getattr(self._ctx.command_dispatcher, "security", None) - if not moderation_enabled: - moderation = { - "status": "skipped_disabled", - "blocked": False, - "categories": [], - "message": "Naga moderation disabled by config; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - ) - elif security is None or not hasattr(security, "moderate_naga_message"): - moderation = { - "status": "error_allowed", - "blocked": False, - "categories": [], - "message": "Naga moderation service unavailable; message sent without moderation block", - "model_name": "", - } - logger.warning( - "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", - trace_id, - naga_id, - bind_uuid, - ) - else: - logger.info( - "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - len(content), - ) - result = await security.moderate_naga_message( - message_format=message_format, - content=content, - ) - moderation = { - "status": result.status, - "blocked": result.blocked, - "categories": result.categories, - "message": result.message, - "model_name": result.model_name, - } - logger.info( - "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - result.blocked, - result.status, - result.model_name, - ",".join(result.categories) or "-", - ) - if moderation["blocked"]: - logger.warning( - "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - moderation["message"], - ) - return web.json_response( - { - "ok": False, - "error": "message blocked by moderation", - "moderation": moderation, - }, - status=403, - ) - - send_content: str | None = content if message_format == "text" else None - image_path: str | None = None - tmp_path: str | None = None - rendered = False - render_fallback = False - if message_format in {"markdown", "html"}: - import tempfile - - try: - html_str = content - if message_format == "markdown": - html_str = await render_markdown_to_html(content) - fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") - os.close(fd) - await render_html_to_image(html_str, tmp_path) - image_path = tmp_path - rendered = True - logger.info( - "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - message_format, - Path(tmp_path).name if tmp_path is not None else "", - ) - except Exception as exc: - logger.warning( - "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - exc, - ) - send_content = content - render_fallback = True - - sent_private = False - sent_group = False - group_policy_blocked = False - - async def _ensure_delivery_active() -> tuple[Any, Response | None]: - current_binding, live_err = await naga_store.ensure_delivery_active( - naga_id=naga_id, - bind_uuid=bind_uuid, - ) - if current_binding is None: - logger.warning( - "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - live_err.message - if live_err is not None - else "delivery no longer active", - ) - return None, web.json_response( - { - "ok": False, - "error": ( - live_err.message - if live_err is not None - else "delivery no longer active" - ), - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=live_err.http_status if live_err is not None else 409, - ) - return current_binding, None - - try: - cq_image: str | None = None - if image_path is not None: - file_uri = Path(image_path).resolve().as_uri() - cq_image = f"[CQ:image,file={file_uri}]" - - if mode in {"private", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - logger.info( - "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - try: - if send_content is not None: - await sender.send_private_message( - current_binding.qq_id, send_content - ) - elif cq_image is not None: - await sender.send_private_message( - current_binding.qq_id, cq_image - ) - sent_private = True - logger.info( - "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.qq_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.qq_id, - message_key, - exc, - ) - - if mode in {"group", "both"}: - current_binding, abort_response = await _ensure_delivery_active() - if abort_response is not None: - return abort_response - current_cfg = self._ctx.config_getter() - if current_binding.group_id not in current_cfg.naga.allowed_groups: - group_policy_blocked = True - logger.warning( - "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - else: - logger.info( - "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - try: - if send_content is not None: - await sender.send_group_message( - current_binding.group_id, send_content - ) - elif cq_image is not None: - await sender.send_group_message( - current_binding.group_id, cq_image - ) - sent_group = True - logger.info( - "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - current_binding.group_id, - ) - except Exception as exc: - logger.warning( - "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", - trace_id, - naga_id, - current_binding.group_id, - message_key, - exc, - ) - finally: - if tmp_path is not None: - try: - os.unlink(tmp_path) - except OSError: - pass - - if mode == "private" and not sent_private: - return web.json_response( - { - "ok": False, - "error": "private delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "group" and not sent_group: - return web.json_response( - { - "ok": False, - "error": "group delivery failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - if mode == "both" and not (sent_private or sent_group): - if group_policy_blocked: - return web.json_response( - { - "ok": False, - "error": "bound group is not in naga.allowed_groups", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=403, - ) - return web.json_response( - { - "ok": False, - "error": "all deliveries failed", - "sent_private": sent_private, - "sent_group": sent_group, - "moderation": moderation, - }, - status=502, - ) - - await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) - partial_success = mode == "both" and (sent_private != sent_group) - logger.info( - "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", - trace_id, - naga_id, - bind_uuid, - message_key, - sent_private, - sent_group, - partial_success, - rendered, - render_fallback, - ) - return web.json_response( - { - "ok": True, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - "sent_private": sent_private, - "sent_group": sent_group, - "partial_success": partial_success, - "delivery_status": ( - "partial_success" if partial_success else "full_success" - ), - "rendered": rendered, - "render_fallback": render_fallback, - "moderation": moderation, - } - ) - finally: - await naga_store.release_delivery(bind_uuid=bind_uuid) async def _naga_unbind_handler(self, request: web.Request) -> Response: - """POST /api/v1/naga/unbind — 远端主动解绑。""" - trace_id = _uuid.uuid4().hex[:8] - auth_err = self._verify_naga_api_key(request) - if auth_err is not None: - logger.warning( - "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", - trace_id, - getattr(request, "remote", None), - auth_err, - ) - return _json_error("Unauthorized", status=401) - - try: - body = await request.json() - except Exception: - return _json_error("Invalid JSON", status=400) - - bind_uuid = str(body.get("bind_uuid", "") or "").strip() - naga_id = str(body.get("naga_id", "") or "").strip() - delivery_signature = str(body.get("delivery_signature", "") or "").strip() - if not bind_uuid or not naga_id or not delivery_signature: - return _json_error( - "bind_uuid, naga_id and delivery_signature are required", - status=400, - ) - logger.info( - "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", - trace_id, - getattr(request, "remote", None), - naga_id, - bind_uuid, - delivery_signature[:12] + "...", - ) - - naga_store = self._ctx.naga_store - if naga_store is None: - return _json_error("Naga integration not available", status=503) - - binding, changed, err = await naga_store.revoke_binding( - naga_id, - expected_bind_uuid=bind_uuid, - delivery_signature=delivery_signature, - ) - if binding is None: - logger.warning( - "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", - trace_id, - naga_id, - bind_uuid, - err.message if err is not None else "binding not found", - ) - return _json_error( - err.message if err is not None else "binding not found", - status=err.http_status if err is not None else 404, - ) - logger.info( - "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", - trace_id, - naga_id, - bind_uuid, - changed, - binding.qq_id, - binding.group_id, - ) - return web.json_response( - { - "ok": True, - "idempotent": not changed, - "naga_id": naga_id, - "bind_uuid": bind_uuid, - } - ) + return await naga.naga_unbind_handler(self._ctx, request) diff --git a/src/Undefined/api/routes/__init__.py b/src/Undefined/api/routes/__init__.py new file mode 100644 index 0000000..8a999d1 --- /dev/null +++ b/src/Undefined/api/routes/__init__.py @@ -0,0 +1 @@ +"""Runtime API route modules.""" diff --git a/src/Undefined/api/routes/chat.py b/src/Undefined/api/routes/chat.py new file mode 100644 index 0000000..536435c --- /dev/null +++ b/src/Undefined/api/routes/chat.py @@ -0,0 +1,359 @@ +"""Chat route handlers extracted from the Runtime API application.""" + +from __future__ import annotations + +import asyncio +import logging +from contextlib import suppress +from datetime import datetime +from typing import Any, Awaitable, Callable + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _WebUIVirtualSender, + _build_chat_response_payload, + _json_error, + _sse_event, + _to_bool, +) +from Undefined.attachments import ( + attachment_refs_to_xml, + build_attachment_scope, + register_message_attachments, + render_message_with_pic_placeholders, +) +from Undefined.context import RequestContext +from Undefined.context_resource_registry import collect_context_resources +from Undefined.services.queue_manager import QUEUE_LANE_SUPERADMIN +from Undefined.utils.common import message_to_segments +from Undefined.utils.recent_messages import get_recent_messages_prefer_local +from Undefined.utils.xml import escape_xml_attr, escape_xml_text + +logger = logging.getLogger(__name__) + +_VIRTUAL_USER_NAME = "system" +_CHAT_SSE_KEEPALIVE_SECONDS = 10.0 + + +async def run_webui_chat( + ctx: RuntimeAPIContext, + *, + text: str, + send_output: Callable[[int, str], Awaitable[None]], +) -> str: + """Execute a single WebUI chat turn (command dispatch or AI ask).""" + + cfg = ctx.config_getter() + permission_sender_id = int(cfg.superadmin_qq) + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + input_segments = message_to_segments(text) + registered_input = await register_message_attachments( + registry=ctx.ai.attachment_registry, + segments=input_segments, + scope_key=webui_scope_key, + resolve_image_url=ctx.onebot.get_image, + get_forward_messages=ctx.onebot.get_forward_msg, + ) + normalized_text = registered_input.normalized_text or text + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=normalized_text, + display_name=_VIRTUAL_USER_NAME, + user_name=_VIRTUAL_USER_NAME, + attachments=registered_input.attachments, + ) + + command = ctx.command_dispatcher.parse_command(normalized_text) + if command: + await ctx.command_dispatcher.dispatch_private( + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + command=command, + send_private_callback=send_output, + is_webui_session=True, + ) + return "command" + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + attachment_xml = ( + f"\n{attachment_refs_to_xml(registered_input.attachments)}" + if registered_input.attachments + else "" + ) + full_question = f""" + {escape_xml_text(normalized_text)}{attachment_xml} + + +【WebUI 会话】 +这是一条来自 WebUI 控制台的会话请求。 +会话身份:虚拟用户 system(42)。 +权限等级:superadmin(你可按最高管理权限处理)。 +请正常进行私聊对话;如果需要结束会话,调用 end 工具。""" + virtual_sender = _WebUIVirtualSender( + _VIRTUAL_USER_ID, send_output, onebot=ctx.onebot + ) + + async def _get_recent_cb( + chat_id: str, msg_type: str, start: int, end: int + ) -> list[dict[str, Any]]: + return await get_recent_messages_prefer_local( + chat_id=chat_id, + msg_type=msg_type, + start=start, + end=end, + onebot_client=ctx.onebot, + history_manager=ctx.history_manager, + bot_qq=cfg.bot_qq, + attachment_registry=getattr(ctx.ai, "attachment_registry", None), + ) + + async with RequestContext( + request_type="private", + user_id=_VIRTUAL_USER_ID, + sender_id=permission_sender_id, + ) as rctx: + ai_client = ctx.ai # noqa: F841 + memory_storage = ctx.ai.memory_storage # noqa: F841 + runtime_config = ctx.ai.runtime_config # noqa: F841 + sender = virtual_sender # noqa: F841 + history_manager = ctx.history_manager # noqa: F841 + onebot_client = ctx.onebot # noqa: F841 + scheduler = ctx.scheduler # noqa: F841 + + def send_message_callback( + msg: str, reply_to: int | None = None + ) -> Awaitable[None]: + _ = reply_to + return send_output(_VIRTUAL_USER_ID, msg) + + get_recent_messages_callback = _get_recent_cb # noqa: F841 + get_image_url_callback = ctx.onebot.get_image # noqa: F841 + get_forward_msg_callback = ctx.onebot.get_forward_msg # noqa: F841 + resource_vars = dict(globals()) + resource_vars.update(locals()) + resources = collect_context_resources(resource_vars) + for key, value in resources.items(): + if value is not None: + rctx.set_resource(key, value) + rctx.set_resource("queue_lane", QUEUE_LANE_SUPERADMIN) + rctx.set_resource("webui_session", True) + rctx.set_resource("webui_permission", "superadmin") + + result = await ctx.ai.ask( + full_question, + send_message_callback=send_message_callback, + get_recent_messages_callback=get_recent_messages_callback, + get_image_url_callback=get_image_url_callback, + get_forward_msg_callback=get_forward_msg_callback, + sender=sender, + history_manager=history_manager, + onebot_client=onebot_client, + scheduler=scheduler, + extra_context={ + "is_private_chat": True, + "request_type": "private", + "user_id": _VIRTUAL_USER_ID, + "sender_name": _VIRTUAL_USER_NAME, + "webui_session": True, + "webui_permission": "superadmin", + }, + ) + + final_reply = str(result or "").strip() + if final_reply: + await send_output(_VIRTUAL_USER_ID, final_reply) + + return "chat" + + +async def chat_history_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """Return recent WebUI chat history.""" + + limit_raw = str(request.query.get("limit", "200") or "200").strip() + try: + limit = int(limit_raw) + except ValueError: + limit = 200 + limit = max(1, min(limit, 500)) + + getter = getattr(ctx.history_manager, "get_recent_private", None) + if not callable(getter): + return _json_error("History manager not ready", status=503) + + records = getter(_VIRTUAL_USER_ID, limit) + items: list[dict[str, Any]] = [] + for item in records: + if not isinstance(item, dict): + continue + content = str(item.get("message", "")).strip() + if not content: + continue + display_name = str(item.get("display_name", "")).strip().lower() + role = "bot" if display_name == "bot" else "user" + items.append( + { + "role": role, + "content": content, + "timestamp": str(item.get("timestamp", "") or "").strip(), + } + ) + + return web.json_response( + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + "count": len(items), + "items": items, + } + ) + + +async def chat_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> web.StreamResponse: + """Handle a WebUI chat request (non-streaming or SSE streaming).""" + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + text = str(body.get("message", "") or "").strip() + if not text: + return _json_error("message is required", status=400) + + stream = _to_bool(body.get("stream")) + outputs: list[str] = [] + webui_scope_key = build_attachment_scope( + user_id=_VIRTUAL_USER_ID, + request_type="private", + webui_session=True, + ) + + async def _capture_private_message(user_id: int, message: str) -> None: + _ = user_id + content = str(message or "").strip() + if not content: + return + rendered = await render_message_with_pic_placeholders( + content, + registry=ctx.ai.attachment_registry, + scope_key=webui_scope_key, + strict=False, + ) + if not rendered.delivery_text.strip(): + return + outputs.append(rendered.delivery_text) + await ctx.history_manager.add_private_message( + user_id=_VIRTUAL_USER_ID, + text_content=rendered.history_text, + display_name="Bot", + user_name="Bot", + attachments=rendered.attachments, + ) + + if not stream: + try: + mode = await run_webui_chat( + ctx, text=text, send_output=_capture_private_message + ) + except Exception as exc: + logger.exception("[RuntimeAPI] chat failed: %s", exc) + return _json_error("Chat failed", status=502) + return web.json_response(_build_chat_response_payload(mode, outputs)) + + response = web.StreamResponse( + status=200, + reason="OK", + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await response.prepare(request) + + message_queue: asyncio.Queue[str] = asyncio.Queue() + + async def _capture_private_message_stream(user_id: int, message: str) -> None: + output_count = len(outputs) + await _capture_private_message(user_id, message) + if len(outputs) <= output_count: + return + content = outputs[-1].strip() + if content: + await message_queue.put(content) + + task = asyncio.create_task( + run_webui_chat(ctx, text=text, send_output=_capture_private_message_stream) + ) + mode = "chat" + client_disconnected = False + try: + await response.write( + _sse_event( + "meta", + { + "virtual_user_id": _VIRTUAL_USER_ID, + "permission": "superadmin", + }, + ) + ) + + while True: + if request.transport is None or request.transport.is_closing(): + client_disconnected = True + break + if task.done() and message_queue.empty(): + break + try: + message = await asyncio.wait_for( + message_queue.get(), + timeout=_CHAT_SSE_KEEPALIVE_SECONDS, + ) + await response.write(_sse_event("message", {"content": message})) + except asyncio.TimeoutError: + await response.write(b": keep-alive\n\n") + + if client_disconnected: + task.cancel() + with suppress(asyncio.CancelledError): + await task + return response + + mode = await task + await response.write( + _sse_event("done", _build_chat_response_payload(mode, outputs)) + ) + except asyncio.CancelledError: + task.cancel() + with suppress(asyncio.CancelledError): + await task + raise + except (ConnectionResetError, RuntimeError): + task.cancel() + with suppress(asyncio.CancelledError): + await task + except Exception as exc: + logger.exception("[RuntimeAPI] chat stream failed: %s", exc) + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + with suppress(Exception): + await response.write(_sse_event("error", {"error": str(exc)})) + finally: + with suppress(Exception): + await response.write_eof() + + return response diff --git a/src/Undefined/api/routes/cognitive.py b/src/Undefined/api/routes/cognitive.py new file mode 100644 index 0000000..6609699 --- /dev/null +++ b/src/Undefined/api/routes/cognitive.py @@ -0,0 +1,86 @@ +"""Cognitive event & profile routes.""" + +from __future__ import annotations + +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param + + +async def cognitive_events_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + for key in ( + "target_user_id", + "target_group_id", + "sender_id", + "request_type", + "top_k", + "time_from", + "time_to", + ): + value = _optional_query_param(request, key) + if value is not None: + search_kwargs[key] = value + + results = await cognitive_service.search_events(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profiles_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + query = str(request.query.get("q", "") or "").strip() + if not query: + return _json_error("q is required", status=400) + + search_kwargs: dict[str, Any] = {"query": query} + entity_type = _optional_query_param(request, "entity_type") + if entity_type is not None: + search_kwargs["entity_type"] = entity_type + top_k = _optional_query_param(request, "top_k") + if top_k is not None: + search_kwargs["top_k"] = top_k + + results = await cognitive_service.search_profiles(**search_kwargs) + return web.json_response({"count": len(results), "items": results}) + + +async def cognitive_profile_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + cognitive_service = ctx.cognitive_service + if not cognitive_service or not cognitive_service.enabled: + return _json_error("Cognitive service disabled", status=400) + + entity_type = str(request.match_info.get("entity_type", "")).strip() + entity_id = str(request.match_info.get("entity_id", "")).strip() + if not entity_type or not entity_id: + return _json_error("entity_type/entity_id are required", status=400) + + profile = await cognitive_service.get_profile(entity_type, entity_id) + return web.json_response( + { + "entity_type": entity_type, + "entity_id": entity_id, + "profile": profile or "", + "found": bool(profile), + } + ) diff --git a/src/Undefined/api/routes/health.py b/src/Undefined/api/routes/health.py new file mode 100644 index 0000000..bbdb020 --- /dev/null +++ b/src/Undefined/api/routes/health.py @@ -0,0 +1,23 @@ +"""Health check route.""" + +from __future__ import annotations + +from datetime import datetime + +from aiohttp.web_response import Response +from aiohttp import web + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext + + +async def health_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = ctx, request + return web.json_response( + { + "ok": True, + "service": "undefined-runtime-api", + "version": __version__, + "timestamp": datetime.now().isoformat(), + } + ) diff --git a/src/Undefined/api/routes/memes.py b/src/Undefined/api/routes/memes.py new file mode 100644 index 0000000..dcef98a --- /dev/null +++ b/src/Undefined/api/routes/memes.py @@ -0,0 +1,222 @@ +"""Meme management route handlers.""" + +from __future__ import annotations + +from typing import Any, cast + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _to_bool + + +async def meme_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + + def _parse_optional_bool(name: str) -> bool | None: + raw = request.query.get(name) + if raw is None or str(raw).strip() == "": + return None + return _to_bool(raw) + + page_raw = _optional_query_param(request, "page") + page_size_raw = _optional_query_param(request, "page_size") + top_k_raw = _optional_query_param(request, "top_k") + query = str(request.query.get("q", "") or "").strip() + query_mode = str(request.query.get("query_mode", "") or "").strip().lower() + keyword_query = str(request.query.get("keyword_query", "") or "").strip() + semantic_query = str(request.query.get("semantic_query", "") or "").strip() + try: + page = int(page_raw) if page_raw is not None else 1 + page_size = int(page_size_raw) if page_size_raw is not None else 50 + top_k = int(top_k_raw) if top_k_raw is not None else page_size + except ValueError: + return _json_error("page/page_size/top_k must be integers", status=400) + page = max(1, page) + page_size = max(1, min(200, page_size)) + top_k = max(1, top_k) + sort = str(request.query.get("sort", "updated_at") or "updated_at").strip() + + enabled_filter = _parse_optional_bool("enabled") + animated_filter = _parse_optional_bool("animated") + pinned_filter = _parse_optional_bool("pinned") + if not (query or keyword_query or semantic_query) and sort == "relevance": + sort = "updated_at" + + if query or keyword_query or semantic_query: + has_post_filter = any( + f is not None for f in (enabled_filter, animated_filter, pinned_filter) + ) + requested_window = max(page * page_size, top_k) + if has_post_filter or page > 1 or sort != "relevance": + fetch_k = min(500, max(requested_window * 4, top_k)) + else: + fetch_k = min(500, requested_window) + search_payload = await meme_service.search_memes( + query, + query_mode=query_mode or meme_service.default_query_mode, + keyword_query=keyword_query or None, + semantic_query=semantic_query or None, + top_k=fetch_k, + include_disabled=enabled_filter is not True, + sort=sort, + ) + filtered_items: list[dict[str, Any]] = [] + for item in list(search_payload.get("items") or []): + if ( + enabled_filter is not None + and bool(item.get("enabled")) != enabled_filter + ): + continue + if ( + animated_filter is not None + and bool(item.get("is_animated")) != animated_filter + ): + continue + if pinned_filter is not None and bool(item.get("pinned")) != pinned_filter: + continue + filtered_items.append(item) + offset = (page - 1) * page_size + paged_items = filtered_items[offset : offset + page_size] + window_total = len(filtered_items) + fetched_window_count = len(list(search_payload.get("items") or [])) + window_exhausted = fetched_window_count < fetch_k + has_more = bool(paged_items) and ( + offset + page_size < window_total + or (not window_exhausted and window_total >= offset + page_size) + ) + return web.json_response( + { + "ok": True, + "total": None, + "window_total": window_total, + "total_exact": False, + "page": page, + "page_size": page_size, + "has_more": has_more, + "query_mode": search_payload.get("query_mode"), + "keyword_query": search_payload.get("keyword_query"), + "semantic_query": search_payload.get("semantic_query"), + "sort": search_payload.get("sort", sort), + "items": paged_items, + } + ) + + payload = await meme_service.list_memes( + query=query, + enabled=enabled_filter, + animated=animated_filter, + pinned=pinned_filter, + sort=sort, + page=page, + page_size=page_size, + summary=True, + ) + return web.json_response(payload) + + +async def meme_stats_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + return web.json_response(await meme_service.stats()) + + +async def meme_detail_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + detail = await meme_service.get_meme(uid) + if detail is None: + return _json_error("Meme not found", status=404) + return web.json_response(detail) + + +async def meme_blob_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=False) + if path is None: + return _json_error("Meme blob not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_preview_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + path = await meme_service.blob_path_for_uid(uid, preview=True) + if path is None: + return _json_error("Meme preview not found", status=404) + return cast(Response, web.FileResponse(path=path)) + + +async def meme_update_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + try: + payload = await request.json() + except Exception: + return _json_error("Invalid JSON body", status=400) + if not isinstance(payload, dict): + return _json_error("JSON body must be an object", status=400) + updated = await meme_service.update_meme( + uid, + manual_description=payload.get("manual_description"), + tags=payload.get("tags"), + aliases=payload.get("aliases"), + enabled=payload.get("enabled") if "enabled" in payload else None, + pinned=payload.get("pinned") if "pinned" in payload else None, + ) + if updated is None: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "record": updated}) + + +async def meme_delete_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + deleted = await meme_service.delete_meme(uid) + if not deleted: + return _json_error("Meme not found", status=404) + return web.json_response({"ok": True, "uid": uid}) + + +async def meme_reanalyze_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reanalyze(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) + + +async def meme_reindex_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + meme_service = ctx.meme_service + if meme_service is None or not meme_service.enabled: + return _json_error("Meme service disabled", status=400) + uid = str(request.match_info.get("uid", "")).strip() + job_id = await meme_service.enqueue_reindex(uid) + if not job_id: + return _json_error("Meme queue unavailable", status=503) + return web.json_response({"ok": True, "uid": uid, "job_id": job_id}) diff --git a/src/Undefined/api/routes/memory.py b/src/Undefined/api/routes/memory.py new file mode 100644 index 0000000..30c114e --- /dev/null +++ b/src/Undefined/api/routes/memory.py @@ -0,0 +1,156 @@ +"""Memory CRUD routes.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import _json_error, _optional_query_param, _parse_query_time + + +async def memory_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + query = str(request.query.get("q", "") or "").strip().lower() + top_k_raw = _optional_query_param(request, "top_k") + time_from_raw = _optional_query_param(request, "time_from") + time_to_raw = _optional_query_param(request, "time_to") + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + + limit: int | None = None + if top_k_raw is not None: + try: + limit = int(top_k_raw) + except ValueError: + return _json_error("top_k must be an integer", status=400) + if limit <= 0: + return _json_error("top_k must be > 0", status=400) + + time_from_dt = _parse_query_time(time_from_raw) + if time_from_raw is not None and time_from_dt is None: + return _json_error("time_from must be ISO datetime", status=400) + time_to_dt = _parse_query_time(time_to_raw) + if time_to_raw is not None and time_to_dt is None: + return _json_error("time_to must be ISO datetime", status=400) + if time_from_dt and time_to_dt and time_from_dt > time_to_dt: + time_from_dt, time_to_dt = time_to_dt, time_from_dt + + records = memory_storage.get_all() + items: list[dict[str, Any]] = [] + for item in records: + created_at = str(item.created_at or "").strip() + created_dt = _parse_query_time(created_at) + if time_from_dt and created_dt and created_dt < time_from_dt: + continue + if time_to_dt and created_dt and created_dt > time_to_dt: + continue + if (time_from_dt or time_to_dt) and created_dt is None: + continue + items.append( + { + "uuid": item.uuid, + "fact": item.fact, + "created_at": created_at, + } + ) + if query: + items = [ + item + for item in items + if query in str(item.get("fact", "")).lower() + or query in str(item.get("uuid", "")).lower() + ] + + def _created_sort_key(item: dict[str, Any]) -> float: + created_dt = _parse_query_time(str(item.get("created_at") or "")) + if created_dt is None: + return float("-inf") + with suppress(OSError, OverflowError, ValueError): + return float(created_dt.timestamp()) + return float("-inf") + + items.sort(key=_created_sort_key) + if limit is not None: + items = items[:limit] + + return web.json_response( + { + "total": len(items), + "items": items, + "query": { + "q": query or "", + "top_k": limit, + "time_from": time_from_raw, + "time_to": time_to_raw, + }, + } + ) + + +async def memory_create_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + new_uuid = await memory_storage.add(fact) + if new_uuid is None: + return _json_error("Failed to create memory", status=500) + existing = [m for m in memory_storage.get_all() if m.uuid == new_uuid] + item = existing[0] if existing else None + return web.json_response( + { + "uuid": new_uuid, + "fact": item.fact if item else fact, + "created_at": item.created_at if item else "", + }, + status=201, + ) + + +async def memory_update_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + fact = str(body.get("fact", "") or "").strip() + if not fact: + return _json_error("fact must not be empty", status=400) + ok = await memory_storage.update(target_uuid, fact) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "fact": fact, "updated": True}) + + +async def memory_delete_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + memory_storage = getattr(ctx.ai, "memory_storage", None) + if memory_storage is None: + return _json_error("Memory storage not ready", status=503) + target_uuid = str(request.match_info.get("uuid", "")).strip() + if not target_uuid: + return _json_error("uuid is required", status=400) + ok = await memory_storage.delete(target_uuid) + if not ok: + return _json_error(f"Memory {target_uuid} not found", status=404) + return web.json_response({"uuid": target_uuid, "deleted": True}) diff --git a/src/Undefined/api/routes/naga.py b/src/Undefined/api/routes/naga.py new file mode 100644 index 0000000..b76375d --- /dev/null +++ b/src/Undefined/api/routes/naga.py @@ -0,0 +1,897 @@ +"""Naga integration route handlers. + +Extracted from ``RuntimeAPI`` methods into free functions so they can be +registered declaratively in the route table. +""" + +from __future__ import annotations + +import logging +import os +import uuid as _uuid +from copy import deepcopy +from pathlib import Path +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _json_error, + _naga_message_digest, + _parse_response_payload, + _short_text_preview, +) +from Undefined.api._naga_state import NagaState +from Undefined.render import render_html_to_image, render_markdown_to_html + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Auth helper +# ------------------------------------------------------------------ + + +def verify_naga_api_key(ctx: RuntimeAPIContext, request: web.Request) -> str | None: + """校验 Naga 共享密钥,返回错误信息或 ``None`` 表示通过。""" + import secrets as _secrets + + cfg = ctx.config_getter() + expected = cfg.naga.api_key + if not expected: + return "naga api_key not configured" + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return "missing or invalid Authorization header" + provided = auth_header[7:] + if not _secrets.compare_digest(provided, expected): + return "invalid api_key" + return None + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/bind/callback +# ------------------------------------------------------------------ + + +async def naga_bind_callback_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + """POST /api/v1/naga/bind/callback — Naga 绑定回调。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaBindCallback] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + status = str(body.get("status", "") or "").strip().lower() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + reason = str(body.get("reason", "") or "").strip() + if not bind_uuid or not naga_id: + return _json_error("bind_uuid and naga_id are required", status=400) + if status not in {"approved", "rejected"}: + return _json_error("status must be 'approved' or 'rejected'", status=400) + logger.info( + "[NagaBindCallback] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s status=%s reason=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + status, + _short_text_preview(reason, limit=60), + delivery_signature[:12] + "..." if delivery_signature else "", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + sender = ctx.sender + if status == "approved": + if not delivery_signature: + return _json_error( + "delivery_signature is required when approved", status=400 + ) + binding, created, err = await naga_store.activate_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + delivery_signature=delivery_signature, + ) + if err: + logger.warning( + "[NagaBindCallback] 激活失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 激活完成: trace=%s naga_id=%s bind_uuid=%s created=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + created, + binding.qq_id if binding is not None else "", + ) + if created and binding is not None and sender is not None: + try: + await sender.send_private_message( + binding.qq_id, + f"🎉 你的 Naga 绑定已生效\nnaga_id: {naga_id}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定成功失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "approved", + "idempotent": not created, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + # --- rejected --- + pending, removed, err = await naga_store.reject_binding( + bind_uuid=bind_uuid, + naga_id=naga_id, + reason=reason, + ) + if err: + logger.warning( + "[NagaBindCallback] 拒绝失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message, + ) + return _json_error(err.message, status=err.http_status) + logger.info( + "[NagaBindCallback] 拒绝完成: trace=%s naga_id=%s bind_uuid=%s removed=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + removed, + pending.qq_id if pending is not None else "", + ) + if removed and pending is not None and sender is not None: + try: + detail = f"\n原因: {reason}" if reason else "" + await sender.send_private_message( + pending.qq_id, + f"❌ 你的 Naga 绑定被远端拒绝\nnaga_id: {naga_id}{detail}", + ) + except Exception as exc: + logger.warning("[NagaBindCallback] 通知绑定拒绝失败: %s", exc) + return web.json_response( + { + "ok": True, + "status": "rejected", + "idempotent": not removed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/messages/send +# ------------------------------------------------------------------ + + +async def naga_messages_send_handler( + ctx: RuntimeAPIContext, + naga_state: NagaState, + request: web.Request, +) -> Response: + """POST /api/v1/naga/messages/send — 验签后发送消息。""" + from Undefined.api.naga_store import mask_token + + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning("[NagaSend] 鉴权失败: trace=%s err=%s", trace_id, auth_err) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + request_uuid = str(body.get("uuid", "") or "").strip() + target = body.get("target") + message = body.get("message") + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + if not isinstance(target, dict): + return _json_error("target object is required", status=400) + if not isinstance(message, dict): + return _json_error("message object is required", status=400) + + raw_target_qq = target.get("qq_id") + raw_target_group = target.get("group_id") + if raw_target_qq is None or raw_target_group is None: + return _json_error("target.qq_id and target.group_id are required", status=400) + try: + target_qq = int(raw_target_qq) + target_group = int(raw_target_group) + except Exception: + return _json_error( + "target.qq_id and target.group_id must be integers", status=400 + ) + mode = str(target.get("mode", "") or "").strip().lower() + if mode not in {"private", "group", "both"}: + return _json_error( + "target.mode must be 'private', 'group', or 'both'", status=400 + ) + + fmt = str(message.get("format", "text") or "text").strip().lower() + content = str(message.get("content", "") or "").strip() + if fmt not in {"text", "markdown", "html"}: + return _json_error( + "message.format must be 'text', 'markdown', or 'html'", status=400 + ) + if not content: + return _json_error("message.content is required", status=400) + + message_key = _naga_message_digest( + bind_uuid=bind_uuid, + naga_id=naga_id, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + ) + logger.info( + "[NagaSend] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s request_uuid=%s mode=%s fmt=%s qq=%s group=%s key=%s content_len=%s preview=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + request_uuid, + mode, + fmt, + target_qq, + target_group, + message_key, + len(content), + _short_text_preview(content), + mask_token(delivery_signature), + ) + if mode == "both": + logger.warning( + "[NagaSend] 上游请求显式要求双路投递: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + inflight_count = await naga_state.track_send_start(message_key) + if inflight_count > 1: + logger.warning( + "[NagaSend] 检测到相同 payload 并发请求: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + inflight_count, + ) + try: + if request_uuid: + dedupe_action, dedupe_value = await naga_state.register_request_uuid( + request_uuid, message_key + ) + if dedupe_action == "conflict": + logger.warning( + "[NagaSend] uuid 与历史 payload 冲突: trace=%s naga_id=%s bind_uuid=%s uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return _json_error("uuid reused with different payload", status=409) + if dedupe_action == "cached": + cached_status, cached_payload = dedupe_value + logger.warning( + "[NagaSend] 命中已完成幂等结果,直接复用: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + if dedupe_action == "await": + wait_future = dedupe_value + logger.warning( + "[NagaSend] 命中进行中幂等请求,等待首个结果: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + ) + cached_status, cached_payload = await wait_future + return web.json_response( + deepcopy(cached_payload), + status=int(cached_status), + ) + + response = await naga_messages_send_impl( + ctx, + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + target_qq=target_qq, + target_group=target_group, + mode=mode, + message_format=fmt, + content=content, + trace_id=trace_id, + message_key=message_key, + ) + if request_uuid: + await naga_state.finish_request_uuid( + request_uuid, + message_key, + status=response.status, + payload=_parse_response_payload(response), + ) + return response + except Exception as exc: + if request_uuid: + await naga_state.fail_request_uuid(request_uuid, message_key, exc) + raise + finally: + remaining = await naga_state.track_send_done(message_key) + logger.info( + "[NagaSend] 请求退出: trace=%s naga_id=%s bind_uuid=%s request_uuid=%s key=%s inflight_remaining=%s", + trace_id, + naga_id, + bind_uuid, + request_uuid, + message_key, + remaining, + ) + + +# ------------------------------------------------------------------ +# Core send implementation (no NagaState dependency) +# ------------------------------------------------------------------ + + +async def naga_messages_send_impl( + ctx: RuntimeAPIContext, + *, + naga_id: str, + bind_uuid: str, + delivery_signature: str, + target_qq: int, + target_group: int, + mode: str, + message_format: str, + content: str, + trace_id: str, + message_key: str, +) -> Response: + from Undefined.api.naga_store import mask_token + + naga_store = ctx.naga_store + if naga_store is None: + logger.warning( + "[NagaSend] NagaStore 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("Naga integration not available", status=503) + + binding, err_msg = await naga_store.acquire_delivery( + naga_id=naga_id, + bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaSend] 签名校验失败: trace=%s naga_id=%s bind_uuid=%s reason=%s signature=%s", + trace_id, + naga_id, + bind_uuid, + err_msg.message if err_msg is not None else "unknown_error", + mask_token(delivery_signature), + ) + return _json_error( + err_msg.message if err_msg is not None else "delivery not available", + status=err_msg.http_status if err_msg is not None else 403, + ) + + logger.info( + "[NagaSend] 投递凭证已占用: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + binding.qq_id, + binding.group_id, + ) + try: + if target_qq != binding.qq_id or target_group != binding.group_id: + logger.warning( + "[NagaSend] 目标不匹配: trace=%s naga_id=%s bind_uuid=%s target_qq=%s target_group=%s bound_qq=%s bound_group=%s", + trace_id, + naga_id, + bind_uuid, + target_qq, + target_group, + binding.qq_id, + binding.group_id, + ) + return _json_error("target does not match bound qq/group", status=403) + + cfg = ctx.config_getter() + if mode == "group" and binding.group_id not in cfg.naga.allowed_groups: + logger.warning( + "[NagaSend] 群投递被策略拒绝: trace=%s naga_id=%s bind_uuid=%s group=%s", + trace_id, + naga_id, + bind_uuid, + binding.group_id, + ) + return _json_error("bound group is not in naga.allowed_groups", status=403) + + sender = ctx.sender + if sender is None: + logger.warning( + "[NagaSend] sender 不可用: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + return _json_error("sender not available", status=503) + + moderation: dict[str, Any] + naga_cfg = getattr(cfg, "naga", None) + moderation_enabled = bool(getattr(naga_cfg, "moderation_enabled", True)) + security = getattr(ctx.command_dispatcher, "security", None) + if not moderation_enabled: + moderation = { + "status": "skipped_disabled", + "blocked": False, + "categories": [], + "message": "Naga moderation disabled by config; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核已禁用,直接放行: trace=%s naga_id=%s bind_uuid=%s key=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + ) + elif security is None or not hasattr(security, "moderate_naga_message"): + moderation = { + "status": "error_allowed", + "blocked": False, + "categories": [], + "message": "Naga moderation service unavailable; message sent without moderation block", + "model_name": "", + } + logger.warning( + "[NagaSend] 审核服务不可用,按允许发送: trace=%s naga_id=%s bind_uuid=%s", + trace_id, + naga_id, + bind_uuid, + ) + else: + logger.info( + "[NagaSend] 审核开始: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s content_len=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + len(content), + ) + result = await security.moderate_naga_message( + message_format=message_format, + content=content, + ) + moderation = { + "status": result.status, + "blocked": result.blocked, + "categories": result.categories, + "message": result.message, + "model_name": result.model_name, + } + logger.info( + "[NagaSend] 审核完成: trace=%s naga_id=%s bind_uuid=%s key=%s blocked=%s status=%s model=%s categories=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + result.blocked, + result.status, + result.model_name, + ",".join(result.categories) or "-", + ) + if moderation["blocked"]: + logger.warning( + "[NagaSend] 审核拦截: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + moderation["message"], + ) + return web.json_response( + { + "ok": False, + "error": "message blocked by moderation", + "moderation": moderation, + }, + status=403, + ) + + send_content: str | None = content if message_format == "text" else None + image_path: str | None = None + tmp_path: str | None = None + rendered = False + render_fallback = False + if message_format in {"markdown", "html"}: + import tempfile + + try: + html_str = content + if message_format == "markdown": + html_str = await render_markdown_to_html(content) + fd, tmp_path = tempfile.mkstemp(suffix=".png", prefix="naga_send_") + os.close(fd) + await render_html_to_image(html_str, tmp_path) + image_path = tmp_path + rendered = True + logger.info( + "[NagaSend] 富文本渲染成功: trace=%s naga_id=%s bind_uuid=%s key=%s fmt=%s image=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + message_format, + Path(tmp_path).name if tmp_path is not None else "", + ) + except Exception as exc: + logger.warning( + "[NagaSend] 渲染失败,回退文本发送: trace=%s naga_id=%s bind_uuid=%s key=%s err=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + exc, + ) + send_content = content + render_fallback = True + + sent_private = False + sent_group = False + group_policy_blocked = False + + async def _ensure_delivery_active() -> tuple[Any, Response | None]: + current_binding, live_err = await naga_store.ensure_delivery_active( + naga_id=naga_id, + bind_uuid=bind_uuid, + ) + if current_binding is None: + logger.warning( + "[NagaSend] 投递中止: trace=%s naga_id=%s bind_uuid=%s key=%s reason=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + live_err.message + if live_err is not None + else "delivery no longer active", + ) + return None, web.json_response( + { + "ok": False, + "error": ( + live_err.message + if live_err is not None + else "delivery no longer active" + ), + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=live_err.http_status if live_err is not None else 409, + ) + return current_binding, None + + try: + cq_image: str | None = None + if image_path is not None: + file_uri = Path(image_path).resolve().as_uri() + cq_image = f"[CQ:image,file={file_uri}]" + + if mode in {"private", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + logger.info( + "[NagaSend] 私聊投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + try: + if send_content is not None: + await sender.send_private_message( + current_binding.qq_id, send_content + ) + elif cq_image is not None: + await sender.send_private_message( + current_binding.qq_id, cq_image + ) + sent_private = True + logger.info( + "[NagaSend] 私聊投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s qq=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.qq_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 私聊发送失败: trace=%s naga_id=%s qq=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.qq_id, + message_key, + exc, + ) + + if mode in {"group", "both"}: + current_binding, abort_response = await _ensure_delivery_active() + if abort_response is not None: + return abort_response + current_cfg = ctx.config_getter() + if current_binding.group_id not in current_cfg.naga.allowed_groups: + group_policy_blocked = True + logger.warning( + "[NagaSend] 群投递被策略阻止: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + else: + logger.info( + "[NagaSend] 群投递开始: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + try: + if send_content is not None: + await sender.send_group_message( + current_binding.group_id, send_content + ) + elif cq_image is not None: + await sender.send_group_message( + current_binding.group_id, cq_image + ) + sent_group = True + logger.info( + "[NagaSend] 群投递成功: trace=%s naga_id=%s bind_uuid=%s key=%s group=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + current_binding.group_id, + ) + except Exception as exc: + logger.warning( + "[NagaSend] 群聊发送失败: trace=%s naga_id=%s group=%d key=%s err=%s", + trace_id, + naga_id, + current_binding.group_id, + message_key, + exc, + ) + finally: + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + + if mode == "private" and not sent_private: + return web.json_response( + { + "ok": False, + "error": "private delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "group" and not sent_group: + return web.json_response( + { + "ok": False, + "error": "group delivery failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + if mode == "both" and not (sent_private or sent_group): + if group_policy_blocked: + return web.json_response( + { + "ok": False, + "error": "bound group is not in naga.allowed_groups", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=403, + ) + return web.json_response( + { + "ok": False, + "error": "all deliveries failed", + "sent_private": sent_private, + "sent_group": sent_group, + "moderation": moderation, + }, + status=502, + ) + + await naga_store.record_usage(naga_id, bind_uuid=bind_uuid) + partial_success = mode == "both" and (sent_private != sent_group) + logger.info( + "[NagaSend] 请求完成: trace=%s naga_id=%s bind_uuid=%s key=%s sent_private=%s sent_group=%s partial=%s rendered=%s fallback=%s", + trace_id, + naga_id, + bind_uuid, + message_key, + sent_private, + sent_group, + partial_success, + rendered, + render_fallback, + ) + return web.json_response( + { + "ok": True, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + "sent_private": sent_private, + "sent_group": sent_group, + "partial_success": partial_success, + "delivery_status": ( + "partial_success" if partial_success else "full_success" + ), + "rendered": rendered, + "render_fallback": render_fallback, + "moderation": moderation, + } + ) + finally: + await naga_store.release_delivery(bind_uuid=bind_uuid) + + +# ------------------------------------------------------------------ +# POST /api/v1/naga/unbind +# ------------------------------------------------------------------ + + +async def naga_unbind_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + """POST /api/v1/naga/unbind — 远端主动解绑。""" + trace_id = _uuid.uuid4().hex[:8] + auth_err = verify_naga_api_key(ctx, request) + if auth_err is not None: + logger.warning( + "[NagaUnbind] 鉴权失败: trace=%s remote=%s err=%s", + trace_id, + getattr(request, "remote", None), + auth_err, + ) + return _json_error("Unauthorized", status=401) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + bind_uuid = str(body.get("bind_uuid", "") or "").strip() + naga_id = str(body.get("naga_id", "") or "").strip() + delivery_signature = str(body.get("delivery_signature", "") or "").strip() + if not bind_uuid or not naga_id or not delivery_signature: + return _json_error( + "bind_uuid, naga_id and delivery_signature are required", + status=400, + ) + logger.info( + "[NagaUnbind] 请求开始: trace=%s remote=%s naga_id=%s bind_uuid=%s signature=%s", + trace_id, + getattr(request, "remote", None), + naga_id, + bind_uuid, + delivery_signature[:12] + "...", + ) + + naga_store = ctx.naga_store + if naga_store is None: + return _json_error("Naga integration not available", status=503) + + binding, changed, err = await naga_store.revoke_binding( + naga_id, + expected_bind_uuid=bind_uuid, + delivery_signature=delivery_signature, + ) + if binding is None: + logger.warning( + "[NagaUnbind] 吊销失败: trace=%s naga_id=%s bind_uuid=%s err=%s", + trace_id, + naga_id, + bind_uuid, + err.message if err is not None else "binding not found", + ) + return _json_error( + err.message if err is not None else "binding not found", + status=err.http_status if err is not None else 404, + ) + logger.info( + "[NagaUnbind] 吊销完成: trace=%s naga_id=%s bind_uuid=%s changed=%s qq=%s group=%s", + trace_id, + naga_id, + bind_uuid, + changed, + binding.qq_id, + binding.group_id, + ) + return web.json_response( + { + "ok": True, + "idempotent": not changed, + "naga_id": naga_id, + "bind_uuid": bind_uuid, + } + ) diff --git a/src/Undefined/api/routes/system.py b/src/Undefined/api/routes/system.py new file mode 100644 index 0000000..2cf3915 --- /dev/null +++ b/src/Undefined/api/routes/system.py @@ -0,0 +1,228 @@ +"""System / probe route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import platform +import sys +import time +from datetime import datetime +from typing import Any + +from aiohttp import web +from aiohttp.web_response import Response + +from Undefined import __version__ +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _VIRTUAL_USER_ID, + _json_error, + _mask_url, + _naga_routes_enabled, + _registry_summary, +) +from Undefined.api._openapi import _build_openapi_spec +from Undefined.api._probes import ( + _build_internal_model_probe_payload, + _probe_http_endpoint, + _probe_ws_endpoint, + _skipped_probe, +) + +logger = logging.getLogger(__name__) + +_PROCESS_START_TIME = time.time() + + +async def openapi_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + cfg = ctx.config_getter() + if not bool(getattr(cfg.api, "openapi_enabled", True)): + logger.info( + "[RuntimeAPI] OpenAPI 请求被拒绝: disabled remote=%s", request.remote + ) + return _json_error("OpenAPI disabled", status=404) + naga_routes_enabled = _naga_routes_enabled(cfg, ctx.naga_store) + logger.info( + "[RuntimeAPI] OpenAPI 请求: remote=%s naga_routes_enabled=%s", + request.remote, + naga_routes_enabled, + ) + return web.json_response(_build_openapi_spec(ctx, request)) + + +async def internal_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + queue_snapshot = ctx.queue_manager.snapshot() if ctx.queue_manager else {} + cognitive_queue_snapshot = ( + ctx.cognitive_job_queue.snapshot() if ctx.cognitive_job_queue else {} + ) + memory_storage = getattr(ctx.ai, "memory_storage", None) + memory_count = memory_storage.count() if memory_storage is not None else 0 + + # Skills 统计 + ai = ctx.ai + skills_info: dict[str, Any] = {} + if ai is not None: + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + anthropic_reg = getattr(ai, "anthropic_skill_registry", None) + skills_info["tools"] = _registry_summary(tool_reg) + skills_info["agents"] = _registry_summary(agent_reg) + skills_info["anthropic_skills"] = _registry_summary(anthropic_reg) + + # 模型配置(脱敏) + models_info: dict[str, Any] = {} + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + for label in ( + "chat_model", + "vision_model", + "agent_model", + "security_model", + "naga_model", + "grok_model", + ): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = _build_internal_model_probe_payload(mcfg) + if summary_model is not None: + models_info["summary_model"] = _build_internal_model_probe_payload( + summary_model + ) + for label in ("embedding_model", "rerank_model"): + mcfg = getattr(cfg, label, None) + if mcfg is not None: + models_info[label] = { + "model_name": getattr(mcfg, "model_name", ""), + "api_url": _mask_url(getattr(mcfg, "api_url", "")), + } + + uptime_seconds = round(time.time() - _PROCESS_START_TIME, 1) + payload = { + "timestamp": datetime.now().isoformat(), + "version": __version__, + "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "platform": platform.system(), + "uptime_seconds": uptime_seconds, + "onebot": ctx.onebot.connection_status() if ctx.onebot is not None else {}, + "queues": queue_snapshot, + "memory": {"count": memory_count, "virtual_user_id": _VIRTUAL_USER_ID}, + "cognitive": { + "enabled": bool(ctx.cognitive_service and ctx.cognitive_service.enabled), + "queue": cognitive_queue_snapshot, + }, + "api": { + "enabled": bool(cfg.api.enabled), + "host": cfg.api.host, + "port": cfg.api.port, + "openapi_enabled": bool(cfg.api.openapi_enabled), + }, + "skills": skills_info, + "models": models_info, + } + return web.json_response(payload) + + +async def external_probe_handler( + ctx: RuntimeAPIContext, request: web.Request +) -> Response: + _ = request + cfg = ctx.config_getter() + summary_model = getattr( + cfg, + "summary_model", + getattr(cfg, "agent_model", getattr(cfg, "chat_model", None)), + ) + naga_probe = ( + _probe_http_endpoint( + name="naga_model", + base_url=cfg.naga_model.api_url, + api_key=cfg.naga_model.api_key, + model_name=cfg.naga_model.model_name, + ) + if bool(cfg.api.enabled and cfg.nagaagent_mode_enabled and cfg.naga.enabled) + else _skipped_probe( + name="naga_model", + reason="naga_integration_disabled", + model_name=cfg.naga_model.model_name, + ) + ) + checks = [ + _probe_http_endpoint( + name="chat_model", + base_url=cfg.chat_model.api_url, + api_key=cfg.chat_model.api_key, + model_name=cfg.chat_model.model_name, + ), + _probe_http_endpoint( + name="vision_model", + base_url=cfg.vision_model.api_url, + api_key=cfg.vision_model.api_key, + model_name=cfg.vision_model.model_name, + ), + _probe_http_endpoint( + name="security_model", + base_url=cfg.security_model.api_url, + api_key=cfg.security_model.api_key, + model_name=cfg.security_model.model_name, + ), + naga_probe, + _probe_http_endpoint( + name="agent_model", + base_url=cfg.agent_model.api_url, + api_key=cfg.agent_model.api_key, + model_name=cfg.agent_model.model_name, + ), + ] + if summary_model is not None: + checks.append( + _probe_http_endpoint( + name="summary_model", + base_url=summary_model.api_url, + api_key=summary_model.api_key, + model_name=summary_model.model_name, + ) + ) + grok_model = getattr(cfg, "grok_model", None) + if grok_model is not None: + checks.append( + _probe_http_endpoint( + name="grok_model", + base_url=getattr(grok_model, "api_url", ""), + api_key=getattr(grok_model, "api_key", ""), + model_name=getattr(grok_model, "model_name", ""), + ) + ) + checks.extend( + [ + _probe_http_endpoint( + name="embedding_model", + base_url=cfg.embedding_model.api_url, + api_key=cfg.embedding_model.api_key, + model_name=getattr(cfg.embedding_model, "model_name", ""), + ), + _probe_http_endpoint( + name="rerank_model", + base_url=cfg.rerank_model.api_url, + api_key=cfg.rerank_model.api_key, + model_name=getattr(cfg.rerank_model, "model_name", ""), + ), + _probe_ws_endpoint(cfg.onebot_ws_url), + ] + ) + results = await asyncio.gather(*checks) + ok = all(item.get("status") in {"ok", "skipped"} for item in results) + return web.json_response( + { + "ok": ok, + "timestamp": datetime.now().isoformat(), + "results": results, + } + ) diff --git a/src/Undefined/api/routes/tools.py b/src/Undefined/api/routes/tools.py new file mode 100644 index 0000000..f305179 --- /dev/null +++ b/src/Undefined/api/routes/tools.py @@ -0,0 +1,462 @@ +"""Tools route handlers for the Runtime API.""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid as _uuid +from typing import Any, Awaitable + +from aiohttp import ClientSession, ClientTimeout, web +from aiohttp.web_response import Response + +from Undefined.api._context import RuntimeAPIContext +from Undefined.api._helpers import ( + _ToolInvokeExecutionTimeoutError, + _json_error, + _mask_url, + _to_bool, + _validate_callback_url, +) +from Undefined.context import RequestContext + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Pure helpers +# ------------------------------------------------------------------ + + +def get_filtered_tools(ctx: RuntimeAPIContext) -> list[dict[str, Any]]: + """按配置过滤可用工具,返回 OpenAI function calling schema 列表。""" + cfg = ctx.config_getter() + api_cfg = cfg.api + ai = ctx.ai + if ai is None: + return [] + + tool_reg = getattr(ai, "tool_registry", None) + agent_reg = getattr(ai, "agent_registry", None) + + all_schemas: list[dict[str, Any]] = [] + if tool_reg is not None: + all_schemas.extend(tool_reg.get_tools_schema()) + + # 收集 agent schema 并缓存名称集合(避免重复调用) + agent_names: set[str] = set() + if agent_reg is not None: + agent_schemas = agent_reg.get_agents_schema() + all_schemas.extend(agent_schemas) + for schema in agent_schemas: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + + denylist: set[str] = set(api_cfg.tool_invoke_denylist) + allowlist: set[str] = set(api_cfg.tool_invoke_allowlist) + expose = api_cfg.tool_invoke_expose + + def _get_name(schema: dict[str, Any]) -> str: + func = schema.get("function", {}) + return str(func.get("name", "")) + + # 1. 先排除黑名单 + if denylist: + all_schemas = [s for s in all_schemas if _get_name(s) not in denylist] + + # 2. 白名单非空时仅保留匹配项 + if allowlist: + return [s for s in all_schemas if _get_name(s) in allowlist] + + # 3. 按 expose 过滤 + if expose == "all": + return all_schemas + + def _is_tool(name: str) -> bool: + return "." not in name and name not in agent_names + + def _is_toolset(name: str) -> bool: + return "." in name and not name.startswith("mcp.") + + filtered: list[dict[str, Any]] = [] + for schema in all_schemas: + name = _get_name(schema) + if not name: + continue + if expose == "tools" and _is_tool(name): + filtered.append(schema) + elif expose == "toolsets" and _is_toolset(name): + filtered.append(schema) + elif expose == "tools+toolsets" and (_is_tool(name) or _is_toolset(name)): + filtered.append(schema) + elif expose == "agents" and name in agent_names: + filtered.append(schema) + + return filtered + + +def get_agent_tool_names(ctx: RuntimeAPIContext) -> set[str]: + ai = ctx.ai + if ai is None: + return set() + + agent_reg = getattr(ai, "agent_registry", None) + if agent_reg is None: + return set() + + agent_names: set[str] = set() + for schema in agent_reg.get_agents_schema(): + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + agent_names.add(name) + return agent_names + + +def resolve_tool_invoke_timeout( + ctx: RuntimeAPIContext, tool_name: str, timeout: int +) -> float | None: + if tool_name in get_agent_tool_names(ctx): + return None + return float(timeout) + + +# ------------------------------------------------------------------ +# Async helpers +# ------------------------------------------------------------------ + + +async def await_tool_invoke_result( + awaitable: Awaitable[Any], + *, + timeout: float | None, +) -> Any: + if timeout is None or timeout <= 0: + return await awaitable + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except asyncio.TimeoutError as exc: + raise _ToolInvokeExecutionTimeoutError from exc + + +# ------------------------------------------------------------------ +# Route handlers +# ------------------------------------------------------------------ + + +async def tools_list_handler(ctx: RuntimeAPIContext, request: web.Request) -> Response: + _ = request + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + tools = get_filtered_tools(ctx) + return web.json_response({"count": len(tools), "tools": tools}) + + +async def tools_invoke_handler( + ctx: RuntimeAPIContext, + background_tasks: set[asyncio.Task[Any]], + request: web.Request, +) -> Response: + cfg = ctx.config_getter() + if not cfg.api.tool_invoke_enabled: + return _json_error("Tool invoke API is disabled", status=403) + + try: + body = await request.json() + except Exception: + return _json_error("Invalid JSON", status=400) + + if not isinstance(body, dict): + return _json_error("Request body must be a JSON object", status=400) + + tool_name = str(body.get("tool_name", "") or "").strip() + if not tool_name: + return _json_error("tool_name is required", status=400) + + args = body.get("args") + if not isinstance(args, dict): + return _json_error("args must be a JSON object", status=400) + + # 验证工具是否在允许列表中 + filtered_tools = get_filtered_tools(ctx) + available_names: set[str] = set() + for schema in filtered_tools: + func = schema.get("function", {}) + name = str(func.get("name", "")) + if name: + available_names.add(name) + + if tool_name not in available_names: + caller_ip = request.remote or "unknown" + logger.warning( + "[ToolInvoke] 请求拒绝: tool=%s reason=not_available caller_ip=%s", + tool_name, + caller_ip, + ) + return _json_error(f"Tool '{tool_name}' is not available", status=404) + + # 解析回调配置 + callback_cfg = body.get("callback") + use_callback = False + callback_url = "" + callback_headers: dict[str, str] = {} + if isinstance(callback_cfg, dict) and _to_bool(callback_cfg.get("enabled")): + callback_url = str(callback_cfg.get("url", "") or "").strip() + if not callback_url: + return _json_error( + "callback.url is required when callback is enabled", + status=400, + ) + url_error = _validate_callback_url(callback_url) + if url_error: + return _json_error(url_error, status=400) + raw_headers = callback_cfg.get("headers") + if isinstance(raw_headers, dict): + callback_headers = {str(k): str(v) for k, v in raw_headers.items()} + use_callback = True + + request_id = _uuid.uuid4().hex + caller_ip = request.remote or "unknown" + logger.info( + "[ToolInvoke] 收到请求: request_id=%s tool=%s caller_ip=%s", + request_id, + tool_name, + caller_ip, + ) + + if use_callback: + # 异步执行 + 回调 + task = asyncio.create_task( + execute_and_callback( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + callback_url=callback_url, + callback_headers=callback_headers, + timeout=cfg.api.tool_invoke_timeout, + callback_timeout=cfg.api.tool_invoke_callback_timeout, + ) + ) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return web.json_response( + { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "status": "accepted", + } + ) + + # 同步执行 + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body.get("context"), + timeout=cfg.api.tool_invoke_timeout, + ) + return web.json_response(result) + + +# ------------------------------------------------------------------ +# Execution core +# ------------------------------------------------------------------ + + +async def execute_tool_invoke( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + timeout: int, +) -> dict[str, Any]: + """执行工具调用并返回结果字典。""" + ai = ctx.ai + if ai is None: + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": "AI client not ready", + "duration_ms": 0, + } + + # 解析请求上下文 + ctx_data: dict[str, Any] = {} + if isinstance(body_context, dict): + ctx_data = body_context + request_type = str(ctx_data.get("request_type", "api") or "api") + group_id = ctx_data.get("group_id") + user_id = ctx_data.get("user_id") + sender_id = ctx_data.get("sender_id") + + args_keys = list(args.keys()) + logger.info( + "[ToolInvoke] 开始执行: request_id=%s tool=%s args_keys=%s", + request_id, + tool_name, + args_keys, + ) + + start = time.perf_counter() + effective_timeout = resolve_tool_invoke_timeout(ctx, tool_name, timeout) + try: + async with RequestContext( + request_type=request_type, + group_id=int(group_id) if group_id is not None else None, + user_id=int(user_id) if user_id is not None else None, + sender_id=int(sender_id) if sender_id is not None else None, + ) as req_ctx: + # 注入核心服务资源 + if ctx.sender is not None: + req_ctx.set_resource("sender", ctx.sender) + if ctx.history_manager is not None: + req_ctx.set_resource("history_manager", ctx.history_manager) + runtime_config = getattr(ai, "runtime_config", None) + if runtime_config is not None: + req_ctx.set_resource("runtime_config", runtime_config) + memory_storage = getattr(ai, "memory_storage", None) + if memory_storage is not None: + req_ctx.set_resource("memory_storage", memory_storage) + if ctx.onebot is not None: + req_ctx.set_resource("onebot_client", ctx.onebot) + if ctx.scheduler is not None: + req_ctx.set_resource("scheduler", ctx.scheduler) + if ctx.cognitive_service is not None: + req_ctx.set_resource("cognitive_service", ctx.cognitive_service) + if ctx.meme_service is not None: + req_ctx.set_resource("meme_service", ctx.meme_service) + + tool_context: dict[str, Any] = { + "request_type": request_type, + "request_id": request_id, + } + if group_id is not None: + tool_context["group_id"] = int(group_id) + if user_id is not None: + tool_context["user_id"] = int(user_id) + if sender_id is not None: + tool_context["sender_id"] = int(sender_id) + + tool_manager = getattr(ai, "tool_manager", None) + if tool_manager is None: + raise RuntimeError("ToolManager not available") + + raw_result = await await_tool_invoke_result( + tool_manager.execute_tool(tool_name, args, tool_context), + timeout=effective_timeout, + ) + + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + result_str = str(raw_result or "") + logger.info( + "[ToolInvoke] 执行完成: request_id=%s tool=%s ok=true " + "duration_ms=%s result_len=%d", + request_id, + tool_name, + elapsed_ms, + len(result_str), + ) + return { + "ok": True, + "request_id": request_id, + "tool_name": tool_name, + "result": result_str, + "duration_ms": elapsed_ms, + } + + except _ToolInvokeExecutionTimeoutError: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.warning( + "[ToolInvoke] 执行超时: request_id=%s tool=%s timeout=%ds", + request_id, + tool_name, + timeout, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": f"Execution timed out after {timeout}s", + "duration_ms": elapsed_ms, + } + except Exception as exc: + elapsed_ms = round((time.perf_counter() - start) * 1000, 1) + logger.exception( + "[ToolInvoke] 执行失败: request_id=%s tool=%s error=%s", + request_id, + tool_name, + exc, + ) + return { + "ok": False, + "request_id": request_id, + "tool_name": tool_name, + "error": str(exc), + "duration_ms": elapsed_ms, + } + + +async def execute_and_callback( + ctx: RuntimeAPIContext, + *, + request_id: str, + tool_name: str, + args: dict[str, Any], + body_context: Any, + callback_url: str, + callback_headers: dict[str, str], + timeout: int, + callback_timeout: int, +) -> None: + """异步执行工具并发送回调。""" + result = await execute_tool_invoke( + ctx, + request_id=request_id, + tool_name=tool_name, + args=args, + body_context=body_context, + timeout=timeout, + ) + + payload = { + "request_id": result["request_id"], + "tool_name": result["tool_name"], + "ok": result["ok"], + "result": result.get("result"), + "duration_ms": result.get("duration_ms", 0), + "error": result.get("error"), + } + + try: + cb_timeout = ClientTimeout(total=callback_timeout) + async with ClientSession(timeout=cb_timeout) as session: + async with session.post( + callback_url, + json=payload, + headers=callback_headers or None, + ) as resp: + logger.info( + "[ToolInvoke] 回调发送: request_id=%s url=%s status=%d", + request_id, + _mask_url(callback_url), + resp.status, + ) + except Exception as exc: + logger.warning( + "[ToolInvoke] 回调失败: request_id=%s url=%s error=%s", + request_id, + _mask_url(callback_url), + exc, + ) diff --git a/tests/test_runtime_api_chat_stream.py b/tests/test_runtime_api_chat_stream.py index c1f6e99..09738cb 100644 --- a/tests/test_runtime_api_chat_stream.py +++ b/tests/test_runtime_api_chat_stream.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import chat as runtime_api_chat class _DummyTransport: @@ -90,18 +90,18 @@ async def _fake_render_message_with_pic_placeholders( ) server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) - async def _fake_run_webui_chat(*, text: str, send_output: Any) -> str: + async def _fake_run_webui_chat(_ctx: Any, *, text: str, send_output: Any) -> str: assert text == "hello" await send_output(42, "bot reply with ") return "chat" monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "render_message_with_pic_placeholders", _fake_render_message_with_pic_placeholders, ) monkeypatch.setattr(web, "StreamResponse", _DummyStreamResponse) - monkeypatch.setattr(server, "_run_webui_chat", _fake_run_webui_chat) + monkeypatch.setattr(runtime_api_chat, "run_webui_chat", _fake_run_webui_chat) request = cast( web.Request, @@ -173,11 +173,11 @@ async def _fake_ask(full_question: str, **kwargs: Any) -> str: server = RuntimeAPIServer(context, host="127.0.0.1", port=8788) monkeypatch.setattr( - runtime_api_app, + runtime_api_chat, "register_message_attachments", _fake_register_message_attachments, ) - monkeypatch.setattr(runtime_api_app, "collect_context_resources", lambda _vars: {}) + monkeypatch.setattr(runtime_api_chat, "collect_context_resources", lambda _vars: {}) sent_messages: list[tuple[int, str]] = [] diff --git a/tests/test_runtime_api_naga.py b/tests/test_runtime_api_naga.py index 55a66b4..739919d 100644 --- a/tests/test_runtime_api_naga.py +++ b/tests/test_runtime_api_naga.py @@ -608,9 +608,11 @@ async def _render_html_to_image(_: str, __: str) -> None: raise RuntimeError("render failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html + ) + monkeypatch.setattr( + "Undefined.api.routes.naga.render_html_to_image", _render_html_to_image ) - monkeypatch.setattr("Undefined.api.app.render_html_to_image", _render_html_to_image) response = await server._naga_messages_send_handler( _make_request( @@ -665,7 +667,7 @@ async def _render_markdown_to_html(_: str) -> str: raise RuntimeError("markdown failed") monkeypatch.setattr( - "Undefined.api.app.render_markdown_to_html", _render_markdown_to_html + "Undefined.api.routes.naga.render_markdown_to_html", _render_markdown_to_html ) response = await server._naga_messages_send_handler( diff --git a/tests/test_runtime_api_probes.py b/tests/test_runtime_api_probes.py index 3d6e0dd..940f32f 100644 --- a/tests/test_runtime_api_probes.py +++ b/tests/test_runtime_api_probes.py @@ -8,7 +8,7 @@ from aiohttp import web from Undefined.api import RuntimeAPIContext, RuntimeAPIServer -from Undefined.api import app as runtime_api_app +from Undefined.api.routes import system as runtime_api_system @pytest.mark.asyncio @@ -166,9 +166,11 @@ async def _fake_probe_ws_endpoint(_: str) -> dict[str, Any]: } monkeypatch.setattr( - runtime_api_app, "_probe_http_endpoint", _fake_probe_http_endpoint + runtime_api_system, "_probe_http_endpoint", _fake_probe_http_endpoint + ) + monkeypatch.setattr( + runtime_api_system, "_probe_ws_endpoint", _fake_probe_ws_endpoint ) - monkeypatch.setattr(runtime_api_app, "_probe_ws_endpoint", _fake_probe_ws_endpoint) context = RuntimeAPIContext( config_getter=lambda: SimpleNamespace( diff --git a/tests/test_runtime_api_tool_invoke.py b/tests/test_runtime_api_tool_invoke.py index 16dc141..e3c6615 100644 --- a/tests/test_runtime_api_tool_invoke.py +++ b/tests/test_runtime_api_tool_invoke.py @@ -365,7 +365,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-tool", @@ -391,7 +391,7 @@ async def _wait_for(awaitable: Any, timeout: float) -> Any: seen["timeout"] = timeout return await original_wait_for(awaitable, timeout) - monkeypatch.setattr("Undefined.api.app.asyncio.wait_for", _wait_for) + monkeypatch.setattr("Undefined.api.routes.tools.asyncio.wait_for", _wait_for) payload = await server._execute_tool_invoke( request_id="req-agent", From 0480ac7c01e1d0f75b7be364e40f2ac03ba513e2 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:11:32 +0800 Subject: [PATCH 35/57] feat(repeat): add cooldown to prevent re-repeating same content MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After the bot repeats a message, the same content enters a configurable cooldown period (default 60 minutes) during which it won't be repeated again, even if the repeat chain conditions are met again. Features: - New config: easter_egg.repeat_cooldown_minutes (default 60, 0=disabled) - Question mark normalization: ? and ? treated as equivalent for cooldown - Per-group, per-text independent cooldown tracking - Cooldown uses monotonic clock (immune to wall-clock changes) Tests: 8 new repeat cooldown tests + 2 config parsing tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- config.toml.example | 3 + src/Undefined/config/loader.py | 12 +++ src/Undefined/handlers.py | 70 ++++++++++--- tests/test_config_easter_egg_repeat.py | 18 ++++ tests/test_handlers_repeat.py | 137 ++++++++++++++++++++++++- 5 files changed, 222 insertions(+), 18 deletions(-) diff --git a/config.toml.example b/config.toml.example index 4ab0d04..c2a7fed 100644 --- a/config.toml.example +++ b/config.toml.example @@ -719,6 +719,9 @@ repeat_enabled = false # zh: 复读触发所需的连续相同消息条数(来自不同发送者),范围 2–20,默认 3。 # en: Number of consecutive identical messages (from different senders) required to trigger repeat, range 2–20, default 3. repeat_threshold = 3 +# zh: 复读冷却时间(分钟)。同一内容被复读后,在冷却时间内不再重复复读。0 = 无冷却。问号类消息(?/?)视为等价。 +# en: Repeat cooldown (minutes). After repeating the same content, won't repeat it again within this cooldown. 0 = no cooldown. Question marks (?/?) are treated as equivalent. +repeat_cooldown_minutes = 60 # zh: 是否启用倒问号(复读触发时,若消息为问号则发送倒问号 ¿)。 # en: Enable inverted question mark (when repeat triggers on "?" messages, send "¿" instead). inverted_question_enabled = false diff --git a/src/Undefined/config/loader.py b/src/Undefined/config/loader.py index 26f669c..992829a 100644 --- a/src/Undefined/config/loader.py +++ b/src/Undefined/config/loader.py @@ -219,6 +219,7 @@ class Config: keyword_reply_enabled: bool repeat_enabled: bool repeat_threshold: int + repeat_cooldown_minutes: int inverted_question_enabled: bool context_recent_messages_limit: int ai_request_max_retries: int @@ -493,6 +494,16 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi repeat_threshold = 2 if repeat_threshold > 20: repeat_threshold = 20 + repeat_cooldown_minutes = _coerce_int( + _get_value( + data, + ("easter_egg", "repeat_cooldown_minutes"), + "EASTER_EGG_REPEAT_COOLDOWN_MINUTES", + ), + 60, + ) + if repeat_cooldown_minutes < 0: + repeat_cooldown_minutes = 0 context_recent_messages_limit = _coerce_int( _get_value( data, @@ -1241,6 +1252,7 @@ def load(cls, config_path: Optional[Path] = None, strict: bool = True) -> "Confi keyword_reply_enabled=keyword_reply_enabled, repeat_enabled=repeat_enabled, repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, inverted_question_enabled=inverted_question_enabled, context_recent_messages_limit=context_recent_messages_limit, ai_request_max_retries=ai_request_max_retries, diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 726fa5c..357e0c6 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -8,6 +8,7 @@ import os from pathlib import Path import random +import time from typing import Any, Coroutine from Undefined.attachments import ( @@ -131,6 +132,8 @@ def __init__( # 复读功能状态(按群跟踪最近消息文本与发送者) self._repeat_counter: dict[int, list[tuple[str, int]]] = {} self._repeat_locks: dict[int, asyncio.Lock] = {} + # 复读冷却:group_id → {normalized_text → monotonic_timestamp} + self._repeat_cooldown: dict[int, dict[str, float]] = {} # 启动队列 self.ai_coordinator.queue_manager.start(self.ai_coordinator.execute_reply) @@ -143,6 +146,31 @@ def _get_repeat_lock(self, group_id: int) -> asyncio.Lock: self._repeat_locks[group_id] = lock return lock + @staticmethod + def _normalize_repeat_text(text: str) -> str: + """规范化复读文本用于冷却比较(?→?)。""" + return text.replace("?", "?") + + def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: + """检查指定群的文本是否在复读冷却期内。""" + cooldown_minutes = self.config.repeat_cooldown_minutes + if cooldown_minutes <= 0: + return False + group_cd = self._repeat_cooldown.get(group_id) + if not group_cd: + return False + key = self._normalize_repeat_text(text) + last_time = group_cd.get(key) + if last_time is None: + return False + return (time.monotonic() - last_time) < cooldown_minutes * 60 + + def _record_repeat_cooldown(self, group_id: int, text: str) -> None: + """记录复读冷却时间戳。""" + key = self._normalize_repeat_text(text) + group_cd = self._repeat_cooldown.setdefault(group_id, {}) + group_cd[key] = time.monotonic() + async def _annotate_meme_descriptions( self, attachments: list[dict[str, str]], @@ -829,22 +857,32 @@ async def _fetch_group_name() -> str: and self.config.bot_qq not in senders ): reply_text = texts[0] - if self.config.inverted_question_enabled: - stripped = reply_text.strip() - if set(stripped) <= {"?", "?"}: - reply_text = "¿" * len(stripped) - # 清空计数器防止重复触发 - self._repeat_counter[group_id] = [] - logger.info( - "[复读] 触发复读: group=%s text=%s", - group_id, - redact_string(reply_text)[:50], - ) - await self.sender.send_group_message( - group_id, - reply_text, - history_prefix=REPEAT_REPLY_HISTORY_PREFIX, - ) + # 冷却检查:同一内容在冷却期内不再复读 + if self._is_repeat_on_cooldown(group_id, reply_text): + self._repeat_counter[group_id] = [] + logger.debug( + "[复读] 冷却中跳过: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + else: + if self.config.inverted_question_enabled: + stripped = reply_text.strip() + if set(stripped) <= {"?", "?"}: + reply_text = "¿" * len(stripped) + # 清空计数器防止重复触发 + self._repeat_counter[group_id] = [] + self._record_repeat_cooldown(group_id, texts[0]) + logger.info( + "[复读] 触发复读: group=%s text=%s", + group_id, + redact_string(reply_text)[:50], + ) + await self.sender.send_group_message( + group_id, + reply_text, + history_prefix=REPEAT_REPLY_HISTORY_PREFIX, + ) return # Bilibili 视频自动提取 diff --git a/tests/test_config_easter_egg_repeat.py b/tests/test_config_easter_egg_repeat.py index 89872e5..e27ae05 100644 --- a/tests/test_config_easter_egg_repeat.py +++ b/tests/test_config_easter_egg_repeat.py @@ -27,12 +27,14 @@ def test_repeat_defaults_to_false(tmp_path: Path) -> None: cfg = _load(tmp_path, _MINIMAL) assert cfg.repeat_enabled is False assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 def test_repeat_enabled_explicit(tmp_path: Path) -> None: cfg = _load(tmp_path, _MINIMAL + "\n[easter_egg]\nrepeat_enabled = true\n") assert cfg.repeat_enabled is True assert cfg.inverted_question_enabled is False + assert cfg.repeat_cooldown_minutes == 60 def test_inverted_question_enabled_explicit(tmp_path: Path) -> None: @@ -60,3 +62,19 @@ def test_keyword_reply_still_parsed_from_easter_egg(tmp_path: Path) -> None: _MINIMAL + "\n[easter_egg]\nkeyword_reply_enabled = true\n", ) assert cfg.keyword_reply_enabled is True + + +def test_repeat_cooldown_custom_value(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = 30\n", + ) + assert cfg.repeat_cooldown_minutes == 30 + + +def test_repeat_cooldown_negative_clamped_to_zero(tmp_path: Path) -> None: + cfg = _load( + tmp_path, + _MINIMAL + "\n[easter_egg]\nrepeat_cooldown_minutes = -5\n", + ) + assert cfg.repeat_cooldown_minutes == 0 diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 8d508de..4f9a830 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -18,6 +18,7 @@ def _build_handler( *, repeat_enabled: bool = False, repeat_threshold: int = 3, + repeat_cooldown_minutes: int = 60, inverted_question_enabled: bool = False, keyword_reply_enabled: bool = False, ) -> Any: @@ -26,6 +27,7 @@ def _build_handler( bot_qq=10000, repeat_enabled=repeat_enabled, repeat_threshold=repeat_threshold, + repeat_cooldown_minutes=repeat_cooldown_minutes, inverted_question_enabled=inverted_question_enabled, keyword_reply_enabled=keyword_reply_enabled, bilibili_auto_extract_enabled=False, @@ -69,6 +71,7 @@ def _build_handler( handler._background_tasks = set() handler._repeat_counter = {} handler._repeat_locks = {} + handler._repeat_cooldown = {} handler._profile_name_refresh_cache = {} handler._bot_nickname_cache = SimpleNamespace( get_nicknames=AsyncMock(return_value=frozenset()), @@ -155,14 +158,14 @@ async def test_repeat_does_not_trigger_for_different_texts() -> None: @pytest.mark.asyncio async def test_repeat_clears_counter_after_trigger() -> None: - handler = _build_handler(repeat_enabled=True) + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) # 第一轮:3条相同触发复读 for uid in [20001, 20002, 20003]: await handler.handle_message(_group_event(sender_id=uid, text="hello")) assert handler.sender.send_group_message.call_count == 1 - # 第二轮:再来3条相同应再次触发 + # 第二轮:再来3条相同应再次触发(无冷却) for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="hello")) @@ -343,3 +346,133 @@ async def test_repeat_custom_threshold_4() -> None: await handler.handle_message(_group_event(sender_id=20004, text="hey")) handler.sender.send_group_message.assert_called_once() assert handler.sender.send_group_message.call_args.args[1] == "hey" + + +# ── 冷却机制:复读后同一内容在冷却期内不再触发 ── + + +@pytest.mark.asyncio +async def test_repeat_cooldown_suppresses_same_text() -> None: + """复读触发后,同一内容在冷却期内再次满足条件也不触发。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 第二轮:同一内容,应被冷却抑制 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_allows_different_text() -> None: + """复读 "草" 后,不同内容 "lol" 仍可正常复读。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + assert handler.sender.send_group_message.call_args.args[1] == "lol" + + +@pytest.mark.asyncio +async def test_repeat_cooldown_expired_allows_retrigger( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """冷却过期后,相同内容可以再次触发。""" + import time as _time + + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 第一轮:触发 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + # 模拟时间流逝 61 分钟 + original_monotonic = _time.monotonic + monkeypatch.setattr(_time, "monotonic", lambda: original_monotonic() + 3660) + + # 第二轮:冷却已过期,应触发 + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_zero_disables() -> None: + """cooldown=0 时不启用冷却,连续复读同一内容仍可触发。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=0) + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 1 + + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_question_mark_normalization() -> None: + """全角问号 ? 和半角问号 ? 视为等价,复读后互相抑制。""" + handler = _build_handler( + repeat_enabled=True, + repeat_cooldown_minutes=60, + inverted_question_enabled=True, + ) + # 全角问号触发复读 + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 + assert handler.sender.send_group_message.call_args.args[1] == "¿¿¿" + + # 半角问号——应被冷却抑制(?和 ? 等价) + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="???")) + assert handler.sender.send_group_message.call_count == 1 # 不增加 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_multiple_texts_tracked() -> None: + """多种不同内容各自独立冷却。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + # 复读 "lol" + for uid in [20004, 20005, 20006]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + # "草" 再次满足条件 → 冷却中,不触发 + for uid in [20007, 20008, 20009]: + await handler.handle_message(_group_event(sender_id=uid, text="草")) + assert handler.sender.send_group_message.call_count == 2 + + # "lol" 再次满足条件 → 冷却中,不触发 + for uid in [20010, 20011, 20012]: + await handler.handle_message(_group_event(sender_id=uid, text="lol")) + assert handler.sender.send_group_message.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeat_cooldown_groups_independent() -> None: + """不同群的冷却互不影响。""" + handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) + # 群A 复读 "草" + for uid in [20001, 20002, 20003]: + await handler.handle_message( + _group_event(group_id=30001, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 1 + + # 群B 复读 "草" — 不同群,不受群A冷却影响 + for uid in [20004, 20005, 20006]: + await handler.handle_message( + _group_event(group_id=30002, sender_id=uid, text="草") + ) + assert handler.sender.send_group_message.call_count == 2 From b024427d4a8f9e785969b9163ceba106ef1c709f Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:19:00 +0800 Subject: [PATCH 36/57] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E8=A6=86=E7=9B=96=E5=A4=8D=E8=AF=BB=E5=86=B7=E5=8D=B4?= =?UTF-8?q?=E3=80=81=E5=81=87@=E6=A3=80=E6=B5=8B=E3=80=81profile=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E6=A8=A1=E5=BC=8F=E4=B8=8EAPI=E6=8B=86=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - configuration.md: 新增 repeat_threshold、repeat_cooldown_minutes 字段 - slash-commands.md: 补充 /profile 输出模式(-f/-r/-t)与超管定向查看 - ARCHITECTURE.md: 补充 Runtime API 路由子模块层级 - development.md: 补充 config/、api/routes/、utils/ 目录说明 - CHANGELOG.md: 新增 v3.3.2 版本条目 Co-authored-by: Claude Opus 4.6 --- ARCHITECTURE.md | 2 +- CHANGELOG.md | 15 +++++++++++++++ docs/configuration.md | 6 ++++-- docs/development.md | 8 +++++++- docs/slash-commands.md | 24 ++++++++++++++++++------ 5 files changed, 45 insertions(+), 10 deletions(-) diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index e2d4176..53d35be 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -819,7 +819,7 @@ description: 从 PDF 文件中提取文本和表格,填写表单。当用户 ### 8层架构分层 1. **外部实体层**:用户、管理员、OneBot 协议端 (NapCat/Lagrange.Core)、大模型 API 服务商 -2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py) +2. **核心入口层**:main.py 启动入口、配置管理器 (config/loader.py)、热更新应用器 (config/hot_reload.py)、OneBotClient (onebot.py)、RequestContext (context.py)、Runtime API Server (api/app.py → api/routes/ 路由子模块) 3. **消息处理层**:MessageHandler (handlers.py)、SecurityService (security.py)、CommandDispatcher (services/command.py)、AICoordinator (ai_coordinator.py)、QueueManager (queue_manager.py)、Bilibili 自动提取 (bilibili/) 4. **AI 核心能力层**:AIClient (client.py)、PromptBuilder (prompts.py)、ModelRequester (llm.py)、ToolManager (tooling.py)、MultimodalAnalyzer (multimodal.py)、SummaryService (summaries.py)、TokenCounter (tokens.py) 5. **存储与上下文层**:MessageHistoryManager (utils/history.py, 10000条限制)、MemoryStorage (memory.py, 置顶备忘录, 500条上限)、EndSummaryStorage、CognitiveService + JobQueue + HistorianWorker + VectorStore + ProfileStorage、MemeService + MemeWorker + MemeStore + MemeVectorStore (表情包库)、FAQStorage、ScheduledTaskStorage、TokenUsageStorage (自动归档) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb1da18..e7b4f5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## v3.3.2 架构拆分、假@检测与多模式侧写 + +Runtime API 大规模拆分、假@检测、/profile 多输出模式、超管跨目标侧写查看、复读冷却机制,以及配置子模块化与大量工具函数增强。 + +- **api/app.py 拆分**:将 2491 行的巨型路由文件拆分为 8 个独立路由子模块 (`api/routes/`),app.py 仅保留薄包装委派层(333 行)。 +- **假@检测**:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息(自动获取群内昵称,防竞态)。支持 `@昵称 /命令` 正常触发斜杠指令。 +- **/profile 多输出模式**:新增 `-f`(合并转发,默认)、`-r`(渲染图片)、`-t`(直接文本)三种侧写输出模式。合并转发分元数据与内容两条消息;渲染模式优化了元数据与正文的视觉排版。 +- **超管跨目标侧写**:超级管理员可通过 `/p ` 和 `/p g <群号>` 查看任意用户/群的侧写。 +- **复读冷却**:同一内容被复读后,在可配置的冷却期(默认 60 分钟)内不再重复复读。?和 ? 视为等价。 +- **配置子模块化**:config/ 拆分为 loader.py、models.py、hot_reload.py,新增 `repeat_threshold`(2–20)与 `repeat_cooldown_minutes` 配置项。 +- **utils 增强**:新增 `coerce.py`(安全类型强转)、`fake_at.py`(假@文本检测与解析)。 +- 大量测试覆盖新增(1438+ 测试),ruff + mypy 零错误。 + +--- + ## v3.3.1 /version 命令的添加 添加了 /version 命令以查看版本号和更改内容。 diff --git a/docs/configuration.md b/docs/configuration.md index eceb9ea..7764a2d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -426,8 +426,10 @@ Prompt caching 补充: |---|---:|---|---| | `agent_call_message_enabled` | `"none"` | 调用提示模式 | `none` / `agent` / `tools` / `all` / `clean` | | `keyword_reply_enabled` | `false` | 群聊关键词自动回复 | 布尔 | -| `repeat_enabled` | `false` | 群聊复读(连续3条相同消息时复读) | 布尔 | -| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送¿) | 布尔 | +| `repeat_enabled` | `false` | 群聊复读(连续 N 条相同消息时复读) | 布尔 | +| `repeat_threshold` | `3` | 触发复读所需的连续相同消息条数(来自不同发送者) | 整数,2–20 | +| `repeat_cooldown_minutes` | `60` | 复读冷却时间(分钟)。同一内容被复读后,在冷却期内不再重复复读。?和 ? 视为等价。0 = 无冷却 | 整数,≥ 0 | +| `inverted_question_enabled` | `false` | 倒问号(复读触发时若消息为问号则发送 ¿) | 布尔 | 兼容:历史字段 `[core].keyword_reply_enabled` 仍可读取,建议迁移到 `[easter_egg]`。 diff --git a/docs/development.md b/docs/development.md index 6dab818..522c1f6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -22,8 +22,14 @@ src/Undefined/ │ ├── agents/ # 智能体 (独立自主的子 AI,负责处理诸如 Web 搜索、文件分析的具体长时任务) │ ├── commands/ # 中心化斜杠指令系统 (实现如 /help, /stats, /addadmin 等平台功能) │ └── anthropic_skills/# Anthropic 协议集成的外部 Skills (兼容 SKILL.md 格式) +├── config/ # 配置系统 (loader.py TOML 解析, models.py 数据模型, hot_reload.py 热更新) +├── api/ # Management API + Runtime API +│ ├── routes/ # 路由子模块 (chat, tools, naga, system, memes, memory, cognitive, health) +│ ├── app.py # aiohttp 服务主入口 (薄包装委派到 routes/) +│ └── _openapi.py # OpenAPI 文档生成 +├── memes/ # 表情包库 (两阶段 AI 管线, SQLite + ChromaDB) ├── services/ # 核心运行服务 (Queue 任务队列, Command 命令分发, Security 安全防护拦截) -├── utils/ # 通用支持工具组 (包含历史处理、JSON原子读写加锁 IO 操作等) +├── utils/ # 通用支持工具组 (io.py 异步原子读写, history.py, coerce.py 类型强转, fake_at.py 假@检测) ├── handlers.py # 最外层 OneBot 消息分流处理层 └── onebot.py # OneBot WebSocket 客户端核心连接 ``` diff --git a/docs/slash-commands.md b/docs/slash-commands.md index 6acf72b..f0baa17 100644 --- a/docs/slash-commands.md +++ b/docs/slash-commands.md @@ -86,25 +86,37 @@ Undefined 提供了一套强大的斜杠指令(Slash Commands)系统。管 #### 2. 消息总结与侧写查看 -- **/profile [group]** +- **/profile [group] [-f|-r|-t] [目标ID]** - **说明**:查看用户或群聊的认知侧写。侧写由系统根据聊天历史自动生成和更新。 - **别名**:`/me`、`/p` - **参数**: | 参数 | 是否必填 | 说明 | |------|----------|------| - | `group` | 可选 | 传入 `group` 时查看当前群聊的侧写(仅群聊可用) | + | `group` / `g` | 可选 | 查看群聊侧写(仅群聊可用) | + | `-f` / `--forward` | 可选 | 合并转发模式输出(默认) | + | `-r` / `--render` | 可选 | 渲染为图片发送 | + | `-t` / `--text` | 可选 | 直接文本消息发送 | + | `` | 可选 | 🔒 超管专用:查看指定用户的侧写 | + | `g <群号>` | 可选 | 🔒 超管专用:查看指定群聊的侧写 | - **行为**: - - **私聊**:只能查看自己的用户侧写,不支持 `group` 参数。 - - **群聊**:不带参数查看自己的用户侧写,带 `group` 查看当前群聊侧写。 - - 过长侧写会自动截断(3000 字符上限)。 + - **私聊**:查看自己的用户侧写,不支持 `group` 参数。 + - **群聊**:不带参数查看自己的用户侧写,带 `group` / `g` 查看当前群聊侧写。 + - **超管指定目标**:超级管理员可传入 QQ 号或群号查看任意用户/群的侧写,非超管使用时提示无权限。 + - **输出模式**:默认合并转发;`-r` 渲染为图片;`-t` 直接文本发送。 - **限流**:普通用户 60 秒,管理员 10 秒,超管无限制。 - **示例**: ``` - /profile → 查看自己的侧写 + /profile → 查看自己的侧写(合并转发) + /p -r → 查看自己的侧写(渲染图片) + /p -t → 查看自己的侧写(直接文本) /me → 同上(别名) /profile group → 查看当前群聊的侧写 + /p g → 同上 + /p 123456 → 🔒 超管:查看QQ号123456的侧写 + /p g 789012 → 🔒 超管:查看群号789012的侧写 + /p 123456 -r → 🔒 超管:查看指定用户侧写(渲染图片) ``` - **/summary [条数|时间范围] [自定义描述]** From 33d715859c59f07a6dda85187cd39ff7ec543aa9 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:21:59 +0800 Subject: [PATCH 37/57] =?UTF-8?q?docs(changelog):=20=E9=87=8D=E5=86=99=20v?= =?UTF-8?q?3.3.2=20=E6=9D=A1=E7=9B=AE=E8=A6=86=E7=9B=96=E6=89=80=E6=9C=89?= =?UTF-8?q?=20feature/u-guess=20=E5=8F=98=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Claude Opus 4.6 --- CHANGELOG.md | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7b4f5d..83f2f07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,22 @@ -## v3.3.2 架构拆分、假@检测与多模式侧写 - -Runtime API 大规模拆分、假@检测、/profile 多输出模式、超管跨目标侧写查看、复读冷却机制,以及配置子模块化与大量工具函数增强。 - -- **api/app.py 拆分**:将 2491 行的巨型路由文件拆分为 8 个独立路由子模块 (`api/routes/`),app.py 仅保留薄包装委派层(333 行)。 -- **假@检测**:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息(自动获取群内昵称,防竞态)。支持 `@昵称 /命令` 正常触发斜杠指令。 -- **/profile 多输出模式**:新增 `-f`(合并转发,默认)、`-r`(渲染图片)、`-t`(直接文本)三种侧写输出模式。合并转发分元数据与内容两条消息;渲染模式优化了元数据与正文的视觉排版。 -- **超管跨目标侧写**:超级管理员可通过 `/p ` 和 `/p g <群号>` 查看任意用户/群的侧写。 -- **复读冷却**:同一内容被复读后,在可配置的冷却期(默认 60 分钟)内不再重复复读。?和 ? 视为等价。 -- **配置子模块化**:config/ 拆分为 loader.py、models.py、hot_reload.py,新增 `repeat_threshold`(2–20)与 `repeat_cooldown_minutes` 配置项。 -- **utils 增强**:新增 `coerce.py`(安全类型强转)、`fake_at.py`(假@文本检测与解析)。 -- 大量测试覆盖新增(1438+ 测试),ruff + mypy 零错误。 +## v3.3.2 架构重构、假@检测与认知侧写增强 + +围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 + +- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 +- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 +- 超级管理员可通过 `/p ` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 +- 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 +- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 +- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 +- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 +- 新增 `calculator` 多功能安全计算器工具。 +- 新增消息历史限制全面可配置化(`[history].max_records`)。 +- 新增 `utils/coerce.py`(安全类型强转)与 `utils/fake_at.py`(假@文本检测与解析)公共模块。 +- WebUI 新增功能:Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 修复队列系统 historian 模型未注册的调度问题。 +- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 +- 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 --- From dcf02ab40c584554f26309eaaa65528c60dabf8d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:34:40 +0800 Subject: [PATCH 38/57] chore(version): bump version to 3.3.2 --- apps/undefined-console/package-lock.json | 4 ++-- apps/undefined-console/package.json | 2 +- apps/undefined-console/src-tauri/Cargo.lock | 2 +- apps/undefined-console/src-tauri/Cargo.toml | 2 +- apps/undefined-console/src-tauri/tauri.conf.json | 2 +- pyproject.toml | 2 +- src/Undefined/__init__.py | 2 +- uv.lock | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/apps/undefined-console/package-lock.json b/apps/undefined-console/package-lock.json index 4c32b00..cb66706 100644 --- a/apps/undefined-console/package-lock.json +++ b/apps/undefined-console/package-lock.json @@ -1,12 +1,12 @@ { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "undefined-console", - "version": "3.3.1", + "version": "3.3.2", "dependencies": { "@tauri-apps/api": "^2.3.0", "@tauri-apps/plugin-http": "^2.3.0" diff --git a/apps/undefined-console/package.json b/apps/undefined-console/package.json index 1a85c4c..d7b6e21 100644 --- a/apps/undefined-console/package.json +++ b/apps/undefined-console/package.json @@ -1,7 +1,7 @@ { "name": "undefined-console", "private": true, - "version": "3.3.1", + "version": "3.3.2", "type": "module", "scripts": { "tauri": "tauri", diff --git a/apps/undefined-console/src-tauri/Cargo.lock b/apps/undefined-console/src-tauri/Cargo.lock index 190f9f8..f9b61ac 100644 --- a/apps/undefined-console/src-tauri/Cargo.lock +++ b/apps/undefined-console/src-tauri/Cargo.lock @@ -4063,7 +4063,7 @@ checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" dependencies = [ "serde", "serde_json", diff --git a/apps/undefined-console/src-tauri/Cargo.toml b/apps/undefined-console/src-tauri/Cargo.toml index c3a9356..3acc528 100644 --- a/apps/undefined-console/src-tauri/Cargo.toml +++ b/apps/undefined-console/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "undefined_console" -version = "3.3.1" +version = "3.3.2" description = "Undefined cross-platform management console" authors = ["Undefined contributors"] license = "MIT" diff --git a/apps/undefined-console/src-tauri/tauri.conf.json b/apps/undefined-console/src-tauri/tauri.conf.json index a8a61f1..7fe5107 100644 --- a/apps/undefined-console/src-tauri/tauri.conf.json +++ b/apps/undefined-console/src-tauri/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "Undefined Console", - "version": "3.3.1", + "version": "3.3.2", "identifier": "com.undefined.console", "build": { "beforeDevCommand": "npm run dev", diff --git a/pyproject.toml b/pyproject.toml index 9128920..6dcc6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Undefined-bot" -version = "3.3.1" +version = "3.3.2" description = "QQ bot platform with cognitive memory architecture and multi-agent Skills, via OneBot V11." readme = "README.md" authors = [ diff --git a/src/Undefined/__init__.py b/src/Undefined/__init__.py index 5722e02..c2c8768 100644 --- a/src/Undefined/__init__.py +++ b/src/Undefined/__init__.py @@ -1,3 +1,3 @@ """Undefined - A high-performance, highly scalable QQ group and private chat robot based on a self-developed architecture.""" -__version__ = "3.3.1" +__version__ = "3.3.2" diff --git a/uv.lock b/uv.lock index 810f4a0..d1d2053 100644 --- a/uv.lock +++ b/uv.lock @@ -4638,7 +4638,7 @@ wheels = [ [[package]] name = "undefined-bot" -version = "3.3.1" +version = "3.3.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, From 64af9b3b803268018d826f93d44194fb092efea7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:39:02 +0800 Subject: [PATCH 39/57] fix(test): skip LaTeX render tests when Playwright browser binary missing The tests only caught ImportError but CI has Playwright installed without browser binaries, causing a runtime Error. Now checks the error message returned by execute() for 'Executable doesn't exist' and skips. Co-authored-by: Claude Opus 4.6 --- tests/test_render_latex_tool.py | 51 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py index 0595c99..d4d3e80 100644 --- a/tests/test_render_latex_tool.py +++ b/tests/test_render_latex_tool.py @@ -58,17 +58,14 @@ async def test_render_simple_equation() -> None: args = {"content": "E = mc^2", "output_format": "png"} - try: - result = await execute(args, context) - assert result == '' - assert len(mock_registry.registered_items) == 1 - assert mock_registry.registered_items[0]["kind"] == "image" - assert mock_registry.registered_items[0]["mime_type"] == "image/png" - assert mock_registry.registered_items[0]["size"] > 0 - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "image" + assert mock_registry.registered_items[0]["mime_type"] == "image/png" + assert mock_registry.registered_items[0]["size"] > 0 @pytest.mark.asyncio @@ -85,14 +82,11 @@ async def test_render_with_delimiters() -> None: args = {"content": r"\[ \int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2} \]"} - try: - result = await execute(args, context) - assert result == '' - assert len(mock_registry.registered_items) == 1 - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 @pytest.mark.asyncio @@ -109,17 +103,14 @@ async def test_render_pdf_output() -> None: args = {"content": r"\frac{a}{b} + \sqrt{c}", "output_format": "pdf"} - try: - result = await execute(args, context) - assert result == '' - assert len(mock_registry.registered_items) == 1 - assert mock_registry.registered_items[0]["kind"] == "file" - assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" - assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" - except ImportError as e: - if "playwright" in str(e).lower(): - pytest.skip("Playwright 未安装,跳过测试") - raise + result = await execute(args, context) + if "渲染失败" in result and "Executable doesn't exist" in result: + pytest.skip("Playwright 浏览器未安装,跳过测试") + assert result == '' + assert len(mock_registry.registered_items) == 1 + assert mock_registry.registered_items[0]["kind"] == "file" + assert mock_registry.registered_items[0]["mime_type"] == "application/pdf" + assert mock_registry.registered_items[0]["display_name"] == "latex.pdf" @pytest.mark.asyncio From 73954722efa2f6028b7419c8fd3ed21441e92f8a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:47:23 +0800 Subject: [PATCH 40/57] fix: address PR review findings from Devin and Codex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - fix(prompt): split concatenated bullet points in judge_meme_image.txt - fix(prompts): use configurable repeat_threshold instead of hardcoded 3 - fix(memes): guard safe_int(0) → None for group_id sentinel value - fix(fake_at): normalize cached nicknames with NFKC (not just casefold) - fix(handlers): skip repeat counting for empty text messages Co-authored-by: Claude Opus 4.6 --- res/prompts/judge_meme_image.txt | 3 ++- src/Undefined/ai/prompts.py | 7 +++++-- src/Undefined/handlers.py | 2 +- src/Undefined/memes/service.py | 2 +- src/Undefined/utils/fake_at.py | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/res/prompts/judge_meme_image.txt b/res/prompts/judge_meme_image.txt index 5a3d02e..fdb480b 100644 --- a/res/prompts/judge_meme_image.txt +++ b/res/prompts/judge_meme_image.txt @@ -21,4 +21,5 @@ - 必须且只能调用 `submit_meme_judgement` - `is_meme` 仅表示“适不适合放进聊天表情包库”,不是“图里有没有梗” - `reason` 用一句简短中文说明依据 -- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true`- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file +- 只有当“整张图整体上就是一张可直接发送的表情包”时,才能给 `is_meme=true` +- 如果收到的是一张网格图(多帧拼接)或多张图片,说明原图是动图/GIF:请综合所有帧判断这个动图整体是否适合作为表情包,不要只看单帧 \ No newline at end of file diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index dda6b04..b1dd63b 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -222,7 +222,8 @@ def _build_model_config_info(self, runtime_config: Any) -> str: '关键词自动回复(触发词"心理委员"等,系统自动发送固定回复)' ) if repeat_enabled: - desc = "复读(群聊连续3条相同消息时自动复读)" + threshold = int(getattr(runtime_config, "repeat_threshold", 3)) + desc = f"复读(群聊连续{threshold}条相同消息时自动复读)" if inverted_question_enabled: desc += ",倒问号(复读触发时若消息为问号则发送¿)" easter_egg_parts.append(desc) @@ -342,6 +343,7 @@ async def build_messages( keyword_reply_enabled = False repeat_enabled = False + repeat_threshold = 3 inverted_question_enabled = False if self._runtime_config_getter is not None: try: @@ -350,6 +352,7 @@ async def build_messages( getattr(runtime_config, "keyword_reply_enabled", False) ) repeat_enabled = bool(getattr(runtime_config, "repeat_enabled", False)) + repeat_threshold = int(getattr(runtime_config, "repeat_threshold", 3)) inverted_question_enabled = bool( getattr(runtime_config, "inverted_question_enabled", False) ) @@ -375,7 +378,7 @@ async def build_messages( if is_group_context and repeat_enabled: repeat_desc = ( "【系统行为说明】\n" - "当前群聊已开启复读彩蛋:当群聊中连续出现3条内容相同且来自不同人的消息时," + f"当前群聊已开启复读彩蛋:当群聊中连续出现{repeat_threshold}条内容相同且来自不同人的消息时," "系统会自动复读一条相同的消息,并在历史中写入" '以"[系统复读] "开头的消息。' ) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 357e0c6..37d7e8f 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -837,7 +837,7 @@ async def _fetch_group_name() -> str: return # 复读功能:连续 N 条相同消息(来自不同发送者)时复读,N = repeat_threshold - if self.config.repeat_enabled: + if self.config.repeat_enabled and text: n = self.config.repeat_threshold async with self._get_repeat_lock(group_id): counter = self._repeat_counter.setdefault(group_id, []) diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 7712232..1159d0c 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -739,7 +739,7 @@ async def send_meme_by_uid(self, uid: str, context: dict[str, Any]) -> str: attachments=history_attachments, ) else: - preferred_temp_group_id = safe_int(context.get("group_id")) + preferred_temp_group_id = safe_int(context.get("group_id")) or None sent_message_id = await sender.send_private_message( int(target_id), cq_message, diff --git a/src/Undefined/utils/fake_at.py b/src/Undefined/utils/fake_at.py index 0ac0200..c3df1ae 100644 --- a/src/Undefined/utils/fake_at.py +++ b/src/Undefined/utils/fake_at.py @@ -163,7 +163,7 @@ async def _fetch(self, group_id: int) -> frozenset[str]: for key in ("card", "nickname"): val = str(info.get(key, "") or "").strip() if val: - names.add(val.casefold()) + names.add(_normalize(val)) except Exception as exc: logger.debug( "[假@] 获取 bot 群成员信息失败: group=%s err=%s", From 5f659a83eac39b9f8af46d3e513ded5972932f5a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 14:57:45 +0800 Subject: [PATCH 41/57] fix: address additional Devin review flags - Add 'pic' key to _MEDIA_LABELS for correct tag-based fallback label - Harden config backup path validation with backslash check and resolve()-based containment verification Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/attachments.py | 1 + src/Undefined/webui/routes/_config.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index e8fa8fc..951a75b 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -47,6 +47,7 @@ "audio": "音频", "video": "视频", "record": "语音", + "pic": "图片", } _WINDOWS_ABS_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]") _DEFAULT_REMOTE_TIMEOUT_SECONDS = 120.0 diff --git a/src/Undefined/webui/routes/_config.py b/src/Undefined/webui/routes/_config.py index c6c65b3..e3f44e4 100644 --- a/src/Undefined/webui/routes/_config.py +++ b/src/Undefined/webui/routes/_config.py @@ -239,9 +239,11 @@ async def config_restore_handler(request: web.Request) -> Response: except Exception: return web.json_response({"error": "Invalid JSON"}, status=400) name = str(data.get("name", "")) - if not name or ".." in name or "/" in name: + if not name or ".." in name or "/" in name or "\\" in name: + return web.json_response({"error": "Invalid backup name"}, status=400) + backup_path = (_BACKUP_DIR / name).resolve() + if not str(backup_path).startswith(str(_BACKUP_DIR.resolve())): return web.json_response({"error": "Invalid backup name"}, status=400) - backup_path = _BACKUP_DIR / name if not backup_path.exists(): return web.json_response({"error": "Backup not found"}, status=404) content = backup_path.read_text(encoding="utf-8") From 4c973b934ee679f5c1b987df303c2d6f6ce356e5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:06:40 +0800 Subject: [PATCH 42/57] fix(config): parse gif_analysis_mode/gif_analysis_frames from TOML MemeConfig had these fields with defaults but _parse_memes_config never read them from the [memes] section, silently ignoring user configuration. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/config/domain_parsers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Undefined/config/domain_parsers.py b/src/Undefined/config/domain_parsers.py index ebb2222..07f807b 100644 --- a/src/Undefined/config/domain_parsers.py +++ b/src/Undefined/config/domain_parsers.py @@ -172,6 +172,8 @@ def _parse_memes_config(data: dict[str, Any]) -> MemeConfig: worker_max_concurrency=max( 1, _coerce_int(section.get("worker_max_concurrency"), 4) ), + gif_analysis_mode=_coerce_str(section.get("gif_analysis_mode"), "grid"), + gif_analysis_frames=max(1, _coerce_int(section.get("gif_analysis_frames"), 6)), ) From 589f949a84ad5ec67ac4a72f9fc9d86a22f9ac0b Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:26:11 +0800 Subject: [PATCH 43/57] fix(memes): clean up GIF multi-frame analysis temp files _prepare_gif_multi_frames creates per-frame PNG files ({uid}_f{i}.png) for LLM analysis but they were never cleaned up. Add _cleanup_gif_frame_files helper and call it: - In _cleanup_meme_artifacts when uid is provided - In delete_meme - After judge/describe AI calls complete (frames no longer needed) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/memes/service.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 1159d0c..9b02d41 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -438,6 +438,7 @@ async def delete_meme(self, uid: str) -> bool: self._delete_file_if_exists, Path(record.preview_path), ) + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return True def _delete_file_if_exists(self, path: Path) -> None: @@ -446,6 +447,17 @@ def _delete_file_if_exists(self, path: Path) -> None: except OSError: logger.debug("[memes] 删除文件失败: path=%s", path, exc_info=True) + def _cleanup_gif_frame_files(self, uid: str) -> None: + """清理 GIF 多帧分析产生的临时帧文件 ({uid}_f{i}.png)。""" + preview_dir = self._preview_dir() + for frame_file in preview_dir.glob(f"{uid}_f*.png"): + try: + frame_file.unlink(missing_ok=True) + except OSError: + logger.debug( + "[memes] 删除帧文件失败: path=%s", frame_file, exc_info=True + ) + async def _cleanup_meme_artifacts( self, *, @@ -472,6 +484,8 @@ async def _cleanup_meme_artifacts( await asyncio.to_thread(self._delete_file_if_exists, blob_path) if preview_path is not None and preview_path != blob_path: await asyncio.to_thread(self._delete_file_if_exists, preview_path) + if uid: + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) async def search_memes( self, @@ -1025,6 +1039,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: ) judgement = {"is_meme": False} if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self._cleanup_meme_artifacts( uid=None, blob_path=blob_path, @@ -1041,6 +1057,9 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: "[memes] describe stage failed, drop uid=%s err=%s", uid, exc ) described = {"description": "", "tags": []} + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) tags = _normalize_tags(described.get("tags")) auto_description = str(described.get("description") or "").strip() if not auto_description and not tags: From b80277cc7fede94db543473a7995be75126e59f4 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:38:11 +0800 Subject: [PATCH 44/57] fix(attachments): skip prompt_ref append on render error _render_image_tag and _render_file_tag error paths returned early but attachments.append(record.prompt_ref()) still executed unconditionally. Change both helpers to return bool; only append on success. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/attachments.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/Undefined/attachments.py b/src/Undefined/attachments.py index 951a75b..403154d 100644 --- a/src/Undefined/attachments.py +++ b/src/Undefined/attachments.py @@ -1154,9 +1154,9 @@ async def render_message_with_attachments( # Route by media type if record.media_type == "image": - _render_image_tag(record, uid, strict, delivery_parts, history_parts) + ok = _render_image_tag(record, uid, strict, delivery_parts, history_parts) else: - _render_file_tag( + ok = _render_file_tag( record, uid, strict, @@ -1165,7 +1165,8 @@ async def render_message_with_attachments( pending_files, ) - attachments.append(record.prompt_ref()) + if ok: + attachments.append(record.prompt_ref()) delivery_parts.append(message[last_index:]) history_parts.append(message[last_index:]) @@ -1183,8 +1184,8 @@ def _render_image_tag( strict: bool, delivery_parts: list[str], history_parts: list[str], -) -> None: - """Render an image attachment as an inline CQ:image.""" +) -> bool: + """Render an image attachment as an inline CQ:image. Returns True on success.""" image_source = record.source_ref if record.local_path: image_source = Path(record.local_path).resolve().as_uri() @@ -1194,7 +1195,7 @@ def _render_image_tag( raise AttachmentRenderError(f"图片 UID 缺少可发送的文件:{uid}") delivery_parts.append(replacement) history_parts.append(replacement) - return + return False cq_args = [f"file={image_source}"] for key, value in dict(getattr(record, "segment_data", {}) or {}).items(): @@ -1210,6 +1211,7 @@ def _render_image_tag( history_parts.append(f"[图片 uid={uid} name={record.display_name}]") else: history_parts.append(f"[图片 uid={uid}]") + return True def _render_file_tag( @@ -1219,21 +1221,22 @@ def _render_file_tag( delivery_parts: list[str], history_parts: list[str], pending_files: list[AttachmentRecord], -) -> None: - """Render a non-image attachment as a pending file send.""" +) -> bool: + """Render a non-image attachment as a pending file send. Returns True on success.""" if not record.local_path or not Path(record.local_path).is_file(): replacement = f"[文件 uid={uid} 缺少本地文件]" if strict: raise AttachmentRenderError(f"文件 UID 缺少本地文件,无法发送:{uid}") delivery_parts.append(replacement) history_parts.append(replacement) - return + return False # Remove from delivery text (file sent separately) # Keep a readable placeholder in history name_part = f" name={record.display_name}" if record.display_name else "" history_parts.append(f"[文件 uid={uid}{name_part}]") pending_files.append(record) + return True # Backward-compatible alias From 58df7907771774f42ee08b1b9d6bcb3f478eead4 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:38:42 +0800 Subject: [PATCH 45/57] docs(prompts): strengthen naga_code_analysis_agent call guidance Reinforce that NagaAgent technical questions must call the agent before replying. Add explicit when_to_call scenarios and emphasize not relying on memory for frequently-updated project details. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- res/prompts/undefined_nagaagent.xml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/res/prompts/undefined_nagaagent.xml b/res/prompts/undefined_nagaagent.xml index acd3cc3..205b99f 100644 --- a/res/prompts/undefined_nagaagent.xml +++ b/res/prompts/undefined_nagaagent.xml @@ -305,8 +305,8 @@ 明确的 NagaAgent 技术问题或讨论 - 直接调用 naga_code_analysis_agent,确认相关性后再回复 - 如果只是泛泛提到naga但不是技术讨论,不要回复 + **必须**先调用 naga_code_analysis_agent 获取信息,再基于返回结果回复 + 如果只是泛泛提到naga但不是技术讨论,不要回复;但只要涉及技术细节,一定要先调 agent @@ -437,9 +437,20 @@ - 对于任何涉及 NagaAgent 的技术问题,直接调用 naga_code_analysis_agent 处理。 + 对于任何涉及 NagaAgent 的技术问题,**必须先调用 naga_code_analysis_agent 获取准确信息后再回复**。 + 不要依赖自身记忆或猜测来回答 NagaAgent 相关问题——该项目代码频繁更新,只有通过 agent 实时查阅才能保证准确。 该 Agent 内部拥有自己的工具集(read_naga_intro、read_file、search_file_content 等), 这些内部工具你无法直接调用,你只需要调用 naga_code_analysis_agent 即可。 + + + 以下场景必须调用 naga_code_analysis_agent: + - 用户询问 NagaAgent 的功能、配置、部署、构建方式 + - 用户遇到 NagaAgent 相关的报错或问题 + - 用户想了解 NagaAgent 的架构、代码逻辑、技能系统等 + - 用户提到 NagaAgent 的任何技术细节(API、openclaw、干员、技能等) + - 讨论涉及 NagaAgent 与其他系统的集成或对比 + 只有纯闲聊式提及(如"naga好用吗"这类不需要技术细节的对话)才可以不调用。 + From e02e0665d7a3b17d6d499cb3b3a95ca8e7f6c6f5 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:57:14 +0800 Subject: [PATCH 46/57] fix(prompt): clarify keyword auto-reply is code-path-only to prevent AI mimicry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AI model (kimi-k2.5) read the system prompt about [系统关键词自动回复] and mimicked the format via send_message tool, fabricating a reply that didn't exist in the codebase. Reworded the prompt to explicitly state these messages are generated by a separate code path (handlers.py), use fixed responses, and never go through the AI's tool calls. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/ai/prompts.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index b1dd63b..e938181 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -364,11 +364,13 @@ async def build_messages( { "role": "system", "content": ( - "【系统行为说明】\n" + "【系统行为说明 — 关键词自动回复】\n" '当前群聊已开启关键词自动回复彩蛋(例如触发词"心理委员")。' - "命中时,系统可能直接发送固定回复,并在历史中写入" - '以"[系统关键词自动回复] "开头的消息。\n\n' - "这类消息属于系统预设机制,不代表你在该轮主动决策。" + "该功能由 handlers.py 中的独立代码路径处理," + "在消息到达你之前就已完成发送。\n\n" + '发送后,历史中会出现以"[系统关键词自动回复] "开头的消息。' + "这些消息完全由系统代码生成(固定文案如'受着''那咋了'等)," + "不经过你的工具调用,与你的决策无关。\n\n" "阅读历史时请识别该前缀,避免误判为人格漂移或上下文异常。" "除非用户主动询问,否则不要主动解释此机制。" ), From d7a294f69686c257a4adea161e3e5dbd59ede514 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 15:59:12 +0800 Subject: [PATCH 47/57] fix(repeat): don't silently drop messages when cooldown suppresses repeat Move the early 'return' inside the else branch (actual repeat sent) so that when cooldown suppresses a repeat, the message continues through downstream handlers (bilibili/arxiv/command/AI auto-reply) instead of being silently dropped. Update test_repeat_cooldown_suppresses_same_text to verify that ai_coordinator.handle_auto_reply is still called after suppression. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 2 +- tests/test_handlers_repeat.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 37d7e8f..33d6a07 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -883,7 +883,7 @@ async def _fetch_group_name() -> str: reply_text, history_prefix=REPEAT_REPLY_HISTORY_PREFIX, ) - return + return # Bilibili 视频自动提取 if self.config.bilibili_auto_extract_enabled: diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index 4f9a830..c29219d 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -353,18 +353,21 @@ async def test_repeat_custom_threshold_4() -> None: @pytest.mark.asyncio async def test_repeat_cooldown_suppresses_same_text() -> None: - """复读触发后,同一内容在冷却期内再次满足条件也不触发。""" + """复读触发后,同一内容在冷却期内再次满足条件也不触发,但消息继续处理。""" handler = _build_handler(repeat_enabled=True, repeat_cooldown_minutes=60) # 第一轮:触发复读 for uid in [20001, 20002, 20003]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 1 - # 第二轮:同一内容,应被冷却抑制 + # 第二轮:同一内容,应被冷却抑制——复读不发送 for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 1 # 不增加 + # 冷却抑制后消息应继续处理(到达 AI auto-reply),不应被 silently dropped + assert handler.ai_coordinator.handle_auto_reply.call_count > 0 + @pytest.mark.asyncio async def test_repeat_cooldown_allows_different_text() -> None: From 1daa53fbac985e97085a5f0c472ab9b5b39fbf44 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:24:31 +0800 Subject: [PATCH 48/57] fix(memes,repeat): address 3 Devin review bugs 1. Meme reanalyze now uses GIF multi-frame analysis (Flag #1): _process_reanalyze_job checks record.is_animated + gif_analysis_mode and calls _prepare_gif_multi_frames, matching the ingest code path. Frame files are cleaned up in all exit paths. 2. Repeat cooldown dict no longer grows unboundedly (Flag #17): _record_repeat_cooldown now prunes expired entries on each insert, preventing slow memory leak from accumulated unique texts per group. 3. GIF frame files cleaned up on retryable LLM errors (Flag #25): Both judge and describe stages in the ingest path now clean up multi-frame temp files before re-raising retryable exceptions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- src/Undefined/handlers.py | 13 +++++++++++-- src/Undefined/memes/service.py | 28 +++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 33d6a07..34fa9f6 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -166,10 +166,19 @@ def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: return (time.monotonic() - last_time) < cooldown_minutes * 60 def _record_repeat_cooldown(self, group_id: int, text: str) -> None: - """记录复读冷却时间戳。""" + """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" key = self._normalize_repeat_text(text) group_cd = self._repeat_cooldown.setdefault(group_id, {}) - group_cd[key] = time.monotonic() + now = time.monotonic() + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + # 清理已过期条目 + if cooldown_seconds > 0: + expired = [ + k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds + ] + for k in expired: + del group_cd[k] + group_cd[key] = now async def _annotate_meme_descriptions( self, diff --git a/src/Undefined/memes/service.py b/src/Undefined/memes/service.py index 9b02d41..96f80ab 100644 --- a/src/Undefined/memes/service.py +++ b/src/Undefined/memes/service.py @@ -872,30 +872,52 @@ async def _process_reanalyze_job(self, job: Mapping[str, Any]) -> None: return if self._ai_client is None: raise RuntimeError("reanalyze requires ai_client") - analyze_path = record.preview_path if record.preview_path else record.blob_path + analyze_path: str | list[str] = ( + record.preview_path if record.preview_path else record.blob_path + ) + # GIF 多帧模式:与 ingest 路径保持一致 + if record.is_animated: + cfg = self._cfg() + if str(getattr(cfg, "gif_analysis_mode", "grid")).lower() == "multi": + analyze_path = await self._prepare_gif_multi_frames( + Path(record.blob_path), uid + ) try: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] judge stage failed during reanalyze: uid=%s err=%s", uid, exc ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return if not bool(judgement.get("is_meme", False)): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) await self.delete_meme(uid) return try: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] describe stage failed during reanalyze: uid=%s err=%s", uid, exc, ) + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) return + # GIF 多帧文件用完即清理 + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) auto_description = str(described.get("description") or "").strip() next_tags = _normalize_tags(described.get("tags")) if not auto_description and not next_tags: @@ -1031,6 +1053,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: judgement = await self._ai_client.judge_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] judge stage failed, treat as non-meme: uid=%s err=%s", @@ -1052,6 +1076,8 @@ async def _process_ingest_job(self, job: Mapping[str, Any]) -> None: described = await self._ai_client.describe_meme_image(analyze_path) except Exception as exc: if _is_retryable_llm_error(exc): + if isinstance(analyze_path, list): + await asyncio.to_thread(self._cleanup_gif_frame_files, uid) raise logger.exception( "[memes] describe stage failed, drop uid=%s err=%s", uid, exc From 60e723bad8d24c0780f68c1b1b3f1072a596c8e7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:37:43 +0800 Subject: [PATCH 49/57] fix(latex): preserve LaTeX commands like \nu \nabla \neq during \n replacement The aggressive content.replace('\\n', '\n') destroyed any LaTeX command starting with \n (\nu, \nabla, \neq, \neg, etc). Use regex with negative converted to a real newline. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- .../skills/toolsets/render/render_latex/handler.py | 4 ++-- tests/test_render_latex_tool.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Undefined/skills/toolsets/render/render_latex/handler.py b/src/Undefined/skills/toolsets/render/render_latex/handler.py index 0637b9f..5da83f9 100644 --- a/src/Undefined/skills/toolsets/render/render_latex/handler.py +++ b/src/Undefined/skills/toolsets/render/render_latex/handler.py @@ -42,8 +42,8 @@ def _prepare_content(raw_content: str) -> str: 3. 如果没有数学分隔符,自动用 \\[ ... \\] 包装 """ content = _strip_document_wrappers(raw_content) - # 替换字面量 \\n 为真实换行符 - content = content.replace("\\n", "\n") + # 替换字面量 \\n 为真实换行符,但保留 LaTeX 命令如 \nu \nabla \neq 等 + content = re.sub(r"\\n(?![a-zA-Z])", "\n", content) if not _has_math_delimiters(content): # 没有分隔符,自动包装为块级数学环境 diff --git a/tests/test_render_latex_tool.py b/tests/test_render_latex_tool.py index d4d3e80..130991a 100644 --- a/tests/test_render_latex_tool.py +++ b/tests/test_render_latex_tool.py @@ -182,10 +182,17 @@ def test_prepare_content() -> None: result_with_delim = _prepare_content(r"\[ E = mc^2 \]") assert result_with_delim == r"\[ E = mc^2 \]" - # 字面量 \\n 处理 - result_newline = _prepare_content(r"x = 1\\ny = 2") + # 字面量 \\n 处理(后面不跟字母时替换为换行) + result_newline = _prepare_content("x = 1\\n2 = y") assert "\n" in result_newline - assert "\\n" not in result_newline.replace(r"\[", "").replace(r"\]", "") + assert "x = 1" in result_newline + assert "2 = y" in result_newline + + # LaTeX 命令不被破坏:\nu \nabla \neq 保持不变 + result_latex = _prepare_content(r"\nu + \nabla \neq 0") + assert r"\nu" in result_latex + assert r"\nabla" in result_latex + assert r"\neq" in result_latex def test_build_html_contains_mathjax_ready_flag() -> None: From 9f1c756e474995e15cead887730b4533a51a179a Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 16:57:31 +0800 Subject: [PATCH 50/57] feat(vision): add configurable max_tokens to VisionModelConfig Previously vision model max_tokens was hardcoded (256/512/8192) at call sites. With thinking-enabled models like kimi-k2.5, the small budgets were entirely consumed by the thinking chain, leaving no room for tool-call output. - Add max_tokens field to VisionModelConfig (default 8192) - Parse from config.toml [models.vision] and VISION_MODEL_MAX_TOKENS env - Replace all hardcoded max_tokens in multimodal.py with config value - Update config.toml.example with documentation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- config.toml.example | 3 +++ src/Undefined/ai/multimodal.py | 6 +++--- src/Undefined/config/model_parsers.py | 8 ++++++++ src/Undefined/config/models.py | 1 + 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/config.toml.example b/config.toml.example index c2a7fed..4448a04 100644 --- a/config.toml.example +++ b/config.toml.example @@ -152,6 +152,9 @@ api_key = "" # zh: Vision 模型名称。 # en: Vision model name. model_name = "" +# zh: Vision 模型最大输出 tokens。启用 thinking 时建议设大(如 8192),确保思维链消耗后仍有余量输出工具调用。 +# en: Vision model max output tokens. When thinking is enabled, use a larger value (e.g. 8192) so there is still room for tool-call output after thinking. +max_tokens = 8192 # zh: 队列发车间隔(秒,0 表示立即发车)。 # en: Queue interval (seconds; 0 dispatches immediately). queue_interval_seconds = 1.0 diff --git a/src/Undefined/ai/multimodal.py b/src/Undefined/ai/multimodal.py index f5f1ab2..e4dda36 100644 --- a/src/Undefined/ai/multimodal.py +++ b/src/Undefined/ai/multimodal.py @@ -696,7 +696,7 @@ async def analyze( result = await self._requester.request( model_config=self._vision_config, messages=[{"role": "user", "content": content_items}], - max_tokens=8192, + max_tokens=self._vision_config.max_tokens, call_type=f"vision_{detected_type}", ) content = extract_choices_content(result) @@ -861,7 +861,7 @@ async def judge_meme_image(self, image_url: str | list[str]) -> dict[str, Any]: tool_schema=_MEME_JUDGE_TOOL, tool_name="submit_meme_judgement", call_type="vision_meme_judge", - max_tokens=256, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包判定失败,按非表情包处理: %s", exc) @@ -899,7 +899,7 @@ async def describe_meme_image(self, image_url: str | list[str]) -> dict[str, Any tool_schema=_MEME_DESCRIBE_TOOL, tool_name="submit_meme_description", call_type="vision_meme_describe", - max_tokens=512, + max_tokens=self._vision_config.max_tokens, ) except Exception as exc: logger.exception("[媒体分析] 表情包描述失败: %s", exc) diff --git a/src/Undefined/config/model_parsers.py b/src/Undefined/config/model_parsers.py index e17fc85..13e16c2 100644 --- a/src/Undefined/config/model_parsers.py +++ b/src/Undefined/config/model_parsers.py @@ -382,6 +382,14 @@ def _parse_vision_model_config(data: dict[str, Any]) -> VisionModelConfig: _get_value(data, ("models", "vision", "model_name"), "VISION_MODEL_NAME"), "", ), + max_tokens=_coerce_int( + _get_value( + data, + ("models", "vision", "max_tokens"), + "VISION_MODEL_MAX_TOKENS", + ), + 8192, + ), queue_interval_seconds=queue_interval_seconds, api_mode=api_mode, thinking_enabled=_coerce_bool( diff --git a/src/Undefined/config/models.py b/src/Undefined/config/models.py index 3190b38..118fb9a 100644 --- a/src/Undefined/config/models.py +++ b/src/Undefined/config/models.py @@ -92,6 +92,7 @@ class VisionModelConfig: api_url: str api_key: str model_name: str + max_tokens: int = 8192 # 最大输出 tokens queue_interval_seconds: float = 1.0 api_mode: str = "chat_completions" # 请求 API 模式 thinking_enabled: bool = False # 是否启用 thinking From 4a6c57c464dd10f3acda0338c8b372a13d8204f7 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 17:21:42 +0800 Subject: [PATCH 51/57] refactor(historian): use context_recent_messages_limit and XML format The historian now sees the same message count and XML format as the main AI, improving disambiguation quality: - Extract shared format_message_xml() / format_messages_xml() into utils/xml.py; deduplicate prompts.py and fetch_messages handler - Historian recent messages use context_recent_messages_limit (default 20) instead of historian_recent_messages_inject_k (was 12) - Messages formatted as XML (matching main AI) instead of plain text bullet list, including attachments and full metadata - Update historian_rewrite.md prompt to note XML format - Update tests for new format and import paths Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- res/prompts/historian_rewrite.md | 2 +- src/Undefined/ai/prompts.py | 56 +------------ src/Undefined/cognitive/historian.py | 2 +- .../tools/fetch_messages/handler.py | 68 +--------------- src/Undefined/skills/tools/end/handler.py | 61 +++++++------- src/Undefined/utils/xml.py | 81 ++++++++++++++++++- tests/test_end_tool.py | 12 +-- tests/test_fetch_messages_tool.py | 25 +++--- 8 files changed, 132 insertions(+), 175 deletions(-) diff --git a/res/prompts/historian_rewrite.md b/res/prompts/historian_rewrite.md index 72a8a2a..ac706d2 100644 --- a/res/prompts/historian_rewrite.md +++ b/res/prompts/historian_rewrite.md @@ -35,7 +35,7 @@ observations: {observations} 当前消息原文(触发本轮): {source_message} -最近消息参考(用于消歧,不要求逐字复述): +最近消息参考(XML 格式,与主对话一致,消息间以 `---` 分隔,用于消歧,不要求逐字复述): {recent_messages} 必须通过 `submit_rewrite` 工具提交结果,禁止输出普通文本内容。 diff --git a/src/Undefined/ai/prompts.py b/src/Undefined/ai/prompts.py index e938181..5a87c8d 100644 --- a/src/Undefined/ai/prompts.py +++ b/src/Undefined/ai/prompts.py @@ -11,7 +11,6 @@ import aiofiles -from Undefined.attachments import attachment_refs_to_xml from Undefined.utils.coerce import safe_int from Undefined.context import RequestContext from Undefined.end_summary_storage import ( @@ -23,7 +22,7 @@ from Undefined.skills.anthropic_skills import AnthropicSkillRegistry from Undefined.utils.logging import log_debug_json from Undefined.utils.resources import read_text_resource -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from Undefined.utils.xml import format_message_xml logger = logging.getLogger(__name__) @@ -725,58 +724,7 @@ async def _inject_recent_messages( recent_msgs = self._drop_current_message_if_duplicated( recent_msgs, question ) - context_lines: list[str] = [] - for msg in recent_msgs: - msg_type_val = msg.get("type", "group") - sender_name = msg.get("display_name", "未知用户") - sender_id = msg.get("user_id", "") - chat_id = msg.get("chat_id", "") - chat_name = msg.get("chat_name", "未知群聊") - timestamp = msg.get("timestamp", "") - text = msg.get("message", "") - attachments = msg.get("attachments", []) - role = msg.get("role", "member") - title = msg.get("title", "") - level = msg.get("level", "") - message_id = msg.get("message_id") - - safe_sender = escape_xml_attr(sender_name) - safe_sender_id = escape_xml_attr(sender_id) - safe_chat_id = escape_xml_attr(chat_id) - safe_chat_name = escape_xml_attr(chat_name) - safe_role = escape_xml_attr(role) - safe_title = escape_xml_attr(title) - safe_time = escape_xml_attr(timestamp) - safe_text = escape_xml_text(str(text)) - - msg_id_attr = "" - if message_id is not None: - msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' - attachment_xml = ( - f"\n{attachment_refs_to_xml(attachments)}" - if isinstance(attachments, list) and attachments - else "" - ) - - if msg_type_val == "group": - location = ( - chat_name if chat_name.endswith("群") else f"{chat_name}群" - ) - safe_location = escape_xml_attr(location) - level_attr = f' level="{escape_xml_attr(level)}"' if level else "" - xml_msg = ( - f'\n{safe_text}{attachment_xml}\n' - ) - else: - location = "私聊" - safe_location = escape_xml_attr(location) - xml_msg = ( - f'\n{safe_text}{attachment_xml}\n' - ) - context_lines.append(xml_msg) + context_lines: list[str] = [format_message_xml(msg) for msg in recent_msgs] formatted_context = "\n---\n".join(context_lines) diff --git a/src/Undefined/cognitive/historian.py b/src/Undefined/cognitive/historian.py index fde5d24..fda333c 100644 --- a/src/Undefined/cognitive/historian.py +++ b/src/Undefined/cognitive/historian.py @@ -461,7 +461,7 @@ async def _rewrite( recent_messages = [ str(item).strip() for item in recent_messages_raw if str(item).strip() ] - recent_messages_text = "\n".join(f"- {line}" for line in recent_messages) + recent_messages_text = "\n---\n".join(recent_messages) prompt = template.format( request_id=job.get("request_id", ""), end_seq=job.get("end_seq", 0), diff --git a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py index 4426752..56be9c3 100644 --- a/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py +++ b/src/Undefined/skills/agents/summary_agent/tools/fetch_messages/handler.py @@ -5,8 +5,7 @@ from datetime import datetime, timedelta from typing import Any -from Undefined.attachments import attachment_refs_to_xml -from Undefined.utils.xml import escape_xml_attr, escape_xml_text +from Undefined.utils.xml import format_messages_xml logger = logging.getLogger(__name__) @@ -58,12 +57,6 @@ def _filter_by_time( return result -def _format_message_location(msg_type_val: str, chat_name: str) -> str: - if msg_type_val == "group": - return chat_name if chat_name.endswith("群") else f"{chat_name}群" - return "私聊" - - def _normalize_messages_for_chat( messages: list[dict[str, Any]], *, @@ -83,63 +76,6 @@ def _normalize_messages_for_chat( return normalized -def _format_message_xml(msg: dict[str, Any]) -> str: - msg_type_val = str(msg.get("type", "group") or "group") - sender_name = str(msg.get("display_name", "未知用户") or "未知用户") - sender_id = str(msg.get("user_id", "") or "") - chat_id = str(msg.get("chat_id", "") or "") - chat_name = str(msg.get("chat_name", "未知群聊") or "未知群聊") - timestamp = str(msg.get("timestamp", "") or "") - text = str(msg.get("message", "") or "") - message_id = msg.get("message_id") - role = str(msg.get("role", "member") or "member") - title = str(msg.get("title", "") or "") - level = str(msg.get("level", "") or "") - attachments = msg.get("attachments", []) - - safe_sender = escape_xml_attr(sender_name) - safe_sender_id = escape_xml_attr(sender_id) - safe_chat_id = escape_xml_attr(chat_id) - safe_chat_name = escape_xml_attr(chat_name) - safe_role = escape_xml_attr(role) - safe_title = escape_xml_attr(title) - safe_time = escape_xml_attr(timestamp) - safe_text = escape_xml_text(text) - safe_location = escape_xml_attr(_format_message_location(msg_type_val, chat_name)) - - msg_id_attr = "" - if message_id is not None: - msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' - - attachment_xml = ( - f"\n{attachment_refs_to_xml(attachments)}" - if isinstance(attachments, list) and attachments - else "" - ) - - if msg_type_val == "group": - level_attr = f' level="{escape_xml_attr(level)}"' if level else "" - return ( - f'\n' - f"{safe_text}{attachment_xml}\n" - f"" - ) - - return ( - f'\n' - f"{safe_text}{attachment_xml}\n" - f"" - ) - - -def _format_messages(messages: list[dict[str, Any]]) -> str: - """Format messages into main-AI-compatible XML for the summary agent.""" - return "\n---\n".join(_format_message_xml(msg) for msg in messages) - - async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: """拉取当前会话的聊天消息。""" history_manager = context.get("history_manager") @@ -189,7 +125,7 @@ async def execute(args: dict[str, Any], context: dict[str, Any]) -> str: messages, chat_type=chat_type, chat_id=chat_id ) - formatted = _format_messages(messages) + formatted = format_messages_xml(messages) total = len(messages) header = f"共获取 {total} 条消息" if time_range_str: diff --git a/src/Undefined/skills/tools/end/handler.py b/src/Undefined/skills/tools/end/handler.py index 8afb076..42a9883 100644 --- a/src/Undefined/skills/tools/end/handler.py +++ b/src/Undefined/skills/tools/end/handler.py @@ -7,6 +7,7 @@ from Undefined.context import RequestContext from Undefined.utils.coerce import safe_int +from Undefined.utils.xml import format_message_xml from Undefined.end_summary_storage import ( EndSummaryLocation, @@ -92,37 +93,38 @@ def _clamp_int(value: int, min_value: int, max_value: int) -> int: return value -def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int, int]: +def _resolve_historian_limits(context: Dict[str, Any]) -> tuple[int, int]: + """Return (max_source_len, recent_k) for historian context injection. + + ``recent_k`` is derived from ``context_recent_messages_limit`` (same as + main AI) so the historian sees the same message window. + """ max_source_len = _DEFAULT_HISTORIAN_TEXT_LEN recent_k = _DEFAULT_HISTORIAN_LINES - max_recent_line_len = _DEFAULT_HISTORIAN_LINE_LEN runtime_config = context.get("runtime_config") + + # 消息数量:优先使用 context_recent_messages_limit(与主 AI 一致) + if runtime_config is not None and hasattr( + runtime_config, "get_context_recent_messages_limit" + ): + try: + recent_k = int(runtime_config.get_context_recent_messages_limit()) + except Exception: + pass + cognitive = getattr(runtime_config, "cognitive", None) if runtime_config else None if cognitive is not None: max_source_len = safe_int( getattr(cognitive, "historian_source_message_max_len", max_source_len), max_source_len, ) - recent_k = safe_int( - getattr(cognitive, "historian_recent_messages_inject_k", recent_k), - recent_k, - ) - max_recent_line_len = safe_int( - getattr( - cognitive, "historian_recent_message_line_max_len", max_recent_line_len - ), - max_recent_line_len, - ) max_source_len = _clamp_int( max_source_len, _MIN_HISTORIAN_TEXT_LEN, _MAX_HISTORIAN_TEXT_LEN ) recent_k = _clamp_int(recent_k, _MIN_HISTORIAN_LINES, _MAX_HISTORIAN_LINES) - max_recent_line_len = _clamp_int( - max_recent_line_len, _MIN_HISTORIAN_LINE_LEN, _MAX_HISTORIAN_LINE_LEN - ) - return max_source_len, recent_k, max_recent_line_len + return max_source_len, recent_k def _extract_current_content_from_question(question: str, *, max_len: int) -> str: @@ -136,8 +138,13 @@ def _extract_current_content_from_question(question: str, *, max_len: int) -> st def _build_historian_recent_messages( - context: Dict[str, Any], *, recent_k: int, max_line_len: int + context: Dict[str, Any], *, recent_k: int ) -> list[str]: + """Build XML-formatted recent messages for historian context. + + Uses the same XML schema as the main AI prompt so the historian LLM + sees identical message structure for better disambiguation. + """ if recent_k <= 0: return [] @@ -173,24 +180,14 @@ def _build_historian_recent_messages( for msg in recent: if not isinstance(msg, dict): continue - timestamp = str(msg.get("timestamp", "")).strip() - display_name = str(msg.get("display_name", "")).strip() - user_id = str(msg.get("user_id", "")).strip() - message_text = _clip_text(msg.get("message", ""), max_line_len) - if not message_text: + if not str(msg.get("message", "") or "").strip(): continue - who = display_name or (f"UID:{user_id}" if user_id else "未知用户") - if user_id: - who = f"{who}({user_id})" - if timestamp: - lines.append(f"[{timestamp}] {who}: {message_text}") - else: - lines.append(f"{who}: {message_text}") + lines.append(format_message_xml(msg)) return lines[-recent_k:] def _inject_historian_reference_context(context: Dict[str, Any]) -> None: - max_source_len, recent_k, max_recent_line_len = _resolve_historian_limits(context) + max_source_len, recent_k = _resolve_historian_limits(context) current_question = str(context.get("current_question") or "").strip() source_message = _extract_current_content_from_question( current_question, max_len=max_source_len @@ -202,9 +199,7 @@ def _inject_historian_reference_context(context: Dict[str, Any]) -> None: current_question, max_source_len ) - recent_lines = _build_historian_recent_messages( - context, recent_k=recent_k, max_line_len=max_recent_line_len - ) + recent_lines = _build_historian_recent_messages(context, recent_k=recent_k) if recent_lines: context["historian_recent_messages"] = recent_lines diff --git a/src/Undefined/utils/xml.py b/src/Undefined/utils/xml.py index 865ac24..16d2415 100644 --- a/src/Undefined/utils/xml.py +++ b/src/Undefined/utils/xml.py @@ -1,7 +1,9 @@ -"""Minimal XML escaping helpers.""" +"""Minimal XML escaping helpers and message formatting.""" from __future__ import annotations +from typing import Any, Callable, Sequence, Mapping + from xml.sax.saxutils import escape @@ -12,3 +14,80 @@ def escape_xml_text(value: str) -> str: def escape_xml_attr(value: object) -> str: text = "" if value is None else str(value) return escape(text, {'"': """, "'": "'"}) + + +def _message_location(msg_type: str, chat_name: str) -> str: + """Derive the human-readable location label from message type.""" + if msg_type == "group": + return chat_name if chat_name.endswith("群") else f"{chat_name}群" + return "私聊" + + +def format_message_xml( + msg: dict[str, Any], + *, + attachment_formatter: (Callable[[Sequence[Mapping[str, str]]], str] | None) = None, +) -> str: + """Format a single history record dict into main-AI-compatible XML. + + ``attachment_formatter`` is an optional callable that turns the attachments + list into an XML fragment. When *None* (the default) a lazy import of + :func:`Undefined.attachments.attachment_refs_to_xml` is used so that + lightweight callers do not pay the import cost. + """ + msg_type_val = str(msg.get("type", "group") or "group") + sender_name = str(msg.get("display_name", "未知用户") or "未知用户") + sender_id = str(msg.get("user_id", "") or "") + chat_id = str(msg.get("chat_id", "") or "") + chat_name = str(msg.get("chat_name", "未知群聊") or "未知群聊") + timestamp = str(msg.get("timestamp", "") or "") + text = str(msg.get("message", "") or "") + message_id = msg.get("message_id") + role = str(msg.get("role", "member") or "member") + title = str(msg.get("title", "") or "") + level = str(msg.get("level", "") or "") + attachments = msg.get("attachments", []) + + safe_sender = escape_xml_attr(sender_name) + safe_sender_id = escape_xml_attr(sender_id) + safe_chat_id = escape_xml_attr(chat_id) + safe_chat_name = escape_xml_attr(chat_name) + safe_role = escape_xml_attr(role) + safe_title = escape_xml_attr(title) + safe_time = escape_xml_attr(timestamp) + safe_text = escape_xml_text(text) + safe_location = escape_xml_attr(_message_location(msg_type_val, chat_name)) + + msg_id_attr = "" + if message_id is not None: + msg_id_attr = f' message_id="{escape_xml_attr(str(message_id))}"' + + attachment_xml = "" + if isinstance(attachments, list) and attachments: + if attachment_formatter is None: + from Undefined.attachments import attachment_refs_to_xml + + attachment_formatter = attachment_refs_to_xml + attachment_xml = f"\n{attachment_formatter(attachments)}" + + if msg_type_val == "group": + level_attr = f' level="{escape_xml_attr(level)}"' if level else "" + return ( + f'\n' + f"{safe_text}{attachment_xml}\n" + f"" + ) + + return ( + f'\n' + f"{safe_text}{attachment_xml}\n" + f"" + ) + + +def format_messages_xml(messages: list[dict[str, Any]]) -> str: + """Format a list of history records into ``\\n---\\n``-separated XML.""" + return "\n---\n".join(format_message_xml(msg) for msg in messages) diff --git a/tests/test_end_tool.py b/tests/test_end_tool.py index 45e4106..4b65584 100644 --- a/tests/test_end_tool.py +++ b/tests/test_end_tool.py @@ -178,10 +178,9 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: cognitive_service = _FakeCognitiveService() runtime_config = SimpleNamespace( cognitive=SimpleNamespace( - historian_recent_messages_inject_k=2, - historian_recent_message_line_max_len=60, historian_source_message_max_len=40, - ) + ), + get_context_recent_messages_limit=lambda: 2, ) long_content = "A" * 300 context: dict[str, Any] = { @@ -207,6 +206,7 @@ async def test_end_uses_runtime_config_for_historian_reference_limits() -> None: assert len(source) <= 40 assert isinstance(recent, list) assert len(recent) == 2 - assert all( - len(str(line).split(": ", 1)[1]) <= 60 for line in recent if ": " in str(line) - ) + # Recent messages now use XML format (same as main AI) + for line in recent: + assert "" in str(line) diff --git a/tests/test_fetch_messages_tool.py b/tests/test_fetch_messages_tool.py index 1804c72..65fff3d 100644 --- a/tests/test_fetch_messages_tool.py +++ b/tests/test_fetch_messages_tool.py @@ -8,11 +8,10 @@ from Undefined.skills.agents.summary_agent.tools.fetch_messages.handler import ( _filter_by_time, - _format_messages, - _format_message_xml, _parse_time_range, execute as fetch_messages_execute, ) +from Undefined.utils.xml import format_message_xml, format_messages_xml # -- _parse_time_range unit tests -- @@ -124,10 +123,10 @@ def test_filter_by_time_invalid_timestamp() -> None: assert len(result) == 0 -# -- _format_messages unit tests -- +# -- format_messages_xml unit tests -- -def test_format_message_xml_group_basic() -> None: +def testformat_message_xml_group_basic() -> None: """Group message is formatted into main-AI-compatible XML.""" messages = [ { @@ -145,7 +144,7 @@ def test_format_message_xml_group_basic() -> None: }, ] - result = _format_message_xml(messages[0]) + result = format_message_xml(messages[0]) assert 'message_id="123"' in result assert 'sender="Alice"' in result assert 'sender_id="10001"' in result @@ -158,7 +157,7 @@ def test_format_message_xml_group_basic() -> None: assert "Hello" in result -def test_format_message_xml_private_basic() -> None: +def testformat_message_xml_private_basic() -> None: """Private message uses the private XML shape.""" msg = { "type": "private", @@ -169,7 +168,7 @@ def test_format_message_xml_private_basic() -> None: "message_id": 456, } - result = _format_message_xml(msg) + result = format_message_xml(msg) assert 'message_id="456"' in result assert 'sender="Bob"' in result assert 'sender_id="10002"' in result @@ -179,7 +178,7 @@ def test_format_message_xml_private_basic() -> None: assert "Hi" in result -def test_format_message_xml_includes_attachments() -> None: +def testformat_message_xml_includes_attachments() -> None: """Attachment refs are rendered as XML below content.""" msg = { "type": "group", @@ -200,14 +199,14 @@ def test_format_message_xml_includes_attachments() -> None: ], } - result = _format_message_xml(msg) + result = format_message_xml(msg) assert "" in result assert 'uid="pic_abcd1234"' in result assert 'type="image"' in result assert 'description="截图"' in result -def test_format_messages_multiple() -> None: +def testformat_messages_xml_multiple() -> None: """Multiple messages are separated by main-AI-style delimiters.""" messages = [ { @@ -228,12 +227,12 @@ def test_format_messages_multiple() -> None: }, ] - result = _format_messages(messages) + result = format_messages_xml(messages) assert "\n---\n" in result assert result.count(" None: +def testformat_messages_xml_missing_fields() -> None: """Missing fields still produce valid XML.""" messages = [ { @@ -242,7 +241,7 @@ def test_format_messages_missing_fields() -> None: }, ] - result = _format_messages(messages) + result = format_messages_xml(messages) assert "未知用户" in result assert "No timestamp" in result assert " Date: Sun, 19 Apr 2026 17:44:49 +0800 Subject: [PATCH 52/57] fix(webui): cap top_k overflow, debounce meme search, reorder dashboard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Cap top_k to 500 in frontend (appendPositiveIntParam) and backend (cognitive service + vector_store _safe_positive_int) to prevent ChromaDB OverflowError when large integers are passed - Add max=500 to all top_k HTML inputs - Cap fetch_k to 10000 in vector_store._query() as safety net - Add debounced auto-search (350ms) for meme text inputs with pending-refresh pattern to avoid stale results - Enter key flushes debounce timer for instant search - Move 运行环境 card before 资源趋势 chart in dashboard layout Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/cognitive/service.py | 3 ++ src/Undefined/cognitive/vector_store.py | 7 +++- src/Undefined/webui/static/js/memes.js | 47 +++++++++++++++++++----- src/Undefined/webui/static/js/runtime.js | 4 +- src/Undefined/webui/templates/index.html | 24 ++++++------ 5 files changed, 60 insertions(+), 25 deletions(-) diff --git a/src/Undefined/cognitive/service.py b/src/Undefined/cognitive/service.py index 1febe85..13cf3f1 100644 --- a/src/Undefined/cognitive/service.py +++ b/src/Undefined/cognitive/service.py @@ -725,6 +725,7 @@ async def build_context( top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) try: events = await self._query_events_for_auto_context( query=query, @@ -819,6 +820,7 @@ async def search_events(self, query: str, **kwargs: Any) -> list[dict[str, Any]] top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) logger.info( "[认知服务] 搜索事件: query_len=%s top_k=%s where=%s time_from=%s time_to=%s", len(query or ""), @@ -872,6 +874,7 @@ async def search_profiles(self, query: str, **kwargs: Any) -> list[dict[str, Any top_k = default_top_k if top_k <= 0: top_k = default_top_k + top_k = min(top_k, 500) where: dict[str, Any] | None = None entity_type_raw = kwargs.get("entity_type") diff --git a/src/Undefined/cognitive/vector_store.py b/src/Undefined/cognitive/vector_store.py index 0c84f5a..48b80eb 100644 --- a/src/Undefined/cognitive/vector_store.py +++ b/src/Undefined/cognitive/vector_store.py @@ -34,13 +34,15 @@ def _clamp(value: float, lower: float, upper: float) -> float: return value -def _safe_positive_int(value: Any, default: int) -> int: +def _safe_positive_int(value: Any, default: int, maximum: int = 0) -> int: try: parsed = int(value) except Exception: return max(1, int(default)) if parsed <= 0: return max(1, int(default)) + if maximum > 0 and parsed > maximum: + return maximum return parsed @@ -427,7 +429,7 @@ async def _query( query_embedding: list[float] | None = None, ) -> list[dict[str, Any]]: col_name = getattr(col, "name", "unknown") - safe_top_k = _safe_positive_int(top_k, default=1) + safe_top_k = _safe_positive_int(top_k, default=1, maximum=500) safe_multiplier = _safe_positive_int(candidate_multiplier, default=1) total_started = time.perf_counter() logger.debug( @@ -455,6 +457,7 @@ async def _query( use_reranker or apply_time_decay or apply_mmr ) fetch_k = safe_top_k * safe_multiplier if use_extra_candidates else safe_top_k + fetch_k = min(fetch_k, 10000) include: list[str] = ["documents", "metadatas", "distances"] if apply_mmr: include.append("embeddings") diff --git a/src/Undefined/webui/static/js/memes.js b/src/Undefined/webui/static/js/memes.js index 0afef67..465c3cd 100644 --- a/src/Undefined/webui/static/js/memes.js +++ b/src/Undefined/webui/static/js/memes.js @@ -314,6 +314,27 @@ return payload; } + let _debounceTimer = null; + let _pendingRefresh = false; + + function debouncedFetchList(delayMs = 350) { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + } + _debounceTimer = setTimeout(() => { + _debounceTimer = null; + fetchList().catch(showError); + }, delayMs); + } + + function flushDebouncedFetchList() { + if (_debounceTimer !== null) { + clearTimeout(_debounceTimer); + _debounceTimer = null; + } + fetchList().catch(showError); + } + async function fetchList(options = {}) { const append = !!options.append; if (append) { @@ -321,6 +342,7 @@ return null; } } else if (state.loading) { + _pendingRefresh = true; return null; } @@ -415,6 +437,10 @@ } renderLoadMore(); } + if (!append && _pendingRefresh) { + _pendingRefresh = false; + fetchList().catch(showError); + } } } @@ -558,7 +584,7 @@ state.initialized = true; get("btnMemesRefresh")?.addEventListener("click", refreshAll); get("btnMemesSearch")?.addEventListener("click", () => { - fetchList().catch(showError); + flushDebouncedFetchList(); }); get("btnMemesLoadMore")?.addEventListener("click", () => { fetchList({ append: true }).catch(showError); @@ -579,14 +605,17 @@ get("btnMemesDelete")?.addEventListener("click", () => { deleteSelected().catch(showError); }); - bindEnter("memesSearchInput", () => { - fetchList().catch(showError); - }); - bindEnter("memesKeywordQuery", () => { - fetchList().catch(showError); - }); - bindEnter("memesSemanticQuery", () => { - fetchList().catch(showError); + bindEnter("memesSearchInput", flushDebouncedFetchList); + bindEnter("memesKeywordQuery", flushDebouncedFetchList); + bindEnter("memesSemanticQuery", flushDebouncedFetchList); + [ + "memesSearchInput", + "memesKeywordQuery", + "memesSemanticQuery", + ].forEach((id) => { + const el = get(id); + if (el) + el.addEventListener("input", () => debouncedFetchList()); }); [ "memesQueryMode", diff --git a/src/Undefined/webui/static/js/runtime.js b/src/Undefined/webui/static/js/runtime.js index 42655a5..c1e1282 100644 --- a/src/Undefined/webui/static/js/runtime.js +++ b/src/Undefined/webui/static/js/runtime.js @@ -303,12 +303,12 @@ params.set(key, text); } - function appendPositiveIntParam(params, key, value) { + function appendPositiveIntParam(params, key, value, max = 500) { const text = String(value || "").trim(); if (!text) return; const num = Number.parseInt(text, 10); if (!Number.isFinite(num) || num <= 0) return; - params.set(key, String(num)); + params.set(key, String(Math.min(num, max))); } function formatNumeric(value, digits = 4) { diff --git a/src/Undefined/webui/templates/index.html b/src/Undefined/webui/templates/index.html index 214ee63..a306720 100644 --- a/src/Undefined/webui/templates/index.html +++ b/src/Undefined/webui/templates/index.html @@ -299,15 +299,6 @@

运行概览

-
-
资源趋势
- -
- CPU - Memory -
-
-
运行环境
@@ -323,6 +314,15 @@

运行概览

--
+ +
+
资源趋势
+ +
+ CPU + Memory +
+
@@ -492,7 +492,7 @@

记忆检索

- @@ -571,7 +571,7 @@

记忆检索

- From c347f6d744e6a200bc4a16befc445dacff86bad0 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 17:56:05 +0800 Subject: [PATCH 53/57] fix(hot_reload,repeat): address 2 Devin review bugs - Add historian_model to hot_reload tracking sets (_QUEUE_INTERVAL_KEYS and _MODEL_NAME_KEYS) so config changes take effect without restart - Fix memory leak in _record_repeat_cooldown: skip recording entirely when cooldown_minutes=0 instead of accumulating never-evicted entries - Add test assertion verifying no cooldown entries when disabled Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Undefined/config/hot_reload.py | 2 ++ src/Undefined/handlers.py | 13 ++++++------- tests/test_handlers_repeat.py | 2 ++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/Undefined/config/hot_reload.py b/src/Undefined/config/hot_reload.py index f603524..f0d83c9 100644 --- a/src/Undefined/config/hot_reload.py +++ b/src/Undefined/config/hot_reload.py @@ -43,6 +43,7 @@ "naga_model.queue_interval_seconds", "agent_model.queue_interval_seconds", "summary_model.queue_interval_seconds", + "historian_model.queue_interval_seconds", "grok_model.queue_interval_seconds", "chat_model.pool", "agent_model.pool", @@ -55,6 +56,7 @@ "naga_model.model_name", "agent_model.model_name", "summary_model.model_name", + "historian_model.model_name", "grok_model.model_name", } diff --git a/src/Undefined/handlers.py b/src/Undefined/handlers.py index 34fa9f6..9d5bfcb 100644 --- a/src/Undefined/handlers.py +++ b/src/Undefined/handlers.py @@ -167,17 +167,16 @@ def _is_repeat_on_cooldown(self, group_id: int, text: str) -> bool: def _record_repeat_cooldown(self, group_id: int, text: str) -> None: """记录复读冷却时间戳,同时清理已过期条目防止内存泄漏。""" + cooldown_seconds = self.config.repeat_cooldown_minutes * 60 + if cooldown_seconds <= 0: + return key = self._normalize_repeat_text(text) group_cd = self._repeat_cooldown.setdefault(group_id, {}) now = time.monotonic() - cooldown_seconds = self.config.repeat_cooldown_minutes * 60 # 清理已过期条目 - if cooldown_seconds > 0: - expired = [ - k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds - ] - for k in expired: - del group_cd[k] + expired = [k for k, ts in group_cd.items() if (now - ts) >= cooldown_seconds] + for k in expired: + del group_cd[k] group_cd[key] = now async def _annotate_meme_descriptions( diff --git a/tests/test_handlers_repeat.py b/tests/test_handlers_repeat.py index c29219d..c305b6b 100644 --- a/tests/test_handlers_repeat.py +++ b/tests/test_handlers_repeat.py @@ -417,6 +417,8 @@ async def test_repeat_cooldown_zero_disables() -> None: for uid in [20004, 20005, 20006]: await handler.handle_message(_group_event(sender_id=uid, text="草")) assert handler.sender.send_group_message.call_count == 2 + # cooldown=0 不应写入任何冷却记录(防止内存泄漏) + assert len(handler._repeat_cooldown) == 0 @pytest.mark.asyncio From ce5db0f46a3d70cfa37539be57935746929f0404 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:03:46 +0800 Subject: [PATCH 54/57] docs(changelog): update v3.3.2 with all session fixes and improvements Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f2f07..ebe9f48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,19 +2,44 @@ 围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 +### 新功能 + - 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 - `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 - 超级管理员可通过 `/p ` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 - 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 -- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 -- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 -- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型消耗全部 token 导致工具调用截断的问题。 - 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 - 新增 `calculator` 多功能安全计算器工具。 - 新增消息历史限制全面可配置化(`[history].max_records`)。 -- 新增 `utils/coerce.py`(安全类型强转)与 `utils/fake_at.py`(假@文本检测与解析)公共模块。 -- WebUI 新增功能:Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 新增 `utils/coerce.py`(安全类型强转)、`utils/fake_at.py`(假@检测与解析)与 `utils/xml.py`(统一 XML 消息格式化)公共模块。 + +### 架构重构 + +- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 +- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 +- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 +- 认知史官(historian)消息格式从纯文本改为 XML,与主 AI 上下文格式统一;消息数量改用 `context_recent_messages_limit`(默认 20)。 +- 统一 XML 消息格式化逻辑至 `utils/xml.py`,消除 `prompts.py`、`fetch_messages` 和 `end` tool 之间的重复代码。 + +### WebUI + +- Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 +- 修复 top_k 参数溢出导致 ChromaDB Rust 整数越界崩溃的问题,前端/后端四层防护(HTML max 属性、JS 参数上限、Service 层 clamp、向量存储 fetch_k 硬上限)。 +- 表情包搜索页新增 350ms 防抖自动搜索,Enter 立即刷新,`_pendingRefresh` 模式防止并发请求返回过期结果。 +- 仪表盘布局优化:三张信息卡共占一行,资源趋势图全宽置底。 + +### Bug 修复 + +- 修复 AI 模仿系统关键词自动回复前缀生成伪系统消息的问题,在提示词中明确标注该前缀仅由代码路径使用。 +- 修复 LaTeX 渲染中 `\n` 替换破坏 `\nu`、`\nabla`、`\neq` 等命令的问题,改用负向前瞻正则 `\\n(?![a-zA-Z])`。 +- 修复复读冷却抑制后错误丢弃消息(不再触发自动回复)的问题。 +- 修复 `repeat_cooldown_minutes=0` 时 `_repeat_cooldown` 字典持续积累不被清理的内存泄漏。 +- 修复 `historian_model` 未加入热重载追踪集导致配置热更新后队列间隔和模型名失效的问题。 - 修复队列系统 historian 模型未注册的调度问题。 +- 修复表情包 GIF 多帧分析 `_process_reanalyze_job` 未检查 `gif_analysis_mode` 的问题。 +- 修复表情包 GIF 帧临时文件在可重试 LLM 错误时未被清理的资源泄漏。 +- 修复附件渲染错误时仍追加 `prompt_ref` 导致后续异常的问题。 - 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 - 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 From 0b009db57d79b3d146138f46213631945dda5578 Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:07:01 +0800 Subject: [PATCH 55/57] docs(changelog): simplify v3.3.2 format to match prior versions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 51 +++++++++++---------------------------------------- 1 file changed, 11 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebe9f48..5db78b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,46 +2,17 @@ 围绕核心架构进行了大规模重构与功能增强:Runtime API 拆分为路由子模块、配置系统模块化拆分、新增假@检测机制与 /profile 多输出模式。同步引入复读机制全面升级(可配置阈值与冷却)、消息预处理并行化、WebUI 多项交互功能,以及 arXiv 论文分析 Agent 和安全计算器工具。测试覆盖从约 800 提升至 1438+。 -### 新功能 - -- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动从群上下文获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 -- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认,分元数据与内容两条消息)、`-r` 渲染为图片、`-t` 直接文本发送。 -- 超级管理员可通过 `/p ` 和 `/p g <群号>` 跨目标查看任意用户或群聊的认知侧写。 -- 复读系统全面升级:触发阈值可配置(`repeat_threshold`,2–20)、Bot 发言不计入复读链、新增复读冷却机制(`repeat_cooldown_minutes`,默认 60 分钟,?与 ? 等价)。 -- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型消耗全部 token 导致工具调用截断的问题。 -- 新增 arXiv 论文深度分析 Agent,提供论文搜索、摘要提取与关键信息分析能力。 -- 新增 `calculator` 多功能安全计算器工具。 -- 新增消息历史限制全面可配置化(`[history].max_records`)。 -- 新增 `utils/coerce.py`(安全类型强转)、`utils/fake_at.py`(假@检测与解析)与 `utils/xml.py`(统一 XML 消息格式化)公共模块。 - -### 架构重构 - -- Runtime API 从 2491 行的单体 `app.py` 拆分为 8 个路由子模块 (`api/routes/`),主文件仅保留薄包装委派层。 -- 配置系统模块化拆分:`config/` 拆为 `loader.py`、`models.py`、`hot_reload.py`,`sync_config_template` 脚本支持报告注释变更路径。 -- 消息预处理流程并行化:使用 `asyncio.gather` 并行执行安全检查、认知检索和假@检测,降低消息处理延迟。 -- 认知史官(historian)消息格式从纯文本改为 XML,与主 AI 上下文格式统一;消息数量改用 `context_recent_messages_limit`(默认 20)。 -- 统一 XML 消息格式化逻辑至 `utils/xml.py`,消除 `prompts.py`、`fetch_messages` 和 `end` tool 之间的重复代码。 - -### WebUI - -- Cmd/Ctrl+K 命令面板、骨架屏加载态、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆完整 CRUD 管理、Modal 焦点陷阱。 -- 修复 top_k 参数溢出导致 ChromaDB Rust 整数越界崩溃的问题,前端/后端四层防护(HTML max 属性、JS 参数上限、Service 层 clamp、向量存储 fetch_k 硬上限)。 -- 表情包搜索页新增 350ms 防抖自动搜索,Enter 立即刷新,`_pendingRefresh` 模式防止并发请求返回过期结果。 -- 仪表盘布局优化:三张信息卡共占一行,资源趋势图全宽置底。 - -### Bug 修复 - -- 修复 AI 模仿系统关键词自动回复前缀生成伪系统消息的问题,在提示词中明确标注该前缀仅由代码路径使用。 -- 修复 LaTeX 渲染中 `\n` 替换破坏 `\nu`、`\nabla`、`\neq` 等命令的问题,改用负向前瞻正则 `\\n(?![a-zA-Z])`。 -- 修复复读冷却抑制后错误丢弃消息(不再触发自动回复)的问题。 -- 修复 `repeat_cooldown_minutes=0` 时 `_repeat_cooldown` 字典持续积累不被清理的内存泄漏。 -- 修复 `historian_model` 未加入热重载追踪集导致配置热更新后队列间隔和模型名失效的问题。 -- 修复队列系统 historian 模型未注册的调度问题。 -- 修复表情包 GIF 多帧分析 `_process_reanalyze_job` 未检查 `gif_analysis_mode` 的问题。 -- 修复表情包 GIF 帧临时文件在可重试 LLM 错误时未被清理的资源泄漏。 -- 修复附件渲染错误时仍追加 `prompt_ref` 导致后续异常的问题。 -- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色方案并提高截断上限至 5000 字符。 -- 测试覆盖大幅补齐(804 → 1438+),ruff + mypy 零错误。 +- 新增假@检测:群聊中 `@+Bot昵称` 的文本形式也被识别为@消息,自动获取昵称(防竞态),`@昵称 /命令` 可正常触发斜杠指令。 +- `/profile` 命令支持三种输出模式:`-f` 合并转发(默认)、`-r` 渲染为图片、`-t` 直接文本发送;超管可通过 `/p ` 和 `/p g <群号>` 跨目标查看。 +- 复读系统全面升级:触发阈值可配置(`repeat_threshold`)、Bot 发言不计入复读链、新增冷却机制(`repeat_cooldown_minutes`,?与 ? 等价)。 +- Vision 模型 `max_tokens` 可配置(`[models.vision].max_tokens`,默认 8192),解决 thinking 模型 token 耗尽导致工具调用截断。 +- 新增 arXiv 论文深度分析 Agent 与 `calculator` 安全计算器工具;新增消息历史限制可配置化(`[history].max_records`)。 +- Runtime API 拆分为 8 个路由子模块(`api/routes/`);配置系统拆为 `loader.py`、`models.py`、`hot_reload.py`。 +- 消息预处理并行化(`asyncio.gather`),认知史官改用 XML 格式并统一至 `utils/xml.py`。 +- WebUI:Cmd/Ctrl+K 命令面板、骨架屏、日志时间过滤、资源趋势图、TOML 原始视图、配置版本历史与回滚、长期记忆 CRUD、Modal 焦点陷阱。 +- WebUI 修复 top_k 溢出崩溃(四层防护)、表情包搜索防抖、仪表盘布局优化。 +- 修复 AI 模仿系统关键词自动回复前缀、LaTeX `\n` 替换破坏数学命令、复读冷却消息丢弃与内存泄漏、historian 热重载追踪缺失、GIF 多帧分析与临时文件泄漏、附件渲染异常等多项问题。 +- 修复 `/profile` 渲染留白和字体过小问题,使用 WebUI 配色并提高截断上限至 5000 字符。 --- From a2b8b2669e0e9d4bbf92d59e4d0ab82b0e246a3d Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 18:14:14 +0800 Subject: [PATCH 56/57] docs(webui): add WebUI usage guide and cross-reference links - Create docs/webui-guide.md covering all 8 tabs, config, shortcuts, FAQ - Add link in README.md documentation navigation section - Add reference in docs/deployment.md startup section - Add reference in docs/management-api.md recommended entry section Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 1 + docs/deployment.md | 2 + docs/management-api.md | 2 + docs/webui-guide.md | 181 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 docs/webui-guide.md diff --git a/README.md b/README.md index 48f231c..0cc94f0 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ Undefined 的功能极为丰富,为了让本页面不过于臃肿,我们将各个模块的深入解析与高阶玩法整理成了专题游览图。这里是开启探索的钥匙: - ⚙️ **[安装与部署指南](docs/deployment.md)**:不管你是需要 `pip` 无脑一键安装,还是源码二次开发,这里的排坑指南应有尽有。 +- 🖥️ **[WebUI 使用指南](docs/webui-guide.md)**:管理控制台功能一览——配置编辑、日志查看、认知记忆管理、表情包库、AI 对话与系统监控。 - 🧭 **[Management API 与远程管理](docs/management-api.md)**:WebUI / App 共用的管理接口、认证、配置/日志/Bot 控制与引导探针说明。 - 🛠️ **[配置与热更新说明](docs/configuration.md)**:从模型切换到 MCP 库挂载,全方位掌握 `config.toml` 的高阶配置。 - 😶 **[表情包系统 (Memes)](docs/memes.md)**:查看表情包两阶段判定管线、统一图片 `uid` 发送机制、检索模式及库存管理说明。 diff --git a/docs/deployment.md b/docs/deployment.md index 335b51d..bc5eb1b 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -113,6 +113,8 @@ uv run Undefined-webui ``` > **重要**:两种方式 **二选一即可**,不要同时运行。若你选择 `Undefined-webui`,请在 WebUI 中管理机器人进程的启停。 +> +> WebUI 功能详见 [WebUI 使用指南](webui-guide.md)。 ### 6. 跨平台与资源路径(重要) diff --git a/docs/management-api.md b/docs/management-api.md index 4daec18..346ae86 100644 --- a/docs/management-api.md +++ b/docs/management-api.md @@ -25,6 +25,8 @@ uv run Undefined-webui 4. 直接在控制台启动 Bot 5. 如需远程管理,再让桌面端或 Android App 连接同一个 Management API 地址 +> WebUI 各页面的功能和操作详见 [WebUI 使用指南](webui-guide.md)。 + ## 2. 鉴权模型 Management API 兼容两套鉴权: diff --git a/docs/webui-guide.md b/docs/webui-guide.md new file mode 100644 index 0000000..9a77ff1 --- /dev/null +++ b/docs/webui-guide.md @@ -0,0 +1,181 @@ +# WebUI 使用指南 + +WebUI 是 Undefined 的主要管理入口,提供配置编辑、日志查看、认知记忆管理、表情包库、AI 对话和系统监控等一站式功能。即使 `config.toml` 尚未创建,也可以通过 WebUI 补全配置并启动 Bot。 + +--- + +## 快速开始 + +### 启动 + +```bash +uv run Undefined-webui +``` + +默认监听 `http://127.0.0.1:8787`。相关配置位于 `config.toml` 的 `[webui]` 节: + +```toml +[webui] +url = "127.0.0.1" # 监听地址 +port = 8787 # 端口 +password = "changeme" # 密码(必须在首次登录时修改) +``` + +> 如需远程访问,将 `url` 改为 `0.0.0.0` 或实际 IP。 + +### 首次登录 + +1. 浏览器打开 `http://127.0.0.1:8787` +2. 输入初始密码 `changeme` +3. 系统强制要求修改默认密码后才能进入管理界面 + +修改后的密码会写入 `config.toml`。桌面端 / Android 客户端也使用同一密码连接。 + +--- + +## 功能概览 + +WebUI 共有 8 个主要页签(Tab),下面逐一介绍。 + +### 概览(Overview) + +仪表盘页面,展示系统运行状态: + +| 指标 | 说明 | +|------|------| +| Bot 状态 | 运行中 / 已停止 | +| 系统运行时间 | 自启动以来的持续时间 | +| CPU 使用率 | 实时百分比 | +| 内存占用 | 已用 / 总量 / 百分比 | +| 运行环境 | CPU 型号、操作系统、Python 版本 | +| 资源趋势图 | CPU / 内存随时间的变化曲线 | + +页面会自动刷新,也可手动触发。 + +### 配置管理(Config) + +提供两种编辑方式: + +- **表单模式**:按分组展示所有配置项,带字段说明和类型校验,适合日常修改。 +- **TOML 原始编辑器**:直接编辑 `config.toml` 文本,带语法高亮,适合批量变更。 + +其他能力: + +- **验证**:保存前自动进行 TOML 语法检查和严格配置校验,错误项会标红提示。 +- **配置历史**:每次保存自动生成带时间戳的备份(最多 50 版本),可随时回滚到任意历史版本。 +- **模板同步**:一键将 `config.toml.example` 中新增的配置项合并到当前配置,不覆盖已有值。 + +> 配置支持热更新——大多数配置项修改后即时生效,无需重启 Bot。需要重启的项(如 `onebot_ws_url`、`webui_port`)会在保存时提示。 + +### 日志查看(Logs) + +- 支持 **Bot 日志** 和 **WebUI 日志** 切换。 +- **实时流式推送**(SSE):日志实时滚动到最新。 +- 可暂停 / 恢复流式推送,调整显示行数(1–2000)。 +- 左侧文件列表展示所有日志文件(含归档),可选择查看历史日志。 +- 支持下载日志文件。 + +### 探针与诊断(Probes) + +三类探针帮助排查问题: + +- **内部探针**:版本号、Python 版本、平台、运行时间、OneBot 连接状态及 WebSocket 地址。 +- **外部探针**:Runtime API 的可用端点和能力列表。 +- **引导探针**:检查 `config.toml` 是否存在、TOML 语法是否合法、配置值是否有效,并给出修复建议。 + +### 认知记忆(Memory) + +分为三个子面板,对应 [认知记忆系统](cognitive-memory.md) 的不同层次: + +**认知事件搜索** + +输入关键词进行语义搜索,查看 AI 提取并存储的用户 / 群聊事实记录。支持调整返回条数和排序方式。 + +**认知侧写查看** + +- 按关键词搜索侧写。 +- 按 QQ 号或群号精确查看对应的完整侧写内容。 + +**长期记忆管理** + +AI 的置顶备忘录(自我约束、待办事项等),支持完整 CRUD: + +- 新建记忆条目 +- 编辑现有内容 +- 删除条目 +- 关键词搜索 + +### 表情包库(Memes) + +管理 [全局表情包库](memes.md) 的 Web 界面: + +- **浏览与搜索**:分页列表展示,支持关键词搜索和语义搜索(以及混合模式)。输入时自动防抖搜索,Enter 立即触发。 +- **筛选与排序**:按启用 / 禁用、静态 / 动态、置顶等条件筛选,按创建或更新时间排序。 +- **详情与操作**:查看元数据(UID、描述、标签)和预览图;支持编辑描述 / 标签、启用 / 禁用、置顶 / 取消、删除。 +- **重分析 / 重索引**:对单张表情包重新触发 AI 描述生成或搜索索引更新。 +- **统计概览**:总数、启用 / 禁用数、静态 / 动态数等。 + +### AI 对话(Chat) + +WebUI 内置的对话界面,直接与 Bot 的 AI 进行交互: + +- 支持文本和图片消息。 +- AI 回复支持 Markdown 渲染。 +- 消息历史分页浏览。 +- 发出的消息会经过与 QQ 侧相同的处理流程(安全检查、工具调用等)。 + +### 关于(About) + +显示当前版本号和 MIT 许可证文本。 + +--- + +## Bot 控制 + +WebUI 首页(Landing Page)提供 Bot 的启停控制: + +- **启动 Bot**:点击启动按钮,Bot 进程在后台运行。 +- **停止 Bot**:安全停止当前 Bot 进程。 +- **状态指示**:实时显示 Bot 运行状态。 + +首页还会检测是否有可用更新(基于 Git),并提供更新 + 重启功能。 + +--- + +## 键盘快捷键 + +| 快捷键 | 功能 | +|--------|------| +| `Cmd/Ctrl + K` | 打开命令面板,可快速跳转到任意页签或执行操作 | + +--- + +## 远程访问 + +WebUI 和桌面端 / Android 客户端共享同一 Management API: + +1. 将 `[webui].url` 设为 `0.0.0.0`(或你的 LAN/公网 IP)。 +2. 确保防火墙放行 `[webui].port`(默认 8787)。 +3. 桌面端 / Android 客户端输入 `http://:8787` 和密码即可连接。 + +如果启用了 Runtime API(`[api].enabled = true`),WebUI 会自动代理 Runtime API 的功能(探针、记忆查询、AI Chat 等),无需单独暴露 Runtime API 端口。 + +--- + +## 常见问题 + +**Q: 忘记密码怎么办?** + +直接编辑 `config.toml` 中的 `[webui].password` 字段,重启 WebUI 即可。 + +**Q: 配置保存后 Bot 没反应?** + +大多数配置项支持热更新。但少数关键配置(如 WebSocket 地址、WebUI 端口、API 端口)需要重启才能生效,保存时会有提示。 + +**Q: 日志不滚动了?** + +检查是否意外暂停了日志流。点击日志页面的播放按钮恢复实时推送。 + +**Q: 探针显示 Runtime API 不可达?** + +确认 `[api].enabled = true` 且 Bot 正在运行。Runtime API 由 Bot 主进程提供,Bot 未启动时自然不可达。 From 06a831149448d3fb7fbb6c1163bec4cd8acef17c Mon Sep 17 00:00:00 2001 From: Null <1708213363@qq.com> Date: Sun, 19 Apr 2026 21:07:17 +0800 Subject: [PATCH 57/57] fix(calculator): cap combinatorial function args to prevent CPU exhaustion Add _MAX_COMBINATORIAL_ARG=1000 limit for factorial/perm/comb to prevent adversarial inputs like factorial(99999) from consuming excessive CPU. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../skills/tools/calculator/handler.py | 10 ++++++++++ tests/test_calculator.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/Undefined/skills/tools/calculator/handler.py b/src/Undefined/skills/tools/calculator/handler.py index cd26438..fd73024 100644 --- a/src/Undefined/skills/tools/calculator/handler.py +++ b/src/Undefined/skills/tools/calculator/handler.py @@ -134,6 +134,9 @@ def _stat_fn(fn_name: str, args: list[float | int]) -> float | int: _MAX_POWER = 10000 _MAX_EXPRESSION_LENGTH = 500 +_MAX_COMBINATORIAL_ARG = 1000 + +_COMBINATORIAL_FUNCS = frozenset({"factorial", "perm", "comb"}) class _SafeEvaluator(ast.NodeVisitor): @@ -198,6 +201,13 @@ def visit_Call(self, node: ast.Call) -> Any: args = [self.visit(arg) for arg in node.args] + if fn_name in _COMBINATORIAL_FUNCS: + for a in args: + if isinstance(a, (int, float)) and abs(a) > _MAX_COMBINATORIAL_ARG: + raise ValueError( + f"{fn_name}() 参数过大: {a}(上限 {_MAX_COMBINATORIAL_ARG})" + ) + if fn_name in _STAT_FUNCS: return _stat_fn(fn_name, args) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 118759a..d0e6d2f 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -150,6 +150,23 @@ def test_comb(self) -> None: def test_perm(self) -> None: assert safe_eval("perm(5, 3)") == "60" + def test_factorial_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("factorial(9999)") + + def test_comb_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("comb(9999, 5000)") + + def test_perm_too_large(self) -> None: + with pytest.raises(ValueError, match="参数过大"): + safe_eval("perm(9999, 5000)") + + def test_factorial_at_limit(self) -> None: + """factorial(1000) should succeed (within limit).""" + result = safe_eval("factorial(1000)") + assert int(result) > 0 + def test_hypot(self) -> None: assert safe_eval("hypot(3, 4)") == "5"