1 ? "repeat(auto-fit, minmax(min(100%, 22rem), 1fr))" : "minmax(0, 1fr)",
+ }}
+ >
{renderMetaCollectionSection(
t("settings.debug.prompt.tools"),
toolItems,
diff --git a/frontend/src/features/settings/debug/types.ts b/frontend/src/features/settings/debug/types.ts
index 04a6a47..edce7b7 100644
--- a/frontend/src/features/settings/debug/types.ts
+++ b/frontend/src/features/settings/debug/types.ts
@@ -164,6 +164,7 @@ export type FrameworkTabProps = {
showToastPreview: () => void;
showNotificationPreview: () => void;
showDialogPreview: () => void;
+ showWhatsNewPreview: () => void;
sendOsNotification: () => void;
publishBackendDebug: () => void;
};
diff --git a/frontend/src/features/update/WhatsNewDialog.tsx b/frontend/src/features/update/WhatsNewDialog.tsx
new file mode 100644
index 0000000..4d7afb6
--- /dev/null
+++ b/frontend/src/features/update/WhatsNewDialog.tsx
@@ -0,0 +1,233 @@
+import * as React from "react";
+import { CheckCircle2 } from "lucide-react";
+
+import { useDismissWhatsNew, useWhatsNew } from "@/shared/query/update";
+import { useI18n } from "@/shared/i18n";
+import { useUpdateStore } from "@/shared/store/update";
+import { Button } from "@/shared/ui/button";
+import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/shared/ui/dialog";
+import { ProductModeGlyph } from "@/shared/ui/product-mode-glyph";
+import { DialogMarkdown } from "@/shared/markdown/dialog-markdown";
+import { useSetupCenter } from "@/features/setup";
+
+function useWindowVisible() {
+ const [visible, setVisible] = React.useState(() =>
+ typeof document !== "undefined" ? document.visibilityState === "visible" : false
+ );
+
+ React.useEffect(() => {
+ const update = () => {
+ setVisible(document.visibilityState === "visible");
+ };
+ update();
+ document.addEventListener("visibilitychange", update);
+ window.addEventListener("focus", update);
+ window.addEventListener("blur", update);
+ return () => {
+ document.removeEventListener("visibilitychange", update);
+ window.removeEventListener("focus", update);
+ window.removeEventListener("blur", update);
+ };
+ }, []);
+
+ return visible;
+}
+
+function useBlockingDialogPresent() {
+ const [present, setPresent] = React.useState(false);
+
+ React.useEffect(() => {
+ const resolve = () => {
+ const dialogs = Array.from(document.querySelectorAll("[role='dialog']"));
+ setPresent(
+ dialogs.some((node) => {
+ const element = node as HTMLElement;
+ return element.dataset.whatsNewDialog !== "true";
+ })
+ );
+ };
+
+ resolve();
+ const observer = new MutationObserver(() => resolve());
+ observer.observe(document.body, {
+ childList: true,
+ subtree: true,
+ attributes: true,
+ attributeFilter: ["data-state", "open", "style", "hidden"],
+ });
+
+ return () => observer.disconnect();
+ }, []);
+
+ return present;
+}
+
+export interface WhatsNewDialogProps {
+ activeWindow: "main" | "settings";
+ autoOpen?: boolean;
+}
+
+export function WhatsNewDialog({
+ activeWindow,
+ autoOpen = activeWindow === "main",
+}: WhatsNewDialogProps) {
+ const { t } = useI18n();
+ const { data: backendNotice } = useWhatsNew();
+ const dismissWhatsNew = useDismissWhatsNew();
+ const { open: isSetupCenterOpen } = useSetupCenter();
+ const whatsNewPreview = useUpdateStore((state) => state.whatsNewPreview);
+ const clearWhatsNewPreview = useUpdateStore((state) => state.clearWhatsNewPreview);
+ const isWindowVisible = useWindowVisible();
+ const hasBlockingDialog = useBlockingDialogPresent();
+ const [open, setOpen] = React.useState(false);
+ const [presentationReady, setPresentationReady] = React.useState(false);
+ const dismissedVersionRef = React.useRef("");
+ const previewNotice =
+ whatsNewPreview?.targetWindow === activeWindow ? whatsNewPreview.notice : null;
+ const notice = previewNotice ?? backendNotice;
+
+ React.useEffect(() => {
+ setOpen(false);
+ setPresentationReady(false);
+ if (!notice?.version) {
+ dismissedVersionRef.current = "";
+ return;
+ }
+ const timer = window.setTimeout(() => {
+ setPresentationReady(true);
+ }, 900);
+ return () => window.clearTimeout(timer);
+ }, [notice?.version]);
+
+ React.useEffect(() => {
+ if (!notice?.version) {
+ return;
+ }
+ if (previewNotice) {
+ setOpen(true);
+ return;
+ }
+ if (dismissedVersionRef.current === notice.version) {
+ return;
+ }
+ if (
+ !autoOpen ||
+ !presentationReady ||
+ !isWindowVisible ||
+ isSetupCenterOpen ||
+ hasBlockingDialog ||
+ open
+ ) {
+ return;
+ }
+ setOpen(true);
+ }, [
+ activeWindow,
+ autoOpen,
+ hasBlockingDialog,
+ isSetupCenterOpen,
+ isWindowVisible,
+ notice?.version,
+ open,
+ presentationReady,
+ previewNotice,
+ ]);
+
+ const handleOpenChange = React.useCallback(
+ (nextOpen: boolean) => {
+ setOpen(nextOpen);
+ if (nextOpen || !notice?.version) {
+ return;
+ }
+ if (previewNotice) {
+ clearWhatsNewPreview();
+ return;
+ }
+ dismissedVersionRef.current = notice.version;
+ dismissWhatsNew.mutate(notice.version);
+ },
+ [clearWhatsNewPreview, dismissWhatsNew, notice?.version, previewNotice]
+ );
+
+ const title = notice?.version
+ ? t("whatsNew.title").replace("{version}", notice.version)
+ : t("whatsNew.title").replace("{version}", "");
+ const description = t("whatsNew.description").trim();
+
+ return (
+
+ );
+}
diff --git a/frontend/src/shared/contracts/connectors.ts b/frontend/src/shared/contracts/connectors.ts
index fcf559f..760ca6d 100644
--- a/frontend/src/shared/contracts/connectors.ts
+++ b/frontend/src/shared/contracts/connectors.ts
@@ -19,6 +19,9 @@ export interface Connector {
status: ConnectorStatus | string;
cookiesCount?: number;
cookies?: ConnectorCookie[];
+ domains?: string[];
+ policyKey?: string;
+ capabilities?: string[];
lastVerifiedAt?: string;
}
@@ -33,10 +36,51 @@ export interface ClearConnectorRequest {
id: string;
}
-export interface ConnectConnectorRequest {
+export interface StartConnectorConnectRequest {
id: string;
}
+export interface StartConnectorConnectResult {
+ sessionId: string;
+ connector: Connector;
+}
+
+export interface FinishConnectorConnectRequest {
+ sessionId: string;
+}
+
+export interface FinishConnectorConnectResult {
+ sessionId: string;
+ saved: boolean;
+ rawCookiesCount: number;
+ filteredCookiesCount: number;
+ domains?: string[];
+ reason?: string;
+ connector: Connector;
+}
+
+export interface CancelConnectorConnectRequest {
+ sessionId: string;
+}
+
+export interface ConnectorConnectSession {
+ sessionId: string;
+ connectorId: string;
+ state: string;
+ saved: boolean;
+ rawCookiesCount: number;
+ filteredCookiesCount: number;
+ domains?: string[];
+ reason?: string;
+ error?: string;
+ lastCookiesAt?: string;
+ connector: Connector;
+}
+
+export interface GetConnectorConnectSessionRequest {
+ sessionId: string;
+}
+
export interface OpenConnectorSiteRequest {
id: string;
}
diff --git a/frontend/src/shared/i18n/locales/en.json b/frontend/src/shared/i18n/locales/en.json
index ba82402..38c1ecf 100644
--- a/frontend/src/shared/i18n/locales/en.json
+++ b/frontend/src/shared/i18n/locales/en.json
@@ -932,11 +932,11 @@
"groups": {
"read": {
"label": "Read",
- "description": "Skills.Status/skills.Bins/skill_manage.Search/skill_manage.List"
+ "description": "Read-only skills status, bins, and package discovery actions."
},
"package_write": {
"label": "Package write",
- "description": "Skill_manage.Install/update/remove/sync"
+ "description": "Skill package install, update, remove, and sync actions."
},
"deps_write": {
"label": "Dependencies write",
@@ -1046,7 +1046,22 @@
"updateFailed": "Failed to update tool status"
},
"reason": {
- "unavailable": "Required runtime dependencies are unavailable."
+ "unavailable": "Required runtime dependencies are unavailable.",
+ "browserProcessExited": "Browser process exited.",
+ "remoteNodeRuntimeUnavailable": "Temporarily unavailable: remote node runtime is not implemented yet.",
+ "imageModelNotConfigured": "Image model is not configured.",
+ "providerRepositoriesUnavailable": "Provider repositories are unavailable.",
+ "gatewayControlPlaneDisabled": "Gateway control plane is disabled.",
+ "webSearchModeUnsupported": "Search mode is not supported.",
+ "webSearchProviderUnsupported": "{provider} is not supported in API mode.",
+ "webSearchProviderApiKeyMissing": "{provider} API key is missing.",
+ "webSearchExternalToolsUnavailable": "External tools mode is not implemented yet.",
+ "voiceDisabled": "Gateway voice feature is disabled.",
+ "voiceServiceUnavailable": "Voice service is unavailable.",
+ "ttsProviderApiKeyMissing": "TTS provider API key is missing.",
+ "ttsVoiceIdMissing": "TTS voice ID is not configured.",
+ "ttsEdgeProviderUnavailable": "Temporarily unavailable: Edge-TTS provider is not implemented yet.",
+ "ttsProviderUnsupported": "TTS provider is not supported."
},
"builtin": {
"read": {
@@ -1083,11 +1098,11 @@
},
"browser": {
"name": "Browser control",
- "description": "Control browser page loading with Playwright."
+ "description": "Control browser page loading through a local CDP browser."
},
"canvas": {
"name": "Canvas rendering",
- "description": "Control node canvases (present/hide/navigate/eval/snapshot/a2ui)."
+ "description": "Control node canvases (present/hide/navigate/eval/snapshot/a2ui). Temporarily unavailable until remote node runtime support is implemented."
},
"image": {
"name": "Image processing",
@@ -1095,7 +1110,7 @@
},
"message": {
"name": "Message sending",
- "description": "Send a channel message (not implemented)."
+ "description": "Send, delete, and manage messages through configured channel plugins."
},
"gateway": {
"name": "Gateway control",
@@ -1107,27 +1122,27 @@
},
"agents_list": {
"name": "Agents list",
- "description": "List agents."
+ "description": "List available agent profiles for subagent spawning."
},
"sessions_list": {
"name": "Sessions list",
- "description": "List sessions."
+ "description": "List current sessions."
},
"sessions_history": {
"name": "Session history",
- "description": "Fetch session history."
+ "description": "Read message history for a session."
},
"sessions_send": {
"name": "Session send",
- "description": "Send a message to a session."
+ "description": "Append a message to an existing session."
},
"sessions_spawn": {
"name": "Session spawn",
- "description": "Spawn a new session."
+ "description": "Spawn an isolated subagent run."
},
"session_status": {
"name": "Session status",
- "description": "Get session status."
+ "description": "Get session metadata and status."
},
"external_tools_query": {
"name": "External tools query",
@@ -1139,7 +1154,7 @@
},
"skills": {
"name": "Skills",
- "description": "Skills runtime tool (status/bins/install/update)."
+ "description": "Inspect skills runtime status and update per-skill runtime dependencies or configuration."
},
"skill_manage": {
"name": "Skill manage",
@@ -1151,7 +1166,7 @@
},
"skills_manage": {
"name": "Skills manage",
- "description": "Manage skills lifecycle actions."
+ "description": "Search, install, update, remove, and sync skill packages via ClawHub."
},
"skills_policy": {
"name": "Skills policy",
@@ -1159,15 +1174,15 @@
},
"subagents": {
"name": "Subagents",
- "description": "Manage subagent runs (list/info/log/kill/steer/spawn)."
+ "description": "Manage existing subagent runs (list/info/log/kill/steer/send)."
},
"nodes": {
"name": "Nodes",
- "description": "Invoke node capability."
+ "description": "Experimental low-level RPC to a registered node. Temporarily unavailable until remote node runtime support is implemented."
},
"tts": {
"name": "Text to speech",
- "description": "Convert text to speech."
+ "description": "Synthesize speech audio from text with the configured voice provider."
},
"render_chart": {
"name": "Render chart",
@@ -1221,6 +1236,14 @@
"name": "Memory list",
"description": "List recent memory entries by filter."
},
+ "memory_query": {
+ "name": "Memory query",
+ "description": "Query long-term memory with recall, list, and stats actions."
+ },
+ "memory_manage": {
+ "name": "Memory manage",
+ "description": "Create, update, or delete long-term memory entries."
+ },
"subagent_run": {
"name": "Subagent run",
"description": "Run a sub-agent task in an isolated workspace."
@@ -1260,55 +1283,68 @@
"externalToolsHint": "External tools mode is reserved. No options are available yet."
},
"webFetch": {
- "type": "Mode",
- "typeDesc": "Switch between Playwright headless browser mode and builtin HTTP fetch mode.",
- "typeValue": {
- "playwright": "Playwright",
- "builtin": "Builtin"
- },
- "playwrightMarkdown": "Convert to Markdown",
- "playwrightMarkdownDesc": "Convert rendered HTML output into Markdown.",
- "playwrightHint": "Playwright mode runs in headless browser and only reads cookies from connectors. It will not persist new cookies.",
- "acceptMarkdown": "Prefer Markdown",
- "acceptMarkdownDesc": "When enabled, requests prefer text/Markdown with HTML fallback.",
- "enableUserAgent": "Send User-Agent",
- "enableUserAgentDesc": "When disabled, no explicit User-Agent will be sent.",
- "userAgent": "User-Agent",
- "acceptLanguage": "Accept-Language",
+ "preferredBrowser": "Preferred browser",
+ "preferredBrowserDesc": "Show only browsers detected on this machine.",
+ "headless": "Headless",
+ "headlessDesc": "Run web fetch in headless mode.",
"timeoutSeconds": "Timeout (sec)",
- "maxChars": "Max chars",
- "maxRedirects": "Max redirects",
- "maxRedirectsDesc": "Set 0 to disable redirect following.",
- "retryMax": "Max retries",
- "retryMaxDesc": "Applies only to safely retryable methods (E.G. GET/HEAD/OPTIONS).",
- "headers": "Extra headers (JSON)",
- "headersDesc": "Provide a JSON object, E.G. {\"X-Test\":\"1\"}. It merges with tool-call headers.",
- "headersInvalid": "Invalid headers format",
- "headersInvalidDesc": "Use a JSON object, E.G. {\"Accept-Language\":\"En-US\"}."
+ "maxChars": "Max chars"
},
"browserControl": {
"enabled": "Enabled",
- "evaluateEnabled": "Allow evaluate",
- "evaluateEnabledDesc": "Disable browser evaluate actions.",
"ssrfSection": "SSRF Security",
"ssrfSectionDesc": "Private network and hostname allowlist options.",
- "executablePath": "Playwright chromium path",
- "executablePathDesc": "Managed by Playwright and read-only.",
- "executablePathPending": "Path will appear after Playwright runtime check.",
+ "preferredBrowser": "Preferred browser",
+ "preferredBrowserDesc": "Show only browsers detected on this machine.",
"headless": "Headless",
"headlessDesc": "Start browser in headless mode.",
- "noSandbox": "No sandbox",
- "noSandboxDesc": "Pass --no-sandbox to chrome.",
- "snapshotDefaultMode": "Snapshot default mode",
- "snapshotDefaultModeAuto": "Auto",
"ssrfDangerouslyAllowPrivateNetwork": "Allow private network",
"ssrfDangerouslyAllowPrivateNetworkDesc": "Allow browser navigation to private/internal networks.",
"ssrfAllowedHostnames": "Allowed hostnames (JSON)",
"ssrfHostnameAllowlist": "Hostname allowlist (JSON)",
- "extraArgs": "Extra args (JSON)",
"arrayInvalid": "Array format is invalid",
"arrayInvalidDesc": "Use a JSON string array, for example [\"localhost\",\"metadata.Internal\"]."
},
+ "runtimeDetection": {
+ "detected": "Detected",
+ "notDetected": "Not detected",
+ "notInstalled": "Not installed",
+ "noneDetected": "No supported browser detected"
+ },
+ "requirements": {
+ "gatewayControlPlane": "Gateway control plane",
+ "localCDPBrowser": "Local CDP browser",
+ "remoteNodeRuntime": "Remote node runtime",
+ "webSearchMode": "Search mode",
+ "webSearchProvider": "Provider",
+ "webSearchProviderApiKey": "Provider API key",
+ "externalToolsRuntime": "External tools runtime",
+ "imageModel": "Image model",
+ "voiceService": "Voice service",
+ "voiceFeature": "Voice feature",
+ "ttsProvider": "Provider",
+ "ttsProviderApiKey": "Provider API key",
+ "ttsVoiceId": "Voice ID",
+ "values": {
+ "available": "Available",
+ "unavailable": "Unavailable",
+ "configured": "Configured",
+ "notConfigured": "Not configured",
+ "enabled": "Enabled",
+ "disabled": "Disabled",
+ "detected": "Detected",
+ "missing": "Missing"
+ },
+ "providers": {
+ "brave": "Brave",
+ "tavily": "Tavily",
+ "perplexity": "Perplexity",
+ "grok": "Grok",
+ "openai": "OpenAI",
+ "elevenlabs": "ElevenLabs",
+ "edge": "Edge-TTS"
+ }
+ },
"detail": {
"descriptionLabel": "Description",
"enabled": "Enabled",
@@ -1351,31 +1387,46 @@
"searchPlaceholder": "Search connectors",
"searchEmpty": "No connectors match your search.",
"group.searchEngine": "Search engines",
+ "group.community": "Communities",
"group.video": "Video",
+ "group.developer": "Developer",
"group.other": "Other",
"item.google": "Google",
+ "item.github": "GitHub",
+ "item.reddit": "Reddit",
+ "item.zhihu": "Zhihu",
+ "item.x": "X",
"item.xiaohongshu": "Xiaohongshu",
"item.bilibili": "Bilibili",
"connect": "Connect",
"reconnect": "Reconnect",
"clear": "Clear",
"noCookies": "No cookies configured",
- "loginHint": "A browser window will open. Sign in, then close the window to save cookies.",
+ "noCookiesRead": "No connector cookies were read yet. Stay signed in and click finish again.",
+ "loginHint": "A browser window will open with a temporary profile. Sign in there, then click finish to read and save cookies.",
"loginTitle": "Browser login",
- "loginDescription": "Finish login in the browser window, then close it.",
- "installDescription": "Playwright is required to open the browser.",
+ "loginDescription": "Finish login in the browser window, then click finish in this dialog.",
"loginTarget": "Connector",
- "loginRunning": "Waiting for browser to close...",
- "loginDone": "Login completed.",
+ "loginLaunching": "Launching browser...",
+ "loginReady": "Browser opened. Complete sign-in, then click finish.",
+ "loginReadingCookies": "Reading cookies from the browser session...",
+ "loginClosingBrowser": "Closing the browser session...",
+ "loginCompleted": "Connection processing completed.",
+ "loginIdle": "Waiting to start login.",
+ "loginFinish": "Finish",
"loginError": "Login failed",
- "playwrightMissing": "Playwright is not installed. Install it to continue.",
- "openExternalTools": "Open external tools",
- "installRequiredStatus": "Install required to continue.",
+ "loginSessionMissing": "The login session was not found. Start the connection again.",
+ "browserSessionEnded": "The browser session has already ended. Start the connection again.",
+ "browserMissing": "No supported local browser was detected.",
+ "cookiesRead": "Cookies read",
+ "cookiesSaved": "Cookies kept",
+ "cookiesDomains": "Matched domains",
"status.connected": "Connected",
"status.expired": "Expired",
"status.disconnected": "Disconnected",
"detail.status": "Status",
"detail.data": "Data",
+ "detail.scope": "Cookie scope",
"detail.actions": "Actions",
"viewCookies": "View",
"openSite": "Open site",
@@ -1720,8 +1771,9 @@
"current": "Current version",
"changelog": "Release notes",
"check": "Check for updates",
- "install": "Install updates",
- "restart": "Restart to apply",
+ "recheck": "Check again",
+ "downloadAndInstall": "Download and install",
+ "restartAfterUpdate": "Restart to update",
"command": "Update actions",
"status": "Status",
"downloading": "Downloading",
@@ -3524,6 +3576,14 @@
"notAvailable": "Not available",
"justNow": "Just now"
},
+ "whatsNew": {
+ "eyebrow": "Release notes",
+ "currentVersion": "Current version",
+ "title": "What's new in {version}",
+ "description": "",
+ "emptyState": "This version has already been applied, but no release notes are available yet.",
+ "versionLabel": "Current version: {version}"
+ },
"gateway": {
"page": {
"subtitle": "Control plane overview and diagnostics"
@@ -3804,6 +3864,7 @@
"notification": "Notification",
"sendOS": "Send OS message",
"dialog": "Dialog",
+ "whatsNew": "What's new",
"publish": "Publish debug event"
}
},
@@ -3841,7 +3902,8 @@
"action": "Acknowledge",
"dialogTitle": "Danger Dialog",
"dialogDesc": "Use this to verify modal styles.",
- "dialogConfirm": "Confirm"
+ "dialogConfirm": "Confirm",
+ "whatsNewMarkdown": "## Debug preview\n\n### layout verification\n- this line checks the title-to-body spacing.\n- this line checks paragraph rhythm in markdown.\n- this line checks the rounded container edge.\n- this line checks long-content readability.\n- this line checks the scroll threshold.\n- this line checks whether the dialog height stays stable.\n\n### interaction verification\n- this line checks scrolling after the content grows.\n- this line checks that the footer stays visible.\n- this line checks that the close button does not move.\n- this line checks markdown list rendering.\n- this line checks the spacing near the bottom.\n- this line confirms this preview is debug-only."
},
"realtime": {
"websocket": "WebSocket",
@@ -4041,6 +4103,7 @@
"productMode": "Product mode",
"notifications": "Notifications",
"appUpdate": "App update",
+ "restartAndUpdate": "Restart and update",
"externalToolsUpdate": "Tool updates",
"settings": "Settings",
"open": "Open profile menu"
diff --git a/frontend/src/shared/i18n/locales/zh-CN.json b/frontend/src/shared/i18n/locales/zh-CN.json
index e56fd66..7e9f8a5 100644
--- a/frontend/src/shared/i18n/locales/zh-CN.json
+++ b/frontend/src/shared/i18n/locales/zh-CN.json
@@ -932,11 +932,11 @@
"groups": {
"read": {
"label": "读取",
- "description": "skills.status/skills.bins/skill_manage.search/skill_manage.list"
+ "description": "skills.status/skills.bins/skills_manage.search/skills_manage.list"
},
"package_write": {
"label": "包管理写入",
- "description": "skill_manage.install/update/remove/sync"
+ "description": "skills_manage.install/update/remove/sync"
},
"deps_write": {
"label": "依赖写入",
@@ -1046,7 +1046,22 @@
"updateFailed": "更新工具状态失败"
},
"reason": {
- "unavailable": "缺少运行时依赖,当前不可用。"
+ "unavailable": "缺少运行时依赖,当前不可用。",
+ "browserProcessExited": "浏览器进程已退出。",
+ "remoteNodeRuntimeUnavailable": "暂时不可用:远端 Node 运行时还未实现。",
+ "imageModelNotConfigured": "图像模型尚未配置。",
+ "providerRepositoriesUnavailable": "Provider 仓库当前不可用。",
+ "gatewayControlPlaneDisabled": "Gateway 控制平面已关闭。",
+ "webSearchModeUnsupported": "当前搜索模式不受支持。",
+ "webSearchProviderUnsupported": "{provider} 当前不支持 API 模式。",
+ "webSearchProviderApiKeyMissing": "{provider} 的 API key 缺失。",
+ "webSearchExternalToolsUnavailable": "External tools 模式暂未实现。",
+ "voiceDisabled": "Gateway 语音功能已关闭。",
+ "voiceServiceUnavailable": "语音服务当前不可用。",
+ "ttsProviderApiKeyMissing": "TTS provider 的 API key 缺失。",
+ "ttsVoiceIdMissing": "TTS 的 voice ID 尚未配置。",
+ "ttsEdgeProviderUnavailable": "暂时不可用:Edge-TTS provider 还未实现。",
+ "ttsProviderUnsupported": "当前 TTS provider 不受支持。"
},
"builtin": {
"read": {
@@ -1083,11 +1098,11 @@
},
"browser": {
"name": "浏览器控制",
- "description": "通过 Playwright 控制页面加载。"
+ "description": "通过本地 CDP 浏览器控制页面加载。"
},
"canvas": {
"name": "画布渲染",
- "description": "控制节点画布能力(present/hide/navigate/eval/snapshot/a2ui)。"
+ "description": "控制节点画布能力(present/hide/navigate/eval/snapshot/a2ui)。远端 Node 运行时落地前暂不可用。"
},
"image": {
"name": "图像处理",
@@ -1095,7 +1110,7 @@
},
"message": {
"name": "消息发送",
- "description": "发送频道消息(当前未实现)。"
+ "description": "通过已配置的频道插件发送、删除和管理消息。"
},
"gateway": {
"name": "Gateway 控制",
@@ -1107,27 +1122,27 @@
},
"agents_list": {
"name": "代理列表",
- "description": "列出代理。"
+ "description": "列出可用于子代理启动的代理配置。"
},
"sessions_list": {
"name": "会话列表",
- "description": "列出会话。"
+ "description": "列出当前会话。"
},
"sessions_history": {
"name": "会话历史",
- "description": "获取会话历史。"
+ "description": "读取某个会话的消息历史。"
},
"sessions_send": {
"name": "会话发送",
- "description": "向会话发送消息。"
+ "description": "向已有会话追加一条消息。"
},
"sessions_spawn": {
"name": "新建会话",
- "description": "创建新的会话。"
+ "description": "启动一次隔离的子代理运行。"
},
"session_status": {
"name": "会话状态",
- "description": "获取会话状态。"
+ "description": "获取会话元数据与状态。"
},
"external_tools_query": {
"name": "外部工具查询",
@@ -1139,7 +1154,7 @@
},
"skills": {
"name": "技能",
- "description": "技能运行工具(状态/命令依赖/依赖安装/配置更新)。"
+ "description": "查看技能运行状态,并更新单个技能的运行依赖或配置。"
},
"skill_manage": {
"name": "技能管理",
@@ -1151,7 +1166,7 @@
},
"skills_manage": {
"name": "技能管理",
- "description": "管理技能生命周期操作。"
+ "description": "通过 ClawHub 搜索、安装、更新、移除和同步技能包。"
},
"skills_policy": {
"name": "技能策略",
@@ -1159,15 +1174,15 @@
},
"subagents": {
"name": "子代理管理",
- "description": "管理子代理运行(列表/详情/日志/终止/转向/启动)。"
+ "description": "管理已有子代理运行(列表/详情/日志/终止/转向/发送)。"
},
"nodes": {
"name": "节点调用",
- "description": "调用节点能力。"
+ "description": "实验性的低层节点 RPC。远端 Node 运行时落地前暂不可用。"
},
"tts": {
"name": "文本转语音",
- "description": "将文本转换为语音。"
+ "description": "使用当前语音提供方把文本合成为语音音频。"
},
"render_chart": {
"name": "图表渲染",
@@ -1221,6 +1236,14 @@
"name": "记忆列表",
"description": "按条件列出近期记忆条目。"
},
+ "memory_query": {
+ "name": "记忆查询",
+ "description": "通过 recall、list、stats 动作查询长期记忆。"
+ },
+ "memory_manage": {
+ "name": "记忆管理",
+ "description": "创建、更新或删除长期记忆条目。"
+ },
"subagent_run": {
"name": "子代理运行",
"description": "在隔离工作区运行子代理任务。"
@@ -1260,55 +1283,68 @@
"externalToolsHint": "外部工具模式仍为预留能力,暂时没有可配置项。"
},
"webFetch": {
- "type": "模式",
- "typeDesc": "在 Playwright 无头浏览器与内置 HTTP 抓取之间切换。",
- "typeValue": {
- "playwright": "Playwright",
- "builtin": "内置"
- },
- "playwrightMarkdown": "转换为 Markdown",
- "playwrightMarkdownDesc": "开启后会将渲染后的 HTML 内容转换为 Markdown。",
- "playwrightHint": "Playwright 模式会以无头浏览器访问页面,并只读取 Connectors 中已有 Cookies,不会写回或新增 Cookies。",
- "acceptMarkdown": "优先 Markdown",
- "acceptMarkdownDesc": "开启后会优先请求 text/markdown,并在服务端回退到 HTML。",
- "enableUserAgent": "发送 User-Agent",
- "enableUserAgentDesc": "关闭后将不主动发送 User-Agent。",
- "userAgent": "User-Agent",
- "acceptLanguage": "Accept-Language",
+ "preferredBrowser": "默认浏览器",
+ "preferredBrowserDesc": "仅展示本机已检测到的浏览器。",
+ "headless": "无头模式",
+ "headlessDesc": "以无头模式进行网页抓取。",
"timeoutSeconds": "超时(秒)",
- "maxChars": "最大字符数",
- "maxRedirects": "最大重定向次数",
- "maxRedirectsDesc": "0 表示不跟随重定向。",
- "retryMax": "最大重试次数",
- "retryMaxDesc": "仅对可安全重试的方法生效(如 GET/HEAD/OPTIONS)。",
- "headers": "附加 Headers(JSON)",
- "headersDesc": "填写 JSON 对象,例如 {\"X-Test\":\"1\"}。会与工具调用参数 headers 合并。",
- "headersInvalid": "Headers 格式不正确",
- "headersInvalidDesc": "请填写 JSON 对象,例如 {\"Accept-Language\":\"en-US\"}。"
+ "maxChars": "最大字符数"
},
"browserControl": {
"enabled": "启用",
- "evaluateEnabled": "允许 Evaluate",
- "evaluateEnabledDesc": "关闭后禁用 browser evaluate 动作。",
"ssrfSection": "SSRF 安全",
"ssrfSectionDesc": "私有网络与主机名白名单配置。",
- "executablePath": "Playwright Chromium 路径",
- "executablePathDesc": "由 Playwright 管理,只读不可修改。",
- "executablePathPending": "等待 Playwright 运行时检查后显示路径。",
+ "preferredBrowser": "默认浏览器",
+ "preferredBrowserDesc": "仅展示本机已检测到的浏览器。",
"headless": "无头模式",
"headlessDesc": "以无头模式启动浏览器。",
- "noSandbox": "关闭沙箱",
- "noSandboxDesc": "启动 Chrome 时添加 --no-sandbox。",
- "snapshotDefaultMode": "快照默认模式",
- "snapshotDefaultModeAuto": "自动",
"ssrfDangerouslyAllowPrivateNetwork": "允许私有网络",
"ssrfDangerouslyAllowPrivateNetworkDesc": "允许浏览器访问私有/内网地址。",
"ssrfAllowedHostnames": "允许的主机名(JSON)",
"ssrfHostnameAllowlist": "主机名白名单(JSON)",
- "extraArgs": "额外参数(JSON)",
"arrayInvalid": "数组格式不正确",
"arrayInvalidDesc": "请填写 JSON 字符串数组,例如 [\"localhost\",\"metadata.internal\"]。"
},
+ "runtimeDetection": {
+ "detected": "已检测到",
+ "notDetected": "未检测到",
+ "notInstalled": "未安装",
+ "noneDetected": "未检测到可用浏览器"
+ },
+ "requirements": {
+ "gatewayControlPlane": "Gateway 控制平面",
+ "localCDPBrowser": "本地 CDP 浏览器",
+ "remoteNodeRuntime": "远端 Node 运行时",
+ "webSearchMode": "搜索模式",
+ "webSearchProvider": "提供方",
+ "webSearchProviderApiKey": "提供方 API Key",
+ "externalToolsRuntime": "外部工具运行时",
+ "imageModel": "图像模型",
+ "voiceService": "语音服务",
+ "voiceFeature": "语音功能",
+ "ttsProvider": "提供方",
+ "ttsProviderApiKey": "提供方 API Key",
+ "ttsVoiceId": "语音 ID",
+ "values": {
+ "available": "可用",
+ "unavailable": "不可用",
+ "configured": "已配置",
+ "notConfigured": "未配置",
+ "enabled": "已启用",
+ "disabled": "已关闭",
+ "detected": "已检测到",
+ "missing": "缺失"
+ },
+ "providers": {
+ "brave": "Brave",
+ "tavily": "Tavily",
+ "perplexity": "Perplexity",
+ "grok": "Grok",
+ "openai": "OpenAI",
+ "elevenlabs": "ElevenLabs",
+ "edge": "Edge-TTS"
+ }
+ },
"detail": {
"descriptionLabel": "说明",
"enabled": "启用",
@@ -1351,31 +1387,46 @@
"searchPlaceholder": "搜索连接",
"searchEmpty": "没有匹配的连接。",
"group.searchEngine": "搜索引擎",
+ "group.community": "社区",
"group.video": "视频",
+ "group.developer": "开发者",
"group.other": "其他",
"item.google": "Google",
+ "item.github": "GitHub",
+ "item.reddit": "Reddit",
+ "item.zhihu": "知乎",
+ "item.x": "X",
"item.xiaohongshu": "小红书",
"item.bilibili": "Bilibili",
"connect": "连接",
"reconnect": "重新连接",
"clear": "清除",
"noCookies": "未配置 cookies",
- "loginHint": "将打开浏览器窗口。完成登录后关闭窗口即可保存 cookies。",
+ "noCookiesRead": "还没有读取到当前连接的 cookies。请保持登录状态后再次点击完成连接。",
+ "loginHint": "将以临时 profile 打开浏览器窗口。请先在浏览器里完成登录,然后回到这里点击完成连接。",
"loginTitle": "浏览器登录",
- "loginDescription": "请在浏览器窗口完成登录,然后关闭窗口。",
- "installDescription": "需要安装 Playwright 才能打开浏览器。",
+ "loginDescription": "请在浏览器窗口完成登录,然后回到此对话框点击完成连接。",
"loginTarget": "连接",
- "loginRunning": "等待浏览器关闭中...",
- "loginDone": "登录完成。",
+ "loginLaunching": "正在启动浏览器...",
+ "loginReady": "浏览器已打开。完成登录后点击完成连接。",
+ "loginReadingCookies": "正在从浏览器会话读取 cookies...",
+ "loginClosingBrowser": "正在关闭浏览器会话...",
+ "loginCompleted": "连接处理已完成。",
+ "loginIdle": "等待开始连接。",
+ "loginFinish": "完成连接",
"loginError": "登录失败",
- "playwrightMissing": "未安装 Playwright,无法打开浏览器。",
- "openExternalTools": "打开外部工具",
- "installRequiredStatus": "需要安装 Playwright 才能继续。",
+ "loginSessionMissing": "未找到当前连接会话,请重新开始连接。",
+ "browserSessionEnded": "浏览器会话已结束,请重新开始连接。",
+ "browserMissing": "未检测到可用的本机浏览器。",
+ "cookiesRead": "读取到的 cookies",
+ "cookiesSaved": "保留的 cookies",
+ "cookiesDomains": "命中的域名",
"status.connected": "已连接",
"status.expired": "已过期",
"status.disconnected": "未连接",
"detail.status": "状态",
"detail.data": "数据",
+ "detail.scope": "Cookie 范围",
"detail.actions": "操作",
"viewCookies": "查看",
"openSite": "打开站点",
@@ -1720,8 +1771,9 @@
"current": "当前版本",
"changelog": "更新日志",
"check": "检查更新",
- "install": "安装更新",
- "restart": "重启生效",
+ "recheck": "重新检查",
+ "downloadAndInstall": "下载并安装",
+ "restartAfterUpdate": "重启后更新",
"command": "更新操作",
"status": "状态",
"downloading": "下载中",
@@ -2390,6 +2442,7 @@
"notification": "通知",
"sendOS": "系统消息",
"dialog": "对话框",
+ "whatsNew": "What's New",
"publish": "调试事件"
}
},
@@ -2427,7 +2480,8 @@
"action": "知道了",
"dialogTitle": "危险弹窗",
"dialogDesc": "用于确认样式的弹窗示例。",
- "dialogConfirm": "确认"
+ "dialogConfirm": "确认",
+ "whatsNewMarkdown": "## 调试预览\n\n### 布局验证\n- 这一行用于检查标题与正文之间的留白。\n- 这一行用于检查 Markdown 段落节奏。\n- 这一行用于检查内容卡片的圆角边界。\n- 这一行用于检查长内容时的可读性。\n- 这一行用于检查滚动阈值是否生效。\n- 这一行用于检查 dialog 高度是否保持稳定。\n\n### 交互验证\n- 这一行用于检查内容超过一屏后的滚动行为。\n- 这一行用于检查底部按钮区域是否始终可见。\n- 这一行用于检查关闭按钮位置不会被挤压。\n- 这一行用于检查 Markdown 列表渲染效果。\n- 这一行用于检查底部间距是否自然。\n- 这一行确认这是仅供调试使用的预览内容。"
},
"realtime": {
"websocket": "WebSocket",
@@ -4001,6 +4055,14 @@
"notAvailable": "不可用",
"justNow": "刚刚"
},
+ "whatsNew": {
+ "eyebrow": "版本已更新",
+ "currentVersion": "当前版本",
+ "title": "{version} 有这些新内容",
+ "description": "",
+ "emptyState": "当前版本已经更新完成,但暂时没有可展示的更新日志。",
+ "versionLabel": "当前版本:{version}"
+ },
"gateway": {
"page": {
"subtitle": "控制平面概览与诊断信息"
@@ -4053,6 +4115,7 @@
"productMode": "产品形态",
"notifications": "通知中心",
"appUpdate": "应用升级",
+ "restartAndUpdate": "重启后更新",
"externalToolsUpdate": "工具升级",
"settings": "设置",
"open": "打开个人菜单"
diff --git a/frontend/src/shared/query/connectors.ts b/frontend/src/shared/query/connectors.ts
index 169aa75..641929c 100644
--- a/frontend/src/shared/query/connectors.ts
+++ b/frontend/src/shared/query/connectors.ts
@@ -1,29 +1,44 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import type {
+ CancelConnectorConnectRequest,
ClearConnectorRequest,
- ConnectConnectorRequest,
+ ConnectorConnectSession,
Connector,
+ FinishConnectorConnectRequest,
+ FinishConnectorConnectResult,
+ GetConnectorConnectSessionRequest,
OpenConnectorSiteRequest,
+ StartConnectorConnectRequest,
+ StartConnectorConnectResult,
UpsertConnectorRequest,
} from "@/shared/contracts/connectors";
import {
+ CancelConnectorConnect as CancelConnectorConnectBinding,
ClearConnector as ClearConnectorBinding,
- ConnectConnector as ConnectConnectorBinding,
- InstallPlaywright,
+ FinishConnectorConnect as FinishConnectorConnectBinding,
+ GetConnectorConnectSession as GetConnectorConnectSessionBinding,
ListConnectors,
OpenConnectorSite as OpenConnectorSiteBinding,
+ StartConnectorConnect as StartConnectorConnectBinding,
UpsertConnector as UpsertConnectorBinding,
} from "../../../bindings/dreamcreator/internal/presentation/wails/connectorshandler";
import {
+ CancelConnectorConnectRequest as BindingsCancelConnectorConnectRequest,
ClearConnectorRequest as BindingsClearConnectorRequest,
- ConnectConnectorRequest as BindingsConnectConnectorRequest,
+ ConnectorConnectSession as BindingsConnectorConnectSession,
Connector as BindingsConnector,
+ FinishConnectorConnectRequest as BindingsFinishConnectorConnectRequest,
+ FinishConnectorConnectResult as BindingsFinishConnectorConnectResult,
+ GetConnectorConnectSessionRequest as BindingsGetConnectorConnectSessionRequest,
OpenConnectorSiteRequest as BindingsOpenConnectorSiteRequest,
+ StartConnectorConnectRequest as BindingsStartConnectorConnectRequest,
+ StartConnectorConnectResult as BindingsStartConnectorConnectResult,
UpsertConnectorRequest as BindingsUpsertConnectorRequest,
} from "../../../bindings/dreamcreator/internal/application/connectors/dto/models";
export const CONNECTORS_QUERY_KEY = ["connectors"];
+export const CONNECTOR_CONNECT_SESSION_QUERY_KEY = ["connector-connect-session"];
export function useConnectors() {
return useQuery({
@@ -59,11 +74,23 @@ export function useClearConnector() {
});
}
-export function useConnectConnector() {
+export function useStartConnectorConnect() {
+ return useMutation({
+ mutationFn: async (request: StartConnectorConnectRequest): Promise
=> {
+ return toStartConnectorConnectResult(
+ await StartConnectorConnectBinding(BindingsStartConnectorConnectRequest.createFrom(request))
+ );
+ },
+ });
+}
+
+export function useFinishConnectorConnect() {
const queryClient = useQueryClient();
return useMutation({
- mutationFn: async (request: ConnectConnectorRequest): Promise => {
- return toConnector(await ConnectConnectorBinding(BindingsConnectConnectorRequest.createFrom(request)));
+ mutationFn: async (request: FinishConnectorConnectRequest): Promise => {
+ return toFinishConnectorConnectResult(
+ await FinishConnectorConnectBinding(BindingsFinishConnectorConnectRequest.createFrom(request))
+ );
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: CONNECTORS_QUERY_KEY });
@@ -71,6 +98,14 @@ export function useConnectConnector() {
});
}
+export function useCancelConnectorConnect() {
+ return useMutation({
+ mutationFn: async (request: CancelConnectorConnectRequest): Promise => {
+ await CancelConnectorConnectBinding(BindingsCancelConnectorConnectRequest.createFrom(request));
+ },
+ });
+}
+
export function useOpenConnectorSite() {
return useMutation({
mutationFn: async (request: OpenConnectorSiteRequest): Promise => {
@@ -79,11 +114,17 @@ export function useOpenConnectorSite() {
});
}
-export function useInstallPlaywright() {
- return useMutation({
- mutationFn: async (): Promise => {
- await InstallPlaywright();
+export function useConnectorConnectSession(request: GetConnectorConnectSessionRequest, enabled: boolean) {
+ return useQuery({
+ queryKey: [...CONNECTOR_CONNECT_SESSION_QUERY_KEY, request.sessionId],
+ enabled: enabled && request.sessionId.trim().length > 0,
+ queryFn: async (): Promise => {
+ return toConnectorConnectSession(
+ await GetConnectorConnectSessionBinding(BindingsGetConnectorConnectSessionRequest.createFrom(request))
+ );
},
+ refetchInterval: 1000,
+ staleTime: 0,
});
}
@@ -93,3 +134,24 @@ function toConnector(raw: BindingsConnector): Connector {
cookies: raw.cookies.map((item) => ({ ...item })),
};
}
+
+function toStartConnectorConnectResult(raw: BindingsStartConnectorConnectResult): StartConnectorConnectResult {
+ return {
+ ...raw,
+ connector: toConnector(raw.connector),
+ };
+}
+
+function toFinishConnectorConnectResult(raw: BindingsFinishConnectorConnectResult): FinishConnectorConnectResult {
+ return {
+ ...raw,
+ connector: toConnector(raw.connector),
+ };
+}
+
+function toConnectorConnectSession(raw: BindingsConnectorConnectSession): ConnectorConnectSession {
+ return {
+ ...raw,
+ connector: toConnector(raw.connector),
+ };
+}
diff --git a/frontend/src/shared/query/update.ts b/frontend/src/shared/query/update.ts
index 9ef287c..02279b2 100644
--- a/frontend/src/shared/query/update.ts
+++ b/frontend/src/shared/query/update.ts
@@ -1,9 +1,15 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { Call } from "@wailsio/runtime";
-import { normalizeUpdateInfo, type UpdateInfo } from "@/shared/store/update";
+import {
+ normalizeUpdateInfo,
+ normalizeWhatsNewInfo,
+ type UpdateInfo,
+ type WhatsNewInfo,
+} from "@/shared/store/update";
const UPDATE_QUERY_KEY = ["update-state"];
+const WHATS_NEW_QUERY_KEY = ["whats-new"];
export function useUpdateState() {
return useQuery({
@@ -58,4 +64,30 @@ export function useRestartToApply() {
});
}
-export { UPDATE_QUERY_KEY };
+export function useWhatsNew() {
+ return useQuery({
+ queryKey: WHATS_NEW_QUERY_KEY,
+ queryFn: async (): Promise => {
+ const result = await Call.ByName("dreamcreator/internal/presentation/wails.UpdateHandler.GetWhatsNew");
+ return normalizeWhatsNewInfo(result as Partial);
+ },
+ staleTime: Infinity,
+ });
+}
+
+export function useDismissWhatsNew() {
+ const queryClient = useQueryClient();
+ return useMutation({
+ mutationFn: async (version: string): Promise => {
+ await Call.ByName(
+ "dreamcreator/internal/presentation/wails.UpdateHandler.DismissWhatsNew",
+ version
+ );
+ },
+ onSuccess: () => {
+ queryClient.setQueryData(WHATS_NEW_QUERY_KEY, null);
+ },
+ });
+}
+
+export { UPDATE_QUERY_KEY, WHATS_NEW_QUERY_KEY };
diff --git a/frontend/src/shared/store/gatewayTools.ts b/frontend/src/shared/store/gatewayTools.ts
index f5d024e..9978646 100644
--- a/frontend/src/shared/store/gatewayTools.ts
+++ b/frontend/src/shared/store/gatewayTools.ts
@@ -11,6 +11,7 @@ export type GatewayToolRequirement = {
name?: string;
available: boolean;
reason?: string;
+ data?: unknown;
};
export type GatewayToolSpec = {
diff --git a/frontend/src/shared/store/update.ts b/frontend/src/shared/store/update.ts
index 3eab3e3..5ae5648 100644
--- a/frontend/src/shared/store/update.ts
+++ b/frontend/src/shared/store/update.ts
@@ -16,6 +16,8 @@ export interface UpdateInfo {
currentVersion: string;
latestVersion: string;
changelog: string;
+ preparedVersion: string;
+ preparedChangelog: string;
downloadURL: string;
checkedAt?: string;
status: UpdateStatus;
@@ -23,9 +25,23 @@ export interface UpdateInfo {
message?: string;
}
+export interface WhatsNewInfo {
+ version: string;
+ currentVersion: string;
+ changelog: string;
+}
+
+export interface WhatsNewPreview {
+ notice: WhatsNewInfo;
+ targetWindow: "main" | "settings";
+}
+
export interface UpdateStore {
info: UpdateInfo;
+ whatsNewPreview: WhatsNewPreview | null;
setInfo: (info: UpdateInfo) => void;
+ openWhatsNewPreview: (notice: WhatsNewInfo, targetWindow: WhatsNewPreview["targetWindow"]) => void;
+ clearWhatsNewPreview: () => void;
}
const defaultInfo: UpdateInfo = {
@@ -33,6 +49,8 @@ const defaultInfo: UpdateInfo = {
currentVersion: "",
latestVersion: "",
changelog: "",
+ preparedVersion: "",
+ preparedChangelog: "",
downloadURL: "",
status: "idle",
progress: 0,
@@ -41,7 +59,10 @@ const defaultInfo: UpdateInfo = {
export const useUpdateStore = create((set) => ({
info: defaultInfo,
+ whatsNewPreview: null,
setInfo: (info) => set({ info }),
+ openWhatsNewPreview: (notice, targetWindow) => set({ whatsNewPreview: { notice, targetWindow } }),
+ clearWhatsNewPreview: () => set({ whatsNewPreview: null }),
}));
export function normalizeUpdateInfo(raw: Partial | null | undefined): UpdateInfo {
@@ -54,6 +75,8 @@ export function normalizeUpdateInfo(raw: Partial | null | undefined)
currentVersion: raw.currentVersion ?? anyRaw.CurrentVersion ?? "",
latestVersion: raw.latestVersion ?? anyRaw.LatestVersion ?? "",
changelog: raw.changelog ?? anyRaw.Changelog ?? "",
+ preparedVersion: raw.preparedVersion ?? anyRaw.PreparedVersion ?? "",
+ preparedChangelog: raw.preparedChangelog ?? anyRaw.PreparedChangelog ?? "",
downloadURL: raw.downloadURL ?? anyRaw.DownloadURL ?? "",
checkedAt: raw.checkedAt ?? anyRaw.CheckedAt,
status: (raw.status as UpdateStatus) ?? (anyRaw.Status as UpdateStatus) ?? "idle",
@@ -61,3 +84,70 @@ export function normalizeUpdateInfo(raw: Partial | null | undefined)
message: raw.message ?? anyRaw.Message ?? "",
};
}
+
+export function normalizeWhatsNewInfo(
+ raw: Partial | null | undefined
+): WhatsNewInfo | null {
+ if (!raw) {
+ return null;
+ }
+ const anyRaw = raw as any;
+ const version = (raw.version ?? anyRaw.Version ?? "").trim();
+ if (!version) {
+ return null;
+ }
+ return {
+ version,
+ currentVersion: (raw.currentVersion ?? anyRaw.CurrentVersion ?? version).trim(),
+ changelog: raw.changelog ?? anyRaw.Changelog ?? "",
+ };
+}
+
+export function compareUpdateVersion(left: string, right: string): number {
+ const leftParts = normalizeVersionParts(left);
+ const rightParts = normalizeVersionParts(right);
+ const maxLength = Math.max(leftParts.length, rightParts.length);
+ for (let index = 0; index < maxLength; index += 1) {
+ const leftValue = leftParts[index] ?? 0;
+ const rightValue = rightParts[index] ?? 0;
+ if (leftValue < rightValue) {
+ return -1;
+ }
+ if (leftValue > rightValue) {
+ return 1;
+ }
+ }
+ return 0;
+}
+
+export function hasPreparedUpdate(info: UpdateInfo): boolean {
+ const preparedVersion = info.preparedVersion.trim();
+ if (!preparedVersion) {
+ return false;
+ }
+ return compareUpdateVersion(info.currentVersion, preparedVersion) < 0;
+}
+
+export function hasRemoteUpdate(info: UpdateInfo): boolean {
+ const latestVersion = info.latestVersion.trim();
+ if (!latestVersion) {
+ return false;
+ }
+ return compareUpdateVersion(info.currentVersion, latestVersion) < 0;
+}
+
+export function displayUpdateVersion(info: UpdateInfo): string {
+ if (hasPreparedUpdate(info)) {
+ return info.preparedVersion.trim();
+ }
+ return info.latestVersion.trim();
+}
+
+function normalizeVersionParts(version: string): number[] {
+ return version
+ .trim()
+ .replace(/^v/i, "")
+ .split(".")
+ .map((part) => Number.parseInt(part, 10))
+ .filter((part) => Number.isFinite(part));
+}
diff --git a/go.mod b/go.mod
index 38eaa33..1b08cea 100644
--- a/go.mod
+++ b/go.mod
@@ -4,8 +4,11 @@ go 1.25.5
require (
github.com/JohannesKaufmann/html-to-markdown v1.6.0
+ github.com/PuerkitoBio/goquery v1.9.2
github.com/asg017/sqlite-vec-go-bindings v0.1.6
github.com/bep/debounce v1.2.1
+ github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d
+ github.com/chromedp/chromedp v0.14.2
github.com/cloudwego/eino v0.8.6
github.com/eino-contrib/jsonschema v1.0.3
github.com/fsnotify/fsnotify v1.8.0
@@ -13,7 +16,6 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.8
github.com/mymmrac/telego v1.6.0
github.com/ncruces/go-sqlite3 v0.23.3
- github.com/playwright-community/playwright-go v0.5200.1
github.com/tetratelabs/wazero v1.11.0
github.com/uptrace/bun v1.2.16
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
@@ -34,7 +36,6 @@ require (
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
- github.com/PuerkitoBio/goquery v1.9.2 // indirect
github.com/adrg/xdg v0.5.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/andybalholm/cascadia v1.3.2 // indirect
@@ -43,11 +44,11 @@ require (
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
+ github.com/chromedp/sysutil v1.1.0 // indirect
github.com/cloudflare/circl v1.6.3 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
- github.com/deckarep/golang-set/v2 v2.7.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.9.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
@@ -55,9 +56,11 @@ require (
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
github.com/go-git/go-billy/v5 v5.8.0 // indirect
github.com/go-git/go-git/v5 v5.17.2 // indirect
- github.com/go-jose/go-jose/v3 v3.0.5 // indirect
+ github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
- github.com/go-stack/stack v1.8.1 // indirect
+ github.com/gobwas/httphead v0.1.0 // indirect
+ github.com/gobwas/pool v0.2.1 // indirect
+ github.com/gobwas/ws v1.4.0 // indirect
github.com/godbus/dbus/v5 v5.2.2 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/goph/emperror v0.17.2 // indirect
diff --git a/go.sum b/go.sum
index 442a444..d849e81 100644
--- a/go.sum
+++ b/go.sum
@@ -41,6 +41,12 @@ github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9V
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
+github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d h1:ZtA1sedVbEW7EW80Iz2GR3Ye6PwbJAJXjv7D74xG6HU=
+github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d/go.mod h1:NItd7aLkcfOA/dcMXvl8p1u+lQqioRMq/SqDp71Pb/k=
+github.com/chromedp/chromedp v0.14.2 h1:r3b/WtwM50RsBZHMUm9fsNhhzRStTHrKdr2zmwbZSzM=
+github.com/chromedp/chromedp v0.14.2/go.mod h1:rHzAv60xDE7VNy/MYtTUrYreSc0ujt2O1/C3bzctYBo=
+github.com/chromedp/sysutil v1.1.0 h1:PUFNv5EcprjqXZD9nJb9b/c9ibAbxiYo4exNWZyipwM=
+github.com/chromedp/sysutil v1.1.0/go.mod h1:WiThHUdltqCNKGc4gaU50XgYjwjYIhKWoHGPTUfWTJ8=
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
@@ -54,8 +60,6 @@ github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/deckarep/golang-set/v2 v2.7.0 h1:gIloKvD7yH2oip4VLhsv3JyLLFnC0Y2mlusgcvJYW5k=
-github.com/deckarep/golang-set/v2 v2.7.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
@@ -84,21 +88,22 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
github.com/go-git/go-git/v5 v5.17.2 h1:B+nkdlxdYrvyFK4GPXVU8w1U+YkbsgciIR7f2sZJ104=
github.com/go-git/go-git/v5 v5.17.2/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo=
-github.com/go-jose/go-jose/v3 v3.0.5 h1:BLLJWbC4nMZOfuPVxoZIxeYsn6Nl2r1fITaJ78UQlVQ=
-github.com/go-jose/go-jose/v3 v3.0.5/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
-github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw=
-github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4=
+github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
+github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
+github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
+github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
+github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
+github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -147,6 +152,8 @@ github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed
github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU=
github.com/leaanthony/u v1.1.1 h1:TUFjwDGlNX+WuwVEzDqQwC2lOv0P4uhTQw7CMFdiK7M=
github.com/leaanthony/u v1.1.1/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI=
+github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
+github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
@@ -160,8 +167,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
-github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc=
-github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -180,6 +185,8 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
+github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde h1:x0TT0RDC7UhAVbbWWBzr41ElhJx5tXPWkIHA2HWPRuw=
+github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
@@ -190,8 +197,6 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/playwright-community/playwright-go v0.5200.1 h1:Sm2oOuhqt0M5Y4kUi/Qh9w4cyyi3ZIWTBeGKImc2UVo=
-github.com/playwright-community/playwright-go v0.5200.1/go.mod h1:UnnyQZaqUOO5ywAZu60+N4EiWReUqX1MQBBA3Oofvf8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go
index d3c0339..608f796 100644
--- a/internal/app/bootstrap.go
+++ b/internal/app/bootstrap.go
@@ -13,7 +13,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync"
"time"
agentservice "dreamcreator/internal/application/agent/service"
@@ -125,63 +124,6 @@ var (
AppDescription = "An AI assistant for content creators."
)
-type providersUpdatedWindowNotifier struct {
- manager *wails.WindowManager
-}
-
-func (notifier providersUpdatedWindowNotifier) ProvidersUpdated() {
- if notifier.manager == nil {
- return
- }
- notifier.manager.EmitProvidersUpdated()
-}
-
-type settingsBroadcastAdapter struct {
- service *service.SettingsService
-
- mu sync.RWMutex
- applier func(settingsdto.Settings)
-}
-
-func newSettingsBroadcastAdapter(settingsService *service.SettingsService) *settingsBroadcastAdapter {
- return &settingsBroadcastAdapter{service: settingsService}
-}
-
-func (adapter *settingsBroadcastAdapter) SetApplier(applier func(settingsdto.Settings)) {
- if adapter == nil {
- return
- }
- adapter.mu.Lock()
- adapter.applier = applier
- adapter.mu.Unlock()
-}
-
-func (adapter *settingsBroadcastAdapter) GetSettings(ctx context.Context) (settingsdto.Settings, error) {
- if adapter == nil || adapter.service == nil {
- return settingsdto.Settings{}, errors.New("settings service unavailable")
- }
- return adapter.service.GetSettings(ctx)
-}
-
-func (adapter *settingsBroadcastAdapter) UpdateSettings(ctx context.Context, request settingsdto.UpdateSettingsRequest) (settingsdto.Settings, error) {
- if adapter == nil || adapter.service == nil {
- return settingsdto.Settings{}, errors.New("settings service unavailable")
- }
- return adapter.service.UpdateSettings(ctx, request)
-}
-
-func (adapter *settingsBroadcastAdapter) ApplySettings(updated settingsdto.Settings) {
- if adapter == nil {
- return
- }
- adapter.mu.RLock()
- applier := adapter.applier
- adapter.mu.RUnlock()
- if applier != nil {
- applier(updated)
- }
-}
-
func CreateApplication(assets fs.FS) (*application.App, error) {
appVersion := resolveVersion(os.Getenv("APP_ENV"))
startup := currentStartupContext(os.Args[1:])
@@ -464,7 +406,7 @@ func CreateApplication(assets fs.FS) (*application.App, error) {
startAccentColorWatcher(accentCtx, settingsService, windowManager)
updateCatalog := buildSoftwareUpdateService(proxyManager)
- updateService, err := buildUpdateService(proxyManager, eventBus, windowManager, updateCatalog, appVersion)
+ updateService, err := buildUpdateService(ctx, proxyManager, eventBus, windowManager, updateCatalog, appVersion)
if err != nil {
return nil, err
}
@@ -497,7 +439,7 @@ func CreateApplication(assets fs.FS) (*application.App, error) {
startModelsDevCatalogSyncWorker(ctx, modelsDevCatalog)
connectorsRepo := connectorsrepo.NewSQLiteConnectorRepository(database.Bun)
- connectorsService := connectorsservice.NewConnectorsService(connectorsRepo)
+ connectorsService := connectorsservice.NewConnectorsService(connectorsRepo, settingsService)
if err := connectorsService.EnsureDefaults(ctx); err != nil {
return nil, err
}
@@ -561,7 +503,8 @@ func CreateApplication(assets fs.FS) (*application.App, error) {
llmCallRecordRepo := llmrecordrepo.NewSQLiteRepository(database.Bun)
llmCallRecordService := llmrecord.NewService(llmCallRecordRepo, settingsService)
toolService := toolsservice.NewToolService()
- toolService.SetPolicy(gatewaytools.NewPolicyPipeline(settingsService))
+ toolPolicy := gatewaytools.NewPolicyPipeline(settingsService)
+ toolService.SetPolicy(toolPolicy)
toolExecutor := gatewaytools.NewRegistryExecutor()
toolService.SetExecutor(toolExecutor)
policyAuditStore := toolpolicyrepo.NewSQLitePolicyAuditStore(database.Bun)
@@ -643,6 +586,16 @@ func CreateApplication(assets fs.FS) (*application.App, error) {
voiceConfigRepo := voicerepo.NewSQLiteVoiceConfigRepository(database.Bun)
ttsJobRepo := voicerepo.NewSQLiteTTSJobRepository(database.Bun)
voiceService := gatewayvoice.NewService(voiceConfigRepo, ttsJobRepo, usageService, settingsService, gatewayServer)
+ builtinRequirementResolver := gatewaytools.NewBuiltinRequirementResolver(gatewaytools.BuiltinRequirementDeps{
+ Settings: settingsNotifier,
+ Assistants: assistantService,
+ Providers: providerRepo,
+ Models: modelRepo,
+ Secrets: secretRepo,
+ Voice: voiceService,
+ })
+ toolPolicy.SetRequirementsResolver(builtinRequirementResolver)
+ gatewayToolService.SetRequirementsResolver(builtinRequirementResolver)
gatewaymethods.RegisterUsage(gatewayRouter, usageService)
gatewaymethods.RegisterVoice(gatewayRouter, voiceService)
gatewaytools.RegisterBuiltinTools(ctx, toolService, toolExecutor, gatewaytools.BuiltinToolDeps{
@@ -1313,7 +1266,7 @@ func buildSoftwareUpdateService(proxyManager *proxy.Manager) *softwareupdate.Ser
})
}
-func buildUpdateService(proxyManager *proxy.Manager, bus appevents.Bus, notifier applicationupdate.Notifier, catalog *softwareupdate.Service, currentVersion string) (*applicationupdate.Service, error) {
+func buildUpdateService(ctx context.Context, proxyManager *proxy.Manager, bus appevents.Bus, notifier applicationupdate.Notifier, catalog *softwareupdate.Service, currentVersion string) (*applicationupdate.Service, error) {
httpClient := proxyManager.HTTPClient()
downloader := infrastructureupdate.NewHTTPDownloader(httpClient)
installer, err := infrastructureupdate.NewInstaller("")
@@ -1329,6 +1282,9 @@ func buildUpdateService(proxyManager *proxy.Manager, bus appevents.Bus, notifier
Notifier: notifier,
})
service.SetCurrentVersion(currentVersion)
+ if _, err := service.RestorePreparedUpdate(ctx); err != nil {
+ zap.L().Warn("update: restore prepared update failed", zap.Error(err))
+ }
return service, nil
}
diff --git a/internal/app/bootstrap_support.go b/internal/app/bootstrap_support.go
new file mode 100644
index 0000000..3f44252
--- /dev/null
+++ b/internal/app/bootstrap_support.go
@@ -0,0 +1,68 @@
+package app
+
+import (
+ "context"
+ "errors"
+ "sync"
+
+ settingsdto "dreamcreator/internal/application/settings/dto"
+ "dreamcreator/internal/application/settings/service"
+ "dreamcreator/internal/presentation/wails"
+)
+
+type providersUpdatedWindowNotifier struct {
+ manager *wails.WindowManager
+}
+
+func (notifier providersUpdatedWindowNotifier) ProvidersUpdated() {
+ if notifier.manager == nil {
+ return
+ }
+ notifier.manager.EmitProvidersUpdated()
+}
+
+type settingsBroadcastAdapter struct {
+ service *service.SettingsService
+
+ mu sync.RWMutex
+ applier func(settingsdto.Settings)
+}
+
+func newSettingsBroadcastAdapter(settingsService *service.SettingsService) *settingsBroadcastAdapter {
+ return &settingsBroadcastAdapter{service: settingsService}
+}
+
+func (adapter *settingsBroadcastAdapter) SetApplier(applier func(settingsdto.Settings)) {
+ if adapter == nil {
+ return
+ }
+ adapter.mu.Lock()
+ adapter.applier = applier
+ adapter.mu.Unlock()
+}
+
+func (adapter *settingsBroadcastAdapter) GetSettings(ctx context.Context) (settingsdto.Settings, error) {
+ if adapter == nil || adapter.service == nil {
+ return settingsdto.Settings{}, errors.New("settings service unavailable")
+ }
+ return adapter.service.GetSettings(ctx)
+}
+
+func (adapter *settingsBroadcastAdapter) UpdateSettings(ctx context.Context, request settingsdto.UpdateSettingsRequest) (settingsdto.Settings, error) {
+ if adapter == nil || adapter.service == nil {
+ return settingsdto.Settings{}, errors.New("settings service unavailable")
+ }
+ return adapter.service.UpdateSettings(ctx, request)
+}
+
+func (adapter *settingsBroadcastAdapter) ApplySettings(updated settingsdto.Settings) {
+ if adapter == nil {
+ return
+ }
+ adapter.mu.RLock()
+ applier := adapter.applier
+ adapter.mu.RUnlock()
+ if applier != nil {
+ applier(updated)
+ }
+}
diff --git a/internal/app/startup_context.go b/internal/app/startup_context.go
index 9902dda..1168291 100644
--- a/internal/app/startup_context.go
+++ b/internal/app/startup_context.go
@@ -3,16 +3,22 @@ package app
import "strings"
const autoStartLaunchArgument = "--autostart"
+const skipPreparedUpdateLaunchArgument = "--skip-prepared-update-once"
type startupContext struct {
launchedByAutoStart bool
+ skipPreparedUpdate bool
}
func currentStartupContext(args []string) startupContext {
+ context := startupContext{}
for _, arg := range args {
- if strings.EqualFold(strings.TrimSpace(arg), autoStartLaunchArgument) {
- return startupContext{launchedByAutoStart: true}
+ switch strings.ToLower(strings.TrimSpace(arg)) {
+ case autoStartLaunchArgument:
+ context.launchedByAutoStart = true
+ case skipPreparedUpdateLaunchArgument:
+ context.skipPreparedUpdate = true
}
}
- return startupContext{}
+ return context
}
diff --git a/internal/app/startup_context_test.go b/internal/app/startup_context_test.go
index 5841149..f82720a 100644
--- a/internal/app/startup_context_test.go
+++ b/internal/app/startup_context_test.go
@@ -4,20 +4,25 @@ import "testing"
func TestCurrentStartupContext(t *testing.T) {
tests := []struct {
- name string
- args []string
- expected bool
+ name string
+ args []string
+ expectedAutoStart bool
+ expectedSkip bool
}{
- {name: "no marker", args: []string{"--verbose"}, expected: false},
- {name: "exact marker", args: []string{"--autostart"}, expected: true},
- {name: "marker with spaces and mixed case", args: []string{" --AutoStart "}, expected: true},
+ {name: "no marker", args: []string{"--verbose"}},
+ {name: "exact marker", args: []string{"--autostart"}, expectedAutoStart: true},
+ {name: "marker with spaces and mixed case", args: []string{" --AutoStart "}, expectedAutoStart: true},
+ {name: "skip prepared update marker", args: []string{"--skip-prepared-update-once"}, expectedSkip: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := currentStartupContext(tt.args)
- if got.launchedByAutoStart != tt.expected {
- t.Fatalf("launchedByAutoStart = %v, want %v", got.launchedByAutoStart, tt.expected)
+ if got.launchedByAutoStart != tt.expectedAutoStart {
+ t.Fatalf("launchedByAutoStart = %v, want %v", got.launchedByAutoStart, tt.expectedAutoStart)
+ }
+ if got.skipPreparedUpdate != tt.expectedSkip {
+ t.Fatalf("skipPreparedUpdate = %v, want %v", got.skipPreparedUpdate, tt.expectedSkip)
}
})
}
diff --git a/internal/app/update_startup.go b/internal/app/update_startup.go
new file mode 100644
index 0000000..0dcfa62
--- /dev/null
+++ b/internal/app/update_startup.go
@@ -0,0 +1,83 @@
+package app
+
+import (
+ "context"
+ "os"
+ "strconv"
+ "strings"
+
+ domainupdate "dreamcreator/internal/domain/update"
+ infrastructureupdate "dreamcreator/internal/infrastructure/update"
+)
+
+type preparedUpdateStartupRunner interface {
+ PreparedUpdate(ctx context.Context) (domainupdate.Info, bool, error)
+ ClearPreparedUpdate(ctx context.Context) error
+ RestartToApply(ctx context.Context) error
+}
+
+func TryApplyPreparedUpdateOnLaunch(ctx context.Context, args []string) (bool, error) {
+ startup := currentStartupContext(args)
+ if startup.skipPreparedUpdate {
+ return false, nil
+ }
+
+ installer, err := infrastructureupdate.NewInstaller("")
+ if err != nil {
+ return false, err
+ }
+ return maybeApplyPreparedUpdateOnLaunch(ctx, resolveVersion(os.Getenv("APP_ENV")), installer)
+}
+
+func maybeApplyPreparedUpdateOnLaunch(ctx context.Context, currentVersion string, installer preparedUpdateStartupRunner) (bool, error) {
+ if installer == nil {
+ return false, nil
+ }
+
+ normalizedCurrent := domainupdate.NormalizeVersion(currentVersion)
+ if !isComparableReleaseVersion(normalizedCurrent) {
+ return false, nil
+ }
+
+ prepared, found, err := installer.PreparedUpdate(ctx)
+ if err != nil || !found {
+ return false, err
+ }
+
+ preparedVersion := domainupdate.NormalizeVersion(prepared.PreparedVersion)
+ if preparedVersion == "" {
+ preparedVersion = domainupdate.NormalizeVersion(prepared.LatestVersion)
+ }
+ if !isComparableReleaseVersion(preparedVersion) {
+ return false, nil
+ }
+
+ if domainupdate.CompareVersion(normalizedCurrent, preparedVersion) >= 0 {
+ if err := installer.ClearPreparedUpdate(ctx); err != nil {
+ return false, err
+ }
+ return false, nil
+ }
+
+ if err := installer.RestartToApply(ctx); err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func isComparableReleaseVersion(version string) bool {
+ normalized := domainupdate.NormalizeVersion(version)
+ if normalized == "" {
+ return false
+ }
+ parts := strings.Split(normalized, ".")
+ for _, part := range parts {
+ if part == "" {
+ return false
+ }
+ if _, err := strconv.Atoi(part); err != nil {
+ return false
+ }
+ }
+ return true
+}
diff --git a/internal/app/update_startup_test.go b/internal/app/update_startup_test.go
new file mode 100644
index 0000000..2181b69
--- /dev/null
+++ b/internal/app/update_startup_test.go
@@ -0,0 +1,119 @@
+package app
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ domainupdate "dreamcreator/internal/domain/update"
+)
+
+type preparedUpdateStartupStub struct {
+ prepared domainupdate.Info
+ found bool
+ preparedErr error
+ clearInvoked bool
+ restartInvoked bool
+ restartErr error
+}
+
+func (stub *preparedUpdateStartupStub) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) {
+ return stub.prepared, stub.found, stub.preparedErr
+}
+
+func (stub *preparedUpdateStartupStub) ClearPreparedUpdate(_ context.Context) error {
+ stub.clearInvoked = true
+ return nil
+}
+
+func (stub *preparedUpdateStartupStub) RestartToApply(_ context.Context) error {
+ stub.restartInvoked = true
+ return stub.restartErr
+}
+
+func TestMaybeApplyPreparedUpdateOnLaunchStartsHelperForOlderCurrentVersion(t *testing.T) {
+ t.Parallel()
+
+ installer := &preparedUpdateStartupStub{
+ found: true,
+ prepared: domainupdate.Info{
+ PreparedVersion: "2.0.7",
+ },
+ }
+
+ applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.6", installer)
+ if err != nil {
+ t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err)
+ }
+ if !applied {
+ t.Fatal("expected prepared update to be applied on launch")
+ }
+ if !installer.restartInvoked {
+ t.Fatal("expected restart helper to be invoked")
+ }
+}
+
+func TestMaybeApplyPreparedUpdateOnLaunchClearsStalePreparedPlan(t *testing.T) {
+ t.Parallel()
+
+ installer := &preparedUpdateStartupStub{
+ found: true,
+ prepared: domainupdate.Info{
+ PreparedVersion: "2.0.7",
+ },
+ }
+
+ applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.7", installer)
+ if err != nil {
+ t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err)
+ }
+ if applied {
+ t.Fatal("did not expect prepared update to apply when current version is already latest")
+ }
+ if !installer.clearInvoked {
+ t.Fatal("expected stale prepared plan to be cleared")
+ }
+}
+
+func TestMaybeApplyPreparedUpdateOnLaunchSkipsNonReleaseVersion(t *testing.T) {
+ t.Parallel()
+
+ installer := &preparedUpdateStartupStub{
+ found: true,
+ prepared: domainupdate.Info{
+ PreparedVersion: "2.0.7",
+ },
+ }
+
+ applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "dev", installer)
+ if err != nil {
+ t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err)
+ }
+ if applied {
+ t.Fatal("did not expect dev build to auto-apply prepared update")
+ }
+ if installer.restartInvoked {
+ t.Fatal("did not expect restart helper to run")
+ }
+}
+
+func TestMaybeApplyPreparedUpdateOnLaunchReturnsRestartError(t *testing.T) {
+ t.Parallel()
+
+ restartErr := errors.New("helper launch failed")
+ installer := &preparedUpdateStartupStub{
+ found: true,
+ prepared: domainupdate.Info{
+ PreparedVersion: "2.0.7",
+ },
+ restartErr: restartErr,
+ }
+
+ applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.6", installer)
+ if !errors.Is(err, restartErr) {
+ t.Fatalf("expected restart error, got %v", err)
+ }
+ if applied {
+ t.Fatal("did not expect prepared update to report success")
+ }
+}
diff --git a/internal/application/agentruntime/tool_executor.go b/internal/application/agentruntime/tool_executor.go
index a216fff..7776313 100644
--- a/internal/application/agentruntime/tool_executor.go
+++ b/internal/application/agentruntime/tool_executor.go
@@ -14,9 +14,10 @@ import (
var ErrToolNotFound = errors.New("tool not found")
type ToolDefinition struct {
- Name string
- Type string
- Invoke func(ctx context.Context, args string) (string, error)
+ Name string
+ Type string
+ SchemaJSON string
+ Invoke func(ctx context.Context, args string) (string, error)
}
type ToolExecutor struct {
diff --git a/internal/application/agentruntime/tool_validator.go b/internal/application/agentruntime/tool_validator.go
index 93b1c0f..a5aa2d3 100644
--- a/internal/application/agentruntime/tool_validator.go
+++ b/internal/application/agentruntime/tool_validator.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "reflect"
"strings"
"github.com/cloudwego/eino/schema"
@@ -12,15 +13,18 @@ import (
var (
ErrToolNameRequired = errors.New("tool name is required")
ErrToolArgsInvalid = errors.New("tool args must be a json object")
+ ErrToolArgsSchema = errors.New("tool args do not match schema")
)
type ToolValidator interface {
Validate(call schema.ToolCall) error
}
-type JSONToolValidator struct{}
+type JSONToolValidator struct {
+ Tools map[string]ToolDefinition
+}
-func (JSONToolValidator) Validate(call schema.ToolCall) error {
+func (validator JSONToolValidator) Validate(call schema.ToolCall) error {
name := strings.TrimSpace(call.Function.Name)
if name == "" {
return ErrToolNameRequired
@@ -36,5 +40,249 @@ func (JSONToolValidator) Validate(call schema.ToolCall) error {
if _, ok := decoded.(map[string]any); !ok {
return ErrToolArgsInvalid
}
+ if definition, ok := validator.Tools[name]; ok {
+ if err := validateToolArgsAgainstSchema(decoded, strings.TrimSpace(definition.SchemaJSON)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func validateToolArgsAgainstSchema(value any, schemaJSON string) error {
+ if strings.TrimSpace(schemaJSON) == "" {
+ return nil
+ }
+ var schema map[string]any
+ if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
+ return nil
+ }
+ return validateJSONSchemaValue(value, schema, "$")
+}
+
+func validateJSONSchemaValue(value any, schema map[string]any, path string) error {
+ if len(schema) == 0 {
+ return nil
+ }
+ if allOf, ok := schema["allOf"].([]any); ok {
+ for _, item := range allOf {
+ child, _ := item.(map[string]any)
+ if err := validateJSONSchemaValue(value, child, path); err != nil {
+ return err
+ }
+ }
+ }
+ if anyOf, ok := schema["anyOf"].([]any); ok {
+ matched := false
+ var lastErr error
+ for _, item := range anyOf {
+ child, _ := item.(map[string]any)
+ err := validateJSONSchemaValue(value, child, path)
+ if err == nil {
+ matched = true
+ break
+ }
+ lastErr = err
+ }
+ if !matched {
+ if lastErr != nil {
+ return lastErr
+ }
+ return fmt.Errorf("%w: %s must match at least one schema", ErrToolArgsSchema, path)
+ }
+ }
+ if oneOf, ok := schema["oneOf"].([]any); ok {
+ matches := 0
+ var lastErr error
+ for _, item := range oneOf {
+ child, _ := item.(map[string]any)
+ err := validateJSONSchemaValue(value, child, path)
+ if err == nil {
+ matches++
+ continue
+ }
+ lastErr = err
+ }
+ if matches != 1 {
+ if lastErr != nil && matches == 0 {
+ return lastErr
+ }
+ return fmt.Errorf("%w: %s must match exactly one schema", ErrToolArgsSchema, path)
+ }
+ }
+ if constValue, ok := schema["const"]; ok && !jsonSchemaValuesEqual(value, constValue) {
+ return fmt.Errorf("%w: %s must equal %v", ErrToolArgsSchema, path, constValue)
+ }
+ if enumValues, ok := schema["enum"].([]any); ok && len(enumValues) > 0 {
+ matched := false
+ for _, enumValue := range enumValues {
+ if jsonSchemaValuesEqual(value, enumValue) {
+ matched = true
+ break
+ }
+ }
+ if !matched {
+ return fmt.Errorf("%w: %s must be one of %v", ErrToolArgsSchema, path, enumValues)
+ }
+ }
+ switch schemaType := schema["type"].(type) {
+ case string:
+ if err := validateJSONSchemaType(value, schema, path, schemaType); err != nil {
+ return err
+ }
+ case []any:
+ var lastErr error
+ for _, raw := range schemaType {
+ typeName, _ := raw.(string)
+ if typeName == "" {
+ continue
+ }
+ if err := validateJSONSchemaType(value, schema, path, typeName); err == nil {
+ return nil
+ } else {
+ lastErr = err
+ }
+ }
+ if lastErr != nil {
+ return lastErr
+ }
+ }
+ if _, hasProperties := schema["properties"]; hasProperties || schema["required"] != nil || schema["additionalProperties"] != nil {
+ return validateJSONSchemaType(value, schema, path, "object")
+ }
+ if _, hasItems := schema["items"]; hasItems {
+ return validateJSONSchemaType(value, schema, path, "array")
+ }
return nil
}
+
+func validateJSONSchemaType(value any, schema map[string]any, path string, schemaType string) error {
+ switch schemaType {
+ case "object":
+ obj, ok := value.(map[string]any)
+ if !ok {
+ return fmt.Errorf("%w: %s must be an object", ErrToolArgsSchema, path)
+ }
+ properties, _ := schema["properties"].(map[string]any)
+ requiredFields := jsonSchemaStringSlice(schema["required"])
+ for _, field := range requiredFields {
+ if _, exists := obj[field]; !exists {
+ return fmt.Errorf("%w: %s.%s is required", ErrToolArgsSchema, path, field)
+ }
+ }
+ for key, raw := range obj {
+ childPath := path + "." + key
+ propertySchema, hasProperty := properties[key]
+ if hasProperty {
+ if child, ok := propertySchema.(map[string]any); ok {
+ if err := validateJSONSchemaValue(raw, child, childPath); err != nil {
+ return err
+ }
+ }
+ continue
+ }
+ switch additional := schema["additionalProperties"].(type) {
+ case bool:
+ if !additional {
+ return fmt.Errorf("%w: %s is not allowed", ErrToolArgsSchema, childPath)
+ }
+ case map[string]any:
+ if err := validateJSONSchemaValue(raw, additional, childPath); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+ case "array":
+ items, ok := value.([]any)
+ if !ok {
+ return fmt.Errorf("%w: %s must be an array", ErrToolArgsSchema, path)
+ }
+ if itemSchema, ok := schema["items"].(map[string]any); ok {
+ for index, item := range items {
+ if err := validateJSONSchemaValue(item, itemSchema, fmt.Sprintf("%s[%d]", path, index)); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+ case "string":
+ if _, ok := value.(string); !ok {
+ return fmt.Errorf("%w: %s must be a string", ErrToolArgsSchema, path)
+ }
+ return nil
+ case "boolean":
+ if _, ok := value.(bool); !ok {
+ return fmt.Errorf("%w: %s must be a boolean", ErrToolArgsSchema, path)
+ }
+ return nil
+ case "integer":
+ number, ok := jsonSchemaAsFloat64(value)
+ if !ok || float64(int64(number)) != number {
+ return fmt.Errorf("%w: %s must be an integer", ErrToolArgsSchema, path)
+ }
+ if minimum, ok := jsonSchemaAsFloat64(schema["minimum"]); ok && number < minimum {
+ return fmt.Errorf("%w: %s must be >= %v", ErrToolArgsSchema, path, minimum)
+ }
+ if maximum, ok := jsonSchemaAsFloat64(schema["maximum"]); ok && number > maximum {
+ return fmt.Errorf("%w: %s must be <= %v", ErrToolArgsSchema, path, maximum)
+ }
+ return nil
+ case "number":
+ number, ok := jsonSchemaAsFloat64(value)
+ if !ok {
+ return fmt.Errorf("%w: %s must be a number", ErrToolArgsSchema, path)
+ }
+ if minimum, ok := jsonSchemaAsFloat64(schema["minimum"]); ok && number < minimum {
+ return fmt.Errorf("%w: %s must be >= %v", ErrToolArgsSchema, path, minimum)
+ }
+ if maximum, ok := jsonSchemaAsFloat64(schema["maximum"]); ok && number > maximum {
+ return fmt.Errorf("%w: %s must be <= %v", ErrToolArgsSchema, path, maximum)
+ }
+ return nil
+ default:
+ return nil
+ }
+}
+
+func jsonSchemaStringSlice(value any) []string {
+ switch typed := value.(type) {
+ case []string:
+ return append([]string(nil), typed...)
+ case []any:
+ result := make([]string, 0, len(typed))
+ for _, item := range typed {
+ if text, ok := item.(string); ok && strings.TrimSpace(text) != "" {
+ result = append(result, text)
+ }
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
+func jsonSchemaAsFloat64(value any) (float64, bool) {
+ switch typed := value.(type) {
+ case float64:
+ return typed, true
+ case float32:
+ return float64(typed), true
+ case int:
+ return float64(typed), true
+ case int64:
+ return float64(typed), true
+ case int32:
+ return float64(typed), true
+ default:
+ return 0, false
+ }
+}
+
+func jsonSchemaValuesEqual(left any, right any) bool {
+ leftNumber, leftIsNumber := jsonSchemaAsFloat64(left)
+ rightNumber, rightIsNumber := jsonSchemaAsFloat64(right)
+ if leftIsNumber && rightIsNumber {
+ return leftNumber == rightNumber
+ }
+ return reflect.DeepEqual(left, right)
+}
diff --git a/internal/application/agentruntime/tool_validator_test.go b/internal/application/agentruntime/tool_validator_test.go
index 1faa9d3..a26da28 100644
--- a/internal/application/agentruntime/tool_validator_test.go
+++ b/internal/application/agentruntime/tool_validator_test.go
@@ -33,3 +33,76 @@ func TestJSONToolValidatorAcceptsObjectArgs(t *testing.T) {
t.Fatalf("expected valid args, got %v", err)
}
}
+
+func TestJSONToolValidatorRejectsSchemaViolation(t *testing.T) {
+ validator := JSONToolValidator{
+ Tools: map[string]ToolDefinition{
+ "browser": {
+ Name: "browser",
+ SchemaJSON: `{
+ "type":"object",
+ "properties":{
+ "action":{"type":"string","enum":["open","act"]},
+ "url":{"type":"string"},
+ "request":{
+ "type":"object",
+ "properties":{"kind":{"type":"string"}},
+ "required":["kind"]
+ }
+ },
+ "required":["action"],
+ "allOf":[
+ {
+ "anyOf":[
+ {"properties":{"action":{"const":"open"}},"required":["action","url"]},
+ {"properties":{"action":{"const":"act"}},"required":["action","request"]}
+ ]
+ }
+ ]
+ }`,
+ },
+ },
+ }
+
+ err := validator.Validate(schema.ToolCall{
+ Function: schema.FunctionCall{
+ Name: "browser",
+ Arguments: `{"action":"open"}`,
+ },
+ })
+ if err == nil || !errors.Is(err, ErrToolArgsSchema) {
+ t.Fatalf("expected schema error for missing url, got %v", err)
+ }
+}
+
+func TestJSONToolValidatorRejectsNestedRequiredSchemaViolation(t *testing.T) {
+ validator := JSONToolValidator{
+ Tools: map[string]ToolDefinition{
+ "browser": {
+ Name: "browser",
+ SchemaJSON: `{
+ "type":"object",
+ "properties":{
+ "action":{"type":"string","enum":["act"]},
+ "request":{
+ "type":"object",
+ "properties":{"kind":{"type":"string"}},
+ "required":["kind"]
+ }
+ },
+ "required":["action","request"]
+ }`,
+ },
+ },
+ }
+
+ err := validator.Validate(schema.ToolCall{
+ Function: schema.FunctionCall{
+ Name: "browser",
+ Arguments: `{"action":"act","request":{}}`,
+ },
+ })
+ if err == nil || !errors.Is(err, ErrToolArgsSchema) {
+ t.Fatalf("expected schema error for missing request.kind, got %v", err)
+ }
+}
diff --git a/internal/application/browsercdp/connectors.go b/internal/application/browsercdp/connectors.go
new file mode 100644
index 0000000..1cdc2d5
--- /dev/null
+++ b/internal/application/browsercdp/connectors.go
@@ -0,0 +1,79 @@
+package browsercdp
+
+import (
+ "context"
+ "sort"
+ "strings"
+
+ connectorsdto "dreamcreator/internal/application/connectors/dto"
+ appcookies "dreamcreator/internal/application/cookies"
+ "dreamcreator/internal/application/sitepolicy"
+)
+
+type ConnectorsReader interface {
+ ListConnectors(ctx context.Context) ([]connectorsdto.Connector, error)
+}
+
+type ConnectorCookieProvider interface {
+ ResolveCookiesForURL(ctx context.Context, rawURL string) ([]appcookies.Record, error)
+}
+
+type ConnectorCookieProviderFunc func(ctx context.Context, rawURL string) ([]appcookies.Record, error)
+
+func (fn ConnectorCookieProviderFunc) ResolveCookiesForURL(ctx context.Context, rawURL string) ([]appcookies.Record, error) {
+ if fn == nil {
+ return nil, nil
+ }
+ return fn(ctx, rawURL)
+}
+
+func ResolveConnectorCookiesForURL(ctx context.Context, connectors ConnectorsReader, rawURL string) ([]appcookies.Record, error) {
+ if connectors == nil {
+ return nil, nil
+ }
+ items, err := connectors.ListConnectors(ctx)
+ if err != nil {
+ return nil, err
+ }
+ for _, item := range items {
+ policy, ok := sitepolicy.ForConnectorType(item.Type)
+ if !ok || !sitepolicy.MatchDomains(rawURL, policy.Domains) {
+ continue
+ }
+ records := make([]appcookies.Record, 0, len(item.Cookies))
+ for _, cookie := range item.Cookies {
+ records = append(records, appcookies.Record{
+ Name: strings.TrimSpace(cookie.Name),
+ Value: cookie.Value,
+ Domain: strings.TrimSpace(cookie.Domain),
+ Path: strings.TrimSpace(cookie.Path),
+ Expires: cookie.Expires,
+ HttpOnly: cookie.HttpOnly,
+ Secure: cookie.Secure,
+ SameSite: strings.TrimSpace(cookie.SameSite),
+ })
+ }
+ sort.Slice(records, func(i, j int) bool {
+ left := records[i]
+ right := records[j]
+ switch {
+ case left.Domain != right.Domain:
+ return left.Domain < right.Domain
+ case left.Path != right.Path:
+ return left.Path < right.Path
+ default:
+ return left.Name < right.Name
+ }
+ })
+ return records, nil
+ }
+ return nil, nil
+}
+
+func ConnectorTypeForURL(rawURL string) string {
+ policy, ok := sitepolicy.ForURL(rawURL)
+ if !ok {
+ return ""
+ }
+ return policy.ConnectorType
+}
diff --git a/internal/application/browsercdp/connectors_test.go b/internal/application/browsercdp/connectors_test.go
new file mode 100644
index 0000000..7a7bb73
--- /dev/null
+++ b/internal/application/browsercdp/connectors_test.go
@@ -0,0 +1,84 @@
+package browsercdp
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ connectorsdto "dreamcreator/internal/application/connectors/dto"
+)
+
+type connectorsReaderStub struct {
+ items []connectorsdto.Connector
+ err error
+}
+
+func (stub connectorsReaderStub) ListConnectors(context.Context) ([]connectorsdto.Connector, error) {
+ if stub.err != nil {
+ return nil, stub.err
+ }
+ return append([]connectorsdto.Connector(nil), stub.items...), nil
+}
+
+func TestResolveConnectorCookiesForURL_MatchesConnectorPolicy(t *testing.T) {
+ t.Parallel()
+
+ cookies, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{
+ items: []connectorsdto.Connector{
+ {
+ Type: "google",
+ Cookies: []connectorsdto.ConnectorCookie{
+ {Name: "SID", Value: "google-cookie", Domain: ".google.com", Path: "/"},
+ },
+ },
+ {
+ Type: "github",
+ Cookies: []connectorsdto.ConnectorCookie{
+ {Name: "logged_in", Value: "yes", Domain: ".github.com", Path: "/"},
+ },
+ },
+ },
+ }, "https://www.youtube.com/watch?v=test")
+ if err != nil {
+ t.Fatalf("resolve cookies: %v", err)
+ }
+ if len(cookies) != 1 {
+ t.Fatalf("expected 1 cookie, got %d", len(cookies))
+ }
+ if cookies[0].Name != "SID" || cookies[0].Value != "google-cookie" {
+ t.Fatalf("unexpected cookie: %#v", cookies[0])
+ }
+}
+
+func TestResolveConnectorCookiesForURL_ReturnsNilWhenNoMatch(t *testing.T) {
+ t.Parallel()
+
+ cookies, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{
+ items: []connectorsdto.Connector{
+ {
+ Type: "github",
+ Cookies: []connectorsdto.ConnectorCookie{
+ {Name: "logged_in", Value: "yes", Domain: ".github.com", Path: "/"},
+ },
+ },
+ },
+ }, "https://example.com/")
+ if err != nil {
+ t.Fatalf("resolve cookies: %v", err)
+ }
+ if len(cookies) != 0 {
+ t.Fatalf("expected no cookies, got %#v", cookies)
+ }
+}
+
+func TestResolveConnectorCookiesForURL_PropagatesReaderError(t *testing.T) {
+ t.Parallel()
+
+ expectedErr := errors.New("connectors unavailable")
+ _, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{
+ err: expectedErr,
+ }, "https://github.com/openai")
+ if !errors.Is(err, expectedErr) {
+ t.Fatalf("expected %v, got %v", expectedErr, err)
+ }
+}
diff --git a/internal/application/browsercdp/cookies.go b/internal/application/browsercdp/cookies.go
new file mode 100644
index 0000000..26c5335
--- /dev/null
+++ b/internal/application/browsercdp/cookies.go
@@ -0,0 +1,117 @@
+package browsercdp
+
+import (
+ "context"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/chromedp/cdproto/cdp"
+ "github.com/chromedp/cdproto/network"
+ "github.com/chromedp/cdproto/storage"
+
+ appcookies "dreamcreator/internal/application/cookies"
+)
+
+func SetCookies(ctx context.Context, targetURL string, records []appcookies.Record) error {
+ params := buildCookieParams(targetURL, records)
+ if len(params) == 0 {
+ return nil
+ }
+ return network.SetCookies(params).Do(ctx)
+}
+
+func SetCookiesOnBrowser(ctx context.Context, targetURL string, records []appcookies.Record) error {
+ params := buildCookieParams(targetURL, records)
+ if len(params) == 0 {
+ return nil
+ }
+ return storage.SetCookies(params).Do(ctx)
+}
+
+func GetAllCookies(ctx context.Context) ([]appcookies.Record, error) {
+ items, err := network.GetCookies().Do(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return mapCDPCookies(items), nil
+}
+
+func GetStorageCookies(ctx context.Context) ([]appcookies.Record, error) {
+ items, err := storage.GetCookies().Do(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return mapCDPCookies(items), nil
+}
+
+func buildCookieParams(targetURL string, records []appcookies.Record) []*network.CookieParam {
+ if len(records) == 0 {
+ return nil
+ }
+ params := make([]*network.CookieParam, 0, len(records))
+ for _, record := range records {
+ if strings.TrimSpace(record.Name) == "" {
+ continue
+ }
+ param := &network.CookieParam{
+ Name: strings.TrimSpace(record.Name),
+ Value: record.Value,
+ Domain: strings.TrimSpace(record.Domain),
+ Path: strings.TrimSpace(record.Path),
+ HTTPOnly: record.HttpOnly,
+ Secure: record.Secure,
+ }
+ if param.Path == "" {
+ param.Path = "/"
+ }
+ if record.Expires > 0 {
+ expires := cdp.TimeSinceEpoch(time.Unix(record.Expires, 0))
+ param.Expires = &expires
+ }
+ if param.Domain == "" {
+ if parsed, err := url.Parse(strings.TrimSpace(targetURL)); err == nil && parsed.Hostname() != "" {
+ param.URL = strings.TrimSpace(parsed.Scheme) + "://" + parsed.Hostname()
+ }
+ }
+ switch strings.ToLower(strings.TrimSpace(record.SameSite)) {
+ case "lax":
+ param.SameSite = network.CookieSameSiteLax
+ case "strict":
+ param.SameSite = network.CookieSameSiteStrict
+ case "none":
+ param.SameSite = network.CookieSameSiteNone
+ }
+ params = append(params, param)
+ }
+ return params
+}
+
+func mapCDPCookies(items []*network.Cookie) []appcookies.Record {
+ result := make([]appcookies.Record, 0, len(items))
+ for _, item := range items {
+ if item == nil {
+ continue
+ }
+ sameSite := ""
+ switch item.SameSite {
+ case network.CookieSameSiteLax:
+ sameSite = "lax"
+ case network.CookieSameSiteStrict:
+ sameSite = "strict"
+ case network.CookieSameSiteNone:
+ sameSite = "none"
+ }
+ result = append(result, appcookies.Record{
+ Name: item.Name,
+ Value: item.Value,
+ Domain: item.Domain,
+ Path: item.Path,
+ Expires: int64(item.Expires),
+ HttpOnly: item.HTTPOnly,
+ Secure: item.Secure,
+ SameSite: sameSite,
+ })
+ }
+ return result
+}
diff --git a/internal/application/browsercdp/detect.go b/internal/application/browsercdp/detect.go
new file mode 100644
index 0000000..bf8333f
--- /dev/null
+++ b/internal/application/browsercdp/detect.go
@@ -0,0 +1,295 @@
+package browsercdp
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync"
+ "time"
+)
+
+type BrowserID string
+
+const (
+ BrowserChrome BrowserID = "chrome"
+ BrowserChromium BrowserID = "chromium"
+ BrowserEdge BrowserID = "edge"
+ BrowserBrave BrowserID = "brave"
+)
+
+type Candidate struct {
+ ID BrowserID `json:"id"`
+ Label string `json:"label"`
+ ExecPath string `json:"execPath,omitempty"`
+ Available bool `json:"available"`
+ Error string `json:"error,omitempty"`
+}
+
+var (
+ detectCandidatesCacheMu sync.RWMutex
+ detectCandidatesCache []Candidate
+ detectCandidatesCacheExpiresAt time.Time
+ detectCandidatesCacheLoaded bool
+ detectCandidatesCacheTTL = 5 * time.Second
+ detectCandidatesNow = time.Now
+ detectCandidatesScan = scanCandidates
+)
+
+func DetectCandidates() []Candidate {
+ now := detectCandidatesNow()
+ detectCandidatesCacheMu.RLock()
+ if detectCandidatesCacheLoaded && now.Before(detectCandidatesCacheExpiresAt) {
+ cached := cloneCandidates(detectCandidatesCache)
+ detectCandidatesCacheMu.RUnlock()
+ return cached
+ }
+ detectCandidatesCacheMu.RUnlock()
+
+ detectCandidatesCacheMu.Lock()
+ defer detectCandidatesCacheMu.Unlock()
+
+ now = detectCandidatesNow()
+ if detectCandidatesCacheLoaded && now.Before(detectCandidatesCacheExpiresAt) {
+ return cloneCandidates(detectCandidatesCache)
+ }
+
+ result := detectCandidatesScan()
+ detectCandidatesCache = cloneCandidates(result)
+ detectCandidatesCacheExpiresAt = now.Add(detectCandidatesCacheTTL)
+ detectCandidatesCacheLoaded = true
+ return cloneCandidates(result)
+}
+
+func ChooseCandidate(candidates []Candidate, preferred string) (Candidate, bool) {
+ preferredID := BrowserID(strings.ToLower(strings.TrimSpace(preferred)))
+ if preferredID != "" {
+ for _, candidate := range candidates {
+ if candidate.ID == preferredID && candidate.Available {
+ return candidate, true
+ }
+ }
+ }
+ for _, candidate := range candidates {
+ if candidate.Available {
+ return candidate, true
+ }
+ }
+ return Candidate{}, false
+}
+
+func CheckCDPReady(ctx context.Context, host string, port int) error {
+ if port <= 0 {
+ return fmt.Errorf("invalid cdp port")
+ }
+ if strings.TrimSpace(host) == "" {
+ host = "127.0.0.1"
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://%s:%d/json/version", host, port), nil)
+ if err != nil {
+ return err
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("unexpected cdp status %d", resp.StatusCode)
+ }
+ return nil
+}
+
+func detectCandidate(id BrowserID) Candidate {
+ candidate := Candidate{ID: id, Label: labelForID(id)}
+ for _, path := range candidatesForID(id) {
+ resolved := resolveExecutable(path)
+ if strings.TrimSpace(resolved) == "" {
+ continue
+ }
+ candidate.ExecPath = resolved
+ candidate.Available = true
+ candidate.Error = ""
+ return candidate
+ }
+ candidate.Error = "browser executable not found"
+ return candidate
+}
+
+func scanCandidates() []Candidate {
+ order := []BrowserID{BrowserChrome, BrowserChromium, BrowserEdge, BrowserBrave}
+ result := make([]Candidate, 0, len(order))
+ for _, id := range order {
+ result = append(result, detectCandidate(id))
+ }
+ return result
+}
+
+func cloneCandidates(source []Candidate) []Candidate {
+ if len(source) == 0 {
+ return []Candidate{}
+ }
+ result := make([]Candidate, len(source))
+ copy(result, source)
+ return result
+}
+
+func resetDetectCandidatesCache() {
+ detectCandidatesCacheMu.Lock()
+ defer detectCandidatesCacheMu.Unlock()
+ detectCandidatesCache = nil
+ detectCandidatesCacheExpiresAt = time.Time{}
+ detectCandidatesCacheLoaded = false
+}
+
+func labelForID(id BrowserID) string {
+ switch id {
+ case BrowserChrome:
+ return "Chrome"
+ case BrowserChromium:
+ return "Chromium"
+ case BrowserEdge:
+ return "Edge"
+ case BrowserBrave:
+ return "Brave"
+ default:
+ return strings.Title(string(id))
+ }
+}
+
+func candidatesForID(id BrowserID) []string {
+ switch runtime.GOOS {
+ case "darwin":
+ switch id {
+ case BrowserChrome:
+ return []string{
+ "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome",
+ "/Applications/Google Chrome Beta.app/Contents/MacOS/Google Chrome Beta",
+ "/Applications/Google Chrome Dev.app/Contents/MacOS/Google Chrome Dev",
+ "/Applications/Google Chrome Canary.app/Contents/MacOS/Google Chrome Canary",
+ }
+ case BrowserChromium:
+ return []string{
+ "/Applications/Chromium.app/Contents/MacOS/Chromium",
+ }
+ case BrowserEdge:
+ return []string{
+ "/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge",
+ "/Applications/Microsoft Edge Beta.app/Contents/MacOS/Microsoft Edge Beta",
+ "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Dev",
+ }
+ case BrowserBrave:
+ return []string{
+ "/Applications/Brave Browser.app/Contents/MacOS/Brave Browser",
+ "/Applications/Brave Browser Beta.app/Contents/MacOS/Brave Browser Beta",
+ }
+ }
+ case "windows":
+ localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA"))
+ programFiles := strings.TrimSpace(os.Getenv("ProgramFiles"))
+ programFilesX86 := strings.TrimSpace(os.Getenv("ProgramFiles(x86)"))
+ switch id {
+ case BrowserChrome:
+ return compact([]string{
+ filepath.Join(programFiles, "Google", "Chrome", "Application", "chrome.exe"),
+ filepath.Join(programFilesX86, "Google", "Chrome", "Application", "chrome.exe"),
+ filepath.Join(localAppData, "Google", "Chrome", "Application", "chrome.exe"),
+ })
+ case BrowserChromium:
+ return compact([]string{
+ filepath.Join(programFiles, "Chromium", "Application", "chrome.exe"),
+ filepath.Join(programFilesX86, "Chromium", "Application", "chrome.exe"),
+ filepath.Join(localAppData, "Chromium", "Application", "chrome.exe"),
+ })
+ case BrowserEdge:
+ return compact([]string{
+ filepath.Join(programFiles, "Microsoft", "Edge", "Application", "msedge.exe"),
+ filepath.Join(programFilesX86, "Microsoft", "Edge", "Application", "msedge.exe"),
+ filepath.Join(localAppData, "Microsoft", "Edge", "Application", "msedge.exe"),
+ })
+ case BrowserBrave:
+ return compact([]string{
+ filepath.Join(programFiles, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"),
+ filepath.Join(programFilesX86, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"),
+ filepath.Join(localAppData, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"),
+ })
+ }
+ default:
+ switch id {
+ case BrowserChrome:
+ return []string{"google-chrome", "google-chrome-stable"}
+ case BrowserChromium:
+ return []string{"chromium-browser", "chromium"}
+ case BrowserEdge:
+ return []string{"microsoft-edge", "microsoft-edge-stable", "msedge"}
+ case BrowserBrave:
+ return []string{"brave-browser", "brave-browser-stable", "brave"}
+ }
+ }
+ return nil
+}
+
+func resolveExecutable(candidate string) string {
+ trimmed := strings.TrimSpace(candidate)
+ if trimmed == "" {
+ return ""
+ }
+ if filepath.IsAbs(trimmed) {
+ if fileInfo, err := os.Stat(trimmed); err == nil && !fileInfo.IsDir() {
+ return trimmed
+ }
+ return ""
+ }
+ resolved, err := exec.LookPath(trimmed)
+ if err != nil {
+ return ""
+ }
+ return resolved
+}
+
+func compact(values []string) []string {
+ result := make([]string, 0, len(values))
+ for _, value := range values {
+ trimmed := strings.TrimSpace(value)
+ if trimmed != "" {
+ result = append(result, trimmed)
+ }
+ }
+ return result
+}
+
+func WaitForCDP(ctx context.Context, host string, port int, timeout time.Duration) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if timeout <= 0 {
+ timeout = 8 * time.Second
+ }
+ deadline := time.Now().Add(timeout)
+ for {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ checkCtx, cancel := context.WithTimeout(ctx, 1200*time.Millisecond)
+ err := CheckCDPReady(checkCtx, host, port)
+ cancel()
+ if err == nil {
+ return nil
+ }
+ if time.Now().After(deadline) {
+ if ctxErr := ctx.Err(); ctxErr != nil {
+ return ctxErr
+ }
+ return err
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-time.After(250 * time.Millisecond):
+ }
+ }
+}
diff --git a/internal/application/browsercdp/detect_test.go b/internal/application/browsercdp/detect_test.go
new file mode 100644
index 0000000..2506b76
--- /dev/null
+++ b/internal/application/browsercdp/detect_test.go
@@ -0,0 +1,71 @@
+package browsercdp
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestDetectCandidatesCachesScanUntilTTLExpires(t *testing.T) {
+ t.Parallel()
+
+ originalNow := detectCandidatesNow
+ originalScan := detectCandidatesScan
+ originalTTL := detectCandidatesCacheTTL
+ current := time.Unix(1_700_000_000, 0)
+ scanCalls := 0
+
+ detectCandidatesNow = func() time.Time { return current }
+ detectCandidatesScan = func() []Candidate {
+ scanCalls += 1
+ return []Candidate{
+ {
+ ID: BrowserChrome,
+ Label: "Chrome",
+ ExecPath: "/tmp/chrome",
+ Available: true,
+ },
+ }
+ }
+ detectCandidatesCacheTTL = time.Minute
+ resetDetectCandidatesCache()
+ t.Cleanup(func() {
+ detectCandidatesNow = originalNow
+ detectCandidatesScan = originalScan
+ detectCandidatesCacheTTL = originalTTL
+ resetDetectCandidatesCache()
+ })
+
+ first := DetectCandidates()
+ if scanCalls != 1 {
+ t.Fatalf("expected first detect to scan once, got %d", scanCalls)
+ }
+ first[0].ExecPath = "/tmp/changed"
+
+ second := DetectCandidates()
+ if scanCalls != 1 {
+ t.Fatalf("expected second detect to use cache, got %d scans", scanCalls)
+ }
+ if got := second[0].ExecPath; got != "/tmp/chrome" {
+ t.Fatalf("expected cached detect result to be cloned, got %q", got)
+ }
+
+ current = current.Add(time.Minute + time.Second)
+ _ = DetectCandidates()
+ if scanCalls != 2 {
+ t.Fatalf("expected detect to rescan after ttl, got %d scans", scanCalls)
+ }
+}
+
+func TestWaitForCDPHonorsCancelledContext(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err := WaitForCDP(ctx, "127.0.0.1", 1, time.Second)
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("expected context canceled, got %v", err)
+ }
+}
diff --git a/internal/application/browsercdp/runtime.go b/internal/application/browsercdp/runtime.go
new file mode 100644
index 0000000..9ebd632
--- /dev/null
+++ b/internal/application/browsercdp/runtime.go
@@ -0,0 +1,352 @@
+package browsercdp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/chromedp/chromedp"
+ "go.uber.org/zap"
+)
+
+type LaunchOptions struct {
+ PreferredBrowser string
+ Headless bool
+ NoSandbox bool
+ ExtraArgs []string
+ UserDataDir string
+}
+
+type Status struct {
+ Ready bool `json:"ready"`
+ Candidates []Candidate `json:"candidates,omitempty"`
+ SelectedBrowser string `json:"selectedBrowser,omitempty"`
+ ChosenBrowser string `json:"chosenBrowser,omitempty"`
+ DetectedExecutablePath string `json:"detectedExecutablePath,omitempty"`
+ DetectError string `json:"detectError,omitempty"`
+ CDPURL string `json:"cdpUrl,omitempty"`
+ CDPPort int `json:"cdpPort,omitempty"`
+ Headless bool `json:"headless"`
+}
+
+type Runtime struct {
+ mu sync.Mutex
+
+ options LaunchOptions
+ candidate Candidate
+ status Status
+
+ cmd *exec.Cmd
+ userDataDir string
+ allocCtx context.Context
+ allocCancel context.CancelFunc
+ browserCtx context.Context
+ browserCancel context.CancelFunc
+ stopping bool
+ stopped chan struct{}
+}
+
+type versionResponse struct {
+ WebSocketDebuggerURL string `json:"webSocketDebuggerUrl"`
+}
+
+func ResolveStatus(preferred string, headless bool) Status {
+ candidates := DetectCandidates()
+ status := Status{
+ Candidates: candidates,
+ SelectedBrowser: strings.TrimSpace(preferred),
+ Headless: headless,
+ }
+ candidate, ok := ChooseCandidate(candidates, preferred)
+ if !ok {
+ status.DetectError = "no supported browser detected"
+ return status
+ }
+ status.ChosenBrowser = string(candidate.ID)
+ status.DetectedExecutablePath = candidate.ExecPath
+ status.Ready = candidate.Available
+ if !candidate.Available {
+ status.DetectError = candidate.Error
+ }
+ return status
+}
+
+func Start(ctx context.Context, options LaunchOptions) (*Runtime, error) {
+ candidates := DetectCandidates()
+ candidate, ok := ChooseCandidate(candidates, options.PreferredBrowser)
+ if !ok {
+ return nil, fmt.Errorf("no supported browser detected")
+ }
+
+ port, err := reservePort()
+ if err != nil {
+ return nil, err
+ }
+ userDataDir := strings.TrimSpace(options.UserDataDir)
+ if userDataDir == "" {
+ userDataDir = filepath.Join(os.TempDir(), "dreamcreator", "browsercdp", string(candidate.ID))
+ }
+ if err := os.MkdirAll(userDataDir, 0o755); err != nil {
+ return nil, err
+ }
+
+ args := []string{
+ fmt.Sprintf("--remote-debugging-port=%d", port),
+ fmt.Sprintf("--user-data-dir=%s", userDataDir),
+ "--no-first-run",
+ "--no-default-browser-check",
+ "--disable-background-networking",
+ "--disable-background-timer-throttling",
+ "--disable-backgrounding-occluded-windows",
+ "--disable-breakpad",
+ "--disable-client-side-phishing-detection",
+ "--disable-default-apps",
+ "--disable-features=Translate,OptimizationHints,MediaRouter,AutomationControlled",
+ "--disable-hang-monitor",
+ "--disable-popup-blocking",
+ "--disable-prompt-on-repost",
+ "--disable-sync",
+ "--metrics-recording-only",
+ "--password-store=basic",
+ "--use-mock-keychain",
+ }
+ if options.Headless {
+ args = append([]string{"--headless=new", "--hide-scrollbars", "--mute-audio"}, args...)
+ } else {
+ args = append([]string{"--no-startup-window"}, args...)
+ }
+ if options.NoSandbox {
+ args = append([]string{"--no-sandbox"}, args...)
+ }
+ for _, extra := range options.ExtraArgs {
+ if trimmed := strings.TrimSpace(extra); trimmed != "" {
+ args = append(args, trimmed)
+ }
+ }
+
+ cmd := exec.Command(candidate.ExecPath, args...)
+ cmd.Stdout = io.Discard
+ cmd.Stderr = io.Discard
+ cmd.SysProcAttr = &syscall.SysProcAttr{}
+ zap.L().Info(
+ "browser runtime launch started",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Bool("headless", options.Headless),
+ zap.String("userDataDir", userDataDir),
+ zap.Int("cdpPort", port),
+ )
+ if err := cmd.Start(); err != nil {
+ zap.L().Warn(
+ "browser runtime launch failed",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Error(err),
+ )
+ return nil, err
+ }
+
+ if err := WaitForCDP(ctx, "127.0.0.1", port, 10*time.Second); err != nil {
+ zap.L().Warn(
+ "browser runtime cdp wait failed",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Int("cdpPort", port),
+ zap.Error(err),
+ )
+ _ = cmd.Process.Kill()
+ return nil, err
+ }
+ zap.L().Info(
+ "browser runtime cdp ready",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Int("cdpPort", port),
+ )
+ wsURL, err := fetchWebSocketURL(ctx, port)
+ if err != nil {
+ zap.L().Warn(
+ "browser runtime websocket resolve failed",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Int("cdpPort", port),
+ zap.Error(err),
+ )
+ _ = cmd.Process.Kill()
+ return nil, err
+ }
+ allocCtx, allocCancel := chromedp.NewRemoteAllocator(context.Background(), wsURL)
+ browserCtx, browserCancel := chromedp.NewContext(allocCtx)
+ if _, err := chromedp.Targets(browserCtx); err != nil {
+ zap.L().Warn(
+ "browser runtime chromedp attach failed",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.String("cdpUrl", wsURL),
+ zap.Error(err),
+ )
+ browserCancel()
+ allocCancel()
+ _ = cmd.Process.Kill()
+ return nil, err
+ }
+ zap.L().Info(
+ "browser runtime ready",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.String("cdpUrl", wsURL),
+ zap.Int("cdpPort", port),
+ )
+
+ runtime := &Runtime{
+ options: options,
+ candidate: candidate,
+ cmd: cmd,
+ userDataDir: userDataDir,
+ allocCtx: allocCtx,
+ allocCancel: allocCancel,
+ browserCtx: browserCtx,
+ browserCancel: browserCancel,
+ stopped: make(chan struct{}),
+ status: Status{
+ Ready: true,
+ Candidates: candidates,
+ SelectedBrowser: strings.TrimSpace(options.PreferredBrowser),
+ ChosenBrowser: string(candidate.ID),
+ DetectedExecutablePath: candidate.ExecPath,
+ CDPURL: wsURL,
+ CDPPort: port,
+ Headless: options.Headless,
+ },
+ }
+
+ go func() {
+ _ = cmd.Wait()
+ runtime.mu.Lock()
+ runtime.status.Ready = false
+ stopping := runtime.stopping
+ if !stopping && runtime.status.DetectError == "" {
+ runtime.status.DetectError = "browser process exited"
+ }
+ runtime.mu.Unlock()
+ if !stopping {
+ zap.L().Warn(
+ "browser runtime exited unexpectedly",
+ zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)),
+ zap.String("chosenBrowser", string(candidate.ID)),
+ zap.String("execPath", candidate.ExecPath),
+ zap.Int("cdpPort", port),
+ )
+ }
+ close(runtime.stopped)
+ }()
+
+ return runtime, nil
+}
+
+func (runtime *Runtime) BrowserContext() context.Context {
+ runtime.mu.Lock()
+ defer runtime.mu.Unlock()
+ return runtime.browserCtx
+}
+
+func (runtime *Runtime) UserDataDir() string {
+ runtime.mu.Lock()
+ defer runtime.mu.Unlock()
+ return runtime.userDataDir
+}
+
+func (runtime *Runtime) Candidate() Candidate {
+ runtime.mu.Lock()
+ defer runtime.mu.Unlock()
+ return runtime.candidate
+}
+
+func (runtime *Runtime) Status() Status {
+ runtime.mu.Lock()
+ defer runtime.mu.Unlock()
+ return runtime.status
+}
+
+func (runtime *Runtime) Stop() {
+ if runtime == nil {
+ return
+ }
+ runtime.mu.Lock()
+ cmd := runtime.cmd
+ browserCancel := runtime.browserCancel
+ allocCancel := runtime.allocCancel
+ stopped := runtime.stopped
+ runtime.stopping = true
+ runtime.status.Ready = false
+ runtime.mu.Unlock()
+
+ if browserCancel != nil {
+ browserCancel()
+ }
+ if allocCancel != nil {
+ allocCancel()
+ }
+ if cmd != nil && cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ if stopped != nil {
+ select {
+ case <-stopped:
+ case <-time.After(2 * time.Second):
+ }
+ }
+}
+
+func fetchWebSocketURL(ctx context.Context, port int) (string, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/json/version", port), nil)
+ if err != nil {
+ return "", err
+ }
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("unexpected cdp status %d", resp.StatusCode)
+ }
+ var payload versionResponse
+ if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+ return "", err
+ }
+ if strings.TrimSpace(payload.WebSocketDebuggerURL) == "" {
+ return "", fmt.Errorf("webSocketDebuggerUrl missing")
+ }
+ return strings.TrimSpace(payload.WebSocketDebuggerURL), nil
+}
+
+func reservePort() (int, error) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return 0, err
+ }
+ defer listener.Close()
+ addr, ok := listener.Addr().(*net.TCPAddr)
+ if !ok {
+ return 0, fmt.Errorf("failed to reserve tcp port")
+ }
+ return addr.Port, nil
+}
diff --git a/internal/application/browsercdp/session.go b/internal/application/browsercdp/session.go
new file mode 100644
index 0000000..c9eff3c
--- /dev/null
+++ b/internal/application/browsercdp/session.go
@@ -0,0 +1,3300 @@
+package browsercdp
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/chromedp/cdproto/cdp"
+ "github.com/chromedp/cdproto/emulation"
+ "github.com/chromedp/cdproto/fetch"
+ "github.com/chromedp/cdproto/network"
+ pagepkg "github.com/chromedp/cdproto/page"
+ targetpkg "github.com/chromedp/cdproto/target"
+ "github.com/chromedp/chromedp"
+ "go.uber.org/zap"
+
+ appcookies "dreamcreator/internal/application/cookies"
+)
+
+const (
+ defaultSnapshotLimit = 200
+ defaultSSRFValidationTimeout = 2 * time.Second
+)
+
+var lookupIPAddrsForHost = func(ctx context.Context, host string) ([]net.IPAddr, error) {
+ return net.DefaultResolver.LookupIPAddr(ctx, host)
+}
+
+type SSRFPolicy struct {
+ DangerouslyAllowPrivateNetwork bool
+ AllowedHostnames map[string]struct{}
+ HostnameAllowlist []string
+}
+
+type SessionOptions struct {
+ SessionKey string
+ ProfileName string
+ PreferredBrowser string
+ Headless bool
+ UserDataDir string
+ SSRFRules SSRFPolicy
+ Cookies ConnectorCookieProvider
+}
+
+type SessionRegistry struct {
+ mu sync.Mutex
+ sessions map[string]map[string]*Session
+}
+
+type Session struct {
+ mu sync.Mutex
+
+ options SessionOptions
+ runtime *Runtime
+
+ tabs map[string]*sessionTab
+ activeTarget string
+
+ pendingDialogs map[string]PendingDialog
+ cookieSync map[string]string
+}
+
+type sessionTab struct {
+ TargetID string
+ ctx context.Context
+ cancel context.CancelFunc
+
+ mu sync.RWMutex
+ cleanupCancels []context.CancelFunc
+ refs map[string]snapshotRef
+ evaluateResult any
+ lastURL string
+ title string
+ lastState *PageState
+ stateVersion uint64
+ nextRefID uint64
+ blockedRequestErr string
+ fetchEnabled bool
+}
+
+type newTabWaiter struct {
+ ctx context.Context
+ ids <-chan targetpkg.ID
+ cancel context.CancelFunc
+ stop func() bool
+}
+
+func (waiter *newTabWaiter) close() {
+ if waiter == nil {
+ return
+ }
+ if waiter.stop != nil {
+ waiter.stop()
+ }
+ if waiter.cancel != nil {
+ waiter.cancel()
+ }
+}
+
+type snapshotRef struct {
+ Selector string
+ Role string
+ Name string
+ Nth int
+}
+
+func clearBlockedRequestError(tab *sessionTab) {
+ if tab == nil {
+ return
+ }
+ tab.mu.Lock()
+ tab.blockedRequestErr = ""
+ tab.mu.Unlock()
+}
+
+func setBlockedRequestError(tab *sessionTab, err error) {
+ if tab == nil || err == nil {
+ return
+ }
+ message := strings.TrimSpace(err.Error())
+ if message == "" {
+ return
+ }
+ tab.mu.Lock()
+ if tab.blockedRequestErr == "" {
+ tab.blockedRequestErr = message
+ }
+ tab.mu.Unlock()
+}
+
+func peekBlockedRequestError(tab *sessionTab) error {
+ if tab == nil {
+ return nil
+ }
+ tab.mu.RLock()
+ message := strings.TrimSpace(tab.blockedRequestErr)
+ tab.mu.RUnlock()
+ if message == "" {
+ return nil
+ }
+ return errors.New(message)
+}
+
+func consumeBlockedRequestError(tab *sessionTab) error {
+ if tab == nil {
+ return nil
+ }
+ tab.mu.Lock()
+ message := strings.TrimSpace(tab.blockedRequestErr)
+ tab.blockedRequestErr = ""
+ tab.mu.Unlock()
+ if message == "" {
+ return nil
+ }
+ return errors.New(message)
+}
+
+type snapshotCapture struct {
+ Version uint64
+ URL string
+ Title string
+ Items []SnapshotItem
+ Refs map[string]snapshotRef
+ Truncated bool
+ ViewportOnly bool
+}
+
+type SnapshotItem struct {
+ Ref string `json:"ref,omitempty"`
+ Role string `json:"role,omitempty"`
+ Name string `json:"name,omitempty"`
+ Text string `json:"text,omitempty"`
+ Depth int `json:"depth,omitempty"`
+ Nth int `json:"nth,omitempty"`
+}
+
+type PageState struct {
+ Version uint64 `json:"version"`
+ URL string `json:"url"`
+ Title string `json:"title,omitempty"`
+ Items []SnapshotItem `json:"items"`
+ ItemCount int `json:"itemCount"`
+ Truncated bool `json:"truncated"`
+ ViewportOnly bool `json:"viewportOnly"`
+ CapturedAt string `json:"capturedAt"`
+}
+
+type PendingDialog struct {
+ Message string `json:"message,omitempty"`
+ Type string `json:"type,omitempty"`
+ ExpiresAt time.Time `json:"expiresAt"`
+}
+
+type ScrollDelta struct {
+ X int `json:"x"`
+ Y int `json:"y"`
+}
+
+type ActionResult struct {
+ OK bool `json:"ok"`
+ TargetID string `json:"targetId,omitempty"`
+ URL string `json:"url,omitempty"`
+ Title string `json:"title,omitempty"`
+ StateVersion uint64 `json:"stateVersion,omitempty"`
+ State *PageState `json:"state,omitempty"`
+ Items []SnapshotItem `json:"items,omitempty"`
+ Action string `json:"action,omitempty"`
+ OpenedNewTab bool `json:"openedNewTab,omitempty"`
+ PreviousTargetID string `json:"previousTargetId,omitempty"`
+ PreviousURL string `json:"previousURL,omitempty"`
+ Navigated bool `json:"navigated,omitempty"`
+ Waited bool `json:"waited,omitempty"`
+ Scroll *ScrollDelta `json:"scroll,omitempty"`
+ Paths []string `json:"paths,omitempty"`
+ Result any `json:"result,omitempty"`
+ Pending *PendingDialog `json:"pending,omitempty"`
+ StateAvailable bool `json:"stateAvailable"`
+ StateError string `json:"stateError,omitempty"`
+ Reset bool `json:"reset,omitempty"`
+ Restarted bool `json:"restarted,omitempty"`
+ Ready bool `json:"ready,omitempty"`
+ Closed bool `json:"closed,omitempty"`
+}
+
+type CommandOptions struct {
+ Limit int
+ Timeout time.Duration
+ WaitFor *WaitRequest
+}
+
+type WaitRequest struct {
+ Time time.Duration
+ Selector string
+ Text string
+ TextGone string
+ URL string
+ Fn string
+ Timeout time.Duration
+}
+
+type ScrollRequest struct {
+ TargetID string
+ Ref string
+ DeltaX int
+ DeltaY int
+ Limit int
+ Timeout time.Duration
+}
+
+type UploadRequest struct {
+ TargetID string
+ Ref string
+ Paths []string
+ Limit int
+ Timeout time.Duration
+}
+
+type DialogRequest struct {
+ TargetID string
+ Accept *bool
+ PromptText string
+ Limit int
+ Timeout time.Duration
+}
+
+type ActRequest struct {
+ Kind string
+ TargetID string
+ Ref string
+ Text string
+ Key string
+ Value string
+ Expression string
+ Width int
+ Height int
+ Wait WaitRequest
+ WaitFor *WaitRequest
+ Limit int
+ Timeout time.Duration
+}
+
+type FatalError struct {
+ Err error
+}
+
+func (err *FatalError) Error() string {
+ if err == nil || err.Err == nil {
+ return "browser runtime unavailable"
+ }
+ return err.Err.Error()
+}
+
+func (err *FatalError) Unwrap() error {
+ if err == nil {
+ return nil
+ }
+ return err.Err
+}
+
+type InvalidRefError struct {
+ Ref string
+}
+
+func (err *InvalidRefError) Error() string {
+ if err == nil || strings.TrimSpace(err.Ref) == "" {
+ return "ref not found; run snapshot again to get fresh refs"
+ }
+ return fmt.Sprintf("ref %q not found; run snapshot again to get fresh refs", strings.TrimSpace(err.Ref))
+}
+
+type ConnectorCookieError struct {
+ URL string
+ Err error
+}
+
+func (err *ConnectorCookieError) Error() string {
+ if err == nil || err.Err == nil {
+ return "connector cookie sync failed"
+ }
+ return fmt.Sprintf("connector cookie sync failed for %s: %v", strings.TrimSpace(err.URL), err.Err)
+}
+
+func (err *ConnectorCookieError) Unwrap() error {
+ if err == nil {
+ return nil
+ }
+ return err.Err
+}
+
+type WaitTimeoutError struct {
+ Condition string
+}
+
+func (err *WaitTimeoutError) Error() string {
+ if err == nil || strings.TrimSpace(err.Condition) == "" {
+ return "wait timeout"
+ }
+ return fmt.Sprintf("wait %s timeout", strings.TrimSpace(err.Condition))
+}
+
+func NewSessionRegistry() *SessionRegistry {
+ return &SessionRegistry{
+ sessions: map[string]map[string]*Session{},
+ }
+}
+
+func (registry *SessionRegistry) GetOrCreate(sessionKey string, profileName string, options SessionOptions) *Session {
+ registry.mu.Lock()
+ defer registry.mu.Unlock()
+
+ sessionKey = strings.TrimSpace(sessionKey)
+ if sessionKey == "" {
+ sessionKey = "default"
+ }
+ profileName = strings.TrimSpace(profileName)
+ if profileName == "" {
+ profileName = "dreamcreator"
+ }
+ bucket, ok := registry.sessions[sessionKey]
+ if !ok {
+ bucket = map[string]*Session{}
+ registry.sessions[sessionKey] = bucket
+ }
+ session, ok := bucket[profileName]
+ if !ok {
+ options.SessionKey = sessionKey
+ options.ProfileName = profileName
+ session = &Session{
+ options: normalizeSessionOptions(options),
+ tabs: map[string]*sessionTab{},
+ pendingDialogs: map[string]PendingDialog{},
+ cookieSync: map[string]string{},
+ }
+ bucket[profileName] = session
+ return session
+ }
+ session.mu.Lock()
+ session.options = normalizeSessionOptions(options)
+ session.mu.Unlock()
+ return session
+}
+
+func (registry *SessionRegistry) CloseSessionKey(sessionKey string) {
+ if registry == nil {
+ return
+ }
+ sessionKey = strings.TrimSpace(sessionKey)
+ if sessionKey == "" {
+ sessionKey = "default"
+ }
+ registry.mu.Lock()
+ bucket := registry.sessions[sessionKey]
+ delete(registry.sessions, sessionKey)
+ sessions := make([]*Session, 0, len(bucket))
+ for _, session := range bucket {
+ if session != nil {
+ sessions = append(sessions, session)
+ }
+ }
+ registry.mu.Unlock()
+ for _, session := range sessions {
+ session.stop()
+ }
+}
+
+func (registry *SessionRegistry) CloseAll() {
+ if registry == nil {
+ return
+ }
+ registry.mu.Lock()
+ sessions := make([]*Session, 0)
+ for sessionKey, bucket := range registry.sessions {
+ delete(registry.sessions, sessionKey)
+ for _, session := range bucket {
+ if session != nil {
+ sessions = append(sessions, session)
+ }
+ }
+ }
+ registry.mu.Unlock()
+ for _, session := range sessions {
+ session.stop()
+ }
+}
+
+func IsFatalError(err error) bool {
+ if err == nil {
+ return false
+ }
+ var fatal *FatalError
+ if errors.As(err, &fatal) {
+ return true
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "browser runtime unavailable"),
+ strings.Contains(message, "context canceled"),
+ strings.Contains(message, "target closed"),
+ strings.Contains(message, "connection closed"),
+ strings.Contains(message, "websocket"),
+ strings.Contains(message, "session closed"),
+ strings.Contains(message, "browser session reset"):
+ return true
+ default:
+ return false
+ }
+}
+
+func (session *Session) Reset(restart bool) (ActionResult, error) {
+ session.stop()
+ if restart {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ }
+ return ActionResult{
+ OK: true,
+ Reset: true,
+ Restarted: restart,
+ Ready: restart,
+ }, nil
+}
+
+func (session *Session) Open(ctx context.Context, targetURL string, options CommandOptions) (ActionResult, error) {
+ openStartedAt := time.Now()
+ if err := session.assertURLAllowed(targetURL); err != nil {
+ zap.L().Warn(
+ "browser open blocked by url policy",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, err
+ }
+ if err := session.ensureStarted(); err != nil {
+ zap.L().Warn(
+ "browser open start runtime failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ cookiesStartedAt := time.Now()
+ if err := session.ensureCookiesForURLOnBrowser(ctx, targetURL); err != nil {
+ if isRecoverableCookieSyncError(err) {
+ zap.L().Warn(
+ "browser open cookie sync failed; retrying on fresh runtime",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.Error(err),
+ )...,
+ )
+ session.stop()
+ if retryErr := session.ensureStarted(); retryErr == nil {
+ err = session.ensureCookiesForURLOnBrowser(ctx, targetURL)
+ } else {
+ err = retryErr
+ }
+ }
+ zap.L().Warn(
+ "browser open cookie sync failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.Duration("elapsed", time.Since(cookiesStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, err
+ }
+ createTabStartedAt := time.Now()
+ tab, err := session.createTab()
+ if err != nil {
+ zap.L().Warn(
+ "browser open create tab failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.Duration("elapsed", time.Since(createTabStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ cleanup := true
+ defer func() {
+ if cleanup {
+ zap.L().Warn(
+ "browser open cleaning up failed tab",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.String("targetId", tab.TargetID),
+ )...,
+ )
+ session.detachTab(tab.TargetID)
+ cancelTabContexts(tab)
+ }
+ }()
+ session.setActiveTarget(tab.TargetID)
+ clearBlockedRequestError(tab)
+ navigateTimeout := normalizeTimeout(options.Timeout, 30*time.Second)
+ navigateStartedAt := time.Now()
+ if err := session.openTab(tab, targetURL, navigateTimeout); err != nil {
+ zap.L().Warn(
+ "browser open navigation failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.String("targetId", tab.TargetID),
+ zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ cleanup = false
+ if options.WaitFor != nil {
+ waitTimeout := normalizeTimeout(options.WaitFor.Timeout, options.Timeout)
+ waitStartedAt := time.Now()
+ if err := session.waitOnTab(ctx, tab, *options.WaitFor, waitTimeout); err != nil {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ zap.L().Warn(
+ "browser open wait failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.String("targetId", tab.TargetID),
+ zap.Duration("elapsed", time.Since(waitStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ }
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ stateTimeout := captureTimeout(options.Timeout)
+ stateStartedAt := time.Now()
+ result, err := session.collectActionResult(tab, options.Limit, stateTimeout, false)
+ if err != nil {
+ zap.L().Warn(
+ "browser open state capture failed",
+ append(session.logFields(),
+ sanitizedURLField("url", targetURL),
+ zap.String("targetId", tab.TargetID),
+ zap.Duration("elapsed", time.Since(stateStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ result.OpenedNewTab = true
+ result.URL = preferredPageURL(result.URL, tabURL(tab), targetURL)
+ if result.State != nil {
+ result.State.URL = preferredPageURL(result.State.URL, result.URL)
+ }
+ zap.L().Info(
+ "browser open completed",
+ append(session.logFields(),
+ sanitizedURLField("requestedURL", targetURL),
+ zap.String("targetId", result.TargetID),
+ sanitizedURLField("finalURL", result.URL),
+ zap.Bool("stateAvailable", result.State != nil || result.StateAvailable),
+ zap.Uint64("stateVersion", result.StateVersion),
+ zap.Int("itemCount", resultItemCount(result)),
+ zap.String("stateError", strings.TrimSpace(result.StateError)),
+ zap.Duration("elapsed", time.Since(openStartedAt).Round(time.Millisecond)),
+ )...,
+ )
+ return result, nil
+}
+
+func (session *Session) Navigate(ctx context.Context, targetID string, targetURL string, newTab bool, options CommandOptions) (ActionResult, error) {
+ navigateStartedAt := time.Now()
+ if err := session.assertURLAllowed(targetURL); err != nil {
+ return ActionResult{}, err
+ }
+ if newTab {
+ return session.Open(ctx, targetURL, options)
+ }
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(targetID, true)
+ if err != nil {
+ if targetID == "" && errors.Is(err, errNoOpenTab) {
+ return session.Open(ctx, targetURL, options)
+ }
+ return ActionResult{}, session.wrapError(err)
+ }
+ clearBlockedRequestError(tab)
+ if err := session.ensureCookiesForURL(ctx, tab, targetURL); err != nil {
+ if isRecoverableCookieSyncError(err) {
+ zap.L().Warn(
+ "browser navigate cookie sync failed; restarting runtime and falling back to open",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ zap.Error(err),
+ )...,
+ )
+ session.stop()
+ recovered, recoveredErr := session.Open(ctx, targetURL, options)
+ if recoveredErr == nil {
+ recovered.OpenedNewTab = true
+ return recovered, nil
+ }
+ return ActionResult{}, recoveredErr
+ }
+ return ActionResult{}, err
+ }
+ tab, err = session.navigateTab(tab, targetURL, normalizeTimeout(options.Timeout, 30*time.Second))
+ if err != nil {
+ currentTargetID := ""
+ if tab != nil {
+ currentTargetID = tab.TargetID
+ }
+ zap.L().Warn(
+ "browser navigate failed before state capture",
+ append(session.logFields(),
+ zap.String("targetId", currentTargetID),
+ sanitizedURLField("url", targetURL),
+ zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(wrapRuntimeHangError(err))
+ }
+ if options.WaitFor != nil {
+ if err := session.waitOnTab(ctx, tab, *options.WaitFor, normalizeTimeout(options.WaitFor.Timeout, options.Timeout)); err != nil {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ zap.L().Warn(
+ "browser navigate wait failed",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ }
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ stateTimeout := captureTimeout(options.Timeout)
+ result, err := session.collectActionResult(tab, options.Limit, stateTimeout, false)
+ if err != nil {
+ zap.L().Warn(
+ "browser navigate state capture failed",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ if isTargetLookupError(err) {
+ zap.L().Warn(
+ "browser navigate falling back to open after target loss",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ )...,
+ )
+ session.stop()
+ recovered, recoveredErr := session.Open(ctx, targetURL, options)
+ if recoveredErr == nil {
+ recovered.OpenedNewTab = true
+ return recovered, nil
+ }
+ zap.L().Warn(
+ "browser navigate recovery open failed after runtime restart",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ zap.Error(recoveredErr),
+ )...,
+ )
+ }
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(tab, result, options.Limit, stateTimeout, minDuration(2*time.Second, normalizeTimeout(options.Timeout, 30*time.Second)/2))
+ if shouldRetryActionState(result) && strings.TrimSpace(result.URL) == strings.TrimSpace(targetURL) {
+ zap.L().Warn(
+ "browser navigate state remained unstable after stabilization; falling back to open",
+ append(session.logFields(),
+ zap.String("targetId", result.TargetID),
+ sanitizedURLField("url", targetURL),
+ zap.String("reason", actionStateReason(result)),
+ )...,
+ )
+ fallback, openErr := session.Open(ctx, targetURL, options)
+ if openErr == nil {
+ fallback.OpenedNewTab = true
+ return fallback, nil
+ }
+ session.stop()
+ recovered, recoveredErr := session.Open(ctx, targetURL, options)
+ if recoveredErr == nil {
+ recovered.OpenedNewTab = true
+ return recovered, nil
+ }
+ }
+ result.OpenedNewTab = false
+ zap.L().Info(
+ "browser navigate completed",
+ append(session.logFields(),
+ zap.String("targetId", result.TargetID),
+ sanitizedURLField("requestedURL", targetURL),
+ sanitizedURLField("finalURL", result.URL),
+ zap.Bool("stateAvailable", result.StateAvailable),
+ zap.Uint64("stateVersion", result.StateVersion),
+ zap.Int("itemCount", resultItemCount(result)),
+ zap.String("stateError", strings.TrimSpace(result.StateError)),
+ zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)),
+ )...,
+ )
+ return result, nil
+}
+
+func (session *Session) State(targetID string, limit int) (ActionResult, error) {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(targetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ result, err := session.collectActionResult(tab, limit, 10*time.Second, false)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ return result, nil
+}
+
+func (session *Session) Wait(ctx context.Context, targetID string, request WaitRequest, options CommandOptions) (ActionResult, error) {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(targetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ if err := session.ensureRequestInterception(tab); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ if err := session.waitOnTab(ctx, tab, request, normalizeTimeout(request.Timeout, options.Timeout)); err != nil {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ return ActionResult{}, session.wrapError(err)
+ }
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ result, err := session.collectActionResult(tab, options.Limit, captureTimeout(options.Timeout), false)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(tab, result, options.Limit, captureTimeout(options.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(options.Timeout, 15*time.Second)/2))
+ result.Waited = true
+ return result, nil
+}
+
+func (session *Session) Scroll(request ScrollRequest) (ActionResult, error) {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(request.TargetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second))
+ defer cancel()
+ if strings.TrimSpace(request.Ref) != "" {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return ActionResult{}, err
+ }
+ script := fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.scrollIntoView({block: "center", inline: "center"}); if (%d !== 0 || %d !== 0) { el.scrollBy(%d, %d); } })()`, selector, request.DeltaX, request.DeltaY, request.DeltaX, request.DeltaY)
+ if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, nil)); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ } else {
+ script := fmt.Sprintf(`window.scrollBy(%d, %d)`, request.DeltaX, request.DeltaY)
+ if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, nil)); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ }
+ time.Sleep(200 * time.Millisecond)
+ result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2))
+ result.Scroll = &ScrollDelta{X: request.DeltaX, Y: request.DeltaY}
+ return result, nil
+}
+
+func (session *Session) Upload(request UploadRequest) (ActionResult, error) {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(request.TargetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return ActionResult{}, err
+ }
+ if len(request.Paths) == 0 {
+ return ActionResult{}, errors.New("paths are required")
+ }
+ runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second))
+ defer cancel()
+ if err := chromedp.Run(runCtx, chromedp.SetUploadFiles(selector, request.Paths, chromedp.ByQuery)); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ clearTabState(tab)
+ result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2))
+ result.Paths = append([]string(nil), request.Paths...)
+ return result, nil
+}
+
+func (session *Session) Dialog(request DialogRequest) (ActionResult, error) {
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(request.TargetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ targetID := tab.TargetID
+ session.mu.Lock()
+ pending, exists := session.pendingDialogs[targetID]
+ session.mu.Unlock()
+ if !exists {
+ return ActionResult{
+ OK: true,
+ TargetID: targetID,
+ }, nil
+ }
+ if request.Accept == nil {
+ dialog := pending
+ return ActionResult{
+ OK: true,
+ TargetID: targetID,
+ Pending: &dialog,
+ }, nil
+ }
+ runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second))
+ defer cancel()
+ if err := chromedp.Run(runCtx, chromedp.ActionFunc(func(ctx context.Context) error {
+ return pagepkg.HandleJavaScriptDialog(*request.Accept).WithPromptText(request.PromptText).Do(ctx)
+ })); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ session.mu.Lock()
+ delete(session.pendingDialogs, targetID)
+ session.mu.Unlock()
+ clearTabState(tab)
+ result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2))
+ return result, nil
+}
+
+func (session *Session) Act(ctx context.Context, request ActRequest) (ActionResult, error) {
+ actStartedAt := time.Now()
+ if err := session.ensureStarted(); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ tab, err := session.resolveTab(request.TargetID, true)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ if err := session.ensureRequestInterception(tab); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ previousTargetID := tab.TargetID
+ previousURL := tabURL(tab)
+ clearBlockedRequestError(tab)
+ beforeTargets, _ := session.snapshotPageTargets()
+ var newTabWaiter *newTabWaiter
+ if actMayOpenNewTab(request.Kind) {
+ newTabWaiter, err = session.prepareNewTabWaiter(ctx, tab)
+ if err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ defer newTabWaiter.close()
+ }
+ switch request.Kind {
+ case "click":
+ err = session.actClick(tab, request)
+ case "type":
+ err = session.actType(tab, request)
+ case "press":
+ err = session.actPress(tab, request)
+ case "hover":
+ err = session.actHover(tab, request)
+ case "select":
+ err = session.actSelect(tab, request)
+ case "fill":
+ err = session.actFill(tab, request)
+ case "resize":
+ err = session.actResize(tab, request)
+ case "wait":
+ err = session.waitOnTab(ctx, tab, request.Wait, normalizeTimeout(request.Wait.Timeout, request.Timeout))
+ case "evaluate":
+ err = session.actEvaluate(tab, request)
+ case "close":
+ _, err = session.closeTab(tab.TargetID, normalizeTimeout(request.Timeout, 15*time.Second))
+ default:
+ err = fmt.Errorf("act kind not supported: %s", strings.TrimSpace(request.Kind))
+ }
+ if err != nil {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ zap.L().Warn(
+ "browser act command failed",
+ append(session.logFields(),
+ zap.String("kind", strings.TrimSpace(request.Kind)),
+ zap.String("targetId", tab.TargetID),
+ zap.String("ref", strings.TrimSpace(request.Ref)),
+ zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ if request.Kind != "wait" {
+ err = wrapRuntimeHangError(err)
+ }
+ return ActionResult{}, session.wrapError(err)
+ }
+ if request.Kind == "close" {
+ return ActionResult{
+ OK: true,
+ TargetID: previousTargetID,
+ Closed: true,
+ }, nil
+ }
+ if actInvalidatesState(request.Kind) {
+ clearTabState(tab)
+ }
+ currentTab := tab
+ openedNewTab := false
+ if newTabWaiter != nil {
+ if detectedTab, ok := session.waitForNewTab(newTabWaiter, 1500*time.Millisecond); ok && detectedTab != nil {
+ currentTab = detectedTab
+ openedNewTab = detectedTab.TargetID != previousTargetID
+ }
+ }
+ if request.WaitFor != nil {
+ if err := session.waitOnTab(ctx, currentTab, *request.WaitFor, normalizeTimeout(request.WaitFor.Timeout, request.Timeout)); err != nil {
+ zap.L().Warn(
+ "browser act wait failed",
+ append(session.logFields(),
+ zap.String("kind", strings.TrimSpace(request.Kind)),
+ zap.String("targetId", currentTab.TargetID),
+ zap.String("ref", strings.TrimSpace(request.Ref)),
+ zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ } else if actNeedsSettle(request.Kind) {
+ if err := sleepWithContext(ctx, 250*time.Millisecond); err != nil {
+ return ActionResult{}, session.wrapError(err)
+ }
+ }
+ if request.WaitFor == nil {
+ switch request.Kind {
+ case "click", "press":
+ if navigatedTab, observed, observeErr := session.observeActionNavigation(currentTab, beforeTargets, previousURL, 2*time.Second); observeErr != nil {
+ return ActionResult{}, session.wrapError(observeErr)
+ } else if observed && navigatedTab != nil {
+ currentTab = navigatedTab
+ }
+ }
+ }
+ if blockedErr := consumeBlockedRequestError(currentTab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ stateTimeout := captureTimeout(request.Timeout)
+ result, err := session.collectActionResult(currentTab, request.Limit, stateTimeout, false)
+ if err != nil {
+ zap.L().Warn(
+ "browser act state capture failed",
+ append(session.logFields(),
+ zap.String("kind", strings.TrimSpace(request.Kind)),
+ zap.String("targetId", currentTab.TargetID),
+ zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )...,
+ )
+ return ActionResult{}, session.wrapError(err)
+ }
+ result = session.stabilizeActionResult(currentTab, result, request.Limit, stateTimeout, minDuration(2*time.Second, normalizeTimeout(request.Timeout, 20*time.Second)/2))
+ result.Action = request.Kind
+ result.OpenedNewTab = openedNewTab
+ result.PreviousTargetID = previousTargetID
+ result.PreviousURL = previousURL
+ result.Navigated = openedNewTab || !urlsEqual(previousURL, result.URL)
+ if request.Kind == "evaluate" {
+ result.Result = evaluateResult(tab)
+ }
+ zap.L().Info(
+ "browser act completed",
+ append(session.logFields(),
+ zap.String("kind", strings.TrimSpace(request.Kind)),
+ zap.String("targetId", result.TargetID),
+ sanitizedURLField("previousURL", previousURL),
+ sanitizedURLField("finalURL", result.URL),
+ zap.Bool("openedNewTab", openedNewTab),
+ zap.Bool("navigated", result.Navigated),
+ zap.Bool("stateAvailable", result.StateAvailable),
+ zap.Uint64("stateVersion", result.StateVersion),
+ zap.Int("itemCount", resultItemCount(result)),
+ zap.String("stateError", strings.TrimSpace(result.StateError)),
+ zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)),
+ )...,
+ )
+ return result, nil
+}
+
+var errNoOpenTab = errors.New("no browser tab is open")
+
+func normalizeSessionOptions(options SessionOptions) SessionOptions {
+ options.SessionKey = strings.TrimSpace(options.SessionKey)
+ if options.SessionKey == "" {
+ options.SessionKey = "default"
+ }
+ options.ProfileName = strings.TrimSpace(options.ProfileName)
+ if options.ProfileName == "" {
+ options.ProfileName = "dreamcreator"
+ }
+ options.PreferredBrowser = strings.ToLower(strings.TrimSpace(options.PreferredBrowser))
+ if options.UserDataDir == "" {
+ options.UserDataDir = ResolveProfileUserDataDir(options.SessionKey, options.ProfileName)
+ }
+ if options.SSRFRules.AllowedHostnames == nil {
+ options.SSRFRules.AllowedHostnames = map[string]struct{}{}
+ }
+ return options
+}
+
+func (session *Session) optionsSnapshot() SessionOptions {
+ session.mu.Lock()
+ defer session.mu.Unlock()
+ return session.options
+}
+
+func (session *Session) ensureStarted() error {
+ session.mu.Lock()
+ if session.runtime != nil && session.runtime.Status().Ready {
+ session.mu.Unlock()
+ return nil
+ }
+ options := session.options
+ session.mu.Unlock()
+
+ startCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+ runtime, err := Start(startCtx, LaunchOptions{
+ PreferredBrowser: options.PreferredBrowser,
+ Headless: options.Headless,
+ UserDataDir: options.UserDataDir,
+ })
+ if err != nil {
+ return err
+ }
+ session.mu.Lock()
+ if session.runtime != nil && session.runtime.Status().Ready {
+ session.mu.Unlock()
+ runtime.Stop()
+ return nil
+ }
+ session.runtime = runtime
+ session.mu.Unlock()
+ return nil
+}
+
+func (session *Session) stop() {
+ session.mu.Lock()
+ runtime := session.runtime
+ tabs := make([]*sessionTab, 0, len(session.tabs))
+ for _, tab := range session.tabs {
+ tabs = append(tabs, tab)
+ }
+ session.runtime = nil
+ session.tabs = map[string]*sessionTab{}
+ session.activeTarget = ""
+ session.pendingDialogs = map[string]PendingDialog{}
+ session.cookieSync = map[string]string{}
+ session.mu.Unlock()
+
+ for _, tab := range tabs {
+ cancelTabContexts(tab)
+ }
+ if runtime != nil {
+ runtime.Stop()
+ }
+}
+
+func (session *Session) wrapError(err error) error {
+ if err == nil {
+ return nil
+ }
+ if IsFatalError(err) {
+ session.stop()
+ if _, ok := err.(*FatalError); ok {
+ return err
+ }
+ return &FatalError{Err: err}
+ }
+ return err
+}
+
+func (session *Session) createTab() (*sessionTab, error) {
+ if err := session.ensureStarted(); err != nil {
+ return nil, err
+ }
+ reusableTargetID, err := session.resolveReusableTargetID()
+ if err != nil {
+ return nil, err
+ }
+ if strings.TrimSpace(reusableTargetID) != "" {
+ return session.attachTargetAsTab(reusableTargetID, false)
+ }
+ targetID, err := session.createBlankTarget()
+ if err != nil {
+ return nil, err
+ }
+ return session.attachTargetAsTab(targetID, false)
+}
+
+func (session *Session) createBlankTarget() (string, error) {
+ options := session.optionsSnapshot()
+ runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second)
+ if err != nil {
+ return "", err
+ }
+ defer cancel()
+
+ createTarget := targetpkg.CreateTarget("about:blank")
+ if !options.Headless {
+ createTarget = createTarget.WithNewWindow(true)
+ }
+ createdTargetID, err := createTarget.Do(runCtx)
+ if err != nil {
+ return "", wrapRuntimeHangError(err)
+ }
+ targetID := strings.TrimSpace(string(createdTargetID))
+ if targetID == "" {
+ return "", errors.New("create target returned empty target id")
+ }
+ return targetID, nil
+}
+
+func (session *Session) openTab(tab *sessionTab, targetURL string, timeout time.Duration) error {
+ if session == nil || tab == nil {
+ return errors.New("tab unavailable")
+ }
+ if err := session.ensureRequestInterception(tab); err != nil {
+ return err
+ }
+ runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(timeout, 30*time.Second))
+ defer cancel()
+
+ var currentURL string
+ var title string
+ if err := chromedp.Run(runCtx,
+ chromedp.Navigate(strings.TrimSpace(targetURL)),
+ chromedp.Location(¤tURL),
+ chromedp.Title(&title),
+ ); err != nil {
+ return wrapRuntimeHangError(err)
+ }
+
+ finalURL := preferredPageURL(currentURL, targetURL)
+ if err := session.assertObservedURLAllowed(finalURL); err != nil {
+ return err
+ }
+ storeNavigation(tab, finalURL)
+ tab.mu.Lock()
+ if title = strings.TrimSpace(title); title != "" {
+ tab.title = title
+ }
+ tab.mu.Unlock()
+ return nil
+}
+
+func (session *Session) attachTargetAsTab(targetID string, createNew bool) (*sessionTab, error) {
+ if err := session.ensureStarted(); err != nil {
+ return nil, err
+ }
+ session.mu.Lock()
+ if trimmed := strings.TrimSpace(targetID); trimmed != "" {
+ if existing, ok := session.tabs[trimmed]; ok {
+ session.activeTarget = existing.TargetID
+ session.mu.Unlock()
+ return existing, nil
+ }
+ }
+ runtime := session.runtime
+ session.mu.Unlock()
+ if runtime == nil {
+ return nil, errors.New("browser runtime unavailable")
+ }
+ var (
+ tabCtx context.Context
+ cancel context.CancelFunc
+ )
+ trimmedTargetID := strings.TrimSpace(targetID)
+ if !createNew && trimmedTargetID != "" {
+ tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(trimmedTargetID)))
+ } else {
+ tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext())
+ }
+ tab := &sessionTab{
+ TargetID: trimmedTargetID,
+ ctx: tabCtx,
+ cancel: cancel,
+ refs: map[string]snapshotRef{},
+ }
+ if err := chromedp.Run(tabCtx, chromedp.ActionFunc(func(ctx context.Context) error {
+ chromeCtx := chromedp.FromContext(ctx)
+ if chromeCtx == nil || chromeCtx.Target == nil {
+ return errors.New("tab target unavailable")
+ }
+ resolvedTargetID := string(chromeCtx.Target.TargetID)
+ if strings.TrimSpace(resolvedTargetID) == "" {
+ return errors.New("tab target unavailable")
+ }
+ if strings.TrimSpace(tab.TargetID) == "" {
+ tab.TargetID = resolvedTargetID
+ return nil
+ }
+ if !strings.EqualFold(strings.TrimSpace(tab.TargetID), strings.TrimSpace(resolvedTargetID)) {
+ return fmt.Errorf("tab target mismatch: expected %s got %s", strings.TrimSpace(tab.TargetID), strings.TrimSpace(resolvedTargetID))
+ }
+ return nil
+ })); err != nil {
+ cancel()
+ return nil, wrapRuntimeHangError(err)
+ }
+ activateCtx, activateCancel, activateErr := session.newBrowserExecutorContext(5 * time.Second)
+ if activateErr == nil {
+ _ = targetpkg.ActivateTarget(targetpkg.ID(tab.TargetID)).Do(activateCtx)
+ activateCancel()
+ }
+ if err := session.attachTab(tab); err != nil {
+ cancel()
+ return nil, err
+ }
+ session.mu.Lock()
+ session.tabs[tab.TargetID] = tab
+ session.activeTarget = tab.TargetID
+ session.mu.Unlock()
+ return tab, nil
+}
+
+func (session *Session) attachTab(tab *sessionTab) error {
+ chromedp.ListenTarget(tab.ctx, func(ev any) {
+ switch event := ev.(type) {
+ case *pagepkg.EventJavascriptDialogOpening:
+ session.mu.Lock()
+ session.pendingDialogs[tab.TargetID] = PendingDialog{
+ Message: strings.TrimSpace(event.Message),
+ Type: string(event.Type),
+ ExpiresAt: time.Now().Add(5 * time.Minute),
+ }
+ session.mu.Unlock()
+ case *fetch.EventRequestPaused:
+ go session.handlePausedRequest(tab, event)
+ }
+ })
+ return nil
+}
+
+func (session *Session) enableRequestInterception(tab *sessionTab) error {
+ return session.runOnTab(tab, 5*time.Second, chromedp.ActionFunc(func(ctx context.Context) error {
+ return fetch.Enable().WithPatterns([]*fetch.RequestPattern{
+ {
+ URLPattern: "*",
+ RequestStage: fetch.RequestStageRequest,
+ },
+ }).Do(ctx)
+ }))
+}
+
+func (session *Session) ensureRequestInterception(tab *sessionTab) error {
+ if tab == nil {
+ return errors.New("tab unavailable")
+ }
+ tab.mu.RLock()
+ enabled := tab.fetchEnabled
+ tab.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+ if err := session.enableRequestInterception(tab); err != nil {
+ return err
+ }
+ tab.mu.Lock()
+ tab.fetchEnabled = true
+ tab.mu.Unlock()
+ return nil
+}
+
+func (session *Session) handlePausedRequest(tab *sessionTab, event *fetch.EventRequestPaused) {
+ if tab == nil || event == nil {
+ return
+ }
+ requestURL := ""
+ if event.Request != nil {
+ requestURL = strings.TrimSpace(event.Request.URL)
+ }
+ err := session.runOnTab(tab, 5*time.Second, chromedp.ActionFunc(func(ctx context.Context) error {
+ if event.Request == nil {
+ return fetch.ContinueRequest(event.RequestID).Do(ctx)
+ }
+ options := session.optionsSnapshot()
+ if err := assertRequestURLAllowed(ctx, requestURL, options.SSRFRules); err != nil {
+ blockedErr := fmt.Errorf("blocked by SSRF policy for %s: %w", requestURL, err)
+ setBlockedRequestError(tab, blockedErr)
+ zap.L().Warn(
+ "browser request blocked by ssrf policy",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", requestURL),
+ zap.Error(err),
+ )...,
+ )
+ return fetch.FailRequest(event.RequestID, network.ErrorReasonBlockedByClient).Do(ctx)
+ }
+ return fetch.ContinueRequest(event.RequestID).Do(ctx)
+ }))
+ if err != nil {
+ zap.L().Warn(
+ "browser request interception handling failed",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", requestURL),
+ zap.Error(err),
+ )...,
+ )
+ }
+}
+
+func (session *Session) resolveTab(targetID string, allowActive bool) (*sessionTab, error) {
+ session.mu.Lock()
+ defer session.mu.Unlock()
+ targetID = strings.TrimSpace(targetID)
+ if targetID != "" {
+ tab, ok := session.tabs[targetID]
+ if !ok {
+ return nil, errors.New("tab not found")
+ }
+ return tab, nil
+ }
+ if allowActive && session.activeTarget != "" {
+ if tab, ok := session.tabs[session.activeTarget]; ok {
+ return tab, nil
+ }
+ }
+ for _, tab := range session.tabs {
+ return tab, nil
+ }
+ return nil, errNoOpenTab
+}
+
+func (session *Session) closeTab(targetID string, timeout time.Duration) (ActionResult, error) {
+ tab, err := session.resolveTab(targetID, false)
+ if err != nil {
+ return ActionResult{}, err
+ }
+ cancelTabContexts(tab)
+ session.mu.Lock()
+ runtime := session.runtime
+ session.mu.Unlock()
+ if runtime != nil {
+ runCtx, cancel, err := session.newBrowserExecutorContext(normalizeTimeout(timeout, 15*time.Second))
+ if err == nil {
+ _ = targetpkg.CloseTarget(targetpkg.ID(tab.TargetID)).Do(runCtx)
+ cancel()
+ }
+ }
+ session.detachTab(tab.TargetID)
+ return ActionResult{
+ OK: true,
+ TargetID: tab.TargetID,
+ Closed: true,
+ }, nil
+}
+
+func (session *Session) detachTab(targetID string) {
+ session.mu.Lock()
+ defer session.mu.Unlock()
+ targetID = strings.TrimSpace(targetID)
+ if targetID == "" {
+ return
+ }
+ delete(session.tabs, targetID)
+ delete(session.pendingDialogs, targetID)
+ if session.activeTarget == targetID {
+ session.activeTarget = ""
+ for _, tab := range session.tabs {
+ session.activeTarget = tab.TargetID
+ break
+ }
+ }
+}
+
+func (session *Session) browserPageTargets(timeout time.Duration) (map[string]*targetpkg.Info, error) {
+ execCtx, cancel, err := session.newBrowserExecutorContext(timeout)
+ if err != nil {
+ return nil, err
+ }
+ defer cancel()
+ infos, err := targetpkg.GetTargets().Do(execCtx)
+ if err != nil {
+ return nil, err
+ }
+ result := map[string]*targetpkg.Info{}
+ for _, info := range infos {
+ if info == nil || info.Type != "page" {
+ continue
+ }
+ result[string(info.TargetID)] = info
+ }
+ return result, nil
+}
+
+func (session *Session) setActiveTarget(targetID string) {
+ session.mu.Lock()
+ defer session.mu.Unlock()
+ session.activeTarget = strings.TrimSpace(targetID)
+}
+
+func (session *Session) snapshotPageTargets() (map[string]*targetpkg.Info, error) {
+ return session.browserPageTargets(3 * time.Second)
+}
+
+func (session *Session) prepareNewTabWaiter(parent context.Context, tab *sessionTab) (*newTabWaiter, error) {
+ if tab == nil {
+ return nil, errors.New("tab unavailable")
+ }
+ tab.mu.RLock()
+ baseCtx := tab.ctx
+ tab.mu.RUnlock()
+ if baseCtx == nil {
+ return nil, errors.New("tab context unavailable")
+ }
+ waitCtx, cancel := context.WithCancel(baseCtx)
+ var stop func() bool
+ if parent != nil {
+ if err := parent.Err(); err != nil {
+ cancel()
+ return nil, err
+ }
+ stop = context.AfterFunc(parent, cancel)
+ }
+ return &newTabWaiter{
+ ctx: waitCtx,
+ ids: chromedp.WaitNewTarget(waitCtx, func(info *targetpkg.Info) bool {
+ return info != nil && info.Type == "page"
+ }),
+ cancel: cancel,
+ stop: stop,
+ }, nil
+}
+
+func (session *Session) waitForNewTab(waiter *newTabWaiter, timeout time.Duration) (*sessionTab, bool) {
+ if waiter == nil {
+ return nil, false
+ }
+ timeout = normalizeTimeout(timeout, 1500*time.Millisecond)
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
+ select {
+ case <-waiter.ctx.Done():
+ return nil, false
+ case <-timer.C:
+ return nil, false
+ case targetID, ok := <-waiter.ids:
+ if !ok {
+ return nil, false
+ }
+ trimmedTargetID := strings.TrimSpace(string(targetID))
+ if trimmedTargetID == "" {
+ return nil, false
+ }
+ tab, err := session.attachTargetAsTab(trimmedTargetID, false)
+ if err != nil {
+ return nil, false
+ }
+ return tab, true
+ }
+}
+
+func (session *Session) resolveReusableTargetID() (string, error) {
+ session.mu.Lock()
+ runtime := session.runtime
+ managedTabs := len(session.tabs)
+ session.mu.Unlock()
+ if runtime == nil || managedTabs > 0 {
+ return "", nil
+ }
+ infos, err := session.browserPageTargets(3 * time.Second)
+ if err != nil {
+ return "", err
+ }
+ return pickReusableTargetID(mapTargetInfos(infos)), nil
+}
+
+func (session *Session) assertURLAllowed(rawURL string) error {
+ options := session.optionsSnapshot()
+ return assertNavigationURLAllowed(rawURL, options.SSRFRules)
+}
+
+func AssertURLAllowed(rawURL string, policy SSRFPolicy) error {
+ return assertNavigationURLAllowed(rawURL, policy)
+}
+
+func assertNavigationURLAllowed(rawURL string, policy SSRFPolicy) error {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultSSRFValidationTimeout)
+ defer cancel()
+ return assertURLAllowedWithSchemes(ctx, rawURL, policy, map[string]struct{}{
+ "http": {},
+ "https": {},
+ }, true)
+}
+
+func assertRequestURLAllowed(ctx context.Context, rawURL string, policy SSRFPolicy) error {
+ return assertURLAllowedWithSchemes(ctx, rawURL, policy, map[string]struct{}{
+ "http": {},
+ "https": {},
+ "ws": {},
+ "wss": {},
+ }, false)
+}
+
+func assertURLAllowedWithSchemes(ctx context.Context, rawURL string, policy SSRFPolicy, allowedSchemes map[string]struct{}, rejectUnsupportedScheme bool) error {
+ trimmed := strings.TrimSpace(rawURL)
+ if trimmed == "" {
+ return errors.New("targetUrl is required")
+ }
+ parsed, err := url.Parse(trimmed)
+ if err != nil {
+ return fmt.Errorf("invalid url: %w", err)
+ }
+ scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
+ if _, ok := allowedSchemes[scheme]; !ok {
+ if rejectUnsupportedScheme {
+ return errors.New("only http(s) urls are supported")
+ }
+ return nil
+ }
+ return assertParsedURLAllowed(ctx, parsed, policy)
+}
+
+func assertParsedURLAllowed(ctx context.Context, parsed *url.URL, policy SSRFPolicy) error {
+ if parsed == nil {
+ return errors.New("invalid url")
+ }
+ hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ if hostname == "" {
+ return errors.New("url hostname is required")
+ }
+ if isHostnameAllowed(hostname, policy) {
+ return nil
+ }
+ if policy.DangerouslyAllowPrivateNetwork {
+ return nil
+ }
+ if hostname == "localhost" || strings.HasSuffix(hostname, ".local") || strings.HasSuffix(hostname, ".internal") {
+ return fmt.Errorf("blocked private hostname: %s", hostname)
+ }
+ if ip := net.ParseIP(hostname); ip != nil {
+ if isPrivateOrLocalIP(ip) {
+ return fmt.Errorf("blocked private IP: %s", hostname)
+ }
+ return nil
+ }
+ if ctx == nil {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(context.Background(), defaultSSRFValidationTimeout)
+ defer cancel()
+ }
+ records, err := lookupIPAddrsForHost(ctx, hostname)
+ if err != nil {
+ return fmt.Errorf("could not validate hostname %s: %w", hostname, err)
+ }
+ if len(records) == 0 {
+ return fmt.Errorf("could not validate hostname %s: no IPs returned", hostname)
+ }
+ for _, record := range records {
+ if isPrivateOrLocalIP(record.IP) {
+ return fmt.Errorf("blocked hostname resolving to private IP: %s -> %s", hostname, record.IP.String())
+ }
+ }
+ return nil
+}
+
+func (session *Session) assertObservedURLAllowed(rawURL string) error {
+ trimmed := strings.TrimSpace(rawURL)
+ if trimmed == "" {
+ return nil
+ }
+ return session.assertURLAllowed(trimmed)
+}
+
+func (session *Session) ensureCookiesForURL(ctx context.Context, tab *sessionTab, targetURL string) error {
+ if session == nil || tab == nil {
+ return nil
+ }
+ options := session.optionsSnapshot()
+ if options.Cookies == nil {
+ return nil
+ }
+ cookies, err := options.Cookies.ResolveCookiesForURL(ctx, targetURL)
+ if err != nil {
+ return &ConnectorCookieError{URL: targetURL, Err: err}
+ }
+ if len(cookies) == 0 {
+ return nil
+ }
+ syncKeys := cookieSyncKeys(targetURL, cookies)
+ fingerprint := cookieFingerprint(cookies)
+ session.mu.Lock()
+ if hasCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) {
+ session.mu.Unlock()
+ return nil
+ }
+ session.mu.Unlock()
+ runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second)
+ if err != nil {
+ return &ConnectorCookieError{URL: targetURL, Err: err}
+ }
+ defer cancel()
+ if err := SetCookiesOnBrowser(runCtx, targetURL, cookies); err != nil {
+ return &ConnectorCookieError{URL: targetURL, Err: err}
+ }
+ session.mu.Lock()
+ rememberCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint)
+ session.mu.Unlock()
+ return nil
+}
+
+func (session *Session) ensureCookiesForURLOnBrowser(ctx context.Context, targetURL string) error {
+ options := session.optionsSnapshot()
+ if options.Cookies == nil {
+ return nil
+ }
+ cookies, err := options.Cookies.ResolveCookiesForURL(ctx, targetURL)
+ if err != nil {
+ return &ConnectorCookieError{URL: targetURL, Err: err}
+ }
+ if len(cookies) == 0 {
+ return nil
+ }
+ syncKeys := cookieSyncKeys(targetURL, cookies)
+ fingerprint := cookieFingerprint(cookies)
+ session.mu.Lock()
+ if hasCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) {
+ session.mu.Unlock()
+ return nil
+ }
+ runtime := session.runtime
+ session.mu.Unlock()
+ if runtime == nil {
+ return errors.New("browser runtime unavailable")
+ }
+ runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second)
+ if err != nil {
+ return err
+ }
+ defer cancel()
+ if err := SetCookiesOnBrowser(runCtx, targetURL, cookies); err != nil {
+ return &ConnectorCookieError{URL: targetURL, Err: err}
+ }
+ session.mu.Lock()
+ rememberCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint)
+ session.mu.Unlock()
+ return nil
+}
+
+func (session *Session) collectActionResult(tab *sessionTab, limit int, timeout time.Duration, stateRequired bool) (ActionResult, error) {
+ if err := consumeBlockedRequestError(tab); err != nil {
+ return ActionResult{}, err
+ }
+ pageState, err := session.collectPageState(tab, limit, timeout)
+ if err != nil && isTargetLookupError(err) {
+ rebound := session.rebindTabAfterNavigation(tab, tabURL(tab))
+ if rebound != nil {
+ tab = rebound
+ pageState, err = session.collectPageState(tab, limit, timeout)
+ }
+ } else if err != nil && shouldDeferStateCapture(err) {
+ rebound := session.rebindTabFromBrowserTargets(tab, tabURL(tab))
+ if rebound != nil && rebound.TargetID != "" && rebound.TargetID != tab.TargetID {
+ tab = rebound
+ pageState, err = session.collectPageState(tab, limit, timeout)
+ }
+ }
+ if err == nil {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ return resultFromPageState(tab, pageState), nil
+ }
+ if stateRequired || !shouldDeferStateCapture(err) {
+ return ActionResult{}, err
+ }
+ session.refreshTabMetadataFromBrowser(tab)
+ storedURL := tabURL(tab)
+ url, title := session.readPageMetadata(tab, 1200*time.Millisecond)
+ url = preferredPageURL(url, storedURL)
+ zap.L().Warn(
+ "browser state capture deferred",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", url),
+ zap.Duration("timeout", timeout),
+ zap.Error(err),
+ )...,
+ )
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return ActionResult{}, blockedErr
+ }
+ return ActionResult{
+ OK: true,
+ TargetID: tab.TargetID,
+ URL: url,
+ Title: title,
+ StateAvailable: false,
+ StateError: strings.TrimSpace(err.Error()),
+ }, nil
+}
+
+func (session *Session) collectPageState(tab *sessionTab, limit int, timeout time.Duration) (*PageState, error) {
+ snapshotCtx, snapshotCancel, err := newTabRunContext(tab, timeout)
+ if err != nil {
+ return nil, err
+ }
+ defer snapshotCancel()
+ capture, err := collectSnapshot(tab, snapshotCtx, limit, timeout)
+ if err != nil {
+ return nil, err
+ }
+ pageState := &PageState{
+ Version: capture.Version,
+ URL: capture.URL,
+ Title: capture.Title,
+ Items: capture.Items,
+ ItemCount: len(capture.Items),
+ Truncated: capture.Truncated,
+ ViewportOnly: capture.ViewportOnly,
+ CapturedAt: time.Now().Format(time.RFC3339),
+ }
+ tab.mu.Lock()
+ tab.refs = capture.Refs
+ tab.lastURL = capture.URL
+ tab.title = capture.Title
+ tab.lastState = pageState
+ tab.stateVersion = capture.Version
+ tab.mu.Unlock()
+ return pageState, nil
+}
+
+func resultFromPageState(tab *sessionTab, pageState *PageState) ActionResult {
+ result := ActionResult{
+ OK: true,
+ StateAvailable: true,
+ }
+ if tab != nil {
+ result.TargetID = tab.TargetID
+ }
+ if pageState != nil {
+ result.URL = pageState.URL
+ result.Title = pageState.Title
+ result.StateVersion = pageState.Version
+ result.State = pageState
+ result.Items = pageState.Items
+ }
+ return result
+}
+
+func collectSnapshot(tab *sessionTab, runBaseCtx context.Context, limit int, timeout time.Duration) (*snapshotCapture, error) {
+ if limit <= 0 {
+ limit = defaultSnapshotLimit
+ }
+ if timeout <= 0 {
+ timeout = 10 * time.Second
+ }
+ maxScan := limit * 20
+ if maxScan < 500 {
+ maxScan = 500
+ }
+ if maxScan > 3000 {
+ maxScan = 3000
+ }
+ timeBudgetMs := int(timeout / time.Millisecond / 2)
+ if timeBudgetMs < 250 {
+ timeBudgetMs = 250
+ }
+ if timeBudgetMs > 1500 {
+ timeBudgetMs = 1500
+ }
+ script := fmt.Sprintf(`(() => {
+ const limit = %d;
+ const maxScan = %d;
+ const timeBudgetMs = %d;
+ const startedAt = (typeof performance !== "undefined" && performance.now) ? performance.now() : Date.now();
+ let truncated = false;
+ const now = () => ((typeof performance !== "undefined" && performance.now) ? performance.now() : Date.now());
+ const inferRole = (el) => {
+ const explicit = (el.getAttribute("role") || "").trim();
+ if (explicit) return explicit.toLowerCase();
+ const tag = el.tagName.toLowerCase();
+ if (tag === "a") return "link";
+ if (tag === "button") return "button";
+ if (tag === "textarea" || tag === "input") return "textbox";
+ if (tag === "select") return "combobox";
+ return "element";
+ };
+ const cssPath = (el) => {
+ if (el.id) return "#" + CSS.escape(el.id);
+ const parts = [];
+ let node = el;
+ while (node && node.nodeType === 1 && parts.length < 6) {
+ let part = node.tagName.toLowerCase();
+ if (node.classList && node.classList.length > 0) {
+ part += "." + Array.from(node.classList).slice(0, 2).map((item) => CSS.escape(item)).join(".");
+ }
+ const parent = node.parentElement;
+ if (parent) {
+ const siblings = Array.from(parent.children).filter((candidate) => candidate.tagName === node.tagName);
+ if (siblings.length > 1) {
+ part += ":nth-of-type(" + (siblings.indexOf(node) + 1) + ")";
+ }
+ }
+ parts.unshift(part);
+ node = parent;
+ }
+ return parts.join(" > ");
+ };
+ const visible = (el) => {
+ const style = window.getComputedStyle(el);
+ const rect = el.getBoundingClientRect();
+ return style && style.visibility !== "hidden" && style.display !== "none" && rect.width > 0 && rect.height > 0;
+ };
+ const elements = document.querySelectorAll('a,button,input,textarea,select,summary,[role="button"],[role="link"],[role="menuitem"],[tabindex]:not([tabindex="-1"])');
+ const candidates = [];
+ const scanLimit = Math.min(elements.length, maxScan);
+ if (elements.length > scanLimit) {
+ truncated = true;
+ }
+ for (let index = 0; index < scanLimit; index += 1) {
+ if (candidates.length >= limit) {
+ truncated = true;
+ break;
+ }
+ if ((now() - startedAt) > timeBudgetMs) {
+ truncated = true;
+ break;
+ }
+ const el = elements[index];
+ if (!el) continue;
+ try {
+ if (!visible(el)) continue;
+ const text = ((el.value || el.textContent || "") + "").replace(/\s+/g, " ").trim();
+ const name = ((el.getAttribute("aria-label") || el.getAttribute("title") || el.placeholder || text) + "").replace(/\s+/g, " ").trim();
+ candidates.push({
+ selector: cssPath(el),
+ role: inferRole(el),
+ name,
+ text,
+ });
+ } catch (_err) {
+ truncated = true;
+ }
+ }
+ return {
+ url: document.location.toString(),
+ title: document.title,
+ truncated,
+ candidates,
+ };
+ })()`, limit, maxScan, timeBudgetMs)
+ var payload struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Truncated bool `json:"truncated"`
+ Candidates []struct {
+ Selector string `json:"selector"`
+ Role string `json:"role"`
+ Name string `json:"name"`
+ Text string `json:"text"`
+ } `json:"candidates"`
+ }
+ runCtx, cancel := context.WithTimeout(runBaseCtx, timeout)
+ defer cancel()
+ captureStartedAt := time.Now()
+ if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, &payload)); err != nil {
+ zap.L().Warn(
+ "browser snapshot capture failed",
+ zap.String("targetId", tab.TargetID),
+ zap.Duration("elapsed", time.Since(captureStartedAt).Round(time.Millisecond)),
+ zap.Duration("timeout", timeout),
+ zap.Error(err),
+ )
+ return nil, err
+ }
+ currentURL := strings.TrimSpace(payload.URL)
+ title := strings.TrimSpace(payload.Title)
+ version, refsAllocated := allocateStateRefs(tab, len(payload.Candidates))
+ items := make([]SnapshotItem, 0, len(payload.Candidates))
+ refs := map[string]snapshotRef{}
+ countByRoleName := map[string]int{}
+ for index, item := range payload.Candidates {
+ ref := refsAllocated[index]
+ key := item.Role + "\n" + item.Name
+ nth := countByRoleName[key]
+ countByRoleName[key] = nth + 1
+ items = append(items, SnapshotItem{
+ Ref: ref,
+ Role: item.Role,
+ Name: item.Name,
+ Text: item.Text,
+ Depth: 0,
+ Nth: nth,
+ })
+ refs[ref] = snapshotRef{
+ Selector: item.Selector,
+ Role: item.Role,
+ Name: item.Name,
+ Nth: nth,
+ }
+ }
+ return &snapshotCapture{
+ Version: version,
+ URL: currentURL,
+ Title: title,
+ Items: items,
+ Refs: refs,
+ Truncated: payload.Truncated || len(payload.Candidates) >= limit,
+ ViewportOnly: false,
+ }, nil
+}
+
+func allocateStateRefs(tab *sessionTab, count int) (uint64, []string) {
+ tab.mu.Lock()
+ defer tab.mu.Unlock()
+ tab.stateVersion++
+ version := tab.stateVersion
+ refs := make([]string, count)
+ for index := 0; index < count; index++ {
+ tab.nextRefID++
+ refs[index] = fmt.Sprintf("e%d", tab.nextRefID)
+ }
+ return version, refs
+}
+
+func clearTabState(tab *sessionTab) {
+ if tab == nil {
+ return
+ }
+ tab.mu.Lock()
+ tab.refs = map[string]snapshotRef{}
+ tab.lastState = nil
+ tab.mu.Unlock()
+}
+
+func resolveRefSelector(tab *sessionTab, ref string) (string, error) {
+ if tab == nil {
+ return "", errors.New("tab unavailable")
+ }
+ ref = strings.TrimSpace(ref)
+ if ref == "" {
+ return "", errors.New("ref is required")
+ }
+ tab.mu.RLock()
+ defer tab.mu.RUnlock()
+ item, ok := tab.refs[ref]
+ if !ok {
+ return "", &InvalidRefError{Ref: ref}
+ }
+ if strings.TrimSpace(item.Selector) == "" {
+ return "", errors.New("ref selector unavailable")
+ }
+ return item.Selector, nil
+}
+
+func (session *Session) readPageMetadata(tab *sessionTab, timeout time.Duration) (string, string) {
+ if session == nil || tab == nil {
+ return "", ""
+ }
+ if timeout <= 0 {
+ timeout = 1200 * time.Millisecond
+ }
+ baseCtx, cancel, err := newTabRunContext(tab, timeout)
+ if err != nil {
+ tab.mu.RLock()
+ defer tab.mu.RUnlock()
+ return strings.TrimSpace(tab.lastURL), strings.TrimSpace(tab.title)
+ }
+ defer cancel()
+ var metadata struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ }
+ if err := chromedp.Run(baseCtx, chromedp.EvaluateAsDevTools(`({
+ url: document.location.toString(),
+ title: document.title,
+ })`, &metadata)); err == nil {
+ currentURL := strings.TrimSpace(metadata.URL)
+ title := strings.TrimSpace(metadata.Title)
+ tab.mu.Lock()
+ currentURL = preferredPageURL(currentURL, tab.lastURL)
+ if currentURL != "" {
+ tab.lastURL = currentURL
+ }
+ if title != "" {
+ tab.title = title
+ }
+ tab.mu.Unlock()
+ return currentURL, title
+ }
+ tab.mu.RLock()
+ defer tab.mu.RUnlock()
+ return strings.TrimSpace(tab.lastURL), strings.TrimSpace(tab.title)
+}
+
+func (session *Session) newEphemeralTargetContext(targetID string) (context.Context, context.CancelFunc, error) {
+ targetID = strings.TrimSpace(targetID)
+ if targetID == "" {
+ return nil, nil, errors.New("tab target unavailable")
+ }
+ session.mu.Lock()
+ runtime := session.runtime
+ session.mu.Unlock()
+ if runtime == nil {
+ return nil, nil, errors.New("browser runtime unavailable")
+ }
+ ctx, cancel := chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(targetID)))
+ return ctx, cancel, nil
+}
+
+func (session *Session) newBrowserExecutorContext(timeout time.Duration) (context.Context, context.CancelFunc, error) {
+ session.mu.Lock()
+ runtime := session.runtime
+ session.mu.Unlock()
+ if runtime == nil {
+ return nil, nil, errors.New("browser runtime unavailable")
+ }
+ baseCtx := runtime.BrowserContext()
+ if baseCtx == nil {
+ return nil, nil, errors.New("browser context unavailable")
+ }
+ chromeCtx := chromedp.FromContext(baseCtx)
+ if chromeCtx == nil || chromeCtx.Browser == nil {
+ return nil, nil, errors.New("browser executor unavailable")
+ }
+ runCtx, cancel := context.WithTimeout(baseCtx, timeout)
+ return cdp.WithExecutor(runCtx, chromeCtx.Browser), cancel, nil
+}
+
+func (session *Session) browserTargetInfo(targetID string, timeout time.Duration) *targetpkg.Info {
+ targetID = strings.TrimSpace(targetID)
+ if targetID == "" {
+ return nil
+ }
+ targets, err := session.browserPageTargets(timeout)
+ if err != nil {
+ return nil
+ }
+ return targets[targetID]
+}
+
+func (session *Session) refreshTabMetadataFromBrowser(tab *sessionTab) {
+ if tab == nil {
+ return
+ }
+ info := session.browserTargetInfo(tab.TargetID, 1500*time.Millisecond)
+ if info == nil {
+ return
+ }
+ currentURL := strings.TrimSpace(info.URL)
+ title := strings.TrimSpace(info.Title)
+ tab.mu.Lock()
+ currentURL = preferredPageURL(currentURL, tab.lastURL)
+ if currentURL != "" {
+ tab.lastURL = currentURL
+ }
+ if title != "" {
+ tab.title = title
+ }
+ tab.mu.Unlock()
+}
+
+func (session *Session) observeActionNavigation(tab *sessionTab, before map[string]*targetpkg.Info, previousURL string, timeout time.Duration) (*sessionTab, bool, error) {
+ if tab == nil {
+ return tab, false, nil
+ }
+ timeout = normalizeTimeout(timeout, 5*time.Second)
+ deadline := time.Now().Add(timeout)
+ for {
+ if err := consumeBlockedRequestError(tab); err != nil {
+ return tab, false, err
+ }
+ targets, err := session.browserPageTargets(500 * time.Millisecond)
+ if err == nil {
+ if info := targets[tab.TargetID]; info != nil {
+ currentURL := strings.TrimSpace(info.URL)
+ currentTitle := strings.TrimSpace(info.Title)
+ if shouldTreatNavigationAsComplete(currentURL, previousURL, "") {
+ if err := session.assertObservedURLAllowed(currentURL); err != nil {
+ return tab, false, err
+ }
+ tab.mu.Lock()
+ tab.lastURL = currentURL
+ if currentTitle != "" {
+ tab.title = currentTitle
+ }
+ tab.mu.Unlock()
+ rebound := session.rebindTabAfterNavigation(tab, currentURL)
+ if rebound != nil {
+ tab = rebound
+ }
+ storeNavigation(tab, currentURL)
+ return tab, true, nil
+ }
+ }
+ for targetID, info := range targets {
+ if info == nil || strings.TrimSpace(targetID) == "" || targetID == tab.TargetID {
+ continue
+ }
+ if _, existed := before[targetID]; existed {
+ continue
+ }
+ currentURL := strings.TrimSpace(info.URL)
+ currentTitle := strings.TrimSpace(info.Title)
+ if !shouldTreatNavigationAsComplete(currentURL, previousURL, "") {
+ continue
+ }
+ if err := session.assertObservedURLAllowed(currentURL); err != nil {
+ return tab, false, err
+ }
+ rebound, attachErr := session.attachTargetAsTab(targetID, false)
+ if attachErr != nil || rebound == nil {
+ continue
+ }
+ rebound.mu.Lock()
+ rebound.lastURL = currentURL
+ if currentTitle != "" {
+ rebound.title = currentTitle
+ }
+ rebound.mu.Unlock()
+ if rebound.TargetID != tab.TargetID {
+ session.detachTab(tab.TargetID)
+ }
+ storeNavigation(rebound, currentURL)
+ return rebound, true, nil
+ }
+ } else {
+ info := session.browserTargetInfo(tab.TargetID, 500*time.Millisecond)
+ if info == nil {
+ if time.Now().After(deadline) {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return tab, false, blockedErr
+ }
+ return tab, false, nil
+ }
+ time.Sleep(150 * time.Millisecond)
+ continue
+ }
+ currentURL := strings.TrimSpace(info.URL)
+ currentTitle := strings.TrimSpace(info.Title)
+ if shouldTreatNavigationAsComplete(currentURL, previousURL, "") {
+ if err := session.assertObservedURLAllowed(currentURL); err != nil {
+ return tab, false, err
+ }
+ tab.mu.Lock()
+ tab.lastURL = currentURL
+ if currentTitle != "" {
+ tab.title = currentTitle
+ }
+ tab.mu.Unlock()
+ rebound := session.rebindTabAfterNavigation(tab, currentURL)
+ if rebound != nil {
+ tab = rebound
+ }
+ storeNavigation(tab, currentURL)
+ return tab, true, nil
+ }
+ }
+ if time.Now().After(deadline) {
+ if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil {
+ return tab, false, blockedErr
+ }
+ return tab, false, nil
+ }
+ time.Sleep(150 * time.Millisecond)
+ }
+}
+
+func mapTargetInfos(targets map[string]*targetpkg.Info) []*targetpkg.Info {
+ if len(targets) == 0 {
+ return nil
+ }
+ result := make([]*targetpkg.Info, 0, len(targets))
+ for _, info := range targets {
+ if info != nil {
+ result = append(result, info)
+ }
+ }
+ return result
+}
+
+func (session *Session) runOnTab(tab *sessionTab, timeout time.Duration, actions ...chromedp.Action) error {
+ return session.runOnTabFunc(tab, timeout, func(ctx context.Context) error {
+ return chromedp.Run(ctx, actions...)
+ })
+}
+
+func (session *Session) runOnTabFunc(tab *sessionTab, timeout time.Duration, fn func(context.Context) error) error {
+ if session == nil || tab == nil {
+ return errors.New("tab unavailable")
+ }
+ baseCtx, cancel, err := newTabRunContext(tab, timeout)
+ if err != nil {
+ return err
+ }
+ defer cancel()
+ return fn(baseCtx)
+}
+
+func newTabRunContextWithParent(parent context.Context, tab *sessionTab, timeout time.Duration) (context.Context, context.CancelFunc, error) {
+ runCtx, cancel, err := newTabRunContext(tab, timeout)
+ if err != nil {
+ return nil, nil, err
+ }
+ if parent == nil {
+ return runCtx, cancel, nil
+ }
+ if err := parent.Err(); err != nil {
+ cancel()
+ return nil, nil, err
+ }
+ stop := context.AfterFunc(parent, cancel)
+ return runCtx, func() {
+ stop()
+ cancel()
+ }, nil
+}
+
+func newTabRunContext(tab *sessionTab, timeout time.Duration) (context.Context, context.CancelFunc, error) {
+ if tab == nil {
+ return nil, nil, errors.New("tab unavailable")
+ }
+ if timeout <= 0 {
+ timeout = 10 * time.Second
+ }
+ tab.mu.RLock()
+ baseCtx := tab.ctx
+ tab.mu.RUnlock()
+ if baseCtx == nil {
+ return nil, nil, errors.New("tab context unavailable")
+ }
+ runCtx, cancel := context.WithTimeout(baseCtx, timeout)
+ return runCtx, cancel, nil
+}
+
+func rememberTabCleanup(tab *sessionTab, cancel context.CancelFunc) {
+ if tab == nil || cancel == nil {
+ return
+ }
+ tab.mu.Lock()
+ tab.cleanupCancels = append(tab.cleanupCancels, cancel)
+ tab.mu.Unlock()
+}
+
+func cancelTabContexts(tab *sessionTab) {
+ if tab == nil {
+ return
+ }
+ tab.mu.Lock()
+ cancels := append([]context.CancelFunc(nil), tab.cleanupCancels...)
+ if tab.cancel != nil {
+ cancels = append(cancels, tab.cancel)
+ }
+ tab.cleanupCancels = nil
+ tab.cancel = nil
+ tab.ctx = nil
+ tab.mu.Unlock()
+ for index := len(cancels) - 1; index >= 0; index-- {
+ if cancels[index] != nil {
+ cancels[index]()
+ }
+ }
+}
+
+func (session *Session) navigateTab(tab *sessionTab, targetURL string, timeout time.Duration) (*sessionTab, error) {
+ previousURL := tabURL(tab)
+ if err := session.ensureRequestInterception(tab); err != nil {
+ return nil, err
+ }
+ commandStartedAt := time.Now()
+ var errorText string
+ script := fmt.Sprintf(`(() => { window.location.href = %q; return true; })()`, strings.TrimSpace(targetURL))
+ err := session.runOnTab(tab, timeout, chromedp.EvaluateAsDevTools(script, nil))
+ if err != nil {
+ zap.L().Warn(
+ "browser navigate command failed",
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ sanitizedURLField("previousURL", previousURL),
+ zap.Duration("elapsed", time.Since(commandStartedAt).Round(time.Millisecond)),
+ zap.Error(err),
+ )
+ return nil, err
+ }
+ if errorText != "" {
+ navigationErr := fmt.Errorf("page load error %s", errorText)
+ zap.L().Warn(
+ "browser navigate command returned load error",
+ zap.String("targetId", tab.TargetID),
+ sanitizedURLField("url", targetURL),
+ sanitizedURLField("previousURL", previousURL),
+ zap.Duration("elapsed", time.Since(commandStartedAt).Round(time.Millisecond)),
+ zap.Error(navigationErr),
+ )
+ return nil, navigationErr
+ }
+ settleTimeout := minDuration(3*time.Second, timeout)
+ finalURL, ok, observeErr := session.observeNavigationURL(tab, previousURL, targetURL, settleTimeout)
+ if observeErr != nil {
+ return nil, observeErr
+ }
+ if !ok {
+ finalURL = strings.TrimSpace(targetURL)
+ }
+ if err := session.assertObservedURLAllowed(finalURL); err != nil {
+ return nil, err
+ }
+ tab = session.rebindTabAfterNavigation(tab, finalURL)
+ storeNavigation(tab, finalURL)
+ return tab, nil
+}
+
+func (session *Session) rebindTabAfterNavigation(tab *sessionTab, finalURL string) *sessionTab {
+ if tab == nil {
+ return nil
+ }
+ targets, err := session.snapshotPageTargets()
+ if err != nil {
+ return tab
+ }
+ if _, ok := targets[tab.TargetID]; ok {
+ return tab
+ }
+ expectedURL := strings.TrimSpace(finalURL)
+ if expectedURL == "" {
+ return tab
+ }
+ for targetID, info := range targets {
+ if info == nil || strings.TrimSpace(targetID) == "" {
+ continue
+ }
+ if !shouldTreatNavigationAsComplete(info.URL, "", expectedURL) {
+ continue
+ }
+ rebound, attachErr := session.attachTargetAsTab(targetID, false)
+ if attachErr != nil || rebound == nil {
+ continue
+ }
+ if rebound.TargetID != tab.TargetID {
+ session.detachTab(tab.TargetID)
+ }
+ return rebound
+ }
+ return tab
+}
+
+func (session *Session) rebindTabFromBrowserTargets(tab *sessionTab, previousURL string) *sessionTab {
+ if tab == nil {
+ return nil
+ }
+ targets, err := session.browserPageTargets(1500 * time.Millisecond)
+ if err != nil || len(targets) == 0 {
+ return nil
+ }
+ currentTargetID := strings.TrimSpace(tab.TargetID)
+ previousURL = strings.TrimSpace(previousURL)
+ type candidate struct {
+ targetID string
+ info *targetpkg.Info
+ }
+ candidates := make([]candidate, 0, len(targets))
+ for targetID, info := range targets {
+ if info == nil {
+ continue
+ }
+ targetID = strings.TrimSpace(targetID)
+ if targetID == "" || targetID == currentTargetID {
+ continue
+ }
+ currentURL := strings.TrimSpace(info.URL)
+ if currentURL == "" || isReusablePageURL(currentURL) || urlsEqual(currentURL, previousURL) {
+ continue
+ }
+ candidates = append(candidates, candidate{targetID: targetID, info: info})
+ }
+ if len(candidates) == 0 {
+ return nil
+ }
+ sort.SliceStable(candidates, func(left int, right int) bool {
+ leftInfo := candidates[left].info
+ rightInfo := candidates[right].info
+ if leftInfo.Attached != rightInfo.Attached {
+ return !leftInfo.Attached
+ }
+ return candidates[left].targetID < candidates[right].targetID
+ })
+ chosen := candidates[0]
+ rebound, attachErr := session.attachTargetAsTab(chosen.targetID, false)
+ if attachErr != nil || rebound == nil {
+ return nil
+ }
+ currentURL := strings.TrimSpace(chosen.info.URL)
+ currentTitle := strings.TrimSpace(chosen.info.Title)
+ rebound.mu.Lock()
+ if currentURL != "" {
+ rebound.lastURL = currentURL
+ }
+ if currentTitle != "" {
+ rebound.title = currentTitle
+ }
+ rebound.mu.Unlock()
+ if rebound.TargetID != currentTargetID {
+ session.detachTab(currentTargetID)
+ }
+ return rebound
+}
+
+func (session *Session) observeNavigationURL(tab *sessionTab, previousURL string, targetURL string, timeout time.Duration) (string, bool, error) {
+ if timeout <= 0 {
+ timeout = 1500 * time.Millisecond
+ }
+ deadline := time.Now().Add(timeout)
+ for {
+ if err := consumeBlockedRequestError(tab); err != nil {
+ return "", false, err
+ }
+ currentURL, _ := session.readPageMetadata(tab, 1200*time.Millisecond)
+ err := error(nil)
+ if err == nil && shouldTreatNavigationAsComplete(currentURL, previousURL, targetURL) {
+ currentURL = strings.TrimSpace(currentURL)
+ if err := session.assertObservedURLAllowed(currentURL); err != nil {
+ return "", false, err
+ }
+ return currentURL, true, nil
+ }
+ if time.Now().After(deadline) {
+ break
+ }
+ time.Sleep(150 * time.Millisecond)
+ }
+ if err := consumeBlockedRequestError(tab); err != nil {
+ return "", false, err
+ }
+ return "", false, nil
+}
+
+func storeNavigation(tab *sessionTab, finalURL string) {
+ if tab == nil {
+ return
+ }
+ tab.mu.Lock()
+ tab.lastURL = strings.TrimSpace(finalURL)
+ tab.title = ""
+ tab.refs = map[string]snapshotRef{}
+ tab.lastState = nil
+ tab.mu.Unlock()
+}
+
+func (session *Session) waitOnTab(parent context.Context, tab *sessionTab, request WaitRequest, fallbackTimeout time.Duration) error {
+ timeout := normalizeTimeout(request.Timeout, fallbackTimeout)
+ if request.Time > 0 {
+ return sleepWithContext(parent, request.Time)
+ }
+ runCtx, cancel, err := newTabRunContextWithParent(parent, tab, timeout)
+ if err != nil {
+ return err
+ }
+ defer cancel()
+ if selector := strings.TrimSpace(request.Selector); selector != "" {
+ return chromedp.Run(runCtx, chromedp.PollFunction(
+ `(selector) => {
+ const el = document.querySelector(selector);
+ if (!el) return false;
+ const style = window.getComputedStyle(el);
+ const rect = el.getBoundingClientRect();
+ return style && style.visibility !== "hidden" && style.display !== "none" && rect.width > 0 && rect.height > 0;
+ }`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(selector),
+ ))
+ }
+ if text := strings.TrimSpace(request.Text); text != "" {
+ return chromedp.Run(runCtx, chromedp.PollFunction(
+ `(expected) => {
+ const text = document.body ? document.body.innerText : "";
+ return text.includes(expected);
+ }`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(text),
+ ))
+ }
+ if textGone := strings.TrimSpace(request.TextGone); textGone != "" {
+ return chromedp.Run(runCtx, chromedp.PollFunction(
+ `(expected) => {
+ const text = document.body ? document.body.innerText : "";
+ return !text.includes(expected);
+ }`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(textGone),
+ ))
+ }
+ if urlWait := strings.TrimSpace(request.URL); urlWait != "" {
+ return chromedp.Run(runCtx, chromedp.PollFunction(
+ `(expected) => document.location.toString() === expected`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(urlWait),
+ ))
+ }
+ if fn := strings.TrimSpace(request.Fn); fn != "" {
+ return chromedp.Run(runCtx, chromedp.Poll(
+ fn,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ ))
+ }
+ return errors.New("wait requires at least one of: timeMs, text, textGone, selector, url, fn")
+}
+
+func waitForTabExecutionContext(tab *sessionTab, timeout time.Duration) (string, error) {
+ if tab == nil {
+ return "", errors.New("tab unavailable")
+ }
+ timeout = normalizeTimeout(timeout, 3*time.Second)
+ deadline := time.Now().Add(timeout)
+ var lastErr error
+ for {
+ remaining := time.Until(deadline)
+ if remaining <= 0 {
+ if lastErr != nil {
+ return "", lastErr
+ }
+ return "", context.DeadlineExceeded
+ }
+ attemptTimeout := minDuration(remaining, 800*time.Millisecond)
+ runCtx, cancel := context.WithTimeout(tab.ctx, attemptTimeout)
+ var readyState string
+ err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(`document.readyState`, &readyState))
+ cancel()
+ if err == nil {
+ readyState = strings.TrimSpace(readyState)
+ if readyState == "" {
+ readyState = "unknown"
+ }
+ return readyState, nil
+ }
+ lastErr = err
+ time.Sleep(150 * time.Millisecond)
+ }
+}
+
+func shouldRetryActionState(result ActionResult) bool {
+ currentURL := strings.TrimSpace(result.URL)
+ if currentURL == "" || isReusablePageURL(currentURL) {
+ return true
+ }
+ if !result.StateAvailable && result.State == nil {
+ return true
+ }
+ if result.State == nil {
+ return false
+ }
+ return strings.TrimSpace(result.Title) == "" && result.State.ItemCount == 0
+}
+
+func actionStateReason(result ActionResult) string {
+ currentURL := strings.TrimSpace(result.URL)
+ switch {
+ case currentURL == "":
+ return "missing-url"
+ case isReusablePageURL(currentURL):
+ return "placeholder-url"
+ case !result.StateAvailable && result.State == nil:
+ if strings.TrimSpace(result.StateError) != "" {
+ return "state-unavailable:" + strings.TrimSpace(result.StateError)
+ }
+ return "state-unavailable"
+ case result.State == nil:
+ return "state-missing"
+ case strings.TrimSpace(result.Title) == "" && result.State.ItemCount == 0:
+ return "empty-title-and-items"
+ default:
+ return "stable"
+ }
+}
+
+func (session *Session) stabilizeActionResult(tab *sessionTab, result ActionResult, limit int, timeout time.Duration, maxWait time.Duration) ActionResult {
+ if tab == nil || maxWait <= 0 {
+ return result
+ }
+ deadline := time.Now().Add(maxWait)
+ attempts := 0
+ for shouldRetryActionState(result) && time.Now().Before(deadline) {
+ time.Sleep(150 * time.Millisecond)
+ attempts++
+ retried, err := session.collectActionResult(tab, limit, timeout, false)
+ if err != nil {
+ zap.L().Warn(
+ "browser action state stabilization failed",
+ append(session.logFields(),
+ zap.String("targetId", tab.TargetID),
+ zap.Int("attempts", attempts),
+ zap.String("reason", actionStateReason(result)),
+ zap.Error(err),
+ )...,
+ )
+ break
+ }
+ result = retried
+ }
+ return result
+}
+
+func (session *Session) actClick(tab *sessionTab, request ActRequest) error {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return err
+ }
+ script := fmt.Sprintf(`(() => {
+ const el = document.querySelector(%q);
+ if (!el) throw new Error("element not found");
+ el.scrollIntoView({ block: "center", inline: "center" });
+ const invoke = () => {
+ el.focus?.();
+ if (typeof PointerEvent === "function") {
+ el.dispatchEvent(new PointerEvent("pointerdown", { bubbles: true, cancelable: true, pointerType: "mouse", isPrimary: true, button: 0, buttons: 1 }));
+ el.dispatchEvent(new PointerEvent("pointerup", { bubbles: true, cancelable: true, pointerType: "mouse", isPrimary: true, button: 0, buttons: 0 }));
+ }
+ el.dispatchEvent(new MouseEvent("mousedown", { bubbles: true, cancelable: true, button: 0, buttons: 1 }));
+ el.dispatchEvent(new MouseEvent("mouseup", { bubbles: true, cancelable: true, button: 0, buttons: 0 }));
+ el.click();
+ };
+ setTimeout(invoke, 0);
+ return true;
+ })()`, selector)
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil))
+}
+
+func (session *Session) actType(tab *sessionTab, request ActRequest) error {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return err
+ }
+ script := fmt.Sprintf(`(() => {
+ const el = document.querySelector(%q);
+ if (!el) throw new Error("element not found");
+ el.focus();
+ const nextValue = String((el.value ?? "")) + %q;
+ el.value = nextValue;
+ el.dispatchEvent(new InputEvent("input", { bubbles: true, data: %q, inputType: "insertText" }));
+ el.dispatchEvent(new Event("change", { bubbles: true }));
+ return nextValue;
+ })()`, selector, request.Text, request.Text)
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil))
+}
+
+func (session *Session) actPress(tab *sessionTab, request ActRequest) error {
+ if strings.TrimSpace(request.Key) == "" {
+ return errors.New("key is required")
+ }
+ script := fmt.Sprintf(`(() => {
+ const target = document.activeElement || document.body;
+ const key = %q;
+ setTimeout(() => {
+ const down = new KeyboardEvent("keydown", { key, bubbles: true });
+ const press = new KeyboardEvent("keypress", { key, bubbles: true });
+ const up = new KeyboardEvent("keyup", { key, bubbles: true });
+ target.dispatchEvent(down);
+ target.dispatchEvent(press);
+ if (key === "Enter" && target && target.form) {
+ if (typeof target.form.requestSubmit === "function") {
+ target.form.requestSubmit();
+ } else {
+ target.form.submit();
+ }
+ }
+ target.dispatchEvent(up);
+ }, 0);
+ return true;
+ })()`, request.Key)
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil))
+}
+
+func (session *Session) actHover(tab *sessionTab, request ActRequest) error {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return err
+ }
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.dispatchEvent(new MouseEvent("mouseover", {bubbles:true})); el.dispatchEvent(new MouseEvent("mouseenter", {bubbles:true})); })()`, selector), nil))
+}
+
+func (session *Session) actSelect(tab *sessionTab, request ActRequest) error {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return err
+ }
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.value = %q; el.dispatchEvent(new Event("input", {bubbles:true})); el.dispatchEvent(new Event("change", {bubbles:true})); })()`, selector, request.Value), nil))
+}
+
+func (session *Session) actFill(tab *sessionTab, request ActRequest) error {
+ selector, err := resolveRefSelector(tab, request.Ref)
+ if err != nil {
+ return err
+ }
+ return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.focus(); el.value = %q; el.dispatchEvent(new Event("input", {bubbles:true})); el.dispatchEvent(new Event("change", {bubbles:true})); })()`, selector, request.Value), nil))
+}
+
+func (session *Session) actResize(tab *sessionTab, request ActRequest) error {
+ if request.Width <= 0 {
+ return errors.New("width is required")
+ }
+ if request.Height <= 0 {
+ return errors.New("height is required")
+ }
+ return session.runOnTabFunc(tab, normalizeTimeout(request.Timeout, 15*time.Second), func(ctx context.Context) error {
+ return emulation.SetDeviceMetricsOverride(int64(request.Width), int64(request.Height), 1, false).Do(ctx)
+ })
+}
+
+func (session *Session) actEvaluate(tab *sessionTab, request ActRequest) error {
+ if strings.TrimSpace(request.Expression) == "" {
+ return errors.New("expression is required")
+ }
+ var result any
+ if err := session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.Evaluate(request.Expression, &result)); err != nil {
+ return err
+ }
+ tab.mu.Lock()
+ tab.evaluateResult = result
+ tab.mu.Unlock()
+ return nil
+}
+
+func evaluateResult(tab *sessionTab) any {
+ if tab == nil {
+ return nil
+ }
+ tab.mu.RLock()
+ defer tab.mu.RUnlock()
+ return tab.evaluateResult
+}
+
+func tabURL(tab *sessionTab) string {
+ if tab == nil {
+ return ""
+ }
+ tab.mu.RLock()
+ defer tab.mu.RUnlock()
+ return tab.lastURL
+}
+
+func preferredPageURL(candidates ...string) string {
+ fallback := ""
+ for _, candidate := range candidates {
+ candidate = strings.TrimSpace(candidate)
+ if candidate == "" {
+ continue
+ }
+ if fallback == "" {
+ fallback = candidate
+ }
+ if !isReusablePageURL(candidate) {
+ return candidate
+ }
+ }
+ return fallback
+}
+
+func waitForText(tab *sessionTab, expected string, timeout time.Duration, gone bool) error {
+ expected = strings.TrimSpace(expected)
+ runCtx, cancel := context.WithTimeout(tab.ctx, timeout)
+ defer cancel()
+ err := chromedp.Run(runCtx, chromedp.PollFunction(
+ `(expected, gone) => {
+ const text = document.body ? document.body.innerText : "";
+ const contains = text.includes(expected);
+ return gone ? !contains : contains;
+ }`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(expected, gone),
+ ))
+ if errors.Is(err, chromedp.ErrPollingTimeout) {
+ if gone {
+ return &WaitTimeoutError{Condition: "textGone"}
+ }
+ return &WaitTimeoutError{Condition: "text"}
+ }
+ return err
+}
+
+func waitForURL(tab *sessionTab, expected string, timeout time.Duration) error {
+ expected = strings.TrimSpace(expected)
+ runCtx, cancel := context.WithTimeout(tab.ctx, timeout)
+ defer cancel()
+ err := chromedp.Run(runCtx, chromedp.PollFunction(
+ `(expected) => document.location.toString() === expected`,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ chromedp.WithPollingArgs(expected),
+ ))
+ if errors.Is(err, chromedp.ErrPollingTimeout) {
+ return &WaitTimeoutError{Condition: "url"}
+ }
+ return err
+}
+
+func waitForEvaluateCondition(tab *sessionTab, fn string, timeout time.Duration) error {
+ runCtx, cancel := context.WithTimeout(tab.ctx, timeout)
+ defer cancel()
+ err := chromedp.Run(runCtx, chromedp.Poll(
+ fn,
+ nil,
+ chromedp.WithPollingInterval(200*time.Millisecond),
+ chromedp.WithPollingTimeout(timeout),
+ ))
+ if errors.Is(err, chromedp.ErrPollingTimeout) {
+ return &WaitTimeoutError{Condition: "fn"}
+ }
+ return err
+}
+
+func pickReusableTargetID(infos []*targetpkg.Info) string {
+ choose := func(requireUnattached bool, preferBlank bool) string {
+ for _, info := range infos {
+ if info == nil || info.Type != "page" {
+ continue
+ }
+ if requireUnattached && info.Attached {
+ continue
+ }
+ if preferBlank && !isReusablePageURL(info.URL) {
+ continue
+ }
+ return string(info.TargetID)
+ }
+ return ""
+ }
+ for _, candidate := range []string{
+ choose(true, true),
+ choose(true, false),
+ choose(false, true),
+ choose(false, false),
+ } {
+ if strings.TrimSpace(candidate) != "" {
+ return candidate
+ }
+ }
+ return ""
+}
+
+func isReusablePageURL(rawURL string) bool {
+ switch strings.TrimSpace(strings.ToLower(rawURL)) {
+ case "", "about:blank", "chrome://newtab/", "chrome-search://local-ntp/local-ntp.html":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldTreatNavigationAsComplete(observedURL string, previousURL string, targetURL string) bool {
+ observed := strings.TrimSpace(observedURL)
+ if observed == "" || observed == "about:blank" {
+ return false
+ }
+ if urlsEqual(observed, targetURL) {
+ return true
+ }
+ previous := strings.TrimSpace(previousURL)
+ if previous == "" || previous == "about:blank" {
+ return true
+ }
+ return !urlsEqual(observed, previous)
+}
+
+func urlsEqual(left string, right string) bool {
+ return strings.TrimSpace(left) == strings.TrimSpace(right)
+}
+
+func cookieSyncKeys(rawURL string, records []appcookies.Record) []string {
+ keys := map[string]struct{}{}
+ if parsed, err := url.Parse(strings.TrimSpace(rawURL)); err == nil {
+ if hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname())); hostname != "" {
+ keys[hostname] = struct{}{}
+ }
+ }
+ for _, record := range records {
+ domain := strings.ToLower(strings.TrimSpace(record.Domain))
+ domain = strings.TrimPrefix(domain, ".")
+ if domain == "" {
+ continue
+ }
+ keys[domain] = struct{}{}
+ }
+ if len(keys) == 0 {
+ trimmed := strings.ToLower(strings.TrimSpace(rawURL))
+ if trimmed != "" {
+ keys[trimmed] = struct{}{}
+ }
+ }
+ result := make([]string, 0, len(keys))
+ for key := range keys {
+ result = append(result, key)
+ }
+ sort.Strings(result)
+ return result
+}
+
+func hasCookieSyncFingerprint(values map[string]string, keys []string, fingerprint string) bool {
+ if len(values) == 0 || len(keys) == 0 || fingerprint == "" {
+ return false
+ }
+ for _, key := range keys {
+ if current := strings.TrimSpace(values[strings.ToLower(strings.TrimSpace(key))]); current == fingerprint {
+ return true
+ }
+ }
+ return false
+}
+
+func rememberCookieSyncFingerprint(values map[string]string, keys []string, fingerprint string) {
+ if len(keys) == 0 || fingerprint == "" {
+ return
+ }
+ for _, key := range keys {
+ key = strings.ToLower(strings.TrimSpace(key))
+ if key == "" {
+ continue
+ }
+ values[key] = fingerprint
+ }
+}
+
+func isRecoverableCookieSyncError(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "context deadline exceeded"),
+ strings.Contains(message, "browser context unavailable"),
+ strings.Contains(message, "browser executor unavailable"),
+ strings.Contains(message, "target closed"),
+ strings.Contains(message, "session closed"),
+ strings.Contains(message, "connection closed"),
+ strings.Contains(message, "execution context was destroyed"),
+ strings.Contains(message, "cannot find context with specified id"),
+ strings.Contains(message, "unique context id not found"):
+ return true
+ default:
+ return false
+ }
+}
+
+func wrapRuntimeHangError(err error) error {
+ if err == nil {
+ return nil
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "context deadline exceeded"),
+ strings.Contains(message, "target closed"),
+ strings.Contains(message, "connection closed"),
+ strings.Contains(message, "session closed"),
+ strings.Contains(message, "websocket"):
+ return &FatalError{Err: err}
+ default:
+ return err
+ }
+}
+
+func cookieFingerprint(records []appcookies.Record) string {
+ if len(records) == 0 {
+ return ""
+ }
+ items := append([]appcookies.Record(nil), records...)
+ sort.Slice(items, func(i, j int) bool {
+ left := items[i]
+ right := items[j]
+ switch {
+ case left.Domain != right.Domain:
+ return left.Domain < right.Domain
+ case left.Path != right.Path:
+ return left.Path < right.Path
+ default:
+ return left.Name < right.Name
+ }
+ })
+ parts := make([]string, 0, len(items))
+ for _, item := range items {
+ parts = append(parts, strings.Join([]string{
+ strings.TrimSpace(item.Domain),
+ strings.TrimSpace(item.Path),
+ strings.TrimSpace(item.Name),
+ item.Value,
+ }, "\n"))
+ }
+ return strings.Join(parts, "\n---\n")
+}
+
+func normalizeTimeout(value time.Duration, fallback time.Duration) time.Duration {
+ if value <= 0 {
+ value = fallback
+ }
+ if value <= 0 {
+ value = 15 * time.Second
+ }
+ if value < 500*time.Millisecond {
+ return 500 * time.Millisecond
+ }
+ if value > 120*time.Second {
+ return 120 * time.Second
+ }
+ return value
+}
+
+func captureTimeout(timeout time.Duration) time.Duration {
+ timeout = normalizeTimeout(timeout, 15*time.Second)
+ scaled := timeout / 5
+ if scaled < time.Second {
+ scaled = time.Second
+ }
+ if scaled > 5*time.Second {
+ scaled = 5 * time.Second
+ }
+ return scaled
+}
+
+func sleepWithContext(ctx context.Context, delay time.Duration) error {
+ if delay <= 0 {
+ return nil
+ }
+ if ctx == nil {
+ time.Sleep(delay)
+ return nil
+ }
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
+}
+
+func (session *Session) logFields() []zap.Field {
+ options := session.optionsSnapshot()
+ return []zap.Field{
+ zap.String("profile", options.ProfileName),
+ zap.String("preferredBrowser", options.PreferredBrowser),
+ zap.Bool("headless", options.Headless),
+ }
+}
+
+func sanitizedURLField(key string, rawURL string) zap.Field {
+ return zap.String(key, sanitizeLogURL(rawURL))
+}
+
+func sanitizeLogURL(rawURL string) string {
+ trimmed := strings.TrimSpace(rawURL)
+ if trimmed == "" {
+ return ""
+ }
+ parsed, err := url.Parse(trimmed)
+ if err != nil {
+ return trimmed
+ }
+ parsed.User = nil
+ parsed.RawQuery = ""
+ parsed.Fragment = ""
+ return parsed.String()
+}
+
+func resultItemCount(result ActionResult) int {
+ if result.State != nil && result.State.ItemCount > 0 {
+ return result.State.ItemCount
+ }
+ return len(result.Items)
+}
+
+func normalizedSnapshotLimit(limit int) int {
+ if limit > 0 {
+ return limit
+ }
+ return defaultSnapshotLimit
+}
+
+func minDuration(left time.Duration, right time.Duration) time.Duration {
+ if left <= 0 {
+ return right
+ }
+ if right <= 0 {
+ return left
+ }
+ if left < right {
+ return left
+ }
+ return right
+}
+
+func actInvalidatesState(kind string) bool {
+ switch kind {
+ case "click", "type", "press", "hover", "select", "fill", "resize":
+ return true
+ default:
+ return false
+ }
+}
+
+func actNeedsSettle(kind string) bool {
+ switch kind {
+ case "click", "type", "press", "hover", "select", "fill", "resize", "wait":
+ return true
+ default:
+ return false
+ }
+}
+
+func actMayOpenNewTab(kind string) bool {
+ switch kind {
+ case "click", "press":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldDeferStateCapture(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "context deadline exceeded"),
+ strings.Contains(message, "execution context was destroyed"),
+ strings.Contains(message, "cannot find context with specified id"),
+ strings.Contains(message, "unique context id not found"):
+ return true
+ default:
+ return false
+ }
+}
+
+func isTargetLookupError(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ return strings.Contains(message, "no target with given id found")
+}
+
+func isHostnameAllowed(hostname string, policy SSRFPolicy) bool {
+ hostname = strings.ToLower(strings.TrimSpace(hostname))
+ if hostname == "" {
+ return false
+ }
+ if _, ok := policy.AllowedHostnames[hostname]; ok {
+ return true
+ }
+ for _, pattern := range policy.HostnameAllowlist {
+ pattern = strings.ToLower(strings.TrimSpace(pattern))
+ if pattern == "" {
+ continue
+ }
+ if pattern == hostname {
+ return true
+ }
+ if strings.HasPrefix(pattern, "*.") {
+ suffix := strings.TrimPrefix(pattern, "*.")
+ if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix {
+ return true
+ }
+ continue
+ }
+ if matched, _ := filepath.Match(pattern, hostname); matched {
+ return true
+ }
+ }
+ return false
+}
+
+func isPrivateOrLocalIP(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+ if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() || ip.IsMulticast() {
+ return true
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ if ip4[0] == 100 && (ip4[1]&0xC0) == 64 {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/application/browsercdp/session_cookie_sync_test.go b/internal/application/browsercdp/session_cookie_sync_test.go
new file mode 100644
index 0000000..2c48581
--- /dev/null
+++ b/internal/application/browsercdp/session_cookie_sync_test.go
@@ -0,0 +1,120 @@
+package browsercdp
+
+import (
+ "errors"
+ "testing"
+
+ appcookies "dreamcreator/internal/application/cookies"
+)
+
+const (
+ testCookiePrimaryURL = "https://www.example.test/"
+ testCookieProfileURL = "https://space.example.test/profile/demo"
+ testCookiePrimaryHost = "www.example.test"
+ testCookieProfileHost = "space.example.test"
+ testCookieSharedDomain = ".example.test"
+)
+
+func TestCookieSyncKeys_IncludeCookieScopeDomains(t *testing.T) {
+ t.Parallel()
+
+ keys := cookieSyncKeys(testCookieProfileURL, []appcookies.Record{
+ {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"},
+ {Name: "site_session", Domain: testCookiePrimaryHost, Path: "/"},
+ })
+
+ expected := map[string]struct{}{
+ testCookieProfileHost: {},
+ "example.test": {},
+ testCookiePrimaryHost: {},
+ }
+ if len(keys) != len(expected) {
+ t.Fatalf("expected %d keys, got %d: %#v", len(expected), len(keys), keys)
+ }
+ for _, key := range keys {
+ if _, ok := expected[key]; !ok {
+ t.Fatalf("unexpected key %q in %#v", key, keys)
+ }
+ delete(expected, key)
+ }
+ if len(expected) != 0 {
+ t.Fatalf("missing keys: %#v", expected)
+ }
+}
+
+func TestCookieSyncFingerprint_ReusesBroadCookieDomainAcrossSubdomains(t *testing.T) {
+ t.Parallel()
+
+ fingerprint := "same-cookie-set"
+ values := map[string]string{}
+ initialKeys := cookieSyncKeys(testCookiePrimaryURL, []appcookies.Record{
+ {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"},
+ })
+ rememberCookieSyncFingerprint(values, initialKeys, fingerprint)
+
+ nextKeys := cookieSyncKeys(testCookieProfileURL, []appcookies.Record{
+ {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"},
+ })
+ if !hasCookieSyncFingerprint(values, nextKeys, fingerprint) {
+ t.Fatalf("expected sync fingerprint to be reused across subdomains: values=%#v keys=%#v", values, nextKeys)
+ }
+}
+
+func TestIsRecoverableCookieSyncError(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {name: "timeout", err: errors.New("connector cookie sync failed: context deadline exceeded"), want: true},
+ {name: "destroyed", err: errors.New("execution context was destroyed"), want: true},
+ {name: "closed", err: errors.New("target closed"), want: true},
+ {name: "validation", err: errors.New("invalid cookie domain"), want: false},
+ }
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ if got := isRecoverableCookieSyncError(tc.err); got != tc.want {
+ t.Fatalf("expected %v, got %v for %v", tc.want, got, tc.err)
+ }
+ })
+ }
+}
+
+func TestPreferredPageURL(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ candidates []string
+ want string
+ }{
+ {
+ name: "preserve known non placeholder over about blank",
+ candidates: []string{"about:blank", testCookiePrimaryURL},
+ want: testCookiePrimaryURL,
+ },
+ {
+ name: "prefer observed real url",
+ candidates: []string{testCookieProfileURL, testCookiePrimaryURL},
+ want: testCookieProfileURL,
+ },
+ {
+ name: "fallback to requested real url",
+ candidates: []string{"about:blank", "", "https://www.example.test"},
+ want: "https://www.example.test",
+ },
+ }
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ if got := preferredPageURL(tc.candidates...); got != tc.want {
+ t.Fatalf("expected %q, got %q for %#v", tc.want, got, tc.candidates)
+ }
+ })
+ }
+}
diff --git a/internal/application/browsercdp/session_open_live_test.go b/internal/application/browsercdp/session_open_live_test.go
new file mode 100644
index 0000000..14608bf
--- /dev/null
+++ b/internal/application/browsercdp/session_open_live_test.go
@@ -0,0 +1,411 @@
+package browsercdp
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ targetpkg "github.com/chromedp/cdproto/target"
+ "github.com/chromedp/chromedp"
+ "go.uber.org/zap"
+)
+
+func TestSessionOpenLive(t *testing.T) {
+ if os.Getenv("DREAMCREATOR_BROWSER_OPEN_LIVE") != "1" {
+ t.Skip("set DREAMCREATOR_BROWSER_OPEN_LIVE=1 to run the live browser open probe")
+ }
+
+ logger, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatalf("create logger: %v", err)
+ }
+ defer func() {
+ _ = logger.Sync()
+ }()
+ restore := zap.ReplaceGlobals(logger)
+ defer restore()
+
+ status := ResolveStatus("", true)
+ if !status.Ready {
+ t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError))
+ }
+
+ targetURL := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_OPEN_URL"))
+ if targetURL == "" {
+ targetURL = "https://example.com"
+ }
+
+ registry := NewSessionRegistry()
+ session := registry.GetOrCreate("live-open", "dreamcreator", SessionOptions{
+ SessionKey: "live-open",
+ ProfileName: "dreamcreator",
+ PreferredBrowser: strings.TrimSpace(status.ChosenBrowser),
+ Headless: true,
+ })
+ defer session.stop()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+ defer cancel()
+
+ result, err := session.Open(ctx, targetURL, CommandOptions{
+ Timeout: 30 * time.Second,
+ Limit: 50,
+ })
+ if err != nil {
+ t.Fatalf("session open failed: %v", err)
+ }
+
+ itemCount := len(result.Items)
+ if result.State != nil && result.State.ItemCount > 0 {
+ itemCount = result.State.ItemCount
+ }
+ t.Logf(
+ "open result ok=%v targetId=%s url=%s hasTitle=%v stateAvailable=%v stateVersion=%d itemCount=%d stateError=%s",
+ result.OK,
+ result.TargetID,
+ sanitizeLogURL(result.URL),
+ strings.TrimSpace(result.Title) != "",
+ result.State != nil || result.StateAvailable,
+ result.StateVersion,
+ itemCount,
+ result.StateError,
+ )
+
+ tab, resolveErr := session.resolveTab(result.TargetID, true)
+ if resolveErr != nil {
+ t.Fatalf("resolve tab: %v", resolveErr)
+ }
+ probeTargetContext(t, "tab.ctx", tab.ctx)
+
+ browserProbeCtx, browserProbeCancel := chromedp.NewContext(session.runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(result.TargetID)))
+ defer browserProbeCancel()
+ probeTargetContext(t, "fresh-browserCtx#1", browserProbeCtx)
+ probeTargetContext(t, "fresh-browserCtx#2", browserProbeCtx)
+
+ allocProbeCtx, allocProbeCancel := chromedp.NewContext(session.runtime.allocCtx, chromedp.WithTargetID(targetpkg.ID(result.TargetID)))
+ defer allocProbeCancel()
+ probeTargetContext(t, "fresh-allocCtx#1", allocProbeCtx)
+ probeTargetContext(t, "fresh-allocCtx#2", allocProbeCtx)
+}
+
+func TestSessionWorkflowLive(t *testing.T) {
+ if os.Getenv("DREAMCREATOR_BROWSER_WORKFLOW_LIVE") != "1" {
+ t.Skip("set DREAMCREATOR_BROWSER_WORKFLOW_LIVE=1 to run the live browser workflow probe")
+ }
+
+ logger, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatalf("create logger: %v", err)
+ }
+ defer func() {
+ _ = logger.Sync()
+ }()
+ restore := zap.ReplaceGlobals(logger)
+ defer restore()
+
+ status := ResolveStatus("", true)
+ if !status.Ready {
+ t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError))
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ switch r.URL.Path {
+ case "/":
+ _, _ = fmt.Fprint(w, `
+
+ Workflow Home
+
+
+ Next Page
+
+`)
+ case "/search":
+ query := r.URL.Query().Get("q")
+ _, _ = fmt.Fprintf(w, `
+
+ Results %s
+
+ Result for %s
+ Next Page
+
+`, query, query)
+ case "/next":
+ _, _ = fmt.Fprint(w, `
+
+ Next Page
+
+
+
+`)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ registry := NewSessionRegistry()
+ session := registry.GetOrCreate("live-workflow", "dreamcreator", SessionOptions{
+ SessionKey: "live-workflow",
+ ProfileName: "dreamcreator",
+ PreferredBrowser: strings.TrimSpace(status.ChosenBrowser),
+ Headless: true,
+ SSRFRules: SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: true,
+ },
+ })
+ defer session.stop()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second)
+ defer cancel()
+
+ openResult, err := session.Open(ctx, server.URL, CommandOptions{
+ Timeout: 20 * time.Second,
+ Limit: 50,
+ })
+ if err != nil {
+ t.Fatalf("open failed: %v", err)
+ }
+ stateResult, stateErr := session.State(openResult.TargetID, 50)
+ if stateErr != nil {
+ t.Logf("post-open state failed: %v", stateErr)
+ } else {
+ t.Logf("post-open state url=%s hasTitle=%v items=%d", sanitizeLogURL(stateResult.URL), strings.TrimSpace(stateResult.Title) != "", len(stateResult.Items))
+ }
+ tab, resolveErr := session.resolveTab(openResult.TargetID, true)
+ if resolveErr != nil {
+ t.Fatalf("resolve workflow tab: %v", resolveErr)
+ }
+ probeTargetContext(t, "workflow-tab.ctx", tab.ctx)
+ browserProbeCtx, browserProbeCancel := chromedp.NewContext(session.runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(openResult.TargetID)))
+ defer browserProbeCancel()
+ probeTargetContext(t, "workflow-fresh-browserCtx", browserProbeCtx)
+
+ currentItems := openResult.Items
+ if stateErr == nil && len(stateResult.Items) > 0 {
+ currentItems = stateResult.Items
+ }
+ inputRef := findSnapshotRef(currentItems, "textbox", "search weather")
+ if inputRef == "" {
+ t.Fatalf("textbox ref not found in latest snapshot: %+v", currentItems)
+ }
+ buttonRef := findSnapshotRef(currentItems, "button", "search")
+ if buttonRef == "" {
+ t.Fatalf("button ref not found in latest snapshot: %+v", currentItems)
+ }
+
+ typeResult, err := session.Act(ctx, ActRequest{
+ Kind: "type",
+ TargetID: openResult.TargetID,
+ Ref: inputRef,
+ Text: "weather",
+ Limit: 50,
+ Timeout: 15 * time.Second,
+ })
+ if err != nil {
+ t.Fatalf("type failed: %v", err)
+ }
+ nextItems := typeResult.Items
+ if len(nextItems) == 0 {
+ nextItems = currentItems
+ }
+ buttonRef = findSnapshotRef(nextItems, "button", "search")
+ if buttonRef == "" {
+ t.Fatalf("button ref not found after type: %+v", nextItems)
+ }
+
+ searchURL := server.URL + "/search?q=weather"
+ clickResult, err := session.Act(ctx, ActRequest{
+ Kind: "click",
+ TargetID: openResult.TargetID,
+ Ref: buttonRef,
+ WaitFor: &WaitRequest{
+ URL: searchURL,
+ Timeout: 8 * time.Second,
+ },
+ Limit: 50,
+ Timeout: 20 * time.Second,
+ })
+ if err != nil {
+ t.Fatalf("click failed: %v", err)
+ }
+ if !strings.Contains(clickResult.URL, "/search?q=weather") {
+ t.Fatalf("unexpected search url: %s", sanitizeLogURL(clickResult.URL))
+ }
+ if !strings.Contains(clickResult.Title, "Results weather") {
+ t.Fatalf("unexpected search title")
+ }
+
+ navigateResult, err := session.Navigate(ctx, clickResult.TargetID, server.URL+"/next", false, CommandOptions{
+ Timeout: 20 * time.Second,
+ Limit: 50,
+ })
+ if err != nil {
+ t.Fatalf("navigate failed: %v", err)
+ }
+ if !strings.Contains(navigateResult.URL, "/next") {
+ t.Fatalf("unexpected navigate url: %s", sanitizeLogURL(navigateResult.URL))
+ }
+ if strings.TrimSpace(navigateResult.Title) != "Next Page" {
+ t.Fatalf("unexpected navigate title")
+ }
+
+ t.Logf("workflow open/type/click/navigate succeeded targetId=%s finalURL=%s hasTitle=%v", navigateResult.TargetID, sanitizeLogURL(navigateResult.URL), strings.TrimSpace(navigateResult.Title) != "")
+}
+
+func TestSessionBaiduSearchLive(t *testing.T) {
+ if os.Getenv("DREAMCREATOR_BROWSER_BAIDU_LIVE") != "1" {
+ t.Skip("set DREAMCREATOR_BROWSER_BAIDU_LIVE=1 to run the live Baidu browser probe")
+ }
+
+ logger, err := zap.NewDevelopment()
+ if err != nil {
+ t.Fatalf("create logger: %v", err)
+ }
+ defer func() {
+ _ = logger.Sync()
+ }()
+ restore := zap.ReplaceGlobals(logger)
+ defer restore()
+
+ status := ResolveStatus("", false)
+ if !status.Ready {
+ t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError))
+ }
+
+ headless := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_BAIDU_HEADLESS")) == "1"
+ query := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_BAIDU_QUERY"))
+ if query == "" {
+ query = "天气"
+ }
+
+ registry := NewSessionRegistry()
+ session := registry.GetOrCreate("live-baidu", "dreamcreator", SessionOptions{
+ SessionKey: "live-baidu",
+ ProfileName: "dreamcreator",
+ PreferredBrowser: strings.TrimSpace(status.ChosenBrowser),
+ Headless: headless,
+ })
+ defer session.stop()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ openResult, err := session.Open(ctx, "https://www.baidu.com", CommandOptions{
+ Timeout: 30 * time.Second,
+ Limit: 50,
+ })
+ if err != nil {
+ t.Fatalf("open failed: %v", err)
+ }
+
+ inputRef := findSnapshotRef(openResult.Items, "textbox", "")
+ if inputRef == "" {
+ t.Fatalf("textbox ref not found in open items: %+v", openResult.Items)
+ }
+
+ typeResult, err := session.Act(ctx, ActRequest{
+ Kind: "type",
+ TargetID: openResult.TargetID,
+ Ref: inputRef,
+ Text: query,
+ Limit: 50,
+ Timeout: 15 * time.Second,
+ })
+ if err != nil {
+ t.Fatalf("type failed: %v", err)
+ }
+
+ buttonItems := typeResult.Items
+ if len(buttonItems) == 0 {
+ buttonItems = openResult.Items
+ }
+ buttonRef := findSnapshotRef(buttonItems, "button", "百度一下")
+ if buttonRef == "" {
+ buttonRef = findSnapshotRef(buttonItems, "button", "百度")
+ }
+ if buttonRef == "" {
+ t.Fatalf("search button ref not found after type: %+v", buttonItems)
+ }
+
+ clickResult, err := session.Act(ctx, ActRequest{
+ Kind: "click",
+ TargetID: openResult.TargetID,
+ Ref: buttonRef,
+ Limit: 50,
+ Timeout: 20 * time.Second,
+ })
+ if err != nil {
+ t.Fatalf("click failed: %v", err)
+ }
+
+ t.Logf(
+ "baidu click result targetId=%s url=%s hasTitle=%v stateAvailable=%v itemCount=%d stateError=%s",
+ clickResult.TargetID,
+ sanitizeLogURL(clickResult.URL),
+ strings.TrimSpace(clickResult.Title) != "",
+ clickResult.State != nil || clickResult.StateAvailable,
+ len(clickResult.Items),
+ clickResult.StateError,
+ )
+ if !strings.Contains(clickResult.URL, "baidu.com") {
+ t.Fatalf("unexpected click url: %s", sanitizeLogURL(clickResult.URL))
+ }
+ if !clickResult.StateAvailable && clickResult.State == nil {
+ t.Fatalf("click state unavailable: %s", clickResult.StateError)
+ }
+}
+
+func probeTargetContext(t *testing.T, label string, ctx context.Context) {
+ t.Helper()
+
+ probeCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ defer cancel()
+
+ var payload struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ ReadyState string `json:"readyState"`
+ }
+ err := chromedp.Run(probeCtx, chromedp.EvaluateAsDevTools(`({
+ url: document.location.toString(),
+ title: document.title,
+ readyState: document.readyState,
+ })`, &payload))
+ if err != nil {
+ t.Logf("%s probe failed: %v", label, err)
+ return
+ }
+ t.Logf(
+ "%s probe url=%s hasTitle=%v readyState=%s",
+ label,
+ sanitizeLogURL(payload.URL),
+ strings.TrimSpace(payload.Title) != "",
+ strings.TrimSpace(payload.ReadyState),
+ )
+}
+
+func findSnapshotRef(items []SnapshotItem, role string, needle string) string {
+ role = strings.TrimSpace(strings.ToLower(role))
+ needle = strings.TrimSpace(strings.ToLower(needle))
+ for _, item := range items {
+ if role != "" && strings.TrimSpace(strings.ToLower(item.Role)) != role {
+ continue
+ }
+ name := strings.ToLower(strings.TrimSpace(item.Name))
+ text := strings.ToLower(strings.TrimSpace(item.Text))
+ if needle == "" || strings.Contains(name, needle) || strings.Contains(text, needle) {
+ return strings.TrimSpace(item.Ref)
+ }
+ }
+ return ""
+}
diff --git a/internal/application/browsercdp/session_registry_test.go b/internal/application/browsercdp/session_registry_test.go
new file mode 100644
index 0000000..e5529c0
--- /dev/null
+++ b/internal/application/browsercdp/session_registry_test.go
@@ -0,0 +1,68 @@
+package browsercdp
+
+import "testing"
+
+func TestSessionRegistryCloseSessionKeyRemovesSessions(t *testing.T) {
+ t.Parallel()
+
+ registry := NewSessionRegistry()
+ first := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{})
+ second := registry.GetOrCreate("session-a", "work", SessionOptions{})
+ other := registry.GetOrCreate("session-b", "dreamcreator", SessionOptions{})
+
+ if first == nil || second == nil || other == nil {
+ t.Fatalf("expected sessions to be created")
+ }
+
+ registry.CloseSessionKey("session-a")
+
+ registry.mu.Lock()
+ _, existsA := registry.sessions["session-a"]
+ bucketB := registry.sessions["session-b"]
+ registry.mu.Unlock()
+
+ if existsA {
+ t.Fatalf("expected session-a bucket to be removed")
+ }
+ if len(bucketB) != 1 {
+ t.Fatalf("expected session-b bucket to remain intact, got %#v", bucketB)
+ }
+
+ recreated := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{})
+ if recreated == nil {
+ t.Fatalf("expected recreated session")
+ }
+ if recreated == first {
+ t.Fatalf("expected session-a to be recreated after cleanup")
+ }
+}
+
+func TestSessionRegistryCloseAllRemovesAllSessions(t *testing.T) {
+ t.Parallel()
+
+ registry := NewSessionRegistry()
+ first := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{})
+ second := registry.GetOrCreate("session-b", "work", SessionOptions{})
+
+ if first == nil || second == nil {
+ t.Fatalf("expected sessions to be created")
+ }
+
+ registry.CloseAll()
+
+ registry.mu.Lock()
+ sessionCount := len(registry.sessions)
+ registry.mu.Unlock()
+
+ if sessionCount != 0 {
+ t.Fatalf("expected all sessions to be removed, got %d buckets", sessionCount)
+ }
+
+ recreated := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{})
+ if recreated == nil {
+ t.Fatalf("expected recreated session")
+ }
+ if recreated == first {
+ t.Fatalf("expected session to be recreated after close all")
+ }
+}
diff --git a/internal/application/browsercdp/session_ssrf_test.go b/internal/application/browsercdp/session_ssrf_test.go
new file mode 100644
index 0000000..6a9c61f
--- /dev/null
+++ b/internal/application/browsercdp/session_ssrf_test.go
@@ -0,0 +1,78 @@
+package browsercdp
+
+import (
+ "context"
+ "net"
+ "strings"
+ "testing"
+)
+
+func stubLookupIPAddrsForHost(t *testing.T, stub func(context.Context, string) ([]net.IPAddr, error)) {
+ t.Helper()
+ original := lookupIPAddrsForHost
+ lookupIPAddrsForHost = stub
+ t.Cleanup(func() {
+ lookupIPAddrsForHost = original
+ })
+}
+
+func TestAssertURLAllowedRejectsHostnameResolvingToPrivateIP(t *testing.T) {
+ stubLookupIPAddrsForHost(t, func(_ context.Context, host string) ([]net.IPAddr, error) {
+ if host != "public.example.com" {
+ t.Fatalf("unexpected host lookup %q", host)
+ }
+ return []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, nil
+ })
+
+ err := AssertURLAllowed("https://public.example.com/path", SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: false,
+ AllowedHostnames: map[string]struct{}{},
+ })
+ if err == nil || !strings.Contains(err.Error(), "resolving to private IP") {
+ t.Fatalf("expected resolved private IP to be blocked, got %v", err)
+ }
+}
+
+func TestAssertURLAllowedAllowsHostnameResolvingToPublicIP(t *testing.T) {
+ stubLookupIPAddrsForHost(t, func(_ context.Context, host string) ([]net.IPAddr, error) {
+ if host != "public.example.com" {
+ t.Fatalf("unexpected host lookup %q", host)
+ }
+ return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil
+ })
+
+ err := AssertURLAllowed("https://public.example.com/path", SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: false,
+ AllowedHostnames: map[string]struct{}{},
+ })
+ if err != nil {
+ t.Fatalf("expected public hostname to be allowed, got %v", err)
+ }
+}
+
+func TestAssertURLAllowedAllowsExplicitHostnameAllowlist(t *testing.T) {
+ stubLookupIPAddrsForHost(t, func(context.Context, string) ([]net.IPAddr, error) {
+ t.Fatalf("allowlisted hostname should not require DNS validation")
+ return nil, nil
+ })
+
+ err := AssertURLAllowed("http://localhost:3000", SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: false,
+ AllowedHostnames: map[string]struct{}{
+ "localhost": {},
+ },
+ })
+ if err != nil {
+ t.Fatalf("expected explicit hostname allowlist to bypass private host block, got %v", err)
+ }
+}
+
+func TestAssertRequestURLAllowedRejectsPrivateWebsocketTargets(t *testing.T) {
+ err := assertRequestURLAllowed(context.Background(), "wss://127.0.0.1/socket", SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: false,
+ AllowedHostnames: map[string]struct{}{},
+ })
+ if err == nil || !strings.Contains(err.Error(), "blocked private IP") {
+ t.Fatalf("expected private websocket target to be blocked, got %v", err)
+ }
+}
diff --git a/internal/application/browsercdp/session_wait_test.go b/internal/application/browsercdp/session_wait_test.go
new file mode 100644
index 0000000..d038025
--- /dev/null
+++ b/internal/application/browsercdp/session_wait_test.go
@@ -0,0 +1,33 @@
+package browsercdp
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestWaitOnTabTimeRespectsParentCancel(t *testing.T) {
+ session := &Session{}
+ parent, cancel := context.WithCancel(context.Background())
+ time.AfterFunc(50*time.Millisecond, cancel)
+
+ startedAt := time.Now()
+ err := session.waitOnTab(parent, nil, WaitRequest{Time: 2 * time.Second}, 2*time.Second)
+ elapsed := time.Since(startedAt)
+
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("expected context canceled, got %v", err)
+ }
+ if elapsed >= 500*time.Millisecond {
+ t.Fatalf("expected wait to stop early, elapsed=%s", elapsed)
+ }
+}
+
+func TestSanitizeLogURLRemovesSensitiveParts(t *testing.T) {
+ rawURL := "https://user:secret@example.com/path?q=token#frag"
+ sanitized := sanitizeLogURL(rawURL)
+ if sanitized != "https://example.com/path" {
+ t.Fatalf("unexpected sanitized url %q", sanitized)
+ }
+}
diff --git a/internal/application/browsercdp/user_data_dir.go b/internal/application/browsercdp/user_data_dir.go
new file mode 100644
index 0000000..e25e743
--- /dev/null
+++ b/internal/application/browsercdp/user_data_dir.go
@@ -0,0 +1,33 @@
+package browsercdp
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/google/uuid"
+)
+
+var browserProfileNamespace = uuid.MustParse("d8f7f362-4b0f-48d8-b8ac-0dcf3a6af6e2")
+
+func ResolveProfileStorageKey(sessionKey string, profileName string) string {
+ sessionKey = strings.TrimSpace(sessionKey)
+ if sessionKey == "" {
+ sessionKey = "default"
+ }
+ profileName = strings.TrimSpace(profileName)
+ if profileName == "" {
+ profileName = "dreamcreator"
+ }
+ return uuid.NewSHA1(browserProfileNamespace, []byte(sessionKey+"\x00"+profileName)).String()
+}
+
+func ResolveProfileUserDataDir(sessionKey string, profileName string) string {
+ return filepath.Join(
+ os.TempDir(),
+ "dreamcreator",
+ "browser",
+ "profiles",
+ ResolveProfileStorageKey(sessionKey, profileName),
+ )
+}
diff --git a/internal/application/browsercdp/user_data_dir_test.go b/internal/application/browsercdp/user_data_dir_test.go
new file mode 100644
index 0000000..d897135
--- /dev/null
+++ b/internal/application/browsercdp/user_data_dir_test.go
@@ -0,0 +1,75 @@
+package browsercdp
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/google/uuid"
+)
+
+func TestResolveProfileStorageKeyStableAndDistinct(t *testing.T) {
+ t.Parallel()
+
+ sessionKey := "v2::-::aui::-::788e3ff2-b4ad-426d-be0f-1424156032fd::-::788e3ff2-b4ad-426d-be0f-1424156032fd"
+ profileName := "dreamcreator"
+
+ first := ResolveProfileStorageKey(sessionKey, profileName)
+ second := ResolveProfileStorageKey(sessionKey, profileName)
+ otherProfile := ResolveProfileStorageKey(sessionKey, "work")
+
+ if first != second {
+ t.Fatalf("expected stable storage key, got %q and %q", first, second)
+ }
+ if first == otherProfile {
+ t.Fatalf("expected different profiles to use different storage keys")
+ }
+ if _, err := uuid.Parse(first); err != nil {
+ t.Fatalf("expected uuid storage key, got %q: %v", first, err)
+ }
+}
+
+func TestResolveProfileUserDataDirUsesShallowSafePath(t *testing.T) {
+ t.Parallel()
+
+ sessionKey := "v2::-::aui::-::788e3ff2-b4ad-426d-be0f-1424156032fd::-::788e3ff2-b4ad-426d-be0f-1424156032fd"
+ profileName := "dream:creator"
+ dir := ResolveProfileUserDataDir(sessionKey, profileName)
+
+ if strings.Contains(dir, sessionKey) {
+ t.Fatalf("expected session key to stay out of path, got %q", dir)
+ }
+ if strings.Contains(dir, profileName) {
+ t.Fatalf("expected profile name to stay out of path, got %q", dir)
+ }
+ if strings.ContainsAny(filepath.Base(dir), `<>:"/\|?*`) {
+ t.Fatalf("expected final directory name to be path safe, got %q", filepath.Base(dir))
+ }
+
+ rel, err := filepath.Rel(os.TempDir(), dir)
+ if err != nil {
+ t.Fatalf("resolve relative path: %v", err)
+ }
+ parts := strings.Split(rel, string(os.PathSeparator))
+ if len(parts) != 4 {
+ t.Fatalf("expected shallow temp path, got %q (%d parts)", rel, len(parts))
+ }
+ if parts[0] != "dreamcreator" || parts[1] != "browser" || parts[2] != "profiles" {
+ t.Fatalf("unexpected profile path layout: %q", rel)
+ }
+}
+
+func TestNormalizeSessionOptionsUsesEncodedProfileDir(t *testing.T) {
+ t.Parallel()
+
+ options := normalizeSessionOptions(SessionOptions{
+ SessionKey: "v2::-::aui::-::thread:1::-::thread:1",
+ ProfileName: "dream:creator",
+ })
+
+ want := ResolveProfileUserDataDir("v2::-::aui::-::thread:1::-::thread:1", "dream:creator")
+ if options.UserDataDir != want {
+ t.Fatalf("unexpected encoded user data dir: got %q want %q", options.UserDataDir, want)
+ }
+}
diff --git a/internal/application/channels/telegram/bot_service_approval_test.go b/internal/application/channels/telegram/bot_service_approval_test.go
index b3ad849..e678119 100644
--- a/internal/application/channels/telegram/bot_service_approval_test.go
+++ b/internal/application/channels/telegram/bot_service_approval_test.go
@@ -201,7 +201,7 @@ func TestShouldSuppressTelegramResolvedForward(t *testing.T) {
reason string
want bool
}{
- {name: "telegram with sender", reason: "telegram:5234834060", want: true},
+ {name: "telegram with sender", reason: "telegram:test-user-001", want: true},
{name: "telegram plain", reason: "telegram", want: true},
{name: "upper telegram", reason: "TeLeGrAm:1", want: true},
{name: "gateway reason", reason: "gateway:web", want: false},
diff --git a/internal/application/connectors/dto/models.go b/internal/application/connectors/dto/models.go
index 01225ea..ca819d2 100644
--- a/internal/application/connectors/dto/models.go
+++ b/internal/application/connectors/dto/models.go
@@ -8,6 +8,9 @@ type Connector struct {
Status string `json:"status"`
CookiesCount int `json:"cookiesCount"`
Cookies []ConnectorCookie `json:"cookies"`
+ Domains []string `json:"domains,omitempty"`
+ PolicyKey string `json:"policyKey,omitempty"`
+ Capabilities []string `json:"capabilities,omitempty"`
LastVerifiedAt string `json:"lastVerifiedAt"`
}
@@ -22,10 +25,51 @@ type ClearConnectorRequest struct {
ID string `json:"id"`
}
-type ConnectConnectorRequest struct {
+type StartConnectorConnectRequest struct {
ID string `json:"id"`
}
+type StartConnectorConnectResult struct {
+ SessionID string `json:"sessionId"`
+ Connector Connector `json:"connector"`
+}
+
+type FinishConnectorConnectRequest struct {
+ SessionID string `json:"sessionId"`
+}
+
+type FinishConnectorConnectResult struct {
+ SessionID string `json:"sessionId"`
+ Saved bool `json:"saved"`
+ RawCookiesCount int `json:"rawCookiesCount"`
+ FilteredCookiesCount int `json:"filteredCookiesCount"`
+ Domains []string `json:"domains,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ Connector Connector `json:"connector"`
+}
+
+type CancelConnectorConnectRequest struct {
+ SessionID string `json:"sessionId"`
+}
+
+type ConnectorConnectSession struct {
+ SessionID string `json:"sessionId"`
+ ConnectorID string `json:"connectorId"`
+ State string `json:"state"`
+ Saved bool `json:"saved"`
+ RawCookiesCount int `json:"rawCookiesCount"`
+ FilteredCookiesCount int `json:"filteredCookiesCount"`
+ Domains []string `json:"domains,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ Error string `json:"error,omitempty"`
+ LastCookiesAt string `json:"lastCookiesAt,omitempty"`
+ Connector Connector `json:"connector"`
+}
+
+type GetConnectorConnectSessionRequest struct {
+ SessionID string `json:"sessionId"`
+}
+
type OpenConnectorSiteRequest struct {
ID string `json:"id"`
}
diff --git a/internal/application/connectors/service/cookies.go b/internal/application/connectors/service/cookies.go
index 491ab04..373a991 100644
--- a/internal/application/connectors/service/cookies.go
+++ b/internal/application/connectors/service/cookies.go
@@ -1,27 +1,14 @@
package service
import (
- "encoding/json"
- "strings"
"time"
- "github.com/playwright-community/playwright-go"
-
"dreamcreator/internal/application/connectors/dto"
+ appcookies "dreamcreator/internal/application/cookies"
+ "dreamcreator/internal/application/sitepolicy"
"dreamcreator/internal/domain/connectors"
)
-type cookieRecord struct {
- Name string `json:"name"`
- Value string `json:"value"`
- Domain string `json:"domain"`
- Path string `json:"path"`
- Expires int64 `json:"expires"`
- HttpOnly bool `json:"httpOnly"`
- Secure bool `json:"secure"`
- SameSite string `json:"sameSite,omitempty"`
-}
-
func mapConnectorDTO(item connectors.Connector) dto.Connector {
cookies := decodeCookies(item.CookiesJSON)
status := item.Status
@@ -34,6 +21,7 @@ func mapConnectorDTO(item connectors.Connector) dto.Connector {
if item.LastVerifiedAt != nil {
lastVerified = item.LastVerifiedAt.Format(time.RFC3339)
}
+ policy, _ := sitepolicy.ForConnectorType(string(item.Type))
return dto.Connector{
ID: item.ID,
Type: string(item.Type),
@@ -42,16 +30,23 @@ func mapConnectorDTO(item connectors.Connector) dto.Connector {
Status: string(status),
CookiesCount: len(cookies),
Cookies: mapCookiesDTO(cookies),
+ Domains: append([]string(nil), policy.Domains...),
+ PolicyKey: policy.Key,
+ Capabilities: append([]string(nil), policy.Capabilities...),
LastVerifiedAt: lastVerified,
}
}
func connectorGroup(connectorType connectors.ConnectorType) string {
switch connectorType {
- case connectors.ConnectorGoogle, connectors.ConnectorXiaohongshu:
+ case connectors.ConnectorGoogle, connectors.ConnectorZhihu:
return "search_engine"
+ case connectors.ConnectorXiaohongshu, connectors.ConnectorReddit, connectors.ConnectorX:
+ return "community"
case connectors.ConnectorBilibili:
return "video"
+ case connectors.ConnectorGitHub:
+ return "developer"
default:
return "other"
}
@@ -61,6 +56,14 @@ func connectorDesc(connectorType connectors.ConnectorType) string {
switch connectorType {
case connectors.ConnectorGoogle:
return "Global web search with broad multilingual coverage, suitable for general factual lookups."
+ case connectors.ConnectorGitHub:
+ return "Developer collaboration and code hosting, useful for repositories, issues, pull requests, and docs."
+ case connectors.ConnectorReddit:
+ return "Community discussions and long-tail troubleshooting threads, useful for niche product and user experience research."
+ case connectors.ConnectorZhihu:
+ return "Chinese knowledge-sharing community content, useful for answers, articles, and topic exploration."
+ case connectors.ConnectorX:
+ return "Real-time public conversation and creator timelines, useful for posts, threads, and timely updates."
case connectors.ConnectorXiaohongshu:
return "Chinese lifestyle and recommendation community content, useful for reviews and trend discovery."
case connectors.ConnectorBilibili:
@@ -70,7 +73,7 @@ func connectorDesc(connectorType connectors.ConnectorType) string {
}
}
-func mapCookiesDTO(records []cookieRecord) []dto.ConnectorCookie {
+func mapCookiesDTO(records []appcookies.Record) []dto.ConnectorCookie {
if len(records) == 0 {
return nil
}
@@ -90,99 +93,10 @@ func mapCookiesDTO(records []cookieRecord) []dto.ConnectorCookie {
return result
}
-func encodeCookies(records []cookieRecord) (string, error) {
- if len(records) == 0 {
- return "", nil
- }
- data, err := json.Marshal(records)
- if err != nil {
- return "", err
- }
- return string(data), nil
-}
-
-func decodeCookies(data string) []cookieRecord {
- trimmed := strings.TrimSpace(data)
- if trimmed == "" {
- return nil
- }
- var records []cookieRecord
- if err := json.Unmarshal([]byte(trimmed), &records); err != nil {
- return nil
- }
- return records
+func encodeCookies(records []appcookies.Record) (string, error) {
+ return appcookies.EncodeJSON(records)
}
-func cookiesFromPlaywright(cookies []playwright.Cookie) []cookieRecord {
- if len(cookies) == 0 {
- return nil
- }
- result := make([]cookieRecord, 0, len(cookies))
- for _, cookie := range cookies {
- sameSite := ""
- if cookie.SameSite != nil {
- sameSite = strings.ToLower(string(*cookie.SameSite))
- }
- result = append(result, cookieRecord{
- Name: cookie.Name,
- Value: cookie.Value,
- Domain: cookie.Domain,
- Path: cookie.Path,
- Expires: int64(cookie.Expires),
- HttpOnly: cookie.HttpOnly,
- Secure: cookie.Secure,
- SameSite: sameSite,
- })
- }
- return result
-}
-
-func toPlaywrightCookies(records []cookieRecord, targetURL string) []playwright.OptionalCookie {
- if len(records) == 0 {
- return nil
- }
- result := make([]playwright.OptionalCookie, 0, len(records))
- for _, record := range records {
- if strings.TrimSpace(record.Name) == "" {
- continue
- }
- cookie := playwright.OptionalCookie{
- Name: record.Name,
- Value: record.Value,
- }
- domain := strings.TrimSpace(record.Domain)
- path := strings.TrimSpace(record.Path)
- if domain != "" {
- cookie.Domain = playwright.String(domain)
- } else if strings.TrimSpace(targetURL) != "" {
- cookie.URL = playwright.String(strings.TrimSpace(targetURL))
- }
- if path != "" {
- cookie.Path = playwright.String(path)
- }
- if record.Expires > 0 {
- cookie.Expires = playwright.Float(float64(record.Expires))
- }
- cookie.HttpOnly = playwright.Bool(record.HttpOnly)
- cookie.Secure = playwright.Bool(record.Secure)
- if sameSite := mapSameSite(record.SameSite); sameSite != nil {
- cookie.SameSite = sameSite
- }
- result = append(result, cookie)
- }
- return result
-}
-
-func mapSameSite(value string) *playwright.SameSiteAttribute {
- normalized := strings.ToLower(strings.TrimSpace(value))
- switch normalized {
- case "lax":
- return playwright.SameSiteAttributeLax
- case "strict":
- return playwright.SameSiteAttributeStrict
- case "none":
- return playwright.SameSiteAttributeNone
- default:
- return nil
- }
+func decodeCookies(data string) []appcookies.Record {
+ return appcookies.DecodeJSON(data)
}
diff --git a/internal/application/connectors/service/export.go b/internal/application/connectors/service/export.go
index 5033548..acf0e56 100644
--- a/internal/application/connectors/service/export.go
+++ b/internal/application/connectors/service/export.go
@@ -8,6 +8,7 @@ import (
"strconv"
"strings"
+ appcookies "dreamcreator/internal/application/cookies"
"dreamcreator/internal/domain/connectors"
)
@@ -58,7 +59,7 @@ func (service *ConnectorsService) ExportConnectorCookies(ctx context.Context, id
return path, nil
}
-func writeJSONCookies(path string, cookies []cookieRecord) error {
+func writeJSONCookies(path string, cookies []appcookies.Record) error {
data, err := json.MarshalIndent(cookies, "", " ")
if err != nil {
return err
@@ -66,7 +67,7 @@ func writeJSONCookies(path string, cookies []cookieRecord) error {
return writeFileAtomic(path, data, 0o600)
}
-func writeNetscapeCookies(path string, cookies []cookieRecord) error {
+func writeNetscapeCookies(path string, cookies []appcookies.Record) error {
builder := strings.Builder{}
builder.WriteString("# Netscape HTTP Cookie File\n")
builder.WriteString("# This file was generated by DreamCreator.\n")
diff --git a/internal/application/connectors/service/login.go b/internal/application/connectors/service/login.go
index 4a235ed..5f33033 100644
--- a/internal/application/connectors/service/login.go
+++ b/internal/application/connectors/service/login.go
@@ -2,66 +2,360 @@ package service
import (
"context"
+ "errors"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "sort"
"strings"
"time"
- "github.com/playwright-community/playwright-go"
+ cdptarget "github.com/chromedp/cdproto/target"
+ "github.com/chromedp/chromedp"
+ "dreamcreator/internal/application/browsercdp"
"dreamcreator/internal/application/connectors/dto"
+ appcookies "dreamcreator/internal/application/cookies"
+ "dreamcreator/internal/application/sitepolicy"
"dreamcreator/internal/domain/connectors"
)
-func (service *ConnectorsService) ConnectConnector(ctx context.Context, request dto.ConnectConnectorRequest) (dto.Connector, error) {
+func (service *ConnectorsService) StartConnectorConnect(ctx context.Context, request dto.StartConnectorConnectRequest) (dto.StartConnectorConnectResult, error) {
id := strings.TrimSpace(request.ID)
if id == "" {
- return dto.Connector{}, connectors.ErrInvalidConnector
+ return dto.StartConnectorConnectResult{}, connectors.ErrInvalidConnector
}
connector, err := service.repo.Get(ctx, id)
if err != nil {
- return dto.Connector{}, err
+ return dto.StartConnectorConnectResult{}, err
}
targetURL, err := connectorHomeURL(connector.Type)
if err != nil {
- return dto.Connector{}, err
+ return dto.StartConnectorConnectResult{}, err
}
- cookies, err := runPlaywrightLogin(ctx, targetURL)
+
+ sessionID := service.newSessionID()
+ userDataDir := connectorSessionDir(connector.Type, sessionID)
+ runtime, tabCtx, cancel, err := service.startBrowser(service.preferredBrowser(ctx), false, userDataDir)
if err != nil {
- return dto.Connector{}, err
+ return dto.StartConnectorConnectResult{}, err
+ }
+ if err := chromedp.Run(tabCtx, chromedp.Navigate(targetURL)); err != nil {
+ cancel()
+ runtime.Stop()
+ if service.removeAll != nil {
+ _ = service.removeAll(userDataDir)
+ }
+ return dto.StartConnectorConnectResult{}, err
+ }
+
+ session := &connectorSession{
+ ID: sessionID,
+ ConnectorID: connector.ID,
+ ConnectorType: connector.Type,
+ Runtime: runtime,
+ TabCtx: tabCtx,
+ Cancel: cancel,
+ UserDataDir: userDataDir,
+ State: connectorSessionStateRunning,
+ ConnectorSnapshot: mapConnectorDTO(connector),
+ finalizeDone: make(chan struct{}),
+ }
+ if current := chromedp.FromContext(tabCtx); current != nil && current.Target != nil {
+ session.TargetID = current.Target.TargetID
}
- cookiesJSON, err := encodeCookies(cookies)
+
+ replaced := service.putSession(session)
+ service.cleanupSession(replaced)
+ service.startConnectSessionMonitor(sessionID)
+ log.Printf("connectors: started connect session id=%s connector=%s target=%s userDataDir=%s", sessionID, connector.Type, session.TargetID, userDataDir)
+
+ return dto.StartConnectorConnectResult{
+ SessionID: sessionID,
+ Connector: mapConnectorDTO(connector),
+ }, nil
+}
+
+func (service *ConnectorsService) FinishConnectorConnect(ctx context.Context, request dto.FinishConnectorConnectRequest) (dto.FinishConnectorConnectResult, error) {
+ sessionID := strings.TrimSpace(request.SessionID)
+ if sessionID == "" {
+ return dto.FinishConnectorConnectResult{}, connectors.ErrConnectorSessionGone
+ }
+ result, _, err := service.finalizeConnectSession(ctx, sessionID, "manual_finish")
if err != nil {
- return dto.Connector{}, err
+ return dto.FinishConnectorConnectResult{}, err
}
+ return result, nil
+}
- now := service.now()
- status := connectors.StatusDisconnected
- var lastVerifiedAt *time.Time
- if len(cookies) > 0 {
- status = connectors.StatusConnected
- lastVerifiedAt = &now
+func (service *ConnectorsService) CancelConnectorConnect(ctx context.Context, request dto.CancelConnectorConnectRequest) error {
+ sessionID := strings.TrimSpace(request.SessionID)
+ if sessionID == "" {
+ return connectors.ErrConnectorSessionGone
+ }
+ log.Printf("connectors: canceled connect session id=%s", sessionID)
+ service.cleanupSession(service.popSession(sessionID))
+ return nil
+}
+
+func (service *ConnectorsService) finalizeConnectSession(ctx context.Context, sessionID string, reason string) (dto.FinishConnectorConnectResult, bool, error) {
+ session, ok := service.getSession(sessionID)
+ if !ok || session == nil {
+ return dto.FinishConnectorConnectResult{}, false, connectors.ErrConnectorSessionGone
+ }
+ triggered := false
+ session.finalizeOnce.Do(func() {
+ triggered = true
+ result, err := service.performFinalize(ctx, session, reason)
+ service.mu.Lock()
+ defer service.mu.Unlock()
+ if err != nil {
+ session.State = connectorSessionStateFailed
+ session.FinalError = err.Error()
+ } else {
+ session.State = connectorSessionStateCompleted
+ session.FinalError = ""
+ session.FinalResult = &result
+ session.ConnectorSnapshot = result.Connector
+ }
+ close(session.finalizeDone)
+ })
+ <-session.finalizeDone
+
+ session, ok = service.getSession(sessionID)
+ if !ok || session == nil {
+ return dto.FinishConnectorConnectResult{}, triggered, connectors.ErrConnectorSessionGone
+ }
+ service.mu.Lock()
+ defer service.mu.Unlock()
+ if session.FinalError != "" {
+ return dto.FinishConnectorConnectResult{}, triggered, errors.New(session.FinalError)
+ }
+ if session.FinalResult == nil {
+ return dto.FinishConnectorConnectResult{}, triggered, connectors.ErrConnectorSessionDead
+ }
+ return *session.FinalResult, triggered, nil
+}
+
+func (service *ConnectorsService) performFinalize(ctx context.Context, session *connectorSession, reason string) (dto.FinishConnectorConnectResult, error) {
+ if session == nil {
+ return dto.FinishConnectorConnectResult{}, connectors.ErrConnectorSessionGone
+ }
+
+ log.Printf("connectors: finalize requested session=%s connector=%s reason=%s", session.ID, session.ConnectorType, reason)
+ records, err := readConnectorCookiesFromRuntime(session.Runtime)
+ if err != nil {
+ log.Printf("connectors: live cookie read failed session=%s connector=%s reason=%s err=%v", session.ID, session.ConnectorType, reason, err)
+ service.mu.Lock()
+ records = append([]appcookies.Record(nil), session.LastCookies...)
+ service.mu.Unlock()
+ } else {
+ service.updateSession(session.ID, func(current *connectorSession) {
+ current.LastCookies = append([]appcookies.Record(nil), records...)
+ current.LastCookiesAt = service.now()
+ })
+ }
+
+ policy, _ := sitepolicy.ForConnectorType(string(session.ConnectorType))
+ filtered := appcookies.FilterByDomains(records, policy.Domains)
+ log.Printf("connectors: finalize cookies session=%s connector=%s reason=%s raw=%d filtered=%d domains=%s", session.ID, session.ConnectorType, reason, len(records), len(filtered), strings.Join(cookieDomains(filtered), ","))
+
+ current, err := service.repo.Get(ctx, session.ConnectorID)
+ if err != nil {
+ service.cleanupSession(session)
+ return dto.FinishConnectorConnectResult{}, err
+ }
+
+ result := dto.FinishConnectorConnectResult{
+ SessionID: session.ID,
+ Saved: len(filtered) > 0,
+ RawCookiesCount: len(records),
+ FilteredCookiesCount: len(filtered),
+ Domains: cookieDomains(filtered),
+ Reason: reason,
+ Connector: mapConnectorDTO(current),
+ }
+ if len(filtered) == 0 {
+ service.cleanupSession(session)
+ log.Printf("connectors: finalize completed without matching cookies session=%s connector=%s reason=%s", session.ID, session.ConnectorType, reason)
+ return result, nil
+ }
+
+ cookiesJSON, err := encodeCookies(filtered)
+ if err != nil {
+ service.cleanupSession(session)
+ return dto.FinishConnectorConnectResult{}, err
}
+ now := service.now()
updated, err := connectors.NewConnector(connectors.ConnectorParams{
- ID: connector.ID,
- Type: string(connector.Type),
- Status: string(status),
+ ID: current.ID,
+ Type: string(current.Type),
+ Status: string(connectors.StatusConnected),
CookiesJSON: cookiesJSON,
- LastVerifiedAt: lastVerifiedAt,
- CreatedAt: &connector.CreatedAt,
+ LastVerifiedAt: &now,
+ CreatedAt: ¤t.CreatedAt,
UpdatedAt: &now,
})
if err != nil {
- return dto.Connector{}, err
+ service.cleanupSession(session)
+ return dto.FinishConnectorConnectResult{}, err
}
if err := service.repo.Save(ctx, updated); err != nil {
- return dto.Connector{}, err
+ service.cleanupSession(session)
+ return dto.FinishConnectorConnectResult{}, err
+ }
+ result.Connector = mapConnectorDTO(updated)
+ service.cleanupSession(session)
+ log.Printf("connectors: finalize saved cookies session=%s connector=%s reason=%s filtered=%d", session.ID, session.ConnectorType, reason, len(filtered))
+ return result, nil
+}
+
+func (service *ConnectorsService) startConnectSessionMonitor(sessionID string) {
+ session, ok := service.getSession(sessionID)
+ if !ok || session == nil {
+ return
}
- return mapConnectorDTO(updated), nil
+ service.watchConnectSessionTarget(sessionID, session)
+ service.watchConnectSessionBrowser(sessionID, session)
+
+ go func() {
+ ticker := time.NewTicker(1 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ session, ok := service.getSession(sessionID)
+ if !ok || session == nil {
+ return
+ }
+ service.mu.Lock()
+ state := session.State
+ runtime := session.Runtime
+ targetID := session.TargetID
+ tabCtx := session.TabCtx
+ service.mu.Unlock()
+ if state != connectorSessionStateRunning {
+ return
+ }
+ if runtime == nil || !runtime.Status().Ready {
+ service.triggerSessionFinalize(sessionID, "browser_closed")
+ return
+ }
+ if cookies, err := readConnectorCookiesFromRuntime(runtime); err == nil {
+ service.updateSession(sessionID, func(current *connectorSession) {
+ current.LastCookies = append([]appcookies.Record(nil), cookies...)
+ current.LastCookiesAt = service.now()
+ })
+ }
+ if targetID != "" {
+ exists, err := connectorTargetExists(runtime, targetID)
+ if err == nil && !exists {
+ service.triggerSessionFinalize(sessionID, "tab_closed")
+ return
+ }
+ }
+ var tabDone <-chan struct{}
+ if tabCtx != nil {
+ tabDone = tabCtx.Done()
+ }
+ var browserDone <-chan struct{}
+ if browserCtx := runtime.BrowserContext(); browserCtx != nil {
+ browserDone = browserCtx.Done()
+ }
+
+ select {
+ case <-ticker.C:
+ case <-tabDone:
+ service.triggerSessionFinalize(sessionID, "tab_closed")
+ return
+ case <-browserDone:
+ service.triggerSessionFinalize(sessionID, "browser_closed")
+ return
+ }
+ }
+ }()
+}
+
+func (service *ConnectorsService) watchConnectSessionTarget(sessionID string, session *connectorSession) {
+ if session == nil || session.TabCtx == nil {
+ return
+ }
+ targetID := session.TargetID
+ chromedp.ListenTarget(session.TabCtx, func(ev any) {
+ switch current := ev.(type) {
+ case *cdptarget.EventTargetDestroyed:
+ if targetID != "" && current.TargetID != targetID {
+ return
+ }
+ service.triggerSessionFinalize(sessionID, "tab_closed")
+ case *cdptarget.EventTargetCrashed:
+ if targetID != "" && current.TargetID != targetID {
+ return
+ }
+ service.triggerSessionFinalize(sessionID, "tab_closed")
+ case *cdptarget.EventDetachedFromTarget:
+ service.triggerSessionFinalize(sessionID, "tab_closed")
+ }
+ })
+}
+
+func (service *ConnectorsService) watchConnectSessionBrowser(sessionID string, session *connectorSession) {
+ if session == nil || session.Runtime == nil || session.Runtime.BrowserContext() == nil {
+ return
+ }
+ go func(browserCtx context.Context) {
+ <-browserCtx.Done()
+ service.triggerSessionFinalize(sessionID, "browser_closed")
+ }(session.Runtime.BrowserContext())
+}
+
+func (service *ConnectorsService) triggerSessionFinalize(sessionID string, reason string) {
+ go func() {
+ _, _, err := service.finalizeConnectSession(context.Background(), sessionID, reason)
+ if err != nil && !errors.Is(err, connectors.ErrConnectorSessionGone) {
+ log.Printf("connectors: auto-finalize failed session=%s reason=%s err=%v", sessionID, reason, err)
+ }
+ }()
+}
+
+func connectorTargetExists(runtime *browsercdp.Runtime, targetID cdptarget.ID) (bool, error) {
+ if runtime == nil || targetID == "" {
+ return true, nil
+ }
+ timeoutCtx, cancel := context.WithTimeout(runtime.BrowserContext(), 3*time.Second)
+ defer cancel()
+
+ var exists bool
+ if err := chromedp.Run(timeoutCtx, chromedp.ActionFunc(func(actionCtx context.Context) error {
+ targets, err := cdptarget.GetTargets().Do(actionCtx)
+ if err != nil {
+ return err
+ }
+ for _, info := range targets {
+ if info != nil && info.TargetID == targetID {
+ exists = true
+ break
+ }
+ }
+ return nil
+ })); err != nil {
+ return false, err
+ }
+ return exists, nil
}
func connectorHomeURL(connectorType connectors.ConnectorType) (string, error) {
switch connectorType {
case connectors.ConnectorGoogle:
return "https://www.google.com/", nil
+ case connectors.ConnectorGitHub:
+ return "https://github.com/", nil
+ case connectors.ConnectorReddit:
+ return "https://www.reddit.com/", nil
+ case connectors.ConnectorZhihu:
+ return "https://www.zhihu.com/", nil
+ case connectors.ConnectorX:
+ return "https://x.com/", nil
case connectors.ConnectorXiaohongshu:
return "https://www.xiaohongshu.com/", nil
case connectors.ConnectorBilibili:
@@ -71,48 +365,167 @@ func connectorHomeURL(connectorType connectors.ConnectorType) (string, error) {
}
}
-func runPlaywrightLogin(ctx context.Context, targetURL string) ([]cookieRecord, error) {
- pw, err := playwright.Run()
+func startConnectorBrowser(preferredBrowser string, headless bool, userDataDir string) (*browsercdp.Runtime, context.Context, context.CancelFunc, error) {
+ runtime, err := browsercdp.Start(context.Background(), browsercdp.LaunchOptions{
+ PreferredBrowser: preferredBrowser,
+ Headless: headless,
+ UserDataDir: userDataDir,
+ })
if err != nil {
- return nil, mapPlaywrightInstallError(err)
+ return nil, nil, nil, err
}
- defer pw.Stop()
-
- browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(false),
- })
+ tabCtx, cancel, err := attachConnectorTab(runtime)
if err != nil {
- return nil, mapPlaywrightInstallError(err)
+ runtime.Stop()
+ return nil, nil, nil, err
}
- defer browser.Close()
+ return runtime, tabCtx, cancel, nil
+}
- browserCtx, err := browser.NewContext()
- if err != nil {
- return nil, err
+func attachConnectorTab(runtime *browsercdp.Runtime) (context.Context, context.CancelFunc, error) {
+ if runtime == nil {
+ return nil, nil, connectors.ErrConnectorSessionDead
}
- page, err := browserCtx.NewPage()
+ targets, err := chromedp.Targets(runtime.BrowserContext())
if err != nil {
+ return nil, nil, err
+ }
+
+ targetID := selectConnectorStartupTarget(targets)
+ var tabCtx context.Context
+ var cancel context.CancelFunc
+ if targetID != "" {
+ tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetID))
+ } else {
+ tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext())
+ }
+ if err := chromedp.Run(tabCtx); err != nil {
+ cancel()
+ return nil, nil, err
+ }
+ return tabCtx, cancel, nil
+}
+
+func selectConnectorStartupTarget(targets []*cdptarget.Info) cdptarget.ID {
+ var fallback cdptarget.ID
+ for _, item := range targets {
+ if item == nil || item.Type != "page" || item.Attached {
+ continue
+ }
+ if isConnectorStartupBlank(item.URL) {
+ return item.TargetID
+ }
+ if fallback == "" {
+ fallback = item.TargetID
+ }
+ }
+ return fallback
+}
+
+func isConnectorStartupBlank(targetURL string) bool {
+ trimmed := strings.TrimSpace(targetURL)
+ if trimmed == "" || trimmed == "about:blank" {
+ return true
+ }
+ return strings.HasPrefix(trimmed, "chrome://newtab") ||
+ strings.HasPrefix(trimmed, "edge://newtab") ||
+ strings.HasPrefix(trimmed, "brave://newtab")
+}
+
+func readConnectorCookies(ctx context.Context) ([]appcookies.Record, error) {
+ var records []appcookies.Record
+ if err := chromedp.Run(ctx, chromedp.ActionFunc(func(actionCtx context.Context) error {
+ items, err := browsercdp.GetAllCookies(actionCtx)
+ if err != nil {
+ return err
+ }
+ records = items
+ return nil
+ })); err != nil {
return nil, err
}
- if _, err := page.Goto(targetURL); err != nil {
+ return records, nil
+}
+
+func readConnectorCookiesFromRuntime(runtime *browsercdp.Runtime) ([]appcookies.Record, error) {
+ if runtime == nil {
+ return nil, connectors.ErrConnectorSessionDead
+ }
+ timeoutCtx, cancel := context.WithTimeout(runtime.BrowserContext(), 8*time.Second)
+ defer cancel()
+
+ var records []appcookies.Record
+ if err := chromedp.Run(timeoutCtx, chromedp.ActionFunc(func(actionCtx context.Context) error {
+ items, err := browsercdp.GetStorageCookies(actionCtx)
+ if err != nil {
+ return err
+ }
+ records = items
+ return nil
+ })); err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil, fmt.Errorf("connector cookie read timed out: %w", err)
+ }
return nil, err
}
+ return records, nil
+}
- closeTicker := time.NewTicker(500 * time.Millisecond)
- defer closeTicker.Stop()
+func waitForConnectorTabClose(ctx context.Context, runtime *browsercdp.Runtime, tabCtx context.Context, captureCookies bool, readCookies func(context.Context) ([]appcookies.Record, error)) ([]appcookies.Record, error) {
+ ticker := time.NewTicker(500 * time.Millisecond)
+ defer ticker.Stop()
+ var latest []appcookies.Record
for {
select {
case <-ctx.Done():
- return nil, ctx.Err()
- case <-closeTicker.C:
- if page.IsClosed() {
- cookies, err := browserCtx.Cookies()
- if err != nil {
- return nil, err
+ return latest, ctx.Err()
+ case <-ticker.C:
+ if captureCookies && readCookies != nil {
+ if cookies, err := readCookies(tabCtx); err == nil {
+ latest = cookies
}
- return cookiesFromPlaywright(cookies), nil
}
+ if runtime == nil || !runtime.Status().Ready {
+ return latest, nil
+ }
+ var currentURL string
+ if err := chromedp.Run(tabCtx, chromedp.Location(¤tURL)); err != nil {
+ return latest, nil
+ }
+ }
+ }
+}
+
+func connectorSessionDir(connectorType connectors.ConnectorType, sessionID string) string {
+ return filepath.Join(connectorSessionRootDir(), string(connectorType), sessionID)
+}
+
+func connectorOpenDir(connectorType connectors.ConnectorType, sessionID string) string {
+ return filepath.Join(connectorSessionRootDir(), "open", string(connectorType), sessionID)
+}
+
+func connectorSessionRootDir() string {
+ return filepath.Join(os.TempDir(), "dreamcreator", "connectors")
+}
+
+func cookieDomains(records []appcookies.Record) []string {
+ if len(records) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(records))
+ result := make([]string, 0, len(records))
+ for _, record := range records {
+ domain := strings.TrimSpace(record.Domain)
+ if domain == "" {
+ continue
+ }
+ if _, ok := seen[domain]; ok {
+ continue
}
+ seen[domain] = struct{}{}
+ result = append(result, domain)
}
+ sort.Strings(result)
+ return result
}
diff --git a/internal/application/connectors/service/open.go b/internal/application/connectors/service/open.go
index 15777c7..74688bf 100644
--- a/internal/application/connectors/service/open.go
+++ b/internal/application/connectors/service/open.go
@@ -3,10 +3,10 @@ package service
import (
"context"
"strings"
- "time"
- "github.com/playwright-community/playwright-go"
+ "github.com/chromedp/chromedp"
+ "dreamcreator/internal/application/browsercdp"
"dreamcreator/internal/application/connectors/dto"
"dreamcreator/internal/domain/connectors"
)
@@ -28,52 +28,27 @@ func (service *ConnectorsService) OpenConnectorSite(ctx context.Context, request
if err != nil {
return err
}
- return runPlaywrightOpenWithCookies(ctx, targetURL, cookies)
-}
-
-func runPlaywrightOpenWithCookies(ctx context.Context, targetURL string, cookies []cookieRecord) error {
- pw, err := playwright.Run()
- if err != nil {
- return mapPlaywrightInstallError(err)
- }
- defer pw.Stop()
-
- browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(false),
- })
- if err != nil {
- return mapPlaywrightInstallError(err)
- }
- defer browser.Close()
-
- browserCtx, err := browser.NewContext()
+ userDataDir := connectorOpenDir(connector.Type, service.newSessionID())
+ runtime, tabCtx, cancel, err := service.startBrowser(service.preferredBrowser(ctx), false, userDataDir)
if err != nil {
return err
}
- if len(cookies) > 0 {
- if err := browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL)); err != nil {
- return err
+ defer cancel()
+ defer runtime.Stop()
+ defer func() {
+ if service.removeAll != nil {
+ _ = service.removeAll(userDataDir)
}
- }
- page, err := browserCtx.NewPage()
- if err != nil {
+ }()
+
+ if err := chromedp.Run(tabCtx, chromedp.ActionFunc(func(ctx context.Context) error {
+ return browsercdp.SetCookies(ctx, targetURL, cookies)
+ })); err != nil {
return err
}
- if _, err := page.Goto(targetURL); err != nil {
+ if err := chromedp.Run(tabCtx, chromedp.Navigate(targetURL)); err != nil {
return err
}
-
- closeTicker := time.NewTicker(500 * time.Millisecond)
- defer closeTicker.Stop()
-
- for {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-closeTicker.C:
- if page.IsClosed() {
- return nil
- }
- }
- }
+ _, err = waitForConnectorTabClose(ctx, runtime, tabCtx, false, service.readCookies)
+ return err
}
diff --git a/internal/application/connectors/service/playwright.go b/internal/application/connectors/service/playwright.go
deleted file mode 100644
index 414c63d..0000000
--- a/internal/application/connectors/service/playwright.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "strings"
-
- "github.com/playwright-community/playwright-go"
-)
-
-var ErrPlaywrightNotInstalled = errors.New("playwright not installed")
-
-func (service *ConnectorsService) InstallPlaywright(_ context.Context) error {
- return playwright.Install(&playwright.RunOptions{
- Browsers: []string{"chromium"},
- })
-}
-
-func mapPlaywrightInstallError(err error) error {
- if err == nil {
- return nil
- }
- if isPlaywrightInstallError(err.Error()) {
- return ErrPlaywrightNotInstalled
- }
- return err
-}
-
-func isPlaywrightInstallError(message string) bool {
- lowered := strings.ToLower(message)
- if strings.Contains(lowered, "playwright not installed") {
- return true
- }
- if strings.Contains(lowered, "please install") && (strings.Contains(lowered, "playwright") || strings.Contains(lowered, "driver")) {
- return true
- }
- if strings.Contains(lowered, "driver exists but version not") {
- return true
- }
- if strings.Contains(lowered, "executable doesn't exist") {
- return true
- }
- return false
-}
diff --git a/internal/application/connectors/service/service.go b/internal/application/connectors/service/service.go
index 60ee791..0df586f 100644
--- a/internal/application/connectors/service/service.go
+++ b/internal/application/connectors/service/service.go
@@ -2,33 +2,111 @@ package service
import (
"context"
+ "os"
"strings"
+ "sync"
"time"
+ "github.com/chromedp/cdproto/target"
"github.com/google/uuid"
+ "dreamcreator/internal/application/browsercdp"
"dreamcreator/internal/application/connectors/dto"
+ appcookies "dreamcreator/internal/application/cookies"
+ settingsdto "dreamcreator/internal/application/settings/dto"
"dreamcreator/internal/domain/connectors"
)
+type SettingsReader interface {
+ GetSettings(ctx context.Context) (settingsdto.Settings, error)
+}
+
type ConnectorsService struct {
- repo connectors.Repository
- now func() time.Time
+ repo connectors.Repository
+ settings SettingsReader
+ now func() time.Time
+
+ mu sync.Mutex
+ sessions map[string]*connectorSession
+ sessionsByConnector map[string]string
+ startBrowser func(preferredBrowser string, headless bool, userDataDir string) (*browsercdp.Runtime, context.Context, context.CancelFunc, error)
+ readCookies func(ctx context.Context) ([]appcookies.Record, error)
+ removeAll func(path string) error
+ newSessionID func() string
+}
+
+const (
+ connectorSessionStateRunning = "running"
+ connectorSessionStateCompleted = "completed"
+ connectorSessionStateFailed = "failed"
+)
+
+type connectorSession struct {
+ ID string
+ ConnectorID string
+ ConnectorType connectors.ConnectorType
+ Runtime *browsercdp.Runtime
+ TabCtx context.Context
+ Cancel context.CancelFunc
+ UserDataDir string
+ TargetID target.ID
+ State string
+ LastCookies []appcookies.Record
+ LastCookiesAt time.Time
+ FinalResult *dto.FinishConnectorConnectResult
+ FinalError string
+ ConnectorSnapshot dto.Connector
+ finalizeOnce sync.Once
+ finalizeDone chan struct{}
}
-func NewConnectorsService(repo connectors.Repository) *ConnectorsService {
+func NewConnectorsService(repo connectors.Repository, settings SettingsReader) *ConnectorsService {
return &ConnectorsService{
- repo: repo,
- now: time.Now,
+ repo: repo,
+ settings: settings,
+ now: time.Now,
+ sessions: make(map[string]*connectorSession),
+ sessionsByConnector: make(map[string]string),
+ startBrowser: startConnectorBrowser,
+ readCookies: readConnectorCookies,
+ removeAll: os.RemoveAll,
+ newSessionID: uuid.NewString,
}
}
+func (service *ConnectorsService) preferredBrowser(ctx context.Context) string {
+ if service == nil || service.settings == nil {
+ return ""
+ }
+ current, err := service.settings.GetSettings(ctx)
+ if err != nil {
+ return ""
+ }
+ if tools := current.Tools; tools != nil {
+ if browserRaw, ok := tools["browser"].(map[string]any); ok && browserRaw != nil {
+ if value, ok := browserRaw["preferredBrowser"].(string); ok && strings.TrimSpace(value) != "" {
+ return strings.ToLower(strings.TrimSpace(value))
+ }
+ }
+ if fetchRaw, ok := tools["web_fetch"].(map[string]any); ok && fetchRaw != nil {
+ if value, ok := fetchRaw["preferredBrowser"].(string); ok && strings.TrimSpace(value) != "" {
+ return strings.ToLower(strings.TrimSpace(value))
+ }
+ }
+ }
+ return ""
+}
+
func (service *ConnectorsService) EnsureDefaults(ctx context.Context) error {
defaults := []struct {
ID string
Type connectors.ConnectorType
}{
{ID: "connector-google", Type: connectors.ConnectorGoogle},
+ {ID: "connector-github", Type: connectors.ConnectorGitHub},
+ {ID: "connector-reddit", Type: connectors.ConnectorReddit},
+ {ID: "connector-zhihu", Type: connectors.ConnectorZhihu},
+ {ID: "connector-x", Type: connectors.ConnectorX},
{ID: "connector-xiaohongshu", Type: connectors.ConnectorXiaohongshu},
{ID: "connector-bilibili", Type: connectors.ConnectorBilibili},
}
@@ -163,9 +241,148 @@ func (service *ConnectorsService) ClearConnector(ctx context.Context, request dt
return service.repo.Save(ctx, updated)
}
+func (service *ConnectorsService) putSession(session *connectorSession) *connectorSession {
+ if service == nil || session == nil {
+ return nil
+ }
+ service.mu.Lock()
+ defer service.mu.Unlock()
+
+ var replaced *connectorSession
+ if currentID, ok := service.sessionsByConnector[session.ConnectorID]; ok && currentID != "" {
+ replaced = service.sessions[currentID]
+ delete(service.sessions, currentID)
+ }
+ service.sessions[session.ID] = session
+ service.sessionsByConnector[session.ConnectorID] = session.ID
+ return replaced
+}
+
+func (service *ConnectorsService) getSession(sessionID string) (*connectorSession, bool) {
+ if service == nil {
+ return nil, false
+ }
+ service.mu.Lock()
+ defer service.mu.Unlock()
+ session, ok := service.sessions[sessionID]
+ return session, ok
+}
+
+func (service *ConnectorsService) updateSession(sessionID string, update func(session *connectorSession)) (*connectorSession, bool) {
+ if service == nil {
+ return nil, false
+ }
+ service.mu.Lock()
+ defer service.mu.Unlock()
+ session, ok := service.sessions[sessionID]
+ if !ok || session == nil {
+ return nil, false
+ }
+ update(session)
+ return session, true
+}
+
+func (service *ConnectorsService) popSession(sessionID string) *connectorSession {
+ if service == nil {
+ return nil
+ }
+ service.mu.Lock()
+ defer service.mu.Unlock()
+
+ session := service.sessions[sessionID]
+ if session == nil {
+ return nil
+ }
+ delete(service.sessions, sessionID)
+ if currentID, ok := service.sessionsByConnector[session.ConnectorID]; ok && currentID == sessionID {
+ delete(service.sessionsByConnector, session.ConnectorID)
+ }
+ return session
+}
+
+func (service *ConnectorsService) cleanupSession(session *connectorSession) {
+ if session == nil {
+ return
+ }
+ if session.Cancel != nil {
+ session.Cancel()
+ }
+ if session.Runtime != nil {
+ session.Runtime.Stop()
+ }
+ if service.removeAll != nil && strings.TrimSpace(session.UserDataDir) != "" {
+ _ = service.removeAll(session.UserDataDir)
+ }
+}
+
+func (service *ConnectorsService) GetConnectorConnectSession(ctx context.Context, request dto.GetConnectorConnectSessionRequest) (dto.ConnectorConnectSession, error) {
+ sessionID := strings.TrimSpace(request.SessionID)
+ if sessionID == "" {
+ return dto.ConnectorConnectSession{}, connectors.ErrConnectorSessionGone
+ }
+ session, ok := service.getSession(sessionID)
+ if !ok {
+ return dto.ConnectorConnectSession{}, connectors.ErrConnectorSessionGone
+ }
+ return service.snapshotSession(ctx, session), nil
+}
+
+func (service *ConnectorsService) snapshotSession(ctx context.Context, session *connectorSession) dto.ConnectorConnectSession {
+ if session == nil {
+ return dto.ConnectorConnectSession{}
+ }
+ service.mu.Lock()
+ snapshotID := session.ID
+ snapshotConnectorID := session.ConnectorID
+ snapshotState := session.State
+ snapshotLastCookiesAt := session.LastCookiesAt
+ snapshotFinalError := session.FinalError
+ snapshotConnector := session.ConnectorSnapshot
+ var snapshotFinalResult *dto.FinishConnectorConnectResult
+ if session.FinalResult != nil {
+ copyResult := *session.FinalResult
+ copyResult.Domains = append([]string(nil), session.FinalResult.Domains...)
+ snapshotFinalResult = ©Result
+ }
+ service.mu.Unlock()
+
+ connector := snapshotConnector
+ if snapshotFinalResult != nil {
+ connector = snapshotFinalResult.Connector
+ } else if current, err := service.repo.Get(ctx, snapshotConnectorID); err == nil {
+ connector = mapConnectorDTO(current)
+ }
+ lastCookiesAt := ""
+ if !snapshotLastCookiesAt.IsZero() {
+ lastCookiesAt = snapshotLastCookiesAt.Format(time.RFC3339)
+ }
+ result := dto.ConnectorConnectSession{
+ SessionID: snapshotID,
+ ConnectorID: snapshotConnectorID,
+ State: snapshotState,
+ Error: snapshotFinalError,
+ LastCookiesAt: lastCookiesAt,
+ Connector: connector,
+ }
+ if snapshotFinalResult != nil {
+ result.Saved = snapshotFinalResult.Saved
+ result.RawCookiesCount = snapshotFinalResult.RawCookiesCount
+ result.FilteredCookiesCount = snapshotFinalResult.FilteredCookiesCount
+ result.Domains = append([]string(nil), snapshotFinalResult.Domains...)
+ result.Reason = snapshotFinalResult.Reason
+ }
+ return result
+}
+
func isSupportedConnectorType(connectorType connectors.ConnectorType) bool {
switch connectorType {
- case connectors.ConnectorGoogle, connectors.ConnectorXiaohongshu, connectors.ConnectorBilibili:
+ case connectors.ConnectorGoogle,
+ connectors.ConnectorGitHub,
+ connectors.ConnectorReddit,
+ connectors.ConnectorZhihu,
+ connectors.ConnectorX,
+ connectors.ConnectorXiaohongshu,
+ connectors.ConnectorBilibili:
return true
default:
return false
diff --git a/internal/application/cookies/cookies.go b/internal/application/cookies/cookies.go
new file mode 100644
index 0000000..0a3631e
--- /dev/null
+++ b/internal/application/cookies/cookies.go
@@ -0,0 +1,105 @@
+package cookies
+
+import (
+ "encoding/json"
+ "net/url"
+ "strings"
+
+ "dreamcreator/internal/application/sitepolicy"
+)
+
+type Record struct {
+ Name string `json:"name"`
+ Value string `json:"value"`
+ Domain string `json:"domain"`
+ Path string `json:"path"`
+ Expires int64 `json:"expires"`
+ HttpOnly bool `json:"httpOnly"`
+ Secure bool `json:"secure"`
+ SameSite string `json:"sameSite,omitempty"`
+}
+
+func EncodeJSON(records []Record) (string, error) {
+ if len(records) == 0 {
+ return "", nil
+ }
+ data, err := json.Marshal(records)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func DecodeJSON(data string) []Record {
+ trimmed := strings.TrimSpace(data)
+ if trimmed == "" {
+ return nil
+ }
+ var records []Record
+ if err := json.Unmarshal([]byte(trimmed), &records); err != nil {
+ return nil
+ }
+ return records
+}
+
+func FilterByDomains(records []Record, domains []string) []Record {
+ if len(records) == 0 || len(domains) == 0 {
+ return nil
+ }
+ result := make([]Record, 0, len(records))
+ for _, record := range records {
+ domain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(record.Domain)), ".")
+ if domain == "" {
+ continue
+ }
+ for _, allowed := range domains {
+ if sitepolicy.HostMatchesDomain(domain, allowed) {
+ result = append(result, normalizeRecord(record))
+ break
+ }
+ }
+ }
+ return result
+}
+
+func MatchURL(records []Record, rawURL string) []Record {
+ if len(records) == 0 {
+ return nil
+ }
+ parsed, err := url.Parse(strings.TrimSpace(rawURL))
+ if err != nil {
+ return nil
+ }
+ host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ path := parsed.EscapedPath()
+ if path == "" {
+ path = "/"
+ }
+ result := make([]Record, 0, len(records))
+ for _, record := range records {
+ domain := strings.TrimSpace(record.Domain)
+ if domain != "" && !sitepolicy.HostMatchesDomain(host, domain) {
+ continue
+ }
+ cookiePath := strings.TrimSpace(record.Path)
+ if cookiePath == "" {
+ cookiePath = "/"
+ }
+ if !strings.HasPrefix(path, cookiePath) {
+ continue
+ }
+ result = append(result, normalizeRecord(record))
+ }
+ return result
+}
+
+func normalizeRecord(record Record) Record {
+ record.Name = strings.TrimSpace(record.Name)
+ record.Domain = strings.TrimSpace(record.Domain)
+ record.Path = strings.TrimSpace(record.Path)
+ record.SameSite = strings.ToLower(strings.TrimSpace(record.SameSite))
+ if record.Path == "" {
+ record.Path = "/"
+ }
+ return record
+}
diff --git a/internal/application/externaltools/service/service.go b/internal/application/externaltools/service/service.go
index 2d0e908..90e5901 100644
--- a/internal/application/externaltools/service/service.go
+++ b/internal/application/externaltools/service/service.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/google/uuid"
- "github.com/playwright-community/playwright-go"
"dreamcreator/internal/application/externaltools/dto"
"dreamcreator/internal/application/softwareupdate"
@@ -70,16 +69,9 @@ var (
SourceRef: "clawhub",
Manager: toolManagerBun,
},
- externaltools.ToolPlaywright: {
- ToolKind: string(externaltools.KindRuntime),
- Kind: sourceKindRuntime,
- SourceRef: "playwright-community/playwright-go",
- },
}
- semverTokenPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?$`)
- playwrightVersionTokenPattern = regexp.MustCompile(`^\d+(?:\.\d+){2,}(?:[-+][0-9A-Za-z.-]+)?$`)
- playwrightVersionScanPattern = regexp.MustCompile(`\d+(?:\.\d+){2,}(?:[-+][0-9A-Za-z.-]+)?`)
+ semverTokenPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?$`)
)
type ExternalToolsService struct {
@@ -163,7 +155,6 @@ func (service *ExternalToolsService) EnsureDefaults(ctx context.Context) error {
externaltools.ToolFFmpeg,
externaltools.ToolBun,
externaltools.ToolClawHub,
- externaltools.ToolPlaywright,
}
existing, err := service.repo.List(ctx)
if err != nil {
@@ -301,17 +292,8 @@ func (service *ExternalToolsService) InstallTool(ctx context.Context, request dt
case sourceKindNPMRegistry:
return service.installNPMRegistryTool(ctx, toolName, source, request.Version, manager)
case sourceKindRuntime:
- if manager != "" {
- service.setInstallState(toolName, installStageError, downloadProgressStart, "manager is unsupported for this tool")
- return dto.ExternalTool{}, fmt.Errorf("manager is unsupported for tool %s", toolName)
- }
- switch toolName {
- case externaltools.ToolPlaywright:
- return service.installPlaywrightRuntime(ctx)
- default:
- service.setInstallState(toolName, installStageError, downloadProgressStart, "invalid tool")
- return dto.ExternalTool{}, externaltools.ErrInvalidTool
- }
+ service.setInstallState(toolName, installStageError, downloadProgressStart, "runtime tools are not supported")
+ return dto.ExternalTool{}, externaltools.ErrInvalidTool
default:
service.setInstallState(toolName, installStageError, downloadProgressStart, "unsupported source")
return dto.ExternalTool{}, fmt.Errorf("unsupported source for tool %s", toolName)
@@ -408,14 +390,10 @@ func (service *ExternalToolsService) RemoveTool(ctx context.Context, request dto
if name == "" {
return externaltools.ErrInvalidTool
}
- toolName := externaltools.ToolName(name)
tool, err := service.repo.Get(ctx, name)
if err != nil && err != externaltools.ErrToolNotFound {
return err
}
- if toolName == externaltools.ToolPlaywright {
- _ = uninstallPlaywrightRuntime()
- }
if err == nil && tool.ExecPath != "" {
_ = os.RemoveAll(filepath.Dir(tool.ExecPath))
}
@@ -1192,134 +1170,6 @@ func (service *ExternalToolsService) installNPMRegistryTool(ctx context.Context,
return toExternalToolDTO(tool), nil
}
-func (service *ExternalToolsService) installPlaywrightRuntime(ctx context.Context) (dto.ExternalTool, error) {
- service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, downloadProgressStart, "")
- stopProgress := make(chan struct{})
- var stopProgressOnce sync.Once
- stopProgressTicker := func() {
- stopProgressOnce.Do(func() {
- close(stopProgress)
- })
- }
- defer stopProgressTicker()
-
- go func() {
- ticker := time.NewTicker(350 * time.Millisecond)
- defer ticker.Stop()
- progress := downloadProgressStart
- ceiling := downloadProgressEnd - 2
- for {
- select {
- case <-stopProgress:
- return
- case <-ctx.Done():
- return
- case <-ticker.C:
- if progress >= ceiling {
- continue
- }
- progress += 2
- if progress > ceiling {
- progress = ceiling
- }
- service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, progress, "")
- }
- }
- }()
-
- if err := playwright.Install(&playwright.RunOptions{
- Browsers: []string{"chromium"},
- Verbose: false,
- Stdout: io.Discard,
- Stderr: io.Discard,
- }); err != nil {
- stopProgressTicker()
- service.setInstallState(externaltools.ToolPlaywright, installStageError, downloadProgressStart, err.Error())
- return dto.ExternalTool{}, err
- }
- stopProgressTicker()
- service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, downloadProgressEnd, "")
- service.setInstallState(externaltools.ToolPlaywright, installStageVerifying, verifyProgressStart, "")
- execPath, version, err := resolvePlaywrightRuntime(ctx)
- if err != nil {
- service.setInstallState(externaltools.ToolPlaywright, installStageError, verifyProgressStart, err.Error())
- return dto.ExternalTool{}, err
- }
- now := service.now()
- tool, err := externaltools.NewExternalTool(externaltools.ExternalToolParams{
- Name: string(externaltools.ToolPlaywright),
- ExecPath: execPath,
- Version: version,
- Status: string(externaltools.StatusInstalled),
- InstalledAt: &now,
- UpdatedAt: &now,
- })
- if err != nil {
- return dto.ExternalTool{}, err
- }
- if err := service.repo.Save(ctx, tool); err != nil {
- service.setInstallState(externaltools.ToolPlaywright, installStageError, verifyProgressEnd, err.Error())
- return dto.ExternalTool{}, err
- }
- service.setInstallState(externaltools.ToolPlaywright, installStageDone, verifyProgressEnd, "")
- return toExternalToolDTO(tool), nil
-}
-
-func resolvePlaywrightRuntime(ctx context.Context) (string, string, error) {
- pw, err := playwright.Run(&playwright.RunOptions{
- Verbose: false,
- Stdout: io.Discard,
- Stderr: io.Discard,
- SkipInstallBrowsers: true,
- })
- if err != nil {
- return "", "", err
- }
- defer pw.Stop()
- execPath := strings.TrimSpace(pw.Chromium.ExecutablePath())
- if execPath == "" {
- return "", "", fmt.Errorf("playwright chromium executable path is empty")
- }
- browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(true),
- Args: []string{"--headless=new"},
- })
- if err != nil {
- return "", "", err
- }
- browserVersion := strings.TrimSpace(browser.Version())
- if closeErr := browser.Close(); closeErr != nil {
- return "", "", closeErr
- }
- version := browserVersion
- if version != "" {
- if parsedVersion, parseErr := parsePlaywrightVersion(version); parseErr == nil {
- version = parsedVersion
- }
- }
- if strings.TrimSpace(version) == "" {
- resolvedVersion, err := resolveVersion(ctx, externaltools.ToolPlaywright, execPath)
- if err != nil {
- return execPath, "", nil
- }
- version = resolvedVersion
- }
- return execPath, strings.TrimSpace(version), nil
-}
-
-func uninstallPlaywrightRuntime() error {
- driver, err := playwright.NewDriver(&playwright.RunOptions{
- Verbose: false,
- Stdout: io.Discard,
- Stderr: io.Discard,
- SkipInstallBrowsers: true,
- })
- if err != nil {
- return err
- }
- return driver.Uninstall()
-}
-
func executableNameForBinary(name string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
@@ -2292,8 +2142,6 @@ func resolveVersion(ctx context.Context, name externaltools.ToolName, execPath s
args = []string{"-version"}
case externaltools.ToolClawHub:
args = []string{"--cli-version"}
- case externaltools.ToolPlaywright:
- args = []string{"--version"}
default:
args = []string{"--version"}
}
@@ -2319,8 +2167,6 @@ func resolveVersion(ctx context.Context, name externaltools.ToolName, execPath s
return parseFFmpegVersion(text)
case externaltools.ToolClawHub:
return parseClawHubVersion(text)
- case externaltools.ToolPlaywright:
- return parsePlaywrightVersion(text)
default:
return strings.Fields(text)[0], nil
}
@@ -2350,20 +2196,6 @@ func parseClawHubVersion(output string) (string, error) {
return "", fmt.Errorf("clawhub version not found")
}
-func parsePlaywrightVersion(output string) (string, error) {
- for _, token := range strings.Fields(output) {
- candidate := strings.Trim(strings.TrimSpace(token), ",;:()[]{}")
- candidate = strings.TrimPrefix(strings.TrimPrefix(candidate, "v"), "V")
- if playwrightVersionTokenPattern.MatchString(candidate) {
- return candidate, nil
- }
- }
- if matched := strings.TrimSpace(playwrightVersionScanPattern.FindString(output)); matched != "" {
- return matched, nil
- }
- return "", fmt.Errorf("playwright version not found")
-}
-
func percent(written int64, total int64) int {
if total <= 0 {
return 0
diff --git a/internal/application/externaltools/service/service_test.go b/internal/application/externaltools/service/service_test.go
index 2bf06e9..1f3a48d 100644
--- a/internal/application/externaltools/service/service_test.go
+++ b/internal/application/externaltools/service/service_test.go
@@ -60,7 +60,6 @@ func TestEnsureDefaultsIncludesBunAndClawHub(t *testing.T) {
externaltools.ToolFFmpeg,
externaltools.ToolBun,
externaltools.ToolClawHub,
- externaltools.ToolPlaywright,
} {
if _, err := repo.Get(context.Background(), string(name)); err != nil {
t.Fatalf("expected default tool %s: %v", name, err)
@@ -291,14 +290,14 @@ func TestToExternalToolDTOUsesManagedVersionFromPath(t *testing.T) {
}
}
-func TestToExternalToolDTOIncludesRuntimeKind(t *testing.T) {
+func TestToExternalToolDTOIncludesSourceMetadata(t *testing.T) {
t.Parallel()
now := time.Now()
item, err := externaltools.NewExternalTool(externaltools.ExternalToolParams{
- Name: string(externaltools.ToolPlaywright),
- ExecPath: "/tmp/chromium",
- Version: "136.0.7103.25",
+ Name: string(externaltools.ToolClawHub),
+ ExecPath: "/tmp/clawhub",
+ Version: "1.2.3",
Status: string(externaltools.StatusInstalled),
UpdatedAt: &now,
})
@@ -306,30 +305,16 @@ func TestToExternalToolDTOIncludesRuntimeKind(t *testing.T) {
t.Fatalf("new external tool failed: %v", err)
}
result := toExternalToolDTO(item)
- if result.Kind != string(externaltools.KindRuntime) {
- t.Fatalf("expected runtime kind, got %q", result.Kind)
+ if result.Kind != string(externaltools.KindBin) {
+ t.Fatalf("expected bin kind, got %q", result.Kind)
}
- if result.SourceKind != sourceKindRuntime {
- t.Fatalf("expected runtime source kind, got %q", result.SourceKind)
+ if result.SourceKind != sourceKindNPMRegistry {
+ t.Fatalf("expected npm registry source kind, got %q", result.SourceKind)
}
-}
-
-func TestParsePlaywrightVersion(t *testing.T) {
- t.Parallel()
-
- version, err := parsePlaywrightVersion("Chromium 136.0.7103.25")
- if err != nil {
- t.Fatalf("parse playwright version failed: %v", err)
- }
- if version != "136.0.7103.25" {
- t.Fatalf("unexpected playwright version: %s", version)
- }
-
- scanned, err := parsePlaywrightVersion("HeadlessChrome/136.0.7103.25")
- if err != nil {
- t.Fatalf("parse playwright scanned version failed: %v", err)
+ if result.SourceRef != "clawhub" {
+ t.Fatalf("expected clawhub source ref, got %q", result.SourceRef)
}
- if scanned != "136.0.7103.25" {
- t.Fatalf("unexpected scanned playwright version: %s", scanned)
+ if result.Manager != toolManagerBun {
+ t.Fatalf("expected bun manager, got %q", result.Manager)
}
}
diff --git a/internal/application/gateway/approvals/publisher.go b/internal/application/gateway/approvals/publisher.go
index 42158ec..44de551 100644
--- a/internal/application/gateway/approvals/publisher.go
+++ b/internal/application/gateway/approvals/publisher.go
@@ -6,6 +6,7 @@ import (
"time"
gatewayevents "dreamcreator/internal/application/gateway/events"
+ domainsession "dreamcreator/internal/domain/session"
)
type GatewayEventPublisher struct {
@@ -20,11 +21,11 @@ func (publisher *GatewayEventPublisher) Publish(ctx context.Context, eventType s
if publisher == nil || publisher.events == nil {
return nil
}
- sessionKey := resolveSessionKey(payload)
+ sessionID, sessionKey := resolveSessionIdentity(payload)
envelope := gatewayevents.Envelope{
Type: eventType,
Topic: "exec.approval",
- SessionID: sessionKey,
+ SessionID: sessionID,
SessionKey: sessionKey,
Timestamp: time.Now(),
}
@@ -32,16 +33,27 @@ func (publisher *GatewayEventPublisher) Publish(ctx context.Context, eventType s
return err
}
-func resolveSessionKey(payload any) string {
+func resolveSessionIdentity(payload any) (string, string) {
+ sessionKey := ""
switch value := payload.(type) {
case Request:
- return strings.TrimSpace(value.SessionKey)
+ sessionKey = strings.TrimSpace(value.SessionKey)
case *Request:
if value == nil {
- return ""
+ return "", ""
}
- return strings.TrimSpace(value.SessionKey)
+ sessionKey = strings.TrimSpace(value.SessionKey)
default:
- return ""
+ return "", ""
}
+ if parts, _, err := domainsession.NormalizeSessionKey(sessionKey); err == nil {
+ sessionID := strings.TrimSpace(parts.ThreadRef)
+ if sessionID == "" {
+ sessionID = strings.TrimSpace(parts.PrimaryID)
+ }
+ if sessionID != "" {
+ return sessionID, sessionKey
+ }
+ }
+ return sessionKey, sessionKey
}
diff --git a/internal/application/gateway/approvals/publisher_test.go b/internal/application/gateway/approvals/publisher_test.go
index e5e0a4f..42cc0d6 100644
--- a/internal/application/gateway/approvals/publisher_test.go
+++ b/internal/application/gateway/approvals/publisher_test.go
@@ -6,6 +6,7 @@ import (
"time"
gatewayevents "dreamcreator/internal/application/gateway/events"
+ domainsession "dreamcreator/internal/domain/session"
)
func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) {
@@ -18,9 +19,17 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) {
})
defer unsubscribe()
+ sessionKey, err := domainsession.BuildSessionKey(domainsession.KeyParts{
+ Channel: "aui",
+ PrimaryID: "thread-123",
+ ThreadRef: "thread-123",
+ })
+ if err != nil {
+ t.Fatalf("build session key: %v", err)
+ }
request := Request{
ID: "approval-1",
- SessionKey: "thread-123",
+ SessionKey: sessionKey,
ToolName: "exec",
Action: "config.schema",
}
@@ -30,8 +39,8 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) {
select {
case record := <-recordCh:
- if record.Envelope.SessionKey != "thread-123" {
- t.Fatalf("expected session key thread-123, got %q", record.Envelope.SessionKey)
+ if record.Envelope.SessionKey != sessionKey {
+ t.Fatalf("expected session key %q, got %q", sessionKey, record.Envelope.SessionKey)
}
if record.Envelope.SessionID != "thread-123" {
t.Fatalf("expected session id thread-123, got %q", record.Envelope.SessionID)
@@ -40,3 +49,36 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) {
t.Fatal("expected approval event record")
}
}
+
+func TestGatewayEventPublisherFallsBackToRawSessionKey(t *testing.T) {
+ events := gatewayevents.NewBroker(nil)
+ publisher := NewGatewayEventPublisher(events)
+
+ recordCh := make(chan gatewayevents.Record, 1)
+ unsubscribe := events.Subscribe(gatewayevents.Filter{Type: "exec.approval.requested"}, func(record gatewayevents.Record) {
+ recordCh <- record
+ })
+ defer unsubscribe()
+
+ request := Request{
+ ID: "approval-2",
+ SessionKey: "raw-thread-id",
+ ToolName: "exec",
+ Action: "config.schema",
+ }
+ if err := publisher.Publish(context.Background(), "exec.approval.requested", request); err != nil {
+ t.Fatalf("publish approval event: %v", err)
+ }
+
+ select {
+ case record := <-recordCh:
+ if record.Envelope.SessionKey != "raw-thread-id" {
+ t.Fatalf("expected raw session key, got %q", record.Envelope.SessionKey)
+ }
+ if record.Envelope.SessionID != "raw-thread-id" {
+ t.Fatalf("expected raw session id, got %q", record.Envelope.SessionID)
+ }
+ case <-time.After(time.Second):
+ t.Fatal("expected approval event record")
+ }
+}
diff --git a/internal/application/gateway/runtime/prompt_build.go b/internal/application/gateway/runtime/prompt_build.go
index 1ffb13d..e5ffeab 100644
--- a/internal/application/gateway/runtime/prompt_build.go
+++ b/internal/application/gateway/runtime/prompt_build.go
@@ -305,7 +305,7 @@ func formatSkillsSection(skills []skillsdto.SkillPromptItem) string {
if len(skills) == 0 {
lines = append(lines,
"- (none currently eligible)",
- "- Use `skill_manage.search` to discover skills, `skill_manage.install` to install, then `skills.status` to refresh.",
+ "- Use `skills_manage.search` to discover skills, `skills_manage.install` to install, then `skills.status` to refresh.",
)
return joinLines(lines)
}
@@ -348,8 +348,8 @@ func skillsSectionPreambleLines() []string {
"Skills protocol:",
"- First scan the available skills list.",
"- If one skill clearly matches, read its `SKILL.md` before acting.",
- "- If no skill currently matches the task, call `skill_manage` instead of shell commands.",
- "- Recommended flow: `skills.status` -> `skill_manage.search`/`skill_manage.install` -> `skills.status`.",
+ "- If no skill currently matches the task, call `skills_manage` instead of shell commands.",
+ "- Recommended flow: `skills.status` -> `skills_manage.search`/`skills_manage.install` -> `skills.status`.",
"- If status reports missing runtime dependencies, call `skills.install`, then re-check with `skills.status`.",
"- Use `skills.update` for per-skill settings (enabled/apiKey/env/config) when required.",
"Available skills:",
diff --git a/internal/application/gateway/runtime/prompt_build_test.go b/internal/application/gateway/runtime/prompt_build_test.go
index 6c1a210..da91651 100644
--- a/internal/application/gateway/runtime/prompt_build_test.go
+++ b/internal/application/gateway/runtime/prompt_build_test.go
@@ -233,7 +233,7 @@ func TestBuildPromptDocument_FullIncludesToolingToolCallStyleAndSkillsMandatory(
if !strings.Contains(skills, "## Skills (mandatory)") {
t.Fatalf("expected mandatory skills heading, got %q", skills)
}
- if !strings.Contains(skills, "Recommended flow: `skills.status` -> `skill_manage.search`/`skill_manage.install` -> `skills.status`.") {
+ if !strings.Contains(skills, "Recommended flow: `skills.status` -> `skills_manage.search`/`skills_manage.install` -> `skills.status`.") {
t.Fatalf("expected skills tool protocol rule, got %q", skills)
}
}
diff --git a/internal/application/gateway/runtime/service.go b/internal/application/gateway/runtime/service.go
index 3fadc2f..43fe176 100644
--- a/internal/application/gateway/runtime/service.go
+++ b/internal/application/gateway/runtime/service.go
@@ -404,6 +404,13 @@ func (service *Service) runWithStream(ctx context.Context, request dto.RuntimeRu
defer service.aborts.Unregister(run.ID)
}
defer cancel()
+ if service.tools != nil && strings.TrimSpace(sessionKey) != "" {
+ defer func() {
+ cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cleanupCancel()
+ service.tools.CleanupRuntimeSession(cleanupCtx, sessionKey)
+ }()
+ }
if service.queue != nil && flags.UseQueue {
ticket, _, err := service.queue.Enqueue(runCtx, queue.EnqueueRequest{
diff --git a/internal/application/gateway/runtime/service_session_test.go b/internal/application/gateway/runtime/service_session_test.go
index 34e63a6..e54164a 100644
--- a/internal/application/gateway/runtime/service_session_test.go
+++ b/internal/application/gateway/runtime/service_session_test.go
@@ -9,6 +9,15 @@ import (
sessionapp "dreamcreator/internal/application/session"
)
+const (
+ testTelegramSessionPeerID = "test-user-001"
+ testTelegramSessionThreadID = "test-thread-001"
+ testTelegramSessionPeerName = "Test User"
+ testTelegramSessionPeerNameAlt = "Test User Updated"
+ testTelegramSessionUsername = "testuser"
+ testTelegramSessionAvatarURL = "https://example.com/avatars/test-user-001.jpg"
+)
+
func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testing.T) {
t.Parallel()
@@ -18,16 +27,16 @@ func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testin
}
request := runtimedto.RuntimeRunRequest{
- SessionID: "telegram:default:private:5234834060:conv:26291424",
- SessionKey: "telegram:default:private:5234834060:conv:26291424",
+ SessionID: "telegram:default:private:" + testTelegramSessionPeerID + ":conv:" + testTelegramSessionThreadID,
+ SessionKey: "telegram:default:private:" + testTelegramSessionPeerID + ":conv:" + testTelegramSessionThreadID,
Metadata: map[string]any{
"channel": "telegram",
"accountId": "default",
"peerKind": "direct",
- "peerId": "5234834060",
- "peerName": "Arnold",
- "peerUsername": "arnold",
- "peerAvatarUrl": "https://t.me/i/userpic/320/arnold.jpg",
+ "peerId": testTelegramSessionPeerID,
+ "peerName": testTelegramSessionPeerName,
+ "peerUsername": testTelegramSessionUsername,
+ "peerAvatarUrl": testTelegramSessionAvatarURL,
},
}
@@ -55,13 +64,13 @@ func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testin
if stored.Origin.AccountID != "default" {
t.Fatalf("origin accountId mismatch: %q", stored.Origin.AccountID)
}
- if stored.Origin.PeerID != "5234834060" {
+ if stored.Origin.PeerID != testTelegramSessionPeerID {
t.Fatalf("origin peerId mismatch: %q", stored.Origin.PeerID)
}
- if stored.Origin.PeerName != "Arnold" {
+ if stored.Origin.PeerName != testTelegramSessionPeerName {
t.Fatalf("origin peerName mismatch: %q", stored.Origin.PeerName)
}
- if stored.Origin.PeerUsername != "arnold" {
+ if stored.Origin.PeerUsername != testTelegramSessionUsername {
t.Fatalf("origin peerUsername mismatch: %q", stored.Origin.PeerUsername)
}
if stored.Origin.PeerAvatarURL == "" {
@@ -78,15 +87,15 @@ func TestPersistSession_UpdatesAssistantIDForExistingSession(t *testing.T) {
}
request := runtimedto.RuntimeRunRequest{
- SessionID: "telegram:default:private:5234834060",
- SessionKey: "telegram:default:private:5234834060",
+ SessionID: "telegram:default:private:" + testTelegramSessionPeerID,
+ SessionKey: "telegram:default:private:" + testTelegramSessionPeerID,
Metadata: map[string]any{
"channel": "telegram",
"accountId": "default",
"peerKind": "direct",
- "peerId": "5234834060",
- "peerName": "Arnold HAO",
- "peerUsername": "arnold",
+ "peerId": testTelegramSessionPeerID,
+ "peerName": testTelegramSessionPeerNameAlt,
+ "peerUsername": testTelegramSessionUsername,
},
}
sessionID, sessionKey, err := runtimeService.resolveSession(request)
@@ -110,7 +119,7 @@ func TestPersistSession_UpdatesAssistantIDForExistingSession(t *testing.T) {
if after.AssistantID != "assistant-123" {
t.Fatalf("assistant id was not updated: got %q", after.AssistantID)
}
- if after.Origin.PeerName != "Arnold HAO" {
+ if after.Origin.PeerName != testTelegramSessionPeerNameAlt {
t.Fatalf("origin peer name mismatch: got %q", after.Origin.PeerName)
}
}
@@ -124,16 +133,16 @@ func TestPersistSession_PreservesExistingOriginWhenIncomingMetadataIsIncomplete(
}
request := runtimedto.RuntimeRunRequest{
- SessionID: "telegram:default:private:5234834060",
- SessionKey: "telegram:default:private:5234834060",
+ SessionID: "telegram:default:private:" + testTelegramSessionPeerID,
+ SessionKey: "telegram:default:private:" + testTelegramSessionPeerID,
Metadata: map[string]any{
"channel": "telegram",
"accountId": "default",
"peerKind": "direct",
- "peerId": "5234834060",
- "peerName": "Arnold HAO",
- "peerUsername": "arnold",
- "peerAvatarUrl": "https://t.me/i/userpic/320/arnold.jpg",
+ "peerId": testTelegramSessionPeerID,
+ "peerName": testTelegramSessionPeerNameAlt,
+ "peerUsername": testTelegramSessionUsername,
+ "peerAvatarUrl": testTelegramSessionAvatarURL,
},
}
sessionID, sessionKey, err := runtimeService.resolveSession(request)
@@ -157,10 +166,10 @@ func TestPersistSession_PreservesExistingOriginWhenIncomingMetadataIsIncomplete(
if stored.Origin.AccountID != "default" {
t.Fatalf("origin accountId should be preserved: got %q", stored.Origin.AccountID)
}
- if stored.Origin.PeerID != "5234834060" {
+ if stored.Origin.PeerID != testTelegramSessionPeerID {
t.Fatalf("origin peerId should be preserved: got %q", stored.Origin.PeerID)
}
- if stored.Origin.PeerName != "Arnold HAO" {
+ if stored.Origin.PeerName != testTelegramSessionPeerNameAlt {
t.Fatalf("origin peerName should be preserved: got %q", stored.Origin.PeerName)
}
if stored.Origin.PeerAvatarURL == "" {
diff --git a/internal/application/gateway/subagent/service.go b/internal/application/gateway/subagent/service.go
index 10a8d99..f6d0bdc 100644
--- a/internal/application/gateway/subagent/service.go
+++ b/internal/application/gateway/subagent/service.go
@@ -37,8 +37,7 @@ var subagentToolDenyAlways = []string{
"whatsapp_login",
"session_status",
"cron",
- "memory_recall",
- "memory_list",
+ "memory_query",
"sessions_send",
}
diff --git a/internal/application/gateway/tools/browser_tools.go b/internal/application/gateway/tools/browser_tools.go
index 27ee447..454ee79 100644
--- a/internal/application/gateway/tools/browser_tools.go
+++ b/internal/application/gateway/tools/browser_tools.go
@@ -2,217 +2,40 @@ package tools
import (
"context"
- "encoding/base64"
"encoding/json"
"errors"
"fmt"
- "io"
- "net"
- "net/http"
- "net/url"
- "os"
- "path/filepath"
- "reflect"
- "regexp"
- "sort"
- "strconv"
"strings"
- "sync"
- "sync/atomic"
"time"
+ "dreamcreator/internal/application/browsercdp"
+ appcookies "dreamcreator/internal/application/cookies"
gatewaynodes "dreamcreator/internal/application/gateway/nodes"
- "github.com/playwright-community/playwright-go"
)
-const (
- browserTypePlaywright = "playwright"
- defaultBrowserType = browserTypePlaywright
+var browserToolSessions = browsercdp.NewSessionRegistry()
- defaultBrowserWaitUntil = "domcontentloaded"
-
- defaultBrowserSnapshotModeEfficient = "efficient"
-
- defaultBrowserProfileDreamCreator = "dreamcreator"
- defaultBrowserColor = "#FF4500"
-
- defaultBrowserSnapshotAIMaxChars = 80000
- defaultBrowserSnapshotAIEfficientMaxChars = 10000
- defaultBrowserSnapshotDepth = 6
- defaultBrowserSnapshotLimit = 200
- defaultBrowserViewportWidth = 1366
- defaultBrowserViewportHeight = 900
-
- defaultBrowserHookTimeoutMs = 20000
-
- browserRuntimeCheckCacheTTL = 10 * time.Second
-)
-
-var browserToolActions = []string{
- "status",
- "start",
- "stop",
- "profiles",
- "tabs",
- "open",
- "focus",
- "close",
- "snapshot",
- "screenshot",
- "navigate",
- "console",
- "pdf",
- "upload",
- "dialog",
- "act",
-}
-
-var browserSelectorUnsupportedMessage = strings.Join([]string{
- "Error: 'selector' is not supported. Use 'ref' from snapshot instead.",
- "",
- "Example workflow:",
- "1. snapshot action to get page state with refs",
- `2. act with ref: "e123" to interact with element`,
- "",
- "This is more reliable for modern SPAs.",
-}, "\n")
-
-var browserWaitFnDisabledMessage = strings.Join([]string{
- "wait --fn is disabled by config (browser.evaluateEnabled=false).",
- "Docs: /gateway/configuration#browser-playwright-managed-browser",
-}, "\n")
-
-var browserEvaluateDisabledMessage = strings.Join([]string{
- "act:evaluate is disabled by config (browser.evaluateEnabled=false).",
- "Docs: /gateway/configuration#browser-playwright-managed-browser",
-}, "\n")
-
-var browserWaitRequiresConditionMessage = "wait requires at least one of: timeMs, text, textGone, selector, url, loadState, fn"
-
-var errBrowserSnapshotForAIUnavailable = errors.New("playwright snapshotForAI is unavailable")
-
-var browserActKinds = []string{
- "click",
- "type",
- "press",
- "hover",
- "drag",
- "select",
- "fill",
- "resize",
- "wait",
- "evaluate",
- "close",
+func cleanupBrowserToolSessions(sessionKey string) {
+ browserToolSessions.CloseSessionKey(strings.TrimSpace(sessionKey))
}
-var browserPlaywrightRuntimeCache = struct {
- mu sync.Mutex
- checkedAt time.Time
- available bool
- reason string
- execPath string
-}{}
-
-var browserGlobalTabCounter uint64
-
-var globalBrowserSessions = struct {
- mu sync.Mutex
- sessions map[string]*browserSessionState
-}{
- sessions: map[string]*browserSessionState{},
+func CleanupAllBrowserToolSessions() {
+ browserToolSessions.CloseAll()
}
-type browserSessionState struct {
- sessionKey string
- profiles map[string]*browserProfileState
+func BrowserToolRuntimeConfigChanged(previousTools map[string]any, currentTools map[string]any) bool {
+ previous := resolveBrowserRuntimeConfig(previousTools)
+ current := resolveBrowserRuntimeConfig(currentTools)
+ return previous.Enabled != current.Enabled ||
+ previous.Headless != current.Headless ||
+ previous.PreferredBrowser != current.PreferredBrowser
}
type browserProfileState struct {
- mu sync.Mutex
-
+ sessionKey string
profileName string
resolved browserResolvedConfig
- profile browserProfileConfig
-
- pw *playwright.Playwright
- browser playwright.Browser
- context playwright.BrowserContext
-
- tabs map[string]*browserTabState
- pageToTarget map[playwright.Page]string
- activeTarget string
-
- consoleMessages []browserConsoleMessage
- pendingUploads map[string]browserPendingUpload
- pendingDialogs map[string]browserPendingDialog
-}
-
-type browserTabState struct {
- TargetID string
- Page playwright.Page
-
- mu sync.RWMutex
- refs map[string]browserSnapshotRef
- evaluateResult any
-}
-
-type browserSnapshotRef struct {
- Selector string
- Role string
- Name string
- Nth int
- Mode string
- AriaRef string
- Frame string
-}
-
-type browserConsoleMessage struct {
- TargetID string `json:"targetId"`
- Type string `json:"type"`
- Text string `json:"text"`
- Timestamp string `json:"timestamp"`
-}
-
-type browserPendingUpload struct {
- Paths []string
- ExpiresAt time.Time
-}
-
-type browserPendingDialog struct {
- Accept bool
- PromptText string
- ExpiresAt time.Time
-}
-
-type browserResolvedConfig struct {
- Enabled bool
- EvaluateEnabled bool
- CDPURL string
- RemoteCdpTimeoutMs int
- RemoteCdpHandshakeTimeoutMs int
- Color string
- Headless bool
- NoSandbox bool
- AttachOnly bool
- DefaultProfile string
- Profiles map[string]browserProfileConfig
- SnapshotDefaultMode string
- SSRFRules browserSSRFPolicy
- ExtraArgs []string
-}
-
-type browserProfileConfig struct {
- Name string
- CDPURL string
- CDPPort int
- Color string
- Driver string
-}
-
-type browserSSRFPolicy struct {
- DangerouslyAllowPrivateNetwork bool
- AllowedHostnames map[string]struct{}
- HostnameAllowlist []string
+ session *browsercdp.Session
}
func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes *gatewaynodes.Service) func(ctx context.Context, args string) (string, error) {
@@ -221,7 +44,6 @@ func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes
if err != nil {
return "", err
}
-
action, err := resolveBrowserAction(payload)
if err != nil {
return "", err
@@ -229,2894 +51,320 @@ func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes
if isBrowserNodeTargetRequest(payload) {
return runBrowserActionOnNode(ctx, payload, action, nodes)
}
-
toolsConfig := resolveToolsConfig(ctx, settings)
resolved := resolveBrowserRuntimeConfig(toolsConfig)
if !resolved.Enabled {
return "", errors.New("browser disabled")
}
-
profileName := resolveBrowserProfileName(payload, resolved)
sessionKey := resolveBrowserSessionKey(ctx, payload)
- state := getOrCreateBrowserProfileState(sessionKey, profileName, resolved)
-
- result, err := runBrowserAction(ctx, payload, action, state, connectors)
+ state := getBrowserProfileState(sessionKey, profileName, resolved, connectors)
+ result, err := runBrowserAction(ctx, payload, action, state)
if err != nil {
+ if browsercdp.IsFatalError(err) {
+ return "", fmt.Errorf("browser session reset after runtime failure: %w", err)
+ }
return "", err
}
return marshalResult(result), nil
}
}
-func resolveBrowserAction(payload toolArgs) (string, error) {
- rawAction := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "method")))
- if rawAction == "" {
- if getStringArg(payload, "targetUrl", "url") != "" {
- rawAction = "open"
- } else {
- rawAction = "status"
- }
- }
- switch rawAction {
- case "navigate", "status", "start", "stop", "profiles", "tabs", "open", "focus", "close", "snapshot", "screenshot", "console", "pdf", "upload", "dialog", "act":
- return rawAction, nil
- default:
- return "", errors.New("browser action not supported: " + rawAction)
- }
-}
-
-func isBrowserNodeTargetRequest(payload toolArgs) bool {
- target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target")))
- nodeID := strings.TrimSpace(getStringArg(payload, "node", "nodeId"))
- if target == "node" {
- return true
- }
- return nodeID != ""
-}
-
-func resolveBrowserNodeID(ctx context.Context, payload toolArgs, nodes *gatewaynodes.Service) (string, error) {
- requestedNode := strings.TrimSpace(getStringArg(payload, "node", "nodeId"))
- if requestedNode != "" {
- return requestedNode, nil
- }
- if nodes == nil {
- return "", errors.New("nodes service unavailable")
- }
- list, err := nodes.ListNodes(ctx)
- if err != nil {
- return "", err
- }
- for _, descriptor := range list {
- nodeID := strings.TrimSpace(descriptor.NodeID)
- if nodeID == "" {
- continue
- }
- for _, capability := range descriptor.Capabilities {
- if strings.EqualFold(strings.TrimSpace(capability.Name), "browser.control") {
- return nodeID, nil
- }
- }
- }
- for _, descriptor := range list {
- nodeID := strings.TrimSpace(descriptor.NodeID)
- if nodeID != "" {
- return nodeID, nil
- }
- }
- return "", errors.New("nodeId is required")
-}
-
-func runBrowserActionOnNode(ctx context.Context, payload toolArgs, action string, nodes *gatewaynodes.Service) (string, error) {
- if nodes == nil {
- return "", errors.New("nodes service unavailable")
- }
- target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target")))
- if target != "" && target != "node" {
- return "", errors.New(`node is only supported with target="node"`)
- }
- nodeID, err := resolveBrowserNodeID(ctx, payload, nodes)
- if err != nil {
- return "", err
- }
- argsJSON, err := json.Marshal(payload)
- if err != nil {
- return "", err
- }
- request := gatewaynodes.NodeInvokeRequest{
- NodeID: nodeID,
- Capability: "browser.control",
- Action: action,
- Args: string(argsJSON),
- TimeoutMs: resolveBrowserActionTimeoutMs(payload, 30000),
- }
- result, invokeErr := nodes.Invoke(ctx, request)
- if invokeErr != nil {
- return marshalResult(result), invokeErr
- }
- if !result.Ok {
- if strings.TrimSpace(result.Error) != "" {
- return marshalResult(result), errors.New(strings.TrimSpace(result.Error))
- }
- return marshalResult(result), errors.New("node browser invoke failed")
- }
- if parsed := resolveBrowserNodeOutput(result.Output); parsed != nil {
- return marshalResult(parsed), nil
- }
- return marshalResult(result), nil
-}
-
-type browserNodeProxyEnvelope struct {
- Result any `json:"result"`
- Files []browserNodeProxyFile `json:"files"`
-}
-
-type browserNodeProxyFile struct {
- Path string `json:"path"`
- Base64 string `json:"base64"`
- MimeType string `json:"mimeType"`
-}
-
-func resolveBrowserNodeOutput(output string) any {
- trimmedOutput := strings.TrimSpace(output)
- if trimmedOutput == "" {
- return nil
- }
- var parsed any
- if err := json.Unmarshal([]byte(trimmedOutput), &parsed); err != nil {
- return nil
- }
-
- envelope := browserNodeProxyEnvelope{}
- if err := json.Unmarshal([]byte(trimmedOutput), &envelope); err == nil && envelope.Result != nil {
- mapping := persistBrowserNodeProxyFiles(envelope.Files)
- applyBrowserProxyPathMapping(envelope.Result, mapping)
- return envelope.Result
- }
- return parsed
-}
-
-func persistBrowserNodeProxyFiles(files []browserNodeProxyFile) map[string]string {
- if len(files) == 0 {
- return nil
- }
- mapping := map[string]string{}
- for _, file := range files {
- remotePath := strings.TrimSpace(file.Path)
- encoded := strings.TrimSpace(file.Base64)
- if remotePath == "" || encoded == "" {
- continue
- }
- bytes, err := base64.StdEncoding.DecodeString(encoded)
- if err != nil {
- continue
- }
- localPath, err := saveBrowserArtifact(resolveBrowserProxyFileExt(file), bytes)
- if err != nil {
- continue
- }
- mapping[remotePath] = localPath
- }
- if len(mapping) == 0 {
- return nil
- }
- return mapping
-}
-
-func resolveBrowserProxyFileExt(file browserNodeProxyFile) string {
- ext := strings.TrimSpace(strings.TrimPrefix(filepath.Ext(strings.TrimSpace(file.Path)), "."))
- if ext != "" {
- return ext
- }
- mimeType := strings.ToLower(strings.TrimSpace(file.MimeType))
- switch {
- case strings.Contains(mimeType, "png"):
- return "png"
- case strings.Contains(mimeType, "jpeg"), strings.Contains(mimeType, "jpg"):
- return "jpg"
- case strings.Contains(mimeType, "pdf"):
- return "pdf"
- case strings.Contains(mimeType, "json"):
- return "json"
- case strings.Contains(mimeType, "text"), strings.Contains(mimeType, "plain"):
- return "txt"
- default:
- return "bin"
- }
-}
-
-func applyBrowserProxyPathMapping(result any, mapping map[string]string) {
- if len(mapping) == 0 || result == nil {
- return
- }
- obj, ok := result.(map[string]any)
- if !ok {
- return
- }
- if pathValue, ok := obj["path"].(string); ok {
- if mapped, exists := mapping[pathValue]; exists {
- obj["path"] = mapped
- }
- }
- if imagePathValue, ok := obj["imagePath"].(string); ok {
- if mapped, exists := mapping[imagePathValue]; exists {
- obj["imagePath"] = mapped
- }
+func getBrowserProfileState(sessionKey string, profileName string, resolved browserResolvedConfig, connectors ConnectorsReader) *browserProfileState {
+ options := browsercdp.SessionOptions{
+ SessionKey: sessionKey,
+ ProfileName: profileName,
+ PreferredBrowser: resolved.PreferredBrowser,
+ Headless: resolved.Headless,
+ UserDataDir: browsercdp.ResolveProfileUserDataDir(sessionKey, profileName),
+ SSRFRules: browsercdp.SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: resolved.SSRFRules.DangerouslyAllowPrivateNetwork,
+ AllowedHostnames: cloneBrowserAllowedHostnames(resolved.SSRFRules.AllowedHostnames),
+ HostnameAllowlist: append([]string(nil), resolved.SSRFRules.HostnameAllowlist...),
+ },
+ Cookies: browsercdp.ConnectorCookieProviderFunc(func(ctx context.Context, rawURL string) ([]appcookies.Record, error) {
+ return browsercdp.ResolveConnectorCookiesForURL(ctx, connectors, rawURL)
+ }),
}
- if downloadRaw, exists := obj["download"]; exists {
- if downloadObj, ok := downloadRaw.(map[string]any); ok {
- if pathValue, ok := downloadObj["path"].(string); ok {
- if mapped, exists := mapping[pathValue]; exists {
- downloadObj["path"] = mapped
- }
- }
- }
+ return &browserProfileState{
+ sessionKey: sessionKey,
+ profileName: profileName,
+ resolved: resolved,
+ session: browserToolSessions.GetOrCreate(sessionKey, profileName, options),
}
}
-func resolveBrowserSessionKey(ctx context.Context, payload toolArgs) string {
- sessionKey, _ := RuntimeContextFromContext(ctx)
- sessionKey = strings.TrimSpace(sessionKey)
- if sessionKey == "" {
- sessionKey = strings.TrimSpace(getStringArg(payload, "sessionKey", "session_key"))
+func cloneBrowserAllowedHostnames(values map[string]struct{}) map[string]struct{} {
+ if len(values) == 0 {
+ return map[string]struct{}{}
}
- if sessionKey == "" {
- sessionKey = "default"
+ cloned := make(map[string]struct{}, len(values))
+ for key := range values {
+ cloned[key] = struct{}{}
}
- return sessionKey
+ return cloned
}
-func runBrowserAction(
- ctx context.Context,
- payload toolArgs,
- action string,
- state *browserProfileState,
- connectors ConnectorsReader,
-) (any, error) {
+func runBrowserAction(ctx context.Context, payload toolArgs, action string, state *browserProfileState) (map[string]any, error) {
switch action {
- case "status":
- return browserActionStatus(state)
- case "start":
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
- return browserActionStatus(state)
- case "stop":
- if err := stopBrowserProfile(state); err != nil {
- return nil, err
- }
- return browserActionStatus(state)
- case "profiles":
- return browserActionProfiles(state), nil
- case "tabs":
- return browserActionTabs(state)
case "open":
- return browserActionOpen(ctx, payload, state, connectors)
- case "focus":
- return browserActionFocus(payload, state)
- case "close":
- return browserActionClose(payload, state)
+ return browserActionOpen(ctx, payload, state)
+ case "navigate":
+ return browserActionNavigate(ctx, payload, state)
case "snapshot":
return browserActionSnapshot(payload, state)
- case "screenshot":
- return browserActionScreenshot(payload, state)
- case "navigate":
- return browserActionNavigate(ctx, payload, state, connectors)
- case "console":
- return browserActionConsole(payload, state)
- case "pdf":
- return browserActionPDF(payload, state)
+ case "wait":
+ return browserActionWait(ctx, payload, state)
+ case "scroll":
+ return browserActionScroll(payload, state)
case "upload":
return browserActionUpload(payload, state)
case "dialog":
return browserActionDialog(payload, state)
case "act":
- return browserActionAct(payload, state)
+ return browserActionAct(ctx, payload, state)
+ case "reset":
+ return browserActionReset(payload, state)
default:
return nil, errors.New("browser action not supported: " + action)
}
}
-func browserActionStatus(state *browserProfileState) (map[string]any, error) {
- available, reason, execPath := resolveBrowserPlaywrightRuntimeAvailability()
-
- state.mu.Lock()
- defer state.mu.Unlock()
- pruneClosedTabsLocked(state)
-
- running := state.browser != nil && state.context != nil
- detectedPath := any(nil)
- if strings.TrimSpace(execPath) != "" {
- detectedPath = strings.TrimSpace(execPath)
- }
- detectError := any(nil)
- if !available && strings.TrimSpace(reason) != "" {
- detectError = strings.TrimSpace(reason)
- }
- chosenBrowser := any(nil)
- if running {
- chosenBrowser = "chromium"
- }
-
- return map[string]any{
- "enabled": state.resolved.Enabled,
- "profile": state.profileName,
- "running": running,
- "cdpReady": false,
- "cdpHttp": false,
- "pid": nil,
- "cdpPort": nil,
- "cdpUrl": nil,
- "chosenBrowser": chosenBrowser,
- "detectedBrowser": "chromium",
- "detectedExecutablePath": detectedPath,
- "detectError": detectError,
- "userDataDir": nil,
- "color": state.profile.Color,
- "headless": state.resolved.Headless,
- "noSandbox": state.resolved.NoSandbox,
- "executablePath": detectedPath,
- "attachOnly": false,
- "tabCount": len(state.tabs),
- "activeTargetId": state.activeTarget,
- }, nil
-}
-
-func browserActionProfiles(state *browserProfileState) map[string]any {
- state.mu.Lock()
- defer state.mu.Unlock()
- pruneClosedTabsLocked(state)
-
- names := make([]string, 0, len(state.resolved.Profiles))
- for name := range state.resolved.Profiles {
- names = append(names, name)
- }
- sort.Strings(names)
-
- profiles := make([]map[string]any, 0, len(names))
- for _, name := range names {
- cfg := state.resolved.Profiles[name]
- isCurrent := name == state.profileName
- running := isCurrent && state.browser != nil && state.context != nil
- tabCount := 0
- if isCurrent {
- tabCount = len(state.tabs)
- }
- profiles = append(profiles, map[string]any{
- "name": name,
- "cdpPort": cfg.CDPPort,
- "cdpUrl": cfg.CDPURL,
- "color": cfg.Color,
- "running": running,
- "tabCount": tabCount,
- "isDefault": name == state.resolved.DefaultProfile,
- "isRemote": cfg.CDPURL != "",
- })
- }
-
- return map[string]any{"profiles": profiles}
-}
-
-func browserActionTabs(state *browserProfileState) (map[string]any, error) {
- state.mu.Lock()
- running := state.browser != nil && state.context != nil
- state.mu.Unlock()
- if !running {
- return map[string]any{"running": false, "tabs": []any{}}, nil
- }
-
- tabs, err := listBrowserTabs(state)
+func browserActionOpen(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ result, err := state.session.Open(ctx, strings.TrimSpace(getStringArg(payload, "targetUrl", "url")), browserCommandOptions(payload, 30000))
if err != nil {
return nil, err
}
- return map[string]any{"running": true, "tabs": tabs}, nil
+ return browserResultMap(result), nil
}
-func browserActionOpen(ctx context.Context, payload toolArgs, state *browserProfileState, connectors ConnectorsReader) (map[string]any, error) {
- targetURL := getStringArg(payload, "targetUrl", "url")
- if targetURL == "" {
- return nil, errors.New("targetUrl is required")
- }
- if err := assertBrowserURLAllowed(targetURL, state.resolved.SSRFRules); err != nil {
- return nil, err
- }
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
-
- state.mu.Lock()
- browserCtx := state.context
- state.mu.Unlock()
- if browserCtx == nil {
- return nil, errors.New("browser context unavailable")
- }
-
- if err := addConnectorCookiesToContext(ctx, connectors, browserCtx, targetURL); err != nil {
+func browserActionNavigate(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ newTab, _ := getBoolArg(payload, "newTab")
+ result, err := state.session.Navigate(
+ ctx,
+ strings.TrimSpace(getStringArg(payload, "targetId")),
+ strings.TrimSpace(getStringArg(payload, "targetUrl", "url")),
+ newTab,
+ browserCommandOptions(payload, 30000),
+ )
+ if err != nil {
return nil, err
}
+ return browserResultMap(result), nil
+}
- page, err := browserCtx.NewPage()
+func browserActionSnapshot(payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ result, err := state.session.State(strings.TrimSpace(getStringArg(payload, "targetId")), resolveBrowserSnapshotLimit(payload))
if err != nil {
return nil, err
}
- if _, err := page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{
- Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 30000))),
- WaitUntil: resolveBrowserWaitUntilState(resolveBrowserWaitUntil(payload, defaultBrowserWaitUntil)),
- }); err != nil {
- _ = page.Close()
- return nil, err
- }
+ return browserResultMap(result), nil
+}
- tab := attachBrowserTab(state, page)
- title, _ := page.Title()
- urlValue := strings.TrimSpace(page.URL())
- if urlValue == "" {
- urlValue = strings.TrimSpace(targetURL)
- }
- if err := assertBrowserURLAllowed(urlValue, state.resolved.SSRFRules); err != nil {
+func browserActionWait(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ result, err := state.session.Wait(
+ ctx,
+ strings.TrimSpace(getStringArg(payload, "targetId")),
+ browserWaitRequestFromArgs(payload),
+ browserCommandOptions(payload, 15000),
+ )
+ if err != nil {
return nil, err
}
-
- return map[string]any{
- "targetId": tab.TargetID,
- "title": strings.TrimSpace(title),
- "url": urlValue,
- "type": "page",
- }, nil
+ return browserResultMap(result), nil
}
-func browserActionFocus(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- if targetID == "" {
- return nil, errors.New("targetId is required")
- }
- tab, err := resolveBrowserTab(state, targetID, false)
+func browserActionScroll(payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ deltaX, deltaY := resolveBrowserScrollDelta(payload)
+ result, err := state.session.Scroll(browsercdp.ScrollRequest{
+ TargetID: strings.TrimSpace(getStringArg(payload, "targetId")),
+ Ref: strings.TrimSpace(getStringArg(payload, "ref")),
+ DeltaX: deltaX,
+ DeltaY: deltaY,
+ Limit: resolveBrowserSnapshotLimit(payload),
+ Timeout: browserTimeoutDuration(payload, 15000),
+ })
if err != nil {
return nil, err
}
- if err := tab.Page.BringToFront(); err != nil {
- return nil, err
- }
- return map[string]any{"ok": true, "targetId": tab.TargetID}, nil
+ return browserResultMap(result), nil
}
-func browserActionClose(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, false)
+func browserActionUpload(payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ rawPaths := getStringSliceArg(payload, "paths")
+ rootDir, err := resolveBrowserUploadRootDir()
if err != nil {
return nil, err
}
- if err := tab.Page.Close(); err != nil {
- return nil, err
+ paths := make([]string, 0, len(rawPaths))
+ for _, item := range rawPaths {
+ resolvedPath, err := resolvePathWithinRoot(rootDir, item)
+ if err != nil {
+ return nil, err
+ }
+ paths = append(paths, resolvedPath)
}
-
- state.mu.Lock()
- delete(state.tabs, tab.TargetID)
- delete(state.pageToTarget, tab.Page)
- if state.activeTarget == tab.TargetID {
- state.activeTarget = ""
+ result, err := state.session.Upload(browsercdp.UploadRequest{
+ TargetID: strings.TrimSpace(getStringArg(payload, "targetId")),
+ Ref: strings.TrimSpace(getStringArg(payload, "ref")),
+ Paths: paths,
+ Limit: resolveBrowserSnapshotLimit(payload),
+ Timeout: browserTimeoutDuration(payload, 15000),
+ })
+ if err != nil {
+ return nil, err
}
- pruneClosedTabsLocked(state)
- state.mu.Unlock()
-
- return map[string]any{"ok": true, "targetId": tab.TargetID}, nil
+ return browserResultMap(result), nil
}
-func browserActionSnapshot(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
+func browserActionDialog(payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ var accept *bool
+ if value, ok := getBoolArg(payload, "accept"); ok {
+ accept = &value
}
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
+ result, err := state.session.Dialog(browsercdp.DialogRequest{
+ TargetID: strings.TrimSpace(getStringArg(payload, "targetId")),
+ Accept: accept,
+ PromptText: strings.TrimSpace(getStringArg(payload, "promptText")),
+ Limit: resolveBrowserSnapshotLimit(payload),
+ Timeout: browserTimeoutDuration(payload, 15000),
+ })
if err != nil {
return nil, err
}
+ return browserResultMap(result), nil
+}
- format := strings.ToLower(strings.TrimSpace(getStringArg(payload, "snapshotFormat", "format")))
- if format != "aria" {
- format = "ai"
- }
- refsMode := strings.ToLower(strings.TrimSpace(getStringArg(payload, "refs")))
- if refsMode != "aria" {
- refsMode = "role"
- }
- mode := strings.ToLower(strings.TrimSpace(getStringArg(payload, "mode")))
- if mode == "" {
- mode = strings.TrimSpace(state.resolved.SnapshotDefaultMode)
- }
- if mode != defaultBrowserSnapshotModeEfficient {
- mode = ""
- }
- labels, _ := getBoolArg(payload, "labels")
- if format == "aria" && (labels || mode == defaultBrowserSnapshotModeEfficient) {
- return nil, errors.New("labels/mode=efficient require format=ai")
- }
-
- interactive, hasInteractive := getBoolArg(payload, "interactive")
- if !hasInteractive {
- interactive = mode == defaultBrowserSnapshotModeEfficient
- }
- compact, hasCompact := getBoolArg(payload, "compact")
- if !hasCompact {
- compact = mode == defaultBrowserSnapshotModeEfficient
+func browserActionAct(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ request := getMapArg(payload, "request")
+ if request == nil {
+ return nil, errors.New("request is required")
}
-
- depth, hasDepth := getIntArg(payload, "depth")
- if hasDepth && depth <= 0 {
- hasDepth = false
+ requestArgs := toolArgs(request)
+ kind := strings.ToLower(strings.TrimSpace(getStringArg(requestArgs, "kind")))
+ if kind == "" {
+ return nil, errors.New("request.kind is required")
}
- if !hasDepth && mode == defaultBrowserSnapshotModeEfficient {
- depth = defaultBrowserSnapshotDepth
- hasDepth = true
+ if !containsString(browserActKinds, kind) {
+ return nil, errors.New("act kind not supported: " + kind)
}
-
- limit, ok := getIntArg(payload, "limit")
- if !ok || limit <= 0 {
- limit = defaultBrowserSnapshotLimit
+ if _, hasSelector := request["selector"]; hasSelector && kind != "wait" {
+ return nil, errors.New(browserSelectorUnsupportedMessage)
}
-
- maxChars := 0
- _, hasMaxChars := payload["maxChars"]
- if value, ok := getIntArg(payload, "maxChars"); ok && value > 0 {
- maxChars = value
+ if state == nil || state.session == nil {
+ return nil, errors.New("browser session unavailable")
}
- if format == "ai" && !hasMaxChars {
- if mode == defaultBrowserSnapshotModeEfficient {
- maxChars = defaultBrowserSnapshotAIEfficientMaxChars
- } else {
- maxChars = defaultBrowserSnapshotAIMaxChars
- }
+ actRequest := browsercdp.ActRequest{
+ Kind: kind,
+ TargetID: firstNonEmptyString(strings.TrimSpace(getStringArg(requestArgs, "targetId")), strings.TrimSpace(getStringArg(payload, "targetId"))),
+ Ref: strings.TrimSpace(getStringArg(requestArgs, "ref")),
+ Text: getStringArg(requestArgs, "text"),
+ Key: strings.TrimSpace(getStringArg(requestArgs, "key")),
+ Value: getStringArg(requestArgs, "value"),
+ Expression: getStringArg(requestArgs, "expression"),
+ Limit: resolveBrowserSnapshotLimit(payload),
+ Timeout: browserActTimeoutDuration(requestArgs, payload, 15000),
+ Wait: browserWaitRequestFromArgs(requestArgs),
+ WaitFor: browserOptionalWaitRequest(getMapArg(requestArgs, "waitFor")),
}
-
- selector := strings.TrimSpace(getStringArg(payload, "selector"))
- frameSelector := strings.TrimSpace(getStringArg(payload, "frame"))
- if refsMode == "aria" && (selector != "" || frameSelector != "") {
- return nil, errors.New("refs=aria does not support selector/frame snapshots yet")
+ if width, ok := getIntArg(requestArgs, "width"); ok {
+ actRequest.Width = width
}
-
- maxDepth := 0
- if hasDepth {
- maxDepth = depth
+ if height, ok := getIntArg(requestArgs, "height"); ok {
+ actRequest.Height = height
}
- wantsRoleSnapshot := labels ||
- mode == defaultBrowserSnapshotModeEfficient ||
- interactive ||
- compact ||
- hasDepth ||
- selector != "" ||
- frameSelector != ""
- if format == "ai" && !wantsRoleSnapshot {
- aiResult, aiErr := collectBrowserAISnapshot(tab.Page, maxChars)
- if aiErr == nil {
- state.mu.Lock()
- tab.refs = aiResult.Refs
- state.mu.Unlock()
-
- result := map[string]any{
- "ok": true,
- "format": "ai",
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- "snapshot": aiResult.Snapshot,
- "truncated": aiResult.Truncated,
- "refs": aiResult.RefsJSON,
- "stats": aiResult.Stats,
- }
- return result, nil
- }
- if !isBrowserSnapshotForAIUnavailable(aiErr) {
- return nil, aiErr
- }
+ result, err := state.session.Act(ctx, actRequest)
+ if err != nil {
+ return nil, err
}
+ return browserResultMap(result), nil
+}
- items, err := collectBrowserSnapshotItems(tab.Page, selector, frameSelector, interactive, limit, refsMode, maxDepth)
+func browserActionReset(payload toolArgs, state *browserProfileState) (map[string]any, error) {
+ restart, _ := getBoolArg(payload, "restart")
+ result, err := state.session.Reset(restart)
if err != nil {
return nil, err
}
+ output := browserResultMap(result)
+ output["profile"] = state.profileName
+ return output, nil
+}
- refs := make(map[string]browserSnapshotRef, len(items))
- refsJSON := make(map[string]map[string]any, len(items))
- lines := make([]string, 0, len(items))
- nodes := make([]map[string]any, 0, len(items))
- interactiveCount := 0
- for index, item := range items {
- if index >= limit {
- break
- }
- ref := strings.TrimSpace(item.Ref)
- if ref == "" {
- ref = fmt.Sprintf("e%d", index+1)
- }
- entryMode := refsMode
- if entryMode == "aria" && strings.TrimSpace(item.AriaRef) == "" {
- entryMode = "role"
- }
- entry := browserSnapshotRef{
- Role: item.Role,
- Name: item.Name,
- Nth: item.Nth,
- Mode: entryMode,
- AriaRef: item.AriaRef,
- Frame: frameSelector,
- }
- refs[ref] = entry
- refsJSON[ref] = map[string]any{
- "role": entry.Role,
- "name": entry.Name,
- }
- if entry.Nth > 0 {
- refsJSON[ref]["nth"] = entry.Nth
- }
- if isBrowserInteractiveRole(entry.Role) {
- interactiveCount += 1
- }
- line := fmt.Sprintf("[%s] role=%s name=%s", ref, entry.Role, entry.Name)
- if !compact && strings.TrimSpace(item.Text) != "" {
- line += " text=" + trimToMaxChars(item.Text, 120)
- }
- lines = append(lines, line)
- nodeDepth := item.Depth
- if maxDepth > 0 {
- nodeDepth = minInt(nodeDepth, maxDepth)
- }
- nodes = append(nodes, map[string]any{
- "ref": ref,
- "role": entry.Role,
- "name": entry.Name,
- "depth": nodeDepth,
- })
+func browserCommandOptions(payload toolArgs, fallbackTimeoutMs int) browsercdp.CommandOptions {
+ return browsercdp.CommandOptions{
+ Limit: resolveBrowserSnapshotLimit(payload),
+ Timeout: browserTimeoutDuration(payload, fallbackTimeoutMs),
+ WaitFor: browserOptionalWaitRequest(getMapArg(payload, "waitFor")),
}
+}
- state.mu.Lock()
- tab.refs = refs
- state.mu.Unlock()
-
- if format == "aria" {
- result := map[string]any{
- "ok": true,
- "format": "aria",
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- "nodes": nodes,
- }
- return result, nil
+func browserOptionalWaitRequest(raw map[string]any) *browsercdp.WaitRequest {
+ if raw == nil {
+ return nil
}
-
- snapshot := strings.Join(lines, "\n")
- truncated := false
- if maxChars > 0 && len(snapshot) > maxChars {
- snapshot = trimToMaxChars(snapshot, maxChars)
- truncated = true
+ request := browserWaitRequestFromArgs(toolArgs(raw))
+ if browserWaitRequestEmpty(request) {
+ return nil
}
+ return &request
+}
- result := map[string]any{
- "ok": true,
- "format": "ai",
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- "snapshot": snapshot,
- "truncated": truncated,
- "refs": refsJSON,
- "stats": map[string]any{
- "lines": len(lines),
- "chars": len(snapshot),
- "refs": len(refsJSON),
- "interactive": interactiveCount,
- },
+func browserWaitRequestFromArgs(args toolArgs) browsercdp.WaitRequest {
+ request := browsercdp.WaitRequest{
+ Selector: strings.TrimSpace(getStringArg(args, "selector")),
+ Text: strings.TrimSpace(getStringArg(args, "text")),
+ TextGone: strings.TrimSpace(getStringArg(args, "textGone")),
+ URL: strings.TrimSpace(getStringArg(args, "url")),
+ Fn: strings.TrimSpace(getStringArg(args, "fn")),
}
-
- if labels {
- img, err := tab.Page.Screenshot(playwright.PageScreenshotOptions{Type: playwright.ScreenshotTypePng})
- if err == nil {
- path, writeErr := saveBrowserArtifact("png", img)
- if writeErr == nil {
- result["labels"] = true
- result["imagePath"] = path
- result["imageType"] = "png"
- }
- }
+ if timeMs, ok := getIntArg(args, "timeMs"); ok && timeMs > 0 {
+ request.Time = time.Duration(timeMs) * time.Millisecond
}
+ if timeoutMs, ok := getIntArg(args, "timeoutMs"); ok && timeoutMs > 0 {
+ request.Timeout = time.Duration(timeoutMs) * time.Millisecond
+ }
+ return request
+}
- return result, nil
+func browserWaitRequestEmpty(request browsercdp.WaitRequest) bool {
+ return request.Time <= 0 &&
+ request.Selector == "" &&
+ request.Text == "" &&
+ request.TextGone == "" &&
+ request.URL == "" &&
+ request.Fn == ""
}
-func browserActionScreenshot(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
+func browserTimeoutDuration(payload toolArgs, fallbackTimeoutMs int) time.Duration {
+ return time.Duration(resolveBrowserActionTimeoutMs(payload, fallbackTimeoutMs)) * time.Millisecond
+}
- imageType := strings.ToLower(strings.TrimSpace(getStringArg(payload, "type")))
- if imageType != "jpeg" {
- imageType = "png"
- }
+func browserActTimeoutDuration(request toolArgs, payload toolArgs, fallbackTimeoutMs int) time.Duration {
+ return time.Duration(resolveBrowserActTimeoutMs(request, payload, fallbackTimeoutMs)) * time.Millisecond
+}
- fullPage, _ := getBoolArg(payload, "fullPage")
- ref := strings.TrimSpace(getStringArg(payload, "ref"))
- element := strings.TrimSpace(getStringArg(payload, "element"))
- if fullPage && (ref != "" || element != "") {
- return nil, errors.New("fullPage is not supported for element screenshots")
- }
+func browserResultMap(result browsercdp.ActionResult) map[string]any {
+ data, _ := json.Marshal(result)
+ decoded := map[string]any{}
+ _ = json.Unmarshal(data, &decoded)
+ decoded["stateAvailable"] = result.State != nil || result.StateAvailable
+ decoded["itemCount"] = browserResultItemCount(result)
+ return decoded
+}
- var bytes []byte
- if ref != "" {
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return nil, err
- }
- bytes, err = locator.Screenshot(playwright.LocatorScreenshotOptions{
- Type: toPlaywrightScreenshotType(imageType),
- Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))),
- })
- if err != nil {
- return nil, err
- }
- } else if element != "" {
- locator := tab.Page.Locator(strings.TrimSpace(element))
- bytes, err = locator.Screenshot(playwright.LocatorScreenshotOptions{
- Type: toPlaywrightScreenshotType(imageType),
- Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))),
- })
- if err != nil {
- return nil, err
- }
- } else {
- bytes, err = tab.Page.Screenshot(playwright.PageScreenshotOptions{
- Type: toPlaywrightScreenshotType(imageType),
- FullPage: playwright.Bool(fullPage),
- Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))),
- })
- if err != nil {
- return nil, err
+func firstNonEmptyString(values ...string) string {
+ for _, value := range values {
+ if strings.TrimSpace(value) != "" {
+ return strings.TrimSpace(value)
}
}
-
- path, err := saveBrowserArtifact(imageType, bytes)
- if err != nil {
- return nil, err
- }
- return map[string]any{
- "ok": true,
- "path": path,
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- }, nil
-}
-
-func browserActionNavigate(ctx context.Context, payload toolArgs, state *browserProfileState, connectors ConnectorsReader) (map[string]any, error) {
- targetURL := getStringArg(payload, "targetUrl", "url")
- if targetURL == "" {
- return nil, errors.New("targetUrl is required")
- }
- if err := assertBrowserURLAllowed(targetURL, state.resolved.SSRFRules); err != nil {
- return nil, err
- }
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
-
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
-
- state.mu.Lock()
- browserCtx := state.context
- state.mu.Unlock()
- if browserCtx != nil {
- if err := addConnectorCookiesToContext(ctx, connectors, browserCtx, targetURL); err != nil {
- return nil, err
- }
- }
-
- resp, err := tab.Page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{
- Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 30000))),
- WaitUntil: resolveBrowserWaitUntilState(resolveBrowserWaitUntil(payload, defaultBrowserWaitUntil)),
- })
- if err != nil {
- return nil, err
- }
-
- finalURL := strings.TrimSpace(tab.Page.URL())
- if finalURL == "" {
- finalURL = strings.TrimSpace(targetURL)
- }
- if err := assertBrowserURLAllowed(finalURL, state.resolved.SSRFRules); err != nil {
- return nil, err
- }
-
- status := http.StatusOK
- if resp != nil {
- status = resp.Status()
- }
-
- return map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "url": finalURL,
- "status": status,
- }, nil
-}
-
-func browserActionConsole(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, false)
- if err != nil {
- return nil, err
- }
-
- level := strings.ToLower(strings.TrimSpace(getStringArg(payload, "level")))
- state.mu.Lock()
- defer state.mu.Unlock()
-
- messages := make([]browserConsoleMessage, 0, len(state.consoleMessages))
- for _, item := range state.consoleMessages {
- if item.TargetID != tab.TargetID {
- continue
- }
- if level != "" && level != "all" && item.Type != level {
- continue
- }
- messages = append(messages, item)
- }
-
- return map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "messages": messages,
- }, nil
-}
-
-func browserActionPDF(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
-
- bytes, err := tab.Page.PDF(playwright.PagePdfOptions{
- PrintBackground: playwright.Bool(true),
- })
- if err != nil {
- return nil, err
- }
- path, err := saveBrowserArtifact("pdf", bytes)
- if err != nil {
- return nil, err
- }
- return map[string]any{
- "ok": true,
- "path": path,
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- }, nil
-}
-
-func browserActionUpload(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
-
- paths, err := parseBrowserUploadPaths(payload)
- if err != nil {
- return nil, err
- }
-
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
-
- inputRef := strings.TrimSpace(getStringArg(payload, "inputRef"))
- ref := strings.TrimSpace(getStringArg(payload, "ref"))
- element := strings.TrimSpace(getStringArg(payload, "element"))
- if inputRef != "" || ref != "" || element != "" {
- locator, err := resolveBrowserUploadLocator(tab, inputRef, ref, element)
- if err != nil {
- return nil, err
- }
- if err := locator.SetInputFiles(paths); err != nil {
- return nil, err
- }
- return map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "paths": paths,
- "armed": false,
- }, nil
- }
-
- timeoutMs := resolveBrowserActionTimeoutMs(payload, defaultBrowserHookTimeoutMs)
- state.mu.Lock()
- state.pendingUploads[tab.TargetID] = browserPendingUpload{
- Paths: append([]string(nil), paths...),
- ExpiresAt: time.Now().Add(time.Duration(timeoutMs) * time.Millisecond),
- }
- state.mu.Unlock()
-
- return map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "paths": paths,
- "armed": true,
- }, nil
-}
-
-func browserActionDialog(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
- targetID := strings.TrimSpace(getStringArg(payload, "targetId"))
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
- accept, _ := getBoolArg(payload, "accept")
- promptText := strings.TrimSpace(getStringArg(payload, "promptText"))
- timeoutMs := resolveBrowserActionTimeoutMs(payload, defaultBrowserHookTimeoutMs)
-
- state.mu.Lock()
- state.pendingDialogs[tab.TargetID] = browserPendingDialog{
- Accept: accept,
- PromptText: promptText,
- ExpiresAt: time.Now().Add(time.Duration(timeoutMs) * time.Millisecond),
- }
- state.mu.Unlock()
-
- return map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "armed": true,
- "accept": accept,
- }, nil
-}
-
-func browserActionAct(payload toolArgs, state *browserProfileState) (map[string]any, error) {
- request := getMapArg(payload, "request")
- if request == nil {
- return nil, errors.New("request is required")
- }
- kind := strings.ToLower(strings.TrimSpace(getStringArg(toolArgs(request), "kind")))
- if kind == "" {
- return nil, errors.New("request.kind is required")
- }
- if !containsString(browserActKinds, kind) {
- return nil, errors.New("act kind not supported: " + kind)
- }
- if _, hasSelector := request["selector"]; hasSelector && kind != "wait" {
- return nil, errors.New(browserSelectorUnsupportedMessage)
- }
- if err := ensureBrowserProfileStarted(state); err != nil {
- return nil, err
- }
-
- targetID := strings.TrimSpace(getStringArg(toolArgs(request), "targetId"))
- if targetID == "" {
- targetID = strings.TrimSpace(getStringArg(payload, "targetId"))
- }
- tab, err := resolveBrowserTab(state, targetID, true)
- if err != nil {
- return nil, err
- }
-
- switch kind {
- case "click":
- err = browserActClick(tab, toolArgs(request), payload)
- case "type":
- err = browserActType(tab, toolArgs(request), payload)
- case "press":
- err = browserActPress(tab, toolArgs(request), payload)
- case "hover":
- err = browserActHover(tab, toolArgs(request), payload)
- case "drag":
- err = browserActDrag(tab, toolArgs(request), payload)
- case "select":
- err = browserActSelect(tab, toolArgs(request), payload)
- case "fill":
- err = browserActFill(tab, toolArgs(request), payload)
- case "resize":
- err = browserActResize(tab, toolArgs(request))
- case "wait":
- err = browserActWait(tab, toolArgs(request), state.resolved.EvaluateEnabled)
- case "evaluate":
- err = browserActEvaluate(tab, toolArgs(request), payload, state.resolved.EvaluateEnabled)
- case "close":
- err = browserActClose(tab, state)
- default:
- err = errors.New("act kind not supported: " + kind)
- }
- if err != nil {
- return nil, err
- }
-
- result := map[string]any{
- "ok": true,
- "targetId": tab.TargetID,
- "url": strings.TrimSpace(tab.Page.URL()),
- }
- if kind == "evaluate" {
- result["result"] = tabResultFromEvaluate(tab)
- }
- return result, nil
-}
-
-func tabResultFromEvaluate(tab *browserTabState) any {
- if tab == nil {
- return nil
- }
- tab.mu.RLock()
- defer tab.mu.RUnlock()
- return tab.evaluateResult
-}
-
-func browserActClick(tab *browserTabState, request toolArgs, payload toolArgs) error {
- ref := strings.TrimSpace(getStringArg(request, "ref"))
- if ref == "" {
- return errors.New("ref is required")
- }
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return err
- }
- clickOptions := playwright.LocatorClickOptions{}
- if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 {
- clickOptions.Timeout = playwright.Float(float64(timeout))
- }
- buttonRaw := strings.TrimSpace(getStringArg(request, "button"))
- if buttonRaw != "" {
- button := toPlaywrightMouseButton(buttonRaw)
- if button == nil {
- return errors.New("button must be left|right|middle")
- }
- clickOptions.Button = button
- }
- modifiers, err := toPlaywrightKeyboardModifiers(getStringSliceArg(request, "modifiers"))
- if err != nil {
- return err
- }
- if len(modifiers) > 0 {
- clickOptions.Modifiers = modifiers
- }
- if doubleClick, _ := getBoolArg(request, "doubleClick"); doubleClick {
- dblOptions := playwright.LocatorDblclickOptions{}
- if clickOptions.Timeout != nil {
- dblOptions.Timeout = clickOptions.Timeout
- }
- if clickOptions.Button != nil {
- dblOptions.Button = clickOptions.Button
- }
- if len(clickOptions.Modifiers) > 0 {
- dblOptions.Modifiers = clickOptions.Modifiers
- }
- if err := locator.Dblclick(dblOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- return nil
- }
- if err := locator.Click(clickOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- return nil
-}
-
-func browserActType(tab *browserTabState, request toolArgs, payload toolArgs) error {
- ref := strings.TrimSpace(getStringArg(request, "ref"))
- if ref == "" {
- return errors.New("ref is required")
- }
- textRaw, ok := request["text"]
- if !ok {
- return errors.New("text is required")
- }
- text, ok := textRaw.(string)
- if !ok {
- return errors.New("text is required")
- }
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return err
- }
- timeout := resolveBrowserActTimeoutMs(request, payload, 10000)
- if slowly, _ := getBoolArg(request, "slowly"); slowly {
- typeOptions := playwright.LocatorTypeOptions{}
- if timeout > 0 {
- typeOptions.Timeout = playwright.Float(float64(timeout))
- }
- typeOptions.Delay = playwright.Float(75)
- if err := locator.Click(playwright.LocatorClickOptions{Timeout: playwright.Float(float64(timeout))}); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- if err := locator.Type(text, typeOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- } else {
- fillOptions := playwright.LocatorFillOptions{}
- if timeout > 0 {
- fillOptions.Timeout = playwright.Float(float64(timeout))
- }
- if err := locator.Fill(text, fillOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- }
- if submit, _ := getBoolArg(request, "submit"); submit {
- pressOptions := playwright.LocatorPressOptions{}
- if timeout > 0 {
- pressOptions.Timeout = playwright.Float(float64(timeout))
- }
- if err := locator.Press("Enter", pressOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- }
- return nil
-}
-
-func browserActPress(tab *browserTabState, request toolArgs, payload toolArgs) error {
- key := strings.TrimSpace(getStringArg(request, "key"))
- if key == "" {
- return errors.New("key is required")
- }
- options := playwright.KeyboardPressOptions{}
- if delayMs, ok := getIntArg(request, "delayMs"); ok && delayMs >= 0 {
- options.Delay = playwright.Float(float64(delayMs))
- }
- if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 {
- tab.Page.SetDefaultTimeout(float64(timeout))
- defer tab.Page.SetDefaultTimeout(30000)
- }
- return tab.Page.Keyboard().Press(key, options)
-}
-
-func browserActHover(tab *browserTabState, request toolArgs, payload toolArgs) error {
- ref := strings.TrimSpace(getStringArg(request, "ref"))
- if ref == "" {
- return errors.New("ref is required")
- }
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return err
- }
- hoverOptions := playwright.LocatorHoverOptions{}
- if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 {
- hoverOptions.Timeout = playwright.Float(float64(timeout))
- }
- if err := locator.Hover(hoverOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- return nil
-}
-
-func browserActDrag(tab *browserTabState, request toolArgs, payload toolArgs) error {
- startRef := strings.TrimSpace(getStringArg(request, "startRef"))
- endRef := strings.TrimSpace(getStringArg(request, "endRef"))
- if startRef == "" || endRef == "" {
- return errors.New("startRef and endRef are required")
- }
- startLocator, err := resolveBrowserRefLocator(tab, startRef)
- if err != nil {
- return err
- }
- endLocator, err := resolveBrowserRefLocator(tab, endRef)
- if err != nil {
- return err
- }
- dragOptions := playwright.LocatorDragToOptions{}
- if timeout := resolveBrowserActTimeoutMs(request, payload, 12000); timeout > 0 {
- dragOptions.Timeout = playwright.Float(float64(timeout))
- }
- if err := startLocator.DragTo(endLocator, dragOptions); err != nil {
- return toBrowserFriendlyInteractionError(err, startRef+" -> "+endRef)
- }
- return nil
-}
-
-func browserActSelect(tab *browserTabState, request toolArgs, payload toolArgs) error {
- ref := strings.TrimSpace(getStringArg(request, "ref"))
- values := getStringSliceArg(request, "values")
- if ref == "" || len(values) == 0 {
- return errors.New("ref and values are required")
- }
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return err
- }
- timeout := resolveBrowserActTimeoutMs(request, payload, 10000)
- _, err = locator.SelectOption(playwright.SelectOptionValues{Values: &values}, playwright.LocatorSelectOptionOptions{
- Timeout: playwright.Float(float64(timeout)),
- })
- if err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- return nil
-}
-
-func browserActFill(tab *browserTabState, request toolArgs, payload toolArgs) error {
- rawFields, ok := request["fields"].([]any)
- if !ok || len(rawFields) == 0 {
- return errors.New("fields are required")
- }
- timeout := resolveBrowserActTimeoutMs(request, payload, 10000)
- for _, raw := range rawFields {
- field, ok := raw.(map[string]any)
- if !ok {
- continue
- }
- ref := strings.TrimSpace(getStringArg(toolArgs(field), "ref"))
- if ref == "" {
- continue
- }
- typeValue := strings.ToLower(strings.TrimSpace(getStringArg(toolArgs(field), "type")))
- locator, err := resolveBrowserRefLocator(tab, ref)
- if err != nil {
- return err
- }
- switch typeValue {
- case "checkbox", "radio", "bool", "boolean":
- checked, _ := getBoolArg(toolArgs(field), "value")
- if checked {
- err = locator.Check(playwright.LocatorCheckOptions{Timeout: playwright.Float(float64(timeout))})
- } else {
- err = locator.Uncheck(playwright.LocatorUncheckOptions{Timeout: playwright.Float(float64(timeout))})
- }
- case "select", "option":
- value := strings.TrimSpace(getStringArg(toolArgs(field), "value"))
- if value == "" {
- continue
- }
- values := []string{value}
- _, err = locator.SelectOption(playwright.SelectOptionValues{Values: &values}, playwright.LocatorSelectOptionOptions{
- Timeout: playwright.Float(float64(timeout)),
- })
- default:
- value := strings.TrimSpace(getStringArg(toolArgs(field), "value"))
- err = locator.Fill(value, playwright.LocatorFillOptions{Timeout: playwright.Float(float64(timeout))})
- }
- if err != nil {
- return toBrowserFriendlyInteractionError(err, ref)
- }
- }
- return nil
-}
-
-func browserActResize(tab *browserTabState, request toolArgs) error {
- width, hasWidth := getIntArg(request, "width")
- height, hasHeight := getIntArg(request, "height")
- if !hasWidth || !hasHeight || width <= 0 || height <= 0 {
- return errors.New("width and height are required")
- }
- return tab.Page.SetViewportSize(width, height)
-}
-
-func browserActWait(tab *browserTabState, request toolArgs, evaluateEnabled bool) error {
- timeMs, hasTime := getIntArg(request, "timeMs")
- text := strings.TrimSpace(getStringArg(request, "text"))
- textGone := strings.TrimSpace(getStringArg(request, "textGone"))
- selector := strings.TrimSpace(getStringArg(request, "selector"))
- urlWait := strings.TrimSpace(getStringArg(request, "url"))
- loadState := strings.ToLower(strings.TrimSpace(getStringArg(request, "loadState")))
- fn := strings.TrimSpace(getStringArg(request, "fn"))
- timeoutMs := resolveBrowserActionTimeoutMs(request, 15000)
-
- if fn != "" && !evaluateEnabled {
- return errors.New(browserWaitFnDisabledMessage)
- }
-
- hasCondition := false
- if hasTime && timeMs > 0 {
- hasCondition = true
- tab.Page.WaitForTimeout(float64(timeMs))
- }
- if text != "" {
- hasCondition = true
- if err := tab.Page.Locator("text=" + text).First().WaitFor(playwright.LocatorWaitForOptions{
- State: playwright.WaitForSelectorStateVisible,
- Timeout: playwright.Float(float64(timeoutMs)),
- }); err != nil {
- return err
- }
- }
- if textGone != "" {
- hasCondition = true
- if err := tab.Page.Locator("text=" + textGone).First().WaitFor(playwright.LocatorWaitForOptions{
- State: playwright.WaitForSelectorStateHidden,
- Timeout: playwright.Float(float64(timeoutMs)),
- }); err != nil {
- return err
- }
- }
- if selector != "" {
- hasCondition = true
- if _, err := tab.Page.WaitForSelector(selector, playwright.PageWaitForSelectorOptions{
- State: playwright.WaitForSelectorStateVisible,
- Timeout: playwright.Float(float64(timeoutMs)),
- }); err != nil {
- return err
- }
- }
- if urlWait != "" {
- hasCondition = true
- if err := tab.Page.WaitForURL(urlWait, playwright.PageWaitForURLOptions{
- Timeout: playwright.Float(float64(timeoutMs)),
- }); err != nil {
- return err
- }
- }
- if loadState != "" {
- hasCondition = true
- state := resolveBrowserLoadState(loadState)
- if state == nil {
- return errors.New("loadState must be load|domcontentloaded|networkidle")
- }
- if err := tab.Page.WaitForLoadState(playwright.PageWaitForLoadStateOptions{
- State: state,
- Timeout: playwright.Float(float64(timeoutMs)),
- }); err != nil {
- return err
- }
- }
- if fn != "" {
- hasCondition = true
- if err := waitBrowserEvaluateCondition(tab.Page, fn, timeoutMs); err != nil {
- return err
- }
- }
- if !hasCondition {
- return errors.New(browserWaitRequiresConditionMessage)
- }
- return nil
-}
-
-func browserActEvaluate(tab *browserTabState, request toolArgs, payload toolArgs, evaluateEnabled bool) error {
- if !evaluateEnabled {
- return errors.New(browserEvaluateDisabledMessage)
- }
- fn := strings.TrimSpace(getStringArg(request, "fn"))
- if fn == "" {
- return errors.New("fn is required")
- }
- timeoutMs := resolveBrowserActTimeoutMs(request, payload, 20000)
- ref := strings.TrimSpace(getStringArg(request, "ref"))
- var (
- result any
- err error
- )
- if ref != "" {
- locator, resolveErr := resolveBrowserRefLocator(tab, ref)
- if resolveErr != nil {
- return resolveErr
- }
- result, err = locator.Evaluate(browserEvaluateElementExpression, []any{fn, timeoutMs})
- } else {
- result, err = tab.Page.Evaluate(browserEvaluatePageExpression, []any{fn, timeoutMs})
- }
- if err != nil {
- return err
- }
- tab.mu.Lock()
- tab.evaluateResult = result
- tab.mu.Unlock()
- return nil
-}
-
-func browserActClose(tab *browserTabState, state *browserProfileState) error {
- if tab == nil || tab.Page == nil {
- return errors.New("tab not found")
- }
- if err := tab.Page.Close(); err != nil {
- return err
- }
- state.mu.Lock()
- delete(state.tabs, tab.TargetID)
- delete(state.pageToTarget, tab.Page)
- if state.activeTarget == tab.TargetID {
- state.activeTarget = ""
- }
- pruneClosedTabsLocked(state)
- state.mu.Unlock()
- return nil
-}
-
-type browserSnapshotItem struct {
- Ref string
- AriaRef string
- Role string
- Name string
- Text string
- Depth int
- Nth int
-}
-
-var browserAriaSnapshotLinePattern = regexp.MustCompile(`^(\s*)-\s*([^\s":]+)(?:\s+"([^"]*)")?(.*)$`)
-var browserAriaSnapshotRefPattern = regexp.MustCompile(`\[ref=([^\]]+)\]`)
-var browserStrictModeCountPattern = regexp.MustCompile(`resolved to (\d+) elements`)
-var browserEvaluatePageExpression = `([fnBody, timeoutMs]) => {
- try {
- const candidate = eval("(" + fnBody + ")");
- const result = typeof candidate === "function" ? candidate() : candidate;
- if (result && typeof result.then === "function") {
- return Promise.race([
- result,
- new Promise((_, reject) =>
- setTimeout(() => reject(new Error("evaluate timed out after " + timeoutMs + "ms")), timeoutMs),
- ),
- ]);
- }
- return result;
- } catch (err) {
- throw new Error("Invalid evaluate function: " + (err && err.message ? err.message : String(err)));
- }
-}`
-var browserEvaluateElementExpression = `(el, [fnBody, timeoutMs]) => {
- try {
- const candidate = eval("(" + fnBody + ")");
- const result = typeof candidate === "function" ? candidate(el) : candidate;
- if (result && typeof result.then === "function") {
- return Promise.race([
- result,
- new Promise((_, reject) =>
- setTimeout(() => reject(new Error("evaluate timed out after " + timeoutMs + "ms")), timeoutMs),
- ),
- ]);
- }
- return result;
- } catch (err) {
- throw new Error("Invalid evaluate function: " + (err && err.message ? err.message : String(err)));
- }
-}`
-
-type browserAISnapshotResult struct {
- Snapshot string
- Truncated bool
- Refs map[string]browserSnapshotRef
- RefsJSON map[string]map[string]any
- Stats map[string]any
-}
-
-func collectBrowserAISnapshot(page playwright.Page, maxChars int) (*browserAISnapshotResult, error) {
- snapshot, err := captureBrowserPrivateAISnapshot(page, 5000)
- if err != nil {
- return nil, err
- }
- items := parseBrowserAriaSnapshot(snapshot, false, 2000, true, 0)
- refs := make(map[string]browserSnapshotRef, len(items))
- refsJSON := make(map[string]map[string]any, len(items))
- interactiveCount := 0
- for index, item := range items {
- ref := strings.TrimSpace(item.Ref)
- if ref == "" {
- ref = fmt.Sprintf("e%d", index+1)
- }
- entryMode := "aria"
- if strings.TrimSpace(item.AriaRef) == "" {
- entryMode = "role"
- }
- entry := browserSnapshotRef{
- Role: item.Role,
- Name: item.Name,
- Nth: item.Nth,
- Mode: entryMode,
- AriaRef: item.AriaRef,
- }
- refs[ref] = entry
- refsJSON[ref] = map[string]any{
- "role": entry.Role,
- "name": entry.Name,
- }
- if entry.Nth > 0 {
- refsJSON[ref]["nth"] = entry.Nth
- }
- if isBrowserInteractiveRole(entry.Role) {
- interactiveCount += 1
- }
- }
-
- truncated := false
- if maxChars > 0 && len(snapshot) > maxChars {
- snapshot = trimToMaxChars(snapshot, maxChars)
- truncated = true
- }
-
- lines := 0
- for _, line := range strings.Split(snapshot, "\n") {
- if strings.TrimSpace(line) == "" {
- continue
- }
- lines += 1
- }
- return &browserAISnapshotResult{
- Snapshot: snapshot,
- Truncated: truncated,
- Refs: refs,
- RefsJSON: refsJSON,
- Stats: map[string]any{
- "lines": lines,
- "chars": len(snapshot),
- "refs": len(refsJSON),
- "interactive": interactiveCount,
- },
- }, nil
-}
-
-func captureBrowserPrivateAISnapshot(page playwright.Page, timeoutMs int) (snapshot string, err error) {
- defer func() {
- if recover() != nil {
- snapshot = ""
- err = errBrowserSnapshotForAIUnavailable
- }
- }()
- if page == nil {
- return "", errBrowserSnapshotForAIUnavailable
- }
-
- value := reflect.ValueOf(page)
- if !value.IsValid() {
- return "", errBrowserSnapshotForAIUnavailable
- }
- if value.Kind() == reflect.Interface {
- value = value.Elem()
- }
- if value.Kind() == reflect.Pointer {
- value = value.Elem()
- }
- if !value.IsValid() || value.Kind() != reflect.Struct {
- return "", errBrowserSnapshotForAIUnavailable
- }
-
- channelOwnerField := value.FieldByName("channelOwner")
- if !channelOwnerField.IsValid() {
- return "", errBrowserSnapshotForAIUnavailable
- }
- channelField := channelOwnerField.FieldByName("channel")
- if !channelField.IsValid() || (channelField.Kind() == reflect.Pointer && channelField.IsNil()) {
- return "", errBrowserSnapshotForAIUnavailable
- }
- sendMethod := channelField.MethodByName("Send")
- if !sendMethod.IsValid() {
- return "", errBrowserSnapshotForAIUnavailable
- }
-
- calls := sendMethod.Call([]reflect.Value{
- reflect.ValueOf("snapshotForAI"),
- reflect.ValueOf(map[string]any{
- "timeout": normalizeBrowserTimeoutMs(timeoutMs, 5000),
- "track": "response",
- }),
- })
- if len(calls) != 2 {
- return "", errBrowserSnapshotForAIUnavailable
- }
- if errValue := calls[1]; errValue.IsValid() && !errValue.IsNil() {
- if invokeErr, ok := errValue.Interface().(error); ok {
- return "", invokeErr
- }
- return "", fmt.Errorf("snapshotForAI call failed")
- }
-
- result := calls[0].Interface()
- snapshot = extractBrowserAISnapshotText(result)
- if strings.TrimSpace(snapshot) == "" {
- return "", errBrowserSnapshotForAIUnavailable
- }
- return snapshot, nil
-}
-
-func extractBrowserAISnapshotText(raw any) string {
- switch value := raw.(type) {
- case map[string]any:
- if full, ok := value["full"]; ok {
- return fmt.Sprint(full)
- }
- if snapshot, ok := value["snapshot"]; ok {
- return fmt.Sprint(snapshot)
- }
- case string:
- return value
- }
- return ""
-}
-
-func isBrowserSnapshotForAIUnavailable(err error) bool {
- if err == nil {
- return false
- }
- if errors.Is(err, errBrowserSnapshotForAIUnavailable) {
- return true
- }
- message := strings.ToLower(strings.TrimSpace(err.Error()))
- return strings.Contains(message, "_snapshotforai") ||
- strings.Contains(message, "snapshotforai") ||
- strings.Contains(message, "playwright snapshotforai is unavailable")
-}
-
-func collectBrowserSnapshotItems(
- page playwright.Page,
- selector string,
- frameSelector string,
- interactive bool,
- limit int,
- refsMode string,
- maxDepth int,
-) ([]browserSnapshotItem, error) {
- if limit <= 0 {
- limit = defaultBrowserSnapshotLimit
- }
- locator := resolveBrowserSnapshotLocator(page, selector, frameSelector)
- snapshot, err := locator.AriaSnapshot(playwright.LocatorAriaSnapshotOptions{
- Ref: playwright.Bool(strings.EqualFold(refsMode, "aria")),
- })
- if err != nil {
- return nil, err
- }
- return parseBrowserAriaSnapshot(snapshot, interactive, limit, strings.EqualFold(refsMode, "aria"), maxDepth), nil
-}
-
-func resolveBrowserSnapshotLocator(page playwright.Page, selector string, frameSelector string) playwright.Locator {
- selector = strings.TrimSpace(selector)
- frameSelector = strings.TrimSpace(frameSelector)
- if frameSelector != "" {
- frame := page.FrameLocator(frameSelector)
- if selector != "" {
- return frame.Locator(selector)
- }
- return frame.Locator(":root")
- }
- if selector != "" {
- return page.Locator(selector)
- }
- return page.Locator(":root")
-}
-
-func parseBrowserAriaSnapshot(
- snapshot string,
- interactive bool,
- limit int,
- useAriaRefs bool,
- maxDepth int,
-) []browserSnapshotItem {
- if limit <= 0 {
- limit = defaultBrowserSnapshotLimit
- }
- lines := strings.Split(snapshot, "\n")
- items := make([]browserSnapshotItem, 0, minInt(len(lines), limit))
- countByRoleName := map[string]int{}
- nextRef := 0
-
- for _, rawLine := range lines {
- if len(items) >= limit {
- break
- }
- line := strings.TrimRight(rawLine, "\r")
- if strings.TrimSpace(line) == "" {
- continue
- }
-
- match := browserAriaSnapshotLinePattern.FindStringSubmatch(line)
- if len(match) < 5 {
- continue
- }
- role := strings.ToLower(strings.TrimSpace(match[2]))
- if role == "" || strings.HasPrefix(role, "/") {
- continue
- }
- if interactive && !isBrowserInteractiveRole(role) {
- continue
- }
-
- name := strings.TrimSpace(match[3])
- suffix := strings.TrimSpace(match[4])
- text := browserAriaSnapshotTextFromSuffix(suffix)
- depth := browserAriaSnapshotDepth(line)
- if maxDepth > 0 && depth > maxDepth {
- continue
- }
-
- ariaRef := ""
- if refMatch := browserAriaSnapshotRefPattern.FindStringSubmatch(suffix); len(refMatch) >= 2 {
- ariaRef = strings.TrimSpace(refMatch[1])
- }
-
- key := role + "\n" + name
- nth := countByRoleName[key]
- countByRoleName[key] = nth + 1
-
- nextRef += 1
- ref := fmt.Sprintf("e%d", nextRef)
- if useAriaRefs && ariaRef != "" {
- ref = ariaRef
- }
-
- item := browserSnapshotItem{
- Ref: ref,
- AriaRef: ariaRef,
- Role: role,
- Name: name,
- Text: text,
- Depth: depth,
- Nth: nth,
- }
- if item.Role == "" {
- item.Role = "element"
- }
- if item.Name == "" {
- item.Name = item.Text
- }
- if item.Name == "" {
- item.Name = "(empty)"
- }
- items = append(items, item)
- }
- return items
-}
-
-func browserAriaSnapshotDepth(line string) int {
- spaces := 0
- for _, r := range line {
- if r == ' ' {
- spaces += 1
- continue
- }
- if r == '\t' {
- spaces += 2
- continue
- }
- break
- }
- return spaces / 2
-}
-
-func browserAriaSnapshotTextFromSuffix(suffix string) string {
- suffix = strings.TrimSpace(suffix)
- if suffix == "" {
- return ""
- }
- if idx := strings.LastIndex(suffix, ":"); idx >= 0 && idx+1 < len(suffix) {
- return strings.TrimSpace(suffix[idx+1:])
- }
- return ""
-}
-
-func isBrowserInteractiveRole(role string) bool {
- switch strings.ToLower(strings.TrimSpace(role)) {
- case "button",
- "link",
- "textbox",
- "checkbox",
- "radio",
- "combobox",
- "listbox",
- "menuitem",
- "menuitemcheckbox",
- "menuitemradio",
- "option",
- "searchbox",
- "slider",
- "spinbutton",
- "switch",
- "tab",
- "treeitem":
- return true
- default:
- return false
- }
-}
-
-func resolveBrowserUploadLocator(tab *browserTabState, inputRef string, ref string, element string) (playwright.Locator, error) {
- switch {
- case strings.TrimSpace(inputRef) != "":
- return tab.Page.Locator(strings.TrimSpace(inputRef)), nil
- case strings.TrimSpace(ref) != "":
- return resolveBrowserRefLocator(tab, strings.TrimSpace(ref))
- case strings.TrimSpace(element) != "":
- return tab.Page.Locator(strings.TrimSpace(element)), nil
- default:
- return nil, errors.New("upload requires ref/inputRef/element")
- }
-}
-
-func parseBrowserUploadPaths(payload toolArgs) ([]string, error) {
- raw := getStringSliceArg(payload, "paths")
- if len(raw) == 0 {
- return nil, errors.New("paths are required")
- }
- rootDir, err := resolveBrowserUploadRootDir()
- if err != nil {
- return nil, err
- }
- paths := make([]string, 0, len(raw))
- for _, item := range raw {
- trimmed := strings.TrimSpace(item)
- if trimmed == "" {
- continue
- }
- absPath, err := resolvePathWithinRoot(rootDir, trimmed)
- if err != nil {
- return nil, err
- }
- info, err := os.Stat(absPath)
- if err != nil {
- return nil, fmt.Errorf("upload path not found: %s", absPath)
- }
- if info.IsDir() {
- return nil, fmt.Errorf("upload path must be a file: %s", absPath)
- }
- paths = append(paths, absPath)
- }
- if len(paths) == 0 {
- return nil, errors.New("paths are required")
- }
- return paths, nil
-}
-
-func resolveBrowserUploadRootDir() (string, error) {
- dir := filepath.Join(os.TempDir(), "dreamcreator", "browser", "uploads")
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return "", err
- }
- return filepath.Abs(dir)
-}
-
-func resolvePathWithinRoot(rootDir string, requestedPath string) (string, error) {
- root := strings.TrimSpace(rootDir)
- if root == "" {
- return "", errors.New("upload root is required")
- }
- rootAbs, err := filepath.Abs(root)
- if err != nil {
- return "", err
- }
- raw := strings.TrimSpace(requestedPath)
- if raw == "" {
- return "", errors.New("paths are required")
- }
-
- var candidate string
- if filepath.IsAbs(raw) {
- candidate = filepath.Clean(raw)
- } else {
- candidate = filepath.Join(rootAbs, raw)
- }
- candidateAbs, err := filepath.Abs(candidate)
- if err != nil {
- return "", err
- }
- rel, err := filepath.Rel(rootAbs, candidateAbs)
- if err != nil {
- return "", err
- }
- rel = filepath.ToSlash(strings.TrimSpace(rel))
- if rel == "." || strings.HasPrefix(rel, "../") || rel == ".." {
- return "", fmt.Errorf("Invalid path: must stay within uploads directory (%s)", rootAbs)
- }
- return candidateAbs, nil
-}
-
-func resolveBrowserRefLocator(tab *browserTabState, ref string) (playwright.Locator, error) {
- if tab == nil {
- return nil, errors.New("tab unavailable")
- }
- ref = strings.TrimSpace(ref)
- if ref == "" {
- return nil, errors.New("ref is required")
- }
- if tab.refs == nil {
- return nil, errors.New("no snapshot refs available; run action=snapshot first")
- }
- entry, ok := tab.refs[ref]
- if !ok {
- return nil, fmt.Errorf("ref not found: %s (run action=snapshot again)", ref)
- }
-
- mode := strings.ToLower(strings.TrimSpace(entry.Mode))
- if mode == "aria" {
- ariaRef := strings.TrimSpace(entry.AriaRef)
- if ariaRef == "" {
- ariaRef = ref
- }
- if ariaRef == "" {
- return nil, fmt.Errorf("ref selector missing for %s", ref)
- }
- locator := resolveBrowserScopedLocator(tab.Page, entry.Frame, "aria-ref="+ariaRef)
- return locator, nil
- }
- if strings.TrimSpace(entry.Role) != "" {
- locator, err := resolveBrowserRoleLocator(tab.Page, entry.Role, entry.Name, entry.Nth, entry.Frame)
- if err == nil {
- return locator, nil
- }
- // Fall back to selector-based resolution below.
- }
-
- selector := strings.TrimSpace(entry.Selector)
- if selector == "" {
- return nil, fmt.Errorf("ref selector missing for %s", ref)
- }
- return resolveBrowserScopedLocator(tab.Page, entry.Frame, selector), nil
-}
-
-func resolveBrowserScopedLocator(page playwright.Page, frameSelector string, selector string) playwright.Locator {
- frameSelector = strings.TrimSpace(frameSelector)
- selector = strings.TrimSpace(selector)
- if frameSelector == "" {
- if strings.HasPrefix(selector, "//") || strings.HasPrefix(selector, "/") {
- return page.Locator("xpath=" + selector)
- }
- return page.Locator(selector)
- }
- frame := page.FrameLocator(frameSelector)
- if strings.HasPrefix(selector, "//") || strings.HasPrefix(selector, "/") {
- return frame.Locator("xpath=" + selector)
- }
- return frame.Locator(selector)
-}
-
-func resolveBrowserRoleLocator(page playwright.Page, role string, name string, nth int, frameSelector string) (playwright.Locator, error) {
- role = strings.TrimSpace(strings.ToLower(role))
- if role == "" {
- return nil, errors.New("role is required")
- }
-
- frameSelector = strings.TrimSpace(frameSelector)
- name = strings.TrimSpace(name)
-
- var locator playwright.Locator
- if frameSelector != "" {
- frame := page.FrameLocator(frameSelector)
- if name != "" {
- locator = frame.GetByRole(playwright.AriaRole(role), playwright.FrameLocatorGetByRoleOptions{
- Name: name,
- Exact: playwright.Bool(true),
- })
- } else {
- locator = frame.GetByRole(playwright.AriaRole(role))
- }
- } else {
- if name != "" {
- locator = page.GetByRole(playwright.AriaRole(role), playwright.PageGetByRoleOptions{
- Name: name,
- Exact: playwright.Bool(true),
- })
- } else {
- locator = page.GetByRole(playwright.AriaRole(role))
- }
- }
- if nth > 0 {
- locator = locator.Nth(nth)
- }
- return locator, nil
-}
-
-func listBrowserTabs(state *browserProfileState) ([]map[string]any, error) {
- state.mu.Lock()
- defer state.mu.Unlock()
- pruneClosedTabsLocked(state)
- items := make([]map[string]any, 0, len(state.tabs))
- ids := make([]string, 0, len(state.tabs))
- for id := range state.tabs {
- ids = append(ids, id)
- }
- sort.Strings(ids)
- for _, id := range ids {
- tab := state.tabs[id]
- title, _ := tab.Page.Title()
- items = append(items, map[string]any{
- "targetId": tab.TargetID,
- "title": strings.TrimSpace(title),
- "url": strings.TrimSpace(tab.Page.URL()),
- "type": "page",
- })
- }
- return items, nil
-}
-
-func resolveBrowserTab(state *browserProfileState, targetID string, autoCreate bool) (*browserTabState, error) {
- state.mu.Lock()
- pruneClosedTabsLocked(state)
-
- targetID = strings.TrimSpace(targetID)
- if targetID != "" {
- if tab, ok := state.tabs[targetID]; ok {
- state.activeTarget = tab.TargetID
- state.mu.Unlock()
- return tab, nil
- }
- matches := make([]*browserTabState, 0, len(state.tabs))
- targetLower := strings.ToLower(targetID)
- for id, tab := range state.tabs {
- if strings.HasPrefix(strings.ToLower(strings.TrimSpace(id)), targetLower) {
- matches = append(matches, tab)
- }
- }
- if len(matches) == 1 {
- state.activeTarget = matches[0].TargetID
- state.mu.Unlock()
- return matches[0], nil
- }
- if len(matches) > 1 {
- state.mu.Unlock()
- return nil, errors.New("ambiguous target id prefix")
- }
- if len(state.tabs) == 1 {
- for _, only := range state.tabs {
- state.activeTarget = only.TargetID
- state.mu.Unlock()
- return only, nil
- }
- }
- state.mu.Unlock()
- return nil, errors.New("tab not found")
- }
-
- if state.activeTarget != "" {
- if tab, ok := state.tabs[state.activeTarget]; ok {
- state.mu.Unlock()
- return tab, nil
- }
- }
- if len(state.tabs) == 1 {
- for _, only := range state.tabs {
- state.activeTarget = only.TargetID
- state.mu.Unlock()
- return only, nil
- }
- }
- if len(state.tabs) > 1 {
- ids := make([]string, 0, len(state.tabs))
- for id := range state.tabs {
- ids = append(ids, id)
- }
- sort.Strings(ids)
- first := state.tabs[ids[0]]
- state.activeTarget = first.TargetID
- state.mu.Unlock()
- return first, nil
- }
-
- browserCtx := state.context
- state.mu.Unlock()
- if !autoCreate || browserCtx == nil {
- return nil, errors.New("no browser tab available")
- }
- page, err := browserCtx.NewPage()
- if err != nil {
- return nil, err
- }
- tab := attachBrowserTab(state, page)
- return tab, nil
-}
-
-func attachBrowserTab(state *browserProfileState, page playwright.Page) *browserTabState {
- state.mu.Lock()
- defer state.mu.Unlock()
- if targetID, ok := state.pageToTarget[page]; ok {
- if existing, exists := state.tabs[targetID]; exists {
- state.activeTarget = existing.TargetID
- return existing
- }
- }
- targetID := fmt.Sprintf("T%d", atomic.AddUint64(&browserGlobalTabCounter, 1))
- tab := &browserTabState{
- TargetID: targetID,
- Page: page,
- refs: map[string]browserSnapshotRef{},
- }
- state.tabs[targetID] = tab
- state.pageToTarget[page] = targetID
- state.activeTarget = targetID
- attachBrowserPageObservers(state, tab)
- return tab
-}
-
-func attachBrowserPageObservers(state *browserProfileState, tab *browserTabState) {
- if tab == nil || tab.Page == nil {
- return
- }
- targetID := tab.TargetID
-
- tab.Page.OnConsole(func(message playwright.ConsoleMessage) {
- entry := browserConsoleMessage{
- TargetID: targetID,
- Type: strings.ToLower(strings.TrimSpace(message.Type())),
- Text: trimToMaxChars(strings.TrimSpace(message.Text()), 4000),
- Timestamp: time.Now().UTC().Format(time.RFC3339),
- }
- state.mu.Lock()
- state.consoleMessages = append(state.consoleMessages, entry)
- if len(state.consoleMessages) > 400 {
- state.consoleMessages = append([]browserConsoleMessage(nil), state.consoleMessages[len(state.consoleMessages)-400:]...)
- }
- state.mu.Unlock()
- })
-
- tab.Page.OnDialog(func(dialog playwright.Dialog) {
- state.mu.Lock()
- pending, ok := state.pendingDialogs[targetID]
- if ok {
- delete(state.pendingDialogs, targetID)
- }
- state.mu.Unlock()
- if ok && time.Now().Before(pending.ExpiresAt) {
- if pending.Accept {
- if pending.PromptText != "" {
- _ = dialog.Accept(pending.PromptText)
- } else {
- _ = dialog.Accept()
- }
- } else {
- _ = dialog.Dismiss()
- }
- return
- }
- _ = dialog.Dismiss()
- })
-
- tab.Page.OnFileChooser(func(chooser playwright.FileChooser) {
- state.mu.Lock()
- pending, ok := state.pendingUploads[targetID]
- if ok {
- delete(state.pendingUploads, targetID)
- }
- state.mu.Unlock()
- if ok && time.Now().Before(pending.ExpiresAt) {
- _ = chooser.SetFiles(pending.Paths)
- }
- })
-}
-
-func ensureBrowserProfileStarted(state *browserProfileState) error {
- state.mu.Lock()
- defer state.mu.Unlock()
- if state.browser != nil && state.context != nil {
- pruneClosedTabsLocked(state)
- return nil
- }
-
- var err error
- if state.pw == nil {
- state.pw, err = playwright.Run(&playwright.RunOptions{Verbose: false, Stdout: io.Discard, Stderr: io.Discard})
- if err != nil {
- return err
- }
- }
-
- launchOptions := playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(state.resolved.Headless),
- }
- args := append([]string(nil), state.resolved.ExtraArgs...)
- if state.resolved.NoSandbox {
- args = append(args, "--no-sandbox")
- }
- if state.resolved.Headless && !hasBrowserHeadlessArg(args) {
- args = append(args, "--headless=new")
- }
- if len(args) > 0 {
- launchOptions.Args = args
- }
- state.browser, err = state.pw.Chromium.Launch(launchOptions)
- if err != nil {
- return err
- }
- state.context, err = state.browser.NewContext(playwright.BrowserNewContextOptions{
- Viewport: &playwright.Size{Width: defaultBrowserViewportWidth, Height: defaultBrowserViewportHeight},
- })
- if err != nil {
- _ = state.browser.Close()
- state.browser = nil
- return err
- }
-
- if state.tabs == nil {
- state.tabs = map[string]*browserTabState{}
- }
- if state.pageToTarget == nil {
- state.pageToTarget = map[playwright.Page]string{}
- }
- if state.pendingUploads == nil {
- state.pendingUploads = map[string]browserPendingUpload{}
- }
- if state.pendingDialogs == nil {
- state.pendingDialogs = map[string]browserPendingDialog{}
- }
- for _, page := range state.context.Pages() {
- if page == nil || page.IsClosed() {
- continue
- }
- if targetID, ok := state.pageToTarget[page]; ok {
- if _, exists := state.tabs[targetID]; exists {
- continue
- }
- }
- targetID := fmt.Sprintf("T%d", atomic.AddUint64(&browserGlobalTabCounter, 1))
- tab := &browserTabState{TargetID: targetID, Page: page, refs: map[string]browserSnapshotRef{}}
- state.tabs[targetID] = tab
- state.pageToTarget[page] = targetID
- if state.activeTarget == "" {
- state.activeTarget = targetID
- }
- attachBrowserPageObservers(state, tab)
- }
- pruneClosedTabsLocked(state)
- return nil
-}
-
-func stopBrowserProfile(state *browserProfileState) error {
- state.mu.Lock()
- defer state.mu.Unlock()
-
- if state.context != nil {
- _ = state.context.Close()
- state.context = nil
- }
- if state.browser != nil {
- _ = state.browser.Close()
- state.browser = nil
- }
- if state.pw != nil {
- _ = state.pw.Stop()
- state.pw = nil
- }
-
- state.tabs = map[string]*browserTabState{}
- state.pageToTarget = map[playwright.Page]string{}
- state.activeTarget = ""
- state.pendingUploads = map[string]browserPendingUpload{}
- state.pendingDialogs = map[string]browserPendingDialog{}
- state.consoleMessages = nil
- return nil
-}
-
-func pruneClosedTabsLocked(state *browserProfileState) {
- if state == nil {
- return
- }
- for targetID, tab := range state.tabs {
- if tab == nil || tab.Page == nil || tab.Page.IsClosed() {
- delete(state.tabs, targetID)
- if tab != nil && tab.Page != nil {
- delete(state.pageToTarget, tab.Page)
- }
- if state.activeTarget == targetID {
- state.activeTarget = ""
- }
- }
- }
- if state.activeTarget != "" {
- if _, ok := state.tabs[state.activeTarget]; ok {
- return
- }
- }
- if len(state.tabs) == 0 {
- state.activeTarget = ""
- return
- }
- ids := make([]string, 0, len(state.tabs))
- for id := range state.tabs {
- ids = append(ids, id)
- }
- sort.Strings(ids)
- state.activeTarget = ids[0]
-}
-
-func getOrCreateBrowserProfileState(sessionKey string, profileName string, resolved browserResolvedConfig) *browserProfileState {
- globalBrowserSessions.mu.Lock()
- defer globalBrowserSessions.mu.Unlock()
-
- session, ok := globalBrowserSessions.sessions[sessionKey]
- if !ok {
- session = &browserSessionState{
- sessionKey: sessionKey,
- profiles: map[string]*browserProfileState{},
- }
- globalBrowserSessions.sessions[sessionKey] = session
- }
-
- state, ok := session.profiles[profileName]
- if !ok {
- profile := resolved.Profiles[profileName]
- state = &browserProfileState{
- profileName: profileName,
- resolved: resolved,
- profile: profile,
- tabs: map[string]*browserTabState{},
- pageToTarget: map[playwright.Page]string{},
- pendingUploads: map[string]browserPendingUpload{},
- pendingDialogs: map[string]browserPendingDialog{},
- consoleMessages: nil,
- }
- session.profiles[profileName] = state
- return state
- }
-
- state.mu.Lock()
- state.resolved = resolved
- if profile, exists := resolved.Profiles[profileName]; exists {
- state.profile = profile
- }
- state.profileName = profileName
- state.mu.Unlock()
- return state
-}
-
-func resolveBrowserRuntimeConfig(config map[string]any) browserResolvedConfig {
- browserConfig := resolveBrowserConfig(config)
- resolved := browserResolvedConfig{
- Enabled: true,
- EvaluateEnabled: true,
- Color: defaultBrowserColor,
- Headless: false,
- NoSandbox: false,
- DefaultProfile: defaultBrowserProfileDreamCreator,
- Profiles: map[string]browserProfileConfig{
- defaultBrowserProfileDreamCreator: {
- Name: defaultBrowserProfileDreamCreator,
- Color: defaultBrowserColor,
- Driver: browserTypePlaywright,
- },
- },
- SSRFRules: browserSSRFPolicy{
- DangerouslyAllowPrivateNetwork: true,
- AllowedHostnames: map[string]struct{}{},
- HostnameAllowlist: nil,
- },
- ExtraArgs: nil,
- }
- if browserConfig == nil {
- return resolved
- }
-
- if value, ok := getBoolArg(toolArgs(browserConfig), "enabled"); ok {
- resolved.Enabled = value
- }
- if value, ok := getBoolArg(toolArgs(browserConfig), "evaluateEnabled"); ok {
- resolved.EvaluateEnabled = value
- }
- if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "color")); value != "" {
- resolved.Color = value
- }
- if value, ok := getBoolArg(toolArgs(browserConfig), "headless"); ok {
- resolved.Headless = value
- }
- if value, ok := getBoolArg(toolArgs(browserConfig), "noSandbox"); ok {
- resolved.NoSandbox = value
- }
- if values := normalizeStringSlice(getStringSliceArg(toolArgs(browserConfig), "extraArgs")); len(values) > 0 {
- resolved.ExtraArgs = values
- }
-
- if snapshotDefaults := getMapArg(toolArgs(browserConfig), "snapshotDefaults"); snapshotDefaults != nil {
- if value := strings.TrimSpace(getStringArg(toolArgs(snapshotDefaults), "mode")); value == defaultBrowserSnapshotModeEfficient {
- resolved.SnapshotDefaultMode = value
- }
- }
-
- if ssrfRaw := getMapArg(toolArgs(browserConfig), "ssrfPolicy"); ssrfRaw != nil {
- if value, ok := getBoolArg(toolArgs(ssrfRaw), "dangerouslyAllowPrivateNetwork"); ok {
- resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value
- } else if value, ok := getBoolArg(toolArgs(ssrfRaw), "allowPrivateNetwork"); ok {
- resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value
- }
- for _, hostname := range getStringSliceArg(toolArgs(ssrfRaw), "allowedHostnames") {
- resolved.SSRFRules.AllowedHostnames[strings.ToLower(strings.TrimSpace(hostname))] = struct{}{}
- }
- if allowlist := normalizeStringSlice(getStringSliceArg(toolArgs(ssrfRaw), "hostnameAllowlist")); len(allowlist) > 0 {
- resolved.SSRFRules.HostnameAllowlist = allowlist
- }
- }
-
- if _, ok := resolved.Profiles[resolved.DefaultProfile]; !ok {
- resolved.DefaultProfile = defaultBrowserProfileDreamCreator
- }
- for name, profile := range resolved.Profiles {
- if profile.Name == "" {
- profile.Name = name
- }
- if profile.Color == "" {
- profile.Color = resolved.Color
- }
- if profile.Driver == "" {
- profile.Driver = browserTypePlaywright
- }
- resolved.Profiles[name] = profile
- }
-
- return resolved
-}
-
-func resolveBrowserProfileName(payload toolArgs, resolved browserResolvedConfig) string {
- profile := strings.TrimSpace(getStringArg(payload, "profile"))
- if profile == "" {
- profile = strings.TrimSpace(resolved.DefaultProfile)
- }
- if profile == "" {
- profile = defaultBrowserProfileDreamCreator
- }
- if _, ok := resolved.Profiles[profile]; !ok {
- profileCfg := browserProfileConfig{
- Name: profile,
- Color: resolved.Color,
- Driver: browserTypePlaywright,
- }
- resolved.Profiles[profile] = profileCfg
- }
- return profile
-}
-
-func resolveBrowserConfig(config map[string]any) map[string]any {
- return getNestedMap(config, "browser")
-}
-
-func resolveBrowserConfigBool(config map[string]any, key string) (bool, bool) {
- return getBoolArg(toolArgs(resolveBrowserConfig(config)), key)
-}
-
-func resolveBrowserType(payload toolArgs, config map[string]any) string {
- if raw := getStringArg(payload, "type", "mode", "engine"); raw != "" {
- return normalizeBrowserType(raw)
- }
- if payloadBrowser := getMapArg(payload, "browser"); payloadBrowser != nil {
- if raw := getStringArg(toolArgs(payloadBrowser), "type", "mode", "engine"); raw != "" {
- return normalizeBrowserType(raw)
- }
- }
- if raw := getStringArg(toolArgs(resolveBrowserConfig(config)), "type", "mode", "engine"); raw != "" {
- return normalizeBrowserType(raw)
- }
- return defaultBrowserType
-}
-
-func normalizeBrowserType(value string) string {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case browserTypePlaywright, "browser", "headless", "chromium":
- return browserTypePlaywright
- default:
- return ""
- }
-}
-
-func resolveBrowserWaitUntil(payload toolArgs, fallback string) string {
- if value := normalizeBrowserWaitUntil(getStringArg(payload, "waitUntil")); value != "" {
- return value
- }
- return normalizeBrowserWaitUntil(fallback)
-}
-
-func resolveBrowserWaitUntilState(value string) *playwright.WaitUntilState {
- switch normalizeBrowserWaitUntil(value) {
- case "commit":
- return playwright.WaitUntilStateCommit
- case "load":
- return playwright.WaitUntilStateLoad
- case "networkidle":
- return playwright.WaitUntilStateNetworkidle
- default:
- return playwright.WaitUntilStateDomcontentloaded
- }
-}
-
-func normalizeBrowserWaitUntil(value string) string {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "commit":
- return "commit"
- case "load":
- return "load"
- case "networkidle", "network_idle":
- return "networkidle"
- case "domcontentloaded", "dom_content_loaded", "dom-content-loaded":
- return "domcontentloaded"
- default:
- return ""
- }
-}
-
-func resolveBrowserLoadState(value string) *playwright.LoadState {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "load":
- return playwright.LoadStateLoad
- case "domcontentloaded", "dom_content_loaded", "dom-content-loaded":
- return playwright.LoadStateDomcontentloaded
- case "networkidle", "network_idle":
- return playwright.LoadStateNetworkidle
- default:
- return nil
- }
-}
-
-func resolveBrowserActionTimeoutMs(payload toolArgs, fallback int) int {
- if timeoutMs, ok := getIntArg(payload, "timeoutMs"); ok && timeoutMs > 0 {
- return normalizeBrowserTimeoutMs(timeoutMs, fallback)
- }
- if timeoutSeconds, ok := getIntArg(payload, "timeoutSeconds"); ok && timeoutSeconds > 0 {
- return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback)
- }
- return normalizeBrowserTimeoutMs(fallback, fallback)
-}
-
-func resolveBrowserActTimeoutMs(request toolArgs, payload toolArgs, fallback int) int {
- if timeoutMs, ok := getIntArg(request, "timeoutMs"); ok && timeoutMs > 0 {
- return normalizeBrowserTimeoutMs(timeoutMs, fallback)
- }
- if timeoutSeconds, ok := getIntArg(request, "timeoutSeconds"); ok && timeoutSeconds > 0 {
- return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback)
- }
- return resolveBrowserActionTimeoutMs(payload, fallback)
-}
-
-func normalizeBrowserTimeoutMs(value int, fallback int) int {
- if value <= 0 {
- value = fallback
- }
- if value < 500 {
- return 500
- }
- if value > 120000 {
- return 120000
- }
- return value
-}
-
-func toBrowserFriendlyInteractionError(err error, selector string) error {
- if err == nil {
- return nil
- }
- message := strings.TrimSpace(err.Error())
- if message == "" {
- return err
- }
- if strings.Contains(message, "strict mode violation") {
- count := "multiple"
- if match := browserStrictModeCountPattern.FindStringSubmatch(message); len(match) >= 2 {
- count = strings.TrimSpace(match[1])
- }
- return fmt.Errorf(`Selector "%s" matched %s elements. Run a new snapshot to get updated refs, or use a different ref.`, selector, count)
- }
- if (strings.Contains(message, "Timeout") || strings.Contains(message, "waiting for")) &&
- (strings.Contains(message, "to be visible") || strings.Contains(message, "not visible")) {
- return fmt.Errorf(`Element "%s" not found or not visible. Run a new snapshot to see current page elements.`, selector)
- }
- if strings.Contains(message, "intercepts pointer events") ||
- strings.Contains(message, "not receive pointer events") ||
- strings.Contains(message, "not visible") {
- return fmt.Errorf(`Element "%s" is not interactable (hidden or covered). Try scrolling it into view, closing overlays, or re-snapshotting.`, selector)
- }
- return err
-}
-
-func addConnectorCookiesToContext(ctx context.Context, connectors ConnectorsReader, browserCtx playwright.BrowserContext, targetURL string) error {
- if connectors == nil || browserCtx == nil {
- return nil
- }
- cookies, err := resolveConnectorCookiesForURL(ctx, connectors, targetURL)
- if err != nil {
- return err
- }
- if len(cookies) == 0 {
- return nil
- }
- return browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL))
-}
-
-func assertBrowserURLAllowed(rawURL string, policy browserSSRFPolicy) error {
- trimmed := strings.TrimSpace(rawURL)
- if trimmed == "" {
- return errors.New("targetUrl is required")
- }
- parsed, err := url.Parse(trimmed)
- if err != nil {
- return fmt.Errorf("invalid url: %w", err)
- }
- if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
- return errors.New("only http(s) urls are supported")
- }
- hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
- if hostname == "" {
- return errors.New("url hostname is required")
- }
- if isHostnameExplicitlyAllowed(hostname, policy) {
- return nil
- }
- if policy.DangerouslyAllowPrivateNetwork {
- return nil
- }
- if hostname == "localhost" || strings.HasSuffix(hostname, ".local") || strings.HasSuffix(hostname, ".internal") {
- return fmt.Errorf("blocked private hostname: %s", hostname)
- }
- if ip := net.ParseIP(hostname); ip != nil {
- if isPrivateOrLocalIP(ip) {
- return fmt.Errorf("blocked private IP: %s", hostname)
- }
- }
- return nil
-}
-
-func isHostnameExplicitlyAllowed(hostname string, policy browserSSRFPolicy) bool {
- hostname = strings.ToLower(strings.TrimSpace(hostname))
- if hostname == "" {
- return false
- }
- if _, ok := policy.AllowedHostnames[hostname]; ok {
- return true
- }
- for _, pattern := range policy.HostnameAllowlist {
- pattern = strings.ToLower(strings.TrimSpace(pattern))
- if pattern == "" {
- continue
- }
- if pattern == hostname {
- return true
- }
- if strings.HasPrefix(pattern, "*.") {
- suffix := strings.TrimPrefix(pattern, "*.")
- if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix {
- return true
- }
- continue
- }
- if matched, _ := filepath.Match(pattern, hostname); matched {
- return true
- }
- }
- return false
-}
-
-func isPrivateOrLocalIP(ip net.IP) bool {
- if ip == nil {
- return false
- }
- if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
- return true
- }
- if ip.IsUnspecified() || ip.IsMulticast() {
- return true
- }
- if ip4 := ip.To4(); ip4 != nil {
- // Carrier-grade NAT: 100.64.0.0/10
- if ip4[0] == 100 && (ip4[1]&0xC0) == 64 {
- return true
- }
- }
- return false
-}
-
-func waitBrowserEvaluateCondition(page playwright.Page, fn string, timeoutMs int) error {
- if timeoutMs <= 0 {
- timeoutMs = 15000
- }
- deadline := time.Now().Add(time.Duration(timeoutMs) * time.Millisecond)
- for {
- result, err := page.Evaluate(fn)
- if err == nil {
- if value, ok := result.(bool); ok && value {
- return nil
- }
- if result != nil {
- switch typed := result.(type) {
- case string:
- if strings.TrimSpace(strings.ToLower(typed)) == "true" {
- return nil
- }
- case float64:
- if typed != 0 {
- return nil
- }
- case int:
- if typed != 0 {
- return nil
- }
- }
- }
- }
- if time.Now().After(deadline) {
- if err != nil {
- return err
- }
- return errors.New("wait fn timeout")
- }
- time.Sleep(200 * time.Millisecond)
- }
-}
-
-func saveBrowserArtifact(ext string, content []byte) (string, error) {
- ext = strings.TrimSpace(strings.TrimPrefix(ext, "."))
- if ext == "" {
- ext = "bin"
- }
- dir := filepath.Join(os.TempDir(), "dreamcreator", "browser")
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return "", err
- }
- filename := fmt.Sprintf("%d-%d.%s", time.Now().UnixNano(), atomic.AddUint64(&browserGlobalTabCounter, 1), ext)
- path := filepath.Join(dir, filename)
- if err := os.WriteFile(path, content, 0o644); err != nil {
- return "", err
- }
- abs, err := filepath.Abs(path)
- if err != nil {
- return path, nil
- }
- return abs, nil
-}
-
-func resolveBrowserPlaywrightRuntimeAvailability() (bool, string, string) {
- now := time.Now()
- browserPlaywrightRuntimeCache.mu.Lock()
- if !browserPlaywrightRuntimeCache.checkedAt.IsZero() && now.Sub(browserPlaywrightRuntimeCache.checkedAt) < browserRuntimeCheckCacheTTL {
- available := browserPlaywrightRuntimeCache.available
- reason := browserPlaywrightRuntimeCache.reason
- execPath := browserPlaywrightRuntimeCache.execPath
- browserPlaywrightRuntimeCache.mu.Unlock()
- return available, reason, execPath
- }
- browserPlaywrightRuntimeCache.mu.Unlock()
-
- available := true
- reason := ""
- execPath := ""
-
- pw, err := playwright.Run(&playwright.RunOptions{
- Verbose: false,
- Stdout: io.Discard,
- Stderr: io.Discard,
- })
- if err != nil {
- available = false
- reason = trimToMaxChars(strings.TrimSpace(err.Error()), 220)
- } else {
- defer pw.Stop()
- execPath = strings.TrimSpace(pw.Chromium.ExecutablePath())
-
- browser, launchErr := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(true),
- Args: []string{"--headless=new"},
- })
- if launchErr != nil {
- available = false
- reason = trimToMaxChars(strings.TrimSpace(launchErr.Error()), 220)
- } else {
- _ = browser.Close()
- }
- }
-
- browserPlaywrightRuntimeCache.mu.Lock()
- browserPlaywrightRuntimeCache.checkedAt = now
- browserPlaywrightRuntimeCache.available = available
- browserPlaywrightRuntimeCache.reason = reason
- browserPlaywrightRuntimeCache.execPath = execPath
- browserPlaywrightRuntimeCache.mu.Unlock()
-
- return available, reason, execPath
-}
-
-func hasBrowserHeadlessArg(args []string) bool {
- for _, arg := range args {
- if strings.Contains(strings.ToLower(strings.TrimSpace(arg)), "headless") {
- return true
- }
- }
- return false
-}
-
-func toPlaywrightScreenshotType(value string) *playwright.ScreenshotType {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "jpeg", "jpg":
- return playwright.ScreenshotTypeJpeg
- default:
- return playwright.ScreenshotTypePng
- }
-}
-
-func toPlaywrightMouseButton(value string) *playwright.MouseButton {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "left":
- return playwright.MouseButtonLeft
- case "right":
- return playwright.MouseButtonRight
- case "middle":
- return playwright.MouseButtonMiddle
- default:
- return nil
- }
-}
-
-func toPlaywrightKeyboardModifiers(values []string) ([]playwright.KeyboardModifier, error) {
- if len(values) == 0 {
- return nil, nil
- }
- result := make([]playwright.KeyboardModifier, 0, len(values))
- for _, value := range values {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "alt":
- if playwright.KeyboardModifierAlt != nil {
- result = append(result, *playwright.KeyboardModifierAlt)
- }
- case "control", "ctrl":
- if playwright.KeyboardModifierControl != nil {
- result = append(result, *playwright.KeyboardModifierControl)
- }
- case "controlormeta", "control_or_meta":
- if playwright.KeyboardModifierControlOrMeta != nil {
- result = append(result, *playwright.KeyboardModifierControlOrMeta)
- }
- case "meta", "command", "cmd":
- if playwright.KeyboardModifierMeta != nil {
- result = append(result, *playwright.KeyboardModifierMeta)
- }
- case "shift":
- if playwright.KeyboardModifierShift != nil {
- result = append(result, *playwright.KeyboardModifierShift)
- }
- case "":
- continue
- default:
- return nil, errors.New("modifiers must be Alt|Control|ControlOrMeta|Meta|Shift")
- }
- }
- if len(result) == 0 {
- return nil, nil
- }
- return result, nil
-}
-
-func containsString(values []string, candidate string) bool {
- for _, value := range values {
- if value == candidate {
- return true
- }
- }
- return false
-}
-
-func anyToString(value any) string {
- switch typed := value.(type) {
- case string:
- return typed
- case fmt.Stringer:
- return typed.String()
- case float64:
- return strconv.FormatFloat(typed, 'f', -1, 64)
- case int:
- return strconv.Itoa(typed)
- case int64:
- return strconv.FormatInt(typed, 10)
- default:
- return ""
- }
-}
-
-func anyToInt(value any) int {
- switch typed := value.(type) {
- case int:
- return typed
- case int64:
- return int(typed)
- case float64:
- return int(typed)
- case string:
- parsed, _ := strconv.Atoi(strings.TrimSpace(typed))
- return parsed
- default:
- return 0
- }
-}
-
-func minInt(a int, b int) int {
- if a <= b {
- return a
- }
- return b
+ return ""
}
-func maxInt(a int, b int) int {
- if a >= b {
- return a
+func browserResultItemCount(result browsercdp.ActionResult) int {
+ if result.State != nil && result.State.ItemCount > 0 {
+ return result.State.ItemCount
}
- return b
+ return len(result.Items)
}
diff --git a/internal/application/gateway/tools/browser_tools_helpers.go b/internal/application/gateway/tools/browser_tools_helpers.go
new file mode 100644
index 0000000..9b9d363
--- /dev/null
+++ b/internal/application/gateway/tools/browser_tools_helpers.go
@@ -0,0 +1,779 @@
+package tools
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "dreamcreator/internal/application/browsercdp"
+ gatewaynodes "dreamcreator/internal/application/gateway/nodes"
+ targetpkg "github.com/chromedp/cdproto/target"
+)
+
+const (
+ browserTypeCDP = "cdp"
+ defaultBrowserType = browserTypeCDP
+
+ defaultBrowserWaitUntil = "domcontentloaded"
+
+ defaultBrowserProfileDreamCreator = "dreamcreator"
+ defaultBrowserColor = "#FF4500"
+
+ defaultBrowserSnapshotAIMaxChars = 80000
+ defaultBrowserSnapshotAIEfficientMaxChars = 10000
+ defaultBrowserSnapshotDepth = 6
+ defaultBrowserSnapshotLimit = 200
+ defaultBrowserViewportWidth = 1366
+ defaultBrowserViewportHeight = 900
+)
+
+var browserToolActions = []string{
+ "open",
+ "navigate",
+ "snapshot",
+ "act",
+ "wait",
+ "scroll",
+ "upload",
+ "dialog",
+ "reset",
+}
+
+var browserSelectorUnsupportedMessage = strings.Join([]string{
+ "Error: 'selector' is not supported. Use 'ref' from snapshot instead.",
+ "",
+ "Example workflow:",
+ "1. open or navigate to load the page",
+ "2. read returned items, or run snapshot to get fresh refs",
+ `3. act with ref: "e123" to interact with an element`,
+ "4. after the page changes, read returned items or run snapshot again",
+ "",
+ "This is more reliable for modern SPAs.",
+}, "\n")
+
+var browserWaitRequiresConditionMessage = "wait requires at least one of: timeMs, text, textGone, selector, url, fn"
+var errBrowserNoOpenTab = errors.New("no browser tab is open")
+var errBrowserSnapshotForAIUnavailable = errors.New("browser snapshotForAI is unavailable")
+var browserActKinds = []string{"click", "type", "press", "hover", "select", "fill", "resize", "wait", "evaluate", "close"}
+var browserGlobalTabCounter uint64
+
+type browserSnapshotItem struct {
+ Ref string `json:"ref,omitempty"`
+ AriaRef string `json:"ariaRef,omitempty"`
+ Role string `json:"role,omitempty"`
+ Name string `json:"name,omitempty"`
+ Text string `json:"text,omitempty"`
+ Depth int `json:"depth,omitempty"`
+ Nth int `json:"nth,omitempty"`
+}
+
+var browserAriaSnapshotLinePattern = regexp.MustCompile(`^(\s*)-\s*([^\s":]+)(?:\s+"([^"]*)")?(.*)$`)
+var browserAriaSnapshotRefPattern = regexp.MustCompile(`\[ref=([^\]]+)\]`)
+var browserStrictModeCountPattern = regexp.MustCompile(`resolved to (\d+) elements`)
+
+type browserResolvedConfig struct {
+ Enabled bool
+ Color string
+ Headless bool
+ DefaultProfile string
+ PreferredBrowser string
+ Profiles map[string]browserProfileConfig
+ SSRFRules browserSSRFPolicy
+}
+
+type browserProfileConfig struct {
+ Name string
+ Color string
+ Driver string
+}
+
+type browserSSRFPolicy struct {
+ DangerouslyAllowPrivateNetwork bool
+ AllowedHostnames map[string]struct{}
+ HostnameAllowlist []string
+}
+
+type browserNodeProxyEnvelope struct {
+ Result any `json:"result"`
+ Files []browserNodeProxyFile `json:"files"`
+}
+
+type browserNodeProxyFile struct {
+ Path string `json:"path"`
+ Base64 string `json:"base64"`
+ MimeType string `json:"mimeType"`
+}
+
+func resolveBrowserAction(payload toolArgs) (string, error) {
+ rawAction := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action")))
+ if rawAction == "" {
+ return "", errors.New("browser action is required")
+ }
+ switch rawAction {
+ case "open", "navigate", "snapshot", "act", "wait", "scroll", "upload", "dialog", "reset":
+ return rawAction, nil
+ default:
+ return "", errors.New("browser action not supported: " + rawAction)
+ }
+}
+
+func pickReusableBrowserTargetID(infos []*targetpkg.Info) string {
+ choose := func(requireUnattached bool, preferBlank bool) string {
+ for _, info := range infos {
+ if info == nil || info.Type != "page" {
+ continue
+ }
+ if requireUnattached && info.Attached {
+ continue
+ }
+ if preferBlank && !isReusableBrowserPageURL(info.URL) {
+ continue
+ }
+ return string(info.TargetID)
+ }
+ return ""
+ }
+ for _, candidate := range []string{
+ choose(true, true),
+ choose(true, false),
+ choose(false, true),
+ choose(false, false),
+ } {
+ if strings.TrimSpace(candidate) != "" {
+ return candidate
+ }
+ }
+ return ""
+}
+
+func isReusableBrowserPageURL(rawURL string) bool {
+ trimmed := strings.TrimSpace(strings.ToLower(rawURL))
+ switch trimmed {
+ case "", "about:blank", "chrome://newtab/", "chrome-search://local-ntp/local-ntp.html":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldTreatBrowserNavigationAsComplete(observedURL string, previousURL string, targetURL string) bool {
+ observed := strings.TrimSpace(observedURL)
+ if observed == "" || observed == "about:blank" {
+ return false
+ }
+ if urlsEqual(observed, targetURL) {
+ return true
+ }
+ previous := strings.TrimSpace(previousURL)
+ if previous == "" || previous == "about:blank" {
+ return true
+ }
+ return !urlsEqual(observed, previous)
+}
+
+func urlsEqual(left string, right string) bool {
+ return strings.TrimSpace(left) == strings.TrimSpace(right)
+}
+
+func shouldResetBrowserProfileAfterError(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "browser runtime unavailable"),
+ strings.Contains(message, "context canceled"),
+ strings.Contains(message, "target closed"),
+ strings.Contains(message, "connection closed"),
+ strings.Contains(message, "websocket"),
+ strings.Contains(message, "session closed"),
+ strings.Contains(message, "browser session reset"):
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldDeferBrowserStateCaptureError(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ switch {
+ case strings.Contains(message, "context deadline exceeded"),
+ strings.Contains(message, "execution context was destroyed"),
+ strings.Contains(message, "cannot find context with specified id"),
+ strings.Contains(message, "unique context id not found"):
+ return true
+ default:
+ return false
+ }
+}
+
+func isBrowserNodeTargetRequest(payload toolArgs) bool {
+ target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target")))
+ nodeID := strings.TrimSpace(getStringArg(payload, "node", "nodeId"))
+ return target == "node" || nodeID != ""
+}
+
+func resolveBrowserNodeID(ctx context.Context, payload toolArgs, nodes *gatewaynodes.Service) (string, error) {
+ requestedNode := strings.TrimSpace(getStringArg(payload, "node", "nodeId"))
+ if requestedNode != "" {
+ return requestedNode, nil
+ }
+ if nodes == nil {
+ return "", errors.New("nodes service unavailable")
+ }
+ list, err := nodes.ListNodes(ctx)
+ if err != nil {
+ return "", err
+ }
+ for _, descriptor := range list {
+ nodeID := strings.TrimSpace(descriptor.NodeID)
+ if nodeID == "" {
+ continue
+ }
+ for _, capability := range descriptor.Capabilities {
+ if strings.EqualFold(strings.TrimSpace(capability.Name), "browser.control") {
+ return nodeID, nil
+ }
+ }
+ }
+ for _, descriptor := range list {
+ if nodeID := strings.TrimSpace(descriptor.NodeID); nodeID != "" {
+ return nodeID, nil
+ }
+ }
+ return "", errors.New("nodeId is required")
+}
+
+func runBrowserActionOnNode(ctx context.Context, payload toolArgs, action string, nodes *gatewaynodes.Service) (string, error) {
+ if nodes == nil {
+ return "", errors.New("nodes service unavailable")
+ }
+ target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target")))
+ if target != "" && target != "node" {
+ return "", errors.New(`node is only supported with target="node"`)
+ }
+ nodeID, err := resolveBrowserNodeID(ctx, payload, nodes)
+ if err != nil {
+ return "", err
+ }
+ argsJSON, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ request := gatewaynodes.NodeInvokeRequest{
+ NodeID: nodeID,
+ Capability: "browser.control",
+ Action: action,
+ Args: string(argsJSON),
+ TimeoutMs: resolveBrowserActionTimeoutMs(payload, 30000),
+ }
+ result, invokeErr := nodes.Invoke(ctx, request)
+ if invokeErr != nil {
+ return marshalResult(result), invokeErr
+ }
+ if !result.Ok {
+ if strings.TrimSpace(result.Error) != "" {
+ return marshalResult(result), errors.New(strings.TrimSpace(result.Error))
+ }
+ return marshalResult(result), errors.New("node browser invoke failed")
+ }
+ if parsed := resolveBrowserNodeOutput(result.Output); parsed != nil {
+ return marshalResult(parsed), nil
+ }
+ return marshalResult(result), nil
+}
+
+func resolveBrowserNodeOutput(output string) any {
+ trimmedOutput := strings.TrimSpace(output)
+ if trimmedOutput == "" {
+ return nil
+ }
+ var parsed any
+ if err := json.Unmarshal([]byte(trimmedOutput), &parsed); err != nil {
+ return nil
+ }
+ envelope := browserNodeProxyEnvelope{}
+ if err := json.Unmarshal([]byte(trimmedOutput), &envelope); err == nil && envelope.Result != nil {
+ mapping := persistBrowserNodeProxyFiles(envelope.Files)
+ applyBrowserProxyPathMapping(envelope.Result, mapping)
+ return envelope.Result
+ }
+ return parsed
+}
+
+func persistBrowserNodeProxyFiles(files []browserNodeProxyFile) map[string]string {
+ if len(files) == 0 {
+ return nil
+ }
+ mapping := map[string]string{}
+ for _, file := range files {
+ remotePath := strings.TrimSpace(file.Path)
+ encoded := strings.TrimSpace(file.Base64)
+ if remotePath == "" || encoded == "" {
+ continue
+ }
+ bytes, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ continue
+ }
+ localPath, err := saveBrowserArtifact(resolveBrowserProxyFileExt(file), bytes)
+ if err != nil {
+ continue
+ }
+ mapping[remotePath] = localPath
+ }
+ if len(mapping) == 0 {
+ return nil
+ }
+ return mapping
+}
+
+func resolveBrowserProxyFileExt(file browserNodeProxyFile) string {
+ ext := strings.TrimSpace(strings.TrimPrefix(filepath.Ext(strings.TrimSpace(file.Path)), "."))
+ if ext != "" {
+ return ext
+ }
+ mimeType := strings.ToLower(strings.TrimSpace(file.MimeType))
+ switch {
+ case strings.Contains(mimeType, "png"):
+ return "png"
+ case strings.Contains(mimeType, "jpeg"), strings.Contains(mimeType, "jpg"):
+ return "jpg"
+ case strings.Contains(mimeType, "pdf"):
+ return "pdf"
+ case strings.Contains(mimeType, "json"):
+ return "json"
+ case strings.Contains(mimeType, "text"), strings.Contains(mimeType, "plain"):
+ return "txt"
+ default:
+ return "bin"
+ }
+}
+
+func applyBrowserProxyPathMapping(result any, mapping map[string]string) {
+ if len(mapping) == 0 || result == nil {
+ return
+ }
+ obj, ok := result.(map[string]any)
+ if !ok {
+ return
+ }
+ if pathValue, ok := obj["path"].(string); ok {
+ if mapped, exists := mapping[pathValue]; exists {
+ obj["path"] = mapped
+ }
+ }
+ if imagePathValue, ok := obj["imagePath"].(string); ok {
+ if mapped, exists := mapping[imagePathValue]; exists {
+ obj["imagePath"] = mapped
+ }
+ }
+}
+
+func resolveBrowserSessionKey(ctx context.Context, payload toolArgs) string {
+ sessionKey, _ := RuntimeContextFromContext(ctx)
+ sessionKey = strings.TrimSpace(sessionKey)
+ if sessionKey == "" {
+ sessionKey = strings.TrimSpace(getStringArg(payload, "sessionKey", "session_key"))
+ }
+ if sessionKey == "" {
+ sessionKey = "default"
+ }
+ return sessionKey
+}
+
+func resolveBrowserRuntimeConfig(config map[string]any) browserResolvedConfig {
+ browserConfig := getNestedMap(config, "browser")
+ resolved := browserResolvedConfig{
+ Enabled: true,
+ Color: defaultBrowserColor,
+ Headless: false,
+ DefaultProfile: defaultBrowserProfileDreamCreator,
+ PreferredBrowser: string(browsercdp.BrowserChrome),
+ Profiles: map[string]browserProfileConfig{
+ defaultBrowserProfileDreamCreator: {
+ Name: defaultBrowserProfileDreamCreator,
+ Color: defaultBrowserColor,
+ Driver: browserTypeCDP,
+ },
+ },
+ SSRFRules: browserSSRFPolicy{
+ DangerouslyAllowPrivateNetwork: false,
+ AllowedHostnames: map[string]struct{}{},
+ },
+ }
+ if browserConfig == nil {
+ return resolved
+ }
+ if value, ok := getBoolArg(toolArgs(browserConfig), "enabled"); ok {
+ resolved.Enabled = value
+ }
+ if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "color")); value != "" {
+ resolved.Color = value
+ }
+ if value, ok := getBoolArg(toolArgs(browserConfig), "headless"); ok {
+ resolved.Headless = value
+ }
+ if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "preferredBrowser")); value != "" {
+ resolved.PreferredBrowser = strings.ToLower(value)
+ }
+ if ssrfRaw := getMapArg(toolArgs(browserConfig), "ssrfPolicy"); ssrfRaw != nil {
+ if value, ok := getBoolArg(toolArgs(ssrfRaw), "dangerouslyAllowPrivateNetwork"); ok {
+ resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value
+ }
+ for _, hostname := range getStringSliceArg(toolArgs(ssrfRaw), "allowedHostnames") {
+ resolved.SSRFRules.AllowedHostnames[strings.ToLower(strings.TrimSpace(hostname))] = struct{}{}
+ }
+ if allowlist := normalizeStringSlice(getStringSliceArg(toolArgs(ssrfRaw), "hostnameAllowlist")); len(allowlist) > 0 {
+ resolved.SSRFRules.HostnameAllowlist = allowlist
+ }
+ }
+ return resolved
+}
+
+func resolveBrowserProfileName(payload toolArgs, resolved browserResolvedConfig) string {
+ profile := strings.TrimSpace(getStringArg(payload, "profile"))
+ if profile == "" {
+ profile = strings.TrimSpace(resolved.DefaultProfile)
+ }
+ if profile == "" {
+ profile = defaultBrowserProfileDreamCreator
+ }
+ return profile
+}
+
+func resolveBrowserConfig(config map[string]any) map[string]any {
+ return getNestedMap(config, "browser")
+}
+
+func resolveBrowserConfigBool(config map[string]any, key string) (bool, bool) {
+ return getBoolArg(toolArgs(resolveBrowserConfig(config)), key)
+}
+
+func resolveBrowserActionTimeoutMs(payload toolArgs, fallback int) int {
+ if timeoutMs, ok := getIntArg(payload, "timeoutMs"); ok && timeoutMs > 0 {
+ return normalizeBrowserTimeoutMs(timeoutMs, fallback)
+ }
+ if timeoutSeconds, ok := getIntArg(payload, "timeoutSeconds"); ok && timeoutSeconds > 0 {
+ return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback)
+ }
+ return normalizeBrowserTimeoutMs(fallback, fallback)
+}
+
+func resolveBrowserActTimeoutMs(request toolArgs, payload toolArgs, fallback int) int {
+ if timeoutMs, ok := getIntArg(request, "timeoutMs"); ok && timeoutMs > 0 {
+ return normalizeBrowserTimeoutMs(timeoutMs, fallback)
+ }
+ if timeoutSeconds, ok := getIntArg(request, "timeoutSeconds"); ok && timeoutSeconds > 0 {
+ return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback)
+ }
+ return resolveBrowserActionTimeoutMs(payload, fallback)
+}
+
+func resolveBrowserSnapshotLimit(payload toolArgs) int {
+ if value, ok := getIntArg(payload, "limit"); ok && value > 0 {
+ return value
+ }
+ return defaultBrowserSnapshotLimit
+}
+
+func resolveBrowserScrollDelta(payload toolArgs) (int, int) {
+ if x, ok := getIntArg(payload, "x"); ok {
+ if y, ok := getIntArg(payload, "y"); ok {
+ return x, y
+ }
+ return x, 0
+ }
+ if y, ok := getIntArg(payload, "y"); ok {
+ return 0, y
+ }
+ amount, ok := getIntArg(payload, "amount")
+ if !ok || amount <= 0 {
+ amount = 700
+ }
+ switch strings.ToLower(strings.TrimSpace(getStringArg(payload, "direction"))) {
+ case "up":
+ return 0, -amount
+ case "left":
+ return -amount, 0
+ case "right":
+ return amount, 0
+ default:
+ return 0, amount
+ }
+}
+
+func normalizeBrowserTimeoutMs(value int, fallback int) int {
+ if value <= 0 {
+ value = fallback
+ }
+ if value < 500 {
+ return 500
+ }
+ if value > 120000 {
+ return 120000
+ }
+ return value
+}
+
+func parseBrowserAriaSnapshot(snapshot string, interactive bool, limit int, useAriaRefs bool, maxDepth int) []browserSnapshotItem {
+ if limit <= 0 {
+ limit = defaultBrowserSnapshotLimit
+ }
+ lines := strings.Split(snapshot, "\n")
+ items := make([]browserSnapshotItem, 0, minInt(len(lines), limit))
+ countByRoleName := map[string]int{}
+ nextRef := 0
+ for _, rawLine := range lines {
+ if len(items) >= limit {
+ break
+ }
+ line := strings.TrimRight(rawLine, "\r")
+ if strings.TrimSpace(line) == "" {
+ continue
+ }
+ match := browserAriaSnapshotLinePattern.FindStringSubmatch(line)
+ if len(match) < 5 {
+ continue
+ }
+ role := strings.ToLower(strings.TrimSpace(match[2]))
+ if role == "" || strings.HasPrefix(role, "/") {
+ continue
+ }
+ if interactive && !isBrowserInteractiveRole(role) {
+ continue
+ }
+ name := strings.TrimSpace(match[3])
+ suffix := strings.TrimSpace(match[4])
+ text := browserAriaSnapshotTextFromSuffix(suffix)
+ depth := browserAriaSnapshotDepth(line)
+ if maxDepth > 0 && depth > maxDepth {
+ continue
+ }
+ ariaRef := ""
+ if refMatch := browserAriaSnapshotRefPattern.FindStringSubmatch(suffix); len(refMatch) >= 2 {
+ ariaRef = strings.TrimSpace(refMatch[1])
+ }
+ key := role + "\n" + name
+ nth := countByRoleName[key]
+ countByRoleName[key] = nth + 1
+ nextRef++
+ ref := fmt.Sprintf("e%d", nextRef)
+ if useAriaRefs && ariaRef != "" {
+ ref = ariaRef
+ }
+ item := browserSnapshotItem{
+ Ref: ref,
+ AriaRef: ariaRef,
+ Role: role,
+ Name: name,
+ Text: text,
+ Depth: depth,
+ Nth: nth,
+ }
+ if item.Name == "" {
+ item.Name = item.Text
+ }
+ if item.Name == "" {
+ item.Name = "(empty)"
+ }
+ items = append(items, item)
+ }
+ return items
+}
+
+func browserAriaSnapshotDepth(line string) int {
+ spaces := 0
+ for _, r := range line {
+ if r == ' ' {
+ spaces++
+ continue
+ }
+ if r == '\t' {
+ spaces += 2
+ continue
+ }
+ break
+ }
+ return spaces / 2
+}
+
+func browserAriaSnapshotTextFromSuffix(suffix string) string {
+ suffix = strings.TrimSpace(suffix)
+ if suffix == "" {
+ return ""
+ }
+ if idx := strings.LastIndex(suffix, ":"); idx >= 0 && idx+1 < len(suffix) {
+ return strings.TrimSpace(suffix[idx+1:])
+ }
+ return ""
+}
+
+func isBrowserInteractiveRole(role string) bool {
+ switch strings.ToLower(strings.TrimSpace(role)) {
+ case "button", "link", "textbox", "checkbox", "radio", "combobox", "listbox", "menuitem", "menuitemcheckbox",
+ "menuitemradio", "option", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem":
+ return true
+ default:
+ return false
+ }
+}
+
+func resolveBrowserUploadRootDir() (string, error) {
+ dir := filepath.Join(os.TempDir(), "dreamcreator", "browser", "uploads")
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return "", err
+ }
+ return filepath.Abs(dir)
+}
+
+func resolvePathWithinRoot(rootDir string, requestedPath string) (string, error) {
+ root := strings.TrimSpace(rootDir)
+ if root == "" {
+ return "", errors.New("upload root is required")
+ }
+ rootAbs, err := filepath.Abs(root)
+ if err != nil {
+ return "", err
+ }
+ raw := strings.TrimSpace(requestedPath)
+ if raw == "" {
+ return "", errors.New("paths are required")
+ }
+ var candidate string
+ if filepath.IsAbs(raw) {
+ candidate = filepath.Clean(raw)
+ } else {
+ candidate = filepath.Join(rootAbs, raw)
+ }
+ candidateAbs, err := filepath.Abs(candidate)
+ if err != nil {
+ return "", err
+ }
+ rel, err := filepath.Rel(rootAbs, candidateAbs)
+ if err != nil {
+ return "", err
+ }
+ rel = filepath.ToSlash(strings.TrimSpace(rel))
+ if rel == "." || strings.HasPrefix(rel, "../") || rel == ".." {
+ return "", fmt.Errorf("Invalid path: must stay within uploads directory (%s)", rootAbs)
+ }
+ return candidateAbs, nil
+}
+
+func isBrowserSnapshotForAIUnavailable(err error) bool {
+ if err == nil {
+ return false
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ return strings.Contains(message, "_snapshotforai is not available") ||
+ strings.Contains(message, "snapshotforai is unavailable")
+}
+
+func toBrowserFriendlyInteractionError(err error, selector string) error {
+ if err == nil {
+ return nil
+ }
+ message := strings.TrimSpace(err.Error())
+ if strings.Contains(message, "strict mode violation") {
+ count := "multiple"
+ if match := browserStrictModeCountPattern.FindStringSubmatch(message); len(match) >= 2 {
+ count = strings.TrimSpace(match[1])
+ }
+ return fmt.Errorf(`Selector "%s" matched %s elements. Run snapshot again to get updated refs, or use a different ref.`, selector, count)
+ }
+ return err
+}
+
+func assertBrowserURLAllowed(rawURL string, policy browserSSRFPolicy) error {
+ return browsercdp.AssertURLAllowed(rawURL, browsercdp.SSRFPolicy{
+ DangerouslyAllowPrivateNetwork: policy.DangerouslyAllowPrivateNetwork,
+ AllowedHostnames: cloneBrowserAllowedHostnames(policy.AllowedHostnames),
+ HostnameAllowlist: append([]string(nil), policy.HostnameAllowlist...),
+ })
+}
+
+func isHostnameExplicitlyAllowed(hostname string, policy browserSSRFPolicy) bool {
+ hostname = strings.ToLower(strings.TrimSpace(hostname))
+ if hostname == "" {
+ return false
+ }
+ if _, ok := policy.AllowedHostnames[hostname]; ok {
+ return true
+ }
+ for _, pattern := range policy.HostnameAllowlist {
+ pattern = strings.ToLower(strings.TrimSpace(pattern))
+ if pattern == "" {
+ continue
+ }
+ if pattern == hostname {
+ return true
+ }
+ if strings.HasPrefix(pattern, "*.") {
+ suffix := strings.TrimPrefix(pattern, "*.")
+ if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix {
+ return true
+ }
+ continue
+ }
+ if matched, _ := filepath.Match(pattern, hostname); matched {
+ return true
+ }
+ }
+ return false
+}
+
+func isPrivateOrLocalIP(ip net.IP) bool {
+ if ip == nil {
+ return false
+ }
+ if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() || ip.IsMulticast() {
+ return true
+ }
+ if ip4 := ip.To4(); ip4 != nil {
+ if ip4[0] == 100 && (ip4[1]&0xC0) == 64 {
+ return true
+ }
+ }
+ return false
+}
+
+func saveBrowserArtifact(ext string, content []byte) (string, error) {
+ ext = strings.TrimSpace(strings.TrimPrefix(ext, "."))
+ if ext == "" {
+ ext = "bin"
+ }
+ dir := filepath.Join(os.TempDir(), "dreamcreator", "browser")
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return "", err
+ }
+ filename := fmt.Sprintf("%d-%d.%s", time.Now().UnixNano(), atomic.AddUint64(&browserGlobalTabCounter, 1), ext)
+ path := filepath.Join(dir, filename)
+ if err := os.WriteFile(path, content, 0o644); err != nil {
+ return "", err
+ }
+ return filepath.Abs(path)
+}
+
+func resolveBrowserRuntimeAvailability(preferred string, headless bool) browsercdp.Status {
+ return browsercdp.ResolveStatus(preferred, headless)
+}
+
+func minInt(left int, right int) int {
+ if left < right {
+ return left
+ }
+ return right
+}
diff --git a/internal/application/gateway/tools/browser_tools_test.go b/internal/application/gateway/tools/browser_tools_test.go
index f4c13a4..be8bd52 100644
--- a/internal/application/gateway/tools/browser_tools_test.go
+++ b/internal/application/gateway/tools/browser_tools_test.go
@@ -9,29 +9,27 @@ import (
"path/filepath"
"strings"
"testing"
+
+ "dreamcreator/internal/application/agentruntime"
+ "dreamcreator/internal/application/browsercdp"
+ "github.com/cloudwego/eino/schema"
)
-func TestResolveBrowserActionDefaultsStatus(t *testing.T) {
+func TestResolveBrowserActionRequiresExplicitAction(t *testing.T) {
t.Parallel()
- action, err := resolveBrowserAction(toolArgs{})
- if err != nil {
- t.Fatalf("resolve action: %v", err)
- }
- if action != "status" {
- t.Fatalf("expected default action status, got %q", action)
+ _, err := resolveBrowserAction(toolArgs{})
+ if err == nil {
+ t.Fatalf("expected missing action error")
}
}
-func TestResolveBrowserActionDefaultsOpenWhenURLPresent(t *testing.T) {
+func TestResolveBrowserActionRequiresExplicitActionEvenWhenURLPresent(t *testing.T) {
t.Parallel()
- action, err := resolveBrowserAction(toolArgs{"url": "https://example.com"})
- if err != nil {
- t.Fatalf("resolve action: %v", err)
- }
- if action != "open" {
- t.Fatalf("expected default action open with url, got %q", action)
+ _, err := resolveBrowserAction(toolArgs{"targetUrl": "https://example.com"})
+ if err == nil {
+ t.Fatalf("expected missing action error")
}
}
@@ -44,6 +42,98 @@ func TestResolveBrowserActionRejectsUnsupportedFetchAlias(t *testing.T) {
}
}
+func TestSpecBrowserExposesWorkflowActionsOnly(t *testing.T) {
+ t.Parallel()
+
+ var schema map[string]any
+ if err := json.Unmarshal([]byte(specBrowser().SchemaJSON), &schema); err != nil {
+ t.Fatalf("decode schema: %v", err)
+ }
+ properties, _ := schema["properties"].(map[string]any)
+ actionDef, _ := properties["action"].(map[string]any)
+ rawEnum, _ := actionDef["enum"].([]any)
+ got := make([]string, 0, len(rawEnum))
+ for _, item := range rawEnum {
+ if value, ok := item.(string); ok && value != "" {
+ got = append(got, value)
+ }
+ }
+ want := []string{"open", "navigate", "snapshot", "act", "wait", "scroll", "upload", "dialog", "reset"}
+ if strings.Join(got, ",") != strings.Join(want, ",") {
+ t.Fatalf("unexpected browser actions: got %v want %v", got, want)
+ }
+}
+
+func TestSpecBrowserDescriptionExplainsSnapshotWorkflow(t *testing.T) {
+ t.Parallel()
+
+ description := specBrowser().Description
+ if !strings.Contains(description, "browser-use style loop") {
+ t.Fatalf("expected browser description to mention browser-use loop, got %q", description)
+ }
+ if !strings.Contains(description, "pass `url` or `targetUrl`") {
+ t.Fatalf("expected browser description to document url alias, got %q", description)
+ }
+ if !strings.Contains(description, "return `stateAvailable`, `itemCount`, and the current page `state`/`items`") {
+ t.Fatalf("expected browser description to explain open/navigate result state, got %q", description)
+ }
+ if !strings.Contains(description, "`stateAvailable=false`") {
+ t.Fatalf("expected browser description to explain stateAvailable fallback, got %q", description)
+ }
+ if !strings.Contains(description, "call `snapshot` to refresh them") {
+ t.Fatalf("expected browser description to explain open/navigate flow, got %q", description)
+ }
+ if !strings.Contains(description, "use `ref` from the latest state") {
+ t.Fatalf("expected browser description to prefer refs, got %q", description)
+ }
+}
+
+func TestSpecBrowserSchemaAcceptsOpenURLAlias(t *testing.T) {
+ t.Parallel()
+
+ validator := agentruntime.JSONToolValidator{
+ Tools: map[string]agentruntime.ToolDefinition{
+ "browser": {
+ Name: "browser",
+ SchemaJSON: specBrowser().SchemaJSON,
+ },
+ },
+ }
+
+ err := validator.Validate(schema.ToolCall{
+ Function: schema.FunctionCall{
+ Name: "browser",
+ Arguments: `{"action":"open","url":"https://example.com"}`,
+ },
+ })
+ if err != nil {
+ t.Fatalf("expected open url alias to validate, got %v", err)
+ }
+}
+
+func TestSpecBrowserSchemaAcceptsNavigateURLAlias(t *testing.T) {
+ t.Parallel()
+
+ validator := agentruntime.JSONToolValidator{
+ Tools: map[string]agentruntime.ToolDefinition{
+ "browser": {
+ Name: "browser",
+ SchemaJSON: specBrowser().SchemaJSON,
+ },
+ },
+ }
+
+ err := validator.Validate(schema.ToolCall{
+ Function: schema.FunctionCall{
+ Name: "browser",
+ Arguments: `{"action":"navigate","url":"https://example.com"}`,
+ },
+ })
+ if err != nil {
+ t.Fatalf("expected navigate url alias to validate, got %v", err)
+ }
+}
+
func TestResolveBrowserRuntimeConfigDefaults(t *testing.T) {
t.Parallel()
@@ -51,17 +141,14 @@ func TestResolveBrowserRuntimeConfigDefaults(t *testing.T) {
if !resolved.Enabled {
t.Fatalf("expected browser enabled by default")
}
- if !resolved.EvaluateEnabled {
- t.Fatalf("expected evaluateEnabled default true")
- }
if resolved.DefaultProfile != defaultBrowserProfileDreamCreator {
t.Fatalf("expected default profile %q, got %q", defaultBrowserProfileDreamCreator, resolved.DefaultProfile)
}
if _, ok := resolved.Profiles[defaultBrowserProfileDreamCreator]; !ok {
t.Fatalf("expected dreamcreator profile default")
}
- if !resolved.SSRFRules.DangerouslyAllowPrivateNetwork {
- t.Fatalf("expected ssrf dangerous allow default true")
+ if resolved.SSRFRules.DangerouslyAllowPrivateNetwork {
+ t.Fatalf("expected ssrf dangerous allow default false")
}
}
@@ -70,11 +157,8 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) {
resolved := resolveBrowserRuntimeConfig(map[string]any{
"browser": map[string]any{
- "enabled": false,
- "evaluateEnabled": false,
- "headless": true,
- "noSandbox": true,
- "extraArgs": []any{"--window-size=1280,900"},
+ "enabled": false,
+ "headless": true,
"ssrfPolicy": map[string]any{
"dangerouslyAllowPrivateNetwork": false,
"allowedHostnames": []any{"localhost"},
@@ -85,18 +169,9 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) {
if resolved.Enabled {
t.Fatalf("expected enabled false")
}
- if resolved.EvaluateEnabled {
- t.Fatalf("expected evaluateEnabled false")
- }
if !resolved.Headless {
t.Fatalf("expected headless true")
}
- if !resolved.NoSandbox {
- t.Fatalf("expected noSandbox true")
- }
- if len(resolved.ExtraArgs) != 1 || resolved.ExtraArgs[0] != "--window-size=1280,900" {
- t.Fatalf("expected extraArgs from config")
- }
if resolved.SSRFRules.DangerouslyAllowPrivateNetwork {
t.Fatalf("expected ssrf dangerous allow false")
}
@@ -108,6 +183,108 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) {
}
}
+func TestBrowserToolRuntimeConfigChangedDetectsRuntimeChanges(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ previous map[string]any
+ current map[string]any
+ want bool
+ }{
+ {
+ name: "same config",
+ previous: map[string]any{
+ "browser": map[string]any{
+ "enabled": true,
+ "headless": true,
+ "preferredBrowser": "chrome",
+ },
+ },
+ current: map[string]any{
+ "browser": map[string]any{
+ "enabled": true,
+ "headless": true,
+ "preferredBrowser": "chrome",
+ },
+ },
+ want: false,
+ },
+ {
+ name: "headless changed",
+ previous: map[string]any{
+ "browser": map[string]any{
+ "headless": false,
+ },
+ },
+ current: map[string]any{
+ "browser": map[string]any{
+ "headless": true,
+ },
+ },
+ want: true,
+ },
+ {
+ name: "preferred browser changed",
+ previous: map[string]any{
+ "browser": map[string]any{
+ "preferredBrowser": "chrome",
+ },
+ },
+ current: map[string]any{
+ "browser": map[string]any{
+ "preferredBrowser": "brave",
+ },
+ },
+ want: true,
+ },
+ {
+ name: "enabled changed",
+ previous: map[string]any{
+ "browser": map[string]any{
+ "enabled": true,
+ },
+ },
+ current: map[string]any{
+ "browser": map[string]any{
+ "enabled": false,
+ },
+ },
+ want: true,
+ },
+ {
+ name: "ssrf changed only",
+ previous: map[string]any{
+ "browser": map[string]any{
+ "headless": true,
+ "ssrfPolicy": map[string]any{
+ "dangerouslyAllowPrivateNetwork": false,
+ },
+ },
+ },
+ current: map[string]any{
+ "browser": map[string]any{
+ "headless": true,
+ "ssrfPolicy": map[string]any{
+ "dangerouslyAllowPrivateNetwork": true,
+ },
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := BrowserToolRuntimeConfigChanged(tt.previous, tt.current); got != tt.want {
+ t.Fatalf("BrowserToolRuntimeConfigChanged() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
func TestAssertBrowserURLAllowedStrictPolicyBlocksPrivate(t *testing.T) {
t.Parallel()
@@ -227,7 +404,7 @@ func TestParseBrowserAriaSnapshotRespectsMaxDepth(t *testing.T) {
func TestBrowserActionActRejectsSelectorForNonWaitBeforeRuntime(t *testing.T) {
t.Parallel()
- _, err := browserActionAct(toolArgs{
+ _, err := browserActionAct(context.Background(), toolArgs{
"request": map[string]any{
"kind": "click",
"selector": "button.save",
@@ -241,6 +418,50 @@ func TestBrowserActionActRejectsSelectorForNonWaitBeforeRuntime(t *testing.T) {
}
}
+func TestResolveBrowserActionAcceptsSnapshotAction(t *testing.T) {
+ t.Parallel()
+
+ action, err := resolveBrowserAction(toolArgs{
+ "action": "snapshot",
+ })
+ if err != nil {
+ t.Fatalf("expected snapshot action to be accepted, got %v", err)
+ }
+ if action != "snapshot" {
+ t.Fatalf("unexpected action: got %q want snapshot", action)
+ }
+}
+
+func TestResolveBrowserActionRejectsStateAction(t *testing.T) {
+ t.Parallel()
+
+ _, err := resolveBrowserAction(toolArgs{
+ "action": "state",
+ })
+ if err == nil {
+ t.Fatalf("expected state action to be rejected")
+ }
+ if !strings.Contains(err.Error(), "not supported") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestBrowserActionActRejectsUnsupportedNavigateKind(t *testing.T) {
+ t.Parallel()
+
+ _, err := browserActionAct(context.Background(), toolArgs{
+ "request": map[string]any{
+ "kind": "navigate",
+ },
+ }, &browserProfileState{})
+ if err == nil {
+ t.Fatalf("expected unsupported kind error")
+ }
+ if !strings.Contains(err.Error(), "act kind not supported") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
func TestNormalizeBrowserTimeoutMsClampRange(t *testing.T) {
t.Parallel()
@@ -255,28 +476,21 @@ func TestNormalizeBrowserTimeoutMsClampRange(t *testing.T) {
}
}
-func TestTabResultFromEvaluateReturnsStoredValue(t *testing.T) {
+func TestResolveBrowserScrollDeltaDefaultsDown(t *testing.T) {
t.Parallel()
- tab := &browserTabState{TargetID: "tab-1"}
- tab.mu.Lock()
- tab.evaluateResult = map[string]any{"ok": true, "value": "done"}
- tab.mu.Unlock()
-
- result, ok := tabResultFromEvaluate(tab).(map[string]any)
- if !ok {
- t.Fatalf("expected evaluate result map")
- }
- if value, _ := result["value"].(string); value != "done" {
- t.Fatalf("expected stored evaluate result, got %#v", result)
+ x, y := resolveBrowserScrollDelta(toolArgs{})
+ if x != 0 || y != 700 {
+ t.Fatalf("expected default down scroll, got %d/%d", x, y)
}
}
-func TestTabResultFromEvaluateNilTab(t *testing.T) {
+func TestResolveBrowserScrollDeltaHonorsDirectionAndAmount(t *testing.T) {
t.Parallel()
- if result := tabResultFromEvaluate(nil); result != nil {
- t.Fatalf("expected nil result for nil tab, got %#v", result)
+ x, y := resolveBrowserScrollDelta(toolArgs{"direction": "left", "amount": 320})
+ if x != -320 || y != 0 {
+ t.Fatalf("expected left scroll -320/0, got %d/%d", x, y)
}
}
@@ -373,14 +587,58 @@ func TestResolveBrowserNodeOutputReturnsNilOnInvalidJSON(t *testing.T) {
func TestIsBrowserSnapshotForAIUnavailable(t *testing.T) {
t.Parallel()
- if !isBrowserSnapshotForAIUnavailable(errors.New("Playwright _snapshotForAI is not available")) {
- t.Fatalf("expected _snapshotForAI error to be treated as unavailable")
+ if !isBrowserSnapshotForAIUnavailable(errors.New("browser snapshotForAI is unavailable")) {
+ t.Fatalf("expected snapshotForAI unavailable error to be treated as unavailable")
}
if isBrowserSnapshotForAIUnavailable(errors.New("network timeout")) {
t.Fatalf("expected unrelated error to remain actionable")
}
}
+func TestBrowserResultMapPreservesSnapshotPayload(t *testing.T) {
+ t.Parallel()
+
+ result := browserResultMap(browsercdp.ActionResult{
+ OK: true,
+ TargetID: "tab-1",
+ URL: "https://example.com",
+ StateVersion: 7,
+ Items: []browsercdp.SnapshotItem{
+ {Ref: "e1", Role: "button", Name: "Save"},
+ },
+ State: &browsercdp.PageState{
+ Version: 7,
+ URL: "https://example.com",
+ ItemCount: 1,
+ CapturedAt: "2026-04-18T00:00:00Z",
+ },
+ StateAvailable: true,
+ StateError: "warning",
+ })
+
+ if _, ok := result["state"]; !ok {
+ t.Fatalf("expected state to be preserved")
+ }
+ if got := result["stateVersion"]; got != float64(7) {
+ t.Fatalf("expected stateVersion to remain, got %#v", got)
+ }
+ if _, ok := result["items"]; !ok {
+ t.Fatalf("expected items to be preserved")
+ }
+ if got := result["stateAvailable"]; got != true {
+ t.Fatalf("expected stateAvailable to remain, got %#v", got)
+ }
+ if got := result["itemCount"]; got != 1 {
+ t.Fatalf("expected itemCount to remain, got %#v", got)
+ }
+ if got := result["stateError"]; got != "warning" {
+ t.Fatalf("expected stateError to remain, got %#v", got)
+ }
+ if got := result["targetId"]; got != "tab-1" {
+ t.Fatalf("expected targetId to remain, got %#v", got)
+ }
+}
+
func TestToBrowserFriendlyInteractionErrorStrictMode(t *testing.T) {
t.Parallel()
diff --git a/internal/application/gateway/tools/builtin_requirement_resolver.go b/internal/application/gateway/tools/builtin_requirement_resolver.go
new file mode 100644
index 0000000..f6f5a3e
--- /dev/null
+++ b/internal/application/gateway/tools/builtin_requirement_resolver.go
@@ -0,0 +1,203 @@
+package tools
+
+import (
+ "context"
+ "strings"
+
+ assistantservice "dreamcreator/internal/application/assistant/service"
+ gatewayvoice "dreamcreator/internal/application/gateway/voice"
+ tooldto "dreamcreator/internal/application/tools/dto"
+ "dreamcreator/internal/domain/providers"
+)
+
+const (
+ imageRequirementID = "image.model_runtime"
+ ttsServiceRequirementID = "tts.voice_service"
+ ttsVoiceEnabledRequirementID = "tts.voice_enabled"
+ ttsProviderRequirementID = "tts.provider_supported"
+ ttsProviderAPIKeyRequirementID = "tts.provider_api_key"
+ ttsVoiceIDRequirementID = "tts.voice_id"
+ imageModelNotConfiguredReason = "Image model is not configured"
+ providerRepositoriesReason = "Provider repositories are unavailable"
+ voiceDisabledReason = "Voice is disabled"
+ voiceServiceUnavailableReason = "Voice service unavailable"
+ ttsProviderAPIKeyMissingReason = "TTS provider API key is missing"
+ ttsVoiceIDMissingReason = "TTS voice ID is not configured"
+ ttsEdgeProviderUnavailableReason = "Edge-TTS provider is not implemented yet"
+ ttsUnsupportedProviderReason = "TTS provider is not supported"
+)
+
+type voiceToolStatusProvider interface {
+ Status(ctx context.Context) (gatewayvoice.TTSStatusResponse, error)
+}
+
+type BuiltinRequirementDeps struct {
+ Settings SettingsReader
+ Assistants *assistantservice.AssistantService
+ Providers providers.ProviderRepository
+ Models providers.ModelRepository
+ Secrets providers.SecretRepository
+ Voice voiceToolStatusProvider
+}
+
+type builtinRequirementResolver struct {
+ deps BuiltinRequirementDeps
+}
+
+func NewBuiltinRequirementResolver(deps BuiltinRequirementDeps) ToolRequirementResolver {
+ return builtinRequirementResolver{deps: deps}
+}
+
+func (resolver builtinRequirementResolver) ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement {
+ switch resolveToolRequirementKey(spec) {
+ case "image":
+ return resolver.resolveImageRequirements(ctx)
+ case "tts":
+ return resolver.resolveTTSRequirements(ctx)
+ default:
+ return nil
+ }
+}
+
+func (resolver builtinRequirementResolver) resolveImageRequirements(ctx context.Context) []tooldto.ToolRequirement {
+ requirement := tooldto.ToolRequirement{
+ ID: imageRequirementID,
+ Name: "Image model",
+ Available: true,
+ }
+ if resolver.deps.Providers == nil || resolver.deps.Secrets == nil || resolver.deps.Models == nil {
+ requirement.Available = false
+ requirement.Reason = providerRepositoriesReason
+ return []tooldto.ToolRequirement{requirement}
+ }
+ configuredPrimaryRef := resolveImageToolConfiguredPrimaryRef(ctx, resolver.deps.Settings)
+ _, _, err := resolveImageToolCandidates(
+ ctx,
+ resolver.deps.Assistants,
+ "",
+ configuredPrimaryRef,
+ resolver.deps.Providers,
+ resolver.deps.Models,
+ resolver.deps.Secrets,
+ )
+ if err == nil {
+ return []tooldto.ToolRequirement{requirement}
+ }
+ requirement.Available = false
+ requirement.Reason = normalizeBuiltinRequirementReason(err, imageModelNotConfiguredReason)
+ return []tooldto.ToolRequirement{requirement}
+}
+
+func (resolver builtinRequirementResolver) resolveTTSRequirements(ctx context.Context) []tooldto.ToolRequirement {
+ if resolver.deps.Voice == nil {
+ return []tooldto.ToolRequirement{
+ {
+ ID: ttsServiceRequirementID,
+ Name: "Voice service",
+ Available: false,
+ Reason: voiceServiceUnavailableReason,
+ },
+ }
+ }
+ status, err := resolver.deps.Voice.Status(ctx)
+ if err != nil {
+ return []tooldto.ToolRequirement{
+ {
+ ID: ttsServiceRequirementID,
+ Name: "Voice service",
+ Available: false,
+ Reason: normalizeBuiltinRequirementReason(err, voiceServiceUnavailableReason),
+ },
+ }
+ }
+ providerID := strings.ToLower(strings.TrimSpace(status.Config.ProviderID))
+ if providerID == "" {
+ providerID = "edge"
+ }
+ requirements := []tooldto.ToolRequirement{
+ {
+ ID: ttsVoiceEnabledRequirementID,
+ Name: "Voice feature",
+ Available: status.Enabled,
+ },
+ }
+ if !status.Enabled {
+ requirements[0].Reason = voiceDisabledReason
+ }
+ switch providerID {
+ case "edge":
+ requirements = append(requirements, tooldto.ToolRequirement{
+ ID: ttsProviderRequirementID,
+ Name: "Provider",
+ Available: false,
+ Reason: ttsEdgeProviderUnavailableReason,
+ Data: map[string]any{
+ "providerId": providerID,
+ },
+ })
+ return requirements
+ case "openai", "elevenlabs":
+ providerReady := false
+ for _, provider := range status.Providers {
+ if strings.EqualFold(strings.TrimSpace(provider.ProviderID), providerID) {
+ providerReady = provider.Available
+ break
+ }
+ }
+ requirements = append(requirements,
+ tooldto.ToolRequirement{
+ ID: ttsProviderRequirementID,
+ Name: "Provider",
+ Available: true,
+ Data: map[string]any{
+ "providerId": providerID,
+ },
+ },
+ tooldto.ToolRequirement{
+ ID: ttsProviderAPIKeyRequirementID,
+ Name: "Provider API key",
+ Available: providerReady,
+ },
+ )
+ if !providerReady {
+ requirements[len(requirements)-1].Reason = ttsProviderAPIKeyMissingReason
+ }
+ if providerID == "elevenlabs" {
+ voiceIDConfigured := strings.TrimSpace(status.Config.VoiceID) != ""
+ requirements = append(requirements, tooldto.ToolRequirement{
+ ID: ttsVoiceIDRequirementID,
+ Name: "Voice ID",
+ Available: voiceIDConfigured,
+ Data: map[string]any{
+ "value": strings.TrimSpace(status.Config.VoiceID),
+ },
+ })
+ if !voiceIDConfigured {
+ requirements[len(requirements)-1].Reason = ttsVoiceIDMissingReason
+ }
+ }
+ return requirements
+ default:
+ requirements = append(requirements, tooldto.ToolRequirement{
+ ID: ttsProviderRequirementID,
+ Name: "Provider",
+ Available: false,
+ Reason: ttsUnsupportedProviderReason,
+ Data: map[string]any{
+ "providerId": providerID,
+ },
+ })
+ return requirements
+ }
+}
+
+func normalizeBuiltinRequirementReason(err error, fallback string) string {
+ if err == nil {
+ return fallback
+ }
+ message := strings.TrimSpace(err.Error())
+ if message == "" {
+ return fallback
+ }
+ return message
+}
diff --git a/internal/application/gateway/tools/builtin_specs.go b/internal/application/gateway/tools/builtin_specs.go
index d0d5945..2c5b190 100644
--- a/internal/application/gateway/tools/builtin_specs.go
+++ b/internal/application/gateway/tools/builtin_specs.go
@@ -170,26 +170,16 @@ func specWebFetch() toolSpec {
return toolSpec{
ID: "web_fetch",
Name: "web_fetch",
- Description: "Fetch a web page and return structured status fields (status, retryable, next_action, quality). Do not blind-retry the same call when status is not ok.",
+ Description: "Fetch a web page through a local CDP browser, extract token-efficient main content, and return structured status fields (status, retryable, next_action, quality). Do not blind-retry the same call when status is not ok.",
Category: "web",
SchemaJSON: schemaJSON(map[string]any{
"type": "object",
"properties": map[string]any{
- "url": map[string]any{"type": "string"},
- "type": map[string]any{"type": "string"},
- "method": map[string]any{"type": "string"},
- "headers": map[string]any{"type": "object"},
- "maxChars": map[string]any{"type": "integer"},
- "maxBodyBytes": map[string]any{"type": "integer"},
- "maxRedirects": map[string]any{"type": "integer"},
- "retryMax": map[string]any{"type": "integer"},
- "timeoutSeconds": map[string]any{"type": "integer"},
- "acceptMarkdown": map[string]any{"type": "boolean"},
- "markdown": map[string]any{"type": "boolean"},
- "toMarkdown": map[string]any{"type": "boolean"},
- "enableUserAgent": map[string]any{"type": "boolean"},
- "userAgent": map[string]any{"type": "string"},
- "acceptLanguage": map[string]any{"type": "string"},
+ "url": map[string]any{"type": "string"},
+ "method": map[string]any{"type": "string"},
+ "maxChars": map[string]any{"type": "integer"},
+ "maxBodyBytes": map[string]any{"type": "integer"},
+ "timeoutSeconds": map[string]any{"type": "integer"},
},
"required": []string{"url"},
}),
@@ -227,207 +217,6 @@ func specWebSearch() toolSpec {
}
}
-func specBrowser() toolSpec {
- return toolSpec{
- ID: "browser",
- Name: "browser",
- Description: "Control the browser via Playwright runtime (status/start/stop/profiles/tabs/open/focus/close/snapshot/screenshot/navigate/console/pdf/upload/dialog/act).",
- Category: "ui",
- RiskLevel: "high",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "action": map[string]any{
- "type": "string",
- "enum": []string{
- "status",
- "start",
- "stop",
- "profiles",
- "tabs",
- "open",
- "focus",
- "close",
- "snapshot",
- "screenshot",
- "navigate",
- "console",
- "pdf",
- "upload",
- "dialog",
- "act",
- },
- },
- "target": map[string]any{
- "type": "string",
- "enum": []string{"sandbox", "host", "node"},
- },
- "node": map[string]any{"type": "string"},
- "profile": map[string]any{"type": "string"},
- "targetUrl": map[string]any{"type": "string"},
- "targetId": map[string]any{"type": "string"},
- "limit": map[string]any{"type": "integer"},
- "maxChars": map[string]any{"type": "integer"},
- "mode": map[string]any{"type": "string", "enum": []string{"efficient"}},
- "timeoutMs": map[string]any{"type": "integer"},
- "snapshotFormat": map[string]any{
- "type": "string",
- "enum": []string{"aria", "ai"},
- },
- "refs": map[string]any{
- "type": "string",
- "enum": []string{"role", "aria"},
- },
- "interactive": map[string]any{"type": "boolean"},
- "compact": map[string]any{"type": "boolean"},
- "depth": map[string]any{"type": "integer"},
- "selector": map[string]any{"type": "string"},
- "frame": map[string]any{"type": "string"},
- "labels": map[string]any{"type": "boolean"},
- "fullPage": map[string]any{"type": "boolean"},
- "ref": map[string]any{"type": "string"},
- "element": map[string]any{"type": "string"},
- "type": map[string]any{
- "type": "string",
- "enum": []string{"png", "jpeg"},
- },
- "level": map[string]any{"type": "string"},
- "paths": map[string]any{
- "type": "array",
- "items": map[string]any{"type": "string"},
- },
- "inputRef": map[string]any{"type": "string"},
- "accept": map[string]any{"type": "boolean"},
- "promptText": map[string]any{"type": "string"},
- "request": map[string]any{
- "type": "object",
- "properties": map[string]any{
- "kind": map[string]any{
- "type": "string",
- "enum": []string{
- "click",
- "type",
- "press",
- "hover",
- "drag",
- "select",
- "fill",
- "resize",
- "wait",
- "evaluate",
- "close",
- },
- },
- "targetId": map[string]any{"type": "string"},
- "ref": map[string]any{"type": "string"},
- "doubleClick": map[string]any{"type": "boolean"},
- "button": map[string]any{"type": "string"},
- "modifiers": map[string]any{
- "type": "array",
- "items": map[string]any{"type": "string"},
- },
- "text": map[string]any{"type": "string"},
- "submit": map[string]any{"type": "boolean"},
- "slowly": map[string]any{"type": "boolean"},
- "key": map[string]any{"type": "string"},
- "startRef": map[string]any{"type": "string"},
- "endRef": map[string]any{"type": "string"},
- "values": map[string]any{"type": "array", "items": map[string]any{"type": "string"}},
- "fields": map[string]any{"type": "array", "items": map[string]any{"type": "object"}},
- "width": map[string]any{"type": "number"},
- "height": map[string]any{"type": "number"},
- "timeMs": map[string]any{"type": "number"},
- "textGone": map[string]any{"type": "string"},
- "selector": map[string]any{"type": "string"},
- "url": map[string]any{"type": "string"},
- "loadState": map[string]any{"type": "string"},
- "fn": map[string]any{"type": "string"},
- "timeoutMs": map[string]any{"type": "number"},
- },
- },
- },
- "required": []string{"action"},
- }),
- RequiresSandbox: true,
- RequiresApproval: true,
- Enabled: true,
- }
-}
-
-func specCanvas() toolSpec {
- return toolSpec{
- ID: "canvas",
- Name: "canvas",
- Description: "Control node canvases (present/hide/navigate/eval/snapshot/a2ui).",
- Category: "ui",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "action": map[string]any{
- "type": "string",
- "enum": []string{
- "present",
- "hide",
- "navigate",
- "eval",
- "snapshot",
- "a2ui_push",
- "a2ui_reset",
- },
- },
- "gatewayUrl": map[string]any{"type": "string"},
- "gatewayToken": map[string]any{"type": "string"},
- "timeoutMs": map[string]any{"type": "number"},
- "node": map[string]any{"type": "string"},
- "target": map[string]any{"type": "string"},
- "x": map[string]any{"type": "number"},
- "y": map[string]any{"type": "number"},
- "width": map[string]any{"type": "number"},
- "height": map[string]any{"type": "number"},
- "url": map[string]any{"type": "string"},
- "javaScript": map[string]any{"type": "string"},
- "outputFormat": map[string]any{
- "type": "string",
- "enum": []string{"png", "jpg", "jpeg"},
- },
- "maxWidth": map[string]any{"type": "number"},
- "quality": map[string]any{"type": "number"},
- "delayMs": map[string]any{"type": "number"},
- "jsonl": map[string]any{"type": "string"},
- "jsonlPath": map[string]any{
- "type": "string",
- },
- },
- "required": []string{"action"},
- }),
- Enabled: true,
- }
-}
-
-func specImage() toolSpec {
- return toolSpec{
- ID: "image",
- Name: "image",
- Description: "Analyze one or more images with the configured image model.",
- Category: "media",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "prompt": map[string]any{"type": "string"},
- "image": map[string]any{"type": "string"},
- "images": map[string]any{
- "type": "array",
- "items": map[string]any{"type": "string"},
- },
- "model": map[string]any{"type": "string"},
- "maxBytesMb": map[string]any{"type": "number"},
- "maxImages": map[string]any{"type": "number"},
- },
- }),
- Enabled: true,
- }
-}
-
func specMessage(ctx context.Context, settings SettingsReader) toolSpec {
spec := specMessageBase()
profile := resolveMessageToolSchemaProfile(ctx, settings)
@@ -1274,7 +1063,7 @@ func specAgentsList() toolSpec {
return toolSpec{
ID: "agents_list",
Name: "agents_list",
- Description: "List agents.",
+ Description: "List available agent profiles for subagent spawning.",
Category: "sessions",
Enabled: true,
}
@@ -1284,7 +1073,7 @@ func specSessionsList() toolSpec {
return toolSpec{
ID: "sessions_list",
Name: "sessions_list",
- Description: "List sessions.",
+ Description: "List current sessions.",
Category: "sessions",
Enabled: true,
}
@@ -1294,7 +1083,7 @@ func specSessionsHistory() toolSpec {
return toolSpec{
ID: "sessions_history",
Name: "sessions_history",
- Description: "Fetch session history.",
+ Description: "Read message history for a session.",
Category: "sessions",
SchemaJSON: schemaJSON(map[string]any{
"type": "object",
@@ -1312,7 +1101,7 @@ func specSessionsSend() toolSpec {
return toolSpec{
ID: "sessions_send",
Name: "sessions_send",
- Description: "Send a message to a session.",
+ Description: "Append a message to an existing session.",
Category: "sessions",
RiskLevel: "medium",
SchemaJSON: schemaJSON(map[string]any{
@@ -1357,7 +1146,7 @@ func specSessionStatus() toolSpec {
return toolSpec{
ID: "session_status",
Name: "session_status",
- Description: "Get session status.",
+ Description: "Get session metadata and status.",
Category: "sessions",
SchemaJSON: schemaJSON(map[string]any{
"type": "object",
@@ -1422,7 +1211,7 @@ func specSkills() toolSpec {
return toolSpec{
ID: "skills",
Name: "skills",
- Description: "Skills runtime tool: status/bins + dependency install + per-skill runtime config update. Not for package search/install.",
+ Description: "Inspect skills runtime status and update per-skill runtime dependencies or configuration.",
Category: "skills",
RiskLevel: "medium",
SchemaJSON: schemaJSON(map[string]any{
@@ -1506,12 +1295,12 @@ func specSkills() toolSpec {
}
}
-func specSkillManage() toolSpec {
+func specSkillsManage() toolSpec {
actions := []string{"list", "search", "install", "update", "remove", "sync"}
return toolSpec{
- ID: "skill_manage",
- Name: "skill_manage",
- Description: "Skills package management tool via ClawHub/SkillHub (list/search/install/update/remove/sync). Use for discovery and package lifecycle.",
+ ID: "skills_manage",
+ Name: "skills_manage",
+ Description: "Search, install, update, remove, and sync skill packages via ClawHub.",
Category: "skills",
RiskLevel: "medium",
SchemaJSON: schemaJSON(map[string]any{
@@ -1574,7 +1363,7 @@ func specSubagents() toolSpec {
return toolSpec{
ID: "subagents",
Name: "subagents",
- Description: "Manage subagent runs (list/info/log/kill/steer/send).",
+ Description: "Manage existing subagent runs (list/info/log/kill/steer/send).",
Category: "sessions",
SchemaJSON: schemaJSON(map[string]any{
"type": "object",
@@ -1593,7 +1382,7 @@ func specNodes() toolSpec {
return toolSpec{
ID: "nodes",
Name: "nodes",
- Description: "Invoke node capability.",
+ Description: "Experimental low-level RPC to a registered node. Temporarily unavailable until remote node runtime support is implemented; prefer specialized tools like canvas or browser when available.",
Category: "nodes",
RiskLevel: "medium",
SchemaJSON: schemaJSON(map[string]any{
@@ -1601,10 +1390,14 @@ func specNodes() toolSpec {
"properties": map[string]any{
"nodeId": map[string]any{"type": "string"},
"capability": map[string]any{"type": "string"},
+ "action": map[string]any{"type": "string"},
+ "args": map[string]any{"type": "string"},
"payload": map[string]any{"type": "object"},
+ "timeoutMs": map[string]any{"type": "integer"},
},
+ "required": []string{"nodeId", "capability"},
}),
- Enabled: true,
+ Enabled: false,
}
}
@@ -1612,7 +1405,7 @@ func specTTS() toolSpec {
return toolSpec{
ID: "tts",
Name: "tts",
- Description: "Text-to-speech conversion.",
+ Description: "Synthesize speech audio from text with the configured voice provider.",
Category: "voice",
SchemaJSON: schemaJSON(map[string]any{
"type": "object",
@@ -1628,15 +1421,20 @@ func specTTS() toolSpec {
}
}
-func specMemoryRecall() toolSpec {
+func specMemoryQuery() toolSpec {
return toolSpec{
- ID: "memory_recall",
- Name: "memory_recall",
- Description: "Search long-term memory with hybrid retrieval.",
+ ID: "memory_query",
+ Name: "memory_query",
+ Description: "Query long-term memory with recall, list, and stats actions.",
Category: "memory",
SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
+ "type": "object",
+ "additionalProperties": false,
"properties": map[string]any{
+ "action": map[string]any{
+ "type": "string",
+ "enum": []string{"recall", "list", "stats"},
+ },
"query": map[string]any{"type": "string"},
"limit": map[string]any{"type": "integer"},
"topK": map[string]any{"type": "integer"},
@@ -1651,80 +1449,38 @@ func specMemoryRecall() toolSpec {
"peerKind": map[string]any{"type": "string"},
"peerId": map[string]any{"type": "string"},
},
- "required": []string{"query"},
- }),
- Enabled: true,
- }
-}
-
-func specMemoryStore() toolSpec {
- return toolSpec{
- ID: "memory_store",
- Name: "memory_store",
- Description: "Store a long-term memory entry.",
- Category: "memory",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "text": map[string]any{"type": "string"},
- "content": map[string]any{"type": "string"},
- "category": map[string]any{"type": "string"},
- "confidence": map[string]any{"type": "number"},
- "assistantId": map[string]any{"type": "string"},
- "threadId": map[string]any{"type": "string"},
- "scope": map[string]any{"type": "string"},
- "metadata": map[string]any{"type": "object"},
- "channel": map[string]any{"type": "string"},
- "accountId": map[string]any{"type": "string"},
- "userId": map[string]any{"type": "string"},
- "groupId": map[string]any{"type": "string"},
- "peerKind": map[string]any{"type": "string"},
- "peerId": map[string]any{"type": "string"},
- },
- "required": []string{"text"},
- }),
- Enabled: true,
- }
-}
-
-func specMemoryForget() toolSpec {
- return toolSpec{
- ID: "memory_forget",
- Name: "memory_forget",
- Description: "Forget memory entries by id or query.",
- Category: "memory",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "memoryId": map[string]any{"type": "string"},
- "id": map[string]any{"type": "string"},
- "query": map[string]any{"type": "string"},
- "assistantId": map[string]any{"type": "string"},
- "threadId": map[string]any{"type": "string"},
- "category": map[string]any{"type": "string"},
- "scope": map[string]any{"type": "string"},
- "limit": map[string]any{"type": "integer"},
- "channel": map[string]any{"type": "string"},
- "accountId": map[string]any{"type": "string"},
- "userId": map[string]any{"type": "string"},
- "groupId": map[string]any{"type": "string"},
- "peerKind": map[string]any{"type": "string"},
- "peerId": map[string]any{"type": "string"},
+ "required": []string{"action"},
+ "allOf": []any{
+ map[string]any{
+ "if": map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "recall"},
+ },
+ },
+ "then": map[string]any{
+ "required": []string{"query"},
+ },
+ },
},
}),
Enabled: true,
}
}
-func specMemoryUpdate() toolSpec {
+func specMemoryManage() toolSpec {
return toolSpec{
- ID: "memory_update",
- Name: "memory_update",
- Description: "Update an existing memory entry.",
+ ID: "memory_manage",
+ Name: "memory_manage",
+ Description: "Create, update, or delete long-term memory entries.",
Category: "memory",
SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
+ "type": "object",
+ "additionalProperties": false,
"properties": map[string]any{
+ "action": map[string]any{
+ "type": "string",
+ "enum": []string{"store", "update", "forget"},
+ },
"memoryId": map[string]any{"type": "string"},
"id": map[string]any{"type": "string"},
"text": map[string]any{"type": "string"},
@@ -1735,57 +1491,8 @@ func specMemoryUpdate() toolSpec {
"threadId": map[string]any{"type": "string"},
"scope": map[string]any{"type": "string"},
"metadata": map[string]any{"type": "object"},
- "channel": map[string]any{"type": "string"},
- "accountId": map[string]any{"type": "string"},
- "userId": map[string]any{"type": "string"},
- "groupId": map[string]any{"type": "string"},
- "peerKind": map[string]any{"type": "string"},
- "peerId": map[string]any{"type": "string"},
- },
- }),
- Enabled: true,
- }
-}
-
-func specMemoryStats() toolSpec {
- return toolSpec{
- ID: "memory_stats",
- Name: "memory_stats",
- Description: "Show memory statistics.",
- Category: "memory",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "assistantId": map[string]any{"type": "string"},
- "threadId": map[string]any{"type": "string"},
- "scope": map[string]any{"type": "string"},
- "channel": map[string]any{"type": "string"},
- "accountId": map[string]any{"type": "string"},
- "userId": map[string]any{"type": "string"},
- "groupId": map[string]any{"type": "string"},
- "peerKind": map[string]any{"type": "string"},
- "peerId": map[string]any{"type": "string"},
- },
- }),
- Enabled: true,
- }
-}
-
-func specMemoryList() toolSpec {
- return toolSpec{
- ID: "memory_list",
- Name: "memory_list",
- Description: "List recent memory entries.",
- Category: "memory",
- SchemaJSON: schemaJSON(map[string]any{
- "type": "object",
- "properties": map[string]any{
- "assistantId": map[string]any{"type": "string"},
- "threadId": map[string]any{"type": "string"},
- "category": map[string]any{"type": "string"},
- "scope": map[string]any{"type": "string"},
+ "query": map[string]any{"type": "string"},
"limit": map[string]any{"type": "integer"},
- "offset": map[string]any{"type": "integer"},
"channel": map[string]any{"type": "string"},
"accountId": map[string]any{"type": "string"},
"userId": map[string]any{"type": "string"},
@@ -1793,6 +1500,49 @@ func specMemoryList() toolSpec {
"peerKind": map[string]any{"type": "string"},
"peerId": map[string]any{"type": "string"},
},
+ "required": []string{"action"},
+ "allOf": []any{
+ map[string]any{
+ "if": map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "store"},
+ },
+ },
+ "then": map[string]any{
+ "anyOf": []any{
+ map[string]any{"required": []string{"text"}},
+ map[string]any{"required": []string{"content"}},
+ },
+ },
+ },
+ map[string]any{
+ "if": map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "update"},
+ },
+ },
+ "then": map[string]any{
+ "anyOf": []any{
+ map[string]any{"required": []string{"memoryId"}},
+ map[string]any{"required": []string{"id"}},
+ },
+ },
+ },
+ map[string]any{
+ "if": map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "forget"},
+ },
+ },
+ "then": map[string]any{
+ "anyOf": []any{
+ map[string]any{"required": []string{"memoryId"}},
+ map[string]any{"required": []string{"id"}},
+ map[string]any{"required": []string{"query"}},
+ },
+ },
+ },
+ },
}),
Enabled: true,
}
diff --git a/internal/application/gateway/tools/builtin_specs_ui.go b/internal/application/gateway/tools/builtin_specs_ui.go
new file mode 100644
index 0000000..7367fe1
--- /dev/null
+++ b/internal/application/gateway/tools/builtin_specs_ui.go
@@ -0,0 +1,347 @@
+package tools
+
+func specBrowser() toolSpec {
+ waitConditionProperties := map[string]any{
+ "timeMs": map[string]any{"type": "number"},
+ "text": map[string]any{"type": "string"},
+ "textGone": map[string]any{"type": "string"},
+ "selector": map[string]any{"type": "string"},
+ "url": map[string]any{"type": "string"},
+ "fn": map[string]any{"type": "string"},
+ "timeoutMs": map[string]any{"type": "number"},
+ }
+ waitConditionSchema := map[string]any{
+ "type": "object",
+ "properties": waitConditionProperties,
+ "additionalProperties": false,
+ }
+ return toolSpec{
+ ID: "browser",
+ Name: "browser",
+ Description: "Control a local CDP browser (`open`/`navigate`/`snapshot`/`act`/`wait`/`scroll`/`upload`/`dialog`/`reset`) using a browser-use style loop. For `open` or `navigate`, pass `url` or `targetUrl`; these actions return `stateAvailable`, `itemCount`, and the current page `state`/`items` whenever capture succeeds, so inspect that result before deciding the next action. If `stateAvailable=false` or refs look stale after the page changes, call `snapshot` to refresh them, then continue with `act` using `ref` on the same `targetId`. Do not use raw CSS `selector` for normal interactions; use `ref` from the latest state. Matching connector cookies are injected automatically before navigation.",
+ Category: "ui",
+ RiskLevel: "high",
+ SchemaJSON: schemaJSON(map[string]any{
+ "type": "object",
+ "additionalProperties": false,
+ "properties": map[string]any{
+ "action": map[string]any{
+ "type": "string",
+ "enum": []string{
+ "open",
+ "navigate",
+ "snapshot",
+ "act",
+ "wait",
+ "scroll",
+ "upload",
+ "dialog",
+ "reset",
+ },
+ },
+ "target": map[string]any{
+ "type": "string",
+ "enum": []string{"sandbox", "host", "node"},
+ },
+ "node": map[string]any{"type": "string"},
+ "profile": map[string]any{"type": "string"},
+ "targetUrl": map[string]any{"type": "string"},
+ "targetId": map[string]any{"type": "string"},
+ "newTab": map[string]any{"type": "boolean"},
+ "restart": map[string]any{"type": "boolean"},
+ "limit": map[string]any{"type": "integer"},
+ "timeoutMs": map[string]any{"type": "integer"},
+ "selector": map[string]any{"type": "string"},
+ "fullPage": map[string]any{"type": "boolean"},
+ "ref": map[string]any{"type": "string"},
+ "x": map[string]any{"type": "integer"},
+ "y": map[string]any{"type": "integer"},
+ "amount": map[string]any{"type": "integer"},
+ "direction": map[string]any{
+ "type": "string",
+ "enum": []string{"up", "down", "left", "right"},
+ },
+ "text": map[string]any{"type": "string"},
+ "textGone": map[string]any{"type": "string"},
+ "fn": map[string]any{"type": "string"},
+ "timeMs": map[string]any{"type": "number"},
+ "url": map[string]any{"type": "string"},
+ "paths": map[string]any{"type": "array", "items": map[string]any{"type": "string"}},
+ "accept": map[string]any{"type": "boolean"},
+ "promptText": map[string]any{"type": "string"},
+ "waitFor": waitConditionSchema,
+ "request": map[string]any{
+ "type": "object",
+ "additionalProperties": false,
+ "properties": map[string]any{
+ "kind": map[string]any{
+ "type": "string",
+ "enum": []string{
+ "click",
+ "type",
+ "press",
+ "hover",
+ "select",
+ "fill",
+ "resize",
+ "wait",
+ "evaluate",
+ "close",
+ },
+ },
+ "targetId": map[string]any{"type": "string"},
+ "ref": map[string]any{"type": "string"},
+ "text": map[string]any{"type": "string"},
+ "key": map[string]any{"type": "string"},
+ "value": map[string]any{"type": "string"},
+ "width": map[string]any{"type": "number"},
+ "height": map[string]any{"type": "number"},
+ "timeMs": map[string]any{"type": "number"},
+ "textGone": map[string]any{"type": "string"},
+ "selector": map[string]any{"type": "string"},
+ "url": map[string]any{"type": "string"},
+ "fn": map[string]any{"type": "string"},
+ "expression": map[string]any{"type": "string"},
+ "timeoutMs": map[string]any{"type": "number"},
+ "waitFor": waitConditionSchema,
+ },
+ "required": []string{"kind"},
+ },
+ },
+ "required": []string{"action"},
+ "allOf": []any{
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "open"},
+ },
+ "required": []string{"action"},
+ "anyOf": []any{
+ map[string]any{"required": []string{"targetUrl"}},
+ map[string]any{"required": []string{"url"}},
+ },
+ },
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{
+ "enum": []string{
+ "snapshot",
+ "navigate",
+ "wait",
+ "scroll",
+ "upload",
+ "dialog",
+ "act",
+ "reset",
+ },
+ },
+ },
+ "required": []string{"action"},
+ },
+ },
+ },
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "navigate"},
+ },
+ "required": []string{"action"},
+ "anyOf": []any{
+ map[string]any{"required": []string{"targetUrl"}},
+ map[string]any{"required": []string{"url"}},
+ },
+ },
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{
+ "enum": []string{
+ "open",
+ "snapshot",
+ "wait",
+ "scroll",
+ "upload",
+ "dialog",
+ "act",
+ "reset",
+ },
+ },
+ },
+ "required": []string{"action"},
+ },
+ },
+ },
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "act"},
+ },
+ "required": []string{"action", "request"},
+ },
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{
+ "enum": []string{
+ "open",
+ "snapshot",
+ "navigate",
+ "wait",
+ "scroll",
+ "upload",
+ "dialog",
+ "reset",
+ },
+ },
+ },
+ "required": []string{"action"},
+ },
+ },
+ },
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "wait"},
+ },
+ "allOf": []any{
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{"required": []string{"timeMs"}},
+ map[string]any{"required": []string{"text"}},
+ map[string]any{"required": []string{"textGone"}},
+ map[string]any{"required": []string{"selector"}},
+ map[string]any{"required": []string{"url"}},
+ map[string]any{"required": []string{"fn"}},
+ },
+ },
+ },
+ },
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{
+ "enum": []string{
+ "open",
+ "snapshot",
+ "navigate",
+ "scroll",
+ "upload",
+ "dialog",
+ "act",
+ "reset",
+ },
+ },
+ },
+ "required": []string{"action"},
+ },
+ },
+ },
+ map[string]any{
+ "anyOf": []any{
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{"const": "upload"},
+ },
+ "required": []string{"action", "ref", "paths"},
+ },
+ map[string]any{
+ "properties": map[string]any{
+ "action": map[string]any{
+ "enum": []string{
+ "open",
+ "snapshot",
+ "navigate",
+ "wait",
+ "scroll",
+ "dialog",
+ "act",
+ "reset",
+ },
+ },
+ },
+ "required": []string{"action"},
+ },
+ },
+ },
+ },
+ }),
+ RequiresSandbox: true,
+ RequiresApproval: true,
+ Enabled: true,
+ }
+}
+
+func specCanvas() toolSpec {
+ return toolSpec{
+ ID: "canvas",
+ Name: "canvas",
+ Description: "Control node canvases (present/hide/navigate/eval/snapshot/a2ui). Temporarily unavailable until remote node runtime support is implemented.",
+ Category: "ui",
+ SchemaJSON: schemaJSON(map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "action": map[string]any{
+ "type": "string",
+ "enum": []string{
+ "present",
+ "hide",
+ "navigate",
+ "eval",
+ "snapshot",
+ "a2ui_push",
+ "a2ui_reset",
+ },
+ },
+ "gatewayUrl": map[string]any{"type": "string"},
+ "gatewayToken": map[string]any{"type": "string"},
+ "timeoutMs": map[string]any{"type": "number"},
+ "node": map[string]any{"type": "string"},
+ "target": map[string]any{"type": "string"},
+ "x": map[string]any{"type": "number"},
+ "y": map[string]any{"type": "number"},
+ "width": map[string]any{"type": "number"},
+ "height": map[string]any{"type": "number"},
+ "url": map[string]any{"type": "string"},
+ "javaScript": map[string]any{"type": "string"},
+ "outputFormat": map[string]any{
+ "type": "string",
+ "enum": []string{"png", "jpg", "jpeg"},
+ },
+ "maxWidth": map[string]any{"type": "number"},
+ "quality": map[string]any{"type": "number"},
+ "delayMs": map[string]any{"type": "number"},
+ "jsonl": map[string]any{"type": "string"},
+ "jsonlPath": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"action"},
+ }),
+ Enabled: false,
+ }
+}
+
+func specImage() toolSpec {
+ return toolSpec{
+ ID: "image",
+ Name: "image",
+ Description: "Analyze one or more images with the configured image model.",
+ Category: "media",
+ SchemaJSON: schemaJSON(map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "prompt": map[string]any{"type": "string"},
+ "image": map[string]any{"type": "string"},
+ "images": map[string]any{
+ "type": "array",
+ "items": map[string]any{"type": "string"},
+ },
+ "model": map[string]any{"type": "string"},
+ "maxBytesMb": map[string]any{"type": "number"},
+ "maxImages": map[string]any{"type": "number"},
+ },
+ }),
+ Enabled: true,
+ }
+}
diff --git a/internal/application/gateway/tools/builtin_tools.go b/internal/application/gateway/tools/builtin_tools.go
index 266e357..ffb2f79 100644
--- a/internal/application/gateway/tools/builtin_tools.go
+++ b/internal/application/gateway/tools/builtin_tools.go
@@ -65,19 +65,15 @@ func RegisterBuiltinTools(ctx context.Context, toolSvc *toolservice.ToolService,
registerTool(ctx, toolSvc, executor, specSubagents(), runSubagentsTool(nil))
registerTool(ctx, toolSvc, executor, specNodes(), runNodesTool(deps.Nodes))
registerTool(ctx, toolSvc, executor, specTTS(), runTTSTool(deps.Voice))
- registerTool(ctx, toolSvc, executor, specMemoryRecall(), runMemoryRecallTool(deps.Memory))
- registerTool(ctx, toolSvc, executor, specMemoryStore(), runMemoryStoreTool(deps.Memory))
- registerTool(ctx, toolSvc, executor, specMemoryForget(), runMemoryForgetTool(deps.Memory))
- registerTool(ctx, toolSvc, executor, specMemoryUpdate(), runMemoryUpdateTool(deps.Memory))
- registerTool(ctx, toolSvc, executor, specMemoryStats(), runMemoryStatsTool(deps.Memory))
- registerTool(ctx, toolSvc, executor, specMemoryList(), runMemoryListTool(deps.Memory))
+ registerTool(ctx, toolSvc, executor, specMemoryQuery(), runMemoryQueryTool(deps.Memory))
+ registerTool(ctx, toolSvc, executor, specMemoryManage(), runMemoryManageTool(deps.Memory))
if deps.ExternalTools != nil {
registerTool(ctx, toolSvc, executor, specExternalToolsQuery(), runExternalToolsQueryTool(deps.ExternalTools))
registerTool(ctx, toolSvc, executor, specExternalToolsManage(), runExternalToolsManageTool(deps.ExternalTools))
}
if deps.Skills != nil {
registerTool(ctx, toolSvc, executor, specSkills(), runSkillsTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools))
- registerTool(ctx, toolSvc, executor, specSkillManage(), runSkillManageTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools))
+ registerTool(ctx, toolSvc, executor, specSkillsManage(), runSkillsManageTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools))
}
if deps.Library != nil {
registerTool(ctx, toolSvc, executor, specLibrary(), runLibraryGroupTool(deps.Library, "library"))
diff --git a/internal/application/gateway/tools/memory_tools.go b/internal/application/gateway/tools/memory_tools.go
index cf5140b..ec51b37 100644
--- a/internal/application/gateway/tools/memory_tools.go
+++ b/internal/application/gateway/tools/memory_tools.go
@@ -11,6 +11,56 @@ import (
domainsession "dreamcreator/internal/domain/session"
)
+func runMemoryQueryTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) {
+ recallHandler := runMemoryRecallTool(memory)
+ listHandler := runMemoryListTool(memory)
+ statsHandler := runMemoryStatsTool(memory)
+ return func(ctx context.Context, args string) (string, error) {
+ payload, err := parseToolArgs(args)
+ if err != nil {
+ return "", err
+ }
+ action := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "type")))
+ switch action {
+ case "recall", "search":
+ return recallHandler(ctx, args)
+ case "list":
+ return listHandler(ctx, args)
+ case "stats", "status":
+ return statsHandler(ctx, args)
+ case "":
+ return "", errors.New("action is required")
+ default:
+ return "", errors.New("unsupported memory_query action")
+ }
+ }
+}
+
+func runMemoryManageTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) {
+ storeHandler := runMemoryStoreTool(memory)
+ updateHandler := runMemoryUpdateTool(memory)
+ forgetHandler := runMemoryForgetTool(memory)
+ return func(ctx context.Context, args string) (string, error) {
+ payload, err := parseToolArgs(args)
+ if err != nil {
+ return "", err
+ }
+ action := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "type")))
+ switch action {
+ case "store", "create":
+ return storeHandler(ctx, args)
+ case "update":
+ return updateHandler(ctx, args)
+ case "forget", "delete", "remove":
+ return forgetHandler(ctx, args)
+ case "":
+ return "", errors.New("action is required")
+ default:
+ return "", errors.New("unsupported memory_manage action")
+ }
+ }
+}
+
func runMemoryRecallTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) {
return func(ctx context.Context, args string) (string, error) {
if memory == nil {
diff --git a/internal/application/gateway/tools/memory_tools_group_test.go b/internal/application/gateway/tools/memory_tools_group_test.go
new file mode 100644
index 0000000..af3b3d1
--- /dev/null
+++ b/internal/application/gateway/tools/memory_tools_group_test.go
@@ -0,0 +1,56 @@
+package tools
+
+import (
+ "context"
+ "testing"
+)
+
+func TestMemoryQueryToolRequiresAction(t *testing.T) {
+ t.Parallel()
+
+ handler := runMemoryQueryTool(nil)
+ _, err := handler(context.Background(), `{}`)
+ if err == nil || err.Error() != "action is required" {
+ t.Fatalf("expected action is required, got %v", err)
+ }
+}
+
+func TestMemoryManageToolRequiresAction(t *testing.T) {
+ t.Parallel()
+
+ handler := runMemoryManageTool(nil)
+ _, err := handler(context.Background(), `{}`)
+ if err == nil || err.Error() != "action is required" {
+ t.Fatalf("expected action is required, got %v", err)
+ }
+}
+
+func TestMemoryQueryToolRoutesRecallAction(t *testing.T) {
+ t.Parallel()
+
+ handler := runMemoryQueryTool(nil)
+ _, err := handler(context.Background(), `{"action":"recall","query":"demo"}`)
+ if err == nil || err.Error() != "memory service unavailable" {
+ t.Fatalf("expected memory service unavailable, got %v", err)
+ }
+}
+
+func TestMemoryManageToolRoutesCreateAlias(t *testing.T) {
+ t.Parallel()
+
+ handler := runMemoryManageTool(nil)
+ _, err := handler(context.Background(), `{"action":"create","text":"demo"}`)
+ if err == nil || err.Error() != "memory service unavailable" {
+ t.Fatalf("expected memory service unavailable, got %v", err)
+ }
+}
+
+func TestMemoryManageToolRejectsUnknownAction(t *testing.T) {
+ t.Parallel()
+
+ handler := runMemoryManageTool(nil)
+ _, err := handler(context.Background(), `{"action":"archive"}`)
+ if err == nil || err.Error() != "unsupported memory_manage action" {
+ t.Fatalf("expected unsupported memory_manage action, got %v", err)
+ }
+}
diff --git a/internal/application/gateway/tools/nodes_tools.go b/internal/application/gateway/tools/nodes_tools.go
index ccb5b95..c48b99e 100644
--- a/internal/application/gateway/tools/nodes_tools.go
+++ b/internal/application/gateway/tools/nodes_tools.go
@@ -19,10 +19,14 @@ func runNodesTool(nodes *gatewaynodes.Service) func(ctx context.Context, args st
return "", err
}
nodeID := getStringArg(payload, "nodeId", "nodeID")
- capability := getStringArg(payload, "capability", "action")
+ capability := getStringArg(payload, "capability")
if nodeID == "" {
return "", errors.New("nodeId is required")
}
+ if capability == "" {
+ return "", errors.New("capability is required")
+ }
+ timeoutMs, _ := getIntArg(payload, "timeoutMs", "timeout_ms")
argsJSON := strings.TrimSpace(getStringArg(payload, "args"))
if argsJSON == "" {
if payloadMap := getMapArg(payload, "payload"); payloadMap != nil {
@@ -36,6 +40,7 @@ func runNodesTool(nodes *gatewaynodes.Service) func(ctx context.Context, args st
Capability: capability,
Action: getStringArg(payload, "action"),
Args: argsJSON,
+ TimeoutMs: timeoutMs,
}
result, err := nodes.Invoke(ctx, request)
if err != nil {
diff --git a/internal/application/gateway/tools/nodes_tools_test.go b/internal/application/gateway/tools/nodes_tools_test.go
new file mode 100644
index 0000000..306da0b
--- /dev/null
+++ b/internal/application/gateway/tools/nodes_tools_test.go
@@ -0,0 +1,73 @@
+package tools
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ gatewaynodes "dreamcreator/internal/application/gateway/nodes"
+)
+
+func TestRunNodesToolRequiresExplicitCapability(t *testing.T) {
+ t.Parallel()
+
+ handler := runNodesTool(newNodesServiceForTest(t, gatewaynodes.InvokerFunc(func(_ context.Context, request gatewaynodes.NodeInvokeRequest) (gatewaynodes.NodeInvokeResult, error) {
+ t.Fatalf("unexpected invoke: %#v", request)
+ return gatewaynodes.NodeInvokeResult{}, nil
+ })))
+
+ _, err := handler(context.Background(), `{"nodeId":"node-1","action":"screen.capture"}`)
+ if err == nil {
+ t.Fatalf("expected capability required error")
+ }
+ if !strings.Contains(err.Error(), "capability is required") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestRunNodesToolPassesActionPayloadAndTimeout(t *testing.T) {
+ t.Parallel()
+
+ var captured gatewaynodes.NodeInvokeRequest
+ handler := runNodesTool(newNodesServiceForTest(t, gatewaynodes.InvokerFunc(func(_ context.Context, request gatewaynodes.NodeInvokeRequest) (gatewaynodes.NodeInvokeResult, error) {
+ captured = request
+ return gatewaynodes.NodeInvokeResult{
+ InvokeID: request.InvokeID,
+ Ok: true,
+ Output: `{"ok":true}`,
+ }, nil
+ })))
+
+ _, err := handler(context.Background(), `{"nodeId":"node-1","capability":"screen","action":"capture","payload":{"format":"png"},"timeoutMs":4321}`)
+ if err != nil {
+ t.Fatalf("run nodes tool: %v", err)
+ }
+ if captured.NodeID != "node-1" {
+ t.Fatalf("expected node-1, got %q", captured.NodeID)
+ }
+ if captured.Capability != "screen" {
+ t.Fatalf("expected capability screen, got %q", captured.Capability)
+ }
+ if captured.Action != "capture" {
+ t.Fatalf("expected action capture, got %q", captured.Action)
+ }
+ if captured.Args != `{"format":"png"}` {
+ t.Fatalf("unexpected args payload: %q", captured.Args)
+ }
+ if captured.TimeoutMs != 4321 {
+ t.Fatalf("expected timeout 4321, got %d", captured.TimeoutMs)
+ }
+}
+
+func newNodesServiceForTest(t *testing.T, invoker gatewaynodes.Invoker) *gatewaynodes.Service {
+ t.Helper()
+ registry := gatewaynodes.NewRegistry(nil, nil)
+ _, err := registry.Register(context.Background(), "", gatewaynodes.NodeDescriptor{
+ NodeID: "node-1",
+ Status: "online",
+ })
+ if err != nil {
+ t.Fatalf("register node: %v", err)
+ }
+ return gatewaynodes.NewService(registry, nil, invoker, nil, nil)
+}
diff --git a/internal/application/gateway/tools/policy.go b/internal/application/gateway/tools/policy.go
index 4d5868a..ea293e6 100644
--- a/internal/application/gateway/tools/policy.go
+++ b/internal/application/gateway/tools/policy.go
@@ -8,13 +8,21 @@ import (
)
type PolicyPipeline struct {
- settings SettingsReader
+ settings SettingsReader
+ requirementsResolver ToolRequirementResolver
}
func NewPolicyPipeline(settings SettingsReader) *PolicyPipeline {
return &PolicyPipeline{settings: settings}
}
+func (pipeline *PolicyPipeline) SetRequirementsResolver(resolver ToolRequirementResolver) {
+ if pipeline == nil {
+ return
+ }
+ pipeline.requirementsResolver = resolver
+}
+
func (pipeline *PolicyPipeline) Decide(ctx context.Context, spec tooldto.ToolSpec, policyCtx tooldto.ToolPolicyContext) (tooldto.ToolPolicyDecision, error) {
if !spec.Enabled {
return tooldto.ToolPolicyDecision{
@@ -24,7 +32,7 @@ func (pipeline *PolicyPipeline) Decide(ctx context.Context, spec tooldto.ToolSpe
}, nil
}
snapshot := loadToolRequirementSnapshot(ctx, pipeline.settings)
- effectiveSpec := resolveEffectiveToolSpec(spec, snapshot)
+ effectiveSpec := resolveEffectiveToolSpecWithResolver(ctx, spec, snapshot, pipeline.requirementsResolver)
if !effectiveSpec.Enabled {
reason := "tool requirements unavailable"
if requirement, ok := firstUnavailableToolRequirement(effectiveSpec.Requirements); ok {
diff --git a/internal/application/gateway/tools/requirement_resolver.go b/internal/application/gateway/tools/requirement_resolver.go
new file mode 100644
index 0000000..25fed3c
--- /dev/null
+++ b/internal/application/gateway/tools/requirement_resolver.go
@@ -0,0 +1,41 @@
+package tools
+
+import (
+ "context"
+
+ tooldto "dreamcreator/internal/application/tools/dto"
+)
+
+type ToolRequirementResolver interface {
+ ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement
+}
+
+type ToolRequirementResolverFunc func(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement
+
+func (fn ToolRequirementResolverFunc) ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement {
+ if fn == nil {
+ return nil
+ }
+ return fn(ctx, spec)
+}
+
+func resolveEffectiveToolSpecWithResolver(
+ ctx context.Context,
+ spec tooldto.ToolSpec,
+ snapshot toolRequirementSnapshot,
+ resolver ToolRequirementResolver,
+) tooldto.ToolSpec {
+ effective := resolveEffectiveToolSpec(spec, snapshot)
+ if resolver == nil {
+ return effective
+ }
+ additional := resolver.ResolveToolRequirements(ctx, effective)
+ if len(additional) == 0 {
+ return effective
+ }
+ effective.Requirements = append(effective.Requirements, additional...)
+ if effective.Enabled && !toolRequirementsSatisfied(effective.Requirements) {
+ effective.Enabled = false
+ }
+ return effective
+}
diff --git a/internal/application/gateway/tools/service.go b/internal/application/gateway/tools/service.go
index 6fe339f..34d0817 100644
--- a/internal/application/gateway/tools/service.go
+++ b/internal/application/gateway/tools/service.go
@@ -16,14 +16,15 @@ import (
)
type Service struct {
- tools *toolservice.ToolService
- approvals *gatewayapprovals.Service
- sandbox *gatewaysandbox.Service
- settings SettingsReader
- audit PolicyAuditStore
- events *gatewayevents.Broker
- now func() time.Time
- newID func() string
+ tools *toolservice.ToolService
+ approvals *gatewayapprovals.Service
+ sandbox *gatewaysandbox.Service
+ settings SettingsReader
+ audit PolicyAuditStore
+ events *gatewayevents.Broker
+ requirementsResolver ToolRequirementResolver
+ now func() time.Time
+ newID func() string
}
type PolicyAuditStore interface {
@@ -43,6 +44,13 @@ func NewService(tools *toolservice.ToolService, approvals *gatewayapprovals.Serv
}
}
+func (service *Service) SetRequirementsResolver(resolver ToolRequirementResolver) {
+ if service == nil {
+ return
+ }
+ service.requirementsResolver = resolver
+}
+
func (service *Service) ListTools(ctx context.Context) []tooldto.ToolSpec {
if service == nil || service.tools == nil {
return nil
@@ -54,7 +62,7 @@ func (service *Service) ListTools(ctx context.Context) []tooldto.ToolSpec {
snapshot := loadToolRequirementSnapshot(ctx, service.settings)
result := make([]tooldto.ToolSpec, 0, len(specs))
for _, spec := range specs {
- resolved := resolveEffectiveToolSpec(spec, snapshot)
+ resolved := resolveEffectiveToolSpecWithResolver(ctx, spec, snapshot, service.requirementsResolver)
resolved = resolveDynamicToolSpec(ctx, resolved, service.settings)
result = append(result, resolved)
}
@@ -98,6 +106,7 @@ func (service *Service) InvokeWithPolicy(ctx context.Context, request tooldto.To
if err != nil {
return tooldto.ToolsInvokeResponse{}, err
}
+ spec = resolveEffectiveToolSpecWithResolver(ctx, spec, loadToolRequirementSnapshot(ctx, service.settings), service.requirementsResolver)
spec = resolveDynamicToolSpec(ctx, spec, service.settings)
service.auditDecision(ctx, spec, decision, policyCtx)
response := tooldto.ToolsInvokeResponse{
@@ -171,6 +180,13 @@ func (service *Service) InvokeWithPolicy(ctx context.Context, request tooldto.To
return response, nil
}
+func (service *Service) CleanupRuntimeSession(_ context.Context, sessionKey string) {
+ if strings.TrimSpace(sessionKey) == "" {
+ return
+ }
+ cleanupBrowserToolSessions(sessionKey)
+}
+
func (service *Service) auditDecision(ctx context.Context, spec tooldto.ToolSpec, decision tooldto.ToolPolicyDecision, policyCtx tooldto.ToolPolicyContext) {
if service == nil || service.audit == nil {
return
diff --git a/internal/application/gateway/tools/skills_tools.go b/internal/application/gateway/tools/skills_tools.go
index 43f9630..dc59118 100644
--- a/internal/application/gateway/tools/skills_tools.go
+++ b/internal/application/gateway/tools/skills_tools.go
@@ -152,8 +152,12 @@ func runSkillsTool(skills *skillsservice.SkillsService, assistants *assistantser
return runSkillsToolWithName("skills", skills, assistants, settings, externalTools)
}
+func runSkillsManageTool(skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) {
+ return runSkillsToolWithName("skills_manage", skills, assistants, settings, externalTools)
+}
+
func runSkillManageTool(skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) {
- return runSkillsToolWithName("skill_manage", skills, assistants, settings, externalTools)
+ return runSkillsManageTool(skills, assistants, settings, externalTools)
}
func runSkillsToolWithName(toolName string, skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) {
@@ -832,7 +836,7 @@ func canonicalSkillsAction(toolName string, action string) string {
default:
return ""
}
- case "skill_manage":
+ case "skill_manage", "skills_manage":
switch normalized {
case "list":
return "catalog"
@@ -873,20 +877,20 @@ func normalizeSkillsActionKey(toolName string, action string) string {
default:
return normalized
}
- case "skill_manage":
+ case "skill_manage", "skills_manage":
switch normalized {
case "catalog":
- return "skill_manage.list"
+ return "skills_manage.list"
case "search_packages":
- return "skill_manage.search"
+ return "skills_manage.search"
case "install_package":
- return "skill_manage.install"
+ return "skills_manage.install"
case "update_package":
- return "skill_manage.update"
+ return "skills_manage.update"
case "remove_package":
- return "skill_manage.remove"
+ return "skills_manage.remove"
case "sync_packages":
- return "skill_manage.sync"
+ return "skills_manage.sync"
default:
return normalized
}
@@ -900,7 +904,7 @@ func validateSkillsToolArgs(toolName string, payload toolArgs) error {
return nil
}
allowed := allowedSkillsToolArgs
- if strings.EqualFold(strings.TrimSpace(toolName), "skill_manage") {
+ if strings.EqualFold(strings.TrimSpace(toolName), "skill_manage") || strings.EqualFold(strings.TrimSpace(toolName), "skills_manage") {
allowed = allowedSkillManageToolArgs
}
unknown := make([]string, 0)
@@ -1148,9 +1152,9 @@ func marshalSkillsResult(result skillsToolResult) string {
func resolveSkillsActionGroup(action string) string {
switch strings.ToLower(strings.TrimSpace(action)) {
- case "skills.status", "skills.bins", "skill_manage.search", "skill_manage.list":
+ case "skills.status", "skills.bins", "skills_manage.search", "skills_manage.list":
return "read"
- case "skill_manage.install", "skill_manage.update", "skill_manage.remove", "skill_manage.sync":
+ case "skills_manage.install", "skills_manage.update", "skills_manage.remove", "skills_manage.sync":
return "package_write"
case "skills.install":
return "deps_write"
diff --git a/internal/application/gateway/tools/skills_tools_test.go b/internal/application/gateway/tools/skills_tools_test.go
index c3e7447..c7e2e82 100644
--- a/internal/application/gateway/tools/skills_tools_test.go
+++ b/internal/application/gateway/tools/skills_tools_test.go
@@ -201,7 +201,7 @@ func (stub *skillsSettingsStub) UpdateSettings(_ context.Context, request settin
return copied, nil
}
-func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) {
+func TestSkillsManageToolSearchReturnsClawHubUnavailable(t *testing.T) {
t.Parallel()
repo := newSkillsRepoStub()
@@ -209,10 +209,10 @@ func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) {
svc.SetExternalTools(&skillsExternalToolsStubForGateway{ready: false})
svc.SetWorkspaceResolver(skillsWorkspaceResolverStubForGateway{})
- handler := runSkillManageTool(svc, nil, nil, nil)
+ handler := runSkillsManageTool(svc, nil, nil, nil)
output, err := handler(context.Background(), `{"action":"search","query":"demo"}`)
if err != nil {
- t.Fatalf("skill_manage tool failed: %v", err)
+ t.Fatalf("skills_manage tool failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(output), &payload); err != nil {
@@ -226,7 +226,7 @@ func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) {
}
}
-func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) {
+func TestSkillsManageToolInstallRequireForceAndRetrySuccess(t *testing.T) {
t.Parallel()
repo := newSkillsRepoStub()
@@ -248,11 +248,11 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) {
}
},
})
- handler := runSkillManageTool(svc, nil, nil, nil)
+ handler := runSkillsManageTool(svc, nil, nil, nil)
output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro"}`)
if err != nil {
- t.Fatalf("skill_manage tool install failed: %v", err)
+ t.Fatalf("skills_manage tool install failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(output), &payload); err != nil {
@@ -267,13 +267,13 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) {
if payload["requiresForce"] != true {
t.Fatalf("expected requiresForce=true, got %#v", payload["requiresForce"])
}
- if payload["action"] != "skill_manage.install" {
- t.Fatalf("expected action skill_manage.install, got %#v", payload["action"])
+ if payload["action"] != "skills_manage.install" {
+ t.Fatalf("expected action skills_manage.install, got %#v", payload["action"])
}
output, err = handler(context.Background(), `{"action":"install","skill":"web-search-pro","force":true}`)
if err != nil {
- t.Fatalf("skill_manage tool install(force) failed: %v", err)
+ t.Fatalf("skills_manage tool install(force) failed: %v", err)
}
payload = map[string]any{}
if err := json.Unmarshal([]byte(output), &payload); err != nil {
@@ -284,7 +284,7 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) {
}
}
-func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) {
+func TestSkillsManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) {
t.Parallel()
repo := newSkillsRepoStub()
@@ -312,11 +312,11 @@ func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T)
},
},
})
- handler := runSkillManageTool(svc, nil, settings, nil)
+ handler := runSkillsManageTool(svc, nil, settings, nil)
output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro"}`)
if err != nil {
- t.Fatalf("skill_manage tool install failed: %v", err)
+ t.Fatalf("skills_manage tool install failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(output), &payload); err != nil {
@@ -330,7 +330,7 @@ func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T)
}
}
-func TestSkillManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) {
+func TestSkillsManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) {
t.Parallel()
repo := newSkillsRepoStub()
@@ -348,11 +348,11 @@ func TestSkillManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) {
},
},
})
- handler := runSkillManageTool(svc, nil, settings, nil)
+ handler := runSkillsManageTool(svc, nil, settings, nil)
output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro","force":true}`)
if err != nil {
- t.Fatalf("skill_manage tool install failed: %v", err)
+ t.Fatalf("skills_manage tool install failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(output), &payload); err != nil {
@@ -483,7 +483,7 @@ func TestSkillsToolActionModeDeny(t *testing.T) {
}
}
-func TestSkillManageToolActionModeAsk(t *testing.T) {
+func TestSkillsManageToolActionModeAsk(t *testing.T) {
t.Parallel()
repo := newSkillsRepoStub()
@@ -497,11 +497,11 @@ func TestSkillManageToolActionModeAsk(t *testing.T) {
},
},
})
- handler := runSkillManageTool(svc, nil, settings, nil)
+ handler := runSkillsManageTool(svc, nil, settings, nil)
output, err := handler(context.Background(), `{"action":"install","skill":"skill-a"}`)
if err != nil {
- t.Fatalf("skill_manage tool failed: %v", err)
+ t.Fatalf("skills_manage tool failed: %v", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(output), &payload); err != nil {
diff --git a/internal/application/gateway/tools/tool_helpers.go b/internal/application/gateway/tools/tool_helpers.go
index 9b7949c..0aeacc1 100644
--- a/internal/application/gateway/tools/tool_helpers.go
+++ b/internal/application/gateway/tools/tool_helpers.go
@@ -350,3 +350,21 @@ func getNestedInt(root map[string]any, path ...string) (int, bool) {
}
return 0, false
}
+
+func containsString(values []string, needle string) bool {
+ needle = strings.TrimSpace(needle)
+ for _, value := range values {
+ if strings.TrimSpace(value) == needle {
+ return true
+ }
+ }
+ return false
+}
+
+func trimToMaxChars(value string, maxChars int) string {
+ value = strings.TrimSpace(value)
+ if maxChars <= 0 || len(value) <= maxChars {
+ return value
+ }
+ return value[:maxChars]
+}
diff --git a/internal/application/gateway/tools/tool_requirements.go b/internal/application/gateway/tools/tool_requirements.go
index 1ca3e8e..01c50a2 100644
--- a/internal/application/gateway/tools/tool_requirements.go
+++ b/internal/application/gateway/tools/tool_requirements.go
@@ -6,6 +6,7 @@ import (
"os"
"strings"
+ "dreamcreator/internal/application/browsercdp"
tooldto "dreamcreator/internal/application/tools/dto"
)
@@ -41,9 +42,16 @@ func loadToolRequirementSnapshot(ctx context.Context, settings SettingsReader) t
func resolveEffectiveToolSpec(spec tooldto.ToolSpec, snapshot toolRequirementSnapshot) tooldto.ToolSpec {
key := resolveToolRequirementKey(spec)
- if snapshot.loaded && key == "browser" {
- if enabled, ok := resolveBrowserConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled {
- spec.Enabled = false
+ if snapshot.loaded {
+ switch key {
+ case "browser":
+ if enabled, ok := resolveBrowserConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled {
+ spec.Enabled = false
+ }
+ case "web_fetch":
+ if enabled, ok := resolveWebFetchConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled {
+ spec.Enabled = false
+ }
}
}
requirements := resolveToolRequirements(spec, snapshot)
@@ -97,30 +105,56 @@ func resolveToolRequirements(spec tooldto.ToolSpec, snapshot toolRequirementSnap
return resolveWebFetchRequirements(snapshot.toolsConfig)
case "browser":
return resolveBrowserRequirements(snapshot.toolsConfig)
+ case "canvas":
+ return resolveCanvasRequirements()
+ case "nodes":
+ return resolveNodeRequirements()
default:
return nil
}
}
-func resolveBrowserRequirements(_ map[string]any) []tooldto.ToolRequirement {
- runtimeRequirement := tooldto.ToolRequirement{
- ID: "browser.playwright_runtime",
- Name: "Playwright runtime",
- Available: true,
+func resolveBrowserRequirements(config map[string]any) []tooldto.ToolRequirement {
+ resolved := resolveBrowserRuntimeConfig(config)
+ status := browsercdp.ResolveStatus(resolved.PreferredBrowser, resolved.Headless)
+ requirements := []tooldto.ToolRequirement{
+ {
+ ID: "browser.cdp_runtime",
+ Name: "Local CDP browser",
+ Available: status.Ready,
+ Reason: strings.TrimSpace(status.DetectError),
+ Data: map[string]any{
+ "candidates": status.Candidates,
+ "selectedBrowser": status.SelectedBrowser,
+ "chosenBrowser": status.ChosenBrowser,
+ "detectedExecutablePath": status.DetectedExecutablePath,
+ "headless": status.Headless,
+ },
+ },
}
- available, reason, execPath := resolveBrowserPlaywrightRuntimeAvailability()
- runtimeRequirement.Available = available
- if !available {
- if strings.TrimSpace(reason) != "" {
- runtimeRequirement.Reason = reason
- } else {
- runtimeRequirement.Reason = "Playwright runtime is unavailable"
- }
- } else if strings.TrimSpace(execPath) != "" {
- // Surface resolved Chromium binary path for browser settings content.
- runtimeRequirement.Reason = strings.TrimSpace(execPath)
+ return requirements
+}
+
+func resolveNodeRequirements() []tooldto.ToolRequirement {
+ return []tooldto.ToolRequirement{
+ {
+ ID: "nodes.remote_runtime",
+ Name: "Remote node runtime",
+ Available: false,
+ Reason: "Remote node runtime is not implemented yet",
+ },
+ }
+}
+
+func resolveCanvasRequirements() []tooldto.ToolRequirement {
+ return []tooldto.ToolRequirement{
+ {
+ ID: "canvas.remote_runtime",
+ Name: "Remote node runtime",
+ Available: false,
+ Reason: "Remote node runtime is not implemented yet",
+ },
}
- return []tooldto.ToolRequirement{runtimeRequirement}
}
func resolveToolRequirementKey(spec tooldto.ToolSpec) string {
@@ -132,19 +166,21 @@ func resolveToolRequirementKey(spec tooldto.ToolSpec) string {
}
func resolveWebFetchRequirements(config map[string]any) []tooldto.ToolRequirement {
- enabled := true
- if value, ok := resolveWebFetchConfigBool(config, "enabled"); ok {
- enabled = value
+ status := browsercdp.ResolveStatus(resolveWebFetchPreferredBrowser(config), resolveWebFetchHeadless(config))
+ browserRequirement := tooldto.ToolRequirement{
+ ID: "web_fetch.local_browser",
+ Name: "Local CDP browser",
+ Available: status.Ready,
+ Reason: strings.TrimSpace(status.DetectError),
+ Data: map[string]any{
+ "candidates": status.Candidates,
+ "selectedBrowser": status.SelectedBrowser,
+ "chosenBrowser": status.ChosenBrowser,
+ "detectedExecutablePath": status.DetectedExecutablePath,
+ "headless": status.Headless,
+ },
}
- requirement := tooldto.ToolRequirement{
- ID: "web_fetch.config_enabled",
- Name: "Web fetch switch",
- Available: enabled,
- }
- if !enabled {
- requirement.Reason = "web_fetch is disabled in settings"
- }
- return []tooldto.ToolRequirement{requirement}
+ return []tooldto.ToolRequirement{browserRequirement}
}
func resolveWebSearchRequirements(config map[string]any) []tooldto.ToolRequirement {
@@ -264,3 +300,12 @@ func resolveWebSearchProviderAPIKey(config map[string]any, provider string) stri
}
return strings.TrimSpace(apiKey)
}
+
+func resolveWebSearchProviderString(config map[string]any, provider string, key string) string {
+ provider = strings.ToLower(strings.TrimSpace(provider))
+ key = strings.TrimSpace(key)
+ if provider == "" || key == "" {
+ return ""
+ }
+ return getNestedString(config, "web", "search", "providers", provider, key)
+}
diff --git a/internal/application/gateway/tools/tool_requirements_test.go b/internal/application/gateway/tools/tool_requirements_test.go
index 2882f0e..c11442b 100644
--- a/internal/application/gateway/tools/tool_requirements_test.go
+++ b/internal/application/gateway/tools/tool_requirements_test.go
@@ -5,11 +5,25 @@ import (
"strings"
"testing"
+ gatewayvoice "dreamcreator/internal/application/gateway/voice"
settingsdto "dreamcreator/internal/application/settings/dto"
tooldto "dreamcreator/internal/application/tools/dto"
toolservice "dreamcreator/internal/application/tools/service"
+ "dreamcreator/internal/domain/providers"
)
+type voiceStatusStub struct {
+ status gatewayvoice.TTSStatusResponse
+ err error
+}
+
+func (stub voiceStatusStub) Status(context.Context) (gatewayvoice.TTSStatusResponse, error) {
+ if stub.err != nil {
+ return gatewayvoice.TTSStatusResponse{}, stub.err
+ }
+ return stub.status, nil
+}
+
func TestResolveEffectiveToolSpecWebSearchAPIMissingKeyDisablesTool(t *testing.T) {
t.Parallel()
@@ -111,16 +125,17 @@ func TestResolveEffectiveToolSpecWebFetchDisabledByConfig(t *testing.T) {
if spec.Enabled {
t.Fatalf("expected web_fetch disabled when config enabled flag is false")
}
- requirement, ok := findRequirement(spec.Requirements, "web_fetch.config_enabled")
+ requirement, ok := findRequirement(spec.Requirements, "web_fetch.local_browser")
if !ok {
- t.Fatalf("expected web_fetch config requirement")
+ t.Fatalf("expected web_fetch local browser requirement")
}
- if requirement.Available {
- t.Fatalf("expected web_fetch config requirement to be unavailable")
+ if _, ok := findRequirement(spec.Requirements, "web_fetch.config_enabled"); ok {
+ t.Fatalf("did not expect web_fetch config enabled requirement")
}
+ _ = requirement
}
-func TestResolveEffectiveToolSpecBrowserPlaywrightIncludesRuntimeRequirement(t *testing.T) {
+func TestResolveEffectiveToolSpecBrowserIncludesCDPRuntimeRequirement(t *testing.T) {
t.Parallel()
snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
@@ -128,17 +143,23 @@ func TestResolveEffectiveToolSpecBrowserPlaywrightIncludesRuntimeRequirement(t *
Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
Tools: map[string]any{
"browser": map[string]any{
- "type": "playwright",
+ "preferredBrowser": "brave",
},
},
},
})
spec := resolveEffectiveToolSpec(specBrowser().toDTO(), snapshot)
- requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime")
+ requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime")
if !ok {
- t.Fatalf("expected browser playwright runtime requirement")
+ t.Fatalf("expected browser cdp runtime requirement")
+ }
+ data, ok := requirement.Data.(map[string]any)
+ if !ok || data == nil {
+ t.Fatalf("expected browser runtime requirement data")
+ }
+ if data["selectedBrowser"] != "brave" {
+ t.Fatalf("expected selected browser brave, got %#v", data["selectedBrowser"])
}
- _ = requirement
}
func TestResolveEffectiveToolSpecBrowserDisabledByConfig(t *testing.T) {
@@ -158,9 +179,9 @@ func TestResolveEffectiveToolSpecBrowserDisabledByConfig(t *testing.T) {
if spec.Enabled {
t.Fatalf("expected browser disabled when browser.enabled is false")
}
- requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime")
+ requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime")
if !ok {
- t.Fatalf("expected browser playwright runtime requirement")
+ t.Fatalf("expected browser cdp runtime requirement")
}
if _, ok := findRequirement(spec.Requirements, "browser.config_enabled"); ok {
t.Fatalf("did not expect browser config enabled requirement")
@@ -182,9 +203,9 @@ func TestResolveEffectiveToolSpecBrowserLegacyTypeIsIgnored(t *testing.T) {
},
})
spec := resolveEffectiveToolSpec(specBrowser().toDTO(), snapshot)
- requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime")
+ requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime")
if !ok {
- t.Fatalf("expected browser playwright runtime requirement")
+ t.Fatalf("expected browser cdp runtime requirement")
}
if _, ok := findRequirement(spec.Requirements, "browser.type_supported"); ok {
t.Fatalf("did not expect browser type requirement")
@@ -213,6 +234,280 @@ func TestResolveEffectiveToolSpecGatewayRequiresControlPlane(t *testing.T) {
}
}
+func TestResolveEffectiveToolSpecNodesRemainsDisabledUntilRemoteRuntimeExists(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ spec := resolveEffectiveToolSpec(specNodes().toDTO(), snapshot)
+ if spec.Enabled {
+ t.Fatalf("expected nodes tool disabled while remote node runtime is unavailable")
+ }
+ requirement, ok := findRequirement(spec.Requirements, "nodes.remote_runtime")
+ if !ok {
+ t.Fatalf("expected remote node runtime requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected remote node runtime requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "not implemented") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
+func TestResolveEffectiveToolSpecCanvasRemainsDisabledUntilRemoteRuntimeExists(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ spec := resolveEffectiveToolSpec(specCanvas().toDTO(), snapshot)
+ if spec.Enabled {
+ t.Fatalf("expected canvas tool disabled while remote node runtime is unavailable")
+ }
+ requirement, ok := findRequirement(spec.Requirements, "canvas.remote_runtime")
+ if !ok {
+ t.Fatalf("expected canvas remote runtime requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected canvas remote runtime requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "not implemented") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
+func TestResolveEffectiveToolSpecImageRequiresConfiguredModel(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Providers: imageProviderRepoStub{items: map[string]providers.Provider{}},
+ Models: imageModelRepoStub{items: map[string][]providers.Model{}},
+ Secrets: imageSecretRepoStub{items: map[string]providers.ProviderSecret{}},
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specImage().toDTO(), snapshot, resolver)
+ if spec.Enabled {
+ t.Fatalf("expected image tool disabled without a configured model")
+ }
+ requirement, ok := findRequirement(spec.Requirements, imageRequirementID)
+ if !ok {
+ t.Fatalf("expected image model requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected image model requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "image model") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
+func TestResolveEffectiveToolSpecImageStaysEnabledWithConfiguredModel(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Providers: imageProviderRepoStub{items: map[string]providers.Provider{
+ "openai": {
+ ID: "openai",
+ Enabled: true,
+ Type: providers.ProviderTypeOpenAI,
+ Endpoint: "https://api.openai.com/v1",
+ },
+ }},
+ Models: imageModelRepoStub{items: map[string][]providers.Model{
+ "openai": {
+ {
+ ID: "gpt-4o",
+ Name: "gpt-4o",
+ Enabled: true,
+ SupportsVision: ptrBool(true),
+ },
+ },
+ }},
+ Secrets: imageSecretRepoStub{items: map[string]providers.ProviderSecret{
+ "openai": {ProviderID: "openai", APIKey: "test-key"},
+ }},
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specImage().toDTO(), snapshot, resolver)
+ if !spec.Enabled {
+ t.Fatalf("expected image tool to stay enabled with a configured model")
+ }
+ requirement, ok := findRequirement(spec.Requirements, imageRequirementID)
+ if !ok {
+ t.Fatalf("expected image model requirement")
+ }
+ if !requirement.Available {
+ t.Fatalf("expected image model requirement to be available")
+ }
+}
+
+func TestResolveEffectiveToolSpecTTSRequiresRunnableProvider(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Voice: voiceStatusStub{
+ status: gatewayvoice.TTSStatusResponse{
+ Enabled: true,
+ Config: gatewayvoice.TTSConfig{
+ ProviderID: "edge",
+ },
+ Providers: []gatewayvoice.TTSProviderCatalogItem{
+ {ProviderID: "edge", DisplayName: "Edge-TTS", Available: true},
+ },
+ },
+ },
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver)
+ if spec.Enabled {
+ t.Fatalf("expected tts tool disabled when only edge placeholder provider is selected")
+ }
+ requirement, ok := findRequirement(spec.Requirements, ttsProviderRequirementID)
+ if !ok {
+ t.Fatalf("expected tts provider requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected tts provider requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "edge-tts") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
+func TestResolveEffectiveToolSpecTTSStaysEnabledWithConfiguredProvider(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Voice: voiceStatusStub{
+ status: gatewayvoice.TTSStatusResponse{
+ Enabled: true,
+ Config: gatewayvoice.TTSConfig{
+ ProviderID: "openai",
+ },
+ Providers: []gatewayvoice.TTSProviderCatalogItem{
+ {ProviderID: "openai", DisplayName: "OpenAI", Available: true},
+ },
+ },
+ },
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver)
+ if !spec.Enabled {
+ t.Fatalf("expected tts tool to stay enabled with a configured provider")
+ }
+ requirement, ok := findRequirement(spec.Requirements, ttsProviderAPIKeyRequirementID)
+ if !ok {
+ t.Fatalf("expected tts provider api key requirement")
+ }
+ if !requirement.Available {
+ t.Fatalf("expected tts provider api key requirement to be available")
+ }
+}
+
+func TestResolveEffectiveToolSpecTTSVoiceFeatureDisabled(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Voice: voiceStatusStub{
+ status: gatewayvoice.TTSStatusResponse{
+ Enabled: false,
+ Config: gatewayvoice.TTSConfig{
+ ProviderID: "openai",
+ },
+ Providers: []gatewayvoice.TTSProviderCatalogItem{
+ {ProviderID: "openai", DisplayName: "OpenAI", Available: true},
+ },
+ },
+ },
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver)
+ if spec.Enabled {
+ t.Fatalf("expected tts tool disabled when voice feature is disabled")
+ }
+ requirement, ok := findRequirement(spec.Requirements, ttsVoiceEnabledRequirementID)
+ if !ok {
+ t.Fatalf("expected tts voice feature requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected tts voice feature requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "voice is disabled") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
+func TestResolveEffectiveToolSpecTTSElevenLabsRequiresVoiceID(t *testing.T) {
+ t.Parallel()
+
+ snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{
+ settings: settingsdto.Settings{
+ Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true},
+ Tools: map[string]any{},
+ },
+ })
+ resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{
+ Voice: voiceStatusStub{
+ status: gatewayvoice.TTSStatusResponse{
+ Enabled: true,
+ Config: gatewayvoice.TTSConfig{
+ ProviderID: "elevenlabs",
+ },
+ Providers: []gatewayvoice.TTSProviderCatalogItem{
+ {ProviderID: "elevenlabs", DisplayName: "ElevenLabs", Available: true},
+ },
+ },
+ },
+ })
+ spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver)
+ if spec.Enabled {
+ t.Fatalf("expected tts tool disabled when elevenlabs voice id is missing")
+ }
+ requirement, ok := findRequirement(spec.Requirements, ttsVoiceIDRequirementID)
+ if !ok {
+ t.Fatalf("expected tts voice id requirement")
+ }
+ if requirement.Available {
+ t.Fatalf("expected tts voice id requirement to be unavailable")
+ }
+ if !strings.Contains(strings.ToLower(requirement.Reason), "voice id") {
+ t.Fatalf("unexpected requirement reason: %q", requirement.Reason)
+ }
+}
+
func TestPolicyPipelineDeniesUnavailableWebSearch(t *testing.T) {
t.Parallel()
diff --git a/internal/application/gateway/tools/web_tools.go b/internal/application/gateway/tools/web_tools.go
index 5e0ebf4..e83b69f 100644
--- a/internal/application/gateway/tools/web_tools.go
+++ b/internal/application/gateway/tools/web_tools.go
@@ -10,7 +10,6 @@ import (
"io"
"net/http"
"net/url"
- "os"
"regexp"
"strconv"
"strings"
@@ -18,32 +17,33 @@ import (
"time"
md "github.com/JohannesKaufmann/html-to-markdown"
+ "github.com/PuerkitoBio/goquery"
+ "github.com/chromedp/cdproto/network"
+ "github.com/chromedp/chromedp"
"github.com/hashicorp/go-retryablehttp"
- "github.com/playwright-community/playwright-go"
+ "dreamcreator/internal/application/browsercdp"
connectorsdto "dreamcreator/internal/application/connectors/dto"
- domainweb "dreamcreator/internal/domain/web"
+ appcookies "dreamcreator/internal/application/cookies"
+ "dreamcreator/internal/application/sitepolicy"
)
-const webFetchTypeBuiltin = "builtin"
-const webFetchTypePlaywright = "playwright"
-
-const defaultWebFetchType = webFetchTypeBuiltin
-const defaultWebFetchTimeoutSeconds = 20
-const defaultWebFetchMaxChars = 50000
-const defaultWebFetchMaxBodyBytes = 2 << 20
-const defaultWebFetchMaxRedirects = 3
-const defaultWebFetchRetryMax = 2
-const defaultWebFetchAcceptMarkdown = true
-const defaultWebFetchPlaywrightMarkdown = true
-const defaultWebFetchEnableUserAgent = true
-const defaultWebFetchUserAgent = domainweb.DefaultBrowserRequestUserAgent
-const defaultWebFetchAcceptLanguage = domainweb.DefaultBrowserRequestAcceptLanguage
-const defaultWebSearchTimeoutSeconds = 30
-const defaultWebSearchCacheTtlMinutes = 15
-const defaultWebSearchCount = 5
-const maxWebSearchCount = 10
-const defaultWebSearchType = "api"
+const webFetchTypeCDP = "cdp"
+
+const webFetchTypeBuiltin = webFetchTypeCDP
+
+const (
+ defaultWebFetchType = webFetchTypeCDP
+ defaultWebFetchTimeoutSeconds = 20
+ defaultWebFetchMaxChars = 50000
+ defaultWebFetchMaxBodyBytes = 2 << 20
+ defaultWebSearchTimeoutSeconds = 30
+ defaultWebSearchCacheTtlMinutes = 15
+ defaultWebSearchCount = 5
+ maxWebSearchCount = 10
+ defaultWebSearchType = "api"
+ defaultWebFetchContentSignalMain = "main_heuristic"
+)
const (
webStatusOK = "ok"
@@ -74,6 +74,38 @@ var markdownListMarkerPattern = regexp.MustCompile(`(?m)^\s*[-*+]\s+`)
var markdownBlockQuotePattern = regexp.MustCompile(`(?m)^\s*>\s*`)
var markdownTableSepPattern = regexp.MustCompile(`(?m)^\s*\|?[-:\s|]+\|?\s*$`)
+var defaultExtractorSelectors = []string{
+ "article",
+ "main",
+ "[role=main]",
+ ".article",
+ ".post-content",
+ ".entry-content",
+ ".content",
+}
+
+var defaultRemoveSelectors = []string{
+ "script",
+ "style",
+ "noscript",
+ "svg",
+ "canvas",
+ "iframe",
+ "form",
+ "nav",
+ "aside",
+ "footer",
+ "header",
+ "[role=navigation]",
+ ".sidebar",
+ ".comments",
+ ".recommend",
+ ".recommendations",
+ ".related",
+ ".advertisement",
+ ".ads",
+}
+
type webSearchResult struct {
Title string `json:"title,omitempty"`
URL string `json:"url,omitempty"`
@@ -111,10 +143,9 @@ type tavilySearchResponse struct {
Query string `json:"query"`
Answer string `json:"answer"`
Results []struct {
- Title string `json:"title"`
- URL string `json:"url"`
- Content string `json:"content"`
- RawContent interface{} `json:"raw_content"`
+ Title string `json:"title"`
+ URL string `json:"url"`
+ Content string `json:"content"`
} `json:"results"`
}
@@ -151,20 +182,9 @@ type webFetchResult struct {
}
type webFetchOptions struct {
- TimeoutSeconds int
- MaxChars int
- MaxBodyBytes int
- MaxRedirects int
- RetryMax int
- AcceptMarkdown bool
- EnableUserAgent bool
- UserAgent string
- AcceptLanguage string
- Headers map[string]any
-}
-
-type webFetchPlaywrightOptions struct {
- Markdown bool
+ TimeoutSeconds int
+ MaxChars int
+ MaxBodyBytes int
}
type webFetchResponse struct {
@@ -184,40 +204,46 @@ type ConnectorsReader interface {
ListConnectors(ctx context.Context) ([]connectorsdto.Connector, error)
}
+type extractorResult struct {
+ Content string
+ ContentSignal string
+}
+
func runWebFetchTool(settings SettingsReader, connectors ConnectorsReader) func(ctx context.Context, args string) (string, error) {
return func(ctx context.Context, args string) (string, error) {
payload, err := parseToolArgs(args)
if err != nil {
return "", err
}
- url := getStringArg(payload, "url", "href")
- if url == "" {
+ targetURL := getStringArg(payload, "url", "href")
+ if targetURL == "" {
return "", errors.New("url is required")
}
+
config := resolveToolsConfig(ctx, settings)
- enabled := true
- if value, ok := resolveWebFetchConfigBool(config, "enabled"); ok {
- enabled = value
- }
- if !enabled {
+ if enabled, ok := resolveWebFetchConfigBool(config, "enabled"); ok && !enabled {
return "", errors.New("web_fetch disabled")
}
- method := strings.ToUpper(getStringArg(payload, "method"))
+
+ method := strings.ToUpper(strings.TrimSpace(getStringArg(payload, "method")))
if method == "" {
method = http.MethodGet
}
- fetchType, err := resolveWebFetchType(payload, config)
- if err != nil {
- return "", err
+ if method != http.MethodGet {
+ result := buildWebFetchToolResult(targetURL, webFetchResponse{}, errors.New("web_fetch only supports GET"))
+ return marshalResult(result), nil
}
+
options := resolveWebFetchOptions(payload, config, defaultWebFetchTimeoutSeconds)
- playwrightOptions := resolveWebFetchPlaywrightOptions(payload, config)
- cookies, err := resolveConnectorCookiesForURL(ctx, connectors, url)
+ cookies, err := browsercdp.ResolveConnectorCookiesForURL(ctx, connectors, targetURL)
if err != nil {
return "", err
}
- response, err := fetchByWebFetchType(ctx, fetchType, method, url, cookies, options, playwrightOptions)
- result := buildWebFetchToolResult(url, response, err)
+ preferredBrowser := resolveWebFetchPreferredBrowser(config)
+ headless := resolveWebFetchHeadless(config)
+
+ response, err := fetchWithCDP(ctx, targetURL, cookies, options, preferredBrowser, headless)
+ result := buildWebFetchToolResult(targetURL, response, err)
return marshalResult(result), nil
}
}
@@ -239,6 +265,99 @@ func runWebSearchTool(settings SettingsReader, connectors ConnectorsReader) func
}
}
+func fetchWithCDP(
+ ctx context.Context,
+ targetURL string,
+ cookies []appcookies.Record,
+ options webFetchOptions,
+ preferredBrowser string,
+ headless bool,
+) (webFetchResponse, error) {
+ runtimeCtx, cancel := context.WithTimeout(ctx, time.Duration(options.TimeoutSeconds)*time.Second)
+ defer cancel()
+
+ runtime, err := browsercdp.Start(runtimeCtx, browsercdp.LaunchOptions{
+ PreferredBrowser: preferredBrowser,
+ Headless: headless,
+ })
+ if err != nil {
+ return webFetchResponse{}, err
+ }
+ defer runtime.Stop()
+
+ tabCtx, tabCancel := chromedp.NewContext(runtime.BrowserContext())
+ defer tabCancel()
+
+ var navResponse *network.Response
+ var finalURL string
+ var contentType string
+ var htmlContent string
+
+ tasks := chromedp.Tasks{
+ network.Enable(),
+ }
+ if len(cookies) > 0 {
+ tasks = append(tasks, chromedp.ActionFunc(func(ctx context.Context) error {
+ return browsercdp.SetCookies(ctx, targetURL, cookies)
+ }))
+ }
+ if err := chromedp.Run(tabCtx, tasks); err != nil {
+ return webFetchResponse{}, err
+ }
+
+ navResponse, err = chromedp.RunResponse(tabCtx, chromedp.Navigate(strings.TrimSpace(targetURL)))
+ if err != nil {
+ return webFetchResponse{}, err
+ }
+ if navResponse != nil {
+ contentType = strings.TrimSpace(navResponse.MimeType)
+ }
+ if selector := sitepolicy.ReadySelectorForURL(targetURL); selector != "" {
+ waitCtx, waitCancel := context.WithTimeout(tabCtx, time.Duration(options.TimeoutSeconds)*time.Second)
+ _ = chromedp.Run(waitCtx, chromedp.WaitVisible(selector, chromedp.ByQuery))
+ waitCancel()
+ }
+ if err := chromedp.Run(tabCtx,
+ chromedp.Location(&finalURL),
+ chromedp.OuterHTML("html", &htmlContent, chromedp.ByQuery),
+ ); err != nil {
+ return webFetchResponse{}, err
+ }
+ if finalURL == "" {
+ finalURL = targetURL
+ }
+ extracted := extractMainContent(htmlContent, finalURL)
+ content := strings.TrimSpace(extracted.Content)
+ truncated := false
+ if options.MaxBodyBytes > 0 && len(content) > options.MaxBodyBytes {
+ content = content[:options.MaxBodyBytes]
+ truncated = true
+ }
+ if options.MaxChars > 0 && len(content) > options.MaxChars {
+ content = content[:options.MaxChars]
+ truncated = true
+ }
+
+ headersMap := map[string]string{}
+ if navResponse != nil {
+ for key, value := range navResponse.Headers {
+ headersMap[key] = fmt.Sprint(value)
+ }
+ }
+
+ return webFetchResponse{
+ URL: targetURL,
+ FinalURL: finalURL,
+ Status: statusCodeFromResponse(navResponse),
+ Headers: headersMap,
+ ContentType: contentType,
+ Content: content,
+ MarkdownTokens: estimateMarkdownTokens(content),
+ ContentSignal: extracted.ContentSignal,
+ Truncated: truncated,
+ }, nil
+}
+
func runWebSearchWithFallback(
ctx context.Context,
payload toolArgs,
@@ -250,87 +369,224 @@ func runWebSearchWithFallback(
_ = connectors
searchType := resolveWebSearchType(config)
switch searchType {
- case "api":
- provider := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "provider")))
- if provider == "" {
- provider = "brave"
- }
- result, err := runWebSearchByAPI(ctx, payload, config, query, count)
- if err != nil {
- return webSearchResponse{
- Status: webStatusError,
- Retryable: isTimeoutError(err),
- NextAction: nextActionInspectErrorThenSwitch,
- Message: trimToMaxChars(strings.TrimSpace(err.Error()), 260),
- Provider: provider,
- Quality: webQualityEmpty,
- WebSearchAvailable: true,
- Query: query,
- }
- }
- quality := webQualitySufficient
- if len(result.Results) == 0 {
- quality = webQualityEmpty
- }
- return webSearchResponse{
- Status: webStatusOK,
- Retryable: false,
- NextAction: nextActionContinue,
- Message: "search_completed",
- Provider: provider,
- Quality: quality,
- WebSearchAvailable: true,
- Query: query,
- Results: result.Results,
- Cached: result.Cached,
- Data: map[string]any{
- "results_count": len(result.Results),
- },
- }
case "external_tools":
return webSearchResponse{
Status: webStatusError,
Retryable: false,
NextAction: nextActionUseOtherToolsOrSkills,
Message: "web_search_external_tools_not_configured",
- Provider: "external_tools",
+ Provider: "web_search",
Quality: webQualityEmpty,
WebSearchAvailable: false,
Query: query,
}
default:
- return webSearchResponse{
- Status: webStatusError,
- Retryable: false,
- NextAction: nextActionUseOtherToolsOrSkills,
- Message: "web_search_type_not_supported",
- Provider: "web_search",
- Quality: webQualityEmpty,
- WebSearchAvailable: false,
- Query: query,
- Data: map[string]any{
- "type": searchType,
- },
+ result, err := runWebSearchByAPI(ctx, payload, config, query, count)
+ if err != nil {
+ return webSearchResponse{
+ Status: webStatusError,
+ Retryable: false,
+ NextAction: nextActionUseOtherToolsOrSkills,
+ Message: err.Error(),
+ Provider: "web_search",
+ Quality: webQualityEmpty,
+ WebSearchAvailable: true,
+ Query: query,
+ }
}
+ return result
}
}
-func buildWebFetchToolResult(requestURL string, response webFetchResponse, err error) webFetchResult {
- finalURL := strings.TrimSpace(response.FinalURL)
- if finalURL == "" {
- finalURL = strings.TrimSpace(requestURL)
+func runWebSearchByAPI(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) (webSearchResponse, error) {
+ provider := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "provider")))
+ if provider == "" {
+ provider = "brave"
+ }
+ cacheKey := fmt.Sprintf("%s:%s:%d", provider, strings.TrimSpace(query), count)
+ if cached, ok := loadWebSearchCache(cacheKey); ok {
+ cached.Cached = true
+ return cached, nil
+ }
+
+ var (
+ results []webSearchResult
+ err error
+ )
+ switch provider {
+ case "tavily":
+ results, err = runTavilySearch(ctx, payload, config, query, count)
+ default:
+ results, err = runBraveSearch(ctx, payload, config, query, count)
+ }
+ if err != nil {
+ return webSearchResponse{}, err
+ }
+ response := webSearchResponse{
+ Status: webStatusOK,
+ Retryable: false,
+ NextAction: nextActionContinue,
+ Message: "ok",
+ Provider: provider,
+ Quality: webQualitySufficient,
+ WebSearchAvailable: true,
+ Query: query,
+ Results: results,
+ }
+ storeWebSearchCache(cacheKey, response, resolveWebSearchCacheTTL(config))
+ return response, nil
+}
+
+func runBraveSearch(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) ([]webSearchResult, error) {
+ apiKey := strings.TrimSpace(resolveWebSearchProviderAPIKey(config, "brave"))
+ if apiKey == "" {
+ return nil, fmt.Errorf("Brave API key is missing")
+ }
+ reqURL, err := url.Parse(braveSearchEndpoint)
+ if err != nil {
+ return nil, err
+ }
+ values := reqURL.Query()
+ values.Set("q", query)
+ values.Set("count", strconv.Itoa(count))
+ if value := getStringArg(payload, "country"); value != "" {
+ values.Set("country", value)
}
+ if value := getStringArg(payload, "search_lang", "searchLang"); value != "" {
+ values.Set("search_lang", value)
+ }
+ if value := getStringArg(payload, "ui_lang", "uiLang"); value != "" {
+ values.Set("ui_lang", value)
+ }
+ if value := getStringArg(payload, "freshness"); value != "" {
+ values.Set("freshness", value)
+ }
+ reqURL.RawQuery = values.Encode()
+
+ request, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ request.Header.Set("Accept", "application/json")
+ request.Header.Set("X-Subscription-Token", apiKey)
+
+ client := retryablehttp.NewClient()
+ client.RetryMax = 1
+ client.HTTPClient.Timeout = time.Duration(resolveWebSearchTimeoutSeconds(payload, config)) * time.Second
+ response, err := client.Do(request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+ if response.StatusCode >= 400 {
+ body, _ := io.ReadAll(io.LimitReader(response.Body, 2048))
+ return nil, fmt.Errorf("Brave search failed: %s", strings.TrimSpace(string(body)))
+ }
+ var payloadJSON braveSearchResponse
+ if err := json.NewDecoder(response.Body).Decode(&payloadJSON); err != nil {
+ return nil, err
+ }
+ results := make([]webSearchResult, 0, len(payloadJSON.Web.Results))
+ for _, item := range payloadJSON.Web.Results {
+ results = append(results, webSearchResult{
+ Title: strings.TrimSpace(item.Title),
+ URL: strings.TrimSpace(item.URL),
+ Description: strings.TrimSpace(item.Description),
+ Age: strings.TrimSpace(item.Age),
+ })
+ }
+ return results, nil
+}
+
+func runTavilySearch(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) ([]webSearchResult, error) {
+ apiKey := strings.TrimSpace(resolveWebSearchProviderAPIKey(config, "tavily"))
+ if apiKey == "" {
+ return nil, fmt.Errorf("Tavily API key is missing")
+ }
+ requestBody := map[string]any{
+ "api_key": apiKey,
+ "query": query,
+ "max_results": count,
+ }
+ if value := getStringArg(payload, "country"); value != "" {
+ requestBody["country"] = value
+ }
+ if value := getStringArg(payload, "search_depth", "searchDepth"); value != "" {
+ requestBody["search_depth"] = value
+ }
+ bodyBytes, err := json.Marshal(requestBody)
+ if err != nil {
+ return nil, err
+ }
+ request, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, err
+ }
+ request.Header.Set("Accept", "application/json")
+ request.Header.Set("Content-Type", "application/json")
+
+ client := retryablehttp.NewClient()
+ client.RetryMax = 1
+ client.HTTPClient.Timeout = time.Duration(resolveWebSearchTimeoutSeconds(payload, config)) * time.Second
+ response, err := client.Do(request)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+ if response.StatusCode >= 400 {
+ body, _ := io.ReadAll(io.LimitReader(response.Body, 2048))
+ return nil, fmt.Errorf("Tavily search failed: %s", strings.TrimSpace(string(body)))
+ }
+ var payloadJSON tavilySearchResponse
+ if err := json.NewDecoder(response.Body).Decode(&payloadJSON); err != nil {
+ return nil, err
+ }
+ results := make([]webSearchResult, 0, len(payloadJSON.Results))
+ for _, item := range payloadJSON.Results {
+ description := strings.TrimSpace(item.Content)
+ results = append(results, webSearchResult{
+ Title: strings.TrimSpace(item.Title),
+ URL: strings.TrimSpace(item.URL),
+ Description: description,
+ })
+ }
+ return results, nil
+}
+
+func loadWebSearchCache(key string) (webSearchResponse, bool) {
+ webSearchCache.mu.RLock()
+ entry, ok := webSearchCache.entries[key]
+ webSearchCache.mu.RUnlock()
+ if !ok || time.Now().After(entry.expiresAt) {
+ return webSearchResponse{}, false
+ }
+ return entry.value, true
+}
+
+func storeWebSearchCache(key string, value webSearchResponse, ttl time.Duration) {
+ if ttl <= 0 {
+ return
+ }
+ webSearchCache.mu.Lock()
+ webSearchCache.entries[key] = webSearchCacheEntry{
+ value: value,
+ expiresAt: time.Now().Add(ttl),
+ }
+ webSearchCache.mu.Unlock()
+}
+
+func buildWebFetchToolResult(rawURL string, response webFetchResponse, err error) webFetchResult {
if err != nil {
return webFetchResult{
Status: webStatusError,
- Retryable: isTimeoutError(err),
- NextAction: nextActionInspectErrorThenSwitch,
- Message: trimToMaxChars(strings.TrimSpace(err.Error()), 260),
+ Retryable: false,
+ NextAction: nextActionUseOtherToolsOrSkills,
+ Message: err.Error(),
Provider: "web_fetch",
Quality: webQualityEmpty,
WebSearchAvailable: true,
- URL: strings.TrimSpace(requestURL),
- FinalURL: finalURL,
+ URL: rawURL,
+ FinalURL: response.FinalURL,
HTTPStatus: response.Status,
ContentType: response.ContentType,
Content: response.Content,
@@ -338,14 +594,11 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e
ContentSignal: response.ContentSignal,
Truncated: response.Truncated,
Data: map[string]any{
- "error": trimToMaxChars(strings.TrimSpace(err.Error()), 260),
- "httpStatus": response.Status,
- "finalURL": finalURL,
"timeoutStage": response.TimeoutStage,
- "truncated": response.Truncated,
},
}
}
+
quality := webQualitySufficient
if strings.TrimSpace(response.Content) == "" {
quality = webQualityEmpty
@@ -358,8 +611,8 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e
Provider: "web_fetch",
Quality: quality,
WebSearchAvailable: true,
- URL: strings.TrimSpace(requestURL),
- FinalURL: finalURL,
+ URL: rawURL,
+ FinalURL: response.FinalURL,
HTTPStatus: response.Status,
ContentType: response.ContentType,
Content: response.Content,
@@ -367,1326 +620,367 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e
ContentSignal: response.ContentSignal,
Truncated: response.Truncated,
Data: map[string]any{
- "httpStatus": response.Status,
- "finalURL": finalURL,
- "timeoutStage": response.TimeoutStage,
- "truncated": response.Truncated,
+ "extractor": response.ContentSignal,
+ "browserSource": webFetchTypeCDP,
+ "timeoutStage": response.TimeoutStage,
},
}
}
-func runWebSearchByAPI(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) (webSearchResponse, error) {
- provider := strings.ToLower(getNestedString(config, "web", "search", "provider"))
- if provider == "" {
- provider = "brave"
+func resolveWebFetchType(payload toolArgs, config map[string]any) (string, error) {
+ if raw := strings.TrimSpace(getStringArg(payload, "type", "mode")); raw != "" {
+ normalized := normalizeWebFetchType(raw)
+ if normalized == "" {
+ return "", fmt.Errorf("unsupported web_fetch type: %s", raw)
+ }
+ return normalized, nil
}
- country := resolveWebSearchString(payload, config, "country", "US")
- searchLang := resolveWebSearchString(payload, config, "search_lang", "")
- uiLang := resolveWebSearchString(payload, config, "ui_lang", "")
- freshness := resolveWebSearchString(payload, config, "freshness", "")
- cacheTtlMinutes := resolveWebSearchInt(config, "cacheTtlMinutes", defaultWebSearchCacheTtlMinutes)
- cacheKey := normalizeWebSearchCacheKey(provider, query, count, country, searchLang, uiLang, freshness)
- if cacheTtlMinutes > 0 {
- if cached, ok := readWebSearchCache(cacheKey); ok {
- cached.Cached = true
- return cached, nil
+ if fetchConfig := getNestedMap(config, "web_fetch"); fetchConfig != nil {
+ if raw := strings.TrimSpace(getStringArg(toolArgs(fetchConfig), "type", "mode")); raw != "" {
+ normalized := normalizeWebFetchType(raw)
+ if normalized == "" {
+ return "", fmt.Errorf("unsupported web_fetch type: %s", raw)
+ }
+ return normalized, nil
}
}
- timeoutSeconds := resolveWebSearchInt(config, "timeoutSeconds", defaultWebSearchTimeoutSeconds)
- timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
- defer cancel()
+ return defaultWebFetchType, nil
+}
- var (
- response webSearchResponse
- err error
- )
- switch provider {
- case "brave":
- response, err = runBraveSearch(timeoutCtx, config, query, count, country, searchLang, uiLang, freshness, timeoutSeconds)
- case "tavily":
- response, err = runTavilySearch(timeoutCtx, config, payload, query, count, timeoutSeconds)
+func normalizeWebFetchType(value string) string {
+ switch strings.ToLower(strings.TrimSpace(value)) {
+ case "", "cdp", "chrome", "chromium", "browser", "builtin":
+ return webFetchTypeCDP
default:
- return webSearchResponse{}, errors.New("web_search provider not implemented: " + provider)
- }
- if err != nil {
- return webSearchResponse{}, err
- }
- response.Query = query
- response.Provider = provider
- if cacheTtlMinutes > 0 {
- writeWebSearchCache(cacheKey, response, time.Duration(cacheTtlMinutes)*time.Minute)
+ return ""
}
- return response, nil
}
-func resolveWebSearchCount(payload toolArgs, config map[string]any) int {
- count, ok := getIntArg(payload, "count", "maxResults", "max_results")
- if !ok || count <= 0 {
- count = resolveWebSearchInt(config, "maxResults", defaultWebSearchCount)
+func resolveWebFetchOptions(payload toolArgs, config map[string]any, fallbackTimeoutSeconds int) webFetchOptions {
+ fetchConfig := getNestedMap(config, "web_fetch")
+ timeoutSeconds := fallbackTimeoutSeconds
+ if value, ok := getIntArg(toolArgs(fetchConfig), "timeoutSeconds"); ok && value > 0 {
+ timeoutSeconds = value
}
- if count <= 0 {
- count = defaultWebSearchCount
+ if value, ok := getIntArg(payload, "timeoutSeconds"); ok && value > 0 {
+ timeoutSeconds = value
}
- if count > maxWebSearchCount {
- count = maxWebSearchCount
+ maxChars := defaultWebFetchMaxChars
+ if value, ok := getIntArg(toolArgs(fetchConfig), "maxChars"); ok && value > 0 {
+ maxChars = value
}
- if count < 1 {
- count = 1
+ if value, ok := getIntArg(payload, "maxChars"); ok && value > 0 {
+ maxChars = value
}
- return count
-}
-
-func resolveWebSearchString(payload toolArgs, config map[string]any, key string, fallback string) string {
- value := getStringArg(payload, key)
- if value == "" {
- value = getNestedString(config, "web", "search", key)
+ maxBodyBytes := defaultWebFetchMaxBodyBytes
+ if value, ok := getIntArg(toolArgs(fetchConfig), "maxBodyBytes"); ok && value > 0 {
+ maxBodyBytes = value
}
- if value == "" {
- return fallback
+ if value, ok := getIntArg(payload, "maxBodyBytes"); ok && value > 0 {
+ maxBodyBytes = value
}
- return value
-}
-func normalizeWebSearchType(value string) string {
- normalized := strings.ToLower(strings.TrimSpace(value))
- switch normalized {
- case "api":
- return "api"
- case "external_tools", "external-tools", "external tools":
- return "external_tools"
- default:
- return ""
+ return webFetchOptions{
+ TimeoutSeconds: timeoutSeconds,
+ MaxChars: maxChars,
+ MaxBodyBytes: maxBodyBytes,
}
}
-func resolveWebSearchType(config map[string]any) string {
- if value := normalizeWebSearchType(getNestedString(config, "web", "search", "type")); value != "" {
+func connectorTypeForURL(rawURL string) string {
+ return browsercdp.ConnectorTypeForURL(rawURL)
+}
+
+func resolveWebFetchConfigBool(config map[string]any, key string) (bool, bool) {
+ fetchConfig := getNestedMap(config, "web_fetch")
+ return getBoolArg(toolArgs(fetchConfig), key)
+}
+
+func resolveWebFetchPreferredBrowser(config map[string]any) string {
+ fetchConfig := getNestedMap(config, "web_fetch")
+ if value := strings.TrimSpace(getStringArg(toolArgs(fetchConfig), "preferredBrowser")); value != "" {
return value
}
- if enabled, ok := getNestedBool(config, "web", "search", "enabled"); ok && enabled {
- return "api"
- }
- return defaultWebSearchType
+ browserConfig := getNestedMap(config, "browser")
+ return strings.TrimSpace(getStringArg(toolArgs(browserConfig), "preferredBrowser"))
}
-func normalizeConnectorType(value string) string {
- normalized := strings.ToLower(strings.TrimSpace(value))
- switch normalized {
- case "google":
- return "google"
- case "xiaohongshu", "xhs":
- return "xiaohongshu"
- case "bilibili", "b23":
- return "bilibili"
- default:
- return ""
+func resolveWebFetchHeadless(config map[string]any) bool {
+ fetchConfig := getNestedMap(config, "web_fetch")
+ if value, ok := getBoolArg(toolArgs(fetchConfig), "headless"); ok {
+ return value
}
+ return true
}
-func normalizeWebFetchType(value string) string {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case webFetchTypePlaywright:
- return webFetchTypePlaywright
- case webFetchTypeBuiltin:
- return webFetchTypeBuiltin
- default:
- return ""
- }
-}
-
-func resolveWebFetchType(payload toolArgs, config map[string]any) (string, error) {
- if raw := strings.TrimSpace(getStringArg(payload, "type", "mode")); raw != "" {
- if value := normalizeWebFetchType(raw); value != "" {
- return value, nil
- }
- return "", fmt.Errorf("unsupported web_fetch type: %s", raw)
- }
- fetchConfig := toolArgs(resolveWebFetchConfig(config))
- if raw := strings.TrimSpace(getStringArg(fetchConfig, "type", "mode")); raw != "" {
- if value := normalizeWebFetchType(raw); value != "" {
- return value, nil
- }
- return "", fmt.Errorf("unsupported web_fetch type: %s", raw)
- }
- return defaultWebFetchType, nil
-}
-
-func resolveWebFetchPlaywrightOptions(payload toolArgs, config map[string]any) webFetchPlaywrightOptions {
- markdown := defaultWebFetchPlaywrightMarkdown
- fetchConfig := toolArgs(resolveWebFetchConfig(config))
- if playwrightConfig := getMapArg(fetchConfig, "playwright"); playwrightConfig != nil {
- if value, ok := getBoolArg(toolArgs(playwrightConfig), "markdown", "toMarkdown"); ok {
- markdown = value
- }
- }
- if payloadPlaywright := getMapArg(payload, "playwright"); payloadPlaywright != nil {
- if value, ok := getBoolArg(toolArgs(payloadPlaywright), "markdown", "toMarkdown"); ok {
- markdown = value
- }
- }
- if value, ok := getBoolArg(payload, "markdown", "toMarkdown"); ok {
- markdown = value
- }
- return webFetchPlaywrightOptions{
- Markdown: markdown,
- }
-}
-
-func fetchByWebFetchType(
- ctx context.Context,
- fetchType string,
- method string,
- targetURL string,
- cookies []connectorsdto.ConnectorCookie,
- options webFetchOptions,
- playwrightOptions webFetchPlaywrightOptions,
-) (webFetchResponse, error) {
- switch normalizeWebFetchType(fetchType) {
- case webFetchTypePlaywright:
- if strings.ToUpper(strings.TrimSpace(method)) != http.MethodGet {
- return webFetchResponse{}, errors.New("web_fetch playwright mode only supports GET")
- }
- return fetchWithPlaywrightOptions(ctx, targetURL, cookies, options, playwrightOptions)
- case webFetchTypeBuiltin:
- return fetchWithBuiltinOptions(ctx, method, targetURL, cookies, options)
- default:
- return fetchWithBuiltinOptions(ctx, method, targetURL, cookies, options)
- }
-}
-
-func resolveConnectorCookiesForURL(
- ctx context.Context,
- connectors ConnectorsReader,
- targetURL string,
-) ([]connectorsdto.ConnectorCookie, error) {
- if connectors == nil {
- return nil, nil
- }
- connectorType := connectorTypeForURL(targetURL)
- if connectorType == "" {
- return nil, nil
- }
- items, err := connectors.ListConnectors(ctx)
+func extractMainContent(rawHTML string, pageURL string) extractorResult {
+ policy, _ := sitepolicy.ForURL(pageURL)
+ reader := strings.NewReader(rawHTML)
+ doc, err := goquery.NewDocumentFromReader(reader)
if err != nil {
- return nil, err
- }
- for _, item := range items {
- if normalizeConnectorType(item.Type) != connectorType {
- continue
+ return extractorResult{
+ Content: compactText(rawHTML),
+ ContentSignal: "body_fallback",
}
- if len(item.Cookies) == 0 {
- continue
- }
- return item.Cookies, nil
- }
- return nil, nil
-}
-
-func connectorTypeForURL(targetURL string) string {
- parsed, err := url.Parse(strings.TrimSpace(targetURL))
- if err != nil {
- return ""
- }
- host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
- if host == "" {
- return ""
- }
- switch {
- case hostMatchesDomain(host, "google.com"), hostMatchesDomain(host, "youtube.com"), hostMatchesDomain(host, "youtu.be"):
- return "google"
- case hostMatchesDomain(host, "xiaohongshu.com"), hostMatchesDomain(host, "xhslink.com"), hostMatchesDomain(host, "redbook.com"):
- return "xiaohongshu"
- case hostMatchesDomain(host, "bilibili.com"), hostMatchesDomain(host, "b23.tv"):
- return "bilibili"
- default:
- return ""
- }
-}
-
-func hostMatchesDomain(host string, domain string) bool {
- normalizedHost := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(host)), ".")
- normalizedDomain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(domain)), ".")
- if normalizedHost == "" || normalizedDomain == "" {
- return false
}
- return normalizedHost == normalizedDomain || strings.HasSuffix(normalizedHost, "."+normalizedDomain)
-}
-
-func fetchWithPlaywrightOptions(
- ctx context.Context,
- targetURL string,
- cookies []connectorsdto.ConnectorCookie,
- options webFetchOptions,
- playwrightOptions webFetchPlaywrightOptions,
-) (webFetchResponse, error) {
- if strings.TrimSpace(targetURL) == "" {
- return webFetchResponse{}, errors.New("target url is required")
- }
- pw, err := playwright.Run()
- if err != nil {
- return webFetchResponse{}, err
- }
- defer pw.Stop()
-
- browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{
- Headless: playwright.Bool(true),
- Args: []string{
- "--headless=new",
- },
- })
- if err != nil {
- return webFetchResponse{}, err
- }
- defer browser.Close()
-
- contextOptions := playwright.BrowserNewContextOptions{
- Viewport: &playwright.Size{
- Width: 1366,
- Height: 900,
- },
- TimezoneId: playwright.String("UTC"),
- }
- if options.EnableUserAgent {
- userAgent := strings.TrimSpace(options.UserAgent)
- if userAgent == "" {
- userAgent = defaultWebFetchUserAgent
- }
- contextOptions.UserAgent = playwright.String(userAgent)
- }
- if locale := localeFromAcceptLanguage(options.AcceptLanguage); locale != "" {
- contextOptions.Locale = playwright.String(locale)
- }
- if headers := webFetchExtraHTTPHeaders(options); len(headers) > 0 {
- contextOptions.ExtraHttpHeaders = headers
- }
- browserCtx, err := browser.NewContext(contextOptions)
- if err != nil {
- return webFetchResponse{}, err
+ for _, selector := range append(defaultRemoveSelectors, policy.RemoveSelectors...) {
+ doc.Find(selector).Each(func(_ int, selection *goquery.Selection) {
+ selection.Remove()
+ })
}
- defer browserCtx.Close()
- if len(cookies) > 0 {
- if err := browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL)); err != nil {
- return webFetchResponse{}, err
+ root := pickBestContentRoot(doc, policy)
+ if root == nil || root.Length() == 0 {
+ body := strings.TrimSpace(doc.Find("body").First().Text())
+ return extractorResult{
+ Content: compactMarkdown(body),
+ ContentSignal: "body_fallback",
}
}
- page, err := browserCtx.NewPage()
- if err != nil {
- return webFetchResponse{}, err
- }
-
- timeoutMs := float64(resolveWebFetchTimeoutSeconds(options.TimeoutSeconds) * 1000)
- navTimeoutMs := timeoutMs * 0.6
- if navTimeoutMs < 1000 {
- navTimeoutMs = timeoutMs
- }
- readyTimeoutMs := timeoutMs * 0.25
- if readyTimeoutMs < 600 {
- readyTimeoutMs = 600
- }
- extractTimeoutMs := timeoutMs * 0.15
- if extractTimeoutMs < 500 {
- extractTimeoutMs = 500
- }
- response, err := page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{
- Timeout: playwright.Float(navTimeoutMs),
- WaitUntil: playwright.WaitUntilStateDomcontentloaded,
- })
+ htmlFragment, err := root.Html()
if err != nil {
- return webFetchResponse{}, err
- }
-
- timeoutStage := ""
- finalURL := strings.TrimSpace(page.URL())
- if finalURL == "" {
- finalURL = strings.TrimSpace(targetURL)
- }
- if selector := resolvePlaywrightReadySelector(finalURL, targetURL); selector != "" {
- if readyErr := waitForPlaywrightReady(page, selector, readyTimeoutMs); readyErr != nil {
- timeoutStage = "ready"
+ text := compactMarkdown(root.Text())
+ if text == "" {
+ return extractorResult{ContentSignal: "body_fallback"}
}
- }
-
- status := http.StatusOK
- contentType := ""
- headers := map[string]string{}
- if response != nil {
- status = response.Status()
- headers = normalizeHTTPHeaderMap(response.Headers())
- contentType = strings.TrimSpace(headers["content-type"])
- if contentType == "" {
- allHeaders, err := response.AllHeaders()
- if err == nil {
- for key, value := range allHeaders {
- if _, exists := headers[strings.ToLower(strings.TrimSpace(key))]; !exists {
- headers[strings.ToLower(strings.TrimSpace(key))] = strings.TrimSpace(value)
- }
- }
- contentType = strings.TrimSpace(headers["content-type"])
- }
- }
- }
-
- content, err := contentWithTimeout(ctx, page, time.Duration(extractTimeoutMs)*time.Millisecond)
- if err != nil {
- return webFetchResponse{}, err
- }
- if playwrightOptions.Markdown {
- if markdown, err := convertHTMLToMarkdown(content, finalURL); err == nil && strings.TrimSpace(markdown) != "" {
- content = markdown
- contentType = "text/markdown; charset=utf-8"
+ return extractorResult{
+ Content: text,
+ ContentSignal: defaultWebFetchContentSignalMain,
}
}
- content, truncated := truncateWebFetchContent([]byte(content), options.MaxChars)
-
- return webFetchResponse{
- URL: strings.TrimSpace(targetURL),
- FinalURL: finalURL,
- Status: status,
- Headers: headers,
- ContentType: contentType,
- Content: content,
- TimeoutStage: timeoutStage,
- Truncated: truncated,
- }, nil
-}
-func fetchWithBuiltinOptions(
- ctx context.Context,
- method string,
- targetURL string,
- cookies []connectorsdto.ConnectorCookie,
- options webFetchOptions,
-) (webFetchResponse, error) {
- if strings.TrimSpace(targetURL) == "" {
- return webFetchResponse{}, errors.New("target url is required")
- }
- httpMethod := strings.ToUpper(strings.TrimSpace(method))
- if httpMethod == "" {
- httpMethod = http.MethodGet
- }
- request, err := retryablehttp.NewRequestWithContext(ctx, httpMethod, targetURL, nil)
+ converter := md.NewConverter("", true, nil)
+ markdown, err := converter.ConvertString(htmlFragment)
if err != nil {
- return webFetchResponse{}, err
- }
- applyWebFetchHeaders(request.Request, options)
- if cookieHeader := buildCookieHeader(targetURL, cookies); cookieHeader != "" {
- if existing := strings.TrimSpace(request.Header.Get("Cookie")); existing != "" {
- request.Header.Set("Cookie", existing+"; "+cookieHeader)
- } else {
- request.Header.Set("Cookie", cookieHeader)
- }
+ markdown = compactMarkdown(root.Text())
}
- timeoutSeconds := resolveWebFetchTimeoutSeconds(options.TimeoutSeconds)
- maxRedirects := resolveWebFetchMaxRedirects(options.MaxRedirects)
- retryMax := resolveWebFetchRetryMax(options.RetryMax)
- client := retryablehttp.NewClient()
- client.RetryMax = retryMax
- client.RetryWaitMin = 200 * time.Millisecond
- client.RetryWaitMax = 2 * time.Second
- client.Logger = nil
- client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
- if !isWebFetchRetryableMethod(httpMethod) {
- return false, err
- }
- return retryablehttp.DefaultRetryPolicy(ctx, resp, err)
+ markdown = compactMarkdown(markdown)
+ if markdown == "" {
+ markdown = compactMarkdown(root.Text())
}
- client.HTTPClient.Timeout = time.Duration(timeoutSeconds) * time.Second
- client.HTTPClient.CheckRedirect = func(_ *http.Request, via []*http.Request) error {
- if len(via) > maxRedirects {
- return errors.New("too many redirects")
+ if markdown == "" {
+ return extractorResult{
+ Content: compactMarkdown(doc.Text()),
+ ContentSignal: "body_fallback",
}
- return nil
- }
- resp, err := client.Do(request)
- if err != nil {
- return webFetchResponse{}, err
}
- defer resp.Body.Close()
- body, bodyTruncated, err := readWebFetchBodyLimited(resp.Body, options.MaxBodyBytes)
- if err != nil {
- return webFetchResponse{}, err
- }
- content, charTruncated := truncateWebFetchContent(body, options.MaxChars)
- markdownTokens, _ := strconv.Atoi(strings.TrimSpace(resp.Header.Get("x-markdown-tokens")))
- finalURL := strings.TrimSpace(targetURL)
- if resp.Request != nil && resp.Request.URL != nil {
- if resolved := strings.TrimSpace(resp.Request.URL.String()); resolved != "" {
- finalURL = resolved
- }
- }
- return webFetchResponse{
- URL: strings.TrimSpace(targetURL),
- FinalURL: finalURL,
- Status: resp.StatusCode,
- Headers: normalizeHTTPHeaderValues(resp.Header),
- ContentType: strings.TrimSpace(resp.Header.Get("Content-Type")),
- Content: content,
- MarkdownTokens: markdownTokens,
- ContentSignal: strings.TrimSpace(resp.Header.Get("content-signal")),
- Truncated: bodyTruncated || charTruncated,
- }, nil
-}
-
-func isWebFetchRetryableMethod(method string) bool {
- switch strings.ToUpper(strings.TrimSpace(method)) {
- case http.MethodGet, http.MethodHead, http.MethodOptions:
- return true
- default:
- return false
- }
-}
-
-func resolveWebFetchTimeoutSeconds(value int) int {
- timeoutSeconds := value
- if timeoutSeconds <= 0 {
- timeoutSeconds = defaultWebFetchTimeoutSeconds
- }
- return timeoutSeconds
-}
-
-func resolveWebFetchMaxRedirects(value int) int {
- maxRedirects := value
- if maxRedirects < 0 {
- maxRedirects = 0
- }
- return maxRedirects
-}
-
-func resolveWebFetchRetryMax(value int) int {
- retryMax := value
- if retryMax < 0 {
- retryMax = 0
+ return extractorResult{
+ Content: markdown,
+ ContentSignal: resolveContentSignal(policy, root),
}
- return retryMax
}
-func webFetchExtraHTTPHeaders(options webFetchOptions) map[string]string {
- result := make(map[string]string)
- for key, value := range options.Headers {
- trimmedKey := strings.TrimSpace(key)
- if trimmedKey == "" {
- continue
- }
- result[trimmedKey] = toString(value)
- }
- if _, exists := result["Accept"]; !exists {
- if options.AcceptMarkdown {
- result["Accept"] = "text/markdown, text/html;q=0.9, application/xhtml+xml;q=0.8"
- } else {
- result["Accept"] = "text/html,application/xhtml+xml"
+func pickBestContentRoot(doc *goquery.Document, policy sitepolicy.Policy) *goquery.Selection {
+ for _, selector := range policy.ExtractorSelectors {
+ selection := doc.Find(selector).First()
+ if selection.Length() > 0 && len(compactText(selection.Text())) >= 160 {
+ return selection
}
}
- if _, exists := result["Accept-Language"]; !exists && strings.TrimSpace(options.AcceptLanguage) != "" {
- result["Accept-Language"] = strings.TrimSpace(options.AcceptLanguage)
- }
- return result
-}
-
-func localeFromAcceptLanguage(value string) string {
- segments := strings.Split(strings.TrimSpace(value), ",")
- for _, segment := range segments {
- base := strings.TrimSpace(segment)
- if base == "" {
- continue
+ for _, selector := range defaultExtractorSelectors {
+ selection := doc.Find(selector).First()
+ if selection.Length() > 0 && len(compactText(selection.Text())) >= 160 {
+ return selection
}
- if index := strings.Index(base, ";"); index >= 0 {
- base = strings.TrimSpace(base[:index])
- }
- if base != "" {
- return base
- }
- }
- return ""
-}
-
-func convertHTMLToMarkdown(content string, targetURL string) (string, error) {
- converter := md.NewConverter(md.DomainFromURL(strings.TrimSpace(targetURL)), true, nil)
- markdown, err := converter.ConvertString(content)
- if err != nil {
- return "", err
- }
- return strings.TrimSpace(markdown), nil
-}
-
-func resolveWebFetchOptions(payload toolArgs, config map[string]any, timeoutFallback int) webFetchOptions {
- fetchConfig := toolArgs(resolveWebFetchConfig(config))
- timeoutSeconds, ok := getIntArg(fetchConfig, "timeoutSeconds")
- if !ok || timeoutSeconds <= 0 {
- timeoutSeconds, _ = getIntArg(payload, "timeoutSeconds")
- }
- if timeoutSeconds <= 0 {
- timeoutSeconds = timeoutFallback
- }
- if timeoutSeconds <= 0 {
- timeoutSeconds = defaultWebFetchTimeoutSeconds
- }
-
- maxChars, ok := getIntArg(fetchConfig, "maxChars")
- if !ok || maxChars <= 0 {
- maxChars, _ = getIntArg(payload, "maxChars")
- }
- if maxChars <= 0 {
- maxChars = defaultWebFetchMaxChars
- }
-
- maxBodyBytes, ok := getIntArg(fetchConfig, "maxBodyBytes")
- if !ok || maxBodyBytes <= 0 {
- maxBodyBytes, _ = getIntArg(payload, "maxBodyBytes")
- }
- if maxBodyBytes <= 0 {
- maxBodyBytes = defaultWebFetchMaxBodyBytes
}
- maxRedirects := defaultWebFetchMaxRedirects
- if value, present := getIntArg(fetchConfig, "maxRedirects"); present {
- maxRedirects = value
- }
- if value, present := getIntArg(payload, "maxRedirects"); present {
- maxRedirects = value
- }
- if maxRedirects < 0 {
- maxRedirects = 0
- }
-
- retryMax := defaultWebFetchRetryMax
- if value, present := getIntArg(fetchConfig, "retryMax"); present {
- retryMax = value
- }
- if value, present := getIntArg(payload, "retryMax"); present {
- retryMax = value
- }
- if retryMax < 0 {
- retryMax = 0
- }
-
- acceptMarkdown, ok := getBoolArg(fetchConfig, "acceptMarkdown")
- if !ok {
- acceptMarkdown = defaultWebFetchAcceptMarkdown
- }
- if value, present := getBoolArg(payload, "acceptMarkdown", "preferMarkdown", "markdown"); present {
- acceptMarkdown = value
- }
-
- enableUserAgent, ok := getBoolArg(fetchConfig, "enableUserAgent")
- if !ok {
- enableUserAgent = defaultWebFetchEnableUserAgent
- }
- if value, present := getBoolArg(payload, "enableUserAgent", "useUserAgent"); present {
- enableUserAgent = value
- }
-
- userAgent := getStringArg(payload, "userAgent")
- if userAgent == "" {
- userAgent = getStringArg(fetchConfig, "userAgent")
- }
- if userAgent == "" {
- userAgent = defaultWebFetchUserAgent
- }
-
- acceptLanguage := getStringArg(payload, "acceptLanguage")
- if acceptLanguage == "" {
- acceptLanguage = getStringArg(fetchConfig, "acceptLanguage")
- }
- if acceptLanguage == "" {
- acceptLanguage = defaultWebFetchAcceptLanguage
- }
- headers := mergeAnyMap(getMapArg(fetchConfig, "headers"), getMapArg(payload, "headers"))
-
- return webFetchOptions{
- TimeoutSeconds: timeoutSeconds,
- MaxChars: maxChars,
- MaxBodyBytes: maxBodyBytes,
- MaxRedirects: maxRedirects,
- RetryMax: retryMax,
- AcceptMarkdown: acceptMarkdown,
- EnableUserAgent: enableUserAgent,
- UserAgent: strings.TrimSpace(userAgent),
- AcceptLanguage: strings.TrimSpace(acceptLanguage),
- Headers: headers,
- }
-}
-
-func resolveWebFetchConfig(config map[string]any) map[string]any {
- current := getNestedMap(config, "web_fetch")
- if current == nil {
- return nil
- }
- return current
-}
-
-func resolveWebFetchConfigBool(config map[string]any, key string) (bool, bool) {
- return getBoolArg(toolArgs(resolveWebFetchConfig(config)), key)
-}
-
-func applyWebFetchHeaders(request *http.Request, options webFetchOptions) {
- if request == nil {
- return
- }
- for key, value := range options.Headers {
- trimmedKey := strings.TrimSpace(key)
- if trimmedKey == "" {
- continue
- }
- request.Header.Set(trimmedKey, toString(value))
- }
- if strings.TrimSpace(request.Header.Get("Accept")) == "" {
- if options.AcceptMarkdown {
- request.Header.Set("Accept", "text/markdown, text/html;q=0.9, application/xhtml+xml;q=0.8")
- } else {
- request.Header.Set("Accept", "text/html,application/xhtml+xml")
- }
- }
- if strings.TrimSpace(request.Header.Get("User-Agent")) == "" && options.EnableUserAgent {
- userAgent := strings.TrimSpace(options.UserAgent)
- if userAgent == "" {
- userAgent = defaultWebFetchUserAgent
+ var (
+ best *goquery.Selection
+ bestScore float64
+ )
+ doc.Find("article,main,section,div").Each(func(_ int, selection *goquery.Selection) {
+ text := compactText(selection.Text())
+ textLen := len(text)
+ if textLen < 160 {
+ return
+ }
+ linkTextLen := len(compactText(selection.Find("a").Text()))
+ linkDensity := 0.0
+ if textLen > 0 {
+ linkDensity = float64(linkTextLen) / float64(textLen)
+ }
+ paragraphs := selection.Find("p").Length()
+ score := float64(textLen) + float64(paragraphs*80) - (linkDensity * 800)
+ if score > bestScore {
+ bestScore = score
+ best = selection
}
- request.Header.Set("User-Agent", userAgent)
- } else if strings.TrimSpace(request.Header.Get("User-Agent")) == "" && !options.EnableUserAgent {
- // Prevent net/http from injecting the default Go user-agent when disabled.
- request.Header.Set("User-Agent", "")
+ })
+ if best != nil {
+ return best
}
- if strings.TrimSpace(request.Header.Get("Accept-Language")) == "" && strings.TrimSpace(options.AcceptLanguage) != "" {
- request.Header.Set("Accept-Language", strings.TrimSpace(options.AcceptLanguage))
+ body := doc.Find("body").First()
+ if body.Length() > 0 {
+ return body
}
+ return nil
}
-func truncateWebFetchContent(body []byte, maxChars int) (string, bool) {
- content := string(body)
- if maxChars <= 0 {
- return content, false
+func resolveContentSignal(policy sitepolicy.Policy, root *goquery.Selection) string {
+ if root == nil || root.Length() == 0 {
+ return "body_fallback"
}
- runeCount := 0
- for index := range content {
- if runeCount >= maxChars {
- return content[:index], true
+ nodeName := goquery.NodeName(root)
+ switch nodeName {
+ case "article":
+ return "article_readability"
+ case "main":
+ return "main_heuristic"
+ default:
+ if policy.Key != "" {
+ return defaultWebFetchContentSignalMain
}
- runeCount++
- }
- if runeCount <= maxChars {
- return content, false
- }
- return content, false
-}
-
-func readWebFetchBodyLimited(reader io.Reader, maxBytes int) ([]byte, bool, error) {
- if reader == nil {
- return nil, false, nil
+ return defaultWebFetchContentSignalMain
}
- if maxBytes <= 0 {
- maxBytes = defaultWebFetchMaxBodyBytes
- }
- limited := io.LimitReader(reader, int64(maxBytes)+1)
- body, err := io.ReadAll(limited)
- if err != nil {
- return nil, false, err
- }
- if len(body) > maxBytes {
- return body[:maxBytes], true, nil
- }
- return body, false, nil
}
-func buildCookieHeader(targetURL string, cookies []connectorsdto.ConnectorCookie) string {
- if len(cookies) == 0 {
- return ""
- }
- parsed, err := url.Parse(targetURL)
- if err != nil {
- return ""
- }
- host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
- path := strings.TrimSpace(parsed.Path)
- if path == "" {
- path = "/"
- }
- now := time.Now().Unix()
- pairs := make([]string, 0, len(cookies))
- for _, cookie := range cookies {
- name := strings.TrimSpace(cookie.Name)
- if name == "" {
+func compactMarkdown(input string) string {
+ text := strings.ReplaceAll(input, "\r\n", "\n")
+ text = markdownCodeFencePattern.ReplaceAllStringFunc(text, func(code string) string {
+ return strings.TrimSpace(code)
+ })
+ lines := strings.Split(text, "\n")
+ result := make([]string, 0, len(lines))
+ previousBlank := false
+ seen := map[string]struct{}{}
+ for _, line := range lines {
+ trimmed := strings.TrimSpace(line)
+ if trimmed == "" {
+ if previousBlank {
+ continue
+ }
+ previousBlank = true
+ result = append(result, "")
continue
}
- if cookie.Expires > 0 && cookie.Expires < now {
+ previousBlank = false
+ normalized := htmlSpacePattern.ReplaceAllString(trimmed, " ")
+ if strings.HasPrefix(normalized, "Recommended") || strings.HasPrefix(normalized, "Related") {
continue
}
- if !cookieDomainMatches(host, cookie.Domain) {
- continue
+ if len(normalized) > 4000 {
+ normalized = normalized[:4000]
}
- if !cookiePathMatches(path, cookie.Path) {
+ key := strings.ToLower(normalized)
+ if _, exists := seen[key]; exists && len(normalized) > 48 {
continue
}
- pairs = append(pairs, name+"="+cookie.Value)
- }
- return strings.Join(pairs, "; ")
-}
-
-func toPlaywrightCookies(cookies []connectorsdto.ConnectorCookie, targetURL string) []playwright.OptionalCookie {
- if len(cookies) == 0 {
- return nil
- }
- result := make([]playwright.OptionalCookie, 0, len(cookies))
- for _, item := range cookies {
- name := strings.TrimSpace(item.Name)
- if name == "" {
- continue
- }
- cookie := playwright.OptionalCookie{
- Name: name,
- Value: item.Value,
- }
- domain := strings.TrimSpace(item.Domain)
- path := strings.TrimSpace(item.Path)
- if domain != "" {
- cookie.Domain = playwright.String(domain)
- } else if strings.TrimSpace(targetURL) != "" {
- cookie.URL = playwright.String(strings.TrimSpace(targetURL))
- }
- if path != "" {
- cookie.Path = playwright.String(path)
- }
- if item.Expires > 0 {
- cookie.Expires = playwright.Float(float64(item.Expires))
+ if len(normalized) > 48 {
+ seen[key] = struct{}{}
}
- cookie.HttpOnly = playwright.Bool(item.HttpOnly)
- cookie.Secure = playwright.Bool(item.Secure)
- if sameSite := toPlaywrightSameSite(item.SameSite); sameSite != nil {
- cookie.SameSite = sameSite
- }
- result = append(result, cookie)
+ result = append(result, normalized)
}
- return result
+ return strings.TrimSpace(strings.Join(result, "\n"))
}
-func toPlaywrightSameSite(value string) *playwright.SameSiteAttribute {
- switch strings.ToLower(strings.TrimSpace(value)) {
- case "lax":
- return playwright.SameSiteAttributeLax
- case "strict":
- return playwright.SameSiteAttributeStrict
- case "none":
- return playwright.SameSiteAttributeNone
- default:
- return nil
- }
-}
-
-func cookieDomainMatches(host string, cookieDomain string) bool {
- domain := strings.ToLower(strings.TrimSpace(cookieDomain))
- if domain == "" {
- return true
- }
- domain = strings.TrimPrefix(domain, ".")
- if domain == "" {
- return true
- }
- return host == domain || strings.HasSuffix(host, "."+domain)
+func compactText(input string) string {
+ text := html.UnescapeString(input)
+ text = htmlSpacePattern.ReplaceAllString(strings.TrimSpace(text), " ")
+ return text
}
-func cookiePathMatches(requestPath string, cookiePath string) bool {
- path := strings.TrimSpace(cookiePath)
- if path == "" || path == "/" {
- return true
+func estimateMarkdownTokens(content string) int {
+ trimmed := strings.TrimSpace(content)
+ if trimmed == "" {
+ return 0
}
- return strings.HasPrefix(requestPath, path)
+ return (len(trimmed) + 3) / 4
}
-func extractHTMLTitle(content string) string {
- if strings.TrimSpace(content) == "" {
- return ""
+func statusCodeFromResponse(response *network.Response) int {
+ if response == nil {
+ return 0
}
- matches := htmlTitlePattern.FindStringSubmatch(content)
- if len(matches) < 2 {
- return ""
- }
- return compactPlainText(matches[1], 120)
+ return int(response.Status)
}
-func extractHTMLSnippet(content string, limit int) string {
- if strings.TrimSpace(content) == "" {
+func extractWebPageTitle(content string, contentType string) string {
+ trimmed := strings.TrimSpace(content)
+ if trimmed == "" {
return ""
}
- sanitized := htmlScriptStylePattern.ReplaceAllString(content, " ")
- sanitized = htmlTagPattern.ReplaceAllString(sanitized, " ")
- return compactPlainText(sanitized, limit)
-}
-
-func extractWebPageTitle(content string, contentType string) string {
if strings.Contains(strings.ToLower(contentType), "markdown") {
- if title := extractMarkdownTitle(content); title != "" {
- return title
+ if matches := markdownFrontMatterTitlePattern.FindStringSubmatch(trimmed); len(matches) > 1 {
+ return strings.TrimSpace(matches[1])
}
- }
- if title := extractHTMLTitle(content); title != "" {
- return title
- }
- return extractMarkdownTitle(content)
-}
-
-func extractWebPageSnippet(content string, contentType string, limit int) string {
- if strings.Contains(strings.ToLower(contentType), "markdown") {
- if snippet := extractMarkdownSnippet(content, limit); snippet != "" {
- return snippet
+ if matches := markdownHeadingPattern.FindStringSubmatch(trimmed); len(matches) > 1 {
+ return strings.TrimSpace(matches[1])
}
- }
- if snippet := extractHTMLSnippet(content, limit); snippet != "" {
- return snippet
- }
- return extractMarkdownSnippet(content, limit)
-}
-
-func extractMarkdownTitle(content string) string {
- if strings.TrimSpace(content) == "" {
return ""
}
- if strings.HasPrefix(strings.TrimSpace(content), "---") {
- matches := markdownFrontMatterTitlePattern.FindStringSubmatch(content)
- if len(matches) >= 2 {
- return compactPlainText(matches[1], 120)
- }
- }
- matches := markdownHeadingPattern.FindStringSubmatch(content)
- if len(matches) >= 2 {
- return compactPlainText(matches[1], 120)
+ if matches := htmlTitlePattern.FindStringSubmatch(trimmed); len(matches) > 1 {
+ return compactText(matches[1])
}
return ""
}
-func extractMarkdownSnippet(content string, limit int) string {
- if strings.TrimSpace(content) == "" {
- return ""
- }
- normalized := markdownCodeFencePattern.ReplaceAllString(content, " ")
- normalized = markdownImagePattern.ReplaceAllString(normalized, "$1")
- normalized = markdownLinkPattern.ReplaceAllString(normalized, "$1")
- normalized = markdownHeadingMarkerPattern.ReplaceAllString(normalized, "")
- normalized = markdownListMarkerPattern.ReplaceAllString(normalized, "")
- normalized = markdownBlockQuotePattern.ReplaceAllString(normalized, "")
- normalized = markdownTableSepPattern.ReplaceAllString(normalized, " ")
- normalized = strings.NewReplacer("*", " ", "_", " ", "`", " ", "~", " ").Replace(normalized)
- return compactPlainText(normalized, limit)
-}
-
-func compactPlainText(value string, limit int) string {
- if limit <= 0 {
- limit = 320
- }
- text := html.UnescapeString(value)
- text = htmlSpacePattern.ReplaceAllString(strings.TrimSpace(text), " ")
- if len(text) > limit {
- return strings.TrimSpace(text[:limit]) + "..."
- }
- return text
-}
-
-func resolveWebSearchProviderMap(config map[string]any, provider string) map[string]any {
- if provider == "" {
- return nil
- }
- return getNestedMap(config, "web", "search", "providers", provider)
-}
-
-func resolveWebSearchProviderString(config map[string]any, provider string, key string) string {
- if key == "" {
+func extractWebPageSnippet(content string, contentType string, maxChars int) string {
+ trimmed := strings.TrimSpace(content)
+ if trimmed == "" {
return ""
}
- providerConfig := resolveWebSearchProviderMap(config, provider)
- if providerConfig == nil {
- return ""
- }
- value, ok := providerConfig[key]
- if !ok {
- return ""
- }
- if str, ok := value.(string); ok {
- return strings.TrimSpace(str)
- }
- return ""
-}
-
-func resolveWebSearchInt(config map[string]any, key string, fallback int) int {
- value, ok := getNestedInt(config, "web", "search", key)
- if !ok || value <= 0 {
- return fallback
- }
- return value
-}
-
-func normalizeWebSearchCacheKey(provider string, query string, count int, country string, searchLang string, uiLang string, freshness string) string {
- parts := []string{
- strings.ToLower(strings.TrimSpace(provider)),
- strings.ToLower(strings.TrimSpace(query)),
- strings.ToLower(strings.TrimSpace(country)),
- strings.ToLower(strings.TrimSpace(searchLang)),
- strings.ToLower(strings.TrimSpace(uiLang)),
- strings.ToLower(strings.TrimSpace(freshness)),
- strconv.Itoa(count),
- }
- return strings.Join(parts, "|")
-}
-
-func readWebSearchCache(key string) (webSearchResponse, bool) {
- webSearchCache.mu.RLock()
- entry, ok := webSearchCache.entries[key]
- webSearchCache.mu.RUnlock()
- if !ok {
- return webSearchResponse{}, false
- }
- if time.Now().After(entry.expiresAt) {
- webSearchCache.mu.Lock()
- delete(webSearchCache.entries, key)
- webSearchCache.mu.Unlock()
- return webSearchResponse{}, false
- }
- return entry.value, true
-}
-
-func writeWebSearchCache(key string, value webSearchResponse, ttl time.Duration) {
- if ttl <= 0 {
- return
- }
- webSearchCache.mu.Lock()
- webSearchCache.entries[key] = webSearchCacheEntry{
- value: value,
- expiresAt: time.Now().Add(ttl),
- }
- webSearchCache.mu.Unlock()
-}
-
-func runBraveSearch(ctx context.Context, config map[string]any, query string, count int, country string, searchLang string, uiLang string, freshness string, timeoutSeconds int) (webSearchResponse, error) {
- apiKey := resolveWebSearchProviderString(config, "brave", "apiKey")
- if apiKey == "" {
- apiKey = getNestedString(config, "web", "search", "brave", "apiKey")
- }
- if apiKey == "" {
- apiKey = getNestedString(config, "web", "search", "apiKey")
- }
- if apiKey == "" {
- apiKey = strings.TrimSpace(os.Getenv("BRAVE_API_KEY"))
- }
- if apiKey == "" {
- return webSearchResponse{}, errors.New("web_search needs a Brave API key")
- }
- endpoint := resolveWebSearchProviderString(config, "brave", "apiBaseUrl")
- if endpoint == "" {
- endpoint = getNestedString(config, "web", "search", "brave", "baseUrl")
- }
- if endpoint == "" {
- endpoint = getNestedString(config, "web", "search", "baseUrl")
- }
- if endpoint == "" {
- endpoint = braveSearchEndpoint
- }
- values := url.Values{}
- values.Set("q", query)
- if count > 0 {
- values.Set("count", strconv.Itoa(count))
- }
- if country != "" {
- values.Set("country", country)
- }
- if searchLang != "" {
- values.Set("search_lang", searchLang)
- }
- if uiLang != "" {
- values.Set("ui_lang", uiLang)
- }
- if freshness != "" {
- values.Set("freshness", freshness)
- }
- request, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint+"?"+values.Encode(), nil)
- if err != nil {
- return webSearchResponse{}, err
- }
- request.Header.Set("Accept", "application/json")
- request.Header.Set("X-Subscription-Token", apiKey)
- timeout := time.Duration(timeoutSeconds) * time.Second
- if timeout <= 0 {
- timeout = time.Duration(defaultWebSearchTimeoutSeconds) * time.Second
- }
- client := &http.Client{Timeout: timeout}
- resp, err := client.Do(request)
- if err != nil {
- return webSearchResponse{}, err
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return webSearchResponse{}, err
- }
- if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
- message := strings.TrimSpace(string(body))
- if message == "" {
- message = resp.Status
- }
- return webSearchResponse{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, message)
- }
- var parsed braveSearchResponse
- if err := json.Unmarshal(body, &parsed); err != nil {
- return webSearchResponse{}, err
- }
- results := make([]webSearchResult, 0, len(parsed.Web.Results))
- for _, item := range parsed.Web.Results {
- results = append(results, webSearchResult{
- Title: strings.TrimSpace(item.Title),
- URL: strings.TrimSpace(item.URL),
- Description: strings.TrimSpace(item.Description),
- Age: strings.TrimSpace(item.Age),
- })
- }
- if count > 0 && len(results) > count {
- results = results[:count]
- }
- return webSearchResponse{
- Query: query,
- Provider: "brave",
- Results: results,
- }, nil
-}
-
-func runTavilySearch(ctx context.Context, config map[string]any, payload toolArgs, query string, count int, timeoutSeconds int) (webSearchResponse, error) {
- apiKey := resolveWebSearchProviderString(config, "tavily", "apiKey")
- if apiKey == "" {
- apiKey = getNestedString(config, "web", "search", "tavily", "apiKey")
- }
- if apiKey == "" {
- apiKey = getNestedString(config, "web", "search", "apiKey")
- }
- if apiKey == "" {
- apiKey = strings.TrimSpace(os.Getenv("TAVILY_API_KEY"))
- }
- if apiKey == "" {
- return webSearchResponse{}, errors.New("web_search needs a Tavily API key")
- }
- endpoint := resolveWebSearchProviderString(config, "tavily", "apiBaseUrl")
- if endpoint == "" {
- endpoint = getNestedString(config, "web", "search", "tavily", "baseUrl")
- }
- if endpoint == "" {
- endpoint = getNestedString(config, "web", "search", "baseUrl")
- }
- if endpoint == "" {
- endpoint = tavilySearchEndpoint
- }
-
- requestBody := map[string]any{
- "query": query,
- }
- if count > 0 {
- requestBody["max_results"] = count
- }
- if value := getStringArg(payload, "search_depth", "searchDepth"); value != "" {
- requestBody["search_depth"] = value
- }
- if value := getStringArg(payload, "topic"); value != "" {
- requestBody["topic"] = value
- }
- if value := getStringArg(payload, "time_range", "timeRange"); value != "" {
- requestBody["time_range"] = value
- }
- if value := getStringArg(payload, "start_date", "startDate"); value != "" {
- requestBody["start_date"] = value
- }
- if value := getStringArg(payload, "end_date", "endDate"); value != "" {
- requestBody["end_date"] = value
- }
- if value := getStringArg(payload, "country"); value != "" {
- requestBody["country"] = value
- }
- if value := getStringArg(payload, "include_answer", "includeAnswer"); value != "" {
- requestBody["include_answer"] = value
- } else if includeAnswer, ok := getBoolArg(payload, "include_answer", "includeAnswer"); ok {
- requestBody["include_answer"] = includeAnswer
- }
- if value := getStringArg(payload, "include_raw_content", "includeRawContent"); value != "" {
- requestBody["include_raw_content"] = value
- } else if includeRaw, ok := getBoolArg(payload, "include_raw_content", "includeRawContent"); ok {
- requestBody["include_raw_content"] = includeRaw
- }
- if includeImages, ok := getBoolArg(payload, "include_images", "includeImages"); ok {
- requestBody["include_images"] = includeImages
- }
- if includeDescriptions, ok := getBoolArg(payload, "include_image_descriptions", "includeImageDescriptions"); ok {
- requestBody["include_image_descriptions"] = includeDescriptions
- }
- if includeFavicon, ok := getBoolArg(payload, "include_favicon", "includeFavicon"); ok {
- requestBody["include_favicon"] = includeFavicon
- }
- if autoParams, ok := getBoolArg(payload, "auto_parameters", "autoParameters"); ok {
- requestBody["auto_parameters"] = autoParams
- }
- if chunks, ok := getIntArg(payload, "chunks_per_source", "chunksPerSource"); ok && chunks > 0 {
- requestBody["chunks_per_source"] = chunks
- }
- if includeDomains := getStringSliceArg(payload, "include_domains", "includeDomains"); len(includeDomains) > 0 {
- requestBody["include_domains"] = includeDomains
- }
- if excludeDomains := getStringSliceArg(payload, "exclude_domains", "excludeDomains"); len(excludeDomains) > 0 {
- requestBody["exclude_domains"] = excludeDomains
- }
-
- encoded, err := json.Marshal(requestBody)
- if err != nil {
- return webSearchResponse{}, err
- }
- request, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(encoded))
- if err != nil {
- return webSearchResponse{}, err
- }
- request.Header.Set("Accept", "application/json")
- request.Header.Set("Content-Type", "application/json")
- request.Header.Set("Authorization", "Bearer "+apiKey)
-
- timeout := time.Duration(timeoutSeconds) * time.Second
- if timeout <= 0 {
- timeout = time.Duration(defaultWebSearchTimeoutSeconds) * time.Second
- }
- client := &http.Client{Timeout: timeout}
- resp, err := client.Do(request)
- if err != nil {
- return webSearchResponse{}, err
- }
- defer resp.Body.Close()
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return webSearchResponse{}, err
- }
- if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
- message := strings.TrimSpace(string(body))
- if message == "" {
- message = resp.Status
- }
- return webSearchResponse{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, message)
- }
- var parsed tavilySearchResponse
- if err := json.Unmarshal(body, &parsed); err != nil {
- return webSearchResponse{}, err
- }
- results := make([]webSearchResult, 0, len(parsed.Results))
- for _, item := range parsed.Results {
- description := strings.TrimSpace(item.Content)
- if description == "" {
- if raw, ok := item.RawContent.(string); ok {
- description = strings.TrimSpace(raw)
- }
- }
- results = append(results, webSearchResult{
- Title: strings.TrimSpace(item.Title),
- URL: strings.TrimSpace(item.URL),
- Description: description,
- })
- }
- if count > 0 && len(results) > count {
- results = results[:count]
- }
- return webSearchResponse{
- Query: query,
- Provider: "tavily",
- Results: results,
- }, nil
-}
-
-func isTimeoutError(err error) bool {
- if err == nil {
- return false
+ var snippet string
+ if strings.Contains(strings.ToLower(contentType), "markdown") {
+ snippet = markdownCodeFencePattern.ReplaceAllString(trimmed, "")
+ snippet = markdownImagePattern.ReplaceAllString(snippet, "$1")
+ snippet = markdownLinkPattern.ReplaceAllString(snippet, "$1")
+ snippet = markdownHeadingMarkerPattern.ReplaceAllString(snippet, "")
+ snippet = markdownListMarkerPattern.ReplaceAllString(snippet, "")
+ snippet = markdownBlockQuotePattern.ReplaceAllString(snippet, "")
+ snippet = markdownTableSepPattern.ReplaceAllString(snippet, "")
+ } else {
+ snippet = htmlScriptStylePattern.ReplaceAllString(trimmed, "")
+ snippet = htmlTagPattern.ReplaceAllString(snippet, " ")
}
- if errors.Is(err, context.DeadlineExceeded) {
- return true
+ snippet = compactText(snippet)
+ if maxChars > 0 && len(snippet) > maxChars {
+ return snippet[:maxChars]
}
- lower := strings.ToLower(strings.TrimSpace(err.Error()))
- return strings.Contains(lower, "timeout") || strings.Contains(lower, "timed out")
+ return snippet
}
-func trimToMaxChars(value string, max int) string {
- trimmed := strings.TrimSpace(value)
- if max <= 0 {
- return trimmed
- }
- runes := []rune(trimmed)
- if len(runes) <= max {
- return trimmed
+func resolveWebSearchType(config map[string]any) string {
+ value := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "type")))
+ switch value {
+ case "external_tools":
+ return "external_tools"
+ case "api":
+ return "api"
+ default:
+ return defaultWebSearchType
}
- return string(runes[:max])
}
-func normalizeHTTPHeaderValues(header http.Header) map[string]string {
- result := make(map[string]string, len(header))
- for key, values := range header {
- trimmedKey := strings.ToLower(strings.TrimSpace(key))
- if trimmedKey == "" {
- continue
+func resolveWebSearchCount(payload toolArgs, config map[string]any) int {
+ if value, ok := getIntArg(payload, "count", "maxResults"); ok && value > 0 {
+ if value > maxWebSearchCount {
+ return maxWebSearchCount
}
- result[trimmedKey] = strings.TrimSpace(strings.Join(values, ", "))
+ return value
}
- return result
-}
-
-func normalizeHTTPHeaderMap(headers map[string]string) map[string]string {
- result := make(map[string]string, len(headers))
- for key, value := range headers {
- trimmedKey := strings.ToLower(strings.TrimSpace(key))
- if trimmedKey == "" {
- continue
+ if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "maxResults"); ok && value > 0 {
+ if value > maxWebSearchCount {
+ return maxWebSearchCount
}
- result[trimmedKey] = strings.TrimSpace(value)
- }
- return result
-}
-
-func resolvePlaywrightReadySelector(finalURL string, targetURL string) string {
- host := extractHostname(finalURL)
- if host == "" {
- host = extractHostname(targetURL)
- }
- switch {
- case hostMatchesDomain(host, "google.com"):
- return "#search div.g, #search a h3"
- case hostMatchesDomain(host, "xiaohongshu.com"):
- return ".note-item, .search-result, section"
- default:
- return "main, article, body"
- }
-}
-
-func waitForPlaywrightReady(page playwright.Page, selector string, timeoutMs float64) error {
- if strings.TrimSpace(selector) == "" {
- return nil
+ return value
}
- _, err := page.WaitForSelector(selector, playwright.PageWaitForSelectorOptions{
- Timeout: playwright.Float(timeoutMs),
- })
- return err
+ return defaultWebSearchCount
}
-func contentWithTimeout(ctx context.Context, page playwright.Page, timeout time.Duration) (string, error) {
- type result struct {
- content string
- err error
+func resolveWebSearchTimeoutSeconds(payload toolArgs, config map[string]any) int {
+ if value, ok := getIntArg(payload, "timeoutSeconds"); ok && value > 0 {
+ return value
}
- done := make(chan result, 1)
- go func() {
- content, err := page.Content()
- done <- result{content: content, err: err}
- }()
-
- timer := time.NewTimer(timeout)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- return "", ctx.Err()
- case <-timer.C:
- return "", errors.New("playwright extract timeout")
- case output := <-done:
- return output.content, output.err
+ if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "timeoutSeconds"); ok && value > 0 {
+ return value
}
+ return defaultWebSearchTimeoutSeconds
}
-func extractHostname(rawURL string) string {
- parsed, err := url.Parse(strings.TrimSpace(rawURL))
- if err != nil {
- return ""
+func resolveWebSearchCacheTTL(config map[string]any) time.Duration {
+ if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "cacheTtlMinutes"); ok && value > 0 {
+ return time.Duration(value) * time.Minute
}
- return strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ return time.Duration(defaultWebSearchCacheTtlMinutes) * time.Minute
}
diff --git a/internal/application/gateway/tools/web_tools_test.go b/internal/application/gateway/tools/web_tools_test.go
index fee764e..83b66d8 100644
--- a/internal/application/gateway/tools/web_tools_test.go
+++ b/internal/application/gateway/tools/web_tools_test.go
@@ -3,8 +3,6 @@ package tools
import (
"context"
"encoding/json"
- "net/http"
- "net/http/httptest"
"strings"
"testing"
@@ -28,156 +26,59 @@ func builtinWebFetchSettingsStub() *settingsReaderStub {
settings: settingsdto.Settings{
Tools: map[string]any{
"web_fetch": map[string]any{
- "type": "builtin",
+ "preferredBrowser": "chrome",
},
},
},
}
}
-func TestRunWebFetchTool_DefaultLLMHeaders(t *testing.T) {
+func TestRunWebFetchTool_RejectsNonGETMethods(t *testing.T) {
t.Parallel()
- var acceptHeader string
- var userAgentHeader string
- var acceptLanguageHeader string
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- acceptHeader = r.Header.Get("Accept")
- userAgentHeader = r.Header.Get("User-Agent")
- acceptLanguageHeader = r.Header.Get("Accept-Language")
- w.Header().Set("Content-Type", "text/markdown; charset=utf-8")
- w.Header().Set("x-markdown-tokens", "256")
- w.Header().Set("content-signal", "ai-input=yes")
- _, _ = w.Write([]byte("# Hello\n\nworld"))
- }))
- defer server.Close()
-
handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil)
- output, err := handler(context.Background(), `{"url":"`+server.URL+`"}`)
+ output, err := handler(context.Background(), `{"url":"https://example.com","method":"POST"}`)
if err != nil {
t.Fatalf("run web_fetch: %v", err)
}
- if !strings.Contains(acceptHeader, "text/markdown") {
- t.Fatalf("expected markdown accept header, got %q", acceptHeader)
- }
- if !strings.Contains(userAgentHeader, "Version/17.0") || !strings.Contains(userAgentHeader, "Safari/605.1.15") {
- t.Fatalf("expected default user-agent, got %q", userAgentHeader)
- }
- if acceptLanguageHeader != "en-US,en;q=0.9" {
- t.Fatalf("expected default accept-language, got %q", acceptLanguageHeader)
- }
-
var payload webFetchResult
if err := json.Unmarshal([]byte(output), &payload); err != nil {
t.Fatalf("decode result: %v", err)
}
- if payload.MarkdownTokens != 256 {
- t.Fatalf("unexpected markdownTokens: %d", payload.MarkdownTokens)
- }
- if payload.ContentSignal != "ai-input=yes" {
- t.Fatalf("unexpected content signal: %q", payload.ContentSignal)
- }
- if !strings.Contains(strings.ToLower(payload.ContentType), "markdown") {
- t.Fatalf("expected markdown content type, got %q", payload.ContentType)
- }
-}
-
-func TestRunWebFetchTool_HeaderSwitches(t *testing.T) {
- t.Parallel()
-
- var acceptHeader string
- var userAgentHeader string
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- acceptHeader = r.Header.Get("Accept")
- userAgentHeader = r.Header.Get("User-Agent")
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("ok"))
- }))
- defer server.Close()
-
- handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil)
- output, err := handler(context.Background(), `{"url":"`+server.URL+`","acceptMarkdown":false,"enableUserAgent":false}`)
- if err != nil {
- t.Fatalf("run web_fetch: %v", err)
- }
- if strings.Contains(strings.ToLower(acceptHeader), "text/markdown") {
- t.Fatalf("expected html accept header, got %q", acceptHeader)
- }
- if userAgentHeader != "" {
- t.Fatalf("expected empty user-agent when disabled, got %q", userAgentHeader)
- }
-
- var payload webFetchResult
- if err := json.Unmarshal([]byte(output), &payload); err != nil {
- t.Fatalf("decode result: %v", err)
- }
- if payload.Status != webStatusOK {
+ if payload.Status != webStatusError {
t.Fatalf("unexpected status: %q", payload.Status)
}
- if payload.HTTPStatus != http.StatusOK {
- t.Fatalf("unexpected http status: %d", payload.HTTPStatus)
+ if payload.Message != "web_fetch only supports GET" {
+ t.Fatalf("unexpected message: %q", payload.Message)
+ }
+ if payload.NextAction != nextActionUseOtherToolsOrSkills {
+ t.Fatalf("unexpected next action: %q", payload.NextAction)
}
}
-func TestRunWebFetchTool_ReadsTopLevelWebFetchSettings(t *testing.T) {
+func TestResolveWebFetchPreferredBrowser_ReadsTopLevelConfig(t *testing.T) {
t.Parallel()
- var acceptHeader string
- var userAgentHeader string
- var acceptLanguageHeader string
- var customHeader string
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- acceptHeader = r.Header.Get("Accept")
- userAgentHeader = r.Header.Get("User-Agent")
- acceptLanguageHeader = r.Header.Get("Accept-Language")
- customHeader = r.Header.Get("X-Test")
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("ok"))
- }))
- defer server.Close()
-
- handler := runWebFetchTool(&settingsReaderStub{
- settings: settingsdto.Settings{
- Tools: map[string]any{
- "web_fetch": map[string]any{
- "type": "builtin",
- "acceptMarkdown": false,
- "enableUserAgent": false,
- "acceptLanguage": "en-US,en;q=0.9",
- "headers": map[string]any{
- "X-Test": "enabled",
- },
- },
- },
+ preferred := resolveWebFetchPreferredBrowser(map[string]any{
+ "web_fetch": map[string]any{
+ "preferredBrowser": "brave",
},
- }, nil)
- output, err := handler(context.Background(), `{"url":"`+server.URL+`"}`)
- if err != nil {
- t.Fatalf("run web_fetch: %v", err)
- }
- if strings.Contains(strings.ToLower(acceptHeader), "text/markdown") {
- t.Fatalf("expected html accept header from settings, got %q", acceptHeader)
- }
- if userAgentHeader != "" {
- t.Fatalf("expected empty user-agent from settings, got %q", userAgentHeader)
- }
- if acceptLanguageHeader != "en-US,en;q=0.9" {
- t.Fatalf("expected accept-language from settings, got %q", acceptLanguageHeader)
- }
- if customHeader != "enabled" {
- t.Fatalf("expected custom header from settings, got %q", customHeader)
+ "browser": map[string]any{
+ "preferredBrowser": "chrome",
+ },
+ })
+ if preferred != "brave" {
+ t.Fatalf("expected preferred browser brave, got %q", preferred)
}
- var payload webFetchResult
- if err := json.Unmarshal([]byte(output), &payload); err != nil {
- t.Fatalf("decode result: %v", err)
- }
- if payload.Status != webStatusOK {
- t.Fatalf("unexpected status: %q", payload.Status)
- }
- if payload.HTTPStatus != http.StatusOK {
- t.Fatalf("unexpected http status: %d", payload.HTTPStatus)
+ fallback := resolveWebFetchPreferredBrowser(map[string]any{
+ "browser": map[string]any{
+ "preferredBrowser": "edge",
+ },
+ })
+ if fallback != "edge" {
+ t.Fatalf("expected fallback preferred browser edge, got %q", fallback)
}
}
@@ -186,15 +87,9 @@ func TestResolveWebFetchOptions_ReadsTopLevelWebFetchConfig(t *testing.T) {
options := resolveWebFetchOptions(toolArgs{}, map[string]any{
"web_fetch": map[string]any{
- "timeoutSeconds": 21,
- "maxChars": 1234,
- "maxBodyBytes": 4321,
- "acceptMarkdown": false,
- "enableUserAgent": false,
- "userAgent": "top-level-agent",
- "headers": map[string]any{
- "X-One": "1",
- },
+ "timeoutSeconds": 21,
+ "maxChars": 1234,
+ "maxBodyBytes": 4321,
},
}, defaultWebFetchTimeoutSeconds)
@@ -207,56 +102,30 @@ func TestResolveWebFetchOptions_ReadsTopLevelWebFetchConfig(t *testing.T) {
if options.MaxBodyBytes != 4321 {
t.Fatalf("expected maxBodyBytes from web_fetch, got %d", options.MaxBodyBytes)
}
- if options.MaxRedirects != defaultWebFetchMaxRedirects {
- t.Fatalf("expected default maxRedirects, got %d", options.MaxRedirects)
- }
- if options.RetryMax != defaultWebFetchRetryMax {
- t.Fatalf("expected default retryMax, got %d", options.RetryMax)
- }
- if options.AcceptMarkdown {
- t.Fatalf("expected acceptMarkdown=false from web_fetch")
- }
- if options.EnableUserAgent {
- t.Fatalf("expected enableUserAgent=false from web_fetch")
- }
- if options.UserAgent != "top-level-agent" {
- t.Fatalf("expected userAgent from web_fetch, got %q", options.UserAgent)
- }
- if options.Headers["X-One"] != "1" {
- t.Fatalf("expected headers from web_fetch, got %#v", options.Headers)
- }
}
-func TestRunWebFetchTool_TruncatesLargeBodiesByMaxBodyBytes(t *testing.T) {
+func TestBuildWebFetchToolResult_AnnotatesCDPSource(t *testing.T) {
t.Parallel()
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/plain; charset=utf-8")
- _, _ = w.Write([]byte(strings.Repeat("a", 512)))
- }))
- defer server.Close()
-
- handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil)
- output, err := handler(context.Background(), `{"url":"`+server.URL+`","maxBodyBytes":64,"maxChars":200}`)
- if err != nil {
- t.Fatalf("run web_fetch: %v", err)
- }
-
- var payload webFetchResult
- if err := json.Unmarshal([]byte(output), &payload); err != nil {
- t.Fatalf("decode result: %v", err)
- }
+ payload := buildWebFetchToolResult("https://example.com", webFetchResponse{
+ FinalURL: "https://example.com/article",
+ Status: 200,
+ ContentType: "text/markdown",
+ Content: "# Hello\n\nworld",
+ MarkdownTokens: 4,
+ ContentSignal: "article_readability",
+ }, nil)
if payload.Status != webStatusOK {
t.Fatalf("unexpected status: %q", payload.Status)
}
- if !payload.Truncated {
- t.Fatalf("expected payload to be truncated")
+ if payload.Data["browserSource"] != webFetchTypeCDP {
+ t.Fatalf("expected browser source %q, got %#v", webFetchTypeCDP, payload.Data["browserSource"])
}
- if len(payload.Content) != 64 {
- t.Fatalf("expected content to be capped at 64 bytes, got %d", len(payload.Content))
+ if payload.ContentSignal != "article_readability" {
+ t.Fatalf("unexpected content signal: %q", payload.ContentSignal)
}
- if payload.Content != strings.Repeat("a", 64) {
- t.Fatalf("unexpected truncated content length=%d", len(payload.Content))
+ if payload.MarkdownTokens != 4 {
+ t.Fatalf("unexpected markdown tokens: %d", payload.MarkdownTokens)
}
}
@@ -281,33 +150,24 @@ Markdown has quickly become the lingua franca for agents.`
}
}
-func TestResolveWebFetchTypeDefaultsToBuiltin(t *testing.T) {
+func TestResolveWebFetchTypeDefaultsToCDP(t *testing.T) {
t.Parallel()
fetchType, err := resolveWebFetchType(toolArgs{}, map[string]any{})
if err != nil {
t.Fatalf("resolve web_fetch type: %v", err)
}
- if fetchType != webFetchTypeBuiltin {
- t.Fatalf("expected default fetch type builtin, got %q", fetchType)
+ if fetchType != webFetchTypeCDP {
+ t.Fatalf("expected default fetch type cdp, got %q", fetchType)
}
}
-func TestResolveWebFetchPlaywrightOptionsDefaultsMarkdownEnabled(t *testing.T) {
- t.Parallel()
-
- options := resolveWebFetchPlaywrightOptions(toolArgs{}, map[string]any{})
- if !options.Markdown {
- t.Fatalf("expected markdown conversion enabled by default")
- }
-}
-
-func TestResolveWebFetchTypeOnlyAcceptsBuiltinAndPlaywright(t *testing.T) {
+func TestResolveWebFetchTypeNormalizesLegacyLabelsToCDP(t *testing.T) {
t.Parallel()
fetchType := normalizeWebFetchType("builtin")
- if fetchType != webFetchTypeBuiltin {
- t.Fatalf("expected builtin fetch type, got %q", fetchType)
+ if fetchType != webFetchTypeCDP {
+ t.Fatalf("expected builtin label to normalize to cdp, got %q", fetchType)
}
if _, err := resolveWebFetchType(toolArgs{"type": "terminal"}, map[string]any{}); err == nil {
@@ -327,6 +187,21 @@ func TestConnectorTypeForURL(t *testing.T) {
if connectorType := connectorTypeForURL("https://www.google.com/search?q=test"); connectorType != "google" {
t.Fatalf("expected google connector, got %q", connectorType)
}
+ if connectorType := connectorTypeForURL("https://www.youtube.com/watch?v=test"); connectorType != "google" {
+ t.Fatalf("expected youtube URLs to use google connector cookies, got %q", connectorType)
+ }
+ if connectorType := connectorTypeForURL("https://github.com/owner/repo"); connectorType != "github" {
+ t.Fatalf("expected github connector, got %q", connectorType)
+ }
+ if connectorType := connectorTypeForURL("https://www.reddit.com/r/golang/comments/test"); connectorType != "reddit" {
+ t.Fatalf("expected reddit connector, got %q", connectorType)
+ }
+ if connectorType := connectorTypeForURL("https://www.zhihu.com/question/123456"); connectorType != "zhihu" {
+ t.Fatalf("expected zhihu connector, got %q", connectorType)
+ }
+ if connectorType := connectorTypeForURL("https://x.com/example/status/1"); connectorType != "x" {
+ t.Fatalf("expected x connector, got %q", connectorType)
+ }
if connectorType := connectorTypeForURL("https://www.xiaohongshu.com/explore"); connectorType != "xiaohongshu" {
t.Fatalf("expected xiaohongshu connector, got %q", connectorType)
}
diff --git a/internal/application/sitepolicy/policy.go b/internal/application/sitepolicy/policy.go
new file mode 100644
index 0000000..636e9c9
--- /dev/null
+++ b/internal/application/sitepolicy/policy.go
@@ -0,0 +1,339 @@
+package sitepolicy
+
+import (
+ "net/url"
+ "strings"
+)
+
+type Policy struct {
+ Key string
+ ConnectorType string
+ Domains []string
+ ReadySelectors []string
+ ExtractorSelectors []string
+ RemoveSelectors []string
+ Capabilities []string
+}
+
+var builtinPolicyOrder = []string{
+ "youtube",
+ "google",
+ "github",
+ "reddit",
+ "zhihu",
+ "x",
+ "xiaohongshu",
+ "bilibili",
+}
+
+var builtinPolicies = map[string]Policy{
+ "youtube": {
+ Key: "youtube",
+ ConnectorType: "google",
+ Domains: []string{
+ "youtube.com",
+ "youtu.be",
+ },
+ ReadySelectors: []string{
+ "ytd-watch-flexy",
+ "#content",
+ "main",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "#description",
+ "#description-inline-expander",
+ "ytd-watch-metadata",
+ "main",
+ },
+ RemoveSelectors: []string{
+ "#related",
+ "ytd-comments",
+ "ytd-merch-shelf-renderer",
+ "ytd-rich-grid-renderer",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser", "download"},
+ },
+ "google": {
+ Key: "google",
+ ConnectorType: "google",
+ Domains: []string{
+ "google.com",
+ "youtube.com",
+ "youtu.be",
+ },
+ ReadySelectors: []string{
+ "#search",
+ "main",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ "#content",
+ },
+ RemoveSelectors: []string{
+ "#related",
+ "#secondary",
+ "ytd-comments",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser", "download"},
+ },
+ "github": {
+ Key: "github",
+ ConnectorType: "github",
+ Domains: []string{
+ "github.com",
+ "raw.githubusercontent.com",
+ },
+ ReadySelectors: []string{
+ "main",
+ "#repo-content-pjax-container",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ ".markdown-body",
+ "[data-testid=\"issue-body\"]",
+ "[data-testid=\"pull-request-comment\"]",
+ },
+ RemoveSelectors: []string{
+ "header",
+ "footer",
+ ".Layout-sidebar",
+ ".js-header-wrapper",
+ "#repos-sticky-header",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser"},
+ },
+ "reddit": {
+ Key: "reddit",
+ ConnectorType: "reddit",
+ Domains: []string{
+ "reddit.com",
+ "redd.it",
+ },
+ ReadySelectors: []string{
+ "main",
+ "[data-testid=\"post-container\"]",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ "[data-testid=\"post-container\"]",
+ ".md",
+ },
+ RemoveSelectors: []string{
+ "nav",
+ "[data-testid=\"frontpage-sidebar\"]",
+ "shreddit-comments-page-ad",
+ "shreddit-experience-tree",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser"},
+ },
+ "zhihu": {
+ Key: "zhihu",
+ ConnectorType: "zhihu",
+ Domains: []string{
+ "zhihu.com",
+ },
+ ReadySelectors: []string{
+ "main",
+ ".Question-main",
+ ".Post-content",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ ".Question-mainColumn",
+ ".Post-RichTextContainer",
+ "article",
+ },
+ RemoveSelectors: []string{
+ ".Question-sideColumn",
+ ".CornerButtons",
+ ".Recommendations-Main",
+ ".Comment-container",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser"},
+ },
+ "x": {
+ Key: "x",
+ ConnectorType: "x",
+ Domains: []string{
+ "x.com",
+ "twitter.com",
+ },
+ ReadySelectors: []string{
+ "main",
+ "[data-testid=\"primaryColumn\"]",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ "[data-testid=\"tweet\"]",
+ },
+ RemoveSelectors: []string{
+ "nav",
+ "[data-testid=\"sidebarColumn\"]",
+ "[aria-label=\"Timeline: Trending now\"]",
+ "[aria-label=\"Who to follow\"]",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser"},
+ },
+ "xiaohongshu": {
+ Key: "xiaohongshu",
+ ConnectorType: "xiaohongshu",
+ Domains: []string{
+ "xiaohongshu.com",
+ "xhslink.com",
+ "redbook.com",
+ },
+ ReadySelectors: []string{
+ "#app",
+ "main",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ "#noteContainer",
+ },
+ RemoveSelectors: []string{
+ ".note-side-bar",
+ ".recommend-container",
+ ".comment-container",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser"},
+ },
+ "bilibili": {
+ Key: "bilibili",
+ ConnectorType: "bilibili",
+ Domains: []string{
+ "bilibili.com",
+ "b23.tv",
+ },
+ ReadySelectors: []string{
+ "#app",
+ "#arc_toolbar_report",
+ "main",
+ "body",
+ },
+ ExtractorSelectors: []string{
+ "main",
+ "article",
+ "#app",
+ },
+ RemoveSelectors: []string{
+ ".video-toolbar-v1",
+ ".right-container",
+ ".comment-container",
+ },
+ Capabilities: []string{"cookies", "web_fetch", "browser", "download"},
+ },
+}
+
+func List() []Policy {
+ result := make([]Policy, 0, len(builtinPolicyOrder))
+ for _, key := range builtinPolicyOrder {
+ policy, ok := builtinPolicies[key]
+ if !ok {
+ continue
+ }
+ result = append(result, policy)
+ }
+ return result
+}
+
+func ForConnectorType(connectorType string) (Policy, bool) {
+ policy, ok := builtinPolicies[strings.ToLower(strings.TrimSpace(connectorType))]
+ return policy, ok
+}
+
+func ForURL(rawURL string) (Policy, bool) {
+ host := hostname(rawURL)
+ if host == "" {
+ return Policy{}, false
+ }
+ for _, key := range builtinPolicyOrder {
+ policy, ok := builtinPolicies[key]
+ if !ok {
+ continue
+ }
+ for _, domain := range policy.Domains {
+ if HostMatchesDomain(host, domain) {
+ return policy, true
+ }
+ }
+ }
+ return Policy{}, false
+}
+
+func DomainsForConnector(connectorType string) []string {
+ policy, ok := ForConnectorType(connectorType)
+ if !ok {
+ return nil
+ }
+ return cloneStrings(policy.Domains)
+}
+
+func ReadySelectorForURL(rawURL string) string {
+ policy, ok := ForURL(rawURL)
+ if !ok {
+ return ""
+ }
+ for _, selector := range policy.ReadySelectors {
+ if strings.TrimSpace(selector) != "" {
+ return strings.TrimSpace(selector)
+ }
+ }
+ return ""
+}
+
+func HostMatchesDomain(host string, domain string) bool {
+ normalizedHost := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(host)), ".")
+ normalizedDomain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(domain)), ".")
+ if normalizedHost == "" || normalizedDomain == "" {
+ return false
+ }
+ return normalizedHost == normalizedDomain || strings.HasSuffix(normalizedHost, "."+normalizedDomain)
+}
+
+func MatchDomains(rawURL string, domains []string) bool {
+ host := hostname(rawURL)
+ if host == "" {
+ return false
+ }
+ for _, domain := range domains {
+ if HostMatchesDomain(host, domain) {
+ return true
+ }
+ }
+ return false
+}
+
+func cloneStrings(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ result := make([]string, 0, len(values))
+ for _, value := range values {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ continue
+ }
+ result = append(result, trimmed)
+ }
+ return result
+}
+
+func hostname(rawURL string) string {
+ parsed, err := url.Parse(strings.TrimSpace(rawURL))
+ if err != nil {
+ return ""
+ }
+ return strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+}
diff --git a/internal/application/sitepolicy/policy_test.go b/internal/application/sitepolicy/policy_test.go
new file mode 100644
index 0000000..c0d5578
--- /dev/null
+++ b/internal/application/sitepolicy/policy_test.go
@@ -0,0 +1,53 @@
+package sitepolicy
+
+import "testing"
+
+func TestForURLPrefersYouTubePolicyBeforeGoogle(t *testing.T) {
+ t.Parallel()
+
+ policy, ok := ForURL("https://www.youtube.com/watch?v=dQw4w9WgXcQ")
+ if !ok {
+ t.Fatalf("expected youtube policy match")
+ }
+ if policy.Key != "youtube" {
+ t.Fatalf("expected youtube policy key, got %q", policy.Key)
+ }
+ if policy.ConnectorType != "google" {
+ t.Fatalf("expected youtube URLs to reuse google connector cookies, got %q", policy.ConnectorType)
+ }
+}
+
+func TestForConnectorTypeGoogleDomainsIncludeYouTube(t *testing.T) {
+ t.Parallel()
+
+ policy, ok := ForConnectorType("google")
+ if !ok {
+ t.Fatalf("expected google connector policy")
+ }
+ if !MatchDomains("https://www.youtube.com/watch?v=test", policy.Domains) {
+ t.Fatalf("expected google connector domains to cover youtube URLs")
+ }
+}
+
+func TestForURLMatchesNewBuiltinSites(t *testing.T) {
+ t.Parallel()
+
+ cases := map[string]string{
+ "https://github.com/owner/repo": "github",
+ "https://www.reddit.com/r/golang/comments/test": "reddit",
+ "https://www.zhihu.com/question/123456": "zhihu",
+ "https://x.com/example/status/1": "x",
+ "https://www.xiaohongshu.com/explore/abc": "xiaohongshu",
+ "https://www.bilibili.com/video/BV1xx411c7mD/": "bilibili",
+ }
+
+ for rawURL, expected := range cases {
+ policy, ok := ForURL(rawURL)
+ if !ok {
+ t.Fatalf("expected policy match for %s", rawURL)
+ }
+ if policy.Key != expected {
+ t.Fatalf("expected policy %q for %s, got %q", expected, rawURL, policy.Key)
+ }
+ }
+}
diff --git a/internal/application/skills/service/audit.go b/internal/application/skills/service/audit.go
index a3ae365..37aa51f 100644
--- a/internal/application/skills/service/audit.go
+++ b/internal/application/skills/service/audit.go
@@ -48,9 +48,9 @@ func resolveSkillsAuditSourceFromContext(ctx context.Context) string {
func resolveSkillsAuditGroup(action string) string {
switch strings.ToLower(strings.TrimSpace(action)) {
- case "skills.status", "skills.bins", "skill_manage.search", "skill_manage.list":
+ case "skills.status", "skills.bins", "skills_manage.search", "skills_manage.list":
return "read"
- case "skill_manage.install", "skill_manage.update", "skill_manage.remove", "skill_manage.sync":
+ case "skills_manage.install", "skills_manage.update", "skills_manage.remove", "skills_manage.sync":
return "package_write"
case "skills.install":
return "deps_write"
@@ -67,8 +67,8 @@ func resolveSkillsAuditTool(action string) string {
switch {
case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skills."):
return "skills"
- case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skill_manage."):
- return "skill_manage"
+ case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skills_manage."):
+ return "skills_manage"
default:
return ""
}
diff --git a/internal/application/skills/service/audit_test.go b/internal/application/skills/service/audit_test.go
index c52dd54..c260483 100644
--- a/internal/application/skills/service/audit_test.go
+++ b/internal/application/skills/service/audit_test.go
@@ -47,7 +47,7 @@ func TestPruneSkillsAuditEntriesByRetention(t *testing.T) {
"timestamp": now.AddDate(0, 0, -1).Format(time.RFC3339),
},
map[string]any{
- "action": "skill_manage.search",
+ "action": "skills_manage.search",
"timestamp": now.AddDate(0, 0, -30).Format(time.RFC3339),
},
}
diff --git a/internal/application/skills/service/search.go b/internal/application/skills/service/search.go
index cc4d33d..4f67a80 100644
--- a/internal/application/skills/service/search.go
+++ b/internal/application/skills/service/search.go
@@ -106,20 +106,20 @@ func (service *SkillsService) SearchSkills(ctx context.Context, request dto.Sear
providerID := "clawhub"
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, err)
return nil, err
}
output, err := service.runClawHubCommand(ctx, workspaceRoot, skillsSearchTimeout, "search", "--limit", fmt.Sprintf("%d", limit), query)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, err)
return nil, err
}
results := parseClawHubSearchOutput(output)
if len(results) > limit {
- service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, nil)
return results[:limit], nil
}
- service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, nil)
return results, nil
}
@@ -132,17 +132,17 @@ func (service *SkillsService) InspectSkill(ctx context.Context, request dto.Insp
providerID := "clawhub"
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err)
return dto.SkillDetail{}, err
}
output, err := service.runClawHubCommand(ctx, workspaceRoot, skillsSearchTimeout, "inspect", skill, "--json", "--files")
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err)
return dto.SkillDetail{}, err
}
detail, err := parseClawHubInspectOutput(output)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err)
return dto.SkillDetail{}, err
}
if version := service.resolveInstalledSkillVersion(ctx, workspaceRoot, skill); version != "" {
@@ -161,7 +161,7 @@ func (service *SkillsService) InspectSkill(ctx context.Context, request dto.Insp
if detail.Name == "" {
detail.Name = detail.ID
}
- service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, nil)
return detail, nil
}
@@ -174,7 +174,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst
providerID := "clawhub"
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, err)
return err
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -206,7 +206,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst
Force: request.Force,
Error: err.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, err)
return err
}
service.recordInstallAttempt(true)
@@ -219,7 +219,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst
WorkspaceRoot: workspaceRoot,
Force: request.Force,
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, nil)
return err
}
@@ -232,7 +232,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat
providerID := "clawhub"
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, err)
return err
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -264,7 +264,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat
Force: request.Force,
Error: err.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, err)
return err
}
service.recordInstallAttempt(true)
@@ -277,7 +277,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat
WorkspaceRoot: workspaceRoot,
Force: request.Force,
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, nil)
return err
}
@@ -290,7 +290,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
providerID := "clawhub"
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, err)
return err
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -312,7 +312,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
WorkspaceRoot: workspaceRoot,
Error: cleanupErr.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, cleanupErr)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, cleanupErr)
return cleanupErr
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -322,7 +322,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
AssistantID: request.AssistantID,
WorkspaceRoot: workspaceRoot,
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, nil)
return nil
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -333,7 +333,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
WorkspaceRoot: workspaceRoot,
Error: err.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, err)
return err
}
if cleanupErr := service.removeWorkspaceSkillDirectories(ctx, workspaceRoot, skill); cleanupErr != nil {
@@ -345,7 +345,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
WorkspaceRoot: workspaceRoot,
Error: cleanupErr.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, cleanupErr)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, cleanupErr)
return cleanupErr
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -355,7 +355,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov
AssistantID: request.AssistantID,
WorkspaceRoot: workspaceRoot,
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, nil)
return nil
}
@@ -449,7 +449,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk
}
workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot)
if err != nil {
- service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err)
return nil, err
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -468,7 +468,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk
WorkspaceRoot: workspaceRoot,
Error: err.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err)
return nil, err
}
result, err := service.ResolveSkillsForProviderInWorkspace(ctx, dto.ResolveSkillsRequest{
@@ -483,7 +483,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk
WorkspaceRoot: workspaceRoot,
Error: err.Error(),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err)
return nil, err
}
service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{
@@ -494,7 +494,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk
WorkspaceRoot: workspaceRoot,
CatalogCount: len(result),
})
- service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, nil)
+ service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, nil)
return result, nil
}
diff --git a/internal/application/skills/service/search_test.go b/internal/application/skills/service/search_test.go
index eb3f707..38a20c2 100644
--- a/internal/application/skills/service/search_test.go
+++ b/internal/application/skills/service/search_test.go
@@ -495,14 +495,8 @@ func TestClassifyClawHubCommandError(t *testing.T) {
func TestInstallSkillEmitsRealtimeStartedAndCompletedEvents(t *testing.T) {
t.Parallel()
- scriptPath := filepath.Join(t.TempDir(), "clawhub")
- script := "#!/bin/sh\nexit 0\n"
- if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
- t.Fatalf("write script failed: %v", err)
- }
-
svc := NewSkillsService(newMemorySkillsRepo(), nil)
- svc.SetExternalTools(&skillsExternalToolsStub{ready: true, execPath: scriptPath})
+ svc.SetPackageAdapter(&skillsPackageAdapterStub{})
fixedNow := time.Date(2026, time.March, 17, 12, 0, 0, 0, time.UTC)
svc.now = func() time.Time { return fixedNow }
@@ -556,14 +550,12 @@ func TestInstallSkillEmitsRealtimeStartedAndCompletedEvents(t *testing.T) {
func TestInstallSkillEmitsRealtimeFailedEvent(t *testing.T) {
t.Parallel()
- scriptPath := filepath.Join(t.TempDir(), "clawhub")
- script := "#!/bin/sh\necho \"install failed\" 1>&2\nexit 1\n"
- if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
- t.Fatalf("write script failed: %v", err)
- }
-
svc := NewSkillsService(newMemorySkillsRepo(), nil)
- svc.SetExternalTools(&skillsExternalToolsStub{ready: true, execPath: scriptPath})
+ svc.SetPackageAdapter(&skillsPackageAdapterStub{
+ run: func(_ context.Context, _ string, _ time.Duration, _ ...string) ([]byte, error) {
+ return nil, errors.New("install failed")
+ },
+ })
fixedNow := time.Date(2026, time.March, 17, 12, 0, 1, 0, time.UTC)
svc.now = func() time.Time { return fixedNow }
diff --git a/internal/application/softwareupdate/service.go b/internal/application/softwareupdate/service.go
index 49dabb7..f903433 100644
--- a/internal/application/softwareupdate/service.go
+++ b/internal/application/softwareupdate/service.go
@@ -2,6 +2,7 @@ package softwareupdate
import (
"context"
+ "strings"
"sync"
"time"
@@ -21,6 +22,10 @@ type AppFallbackProvider interface {
FetchAppRelease(ctx context.Context, request AppRequest) (AppRelease, error)
}
+type appVersionFallbackProvider interface {
+ FetchAppReleaseByVersion(ctx context.Context, version string) (AppRelease, error)
+}
+
type ToolFallbackProvider interface {
FetchToolRelease(ctx context.Context, request ToolRequest) (ToolRelease, error)
}
@@ -164,6 +169,38 @@ func (service *Service) ResolveAppRelease(ctx context.Context, request AppReques
return release, nil
}
+func (service *Service) ResolveAppReleaseByVersion(ctx context.Context, version string) (AppRelease, error) {
+ normalizedVersion := normalizeAppReleaseVersion(version)
+ if service == nil || normalizedVersion == "" {
+ return AppRelease{}, ErrReleaseNotFound
+ }
+
+ snapshot, err := service.EnsureCatalog(ctx, time.Minute, Request{AppVersion: normalizedVersion})
+ if err == nil && snapshot.Catalog.App != nil && sameAppReleaseVersion(snapshot.Catalog.App.Version, normalizedVersion) {
+ release := *snapshot.Catalog.App
+ release.ResolvedBy = SourceManifest
+ return release, nil
+ }
+
+ resolver, ok := service.appFallbackProvider.(appVersionFallbackProvider)
+ if !ok || resolver == nil {
+ if err != nil {
+ return AppRelease{}, err
+ }
+ return AppRelease{}, ErrReleaseNotFound
+ }
+
+ release, fallbackErr := resolver.FetchAppReleaseByVersion(ctx, normalizedVersion)
+ if fallbackErr != nil {
+ if err != nil {
+ return AppRelease{}, err
+ }
+ return AppRelease{}, fallbackErr
+ }
+ release.ResolvedBy = SourceFallback
+ return release, nil
+}
+
func (service *Service) ResolveToolRelease(ctx context.Context, request ToolRequest) (ToolRelease, error) {
if service == nil {
return ToolRelease{}, ErrReleaseNotFound
@@ -240,6 +277,18 @@ func (service *Service) StartSchedule(ctx context.Context, initialDelay time.Dur
}()
}
+func normalizeAppReleaseVersion(version string) string {
+ trimmed := strings.TrimSpace(version)
+ trimmed = strings.TrimPrefix(strings.TrimPrefix(trimmed, "v"), "V")
+ return trimmed
+}
+
+func sameAppReleaseVersion(left string, right string) bool {
+ normalizedLeft := normalizeAppReleaseVersion(left)
+ normalizedRight := normalizeAppReleaseVersion(right)
+ return normalizedLeft != "" && normalizedLeft == normalizedRight
+}
+
func (service *Service) StopSchedule() {
service.mu.Lock()
defer service.mu.Unlock()
diff --git a/internal/application/tools/dto/models.go b/internal/application/tools/dto/models.go
index b6f382a..4af55da 100644
--- a/internal/application/tools/dto/models.go
+++ b/internal/application/tools/dto/models.go
@@ -21,6 +21,7 @@ type ToolRequirement struct {
Name string `json:"name,omitempty"`
Available bool `json:"available"`
Reason string `json:"reason,omitempty"`
+ Data any `json:"data,omitempty"`
}
type ToolMethodSpec struct {
diff --git a/internal/application/update/service.go b/internal/application/update/service.go
index 46215a1..7ce1471 100644
--- a/internal/application/update/service.go
+++ b/internal/application/update/service.go
@@ -8,6 +8,7 @@ import (
"io"
"os"
"slices"
+ "strconv"
"strings"
"sync"
"time"
@@ -23,7 +24,7 @@ type Downloader interface {
}
type Installer interface {
- Install(ctx context.Context, artifactPath string) error
+ Install(ctx context.Context, artifactPath string, prepared update.Info) error
RestartToApply(ctx context.Context) error
}
@@ -31,24 +32,36 @@ type downloadURLSelector interface {
SelectDownloadURLs(ctx context.Context, urls []string) []string
}
+type preparedUpdateInspector interface {
+ PreparedUpdate(ctx context.Context) (update.Info, bool, error)
+ ClearPreparedUpdate(ctx context.Context) error
+}
+
+type whatsNewStore interface {
+ PendingWhatsNew(ctx context.Context) (update.WhatsNew, bool, error)
+ SeenWhatsNewVersion(ctx context.Context) (string, error)
+ MarkWhatsNewSeen(ctx context.Context, version string) error
+}
+
type Notifier interface {
SetUpdateAvailable(available bool)
NotifyUpdateState(info update.Info)
}
type Service struct {
- mu sync.Mutex
- state update.Info
- catalog *softwareupdate.Service
- downloader Downloader
- installer Installer
- bus events.Bus
- notifier Notifier
- now func() time.Time
- scheduleTicker *time.Ticker
- cancelSchedule context.CancelFunc
- downloadURLs []string
- downloadSHA256 string
+ mu sync.Mutex
+ state update.Info
+ catalog *softwareupdate.Service
+ downloader Downloader
+ installer Installer
+ bus events.Bus
+ notifier Notifier
+ now func() time.Time
+ scheduleTicker *time.Ticker
+ cancelSchedule context.CancelFunc
+ downloadURLs []string
+ downloadSHA256 string
+ autoPrepareInFlight bool
}
type ServiceParams struct {
@@ -89,9 +102,15 @@ func (service *Service) PublishCurrentState() {
func (service *Service) SetCurrentVersion(version string) {
service.mu.Lock()
service.state.CurrentVersion = update.NormalizeVersion(version)
+ if service.state.CurrentVersion != "" &&
+ service.state.PreparedVersion != "" &&
+ update.CompareVersion(service.state.CurrentVersion, service.state.PreparedVersion) >= 0 {
+ service.clearPreparedStateLocked()
+ }
if service.state.CurrentVersion != "" &&
service.state.LatestVersion != "" &&
- update.CompareVersion(service.state.CurrentVersion, service.state.LatestVersion) >= 0 {
+ update.CompareVersion(service.state.CurrentVersion, service.state.LatestVersion) >= 0 &&
+ !service.state.HasPreparedUpdate() {
service.state.Status = update.StatusIdle
service.state.Progress = 0
service.state.DownloadURL = ""
@@ -102,11 +121,117 @@ func (service *Service) SetCurrentVersion(version string) {
service.mu.Unlock()
}
+func (service *Service) RestorePreparedUpdate(ctx context.Context) (update.Info, error) {
+ inspector, ok := service.installer.(preparedUpdateInspector)
+ if !ok || inspector == nil {
+ return service.State(), nil
+ }
+
+ prepared, found, err := inspector.PreparedUpdate(ctx)
+ if err != nil {
+ return service.State(), err
+ }
+ if !found {
+ return service.State(), nil
+ }
+
+ preparedVersion := update.NormalizeVersion(prepared.PreparedVersion)
+ currentVersion := update.NormalizeVersion(service.State().CurrentVersion)
+ if preparedVersion != "" && currentVersion != "" && update.CompareVersion(currentVersion, preparedVersion) >= 0 {
+ if clearErr := inspector.ClearPreparedUpdate(ctx); clearErr != nil {
+ zap.L().Warn("update: clear stale prepared update failed", zap.Error(clearErr))
+ }
+ service.mu.Lock()
+ service.clearPreparedStateLocked()
+ state := service.state
+ service.mu.Unlock()
+ return state, nil
+ }
+
+ service.mu.Lock()
+ service.state.Kind = update.KindApp
+ if strings.TrimSpace(service.state.LatestVersion) == "" ||
+ update.CompareVersion(preparedVersion, service.state.LatestVersion) >= 0 {
+ service.state.LatestVersion = preparedVersion
+ service.state.Changelog = prepared.PreparedChangelog
+ }
+ service.state.PreparedVersion = preparedVersion
+ service.state.PreparedChangelog = prepared.PreparedChangelog
+ service.setPreparedReadyLocked()
+ state := service.state
+ service.mu.Unlock()
+ service.notifyAvailability(true)
+ return state, nil
+}
+
+func (service *Service) GetWhatsNew(ctx context.Context) (update.WhatsNew, error) {
+ store, ok := service.installer.(whatsNewStore)
+ if !ok || store == nil {
+ return update.WhatsNew{}, nil
+ }
+
+ currentVersion := update.NormalizeVersion(service.State().CurrentVersion)
+ if !isReleaseVersion(currentVersion) {
+ return update.WhatsNew{}, nil
+ }
+
+ seenVersion, err := store.SeenWhatsNewVersion(ctx)
+ if err != nil {
+ return update.WhatsNew{}, err
+ }
+ seenVersion = update.NormalizeVersion(seenVersion)
+
+ pending, found, err := store.PendingWhatsNew(ctx)
+ if err != nil {
+ return update.WhatsNew{}, err
+ }
+ if found {
+ pendingVersion := update.NormalizeVersion(pending.Version)
+ switch {
+ case pendingVersion != "" &&
+ update.CompareVersion(currentVersion, pendingVersion) == 0 &&
+ (seenVersion == "" || update.CompareVersion(pendingVersion, seenVersion) > 0):
+ pending.CurrentVersion = currentVersion
+ if strings.TrimSpace(pending.Changelog) == "" {
+ pending.Changelog = service.resolveReleaseNotes(ctx, pendingVersion)
+ }
+ return pending, nil
+ case pendingVersion != "" && update.CompareVersion(currentVersion, pendingVersion) > 0:
+ // A newer version is already running; ignore the older pending notice.
+ case pendingVersion != "" && update.CompareVersion(currentVersion, pendingVersion) < 0:
+ return update.WhatsNew{}, nil
+ }
+ }
+
+ if seenVersion != "" && update.CompareVersion(currentVersion, seenVersion) <= 0 {
+ return update.WhatsNew{}, nil
+ }
+
+ return update.WhatsNew{
+ Version: currentVersion,
+ CurrentVersion: currentVersion,
+ Changelog: service.resolveReleaseNotes(ctx, currentVersion),
+ }, nil
+}
+
+func (service *Service) DismissWhatsNew(ctx context.Context, version string) error {
+ store, ok := service.installer.(whatsNewStore)
+ if !ok || store == nil {
+ return nil
+ }
+ return store.MarkWhatsNewSeen(ctx, update.NormalizeVersion(version))
+}
+
func (service *Service) CheckForUpdate(ctx context.Context, currentVersion string) (update.Info, error) {
service.mu.Lock()
if currentVersion != "" {
service.state.CurrentVersion = update.NormalizeVersion(currentVersion)
}
+ if service.state.Status == update.StatusDownloading || service.state.Status == update.StatusInstalling {
+ state := service.state
+ service.mu.Unlock()
+ return state, nil
+ }
service.setStatusLocked(update.StatusChecking, 0, "")
state := service.state
service.mu.Unlock()
@@ -125,18 +250,12 @@ func (service *Service) CheckForUpdate(ctx context.Context, currentVersion strin
})
if err != nil {
- service.mu.Lock()
- service.setStatusLocked(update.StatusError, service.state.Progress, err.Error())
- state := service.state
- service.mu.Unlock()
- go service.publishSnapshot(state)
- return state, err
+ return service.publishCheckError(err)
}
downloadURLs := service.selectDownloadURLs(ctx, release.Asset.DownloadURLs())
service.mu.Lock()
- defer service.mu.Unlock()
latest := update.NormalizeVersion(release.Version)
current := update.NormalizeVersion(service.state.CurrentVersion)
service.state.LatestVersion = latest
@@ -155,43 +274,71 @@ func (service *Service) CheckForUpdate(ctx context.Context, currentVersion strin
)
if current != "" && latest != "" && update.CompareVersion(current, latest) >= 0 {
+ if service.state.HasPreparedUpdate() {
+ service.setPreparedReadyLocked()
+ state := service.state
+ service.mu.Unlock()
+ service.notifyAvailability(true)
+ service.publishSnapshot(state)
+ return state, nil
+ }
service.setStatusLocked(update.StatusNoUpdate, 0, "")
state := service.state
service.downloadURLs = nil
service.downloadSHA256 = ""
+ service.mu.Unlock()
service.notifyAvailability(false)
- go service.publishSnapshot(state)
+ service.publishSnapshot(state)
return state, nil
}
if service.state.DownloadURL == "" {
- service.setStatusLocked(update.StatusError, service.state.Progress, "no downloadable asset for update")
+ service.mu.Unlock()
+ return service.publishPrepareError(fmt.Errorf("no downloadable asset for update"), service.capturePreparedFallback())
+ }
+
+ if service.state.HasPreparedUpdate() &&
+ update.CompareVersion(service.state.PreparedVersion, latest) == 0 {
+ service.setPreparedReadyLocked()
state := service.state
- go service.publishSnapshot(state)
- return state, fmt.Errorf("no downloadable asset for update")
+ service.mu.Unlock()
+ service.notifyAvailability(true)
+ service.publishSnapshot(state)
+ return state, nil
}
service.setStatusLocked(update.StatusAvailable, 0, "")
state = service.state
+ shouldAutoPrepare := service.shouldAutoPrepareLocked()
+ service.mu.Unlock()
service.notifyAvailability(true)
- go service.publishSnapshot(state)
+ service.publishSnapshot(state)
+ if shouldAutoPrepare {
+ service.scheduleAutoPrepare()
+ }
return state, nil
}
func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error) {
service.mu.Lock()
+ if service.state.Status == update.StatusDownloading || service.state.Status == update.StatusInstalling {
+ state := service.state
+ service.mu.Unlock()
+ return state, nil
+ }
downloadURLs := service.resolveDownloadURLsLocked()
expectedSHA256 := service.downloadSHA256
+ fallback := service.capturePreparedFallbackLocked()
service.setStatusLocked(update.StatusDownloading, 0, "")
state := service.state
service.mu.Unlock()
service.publishSnapshot(state)
if len(downloadURLs) == 0 {
- return service.publishError(fmt.Errorf("missing download url"))
+ return service.publishPrepareError(fmt.Errorf("missing download url"), fallback)
}
if service.downloader == nil {
- return service.publishError(fmt.Errorf("downloader not configured"))
+ return service.publishPrepareError(fmt.Errorf("downloader not configured"), fallback)
}
var path string
@@ -214,7 +361,7 @@ func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error)
zap.L().Warn("update: download source failed", zap.String("url", downloadURL), zap.Error(err))
}
if err != nil {
- return service.publishError(err)
+ return service.publishPrepareError(err, fallback)
}
service.mu.Lock()
@@ -226,12 +373,14 @@ func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error)
service.publishSnapshot(installingState)
if service.installer != nil {
- if err := service.installer.Install(ctx, path); err != nil {
- return service.publishError(err)
+ if err := service.installer.Install(ctx, path, installingState); err != nil {
+ return service.publishPrepareError(err, fallback)
}
}
service.mu.Lock()
+ service.state.PreparedVersion = update.NormalizeVersion(service.state.LatestVersion)
+ service.state.PreparedChangelog = service.state.Changelog
service.setStatusLocked(update.StatusReadyToRestart, 100, "")
finalState := service.state
service.mu.Unlock()
@@ -249,6 +398,7 @@ func (service *Service) RestartToApply(ctx context.Context) (update.Info, error)
return service.publishError(err)
}
service.mu.Lock()
+ service.clearPreparedStateLocked()
service.setStatusLocked(update.StatusIdle, 0, "")
service.downloadURLs = nil
service.downloadSHA256 = ""
@@ -314,6 +464,67 @@ func (service *Service) safeCheck(ctx context.Context, currentVersion string) {
_, _ = service.CheckForUpdate(ctx, currentVersion) // errors already published
}
+func (service *Service) scheduleAutoPrepare() {
+ service.mu.Lock()
+ if !service.shouldAutoPrepareLocked() {
+ service.mu.Unlock()
+ return
+ }
+ latestVersion := service.state.LatestVersion
+ service.autoPrepareInFlight = true
+ service.mu.Unlock()
+
+ go func() {
+ defer service.finishAutoPrepare()
+ if _, err := service.DownloadUpdate(context.Background()); err != nil {
+ zap.L().Warn("update: auto-prepare failed", zap.String("latestVersion", latestVersion), zap.Error(err))
+ }
+ }()
+}
+
+func (service *Service) finishAutoPrepare() {
+ service.mu.Lock()
+ service.autoPrepareInFlight = false
+ service.mu.Unlock()
+}
+
+func (service *Service) publishCheckError(err error) (update.Info, error) {
+ service.mu.Lock()
+ service.state.CheckedAt = service.now()
+ if service.state.HasPreparedUpdate() {
+ service.setPreparedReadyLocked()
+ state := service.state
+ service.mu.Unlock()
+ service.notifyAvailability(true)
+ service.publishSnapshot(state)
+ return state, err
+ }
+ service.setStatusLocked(update.StatusError, service.state.Progress, err.Error())
+ state := service.state
+ service.mu.Unlock()
+ service.publishSnapshot(state)
+ return state, err
+}
+
+func (service *Service) publishPrepareError(err error, fallback preparedFallback) (update.Info, error) {
+ service.mu.Lock()
+ if fallback.HasPreparedUpdate(service.state.CurrentVersion) {
+ service.state.PreparedVersion = fallback.Version
+ service.state.PreparedChangelog = fallback.Changelog
+ service.setPreparedReadyLocked()
+ state := service.state
+ service.mu.Unlock()
+ service.notifyAvailability(true)
+ service.publishSnapshot(state)
+ return state, err
+ }
+ service.setStatusLocked(update.StatusError, service.state.Progress, err.Error())
+ state := service.state
+ service.mu.Unlock()
+ service.publishSnapshot(state)
+ return state, err
+}
+
func (service *Service) publishError(err error) (update.Info, error) {
service.mu.Lock()
service.setStatusLocked(update.StatusError, service.state.Progress, err.Error())
@@ -331,6 +542,35 @@ func (service *Service) setStatusLocked(status update.Status, progress int, mess
service.state.Message = message
}
+func (service *Service) setPreparedReadyLocked() {
+ service.state.Status = update.StatusReadyToRestart
+ service.state.Progress = 100
+ service.state.Message = ""
+}
+
+func (service *Service) clearPreparedStateLocked() {
+ service.state.PreparedVersion = ""
+ service.state.PreparedChangelog = ""
+}
+
+func (service *Service) shouldAutoPrepareLocked() bool {
+ if service.autoPrepareInFlight {
+ return false
+ }
+ if service.state.Status != update.StatusAvailable {
+ return false
+ }
+ if strings.TrimSpace(service.state.DownloadURL) == "" {
+ return false
+ }
+ latestVersion := update.NormalizeVersion(service.state.LatestVersion)
+ if latestVersion == "" {
+ return false
+ }
+ preparedVersion := update.NormalizeVersion(service.state.PreparedVersion)
+ return preparedVersion == "" || update.CompareVersion(latestVersion, preparedVersion) > 0
+}
+
func (service *Service) publishState() {
service.mu.Lock()
state := service.state
@@ -368,6 +608,62 @@ func (service *Service) resolveDownloadURLsLocked() []string {
return []string{service.state.DownloadURL}
}
+type preparedFallback struct {
+ Version string
+ Changelog string
+}
+
+func (fallback preparedFallback) HasPreparedUpdate(currentVersion string) bool {
+ preparedVersion := update.NormalizeVersion(fallback.Version)
+ current := update.NormalizeVersion(currentVersion)
+ return preparedVersion != "" && update.CompareVersion(preparedVersion, current) > 0
+}
+
+func (service *Service) capturePreparedFallback() preparedFallback {
+ service.mu.Lock()
+ defer service.mu.Unlock()
+ return service.capturePreparedFallbackLocked()
+}
+
+func (service *Service) capturePreparedFallbackLocked() preparedFallback {
+ return preparedFallback{
+ Version: service.state.PreparedVersion,
+ Changelog: service.state.PreparedChangelog,
+ }
+}
+
+func (service *Service) resolveReleaseNotes(ctx context.Context, version string) string {
+ if service.catalog == nil || !isReleaseVersion(version) {
+ return ""
+ }
+ release, err := service.catalog.ResolveAppReleaseByVersion(ctx, version)
+ if err != nil {
+ zap.L().Warn("update: resolve release notes failed",
+ zap.String("version", version),
+ zap.Error(err),
+ )
+ return ""
+ }
+ return strings.TrimSpace(release.Notes)
+}
+
+func isReleaseVersion(version string) bool {
+ normalized := update.NormalizeVersion(version)
+ if normalized == "" {
+ return false
+ }
+ parts := strings.Split(normalized, ".")
+ for _, part := range parts {
+ if part == "" {
+ return false
+ }
+ if _, err := strconv.Atoi(part); err != nil {
+ return false
+ }
+ }
+ return true
+}
+
func normalizeSHA256(raw string) string {
value := strings.ToLower(strings.TrimSpace(raw))
value = strings.TrimPrefix(value, "sha256:")
diff --git a/internal/application/update/service_test.go b/internal/application/update/service_test.go
index 9d96103..0cfd4dd 100644
--- a/internal/application/update/service_test.go
+++ b/internal/application/update/service_test.go
@@ -38,9 +38,16 @@ type installerStub struct {
restarted bool
selectedDownloadURLs []string
selectDownloadInvoked bool
+ preparedInfo domainupdate.Info
+ hasPreparedInfo bool
+ clearPreparedInvoked bool
+ pendingWhatsNew domainupdate.WhatsNew
+ hasPendingWhatsNew bool
+ seenWhatsNewVersion string
+ markSeenVersion string
}
-func (stub installerStub) Install(_ context.Context, _ string) error {
+func (stub installerStub) Install(_ context.Context, _ string, _ domainupdate.Info) error {
return stub.installErr
}
@@ -57,6 +64,28 @@ func (stub *installerStub) SelectDownloadURLs(_ context.Context, urls []string)
return urls
}
+func (stub *installerStub) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) {
+ return stub.preparedInfo, stub.hasPreparedInfo, nil
+}
+
+func (stub *installerStub) ClearPreparedUpdate(_ context.Context) error {
+ stub.clearPreparedInvoked = true
+ return nil
+}
+
+func (stub *installerStub) PendingWhatsNew(_ context.Context) (domainupdate.WhatsNew, bool, error) {
+ return stub.pendingWhatsNew, stub.hasPendingWhatsNew, nil
+}
+
+func (stub *installerStub) SeenWhatsNewVersion(_ context.Context) (string, error) {
+ return stub.seenWhatsNewVersion, nil
+}
+
+func (stub *installerStub) MarkWhatsNewSeen(_ context.Context, version string) error {
+ stub.markSeenVersion = version
+ return nil
+}
+
func (stub *catalogProviderStub) FetchCatalog(_ context.Context, _ softwareupdate.Request) (softwareupdate.Catalog, error) {
stub.fetchCount++
if stub.err != nil {
@@ -71,6 +100,25 @@ func newCatalogService(provider *catalogProviderStub) *softwareupdate.Service {
})
}
+type appVersionProviderStub struct {
+ release softwareupdate.AppRelease
+ err error
+}
+
+func (stub appVersionProviderStub) FetchAppRelease(_ context.Context, _ softwareupdate.AppRequest) (softwareupdate.AppRelease, error) {
+ if stub.err != nil {
+ return softwareupdate.AppRelease{}, stub.err
+ }
+ return stub.release, nil
+}
+
+func (stub appVersionProviderStub) FetchAppReleaseByVersion(_ context.Context, _ string) (softwareupdate.AppRelease, error) {
+ if stub.err != nil {
+ return softwareupdate.AppRelease{}, stub.err
+ }
+ return stub.release, nil
+}
+
func buildCatalog(version string, downloadURL string) softwareupdate.Catalog {
return softwareupdate.Catalog{
App: &softwareupdate.AppRelease{
@@ -309,3 +357,224 @@ func TestRestartToApplyPublishesErrorWhenInstallerFails(t *testing.T) {
t.Fatalf("expected error message %q, got %q", restartErr.Error(), info.Message)
}
}
+
+func TestRestorePreparedUpdateRestoresReadyState(t *testing.T) {
+ t.Parallel()
+
+ installer := &installerStub{
+ hasPreparedInfo: true,
+ preparedInfo: domainupdate.Info{
+ PreparedVersion: "1.2.4",
+ PreparedChangelog: "Bug fixes",
+ },
+ }
+ service := NewService(ServiceParams{Installer: installer})
+ service.SetCurrentVersion("1.2.3")
+
+ info, err := service.RestorePreparedUpdate(context.Background())
+ if err != nil {
+ t.Fatalf("restore prepared update failed: %v", err)
+ }
+ if info.Status != domainupdate.StatusReadyToRestart {
+ t.Fatalf("expected ready_to_restart status, got %q", info.Status)
+ }
+ if info.PreparedVersion != "1.2.4" {
+ t.Fatalf("expected prepared version 1.2.4, got %q", info.PreparedVersion)
+ }
+ if info.PreparedChangelog != "Bug fixes" {
+ t.Fatalf("expected prepared changelog to be restored, got %q", info.PreparedChangelog)
+ }
+}
+
+func TestRestorePreparedUpdateClearsStalePreparedPlan(t *testing.T) {
+ t.Parallel()
+
+ installer := &installerStub{
+ hasPreparedInfo: true,
+ preparedInfo: domainupdate.Info{
+ PreparedVersion: "1.2.3",
+ },
+ }
+ service := NewService(ServiceParams{Installer: installer})
+ service.SetCurrentVersion("1.2.3")
+
+ info, err := service.RestorePreparedUpdate(context.Background())
+ if err != nil {
+ t.Fatalf("restore prepared update failed: %v", err)
+ }
+ if !installer.clearPreparedInvoked {
+ t.Fatal("expected stale prepared update to be cleared")
+ }
+ if info.PreparedVersion != "" {
+ t.Fatalf("expected prepared version to be cleared, got %q", info.PreparedVersion)
+ }
+}
+
+func TestCheckForUpdateAutoPreparesLatestVersion(t *testing.T) {
+ t.Parallel()
+
+ file, err := os.CreateTemp(t.TempDir(), "update-*.zip")
+ if err != nil {
+ t.Fatalf("create temp file failed: %v", err)
+ }
+ if err := file.Close(); err != nil {
+ t.Fatalf("close temp file failed: %v", err)
+ }
+
+ provider := &catalogProviderStub{
+ catalog: buildCatalog("1.2.4", "https://example.com/download.zip"),
+ }
+ service := NewService(ServiceParams{
+ Catalog: newCatalogService(provider),
+ Downloader: &downloaderStub{path: file.Name()},
+ Installer: &installerStub{},
+ })
+
+ if _, err := service.CheckForUpdate(context.Background(), "1.2.3"); err != nil {
+ t.Fatalf("check for update failed: %v", err)
+ }
+
+ deadline := time.Now().Add(2 * time.Second)
+ for time.Now().Before(deadline) {
+ info := service.State()
+ if info.Status == domainupdate.StatusReadyToRestart {
+ if info.PreparedVersion != "1.2.4" {
+ t.Fatalf("expected prepared version 1.2.4, got %q", info.PreparedVersion)
+ }
+ return
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+
+ t.Fatalf("expected auto prepare to reach ready_to_restart, got %q", service.State().Status)
+}
+
+func TestCheckForUpdatePreservesPreparedStateWhenRefreshFails(t *testing.T) {
+ t.Parallel()
+
+ provider := &catalogProviderStub{err: errors.New("manifest unavailable")}
+ service := NewService(ServiceParams{Catalog: newCatalogService(provider)})
+ service.state = domainupdate.Info{
+ Kind: domainupdate.KindApp,
+ CurrentVersion: "1.2.3",
+ LatestVersion: "1.2.4",
+ PreparedVersion: "1.2.4",
+ PreparedChangelog: "Bug fixes",
+ Status: domainupdate.StatusReadyToRestart,
+ Progress: 100,
+ }
+
+ info, err := service.CheckForUpdate(context.Background(), "1.2.3")
+ if err == nil {
+ t.Fatal("expected check to fail")
+ }
+ if info.Status != domainupdate.StatusReadyToRestart {
+ t.Fatalf("expected ready_to_restart status, got %q", info.Status)
+ }
+ if info.PreparedVersion != "1.2.4" {
+ t.Fatalf("expected prepared version to be preserved, got %q", info.PreparedVersion)
+ }
+}
+
+func TestDownloadUpdateRestoresPreviousPreparedVersionWhenNewerPrepareFails(t *testing.T) {
+ t.Parallel()
+
+ installerErr := errors.New("prepare latest failed")
+ service := NewService(ServiceParams{
+ Downloader: &downloaderStub{path: "/tmp/dreamcreator-update.zip"},
+ Installer: &installerStub{installErr: installerErr},
+ })
+ service.state = domainupdate.Info{
+ Kind: domainupdate.KindApp,
+ CurrentVersion: "1.2.3",
+ LatestVersion: "1.2.5",
+ PreparedVersion: "1.2.4",
+ PreparedChangelog: "Prepared 1.2.4",
+ DownloadURL: "https://example.com/dreamcreator-update.zip",
+ Status: domainupdate.StatusAvailable,
+ }
+
+ info, err := service.DownloadUpdate(context.Background())
+ if !errors.Is(err, installerErr) {
+ t.Fatalf("expected installer error, got %v", err)
+ }
+ if info.Status != domainupdate.StatusReadyToRestart {
+ t.Fatalf("expected ready_to_restart status, got %q", info.Status)
+ }
+ if info.PreparedVersion != "1.2.4" {
+ t.Fatalf("expected prepared version 1.2.4 to be preserved, got %q", info.PreparedVersion)
+ }
+ if info.LatestVersion != "1.2.5" {
+ t.Fatalf("expected latest version 1.2.5 to stay visible, got %q", info.LatestVersion)
+ }
+}
+
+func TestGetWhatsNewReturnsPendingPreparedNoticeForCurrentVersion(t *testing.T) {
+ t.Parallel()
+
+ installer := &installerStub{
+ hasPendingWhatsNew: true,
+ pendingWhatsNew: domainupdate.WhatsNew{
+ Version: "2.0.7",
+ Changelog: "## Prepared update",
+ },
+ seenWhatsNewVersion: "2.0.6",
+ }
+ service := NewService(ServiceParams{Installer: installer})
+ service.SetCurrentVersion("2.0.7")
+
+ notice, err := service.GetWhatsNew(context.Background())
+ if err != nil {
+ t.Fatalf("GetWhatsNew failed: %v", err)
+ }
+ if notice.Version != "2.0.7" {
+ t.Fatalf("expected version 2.0.7, got %q", notice.Version)
+ }
+ if notice.Changelog != "## Prepared update" {
+ t.Fatalf("expected prepared changelog, got %q", notice.Changelog)
+ }
+}
+
+func TestGetWhatsNewFallsBackToCurrentVersionReleaseNotes(t *testing.T) {
+ t.Parallel()
+
+ installer := &installerStub{seenWhatsNewVersion: "2.0.6"}
+ catalog := softwareupdate.NewService(softwareupdate.ServiceParams{
+ AppFallbackProvider: appVersionProviderStub{
+ release: softwareupdate.AppRelease{
+ Version: "2.0.7",
+ Notes: "## Current release notes",
+ },
+ },
+ })
+ service := NewService(ServiceParams{
+ Catalog: catalog,
+ Installer: installer,
+ })
+ service.SetCurrentVersion("2.0.7")
+
+ notice, err := service.GetWhatsNew(context.Background())
+ if err != nil {
+ t.Fatalf("GetWhatsNew failed: %v", err)
+ }
+ if notice.Version != "2.0.7" {
+ t.Fatalf("expected version 2.0.7, got %q", notice.Version)
+ }
+ if notice.Changelog != "## Current release notes" {
+ t.Fatalf("expected current release notes, got %q", notice.Changelog)
+ }
+}
+
+func TestDismissWhatsNewMarksSeenVersion(t *testing.T) {
+ t.Parallel()
+
+ installer := &installerStub{}
+ service := NewService(ServiceParams{Installer: installer})
+
+ if err := service.DismissWhatsNew(context.Background(), "2.0.7"); err != nil {
+ t.Fatalf("DismissWhatsNew failed: %v", err)
+ }
+ if installer.markSeenVersion != "2.0.7" {
+ t.Fatalf("expected seen version 2.0.7, got %q", installer.markSeenVersion)
+ }
+}
diff --git a/internal/domain/connectors/connector.go b/internal/domain/connectors/connector.go
index fe60756..9cc8907 100644
--- a/internal/domain/connectors/connector.go
+++ b/internal/domain/connectors/connector.go
@@ -9,6 +9,10 @@ type ConnectorType string
const (
ConnectorGoogle ConnectorType = "google"
+ ConnectorGitHub ConnectorType = "github"
+ ConnectorReddit ConnectorType = "reddit"
+ ConnectorZhihu ConnectorType = "zhihu"
+ ConnectorX ConnectorType = "x"
ConnectorXiaohongshu ConnectorType = "xiaohongshu"
ConnectorBilibili ConnectorType = "bilibili"
)
diff --git a/internal/domain/connectors/errors.go b/internal/domain/connectors/errors.go
index 96a7122..6f214cc 100644
--- a/internal/domain/connectors/errors.go
+++ b/internal/domain/connectors/errors.go
@@ -3,7 +3,9 @@ package connectors
import "errors"
var (
- ErrConnectorNotFound = errors.New("connector not found")
- ErrInvalidConnector = errors.New("invalid connector")
- ErrNoCookies = errors.New("no cookies stored")
+ ErrConnectorNotFound = errors.New("connector not found")
+ ErrInvalidConnector = errors.New("invalid connector")
+ ErrNoCookies = errors.New("no cookies stored")
+ ErrConnectorSessionDead = errors.New("connector browser session ended")
+ ErrConnectorSessionGone = errors.New("connector session not found")
)
diff --git a/internal/domain/externaltools/tool.go b/internal/domain/externaltools/tool.go
index a54bd89..e77bfc0 100644
--- a/internal/domain/externaltools/tool.go
+++ b/internal/domain/externaltools/tool.go
@@ -8,11 +8,10 @@ import (
type ToolName string
const (
- ToolYTDLP ToolName = "yt-dlp"
- ToolFFmpeg ToolName = "ffmpeg"
- ToolBun ToolName = "bun"
- ToolClawHub ToolName = "clawhub"
- ToolPlaywright ToolName = "playwright"
+ ToolYTDLP ToolName = "yt-dlp"
+ ToolFFmpeg ToolName = "ffmpeg"
+ ToolBun ToolName = "bun"
+ ToolClawHub ToolName = "clawhub"
)
type ToolKind string
diff --git a/internal/domain/settings/calls.go b/internal/domain/settings/calls.go
index 0033ebe..1bcef84 100644
--- a/internal/domain/settings/calls.go
+++ b/internal/domain/settings/calls.go
@@ -1,27 +1,18 @@
package settings
-import domainweb "dreamcreator/internal/domain/web"
-
const DefaultBrowserColor = "#FF4500"
func DefaultCallsToolsConfig() map[string]any {
defaultWebFetch := map[string]any{
- "type": "builtin",
- "acceptMarkdown": true,
- "enableUserAgent": true,
- "userAgent": domainweb.DefaultBrowserRequestUserAgent,
- "acceptLanguage": domainweb.DefaultBrowserRequestAcceptLanguage,
- "playwright": map[string]any{
- "markdown": true,
- },
+ "headless": true,
+ "preferredBrowser": "chrome",
}
defaultBrowser := map[string]any{
- "enabled": true,
- "evaluateEnabled": true,
- "headless": false,
- "noSandbox": false,
+ "enabled": true,
+ "headless": true,
+ "preferredBrowser": "chrome",
"ssrfPolicy": map[string]any{
- "dangerouslyAllowPrivateNetwork": true,
+ "dangerouslyAllowPrivateNetwork": false,
},
}
return map[string]any{
diff --git a/internal/domain/settings/calls_test.go b/internal/domain/settings/calls_test.go
index 6c388fe..7454105 100644
--- a/internal/domain/settings/calls_test.go
+++ b/internal/domain/settings/calls_test.go
@@ -10,21 +10,14 @@ func TestDefaultCallsToolsConfigIncludesWebFetchDefaults(t *testing.T) {
if !ok || fetchRaw == nil {
t.Fatalf("expected web_fetch defaults")
}
- if fetchRaw["acceptMarkdown"] != true {
- t.Fatalf("expected acceptMarkdown default true, got %#v", fetchRaw["acceptMarkdown"])
+ if fetchRaw["preferredBrowser"] != "chrome" {
+ t.Fatalf("expected preferredBrowser default chrome, got %#v", fetchRaw["preferredBrowser"])
}
- if fetchRaw["acceptLanguage"] != "en-US,en;q=0.9" {
- t.Fatalf("expected acceptLanguage default en-US,en;q=0.9, got %#v", fetchRaw["acceptLanguage"])
+ if fetchRaw["headless"] != true {
+ t.Fatalf("expected web_fetch headless default true, got %#v", fetchRaw["headless"])
}
- if fetchRaw["type"] != "builtin" {
- t.Fatalf("expected type default builtin, got %#v", fetchRaw["type"])
- }
- playwrightRaw, ok := fetchRaw["playwright"].(map[string]any)
- if !ok || playwrightRaw == nil {
- t.Fatalf("expected playwright defaults")
- }
- if playwrightRaw["markdown"] != true {
- t.Fatalf("expected playwright markdown default true, got %#v", playwrightRaw["markdown"])
+ if _, exists := fetchRaw["type"]; exists {
+ t.Fatalf("expected legacy web_fetch type to be removed")
}
webRaw, ok := defaults["web"].(map[string]any)
if !ok || webRaw == nil {
@@ -46,19 +39,19 @@ func TestDefaultCallsToolsConfigIncludesBrowserDefaults(t *testing.T) {
if browserRaw["enabled"] != true {
t.Fatalf("expected browser enabled default true, got %#v", browserRaw["enabled"])
}
- if browserRaw["evaluateEnabled"] != true {
- t.Fatalf("expected browser evaluateEnabled default true, got %#v", browserRaw["evaluateEnabled"])
+ if browserRaw["headless"] != true {
+ t.Fatalf("expected browser headless default true, got %#v", browserRaw["headless"])
}
- if browserRaw["headless"] != false {
- t.Fatalf("expected browser headless default false, got %#v", browserRaw["headless"])
+ if browserRaw["preferredBrowser"] != "chrome" {
+ t.Fatalf("expected browser preferredBrowser default chrome, got %#v", browserRaw["preferredBrowser"])
}
ssrfRaw, ok := browserRaw["ssrfPolicy"].(map[string]any)
if !ok || ssrfRaw == nil {
t.Fatalf("expected browser ssrfPolicy defaults")
}
- if ssrfRaw["dangerouslyAllowPrivateNetwork"] != true {
+ if ssrfRaw["dangerouslyAllowPrivateNetwork"] != false {
t.Fatalf(
- "expected browser ssrfPolicy.dangerouslyAllowPrivateNetwork true, got %#v",
+ "expected browser ssrfPolicy.dangerouslyAllowPrivateNetwork false, got %#v",
ssrfRaw["dangerouslyAllowPrivateNetwork"],
)
}
@@ -72,11 +65,11 @@ func TestNormalizeToolsConfigAddsWebFetchDefaultsWhenMissing(t *testing.T) {
if !ok || fetchRaw == nil {
t.Fatalf("expected web_fetch config")
}
- if fetchRaw["acceptMarkdown"] != true {
- t.Fatalf("expected acceptMarkdown default true, got %#v", fetchRaw["acceptMarkdown"])
+ if fetchRaw["preferredBrowser"] != "chrome" {
+ t.Fatalf("expected preferredBrowser default, got %#v", fetchRaw["preferredBrowser"])
}
- if fetchRaw["userAgent"] != "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15" {
- t.Fatalf("expected userAgent default, got %#v", fetchRaw["userAgent"])
+ if fetchRaw["headless"] != true {
+ t.Fatalf("expected headless default true, got %#v", fetchRaw["headless"])
}
}
diff --git a/internal/domain/update/update.go b/internal/domain/update/update.go
index 7aebdae..c24125b 100644
--- a/internal/domain/update/update.go
+++ b/internal/domain/update/update.go
@@ -28,15 +28,23 @@ const (
)
type Info struct {
- Kind Kind
+ Kind Kind
+ CurrentVersion string
+ LatestVersion string
+ Changelog string
+ DownloadURL string
+ CheckedAt time.Time
+ Status Status
+ Progress int // 0-100
+ Message string
+ PreparedVersion string
+ PreparedChangelog string
+}
+
+type WhatsNew struct {
+ Version string
CurrentVersion string
- LatestVersion string
Changelog string
- DownloadURL string
- CheckedAt time.Time
- Status Status
- Progress int // 0-100
- Message string
}
func (info Info) IsUpdateAvailable() bool {
@@ -51,6 +59,18 @@ func (info Info) IsError() bool {
return info.Status == StatusError
}
+func (info Info) HasPreparedUpdate() bool {
+ current := NormalizeVersion(info.CurrentVersion)
+ prepared := NormalizeVersion(info.PreparedVersion)
+ return prepared != "" && CompareVersion(prepared, current) > 0
+}
+
+func (info Info) HasRemoteUpdate() bool {
+ current := NormalizeVersion(info.CurrentVersion)
+ latest := NormalizeVersion(info.LatestVersion)
+ return latest != "" && CompareVersion(latest, current) > 0
+}
+
func NormalizeVersion(version string) string {
trimmed := strings.TrimSpace(version)
if trimmed == "" {
diff --git a/internal/infrastructure/update/app_fallback_provider.go b/internal/infrastructure/update/app_fallback_provider.go
index 713f056..1849472 100644
--- a/internal/infrastructure/update/app_fallback_provider.go
+++ b/internal/infrastructure/update/app_fallback_provider.go
@@ -12,6 +12,18 @@ func (client *GithubReleaseClient) FetchAppRelease(ctx context.Context, request
if err != nil {
return softwareupdate.AppRelease{}, err
}
+ return client.toAppRelease(release), nil
+}
+
+func (client *GithubReleaseClient) FetchAppReleaseByVersion(ctx context.Context, version string) (softwareupdate.AppRelease, error) {
+ release, err := client.FetchReleaseByVersion(ctx, version)
+ if err != nil {
+ return softwareupdate.AppRelease{}, err
+ }
+ return client.toAppRelease(release), nil
+}
+
+func (client *GithubReleaseClient) toAppRelease(release githubRelease) softwareupdate.AppRelease {
assetURL := selectAsset(release.Assets)
sources := make([]softwareupdate.DownloadSource, 0, 2)
if strings.TrimSpace(assetURL) != "" {
@@ -42,5 +54,5 @@ func (client *GithubReleaseClient) FetchAppRelease(ctx context.Context, request
Asset: softwareupdate.Asset{
Sources: sources,
},
- }, nil
+ }
}
diff --git a/internal/infrastructure/update/github_client.go b/internal/infrastructure/update/github_client.go
index 20229b7..75bc8d1 100644
--- a/internal/infrastructure/update/github_client.go
+++ b/internal/infrastructure/update/github_client.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "net/url"
"runtime"
"strings"
"time"
@@ -16,11 +17,12 @@ const (
)
type GithubReleaseClient struct {
- client *http.Client
+ client *http.Client
+ releasesURL string
}
func NewGithubReleaseClient(client *http.Client) *GithubReleaseClient {
- return &GithubReleaseClient{client: client}
+ return &GithubReleaseClient{client: client, releasesURL: defaultReleasesURL}
}
type githubRelease struct {
@@ -44,28 +46,82 @@ func (client *GithubReleaseClient) FetchLatestRelease(ctx context.Context) (gith
if client == nil || client.client == nil {
return githubRelease{}, fmt.Errorf("http client not configured")
}
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultReleasesURL+"/latest", nil)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.baseReleasesURL()+"/latest", nil)
if err != nil {
return githubRelease{}, err
}
req.Header.Set("Accept", "application/vnd.github+json")
+ release, err := client.fetchRelease(req, "github latest release")
+ if err != nil {
+ return githubRelease{}, err
+ }
+ if strings.TrimSpace(release.TagName) == "" {
+ return githubRelease{}, fmt.Errorf("no latest release found")
+ }
+ return release, nil
+}
+
+func (client *GithubReleaseClient) FetchReleaseByVersion(ctx context.Context, version string) (githubRelease, error) {
+ if client == nil || client.client == nil {
+ return githubRelease{}, fmt.Errorf("http client not configured")
+ }
+ normalized := strings.TrimSpace(version)
+ if normalized == "" {
+ return githubRelease{}, fmt.Errorf("release version is empty")
+ }
+
+ candidates := []string{normalized}
+ if !strings.HasPrefix(strings.ToLower(normalized), "v") {
+ candidates = append(candidates, "v"+normalized)
+ }
+
+ var lastErr error
+ for _, candidate := range candidates {
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ client.baseReleasesURL()+"/tags/"+url.PathEscape(candidate),
+ nil,
+ )
+ if err != nil {
+ return githubRelease{}, err
+ }
+ req.Header.Set("Accept", "application/vnd.github+json")
+
+ release, err := client.fetchRelease(req, fmt.Sprintf("github release %q", candidate))
+ if err == nil {
+ return release, nil
+ }
+ lastErr = err
+ }
+ if lastErr == nil {
+ lastErr = fmt.Errorf("release %q not found", normalized)
+ }
+ return githubRelease{}, lastErr
+}
+
+func (client *GithubReleaseClient) baseReleasesURL() string {
+ if client == nil || strings.TrimSpace(client.releasesURL) == "" {
+ return defaultReleasesURL
+ }
+ return strings.TrimSpace(client.releasesURL)
+}
+
+func (client *GithubReleaseClient) fetchRelease(req *http.Request, label string) (githubRelease, error) {
resp, err := client.client.Do(req)
if err != nil {
return githubRelease{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- return githubRelease{}, fmt.Errorf("github latest release http %d", resp.StatusCode)
+ return githubRelease{}, fmt.Errorf("%s http %d", label, resp.StatusCode)
}
var release githubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return githubRelease{}, err
}
- if strings.TrimSpace(release.TagName) == "" {
- return githubRelease{}, fmt.Errorf("no latest release found")
- }
return release, nil
}
diff --git a/internal/infrastructure/update/installer.go b/internal/infrastructure/update/installer.go
index 84755d9..bd58811 100644
--- a/internal/infrastructure/update/installer.go
+++ b/internal/infrastructure/update/installer.go
@@ -15,6 +15,9 @@ import (
"slices"
"strconv"
"strings"
+ "time"
+
+ domainupdate "dreamcreator/internal/domain/update"
)
var ErrPreparedUpdateNotFound = fmt.Errorf("prepared update not found")
@@ -28,12 +31,14 @@ const (
)
type PlatformInstaller struct {
- stateDir string
- planPath string
- goos string
- goarch string
- executablePath func() (string, error)
- startDetached func(name string, args []string) error
+ stateDir string
+ planPath string
+ whatsNewPendingPath string
+ whatsNewSeenPath string
+ goos string
+ goarch string
+ executablePath func() (string, error)
+ startDetached func(name string, args []string) error
}
type stagedPlan struct {
@@ -45,6 +50,13 @@ type stagedPlan struct {
RelaunchPath string `json:"relaunchPath"`
FallbackPath string `json:"fallbackPath,omitempty"`
InstallDir string `json:"installDir,omitempty"`
+ Version string `json:"version,omitempty"`
+ Changelog string `json:"changelog,omitempty"`
+}
+
+type whatsNewSeenState struct {
+ Version string `json:"version"`
+ SeenAt string `json:"seenAt,omitempty"`
}
func NewInstaller(statePath string) (*PlatformInstaller, error) {
@@ -61,12 +73,14 @@ func NewInstaller(statePath string) (*PlatformInstaller, error) {
return nil, err
}
return &PlatformInstaller{
- stateDir: stateDir,
- planPath: filepath.Join(stateDir, "update_install_plan.json"),
- goos: runtime.GOOS,
- goarch: runtime.GOARCH,
- executablePath: os.Executable,
- startDetached: startDetachedCommand,
+ stateDir: stateDir,
+ planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"),
+ whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"),
+ goos: runtime.GOOS,
+ goarch: runtime.GOARCH,
+ executablePath: os.Executable,
+ startDetached: startDetachedCommand,
}, nil
}
@@ -84,7 +98,7 @@ func (installer *PlatformInstaller) SelectDownloadURLs(_ context.Context, urls [
return preferWindowsPortableDownloadURLs(urls)
}
-func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath string) error {
+func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath string, prepared domainupdate.Info) error {
if installer == nil {
return fmt.Errorf("installer not configured")
}
@@ -92,18 +106,29 @@ func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath st
if normalizedArtifact == "" {
return fmt.Errorf("artifact path is empty")
}
- if err := installer.cleanupStagedUpdate(); err != nil {
+
+ previousPlan, err := installer.loadPlan()
+ hasPreviousPlan := err == nil
+ if err != nil && !errors.Is(err, ErrPreparedUpdateNotFound) {
return err
}
+ var installErr error
switch installer.goos {
case "windows":
- return installer.prepareWindowsUpdate(normalizedArtifact)
+ installErr = installer.prepareWindowsUpdate(normalizedArtifact, prepared)
case "darwin":
- return installer.prepareMacUpdate(ctx, normalizedArtifact)
+ installErr = installer.prepareMacUpdate(ctx, normalizedArtifact, prepared)
default:
- return fmt.Errorf("update install is not supported on %s", installer.goos)
+ installErr = fmt.Errorf("update install is not supported on %s", installer.goos)
+ }
+ if installErr != nil {
+ return installErr
}
+ if hasPreviousPlan && strings.TrimSpace(previousPlan.StageDir) != "" {
+ _ = os.RemoveAll(previousPlan.StageDir)
+ }
+ return nil
}
func (installer *PlatformInstaller) RestartToApply(_ context.Context) error {
@@ -125,7 +150,7 @@ func (installer *PlatformInstaller) RestartToApply(_ context.Context) error {
}
}
-func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) error {
+func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string, prepared domainupdate.Info) error {
currentExe, err := installer.currentExecutable()
if err != nil {
return err
@@ -143,6 +168,8 @@ func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) er
TargetPath: currentExe,
RelaunchPath: currentExe,
InstallDir: filepath.Dir(currentExe),
+ Version: strings.TrimSpace(prepared.LatestVersion),
+ Changelog: prepared.Changelog,
}
switch strings.ToLower(filepath.Ext(artifactName)) {
@@ -168,7 +195,7 @@ func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) er
return installer.savePlan(plan)
}
-func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifactPath string) error {
+func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifactPath string, prepared domainupdate.Info) error {
currentExe, err := installer.currentExecutable()
if err != nil {
return err
@@ -204,9 +231,112 @@ func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifa
TargetPath: targetBundle,
RelaunchPath: targetBundle,
FallbackPath: currentBundle,
+ Version: strings.TrimSpace(prepared.LatestVersion),
+ Changelog: prepared.Changelog,
})
}
+func (installer *PlatformInstaller) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) {
+ if installer == nil {
+ return domainupdate.Info{}, false, fmt.Errorf("installer not configured")
+ }
+ plan, err := installer.loadPlan()
+ if err != nil {
+ if errors.Is(err, ErrPreparedUpdateNotFound) {
+ return domainupdate.Info{}, false, nil
+ }
+ return domainupdate.Info{}, false, err
+ }
+ return domainupdate.Info{
+ Kind: domainupdate.KindApp,
+ Status: domainupdate.StatusReadyToRestart,
+ LatestVersion: strings.TrimSpace(plan.Version),
+ PreparedVersion: strings.TrimSpace(plan.Version),
+ Changelog: plan.Changelog,
+ PreparedChangelog: plan.Changelog,
+ Progress: 100,
+ }, true, nil
+}
+
+func (installer *PlatformInstaller) ClearPreparedUpdate(_ context.Context) error {
+ if installer == nil {
+ return fmt.Errorf("installer not configured")
+ }
+ return installer.cleanupStagedUpdate()
+}
+
+func (installer *PlatformInstaller) PendingWhatsNew(_ context.Context) (domainupdate.WhatsNew, bool, error) {
+ if installer == nil {
+ return domainupdate.WhatsNew{}, false, fmt.Errorf("installer not configured")
+ }
+ data, err := os.ReadFile(installer.whatsNewPendingPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return domainupdate.WhatsNew{}, false, nil
+ }
+ return domainupdate.WhatsNew{}, false, err
+ }
+
+ var plan stagedPlan
+ if err := json.Unmarshal(data, &plan); err != nil {
+ return domainupdate.WhatsNew{}, false, err
+ }
+ version := strings.TrimSpace(plan.Version)
+ if version == "" {
+ return domainupdate.WhatsNew{}, false, nil
+ }
+ return domainupdate.WhatsNew{
+ Version: version,
+ Changelog: plan.Changelog,
+ }, true, nil
+}
+
+func (installer *PlatformInstaller) SeenWhatsNewVersion(_ context.Context) (string, error) {
+ if installer == nil {
+ return "", fmt.Errorf("installer not configured")
+ }
+ data, err := os.ReadFile(installer.whatsNewSeenPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return "", nil
+ }
+ return "", err
+ }
+ var seen whatsNewSeenState
+ if err := json.Unmarshal(data, &seen); err != nil {
+ return "", err
+ }
+ return strings.TrimSpace(seen.Version), nil
+}
+
+func (installer *PlatformInstaller) MarkWhatsNewSeen(_ context.Context, version string) error {
+ if installer == nil {
+ return fmt.Errorf("installer not configured")
+ }
+ normalized := strings.TrimSpace(version)
+ if normalized == "" {
+ return nil
+ }
+ data, err := json.MarshalIndent(whatsNewSeenState{
+ Version: normalized,
+ SeenAt: time.Now().UTC().Format(time.RFC3339),
+ }, "", " ")
+ if err != nil {
+ return err
+ }
+ if err := os.WriteFile(installer.whatsNewSeenPath, data, 0o600); err != nil {
+ return err
+ }
+ pending, found, err := installer.PendingWhatsNew(context.Background())
+ if err != nil {
+ return err
+ }
+ if found && domainupdate.CompareVersion(pending.Version, normalized) <= 0 {
+ _ = os.Remove(installer.whatsNewPendingPath)
+ }
+ return nil
+}
+
func (installer *PlatformInstaller) restartWindows(plan stagedPlan) error {
scriptPath := filepath.Join(installer.stateDir, "apply_update.ps1")
if err := os.WriteFile(scriptPath, []byte(windowsApplyScript), 0o600); err != nil {
@@ -224,6 +354,7 @@ func (installer *PlatformInstaller) restartWindows(plan stagedPlan) error {
plan.InstallDir,
plan.StageDir,
installer.planPath,
+ installer.whatsNewPendingPath,
}
return installer.startDetached("powershell.exe", args)
}
@@ -252,6 +383,7 @@ func (installer *PlatformInstaller) restartDarwin(plan stagedPlan) error {
fallbackPath,
plan.StageDir,
installer.planPath,
+ installer.whatsNewPendingPath,
}
return installer.startDetached("/bin/sh", args)
}
@@ -563,7 +695,8 @@ const windowsApplyScript = `param(
[Parameter(Mandatory = $true)][string]$TargetPath,
[Parameter(Mandatory = $true)][string]$InstallDir,
[Parameter(Mandatory = $true)][string]$StageDir,
- [Parameter(Mandatory = $true)][string]$PlanPath
+ [Parameter(Mandatory = $true)][string]$PlanPath,
+ [Parameter(Mandatory = $true)][string]$PendingWhatsNewPath
)
$ErrorActionPreference = "Stop"
@@ -667,13 +800,18 @@ try {
}
}
+ try {
+ Copy-Item -LiteralPath $PlanPath -Destination $PendingWhatsNewPath -Force -ErrorAction Stop
+ } catch {
+ }
+
Start-Process -FilePath $TargetPath -WorkingDirectory $InstallDir | Out-Null
Remove-Item -LiteralPath $PlanPath -Force -ErrorAction SilentlyContinue
Remove-Item -LiteralPath $StageDir -Recurse -Force -ErrorAction SilentlyContinue
} catch {
try {
if (Test-Path -LiteralPath $TargetPath) {
- Start-Process -FilePath $TargetPath -WorkingDirectory $InstallDir | Out-Null
+ Start-Process -FilePath $TargetPath -ArgumentList @("--skip-prepared-update-once") -WorkingDirectory $InstallDir | Out-Null
}
} catch {
}
@@ -691,6 +829,7 @@ RELAUNCH_APP="$4"
FALLBACK_APP="$5"
STAGE_DIR="$6"
PLAN_PATH="$7"
+PENDING_WHATS_NEW_PATH="$8"
BACKUP_APP="${TARGET_APP}.old"
while kill -0 "$PARENT_PID" 2>/dev/null; do
@@ -699,8 +838,13 @@ done
relaunch_app() {
APP_PATH="$1"
+ shift || true
if [ -n "$APP_PATH" ] && [ -d "$APP_PATH" ]; then
- open "$APP_PATH" >/dev/null 2>&1 || true
+ if [ "$#" -gt 0 ]; then
+ open -a "$APP_PATH" --args "$@" >/dev/null 2>&1 || true
+ else
+ open "$APP_PATH" >/dev/null 2>&1 || true
+ fi
fi
}
@@ -712,9 +856,9 @@ restore_backup() {
}
relaunch_fallback() {
- relaunch_app "$FALLBACK_APP"
+ relaunch_app "$FALLBACK_APP" "--skip-prepared-update-once"
if [ "$FALLBACK_APP" != "$TARGET_APP" ]; then
- relaunch_app "$TARGET_APP"
+ relaunch_app "$TARGET_APP" "--skip-prepared-update-once"
fi
}
@@ -760,6 +904,7 @@ if ! install_direct; then
fi
/usr/bin/xattr -dr com.apple.quarantine "$TARGET_APP" >/dev/null 2>&1 || true
+cp "$PLAN_PATH" "$PENDING_WHATS_NEW_PATH" >/dev/null 2>&1 || true
if ! open "$RELAUNCH_APP"; then
relaunch_fallback
exit 1
@@ -770,7 +915,12 @@ rm -rf "$STAGE_DIR"
`
var _ interface {
- Install(context.Context, string) error
+ Install(context.Context, string, domainupdate.Info) error
RestartToApply(context.Context) error
SelectDownloadURLs(context.Context, []string) []string
+ PreparedUpdate(context.Context) (domainupdate.Info, bool, error)
+ ClearPreparedUpdate(context.Context) error
+ PendingWhatsNew(context.Context) (domainupdate.WhatsNew, bool, error)
+ SeenWhatsNewVersion(context.Context) (string, error)
+ MarkWhatsNewSeen(context.Context, string) error
} = (*PlatformInstaller)(nil)
diff --git a/internal/infrastructure/update/installer_test.go b/internal/infrastructure/update/installer_test.go
index d60232d..172486f 100644
--- a/internal/infrastructure/update/installer_test.go
+++ b/internal/infrastructure/update/installer_test.go
@@ -1,6 +1,8 @@
package update
import (
+ "context"
+ "os"
"path/filepath"
"testing"
)
@@ -34,8 +36,10 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) {
)
stateDir := t.TempDir()
installer := &PlatformInstaller{
- stateDir: stateDir,
- planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ stateDir: stateDir,
+ planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"),
+ whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"),
startDetached: func(name string, args []string) error {
capturedName = name
capturedArgs = append([]string(nil), args...)
@@ -57,7 +61,7 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) {
if capturedName != "/bin/sh" {
t.Fatalf("unexpected restart helper: %q", capturedName)
}
- if len(capturedArgs) != 8 {
+ if len(capturedArgs) != 9 {
t.Fatalf("unexpected helper args: %#v", capturedArgs)
}
if capturedArgs[4] != plan.RelaunchPath {
@@ -66,6 +70,9 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) {
if capturedArgs[5] != plan.FallbackPath {
t.Fatalf("expected fallback path %q, got %q", plan.FallbackPath, capturedArgs[5])
}
+ if capturedArgs[8] != installer.whatsNewPendingPath {
+ t.Fatalf("expected pending what's new path %q, got %q", installer.whatsNewPendingPath, capturedArgs[8])
+ }
}
func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) {
@@ -74,8 +81,10 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) {
var capturedArgs []string
stateDir := t.TempDir()
installer := &PlatformInstaller{
- stateDir: stateDir,
- planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ stateDir: stateDir,
+ planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"),
+ whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"),
startDetached: func(_ string, args []string) error {
capturedArgs = append([]string(nil), args...)
return nil
@@ -91,7 +100,7 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) {
t.Fatalf("restartDarwin failed: %v", err)
}
- if len(capturedArgs) != 8 {
+ if len(capturedArgs) != 9 {
t.Fatalf("unexpected helper args: %#v", capturedArgs)
}
if capturedArgs[4] != plan.TargetPath {
@@ -100,4 +109,76 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) {
if capturedArgs[5] != plan.TargetPath {
t.Fatalf("expected default fallback path %q, got %q", plan.TargetPath, capturedArgs[5])
}
+ if capturedArgs[8] != installer.whatsNewPendingPath {
+ t.Fatalf("expected pending what's new path %q, got %q", installer.whatsNewPendingPath, capturedArgs[8])
+ }
+}
+
+func TestPendingWhatsNewReadsCopiedPlan(t *testing.T) {
+ t.Parallel()
+
+ stateDir := t.TempDir()
+ installer := &PlatformInstaller{
+ stateDir: stateDir,
+ planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"),
+ whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"),
+ }
+ if err := installer.savePlan(stagedPlan{
+ Version: "2.0.7",
+ Changelog: "## Updated",
+ }); err != nil {
+ t.Fatalf("savePlan failed: %v", err)
+ }
+ data, err := os.ReadFile(installer.planPath)
+ if err != nil {
+ t.Fatalf("read plan failed: %v", err)
+ }
+ if err := os.WriteFile(installer.whatsNewPendingPath, data, 0o600); err != nil {
+ t.Fatalf("write pending file failed: %v", err)
+ }
+
+ notice, found, err := installer.PendingWhatsNew(context.Background())
+ if err != nil {
+ t.Fatalf("PendingWhatsNew failed: %v", err)
+ }
+ if !found {
+ t.Fatal("expected pending what's new notice")
+ }
+ if notice.Version != "2.0.7" {
+ t.Fatalf("expected version 2.0.7, got %q", notice.Version)
+ }
+ if notice.Changelog != "## Updated" {
+ t.Fatalf("expected changelog to be preserved, got %q", notice.Changelog)
+ }
+}
+
+func TestMarkWhatsNewSeenPersistsVersionAndClearsCoveredPendingNotice(t *testing.T) {
+ t.Parallel()
+
+ stateDir := t.TempDir()
+ installer := &PlatformInstaller{
+ stateDir: stateDir,
+ planPath: filepath.Join(stateDir, "update_install_plan.json"),
+ whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"),
+ whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"),
+ }
+ if err := os.WriteFile(installer.whatsNewPendingPath, []byte(`{"version":"2.0.7","changelog":"hi"}`), 0o600); err != nil {
+ t.Fatalf("seed pending file failed: %v", err)
+ }
+
+ if err := installer.MarkWhatsNewSeen(context.Background(), "2.0.7"); err != nil {
+ t.Fatalf("MarkWhatsNewSeen failed: %v", err)
+ }
+
+ seenVersion, err := installer.SeenWhatsNewVersion(context.Background())
+ if err != nil {
+ t.Fatalf("SeenWhatsNewVersion failed: %v", err)
+ }
+ if seenVersion != "2.0.7" {
+ t.Fatalf("expected seen version 2.0.7, got %q", seenVersion)
+ }
+ if _, err := os.Stat(installer.whatsNewPendingPath); !os.IsNotExist(err) {
+ t.Fatalf("expected pending file to be removed, got err=%v", err)
+ }
}
diff --git a/internal/infrastructure/update/manifest_provider_test.go b/internal/infrastructure/update/manifest_provider_test.go
index b00905f..51f1750 100644
--- a/internal/infrastructure/update/manifest_provider_test.go
+++ b/internal/infrastructure/update/manifest_provider_test.go
@@ -22,7 +22,7 @@ func TestManifestCatalogProviderSelectsCurrentPlatformAssets(t *testing.T) {
"channels":{
"stable":{
"app":{
- "source":{"provider":"github-release","owner":"arnoldhao","repo":"dreamcreator"},
+ "source":{"provider":"github-release","owner":"example-owner","repo":"dreamcreator"},
"version":"1.3.0",
"publishedAt":"2026-04-06T00:00:00Z",
"platforms":{
diff --git a/internal/infrastructure/update/tool_fallback_provider.go b/internal/infrastructure/update/tool_fallback_provider.go
index e4f291c..831ed80 100644
--- a/internal/infrastructure/update/tool_fallback_provider.go
+++ b/internal/infrastructure/update/tool_fallback_provider.go
@@ -61,8 +61,6 @@ func (provider *ToolFallbackProvider) FetchToolRelease(ctx context.Context, requ
return provider.fetchGitHubToolRelease(ctx, request.Name, "oven-sh", "bun")
case externaltools.ToolClawHub:
return provider.fetchNPMPackageRelease(ctx, request.Name, "clawhub")
- case externaltools.ToolPlaywright:
- return softwareupdate.ToolRelease{}, softwareupdate.ErrReleaseNotFound
default:
return softwareupdate.ToolRelease{}, externaltools.ErrInvalidTool
}
diff --git a/internal/presentation/wails/connectors_handler.go b/internal/presentation/wails/connectors_handler.go
index 164f1d4..5451c57 100644
--- a/internal/presentation/wails/connectors_handler.go
+++ b/internal/presentation/wails/connectors_handler.go
@@ -36,21 +36,29 @@ func (handler *ConnectorsHandler) ClearConnector(ctx context.Context, request dt
return handler.service.ClearConnector(ctx, request)
}
-func (handler *ConnectorsHandler) ConnectConnector(ctx context.Context, request dto.ConnectConnectorRequest) (dto.Connector, error) {
- result, err := handler.service.ConnectConnector(ctx, request)
+func (handler *ConnectorsHandler) StartConnectorConnect(ctx context.Context, request dto.StartConnectorConnectRequest) (dto.StartConnectorConnectResult, error) {
+ return handler.service.StartConnectorConnect(ctx, request)
+}
+
+func (handler *ConnectorsHandler) FinishConnectorConnect(ctx context.Context, request dto.FinishConnectorConnectRequest) (dto.FinishConnectorConnectResult, error) {
+ result, err := handler.service.FinishConnectorConnect(ctx, request)
if err != nil {
- return dto.Connector{}, err
+ return dto.FinishConnectorConnectResult{}, err
}
- if handler.telemetry != nil && result.Status == "connected" {
- handler.telemetry.TrackConnectorConnected(ctx, result.Type)
+ if handler.telemetry != nil && result.Saved && result.Connector.Status == "connected" {
+ handler.telemetry.TrackConnectorConnected(ctx, result.Connector.Type)
}
return result, nil
}
-func (handler *ConnectorsHandler) OpenConnectorSite(ctx context.Context, request dto.OpenConnectorSiteRequest) error {
- return handler.service.OpenConnectorSite(ctx, request)
+func (handler *ConnectorsHandler) CancelConnectorConnect(ctx context.Context, request dto.CancelConnectorConnectRequest) error {
+ return handler.service.CancelConnectorConnect(ctx, request)
}
-func (handler *ConnectorsHandler) InstallPlaywright(ctx context.Context) error {
- return handler.service.InstallPlaywright(ctx)
+func (handler *ConnectorsHandler) GetConnectorConnectSession(ctx context.Context, request dto.GetConnectorConnectSessionRequest) (dto.ConnectorConnectSession, error) {
+ return handler.service.GetConnectorConnectSession(ctx, request)
+}
+
+func (handler *ConnectorsHandler) OpenConnectorSite(ctx context.Context, request dto.OpenConnectorSiteRequest) error {
+ return handler.service.OpenConnectorSite(ctx, request)
}
diff --git a/internal/presentation/wails/settings_handler.go b/internal/presentation/wails/settings_handler.go
index 3c15de9..7e08a92 100644
--- a/internal/presentation/wails/settings_handler.go
+++ b/internal/presentation/wails/settings_handler.go
@@ -6,6 +6,7 @@ import (
"strings"
"time"
+ gatewaytools "dreamcreator/internal/application/gateway/tools"
"dreamcreator/internal/application/settings/dto"
"dreamcreator/internal/application/settings/service"
"dreamcreator/internal/domain/settings"
@@ -84,6 +85,19 @@ func (handler *SettingsHandler) UpdateSettings(ctx context.Context, request dto.
zap.L().Info("proxy applied", proxyFields(updated.Proxy)...)
}
+ if hasPrevious && gatewaytools.BrowserToolRuntimeConfigChanged(previousSettings.Tools, updated.Tools) {
+ gatewaytools.CleanupAllBrowserToolSessions()
+ zap.L().Info(
+ "browser tool sessions reset after browser runtime settings change",
+ zap.Bool("previousEnabled", resolveSettingsBrowserBool(previousSettings.Tools, "enabled")),
+ zap.Bool("currentEnabled", resolveSettingsBrowserBool(updated.Tools, "enabled")),
+ zap.Bool("previousHeadless", resolveSettingsBrowserBool(previousSettings.Tools, "headless")),
+ zap.Bool("currentHeadless", resolveSettingsBrowserBool(updated.Tools, "headless")),
+ zap.String("previousPreferredBrowser", resolveSettingsBrowserString(previousSettings.Tools, "preferredBrowser")),
+ zap.String("currentPreferredBrowser", resolveSettingsBrowserString(updated.Tools, "preferredBrowser")),
+ )
+ }
+
handler.windows.ApplySettings(updated)
return updated, nil
}
@@ -250,6 +264,24 @@ func proxyFields(proxyDTO dto.Proxy) []zap.Field {
}
}
+func resolveSettingsBrowserBool(config map[string]any, key string) bool {
+ browser, ok := config["browser"].(map[string]any)
+ if !ok || browser == nil {
+ return false
+ }
+ value, _ := browser[key].(bool)
+ return value
+}
+
+func resolveSettingsBrowserString(config map[string]any, key string) string {
+ browser, ok := config["browser"].(map[string]any)
+ if !ok || browser == nil {
+ return ""
+ }
+ value, _ := browser[key].(string)
+ return strings.TrimSpace(value)
+}
+
func (handler *SettingsHandler) rollbackSettings(ctx context.Context, previous dto.Settings) {
_, err := handler.service.UpdateSettings(ctx, dto.UpdateSettingsRequest{
Appearance: &previous.Appearance,
diff --git a/internal/presentation/wails/update_handler.go b/internal/presentation/wails/update_handler.go
index 6470e40..01ed03b 100644
--- a/internal/presentation/wails/update_handler.go
+++ b/internal/presentation/wails/update_handler.go
@@ -34,6 +34,10 @@ func (handler *UpdateHandler) GetState(_ context.Context) update.Info {
return handler.service.State()
}
+func (handler *UpdateHandler) GetWhatsNew(ctx context.Context) (update.WhatsNew, error) {
+ return handler.service.GetWhatsNew(ctx)
+}
+
func (handler *UpdateHandler) CheckForUpdate(ctx context.Context, currentVersion string) (update.Info, error) {
return handler.service.CheckForUpdate(ctx, currentVersion)
}
@@ -41,7 +45,11 @@ func (handler *UpdateHandler) CheckForUpdate(ctx context.Context, currentVersion
func (handler *UpdateHandler) DownloadUpdate(ctx context.Context) (update.Info, error) {
info, err := handler.service.DownloadUpdate(ctx)
if err == nil && handler.telemetry != nil && info.Status == update.StatusReadyToRestart {
- handler.telemetry.TrackUpdateReadyToRestart(ctx, info.LatestVersion)
+ latestVersion := info.PreparedVersion
+ if latestVersion == "" {
+ latestVersion = info.LatestVersion
+ }
+ handler.telemetry.TrackUpdateReadyToRestart(ctx, latestVersion)
}
return info, err
}
@@ -56,3 +64,7 @@ func (handler *UpdateHandler) RestartToApply(ctx context.Context) (update.Info,
}
return info, err
}
+
+func (handler *UpdateHandler) DismissWhatsNew(ctx context.Context, version string) error {
+ return handler.service.DismissWhatsNew(ctx, version)
+}
diff --git a/main.go b/main.go
index 860dcac..b2a18bf 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"embed"
"log"
"os"
@@ -21,6 +22,14 @@ func main() {
}
}()
+ appliedPreparedUpdate, err := app.TryApplyPreparedUpdateOnLaunch(context.Background(), os.Args[1:])
+ if err != nil {
+ log.Fatal(err)
+ }
+ if appliedPreparedUpdate {
+ return
+ }
+
application, err := app.CreateApplication(assets)
if err != nil {
log.Fatal(err)