diff --git a/README.md b/README.md index 90d7ef8..c87e448 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ | 桌面框架 | Go / Wails 3 / React | | 本地存储 | SQLite / bun / sqlite-vec | | 媒体处理 | yt-dlp / FFmpeg | -| 浏览器自动化 | Playwright | +| 浏览器自动化 | Chrome DevTools Protocol / chromedp | | 渠道接入 | Telegram / telego | 正是这些项目与它们背后的维护者,让追创作能够在桌面、媒体处理、自动化与渠道接入之间建立起一条持续演进的工作链路。 diff --git a/README_en.md b/README_en.md index a62cf51..a5f3f75 100644 --- a/README_en.md +++ b/README_en.md @@ -80,7 +80,7 @@ DreamCreator is built on top of a number of excellent open-source projects and s | Desktop Framework | Go / Wails 3 / React | | Local Storage | SQLite / bun / sqlite-vec | | Media Processing | yt-dlp / FFmpeg | -| Browser Automation | Playwright | +| Browser Automation | Chrome DevTools Protocol / chromedp | | Channel Access | Telegram / telego | These projects, and the maintainers behind them, make it possible for DreamCreator to connect desktop workflows, media processing, automation, and channel access into one evolving system. diff --git a/frontend/scripts/audit-i18n.mjs b/frontend/scripts/audit-i18n.mjs index 749c8d9..c5e2983 100644 --- a/frontend/scripts/audit-i18n.mjs +++ b/frontend/scripts/audit-i18n.mjs @@ -23,13 +23,16 @@ const englishStylePreservePhrases = [ "Apple", "macOS", "iOS", + "Chrome", + "Chromium", + "Edge", + "Brave", "Chat Completions", "OpenAI", "OpenRouter", "Bun", "ClawHub", "DreamCreator", - "Playwright", "FFmpeg", "FontGet", "LibASS", diff --git a/frontend/scripts/audit-structure.mjs b/frontend/scripts/audit-structure.mjs index bd8b6a3..f3ac7d0 100644 --- a/frontend/scripts/audit-structure.mjs +++ b/frontend/scripts/audit-structure.mjs @@ -35,6 +35,7 @@ const oversizeBaseline = { "frontend/src/shared/contracts/library.ts": 1300, "internal/application/channels/telegram/bot_service.go": 5000, "internal/application/gateway/tools/browser_tools.go": 3150, + "internal/application/browsercdp/session.go": 3362, "internal/application/library/service/service.go": 2750, "internal/application/gateway/cron/scheduler.go": 2650, "internal/application/memory/service/service.go": 2300, diff --git a/frontend/src/app/main/MainApp.tsx b/frontend/src/app/main/MainApp.tsx index 9bcbcba..cc1949a 100644 --- a/frontend/src/app/main/MainApp.tsx +++ b/frontend/src/app/main/MainApp.tsx @@ -393,12 +393,6 @@ export function MainApp() { ) : undefined; - const isAppUpdateAvailable = - updateInfo.status === "available" || - updateInfo.status === "downloading" || - updateInfo.status === "installing" || - updateInfo.status === "ready_to_restart"; - const isExternalToolsUpdateAvailable = useMemo(() => { const tools = externalTools.data ?? []; const updates = externalToolUpdates.data ?? []; @@ -444,7 +438,7 @@ export function MainApp() { onSelectThread={assistantUiEnabled ? handleSelectThread : undefined} highlightThreadActive={assistantUiEnabled && visibleActiveTarget.type === "thread"} showThreadList={assistantUiEnabled} - isAppUpdateAvailable={isAppUpdateAvailable} + appUpdateInfo={updateInfo} isExternalToolsUpdateAvailable={isExternalToolsUpdateAvailable} showTitleBarBorder={showTitleBarBorder} contentScrollable={!isMainFamilyPage} diff --git a/frontend/src/app/settings/SettingsApp.tsx b/frontend/src/app/settings/SettingsApp.tsx index 93778c8..6a84081 100644 --- a/frontend/src/app/settings/SettingsApp.tsx +++ b/frontend/src/app/settings/SettingsApp.tsx @@ -37,6 +37,7 @@ import { useTestProxy, useUpdateSettings, } from "@/shared/query/settings"; +import { useGatewayTools } from "@/shared/query/tools"; import { useI18n } from "@/shared/i18n"; import { useAssistantUiMode } from "@/shared/store/assistantUi"; import { useDebugMode } from "@/shared/store/debug"; @@ -167,6 +168,7 @@ export function SettingsApp() { const { enabled: debugEnabled } = useDebugMode(); const { enabled: assistantUiEnabled } = useAssistantUiMode(); const isWindows = System.IsWindows(); + useGatewayTools(assistantUiEnabled); const [activeSection, setActiveSection] = React.useState("gateway"); const [memoryTab, setMemoryTab] = React.useState("summary"); diff --git a/frontend/src/components/layout/AppShell.tsx b/frontend/src/components/layout/AppShell.tsx index e6e3a5f..56335cc 100644 --- a/frontend/src/components/layout/AppShell.tsx +++ b/frontend/src/components/layout/AppShell.tsx @@ -14,6 +14,7 @@ import { import { Button } from "@/shared/ui/button"; import { AppSidebar } from "@/components/layout/AppSidebar"; import { TitleBar } from "@/components/layout/TitleBar"; +import { WhatsNewDialog } from "@/features/update/WhatsNewDialog"; import { useI18n } from "@/shared/i18n"; import { MessageHost } from "@/shared/message/MessageHost"; import { cn } from "@/lib/utils"; @@ -43,7 +44,7 @@ export interface AppShellProps { onSelectThread?: (threadId: string) => void; highlightThreadActive?: boolean; showThreadList?: boolean; - isAppUpdateAvailable?: boolean; + appUpdateInfo?: import("@/shared/store/update").UpdateInfo; isExternalToolsUpdateAvailable?: boolean; noticeUnreadCount?: number; isNoticePanelOpen?: boolean; @@ -75,7 +76,7 @@ function AppShellLayout({ onSelectThread, highlightThreadActive, showThreadList, - isAppUpdateAvailable, + appUpdateInfo, isExternalToolsUpdateAvailable, noticeUnreadCount, isNoticePanelOpen, @@ -184,7 +185,7 @@ function AppShellLayout({ onSelectThread={onSelectThread} highlightThreadActive={highlightThreadActive} showThreadList={showThreadList} - isAppUpdateAvailable={isAppUpdateAvailable} + appUpdateInfo={appUpdateInfo} isExternalToolsUpdateAvailable={isExternalToolsUpdateAvailable} noticeUnreadCount={noticeUnreadCount} isNoticePanelOpen={isNoticePanelOpen} @@ -216,6 +217,7 @@ function AppShellLayout({ /> )} +
void; highlightThreadActive?: boolean; showThreadList?: boolean; - isAppUpdateAvailable?: boolean; + appUpdateInfo?: UpdateInfo; isExternalToolsUpdateAvailable?: boolean; noticeUnreadCount?: number; isNoticePanelOpen?: boolean; @@ -482,7 +484,7 @@ export function AppSidebar({ onSelectThread, highlightThreadActive = true, showThreadList = true, - isAppUpdateAvailable, + appUpdateInfo, isExternalToolsUpdateAvailable, noticeUnreadCount = 0, isNoticePanelOpen = false, @@ -495,6 +497,7 @@ export function AppSidebar({ const settingsLoading = useSettingsStore((state) => state.isLoading); const { open: isSetupCenterOpen, setOpen: setSetupCenterOpen } = useSetupCenter(); const currentUserProfileQuery = useCurrentUserProfile(); + const restartToApply = useRestartToApply(); const [search, setSearch] = React.useState(""); const deferredSearch = React.useDeferredValue(search); const [isProductModeOpen, setIsProductModeOpen] = React.useState(false); @@ -502,7 +505,17 @@ export function AppSidebar({ const setupAutoOpenInitializedRef = React.useRef(false); const [renderLimit, setRenderLimit] = React.useState(200); - const hasUpdateMenu = Boolean(isAppUpdateAvailable || isExternalToolsUpdateAvailable); + const isAppUpdateReadyToRestart = + Boolean(appUpdateInfo) && + appUpdateInfo?.status === "ready_to_restart" && + hasPreparedUpdate(appUpdateInfo); + const isAppUpdateMenuVisible = + Boolean(appUpdateInfo) && + (isAppUpdateReadyToRestart || + (appUpdateInfo ? hasRemoteUpdate(appUpdateInfo) : false) || + appUpdateInfo?.status === "downloading" || + appUpdateInfo?.status === "installing"); + const hasUpdateMenu = Boolean(isAppUpdateMenuVisible || isExternalToolsUpdateAvailable); const normalizedSearch = deferredSearch.trim().toLowerCase(); const currentUserProfile = currentUserProfileQuery.data; const currentUserName = resolveUserDisplayName(currentUserProfile); @@ -676,16 +689,36 @@ export function AppSidebar({ onOpenSettings?.(); }; + const handleRestartPreparedUpdate = React.useCallback(() => { + void restartToApply.mutateAsync().catch((error) => { + const message = error instanceof Error ? error.message : String(error); + messageBus.publishToast({ + intent: "warning", + title: t("sidebar.footer.menu.restartAndUpdate"), + description: message, + }); + }); + }, [restartToApply, t]); + const updateMenuItems = React.useMemo( () => [ - isAppUpdateAvailable + isAppUpdateMenuVisible ? { key: "app-update", - label: t("sidebar.footer.menu.appUpdate"), + label: isAppUpdateReadyToRestart + ? t("sidebar.footer.menu.restartAndUpdate") + : t("sidebar.footer.menu.appUpdate"), Icon: ArrowUpCircle, iconClassName: "text-primary", - onSelect: () => handleSelectSettings("about"), + onSelect: () => { + if (isAppUpdateReadyToRestart) { + handleRestartPreparedUpdate(); + return; + } + handleSelectSettings("about"); + }, + disabled: restartToApply.isPending, } : null, isExternalToolsUpdateAvailable @@ -695,10 +728,18 @@ export function AppSidebar({ Icon: Wrench, iconClassName: "text-primary", onSelect: () => handleSelectSettings("external-tools"), + disabled: false, } : null, ].filter((item): item is NonNullable => Boolean(item)), - [isAppUpdateAvailable, isExternalToolsUpdateAvailable, t] + [ + handleRestartPreparedUpdate, + isAppUpdateMenuVisible, + isAppUpdateReadyToRestart, + isExternalToolsUpdateAvailable, + restartToApply.isPending, + t, + ] ); const handleSelectProductMode = (nextEnabled: boolean) => { @@ -921,6 +962,7 @@ export function AppSidebar({ key={item.key} className={footerDropdownItemClassName} onSelect={item.onSelect} + disabled={item.disabled} >
diff --git a/frontend/src/features/chat/tool-ui/fallback.tsx b/frontend/src/features/chat/tool-ui/fallback.tsx index 2742157..fc97ab4 100644 --- a/frontend/src/features/chat/tool-ui/fallback.tsx +++ b/frontend/src/features/chat/tool-ui/fallback.tsx @@ -2,11 +2,11 @@ import * as React from "react"; import type { ToolCallMessagePart, ToolCallMessagePartStatus } from "@assistant-ui/react"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; +import { ApprovalCard } from "@/components/tool-ui/approval-card"; import { cn } from "@/lib/utils"; import { useI18n } from "@/shared/i18n"; import { messageBus } from "@/shared/message"; import { requestGateway } from "@/shared/realtime"; -import { Button } from "@/shared/ui/button"; import { DASHBOARD_PANEL_SURFACE_CLASS } from "@/shared/ui/dashboard"; export type ToolUIFallbackCardProps = { @@ -136,22 +136,17 @@ export function ToolUIFallbackCard({ resolvedStatus.type === "requires-action" && approvalID !== "" && approvalDecision === ""; const isCancelled = resolvedStatus.type === "incomplete" && resolvedStatus.reason === "cancelled"; + const approvalTitle = approval?.action?.trim() || t("chat.tools.approvalTool.title"); + const approvalToolName = approval?.toolName?.trim() || resolvedToolName; + const approvalDescription = `${t("chat.tools.approvalTool.tool")}: ${approvalToolName}`; const [open, setOpen] = React.useState(false); - const [pendingDecision, setPendingDecision] = React.useState<"approve" | "deny" | "">(""); - - React.useEffect(() => { - if (!requiresApprovalAction) { - setPendingDecision(""); - } - }, [requiresApprovalAction, approvalID]); const resolveApproval = React.useCallback( async (decision: "approve" | "deny") => { - if (!approvalID || pendingDecision) { + if (!approvalID) { return; } - setPendingDecision(decision); try { await requestGateway("exec.approval.resolve", { id: approvalID, @@ -159,7 +154,6 @@ export function ToolUIFallbackCard({ reason: decision === "approve" ? "approved by aui fallback" : "denied by aui fallback", }); } catch (error) { - setPendingDecision(""); messageBus.publishToast({ intent: "danger", title: t("chat.tools.approvalTool.resolveError"), @@ -168,7 +162,7 @@ export function ToolUIFallbackCard({ }); } }, - [approvalID, pendingDecision, t] + [approvalID, t] ); return ( @@ -182,29 +176,17 @@ export function ToolUIFallbackCard({ > {requiresApprovalAction ? ( -
- - +
+ resolveApproval("approve")} + onCancel={() => resolveApproval("deny")} + />
) : null} diff --git a/frontend/src/features/settings/about/index.tsx b/frontend/src/features/settings/about/index.tsx index f07c897..490f1c8 100644 --- a/frontend/src/features/settings/about/index.tsx +++ b/frontend/src/features/settings/about/index.tsx @@ -24,7 +24,7 @@ import { SettingsCompactListCard, SettingsCompactRow, SettingsCompactSeparator } import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/shared/ui/tooltip"; import { cn } from "@/lib/utils"; import { messageBus } from "@/shared/message"; -import { useUpdateStore } from "@/shared/store/update"; +import { displayUpdateVersion, hasPreparedUpdate, hasRemoteUpdate, useUpdateStore } from "@/shared/store/update"; import { useCheckForUpdate, useDownloadUpdate, useRestartToApply, useUpdateState } from "@/shared/query/update"; import { Browser } from "@wailsio/runtime"; @@ -44,23 +44,27 @@ export function AboutSection() { const isChecking = updateInfo.status === "checking" || checkUpdate.isPending; const isError = updateInfo.status === "error"; - const hasUpdate = - updateInfo.status === "available" || - updateInfo.status === "downloading" || - updateInfo.status === "installing" || - updateInfo.status === "ready_to_restart"; + const hasPrepared = hasPreparedUpdate(updateInfo); + const hasRemote = hasRemoteUpdate(updateInfo); + const hasKnownPendingUpdate = hasPrepared || hasRemote; const isDownloading = updateInfo.status === "downloading" || updateInfo.status === "installing"; - const isReadyToRestart = updateInfo.status === "ready_to_restart"; - const releaseNotes = (updateInfo.changelog ?? "").trim(); + const isReadyToRestart = updateInfo.status === "ready_to_restart" && hasPrepared; + const releaseNotes = ((isReadyToRestart ? updateInfo.preparedChangelog : updateInfo.changelog) ?? "").trim(); const hasReleaseNotes = releaseNotes.length > 0; const errorMessage = (updateInfo.message ?? "").trim(); - const hasKnownLatestVersion = - updateInfo.latestVersion.trim().length > 0 && updateInfo.latestVersion.trim() !== updateInfo.currentVersion.trim(); - const showLatestUpdate = hasUpdate || (hasKnownLatestVersion && (isError || updateInfo.status === "available")); + const showLatestUpdate = hasKnownPendingUpdate || isDownloading || isReadyToRestart; const showStatusRow = isDownloading || (isError && errorMessage.length > 0); + const checkLabel = hasKnownPendingUpdate ? t("settings.about.update.recheck") : t("settings.about.update.check"); + const installLabel = t("settings.about.update.downloadAndInstall"); + const restartLabel = t("settings.about.update.restartAfterUpdate"); + const showCheckAction = !isReadyToRestart && !isDownloading; + const showInstallAction = + !isReadyToRestart && + !isDownloading && + (updateInfo.status === "available" || (isError && hasRemote && !hasPrepared)); const latestLabel = (() => { - if (showLatestUpdate) return updateInfo.latestVersion || t("settings.about.update.latestAvailable"); + if (showLatestUpdate) return displayUpdateVersion(updateInfo) || t("settings.about.update.latestAvailable"); if (isError) return t("settings.about.update.latestFailed"); return t("settings.about.update.latestOk"); })(); @@ -237,25 +241,27 @@ export function AboutSection() {
- - - - - - {t("settings.about.update.check")} - - + {showCheckAction ? ( + + + + + + {checkLabel} + + + ) : null} - {hasUpdate && !isReadyToRestart ? ( + {showInstallAction ? ( @@ -264,17 +270,17 @@ export function AboutSection() { size="compact" onClick={handleInstall} disabled={downloadUpdate.isPending || isDownloading || restartToApply.isPending} - aria-label={t("settings.about.update.install")} + aria-label={installLabel} > {downloadUpdate.isPending || isDownloading ? ( ) : ( )} - {t("settings.about.update.install")} + {installLabel} - {t("settings.about.update.install")} + {installLabel} ) : null} @@ -288,17 +294,17 @@ export function AboutSection() { size="compact" onClick={handleRestart} disabled={restartToApply.isPending} - aria-label={t("settings.about.update.restart")} + aria-label={restartLabel} > {restartToApply.isPending ? ( ) : ( )} - {t("settings.about.update.restart")} + {restartLabel} - {t("settings.about.update.restart")} + {restartLabel} ) : null} diff --git a/frontend/src/features/settings/calls/components/CallsToolsTab.tsx b/frontend/src/features/settings/calls/components/CallsToolsTab.tsx index 8eac49c..f85efd5 100644 --- a/frontend/src/features/settings/calls/components/CallsToolsTab.tsx +++ b/frontend/src/features/settings/calls/components/CallsToolsTab.tsx @@ -23,14 +23,19 @@ import { useGatewayTools } from "@/shared/query/tools"; import type { GatewayToolMethodSpec } from "@/shared/store/gatewayTools"; import { CallsCard } from "./CallsCard"; import { ToolMethodIOPanel } from "./ToolMethodIOPanel"; +import { + normalizeRuntimeBrowserCandidates, + type RuntimeDetectionRow, +} from "./calls-tools-runtime-detection"; import { ToolConfigCard, ToolConfigEmptyState, ToolConfigTabPanel, ToolContentTabs, - type ToolDependencyStatus, ToolDetailLayout, ToolIOTabPanel, + type ToolPermissionBadge, + type ToolRequirementItem, ToolOverviewCard, } from "./tool-detail-layout"; import type { ToolItem } from "../types"; @@ -52,11 +57,9 @@ import { } from "../utils/gateway-tool-utils"; import { DEFAULT_WEB_SEARCH_PROVIDERS, - normalizeWebFetchType, + normalizePreferredBrowser, normalizeWebSearchType, - parseNonNegativeNumberInput, parseNumberInput, - parseObjectJSON, parseStringArrayJSON, readBoolValue, readNumberValue, @@ -66,7 +69,6 @@ import { readWebSearchProviderApiKeys, resolveWebSearchAPIKeyPlaceholder, serializeWebSearchProviderApiKeys, - stringifyObjectValue, stringifyStringArrayValue, type BrowserControlFormState, type WebFetchFormState, @@ -109,42 +111,153 @@ export function CallsToolsTab() { (category: string) => t(`settings.tools.category.${category}`), [t] ); - const resolveToolDependencies = React.useCallback( + const resolveWebSearchProviderLabel = React.useCallback( + (providerID: string) => { + switch (providerID.trim().toLowerCase()) { + case "brave": + return t("settings.tools.requirements.providers.brave"); + case "tavily": + return t("settings.tools.requirements.providers.tavily"); + case "perplexity": + return t("settings.tools.requirements.providers.perplexity"); + case "grok": + return t("settings.tools.requirements.providers.grok"); + default: + return providerID; + } + }, + [t] + ); + const resolveWebSearchModeLabel = React.useCallback( + (mode: string) => { + return normalizeWebSearchType(mode) === "external_tools" + ? t("settings.tools.webSearch.typeValue.externalTools") + : t("settings.tools.webSearch.typeValue.api"); + }, + [t] + ); + const resolveRequirementName = React.useCallback( + (requirementID: string, fallbackName: string) => { + switch (requirementID) { + case "gateway.control_plane_enabled": + return t("settings.tools.requirements.gatewayControlPlane"); + case "browser.cdp_runtime": + case "web_fetch.local_browser": + return t("settings.tools.requirements.localCDPBrowser"); + case "web_search.mode_supported": + return t("settings.tools.requirements.webSearchMode"); + case "web_search.provider_supported": + return t("settings.tools.requirements.webSearchProvider"); + case "web_search.provider_api_key": + return t("settings.tools.requirements.webSearchProviderApiKey"); + case "web_search.external_tools_supported": + return t("settings.tools.requirements.externalToolsRuntime"); + case "image.model_runtime": + return t("settings.tools.requirements.imageModel"); + case "tts.voice_service": + return t("settings.tools.requirements.voiceService"); + case "tts.voice_enabled": + return t("settings.tools.requirements.voiceFeature"); + case "tts.provider_supported": + return t("settings.tools.requirements.ttsProvider"); + case "tts.provider_api_key": + return t("settings.tools.requirements.ttsProviderApiKey"); + case "tts.voice_id": + return t("settings.tools.requirements.ttsVoiceId"); + case "canvas.remote_runtime": + case "nodes.remote_runtime": + return t("settings.tools.requirements.remoteNodeRuntime"); + default: + return fallbackName; + } + }, + [t] + ); + const resolveRequirementReason = React.useCallback( + (reason: string) => { + const trimmed = reason.trim(); + const normalized = trimmed.toLowerCase(); + const providerUnsupportedMatch = trimmed.match(/^(.+?)\s+is not supported in api mode$/i); + if (providerUnsupportedMatch) { + return t("settings.tools.reason.webSearchProviderUnsupported").replace( + "{provider}", + resolveWebSearchProviderLabel(providerUnsupportedMatch[1]?.trim() || "") + ); + } + const providerAPIKeyMissingMatch = normalized === "tts provider api key is missing" + ? null + : trimmed.match(/^(.+?)\s+api key is missing$/i); + if (providerAPIKeyMissingMatch) { + return t("settings.tools.reason.webSearchProviderApiKeyMissing").replace( + "{provider}", + resolveWebSearchProviderLabel(providerAPIKeyMissingMatch[1]?.trim() || "") + ); + } + switch (normalized) { + case "browser executable not found": + return t("settings.tools.runtimeDetection.notInstalled"); + case "no supported browser detected": + return t("settings.tools.runtimeDetection.noneDetected"); + case "browser process exited": + return t("settings.tools.reason.browserProcessExited"); + case "image model is not configured": + return t("settings.tools.reason.imageModelNotConfigured"); + case "provider repositories are unavailable": + return t("settings.tools.reason.providerRepositoriesUnavailable"); + case "control plane is disabled": + return t("settings.tools.reason.gatewayControlPlaneDisabled"); + case "search mode is not supported": + return t("settings.tools.reason.webSearchModeUnsupported"); + case "external tools mode is not implemented": + return t("settings.tools.reason.webSearchExternalToolsUnavailable"); + case "remote node runtime is not implemented yet": + return t("settings.tools.reason.remoteNodeRuntimeUnavailable"); + case "voice is disabled": + return t("settings.tools.reason.voiceDisabled"); + case "voice service unavailable": + return t("settings.tools.reason.voiceServiceUnavailable"); + case "tts provider api key is missing": + return t("settings.tools.reason.ttsProviderApiKeyMissing"); + case "tts voice id is not configured": + return t("settings.tools.reason.ttsVoiceIdMissing"); + case "edge-tts provider is not implemented yet": + return t("settings.tools.reason.ttsEdgeProviderUnavailable"); + case "tts provider is not supported": + return t("settings.tools.reason.ttsProviderUnsupported"); + default: + return reason || t("settings.tools.reason.unavailable"); + } + }, + [resolveWebSearchProviderLabel, t] + ); + const resolveToolPermissionBadges = React.useCallback( (tool: ToolItem) => { - const dependencies: ToolDependencyStatus[] = - !tool.requirements || tool.requirements.length === 0 - ? [] - : tool.requirements.map((requirement) => { - const fallbackName = requirement.name || requirement.id; - const fallbackReason = requirement.reason || t("settings.tools.reason.unavailable"); - return { - id: requirement.id, - name: fallbackName, - ok: requirement.available, - reason: fallbackReason, - }; - }); const riskLevel = (tool.riskLevel ?? "").trim().toLowerCase(); const needsApproval = tool.requiresApproval === true || riskLevel === "high"; const needsSandbox = tool.requiresSandbox === true || riskLevel === "high"; - const permissionBadges: string[] = []; + const permissionBadges: ToolPermissionBadge[] = []; if (needsApproval) { - permissionBadges.push(t("settings.tools.detail.permissions.badges.approval")); + permissionBadges.push({ + id: "approval", + label: t("settings.tools.detail.permissions.badges.approval"), + tone: "warning", + }); } if (needsSandbox) { - permissionBadges.push(t("settings.tools.detail.permissions.badges.sandbox")); + permissionBadges.push({ + id: "sandbox", + label: t("settings.tools.detail.permissions.badges.sandbox"), + tone: "info", + }); } if (permissionBadges.length === 0) { - permissionBadges.push(t("settings.tools.detail.permissions.badges.none")); + permissionBadges.push({ + id: "none", + label: t("settings.tools.detail.permissions.badges.none"), + tone: "neutral", + }); } - dependencies.push({ - id: "__permission__", - name: t("settings.tools.detail.permissions.label"), - ok: true, - reason: "", - badges: permissionBadges, - }); - return dependencies; + return permissionBadges; }, [t] ); @@ -176,6 +289,17 @@ export function CallsToolsTab() { const base = toolItems.filter((tool) => baseToolIds.includes(normalizeToolId(tool.id))); return base.length > 0 ? base : toolItems; }, [toolItems]); + const isInitialToolsLoad = gatewayToolsQuery.isLoading && filteredTools.length === 0; + const toolsLoadError = React.useMemo(() => { + if (!gatewayToolsQuery.isError || filteredTools.length > 0) { + return ""; + } + const error = gatewayToolsQuery.error; + if (error instanceof Error) { + return error.message; + } + return String(error ?? "").trim(); + }, [filteredTools.length, gatewayToolsQuery.error, gatewayToolsQuery.isError]); const groupedTools = React.useMemo(() => { const groups = new Map(); filteredTools.forEach((tool) => { @@ -367,47 +491,143 @@ export function CallsToolsTab() { ); const skipNextWebSearchBlurSaveRef = React.useRef(false); - const initialWebFetchForm = React.useMemo(() => { - const headers = readObjectValue(webFetchConfig, "headers"); - const playwright = readObjectValue(webFetchConfig, "playwright"); - return { - type: normalizeWebFetchType(readStringValue(webFetchConfig, "type", "builtin")), - playwrightMarkdown: readBoolValue(playwright, "markdown", true), - acceptMarkdown: readBoolValue(webFetchConfig, "acceptMarkdown", true), - enableUserAgent: readBoolValue(webFetchConfig, "enableUserAgent", true), - userAgent: readStringValue(webFetchConfig, "userAgent", ""), - acceptLanguage: readStringValue(webFetchConfig, "acceptLanguage", ""), - timeoutSeconds: readNumberValue(webFetchConfig, "timeoutSeconds"), - maxChars: readNumberValue(webFetchConfig, "maxChars"), - maxRedirects: readNumberValue(webFetchConfig, "maxRedirects"), - retryMax: readNumberValue(webFetchConfig, "retryMax"), - headersJson: stringifyObjectValue(headers), - }; - }, [webFetchConfig]); + const initialWebFetchForm = React.useMemo(() => ({ + headless: readBoolValue(webFetchConfig, "headless", true), + preferredBrowser: normalizePreferredBrowser(readStringValue(webFetchConfig, "preferredBrowser", "chrome")), + timeoutSeconds: readNumberValue(webFetchConfig, "timeoutSeconds"), + maxChars: readNumberValue(webFetchConfig, "maxChars"), + }), [webFetchConfig]); const [webFetchForm, setWebFetchForm] = React.useState(initialWebFetchForm); const skipNextWebFetchBlurSaveRef = React.useRef(false); const initialBrowserForm = React.useMemo(() => { - const snapshotDefaults = readObjectValue(browserConfig, "snapshotDefaults"); const ssrfPolicy = readObjectValue(browserConfig, "ssrfPolicy"); return { enabled: readBoolValue(browserConfig, "enabled", true), - evaluateEnabled: readBoolValue(browserConfig, "evaluateEnabled", true), - headless: readBoolValue(browserConfig, "headless", false), - noSandbox: readBoolValue(browserConfig, "noSandbox", false), - snapshotDefaultMode: readStringValue(snapshotDefaults, "mode", ""), + headless: readBoolValue(browserConfig, "headless", true), + preferredBrowser: normalizePreferredBrowser(readStringValue(browserConfig, "preferredBrowser", "chrome")), ssrfDangerouslyAllowPrivateNetwork: readBoolValue( ssrfPolicy, "dangerouslyAllowPrivateNetwork", - true + false ), ssrfAllowedHostnamesJson: stringifyStringArrayValue(ssrfPolicy?.allowedHostnames), ssrfHostnameAllowlistJson: stringifyStringArrayValue(ssrfPolicy?.hostnameAllowlist), - extraArgsJson: stringifyStringArrayValue(browserConfig?.extraArgs), }; }, [browserConfig]); const [browserForm, setBrowserForm] = React.useState(initialBrowserForm); const skipNextBrowserBlurSaveRef = React.useRef(false); + const resolveRequirementValue = React.useCallback( + (requirement: ToolItem["requirements"][number]) => { + const fallbackReason = resolveRequirementReason(requirement.reason || ""); + const requirementData = isRecord(requirement.data) ? requirement.data : undefined; + const resolveProviderLabel = (providerID: string) => { + switch (providerID) { + case "openai": + return t("settings.tools.requirements.providers.openai"); + case "elevenlabs": + return t("settings.tools.requirements.providers.elevenlabs"); + case "edge": + return t("settings.tools.requirements.providers.edge"); + default: + return providerID; + } + }; + switch (requirement.id) { + case "gateway.control_plane_enabled": + return requirement.available + ? t("settings.tools.requirements.values.enabled") + : t("settings.tools.requirements.values.disabled"); + case "browser.cdp_runtime": + case "web_fetch.local_browser": + return requirement.available + ? t("settings.tools.requirements.values.detected") + : fallbackReason; + case "web_search.mode_supported": { + const mode = readStringValue(requirementData, "mode", readStringValue(webSearchConfig, "type", webSearchForm.type)); + return requirement.available + ? resolveWebSearchModeLabel(mode) + : (fallbackReason || t("settings.tools.requirements.values.unavailable")); + } + case "web_search.provider_supported": { + const providerID = readStringValue(requirementData, "providerId", webSearchForm.provider).trim().toLowerCase(); + const providerLabel = providerID ? resolveWebSearchProviderLabel(providerID) : ""; + return requirement.available + ? (providerLabel || t("settings.tools.requirements.values.available")) + : (fallbackReason || providerLabel || t("settings.tools.requirements.values.unavailable")); + } + case "web_search.provider_api_key": + return requirement.available + ? t("settings.tools.requirements.values.configured") + : t("settings.tools.requirements.values.missing"); + case "web_search.external_tools_supported": + return requirement.available + ? t("settings.tools.requirements.values.available") + : (fallbackReason || t("settings.tools.requirements.values.unavailable")); + case "image.model_runtime": + return requirement.available + ? t("settings.tools.requirements.values.configured") + : t("settings.tools.requirements.values.notConfigured"); + case "tts.voice_enabled": + return requirement.available + ? t("settings.tools.requirements.values.enabled") + : t("settings.tools.requirements.values.disabled"); + case "tts.provider_supported": { + const providerID = readStringValue(requirementData, "providerId", "").trim().toLowerCase(); + const providerLabel = providerID ? resolveProviderLabel(providerID) : ""; + return requirement.available + ? (providerLabel || t("settings.tools.requirements.values.available")) + : (fallbackReason || providerLabel || t("settings.tools.requirements.values.unavailable")); + } + case "tts.provider_api_key": + return requirement.available + ? t("settings.tools.requirements.values.configured") + : t("settings.tools.requirements.values.missing"); + case "tts.voice_id": { + const voiceID = readStringValue(requirementData, "value", "").trim(); + return requirement.available + ? (voiceID || t("settings.tools.requirements.values.configured")) + : t("settings.tools.requirements.values.notConfigured"); + } + case "tts.voice_service": + return requirement.available + ? t("settings.tools.requirements.values.available") + : fallbackReason; + default: + return requirement.available + ? t("settings.tools.requirements.values.available") + : (fallbackReason || t("settings.tools.requirements.values.unavailable")); + } + }, + [ + resolveRequirementReason, + resolveWebSearchModeLabel, + resolveWebSearchProviderLabel, + t, + webSearchConfig, + webSearchForm.provider, + webSearchForm.type, + ] + ); + const resolveToolRequirements = React.useCallback( + (tool: ToolItem) => { + const requirements: ToolRequirementItem[] = + !tool.requirements || tool.requirements.length === 0 + ? [] + : tool.requirements.map((requirement) => { + const fallbackName = requirement.name || requirement.id; + return { + id: requirement.id, + name: resolveRequirementName(requirement.id, fallbackName), + value: resolveRequirementValue(requirement), + tone: requirement.available ? "success" : "danger", + }; + }); + return requirements; + }, + [resolveRequirementName, resolveRequirementValue] + ); + React.useEffect(() => { setWebSearchForm(initialWebSearchForm); }, [initialWebSearchForm]); @@ -648,15 +868,6 @@ export function CallsToolsTab() { const handleSaveWebFetch = React.useCallback((formState?: WebFetchFormState) => { const currentForm = formState ?? webFetchForm; - const parsedHeaders = parseObjectJSON(currentForm.headersJson); - if (parsedHeaders.error) { - messageBus.publishToast({ - intent: "danger", - title: t("settings.tools.webFetch.headersInvalid"), - description: t("settings.tools.webFetch.headersInvalidDesc"), - }); - return false; - } const nextToolsConfig: Record = { ...toolsConfig }; const nextWeb = isRecord(nextToolsConfig.web) ? { ...(nextToolsConfig.web as Record) } @@ -672,23 +883,17 @@ export function CallsToolsTab() { } target[key] = value; }; - target.type = normalizeWebFetchType(currentForm.type); - target.playwright = { - markdown: currentForm.playwrightMarkdown, - }; - target.acceptMarkdown = currentForm.acceptMarkdown; - target.enableUserAgent = currentForm.enableUserAgent; - setOrDelete("userAgent", currentForm.userAgent.trim()); - setOrDelete("acceptLanguage", currentForm.acceptLanguage.trim()); + target.headless = currentForm.headless; + target.preferredBrowser = normalizePreferredBrowser(currentForm.preferredBrowser); setOrDelete("timeoutSeconds", parseNumberInput(currentForm.timeoutSeconds)); setOrDelete("maxChars", parseNumberInput(currentForm.maxChars)); - setOrDelete("maxRedirects", parseNonNegativeNumberInput(currentForm.maxRedirects)); - setOrDelete("retryMax", parseNonNegativeNumberInput(currentForm.retryMax)); - if (parsedHeaders.value && Object.keys(parsedHeaders.value).length > 0) { - target.headers = parsedHeaders.value; - } else { - delete target.headers; - } + delete target.acceptMarkdown; + delete target.enableUserAgent; + delete target.userAgent; + delete target.acceptLanguage; + delete target.headers; + delete target.maxRedirects; + delete target.retryMax; delete target.enabled; }; applyWebFetchValues(nextTopLevelFetch); @@ -709,7 +914,7 @@ export function CallsToolsTab() { } ); return true; - }, [gatewayToolsQuery, toolsConfig, updateSettings, webFetchForm, t]); + }, [gatewayToolsQuery, toolsConfig, updateSettings, webFetchForm]); const handleSaveBrowser = React.useCallback((formState?: BrowserControlFormState) => { const currentForm = formState ?? browserForm; @@ -731,32 +936,19 @@ export function CallsToolsTab() { }); return false; } - const parsedExtraArgs = parseStringArrayJSON(currentForm.extraArgsJson); - if (parsedExtraArgs.error) { - messageBus.publishToast({ - intent: "danger", - title: t("settings.tools.browserControl.arrayInvalid"), - description: t("settings.tools.browserControl.arrayInvalidDesc"), - }); - return false; - } const nextToolsConfig: Record = { ...toolsConfig }; const nextBrowser = isRecord(nextToolsConfig.browser) ? { ...(nextToolsConfig.browser as Record) } : {}; nextBrowser.enabled = currentForm.enabled; - nextBrowser.evaluateEnabled = currentForm.evaluateEnabled; nextBrowser.headless = currentForm.headless; - nextBrowser.noSandbox = currentForm.noSandbox; + nextBrowser.preferredBrowser = normalizePreferredBrowser(currentForm.preferredBrowser); delete nextBrowser.executablePath; - - const snapshotMode = currentForm.snapshotDefaultMode.trim(); - if (snapshotMode) { - nextBrowser.snapshotDefaults = { mode: snapshotMode }; - } else { - delete nextBrowser.snapshotDefaults; - } + delete nextBrowser.evaluateEnabled; + delete nextBrowser.noSandbox; + delete nextBrowser.snapshotDefaults; + delete nextBrowser.extraArgs; const nextSSRFRules: Record = { dangerouslyAllowPrivateNetwork: currentForm.ssrfDangerouslyAllowPrivateNetwork, @@ -769,12 +961,6 @@ export function CallsToolsTab() { } nextBrowser.ssrfPolicy = nextSSRFRules; - if (parsedExtraArgs.value && parsedExtraArgs.value.length > 0) { - nextBrowser.extraArgs = parsedExtraArgs.value; - } else { - delete nextBrowser.extraArgs; - } - nextToolsConfig.browser = nextBrowser; const payload = nextToolsConfig; updateSettings.mutate( @@ -861,13 +1047,90 @@ export function CallsToolsTab() { }, [handleSaveWebSearch, webSearchDisabled, webSearchForm, webSearchProviderAPIKeys] ); - const handleWebFetchTypeChange = React.useCallback( - (nextType: string) => { - const normalizedType = normalizeWebFetchType(nextType); - if (webFetchForm.type === normalizedType) { + const handleWebSearchProviderChange = React.useCallback( + (nextProvider: string) => { + const nextProviderApiKey = resolveWebSearchDraftApiKey(nextProvider); + const nextForm: WebSearchFormState = { + ...webSearchForm, + provider: nextProvider, + apiKey: nextProviderApiKey, + }; + setWebSearchForm(nextForm); + if (webSearchDisabled) { + return; + } + handleSaveWebSearch(nextForm, webSearchProviderAPIKeys); + }, + [ + handleSaveWebSearch, + resolveWebSearchDraftApiKey, + webSearchDisabled, + webSearchForm, + webSearchProviderAPIKeys, + ] + ); + const handleBrowserPreferredBrowserChange = React.useCallback( + (value: string) => { + const nextForm: BrowserControlFormState = { + ...browserForm, + preferredBrowser: normalizePreferredBrowser(value), + }; + setBrowserForm(nextForm); + if (browserDisabled) { + return; + } + handleSaveBrowser(nextForm); + }, + [browserDisabled, browserForm, handleSaveBrowser] + ); + const handleBrowserHeadlessChange = React.useCallback( + (checked: boolean) => { + const nextForm: BrowserControlFormState = { + ...browserForm, + headless: Boolean(checked), + }; + setBrowserForm(nextForm); + if (browserDisabled) { + return; + } + handleSaveBrowser(nextForm); + }, + [browserDisabled, browserForm, handleSaveBrowser] + ); + const handleBrowserPrivateNetworkChange = React.useCallback( + (checked: boolean) => { + const nextForm: BrowserControlFormState = { + ...browserForm, + ssrfDangerouslyAllowPrivateNetwork: Boolean(checked), + }; + setBrowserForm(nextForm); + if (browserDisabled) { + return; + } + handleSaveBrowser(nextForm); + }, + [browserDisabled, browserForm, handleSaveBrowser] + ); + const handleWebFetchPreferredBrowserChange = React.useCallback( + (value: string) => { + const nextForm: WebFetchFormState = { + ...webFetchForm, + preferredBrowser: normalizePreferredBrowser(value), + }; + setWebFetchForm(nextForm); + if (webFetchDisabled) { return; } - const nextForm: WebFetchFormState = { ...webFetchForm, type: normalizedType }; + handleSaveWebFetch(nextForm); + }, + [handleSaveWebFetch, webFetchDisabled, webFetchForm] + ); + const handleWebFetchHeadlessChange = React.useCallback( + (checked: boolean) => { + const nextForm: WebFetchFormState = { + ...webFetchForm, + headless: Boolean(checked), + }; setWebFetchForm(nextForm); if (webFetchDisabled) { return; @@ -899,10 +1162,60 @@ export function CallsToolsTab() {
); }, []); + const renderRuntimeDetectionCard = React.useCallback((rows: RuntimeDetectionRow[]) => { + return ( +
+
+ {rows.map((item, index) => { + const rowSpacingClass = rows.length === 1 + ? "" + : index === 0 + ? "pb-2" + : index === rows.length - 1 + ? "pt-2" + : "py-2"; + return ( +
+ {item.label} + {item.badge ? ( + + + {item.value} + + ) : ( + + {item.value} + + )} +
+ ); + })} +
+
+ ); + }, []); return ( + {t("settings.gateway.tools.loading")} +
+ ) : toolsLoadError ? ( +
+ {toolsLoadError} +
+ ) : filteredTools.length === 0 ? (
{t("settings.tools.list.empty")}
@@ -941,7 +1254,15 @@ export function CallsToolsTab() { ) } rightContent={ - !selectedTool ? ( + isInitialToolsLoad ? ( +
+ {t("settings.gateway.tools.loading")} +
+ ) : toolsLoadError ? ( +
+ {toolsLoadError} +
+ ) : !selectedTool ? (
{t("settings.tools.list.empty")}
@@ -952,19 +1273,61 @@ export function CallsToolsTab() { const isWebFetchTool = selectedTool.id === "web_fetch"; const isGatewayTool = selectedTool.id === "gateway"; const isBrowserTool = selectedTool.id === "browser"; + const isCanvasTool = selectedTool.id === "canvas"; + const isNodesTool = selectedTool.id === "nodes"; const gatewayConfig = isGatewayTool && isRecord(toolsConfig.gateway) ? (toolsConfig.gateway as Record) : undefined; const hasGatewayConfig = Boolean(gatewayConfig && Object.keys(gatewayConfig).length > 0); - const toolDependencies = resolveToolDependencies(selectedTool); - const browserPlaywrightExecutablePath = isBrowserTool - ? (selectedTool.requirements ?? []).find( - (requirement) => - requirement.id === "browser.playwright_runtime" && - requirement.available && - requirement.reason.trim() !== "" - )?.reason ?? "" - : ""; + const toolRequirements = resolveToolRequirements(selectedTool); + const toolPermissions = resolveToolPermissionBadges(selectedTool); + const toolToggleDisabledReason = (isCanvasTool || isNodesTool) + ? t("settings.tools.reason.remoteNodeRuntimeUnavailable") + : undefined; + const toolToggleDisabled = gatewayToolEnablePending || isCanvasTool || isNodesTool; + const browserRuntimeRequirement = + isBrowserTool || isWebFetchTool + ? (selectedTool.requirements ?? []).find((requirement) => + requirement.id === (isBrowserTool ? "browser.cdp_runtime" : "web_fetch.local_browser") + ) + : undefined; + const browserRuntimeData = browserRuntimeRequirement && isRecord(browserRuntimeRequirement.data) + ? (browserRuntimeRequirement.data as Record) + : undefined; + const browserCandidates = normalizeRuntimeBrowserCandidates(browserRuntimeData?.candidates); + const availableBrowserCandidates = browserCandidates.filter((candidate) => candidate.available); + const browserSelectOptions = availableBrowserCandidates; + const webFetchPreferredBrowserValue = browserSelectOptions.some( + (candidate) => candidate.id === webFetchForm.preferredBrowser + ) + ? webFetchForm.preferredBrowser + : (browserSelectOptions[0]?.id ?? ""); + const browserPreferredBrowserValue = browserSelectOptions.some( + (candidate) => candidate.id === browserForm.preferredBrowser + ) + ? browserForm.preferredBrowser + : (browserSelectOptions[0]?.id ?? ""); + const runtimeDetectionRows: RuntimeDetectionRow[] = browserCandidates.map((candidate) => { + const normalizedError = candidate.error.trim().toLowerCase(); + if (candidate.available) { + return { + label: candidate.label, + value: candidate.execPath || t("settings.tools.runtimeDetection.detected"), + }; + } + if (normalizedError.includes("browser executable not found")) { + return { + label: candidate.label, + value: t("settings.tools.runtimeDetection.notInstalled"), + badge: "not_installed", + }; + } + return { + label: candidate.label, + value: t("settings.tools.runtimeDetection.notDetected"), + badge: "not_detected", + }; + }); const webSearchProviderMeta = [ selectedWebSearchProvider?.apiBaseUrl ? `${t("settings.tools.webSearch.apiBase")}: ${selectedWebSearchProvider.apiBaseUrl}` @@ -987,11 +1350,14 @@ export function CallsToolsTab() { statusBadge={renderToolStatusBadge(status)} enabledLabel={t("settings.tools.detail.enabled")} enabled={selectedTool.available} - enabledDisabled={gatewayToolEnablePending} + enabledDisabled={toolToggleDisabled} + enabledDisabledReason={toolToggleDisabledReason} onEnabledChange={(enabled) => { void handleToggleToolEnabled(selectedTool, enabled); }} - dependencies={toolDependencies} + permissionsLabel={t("settings.tools.detail.permissions.label")} + permissions={toolPermissions} + requirements={toolRequirements} /> } content={ @@ -1043,11 +1409,14 @@ export function CallsToolsTab() { statusBadge={renderToolStatusBadge(status)} enabledLabel={t("settings.tools.detail.enabled")} enabled={selectedTool.available} - enabledDisabled={gatewayToolEnablePending} + enabledDisabled={toolToggleDisabled} + enabledDisabledReason={toolToggleDisabledReason} onEnabledChange={(enabled) => { void handleToggleToolEnabled(selectedTool, enabled); }} - dependencies={toolDependencies} + permissionsLabel={t("settings.tools.detail.permissions.label")} + permissions={toolPermissions} + requirements={toolRequirements} /> } content={ @@ -1091,15 +1460,8 @@ export function CallsToolsTab() { - setBrowserForm((prev) => ({ - ...prev, - snapshotDefaultMode: event.target.value, - })) - } - onBlur={handleBrowserFieldBlur} + value={browserPreferredBrowserValue} + onChange={(event) => { + handleBrowserPreferredBrowserChange(event.target.value); + }} className={webSearchControlClassName} - disabled={browserDisabled} + disabled={browserDisabled || browserSelectOptions.length === 0} > - - + {browserSelectOptions.length === 0 ? ( + + ) : browserSelectOptions.map((candidate) => ( + + ))}
-
- {renderWebSearchFieldLabel( - t("settings.tools.browserControl.extraArgs") - )} - - setBrowserForm((prev) => ({ ...prev, extraArgsJson: event.target.value })) - } - onBlur={handleBrowserFieldBlur} - placeholder='["--window-size=1920,1080","--disable-infobars"]' - className={webSearchControlClassName} - size="compact" - disabled={browserDisabled} - /> -
+ {renderRuntimeDetectionCard(runtimeDetectionRows)}
{renderWebSearchFieldLabel( - t("settings.tools.browserControl.executablePath"), - t("settings.tools.browserControl.executablePathDesc") + t("settings.tools.browserControl.headless"), + t("settings.tools.browserControl.headlessDesc") )} -
@@ -1540,13 +1827,7 @@ export function CallsToolsTab() { )} - setBrowserForm((prev) => ({ - ...prev, - ssrfDangerouslyAllowPrivateNetwork: Boolean(checked), - })) - } - onBlur={handleBrowserFieldBlur} + onCheckedChange={handleBrowserPrivateNetworkChange} disabled={browserDisabled} />
@@ -1644,11 +1925,14 @@ export function CallsToolsTab() { statusBadge={renderToolStatusBadge(status)} enabledLabel={t("settings.tools.detail.enabled")} enabled={selectedTool.available} - enabledDisabled={gatewayToolEnablePending} + enabledDisabled={toolToggleDisabled} + enabledDisabledReason={toolToggleDisabledReason} onEnabledChange={(enabled) => { void handleToggleToolEnabled(selectedTool, enabled); }} - dependencies={toolDependencies} + permissionsLabel={t("settings.tools.detail.permissions.label")} + permissions={toolPermissions} + requirements={toolRequirements} /> } content={ @@ -1663,219 +1947,87 @@ export function CallsToolsTab() {
{renderWebSearchFieldLabel( - t("settings.tools.webFetch.type"), - t("settings.tools.webFetch.typeDesc") + t("settings.tools.webFetch.preferredBrowser"), + t("settings.tools.webFetch.preferredBrowserDesc") )} - { + handleWebFetchPreferredBrowserChange(event.target.value); + }} + className={webSearchControlClassName} + disabled={webFetchDisabled || browserSelectOptions.length === 0} > - - - {t("settings.tools.webFetch.typeValue.playwright")} - - - {t("settings.tools.webFetch.typeValue.builtin")} - - - + {browserSelectOptions.length === 0 ? ( + + ) : browserSelectOptions.map((candidate) => ( + + ))} + +
+ + {renderRuntimeDetectionCard(runtimeDetectionRows)} + +
+ {renderWebSearchFieldLabel( + t("settings.tools.webFetch.headless"), + t("settings.tools.webFetch.headlessDesc") + )} + +
+ +
+ {renderWebSearchFieldLabel(t("settings.tools.webFetch.timeoutSeconds"))} + + setWebFetchForm((prev) => ({ ...prev, timeoutSeconds: event.target.value })) + } + onBlur={handleWebFetchFieldBlur} + placeholder="20" + className={webSearchControlClassName} + size="compact" + disabled={webFetchDisabled} + /> +
+ +
+ {renderWebSearchFieldLabel(t("settings.tools.webFetch.maxChars"))} + + setWebFetchForm((prev) => ({ ...prev, maxChars: event.target.value })) + } + onBlur={handleWebFetchFieldBlur} + placeholder="50000" + className={webSearchControlClassName} + size="compact" + disabled={webFetchDisabled} + /> +
+
+
- {webFetchForm.type === "playwright" ? ( - <> - -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.playwrightMarkdown"), - t("settings.tools.webFetch.playwrightMarkdownDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, playwrightMarkdown: Boolean(checked) })) - } - onBlur={handleWebFetchFieldBlur} - disabled={webFetchDisabled} - /> -
- -
-

- {t("settings.tools.webFetch.playwrightHint")} -

-
- - ) : ( - <> - -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.acceptMarkdown"), - t("settings.tools.webFetch.acceptMarkdownDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, acceptMarkdown: Boolean(checked) })) - } - onBlur={handleWebFetchFieldBlur} - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.enableUserAgent"), - t("settings.tools.webFetch.enableUserAgentDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, enableUserAgent: Boolean(checked) })) - } - onBlur={handleWebFetchFieldBlur} - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel(t("settings.tools.webFetch.userAgent"))} - - setWebFetchForm((prev) => ({ ...prev, userAgent: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled || !webFetchForm.enableUserAgent} - /> -
- -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.acceptLanguage") - )} - - setWebFetchForm((prev) => ({ ...prev, acceptLanguage: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="en-US,en;q=0.9" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel(t("settings.tools.webFetch.timeoutSeconds"))} - - setWebFetchForm((prev) => ({ ...prev, timeoutSeconds: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="20" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel(t("settings.tools.webFetch.maxChars"))} - - setWebFetchForm((prev) => ({ ...prev, maxChars: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="50000" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.maxRedirects"), - t("settings.tools.webFetch.maxRedirectsDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, maxRedirects: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="3" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.retryMax"), - t("settings.tools.webFetch.retryMaxDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, retryMax: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder="2" - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
- -
- {renderWebSearchFieldLabel( - t("settings.tools.webFetch.headers"), - t("settings.tools.webFetch.headersDesc") - )} - - setWebFetchForm((prev) => ({ ...prev, headersJson: event.target.value })) - } - onBlur={handleWebFetchFieldBlur} - placeholder='{"X-Test":"1"}' - className={webSearchControlClassName} - size="compact" - disabled={webFetchDisabled} - /> -
-
- -
- - )}
diff --git a/frontend/src/features/settings/calls/components/calls-tools-runtime-detection.ts b/frontend/src/features/settings/calls/components/calls-tools-runtime-detection.ts new file mode 100644 index 0000000..ff9b89d --- /dev/null +++ b/frontend/src/features/settings/calls/components/calls-tools-runtime-detection.ts @@ -0,0 +1,43 @@ +import { isRecord } from "../utils/calls-utils"; +import { readBoolValue, readStringValue } from "../utils/web-tool-settings-utils"; + +export type RuntimeBrowserCandidate = { + id: string; + label: string; + available: boolean; + execPath: string; + error: string; +}; + +export type RuntimeDetectionRow = { + label: string; + value: string; + badge?: "not_installed" | "not_detected"; +}; + +const BROWSER_LABELS: Record = { + chrome: "Chrome", + chromium: "Chromium", + edge: "Edge", + brave: "Brave", +}; + +export const normalizeRuntimeBrowserCandidates = (value: unknown): RuntimeBrowserCandidate[] => { + if (!Array.isArray(value)) { + return []; + } + return value.flatMap((item) => { + if (!isRecord(item)) { + return []; + } + const id = readStringValue(item, "id", "").trim().toLowerCase(); + const fallbackLabel = id ? (BROWSER_LABELS[id] ?? id) : "Browser"; + return [{ + id, + label: readStringValue(item, "label", fallbackLabel).trim() || fallbackLabel, + available: readBoolValue(item, "available", false), + execPath: readStringValue(item, "execPath", "").trim(), + error: readStringValue(item, "error", "").trim(), + }]; + }); +}; diff --git a/frontend/src/features/settings/calls/components/tool-detail-layout.tsx b/frontend/src/features/settings/calls/components/tool-detail-layout.tsx index a946a8b..80a0d1c 100644 --- a/frontend/src/features/settings/calls/components/tool-detail-layout.tsx +++ b/frontend/src/features/settings/calls/components/tool-detail-layout.tsx @@ -1,10 +1,18 @@ import * as React from "react"; -import { Check, HelpCircle } from "lucide-react"; +import { HelpCircle } from "lucide-react"; import { Badge } from "@/shared/ui/badge"; import { Empty, EmptyDescription, EmptyHeader, EmptyMedia, EmptyTitle } from "@/shared/ui/empty"; import { Card, CardContent, CardHeader, CardTitle } from "@/shared/ui/card"; -import { Separator } from "@/shared/ui/separator"; +import { + SETTINGS_ROW_CONTENT_BASE_CLASS, + SETTINGS_ROW_CLASS, + SETTINGS_ROW_LABEL_CLASS, + SettingsCompactListCard, + SettingsCompactRow, + SettingsCompactSeparator, + SettingsSeparator, +} from "@/shared/ui/settings-layout"; import { Switch } from "@/shared/ui/switch"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/shared/ui/tabs"; import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/shared/ui/tooltip"; @@ -60,12 +68,17 @@ function ToolConfigEmptyIllustration() { ); } -export type ToolDependencyStatus = { +export type ToolRequirementItem = { id: string; name: string; - ok: boolean; - reason: string; - badges?: string[]; + value: string; + tone?: "neutral" | "success" | "warning" | "danger"; +}; + +export type ToolPermissionBadge = { + id: string; + label: string; + tone?: "neutral" | "info" | "warning"; }; export function ToolDetailLayout({ @@ -76,7 +89,7 @@ export function ToolDetailLayout({ content: React.ReactNode; }) { return ( -
+
{overview} {content}
@@ -95,7 +108,7 @@ function ToolIdentityBlock({ return (
-
{title}
+
{title}
@@ -117,66 +130,121 @@ function ToolIdentityBlock({ ); } +function ToolHeaderMetaRow({ + label, + value, +}: { + label: React.ReactNode; + value: React.ReactNode; +}) { + return ( +
+
{label}
+
+ {value} +
+
+ ); +} + function ToolStatusControlBlock({ statusBadge, enabledLabel, enabled, disabled, + disabledReason, onEnabledChange, }: { statusBadge: React.ReactNode; enabledLabel: string; enabled: boolean; disabled?: boolean; + disabledReason?: string; onEnabledChange: (enabled: boolean) => void; }) { + const switchControl = ; return ( -
+
{statusBadge}
- {enabledLabel} - + {enabledLabel} + {disabled && disabledReason ? ( + + + + {switchControl} + + {disabledReason} + + + ) : switchControl}
); } -function ToolDependencyItem({ item }: { item: ToolDependencyStatus }) { - const badges = (item.badges ?? []).map((entry) => entry.trim()).filter((entry) => entry !== ""); +function permissionBadgeClassName(tone: ToolPermissionBadge["tone"]) { + switch (tone) { + case "info": + return "border-sky-500/25 bg-sky-500/10 text-sky-700 dark:border-sky-400/25 dark:bg-sky-400/10 dark:text-sky-200"; + case "warning": + return "border-amber-500/25 bg-amber-500/10 text-amber-800 dark:border-amber-400/25 dark:bg-amber-400/10 dark:text-amber-100"; + case "neutral": + default: + return "border-border/70 bg-muted text-muted-foreground"; + } +} + +function requirementValueClassName(tone: ToolRequirementItem["tone"]) { + switch (tone) { + case "success": + return "text-emerald-700 dark:text-emerald-300"; + case "warning": + return "text-amber-800 dark:text-amber-100"; + case "danger": + return "text-destructive"; + case "neutral": + default: + return "text-muted-foreground"; + } +} + +function ToolPermissionsValue({ badges }: { badges: ToolPermissionBadge[] }) { return ( -
-
{item.name}
- {badges.length > 0 ? ( -
- {badges.map((badge, index) => ( - - {badge} - - ))} -
- ) : item.ok ? ( - - ) : ( - {item.reason} - )} -
+ <> + {badges.map((badge) => ( + + {badge.label} + + ))} + ); } -function ToolDependenciesBlock({ items }: { items: ToolDependencyStatus[] }) { +function ToolRequirementCard({ items }: { items: ToolRequirementItem[] }) { if (items.length === 0) { return null; } return ( - <> - + {items.map((item, index) => ( - - {index < items.length - 1 ? : null} + {index > 0 ? : null} + + + {item.value} + + ))} - + ); } @@ -188,8 +256,11 @@ export function ToolOverviewCard({ enabledLabel, enabled, enabledDisabled, + enabledDisabledReason, onEnabledChange, - dependencies, + permissionsLabel, + permissions, + requirements, }: { title: string; description: string; @@ -198,25 +269,32 @@ export function ToolOverviewCard({ enabledLabel: string; enabled: boolean; enabledDisabled?: boolean; + enabledDisabledReason?: string; onEnabledChange: (enabled: boolean) => void; - dependencies: ToolDependencyStatus[]; + permissionsLabel: string; + permissions: ToolPermissionBadge[]; + requirements: ToolRequirementItem[]; }) { return ( - - -
- - -
- -
-
+
+
+ + +
+ + } + /> + {requirements.length > 0 ? : null} +
); } diff --git a/frontend/src/features/settings/calls/types.ts b/frontend/src/features/settings/calls/types.ts index 233f193..fb098ff 100644 --- a/frontend/src/features/settings/calls/types.ts +++ b/frontend/src/features/settings/calls/types.ts @@ -10,6 +10,7 @@ export type ToolRequirementStatus = { name: string; available: boolean; reason: string; + data?: unknown; }; export type ToolItem = { diff --git a/frontend/src/features/settings/calls/utils/calls-utils.ts b/frontend/src/features/settings/calls/utils/calls-utils.ts index 330e96f..0326740 100644 --- a/frontend/src/features/settings/calls/utils/calls-utils.ts +++ b/frontend/src/features/settings/calls/utils/calls-utils.ts @@ -248,6 +248,7 @@ const mapGatewayToolRequirement = (requirement: unknown): ToolRequirementStatus name: typeof source.name === "string" && source.name.trim() !== "" ? source.name.trim() : id, available: source.available !== false, reason: typeof source.reason === "string" ? source.reason.trim() : "", + data: source.data, }; }; diff --git a/frontend/src/features/settings/calls/utils/web-tool-settings-utils.ts b/frontend/src/features/settings/calls/utils/web-tool-settings-utils.ts index 1df92c2..6cedf50 100644 --- a/frontend/src/features/settings/calls/utils/web-tool-settings-utils.ts +++ b/frontend/src/features/settings/calls/utils/web-tool-settings-utils.ts @@ -1,5 +1,5 @@ export type WebSearchType = "api" | "external_tools"; -export type WebFetchType = "playwright" | "builtin"; +export type SupportedBrowserID = "chrome" | "chromium" | "edge" | "brave"; export type WebSearchFormState = { type: WebSearchType; @@ -15,29 +15,19 @@ export type WebSearchFormState = { }; export type WebFetchFormState = { - type: WebFetchType; - playwrightMarkdown: boolean; - acceptMarkdown: boolean; - enableUserAgent: boolean; - userAgent: string; - acceptLanguage: string; + headless: boolean; + preferredBrowser: SupportedBrowserID; timeoutSeconds: string; maxChars: string; - maxRedirects: string; - retryMax: string; - headersJson: string; }; export type BrowserControlFormState = { enabled: boolean; - evaluateEnabled: boolean; headless: boolean; - noSandbox: boolean; - snapshotDefaultMode: string; + preferredBrowser: SupportedBrowserID; ssrfDangerouslyAllowPrivateNetwork: boolean; ssrfAllowedHostnamesJson: string; ssrfHostnameAllowlistJson: string; - extraArgsJson: string; }; export type WebSearchProviderOption = { @@ -92,15 +82,12 @@ export const normalizeWebSearchType = (value: string): WebSearchType => { return "api"; }; -export const normalizeWebFetchType = (value: string): WebFetchType => { +export const normalizePreferredBrowser = (value: string): SupportedBrowserID => { const normalized = value.trim().toLowerCase(); - if (normalized === "playwright") { - return "playwright"; - } - if (normalized === "builtin") { - return "builtin"; + if (normalized === "chromium" || normalized === "edge" || normalized === "brave") { + return normalized; } - return "builtin"; + return "chrome"; }; export const parseNumberInput = (value: string) => { @@ -176,17 +163,6 @@ export const readObjectValue = (source: Record | undefined, key return undefined; }; -export const stringifyObjectValue = (source: Record | undefined) => { - if (!source) { - return ""; - } - const entries = Object.entries(source).sort(([left], [right]) => left.localeCompare(right)); - if (entries.length === 0) { - return ""; - } - return JSON.stringify(Object.fromEntries(entries)); -}; - export const stringifyStringArrayValue = (source: unknown) => { if (!Array.isArray(source)) { return ""; @@ -200,22 +176,6 @@ export const stringifyStringArrayValue = (source: unknown) => { return JSON.stringify(values); }; -export const parseObjectJSON = (raw: string) => { - const trimmed = raw.trim(); - if (!trimmed) { - return { value: undefined as Record | undefined, error: null as string | null }; - } - try { - const parsed = JSON.parse(trimmed); - if (!isRecord(parsed)) { - return { value: undefined, error: "invalid-object" }; - } - return { value: parsed as Record, error: null }; - } catch { - return { value: undefined, error: "invalid-json" }; - } -}; - export const parseStringArrayJSON = (raw: string) => { const trimmed = raw.trim(); if (!trimmed) { diff --git a/frontend/src/features/settings/connectors/index.tsx b/frontend/src/features/settings/connectors/index.tsx index 97665c1..e3e8bb9 100644 --- a/frontend/src/features/settings/connectors/index.tsx +++ b/frontend/src/features/settings/connectors/index.tsx @@ -1,5 +1,4 @@ import * as React from "react"; -import { Events } from "@wailsio/runtime"; import { CircleOff, ExternalLink, Eye, Link2, Loader2, Plug2, RefreshCw, Search, Trash2 } from "lucide-react"; import { Button } from "@/shared/ui/button"; @@ -15,13 +14,16 @@ import { } from "@/shared/ui/settings-layout"; import { useI18n } from "@/shared/i18n"; import { + useCancelConnectorConnect, useClearConnector, - useConnectConnector, + useConnectorConnectSession, + useFinishConnectorConnect, useConnectors, useOpenConnectorSite, + useStartConnectorConnect, } from "@/shared/query/connectors"; import { messageBus } from "@/shared/message"; -import type { Connector } from "@/shared/contracts/connectors"; +import type { Connector, ConnectorConnectSession, FinishConnectorConnectResult } from "@/shared/contracts/connectors"; import { cn } from "@/lib/utils"; const STATUS_META: Record }> = { @@ -42,7 +44,7 @@ const STATUS_META: Record = { google: { group: "search_engine", labelKey: "settings.connectors.item.google", fallbackLabel: "Google" }, - xiaohongshu: { group: "search_engine", labelKey: "settings.connectors.item.xiaohongshu", fallbackLabel: "Xiaohongshu" }, + github: { group: "developer", labelKey: "settings.connectors.item.github", fallbackLabel: "GitHub" }, + reddit: { group: "community", labelKey: "settings.connectors.item.reddit", fallbackLabel: "Reddit" }, + zhihu: { group: "search_engine", labelKey: "settings.connectors.item.zhihu", fallbackLabel: "Zhihu" }, + x: { group: "community", labelKey: "settings.connectors.item.x", fallbackLabel: "X" }, + xiaohongshu: { group: "community", labelKey: "settings.connectors.item.xiaohongshu", fallbackLabel: "Xiaohongshu" }, bilibili: { group: "video", labelKey: "settings.connectors.item.bilibili", fallbackLabel: "Bilibili" }, }; @@ -81,7 +87,7 @@ const resolveConnectorMeta = (connectorType: string): ConnectorMeta | null => { const resolveConnectorGroup = (connector: Connector): ConnectorGroup => { const rawGroup = connector.group?.trim().toLowerCase(); - if (rawGroup === "search_engine" || rawGroup === "video" || rawGroup === "other") { + if (rawGroup === "search_engine" || rawGroup === "community" || rawGroup === "video" || rawGroup === "developer" || rawGroup === "other") { return rawGroup; } const meta = resolveConnectorMeta(connector.type); @@ -91,7 +97,9 @@ const resolveConnectorGroup = (connector: Connector): ConnectorGroup => { export function ConnectorsSection() { const { t } = useI18n(); const connectors = useConnectors(); - const connectConnector = useConnectConnector(); + const startConnectorConnect = useStartConnectorConnect(); + const finishConnectorConnect = useFinishConnectorConnect(); + const cancelConnectorConnect = useCancelConnectorConnect(); const clearConnector = useClearConnector(); const openConnectorSite = useOpenConnectorSite(); @@ -99,9 +107,11 @@ export function ConnectorsSection() { const [query, setQuery] = React.useState(""); const [loginDialogOpen, setLoginDialogOpen] = React.useState(false); const [loginTarget, setLoginTarget] = React.useState(null); + const [loginSessionId, setLoginSessionId] = React.useState(""); + const [loginResult, setLoginResult] = React.useState(null); const [loginError, setLoginError] = React.useState(""); - const [needsInstall, setNeedsInstall] = React.useState(false); const [cookiesDialogOpen, setCookiesDialogOpen] = React.useState(false); + const loginSession = useConnectorConnectSession({ sessionId: loginSessionId }, loginDialogOpen && loginSessionId.trim().length > 0); const items = connectors.data ?? []; const resolveConnectorLabel = React.useCallback( @@ -120,8 +130,12 @@ export function ConnectorsSection() { switch (group) { case "search_engine": return t("settings.connectors.group.searchEngine"); + case "community": + return t("settings.connectors.group.community"); case "video": return t("settings.connectors.group.video"); + case "developer": + return t("settings.connectors.group.developer"); default: return t("settings.connectors.group.other"); } @@ -198,46 +212,155 @@ export function ConnectorsSection() { const selected = items.find((item) => item.id === selectedId) ?? null; const status = STATUS_META[selected?.status ?? "disconnected"] ?? STATUS_META.disconnected; - const isBusy = connectConnector.isPending || openConnectorSite.isPending; - const isLoginRunning = connectConnector.isPending; + const isBusy = + startConnectorConnect.isPending || + finishConnectorConnect.isPending || + cancelConnectorConnect.isPending || + openConnectorSite.isPending || + clearConnector.isPending; + const isLoginRunning = + startConnectorConnect.isPending || + finishConnectorConnect.isPending || + cancelConnectorConnect.isPending; const isOpenRunning = openConnectorSite.isPending; - const resolveLoginError = (error: unknown) => { + const resolveLoginError = React.useCallback((error: unknown) => { const message = error instanceof Error ? error.message : String(error); - if (message.toLowerCase().includes("playwright not installed")) { - setNeedsInstall(true); - return t("settings.connectors.playwrightMissing"); + if (message.toLowerCase().includes("no supported browser detected")) { + return t("settings.connectors.browserMissing"); + } + if (message.toLowerCase().includes("connector browser session ended")) { + return t("settings.connectors.browserSessionEnded"); + } + if (message.toLowerCase().includes("connector session not found")) { + return t("settings.connectors.loginSessionMissing"); } return error instanceof Error ? error.message : t("settings.connectors.loginError"); - }; + }, [t]); + + const toLoginResult = React.useCallback((session: ConnectorConnectSession): FinishConnectorConnectResult => { + return { + sessionId: session.sessionId, + saved: session.saved, + rawCookiesCount: session.rawCookiesCount, + filteredCookiesCount: session.filteredCookiesCount, + domains: session.domains, + reason: session.reason, + connector: session.connector, + }; + }, []); + + const disposeLoginSession = React.useCallback(async (sessionId: string) => { + const trimmed = sessionId.trim(); + if (!trimmed) { + return; + } + try { + await cancelConnectorConnect.mutateAsync({ sessionId: trimmed }); + } catch { + // ignore disposal failures; a fresh connect attempt will replace stale sessions + } + }, [cancelConnectorConnect]); - const runConnect = async (connector: Connector) => { + const resetLoginState = React.useCallback(() => { + setLoginDialogOpen(false); + setLoginTarget(null); + setLoginSessionId(""); + setLoginResult(null); + setLoginError(""); + }, []); + + const handleCancelLogin = React.useCallback(async () => { + const sessionId = loginSessionId.trim(); + if (sessionId) { + try { + await disposeLoginSession(sessionId); + } catch (error) { + messageBus.publishToast({ + intent: "danger", + title: t("settings.connectors.loginTitle"), + description: resolveLoginError(error), + }); + } + } + resetLoginState(); + }, [disposeLoginSession, loginSessionId, resetLoginState, resolveLoginError, t]); + + const handleConnect = async (connector: Connector) => { + setLoginTarget(connector); + setLoginDialogOpen(true); + setLoginSessionId(""); + setLoginResult(null); setLoginError(""); - setNeedsInstall(false); try { - await connectConnector.mutateAsync({ id: connector.id }); - setLoginDialogOpen(false); - setLoginTarget(null); + const result = await startConnectorConnect.mutateAsync({ id: connector.id }); + setLoginSessionId(result.sessionId); } catch (error) { setLoginError(resolveLoginError(error)); } }; - const handleConnect = async (connector: Connector) => { - setLoginTarget(connector); + const handleFinishLogin = async () => { + const sessionId = loginSessionId.trim(); + if (!sessionId) { + setLoginError(t("settings.connectors.loginSessionMissing")); + return; + } setLoginError(""); - setNeedsInstall(false); - setLoginDialogOpen(true); - await runConnect(connector); + try { + const result = await finishConnectorConnect.mutateAsync({ sessionId }); + setLoginResult(result); + await disposeLoginSession(sessionId); + setLoginSessionId(""); + if (!result.saved) { + setLoginError(t("settings.connectors.noCookiesRead")); + return; + } + resetLoginState(); + } catch (error) { + setLoginError(resolveLoginError(error)); + } }; + React.useEffect(() => { + const session = loginSession.data; + if (!session || loginSessionId.trim().length === 0 || isLoginRunning) { + return; + } + if (session.state === "running") { + return; + } + + const sessionId = session.sessionId; + setLoginResult(toLoginResult(session)); + void disposeLoginSession(sessionId); + setLoginSessionId(""); + + if (session.state === "completed" && session.saved) { + resetLoginState(); + return; + } + + if (session.state === "completed") { + setLoginError(t("settings.connectors.noCookiesRead")); + return; + } + + if (session.error) { + setLoginError(session.error); + return; + } + + setLoginError(t("settings.connectors.loginError")); + }, [disposeLoginSession, isLoginRunning, loginSession.data, loginSessionId, resetLoginState, t, toLoginResult]); + const resolveOpenError = (error: unknown) => { const message = error instanceof Error ? error.message : String(error); if (message.toLowerCase().includes("no cookies")) { return t("settings.connectors.noCookies"); } - if (message.toLowerCase().includes("playwright not installed")) { - return t("settings.connectors.playwrightMissing"); + if (message.toLowerCase().includes("no supported browser detected")) { + return t("settings.connectors.browserMissing"); } return error instanceof Error ? error.message : t("settings.connectors.openSiteError"); }; @@ -254,20 +377,23 @@ export function ConnectorsSection() { } }; - const handleOpenExternalTools = () => { - Events.Emit("settings:navigate", "external-tools"); - }; - const rowClassName = SETTINGS_ROW_CLASS; - const dialogStatus = needsInstall - ? t("settings.connectors.installRequiredStatus") - : isLoginRunning - ? t("settings.connectors.loginRunning") - : t("settings.connectors.loginDone"); + const dialogStatus = startConnectorConnect.isPending + ? t("settings.connectors.loginLaunching") + : finishConnectorConnect.isPending + ? t("settings.connectors.loginReadingCookies") + : cancelConnectorConnect.isPending + ? t("settings.connectors.loginClosingBrowser") + : loginResult + ? t("settings.connectors.loginCompleted") + : loginSessionId + ? t("settings.connectors.loginReady") + : t("settings.connectors.loginIdle"); const selectedLabel = selected ? resolveConnectorLabel(selected) : ""; const cookiesCount = selected?.cookiesCount ?? selected?.cookies?.length ?? 0; const cookiesList = selected?.cookies ?? []; const isConnected = (selected?.status ?? "disconnected") === "connected"; + const loginDomainsLabel = loginResult?.domains && loginResult.domains.length > 0 ? loginResult.domains.join(", ") : "-"; return (
@@ -385,6 +511,17 @@ export function ConnectorsSection() { +
+
+ {t("settings.connectors.detail.scope")} +
+
+ {selected.domains && selected.domains.length > 0 ? selected.domains.join(", ") : "-"} +
+
+ + +
{t("settings.connectors.detail.actions")} @@ -443,8 +580,12 @@ export function ConnectorsSection() { { - if (!isBusy) { - setLoginDialogOpen(open); + if (open) { + setLoginDialogOpen(true); + return; + } + if (!isLoginRunning) { + void handleCancelLogin(); } }} > @@ -454,9 +595,7 @@ export function ConnectorsSection() { {t("settings.connectors.loginTitle")} - {needsInstall - ? t("settings.connectors.installDescription") - : t("settings.connectors.loginDescription")} + {t("settings.connectors.loginDescription")}
@@ -464,6 +603,19 @@ export function ConnectorsSection() { {t("settings.connectors.loginTarget")}: {loginTarget ? resolveConnectorLabel(loginTarget) : "-"}
{dialogStatus}
+ {loginResult ? ( +
+
+ {t("settings.connectors.cookiesRead")}: {loginResult.rawCookiesCount} +
+
+ {t("settings.connectors.cookiesSaved")}: {loginResult.filteredCookiesCount} +
+
+ {t("settings.connectors.cookiesDomains")}: {loginDomainsLabel} +
+
+ ) : null} {loginError ? (
{loginError} @@ -474,13 +626,12 @@ export function ConnectorsSection() {
- {needsInstall ? ( - - ) : null} - + diff --git a/frontend/src/features/settings/debug/DebugSection.tsx b/frontend/src/features/settings/debug/DebugSection.tsx index 2861551..47a1435 100644 --- a/frontend/src/features/settings/debug/DebugSection.tsx +++ b/frontend/src/features/settings/debug/DebugSection.tsx @@ -25,6 +25,7 @@ import { useRealtimeStore, } from "@/shared/realtime"; import { messageBus } from "@/shared/message"; +import { useUpdateStore } from "@/shared/store/update"; import { ChannelsTab } from "./tabs/ChannelsTab"; import { CallRecordsTab } from "./tabs/CallRecordsTab"; import { EventsTab } from "./tabs/EventsTab"; @@ -270,6 +271,9 @@ export function DebugSection() { clearMessages: state.clearMessages, })) ); + const currentAppVersion = useUpdateStore((state) => state.info.currentVersion); + const latestAppVersion = useUpdateStore((state) => state.info.latestVersion); + const openWhatsNewPreview = useUpdateStore((state) => state.openWhatsNewPreview); const [selectedTopic, setSelectedTopic] = useState(DEFAULT_DEBUG_TOPICS[0]); const [selectedThreadId, setSelectedThreadId] = useState(""); const [selectedCallThreadId, setSelectedCallThreadId] = useState(""); @@ -492,6 +496,18 @@ export function DebugSection() { }); }, [t]); + const showWhatsNewPreview = useCallback(() => { + const version = currentAppVersion.trim() || latestAppVersion.trim() || "2.0.7"; + openWhatsNewPreview( + { + version, + currentVersion: version, + changelog: t("settings.debug.message.frontend.whatsNewMarkdown"), + }, + "settings" + ); + }, [currentAppVersion, latestAppVersion, openWhatsNewPreview, t]); + const statusLabel = useMemo(() => { if (status === "connected") { return t("settings.debug.message.realtime.status.connected"); @@ -1046,6 +1062,7 @@ export function DebugSection() { showToastPreview={showToastPreview} showNotificationPreview={showNotificationPreview} showDialogPreview={showDialogPreview} + showWhatsNewPreview={showWhatsNewPreview} sendOsNotification={() => { void sendOsNotification(); }} diff --git a/frontend/src/features/settings/debug/tabs/FrameworkTab.tsx b/frontend/src/features/settings/debug/tabs/FrameworkTab.tsx index 534f1a2..95897cb 100644 --- a/frontend/src/features/settings/debug/tabs/FrameworkTab.tsx +++ b/frontend/src/features/settings/debug/tabs/FrameworkTab.tsx @@ -18,6 +18,7 @@ export function FrameworkTab({ showToastPreview, showNotificationPreview, showDialogPreview, + showWhatsNewPreview, sendOsNotification, publishBackendDebug, }: FrameworkTabProps) { @@ -59,6 +60,10 @@ export function FrameworkTab({ {t("settings.debug.framework.actions.dialog")} +
diff --git a/frontend/src/features/settings/debug/tabs/PromptTab.tsx b/frontend/src/features/settings/debug/tabs/PromptTab.tsx index 59c4907..3fc24f5 100644 --- a/frontend/src/features/settings/debug/tabs/PromptTab.tsx +++ b/frontend/src/features/settings/debug/tabs/PromptTab.tsx @@ -31,6 +31,10 @@ export function PromptTab({ const toolItems = useMemo(() => selectedPromptRun?.payload.tools ?? [], [selectedPromptRun]); const skillItems = useMemo(() => selectedPromptRun?.payload.skills ?? [], [selectedPromptRun]); + const metaSectionCount = useMemo( + () => Number(toolItems.length > 0) + Number(skillItems.length > 0), + [skillItems.length, toolItems.length] + ); const promptMessages = useMemo(() => { const source = Array.isArray(selectedPromptRun?.payload.messages) ? selectedPromptRun?.payload.messages ?? [] : []; return source @@ -60,7 +64,10 @@ export function PromptTab({ } return ( -
+
{items.map((item) => (
{t("settings.debug.prompt.meta.emptyCollections")}
) : ( -
+
1 ? "repeat(auto-fit, minmax(min(100%, 22rem), 1fr))" : "minmax(0, 1fr)", + }} + > {renderMetaCollectionSection( t("settings.debug.prompt.tools"), toolItems, diff --git a/frontend/src/features/settings/debug/types.ts b/frontend/src/features/settings/debug/types.ts index 04a6a47..edce7b7 100644 --- a/frontend/src/features/settings/debug/types.ts +++ b/frontend/src/features/settings/debug/types.ts @@ -164,6 +164,7 @@ export type FrameworkTabProps = { showToastPreview: () => void; showNotificationPreview: () => void; showDialogPreview: () => void; + showWhatsNewPreview: () => void; sendOsNotification: () => void; publishBackendDebug: () => void; }; diff --git a/frontend/src/features/update/WhatsNewDialog.tsx b/frontend/src/features/update/WhatsNewDialog.tsx new file mode 100644 index 0000000..4d7afb6 --- /dev/null +++ b/frontend/src/features/update/WhatsNewDialog.tsx @@ -0,0 +1,233 @@ +import * as React from "react"; +import { CheckCircle2 } from "lucide-react"; + +import { useDismissWhatsNew, useWhatsNew } from "@/shared/query/update"; +import { useI18n } from "@/shared/i18n"; +import { useUpdateStore } from "@/shared/store/update"; +import { Button } from "@/shared/ui/button"; +import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from "@/shared/ui/dialog"; +import { ProductModeGlyph } from "@/shared/ui/product-mode-glyph"; +import { DialogMarkdown } from "@/shared/markdown/dialog-markdown"; +import { useSetupCenter } from "@/features/setup"; + +function useWindowVisible() { + const [visible, setVisible] = React.useState(() => + typeof document !== "undefined" ? document.visibilityState === "visible" : false + ); + + React.useEffect(() => { + const update = () => { + setVisible(document.visibilityState === "visible"); + }; + update(); + document.addEventListener("visibilitychange", update); + window.addEventListener("focus", update); + window.addEventListener("blur", update); + return () => { + document.removeEventListener("visibilitychange", update); + window.removeEventListener("focus", update); + window.removeEventListener("blur", update); + }; + }, []); + + return visible; +} + +function useBlockingDialogPresent() { + const [present, setPresent] = React.useState(false); + + React.useEffect(() => { + const resolve = () => { + const dialogs = Array.from(document.querySelectorAll("[role='dialog']")); + setPresent( + dialogs.some((node) => { + const element = node as HTMLElement; + return element.dataset.whatsNewDialog !== "true"; + }) + ); + }; + + resolve(); + const observer = new MutationObserver(() => resolve()); + observer.observe(document.body, { + childList: true, + subtree: true, + attributes: true, + attributeFilter: ["data-state", "open", "style", "hidden"], + }); + + return () => observer.disconnect(); + }, []); + + return present; +} + +export interface WhatsNewDialogProps { + activeWindow: "main" | "settings"; + autoOpen?: boolean; +} + +export function WhatsNewDialog({ + activeWindow, + autoOpen = activeWindow === "main", +}: WhatsNewDialogProps) { + const { t } = useI18n(); + const { data: backendNotice } = useWhatsNew(); + const dismissWhatsNew = useDismissWhatsNew(); + const { open: isSetupCenterOpen } = useSetupCenter(); + const whatsNewPreview = useUpdateStore((state) => state.whatsNewPreview); + const clearWhatsNewPreview = useUpdateStore((state) => state.clearWhatsNewPreview); + const isWindowVisible = useWindowVisible(); + const hasBlockingDialog = useBlockingDialogPresent(); + const [open, setOpen] = React.useState(false); + const [presentationReady, setPresentationReady] = React.useState(false); + const dismissedVersionRef = React.useRef(""); + const previewNotice = + whatsNewPreview?.targetWindow === activeWindow ? whatsNewPreview.notice : null; + const notice = previewNotice ?? backendNotice; + + React.useEffect(() => { + setOpen(false); + setPresentationReady(false); + if (!notice?.version) { + dismissedVersionRef.current = ""; + return; + } + const timer = window.setTimeout(() => { + setPresentationReady(true); + }, 900); + return () => window.clearTimeout(timer); + }, [notice?.version]); + + React.useEffect(() => { + if (!notice?.version) { + return; + } + if (previewNotice) { + setOpen(true); + return; + } + if (dismissedVersionRef.current === notice.version) { + return; + } + if ( + !autoOpen || + !presentationReady || + !isWindowVisible || + isSetupCenterOpen || + hasBlockingDialog || + open + ) { + return; + } + setOpen(true); + }, [ + activeWindow, + autoOpen, + hasBlockingDialog, + isSetupCenterOpen, + isWindowVisible, + notice?.version, + open, + presentationReady, + previewNotice, + ]); + + const handleOpenChange = React.useCallback( + (nextOpen: boolean) => { + setOpen(nextOpen); + if (nextOpen || !notice?.version) { + return; + } + if (previewNotice) { + clearWhatsNewPreview(); + return; + } + dismissedVersionRef.current = notice.version; + dismissWhatsNew.mutate(notice.version); + }, + [clearWhatsNewPreview, dismissWhatsNew, notice?.version, previewNotice] + ); + + const title = notice?.version + ? t("whatsNew.title").replace("{version}", notice.version) + : t("whatsNew.title").replace("{version}", ""); + const description = t("whatsNew.description").trim(); + + return ( + + +
+
+
+
+
+
+ +
+ +
+
+
+ + {t("whatsNew.eyebrow")} +
+
+ + {title} + + {description ? ( + + {description} + + ) : null} +
+
+ +
+
+ +
+
+
+ {t("whatsNew.currentVersion")} +
+
+ {notice?.version} +
+
+
+
+
+ +
+ {notice?.changelog?.trim() ? ( + + ) : ( +
+

{t("whatsNew.emptyState")}

+

+ {t("whatsNew.versionLabel").replace("{version}", notice?.version ?? "")} +

+
+ )} +
+ +
+ +
+
+
+ +
+ ); +} diff --git a/frontend/src/shared/contracts/connectors.ts b/frontend/src/shared/contracts/connectors.ts index fcf559f..760ca6d 100644 --- a/frontend/src/shared/contracts/connectors.ts +++ b/frontend/src/shared/contracts/connectors.ts @@ -19,6 +19,9 @@ export interface Connector { status: ConnectorStatus | string; cookiesCount?: number; cookies?: ConnectorCookie[]; + domains?: string[]; + policyKey?: string; + capabilities?: string[]; lastVerifiedAt?: string; } @@ -33,10 +36,51 @@ export interface ClearConnectorRequest { id: string; } -export interface ConnectConnectorRequest { +export interface StartConnectorConnectRequest { id: string; } +export interface StartConnectorConnectResult { + sessionId: string; + connector: Connector; +} + +export interface FinishConnectorConnectRequest { + sessionId: string; +} + +export interface FinishConnectorConnectResult { + sessionId: string; + saved: boolean; + rawCookiesCount: number; + filteredCookiesCount: number; + domains?: string[]; + reason?: string; + connector: Connector; +} + +export interface CancelConnectorConnectRequest { + sessionId: string; +} + +export interface ConnectorConnectSession { + sessionId: string; + connectorId: string; + state: string; + saved: boolean; + rawCookiesCount: number; + filteredCookiesCount: number; + domains?: string[]; + reason?: string; + error?: string; + lastCookiesAt?: string; + connector: Connector; +} + +export interface GetConnectorConnectSessionRequest { + sessionId: string; +} + export interface OpenConnectorSiteRequest { id: string; } diff --git a/frontend/src/shared/i18n/locales/en.json b/frontend/src/shared/i18n/locales/en.json index ba82402..38c1ecf 100644 --- a/frontend/src/shared/i18n/locales/en.json +++ b/frontend/src/shared/i18n/locales/en.json @@ -932,11 +932,11 @@ "groups": { "read": { "label": "Read", - "description": "Skills.Status/skills.Bins/skill_manage.Search/skill_manage.List" + "description": "Read-only skills status, bins, and package discovery actions." }, "package_write": { "label": "Package write", - "description": "Skill_manage.Install/update/remove/sync" + "description": "Skill package install, update, remove, and sync actions." }, "deps_write": { "label": "Dependencies write", @@ -1046,7 +1046,22 @@ "updateFailed": "Failed to update tool status" }, "reason": { - "unavailable": "Required runtime dependencies are unavailable." + "unavailable": "Required runtime dependencies are unavailable.", + "browserProcessExited": "Browser process exited.", + "remoteNodeRuntimeUnavailable": "Temporarily unavailable: remote node runtime is not implemented yet.", + "imageModelNotConfigured": "Image model is not configured.", + "providerRepositoriesUnavailable": "Provider repositories are unavailable.", + "gatewayControlPlaneDisabled": "Gateway control plane is disabled.", + "webSearchModeUnsupported": "Search mode is not supported.", + "webSearchProviderUnsupported": "{provider} is not supported in API mode.", + "webSearchProviderApiKeyMissing": "{provider} API key is missing.", + "webSearchExternalToolsUnavailable": "External tools mode is not implemented yet.", + "voiceDisabled": "Gateway voice feature is disabled.", + "voiceServiceUnavailable": "Voice service is unavailable.", + "ttsProviderApiKeyMissing": "TTS provider API key is missing.", + "ttsVoiceIdMissing": "TTS voice ID is not configured.", + "ttsEdgeProviderUnavailable": "Temporarily unavailable: Edge-TTS provider is not implemented yet.", + "ttsProviderUnsupported": "TTS provider is not supported." }, "builtin": { "read": { @@ -1083,11 +1098,11 @@ }, "browser": { "name": "Browser control", - "description": "Control browser page loading with Playwright." + "description": "Control browser page loading through a local CDP browser." }, "canvas": { "name": "Canvas rendering", - "description": "Control node canvases (present/hide/navigate/eval/snapshot/a2ui)." + "description": "Control node canvases (present/hide/navigate/eval/snapshot/a2ui). Temporarily unavailable until remote node runtime support is implemented." }, "image": { "name": "Image processing", @@ -1095,7 +1110,7 @@ }, "message": { "name": "Message sending", - "description": "Send a channel message (not implemented)." + "description": "Send, delete, and manage messages through configured channel plugins." }, "gateway": { "name": "Gateway control", @@ -1107,27 +1122,27 @@ }, "agents_list": { "name": "Agents list", - "description": "List agents." + "description": "List available agent profiles for subagent spawning." }, "sessions_list": { "name": "Sessions list", - "description": "List sessions." + "description": "List current sessions." }, "sessions_history": { "name": "Session history", - "description": "Fetch session history." + "description": "Read message history for a session." }, "sessions_send": { "name": "Session send", - "description": "Send a message to a session." + "description": "Append a message to an existing session." }, "sessions_spawn": { "name": "Session spawn", - "description": "Spawn a new session." + "description": "Spawn an isolated subagent run." }, "session_status": { "name": "Session status", - "description": "Get session status." + "description": "Get session metadata and status." }, "external_tools_query": { "name": "External tools query", @@ -1139,7 +1154,7 @@ }, "skills": { "name": "Skills", - "description": "Skills runtime tool (status/bins/install/update)." + "description": "Inspect skills runtime status and update per-skill runtime dependencies or configuration." }, "skill_manage": { "name": "Skill manage", @@ -1151,7 +1166,7 @@ }, "skills_manage": { "name": "Skills manage", - "description": "Manage skills lifecycle actions." + "description": "Search, install, update, remove, and sync skill packages via ClawHub." }, "skills_policy": { "name": "Skills policy", @@ -1159,15 +1174,15 @@ }, "subagents": { "name": "Subagents", - "description": "Manage subagent runs (list/info/log/kill/steer/spawn)." + "description": "Manage existing subagent runs (list/info/log/kill/steer/send)." }, "nodes": { "name": "Nodes", - "description": "Invoke node capability." + "description": "Experimental low-level RPC to a registered node. Temporarily unavailable until remote node runtime support is implemented." }, "tts": { "name": "Text to speech", - "description": "Convert text to speech." + "description": "Synthesize speech audio from text with the configured voice provider." }, "render_chart": { "name": "Render chart", @@ -1221,6 +1236,14 @@ "name": "Memory list", "description": "List recent memory entries by filter." }, + "memory_query": { + "name": "Memory query", + "description": "Query long-term memory with recall, list, and stats actions." + }, + "memory_manage": { + "name": "Memory manage", + "description": "Create, update, or delete long-term memory entries." + }, "subagent_run": { "name": "Subagent run", "description": "Run a sub-agent task in an isolated workspace." @@ -1260,55 +1283,68 @@ "externalToolsHint": "External tools mode is reserved. No options are available yet." }, "webFetch": { - "type": "Mode", - "typeDesc": "Switch between Playwright headless browser mode and builtin HTTP fetch mode.", - "typeValue": { - "playwright": "Playwright", - "builtin": "Builtin" - }, - "playwrightMarkdown": "Convert to Markdown", - "playwrightMarkdownDesc": "Convert rendered HTML output into Markdown.", - "playwrightHint": "Playwright mode runs in headless browser and only reads cookies from connectors. It will not persist new cookies.", - "acceptMarkdown": "Prefer Markdown", - "acceptMarkdownDesc": "When enabled, requests prefer text/Markdown with HTML fallback.", - "enableUserAgent": "Send User-Agent", - "enableUserAgentDesc": "When disabled, no explicit User-Agent will be sent.", - "userAgent": "User-Agent", - "acceptLanguage": "Accept-Language", + "preferredBrowser": "Preferred browser", + "preferredBrowserDesc": "Show only browsers detected on this machine.", + "headless": "Headless", + "headlessDesc": "Run web fetch in headless mode.", "timeoutSeconds": "Timeout (sec)", - "maxChars": "Max chars", - "maxRedirects": "Max redirects", - "maxRedirectsDesc": "Set 0 to disable redirect following.", - "retryMax": "Max retries", - "retryMaxDesc": "Applies only to safely retryable methods (E.G. GET/HEAD/OPTIONS).", - "headers": "Extra headers (JSON)", - "headersDesc": "Provide a JSON object, E.G. {\"X-Test\":\"1\"}. It merges with tool-call headers.", - "headersInvalid": "Invalid headers format", - "headersInvalidDesc": "Use a JSON object, E.G. {\"Accept-Language\":\"En-US\"}." + "maxChars": "Max chars" }, "browserControl": { "enabled": "Enabled", - "evaluateEnabled": "Allow evaluate", - "evaluateEnabledDesc": "Disable browser evaluate actions.", "ssrfSection": "SSRF Security", "ssrfSectionDesc": "Private network and hostname allowlist options.", - "executablePath": "Playwright chromium path", - "executablePathDesc": "Managed by Playwright and read-only.", - "executablePathPending": "Path will appear after Playwright runtime check.", + "preferredBrowser": "Preferred browser", + "preferredBrowserDesc": "Show only browsers detected on this machine.", "headless": "Headless", "headlessDesc": "Start browser in headless mode.", - "noSandbox": "No sandbox", - "noSandboxDesc": "Pass --no-sandbox to chrome.", - "snapshotDefaultMode": "Snapshot default mode", - "snapshotDefaultModeAuto": "Auto", "ssrfDangerouslyAllowPrivateNetwork": "Allow private network", "ssrfDangerouslyAllowPrivateNetworkDesc": "Allow browser navigation to private/internal networks.", "ssrfAllowedHostnames": "Allowed hostnames (JSON)", "ssrfHostnameAllowlist": "Hostname allowlist (JSON)", - "extraArgs": "Extra args (JSON)", "arrayInvalid": "Array format is invalid", "arrayInvalidDesc": "Use a JSON string array, for example [\"localhost\",\"metadata.Internal\"]." }, + "runtimeDetection": { + "detected": "Detected", + "notDetected": "Not detected", + "notInstalled": "Not installed", + "noneDetected": "No supported browser detected" + }, + "requirements": { + "gatewayControlPlane": "Gateway control plane", + "localCDPBrowser": "Local CDP browser", + "remoteNodeRuntime": "Remote node runtime", + "webSearchMode": "Search mode", + "webSearchProvider": "Provider", + "webSearchProviderApiKey": "Provider API key", + "externalToolsRuntime": "External tools runtime", + "imageModel": "Image model", + "voiceService": "Voice service", + "voiceFeature": "Voice feature", + "ttsProvider": "Provider", + "ttsProviderApiKey": "Provider API key", + "ttsVoiceId": "Voice ID", + "values": { + "available": "Available", + "unavailable": "Unavailable", + "configured": "Configured", + "notConfigured": "Not configured", + "enabled": "Enabled", + "disabled": "Disabled", + "detected": "Detected", + "missing": "Missing" + }, + "providers": { + "brave": "Brave", + "tavily": "Tavily", + "perplexity": "Perplexity", + "grok": "Grok", + "openai": "OpenAI", + "elevenlabs": "ElevenLabs", + "edge": "Edge-TTS" + } + }, "detail": { "descriptionLabel": "Description", "enabled": "Enabled", @@ -1351,31 +1387,46 @@ "searchPlaceholder": "Search connectors", "searchEmpty": "No connectors match your search.", "group.searchEngine": "Search engines", + "group.community": "Communities", "group.video": "Video", + "group.developer": "Developer", "group.other": "Other", "item.google": "Google", + "item.github": "GitHub", + "item.reddit": "Reddit", + "item.zhihu": "Zhihu", + "item.x": "X", "item.xiaohongshu": "Xiaohongshu", "item.bilibili": "Bilibili", "connect": "Connect", "reconnect": "Reconnect", "clear": "Clear", "noCookies": "No cookies configured", - "loginHint": "A browser window will open. Sign in, then close the window to save cookies.", + "noCookiesRead": "No connector cookies were read yet. Stay signed in and click finish again.", + "loginHint": "A browser window will open with a temporary profile. Sign in there, then click finish to read and save cookies.", "loginTitle": "Browser login", - "loginDescription": "Finish login in the browser window, then close it.", - "installDescription": "Playwright is required to open the browser.", + "loginDescription": "Finish login in the browser window, then click finish in this dialog.", "loginTarget": "Connector", - "loginRunning": "Waiting for browser to close...", - "loginDone": "Login completed.", + "loginLaunching": "Launching browser...", + "loginReady": "Browser opened. Complete sign-in, then click finish.", + "loginReadingCookies": "Reading cookies from the browser session...", + "loginClosingBrowser": "Closing the browser session...", + "loginCompleted": "Connection processing completed.", + "loginIdle": "Waiting to start login.", + "loginFinish": "Finish", "loginError": "Login failed", - "playwrightMissing": "Playwright is not installed. Install it to continue.", - "openExternalTools": "Open external tools", - "installRequiredStatus": "Install required to continue.", + "loginSessionMissing": "The login session was not found. Start the connection again.", + "browserSessionEnded": "The browser session has already ended. Start the connection again.", + "browserMissing": "No supported local browser was detected.", + "cookiesRead": "Cookies read", + "cookiesSaved": "Cookies kept", + "cookiesDomains": "Matched domains", "status.connected": "Connected", "status.expired": "Expired", "status.disconnected": "Disconnected", "detail.status": "Status", "detail.data": "Data", + "detail.scope": "Cookie scope", "detail.actions": "Actions", "viewCookies": "View", "openSite": "Open site", @@ -1720,8 +1771,9 @@ "current": "Current version", "changelog": "Release notes", "check": "Check for updates", - "install": "Install updates", - "restart": "Restart to apply", + "recheck": "Check again", + "downloadAndInstall": "Download and install", + "restartAfterUpdate": "Restart to update", "command": "Update actions", "status": "Status", "downloading": "Downloading", @@ -3524,6 +3576,14 @@ "notAvailable": "Not available", "justNow": "Just now" }, + "whatsNew": { + "eyebrow": "Release notes", + "currentVersion": "Current version", + "title": "What's new in {version}", + "description": "", + "emptyState": "This version has already been applied, but no release notes are available yet.", + "versionLabel": "Current version: {version}" + }, "gateway": { "page": { "subtitle": "Control plane overview and diagnostics" @@ -3804,6 +3864,7 @@ "notification": "Notification", "sendOS": "Send OS message", "dialog": "Dialog", + "whatsNew": "What's new", "publish": "Publish debug event" } }, @@ -3841,7 +3902,8 @@ "action": "Acknowledge", "dialogTitle": "Danger Dialog", "dialogDesc": "Use this to verify modal styles.", - "dialogConfirm": "Confirm" + "dialogConfirm": "Confirm", + "whatsNewMarkdown": "## Debug preview\n\n### layout verification\n- this line checks the title-to-body spacing.\n- this line checks paragraph rhythm in markdown.\n- this line checks the rounded container edge.\n- this line checks long-content readability.\n- this line checks the scroll threshold.\n- this line checks whether the dialog height stays stable.\n\n### interaction verification\n- this line checks scrolling after the content grows.\n- this line checks that the footer stays visible.\n- this line checks that the close button does not move.\n- this line checks markdown list rendering.\n- this line checks the spacing near the bottom.\n- this line confirms this preview is debug-only." }, "realtime": { "websocket": "WebSocket", @@ -4041,6 +4103,7 @@ "productMode": "Product mode", "notifications": "Notifications", "appUpdate": "App update", + "restartAndUpdate": "Restart and update", "externalToolsUpdate": "Tool updates", "settings": "Settings", "open": "Open profile menu" diff --git a/frontend/src/shared/i18n/locales/zh-CN.json b/frontend/src/shared/i18n/locales/zh-CN.json index e56fd66..7e9f8a5 100644 --- a/frontend/src/shared/i18n/locales/zh-CN.json +++ b/frontend/src/shared/i18n/locales/zh-CN.json @@ -932,11 +932,11 @@ "groups": { "read": { "label": "读取", - "description": "skills.status/skills.bins/skill_manage.search/skill_manage.list" + "description": "skills.status/skills.bins/skills_manage.search/skills_manage.list" }, "package_write": { "label": "包管理写入", - "description": "skill_manage.install/update/remove/sync" + "description": "skills_manage.install/update/remove/sync" }, "deps_write": { "label": "依赖写入", @@ -1046,7 +1046,22 @@ "updateFailed": "更新工具状态失败" }, "reason": { - "unavailable": "缺少运行时依赖,当前不可用。" + "unavailable": "缺少运行时依赖,当前不可用。", + "browserProcessExited": "浏览器进程已退出。", + "remoteNodeRuntimeUnavailable": "暂时不可用:远端 Node 运行时还未实现。", + "imageModelNotConfigured": "图像模型尚未配置。", + "providerRepositoriesUnavailable": "Provider 仓库当前不可用。", + "gatewayControlPlaneDisabled": "Gateway 控制平面已关闭。", + "webSearchModeUnsupported": "当前搜索模式不受支持。", + "webSearchProviderUnsupported": "{provider} 当前不支持 API 模式。", + "webSearchProviderApiKeyMissing": "{provider} 的 API key 缺失。", + "webSearchExternalToolsUnavailable": "External tools 模式暂未实现。", + "voiceDisabled": "Gateway 语音功能已关闭。", + "voiceServiceUnavailable": "语音服务当前不可用。", + "ttsProviderApiKeyMissing": "TTS provider 的 API key 缺失。", + "ttsVoiceIdMissing": "TTS 的 voice ID 尚未配置。", + "ttsEdgeProviderUnavailable": "暂时不可用:Edge-TTS provider 还未实现。", + "ttsProviderUnsupported": "当前 TTS provider 不受支持。" }, "builtin": { "read": { @@ -1083,11 +1098,11 @@ }, "browser": { "name": "浏览器控制", - "description": "通过 Playwright 控制页面加载。" + "description": "通过本地 CDP 浏览器控制页面加载。" }, "canvas": { "name": "画布渲染", - "description": "控制节点画布能力(present/hide/navigate/eval/snapshot/a2ui)。" + "description": "控制节点画布能力(present/hide/navigate/eval/snapshot/a2ui)。远端 Node 运行时落地前暂不可用。" }, "image": { "name": "图像处理", @@ -1095,7 +1110,7 @@ }, "message": { "name": "消息发送", - "description": "发送频道消息(当前未实现)。" + "description": "通过已配置的频道插件发送、删除和管理消息。" }, "gateway": { "name": "Gateway 控制", @@ -1107,27 +1122,27 @@ }, "agents_list": { "name": "代理列表", - "description": "列出代理。" + "description": "列出可用于子代理启动的代理配置。" }, "sessions_list": { "name": "会话列表", - "description": "列出会话。" + "description": "列出当前会话。" }, "sessions_history": { "name": "会话历史", - "description": "获取会话历史。" + "description": "读取某个会话的消息历史。" }, "sessions_send": { "name": "会话发送", - "description": "向会话发送消息。" + "description": "向已有会话追加一条消息。" }, "sessions_spawn": { "name": "新建会话", - "description": "创建新的会话。" + "description": "启动一次隔离的子代理运行。" }, "session_status": { "name": "会话状态", - "description": "获取会话状态。" + "description": "获取会话元数据与状态。" }, "external_tools_query": { "name": "外部工具查询", @@ -1139,7 +1154,7 @@ }, "skills": { "name": "技能", - "description": "技能运行工具(状态/命令依赖/依赖安装/配置更新)。" + "description": "查看技能运行状态,并更新单个技能的运行依赖或配置。" }, "skill_manage": { "name": "技能管理", @@ -1151,7 +1166,7 @@ }, "skills_manage": { "name": "技能管理", - "description": "管理技能生命周期操作。" + "description": "通过 ClawHub 搜索、安装、更新、移除和同步技能包。" }, "skills_policy": { "name": "技能策略", @@ -1159,15 +1174,15 @@ }, "subagents": { "name": "子代理管理", - "description": "管理子代理运行(列表/详情/日志/终止/转向/启动)。" + "description": "管理已有子代理运行(列表/详情/日志/终止/转向/发送)。" }, "nodes": { "name": "节点调用", - "description": "调用节点能力。" + "description": "实验性的低层节点 RPC。远端 Node 运行时落地前暂不可用。" }, "tts": { "name": "文本转语音", - "description": "将文本转换为语音。" + "description": "使用当前语音提供方把文本合成为语音音频。" }, "render_chart": { "name": "图表渲染", @@ -1221,6 +1236,14 @@ "name": "记忆列表", "description": "按条件列出近期记忆条目。" }, + "memory_query": { + "name": "记忆查询", + "description": "通过 recall、list、stats 动作查询长期记忆。" + }, + "memory_manage": { + "name": "记忆管理", + "description": "创建、更新或删除长期记忆条目。" + }, "subagent_run": { "name": "子代理运行", "description": "在隔离工作区运行子代理任务。" @@ -1260,55 +1283,68 @@ "externalToolsHint": "外部工具模式仍为预留能力,暂时没有可配置项。" }, "webFetch": { - "type": "模式", - "typeDesc": "在 Playwright 无头浏览器与内置 HTTP 抓取之间切换。", - "typeValue": { - "playwright": "Playwright", - "builtin": "内置" - }, - "playwrightMarkdown": "转换为 Markdown", - "playwrightMarkdownDesc": "开启后会将渲染后的 HTML 内容转换为 Markdown。", - "playwrightHint": "Playwright 模式会以无头浏览器访问页面,并只读取 Connectors 中已有 Cookies,不会写回或新增 Cookies。", - "acceptMarkdown": "优先 Markdown", - "acceptMarkdownDesc": "开启后会优先请求 text/markdown,并在服务端回退到 HTML。", - "enableUserAgent": "发送 User-Agent", - "enableUserAgentDesc": "关闭后将不主动发送 User-Agent。", - "userAgent": "User-Agent", - "acceptLanguage": "Accept-Language", + "preferredBrowser": "默认浏览器", + "preferredBrowserDesc": "仅展示本机已检测到的浏览器。", + "headless": "无头模式", + "headlessDesc": "以无头模式进行网页抓取。", "timeoutSeconds": "超时(秒)", - "maxChars": "最大字符数", - "maxRedirects": "最大重定向次数", - "maxRedirectsDesc": "0 表示不跟随重定向。", - "retryMax": "最大重试次数", - "retryMaxDesc": "仅对可安全重试的方法生效(如 GET/HEAD/OPTIONS)。", - "headers": "附加 Headers(JSON)", - "headersDesc": "填写 JSON 对象,例如 {\"X-Test\":\"1\"}。会与工具调用参数 headers 合并。", - "headersInvalid": "Headers 格式不正确", - "headersInvalidDesc": "请填写 JSON 对象,例如 {\"Accept-Language\":\"en-US\"}。" + "maxChars": "最大字符数" }, "browserControl": { "enabled": "启用", - "evaluateEnabled": "允许 Evaluate", - "evaluateEnabledDesc": "关闭后禁用 browser evaluate 动作。", "ssrfSection": "SSRF 安全", "ssrfSectionDesc": "私有网络与主机名白名单配置。", - "executablePath": "Playwright Chromium 路径", - "executablePathDesc": "由 Playwright 管理,只读不可修改。", - "executablePathPending": "等待 Playwright 运行时检查后显示路径。", + "preferredBrowser": "默认浏览器", + "preferredBrowserDesc": "仅展示本机已检测到的浏览器。", "headless": "无头模式", "headlessDesc": "以无头模式启动浏览器。", - "noSandbox": "关闭沙箱", - "noSandboxDesc": "启动 Chrome 时添加 --no-sandbox。", - "snapshotDefaultMode": "快照默认模式", - "snapshotDefaultModeAuto": "自动", "ssrfDangerouslyAllowPrivateNetwork": "允许私有网络", "ssrfDangerouslyAllowPrivateNetworkDesc": "允许浏览器访问私有/内网地址。", "ssrfAllowedHostnames": "允许的主机名(JSON)", "ssrfHostnameAllowlist": "主机名白名单(JSON)", - "extraArgs": "额外参数(JSON)", "arrayInvalid": "数组格式不正确", "arrayInvalidDesc": "请填写 JSON 字符串数组,例如 [\"localhost\",\"metadata.internal\"]。" }, + "runtimeDetection": { + "detected": "已检测到", + "notDetected": "未检测到", + "notInstalled": "未安装", + "noneDetected": "未检测到可用浏览器" + }, + "requirements": { + "gatewayControlPlane": "Gateway 控制平面", + "localCDPBrowser": "本地 CDP 浏览器", + "remoteNodeRuntime": "远端 Node 运行时", + "webSearchMode": "搜索模式", + "webSearchProvider": "提供方", + "webSearchProviderApiKey": "提供方 API Key", + "externalToolsRuntime": "外部工具运行时", + "imageModel": "图像模型", + "voiceService": "语音服务", + "voiceFeature": "语音功能", + "ttsProvider": "提供方", + "ttsProviderApiKey": "提供方 API Key", + "ttsVoiceId": "语音 ID", + "values": { + "available": "可用", + "unavailable": "不可用", + "configured": "已配置", + "notConfigured": "未配置", + "enabled": "已启用", + "disabled": "已关闭", + "detected": "已检测到", + "missing": "缺失" + }, + "providers": { + "brave": "Brave", + "tavily": "Tavily", + "perplexity": "Perplexity", + "grok": "Grok", + "openai": "OpenAI", + "elevenlabs": "ElevenLabs", + "edge": "Edge-TTS" + } + }, "detail": { "descriptionLabel": "说明", "enabled": "启用", @@ -1351,31 +1387,46 @@ "searchPlaceholder": "搜索连接", "searchEmpty": "没有匹配的连接。", "group.searchEngine": "搜索引擎", + "group.community": "社区", "group.video": "视频", + "group.developer": "开发者", "group.other": "其他", "item.google": "Google", + "item.github": "GitHub", + "item.reddit": "Reddit", + "item.zhihu": "知乎", + "item.x": "X", "item.xiaohongshu": "小红书", "item.bilibili": "Bilibili", "connect": "连接", "reconnect": "重新连接", "clear": "清除", "noCookies": "未配置 cookies", - "loginHint": "将打开浏览器窗口。完成登录后关闭窗口即可保存 cookies。", + "noCookiesRead": "还没有读取到当前连接的 cookies。请保持登录状态后再次点击完成连接。", + "loginHint": "将以临时 profile 打开浏览器窗口。请先在浏览器里完成登录,然后回到这里点击完成连接。", "loginTitle": "浏览器登录", - "loginDescription": "请在浏览器窗口完成登录,然后关闭窗口。", - "installDescription": "需要安装 Playwright 才能打开浏览器。", + "loginDescription": "请在浏览器窗口完成登录,然后回到此对话框点击完成连接。", "loginTarget": "连接", - "loginRunning": "等待浏览器关闭中...", - "loginDone": "登录完成。", + "loginLaunching": "正在启动浏览器...", + "loginReady": "浏览器已打开。完成登录后点击完成连接。", + "loginReadingCookies": "正在从浏览器会话读取 cookies...", + "loginClosingBrowser": "正在关闭浏览器会话...", + "loginCompleted": "连接处理已完成。", + "loginIdle": "等待开始连接。", + "loginFinish": "完成连接", "loginError": "登录失败", - "playwrightMissing": "未安装 Playwright,无法打开浏览器。", - "openExternalTools": "打开外部工具", - "installRequiredStatus": "需要安装 Playwright 才能继续。", + "loginSessionMissing": "未找到当前连接会话,请重新开始连接。", + "browserSessionEnded": "浏览器会话已结束,请重新开始连接。", + "browserMissing": "未检测到可用的本机浏览器。", + "cookiesRead": "读取到的 cookies", + "cookiesSaved": "保留的 cookies", + "cookiesDomains": "命中的域名", "status.connected": "已连接", "status.expired": "已过期", "status.disconnected": "未连接", "detail.status": "状态", "detail.data": "数据", + "detail.scope": "Cookie 范围", "detail.actions": "操作", "viewCookies": "查看", "openSite": "打开站点", @@ -1720,8 +1771,9 @@ "current": "当前版本", "changelog": "更新日志", "check": "检查更新", - "install": "安装更新", - "restart": "重启生效", + "recheck": "重新检查", + "downloadAndInstall": "下载并安装", + "restartAfterUpdate": "重启后更新", "command": "更新操作", "status": "状态", "downloading": "下载中", @@ -2390,6 +2442,7 @@ "notification": "通知", "sendOS": "系统消息", "dialog": "对话框", + "whatsNew": "What's New", "publish": "调试事件" } }, @@ -2427,7 +2480,8 @@ "action": "知道了", "dialogTitle": "危险弹窗", "dialogDesc": "用于确认样式的弹窗示例。", - "dialogConfirm": "确认" + "dialogConfirm": "确认", + "whatsNewMarkdown": "## 调试预览\n\n### 布局验证\n- 这一行用于检查标题与正文之间的留白。\n- 这一行用于检查 Markdown 段落节奏。\n- 这一行用于检查内容卡片的圆角边界。\n- 这一行用于检查长内容时的可读性。\n- 这一行用于检查滚动阈值是否生效。\n- 这一行用于检查 dialog 高度是否保持稳定。\n\n### 交互验证\n- 这一行用于检查内容超过一屏后的滚动行为。\n- 这一行用于检查底部按钮区域是否始终可见。\n- 这一行用于检查关闭按钮位置不会被挤压。\n- 这一行用于检查 Markdown 列表渲染效果。\n- 这一行用于检查底部间距是否自然。\n- 这一行确认这是仅供调试使用的预览内容。" }, "realtime": { "websocket": "WebSocket", @@ -4001,6 +4055,14 @@ "notAvailable": "不可用", "justNow": "刚刚" }, + "whatsNew": { + "eyebrow": "版本已更新", + "currentVersion": "当前版本", + "title": "{version} 有这些新内容", + "description": "", + "emptyState": "当前版本已经更新完成,但暂时没有可展示的更新日志。", + "versionLabel": "当前版本:{version}" + }, "gateway": { "page": { "subtitle": "控制平面概览与诊断信息" @@ -4053,6 +4115,7 @@ "productMode": "产品形态", "notifications": "通知中心", "appUpdate": "应用升级", + "restartAndUpdate": "重启后更新", "externalToolsUpdate": "工具升级", "settings": "设置", "open": "打开个人菜单" diff --git a/frontend/src/shared/query/connectors.ts b/frontend/src/shared/query/connectors.ts index 169aa75..641929c 100644 --- a/frontend/src/shared/query/connectors.ts +++ b/frontend/src/shared/query/connectors.ts @@ -1,29 +1,44 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import type { + CancelConnectorConnectRequest, ClearConnectorRequest, - ConnectConnectorRequest, + ConnectorConnectSession, Connector, + FinishConnectorConnectRequest, + FinishConnectorConnectResult, + GetConnectorConnectSessionRequest, OpenConnectorSiteRequest, + StartConnectorConnectRequest, + StartConnectorConnectResult, UpsertConnectorRequest, } from "@/shared/contracts/connectors"; import { + CancelConnectorConnect as CancelConnectorConnectBinding, ClearConnector as ClearConnectorBinding, - ConnectConnector as ConnectConnectorBinding, - InstallPlaywright, + FinishConnectorConnect as FinishConnectorConnectBinding, + GetConnectorConnectSession as GetConnectorConnectSessionBinding, ListConnectors, OpenConnectorSite as OpenConnectorSiteBinding, + StartConnectorConnect as StartConnectorConnectBinding, UpsertConnector as UpsertConnectorBinding, } from "../../../bindings/dreamcreator/internal/presentation/wails/connectorshandler"; import { + CancelConnectorConnectRequest as BindingsCancelConnectorConnectRequest, ClearConnectorRequest as BindingsClearConnectorRequest, - ConnectConnectorRequest as BindingsConnectConnectorRequest, + ConnectorConnectSession as BindingsConnectorConnectSession, Connector as BindingsConnector, + FinishConnectorConnectRequest as BindingsFinishConnectorConnectRequest, + FinishConnectorConnectResult as BindingsFinishConnectorConnectResult, + GetConnectorConnectSessionRequest as BindingsGetConnectorConnectSessionRequest, OpenConnectorSiteRequest as BindingsOpenConnectorSiteRequest, + StartConnectorConnectRequest as BindingsStartConnectorConnectRequest, + StartConnectorConnectResult as BindingsStartConnectorConnectResult, UpsertConnectorRequest as BindingsUpsertConnectorRequest, } from "../../../bindings/dreamcreator/internal/application/connectors/dto/models"; export const CONNECTORS_QUERY_KEY = ["connectors"]; +export const CONNECTOR_CONNECT_SESSION_QUERY_KEY = ["connector-connect-session"]; export function useConnectors() { return useQuery({ @@ -59,11 +74,23 @@ export function useClearConnector() { }); } -export function useConnectConnector() { +export function useStartConnectorConnect() { + return useMutation({ + mutationFn: async (request: StartConnectorConnectRequest): Promise => { + return toStartConnectorConnectResult( + await StartConnectorConnectBinding(BindingsStartConnectorConnectRequest.createFrom(request)) + ); + }, + }); +} + +export function useFinishConnectorConnect() { const queryClient = useQueryClient(); return useMutation({ - mutationFn: async (request: ConnectConnectorRequest): Promise => { - return toConnector(await ConnectConnectorBinding(BindingsConnectConnectorRequest.createFrom(request))); + mutationFn: async (request: FinishConnectorConnectRequest): Promise => { + return toFinishConnectorConnectResult( + await FinishConnectorConnectBinding(BindingsFinishConnectorConnectRequest.createFrom(request)) + ); }, onSuccess: () => { queryClient.invalidateQueries({ queryKey: CONNECTORS_QUERY_KEY }); @@ -71,6 +98,14 @@ export function useConnectConnector() { }); } +export function useCancelConnectorConnect() { + return useMutation({ + mutationFn: async (request: CancelConnectorConnectRequest): Promise => { + await CancelConnectorConnectBinding(BindingsCancelConnectorConnectRequest.createFrom(request)); + }, + }); +} + export function useOpenConnectorSite() { return useMutation({ mutationFn: async (request: OpenConnectorSiteRequest): Promise => { @@ -79,11 +114,17 @@ export function useOpenConnectorSite() { }); } -export function useInstallPlaywright() { - return useMutation({ - mutationFn: async (): Promise => { - await InstallPlaywright(); +export function useConnectorConnectSession(request: GetConnectorConnectSessionRequest, enabled: boolean) { + return useQuery({ + queryKey: [...CONNECTOR_CONNECT_SESSION_QUERY_KEY, request.sessionId], + enabled: enabled && request.sessionId.trim().length > 0, + queryFn: async (): Promise => { + return toConnectorConnectSession( + await GetConnectorConnectSessionBinding(BindingsGetConnectorConnectSessionRequest.createFrom(request)) + ); }, + refetchInterval: 1000, + staleTime: 0, }); } @@ -93,3 +134,24 @@ function toConnector(raw: BindingsConnector): Connector { cookies: raw.cookies.map((item) => ({ ...item })), }; } + +function toStartConnectorConnectResult(raw: BindingsStartConnectorConnectResult): StartConnectorConnectResult { + return { + ...raw, + connector: toConnector(raw.connector), + }; +} + +function toFinishConnectorConnectResult(raw: BindingsFinishConnectorConnectResult): FinishConnectorConnectResult { + return { + ...raw, + connector: toConnector(raw.connector), + }; +} + +function toConnectorConnectSession(raw: BindingsConnectorConnectSession): ConnectorConnectSession { + return { + ...raw, + connector: toConnector(raw.connector), + }; +} diff --git a/frontend/src/shared/query/update.ts b/frontend/src/shared/query/update.ts index 9ef287c..02279b2 100644 --- a/frontend/src/shared/query/update.ts +++ b/frontend/src/shared/query/update.ts @@ -1,9 +1,15 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { Call } from "@wailsio/runtime"; -import { normalizeUpdateInfo, type UpdateInfo } from "@/shared/store/update"; +import { + normalizeUpdateInfo, + normalizeWhatsNewInfo, + type UpdateInfo, + type WhatsNewInfo, +} from "@/shared/store/update"; const UPDATE_QUERY_KEY = ["update-state"]; +const WHATS_NEW_QUERY_KEY = ["whats-new"]; export function useUpdateState() { return useQuery({ @@ -58,4 +64,30 @@ export function useRestartToApply() { }); } -export { UPDATE_QUERY_KEY }; +export function useWhatsNew() { + return useQuery({ + queryKey: WHATS_NEW_QUERY_KEY, + queryFn: async (): Promise => { + const result = await Call.ByName("dreamcreator/internal/presentation/wails.UpdateHandler.GetWhatsNew"); + return normalizeWhatsNewInfo(result as Partial); + }, + staleTime: Infinity, + }); +} + +export function useDismissWhatsNew() { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: async (version: string): Promise => { + await Call.ByName( + "dreamcreator/internal/presentation/wails.UpdateHandler.DismissWhatsNew", + version + ); + }, + onSuccess: () => { + queryClient.setQueryData(WHATS_NEW_QUERY_KEY, null); + }, + }); +} + +export { UPDATE_QUERY_KEY, WHATS_NEW_QUERY_KEY }; diff --git a/frontend/src/shared/store/gatewayTools.ts b/frontend/src/shared/store/gatewayTools.ts index f5d024e..9978646 100644 --- a/frontend/src/shared/store/gatewayTools.ts +++ b/frontend/src/shared/store/gatewayTools.ts @@ -11,6 +11,7 @@ export type GatewayToolRequirement = { name?: string; available: boolean; reason?: string; + data?: unknown; }; export type GatewayToolSpec = { diff --git a/frontend/src/shared/store/update.ts b/frontend/src/shared/store/update.ts index 3eab3e3..5ae5648 100644 --- a/frontend/src/shared/store/update.ts +++ b/frontend/src/shared/store/update.ts @@ -16,6 +16,8 @@ export interface UpdateInfo { currentVersion: string; latestVersion: string; changelog: string; + preparedVersion: string; + preparedChangelog: string; downloadURL: string; checkedAt?: string; status: UpdateStatus; @@ -23,9 +25,23 @@ export interface UpdateInfo { message?: string; } +export interface WhatsNewInfo { + version: string; + currentVersion: string; + changelog: string; +} + +export interface WhatsNewPreview { + notice: WhatsNewInfo; + targetWindow: "main" | "settings"; +} + export interface UpdateStore { info: UpdateInfo; + whatsNewPreview: WhatsNewPreview | null; setInfo: (info: UpdateInfo) => void; + openWhatsNewPreview: (notice: WhatsNewInfo, targetWindow: WhatsNewPreview["targetWindow"]) => void; + clearWhatsNewPreview: () => void; } const defaultInfo: UpdateInfo = { @@ -33,6 +49,8 @@ const defaultInfo: UpdateInfo = { currentVersion: "", latestVersion: "", changelog: "", + preparedVersion: "", + preparedChangelog: "", downloadURL: "", status: "idle", progress: 0, @@ -41,7 +59,10 @@ const defaultInfo: UpdateInfo = { export const useUpdateStore = create((set) => ({ info: defaultInfo, + whatsNewPreview: null, setInfo: (info) => set({ info }), + openWhatsNewPreview: (notice, targetWindow) => set({ whatsNewPreview: { notice, targetWindow } }), + clearWhatsNewPreview: () => set({ whatsNewPreview: null }), })); export function normalizeUpdateInfo(raw: Partial | null | undefined): UpdateInfo { @@ -54,6 +75,8 @@ export function normalizeUpdateInfo(raw: Partial | null | undefined) currentVersion: raw.currentVersion ?? anyRaw.CurrentVersion ?? "", latestVersion: raw.latestVersion ?? anyRaw.LatestVersion ?? "", changelog: raw.changelog ?? anyRaw.Changelog ?? "", + preparedVersion: raw.preparedVersion ?? anyRaw.PreparedVersion ?? "", + preparedChangelog: raw.preparedChangelog ?? anyRaw.PreparedChangelog ?? "", downloadURL: raw.downloadURL ?? anyRaw.DownloadURL ?? "", checkedAt: raw.checkedAt ?? anyRaw.CheckedAt, status: (raw.status as UpdateStatus) ?? (anyRaw.Status as UpdateStatus) ?? "idle", @@ -61,3 +84,70 @@ export function normalizeUpdateInfo(raw: Partial | null | undefined) message: raw.message ?? anyRaw.Message ?? "", }; } + +export function normalizeWhatsNewInfo( + raw: Partial | null | undefined +): WhatsNewInfo | null { + if (!raw) { + return null; + } + const anyRaw = raw as any; + const version = (raw.version ?? anyRaw.Version ?? "").trim(); + if (!version) { + return null; + } + return { + version, + currentVersion: (raw.currentVersion ?? anyRaw.CurrentVersion ?? version).trim(), + changelog: raw.changelog ?? anyRaw.Changelog ?? "", + }; +} + +export function compareUpdateVersion(left: string, right: string): number { + const leftParts = normalizeVersionParts(left); + const rightParts = normalizeVersionParts(right); + const maxLength = Math.max(leftParts.length, rightParts.length); + for (let index = 0; index < maxLength; index += 1) { + const leftValue = leftParts[index] ?? 0; + const rightValue = rightParts[index] ?? 0; + if (leftValue < rightValue) { + return -1; + } + if (leftValue > rightValue) { + return 1; + } + } + return 0; +} + +export function hasPreparedUpdate(info: UpdateInfo): boolean { + const preparedVersion = info.preparedVersion.trim(); + if (!preparedVersion) { + return false; + } + return compareUpdateVersion(info.currentVersion, preparedVersion) < 0; +} + +export function hasRemoteUpdate(info: UpdateInfo): boolean { + const latestVersion = info.latestVersion.trim(); + if (!latestVersion) { + return false; + } + return compareUpdateVersion(info.currentVersion, latestVersion) < 0; +} + +export function displayUpdateVersion(info: UpdateInfo): string { + if (hasPreparedUpdate(info)) { + return info.preparedVersion.trim(); + } + return info.latestVersion.trim(); +} + +function normalizeVersionParts(version: string): number[] { + return version + .trim() + .replace(/^v/i, "") + .split(".") + .map((part) => Number.parseInt(part, 10)) + .filter((part) => Number.isFinite(part)); +} diff --git a/go.mod b/go.mod index 38eaa33..1b08cea 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,11 @@ go 1.25.5 require ( github.com/JohannesKaufmann/html-to-markdown v1.6.0 + github.com/PuerkitoBio/goquery v1.9.2 github.com/asg017/sqlite-vec-go-bindings v0.1.6 github.com/bep/debounce v1.2.1 + github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d + github.com/chromedp/chromedp v0.14.2 github.com/cloudwego/eino v0.8.6 github.com/eino-contrib/jsonschema v1.0.3 github.com/fsnotify/fsnotify v1.8.0 @@ -13,7 +16,6 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.8 github.com/mymmrac/telego v1.6.0 github.com/ncruces/go-sqlite3 v0.23.3 - github.com/playwright-community/playwright-go v0.5200.1 github.com/tetratelabs/wazero v1.11.0 github.com/uptrace/bun v1.2.16 github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 @@ -34,7 +36,6 @@ require ( git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect - github.com/PuerkitoBio/goquery v1.9.2 // indirect github.com/adrg/xdg v0.5.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/cascadia v1.3.2 // indirect @@ -43,11 +44,11 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/chromedp/sysutil v1.1.0 // indirect github.com/cloudflare/circl v1.6.3 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/coder/websocket v1.8.14 // indirect github.com/cyphar/filepath-securejoin v0.6.1 // indirect - github.com/deckarep/golang-set/v2 v2.7.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.9.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect @@ -55,9 +56,11 @@ require ( github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.8.0 // indirect github.com/go-git/go-git/v5 v5.17.2 // indirect - github.com/go-jose/go-jose/v3 v3.0.5 // indirect + github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e // indirect github.com/go-ole/go-ole v1.3.0 // indirect - github.com/go-stack/stack v1.8.1 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/gobwas/ws v1.4.0 // indirect github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/goph/emperror v0.17.2 // indirect diff --git a/go.sum b/go.sum index 442a444..d849e81 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,12 @@ github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9V github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d h1:ZtA1sedVbEW7EW80Iz2GR3Ye6PwbJAJXjv7D74xG6HU= +github.com/chromedp/cdproto v0.0.0-20250803210736-d308e07a266d/go.mod h1:NItd7aLkcfOA/dcMXvl8p1u+lQqioRMq/SqDp71Pb/k= +github.com/chromedp/chromedp v0.14.2 h1:r3b/WtwM50RsBZHMUm9fsNhhzRStTHrKdr2zmwbZSzM= +github.com/chromedp/chromedp v0.14.2/go.mod h1:rHzAv60xDE7VNy/MYtTUrYreSc0ujt2O1/C3bzctYBo= +github.com/chromedp/sysutil v1.1.0 h1:PUFNv5EcprjqXZD9nJb9b/c9ibAbxiYo4exNWZyipwM= +github.com/chromedp/sysutil v1.1.0/go.mod h1:WiThHUdltqCNKGc4gaU50XgYjwjYIhKWoHGPTUfWTJ8= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= @@ -54,8 +60,6 @@ github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/deckarep/golang-set/v2 v2.7.0 h1:gIloKvD7yH2oip4VLhsv3JyLLFnC0Y2mlusgcvJYW5k= -github.com/deckarep/golang-set/v2 v2.7.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= @@ -84,21 +88,22 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= github.com/go-git/go-git/v5 v5.17.2 h1:B+nkdlxdYrvyFK4GPXVU8w1U+YkbsgciIR7f2sZJ104= github.com/go-git/go-git/v5 v5.17.2/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo= -github.com/go-jose/go-jose/v3 v3.0.5 h1:BLLJWbC4nMZOfuPVxoZIxeYsn6Nl2r1fITaJ78UQlVQ= -github.com/go-jose/go-jose/v3 v3.0.5/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= -github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -147,6 +152,8 @@ github.com/leaanthony/go-ansi-parser v1.6.1 h1:xd8bzARK3dErqkPFtoF9F3/HgN8UQk0ed github.com/leaanthony/go-ansi-parser v1.6.1/go.mod h1:+vva/2y4alzVmmIEpk9QDhA7vLC5zKDTRwfZGOp3IWU= github.com/leaanthony/u v1.1.1 h1:TUFjwDGlNX+WuwVEzDqQwC2lOv0P4uhTQw7CMFdiK7M= github.com/leaanthony/u v1.1.1/go.mod h1:9+o6hejoRljvZ3BzdYlVL0JYCwtnAsVuN9pVTQcaRfI= +github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= +github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w= github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -160,8 +167,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= -github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -180,6 +185,8 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde h1:x0TT0RDC7UhAVbbWWBzr41ElhJx5tXPWkIHA2HWPRuw= +github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= @@ -190,8 +197,6 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/playwright-community/playwright-go v0.5200.1 h1:Sm2oOuhqt0M5Y4kUi/Qh9w4cyyi3ZIWTBeGKImc2UVo= -github.com/playwright-community/playwright-go v0.5200.1/go.mod h1:UnnyQZaqUOO5ywAZu60+N4EiWReUqX1MQBBA3Oofvf8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index d3c0339..608f796 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -13,7 +13,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "time" agentservice "dreamcreator/internal/application/agent/service" @@ -125,63 +124,6 @@ var ( AppDescription = "An AI assistant for content creators." ) -type providersUpdatedWindowNotifier struct { - manager *wails.WindowManager -} - -func (notifier providersUpdatedWindowNotifier) ProvidersUpdated() { - if notifier.manager == nil { - return - } - notifier.manager.EmitProvidersUpdated() -} - -type settingsBroadcastAdapter struct { - service *service.SettingsService - - mu sync.RWMutex - applier func(settingsdto.Settings) -} - -func newSettingsBroadcastAdapter(settingsService *service.SettingsService) *settingsBroadcastAdapter { - return &settingsBroadcastAdapter{service: settingsService} -} - -func (adapter *settingsBroadcastAdapter) SetApplier(applier func(settingsdto.Settings)) { - if adapter == nil { - return - } - adapter.mu.Lock() - adapter.applier = applier - adapter.mu.Unlock() -} - -func (adapter *settingsBroadcastAdapter) GetSettings(ctx context.Context) (settingsdto.Settings, error) { - if adapter == nil || adapter.service == nil { - return settingsdto.Settings{}, errors.New("settings service unavailable") - } - return adapter.service.GetSettings(ctx) -} - -func (adapter *settingsBroadcastAdapter) UpdateSettings(ctx context.Context, request settingsdto.UpdateSettingsRequest) (settingsdto.Settings, error) { - if adapter == nil || adapter.service == nil { - return settingsdto.Settings{}, errors.New("settings service unavailable") - } - return adapter.service.UpdateSettings(ctx, request) -} - -func (adapter *settingsBroadcastAdapter) ApplySettings(updated settingsdto.Settings) { - if adapter == nil { - return - } - adapter.mu.RLock() - applier := adapter.applier - adapter.mu.RUnlock() - if applier != nil { - applier(updated) - } -} - func CreateApplication(assets fs.FS) (*application.App, error) { appVersion := resolveVersion(os.Getenv("APP_ENV")) startup := currentStartupContext(os.Args[1:]) @@ -464,7 +406,7 @@ func CreateApplication(assets fs.FS) (*application.App, error) { startAccentColorWatcher(accentCtx, settingsService, windowManager) updateCatalog := buildSoftwareUpdateService(proxyManager) - updateService, err := buildUpdateService(proxyManager, eventBus, windowManager, updateCatalog, appVersion) + updateService, err := buildUpdateService(ctx, proxyManager, eventBus, windowManager, updateCatalog, appVersion) if err != nil { return nil, err } @@ -497,7 +439,7 @@ func CreateApplication(assets fs.FS) (*application.App, error) { startModelsDevCatalogSyncWorker(ctx, modelsDevCatalog) connectorsRepo := connectorsrepo.NewSQLiteConnectorRepository(database.Bun) - connectorsService := connectorsservice.NewConnectorsService(connectorsRepo) + connectorsService := connectorsservice.NewConnectorsService(connectorsRepo, settingsService) if err := connectorsService.EnsureDefaults(ctx); err != nil { return nil, err } @@ -561,7 +503,8 @@ func CreateApplication(assets fs.FS) (*application.App, error) { llmCallRecordRepo := llmrecordrepo.NewSQLiteRepository(database.Bun) llmCallRecordService := llmrecord.NewService(llmCallRecordRepo, settingsService) toolService := toolsservice.NewToolService() - toolService.SetPolicy(gatewaytools.NewPolicyPipeline(settingsService)) + toolPolicy := gatewaytools.NewPolicyPipeline(settingsService) + toolService.SetPolicy(toolPolicy) toolExecutor := gatewaytools.NewRegistryExecutor() toolService.SetExecutor(toolExecutor) policyAuditStore := toolpolicyrepo.NewSQLitePolicyAuditStore(database.Bun) @@ -643,6 +586,16 @@ func CreateApplication(assets fs.FS) (*application.App, error) { voiceConfigRepo := voicerepo.NewSQLiteVoiceConfigRepository(database.Bun) ttsJobRepo := voicerepo.NewSQLiteTTSJobRepository(database.Bun) voiceService := gatewayvoice.NewService(voiceConfigRepo, ttsJobRepo, usageService, settingsService, gatewayServer) + builtinRequirementResolver := gatewaytools.NewBuiltinRequirementResolver(gatewaytools.BuiltinRequirementDeps{ + Settings: settingsNotifier, + Assistants: assistantService, + Providers: providerRepo, + Models: modelRepo, + Secrets: secretRepo, + Voice: voiceService, + }) + toolPolicy.SetRequirementsResolver(builtinRequirementResolver) + gatewayToolService.SetRequirementsResolver(builtinRequirementResolver) gatewaymethods.RegisterUsage(gatewayRouter, usageService) gatewaymethods.RegisterVoice(gatewayRouter, voiceService) gatewaytools.RegisterBuiltinTools(ctx, toolService, toolExecutor, gatewaytools.BuiltinToolDeps{ @@ -1313,7 +1266,7 @@ func buildSoftwareUpdateService(proxyManager *proxy.Manager) *softwareupdate.Ser }) } -func buildUpdateService(proxyManager *proxy.Manager, bus appevents.Bus, notifier applicationupdate.Notifier, catalog *softwareupdate.Service, currentVersion string) (*applicationupdate.Service, error) { +func buildUpdateService(ctx context.Context, proxyManager *proxy.Manager, bus appevents.Bus, notifier applicationupdate.Notifier, catalog *softwareupdate.Service, currentVersion string) (*applicationupdate.Service, error) { httpClient := proxyManager.HTTPClient() downloader := infrastructureupdate.NewHTTPDownloader(httpClient) installer, err := infrastructureupdate.NewInstaller("") @@ -1329,6 +1282,9 @@ func buildUpdateService(proxyManager *proxy.Manager, bus appevents.Bus, notifier Notifier: notifier, }) service.SetCurrentVersion(currentVersion) + if _, err := service.RestorePreparedUpdate(ctx); err != nil { + zap.L().Warn("update: restore prepared update failed", zap.Error(err)) + } return service, nil } diff --git a/internal/app/bootstrap_support.go b/internal/app/bootstrap_support.go new file mode 100644 index 0000000..3f44252 --- /dev/null +++ b/internal/app/bootstrap_support.go @@ -0,0 +1,68 @@ +package app + +import ( + "context" + "errors" + "sync" + + settingsdto "dreamcreator/internal/application/settings/dto" + "dreamcreator/internal/application/settings/service" + "dreamcreator/internal/presentation/wails" +) + +type providersUpdatedWindowNotifier struct { + manager *wails.WindowManager +} + +func (notifier providersUpdatedWindowNotifier) ProvidersUpdated() { + if notifier.manager == nil { + return + } + notifier.manager.EmitProvidersUpdated() +} + +type settingsBroadcastAdapter struct { + service *service.SettingsService + + mu sync.RWMutex + applier func(settingsdto.Settings) +} + +func newSettingsBroadcastAdapter(settingsService *service.SettingsService) *settingsBroadcastAdapter { + return &settingsBroadcastAdapter{service: settingsService} +} + +func (adapter *settingsBroadcastAdapter) SetApplier(applier func(settingsdto.Settings)) { + if adapter == nil { + return + } + adapter.mu.Lock() + adapter.applier = applier + adapter.mu.Unlock() +} + +func (adapter *settingsBroadcastAdapter) GetSettings(ctx context.Context) (settingsdto.Settings, error) { + if adapter == nil || adapter.service == nil { + return settingsdto.Settings{}, errors.New("settings service unavailable") + } + return adapter.service.GetSettings(ctx) +} + +func (adapter *settingsBroadcastAdapter) UpdateSettings(ctx context.Context, request settingsdto.UpdateSettingsRequest) (settingsdto.Settings, error) { + if adapter == nil || adapter.service == nil { + return settingsdto.Settings{}, errors.New("settings service unavailable") + } + return adapter.service.UpdateSettings(ctx, request) +} + +func (adapter *settingsBroadcastAdapter) ApplySettings(updated settingsdto.Settings) { + if adapter == nil { + return + } + adapter.mu.RLock() + applier := adapter.applier + adapter.mu.RUnlock() + if applier != nil { + applier(updated) + } +} diff --git a/internal/app/startup_context.go b/internal/app/startup_context.go index 9902dda..1168291 100644 --- a/internal/app/startup_context.go +++ b/internal/app/startup_context.go @@ -3,16 +3,22 @@ package app import "strings" const autoStartLaunchArgument = "--autostart" +const skipPreparedUpdateLaunchArgument = "--skip-prepared-update-once" type startupContext struct { launchedByAutoStart bool + skipPreparedUpdate bool } func currentStartupContext(args []string) startupContext { + context := startupContext{} for _, arg := range args { - if strings.EqualFold(strings.TrimSpace(arg), autoStartLaunchArgument) { - return startupContext{launchedByAutoStart: true} + switch strings.ToLower(strings.TrimSpace(arg)) { + case autoStartLaunchArgument: + context.launchedByAutoStart = true + case skipPreparedUpdateLaunchArgument: + context.skipPreparedUpdate = true } } - return startupContext{} + return context } diff --git a/internal/app/startup_context_test.go b/internal/app/startup_context_test.go index 5841149..f82720a 100644 --- a/internal/app/startup_context_test.go +++ b/internal/app/startup_context_test.go @@ -4,20 +4,25 @@ import "testing" func TestCurrentStartupContext(t *testing.T) { tests := []struct { - name string - args []string - expected bool + name string + args []string + expectedAutoStart bool + expectedSkip bool }{ - {name: "no marker", args: []string{"--verbose"}, expected: false}, - {name: "exact marker", args: []string{"--autostart"}, expected: true}, - {name: "marker with spaces and mixed case", args: []string{" --AutoStart "}, expected: true}, + {name: "no marker", args: []string{"--verbose"}}, + {name: "exact marker", args: []string{"--autostart"}, expectedAutoStart: true}, + {name: "marker with spaces and mixed case", args: []string{" --AutoStart "}, expectedAutoStart: true}, + {name: "skip prepared update marker", args: []string{"--skip-prepared-update-once"}, expectedSkip: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := currentStartupContext(tt.args) - if got.launchedByAutoStart != tt.expected { - t.Fatalf("launchedByAutoStart = %v, want %v", got.launchedByAutoStart, tt.expected) + if got.launchedByAutoStart != tt.expectedAutoStart { + t.Fatalf("launchedByAutoStart = %v, want %v", got.launchedByAutoStart, tt.expectedAutoStart) + } + if got.skipPreparedUpdate != tt.expectedSkip { + t.Fatalf("skipPreparedUpdate = %v, want %v", got.skipPreparedUpdate, tt.expectedSkip) } }) } diff --git a/internal/app/update_startup.go b/internal/app/update_startup.go new file mode 100644 index 0000000..0dcfa62 --- /dev/null +++ b/internal/app/update_startup.go @@ -0,0 +1,83 @@ +package app + +import ( + "context" + "os" + "strconv" + "strings" + + domainupdate "dreamcreator/internal/domain/update" + infrastructureupdate "dreamcreator/internal/infrastructure/update" +) + +type preparedUpdateStartupRunner interface { + PreparedUpdate(ctx context.Context) (domainupdate.Info, bool, error) + ClearPreparedUpdate(ctx context.Context) error + RestartToApply(ctx context.Context) error +} + +func TryApplyPreparedUpdateOnLaunch(ctx context.Context, args []string) (bool, error) { + startup := currentStartupContext(args) + if startup.skipPreparedUpdate { + return false, nil + } + + installer, err := infrastructureupdate.NewInstaller("") + if err != nil { + return false, err + } + return maybeApplyPreparedUpdateOnLaunch(ctx, resolveVersion(os.Getenv("APP_ENV")), installer) +} + +func maybeApplyPreparedUpdateOnLaunch(ctx context.Context, currentVersion string, installer preparedUpdateStartupRunner) (bool, error) { + if installer == nil { + return false, nil + } + + normalizedCurrent := domainupdate.NormalizeVersion(currentVersion) + if !isComparableReleaseVersion(normalizedCurrent) { + return false, nil + } + + prepared, found, err := installer.PreparedUpdate(ctx) + if err != nil || !found { + return false, err + } + + preparedVersion := domainupdate.NormalizeVersion(prepared.PreparedVersion) + if preparedVersion == "" { + preparedVersion = domainupdate.NormalizeVersion(prepared.LatestVersion) + } + if !isComparableReleaseVersion(preparedVersion) { + return false, nil + } + + if domainupdate.CompareVersion(normalizedCurrent, preparedVersion) >= 0 { + if err := installer.ClearPreparedUpdate(ctx); err != nil { + return false, err + } + return false, nil + } + + if err := installer.RestartToApply(ctx); err != nil { + return false, err + } + return true, nil +} + +func isComparableReleaseVersion(version string) bool { + normalized := domainupdate.NormalizeVersion(version) + if normalized == "" { + return false + } + parts := strings.Split(normalized, ".") + for _, part := range parts { + if part == "" { + return false + } + if _, err := strconv.Atoi(part); err != nil { + return false + } + } + return true +} diff --git a/internal/app/update_startup_test.go b/internal/app/update_startup_test.go new file mode 100644 index 0000000..2181b69 --- /dev/null +++ b/internal/app/update_startup_test.go @@ -0,0 +1,119 @@ +package app + +import ( + "context" + "errors" + "testing" + + domainupdate "dreamcreator/internal/domain/update" +) + +type preparedUpdateStartupStub struct { + prepared domainupdate.Info + found bool + preparedErr error + clearInvoked bool + restartInvoked bool + restartErr error +} + +func (stub *preparedUpdateStartupStub) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) { + return stub.prepared, stub.found, stub.preparedErr +} + +func (stub *preparedUpdateStartupStub) ClearPreparedUpdate(_ context.Context) error { + stub.clearInvoked = true + return nil +} + +func (stub *preparedUpdateStartupStub) RestartToApply(_ context.Context) error { + stub.restartInvoked = true + return stub.restartErr +} + +func TestMaybeApplyPreparedUpdateOnLaunchStartsHelperForOlderCurrentVersion(t *testing.T) { + t.Parallel() + + installer := &preparedUpdateStartupStub{ + found: true, + prepared: domainupdate.Info{ + PreparedVersion: "2.0.7", + }, + } + + applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.6", installer) + if err != nil { + t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err) + } + if !applied { + t.Fatal("expected prepared update to be applied on launch") + } + if !installer.restartInvoked { + t.Fatal("expected restart helper to be invoked") + } +} + +func TestMaybeApplyPreparedUpdateOnLaunchClearsStalePreparedPlan(t *testing.T) { + t.Parallel() + + installer := &preparedUpdateStartupStub{ + found: true, + prepared: domainupdate.Info{ + PreparedVersion: "2.0.7", + }, + } + + applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.7", installer) + if err != nil { + t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err) + } + if applied { + t.Fatal("did not expect prepared update to apply when current version is already latest") + } + if !installer.clearInvoked { + t.Fatal("expected stale prepared plan to be cleared") + } +} + +func TestMaybeApplyPreparedUpdateOnLaunchSkipsNonReleaseVersion(t *testing.T) { + t.Parallel() + + installer := &preparedUpdateStartupStub{ + found: true, + prepared: domainupdate.Info{ + PreparedVersion: "2.0.7", + }, + } + + applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "dev", installer) + if err != nil { + t.Fatalf("maybeApplyPreparedUpdateOnLaunch failed: %v", err) + } + if applied { + t.Fatal("did not expect dev build to auto-apply prepared update") + } + if installer.restartInvoked { + t.Fatal("did not expect restart helper to run") + } +} + +func TestMaybeApplyPreparedUpdateOnLaunchReturnsRestartError(t *testing.T) { + t.Parallel() + + restartErr := errors.New("helper launch failed") + installer := &preparedUpdateStartupStub{ + found: true, + prepared: domainupdate.Info{ + PreparedVersion: "2.0.7", + }, + restartErr: restartErr, + } + + applied, err := maybeApplyPreparedUpdateOnLaunch(context.Background(), "2.0.6", installer) + if !errors.Is(err, restartErr) { + t.Fatalf("expected restart error, got %v", err) + } + if applied { + t.Fatal("did not expect prepared update to report success") + } +} diff --git a/internal/application/agentruntime/tool_executor.go b/internal/application/agentruntime/tool_executor.go index a216fff..7776313 100644 --- a/internal/application/agentruntime/tool_executor.go +++ b/internal/application/agentruntime/tool_executor.go @@ -14,9 +14,10 @@ import ( var ErrToolNotFound = errors.New("tool not found") type ToolDefinition struct { - Name string - Type string - Invoke func(ctx context.Context, args string) (string, error) + Name string + Type string + SchemaJSON string + Invoke func(ctx context.Context, args string) (string, error) } type ToolExecutor struct { diff --git a/internal/application/agentruntime/tool_validator.go b/internal/application/agentruntime/tool_validator.go index 93b1c0f..a5aa2d3 100644 --- a/internal/application/agentruntime/tool_validator.go +++ b/internal/application/agentruntime/tool_validator.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "strings" "github.com/cloudwego/eino/schema" @@ -12,15 +13,18 @@ import ( var ( ErrToolNameRequired = errors.New("tool name is required") ErrToolArgsInvalid = errors.New("tool args must be a json object") + ErrToolArgsSchema = errors.New("tool args do not match schema") ) type ToolValidator interface { Validate(call schema.ToolCall) error } -type JSONToolValidator struct{} +type JSONToolValidator struct { + Tools map[string]ToolDefinition +} -func (JSONToolValidator) Validate(call schema.ToolCall) error { +func (validator JSONToolValidator) Validate(call schema.ToolCall) error { name := strings.TrimSpace(call.Function.Name) if name == "" { return ErrToolNameRequired @@ -36,5 +40,249 @@ func (JSONToolValidator) Validate(call schema.ToolCall) error { if _, ok := decoded.(map[string]any); !ok { return ErrToolArgsInvalid } + if definition, ok := validator.Tools[name]; ok { + if err := validateToolArgsAgainstSchema(decoded, strings.TrimSpace(definition.SchemaJSON)); err != nil { + return err + } + } + return nil +} + +func validateToolArgsAgainstSchema(value any, schemaJSON string) error { + if strings.TrimSpace(schemaJSON) == "" { + return nil + } + var schema map[string]any + if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil { + return nil + } + return validateJSONSchemaValue(value, schema, "$") +} + +func validateJSONSchemaValue(value any, schema map[string]any, path string) error { + if len(schema) == 0 { + return nil + } + if allOf, ok := schema["allOf"].([]any); ok { + for _, item := range allOf { + child, _ := item.(map[string]any) + if err := validateJSONSchemaValue(value, child, path); err != nil { + return err + } + } + } + if anyOf, ok := schema["anyOf"].([]any); ok { + matched := false + var lastErr error + for _, item := range anyOf { + child, _ := item.(map[string]any) + err := validateJSONSchemaValue(value, child, path) + if err == nil { + matched = true + break + } + lastErr = err + } + if !matched { + if lastErr != nil { + return lastErr + } + return fmt.Errorf("%w: %s must match at least one schema", ErrToolArgsSchema, path) + } + } + if oneOf, ok := schema["oneOf"].([]any); ok { + matches := 0 + var lastErr error + for _, item := range oneOf { + child, _ := item.(map[string]any) + err := validateJSONSchemaValue(value, child, path) + if err == nil { + matches++ + continue + } + lastErr = err + } + if matches != 1 { + if lastErr != nil && matches == 0 { + return lastErr + } + return fmt.Errorf("%w: %s must match exactly one schema", ErrToolArgsSchema, path) + } + } + if constValue, ok := schema["const"]; ok && !jsonSchemaValuesEqual(value, constValue) { + return fmt.Errorf("%w: %s must equal %v", ErrToolArgsSchema, path, constValue) + } + if enumValues, ok := schema["enum"].([]any); ok && len(enumValues) > 0 { + matched := false + for _, enumValue := range enumValues { + if jsonSchemaValuesEqual(value, enumValue) { + matched = true + break + } + } + if !matched { + return fmt.Errorf("%w: %s must be one of %v", ErrToolArgsSchema, path, enumValues) + } + } + switch schemaType := schema["type"].(type) { + case string: + if err := validateJSONSchemaType(value, schema, path, schemaType); err != nil { + return err + } + case []any: + var lastErr error + for _, raw := range schemaType { + typeName, _ := raw.(string) + if typeName == "" { + continue + } + if err := validateJSONSchemaType(value, schema, path, typeName); err == nil { + return nil + } else { + lastErr = err + } + } + if lastErr != nil { + return lastErr + } + } + if _, hasProperties := schema["properties"]; hasProperties || schema["required"] != nil || schema["additionalProperties"] != nil { + return validateJSONSchemaType(value, schema, path, "object") + } + if _, hasItems := schema["items"]; hasItems { + return validateJSONSchemaType(value, schema, path, "array") + } return nil } + +func validateJSONSchemaType(value any, schema map[string]any, path string, schemaType string) error { + switch schemaType { + case "object": + obj, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("%w: %s must be an object", ErrToolArgsSchema, path) + } + properties, _ := schema["properties"].(map[string]any) + requiredFields := jsonSchemaStringSlice(schema["required"]) + for _, field := range requiredFields { + if _, exists := obj[field]; !exists { + return fmt.Errorf("%w: %s.%s is required", ErrToolArgsSchema, path, field) + } + } + for key, raw := range obj { + childPath := path + "." + key + propertySchema, hasProperty := properties[key] + if hasProperty { + if child, ok := propertySchema.(map[string]any); ok { + if err := validateJSONSchemaValue(raw, child, childPath); err != nil { + return err + } + } + continue + } + switch additional := schema["additionalProperties"].(type) { + case bool: + if !additional { + return fmt.Errorf("%w: %s is not allowed", ErrToolArgsSchema, childPath) + } + case map[string]any: + if err := validateJSONSchemaValue(raw, additional, childPath); err != nil { + return err + } + } + } + return nil + case "array": + items, ok := value.([]any) + if !ok { + return fmt.Errorf("%w: %s must be an array", ErrToolArgsSchema, path) + } + if itemSchema, ok := schema["items"].(map[string]any); ok { + for index, item := range items { + if err := validateJSONSchemaValue(item, itemSchema, fmt.Sprintf("%s[%d]", path, index)); err != nil { + return err + } + } + } + return nil + case "string": + if _, ok := value.(string); !ok { + return fmt.Errorf("%w: %s must be a string", ErrToolArgsSchema, path) + } + return nil + case "boolean": + if _, ok := value.(bool); !ok { + return fmt.Errorf("%w: %s must be a boolean", ErrToolArgsSchema, path) + } + return nil + case "integer": + number, ok := jsonSchemaAsFloat64(value) + if !ok || float64(int64(number)) != number { + return fmt.Errorf("%w: %s must be an integer", ErrToolArgsSchema, path) + } + if minimum, ok := jsonSchemaAsFloat64(schema["minimum"]); ok && number < minimum { + return fmt.Errorf("%w: %s must be >= %v", ErrToolArgsSchema, path, minimum) + } + if maximum, ok := jsonSchemaAsFloat64(schema["maximum"]); ok && number > maximum { + return fmt.Errorf("%w: %s must be <= %v", ErrToolArgsSchema, path, maximum) + } + return nil + case "number": + number, ok := jsonSchemaAsFloat64(value) + if !ok { + return fmt.Errorf("%w: %s must be a number", ErrToolArgsSchema, path) + } + if minimum, ok := jsonSchemaAsFloat64(schema["minimum"]); ok && number < minimum { + return fmt.Errorf("%w: %s must be >= %v", ErrToolArgsSchema, path, minimum) + } + if maximum, ok := jsonSchemaAsFloat64(schema["maximum"]); ok && number > maximum { + return fmt.Errorf("%w: %s must be <= %v", ErrToolArgsSchema, path, maximum) + } + return nil + default: + return nil + } +} + +func jsonSchemaStringSlice(value any) []string { + switch typed := value.(type) { + case []string: + return append([]string(nil), typed...) + case []any: + result := make([]string, 0, len(typed)) + for _, item := range typed { + if text, ok := item.(string); ok && strings.TrimSpace(text) != "" { + result = append(result, text) + } + } + return result + default: + return nil + } +} + +func jsonSchemaAsFloat64(value any) (float64, bool) { + switch typed := value.(type) { + case float64: + return typed, true + case float32: + return float64(typed), true + case int: + return float64(typed), true + case int64: + return float64(typed), true + case int32: + return float64(typed), true + default: + return 0, false + } +} + +func jsonSchemaValuesEqual(left any, right any) bool { + leftNumber, leftIsNumber := jsonSchemaAsFloat64(left) + rightNumber, rightIsNumber := jsonSchemaAsFloat64(right) + if leftIsNumber && rightIsNumber { + return leftNumber == rightNumber + } + return reflect.DeepEqual(left, right) +} diff --git a/internal/application/agentruntime/tool_validator_test.go b/internal/application/agentruntime/tool_validator_test.go index 1faa9d3..a26da28 100644 --- a/internal/application/agentruntime/tool_validator_test.go +++ b/internal/application/agentruntime/tool_validator_test.go @@ -33,3 +33,76 @@ func TestJSONToolValidatorAcceptsObjectArgs(t *testing.T) { t.Fatalf("expected valid args, got %v", err) } } + +func TestJSONToolValidatorRejectsSchemaViolation(t *testing.T) { + validator := JSONToolValidator{ + Tools: map[string]ToolDefinition{ + "browser": { + Name: "browser", + SchemaJSON: `{ + "type":"object", + "properties":{ + "action":{"type":"string","enum":["open","act"]}, + "url":{"type":"string"}, + "request":{ + "type":"object", + "properties":{"kind":{"type":"string"}}, + "required":["kind"] + } + }, + "required":["action"], + "allOf":[ + { + "anyOf":[ + {"properties":{"action":{"const":"open"}},"required":["action","url"]}, + {"properties":{"action":{"const":"act"}},"required":["action","request"]} + ] + } + ] + }`, + }, + }, + } + + err := validator.Validate(schema.ToolCall{ + Function: schema.FunctionCall{ + Name: "browser", + Arguments: `{"action":"open"}`, + }, + }) + if err == nil || !errors.Is(err, ErrToolArgsSchema) { + t.Fatalf("expected schema error for missing url, got %v", err) + } +} + +func TestJSONToolValidatorRejectsNestedRequiredSchemaViolation(t *testing.T) { + validator := JSONToolValidator{ + Tools: map[string]ToolDefinition{ + "browser": { + Name: "browser", + SchemaJSON: `{ + "type":"object", + "properties":{ + "action":{"type":"string","enum":["act"]}, + "request":{ + "type":"object", + "properties":{"kind":{"type":"string"}}, + "required":["kind"] + } + }, + "required":["action","request"] + }`, + }, + }, + } + + err := validator.Validate(schema.ToolCall{ + Function: schema.FunctionCall{ + Name: "browser", + Arguments: `{"action":"act","request":{}}`, + }, + }) + if err == nil || !errors.Is(err, ErrToolArgsSchema) { + t.Fatalf("expected schema error for missing request.kind, got %v", err) + } +} diff --git a/internal/application/browsercdp/connectors.go b/internal/application/browsercdp/connectors.go new file mode 100644 index 0000000..1cdc2d5 --- /dev/null +++ b/internal/application/browsercdp/connectors.go @@ -0,0 +1,79 @@ +package browsercdp + +import ( + "context" + "sort" + "strings" + + connectorsdto "dreamcreator/internal/application/connectors/dto" + appcookies "dreamcreator/internal/application/cookies" + "dreamcreator/internal/application/sitepolicy" +) + +type ConnectorsReader interface { + ListConnectors(ctx context.Context) ([]connectorsdto.Connector, error) +} + +type ConnectorCookieProvider interface { + ResolveCookiesForURL(ctx context.Context, rawURL string) ([]appcookies.Record, error) +} + +type ConnectorCookieProviderFunc func(ctx context.Context, rawURL string) ([]appcookies.Record, error) + +func (fn ConnectorCookieProviderFunc) ResolveCookiesForURL(ctx context.Context, rawURL string) ([]appcookies.Record, error) { + if fn == nil { + return nil, nil + } + return fn(ctx, rawURL) +} + +func ResolveConnectorCookiesForURL(ctx context.Context, connectors ConnectorsReader, rawURL string) ([]appcookies.Record, error) { + if connectors == nil { + return nil, nil + } + items, err := connectors.ListConnectors(ctx) + if err != nil { + return nil, err + } + for _, item := range items { + policy, ok := sitepolicy.ForConnectorType(item.Type) + if !ok || !sitepolicy.MatchDomains(rawURL, policy.Domains) { + continue + } + records := make([]appcookies.Record, 0, len(item.Cookies)) + for _, cookie := range item.Cookies { + records = append(records, appcookies.Record{ + Name: strings.TrimSpace(cookie.Name), + Value: cookie.Value, + Domain: strings.TrimSpace(cookie.Domain), + Path: strings.TrimSpace(cookie.Path), + Expires: cookie.Expires, + HttpOnly: cookie.HttpOnly, + Secure: cookie.Secure, + SameSite: strings.TrimSpace(cookie.SameSite), + }) + } + sort.Slice(records, func(i, j int) bool { + left := records[i] + right := records[j] + switch { + case left.Domain != right.Domain: + return left.Domain < right.Domain + case left.Path != right.Path: + return left.Path < right.Path + default: + return left.Name < right.Name + } + }) + return records, nil + } + return nil, nil +} + +func ConnectorTypeForURL(rawURL string) string { + policy, ok := sitepolicy.ForURL(rawURL) + if !ok { + return "" + } + return policy.ConnectorType +} diff --git a/internal/application/browsercdp/connectors_test.go b/internal/application/browsercdp/connectors_test.go new file mode 100644 index 0000000..7a7bb73 --- /dev/null +++ b/internal/application/browsercdp/connectors_test.go @@ -0,0 +1,84 @@ +package browsercdp + +import ( + "context" + "errors" + "testing" + + connectorsdto "dreamcreator/internal/application/connectors/dto" +) + +type connectorsReaderStub struct { + items []connectorsdto.Connector + err error +} + +func (stub connectorsReaderStub) ListConnectors(context.Context) ([]connectorsdto.Connector, error) { + if stub.err != nil { + return nil, stub.err + } + return append([]connectorsdto.Connector(nil), stub.items...), nil +} + +func TestResolveConnectorCookiesForURL_MatchesConnectorPolicy(t *testing.T) { + t.Parallel() + + cookies, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{ + items: []connectorsdto.Connector{ + { + Type: "google", + Cookies: []connectorsdto.ConnectorCookie{ + {Name: "SID", Value: "google-cookie", Domain: ".google.com", Path: "/"}, + }, + }, + { + Type: "github", + Cookies: []connectorsdto.ConnectorCookie{ + {Name: "logged_in", Value: "yes", Domain: ".github.com", Path: "/"}, + }, + }, + }, + }, "https://www.youtube.com/watch?v=test") + if err != nil { + t.Fatalf("resolve cookies: %v", err) + } + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + if cookies[0].Name != "SID" || cookies[0].Value != "google-cookie" { + t.Fatalf("unexpected cookie: %#v", cookies[0]) + } +} + +func TestResolveConnectorCookiesForURL_ReturnsNilWhenNoMatch(t *testing.T) { + t.Parallel() + + cookies, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{ + items: []connectorsdto.Connector{ + { + Type: "github", + Cookies: []connectorsdto.ConnectorCookie{ + {Name: "logged_in", Value: "yes", Domain: ".github.com", Path: "/"}, + }, + }, + }, + }, "https://example.com/") + if err != nil { + t.Fatalf("resolve cookies: %v", err) + } + if len(cookies) != 0 { + t.Fatalf("expected no cookies, got %#v", cookies) + } +} + +func TestResolveConnectorCookiesForURL_PropagatesReaderError(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("connectors unavailable") + _, err := ResolveConnectorCookiesForURL(context.Background(), connectorsReaderStub{ + err: expectedErr, + }, "https://github.com/openai") + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got %v", expectedErr, err) + } +} diff --git a/internal/application/browsercdp/cookies.go b/internal/application/browsercdp/cookies.go new file mode 100644 index 0000000..26c5335 --- /dev/null +++ b/internal/application/browsercdp/cookies.go @@ -0,0 +1,117 @@ +package browsercdp + +import ( + "context" + "net/url" + "strings" + "time" + + "github.com/chromedp/cdproto/cdp" + "github.com/chromedp/cdproto/network" + "github.com/chromedp/cdproto/storage" + + appcookies "dreamcreator/internal/application/cookies" +) + +func SetCookies(ctx context.Context, targetURL string, records []appcookies.Record) error { + params := buildCookieParams(targetURL, records) + if len(params) == 0 { + return nil + } + return network.SetCookies(params).Do(ctx) +} + +func SetCookiesOnBrowser(ctx context.Context, targetURL string, records []appcookies.Record) error { + params := buildCookieParams(targetURL, records) + if len(params) == 0 { + return nil + } + return storage.SetCookies(params).Do(ctx) +} + +func GetAllCookies(ctx context.Context) ([]appcookies.Record, error) { + items, err := network.GetCookies().Do(ctx) + if err != nil { + return nil, err + } + return mapCDPCookies(items), nil +} + +func GetStorageCookies(ctx context.Context) ([]appcookies.Record, error) { + items, err := storage.GetCookies().Do(ctx) + if err != nil { + return nil, err + } + return mapCDPCookies(items), nil +} + +func buildCookieParams(targetURL string, records []appcookies.Record) []*network.CookieParam { + if len(records) == 0 { + return nil + } + params := make([]*network.CookieParam, 0, len(records)) + for _, record := range records { + if strings.TrimSpace(record.Name) == "" { + continue + } + param := &network.CookieParam{ + Name: strings.TrimSpace(record.Name), + Value: record.Value, + Domain: strings.TrimSpace(record.Domain), + Path: strings.TrimSpace(record.Path), + HTTPOnly: record.HttpOnly, + Secure: record.Secure, + } + if param.Path == "" { + param.Path = "/" + } + if record.Expires > 0 { + expires := cdp.TimeSinceEpoch(time.Unix(record.Expires, 0)) + param.Expires = &expires + } + if param.Domain == "" { + if parsed, err := url.Parse(strings.TrimSpace(targetURL)); err == nil && parsed.Hostname() != "" { + param.URL = strings.TrimSpace(parsed.Scheme) + "://" + parsed.Hostname() + } + } + switch strings.ToLower(strings.TrimSpace(record.SameSite)) { + case "lax": + param.SameSite = network.CookieSameSiteLax + case "strict": + param.SameSite = network.CookieSameSiteStrict + case "none": + param.SameSite = network.CookieSameSiteNone + } + params = append(params, param) + } + return params +} + +func mapCDPCookies(items []*network.Cookie) []appcookies.Record { + result := make([]appcookies.Record, 0, len(items)) + for _, item := range items { + if item == nil { + continue + } + sameSite := "" + switch item.SameSite { + case network.CookieSameSiteLax: + sameSite = "lax" + case network.CookieSameSiteStrict: + sameSite = "strict" + case network.CookieSameSiteNone: + sameSite = "none" + } + result = append(result, appcookies.Record{ + Name: item.Name, + Value: item.Value, + Domain: item.Domain, + Path: item.Path, + Expires: int64(item.Expires), + HttpOnly: item.HTTPOnly, + Secure: item.Secure, + SameSite: sameSite, + }) + } + return result +} diff --git a/internal/application/browsercdp/detect.go b/internal/application/browsercdp/detect.go new file mode 100644 index 0000000..bf8333f --- /dev/null +++ b/internal/application/browsercdp/detect.go @@ -0,0 +1,295 @@ +package browsercdp + +import ( + "context" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" +) + +type BrowserID string + +const ( + BrowserChrome BrowserID = "chrome" + BrowserChromium BrowserID = "chromium" + BrowserEdge BrowserID = "edge" + BrowserBrave BrowserID = "brave" +) + +type Candidate struct { + ID BrowserID `json:"id"` + Label string `json:"label"` + ExecPath string `json:"execPath,omitempty"` + Available bool `json:"available"` + Error string `json:"error,omitempty"` +} + +var ( + detectCandidatesCacheMu sync.RWMutex + detectCandidatesCache []Candidate + detectCandidatesCacheExpiresAt time.Time + detectCandidatesCacheLoaded bool + detectCandidatesCacheTTL = 5 * time.Second + detectCandidatesNow = time.Now + detectCandidatesScan = scanCandidates +) + +func DetectCandidates() []Candidate { + now := detectCandidatesNow() + detectCandidatesCacheMu.RLock() + if detectCandidatesCacheLoaded && now.Before(detectCandidatesCacheExpiresAt) { + cached := cloneCandidates(detectCandidatesCache) + detectCandidatesCacheMu.RUnlock() + return cached + } + detectCandidatesCacheMu.RUnlock() + + detectCandidatesCacheMu.Lock() + defer detectCandidatesCacheMu.Unlock() + + now = detectCandidatesNow() + if detectCandidatesCacheLoaded && now.Before(detectCandidatesCacheExpiresAt) { + return cloneCandidates(detectCandidatesCache) + } + + result := detectCandidatesScan() + detectCandidatesCache = cloneCandidates(result) + detectCandidatesCacheExpiresAt = now.Add(detectCandidatesCacheTTL) + detectCandidatesCacheLoaded = true + return cloneCandidates(result) +} + +func ChooseCandidate(candidates []Candidate, preferred string) (Candidate, bool) { + preferredID := BrowserID(strings.ToLower(strings.TrimSpace(preferred))) + if preferredID != "" { + for _, candidate := range candidates { + if candidate.ID == preferredID && candidate.Available { + return candidate, true + } + } + } + for _, candidate := range candidates { + if candidate.Available { + return candidate, true + } + } + return Candidate{}, false +} + +func CheckCDPReady(ctx context.Context, host string, port int) error { + if port <= 0 { + return fmt.Errorf("invalid cdp port") + } + if strings.TrimSpace(host) == "" { + host = "127.0.0.1" + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://%s:%d/json/version", host, port), nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected cdp status %d", resp.StatusCode) + } + return nil +} + +func detectCandidate(id BrowserID) Candidate { + candidate := Candidate{ID: id, Label: labelForID(id)} + for _, path := range candidatesForID(id) { + resolved := resolveExecutable(path) + if strings.TrimSpace(resolved) == "" { + continue + } + candidate.ExecPath = resolved + candidate.Available = true + candidate.Error = "" + return candidate + } + candidate.Error = "browser executable not found" + return candidate +} + +func scanCandidates() []Candidate { + order := []BrowserID{BrowserChrome, BrowserChromium, BrowserEdge, BrowserBrave} + result := make([]Candidate, 0, len(order)) + for _, id := range order { + result = append(result, detectCandidate(id)) + } + return result +} + +func cloneCandidates(source []Candidate) []Candidate { + if len(source) == 0 { + return []Candidate{} + } + result := make([]Candidate, len(source)) + copy(result, source) + return result +} + +func resetDetectCandidatesCache() { + detectCandidatesCacheMu.Lock() + defer detectCandidatesCacheMu.Unlock() + detectCandidatesCache = nil + detectCandidatesCacheExpiresAt = time.Time{} + detectCandidatesCacheLoaded = false +} + +func labelForID(id BrowserID) string { + switch id { + case BrowserChrome: + return "Chrome" + case BrowserChromium: + return "Chromium" + case BrowserEdge: + return "Edge" + case BrowserBrave: + return "Brave" + default: + return strings.Title(string(id)) + } +} + +func candidatesForID(id BrowserID) []string { + switch runtime.GOOS { + case "darwin": + switch id { + case BrowserChrome: + return []string{ + "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome", + "/Applications/Google Chrome Beta.app/Contents/MacOS/Google Chrome Beta", + "/Applications/Google Chrome Dev.app/Contents/MacOS/Google Chrome Dev", + "/Applications/Google Chrome Canary.app/Contents/MacOS/Google Chrome Canary", + } + case BrowserChromium: + return []string{ + "/Applications/Chromium.app/Contents/MacOS/Chromium", + } + case BrowserEdge: + return []string{ + "/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge", + "/Applications/Microsoft Edge Beta.app/Contents/MacOS/Microsoft Edge Beta", + "/Applications/Microsoft Edge Dev.app/Contents/MacOS/Microsoft Edge Dev", + } + case BrowserBrave: + return []string{ + "/Applications/Brave Browser.app/Contents/MacOS/Brave Browser", + "/Applications/Brave Browser Beta.app/Contents/MacOS/Brave Browser Beta", + } + } + case "windows": + localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")) + programFiles := strings.TrimSpace(os.Getenv("ProgramFiles")) + programFilesX86 := strings.TrimSpace(os.Getenv("ProgramFiles(x86)")) + switch id { + case BrowserChrome: + return compact([]string{ + filepath.Join(programFiles, "Google", "Chrome", "Application", "chrome.exe"), + filepath.Join(programFilesX86, "Google", "Chrome", "Application", "chrome.exe"), + filepath.Join(localAppData, "Google", "Chrome", "Application", "chrome.exe"), + }) + case BrowserChromium: + return compact([]string{ + filepath.Join(programFiles, "Chromium", "Application", "chrome.exe"), + filepath.Join(programFilesX86, "Chromium", "Application", "chrome.exe"), + filepath.Join(localAppData, "Chromium", "Application", "chrome.exe"), + }) + case BrowserEdge: + return compact([]string{ + filepath.Join(programFiles, "Microsoft", "Edge", "Application", "msedge.exe"), + filepath.Join(programFilesX86, "Microsoft", "Edge", "Application", "msedge.exe"), + filepath.Join(localAppData, "Microsoft", "Edge", "Application", "msedge.exe"), + }) + case BrowserBrave: + return compact([]string{ + filepath.Join(programFiles, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"), + filepath.Join(programFilesX86, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"), + filepath.Join(localAppData, "BraveSoftware", "Brave-Browser", "Application", "brave.exe"), + }) + } + default: + switch id { + case BrowserChrome: + return []string{"google-chrome", "google-chrome-stable"} + case BrowserChromium: + return []string{"chromium-browser", "chromium"} + case BrowserEdge: + return []string{"microsoft-edge", "microsoft-edge-stable", "msedge"} + case BrowserBrave: + return []string{"brave-browser", "brave-browser-stable", "brave"} + } + } + return nil +} + +func resolveExecutable(candidate string) string { + trimmed := strings.TrimSpace(candidate) + if trimmed == "" { + return "" + } + if filepath.IsAbs(trimmed) { + if fileInfo, err := os.Stat(trimmed); err == nil && !fileInfo.IsDir() { + return trimmed + } + return "" + } + resolved, err := exec.LookPath(trimmed) + if err != nil { + return "" + } + return resolved +} + +func compact(values []string) []string { + result := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +func WaitForCDP(ctx context.Context, host string, port int, timeout time.Duration) error { + if ctx == nil { + ctx = context.Background() + } + if timeout <= 0 { + timeout = 8 * time.Second + } + deadline := time.Now().Add(timeout) + for { + if err := ctx.Err(); err != nil { + return err + } + checkCtx, cancel := context.WithTimeout(ctx, 1200*time.Millisecond) + err := CheckCDPReady(checkCtx, host, port) + cancel() + if err == nil { + return nil + } + if time.Now().After(deadline) { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(250 * time.Millisecond): + } + } +} diff --git a/internal/application/browsercdp/detect_test.go b/internal/application/browsercdp/detect_test.go new file mode 100644 index 0000000..2506b76 --- /dev/null +++ b/internal/application/browsercdp/detect_test.go @@ -0,0 +1,71 @@ +package browsercdp + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestDetectCandidatesCachesScanUntilTTLExpires(t *testing.T) { + t.Parallel() + + originalNow := detectCandidatesNow + originalScan := detectCandidatesScan + originalTTL := detectCandidatesCacheTTL + current := time.Unix(1_700_000_000, 0) + scanCalls := 0 + + detectCandidatesNow = func() time.Time { return current } + detectCandidatesScan = func() []Candidate { + scanCalls += 1 + return []Candidate{ + { + ID: BrowserChrome, + Label: "Chrome", + ExecPath: "/tmp/chrome", + Available: true, + }, + } + } + detectCandidatesCacheTTL = time.Minute + resetDetectCandidatesCache() + t.Cleanup(func() { + detectCandidatesNow = originalNow + detectCandidatesScan = originalScan + detectCandidatesCacheTTL = originalTTL + resetDetectCandidatesCache() + }) + + first := DetectCandidates() + if scanCalls != 1 { + t.Fatalf("expected first detect to scan once, got %d", scanCalls) + } + first[0].ExecPath = "/tmp/changed" + + second := DetectCandidates() + if scanCalls != 1 { + t.Fatalf("expected second detect to use cache, got %d scans", scanCalls) + } + if got := second[0].ExecPath; got != "/tmp/chrome" { + t.Fatalf("expected cached detect result to be cloned, got %q", got) + } + + current = current.Add(time.Minute + time.Second) + _ = DetectCandidates() + if scanCalls != 2 { + t.Fatalf("expected detect to rescan after ttl, got %d scans", scanCalls) + } +} + +func TestWaitForCDPHonorsCancelledContext(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := WaitForCDP(ctx, "127.0.0.1", 1, time.Second) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } +} diff --git a/internal/application/browsercdp/runtime.go b/internal/application/browsercdp/runtime.go new file mode 100644 index 0000000..9ebd632 --- /dev/null +++ b/internal/application/browsercdp/runtime.go @@ -0,0 +1,352 @@ +package browsercdp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/chromedp/chromedp" + "go.uber.org/zap" +) + +type LaunchOptions struct { + PreferredBrowser string + Headless bool + NoSandbox bool + ExtraArgs []string + UserDataDir string +} + +type Status struct { + Ready bool `json:"ready"` + Candidates []Candidate `json:"candidates,omitempty"` + SelectedBrowser string `json:"selectedBrowser,omitempty"` + ChosenBrowser string `json:"chosenBrowser,omitempty"` + DetectedExecutablePath string `json:"detectedExecutablePath,omitempty"` + DetectError string `json:"detectError,omitempty"` + CDPURL string `json:"cdpUrl,omitempty"` + CDPPort int `json:"cdpPort,omitempty"` + Headless bool `json:"headless"` +} + +type Runtime struct { + mu sync.Mutex + + options LaunchOptions + candidate Candidate + status Status + + cmd *exec.Cmd + userDataDir string + allocCtx context.Context + allocCancel context.CancelFunc + browserCtx context.Context + browserCancel context.CancelFunc + stopping bool + stopped chan struct{} +} + +type versionResponse struct { + WebSocketDebuggerURL string `json:"webSocketDebuggerUrl"` +} + +func ResolveStatus(preferred string, headless bool) Status { + candidates := DetectCandidates() + status := Status{ + Candidates: candidates, + SelectedBrowser: strings.TrimSpace(preferred), + Headless: headless, + } + candidate, ok := ChooseCandidate(candidates, preferred) + if !ok { + status.DetectError = "no supported browser detected" + return status + } + status.ChosenBrowser = string(candidate.ID) + status.DetectedExecutablePath = candidate.ExecPath + status.Ready = candidate.Available + if !candidate.Available { + status.DetectError = candidate.Error + } + return status +} + +func Start(ctx context.Context, options LaunchOptions) (*Runtime, error) { + candidates := DetectCandidates() + candidate, ok := ChooseCandidate(candidates, options.PreferredBrowser) + if !ok { + return nil, fmt.Errorf("no supported browser detected") + } + + port, err := reservePort() + if err != nil { + return nil, err + } + userDataDir := strings.TrimSpace(options.UserDataDir) + if userDataDir == "" { + userDataDir = filepath.Join(os.TempDir(), "dreamcreator", "browsercdp", string(candidate.ID)) + } + if err := os.MkdirAll(userDataDir, 0o755); err != nil { + return nil, err + } + + args := []string{ + fmt.Sprintf("--remote-debugging-port=%d", port), + fmt.Sprintf("--user-data-dir=%s", userDataDir), + "--no-first-run", + "--no-default-browser-check", + "--disable-background-networking", + "--disable-background-timer-throttling", + "--disable-backgrounding-occluded-windows", + "--disable-breakpad", + "--disable-client-side-phishing-detection", + "--disable-default-apps", + "--disable-features=Translate,OptimizationHints,MediaRouter,AutomationControlled", + "--disable-hang-monitor", + "--disable-popup-blocking", + "--disable-prompt-on-repost", + "--disable-sync", + "--metrics-recording-only", + "--password-store=basic", + "--use-mock-keychain", + } + if options.Headless { + args = append([]string{"--headless=new", "--hide-scrollbars", "--mute-audio"}, args...) + } else { + args = append([]string{"--no-startup-window"}, args...) + } + if options.NoSandbox { + args = append([]string{"--no-sandbox"}, args...) + } + for _, extra := range options.ExtraArgs { + if trimmed := strings.TrimSpace(extra); trimmed != "" { + args = append(args, trimmed) + } + } + + cmd := exec.Command(candidate.ExecPath, args...) + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + cmd.SysProcAttr = &syscall.SysProcAttr{} + zap.L().Info( + "browser runtime launch started", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Bool("headless", options.Headless), + zap.String("userDataDir", userDataDir), + zap.Int("cdpPort", port), + ) + if err := cmd.Start(); err != nil { + zap.L().Warn( + "browser runtime launch failed", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Error(err), + ) + return nil, err + } + + if err := WaitForCDP(ctx, "127.0.0.1", port, 10*time.Second); err != nil { + zap.L().Warn( + "browser runtime cdp wait failed", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Int("cdpPort", port), + zap.Error(err), + ) + _ = cmd.Process.Kill() + return nil, err + } + zap.L().Info( + "browser runtime cdp ready", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Int("cdpPort", port), + ) + wsURL, err := fetchWebSocketURL(ctx, port) + if err != nil { + zap.L().Warn( + "browser runtime websocket resolve failed", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Int("cdpPort", port), + zap.Error(err), + ) + _ = cmd.Process.Kill() + return nil, err + } + allocCtx, allocCancel := chromedp.NewRemoteAllocator(context.Background(), wsURL) + browserCtx, browserCancel := chromedp.NewContext(allocCtx) + if _, err := chromedp.Targets(browserCtx); err != nil { + zap.L().Warn( + "browser runtime chromedp attach failed", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.String("cdpUrl", wsURL), + zap.Error(err), + ) + browserCancel() + allocCancel() + _ = cmd.Process.Kill() + return nil, err + } + zap.L().Info( + "browser runtime ready", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.String("cdpUrl", wsURL), + zap.Int("cdpPort", port), + ) + + runtime := &Runtime{ + options: options, + candidate: candidate, + cmd: cmd, + userDataDir: userDataDir, + allocCtx: allocCtx, + allocCancel: allocCancel, + browserCtx: browserCtx, + browserCancel: browserCancel, + stopped: make(chan struct{}), + status: Status{ + Ready: true, + Candidates: candidates, + SelectedBrowser: strings.TrimSpace(options.PreferredBrowser), + ChosenBrowser: string(candidate.ID), + DetectedExecutablePath: candidate.ExecPath, + CDPURL: wsURL, + CDPPort: port, + Headless: options.Headless, + }, + } + + go func() { + _ = cmd.Wait() + runtime.mu.Lock() + runtime.status.Ready = false + stopping := runtime.stopping + if !stopping && runtime.status.DetectError == "" { + runtime.status.DetectError = "browser process exited" + } + runtime.mu.Unlock() + if !stopping { + zap.L().Warn( + "browser runtime exited unexpectedly", + zap.String("preferredBrowser", strings.TrimSpace(options.PreferredBrowser)), + zap.String("chosenBrowser", string(candidate.ID)), + zap.String("execPath", candidate.ExecPath), + zap.Int("cdpPort", port), + ) + } + close(runtime.stopped) + }() + + return runtime, nil +} + +func (runtime *Runtime) BrowserContext() context.Context { + runtime.mu.Lock() + defer runtime.mu.Unlock() + return runtime.browserCtx +} + +func (runtime *Runtime) UserDataDir() string { + runtime.mu.Lock() + defer runtime.mu.Unlock() + return runtime.userDataDir +} + +func (runtime *Runtime) Candidate() Candidate { + runtime.mu.Lock() + defer runtime.mu.Unlock() + return runtime.candidate +} + +func (runtime *Runtime) Status() Status { + runtime.mu.Lock() + defer runtime.mu.Unlock() + return runtime.status +} + +func (runtime *Runtime) Stop() { + if runtime == nil { + return + } + runtime.mu.Lock() + cmd := runtime.cmd + browserCancel := runtime.browserCancel + allocCancel := runtime.allocCancel + stopped := runtime.stopped + runtime.stopping = true + runtime.status.Ready = false + runtime.mu.Unlock() + + if browserCancel != nil { + browserCancel() + } + if allocCancel != nil { + allocCancel() + } + if cmd != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } + if stopped != nil { + select { + case <-stopped: + case <-time.After(2 * time.Second): + } + } +} + +func fetchWebSocketURL(ctx context.Context, port int) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/json/version", port), nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected cdp status %d", resp.StatusCode) + } + var payload versionResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "", err + } + if strings.TrimSpace(payload.WebSocketDebuggerURL) == "" { + return "", fmt.Errorf("webSocketDebuggerUrl missing") + } + return strings.TrimSpace(payload.WebSocketDebuggerURL), nil +} + +func reservePort() (int, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, err + } + defer listener.Close() + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + return 0, fmt.Errorf("failed to reserve tcp port") + } + return addr.Port, nil +} diff --git a/internal/application/browsercdp/session.go b/internal/application/browsercdp/session.go new file mode 100644 index 0000000..c9eff3c --- /dev/null +++ b/internal/application/browsercdp/session.go @@ -0,0 +1,3300 @@ +package browsercdp + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/chromedp/cdproto/cdp" + "github.com/chromedp/cdproto/emulation" + "github.com/chromedp/cdproto/fetch" + "github.com/chromedp/cdproto/network" + pagepkg "github.com/chromedp/cdproto/page" + targetpkg "github.com/chromedp/cdproto/target" + "github.com/chromedp/chromedp" + "go.uber.org/zap" + + appcookies "dreamcreator/internal/application/cookies" +) + +const ( + defaultSnapshotLimit = 200 + defaultSSRFValidationTimeout = 2 * time.Second +) + +var lookupIPAddrsForHost = func(ctx context.Context, host string) ([]net.IPAddr, error) { + return net.DefaultResolver.LookupIPAddr(ctx, host) +} + +type SSRFPolicy struct { + DangerouslyAllowPrivateNetwork bool + AllowedHostnames map[string]struct{} + HostnameAllowlist []string +} + +type SessionOptions struct { + SessionKey string + ProfileName string + PreferredBrowser string + Headless bool + UserDataDir string + SSRFRules SSRFPolicy + Cookies ConnectorCookieProvider +} + +type SessionRegistry struct { + mu sync.Mutex + sessions map[string]map[string]*Session +} + +type Session struct { + mu sync.Mutex + + options SessionOptions + runtime *Runtime + + tabs map[string]*sessionTab + activeTarget string + + pendingDialogs map[string]PendingDialog + cookieSync map[string]string +} + +type sessionTab struct { + TargetID string + ctx context.Context + cancel context.CancelFunc + + mu sync.RWMutex + cleanupCancels []context.CancelFunc + refs map[string]snapshotRef + evaluateResult any + lastURL string + title string + lastState *PageState + stateVersion uint64 + nextRefID uint64 + blockedRequestErr string + fetchEnabled bool +} + +type newTabWaiter struct { + ctx context.Context + ids <-chan targetpkg.ID + cancel context.CancelFunc + stop func() bool +} + +func (waiter *newTabWaiter) close() { + if waiter == nil { + return + } + if waiter.stop != nil { + waiter.stop() + } + if waiter.cancel != nil { + waiter.cancel() + } +} + +type snapshotRef struct { + Selector string + Role string + Name string + Nth int +} + +func clearBlockedRequestError(tab *sessionTab) { + if tab == nil { + return + } + tab.mu.Lock() + tab.blockedRequestErr = "" + tab.mu.Unlock() +} + +func setBlockedRequestError(tab *sessionTab, err error) { + if tab == nil || err == nil { + return + } + message := strings.TrimSpace(err.Error()) + if message == "" { + return + } + tab.mu.Lock() + if tab.blockedRequestErr == "" { + tab.blockedRequestErr = message + } + tab.mu.Unlock() +} + +func peekBlockedRequestError(tab *sessionTab) error { + if tab == nil { + return nil + } + tab.mu.RLock() + message := strings.TrimSpace(tab.blockedRequestErr) + tab.mu.RUnlock() + if message == "" { + return nil + } + return errors.New(message) +} + +func consumeBlockedRequestError(tab *sessionTab) error { + if tab == nil { + return nil + } + tab.mu.Lock() + message := strings.TrimSpace(tab.blockedRequestErr) + tab.blockedRequestErr = "" + tab.mu.Unlock() + if message == "" { + return nil + } + return errors.New(message) +} + +type snapshotCapture struct { + Version uint64 + URL string + Title string + Items []SnapshotItem + Refs map[string]snapshotRef + Truncated bool + ViewportOnly bool +} + +type SnapshotItem struct { + Ref string `json:"ref,omitempty"` + Role string `json:"role,omitempty"` + Name string `json:"name,omitempty"` + Text string `json:"text,omitempty"` + Depth int `json:"depth,omitempty"` + Nth int `json:"nth,omitempty"` +} + +type PageState struct { + Version uint64 `json:"version"` + URL string `json:"url"` + Title string `json:"title,omitempty"` + Items []SnapshotItem `json:"items"` + ItemCount int `json:"itemCount"` + Truncated bool `json:"truncated"` + ViewportOnly bool `json:"viewportOnly"` + CapturedAt string `json:"capturedAt"` +} + +type PendingDialog struct { + Message string `json:"message,omitempty"` + Type string `json:"type,omitempty"` + ExpiresAt time.Time `json:"expiresAt"` +} + +type ScrollDelta struct { + X int `json:"x"` + Y int `json:"y"` +} + +type ActionResult struct { + OK bool `json:"ok"` + TargetID string `json:"targetId,omitempty"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + StateVersion uint64 `json:"stateVersion,omitempty"` + State *PageState `json:"state,omitempty"` + Items []SnapshotItem `json:"items,omitempty"` + Action string `json:"action,omitempty"` + OpenedNewTab bool `json:"openedNewTab,omitempty"` + PreviousTargetID string `json:"previousTargetId,omitempty"` + PreviousURL string `json:"previousURL,omitempty"` + Navigated bool `json:"navigated,omitempty"` + Waited bool `json:"waited,omitempty"` + Scroll *ScrollDelta `json:"scroll,omitempty"` + Paths []string `json:"paths,omitempty"` + Result any `json:"result,omitempty"` + Pending *PendingDialog `json:"pending,omitempty"` + StateAvailable bool `json:"stateAvailable"` + StateError string `json:"stateError,omitempty"` + Reset bool `json:"reset,omitempty"` + Restarted bool `json:"restarted,omitempty"` + Ready bool `json:"ready,omitempty"` + Closed bool `json:"closed,omitempty"` +} + +type CommandOptions struct { + Limit int + Timeout time.Duration + WaitFor *WaitRequest +} + +type WaitRequest struct { + Time time.Duration + Selector string + Text string + TextGone string + URL string + Fn string + Timeout time.Duration +} + +type ScrollRequest struct { + TargetID string + Ref string + DeltaX int + DeltaY int + Limit int + Timeout time.Duration +} + +type UploadRequest struct { + TargetID string + Ref string + Paths []string + Limit int + Timeout time.Duration +} + +type DialogRequest struct { + TargetID string + Accept *bool + PromptText string + Limit int + Timeout time.Duration +} + +type ActRequest struct { + Kind string + TargetID string + Ref string + Text string + Key string + Value string + Expression string + Width int + Height int + Wait WaitRequest + WaitFor *WaitRequest + Limit int + Timeout time.Duration +} + +type FatalError struct { + Err error +} + +func (err *FatalError) Error() string { + if err == nil || err.Err == nil { + return "browser runtime unavailable" + } + return err.Err.Error() +} + +func (err *FatalError) Unwrap() error { + if err == nil { + return nil + } + return err.Err +} + +type InvalidRefError struct { + Ref string +} + +func (err *InvalidRefError) Error() string { + if err == nil || strings.TrimSpace(err.Ref) == "" { + return "ref not found; run snapshot again to get fresh refs" + } + return fmt.Sprintf("ref %q not found; run snapshot again to get fresh refs", strings.TrimSpace(err.Ref)) +} + +type ConnectorCookieError struct { + URL string + Err error +} + +func (err *ConnectorCookieError) Error() string { + if err == nil || err.Err == nil { + return "connector cookie sync failed" + } + return fmt.Sprintf("connector cookie sync failed for %s: %v", strings.TrimSpace(err.URL), err.Err) +} + +func (err *ConnectorCookieError) Unwrap() error { + if err == nil { + return nil + } + return err.Err +} + +type WaitTimeoutError struct { + Condition string +} + +func (err *WaitTimeoutError) Error() string { + if err == nil || strings.TrimSpace(err.Condition) == "" { + return "wait timeout" + } + return fmt.Sprintf("wait %s timeout", strings.TrimSpace(err.Condition)) +} + +func NewSessionRegistry() *SessionRegistry { + return &SessionRegistry{ + sessions: map[string]map[string]*Session{}, + } +} + +func (registry *SessionRegistry) GetOrCreate(sessionKey string, profileName string, options SessionOptions) *Session { + registry.mu.Lock() + defer registry.mu.Unlock() + + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + sessionKey = "default" + } + profileName = strings.TrimSpace(profileName) + if profileName == "" { + profileName = "dreamcreator" + } + bucket, ok := registry.sessions[sessionKey] + if !ok { + bucket = map[string]*Session{} + registry.sessions[sessionKey] = bucket + } + session, ok := bucket[profileName] + if !ok { + options.SessionKey = sessionKey + options.ProfileName = profileName + session = &Session{ + options: normalizeSessionOptions(options), + tabs: map[string]*sessionTab{}, + pendingDialogs: map[string]PendingDialog{}, + cookieSync: map[string]string{}, + } + bucket[profileName] = session + return session + } + session.mu.Lock() + session.options = normalizeSessionOptions(options) + session.mu.Unlock() + return session +} + +func (registry *SessionRegistry) CloseSessionKey(sessionKey string) { + if registry == nil { + return + } + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + sessionKey = "default" + } + registry.mu.Lock() + bucket := registry.sessions[sessionKey] + delete(registry.sessions, sessionKey) + sessions := make([]*Session, 0, len(bucket)) + for _, session := range bucket { + if session != nil { + sessions = append(sessions, session) + } + } + registry.mu.Unlock() + for _, session := range sessions { + session.stop() + } +} + +func (registry *SessionRegistry) CloseAll() { + if registry == nil { + return + } + registry.mu.Lock() + sessions := make([]*Session, 0) + for sessionKey, bucket := range registry.sessions { + delete(registry.sessions, sessionKey) + for _, session := range bucket { + if session != nil { + sessions = append(sessions, session) + } + } + } + registry.mu.Unlock() + for _, session := range sessions { + session.stop() + } +} + +func IsFatalError(err error) bool { + if err == nil { + return false + } + var fatal *FatalError + if errors.As(err, &fatal) { + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "browser runtime unavailable"), + strings.Contains(message, "context canceled"), + strings.Contains(message, "target closed"), + strings.Contains(message, "connection closed"), + strings.Contains(message, "websocket"), + strings.Contains(message, "session closed"), + strings.Contains(message, "browser session reset"): + return true + default: + return false + } +} + +func (session *Session) Reset(restart bool) (ActionResult, error) { + session.stop() + if restart { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + } + return ActionResult{ + OK: true, + Reset: true, + Restarted: restart, + Ready: restart, + }, nil +} + +func (session *Session) Open(ctx context.Context, targetURL string, options CommandOptions) (ActionResult, error) { + openStartedAt := time.Now() + if err := session.assertURLAllowed(targetURL); err != nil { + zap.L().Warn( + "browser open blocked by url policy", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.Error(err), + )..., + ) + return ActionResult{}, err + } + if err := session.ensureStarted(); err != nil { + zap.L().Warn( + "browser open start runtime failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + cookiesStartedAt := time.Now() + if err := session.ensureCookiesForURLOnBrowser(ctx, targetURL); err != nil { + if isRecoverableCookieSyncError(err) { + zap.L().Warn( + "browser open cookie sync failed; retrying on fresh runtime", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.Error(err), + )..., + ) + session.stop() + if retryErr := session.ensureStarted(); retryErr == nil { + err = session.ensureCookiesForURLOnBrowser(ctx, targetURL) + } else { + err = retryErr + } + } + zap.L().Warn( + "browser open cookie sync failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.Duration("elapsed", time.Since(cookiesStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, err + } + createTabStartedAt := time.Now() + tab, err := session.createTab() + if err != nil { + zap.L().Warn( + "browser open create tab failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.Duration("elapsed", time.Since(createTabStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + cleanup := true + defer func() { + if cleanup { + zap.L().Warn( + "browser open cleaning up failed tab", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.String("targetId", tab.TargetID), + )..., + ) + session.detachTab(tab.TargetID) + cancelTabContexts(tab) + } + }() + session.setActiveTarget(tab.TargetID) + clearBlockedRequestError(tab) + navigateTimeout := normalizeTimeout(options.Timeout, 30*time.Second) + navigateStartedAt := time.Now() + if err := session.openTab(tab, targetURL, navigateTimeout); err != nil { + zap.L().Warn( + "browser open navigation failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.String("targetId", tab.TargetID), + zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + cleanup = false + if options.WaitFor != nil { + waitTimeout := normalizeTimeout(options.WaitFor.Timeout, options.Timeout) + waitStartedAt := time.Now() + if err := session.waitOnTab(ctx, tab, *options.WaitFor, waitTimeout); err != nil { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + zap.L().Warn( + "browser open wait failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.String("targetId", tab.TargetID), + zap.Duration("elapsed", time.Since(waitStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + } + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + stateTimeout := captureTimeout(options.Timeout) + stateStartedAt := time.Now() + result, err := session.collectActionResult(tab, options.Limit, stateTimeout, false) + if err != nil { + zap.L().Warn( + "browser open state capture failed", + append(session.logFields(), + sanitizedURLField("url", targetURL), + zap.String("targetId", tab.TargetID), + zap.Duration("elapsed", time.Since(stateStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + result.OpenedNewTab = true + result.URL = preferredPageURL(result.URL, tabURL(tab), targetURL) + if result.State != nil { + result.State.URL = preferredPageURL(result.State.URL, result.URL) + } + zap.L().Info( + "browser open completed", + append(session.logFields(), + sanitizedURLField("requestedURL", targetURL), + zap.String("targetId", result.TargetID), + sanitizedURLField("finalURL", result.URL), + zap.Bool("stateAvailable", result.State != nil || result.StateAvailable), + zap.Uint64("stateVersion", result.StateVersion), + zap.Int("itemCount", resultItemCount(result)), + zap.String("stateError", strings.TrimSpace(result.StateError)), + zap.Duration("elapsed", time.Since(openStartedAt).Round(time.Millisecond)), + )..., + ) + return result, nil +} + +func (session *Session) Navigate(ctx context.Context, targetID string, targetURL string, newTab bool, options CommandOptions) (ActionResult, error) { + navigateStartedAt := time.Now() + if err := session.assertURLAllowed(targetURL); err != nil { + return ActionResult{}, err + } + if newTab { + return session.Open(ctx, targetURL, options) + } + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(targetID, true) + if err != nil { + if targetID == "" && errors.Is(err, errNoOpenTab) { + return session.Open(ctx, targetURL, options) + } + return ActionResult{}, session.wrapError(err) + } + clearBlockedRequestError(tab) + if err := session.ensureCookiesForURL(ctx, tab, targetURL); err != nil { + if isRecoverableCookieSyncError(err) { + zap.L().Warn( + "browser navigate cookie sync failed; restarting runtime and falling back to open", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + zap.Error(err), + )..., + ) + session.stop() + recovered, recoveredErr := session.Open(ctx, targetURL, options) + if recoveredErr == nil { + recovered.OpenedNewTab = true + return recovered, nil + } + return ActionResult{}, recoveredErr + } + return ActionResult{}, err + } + tab, err = session.navigateTab(tab, targetURL, normalizeTimeout(options.Timeout, 30*time.Second)) + if err != nil { + currentTargetID := "" + if tab != nil { + currentTargetID = tab.TargetID + } + zap.L().Warn( + "browser navigate failed before state capture", + append(session.logFields(), + zap.String("targetId", currentTargetID), + sanitizedURLField("url", targetURL), + zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(wrapRuntimeHangError(err)) + } + if options.WaitFor != nil { + if err := session.waitOnTab(ctx, tab, *options.WaitFor, normalizeTimeout(options.WaitFor.Timeout, options.Timeout)); err != nil { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + zap.L().Warn( + "browser navigate wait failed", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + } + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + stateTimeout := captureTimeout(options.Timeout) + result, err := session.collectActionResult(tab, options.Limit, stateTimeout, false) + if err != nil { + zap.L().Warn( + "browser navigate state capture failed", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + if isTargetLookupError(err) { + zap.L().Warn( + "browser navigate falling back to open after target loss", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + )..., + ) + session.stop() + recovered, recoveredErr := session.Open(ctx, targetURL, options) + if recoveredErr == nil { + recovered.OpenedNewTab = true + return recovered, nil + } + zap.L().Warn( + "browser navigate recovery open failed after runtime restart", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + zap.Error(recoveredErr), + )..., + ) + } + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(tab, result, options.Limit, stateTimeout, minDuration(2*time.Second, normalizeTimeout(options.Timeout, 30*time.Second)/2)) + if shouldRetryActionState(result) && strings.TrimSpace(result.URL) == strings.TrimSpace(targetURL) { + zap.L().Warn( + "browser navigate state remained unstable after stabilization; falling back to open", + append(session.logFields(), + zap.String("targetId", result.TargetID), + sanitizedURLField("url", targetURL), + zap.String("reason", actionStateReason(result)), + )..., + ) + fallback, openErr := session.Open(ctx, targetURL, options) + if openErr == nil { + fallback.OpenedNewTab = true + return fallback, nil + } + session.stop() + recovered, recoveredErr := session.Open(ctx, targetURL, options) + if recoveredErr == nil { + recovered.OpenedNewTab = true + return recovered, nil + } + } + result.OpenedNewTab = false + zap.L().Info( + "browser navigate completed", + append(session.logFields(), + zap.String("targetId", result.TargetID), + sanitizedURLField("requestedURL", targetURL), + sanitizedURLField("finalURL", result.URL), + zap.Bool("stateAvailable", result.StateAvailable), + zap.Uint64("stateVersion", result.StateVersion), + zap.Int("itemCount", resultItemCount(result)), + zap.String("stateError", strings.TrimSpace(result.StateError)), + zap.Duration("elapsed", time.Since(navigateStartedAt).Round(time.Millisecond)), + )..., + ) + return result, nil +} + +func (session *Session) State(targetID string, limit int) (ActionResult, error) { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(targetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + result, err := session.collectActionResult(tab, limit, 10*time.Second, false) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + return result, nil +} + +func (session *Session) Wait(ctx context.Context, targetID string, request WaitRequest, options CommandOptions) (ActionResult, error) { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(targetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + if err := session.ensureRequestInterception(tab); err != nil { + return ActionResult{}, session.wrapError(err) + } + if err := session.waitOnTab(ctx, tab, request, normalizeTimeout(request.Timeout, options.Timeout)); err != nil { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + return ActionResult{}, session.wrapError(err) + } + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + result, err := session.collectActionResult(tab, options.Limit, captureTimeout(options.Timeout), false) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(tab, result, options.Limit, captureTimeout(options.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(options.Timeout, 15*time.Second)/2)) + result.Waited = true + return result, nil +} + +func (session *Session) Scroll(request ScrollRequest) (ActionResult, error) { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(request.TargetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second)) + defer cancel() + if strings.TrimSpace(request.Ref) != "" { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return ActionResult{}, err + } + script := fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.scrollIntoView({block: "center", inline: "center"}); if (%d !== 0 || %d !== 0) { el.scrollBy(%d, %d); } })()`, selector, request.DeltaX, request.DeltaY, request.DeltaX, request.DeltaY) + if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, nil)); err != nil { + return ActionResult{}, session.wrapError(err) + } + } else { + script := fmt.Sprintf(`window.scrollBy(%d, %d)`, request.DeltaX, request.DeltaY) + if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, nil)); err != nil { + return ActionResult{}, session.wrapError(err) + } + } + time.Sleep(200 * time.Millisecond) + result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2)) + result.Scroll = &ScrollDelta{X: request.DeltaX, Y: request.DeltaY} + return result, nil +} + +func (session *Session) Upload(request UploadRequest) (ActionResult, error) { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(request.TargetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return ActionResult{}, err + } + if len(request.Paths) == 0 { + return ActionResult{}, errors.New("paths are required") + } + runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second)) + defer cancel() + if err := chromedp.Run(runCtx, chromedp.SetUploadFiles(selector, request.Paths, chromedp.ByQuery)); err != nil { + return ActionResult{}, session.wrapError(err) + } + clearTabState(tab) + result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2)) + result.Paths = append([]string(nil), request.Paths...) + return result, nil +} + +func (session *Session) Dialog(request DialogRequest) (ActionResult, error) { + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(request.TargetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + targetID := tab.TargetID + session.mu.Lock() + pending, exists := session.pendingDialogs[targetID] + session.mu.Unlock() + if !exists { + return ActionResult{ + OK: true, + TargetID: targetID, + }, nil + } + if request.Accept == nil { + dialog := pending + return ActionResult{ + OK: true, + TargetID: targetID, + Pending: &dialog, + }, nil + } + runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(request.Timeout, 15*time.Second)) + defer cancel() + if err := chromedp.Run(runCtx, chromedp.ActionFunc(func(ctx context.Context) error { + return pagepkg.HandleJavaScriptDialog(*request.Accept).WithPromptText(request.PromptText).Do(ctx) + })); err != nil { + return ActionResult{}, session.wrapError(err) + } + session.mu.Lock() + delete(session.pendingDialogs, targetID) + session.mu.Unlock() + clearTabState(tab) + result, err := session.collectActionResult(tab, request.Limit, captureTimeout(request.Timeout), false) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(tab, result, request.Limit, captureTimeout(request.Timeout), minDuration(1500*time.Millisecond, normalizeTimeout(request.Timeout, 15*time.Second)/2)) + return result, nil +} + +func (session *Session) Act(ctx context.Context, request ActRequest) (ActionResult, error) { + actStartedAt := time.Now() + if err := session.ensureStarted(); err != nil { + return ActionResult{}, session.wrapError(err) + } + tab, err := session.resolveTab(request.TargetID, true) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + if err := session.ensureRequestInterception(tab); err != nil { + return ActionResult{}, session.wrapError(err) + } + previousTargetID := tab.TargetID + previousURL := tabURL(tab) + clearBlockedRequestError(tab) + beforeTargets, _ := session.snapshotPageTargets() + var newTabWaiter *newTabWaiter + if actMayOpenNewTab(request.Kind) { + newTabWaiter, err = session.prepareNewTabWaiter(ctx, tab) + if err != nil { + return ActionResult{}, session.wrapError(err) + } + defer newTabWaiter.close() + } + switch request.Kind { + case "click": + err = session.actClick(tab, request) + case "type": + err = session.actType(tab, request) + case "press": + err = session.actPress(tab, request) + case "hover": + err = session.actHover(tab, request) + case "select": + err = session.actSelect(tab, request) + case "fill": + err = session.actFill(tab, request) + case "resize": + err = session.actResize(tab, request) + case "wait": + err = session.waitOnTab(ctx, tab, request.Wait, normalizeTimeout(request.Wait.Timeout, request.Timeout)) + case "evaluate": + err = session.actEvaluate(tab, request) + case "close": + _, err = session.closeTab(tab.TargetID, normalizeTimeout(request.Timeout, 15*time.Second)) + default: + err = fmt.Errorf("act kind not supported: %s", strings.TrimSpace(request.Kind)) + } + if err != nil { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + zap.L().Warn( + "browser act command failed", + append(session.logFields(), + zap.String("kind", strings.TrimSpace(request.Kind)), + zap.String("targetId", tab.TargetID), + zap.String("ref", strings.TrimSpace(request.Ref)), + zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + if request.Kind != "wait" { + err = wrapRuntimeHangError(err) + } + return ActionResult{}, session.wrapError(err) + } + if request.Kind == "close" { + return ActionResult{ + OK: true, + TargetID: previousTargetID, + Closed: true, + }, nil + } + if actInvalidatesState(request.Kind) { + clearTabState(tab) + } + currentTab := tab + openedNewTab := false + if newTabWaiter != nil { + if detectedTab, ok := session.waitForNewTab(newTabWaiter, 1500*time.Millisecond); ok && detectedTab != nil { + currentTab = detectedTab + openedNewTab = detectedTab.TargetID != previousTargetID + } + } + if request.WaitFor != nil { + if err := session.waitOnTab(ctx, currentTab, *request.WaitFor, normalizeTimeout(request.WaitFor.Timeout, request.Timeout)); err != nil { + zap.L().Warn( + "browser act wait failed", + append(session.logFields(), + zap.String("kind", strings.TrimSpace(request.Kind)), + zap.String("targetId", currentTab.TargetID), + zap.String("ref", strings.TrimSpace(request.Ref)), + zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + } else if actNeedsSettle(request.Kind) { + if err := sleepWithContext(ctx, 250*time.Millisecond); err != nil { + return ActionResult{}, session.wrapError(err) + } + } + if request.WaitFor == nil { + switch request.Kind { + case "click", "press": + if navigatedTab, observed, observeErr := session.observeActionNavigation(currentTab, beforeTargets, previousURL, 2*time.Second); observeErr != nil { + return ActionResult{}, session.wrapError(observeErr) + } else if observed && navigatedTab != nil { + currentTab = navigatedTab + } + } + } + if blockedErr := consumeBlockedRequestError(currentTab); blockedErr != nil { + return ActionResult{}, blockedErr + } + stateTimeout := captureTimeout(request.Timeout) + result, err := session.collectActionResult(currentTab, request.Limit, stateTimeout, false) + if err != nil { + zap.L().Warn( + "browser act state capture failed", + append(session.logFields(), + zap.String("kind", strings.TrimSpace(request.Kind)), + zap.String("targetId", currentTab.TargetID), + zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)), + zap.Error(err), + )..., + ) + return ActionResult{}, session.wrapError(err) + } + result = session.stabilizeActionResult(currentTab, result, request.Limit, stateTimeout, minDuration(2*time.Second, normalizeTimeout(request.Timeout, 20*time.Second)/2)) + result.Action = request.Kind + result.OpenedNewTab = openedNewTab + result.PreviousTargetID = previousTargetID + result.PreviousURL = previousURL + result.Navigated = openedNewTab || !urlsEqual(previousURL, result.URL) + if request.Kind == "evaluate" { + result.Result = evaluateResult(tab) + } + zap.L().Info( + "browser act completed", + append(session.logFields(), + zap.String("kind", strings.TrimSpace(request.Kind)), + zap.String("targetId", result.TargetID), + sanitizedURLField("previousURL", previousURL), + sanitizedURLField("finalURL", result.URL), + zap.Bool("openedNewTab", openedNewTab), + zap.Bool("navigated", result.Navigated), + zap.Bool("stateAvailable", result.StateAvailable), + zap.Uint64("stateVersion", result.StateVersion), + zap.Int("itemCount", resultItemCount(result)), + zap.String("stateError", strings.TrimSpace(result.StateError)), + zap.Duration("elapsed", time.Since(actStartedAt).Round(time.Millisecond)), + )..., + ) + return result, nil +} + +var errNoOpenTab = errors.New("no browser tab is open") + +func normalizeSessionOptions(options SessionOptions) SessionOptions { + options.SessionKey = strings.TrimSpace(options.SessionKey) + if options.SessionKey == "" { + options.SessionKey = "default" + } + options.ProfileName = strings.TrimSpace(options.ProfileName) + if options.ProfileName == "" { + options.ProfileName = "dreamcreator" + } + options.PreferredBrowser = strings.ToLower(strings.TrimSpace(options.PreferredBrowser)) + if options.UserDataDir == "" { + options.UserDataDir = ResolveProfileUserDataDir(options.SessionKey, options.ProfileName) + } + if options.SSRFRules.AllowedHostnames == nil { + options.SSRFRules.AllowedHostnames = map[string]struct{}{} + } + return options +} + +func (session *Session) optionsSnapshot() SessionOptions { + session.mu.Lock() + defer session.mu.Unlock() + return session.options +} + +func (session *Session) ensureStarted() error { + session.mu.Lock() + if session.runtime != nil && session.runtime.Status().Ready { + session.mu.Unlock() + return nil + } + options := session.options + session.mu.Unlock() + + startCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + runtime, err := Start(startCtx, LaunchOptions{ + PreferredBrowser: options.PreferredBrowser, + Headless: options.Headless, + UserDataDir: options.UserDataDir, + }) + if err != nil { + return err + } + session.mu.Lock() + if session.runtime != nil && session.runtime.Status().Ready { + session.mu.Unlock() + runtime.Stop() + return nil + } + session.runtime = runtime + session.mu.Unlock() + return nil +} + +func (session *Session) stop() { + session.mu.Lock() + runtime := session.runtime + tabs := make([]*sessionTab, 0, len(session.tabs)) + for _, tab := range session.tabs { + tabs = append(tabs, tab) + } + session.runtime = nil + session.tabs = map[string]*sessionTab{} + session.activeTarget = "" + session.pendingDialogs = map[string]PendingDialog{} + session.cookieSync = map[string]string{} + session.mu.Unlock() + + for _, tab := range tabs { + cancelTabContexts(tab) + } + if runtime != nil { + runtime.Stop() + } +} + +func (session *Session) wrapError(err error) error { + if err == nil { + return nil + } + if IsFatalError(err) { + session.stop() + if _, ok := err.(*FatalError); ok { + return err + } + return &FatalError{Err: err} + } + return err +} + +func (session *Session) createTab() (*sessionTab, error) { + if err := session.ensureStarted(); err != nil { + return nil, err + } + reusableTargetID, err := session.resolveReusableTargetID() + if err != nil { + return nil, err + } + if strings.TrimSpace(reusableTargetID) != "" { + return session.attachTargetAsTab(reusableTargetID, false) + } + targetID, err := session.createBlankTarget() + if err != nil { + return nil, err + } + return session.attachTargetAsTab(targetID, false) +} + +func (session *Session) createBlankTarget() (string, error) { + options := session.optionsSnapshot() + runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second) + if err != nil { + return "", err + } + defer cancel() + + createTarget := targetpkg.CreateTarget("about:blank") + if !options.Headless { + createTarget = createTarget.WithNewWindow(true) + } + createdTargetID, err := createTarget.Do(runCtx) + if err != nil { + return "", wrapRuntimeHangError(err) + } + targetID := strings.TrimSpace(string(createdTargetID)) + if targetID == "" { + return "", errors.New("create target returned empty target id") + } + return targetID, nil +} + +func (session *Session) openTab(tab *sessionTab, targetURL string, timeout time.Duration) error { + if session == nil || tab == nil { + return errors.New("tab unavailable") + } + if err := session.ensureRequestInterception(tab); err != nil { + return err + } + runCtx, cancel := context.WithTimeout(tab.ctx, normalizeTimeout(timeout, 30*time.Second)) + defer cancel() + + var currentURL string + var title string + if err := chromedp.Run(runCtx, + chromedp.Navigate(strings.TrimSpace(targetURL)), + chromedp.Location(¤tURL), + chromedp.Title(&title), + ); err != nil { + return wrapRuntimeHangError(err) + } + + finalURL := preferredPageURL(currentURL, targetURL) + if err := session.assertObservedURLAllowed(finalURL); err != nil { + return err + } + storeNavigation(tab, finalURL) + tab.mu.Lock() + if title = strings.TrimSpace(title); title != "" { + tab.title = title + } + tab.mu.Unlock() + return nil +} + +func (session *Session) attachTargetAsTab(targetID string, createNew bool) (*sessionTab, error) { + if err := session.ensureStarted(); err != nil { + return nil, err + } + session.mu.Lock() + if trimmed := strings.TrimSpace(targetID); trimmed != "" { + if existing, ok := session.tabs[trimmed]; ok { + session.activeTarget = existing.TargetID + session.mu.Unlock() + return existing, nil + } + } + runtime := session.runtime + session.mu.Unlock() + if runtime == nil { + return nil, errors.New("browser runtime unavailable") + } + var ( + tabCtx context.Context + cancel context.CancelFunc + ) + trimmedTargetID := strings.TrimSpace(targetID) + if !createNew && trimmedTargetID != "" { + tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(trimmedTargetID))) + } else { + tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext()) + } + tab := &sessionTab{ + TargetID: trimmedTargetID, + ctx: tabCtx, + cancel: cancel, + refs: map[string]snapshotRef{}, + } + if err := chromedp.Run(tabCtx, chromedp.ActionFunc(func(ctx context.Context) error { + chromeCtx := chromedp.FromContext(ctx) + if chromeCtx == nil || chromeCtx.Target == nil { + return errors.New("tab target unavailable") + } + resolvedTargetID := string(chromeCtx.Target.TargetID) + if strings.TrimSpace(resolvedTargetID) == "" { + return errors.New("tab target unavailable") + } + if strings.TrimSpace(tab.TargetID) == "" { + tab.TargetID = resolvedTargetID + return nil + } + if !strings.EqualFold(strings.TrimSpace(tab.TargetID), strings.TrimSpace(resolvedTargetID)) { + return fmt.Errorf("tab target mismatch: expected %s got %s", strings.TrimSpace(tab.TargetID), strings.TrimSpace(resolvedTargetID)) + } + return nil + })); err != nil { + cancel() + return nil, wrapRuntimeHangError(err) + } + activateCtx, activateCancel, activateErr := session.newBrowserExecutorContext(5 * time.Second) + if activateErr == nil { + _ = targetpkg.ActivateTarget(targetpkg.ID(tab.TargetID)).Do(activateCtx) + activateCancel() + } + if err := session.attachTab(tab); err != nil { + cancel() + return nil, err + } + session.mu.Lock() + session.tabs[tab.TargetID] = tab + session.activeTarget = tab.TargetID + session.mu.Unlock() + return tab, nil +} + +func (session *Session) attachTab(tab *sessionTab) error { + chromedp.ListenTarget(tab.ctx, func(ev any) { + switch event := ev.(type) { + case *pagepkg.EventJavascriptDialogOpening: + session.mu.Lock() + session.pendingDialogs[tab.TargetID] = PendingDialog{ + Message: strings.TrimSpace(event.Message), + Type: string(event.Type), + ExpiresAt: time.Now().Add(5 * time.Minute), + } + session.mu.Unlock() + case *fetch.EventRequestPaused: + go session.handlePausedRequest(tab, event) + } + }) + return nil +} + +func (session *Session) enableRequestInterception(tab *sessionTab) error { + return session.runOnTab(tab, 5*time.Second, chromedp.ActionFunc(func(ctx context.Context) error { + return fetch.Enable().WithPatterns([]*fetch.RequestPattern{ + { + URLPattern: "*", + RequestStage: fetch.RequestStageRequest, + }, + }).Do(ctx) + })) +} + +func (session *Session) ensureRequestInterception(tab *sessionTab) error { + if tab == nil { + return errors.New("tab unavailable") + } + tab.mu.RLock() + enabled := tab.fetchEnabled + tab.mu.RUnlock() + if enabled { + return nil + } + if err := session.enableRequestInterception(tab); err != nil { + return err + } + tab.mu.Lock() + tab.fetchEnabled = true + tab.mu.Unlock() + return nil +} + +func (session *Session) handlePausedRequest(tab *sessionTab, event *fetch.EventRequestPaused) { + if tab == nil || event == nil { + return + } + requestURL := "" + if event.Request != nil { + requestURL = strings.TrimSpace(event.Request.URL) + } + err := session.runOnTab(tab, 5*time.Second, chromedp.ActionFunc(func(ctx context.Context) error { + if event.Request == nil { + return fetch.ContinueRequest(event.RequestID).Do(ctx) + } + options := session.optionsSnapshot() + if err := assertRequestURLAllowed(ctx, requestURL, options.SSRFRules); err != nil { + blockedErr := fmt.Errorf("blocked by SSRF policy for %s: %w", requestURL, err) + setBlockedRequestError(tab, blockedErr) + zap.L().Warn( + "browser request blocked by ssrf policy", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", requestURL), + zap.Error(err), + )..., + ) + return fetch.FailRequest(event.RequestID, network.ErrorReasonBlockedByClient).Do(ctx) + } + return fetch.ContinueRequest(event.RequestID).Do(ctx) + })) + if err != nil { + zap.L().Warn( + "browser request interception handling failed", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", requestURL), + zap.Error(err), + )..., + ) + } +} + +func (session *Session) resolveTab(targetID string, allowActive bool) (*sessionTab, error) { + session.mu.Lock() + defer session.mu.Unlock() + targetID = strings.TrimSpace(targetID) + if targetID != "" { + tab, ok := session.tabs[targetID] + if !ok { + return nil, errors.New("tab not found") + } + return tab, nil + } + if allowActive && session.activeTarget != "" { + if tab, ok := session.tabs[session.activeTarget]; ok { + return tab, nil + } + } + for _, tab := range session.tabs { + return tab, nil + } + return nil, errNoOpenTab +} + +func (session *Session) closeTab(targetID string, timeout time.Duration) (ActionResult, error) { + tab, err := session.resolveTab(targetID, false) + if err != nil { + return ActionResult{}, err + } + cancelTabContexts(tab) + session.mu.Lock() + runtime := session.runtime + session.mu.Unlock() + if runtime != nil { + runCtx, cancel, err := session.newBrowserExecutorContext(normalizeTimeout(timeout, 15*time.Second)) + if err == nil { + _ = targetpkg.CloseTarget(targetpkg.ID(tab.TargetID)).Do(runCtx) + cancel() + } + } + session.detachTab(tab.TargetID) + return ActionResult{ + OK: true, + TargetID: tab.TargetID, + Closed: true, + }, nil +} + +func (session *Session) detachTab(targetID string) { + session.mu.Lock() + defer session.mu.Unlock() + targetID = strings.TrimSpace(targetID) + if targetID == "" { + return + } + delete(session.tabs, targetID) + delete(session.pendingDialogs, targetID) + if session.activeTarget == targetID { + session.activeTarget = "" + for _, tab := range session.tabs { + session.activeTarget = tab.TargetID + break + } + } +} + +func (session *Session) browserPageTargets(timeout time.Duration) (map[string]*targetpkg.Info, error) { + execCtx, cancel, err := session.newBrowserExecutorContext(timeout) + if err != nil { + return nil, err + } + defer cancel() + infos, err := targetpkg.GetTargets().Do(execCtx) + if err != nil { + return nil, err + } + result := map[string]*targetpkg.Info{} + for _, info := range infos { + if info == nil || info.Type != "page" { + continue + } + result[string(info.TargetID)] = info + } + return result, nil +} + +func (session *Session) setActiveTarget(targetID string) { + session.mu.Lock() + defer session.mu.Unlock() + session.activeTarget = strings.TrimSpace(targetID) +} + +func (session *Session) snapshotPageTargets() (map[string]*targetpkg.Info, error) { + return session.browserPageTargets(3 * time.Second) +} + +func (session *Session) prepareNewTabWaiter(parent context.Context, tab *sessionTab) (*newTabWaiter, error) { + if tab == nil { + return nil, errors.New("tab unavailable") + } + tab.mu.RLock() + baseCtx := tab.ctx + tab.mu.RUnlock() + if baseCtx == nil { + return nil, errors.New("tab context unavailable") + } + waitCtx, cancel := context.WithCancel(baseCtx) + var stop func() bool + if parent != nil { + if err := parent.Err(); err != nil { + cancel() + return nil, err + } + stop = context.AfterFunc(parent, cancel) + } + return &newTabWaiter{ + ctx: waitCtx, + ids: chromedp.WaitNewTarget(waitCtx, func(info *targetpkg.Info) bool { + return info != nil && info.Type == "page" + }), + cancel: cancel, + stop: stop, + }, nil +} + +func (session *Session) waitForNewTab(waiter *newTabWaiter, timeout time.Duration) (*sessionTab, bool) { + if waiter == nil { + return nil, false + } + timeout = normalizeTimeout(timeout, 1500*time.Millisecond) + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case <-waiter.ctx.Done(): + return nil, false + case <-timer.C: + return nil, false + case targetID, ok := <-waiter.ids: + if !ok { + return nil, false + } + trimmedTargetID := strings.TrimSpace(string(targetID)) + if trimmedTargetID == "" { + return nil, false + } + tab, err := session.attachTargetAsTab(trimmedTargetID, false) + if err != nil { + return nil, false + } + return tab, true + } +} + +func (session *Session) resolveReusableTargetID() (string, error) { + session.mu.Lock() + runtime := session.runtime + managedTabs := len(session.tabs) + session.mu.Unlock() + if runtime == nil || managedTabs > 0 { + return "", nil + } + infos, err := session.browserPageTargets(3 * time.Second) + if err != nil { + return "", err + } + return pickReusableTargetID(mapTargetInfos(infos)), nil +} + +func (session *Session) assertURLAllowed(rawURL string) error { + options := session.optionsSnapshot() + return assertNavigationURLAllowed(rawURL, options.SSRFRules) +} + +func AssertURLAllowed(rawURL string, policy SSRFPolicy) error { + return assertNavigationURLAllowed(rawURL, policy) +} + +func assertNavigationURLAllowed(rawURL string, policy SSRFPolicy) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSSRFValidationTimeout) + defer cancel() + return assertURLAllowedWithSchemes(ctx, rawURL, policy, map[string]struct{}{ + "http": {}, + "https": {}, + }, true) +} + +func assertRequestURLAllowed(ctx context.Context, rawURL string, policy SSRFPolicy) error { + return assertURLAllowedWithSchemes(ctx, rawURL, policy, map[string]struct{}{ + "http": {}, + "https": {}, + "ws": {}, + "wss": {}, + }, false) +} + +func assertURLAllowedWithSchemes(ctx context.Context, rawURL string, policy SSRFPolicy, allowedSchemes map[string]struct{}, rejectUnsupportedScheme bool) error { + trimmed := strings.TrimSpace(rawURL) + if trimmed == "" { + return errors.New("targetUrl is required") + } + parsed, err := url.Parse(trimmed) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if _, ok := allowedSchemes[scheme]; !ok { + if rejectUnsupportedScheme { + return errors.New("only http(s) urls are supported") + } + return nil + } + return assertParsedURLAllowed(ctx, parsed, policy) +} + +func assertParsedURLAllowed(ctx context.Context, parsed *url.URL, policy SSRFPolicy) error { + if parsed == nil { + return errors.New("invalid url") + } + hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if hostname == "" { + return errors.New("url hostname is required") + } + if isHostnameAllowed(hostname, policy) { + return nil + } + if policy.DangerouslyAllowPrivateNetwork { + return nil + } + if hostname == "localhost" || strings.HasSuffix(hostname, ".local") || strings.HasSuffix(hostname, ".internal") { + return fmt.Errorf("blocked private hostname: %s", hostname) + } + if ip := net.ParseIP(hostname); ip != nil { + if isPrivateOrLocalIP(ip) { + return fmt.Errorf("blocked private IP: %s", hostname) + } + return nil + } + if ctx == nil { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), defaultSSRFValidationTimeout) + defer cancel() + } + records, err := lookupIPAddrsForHost(ctx, hostname) + if err != nil { + return fmt.Errorf("could not validate hostname %s: %w", hostname, err) + } + if len(records) == 0 { + return fmt.Errorf("could not validate hostname %s: no IPs returned", hostname) + } + for _, record := range records { + if isPrivateOrLocalIP(record.IP) { + return fmt.Errorf("blocked hostname resolving to private IP: %s -> %s", hostname, record.IP.String()) + } + } + return nil +} + +func (session *Session) assertObservedURLAllowed(rawURL string) error { + trimmed := strings.TrimSpace(rawURL) + if trimmed == "" { + return nil + } + return session.assertURLAllowed(trimmed) +} + +func (session *Session) ensureCookiesForURL(ctx context.Context, tab *sessionTab, targetURL string) error { + if session == nil || tab == nil { + return nil + } + options := session.optionsSnapshot() + if options.Cookies == nil { + return nil + } + cookies, err := options.Cookies.ResolveCookiesForURL(ctx, targetURL) + if err != nil { + return &ConnectorCookieError{URL: targetURL, Err: err} + } + if len(cookies) == 0 { + return nil + } + syncKeys := cookieSyncKeys(targetURL, cookies) + fingerprint := cookieFingerprint(cookies) + session.mu.Lock() + if hasCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) { + session.mu.Unlock() + return nil + } + session.mu.Unlock() + runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second) + if err != nil { + return &ConnectorCookieError{URL: targetURL, Err: err} + } + defer cancel() + if err := SetCookiesOnBrowser(runCtx, targetURL, cookies); err != nil { + return &ConnectorCookieError{URL: targetURL, Err: err} + } + session.mu.Lock() + rememberCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) + session.mu.Unlock() + return nil +} + +func (session *Session) ensureCookiesForURLOnBrowser(ctx context.Context, targetURL string) error { + options := session.optionsSnapshot() + if options.Cookies == nil { + return nil + } + cookies, err := options.Cookies.ResolveCookiesForURL(ctx, targetURL) + if err != nil { + return &ConnectorCookieError{URL: targetURL, Err: err} + } + if len(cookies) == 0 { + return nil + } + syncKeys := cookieSyncKeys(targetURL, cookies) + fingerprint := cookieFingerprint(cookies) + session.mu.Lock() + if hasCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) { + session.mu.Unlock() + return nil + } + runtime := session.runtime + session.mu.Unlock() + if runtime == nil { + return errors.New("browser runtime unavailable") + } + runCtx, cancel, err := session.newBrowserExecutorContext(10 * time.Second) + if err != nil { + return err + } + defer cancel() + if err := SetCookiesOnBrowser(runCtx, targetURL, cookies); err != nil { + return &ConnectorCookieError{URL: targetURL, Err: err} + } + session.mu.Lock() + rememberCookieSyncFingerprint(session.cookieSync, syncKeys, fingerprint) + session.mu.Unlock() + return nil +} + +func (session *Session) collectActionResult(tab *sessionTab, limit int, timeout time.Duration, stateRequired bool) (ActionResult, error) { + if err := consumeBlockedRequestError(tab); err != nil { + return ActionResult{}, err + } + pageState, err := session.collectPageState(tab, limit, timeout) + if err != nil && isTargetLookupError(err) { + rebound := session.rebindTabAfterNavigation(tab, tabURL(tab)) + if rebound != nil { + tab = rebound + pageState, err = session.collectPageState(tab, limit, timeout) + } + } else if err != nil && shouldDeferStateCapture(err) { + rebound := session.rebindTabFromBrowserTargets(tab, tabURL(tab)) + if rebound != nil && rebound.TargetID != "" && rebound.TargetID != tab.TargetID { + tab = rebound + pageState, err = session.collectPageState(tab, limit, timeout) + } + } + if err == nil { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + return resultFromPageState(tab, pageState), nil + } + if stateRequired || !shouldDeferStateCapture(err) { + return ActionResult{}, err + } + session.refreshTabMetadataFromBrowser(tab) + storedURL := tabURL(tab) + url, title := session.readPageMetadata(tab, 1200*time.Millisecond) + url = preferredPageURL(url, storedURL) + zap.L().Warn( + "browser state capture deferred", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", url), + zap.Duration("timeout", timeout), + zap.Error(err), + )..., + ) + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return ActionResult{}, blockedErr + } + return ActionResult{ + OK: true, + TargetID: tab.TargetID, + URL: url, + Title: title, + StateAvailable: false, + StateError: strings.TrimSpace(err.Error()), + }, nil +} + +func (session *Session) collectPageState(tab *sessionTab, limit int, timeout time.Duration) (*PageState, error) { + snapshotCtx, snapshotCancel, err := newTabRunContext(tab, timeout) + if err != nil { + return nil, err + } + defer snapshotCancel() + capture, err := collectSnapshot(tab, snapshotCtx, limit, timeout) + if err != nil { + return nil, err + } + pageState := &PageState{ + Version: capture.Version, + URL: capture.URL, + Title: capture.Title, + Items: capture.Items, + ItemCount: len(capture.Items), + Truncated: capture.Truncated, + ViewportOnly: capture.ViewportOnly, + CapturedAt: time.Now().Format(time.RFC3339), + } + tab.mu.Lock() + tab.refs = capture.Refs + tab.lastURL = capture.URL + tab.title = capture.Title + tab.lastState = pageState + tab.stateVersion = capture.Version + tab.mu.Unlock() + return pageState, nil +} + +func resultFromPageState(tab *sessionTab, pageState *PageState) ActionResult { + result := ActionResult{ + OK: true, + StateAvailable: true, + } + if tab != nil { + result.TargetID = tab.TargetID + } + if pageState != nil { + result.URL = pageState.URL + result.Title = pageState.Title + result.StateVersion = pageState.Version + result.State = pageState + result.Items = pageState.Items + } + return result +} + +func collectSnapshot(tab *sessionTab, runBaseCtx context.Context, limit int, timeout time.Duration) (*snapshotCapture, error) { + if limit <= 0 { + limit = defaultSnapshotLimit + } + if timeout <= 0 { + timeout = 10 * time.Second + } + maxScan := limit * 20 + if maxScan < 500 { + maxScan = 500 + } + if maxScan > 3000 { + maxScan = 3000 + } + timeBudgetMs := int(timeout / time.Millisecond / 2) + if timeBudgetMs < 250 { + timeBudgetMs = 250 + } + if timeBudgetMs > 1500 { + timeBudgetMs = 1500 + } + script := fmt.Sprintf(`(() => { + const limit = %d; + const maxScan = %d; + const timeBudgetMs = %d; + const startedAt = (typeof performance !== "undefined" && performance.now) ? performance.now() : Date.now(); + let truncated = false; + const now = () => ((typeof performance !== "undefined" && performance.now) ? performance.now() : Date.now()); + const inferRole = (el) => { + const explicit = (el.getAttribute("role") || "").trim(); + if (explicit) return explicit.toLowerCase(); + const tag = el.tagName.toLowerCase(); + if (tag === "a") return "link"; + if (tag === "button") return "button"; + if (tag === "textarea" || tag === "input") return "textbox"; + if (tag === "select") return "combobox"; + return "element"; + }; + const cssPath = (el) => { + if (el.id) return "#" + CSS.escape(el.id); + const parts = []; + let node = el; + while (node && node.nodeType === 1 && parts.length < 6) { + let part = node.tagName.toLowerCase(); + if (node.classList && node.classList.length > 0) { + part += "." + Array.from(node.classList).slice(0, 2).map((item) => CSS.escape(item)).join("."); + } + const parent = node.parentElement; + if (parent) { + const siblings = Array.from(parent.children).filter((candidate) => candidate.tagName === node.tagName); + if (siblings.length > 1) { + part += ":nth-of-type(" + (siblings.indexOf(node) + 1) + ")"; + } + } + parts.unshift(part); + node = parent; + } + return parts.join(" > "); + }; + const visible = (el) => { + const style = window.getComputedStyle(el); + const rect = el.getBoundingClientRect(); + return style && style.visibility !== "hidden" && style.display !== "none" && rect.width > 0 && rect.height > 0; + }; + const elements = document.querySelectorAll('a,button,input,textarea,select,summary,[role="button"],[role="link"],[role="menuitem"],[tabindex]:not([tabindex="-1"])'); + const candidates = []; + const scanLimit = Math.min(elements.length, maxScan); + if (elements.length > scanLimit) { + truncated = true; + } + for (let index = 0; index < scanLimit; index += 1) { + if (candidates.length >= limit) { + truncated = true; + break; + } + if ((now() - startedAt) > timeBudgetMs) { + truncated = true; + break; + } + const el = elements[index]; + if (!el) continue; + try { + if (!visible(el)) continue; + const text = ((el.value || el.textContent || "") + "").replace(/\s+/g, " ").trim(); + const name = ((el.getAttribute("aria-label") || el.getAttribute("title") || el.placeholder || text) + "").replace(/\s+/g, " ").trim(); + candidates.push({ + selector: cssPath(el), + role: inferRole(el), + name, + text, + }); + } catch (_err) { + truncated = true; + } + } + return { + url: document.location.toString(), + title: document.title, + truncated, + candidates, + }; + })()`, limit, maxScan, timeBudgetMs) + var payload struct { + URL string `json:"url"` + Title string `json:"title"` + Truncated bool `json:"truncated"` + Candidates []struct { + Selector string `json:"selector"` + Role string `json:"role"` + Name string `json:"name"` + Text string `json:"text"` + } `json:"candidates"` + } + runCtx, cancel := context.WithTimeout(runBaseCtx, timeout) + defer cancel() + captureStartedAt := time.Now() + if err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(script, &payload)); err != nil { + zap.L().Warn( + "browser snapshot capture failed", + zap.String("targetId", tab.TargetID), + zap.Duration("elapsed", time.Since(captureStartedAt).Round(time.Millisecond)), + zap.Duration("timeout", timeout), + zap.Error(err), + ) + return nil, err + } + currentURL := strings.TrimSpace(payload.URL) + title := strings.TrimSpace(payload.Title) + version, refsAllocated := allocateStateRefs(tab, len(payload.Candidates)) + items := make([]SnapshotItem, 0, len(payload.Candidates)) + refs := map[string]snapshotRef{} + countByRoleName := map[string]int{} + for index, item := range payload.Candidates { + ref := refsAllocated[index] + key := item.Role + "\n" + item.Name + nth := countByRoleName[key] + countByRoleName[key] = nth + 1 + items = append(items, SnapshotItem{ + Ref: ref, + Role: item.Role, + Name: item.Name, + Text: item.Text, + Depth: 0, + Nth: nth, + }) + refs[ref] = snapshotRef{ + Selector: item.Selector, + Role: item.Role, + Name: item.Name, + Nth: nth, + } + } + return &snapshotCapture{ + Version: version, + URL: currentURL, + Title: title, + Items: items, + Refs: refs, + Truncated: payload.Truncated || len(payload.Candidates) >= limit, + ViewportOnly: false, + }, nil +} + +func allocateStateRefs(tab *sessionTab, count int) (uint64, []string) { + tab.mu.Lock() + defer tab.mu.Unlock() + tab.stateVersion++ + version := tab.stateVersion + refs := make([]string, count) + for index := 0; index < count; index++ { + tab.nextRefID++ + refs[index] = fmt.Sprintf("e%d", tab.nextRefID) + } + return version, refs +} + +func clearTabState(tab *sessionTab) { + if tab == nil { + return + } + tab.mu.Lock() + tab.refs = map[string]snapshotRef{} + tab.lastState = nil + tab.mu.Unlock() +} + +func resolveRefSelector(tab *sessionTab, ref string) (string, error) { + if tab == nil { + return "", errors.New("tab unavailable") + } + ref = strings.TrimSpace(ref) + if ref == "" { + return "", errors.New("ref is required") + } + tab.mu.RLock() + defer tab.mu.RUnlock() + item, ok := tab.refs[ref] + if !ok { + return "", &InvalidRefError{Ref: ref} + } + if strings.TrimSpace(item.Selector) == "" { + return "", errors.New("ref selector unavailable") + } + return item.Selector, nil +} + +func (session *Session) readPageMetadata(tab *sessionTab, timeout time.Duration) (string, string) { + if session == nil || tab == nil { + return "", "" + } + if timeout <= 0 { + timeout = 1200 * time.Millisecond + } + baseCtx, cancel, err := newTabRunContext(tab, timeout) + if err != nil { + tab.mu.RLock() + defer tab.mu.RUnlock() + return strings.TrimSpace(tab.lastURL), strings.TrimSpace(tab.title) + } + defer cancel() + var metadata struct { + URL string `json:"url"` + Title string `json:"title"` + } + if err := chromedp.Run(baseCtx, chromedp.EvaluateAsDevTools(`({ + url: document.location.toString(), + title: document.title, + })`, &metadata)); err == nil { + currentURL := strings.TrimSpace(metadata.URL) + title := strings.TrimSpace(metadata.Title) + tab.mu.Lock() + currentURL = preferredPageURL(currentURL, tab.lastURL) + if currentURL != "" { + tab.lastURL = currentURL + } + if title != "" { + tab.title = title + } + tab.mu.Unlock() + return currentURL, title + } + tab.mu.RLock() + defer tab.mu.RUnlock() + return strings.TrimSpace(tab.lastURL), strings.TrimSpace(tab.title) +} + +func (session *Session) newEphemeralTargetContext(targetID string) (context.Context, context.CancelFunc, error) { + targetID = strings.TrimSpace(targetID) + if targetID == "" { + return nil, nil, errors.New("tab target unavailable") + } + session.mu.Lock() + runtime := session.runtime + session.mu.Unlock() + if runtime == nil { + return nil, nil, errors.New("browser runtime unavailable") + } + ctx, cancel := chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(targetID))) + return ctx, cancel, nil +} + +func (session *Session) newBrowserExecutorContext(timeout time.Duration) (context.Context, context.CancelFunc, error) { + session.mu.Lock() + runtime := session.runtime + session.mu.Unlock() + if runtime == nil { + return nil, nil, errors.New("browser runtime unavailable") + } + baseCtx := runtime.BrowserContext() + if baseCtx == nil { + return nil, nil, errors.New("browser context unavailable") + } + chromeCtx := chromedp.FromContext(baseCtx) + if chromeCtx == nil || chromeCtx.Browser == nil { + return nil, nil, errors.New("browser executor unavailable") + } + runCtx, cancel := context.WithTimeout(baseCtx, timeout) + return cdp.WithExecutor(runCtx, chromeCtx.Browser), cancel, nil +} + +func (session *Session) browserTargetInfo(targetID string, timeout time.Duration) *targetpkg.Info { + targetID = strings.TrimSpace(targetID) + if targetID == "" { + return nil + } + targets, err := session.browserPageTargets(timeout) + if err != nil { + return nil + } + return targets[targetID] +} + +func (session *Session) refreshTabMetadataFromBrowser(tab *sessionTab) { + if tab == nil { + return + } + info := session.browserTargetInfo(tab.TargetID, 1500*time.Millisecond) + if info == nil { + return + } + currentURL := strings.TrimSpace(info.URL) + title := strings.TrimSpace(info.Title) + tab.mu.Lock() + currentURL = preferredPageURL(currentURL, tab.lastURL) + if currentURL != "" { + tab.lastURL = currentURL + } + if title != "" { + tab.title = title + } + tab.mu.Unlock() +} + +func (session *Session) observeActionNavigation(tab *sessionTab, before map[string]*targetpkg.Info, previousURL string, timeout time.Duration) (*sessionTab, bool, error) { + if tab == nil { + return tab, false, nil + } + timeout = normalizeTimeout(timeout, 5*time.Second) + deadline := time.Now().Add(timeout) + for { + if err := consumeBlockedRequestError(tab); err != nil { + return tab, false, err + } + targets, err := session.browserPageTargets(500 * time.Millisecond) + if err == nil { + if info := targets[tab.TargetID]; info != nil { + currentURL := strings.TrimSpace(info.URL) + currentTitle := strings.TrimSpace(info.Title) + if shouldTreatNavigationAsComplete(currentURL, previousURL, "") { + if err := session.assertObservedURLAllowed(currentURL); err != nil { + return tab, false, err + } + tab.mu.Lock() + tab.lastURL = currentURL + if currentTitle != "" { + tab.title = currentTitle + } + tab.mu.Unlock() + rebound := session.rebindTabAfterNavigation(tab, currentURL) + if rebound != nil { + tab = rebound + } + storeNavigation(tab, currentURL) + return tab, true, nil + } + } + for targetID, info := range targets { + if info == nil || strings.TrimSpace(targetID) == "" || targetID == tab.TargetID { + continue + } + if _, existed := before[targetID]; existed { + continue + } + currentURL := strings.TrimSpace(info.URL) + currentTitle := strings.TrimSpace(info.Title) + if !shouldTreatNavigationAsComplete(currentURL, previousURL, "") { + continue + } + if err := session.assertObservedURLAllowed(currentURL); err != nil { + return tab, false, err + } + rebound, attachErr := session.attachTargetAsTab(targetID, false) + if attachErr != nil || rebound == nil { + continue + } + rebound.mu.Lock() + rebound.lastURL = currentURL + if currentTitle != "" { + rebound.title = currentTitle + } + rebound.mu.Unlock() + if rebound.TargetID != tab.TargetID { + session.detachTab(tab.TargetID) + } + storeNavigation(rebound, currentURL) + return rebound, true, nil + } + } else { + info := session.browserTargetInfo(tab.TargetID, 500*time.Millisecond) + if info == nil { + if time.Now().After(deadline) { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return tab, false, blockedErr + } + return tab, false, nil + } + time.Sleep(150 * time.Millisecond) + continue + } + currentURL := strings.TrimSpace(info.URL) + currentTitle := strings.TrimSpace(info.Title) + if shouldTreatNavigationAsComplete(currentURL, previousURL, "") { + if err := session.assertObservedURLAllowed(currentURL); err != nil { + return tab, false, err + } + tab.mu.Lock() + tab.lastURL = currentURL + if currentTitle != "" { + tab.title = currentTitle + } + tab.mu.Unlock() + rebound := session.rebindTabAfterNavigation(tab, currentURL) + if rebound != nil { + tab = rebound + } + storeNavigation(tab, currentURL) + return tab, true, nil + } + } + if time.Now().After(deadline) { + if blockedErr := consumeBlockedRequestError(tab); blockedErr != nil { + return tab, false, blockedErr + } + return tab, false, nil + } + time.Sleep(150 * time.Millisecond) + } +} + +func mapTargetInfos(targets map[string]*targetpkg.Info) []*targetpkg.Info { + if len(targets) == 0 { + return nil + } + result := make([]*targetpkg.Info, 0, len(targets)) + for _, info := range targets { + if info != nil { + result = append(result, info) + } + } + return result +} + +func (session *Session) runOnTab(tab *sessionTab, timeout time.Duration, actions ...chromedp.Action) error { + return session.runOnTabFunc(tab, timeout, func(ctx context.Context) error { + return chromedp.Run(ctx, actions...) + }) +} + +func (session *Session) runOnTabFunc(tab *sessionTab, timeout time.Duration, fn func(context.Context) error) error { + if session == nil || tab == nil { + return errors.New("tab unavailable") + } + baseCtx, cancel, err := newTabRunContext(tab, timeout) + if err != nil { + return err + } + defer cancel() + return fn(baseCtx) +} + +func newTabRunContextWithParent(parent context.Context, tab *sessionTab, timeout time.Duration) (context.Context, context.CancelFunc, error) { + runCtx, cancel, err := newTabRunContext(tab, timeout) + if err != nil { + return nil, nil, err + } + if parent == nil { + return runCtx, cancel, nil + } + if err := parent.Err(); err != nil { + cancel() + return nil, nil, err + } + stop := context.AfterFunc(parent, cancel) + return runCtx, func() { + stop() + cancel() + }, nil +} + +func newTabRunContext(tab *sessionTab, timeout time.Duration) (context.Context, context.CancelFunc, error) { + if tab == nil { + return nil, nil, errors.New("tab unavailable") + } + if timeout <= 0 { + timeout = 10 * time.Second + } + tab.mu.RLock() + baseCtx := tab.ctx + tab.mu.RUnlock() + if baseCtx == nil { + return nil, nil, errors.New("tab context unavailable") + } + runCtx, cancel := context.WithTimeout(baseCtx, timeout) + return runCtx, cancel, nil +} + +func rememberTabCleanup(tab *sessionTab, cancel context.CancelFunc) { + if tab == nil || cancel == nil { + return + } + tab.mu.Lock() + tab.cleanupCancels = append(tab.cleanupCancels, cancel) + tab.mu.Unlock() +} + +func cancelTabContexts(tab *sessionTab) { + if tab == nil { + return + } + tab.mu.Lock() + cancels := append([]context.CancelFunc(nil), tab.cleanupCancels...) + if tab.cancel != nil { + cancels = append(cancels, tab.cancel) + } + tab.cleanupCancels = nil + tab.cancel = nil + tab.ctx = nil + tab.mu.Unlock() + for index := len(cancels) - 1; index >= 0; index-- { + if cancels[index] != nil { + cancels[index]() + } + } +} + +func (session *Session) navigateTab(tab *sessionTab, targetURL string, timeout time.Duration) (*sessionTab, error) { + previousURL := tabURL(tab) + if err := session.ensureRequestInterception(tab); err != nil { + return nil, err + } + commandStartedAt := time.Now() + var errorText string + script := fmt.Sprintf(`(() => { window.location.href = %q; return true; })()`, strings.TrimSpace(targetURL)) + err := session.runOnTab(tab, timeout, chromedp.EvaluateAsDevTools(script, nil)) + if err != nil { + zap.L().Warn( + "browser navigate command failed", + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + sanitizedURLField("previousURL", previousURL), + zap.Duration("elapsed", time.Since(commandStartedAt).Round(time.Millisecond)), + zap.Error(err), + ) + return nil, err + } + if errorText != "" { + navigationErr := fmt.Errorf("page load error %s", errorText) + zap.L().Warn( + "browser navigate command returned load error", + zap.String("targetId", tab.TargetID), + sanitizedURLField("url", targetURL), + sanitizedURLField("previousURL", previousURL), + zap.Duration("elapsed", time.Since(commandStartedAt).Round(time.Millisecond)), + zap.Error(navigationErr), + ) + return nil, navigationErr + } + settleTimeout := minDuration(3*time.Second, timeout) + finalURL, ok, observeErr := session.observeNavigationURL(tab, previousURL, targetURL, settleTimeout) + if observeErr != nil { + return nil, observeErr + } + if !ok { + finalURL = strings.TrimSpace(targetURL) + } + if err := session.assertObservedURLAllowed(finalURL); err != nil { + return nil, err + } + tab = session.rebindTabAfterNavigation(tab, finalURL) + storeNavigation(tab, finalURL) + return tab, nil +} + +func (session *Session) rebindTabAfterNavigation(tab *sessionTab, finalURL string) *sessionTab { + if tab == nil { + return nil + } + targets, err := session.snapshotPageTargets() + if err != nil { + return tab + } + if _, ok := targets[tab.TargetID]; ok { + return tab + } + expectedURL := strings.TrimSpace(finalURL) + if expectedURL == "" { + return tab + } + for targetID, info := range targets { + if info == nil || strings.TrimSpace(targetID) == "" { + continue + } + if !shouldTreatNavigationAsComplete(info.URL, "", expectedURL) { + continue + } + rebound, attachErr := session.attachTargetAsTab(targetID, false) + if attachErr != nil || rebound == nil { + continue + } + if rebound.TargetID != tab.TargetID { + session.detachTab(tab.TargetID) + } + return rebound + } + return tab +} + +func (session *Session) rebindTabFromBrowserTargets(tab *sessionTab, previousURL string) *sessionTab { + if tab == nil { + return nil + } + targets, err := session.browserPageTargets(1500 * time.Millisecond) + if err != nil || len(targets) == 0 { + return nil + } + currentTargetID := strings.TrimSpace(tab.TargetID) + previousURL = strings.TrimSpace(previousURL) + type candidate struct { + targetID string + info *targetpkg.Info + } + candidates := make([]candidate, 0, len(targets)) + for targetID, info := range targets { + if info == nil { + continue + } + targetID = strings.TrimSpace(targetID) + if targetID == "" || targetID == currentTargetID { + continue + } + currentURL := strings.TrimSpace(info.URL) + if currentURL == "" || isReusablePageURL(currentURL) || urlsEqual(currentURL, previousURL) { + continue + } + candidates = append(candidates, candidate{targetID: targetID, info: info}) + } + if len(candidates) == 0 { + return nil + } + sort.SliceStable(candidates, func(left int, right int) bool { + leftInfo := candidates[left].info + rightInfo := candidates[right].info + if leftInfo.Attached != rightInfo.Attached { + return !leftInfo.Attached + } + return candidates[left].targetID < candidates[right].targetID + }) + chosen := candidates[0] + rebound, attachErr := session.attachTargetAsTab(chosen.targetID, false) + if attachErr != nil || rebound == nil { + return nil + } + currentURL := strings.TrimSpace(chosen.info.URL) + currentTitle := strings.TrimSpace(chosen.info.Title) + rebound.mu.Lock() + if currentURL != "" { + rebound.lastURL = currentURL + } + if currentTitle != "" { + rebound.title = currentTitle + } + rebound.mu.Unlock() + if rebound.TargetID != currentTargetID { + session.detachTab(currentTargetID) + } + return rebound +} + +func (session *Session) observeNavigationURL(tab *sessionTab, previousURL string, targetURL string, timeout time.Duration) (string, bool, error) { + if timeout <= 0 { + timeout = 1500 * time.Millisecond + } + deadline := time.Now().Add(timeout) + for { + if err := consumeBlockedRequestError(tab); err != nil { + return "", false, err + } + currentURL, _ := session.readPageMetadata(tab, 1200*time.Millisecond) + err := error(nil) + if err == nil && shouldTreatNavigationAsComplete(currentURL, previousURL, targetURL) { + currentURL = strings.TrimSpace(currentURL) + if err := session.assertObservedURLAllowed(currentURL); err != nil { + return "", false, err + } + return currentURL, true, nil + } + if time.Now().After(deadline) { + break + } + time.Sleep(150 * time.Millisecond) + } + if err := consumeBlockedRequestError(tab); err != nil { + return "", false, err + } + return "", false, nil +} + +func storeNavigation(tab *sessionTab, finalURL string) { + if tab == nil { + return + } + tab.mu.Lock() + tab.lastURL = strings.TrimSpace(finalURL) + tab.title = "" + tab.refs = map[string]snapshotRef{} + tab.lastState = nil + tab.mu.Unlock() +} + +func (session *Session) waitOnTab(parent context.Context, tab *sessionTab, request WaitRequest, fallbackTimeout time.Duration) error { + timeout := normalizeTimeout(request.Timeout, fallbackTimeout) + if request.Time > 0 { + return sleepWithContext(parent, request.Time) + } + runCtx, cancel, err := newTabRunContextWithParent(parent, tab, timeout) + if err != nil { + return err + } + defer cancel() + if selector := strings.TrimSpace(request.Selector); selector != "" { + return chromedp.Run(runCtx, chromedp.PollFunction( + `(selector) => { + const el = document.querySelector(selector); + if (!el) return false; + const style = window.getComputedStyle(el); + const rect = el.getBoundingClientRect(); + return style && style.visibility !== "hidden" && style.display !== "none" && rect.width > 0 && rect.height > 0; + }`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(selector), + )) + } + if text := strings.TrimSpace(request.Text); text != "" { + return chromedp.Run(runCtx, chromedp.PollFunction( + `(expected) => { + const text = document.body ? document.body.innerText : ""; + return text.includes(expected); + }`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(text), + )) + } + if textGone := strings.TrimSpace(request.TextGone); textGone != "" { + return chromedp.Run(runCtx, chromedp.PollFunction( + `(expected) => { + const text = document.body ? document.body.innerText : ""; + return !text.includes(expected); + }`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(textGone), + )) + } + if urlWait := strings.TrimSpace(request.URL); urlWait != "" { + return chromedp.Run(runCtx, chromedp.PollFunction( + `(expected) => document.location.toString() === expected`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(urlWait), + )) + } + if fn := strings.TrimSpace(request.Fn); fn != "" { + return chromedp.Run(runCtx, chromedp.Poll( + fn, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + )) + } + return errors.New("wait requires at least one of: timeMs, text, textGone, selector, url, fn") +} + +func waitForTabExecutionContext(tab *sessionTab, timeout time.Duration) (string, error) { + if tab == nil { + return "", errors.New("tab unavailable") + } + timeout = normalizeTimeout(timeout, 3*time.Second) + deadline := time.Now().Add(timeout) + var lastErr error + for { + remaining := time.Until(deadline) + if remaining <= 0 { + if lastErr != nil { + return "", lastErr + } + return "", context.DeadlineExceeded + } + attemptTimeout := minDuration(remaining, 800*time.Millisecond) + runCtx, cancel := context.WithTimeout(tab.ctx, attemptTimeout) + var readyState string + err := chromedp.Run(runCtx, chromedp.EvaluateAsDevTools(`document.readyState`, &readyState)) + cancel() + if err == nil { + readyState = strings.TrimSpace(readyState) + if readyState == "" { + readyState = "unknown" + } + return readyState, nil + } + lastErr = err + time.Sleep(150 * time.Millisecond) + } +} + +func shouldRetryActionState(result ActionResult) bool { + currentURL := strings.TrimSpace(result.URL) + if currentURL == "" || isReusablePageURL(currentURL) { + return true + } + if !result.StateAvailable && result.State == nil { + return true + } + if result.State == nil { + return false + } + return strings.TrimSpace(result.Title) == "" && result.State.ItemCount == 0 +} + +func actionStateReason(result ActionResult) string { + currentURL := strings.TrimSpace(result.URL) + switch { + case currentURL == "": + return "missing-url" + case isReusablePageURL(currentURL): + return "placeholder-url" + case !result.StateAvailable && result.State == nil: + if strings.TrimSpace(result.StateError) != "" { + return "state-unavailable:" + strings.TrimSpace(result.StateError) + } + return "state-unavailable" + case result.State == nil: + return "state-missing" + case strings.TrimSpace(result.Title) == "" && result.State.ItemCount == 0: + return "empty-title-and-items" + default: + return "stable" + } +} + +func (session *Session) stabilizeActionResult(tab *sessionTab, result ActionResult, limit int, timeout time.Duration, maxWait time.Duration) ActionResult { + if tab == nil || maxWait <= 0 { + return result + } + deadline := time.Now().Add(maxWait) + attempts := 0 + for shouldRetryActionState(result) && time.Now().Before(deadline) { + time.Sleep(150 * time.Millisecond) + attempts++ + retried, err := session.collectActionResult(tab, limit, timeout, false) + if err != nil { + zap.L().Warn( + "browser action state stabilization failed", + append(session.logFields(), + zap.String("targetId", tab.TargetID), + zap.Int("attempts", attempts), + zap.String("reason", actionStateReason(result)), + zap.Error(err), + )..., + ) + break + } + result = retried + } + return result +} + +func (session *Session) actClick(tab *sessionTab, request ActRequest) error { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return err + } + script := fmt.Sprintf(`(() => { + const el = document.querySelector(%q); + if (!el) throw new Error("element not found"); + el.scrollIntoView({ block: "center", inline: "center" }); + const invoke = () => { + el.focus?.(); + if (typeof PointerEvent === "function") { + el.dispatchEvent(new PointerEvent("pointerdown", { bubbles: true, cancelable: true, pointerType: "mouse", isPrimary: true, button: 0, buttons: 1 })); + el.dispatchEvent(new PointerEvent("pointerup", { bubbles: true, cancelable: true, pointerType: "mouse", isPrimary: true, button: 0, buttons: 0 })); + } + el.dispatchEvent(new MouseEvent("mousedown", { bubbles: true, cancelable: true, button: 0, buttons: 1 })); + el.dispatchEvent(new MouseEvent("mouseup", { bubbles: true, cancelable: true, button: 0, buttons: 0 })); + el.click(); + }; + setTimeout(invoke, 0); + return true; + })()`, selector) + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil)) +} + +func (session *Session) actType(tab *sessionTab, request ActRequest) error { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return err + } + script := fmt.Sprintf(`(() => { + const el = document.querySelector(%q); + if (!el) throw new Error("element not found"); + el.focus(); + const nextValue = String((el.value ?? "")) + %q; + el.value = nextValue; + el.dispatchEvent(new InputEvent("input", { bubbles: true, data: %q, inputType: "insertText" })); + el.dispatchEvent(new Event("change", { bubbles: true })); + return nextValue; + })()`, selector, request.Text, request.Text) + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil)) +} + +func (session *Session) actPress(tab *sessionTab, request ActRequest) error { + if strings.TrimSpace(request.Key) == "" { + return errors.New("key is required") + } + script := fmt.Sprintf(`(() => { + const target = document.activeElement || document.body; + const key = %q; + setTimeout(() => { + const down = new KeyboardEvent("keydown", { key, bubbles: true }); + const press = new KeyboardEvent("keypress", { key, bubbles: true }); + const up = new KeyboardEvent("keyup", { key, bubbles: true }); + target.dispatchEvent(down); + target.dispatchEvent(press); + if (key === "Enter" && target && target.form) { + if (typeof target.form.requestSubmit === "function") { + target.form.requestSubmit(); + } else { + target.form.submit(); + } + } + target.dispatchEvent(up); + }, 0); + return true; + })()`, request.Key) + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(script, nil)) +} + +func (session *Session) actHover(tab *sessionTab, request ActRequest) error { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return err + } + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.dispatchEvent(new MouseEvent("mouseover", {bubbles:true})); el.dispatchEvent(new MouseEvent("mouseenter", {bubbles:true})); })()`, selector), nil)) +} + +func (session *Session) actSelect(tab *sessionTab, request ActRequest) error { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return err + } + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.value = %q; el.dispatchEvent(new Event("input", {bubbles:true})); el.dispatchEvent(new Event("change", {bubbles:true})); })()`, selector, request.Value), nil)) +} + +func (session *Session) actFill(tab *sessionTab, request ActRequest) error { + selector, err := resolveRefSelector(tab, request.Ref) + if err != nil { + return err + } + return session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.EvaluateAsDevTools(fmt.Sprintf(`(() => { const el = document.querySelector(%q); if (!el) throw new Error("element not found"); el.focus(); el.value = %q; el.dispatchEvent(new Event("input", {bubbles:true})); el.dispatchEvent(new Event("change", {bubbles:true})); })()`, selector, request.Value), nil)) +} + +func (session *Session) actResize(tab *sessionTab, request ActRequest) error { + if request.Width <= 0 { + return errors.New("width is required") + } + if request.Height <= 0 { + return errors.New("height is required") + } + return session.runOnTabFunc(tab, normalizeTimeout(request.Timeout, 15*time.Second), func(ctx context.Context) error { + return emulation.SetDeviceMetricsOverride(int64(request.Width), int64(request.Height), 1, false).Do(ctx) + }) +} + +func (session *Session) actEvaluate(tab *sessionTab, request ActRequest) error { + if strings.TrimSpace(request.Expression) == "" { + return errors.New("expression is required") + } + var result any + if err := session.runOnTab(tab, normalizeTimeout(request.Timeout, 15*time.Second), chromedp.Evaluate(request.Expression, &result)); err != nil { + return err + } + tab.mu.Lock() + tab.evaluateResult = result + tab.mu.Unlock() + return nil +} + +func evaluateResult(tab *sessionTab) any { + if tab == nil { + return nil + } + tab.mu.RLock() + defer tab.mu.RUnlock() + return tab.evaluateResult +} + +func tabURL(tab *sessionTab) string { + if tab == nil { + return "" + } + tab.mu.RLock() + defer tab.mu.RUnlock() + return tab.lastURL +} + +func preferredPageURL(candidates ...string) string { + fallback := "" + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + if fallback == "" { + fallback = candidate + } + if !isReusablePageURL(candidate) { + return candidate + } + } + return fallback +} + +func waitForText(tab *sessionTab, expected string, timeout time.Duration, gone bool) error { + expected = strings.TrimSpace(expected) + runCtx, cancel := context.WithTimeout(tab.ctx, timeout) + defer cancel() + err := chromedp.Run(runCtx, chromedp.PollFunction( + `(expected, gone) => { + const text = document.body ? document.body.innerText : ""; + const contains = text.includes(expected); + return gone ? !contains : contains; + }`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(expected, gone), + )) + if errors.Is(err, chromedp.ErrPollingTimeout) { + if gone { + return &WaitTimeoutError{Condition: "textGone"} + } + return &WaitTimeoutError{Condition: "text"} + } + return err +} + +func waitForURL(tab *sessionTab, expected string, timeout time.Duration) error { + expected = strings.TrimSpace(expected) + runCtx, cancel := context.WithTimeout(tab.ctx, timeout) + defer cancel() + err := chromedp.Run(runCtx, chromedp.PollFunction( + `(expected) => document.location.toString() === expected`, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + chromedp.WithPollingArgs(expected), + )) + if errors.Is(err, chromedp.ErrPollingTimeout) { + return &WaitTimeoutError{Condition: "url"} + } + return err +} + +func waitForEvaluateCondition(tab *sessionTab, fn string, timeout time.Duration) error { + runCtx, cancel := context.WithTimeout(tab.ctx, timeout) + defer cancel() + err := chromedp.Run(runCtx, chromedp.Poll( + fn, + nil, + chromedp.WithPollingInterval(200*time.Millisecond), + chromedp.WithPollingTimeout(timeout), + )) + if errors.Is(err, chromedp.ErrPollingTimeout) { + return &WaitTimeoutError{Condition: "fn"} + } + return err +} + +func pickReusableTargetID(infos []*targetpkg.Info) string { + choose := func(requireUnattached bool, preferBlank bool) string { + for _, info := range infos { + if info == nil || info.Type != "page" { + continue + } + if requireUnattached && info.Attached { + continue + } + if preferBlank && !isReusablePageURL(info.URL) { + continue + } + return string(info.TargetID) + } + return "" + } + for _, candidate := range []string{ + choose(true, true), + choose(true, false), + choose(false, true), + choose(false, false), + } { + if strings.TrimSpace(candidate) != "" { + return candidate + } + } + return "" +} + +func isReusablePageURL(rawURL string) bool { + switch strings.TrimSpace(strings.ToLower(rawURL)) { + case "", "about:blank", "chrome://newtab/", "chrome-search://local-ntp/local-ntp.html": + return true + default: + return false + } +} + +func shouldTreatNavigationAsComplete(observedURL string, previousURL string, targetURL string) bool { + observed := strings.TrimSpace(observedURL) + if observed == "" || observed == "about:blank" { + return false + } + if urlsEqual(observed, targetURL) { + return true + } + previous := strings.TrimSpace(previousURL) + if previous == "" || previous == "about:blank" { + return true + } + return !urlsEqual(observed, previous) +} + +func urlsEqual(left string, right string) bool { + return strings.TrimSpace(left) == strings.TrimSpace(right) +} + +func cookieSyncKeys(rawURL string, records []appcookies.Record) []string { + keys := map[string]struct{}{} + if parsed, err := url.Parse(strings.TrimSpace(rawURL)); err == nil { + if hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname())); hostname != "" { + keys[hostname] = struct{}{} + } + } + for _, record := range records { + domain := strings.ToLower(strings.TrimSpace(record.Domain)) + domain = strings.TrimPrefix(domain, ".") + if domain == "" { + continue + } + keys[domain] = struct{}{} + } + if len(keys) == 0 { + trimmed := strings.ToLower(strings.TrimSpace(rawURL)) + if trimmed != "" { + keys[trimmed] = struct{}{} + } + } + result := make([]string, 0, len(keys)) + for key := range keys { + result = append(result, key) + } + sort.Strings(result) + return result +} + +func hasCookieSyncFingerprint(values map[string]string, keys []string, fingerprint string) bool { + if len(values) == 0 || len(keys) == 0 || fingerprint == "" { + return false + } + for _, key := range keys { + if current := strings.TrimSpace(values[strings.ToLower(strings.TrimSpace(key))]); current == fingerprint { + return true + } + } + return false +} + +func rememberCookieSyncFingerprint(values map[string]string, keys []string, fingerprint string) { + if len(keys) == 0 || fingerprint == "" { + return + } + for _, key := range keys { + key = strings.ToLower(strings.TrimSpace(key)) + if key == "" { + continue + } + values[key] = fingerprint + } +} + +func isRecoverableCookieSyncError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "context deadline exceeded"), + strings.Contains(message, "browser context unavailable"), + strings.Contains(message, "browser executor unavailable"), + strings.Contains(message, "target closed"), + strings.Contains(message, "session closed"), + strings.Contains(message, "connection closed"), + strings.Contains(message, "execution context was destroyed"), + strings.Contains(message, "cannot find context with specified id"), + strings.Contains(message, "unique context id not found"): + return true + default: + return false + } +} + +func wrapRuntimeHangError(err error) error { + if err == nil { + return nil + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "context deadline exceeded"), + strings.Contains(message, "target closed"), + strings.Contains(message, "connection closed"), + strings.Contains(message, "session closed"), + strings.Contains(message, "websocket"): + return &FatalError{Err: err} + default: + return err + } +} + +func cookieFingerprint(records []appcookies.Record) string { + if len(records) == 0 { + return "" + } + items := append([]appcookies.Record(nil), records...) + sort.Slice(items, func(i, j int) bool { + left := items[i] + right := items[j] + switch { + case left.Domain != right.Domain: + return left.Domain < right.Domain + case left.Path != right.Path: + return left.Path < right.Path + default: + return left.Name < right.Name + } + }) + parts := make([]string, 0, len(items)) + for _, item := range items { + parts = append(parts, strings.Join([]string{ + strings.TrimSpace(item.Domain), + strings.TrimSpace(item.Path), + strings.TrimSpace(item.Name), + item.Value, + }, "\n")) + } + return strings.Join(parts, "\n---\n") +} + +func normalizeTimeout(value time.Duration, fallback time.Duration) time.Duration { + if value <= 0 { + value = fallback + } + if value <= 0 { + value = 15 * time.Second + } + if value < 500*time.Millisecond { + return 500 * time.Millisecond + } + if value > 120*time.Second { + return 120 * time.Second + } + return value +} + +func captureTimeout(timeout time.Duration) time.Duration { + timeout = normalizeTimeout(timeout, 15*time.Second) + scaled := timeout / 5 + if scaled < time.Second { + scaled = time.Second + } + if scaled > 5*time.Second { + scaled = 5 * time.Second + } + return scaled +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + if ctx == nil { + time.Sleep(delay) + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (session *Session) logFields() []zap.Field { + options := session.optionsSnapshot() + return []zap.Field{ + zap.String("profile", options.ProfileName), + zap.String("preferredBrowser", options.PreferredBrowser), + zap.Bool("headless", options.Headless), + } +} + +func sanitizedURLField(key string, rawURL string) zap.Field { + return zap.String(key, sanitizeLogURL(rawURL)) +} + +func sanitizeLogURL(rawURL string) string { + trimmed := strings.TrimSpace(rawURL) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil { + return trimmed + } + parsed.User = nil + parsed.RawQuery = "" + parsed.Fragment = "" + return parsed.String() +} + +func resultItemCount(result ActionResult) int { + if result.State != nil && result.State.ItemCount > 0 { + return result.State.ItemCount + } + return len(result.Items) +} + +func normalizedSnapshotLimit(limit int) int { + if limit > 0 { + return limit + } + return defaultSnapshotLimit +} + +func minDuration(left time.Duration, right time.Duration) time.Duration { + if left <= 0 { + return right + } + if right <= 0 { + return left + } + if left < right { + return left + } + return right +} + +func actInvalidatesState(kind string) bool { + switch kind { + case "click", "type", "press", "hover", "select", "fill", "resize": + return true + default: + return false + } +} + +func actNeedsSettle(kind string) bool { + switch kind { + case "click", "type", "press", "hover", "select", "fill", "resize", "wait": + return true + default: + return false + } +} + +func actMayOpenNewTab(kind string) bool { + switch kind { + case "click", "press": + return true + default: + return false + } +} + +func shouldDeferStateCapture(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "context deadline exceeded"), + strings.Contains(message, "execution context was destroyed"), + strings.Contains(message, "cannot find context with specified id"), + strings.Contains(message, "unique context id not found"): + return true + default: + return false + } +} + +func isTargetLookupError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(message, "no target with given id found") +} + +func isHostnameAllowed(hostname string, policy SSRFPolicy) bool { + hostname = strings.ToLower(strings.TrimSpace(hostname)) + if hostname == "" { + return false + } + if _, ok := policy.AllowedHostnames[hostname]; ok { + return true + } + for _, pattern := range policy.HostnameAllowlist { + pattern = strings.ToLower(strings.TrimSpace(pattern)) + if pattern == "" { + continue + } + if pattern == hostname { + return true + } + if strings.HasPrefix(pattern, "*.") { + suffix := strings.TrimPrefix(pattern, "*.") + if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix { + return true + } + continue + } + if matched, _ := filepath.Match(pattern, hostname); matched { + return true + } + } + return false +} + +func isPrivateOrLocalIP(ip net.IP) bool { + if ip == nil { + return false + } + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() || ip.IsMulticast() { + return true + } + if ip4 := ip.To4(); ip4 != nil { + if ip4[0] == 100 && (ip4[1]&0xC0) == 64 { + return true + } + } + return false +} diff --git a/internal/application/browsercdp/session_cookie_sync_test.go b/internal/application/browsercdp/session_cookie_sync_test.go new file mode 100644 index 0000000..2c48581 --- /dev/null +++ b/internal/application/browsercdp/session_cookie_sync_test.go @@ -0,0 +1,120 @@ +package browsercdp + +import ( + "errors" + "testing" + + appcookies "dreamcreator/internal/application/cookies" +) + +const ( + testCookiePrimaryURL = "https://www.example.test/" + testCookieProfileURL = "https://space.example.test/profile/demo" + testCookiePrimaryHost = "www.example.test" + testCookieProfileHost = "space.example.test" + testCookieSharedDomain = ".example.test" +) + +func TestCookieSyncKeys_IncludeCookieScopeDomains(t *testing.T) { + t.Parallel() + + keys := cookieSyncKeys(testCookieProfileURL, []appcookies.Record{ + {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"}, + {Name: "site_session", Domain: testCookiePrimaryHost, Path: "/"}, + }) + + expected := map[string]struct{}{ + testCookieProfileHost: {}, + "example.test": {}, + testCookiePrimaryHost: {}, + } + if len(keys) != len(expected) { + t.Fatalf("expected %d keys, got %d: %#v", len(expected), len(keys), keys) + } + for _, key := range keys { + if _, ok := expected[key]; !ok { + t.Fatalf("unexpected key %q in %#v", key, keys) + } + delete(expected, key) + } + if len(expected) != 0 { + t.Fatalf("missing keys: %#v", expected) + } +} + +func TestCookieSyncFingerprint_ReusesBroadCookieDomainAcrossSubdomains(t *testing.T) { + t.Parallel() + + fingerprint := "same-cookie-set" + values := map[string]string{} + initialKeys := cookieSyncKeys(testCookiePrimaryURL, []appcookies.Record{ + {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"}, + }) + rememberCookieSyncFingerprint(values, initialKeys, fingerprint) + + nextKeys := cookieSyncKeys(testCookieProfileURL, []appcookies.Record{ + {Name: "session_token", Domain: testCookieSharedDomain, Path: "/"}, + }) + if !hasCookieSyncFingerprint(values, nextKeys, fingerprint) { + t.Fatalf("expected sync fingerprint to be reused across subdomains: values=%#v keys=%#v", values, nextKeys) + } +} + +func TestIsRecoverableCookieSyncError(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + want bool + }{ + {name: "timeout", err: errors.New("connector cookie sync failed: context deadline exceeded"), want: true}, + {name: "destroyed", err: errors.New("execution context was destroyed"), want: true}, + {name: "closed", err: errors.New("target closed"), want: true}, + {name: "validation", err: errors.New("invalid cookie domain"), want: false}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isRecoverableCookieSyncError(tc.err); got != tc.want { + t.Fatalf("expected %v, got %v for %v", tc.want, got, tc.err) + } + }) + } +} + +func TestPreferredPageURL(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + candidates []string + want string + }{ + { + name: "preserve known non placeholder over about blank", + candidates: []string{"about:blank", testCookiePrimaryURL}, + want: testCookiePrimaryURL, + }, + { + name: "prefer observed real url", + candidates: []string{testCookieProfileURL, testCookiePrimaryURL}, + want: testCookieProfileURL, + }, + { + name: "fallback to requested real url", + candidates: []string{"about:blank", "", "https://www.example.test"}, + want: "https://www.example.test", + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := preferredPageURL(tc.candidates...); got != tc.want { + t.Fatalf("expected %q, got %q for %#v", tc.want, got, tc.candidates) + } + }) + } +} diff --git a/internal/application/browsercdp/session_open_live_test.go b/internal/application/browsercdp/session_open_live_test.go new file mode 100644 index 0000000..14608bf --- /dev/null +++ b/internal/application/browsercdp/session_open_live_test.go @@ -0,0 +1,411 @@ +package browsercdp + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + targetpkg "github.com/chromedp/cdproto/target" + "github.com/chromedp/chromedp" + "go.uber.org/zap" +) + +func TestSessionOpenLive(t *testing.T) { + if os.Getenv("DREAMCREATOR_BROWSER_OPEN_LIVE") != "1" { + t.Skip("set DREAMCREATOR_BROWSER_OPEN_LIVE=1 to run the live browser open probe") + } + + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("create logger: %v", err) + } + defer func() { + _ = logger.Sync() + }() + restore := zap.ReplaceGlobals(logger) + defer restore() + + status := ResolveStatus("", true) + if !status.Ready { + t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError)) + } + + targetURL := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_OPEN_URL")) + if targetURL == "" { + targetURL = "https://example.com" + } + + registry := NewSessionRegistry() + session := registry.GetOrCreate("live-open", "dreamcreator", SessionOptions{ + SessionKey: "live-open", + ProfileName: "dreamcreator", + PreferredBrowser: strings.TrimSpace(status.ChosenBrowser), + Headless: true, + }) + defer session.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + result, err := session.Open(ctx, targetURL, CommandOptions{ + Timeout: 30 * time.Second, + Limit: 50, + }) + if err != nil { + t.Fatalf("session open failed: %v", err) + } + + itemCount := len(result.Items) + if result.State != nil && result.State.ItemCount > 0 { + itemCount = result.State.ItemCount + } + t.Logf( + "open result ok=%v targetId=%s url=%s hasTitle=%v stateAvailable=%v stateVersion=%d itemCount=%d stateError=%s", + result.OK, + result.TargetID, + sanitizeLogURL(result.URL), + strings.TrimSpace(result.Title) != "", + result.State != nil || result.StateAvailable, + result.StateVersion, + itemCount, + result.StateError, + ) + + tab, resolveErr := session.resolveTab(result.TargetID, true) + if resolveErr != nil { + t.Fatalf("resolve tab: %v", resolveErr) + } + probeTargetContext(t, "tab.ctx", tab.ctx) + + browserProbeCtx, browserProbeCancel := chromedp.NewContext(session.runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(result.TargetID))) + defer browserProbeCancel() + probeTargetContext(t, "fresh-browserCtx#1", browserProbeCtx) + probeTargetContext(t, "fresh-browserCtx#2", browserProbeCtx) + + allocProbeCtx, allocProbeCancel := chromedp.NewContext(session.runtime.allocCtx, chromedp.WithTargetID(targetpkg.ID(result.TargetID))) + defer allocProbeCancel() + probeTargetContext(t, "fresh-allocCtx#1", allocProbeCtx) + probeTargetContext(t, "fresh-allocCtx#2", allocProbeCtx) +} + +func TestSessionWorkflowLive(t *testing.T) { + if os.Getenv("DREAMCREATOR_BROWSER_WORKFLOW_LIVE") != "1" { + t.Skip("set DREAMCREATOR_BROWSER_WORKFLOW_LIVE=1 to run the live browser workflow probe") + } + + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("create logger: %v", err) + } + defer func() { + _ = logger.Sync() + }() + restore := zap.ReplaceGlobals(logger) + defer restore() + + status := ResolveStatus("", true) + if !status.Ready { + t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError)) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + switch r.URL.Path { + case "/": + _, _ = fmt.Fprint(w, ` + + Workflow Home + +
+ + + +
+ Next Page + +`) + case "/search": + query := r.URL.Query().Get("q") + _, _ = fmt.Fprintf(w, ` + + Results %s + +

Result for %s

+ Next Page + +`, query, query) + case "/next": + _, _ = fmt.Fprint(w, ` + + Next Page + + + +`) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + registry := NewSessionRegistry() + session := registry.GetOrCreate("live-workflow", "dreamcreator", SessionOptions{ + SessionKey: "live-workflow", + ProfileName: "dreamcreator", + PreferredBrowser: strings.TrimSpace(status.ChosenBrowser), + Headless: true, + SSRFRules: SSRFPolicy{ + DangerouslyAllowPrivateNetwork: true, + }, + }) + defer session.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + openResult, err := session.Open(ctx, server.URL, CommandOptions{ + Timeout: 20 * time.Second, + Limit: 50, + }) + if err != nil { + t.Fatalf("open failed: %v", err) + } + stateResult, stateErr := session.State(openResult.TargetID, 50) + if stateErr != nil { + t.Logf("post-open state failed: %v", stateErr) + } else { + t.Logf("post-open state url=%s hasTitle=%v items=%d", sanitizeLogURL(stateResult.URL), strings.TrimSpace(stateResult.Title) != "", len(stateResult.Items)) + } + tab, resolveErr := session.resolveTab(openResult.TargetID, true) + if resolveErr != nil { + t.Fatalf("resolve workflow tab: %v", resolveErr) + } + probeTargetContext(t, "workflow-tab.ctx", tab.ctx) + browserProbeCtx, browserProbeCancel := chromedp.NewContext(session.runtime.BrowserContext(), chromedp.WithTargetID(targetpkg.ID(openResult.TargetID))) + defer browserProbeCancel() + probeTargetContext(t, "workflow-fresh-browserCtx", browserProbeCtx) + + currentItems := openResult.Items + if stateErr == nil && len(stateResult.Items) > 0 { + currentItems = stateResult.Items + } + inputRef := findSnapshotRef(currentItems, "textbox", "search weather") + if inputRef == "" { + t.Fatalf("textbox ref not found in latest snapshot: %+v", currentItems) + } + buttonRef := findSnapshotRef(currentItems, "button", "search") + if buttonRef == "" { + t.Fatalf("button ref not found in latest snapshot: %+v", currentItems) + } + + typeResult, err := session.Act(ctx, ActRequest{ + Kind: "type", + TargetID: openResult.TargetID, + Ref: inputRef, + Text: "weather", + Limit: 50, + Timeout: 15 * time.Second, + }) + if err != nil { + t.Fatalf("type failed: %v", err) + } + nextItems := typeResult.Items + if len(nextItems) == 0 { + nextItems = currentItems + } + buttonRef = findSnapshotRef(nextItems, "button", "search") + if buttonRef == "" { + t.Fatalf("button ref not found after type: %+v", nextItems) + } + + searchURL := server.URL + "/search?q=weather" + clickResult, err := session.Act(ctx, ActRequest{ + Kind: "click", + TargetID: openResult.TargetID, + Ref: buttonRef, + WaitFor: &WaitRequest{ + URL: searchURL, + Timeout: 8 * time.Second, + }, + Limit: 50, + Timeout: 20 * time.Second, + }) + if err != nil { + t.Fatalf("click failed: %v", err) + } + if !strings.Contains(clickResult.URL, "/search?q=weather") { + t.Fatalf("unexpected search url: %s", sanitizeLogURL(clickResult.URL)) + } + if !strings.Contains(clickResult.Title, "Results weather") { + t.Fatalf("unexpected search title") + } + + navigateResult, err := session.Navigate(ctx, clickResult.TargetID, server.URL+"/next", false, CommandOptions{ + Timeout: 20 * time.Second, + Limit: 50, + }) + if err != nil { + t.Fatalf("navigate failed: %v", err) + } + if !strings.Contains(navigateResult.URL, "/next") { + t.Fatalf("unexpected navigate url: %s", sanitizeLogURL(navigateResult.URL)) + } + if strings.TrimSpace(navigateResult.Title) != "Next Page" { + t.Fatalf("unexpected navigate title") + } + + t.Logf("workflow open/type/click/navigate succeeded targetId=%s finalURL=%s hasTitle=%v", navigateResult.TargetID, sanitizeLogURL(navigateResult.URL), strings.TrimSpace(navigateResult.Title) != "") +} + +func TestSessionBaiduSearchLive(t *testing.T) { + if os.Getenv("DREAMCREATOR_BROWSER_BAIDU_LIVE") != "1" { + t.Skip("set DREAMCREATOR_BROWSER_BAIDU_LIVE=1 to run the live Baidu browser probe") + } + + logger, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("create logger: %v", err) + } + defer func() { + _ = logger.Sync() + }() + restore := zap.ReplaceGlobals(logger) + defer restore() + + status := ResolveStatus("", false) + if !status.Ready { + t.Skipf("browser not available: %s", strings.TrimSpace(status.DetectError)) + } + + headless := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_BAIDU_HEADLESS")) == "1" + query := strings.TrimSpace(os.Getenv("DREAMCREATOR_BROWSER_BAIDU_QUERY")) + if query == "" { + query = "天气" + } + + registry := NewSessionRegistry() + session := registry.GetOrCreate("live-baidu", "dreamcreator", SessionOptions{ + SessionKey: "live-baidu", + ProfileName: "dreamcreator", + PreferredBrowser: strings.TrimSpace(status.ChosenBrowser), + Headless: headless, + }) + defer session.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + openResult, err := session.Open(ctx, "https://www.baidu.com", CommandOptions{ + Timeout: 30 * time.Second, + Limit: 50, + }) + if err != nil { + t.Fatalf("open failed: %v", err) + } + + inputRef := findSnapshotRef(openResult.Items, "textbox", "") + if inputRef == "" { + t.Fatalf("textbox ref not found in open items: %+v", openResult.Items) + } + + typeResult, err := session.Act(ctx, ActRequest{ + Kind: "type", + TargetID: openResult.TargetID, + Ref: inputRef, + Text: query, + Limit: 50, + Timeout: 15 * time.Second, + }) + if err != nil { + t.Fatalf("type failed: %v", err) + } + + buttonItems := typeResult.Items + if len(buttonItems) == 0 { + buttonItems = openResult.Items + } + buttonRef := findSnapshotRef(buttonItems, "button", "百度一下") + if buttonRef == "" { + buttonRef = findSnapshotRef(buttonItems, "button", "百度") + } + if buttonRef == "" { + t.Fatalf("search button ref not found after type: %+v", buttonItems) + } + + clickResult, err := session.Act(ctx, ActRequest{ + Kind: "click", + TargetID: openResult.TargetID, + Ref: buttonRef, + Limit: 50, + Timeout: 20 * time.Second, + }) + if err != nil { + t.Fatalf("click failed: %v", err) + } + + t.Logf( + "baidu click result targetId=%s url=%s hasTitle=%v stateAvailable=%v itemCount=%d stateError=%s", + clickResult.TargetID, + sanitizeLogURL(clickResult.URL), + strings.TrimSpace(clickResult.Title) != "", + clickResult.State != nil || clickResult.StateAvailable, + len(clickResult.Items), + clickResult.StateError, + ) + if !strings.Contains(clickResult.URL, "baidu.com") { + t.Fatalf("unexpected click url: %s", sanitizeLogURL(clickResult.URL)) + } + if !clickResult.StateAvailable && clickResult.State == nil { + t.Fatalf("click state unavailable: %s", clickResult.StateError) + } +} + +func probeTargetContext(t *testing.T, label string, ctx context.Context) { + t.Helper() + + probeCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + var payload struct { + URL string `json:"url"` + Title string `json:"title"` + ReadyState string `json:"readyState"` + } + err := chromedp.Run(probeCtx, chromedp.EvaluateAsDevTools(`({ + url: document.location.toString(), + title: document.title, + readyState: document.readyState, + })`, &payload)) + if err != nil { + t.Logf("%s probe failed: %v", label, err) + return + } + t.Logf( + "%s probe url=%s hasTitle=%v readyState=%s", + label, + sanitizeLogURL(payload.URL), + strings.TrimSpace(payload.Title) != "", + strings.TrimSpace(payload.ReadyState), + ) +} + +func findSnapshotRef(items []SnapshotItem, role string, needle string) string { + role = strings.TrimSpace(strings.ToLower(role)) + needle = strings.TrimSpace(strings.ToLower(needle)) + for _, item := range items { + if role != "" && strings.TrimSpace(strings.ToLower(item.Role)) != role { + continue + } + name := strings.ToLower(strings.TrimSpace(item.Name)) + text := strings.ToLower(strings.TrimSpace(item.Text)) + if needle == "" || strings.Contains(name, needle) || strings.Contains(text, needle) { + return strings.TrimSpace(item.Ref) + } + } + return "" +} diff --git a/internal/application/browsercdp/session_registry_test.go b/internal/application/browsercdp/session_registry_test.go new file mode 100644 index 0000000..e5529c0 --- /dev/null +++ b/internal/application/browsercdp/session_registry_test.go @@ -0,0 +1,68 @@ +package browsercdp + +import "testing" + +func TestSessionRegistryCloseSessionKeyRemovesSessions(t *testing.T) { + t.Parallel() + + registry := NewSessionRegistry() + first := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{}) + second := registry.GetOrCreate("session-a", "work", SessionOptions{}) + other := registry.GetOrCreate("session-b", "dreamcreator", SessionOptions{}) + + if first == nil || second == nil || other == nil { + t.Fatalf("expected sessions to be created") + } + + registry.CloseSessionKey("session-a") + + registry.mu.Lock() + _, existsA := registry.sessions["session-a"] + bucketB := registry.sessions["session-b"] + registry.mu.Unlock() + + if existsA { + t.Fatalf("expected session-a bucket to be removed") + } + if len(bucketB) != 1 { + t.Fatalf("expected session-b bucket to remain intact, got %#v", bucketB) + } + + recreated := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{}) + if recreated == nil { + t.Fatalf("expected recreated session") + } + if recreated == first { + t.Fatalf("expected session-a to be recreated after cleanup") + } +} + +func TestSessionRegistryCloseAllRemovesAllSessions(t *testing.T) { + t.Parallel() + + registry := NewSessionRegistry() + first := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{}) + second := registry.GetOrCreate("session-b", "work", SessionOptions{}) + + if first == nil || second == nil { + t.Fatalf("expected sessions to be created") + } + + registry.CloseAll() + + registry.mu.Lock() + sessionCount := len(registry.sessions) + registry.mu.Unlock() + + if sessionCount != 0 { + t.Fatalf("expected all sessions to be removed, got %d buckets", sessionCount) + } + + recreated := registry.GetOrCreate("session-a", "dreamcreator", SessionOptions{}) + if recreated == nil { + t.Fatalf("expected recreated session") + } + if recreated == first { + t.Fatalf("expected session to be recreated after close all") + } +} diff --git a/internal/application/browsercdp/session_ssrf_test.go b/internal/application/browsercdp/session_ssrf_test.go new file mode 100644 index 0000000..6a9c61f --- /dev/null +++ b/internal/application/browsercdp/session_ssrf_test.go @@ -0,0 +1,78 @@ +package browsercdp + +import ( + "context" + "net" + "strings" + "testing" +) + +func stubLookupIPAddrsForHost(t *testing.T, stub func(context.Context, string) ([]net.IPAddr, error)) { + t.Helper() + original := lookupIPAddrsForHost + lookupIPAddrsForHost = stub + t.Cleanup(func() { + lookupIPAddrsForHost = original + }) +} + +func TestAssertURLAllowedRejectsHostnameResolvingToPrivateIP(t *testing.T) { + stubLookupIPAddrsForHost(t, func(_ context.Context, host string) ([]net.IPAddr, error) { + if host != "public.example.com" { + t.Fatalf("unexpected host lookup %q", host) + } + return []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, nil + }) + + err := AssertURLAllowed("https://public.example.com/path", SSRFPolicy{ + DangerouslyAllowPrivateNetwork: false, + AllowedHostnames: map[string]struct{}{}, + }) + if err == nil || !strings.Contains(err.Error(), "resolving to private IP") { + t.Fatalf("expected resolved private IP to be blocked, got %v", err) + } +} + +func TestAssertURLAllowedAllowsHostnameResolvingToPublicIP(t *testing.T) { + stubLookupIPAddrsForHost(t, func(_ context.Context, host string) ([]net.IPAddr, error) { + if host != "public.example.com" { + t.Fatalf("unexpected host lookup %q", host) + } + return []net.IPAddr{{IP: net.ParseIP("93.184.216.34")}}, nil + }) + + err := AssertURLAllowed("https://public.example.com/path", SSRFPolicy{ + DangerouslyAllowPrivateNetwork: false, + AllowedHostnames: map[string]struct{}{}, + }) + if err != nil { + t.Fatalf("expected public hostname to be allowed, got %v", err) + } +} + +func TestAssertURLAllowedAllowsExplicitHostnameAllowlist(t *testing.T) { + stubLookupIPAddrsForHost(t, func(context.Context, string) ([]net.IPAddr, error) { + t.Fatalf("allowlisted hostname should not require DNS validation") + return nil, nil + }) + + err := AssertURLAllowed("http://localhost:3000", SSRFPolicy{ + DangerouslyAllowPrivateNetwork: false, + AllowedHostnames: map[string]struct{}{ + "localhost": {}, + }, + }) + if err != nil { + t.Fatalf("expected explicit hostname allowlist to bypass private host block, got %v", err) + } +} + +func TestAssertRequestURLAllowedRejectsPrivateWebsocketTargets(t *testing.T) { + err := assertRequestURLAllowed(context.Background(), "wss://127.0.0.1/socket", SSRFPolicy{ + DangerouslyAllowPrivateNetwork: false, + AllowedHostnames: map[string]struct{}{}, + }) + if err == nil || !strings.Contains(err.Error(), "blocked private IP") { + t.Fatalf("expected private websocket target to be blocked, got %v", err) + } +} diff --git a/internal/application/browsercdp/session_wait_test.go b/internal/application/browsercdp/session_wait_test.go new file mode 100644 index 0000000..d038025 --- /dev/null +++ b/internal/application/browsercdp/session_wait_test.go @@ -0,0 +1,33 @@ +package browsercdp + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestWaitOnTabTimeRespectsParentCancel(t *testing.T) { + session := &Session{} + parent, cancel := context.WithCancel(context.Background()) + time.AfterFunc(50*time.Millisecond, cancel) + + startedAt := time.Now() + err := session.waitOnTab(parent, nil, WaitRequest{Time: 2 * time.Second}, 2*time.Second) + elapsed := time.Since(startedAt) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + if elapsed >= 500*time.Millisecond { + t.Fatalf("expected wait to stop early, elapsed=%s", elapsed) + } +} + +func TestSanitizeLogURLRemovesSensitiveParts(t *testing.T) { + rawURL := "https://user:secret@example.com/path?q=token#frag" + sanitized := sanitizeLogURL(rawURL) + if sanitized != "https://example.com/path" { + t.Fatalf("unexpected sanitized url %q", sanitized) + } +} diff --git a/internal/application/browsercdp/user_data_dir.go b/internal/application/browsercdp/user_data_dir.go new file mode 100644 index 0000000..e25e743 --- /dev/null +++ b/internal/application/browsercdp/user_data_dir.go @@ -0,0 +1,33 @@ +package browsercdp + +import ( + "os" + "path/filepath" + "strings" + + "github.com/google/uuid" +) + +var browserProfileNamespace = uuid.MustParse("d8f7f362-4b0f-48d8-b8ac-0dcf3a6af6e2") + +func ResolveProfileStorageKey(sessionKey string, profileName string) string { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + sessionKey = "default" + } + profileName = strings.TrimSpace(profileName) + if profileName == "" { + profileName = "dreamcreator" + } + return uuid.NewSHA1(browserProfileNamespace, []byte(sessionKey+"\x00"+profileName)).String() +} + +func ResolveProfileUserDataDir(sessionKey string, profileName string) string { + return filepath.Join( + os.TempDir(), + "dreamcreator", + "browser", + "profiles", + ResolveProfileStorageKey(sessionKey, profileName), + ) +} diff --git a/internal/application/browsercdp/user_data_dir_test.go b/internal/application/browsercdp/user_data_dir_test.go new file mode 100644 index 0000000..d897135 --- /dev/null +++ b/internal/application/browsercdp/user_data_dir_test.go @@ -0,0 +1,75 @@ +package browsercdp + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/uuid" +) + +func TestResolveProfileStorageKeyStableAndDistinct(t *testing.T) { + t.Parallel() + + sessionKey := "v2::-::aui::-::788e3ff2-b4ad-426d-be0f-1424156032fd::-::788e3ff2-b4ad-426d-be0f-1424156032fd" + profileName := "dreamcreator" + + first := ResolveProfileStorageKey(sessionKey, profileName) + second := ResolveProfileStorageKey(sessionKey, profileName) + otherProfile := ResolveProfileStorageKey(sessionKey, "work") + + if first != second { + t.Fatalf("expected stable storage key, got %q and %q", first, second) + } + if first == otherProfile { + t.Fatalf("expected different profiles to use different storage keys") + } + if _, err := uuid.Parse(first); err != nil { + t.Fatalf("expected uuid storage key, got %q: %v", first, err) + } +} + +func TestResolveProfileUserDataDirUsesShallowSafePath(t *testing.T) { + t.Parallel() + + sessionKey := "v2::-::aui::-::788e3ff2-b4ad-426d-be0f-1424156032fd::-::788e3ff2-b4ad-426d-be0f-1424156032fd" + profileName := "dream:creator" + dir := ResolveProfileUserDataDir(sessionKey, profileName) + + if strings.Contains(dir, sessionKey) { + t.Fatalf("expected session key to stay out of path, got %q", dir) + } + if strings.Contains(dir, profileName) { + t.Fatalf("expected profile name to stay out of path, got %q", dir) + } + if strings.ContainsAny(filepath.Base(dir), `<>:"/\|?*`) { + t.Fatalf("expected final directory name to be path safe, got %q", filepath.Base(dir)) + } + + rel, err := filepath.Rel(os.TempDir(), dir) + if err != nil { + t.Fatalf("resolve relative path: %v", err) + } + parts := strings.Split(rel, string(os.PathSeparator)) + if len(parts) != 4 { + t.Fatalf("expected shallow temp path, got %q (%d parts)", rel, len(parts)) + } + if parts[0] != "dreamcreator" || parts[1] != "browser" || parts[2] != "profiles" { + t.Fatalf("unexpected profile path layout: %q", rel) + } +} + +func TestNormalizeSessionOptionsUsesEncodedProfileDir(t *testing.T) { + t.Parallel() + + options := normalizeSessionOptions(SessionOptions{ + SessionKey: "v2::-::aui::-::thread:1::-::thread:1", + ProfileName: "dream:creator", + }) + + want := ResolveProfileUserDataDir("v2::-::aui::-::thread:1::-::thread:1", "dream:creator") + if options.UserDataDir != want { + t.Fatalf("unexpected encoded user data dir: got %q want %q", options.UserDataDir, want) + } +} diff --git a/internal/application/channels/telegram/bot_service_approval_test.go b/internal/application/channels/telegram/bot_service_approval_test.go index b3ad849..e678119 100644 --- a/internal/application/channels/telegram/bot_service_approval_test.go +++ b/internal/application/channels/telegram/bot_service_approval_test.go @@ -201,7 +201,7 @@ func TestShouldSuppressTelegramResolvedForward(t *testing.T) { reason string want bool }{ - {name: "telegram with sender", reason: "telegram:5234834060", want: true}, + {name: "telegram with sender", reason: "telegram:test-user-001", want: true}, {name: "telegram plain", reason: "telegram", want: true}, {name: "upper telegram", reason: "TeLeGrAm:1", want: true}, {name: "gateway reason", reason: "gateway:web", want: false}, diff --git a/internal/application/connectors/dto/models.go b/internal/application/connectors/dto/models.go index 01225ea..ca819d2 100644 --- a/internal/application/connectors/dto/models.go +++ b/internal/application/connectors/dto/models.go @@ -8,6 +8,9 @@ type Connector struct { Status string `json:"status"` CookiesCount int `json:"cookiesCount"` Cookies []ConnectorCookie `json:"cookies"` + Domains []string `json:"domains,omitempty"` + PolicyKey string `json:"policyKey,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` LastVerifiedAt string `json:"lastVerifiedAt"` } @@ -22,10 +25,51 @@ type ClearConnectorRequest struct { ID string `json:"id"` } -type ConnectConnectorRequest struct { +type StartConnectorConnectRequest struct { ID string `json:"id"` } +type StartConnectorConnectResult struct { + SessionID string `json:"sessionId"` + Connector Connector `json:"connector"` +} + +type FinishConnectorConnectRequest struct { + SessionID string `json:"sessionId"` +} + +type FinishConnectorConnectResult struct { + SessionID string `json:"sessionId"` + Saved bool `json:"saved"` + RawCookiesCount int `json:"rawCookiesCount"` + FilteredCookiesCount int `json:"filteredCookiesCount"` + Domains []string `json:"domains,omitempty"` + Reason string `json:"reason,omitempty"` + Connector Connector `json:"connector"` +} + +type CancelConnectorConnectRequest struct { + SessionID string `json:"sessionId"` +} + +type ConnectorConnectSession struct { + SessionID string `json:"sessionId"` + ConnectorID string `json:"connectorId"` + State string `json:"state"` + Saved bool `json:"saved"` + RawCookiesCount int `json:"rawCookiesCount"` + FilteredCookiesCount int `json:"filteredCookiesCount"` + Domains []string `json:"domains,omitempty"` + Reason string `json:"reason,omitempty"` + Error string `json:"error,omitempty"` + LastCookiesAt string `json:"lastCookiesAt,omitempty"` + Connector Connector `json:"connector"` +} + +type GetConnectorConnectSessionRequest struct { + SessionID string `json:"sessionId"` +} + type OpenConnectorSiteRequest struct { ID string `json:"id"` } diff --git a/internal/application/connectors/service/cookies.go b/internal/application/connectors/service/cookies.go index 491ab04..373a991 100644 --- a/internal/application/connectors/service/cookies.go +++ b/internal/application/connectors/service/cookies.go @@ -1,27 +1,14 @@ package service import ( - "encoding/json" - "strings" "time" - "github.com/playwright-community/playwright-go" - "dreamcreator/internal/application/connectors/dto" + appcookies "dreamcreator/internal/application/cookies" + "dreamcreator/internal/application/sitepolicy" "dreamcreator/internal/domain/connectors" ) -type cookieRecord struct { - Name string `json:"name"` - Value string `json:"value"` - Domain string `json:"domain"` - Path string `json:"path"` - Expires int64 `json:"expires"` - HttpOnly bool `json:"httpOnly"` - Secure bool `json:"secure"` - SameSite string `json:"sameSite,omitempty"` -} - func mapConnectorDTO(item connectors.Connector) dto.Connector { cookies := decodeCookies(item.CookiesJSON) status := item.Status @@ -34,6 +21,7 @@ func mapConnectorDTO(item connectors.Connector) dto.Connector { if item.LastVerifiedAt != nil { lastVerified = item.LastVerifiedAt.Format(time.RFC3339) } + policy, _ := sitepolicy.ForConnectorType(string(item.Type)) return dto.Connector{ ID: item.ID, Type: string(item.Type), @@ -42,16 +30,23 @@ func mapConnectorDTO(item connectors.Connector) dto.Connector { Status: string(status), CookiesCount: len(cookies), Cookies: mapCookiesDTO(cookies), + Domains: append([]string(nil), policy.Domains...), + PolicyKey: policy.Key, + Capabilities: append([]string(nil), policy.Capabilities...), LastVerifiedAt: lastVerified, } } func connectorGroup(connectorType connectors.ConnectorType) string { switch connectorType { - case connectors.ConnectorGoogle, connectors.ConnectorXiaohongshu: + case connectors.ConnectorGoogle, connectors.ConnectorZhihu: return "search_engine" + case connectors.ConnectorXiaohongshu, connectors.ConnectorReddit, connectors.ConnectorX: + return "community" case connectors.ConnectorBilibili: return "video" + case connectors.ConnectorGitHub: + return "developer" default: return "other" } @@ -61,6 +56,14 @@ func connectorDesc(connectorType connectors.ConnectorType) string { switch connectorType { case connectors.ConnectorGoogle: return "Global web search with broad multilingual coverage, suitable for general factual lookups." + case connectors.ConnectorGitHub: + return "Developer collaboration and code hosting, useful for repositories, issues, pull requests, and docs." + case connectors.ConnectorReddit: + return "Community discussions and long-tail troubleshooting threads, useful for niche product and user experience research." + case connectors.ConnectorZhihu: + return "Chinese knowledge-sharing community content, useful for answers, articles, and topic exploration." + case connectors.ConnectorX: + return "Real-time public conversation and creator timelines, useful for posts, threads, and timely updates." case connectors.ConnectorXiaohongshu: return "Chinese lifestyle and recommendation community content, useful for reviews and trend discovery." case connectors.ConnectorBilibili: @@ -70,7 +73,7 @@ func connectorDesc(connectorType connectors.ConnectorType) string { } } -func mapCookiesDTO(records []cookieRecord) []dto.ConnectorCookie { +func mapCookiesDTO(records []appcookies.Record) []dto.ConnectorCookie { if len(records) == 0 { return nil } @@ -90,99 +93,10 @@ func mapCookiesDTO(records []cookieRecord) []dto.ConnectorCookie { return result } -func encodeCookies(records []cookieRecord) (string, error) { - if len(records) == 0 { - return "", nil - } - data, err := json.Marshal(records) - if err != nil { - return "", err - } - return string(data), nil -} - -func decodeCookies(data string) []cookieRecord { - trimmed := strings.TrimSpace(data) - if trimmed == "" { - return nil - } - var records []cookieRecord - if err := json.Unmarshal([]byte(trimmed), &records); err != nil { - return nil - } - return records +func encodeCookies(records []appcookies.Record) (string, error) { + return appcookies.EncodeJSON(records) } -func cookiesFromPlaywright(cookies []playwright.Cookie) []cookieRecord { - if len(cookies) == 0 { - return nil - } - result := make([]cookieRecord, 0, len(cookies)) - for _, cookie := range cookies { - sameSite := "" - if cookie.SameSite != nil { - sameSite = strings.ToLower(string(*cookie.SameSite)) - } - result = append(result, cookieRecord{ - Name: cookie.Name, - Value: cookie.Value, - Domain: cookie.Domain, - Path: cookie.Path, - Expires: int64(cookie.Expires), - HttpOnly: cookie.HttpOnly, - Secure: cookie.Secure, - SameSite: sameSite, - }) - } - return result -} - -func toPlaywrightCookies(records []cookieRecord, targetURL string) []playwright.OptionalCookie { - if len(records) == 0 { - return nil - } - result := make([]playwright.OptionalCookie, 0, len(records)) - for _, record := range records { - if strings.TrimSpace(record.Name) == "" { - continue - } - cookie := playwright.OptionalCookie{ - Name: record.Name, - Value: record.Value, - } - domain := strings.TrimSpace(record.Domain) - path := strings.TrimSpace(record.Path) - if domain != "" { - cookie.Domain = playwright.String(domain) - } else if strings.TrimSpace(targetURL) != "" { - cookie.URL = playwright.String(strings.TrimSpace(targetURL)) - } - if path != "" { - cookie.Path = playwright.String(path) - } - if record.Expires > 0 { - cookie.Expires = playwright.Float(float64(record.Expires)) - } - cookie.HttpOnly = playwright.Bool(record.HttpOnly) - cookie.Secure = playwright.Bool(record.Secure) - if sameSite := mapSameSite(record.SameSite); sameSite != nil { - cookie.SameSite = sameSite - } - result = append(result, cookie) - } - return result -} - -func mapSameSite(value string) *playwright.SameSiteAttribute { - normalized := strings.ToLower(strings.TrimSpace(value)) - switch normalized { - case "lax": - return playwright.SameSiteAttributeLax - case "strict": - return playwright.SameSiteAttributeStrict - case "none": - return playwright.SameSiteAttributeNone - default: - return nil - } +func decodeCookies(data string) []appcookies.Record { + return appcookies.DecodeJSON(data) } diff --git a/internal/application/connectors/service/export.go b/internal/application/connectors/service/export.go index 5033548..acf0e56 100644 --- a/internal/application/connectors/service/export.go +++ b/internal/application/connectors/service/export.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + appcookies "dreamcreator/internal/application/cookies" "dreamcreator/internal/domain/connectors" ) @@ -58,7 +59,7 @@ func (service *ConnectorsService) ExportConnectorCookies(ctx context.Context, id return path, nil } -func writeJSONCookies(path string, cookies []cookieRecord) error { +func writeJSONCookies(path string, cookies []appcookies.Record) error { data, err := json.MarshalIndent(cookies, "", " ") if err != nil { return err @@ -66,7 +67,7 @@ func writeJSONCookies(path string, cookies []cookieRecord) error { return writeFileAtomic(path, data, 0o600) } -func writeNetscapeCookies(path string, cookies []cookieRecord) error { +func writeNetscapeCookies(path string, cookies []appcookies.Record) error { builder := strings.Builder{} builder.WriteString("# Netscape HTTP Cookie File\n") builder.WriteString("# This file was generated by DreamCreator.\n") diff --git a/internal/application/connectors/service/login.go b/internal/application/connectors/service/login.go index 4a235ed..5f33033 100644 --- a/internal/application/connectors/service/login.go +++ b/internal/application/connectors/service/login.go @@ -2,66 +2,360 @@ package service import ( "context" + "errors" + "fmt" + "log" + "os" + "path/filepath" + "sort" "strings" "time" - "github.com/playwright-community/playwright-go" + cdptarget "github.com/chromedp/cdproto/target" + "github.com/chromedp/chromedp" + "dreamcreator/internal/application/browsercdp" "dreamcreator/internal/application/connectors/dto" + appcookies "dreamcreator/internal/application/cookies" + "dreamcreator/internal/application/sitepolicy" "dreamcreator/internal/domain/connectors" ) -func (service *ConnectorsService) ConnectConnector(ctx context.Context, request dto.ConnectConnectorRequest) (dto.Connector, error) { +func (service *ConnectorsService) StartConnectorConnect(ctx context.Context, request dto.StartConnectorConnectRequest) (dto.StartConnectorConnectResult, error) { id := strings.TrimSpace(request.ID) if id == "" { - return dto.Connector{}, connectors.ErrInvalidConnector + return dto.StartConnectorConnectResult{}, connectors.ErrInvalidConnector } connector, err := service.repo.Get(ctx, id) if err != nil { - return dto.Connector{}, err + return dto.StartConnectorConnectResult{}, err } targetURL, err := connectorHomeURL(connector.Type) if err != nil { - return dto.Connector{}, err + return dto.StartConnectorConnectResult{}, err } - cookies, err := runPlaywrightLogin(ctx, targetURL) + + sessionID := service.newSessionID() + userDataDir := connectorSessionDir(connector.Type, sessionID) + runtime, tabCtx, cancel, err := service.startBrowser(service.preferredBrowser(ctx), false, userDataDir) if err != nil { - return dto.Connector{}, err + return dto.StartConnectorConnectResult{}, err + } + if err := chromedp.Run(tabCtx, chromedp.Navigate(targetURL)); err != nil { + cancel() + runtime.Stop() + if service.removeAll != nil { + _ = service.removeAll(userDataDir) + } + return dto.StartConnectorConnectResult{}, err + } + + session := &connectorSession{ + ID: sessionID, + ConnectorID: connector.ID, + ConnectorType: connector.Type, + Runtime: runtime, + TabCtx: tabCtx, + Cancel: cancel, + UserDataDir: userDataDir, + State: connectorSessionStateRunning, + ConnectorSnapshot: mapConnectorDTO(connector), + finalizeDone: make(chan struct{}), + } + if current := chromedp.FromContext(tabCtx); current != nil && current.Target != nil { + session.TargetID = current.Target.TargetID } - cookiesJSON, err := encodeCookies(cookies) + + replaced := service.putSession(session) + service.cleanupSession(replaced) + service.startConnectSessionMonitor(sessionID) + log.Printf("connectors: started connect session id=%s connector=%s target=%s userDataDir=%s", sessionID, connector.Type, session.TargetID, userDataDir) + + return dto.StartConnectorConnectResult{ + SessionID: sessionID, + Connector: mapConnectorDTO(connector), + }, nil +} + +func (service *ConnectorsService) FinishConnectorConnect(ctx context.Context, request dto.FinishConnectorConnectRequest) (dto.FinishConnectorConnectResult, error) { + sessionID := strings.TrimSpace(request.SessionID) + if sessionID == "" { + return dto.FinishConnectorConnectResult{}, connectors.ErrConnectorSessionGone + } + result, _, err := service.finalizeConnectSession(ctx, sessionID, "manual_finish") if err != nil { - return dto.Connector{}, err + return dto.FinishConnectorConnectResult{}, err } + return result, nil +} - now := service.now() - status := connectors.StatusDisconnected - var lastVerifiedAt *time.Time - if len(cookies) > 0 { - status = connectors.StatusConnected - lastVerifiedAt = &now +func (service *ConnectorsService) CancelConnectorConnect(ctx context.Context, request dto.CancelConnectorConnectRequest) error { + sessionID := strings.TrimSpace(request.SessionID) + if sessionID == "" { + return connectors.ErrConnectorSessionGone + } + log.Printf("connectors: canceled connect session id=%s", sessionID) + service.cleanupSession(service.popSession(sessionID)) + return nil +} + +func (service *ConnectorsService) finalizeConnectSession(ctx context.Context, sessionID string, reason string) (dto.FinishConnectorConnectResult, bool, error) { + session, ok := service.getSession(sessionID) + if !ok || session == nil { + return dto.FinishConnectorConnectResult{}, false, connectors.ErrConnectorSessionGone + } + triggered := false + session.finalizeOnce.Do(func() { + triggered = true + result, err := service.performFinalize(ctx, session, reason) + service.mu.Lock() + defer service.mu.Unlock() + if err != nil { + session.State = connectorSessionStateFailed + session.FinalError = err.Error() + } else { + session.State = connectorSessionStateCompleted + session.FinalError = "" + session.FinalResult = &result + session.ConnectorSnapshot = result.Connector + } + close(session.finalizeDone) + }) + <-session.finalizeDone + + session, ok = service.getSession(sessionID) + if !ok || session == nil { + return dto.FinishConnectorConnectResult{}, triggered, connectors.ErrConnectorSessionGone + } + service.mu.Lock() + defer service.mu.Unlock() + if session.FinalError != "" { + return dto.FinishConnectorConnectResult{}, triggered, errors.New(session.FinalError) + } + if session.FinalResult == nil { + return dto.FinishConnectorConnectResult{}, triggered, connectors.ErrConnectorSessionDead + } + return *session.FinalResult, triggered, nil +} + +func (service *ConnectorsService) performFinalize(ctx context.Context, session *connectorSession, reason string) (dto.FinishConnectorConnectResult, error) { + if session == nil { + return dto.FinishConnectorConnectResult{}, connectors.ErrConnectorSessionGone + } + + log.Printf("connectors: finalize requested session=%s connector=%s reason=%s", session.ID, session.ConnectorType, reason) + records, err := readConnectorCookiesFromRuntime(session.Runtime) + if err != nil { + log.Printf("connectors: live cookie read failed session=%s connector=%s reason=%s err=%v", session.ID, session.ConnectorType, reason, err) + service.mu.Lock() + records = append([]appcookies.Record(nil), session.LastCookies...) + service.mu.Unlock() + } else { + service.updateSession(session.ID, func(current *connectorSession) { + current.LastCookies = append([]appcookies.Record(nil), records...) + current.LastCookiesAt = service.now() + }) + } + + policy, _ := sitepolicy.ForConnectorType(string(session.ConnectorType)) + filtered := appcookies.FilterByDomains(records, policy.Domains) + log.Printf("connectors: finalize cookies session=%s connector=%s reason=%s raw=%d filtered=%d domains=%s", session.ID, session.ConnectorType, reason, len(records), len(filtered), strings.Join(cookieDomains(filtered), ",")) + + current, err := service.repo.Get(ctx, session.ConnectorID) + if err != nil { + service.cleanupSession(session) + return dto.FinishConnectorConnectResult{}, err + } + + result := dto.FinishConnectorConnectResult{ + SessionID: session.ID, + Saved: len(filtered) > 0, + RawCookiesCount: len(records), + FilteredCookiesCount: len(filtered), + Domains: cookieDomains(filtered), + Reason: reason, + Connector: mapConnectorDTO(current), + } + if len(filtered) == 0 { + service.cleanupSession(session) + log.Printf("connectors: finalize completed without matching cookies session=%s connector=%s reason=%s", session.ID, session.ConnectorType, reason) + return result, nil + } + + cookiesJSON, err := encodeCookies(filtered) + if err != nil { + service.cleanupSession(session) + return dto.FinishConnectorConnectResult{}, err } + now := service.now() updated, err := connectors.NewConnector(connectors.ConnectorParams{ - ID: connector.ID, - Type: string(connector.Type), - Status: string(status), + ID: current.ID, + Type: string(current.Type), + Status: string(connectors.StatusConnected), CookiesJSON: cookiesJSON, - LastVerifiedAt: lastVerifiedAt, - CreatedAt: &connector.CreatedAt, + LastVerifiedAt: &now, + CreatedAt: ¤t.CreatedAt, UpdatedAt: &now, }) if err != nil { - return dto.Connector{}, err + service.cleanupSession(session) + return dto.FinishConnectorConnectResult{}, err } if err := service.repo.Save(ctx, updated); err != nil { - return dto.Connector{}, err + service.cleanupSession(session) + return dto.FinishConnectorConnectResult{}, err + } + result.Connector = mapConnectorDTO(updated) + service.cleanupSession(session) + log.Printf("connectors: finalize saved cookies session=%s connector=%s reason=%s filtered=%d", session.ID, session.ConnectorType, reason, len(filtered)) + return result, nil +} + +func (service *ConnectorsService) startConnectSessionMonitor(sessionID string) { + session, ok := service.getSession(sessionID) + if !ok || session == nil { + return } - return mapConnectorDTO(updated), nil + service.watchConnectSessionTarget(sessionID, session) + service.watchConnectSessionBrowser(sessionID, session) + + go func() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + session, ok := service.getSession(sessionID) + if !ok || session == nil { + return + } + service.mu.Lock() + state := session.State + runtime := session.Runtime + targetID := session.TargetID + tabCtx := session.TabCtx + service.mu.Unlock() + if state != connectorSessionStateRunning { + return + } + if runtime == nil || !runtime.Status().Ready { + service.triggerSessionFinalize(sessionID, "browser_closed") + return + } + if cookies, err := readConnectorCookiesFromRuntime(runtime); err == nil { + service.updateSession(sessionID, func(current *connectorSession) { + current.LastCookies = append([]appcookies.Record(nil), cookies...) + current.LastCookiesAt = service.now() + }) + } + if targetID != "" { + exists, err := connectorTargetExists(runtime, targetID) + if err == nil && !exists { + service.triggerSessionFinalize(sessionID, "tab_closed") + return + } + } + var tabDone <-chan struct{} + if tabCtx != nil { + tabDone = tabCtx.Done() + } + var browserDone <-chan struct{} + if browserCtx := runtime.BrowserContext(); browserCtx != nil { + browserDone = browserCtx.Done() + } + + select { + case <-ticker.C: + case <-tabDone: + service.triggerSessionFinalize(sessionID, "tab_closed") + return + case <-browserDone: + service.triggerSessionFinalize(sessionID, "browser_closed") + return + } + } + }() +} + +func (service *ConnectorsService) watchConnectSessionTarget(sessionID string, session *connectorSession) { + if session == nil || session.TabCtx == nil { + return + } + targetID := session.TargetID + chromedp.ListenTarget(session.TabCtx, func(ev any) { + switch current := ev.(type) { + case *cdptarget.EventTargetDestroyed: + if targetID != "" && current.TargetID != targetID { + return + } + service.triggerSessionFinalize(sessionID, "tab_closed") + case *cdptarget.EventTargetCrashed: + if targetID != "" && current.TargetID != targetID { + return + } + service.triggerSessionFinalize(sessionID, "tab_closed") + case *cdptarget.EventDetachedFromTarget: + service.triggerSessionFinalize(sessionID, "tab_closed") + } + }) +} + +func (service *ConnectorsService) watchConnectSessionBrowser(sessionID string, session *connectorSession) { + if session == nil || session.Runtime == nil || session.Runtime.BrowserContext() == nil { + return + } + go func(browserCtx context.Context) { + <-browserCtx.Done() + service.triggerSessionFinalize(sessionID, "browser_closed") + }(session.Runtime.BrowserContext()) +} + +func (service *ConnectorsService) triggerSessionFinalize(sessionID string, reason string) { + go func() { + _, _, err := service.finalizeConnectSession(context.Background(), sessionID, reason) + if err != nil && !errors.Is(err, connectors.ErrConnectorSessionGone) { + log.Printf("connectors: auto-finalize failed session=%s reason=%s err=%v", sessionID, reason, err) + } + }() +} + +func connectorTargetExists(runtime *browsercdp.Runtime, targetID cdptarget.ID) (bool, error) { + if runtime == nil || targetID == "" { + return true, nil + } + timeoutCtx, cancel := context.WithTimeout(runtime.BrowserContext(), 3*time.Second) + defer cancel() + + var exists bool + if err := chromedp.Run(timeoutCtx, chromedp.ActionFunc(func(actionCtx context.Context) error { + targets, err := cdptarget.GetTargets().Do(actionCtx) + if err != nil { + return err + } + for _, info := range targets { + if info != nil && info.TargetID == targetID { + exists = true + break + } + } + return nil + })); err != nil { + return false, err + } + return exists, nil } func connectorHomeURL(connectorType connectors.ConnectorType) (string, error) { switch connectorType { case connectors.ConnectorGoogle: return "https://www.google.com/", nil + case connectors.ConnectorGitHub: + return "https://github.com/", nil + case connectors.ConnectorReddit: + return "https://www.reddit.com/", nil + case connectors.ConnectorZhihu: + return "https://www.zhihu.com/", nil + case connectors.ConnectorX: + return "https://x.com/", nil case connectors.ConnectorXiaohongshu: return "https://www.xiaohongshu.com/", nil case connectors.ConnectorBilibili: @@ -71,48 +365,167 @@ func connectorHomeURL(connectorType connectors.ConnectorType) (string, error) { } } -func runPlaywrightLogin(ctx context.Context, targetURL string) ([]cookieRecord, error) { - pw, err := playwright.Run() +func startConnectorBrowser(preferredBrowser string, headless bool, userDataDir string) (*browsercdp.Runtime, context.Context, context.CancelFunc, error) { + runtime, err := browsercdp.Start(context.Background(), browsercdp.LaunchOptions{ + PreferredBrowser: preferredBrowser, + Headless: headless, + UserDataDir: userDataDir, + }) if err != nil { - return nil, mapPlaywrightInstallError(err) + return nil, nil, nil, err } - defer pw.Stop() - - browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(false), - }) + tabCtx, cancel, err := attachConnectorTab(runtime) if err != nil { - return nil, mapPlaywrightInstallError(err) + runtime.Stop() + return nil, nil, nil, err } - defer browser.Close() + return runtime, tabCtx, cancel, nil +} - browserCtx, err := browser.NewContext() - if err != nil { - return nil, err +func attachConnectorTab(runtime *browsercdp.Runtime) (context.Context, context.CancelFunc, error) { + if runtime == nil { + return nil, nil, connectors.ErrConnectorSessionDead } - page, err := browserCtx.NewPage() + targets, err := chromedp.Targets(runtime.BrowserContext()) if err != nil { + return nil, nil, err + } + + targetID := selectConnectorStartupTarget(targets) + var tabCtx context.Context + var cancel context.CancelFunc + if targetID != "" { + tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext(), chromedp.WithTargetID(targetID)) + } else { + tabCtx, cancel = chromedp.NewContext(runtime.BrowserContext()) + } + if err := chromedp.Run(tabCtx); err != nil { + cancel() + return nil, nil, err + } + return tabCtx, cancel, nil +} + +func selectConnectorStartupTarget(targets []*cdptarget.Info) cdptarget.ID { + var fallback cdptarget.ID + for _, item := range targets { + if item == nil || item.Type != "page" || item.Attached { + continue + } + if isConnectorStartupBlank(item.URL) { + return item.TargetID + } + if fallback == "" { + fallback = item.TargetID + } + } + return fallback +} + +func isConnectorStartupBlank(targetURL string) bool { + trimmed := strings.TrimSpace(targetURL) + if trimmed == "" || trimmed == "about:blank" { + return true + } + return strings.HasPrefix(trimmed, "chrome://newtab") || + strings.HasPrefix(trimmed, "edge://newtab") || + strings.HasPrefix(trimmed, "brave://newtab") +} + +func readConnectorCookies(ctx context.Context) ([]appcookies.Record, error) { + var records []appcookies.Record + if err := chromedp.Run(ctx, chromedp.ActionFunc(func(actionCtx context.Context) error { + items, err := browsercdp.GetAllCookies(actionCtx) + if err != nil { + return err + } + records = items + return nil + })); err != nil { return nil, err } - if _, err := page.Goto(targetURL); err != nil { + return records, nil +} + +func readConnectorCookiesFromRuntime(runtime *browsercdp.Runtime) ([]appcookies.Record, error) { + if runtime == nil { + return nil, connectors.ErrConnectorSessionDead + } + timeoutCtx, cancel := context.WithTimeout(runtime.BrowserContext(), 8*time.Second) + defer cancel() + + var records []appcookies.Record + if err := chromedp.Run(timeoutCtx, chromedp.ActionFunc(func(actionCtx context.Context) error { + items, err := browsercdp.GetStorageCookies(actionCtx) + if err != nil { + return err + } + records = items + return nil + })); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, fmt.Errorf("connector cookie read timed out: %w", err) + } return nil, err } + return records, nil +} - closeTicker := time.NewTicker(500 * time.Millisecond) - defer closeTicker.Stop() +func waitForConnectorTabClose(ctx context.Context, runtime *browsercdp.Runtime, tabCtx context.Context, captureCookies bool, readCookies func(context.Context) ([]appcookies.Record, error)) ([]appcookies.Record, error) { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + var latest []appcookies.Record for { select { case <-ctx.Done(): - return nil, ctx.Err() - case <-closeTicker.C: - if page.IsClosed() { - cookies, err := browserCtx.Cookies() - if err != nil { - return nil, err + return latest, ctx.Err() + case <-ticker.C: + if captureCookies && readCookies != nil { + if cookies, err := readCookies(tabCtx); err == nil { + latest = cookies } - return cookiesFromPlaywright(cookies), nil } + if runtime == nil || !runtime.Status().Ready { + return latest, nil + } + var currentURL string + if err := chromedp.Run(tabCtx, chromedp.Location(¤tURL)); err != nil { + return latest, nil + } + } + } +} + +func connectorSessionDir(connectorType connectors.ConnectorType, sessionID string) string { + return filepath.Join(connectorSessionRootDir(), string(connectorType), sessionID) +} + +func connectorOpenDir(connectorType connectors.ConnectorType, sessionID string) string { + return filepath.Join(connectorSessionRootDir(), "open", string(connectorType), sessionID) +} + +func connectorSessionRootDir() string { + return filepath.Join(os.TempDir(), "dreamcreator", "connectors") +} + +func cookieDomains(records []appcookies.Record) []string { + if len(records) == 0 { + return nil + } + seen := make(map[string]struct{}, len(records)) + result := make([]string, 0, len(records)) + for _, record := range records { + domain := strings.TrimSpace(record.Domain) + if domain == "" { + continue + } + if _, ok := seen[domain]; ok { + continue } + seen[domain] = struct{}{} + result = append(result, domain) } + sort.Strings(result) + return result } diff --git a/internal/application/connectors/service/open.go b/internal/application/connectors/service/open.go index 15777c7..74688bf 100644 --- a/internal/application/connectors/service/open.go +++ b/internal/application/connectors/service/open.go @@ -3,10 +3,10 @@ package service import ( "context" "strings" - "time" - "github.com/playwright-community/playwright-go" + "github.com/chromedp/chromedp" + "dreamcreator/internal/application/browsercdp" "dreamcreator/internal/application/connectors/dto" "dreamcreator/internal/domain/connectors" ) @@ -28,52 +28,27 @@ func (service *ConnectorsService) OpenConnectorSite(ctx context.Context, request if err != nil { return err } - return runPlaywrightOpenWithCookies(ctx, targetURL, cookies) -} - -func runPlaywrightOpenWithCookies(ctx context.Context, targetURL string, cookies []cookieRecord) error { - pw, err := playwright.Run() - if err != nil { - return mapPlaywrightInstallError(err) - } - defer pw.Stop() - - browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(false), - }) - if err != nil { - return mapPlaywrightInstallError(err) - } - defer browser.Close() - - browserCtx, err := browser.NewContext() + userDataDir := connectorOpenDir(connector.Type, service.newSessionID()) + runtime, tabCtx, cancel, err := service.startBrowser(service.preferredBrowser(ctx), false, userDataDir) if err != nil { return err } - if len(cookies) > 0 { - if err := browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL)); err != nil { - return err + defer cancel() + defer runtime.Stop() + defer func() { + if service.removeAll != nil { + _ = service.removeAll(userDataDir) } - } - page, err := browserCtx.NewPage() - if err != nil { + }() + + if err := chromedp.Run(tabCtx, chromedp.ActionFunc(func(ctx context.Context) error { + return browsercdp.SetCookies(ctx, targetURL, cookies) + })); err != nil { return err } - if _, err := page.Goto(targetURL); err != nil { + if err := chromedp.Run(tabCtx, chromedp.Navigate(targetURL)); err != nil { return err } - - closeTicker := time.NewTicker(500 * time.Millisecond) - defer closeTicker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-closeTicker.C: - if page.IsClosed() { - return nil - } - } - } + _, err = waitForConnectorTabClose(ctx, runtime, tabCtx, false, service.readCookies) + return err } diff --git a/internal/application/connectors/service/playwright.go b/internal/application/connectors/service/playwright.go deleted file mode 100644 index 414c63d..0000000 --- a/internal/application/connectors/service/playwright.go +++ /dev/null @@ -1,44 +0,0 @@ -package service - -import ( - "context" - "errors" - "strings" - - "github.com/playwright-community/playwright-go" -) - -var ErrPlaywrightNotInstalled = errors.New("playwright not installed") - -func (service *ConnectorsService) InstallPlaywright(_ context.Context) error { - return playwright.Install(&playwright.RunOptions{ - Browsers: []string{"chromium"}, - }) -} - -func mapPlaywrightInstallError(err error) error { - if err == nil { - return nil - } - if isPlaywrightInstallError(err.Error()) { - return ErrPlaywrightNotInstalled - } - return err -} - -func isPlaywrightInstallError(message string) bool { - lowered := strings.ToLower(message) - if strings.Contains(lowered, "playwright not installed") { - return true - } - if strings.Contains(lowered, "please install") && (strings.Contains(lowered, "playwright") || strings.Contains(lowered, "driver")) { - return true - } - if strings.Contains(lowered, "driver exists but version not") { - return true - } - if strings.Contains(lowered, "executable doesn't exist") { - return true - } - return false -} diff --git a/internal/application/connectors/service/service.go b/internal/application/connectors/service/service.go index 60ee791..0df586f 100644 --- a/internal/application/connectors/service/service.go +++ b/internal/application/connectors/service/service.go @@ -2,33 +2,111 @@ package service import ( "context" + "os" "strings" + "sync" "time" + "github.com/chromedp/cdproto/target" "github.com/google/uuid" + "dreamcreator/internal/application/browsercdp" "dreamcreator/internal/application/connectors/dto" + appcookies "dreamcreator/internal/application/cookies" + settingsdto "dreamcreator/internal/application/settings/dto" "dreamcreator/internal/domain/connectors" ) +type SettingsReader interface { + GetSettings(ctx context.Context) (settingsdto.Settings, error) +} + type ConnectorsService struct { - repo connectors.Repository - now func() time.Time + repo connectors.Repository + settings SettingsReader + now func() time.Time + + mu sync.Mutex + sessions map[string]*connectorSession + sessionsByConnector map[string]string + startBrowser func(preferredBrowser string, headless bool, userDataDir string) (*browsercdp.Runtime, context.Context, context.CancelFunc, error) + readCookies func(ctx context.Context) ([]appcookies.Record, error) + removeAll func(path string) error + newSessionID func() string +} + +const ( + connectorSessionStateRunning = "running" + connectorSessionStateCompleted = "completed" + connectorSessionStateFailed = "failed" +) + +type connectorSession struct { + ID string + ConnectorID string + ConnectorType connectors.ConnectorType + Runtime *browsercdp.Runtime + TabCtx context.Context + Cancel context.CancelFunc + UserDataDir string + TargetID target.ID + State string + LastCookies []appcookies.Record + LastCookiesAt time.Time + FinalResult *dto.FinishConnectorConnectResult + FinalError string + ConnectorSnapshot dto.Connector + finalizeOnce sync.Once + finalizeDone chan struct{} } -func NewConnectorsService(repo connectors.Repository) *ConnectorsService { +func NewConnectorsService(repo connectors.Repository, settings SettingsReader) *ConnectorsService { return &ConnectorsService{ - repo: repo, - now: time.Now, + repo: repo, + settings: settings, + now: time.Now, + sessions: make(map[string]*connectorSession), + sessionsByConnector: make(map[string]string), + startBrowser: startConnectorBrowser, + readCookies: readConnectorCookies, + removeAll: os.RemoveAll, + newSessionID: uuid.NewString, } } +func (service *ConnectorsService) preferredBrowser(ctx context.Context) string { + if service == nil || service.settings == nil { + return "" + } + current, err := service.settings.GetSettings(ctx) + if err != nil { + return "" + } + if tools := current.Tools; tools != nil { + if browserRaw, ok := tools["browser"].(map[string]any); ok && browserRaw != nil { + if value, ok := browserRaw["preferredBrowser"].(string); ok && strings.TrimSpace(value) != "" { + return strings.ToLower(strings.TrimSpace(value)) + } + } + if fetchRaw, ok := tools["web_fetch"].(map[string]any); ok && fetchRaw != nil { + if value, ok := fetchRaw["preferredBrowser"].(string); ok && strings.TrimSpace(value) != "" { + return strings.ToLower(strings.TrimSpace(value)) + } + } + } + return "" +} + func (service *ConnectorsService) EnsureDefaults(ctx context.Context) error { defaults := []struct { ID string Type connectors.ConnectorType }{ {ID: "connector-google", Type: connectors.ConnectorGoogle}, + {ID: "connector-github", Type: connectors.ConnectorGitHub}, + {ID: "connector-reddit", Type: connectors.ConnectorReddit}, + {ID: "connector-zhihu", Type: connectors.ConnectorZhihu}, + {ID: "connector-x", Type: connectors.ConnectorX}, {ID: "connector-xiaohongshu", Type: connectors.ConnectorXiaohongshu}, {ID: "connector-bilibili", Type: connectors.ConnectorBilibili}, } @@ -163,9 +241,148 @@ func (service *ConnectorsService) ClearConnector(ctx context.Context, request dt return service.repo.Save(ctx, updated) } +func (service *ConnectorsService) putSession(session *connectorSession) *connectorSession { + if service == nil || session == nil { + return nil + } + service.mu.Lock() + defer service.mu.Unlock() + + var replaced *connectorSession + if currentID, ok := service.sessionsByConnector[session.ConnectorID]; ok && currentID != "" { + replaced = service.sessions[currentID] + delete(service.sessions, currentID) + } + service.sessions[session.ID] = session + service.sessionsByConnector[session.ConnectorID] = session.ID + return replaced +} + +func (service *ConnectorsService) getSession(sessionID string) (*connectorSession, bool) { + if service == nil { + return nil, false + } + service.mu.Lock() + defer service.mu.Unlock() + session, ok := service.sessions[sessionID] + return session, ok +} + +func (service *ConnectorsService) updateSession(sessionID string, update func(session *connectorSession)) (*connectorSession, bool) { + if service == nil { + return nil, false + } + service.mu.Lock() + defer service.mu.Unlock() + session, ok := service.sessions[sessionID] + if !ok || session == nil { + return nil, false + } + update(session) + return session, true +} + +func (service *ConnectorsService) popSession(sessionID string) *connectorSession { + if service == nil { + return nil + } + service.mu.Lock() + defer service.mu.Unlock() + + session := service.sessions[sessionID] + if session == nil { + return nil + } + delete(service.sessions, sessionID) + if currentID, ok := service.sessionsByConnector[session.ConnectorID]; ok && currentID == sessionID { + delete(service.sessionsByConnector, session.ConnectorID) + } + return session +} + +func (service *ConnectorsService) cleanupSession(session *connectorSession) { + if session == nil { + return + } + if session.Cancel != nil { + session.Cancel() + } + if session.Runtime != nil { + session.Runtime.Stop() + } + if service.removeAll != nil && strings.TrimSpace(session.UserDataDir) != "" { + _ = service.removeAll(session.UserDataDir) + } +} + +func (service *ConnectorsService) GetConnectorConnectSession(ctx context.Context, request dto.GetConnectorConnectSessionRequest) (dto.ConnectorConnectSession, error) { + sessionID := strings.TrimSpace(request.SessionID) + if sessionID == "" { + return dto.ConnectorConnectSession{}, connectors.ErrConnectorSessionGone + } + session, ok := service.getSession(sessionID) + if !ok { + return dto.ConnectorConnectSession{}, connectors.ErrConnectorSessionGone + } + return service.snapshotSession(ctx, session), nil +} + +func (service *ConnectorsService) snapshotSession(ctx context.Context, session *connectorSession) dto.ConnectorConnectSession { + if session == nil { + return dto.ConnectorConnectSession{} + } + service.mu.Lock() + snapshotID := session.ID + snapshotConnectorID := session.ConnectorID + snapshotState := session.State + snapshotLastCookiesAt := session.LastCookiesAt + snapshotFinalError := session.FinalError + snapshotConnector := session.ConnectorSnapshot + var snapshotFinalResult *dto.FinishConnectorConnectResult + if session.FinalResult != nil { + copyResult := *session.FinalResult + copyResult.Domains = append([]string(nil), session.FinalResult.Domains...) + snapshotFinalResult = ©Result + } + service.mu.Unlock() + + connector := snapshotConnector + if snapshotFinalResult != nil { + connector = snapshotFinalResult.Connector + } else if current, err := service.repo.Get(ctx, snapshotConnectorID); err == nil { + connector = mapConnectorDTO(current) + } + lastCookiesAt := "" + if !snapshotLastCookiesAt.IsZero() { + lastCookiesAt = snapshotLastCookiesAt.Format(time.RFC3339) + } + result := dto.ConnectorConnectSession{ + SessionID: snapshotID, + ConnectorID: snapshotConnectorID, + State: snapshotState, + Error: snapshotFinalError, + LastCookiesAt: lastCookiesAt, + Connector: connector, + } + if snapshotFinalResult != nil { + result.Saved = snapshotFinalResult.Saved + result.RawCookiesCount = snapshotFinalResult.RawCookiesCount + result.FilteredCookiesCount = snapshotFinalResult.FilteredCookiesCount + result.Domains = append([]string(nil), snapshotFinalResult.Domains...) + result.Reason = snapshotFinalResult.Reason + } + return result +} + func isSupportedConnectorType(connectorType connectors.ConnectorType) bool { switch connectorType { - case connectors.ConnectorGoogle, connectors.ConnectorXiaohongshu, connectors.ConnectorBilibili: + case connectors.ConnectorGoogle, + connectors.ConnectorGitHub, + connectors.ConnectorReddit, + connectors.ConnectorZhihu, + connectors.ConnectorX, + connectors.ConnectorXiaohongshu, + connectors.ConnectorBilibili: return true default: return false diff --git a/internal/application/cookies/cookies.go b/internal/application/cookies/cookies.go new file mode 100644 index 0000000..0a3631e --- /dev/null +++ b/internal/application/cookies/cookies.go @@ -0,0 +1,105 @@ +package cookies + +import ( + "encoding/json" + "net/url" + "strings" + + "dreamcreator/internal/application/sitepolicy" +) + +type Record struct { + Name string `json:"name"` + Value string `json:"value"` + Domain string `json:"domain"` + Path string `json:"path"` + Expires int64 `json:"expires"` + HttpOnly bool `json:"httpOnly"` + Secure bool `json:"secure"` + SameSite string `json:"sameSite,omitempty"` +} + +func EncodeJSON(records []Record) (string, error) { + if len(records) == 0 { + return "", nil + } + data, err := json.Marshal(records) + if err != nil { + return "", err + } + return string(data), nil +} + +func DecodeJSON(data string) []Record { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return nil + } + var records []Record + if err := json.Unmarshal([]byte(trimmed), &records); err != nil { + return nil + } + return records +} + +func FilterByDomains(records []Record, domains []string) []Record { + if len(records) == 0 || len(domains) == 0 { + return nil + } + result := make([]Record, 0, len(records)) + for _, record := range records { + domain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(record.Domain)), ".") + if domain == "" { + continue + } + for _, allowed := range domains { + if sitepolicy.HostMatchesDomain(domain, allowed) { + result = append(result, normalizeRecord(record)) + break + } + } + } + return result +} + +func MatchURL(records []Record, rawURL string) []Record { + if len(records) == 0 { + return nil + } + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return nil + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + path := parsed.EscapedPath() + if path == "" { + path = "/" + } + result := make([]Record, 0, len(records)) + for _, record := range records { + domain := strings.TrimSpace(record.Domain) + if domain != "" && !sitepolicy.HostMatchesDomain(host, domain) { + continue + } + cookiePath := strings.TrimSpace(record.Path) + if cookiePath == "" { + cookiePath = "/" + } + if !strings.HasPrefix(path, cookiePath) { + continue + } + result = append(result, normalizeRecord(record)) + } + return result +} + +func normalizeRecord(record Record) Record { + record.Name = strings.TrimSpace(record.Name) + record.Domain = strings.TrimSpace(record.Domain) + record.Path = strings.TrimSpace(record.Path) + record.SameSite = strings.ToLower(strings.TrimSpace(record.SameSite)) + if record.Path == "" { + record.Path = "/" + } + return record +} diff --git a/internal/application/externaltools/service/service.go b/internal/application/externaltools/service/service.go index 2d0e908..90e5901 100644 --- a/internal/application/externaltools/service/service.go +++ b/internal/application/externaltools/service/service.go @@ -20,7 +20,6 @@ import ( "time" "github.com/google/uuid" - "github.com/playwright-community/playwright-go" "dreamcreator/internal/application/externaltools/dto" "dreamcreator/internal/application/softwareupdate" @@ -70,16 +69,9 @@ var ( SourceRef: "clawhub", Manager: toolManagerBun, }, - externaltools.ToolPlaywright: { - ToolKind: string(externaltools.KindRuntime), - Kind: sourceKindRuntime, - SourceRef: "playwright-community/playwright-go", - }, } - semverTokenPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?$`) - playwrightVersionTokenPattern = regexp.MustCompile(`^\d+(?:\.\d+){2,}(?:[-+][0-9A-Za-z.-]+)?$`) - playwrightVersionScanPattern = regexp.MustCompile(`\d+(?:\.\d+){2,}(?:[-+][0-9A-Za-z.-]+)?`) + semverTokenPattern = regexp.MustCompile(`^v?\d+\.\d+\.\d+(?:[-+][0-9A-Za-z.-]+)?$`) ) type ExternalToolsService struct { @@ -163,7 +155,6 @@ func (service *ExternalToolsService) EnsureDefaults(ctx context.Context) error { externaltools.ToolFFmpeg, externaltools.ToolBun, externaltools.ToolClawHub, - externaltools.ToolPlaywright, } existing, err := service.repo.List(ctx) if err != nil { @@ -301,17 +292,8 @@ func (service *ExternalToolsService) InstallTool(ctx context.Context, request dt case sourceKindNPMRegistry: return service.installNPMRegistryTool(ctx, toolName, source, request.Version, manager) case sourceKindRuntime: - if manager != "" { - service.setInstallState(toolName, installStageError, downloadProgressStart, "manager is unsupported for this tool") - return dto.ExternalTool{}, fmt.Errorf("manager is unsupported for tool %s", toolName) - } - switch toolName { - case externaltools.ToolPlaywright: - return service.installPlaywrightRuntime(ctx) - default: - service.setInstallState(toolName, installStageError, downloadProgressStart, "invalid tool") - return dto.ExternalTool{}, externaltools.ErrInvalidTool - } + service.setInstallState(toolName, installStageError, downloadProgressStart, "runtime tools are not supported") + return dto.ExternalTool{}, externaltools.ErrInvalidTool default: service.setInstallState(toolName, installStageError, downloadProgressStart, "unsupported source") return dto.ExternalTool{}, fmt.Errorf("unsupported source for tool %s", toolName) @@ -408,14 +390,10 @@ func (service *ExternalToolsService) RemoveTool(ctx context.Context, request dto if name == "" { return externaltools.ErrInvalidTool } - toolName := externaltools.ToolName(name) tool, err := service.repo.Get(ctx, name) if err != nil && err != externaltools.ErrToolNotFound { return err } - if toolName == externaltools.ToolPlaywright { - _ = uninstallPlaywrightRuntime() - } if err == nil && tool.ExecPath != "" { _ = os.RemoveAll(filepath.Dir(tool.ExecPath)) } @@ -1192,134 +1170,6 @@ func (service *ExternalToolsService) installNPMRegistryTool(ctx context.Context, return toExternalToolDTO(tool), nil } -func (service *ExternalToolsService) installPlaywrightRuntime(ctx context.Context) (dto.ExternalTool, error) { - service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, downloadProgressStart, "") - stopProgress := make(chan struct{}) - var stopProgressOnce sync.Once - stopProgressTicker := func() { - stopProgressOnce.Do(func() { - close(stopProgress) - }) - } - defer stopProgressTicker() - - go func() { - ticker := time.NewTicker(350 * time.Millisecond) - defer ticker.Stop() - progress := downloadProgressStart - ceiling := downloadProgressEnd - 2 - for { - select { - case <-stopProgress: - return - case <-ctx.Done(): - return - case <-ticker.C: - if progress >= ceiling { - continue - } - progress += 2 - if progress > ceiling { - progress = ceiling - } - service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, progress, "") - } - } - }() - - if err := playwright.Install(&playwright.RunOptions{ - Browsers: []string{"chromium"}, - Verbose: false, - Stdout: io.Discard, - Stderr: io.Discard, - }); err != nil { - stopProgressTicker() - service.setInstallState(externaltools.ToolPlaywright, installStageError, downloadProgressStart, err.Error()) - return dto.ExternalTool{}, err - } - stopProgressTicker() - service.setInstallState(externaltools.ToolPlaywright, installStageDownloading, downloadProgressEnd, "") - service.setInstallState(externaltools.ToolPlaywright, installStageVerifying, verifyProgressStart, "") - execPath, version, err := resolvePlaywrightRuntime(ctx) - if err != nil { - service.setInstallState(externaltools.ToolPlaywright, installStageError, verifyProgressStart, err.Error()) - return dto.ExternalTool{}, err - } - now := service.now() - tool, err := externaltools.NewExternalTool(externaltools.ExternalToolParams{ - Name: string(externaltools.ToolPlaywright), - ExecPath: execPath, - Version: version, - Status: string(externaltools.StatusInstalled), - InstalledAt: &now, - UpdatedAt: &now, - }) - if err != nil { - return dto.ExternalTool{}, err - } - if err := service.repo.Save(ctx, tool); err != nil { - service.setInstallState(externaltools.ToolPlaywright, installStageError, verifyProgressEnd, err.Error()) - return dto.ExternalTool{}, err - } - service.setInstallState(externaltools.ToolPlaywright, installStageDone, verifyProgressEnd, "") - return toExternalToolDTO(tool), nil -} - -func resolvePlaywrightRuntime(ctx context.Context) (string, string, error) { - pw, err := playwright.Run(&playwright.RunOptions{ - Verbose: false, - Stdout: io.Discard, - Stderr: io.Discard, - SkipInstallBrowsers: true, - }) - if err != nil { - return "", "", err - } - defer pw.Stop() - execPath := strings.TrimSpace(pw.Chromium.ExecutablePath()) - if execPath == "" { - return "", "", fmt.Errorf("playwright chromium executable path is empty") - } - browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(true), - Args: []string{"--headless=new"}, - }) - if err != nil { - return "", "", err - } - browserVersion := strings.TrimSpace(browser.Version()) - if closeErr := browser.Close(); closeErr != nil { - return "", "", closeErr - } - version := browserVersion - if version != "" { - if parsedVersion, parseErr := parsePlaywrightVersion(version); parseErr == nil { - version = parsedVersion - } - } - if strings.TrimSpace(version) == "" { - resolvedVersion, err := resolveVersion(ctx, externaltools.ToolPlaywright, execPath) - if err != nil { - return execPath, "", nil - } - version = resolvedVersion - } - return execPath, strings.TrimSpace(version), nil -} - -func uninstallPlaywrightRuntime() error { - driver, err := playwright.NewDriver(&playwright.RunOptions{ - Verbose: false, - Stdout: io.Discard, - Stderr: io.Discard, - SkipInstallBrowsers: true, - }) - if err != nil { - return err - } - return driver.Uninstall() -} - func executableNameForBinary(name string) string { trimmed := strings.TrimSpace(name) if trimmed == "" { @@ -2292,8 +2142,6 @@ func resolveVersion(ctx context.Context, name externaltools.ToolName, execPath s args = []string{"-version"} case externaltools.ToolClawHub: args = []string{"--cli-version"} - case externaltools.ToolPlaywright: - args = []string{"--version"} default: args = []string{"--version"} } @@ -2319,8 +2167,6 @@ func resolveVersion(ctx context.Context, name externaltools.ToolName, execPath s return parseFFmpegVersion(text) case externaltools.ToolClawHub: return parseClawHubVersion(text) - case externaltools.ToolPlaywright: - return parsePlaywrightVersion(text) default: return strings.Fields(text)[0], nil } @@ -2350,20 +2196,6 @@ func parseClawHubVersion(output string) (string, error) { return "", fmt.Errorf("clawhub version not found") } -func parsePlaywrightVersion(output string) (string, error) { - for _, token := range strings.Fields(output) { - candidate := strings.Trim(strings.TrimSpace(token), ",;:()[]{}") - candidate = strings.TrimPrefix(strings.TrimPrefix(candidate, "v"), "V") - if playwrightVersionTokenPattern.MatchString(candidate) { - return candidate, nil - } - } - if matched := strings.TrimSpace(playwrightVersionScanPattern.FindString(output)); matched != "" { - return matched, nil - } - return "", fmt.Errorf("playwright version not found") -} - func percent(written int64, total int64) int { if total <= 0 { return 0 diff --git a/internal/application/externaltools/service/service_test.go b/internal/application/externaltools/service/service_test.go index 2bf06e9..1f3a48d 100644 --- a/internal/application/externaltools/service/service_test.go +++ b/internal/application/externaltools/service/service_test.go @@ -60,7 +60,6 @@ func TestEnsureDefaultsIncludesBunAndClawHub(t *testing.T) { externaltools.ToolFFmpeg, externaltools.ToolBun, externaltools.ToolClawHub, - externaltools.ToolPlaywright, } { if _, err := repo.Get(context.Background(), string(name)); err != nil { t.Fatalf("expected default tool %s: %v", name, err) @@ -291,14 +290,14 @@ func TestToExternalToolDTOUsesManagedVersionFromPath(t *testing.T) { } } -func TestToExternalToolDTOIncludesRuntimeKind(t *testing.T) { +func TestToExternalToolDTOIncludesSourceMetadata(t *testing.T) { t.Parallel() now := time.Now() item, err := externaltools.NewExternalTool(externaltools.ExternalToolParams{ - Name: string(externaltools.ToolPlaywright), - ExecPath: "/tmp/chromium", - Version: "136.0.7103.25", + Name: string(externaltools.ToolClawHub), + ExecPath: "/tmp/clawhub", + Version: "1.2.3", Status: string(externaltools.StatusInstalled), UpdatedAt: &now, }) @@ -306,30 +305,16 @@ func TestToExternalToolDTOIncludesRuntimeKind(t *testing.T) { t.Fatalf("new external tool failed: %v", err) } result := toExternalToolDTO(item) - if result.Kind != string(externaltools.KindRuntime) { - t.Fatalf("expected runtime kind, got %q", result.Kind) + if result.Kind != string(externaltools.KindBin) { + t.Fatalf("expected bin kind, got %q", result.Kind) } - if result.SourceKind != sourceKindRuntime { - t.Fatalf("expected runtime source kind, got %q", result.SourceKind) + if result.SourceKind != sourceKindNPMRegistry { + t.Fatalf("expected npm registry source kind, got %q", result.SourceKind) } -} - -func TestParsePlaywrightVersion(t *testing.T) { - t.Parallel() - - version, err := parsePlaywrightVersion("Chromium 136.0.7103.25") - if err != nil { - t.Fatalf("parse playwright version failed: %v", err) - } - if version != "136.0.7103.25" { - t.Fatalf("unexpected playwright version: %s", version) - } - - scanned, err := parsePlaywrightVersion("HeadlessChrome/136.0.7103.25") - if err != nil { - t.Fatalf("parse playwright scanned version failed: %v", err) + if result.SourceRef != "clawhub" { + t.Fatalf("expected clawhub source ref, got %q", result.SourceRef) } - if scanned != "136.0.7103.25" { - t.Fatalf("unexpected scanned playwright version: %s", scanned) + if result.Manager != toolManagerBun { + t.Fatalf("expected bun manager, got %q", result.Manager) } } diff --git a/internal/application/gateway/approvals/publisher.go b/internal/application/gateway/approvals/publisher.go index 42158ec..44de551 100644 --- a/internal/application/gateway/approvals/publisher.go +++ b/internal/application/gateway/approvals/publisher.go @@ -6,6 +6,7 @@ import ( "time" gatewayevents "dreamcreator/internal/application/gateway/events" + domainsession "dreamcreator/internal/domain/session" ) type GatewayEventPublisher struct { @@ -20,11 +21,11 @@ func (publisher *GatewayEventPublisher) Publish(ctx context.Context, eventType s if publisher == nil || publisher.events == nil { return nil } - sessionKey := resolveSessionKey(payload) + sessionID, sessionKey := resolveSessionIdentity(payload) envelope := gatewayevents.Envelope{ Type: eventType, Topic: "exec.approval", - SessionID: sessionKey, + SessionID: sessionID, SessionKey: sessionKey, Timestamp: time.Now(), } @@ -32,16 +33,27 @@ func (publisher *GatewayEventPublisher) Publish(ctx context.Context, eventType s return err } -func resolveSessionKey(payload any) string { +func resolveSessionIdentity(payload any) (string, string) { + sessionKey := "" switch value := payload.(type) { case Request: - return strings.TrimSpace(value.SessionKey) + sessionKey = strings.TrimSpace(value.SessionKey) case *Request: if value == nil { - return "" + return "", "" } - return strings.TrimSpace(value.SessionKey) + sessionKey = strings.TrimSpace(value.SessionKey) default: - return "" + return "", "" } + if parts, _, err := domainsession.NormalizeSessionKey(sessionKey); err == nil { + sessionID := strings.TrimSpace(parts.ThreadRef) + if sessionID == "" { + sessionID = strings.TrimSpace(parts.PrimaryID) + } + if sessionID != "" { + return sessionID, sessionKey + } + } + return sessionKey, sessionKey } diff --git a/internal/application/gateway/approvals/publisher_test.go b/internal/application/gateway/approvals/publisher_test.go index e5e0a4f..42cc0d6 100644 --- a/internal/application/gateway/approvals/publisher_test.go +++ b/internal/application/gateway/approvals/publisher_test.go @@ -6,6 +6,7 @@ import ( "time" gatewayevents "dreamcreator/internal/application/gateway/events" + domainsession "dreamcreator/internal/domain/session" ) func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) { @@ -18,9 +19,17 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) { }) defer unsubscribe() + sessionKey, err := domainsession.BuildSessionKey(domainsession.KeyParts{ + Channel: "aui", + PrimaryID: "thread-123", + ThreadRef: "thread-123", + }) + if err != nil { + t.Fatalf("build session key: %v", err) + } request := Request{ ID: "approval-1", - SessionKey: "thread-123", + SessionKey: sessionKey, ToolName: "exec", Action: "config.schema", } @@ -30,8 +39,8 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) { select { case record := <-recordCh: - if record.Envelope.SessionKey != "thread-123" { - t.Fatalf("expected session key thread-123, got %q", record.Envelope.SessionKey) + if record.Envelope.SessionKey != sessionKey { + t.Fatalf("expected session key %q, got %q", sessionKey, record.Envelope.SessionKey) } if record.Envelope.SessionID != "thread-123" { t.Fatalf("expected session id thread-123, got %q", record.Envelope.SessionID) @@ -40,3 +49,36 @@ func TestGatewayEventPublisherIncludesSessionInEnvelope(t *testing.T) { t.Fatal("expected approval event record") } } + +func TestGatewayEventPublisherFallsBackToRawSessionKey(t *testing.T) { + events := gatewayevents.NewBroker(nil) + publisher := NewGatewayEventPublisher(events) + + recordCh := make(chan gatewayevents.Record, 1) + unsubscribe := events.Subscribe(gatewayevents.Filter{Type: "exec.approval.requested"}, func(record gatewayevents.Record) { + recordCh <- record + }) + defer unsubscribe() + + request := Request{ + ID: "approval-2", + SessionKey: "raw-thread-id", + ToolName: "exec", + Action: "config.schema", + } + if err := publisher.Publish(context.Background(), "exec.approval.requested", request); err != nil { + t.Fatalf("publish approval event: %v", err) + } + + select { + case record := <-recordCh: + if record.Envelope.SessionKey != "raw-thread-id" { + t.Fatalf("expected raw session key, got %q", record.Envelope.SessionKey) + } + if record.Envelope.SessionID != "raw-thread-id" { + t.Fatalf("expected raw session id, got %q", record.Envelope.SessionID) + } + case <-time.After(time.Second): + t.Fatal("expected approval event record") + } +} diff --git a/internal/application/gateway/runtime/prompt_build.go b/internal/application/gateway/runtime/prompt_build.go index 1ffb13d..e5ffeab 100644 --- a/internal/application/gateway/runtime/prompt_build.go +++ b/internal/application/gateway/runtime/prompt_build.go @@ -305,7 +305,7 @@ func formatSkillsSection(skills []skillsdto.SkillPromptItem) string { if len(skills) == 0 { lines = append(lines, "- (none currently eligible)", - "- Use `skill_manage.search` to discover skills, `skill_manage.install` to install, then `skills.status` to refresh.", + "- Use `skills_manage.search` to discover skills, `skills_manage.install` to install, then `skills.status` to refresh.", ) return joinLines(lines) } @@ -348,8 +348,8 @@ func skillsSectionPreambleLines() []string { "Skills protocol:", "- First scan the available skills list.", "- If one skill clearly matches, read its `SKILL.md` before acting.", - "- If no skill currently matches the task, call `skill_manage` instead of shell commands.", - "- Recommended flow: `skills.status` -> `skill_manage.search`/`skill_manage.install` -> `skills.status`.", + "- If no skill currently matches the task, call `skills_manage` instead of shell commands.", + "- Recommended flow: `skills.status` -> `skills_manage.search`/`skills_manage.install` -> `skills.status`.", "- If status reports missing runtime dependencies, call `skills.install`, then re-check with `skills.status`.", "- Use `skills.update` for per-skill settings (enabled/apiKey/env/config) when required.", "Available skills:", diff --git a/internal/application/gateway/runtime/prompt_build_test.go b/internal/application/gateway/runtime/prompt_build_test.go index 6c1a210..da91651 100644 --- a/internal/application/gateway/runtime/prompt_build_test.go +++ b/internal/application/gateway/runtime/prompt_build_test.go @@ -233,7 +233,7 @@ func TestBuildPromptDocument_FullIncludesToolingToolCallStyleAndSkillsMandatory( if !strings.Contains(skills, "## Skills (mandatory)") { t.Fatalf("expected mandatory skills heading, got %q", skills) } - if !strings.Contains(skills, "Recommended flow: `skills.status` -> `skill_manage.search`/`skill_manage.install` -> `skills.status`.") { + if !strings.Contains(skills, "Recommended flow: `skills.status` -> `skills_manage.search`/`skills_manage.install` -> `skills.status`.") { t.Fatalf("expected skills tool protocol rule, got %q", skills) } } diff --git a/internal/application/gateway/runtime/service.go b/internal/application/gateway/runtime/service.go index 3fadc2f..43fe176 100644 --- a/internal/application/gateway/runtime/service.go +++ b/internal/application/gateway/runtime/service.go @@ -404,6 +404,13 @@ func (service *Service) runWithStream(ctx context.Context, request dto.RuntimeRu defer service.aborts.Unregister(run.ID) } defer cancel() + if service.tools != nil && strings.TrimSpace(sessionKey) != "" { + defer func() { + cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cleanupCancel() + service.tools.CleanupRuntimeSession(cleanupCtx, sessionKey) + }() + } if service.queue != nil && flags.UseQueue { ticket, _, err := service.queue.Enqueue(runCtx, queue.EnqueueRequest{ diff --git a/internal/application/gateway/runtime/service_session_test.go b/internal/application/gateway/runtime/service_session_test.go index 34e63a6..e54164a 100644 --- a/internal/application/gateway/runtime/service_session_test.go +++ b/internal/application/gateway/runtime/service_session_test.go @@ -9,6 +9,15 @@ import ( sessionapp "dreamcreator/internal/application/session" ) +const ( + testTelegramSessionPeerID = "test-user-001" + testTelegramSessionThreadID = "test-thread-001" + testTelegramSessionPeerName = "Test User" + testTelegramSessionPeerNameAlt = "Test User Updated" + testTelegramSessionUsername = "testuser" + testTelegramSessionAvatarURL = "https://example.com/avatars/test-user-001.jpg" +) + func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testing.T) { t.Parallel() @@ -18,16 +27,16 @@ func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testin } request := runtimedto.RuntimeRunRequest{ - SessionID: "telegram:default:private:5234834060:conv:26291424", - SessionKey: "telegram:default:private:5234834060:conv:26291424", + SessionID: "telegram:default:private:" + testTelegramSessionPeerID + ":conv:" + testTelegramSessionThreadID, + SessionKey: "telegram:default:private:" + testTelegramSessionPeerID + ":conv:" + testTelegramSessionThreadID, Metadata: map[string]any{ "channel": "telegram", "accountId": "default", "peerKind": "direct", - "peerId": "5234834060", - "peerName": "Arnold", - "peerUsername": "arnold", - "peerAvatarUrl": "https://t.me/i/userpic/320/arnold.jpg", + "peerId": testTelegramSessionPeerID, + "peerName": testTelegramSessionPeerName, + "peerUsername": testTelegramSessionUsername, + "peerAvatarUrl": testTelegramSessionAvatarURL, }, } @@ -55,13 +64,13 @@ func TestResolveSession_RebuildsCanonicalKeyForCustomChannelSessionKey(t *testin if stored.Origin.AccountID != "default" { t.Fatalf("origin accountId mismatch: %q", stored.Origin.AccountID) } - if stored.Origin.PeerID != "5234834060" { + if stored.Origin.PeerID != testTelegramSessionPeerID { t.Fatalf("origin peerId mismatch: %q", stored.Origin.PeerID) } - if stored.Origin.PeerName != "Arnold" { + if stored.Origin.PeerName != testTelegramSessionPeerName { t.Fatalf("origin peerName mismatch: %q", stored.Origin.PeerName) } - if stored.Origin.PeerUsername != "arnold" { + if stored.Origin.PeerUsername != testTelegramSessionUsername { t.Fatalf("origin peerUsername mismatch: %q", stored.Origin.PeerUsername) } if stored.Origin.PeerAvatarURL == "" { @@ -78,15 +87,15 @@ func TestPersistSession_UpdatesAssistantIDForExistingSession(t *testing.T) { } request := runtimedto.RuntimeRunRequest{ - SessionID: "telegram:default:private:5234834060", - SessionKey: "telegram:default:private:5234834060", + SessionID: "telegram:default:private:" + testTelegramSessionPeerID, + SessionKey: "telegram:default:private:" + testTelegramSessionPeerID, Metadata: map[string]any{ "channel": "telegram", "accountId": "default", "peerKind": "direct", - "peerId": "5234834060", - "peerName": "Arnold HAO", - "peerUsername": "arnold", + "peerId": testTelegramSessionPeerID, + "peerName": testTelegramSessionPeerNameAlt, + "peerUsername": testTelegramSessionUsername, }, } sessionID, sessionKey, err := runtimeService.resolveSession(request) @@ -110,7 +119,7 @@ func TestPersistSession_UpdatesAssistantIDForExistingSession(t *testing.T) { if after.AssistantID != "assistant-123" { t.Fatalf("assistant id was not updated: got %q", after.AssistantID) } - if after.Origin.PeerName != "Arnold HAO" { + if after.Origin.PeerName != testTelegramSessionPeerNameAlt { t.Fatalf("origin peer name mismatch: got %q", after.Origin.PeerName) } } @@ -124,16 +133,16 @@ func TestPersistSession_PreservesExistingOriginWhenIncomingMetadataIsIncomplete( } request := runtimedto.RuntimeRunRequest{ - SessionID: "telegram:default:private:5234834060", - SessionKey: "telegram:default:private:5234834060", + SessionID: "telegram:default:private:" + testTelegramSessionPeerID, + SessionKey: "telegram:default:private:" + testTelegramSessionPeerID, Metadata: map[string]any{ "channel": "telegram", "accountId": "default", "peerKind": "direct", - "peerId": "5234834060", - "peerName": "Arnold HAO", - "peerUsername": "arnold", - "peerAvatarUrl": "https://t.me/i/userpic/320/arnold.jpg", + "peerId": testTelegramSessionPeerID, + "peerName": testTelegramSessionPeerNameAlt, + "peerUsername": testTelegramSessionUsername, + "peerAvatarUrl": testTelegramSessionAvatarURL, }, } sessionID, sessionKey, err := runtimeService.resolveSession(request) @@ -157,10 +166,10 @@ func TestPersistSession_PreservesExistingOriginWhenIncomingMetadataIsIncomplete( if stored.Origin.AccountID != "default" { t.Fatalf("origin accountId should be preserved: got %q", stored.Origin.AccountID) } - if stored.Origin.PeerID != "5234834060" { + if stored.Origin.PeerID != testTelegramSessionPeerID { t.Fatalf("origin peerId should be preserved: got %q", stored.Origin.PeerID) } - if stored.Origin.PeerName != "Arnold HAO" { + if stored.Origin.PeerName != testTelegramSessionPeerNameAlt { t.Fatalf("origin peerName should be preserved: got %q", stored.Origin.PeerName) } if stored.Origin.PeerAvatarURL == "" { diff --git a/internal/application/gateway/subagent/service.go b/internal/application/gateway/subagent/service.go index 10a8d99..f6d0bdc 100644 --- a/internal/application/gateway/subagent/service.go +++ b/internal/application/gateway/subagent/service.go @@ -37,8 +37,7 @@ var subagentToolDenyAlways = []string{ "whatsapp_login", "session_status", "cron", - "memory_recall", - "memory_list", + "memory_query", "sessions_send", } diff --git a/internal/application/gateway/tools/browser_tools.go b/internal/application/gateway/tools/browser_tools.go index 27ee447..454ee79 100644 --- a/internal/application/gateway/tools/browser_tools.go +++ b/internal/application/gateway/tools/browser_tools.go @@ -2,217 +2,40 @@ package tools import ( "context" - "encoding/base64" "encoding/json" "errors" "fmt" - "io" - "net" - "net/http" - "net/url" - "os" - "path/filepath" - "reflect" - "regexp" - "sort" - "strconv" "strings" - "sync" - "sync/atomic" "time" + "dreamcreator/internal/application/browsercdp" + appcookies "dreamcreator/internal/application/cookies" gatewaynodes "dreamcreator/internal/application/gateway/nodes" - "github.com/playwright-community/playwright-go" ) -const ( - browserTypePlaywright = "playwright" - defaultBrowserType = browserTypePlaywright +var browserToolSessions = browsercdp.NewSessionRegistry() - defaultBrowserWaitUntil = "domcontentloaded" - - defaultBrowserSnapshotModeEfficient = "efficient" - - defaultBrowserProfileDreamCreator = "dreamcreator" - defaultBrowserColor = "#FF4500" - - defaultBrowserSnapshotAIMaxChars = 80000 - defaultBrowserSnapshotAIEfficientMaxChars = 10000 - defaultBrowserSnapshotDepth = 6 - defaultBrowserSnapshotLimit = 200 - defaultBrowserViewportWidth = 1366 - defaultBrowserViewportHeight = 900 - - defaultBrowserHookTimeoutMs = 20000 - - browserRuntimeCheckCacheTTL = 10 * time.Second -) - -var browserToolActions = []string{ - "status", - "start", - "stop", - "profiles", - "tabs", - "open", - "focus", - "close", - "snapshot", - "screenshot", - "navigate", - "console", - "pdf", - "upload", - "dialog", - "act", -} - -var browserSelectorUnsupportedMessage = strings.Join([]string{ - "Error: 'selector' is not supported. Use 'ref' from snapshot instead.", - "", - "Example workflow:", - "1. snapshot action to get page state with refs", - `2. act with ref: "e123" to interact with element`, - "", - "This is more reliable for modern SPAs.", -}, "\n") - -var browserWaitFnDisabledMessage = strings.Join([]string{ - "wait --fn is disabled by config (browser.evaluateEnabled=false).", - "Docs: /gateway/configuration#browser-playwright-managed-browser", -}, "\n") - -var browserEvaluateDisabledMessage = strings.Join([]string{ - "act:evaluate is disabled by config (browser.evaluateEnabled=false).", - "Docs: /gateway/configuration#browser-playwright-managed-browser", -}, "\n") - -var browserWaitRequiresConditionMessage = "wait requires at least one of: timeMs, text, textGone, selector, url, loadState, fn" - -var errBrowserSnapshotForAIUnavailable = errors.New("playwright snapshotForAI is unavailable") - -var browserActKinds = []string{ - "click", - "type", - "press", - "hover", - "drag", - "select", - "fill", - "resize", - "wait", - "evaluate", - "close", +func cleanupBrowserToolSessions(sessionKey string) { + browserToolSessions.CloseSessionKey(strings.TrimSpace(sessionKey)) } -var browserPlaywrightRuntimeCache = struct { - mu sync.Mutex - checkedAt time.Time - available bool - reason string - execPath string -}{} - -var browserGlobalTabCounter uint64 - -var globalBrowserSessions = struct { - mu sync.Mutex - sessions map[string]*browserSessionState -}{ - sessions: map[string]*browserSessionState{}, +func CleanupAllBrowserToolSessions() { + browserToolSessions.CloseAll() } -type browserSessionState struct { - sessionKey string - profiles map[string]*browserProfileState +func BrowserToolRuntimeConfigChanged(previousTools map[string]any, currentTools map[string]any) bool { + previous := resolveBrowserRuntimeConfig(previousTools) + current := resolveBrowserRuntimeConfig(currentTools) + return previous.Enabled != current.Enabled || + previous.Headless != current.Headless || + previous.PreferredBrowser != current.PreferredBrowser } type browserProfileState struct { - mu sync.Mutex - + sessionKey string profileName string resolved browserResolvedConfig - profile browserProfileConfig - - pw *playwright.Playwright - browser playwright.Browser - context playwright.BrowserContext - - tabs map[string]*browserTabState - pageToTarget map[playwright.Page]string - activeTarget string - - consoleMessages []browserConsoleMessage - pendingUploads map[string]browserPendingUpload - pendingDialogs map[string]browserPendingDialog -} - -type browserTabState struct { - TargetID string - Page playwright.Page - - mu sync.RWMutex - refs map[string]browserSnapshotRef - evaluateResult any -} - -type browserSnapshotRef struct { - Selector string - Role string - Name string - Nth int - Mode string - AriaRef string - Frame string -} - -type browserConsoleMessage struct { - TargetID string `json:"targetId"` - Type string `json:"type"` - Text string `json:"text"` - Timestamp string `json:"timestamp"` -} - -type browserPendingUpload struct { - Paths []string - ExpiresAt time.Time -} - -type browserPendingDialog struct { - Accept bool - PromptText string - ExpiresAt time.Time -} - -type browserResolvedConfig struct { - Enabled bool - EvaluateEnabled bool - CDPURL string - RemoteCdpTimeoutMs int - RemoteCdpHandshakeTimeoutMs int - Color string - Headless bool - NoSandbox bool - AttachOnly bool - DefaultProfile string - Profiles map[string]browserProfileConfig - SnapshotDefaultMode string - SSRFRules browserSSRFPolicy - ExtraArgs []string -} - -type browserProfileConfig struct { - Name string - CDPURL string - CDPPort int - Color string - Driver string -} - -type browserSSRFPolicy struct { - DangerouslyAllowPrivateNetwork bool - AllowedHostnames map[string]struct{} - HostnameAllowlist []string + session *browsercdp.Session } func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes *gatewaynodes.Service) func(ctx context.Context, args string) (string, error) { @@ -221,7 +44,6 @@ func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes if err != nil { return "", err } - action, err := resolveBrowserAction(payload) if err != nil { return "", err @@ -229,2894 +51,320 @@ func runBrowserTool(settings SettingsReader, connectors ConnectorsReader, nodes if isBrowserNodeTargetRequest(payload) { return runBrowserActionOnNode(ctx, payload, action, nodes) } - toolsConfig := resolveToolsConfig(ctx, settings) resolved := resolveBrowserRuntimeConfig(toolsConfig) if !resolved.Enabled { return "", errors.New("browser disabled") } - profileName := resolveBrowserProfileName(payload, resolved) sessionKey := resolveBrowserSessionKey(ctx, payload) - state := getOrCreateBrowserProfileState(sessionKey, profileName, resolved) - - result, err := runBrowserAction(ctx, payload, action, state, connectors) + state := getBrowserProfileState(sessionKey, profileName, resolved, connectors) + result, err := runBrowserAction(ctx, payload, action, state) if err != nil { + if browsercdp.IsFatalError(err) { + return "", fmt.Errorf("browser session reset after runtime failure: %w", err) + } return "", err } return marshalResult(result), nil } } -func resolveBrowserAction(payload toolArgs) (string, error) { - rawAction := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "method"))) - if rawAction == "" { - if getStringArg(payload, "targetUrl", "url") != "" { - rawAction = "open" - } else { - rawAction = "status" - } - } - switch rawAction { - case "navigate", "status", "start", "stop", "profiles", "tabs", "open", "focus", "close", "snapshot", "screenshot", "console", "pdf", "upload", "dialog", "act": - return rawAction, nil - default: - return "", errors.New("browser action not supported: " + rawAction) - } -} - -func isBrowserNodeTargetRequest(payload toolArgs) bool { - target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target"))) - nodeID := strings.TrimSpace(getStringArg(payload, "node", "nodeId")) - if target == "node" { - return true - } - return nodeID != "" -} - -func resolveBrowserNodeID(ctx context.Context, payload toolArgs, nodes *gatewaynodes.Service) (string, error) { - requestedNode := strings.TrimSpace(getStringArg(payload, "node", "nodeId")) - if requestedNode != "" { - return requestedNode, nil - } - if nodes == nil { - return "", errors.New("nodes service unavailable") - } - list, err := nodes.ListNodes(ctx) - if err != nil { - return "", err - } - for _, descriptor := range list { - nodeID := strings.TrimSpace(descriptor.NodeID) - if nodeID == "" { - continue - } - for _, capability := range descriptor.Capabilities { - if strings.EqualFold(strings.TrimSpace(capability.Name), "browser.control") { - return nodeID, nil - } - } - } - for _, descriptor := range list { - nodeID := strings.TrimSpace(descriptor.NodeID) - if nodeID != "" { - return nodeID, nil - } - } - return "", errors.New("nodeId is required") -} - -func runBrowserActionOnNode(ctx context.Context, payload toolArgs, action string, nodes *gatewaynodes.Service) (string, error) { - if nodes == nil { - return "", errors.New("nodes service unavailable") - } - target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target"))) - if target != "" && target != "node" { - return "", errors.New(`node is only supported with target="node"`) - } - nodeID, err := resolveBrowserNodeID(ctx, payload, nodes) - if err != nil { - return "", err - } - argsJSON, err := json.Marshal(payload) - if err != nil { - return "", err - } - request := gatewaynodes.NodeInvokeRequest{ - NodeID: nodeID, - Capability: "browser.control", - Action: action, - Args: string(argsJSON), - TimeoutMs: resolveBrowserActionTimeoutMs(payload, 30000), - } - result, invokeErr := nodes.Invoke(ctx, request) - if invokeErr != nil { - return marshalResult(result), invokeErr - } - if !result.Ok { - if strings.TrimSpace(result.Error) != "" { - return marshalResult(result), errors.New(strings.TrimSpace(result.Error)) - } - return marshalResult(result), errors.New("node browser invoke failed") - } - if parsed := resolveBrowserNodeOutput(result.Output); parsed != nil { - return marshalResult(parsed), nil - } - return marshalResult(result), nil -} - -type browserNodeProxyEnvelope struct { - Result any `json:"result"` - Files []browserNodeProxyFile `json:"files"` -} - -type browserNodeProxyFile struct { - Path string `json:"path"` - Base64 string `json:"base64"` - MimeType string `json:"mimeType"` -} - -func resolveBrowserNodeOutput(output string) any { - trimmedOutput := strings.TrimSpace(output) - if trimmedOutput == "" { - return nil - } - var parsed any - if err := json.Unmarshal([]byte(trimmedOutput), &parsed); err != nil { - return nil - } - - envelope := browserNodeProxyEnvelope{} - if err := json.Unmarshal([]byte(trimmedOutput), &envelope); err == nil && envelope.Result != nil { - mapping := persistBrowserNodeProxyFiles(envelope.Files) - applyBrowserProxyPathMapping(envelope.Result, mapping) - return envelope.Result - } - return parsed -} - -func persistBrowserNodeProxyFiles(files []browserNodeProxyFile) map[string]string { - if len(files) == 0 { - return nil - } - mapping := map[string]string{} - for _, file := range files { - remotePath := strings.TrimSpace(file.Path) - encoded := strings.TrimSpace(file.Base64) - if remotePath == "" || encoded == "" { - continue - } - bytes, err := base64.StdEncoding.DecodeString(encoded) - if err != nil { - continue - } - localPath, err := saveBrowserArtifact(resolveBrowserProxyFileExt(file), bytes) - if err != nil { - continue - } - mapping[remotePath] = localPath - } - if len(mapping) == 0 { - return nil - } - return mapping -} - -func resolveBrowserProxyFileExt(file browserNodeProxyFile) string { - ext := strings.TrimSpace(strings.TrimPrefix(filepath.Ext(strings.TrimSpace(file.Path)), ".")) - if ext != "" { - return ext - } - mimeType := strings.ToLower(strings.TrimSpace(file.MimeType)) - switch { - case strings.Contains(mimeType, "png"): - return "png" - case strings.Contains(mimeType, "jpeg"), strings.Contains(mimeType, "jpg"): - return "jpg" - case strings.Contains(mimeType, "pdf"): - return "pdf" - case strings.Contains(mimeType, "json"): - return "json" - case strings.Contains(mimeType, "text"), strings.Contains(mimeType, "plain"): - return "txt" - default: - return "bin" - } -} - -func applyBrowserProxyPathMapping(result any, mapping map[string]string) { - if len(mapping) == 0 || result == nil { - return - } - obj, ok := result.(map[string]any) - if !ok { - return - } - if pathValue, ok := obj["path"].(string); ok { - if mapped, exists := mapping[pathValue]; exists { - obj["path"] = mapped - } - } - if imagePathValue, ok := obj["imagePath"].(string); ok { - if mapped, exists := mapping[imagePathValue]; exists { - obj["imagePath"] = mapped - } +func getBrowserProfileState(sessionKey string, profileName string, resolved browserResolvedConfig, connectors ConnectorsReader) *browserProfileState { + options := browsercdp.SessionOptions{ + SessionKey: sessionKey, + ProfileName: profileName, + PreferredBrowser: resolved.PreferredBrowser, + Headless: resolved.Headless, + UserDataDir: browsercdp.ResolveProfileUserDataDir(sessionKey, profileName), + SSRFRules: browsercdp.SSRFPolicy{ + DangerouslyAllowPrivateNetwork: resolved.SSRFRules.DangerouslyAllowPrivateNetwork, + AllowedHostnames: cloneBrowserAllowedHostnames(resolved.SSRFRules.AllowedHostnames), + HostnameAllowlist: append([]string(nil), resolved.SSRFRules.HostnameAllowlist...), + }, + Cookies: browsercdp.ConnectorCookieProviderFunc(func(ctx context.Context, rawURL string) ([]appcookies.Record, error) { + return browsercdp.ResolveConnectorCookiesForURL(ctx, connectors, rawURL) + }), } - if downloadRaw, exists := obj["download"]; exists { - if downloadObj, ok := downloadRaw.(map[string]any); ok { - if pathValue, ok := downloadObj["path"].(string); ok { - if mapped, exists := mapping[pathValue]; exists { - downloadObj["path"] = mapped - } - } - } + return &browserProfileState{ + sessionKey: sessionKey, + profileName: profileName, + resolved: resolved, + session: browserToolSessions.GetOrCreate(sessionKey, profileName, options), } } -func resolveBrowserSessionKey(ctx context.Context, payload toolArgs) string { - sessionKey, _ := RuntimeContextFromContext(ctx) - sessionKey = strings.TrimSpace(sessionKey) - if sessionKey == "" { - sessionKey = strings.TrimSpace(getStringArg(payload, "sessionKey", "session_key")) +func cloneBrowserAllowedHostnames(values map[string]struct{}) map[string]struct{} { + if len(values) == 0 { + return map[string]struct{}{} } - if sessionKey == "" { - sessionKey = "default" + cloned := make(map[string]struct{}, len(values)) + for key := range values { + cloned[key] = struct{}{} } - return sessionKey + return cloned } -func runBrowserAction( - ctx context.Context, - payload toolArgs, - action string, - state *browserProfileState, - connectors ConnectorsReader, -) (any, error) { +func runBrowserAction(ctx context.Context, payload toolArgs, action string, state *browserProfileState) (map[string]any, error) { switch action { - case "status": - return browserActionStatus(state) - case "start": - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - return browserActionStatus(state) - case "stop": - if err := stopBrowserProfile(state); err != nil { - return nil, err - } - return browserActionStatus(state) - case "profiles": - return browserActionProfiles(state), nil - case "tabs": - return browserActionTabs(state) case "open": - return browserActionOpen(ctx, payload, state, connectors) - case "focus": - return browserActionFocus(payload, state) - case "close": - return browserActionClose(payload, state) + return browserActionOpen(ctx, payload, state) + case "navigate": + return browserActionNavigate(ctx, payload, state) case "snapshot": return browserActionSnapshot(payload, state) - case "screenshot": - return browserActionScreenshot(payload, state) - case "navigate": - return browserActionNavigate(ctx, payload, state, connectors) - case "console": - return browserActionConsole(payload, state) - case "pdf": - return browserActionPDF(payload, state) + case "wait": + return browserActionWait(ctx, payload, state) + case "scroll": + return browserActionScroll(payload, state) case "upload": return browserActionUpload(payload, state) case "dialog": return browserActionDialog(payload, state) case "act": - return browserActionAct(payload, state) + return browserActionAct(ctx, payload, state) + case "reset": + return browserActionReset(payload, state) default: return nil, errors.New("browser action not supported: " + action) } } -func browserActionStatus(state *browserProfileState) (map[string]any, error) { - available, reason, execPath := resolveBrowserPlaywrightRuntimeAvailability() - - state.mu.Lock() - defer state.mu.Unlock() - pruneClosedTabsLocked(state) - - running := state.browser != nil && state.context != nil - detectedPath := any(nil) - if strings.TrimSpace(execPath) != "" { - detectedPath = strings.TrimSpace(execPath) - } - detectError := any(nil) - if !available && strings.TrimSpace(reason) != "" { - detectError = strings.TrimSpace(reason) - } - chosenBrowser := any(nil) - if running { - chosenBrowser = "chromium" - } - - return map[string]any{ - "enabled": state.resolved.Enabled, - "profile": state.profileName, - "running": running, - "cdpReady": false, - "cdpHttp": false, - "pid": nil, - "cdpPort": nil, - "cdpUrl": nil, - "chosenBrowser": chosenBrowser, - "detectedBrowser": "chromium", - "detectedExecutablePath": detectedPath, - "detectError": detectError, - "userDataDir": nil, - "color": state.profile.Color, - "headless": state.resolved.Headless, - "noSandbox": state.resolved.NoSandbox, - "executablePath": detectedPath, - "attachOnly": false, - "tabCount": len(state.tabs), - "activeTargetId": state.activeTarget, - }, nil -} - -func browserActionProfiles(state *browserProfileState) map[string]any { - state.mu.Lock() - defer state.mu.Unlock() - pruneClosedTabsLocked(state) - - names := make([]string, 0, len(state.resolved.Profiles)) - for name := range state.resolved.Profiles { - names = append(names, name) - } - sort.Strings(names) - - profiles := make([]map[string]any, 0, len(names)) - for _, name := range names { - cfg := state.resolved.Profiles[name] - isCurrent := name == state.profileName - running := isCurrent && state.browser != nil && state.context != nil - tabCount := 0 - if isCurrent { - tabCount = len(state.tabs) - } - profiles = append(profiles, map[string]any{ - "name": name, - "cdpPort": cfg.CDPPort, - "cdpUrl": cfg.CDPURL, - "color": cfg.Color, - "running": running, - "tabCount": tabCount, - "isDefault": name == state.resolved.DefaultProfile, - "isRemote": cfg.CDPURL != "", - }) - } - - return map[string]any{"profiles": profiles} -} - -func browserActionTabs(state *browserProfileState) (map[string]any, error) { - state.mu.Lock() - running := state.browser != nil && state.context != nil - state.mu.Unlock() - if !running { - return map[string]any{"running": false, "tabs": []any{}}, nil - } - - tabs, err := listBrowserTabs(state) +func browserActionOpen(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) { + result, err := state.session.Open(ctx, strings.TrimSpace(getStringArg(payload, "targetUrl", "url")), browserCommandOptions(payload, 30000)) if err != nil { return nil, err } - return map[string]any{"running": true, "tabs": tabs}, nil + return browserResultMap(result), nil } -func browserActionOpen(ctx context.Context, payload toolArgs, state *browserProfileState, connectors ConnectorsReader) (map[string]any, error) { - targetURL := getStringArg(payload, "targetUrl", "url") - if targetURL == "" { - return nil, errors.New("targetUrl is required") - } - if err := assertBrowserURLAllowed(targetURL, state.resolved.SSRFRules); err != nil { - return nil, err - } - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - - state.mu.Lock() - browserCtx := state.context - state.mu.Unlock() - if browserCtx == nil { - return nil, errors.New("browser context unavailable") - } - - if err := addConnectorCookiesToContext(ctx, connectors, browserCtx, targetURL); err != nil { +func browserActionNavigate(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) { + newTab, _ := getBoolArg(payload, "newTab") + result, err := state.session.Navigate( + ctx, + strings.TrimSpace(getStringArg(payload, "targetId")), + strings.TrimSpace(getStringArg(payload, "targetUrl", "url")), + newTab, + browserCommandOptions(payload, 30000), + ) + if err != nil { return nil, err } + return browserResultMap(result), nil +} - page, err := browserCtx.NewPage() +func browserActionSnapshot(payload toolArgs, state *browserProfileState) (map[string]any, error) { + result, err := state.session.State(strings.TrimSpace(getStringArg(payload, "targetId")), resolveBrowserSnapshotLimit(payload)) if err != nil { return nil, err } - if _, err := page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{ - Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 30000))), - WaitUntil: resolveBrowserWaitUntilState(resolveBrowserWaitUntil(payload, defaultBrowserWaitUntil)), - }); err != nil { - _ = page.Close() - return nil, err - } + return browserResultMap(result), nil +} - tab := attachBrowserTab(state, page) - title, _ := page.Title() - urlValue := strings.TrimSpace(page.URL()) - if urlValue == "" { - urlValue = strings.TrimSpace(targetURL) - } - if err := assertBrowserURLAllowed(urlValue, state.resolved.SSRFRules); err != nil { +func browserActionWait(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) { + result, err := state.session.Wait( + ctx, + strings.TrimSpace(getStringArg(payload, "targetId")), + browserWaitRequestFromArgs(payload), + browserCommandOptions(payload, 15000), + ) + if err != nil { return nil, err } - - return map[string]any{ - "targetId": tab.TargetID, - "title": strings.TrimSpace(title), - "url": urlValue, - "type": "page", - }, nil + return browserResultMap(result), nil } -func browserActionFocus(payload toolArgs, state *browserProfileState) (map[string]any, error) { - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - if targetID == "" { - return nil, errors.New("targetId is required") - } - tab, err := resolveBrowserTab(state, targetID, false) +func browserActionScroll(payload toolArgs, state *browserProfileState) (map[string]any, error) { + deltaX, deltaY := resolveBrowserScrollDelta(payload) + result, err := state.session.Scroll(browsercdp.ScrollRequest{ + TargetID: strings.TrimSpace(getStringArg(payload, "targetId")), + Ref: strings.TrimSpace(getStringArg(payload, "ref")), + DeltaX: deltaX, + DeltaY: deltaY, + Limit: resolveBrowserSnapshotLimit(payload), + Timeout: browserTimeoutDuration(payload, 15000), + }) if err != nil { return nil, err } - if err := tab.Page.BringToFront(); err != nil { - return nil, err - } - return map[string]any{"ok": true, "targetId": tab.TargetID}, nil + return browserResultMap(result), nil } -func browserActionClose(payload toolArgs, state *browserProfileState) (map[string]any, error) { - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, false) +func browserActionUpload(payload toolArgs, state *browserProfileState) (map[string]any, error) { + rawPaths := getStringSliceArg(payload, "paths") + rootDir, err := resolveBrowserUploadRootDir() if err != nil { return nil, err } - if err := tab.Page.Close(); err != nil { - return nil, err + paths := make([]string, 0, len(rawPaths)) + for _, item := range rawPaths { + resolvedPath, err := resolvePathWithinRoot(rootDir, item) + if err != nil { + return nil, err + } + paths = append(paths, resolvedPath) } - - state.mu.Lock() - delete(state.tabs, tab.TargetID) - delete(state.pageToTarget, tab.Page) - if state.activeTarget == tab.TargetID { - state.activeTarget = "" + result, err := state.session.Upload(browsercdp.UploadRequest{ + TargetID: strings.TrimSpace(getStringArg(payload, "targetId")), + Ref: strings.TrimSpace(getStringArg(payload, "ref")), + Paths: paths, + Limit: resolveBrowserSnapshotLimit(payload), + Timeout: browserTimeoutDuration(payload, 15000), + }) + if err != nil { + return nil, err } - pruneClosedTabsLocked(state) - state.mu.Unlock() - - return map[string]any{"ok": true, "targetId": tab.TargetID}, nil + return browserResultMap(result), nil } -func browserActionSnapshot(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err +func browserActionDialog(payload toolArgs, state *browserProfileState) (map[string]any, error) { + var accept *bool + if value, ok := getBoolArg(payload, "accept"); ok { + accept = &value } - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) + result, err := state.session.Dialog(browsercdp.DialogRequest{ + TargetID: strings.TrimSpace(getStringArg(payload, "targetId")), + Accept: accept, + PromptText: strings.TrimSpace(getStringArg(payload, "promptText")), + Limit: resolveBrowserSnapshotLimit(payload), + Timeout: browserTimeoutDuration(payload, 15000), + }) if err != nil { return nil, err } + return browserResultMap(result), nil +} - format := strings.ToLower(strings.TrimSpace(getStringArg(payload, "snapshotFormat", "format"))) - if format != "aria" { - format = "ai" - } - refsMode := strings.ToLower(strings.TrimSpace(getStringArg(payload, "refs"))) - if refsMode != "aria" { - refsMode = "role" - } - mode := strings.ToLower(strings.TrimSpace(getStringArg(payload, "mode"))) - if mode == "" { - mode = strings.TrimSpace(state.resolved.SnapshotDefaultMode) - } - if mode != defaultBrowserSnapshotModeEfficient { - mode = "" - } - labels, _ := getBoolArg(payload, "labels") - if format == "aria" && (labels || mode == defaultBrowserSnapshotModeEfficient) { - return nil, errors.New("labels/mode=efficient require format=ai") - } - - interactive, hasInteractive := getBoolArg(payload, "interactive") - if !hasInteractive { - interactive = mode == defaultBrowserSnapshotModeEfficient - } - compact, hasCompact := getBoolArg(payload, "compact") - if !hasCompact { - compact = mode == defaultBrowserSnapshotModeEfficient +func browserActionAct(ctx context.Context, payload toolArgs, state *browserProfileState) (map[string]any, error) { + request := getMapArg(payload, "request") + if request == nil { + return nil, errors.New("request is required") } - - depth, hasDepth := getIntArg(payload, "depth") - if hasDepth && depth <= 0 { - hasDepth = false + requestArgs := toolArgs(request) + kind := strings.ToLower(strings.TrimSpace(getStringArg(requestArgs, "kind"))) + if kind == "" { + return nil, errors.New("request.kind is required") } - if !hasDepth && mode == defaultBrowserSnapshotModeEfficient { - depth = defaultBrowserSnapshotDepth - hasDepth = true + if !containsString(browserActKinds, kind) { + return nil, errors.New("act kind not supported: " + kind) } - - limit, ok := getIntArg(payload, "limit") - if !ok || limit <= 0 { - limit = defaultBrowserSnapshotLimit + if _, hasSelector := request["selector"]; hasSelector && kind != "wait" { + return nil, errors.New(browserSelectorUnsupportedMessage) } - - maxChars := 0 - _, hasMaxChars := payload["maxChars"] - if value, ok := getIntArg(payload, "maxChars"); ok && value > 0 { - maxChars = value + if state == nil || state.session == nil { + return nil, errors.New("browser session unavailable") } - if format == "ai" && !hasMaxChars { - if mode == defaultBrowserSnapshotModeEfficient { - maxChars = defaultBrowserSnapshotAIEfficientMaxChars - } else { - maxChars = defaultBrowserSnapshotAIMaxChars - } + actRequest := browsercdp.ActRequest{ + Kind: kind, + TargetID: firstNonEmptyString(strings.TrimSpace(getStringArg(requestArgs, "targetId")), strings.TrimSpace(getStringArg(payload, "targetId"))), + Ref: strings.TrimSpace(getStringArg(requestArgs, "ref")), + Text: getStringArg(requestArgs, "text"), + Key: strings.TrimSpace(getStringArg(requestArgs, "key")), + Value: getStringArg(requestArgs, "value"), + Expression: getStringArg(requestArgs, "expression"), + Limit: resolveBrowserSnapshotLimit(payload), + Timeout: browserActTimeoutDuration(requestArgs, payload, 15000), + Wait: browserWaitRequestFromArgs(requestArgs), + WaitFor: browserOptionalWaitRequest(getMapArg(requestArgs, "waitFor")), } - - selector := strings.TrimSpace(getStringArg(payload, "selector")) - frameSelector := strings.TrimSpace(getStringArg(payload, "frame")) - if refsMode == "aria" && (selector != "" || frameSelector != "") { - return nil, errors.New("refs=aria does not support selector/frame snapshots yet") + if width, ok := getIntArg(requestArgs, "width"); ok { + actRequest.Width = width } - - maxDepth := 0 - if hasDepth { - maxDepth = depth + if height, ok := getIntArg(requestArgs, "height"); ok { + actRequest.Height = height } - wantsRoleSnapshot := labels || - mode == defaultBrowserSnapshotModeEfficient || - interactive || - compact || - hasDepth || - selector != "" || - frameSelector != "" - if format == "ai" && !wantsRoleSnapshot { - aiResult, aiErr := collectBrowserAISnapshot(tab.Page, maxChars) - if aiErr == nil { - state.mu.Lock() - tab.refs = aiResult.Refs - state.mu.Unlock() - - result := map[string]any{ - "ok": true, - "format": "ai", - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - "snapshot": aiResult.Snapshot, - "truncated": aiResult.Truncated, - "refs": aiResult.RefsJSON, - "stats": aiResult.Stats, - } - return result, nil - } - if !isBrowserSnapshotForAIUnavailable(aiErr) { - return nil, aiErr - } + result, err := state.session.Act(ctx, actRequest) + if err != nil { + return nil, err } + return browserResultMap(result), nil +} - items, err := collectBrowserSnapshotItems(tab.Page, selector, frameSelector, interactive, limit, refsMode, maxDepth) +func browserActionReset(payload toolArgs, state *browserProfileState) (map[string]any, error) { + restart, _ := getBoolArg(payload, "restart") + result, err := state.session.Reset(restart) if err != nil { return nil, err } + output := browserResultMap(result) + output["profile"] = state.profileName + return output, nil +} - refs := make(map[string]browserSnapshotRef, len(items)) - refsJSON := make(map[string]map[string]any, len(items)) - lines := make([]string, 0, len(items)) - nodes := make([]map[string]any, 0, len(items)) - interactiveCount := 0 - for index, item := range items { - if index >= limit { - break - } - ref := strings.TrimSpace(item.Ref) - if ref == "" { - ref = fmt.Sprintf("e%d", index+1) - } - entryMode := refsMode - if entryMode == "aria" && strings.TrimSpace(item.AriaRef) == "" { - entryMode = "role" - } - entry := browserSnapshotRef{ - Role: item.Role, - Name: item.Name, - Nth: item.Nth, - Mode: entryMode, - AriaRef: item.AriaRef, - Frame: frameSelector, - } - refs[ref] = entry - refsJSON[ref] = map[string]any{ - "role": entry.Role, - "name": entry.Name, - } - if entry.Nth > 0 { - refsJSON[ref]["nth"] = entry.Nth - } - if isBrowserInteractiveRole(entry.Role) { - interactiveCount += 1 - } - line := fmt.Sprintf("[%s] role=%s name=%s", ref, entry.Role, entry.Name) - if !compact && strings.TrimSpace(item.Text) != "" { - line += " text=" + trimToMaxChars(item.Text, 120) - } - lines = append(lines, line) - nodeDepth := item.Depth - if maxDepth > 0 { - nodeDepth = minInt(nodeDepth, maxDepth) - } - nodes = append(nodes, map[string]any{ - "ref": ref, - "role": entry.Role, - "name": entry.Name, - "depth": nodeDepth, - }) +func browserCommandOptions(payload toolArgs, fallbackTimeoutMs int) browsercdp.CommandOptions { + return browsercdp.CommandOptions{ + Limit: resolveBrowserSnapshotLimit(payload), + Timeout: browserTimeoutDuration(payload, fallbackTimeoutMs), + WaitFor: browserOptionalWaitRequest(getMapArg(payload, "waitFor")), } +} - state.mu.Lock() - tab.refs = refs - state.mu.Unlock() - - if format == "aria" { - result := map[string]any{ - "ok": true, - "format": "aria", - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - "nodes": nodes, - } - return result, nil +func browserOptionalWaitRequest(raw map[string]any) *browsercdp.WaitRequest { + if raw == nil { + return nil } - - snapshot := strings.Join(lines, "\n") - truncated := false - if maxChars > 0 && len(snapshot) > maxChars { - snapshot = trimToMaxChars(snapshot, maxChars) - truncated = true + request := browserWaitRequestFromArgs(toolArgs(raw)) + if browserWaitRequestEmpty(request) { + return nil } + return &request +} - result := map[string]any{ - "ok": true, - "format": "ai", - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - "snapshot": snapshot, - "truncated": truncated, - "refs": refsJSON, - "stats": map[string]any{ - "lines": len(lines), - "chars": len(snapshot), - "refs": len(refsJSON), - "interactive": interactiveCount, - }, +func browserWaitRequestFromArgs(args toolArgs) browsercdp.WaitRequest { + request := browsercdp.WaitRequest{ + Selector: strings.TrimSpace(getStringArg(args, "selector")), + Text: strings.TrimSpace(getStringArg(args, "text")), + TextGone: strings.TrimSpace(getStringArg(args, "textGone")), + URL: strings.TrimSpace(getStringArg(args, "url")), + Fn: strings.TrimSpace(getStringArg(args, "fn")), } - - if labels { - img, err := tab.Page.Screenshot(playwright.PageScreenshotOptions{Type: playwright.ScreenshotTypePng}) - if err == nil { - path, writeErr := saveBrowserArtifact("png", img) - if writeErr == nil { - result["labels"] = true - result["imagePath"] = path - result["imageType"] = "png" - } - } + if timeMs, ok := getIntArg(args, "timeMs"); ok && timeMs > 0 { + request.Time = time.Duration(timeMs) * time.Millisecond } + if timeoutMs, ok := getIntArg(args, "timeoutMs"); ok && timeoutMs > 0 { + request.Timeout = time.Duration(timeoutMs) * time.Millisecond + } + return request +} - return result, nil +func browserWaitRequestEmpty(request browsercdp.WaitRequest) bool { + return request.Time <= 0 && + request.Selector == "" && + request.Text == "" && + request.TextGone == "" && + request.URL == "" && + request.Fn == "" } -func browserActionScreenshot(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } +func browserTimeoutDuration(payload toolArgs, fallbackTimeoutMs int) time.Duration { + return time.Duration(resolveBrowserActionTimeoutMs(payload, fallbackTimeoutMs)) * time.Millisecond +} - imageType := strings.ToLower(strings.TrimSpace(getStringArg(payload, "type"))) - if imageType != "jpeg" { - imageType = "png" - } +func browserActTimeoutDuration(request toolArgs, payload toolArgs, fallbackTimeoutMs int) time.Duration { + return time.Duration(resolveBrowserActTimeoutMs(request, payload, fallbackTimeoutMs)) * time.Millisecond +} - fullPage, _ := getBoolArg(payload, "fullPage") - ref := strings.TrimSpace(getStringArg(payload, "ref")) - element := strings.TrimSpace(getStringArg(payload, "element")) - if fullPage && (ref != "" || element != "") { - return nil, errors.New("fullPage is not supported for element screenshots") - } +func browserResultMap(result browsercdp.ActionResult) map[string]any { + data, _ := json.Marshal(result) + decoded := map[string]any{} + _ = json.Unmarshal(data, &decoded) + decoded["stateAvailable"] = result.State != nil || result.StateAvailable + decoded["itemCount"] = browserResultItemCount(result) + return decoded +} - var bytes []byte - if ref != "" { - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return nil, err - } - bytes, err = locator.Screenshot(playwright.LocatorScreenshotOptions{ - Type: toPlaywrightScreenshotType(imageType), - Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))), - }) - if err != nil { - return nil, err - } - } else if element != "" { - locator := tab.Page.Locator(strings.TrimSpace(element)) - bytes, err = locator.Screenshot(playwright.LocatorScreenshotOptions{ - Type: toPlaywrightScreenshotType(imageType), - Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))), - }) - if err != nil { - return nil, err - } - } else { - bytes, err = tab.Page.Screenshot(playwright.PageScreenshotOptions{ - Type: toPlaywrightScreenshotType(imageType), - FullPage: playwright.Bool(fullPage), - Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 15000))), - }) - if err != nil { - return nil, err +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) } } - - path, err := saveBrowserArtifact(imageType, bytes) - if err != nil { - return nil, err - } - return map[string]any{ - "ok": true, - "path": path, - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - }, nil -} - -func browserActionNavigate(ctx context.Context, payload toolArgs, state *browserProfileState, connectors ConnectorsReader) (map[string]any, error) { - targetURL := getStringArg(payload, "targetUrl", "url") - if targetURL == "" { - return nil, errors.New("targetUrl is required") - } - if err := assertBrowserURLAllowed(targetURL, state.resolved.SSRFRules); err != nil { - return nil, err - } - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } - - state.mu.Lock() - browserCtx := state.context - state.mu.Unlock() - if browserCtx != nil { - if err := addConnectorCookiesToContext(ctx, connectors, browserCtx, targetURL); err != nil { - return nil, err - } - } - - resp, err := tab.Page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{ - Timeout: playwright.Float(float64(resolveBrowserActionTimeoutMs(payload, 30000))), - WaitUntil: resolveBrowserWaitUntilState(resolveBrowserWaitUntil(payload, defaultBrowserWaitUntil)), - }) - if err != nil { - return nil, err - } - - finalURL := strings.TrimSpace(tab.Page.URL()) - if finalURL == "" { - finalURL = strings.TrimSpace(targetURL) - } - if err := assertBrowserURLAllowed(finalURL, state.resolved.SSRFRules); err != nil { - return nil, err - } - - status := http.StatusOK - if resp != nil { - status = resp.Status() - } - - return map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "url": finalURL, - "status": status, - }, nil -} - -func browserActionConsole(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, false) - if err != nil { - return nil, err - } - - level := strings.ToLower(strings.TrimSpace(getStringArg(payload, "level"))) - state.mu.Lock() - defer state.mu.Unlock() - - messages := make([]browserConsoleMessage, 0, len(state.consoleMessages)) - for _, item := range state.consoleMessages { - if item.TargetID != tab.TargetID { - continue - } - if level != "" && level != "all" && item.Type != level { - continue - } - messages = append(messages, item) - } - - return map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "messages": messages, - }, nil -} - -func browserActionPDF(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } - - bytes, err := tab.Page.PDF(playwright.PagePdfOptions{ - PrintBackground: playwright.Bool(true), - }) - if err != nil { - return nil, err - } - path, err := saveBrowserArtifact("pdf", bytes) - if err != nil { - return nil, err - } - return map[string]any{ - "ok": true, - "path": path, - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - }, nil -} - -func browserActionUpload(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - - paths, err := parseBrowserUploadPaths(payload) - if err != nil { - return nil, err - } - - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } - - inputRef := strings.TrimSpace(getStringArg(payload, "inputRef")) - ref := strings.TrimSpace(getStringArg(payload, "ref")) - element := strings.TrimSpace(getStringArg(payload, "element")) - if inputRef != "" || ref != "" || element != "" { - locator, err := resolveBrowserUploadLocator(tab, inputRef, ref, element) - if err != nil { - return nil, err - } - if err := locator.SetInputFiles(paths); err != nil { - return nil, err - } - return map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "paths": paths, - "armed": false, - }, nil - } - - timeoutMs := resolveBrowserActionTimeoutMs(payload, defaultBrowserHookTimeoutMs) - state.mu.Lock() - state.pendingUploads[tab.TargetID] = browserPendingUpload{ - Paths: append([]string(nil), paths...), - ExpiresAt: time.Now().Add(time.Duration(timeoutMs) * time.Millisecond), - } - state.mu.Unlock() - - return map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "paths": paths, - "armed": true, - }, nil -} - -func browserActionDialog(payload toolArgs, state *browserProfileState) (map[string]any, error) { - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - targetID := strings.TrimSpace(getStringArg(payload, "targetId")) - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } - accept, _ := getBoolArg(payload, "accept") - promptText := strings.TrimSpace(getStringArg(payload, "promptText")) - timeoutMs := resolveBrowserActionTimeoutMs(payload, defaultBrowserHookTimeoutMs) - - state.mu.Lock() - state.pendingDialogs[tab.TargetID] = browserPendingDialog{ - Accept: accept, - PromptText: promptText, - ExpiresAt: time.Now().Add(time.Duration(timeoutMs) * time.Millisecond), - } - state.mu.Unlock() - - return map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "armed": true, - "accept": accept, - }, nil -} - -func browserActionAct(payload toolArgs, state *browserProfileState) (map[string]any, error) { - request := getMapArg(payload, "request") - if request == nil { - return nil, errors.New("request is required") - } - kind := strings.ToLower(strings.TrimSpace(getStringArg(toolArgs(request), "kind"))) - if kind == "" { - return nil, errors.New("request.kind is required") - } - if !containsString(browserActKinds, kind) { - return nil, errors.New("act kind not supported: " + kind) - } - if _, hasSelector := request["selector"]; hasSelector && kind != "wait" { - return nil, errors.New(browserSelectorUnsupportedMessage) - } - if err := ensureBrowserProfileStarted(state); err != nil { - return nil, err - } - - targetID := strings.TrimSpace(getStringArg(toolArgs(request), "targetId")) - if targetID == "" { - targetID = strings.TrimSpace(getStringArg(payload, "targetId")) - } - tab, err := resolveBrowserTab(state, targetID, true) - if err != nil { - return nil, err - } - - switch kind { - case "click": - err = browserActClick(tab, toolArgs(request), payload) - case "type": - err = browserActType(tab, toolArgs(request), payload) - case "press": - err = browserActPress(tab, toolArgs(request), payload) - case "hover": - err = browserActHover(tab, toolArgs(request), payload) - case "drag": - err = browserActDrag(tab, toolArgs(request), payload) - case "select": - err = browserActSelect(tab, toolArgs(request), payload) - case "fill": - err = browserActFill(tab, toolArgs(request), payload) - case "resize": - err = browserActResize(tab, toolArgs(request)) - case "wait": - err = browserActWait(tab, toolArgs(request), state.resolved.EvaluateEnabled) - case "evaluate": - err = browserActEvaluate(tab, toolArgs(request), payload, state.resolved.EvaluateEnabled) - case "close": - err = browserActClose(tab, state) - default: - err = errors.New("act kind not supported: " + kind) - } - if err != nil { - return nil, err - } - - result := map[string]any{ - "ok": true, - "targetId": tab.TargetID, - "url": strings.TrimSpace(tab.Page.URL()), - } - if kind == "evaluate" { - result["result"] = tabResultFromEvaluate(tab) - } - return result, nil -} - -func tabResultFromEvaluate(tab *browserTabState) any { - if tab == nil { - return nil - } - tab.mu.RLock() - defer tab.mu.RUnlock() - return tab.evaluateResult -} - -func browserActClick(tab *browserTabState, request toolArgs, payload toolArgs) error { - ref := strings.TrimSpace(getStringArg(request, "ref")) - if ref == "" { - return errors.New("ref is required") - } - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return err - } - clickOptions := playwright.LocatorClickOptions{} - if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 { - clickOptions.Timeout = playwright.Float(float64(timeout)) - } - buttonRaw := strings.TrimSpace(getStringArg(request, "button")) - if buttonRaw != "" { - button := toPlaywrightMouseButton(buttonRaw) - if button == nil { - return errors.New("button must be left|right|middle") - } - clickOptions.Button = button - } - modifiers, err := toPlaywrightKeyboardModifiers(getStringSliceArg(request, "modifiers")) - if err != nil { - return err - } - if len(modifiers) > 0 { - clickOptions.Modifiers = modifiers - } - if doubleClick, _ := getBoolArg(request, "doubleClick"); doubleClick { - dblOptions := playwright.LocatorDblclickOptions{} - if clickOptions.Timeout != nil { - dblOptions.Timeout = clickOptions.Timeout - } - if clickOptions.Button != nil { - dblOptions.Button = clickOptions.Button - } - if len(clickOptions.Modifiers) > 0 { - dblOptions.Modifiers = clickOptions.Modifiers - } - if err := locator.Dblclick(dblOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - return nil - } - if err := locator.Click(clickOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - return nil -} - -func browserActType(tab *browserTabState, request toolArgs, payload toolArgs) error { - ref := strings.TrimSpace(getStringArg(request, "ref")) - if ref == "" { - return errors.New("ref is required") - } - textRaw, ok := request["text"] - if !ok { - return errors.New("text is required") - } - text, ok := textRaw.(string) - if !ok { - return errors.New("text is required") - } - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return err - } - timeout := resolveBrowserActTimeoutMs(request, payload, 10000) - if slowly, _ := getBoolArg(request, "slowly"); slowly { - typeOptions := playwright.LocatorTypeOptions{} - if timeout > 0 { - typeOptions.Timeout = playwright.Float(float64(timeout)) - } - typeOptions.Delay = playwright.Float(75) - if err := locator.Click(playwright.LocatorClickOptions{Timeout: playwright.Float(float64(timeout))}); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - if err := locator.Type(text, typeOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - } else { - fillOptions := playwright.LocatorFillOptions{} - if timeout > 0 { - fillOptions.Timeout = playwright.Float(float64(timeout)) - } - if err := locator.Fill(text, fillOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - } - if submit, _ := getBoolArg(request, "submit"); submit { - pressOptions := playwright.LocatorPressOptions{} - if timeout > 0 { - pressOptions.Timeout = playwright.Float(float64(timeout)) - } - if err := locator.Press("Enter", pressOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - } - return nil -} - -func browserActPress(tab *browserTabState, request toolArgs, payload toolArgs) error { - key := strings.TrimSpace(getStringArg(request, "key")) - if key == "" { - return errors.New("key is required") - } - options := playwright.KeyboardPressOptions{} - if delayMs, ok := getIntArg(request, "delayMs"); ok && delayMs >= 0 { - options.Delay = playwright.Float(float64(delayMs)) - } - if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 { - tab.Page.SetDefaultTimeout(float64(timeout)) - defer tab.Page.SetDefaultTimeout(30000) - } - return tab.Page.Keyboard().Press(key, options) -} - -func browserActHover(tab *browserTabState, request toolArgs, payload toolArgs) error { - ref := strings.TrimSpace(getStringArg(request, "ref")) - if ref == "" { - return errors.New("ref is required") - } - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return err - } - hoverOptions := playwright.LocatorHoverOptions{} - if timeout := resolveBrowserActTimeoutMs(request, payload, 10000); timeout > 0 { - hoverOptions.Timeout = playwright.Float(float64(timeout)) - } - if err := locator.Hover(hoverOptions); err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - return nil -} - -func browserActDrag(tab *browserTabState, request toolArgs, payload toolArgs) error { - startRef := strings.TrimSpace(getStringArg(request, "startRef")) - endRef := strings.TrimSpace(getStringArg(request, "endRef")) - if startRef == "" || endRef == "" { - return errors.New("startRef and endRef are required") - } - startLocator, err := resolveBrowserRefLocator(tab, startRef) - if err != nil { - return err - } - endLocator, err := resolveBrowserRefLocator(tab, endRef) - if err != nil { - return err - } - dragOptions := playwright.LocatorDragToOptions{} - if timeout := resolveBrowserActTimeoutMs(request, payload, 12000); timeout > 0 { - dragOptions.Timeout = playwright.Float(float64(timeout)) - } - if err := startLocator.DragTo(endLocator, dragOptions); err != nil { - return toBrowserFriendlyInteractionError(err, startRef+" -> "+endRef) - } - return nil -} - -func browserActSelect(tab *browserTabState, request toolArgs, payload toolArgs) error { - ref := strings.TrimSpace(getStringArg(request, "ref")) - values := getStringSliceArg(request, "values") - if ref == "" || len(values) == 0 { - return errors.New("ref and values are required") - } - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return err - } - timeout := resolveBrowserActTimeoutMs(request, payload, 10000) - _, err = locator.SelectOption(playwright.SelectOptionValues{Values: &values}, playwright.LocatorSelectOptionOptions{ - Timeout: playwright.Float(float64(timeout)), - }) - if err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - return nil -} - -func browserActFill(tab *browserTabState, request toolArgs, payload toolArgs) error { - rawFields, ok := request["fields"].([]any) - if !ok || len(rawFields) == 0 { - return errors.New("fields are required") - } - timeout := resolveBrowserActTimeoutMs(request, payload, 10000) - for _, raw := range rawFields { - field, ok := raw.(map[string]any) - if !ok { - continue - } - ref := strings.TrimSpace(getStringArg(toolArgs(field), "ref")) - if ref == "" { - continue - } - typeValue := strings.ToLower(strings.TrimSpace(getStringArg(toolArgs(field), "type"))) - locator, err := resolveBrowserRefLocator(tab, ref) - if err != nil { - return err - } - switch typeValue { - case "checkbox", "radio", "bool", "boolean": - checked, _ := getBoolArg(toolArgs(field), "value") - if checked { - err = locator.Check(playwright.LocatorCheckOptions{Timeout: playwright.Float(float64(timeout))}) - } else { - err = locator.Uncheck(playwright.LocatorUncheckOptions{Timeout: playwright.Float(float64(timeout))}) - } - case "select", "option": - value := strings.TrimSpace(getStringArg(toolArgs(field), "value")) - if value == "" { - continue - } - values := []string{value} - _, err = locator.SelectOption(playwright.SelectOptionValues{Values: &values}, playwright.LocatorSelectOptionOptions{ - Timeout: playwright.Float(float64(timeout)), - }) - default: - value := strings.TrimSpace(getStringArg(toolArgs(field), "value")) - err = locator.Fill(value, playwright.LocatorFillOptions{Timeout: playwright.Float(float64(timeout))}) - } - if err != nil { - return toBrowserFriendlyInteractionError(err, ref) - } - } - return nil -} - -func browserActResize(tab *browserTabState, request toolArgs) error { - width, hasWidth := getIntArg(request, "width") - height, hasHeight := getIntArg(request, "height") - if !hasWidth || !hasHeight || width <= 0 || height <= 0 { - return errors.New("width and height are required") - } - return tab.Page.SetViewportSize(width, height) -} - -func browserActWait(tab *browserTabState, request toolArgs, evaluateEnabled bool) error { - timeMs, hasTime := getIntArg(request, "timeMs") - text := strings.TrimSpace(getStringArg(request, "text")) - textGone := strings.TrimSpace(getStringArg(request, "textGone")) - selector := strings.TrimSpace(getStringArg(request, "selector")) - urlWait := strings.TrimSpace(getStringArg(request, "url")) - loadState := strings.ToLower(strings.TrimSpace(getStringArg(request, "loadState"))) - fn := strings.TrimSpace(getStringArg(request, "fn")) - timeoutMs := resolveBrowserActionTimeoutMs(request, 15000) - - if fn != "" && !evaluateEnabled { - return errors.New(browserWaitFnDisabledMessage) - } - - hasCondition := false - if hasTime && timeMs > 0 { - hasCondition = true - tab.Page.WaitForTimeout(float64(timeMs)) - } - if text != "" { - hasCondition = true - if err := tab.Page.Locator("text=" + text).First().WaitFor(playwright.LocatorWaitForOptions{ - State: playwright.WaitForSelectorStateVisible, - Timeout: playwright.Float(float64(timeoutMs)), - }); err != nil { - return err - } - } - if textGone != "" { - hasCondition = true - if err := tab.Page.Locator("text=" + textGone).First().WaitFor(playwright.LocatorWaitForOptions{ - State: playwright.WaitForSelectorStateHidden, - Timeout: playwright.Float(float64(timeoutMs)), - }); err != nil { - return err - } - } - if selector != "" { - hasCondition = true - if _, err := tab.Page.WaitForSelector(selector, playwright.PageWaitForSelectorOptions{ - State: playwright.WaitForSelectorStateVisible, - Timeout: playwright.Float(float64(timeoutMs)), - }); err != nil { - return err - } - } - if urlWait != "" { - hasCondition = true - if err := tab.Page.WaitForURL(urlWait, playwright.PageWaitForURLOptions{ - Timeout: playwright.Float(float64(timeoutMs)), - }); err != nil { - return err - } - } - if loadState != "" { - hasCondition = true - state := resolveBrowserLoadState(loadState) - if state == nil { - return errors.New("loadState must be load|domcontentloaded|networkidle") - } - if err := tab.Page.WaitForLoadState(playwright.PageWaitForLoadStateOptions{ - State: state, - Timeout: playwright.Float(float64(timeoutMs)), - }); err != nil { - return err - } - } - if fn != "" { - hasCondition = true - if err := waitBrowserEvaluateCondition(tab.Page, fn, timeoutMs); err != nil { - return err - } - } - if !hasCondition { - return errors.New(browserWaitRequiresConditionMessage) - } - return nil -} - -func browserActEvaluate(tab *browserTabState, request toolArgs, payload toolArgs, evaluateEnabled bool) error { - if !evaluateEnabled { - return errors.New(browserEvaluateDisabledMessage) - } - fn := strings.TrimSpace(getStringArg(request, "fn")) - if fn == "" { - return errors.New("fn is required") - } - timeoutMs := resolveBrowserActTimeoutMs(request, payload, 20000) - ref := strings.TrimSpace(getStringArg(request, "ref")) - var ( - result any - err error - ) - if ref != "" { - locator, resolveErr := resolveBrowserRefLocator(tab, ref) - if resolveErr != nil { - return resolveErr - } - result, err = locator.Evaluate(browserEvaluateElementExpression, []any{fn, timeoutMs}) - } else { - result, err = tab.Page.Evaluate(browserEvaluatePageExpression, []any{fn, timeoutMs}) - } - if err != nil { - return err - } - tab.mu.Lock() - tab.evaluateResult = result - tab.mu.Unlock() - return nil -} - -func browserActClose(tab *browserTabState, state *browserProfileState) error { - if tab == nil || tab.Page == nil { - return errors.New("tab not found") - } - if err := tab.Page.Close(); err != nil { - return err - } - state.mu.Lock() - delete(state.tabs, tab.TargetID) - delete(state.pageToTarget, tab.Page) - if state.activeTarget == tab.TargetID { - state.activeTarget = "" - } - pruneClosedTabsLocked(state) - state.mu.Unlock() - return nil -} - -type browserSnapshotItem struct { - Ref string - AriaRef string - Role string - Name string - Text string - Depth int - Nth int -} - -var browserAriaSnapshotLinePattern = regexp.MustCompile(`^(\s*)-\s*([^\s":]+)(?:\s+"([^"]*)")?(.*)$`) -var browserAriaSnapshotRefPattern = regexp.MustCompile(`\[ref=([^\]]+)\]`) -var browserStrictModeCountPattern = regexp.MustCompile(`resolved to (\d+) elements`) -var browserEvaluatePageExpression = `([fnBody, timeoutMs]) => { - try { - const candidate = eval("(" + fnBody + ")"); - const result = typeof candidate === "function" ? candidate() : candidate; - if (result && typeof result.then === "function") { - return Promise.race([ - result, - new Promise((_, reject) => - setTimeout(() => reject(new Error("evaluate timed out after " + timeoutMs + "ms")), timeoutMs), - ), - ]); - } - return result; - } catch (err) { - throw new Error("Invalid evaluate function: " + (err && err.message ? err.message : String(err))); - } -}` -var browserEvaluateElementExpression = `(el, [fnBody, timeoutMs]) => { - try { - const candidate = eval("(" + fnBody + ")"); - const result = typeof candidate === "function" ? candidate(el) : candidate; - if (result && typeof result.then === "function") { - return Promise.race([ - result, - new Promise((_, reject) => - setTimeout(() => reject(new Error("evaluate timed out after " + timeoutMs + "ms")), timeoutMs), - ), - ]); - } - return result; - } catch (err) { - throw new Error("Invalid evaluate function: " + (err && err.message ? err.message : String(err))); - } -}` - -type browserAISnapshotResult struct { - Snapshot string - Truncated bool - Refs map[string]browserSnapshotRef - RefsJSON map[string]map[string]any - Stats map[string]any -} - -func collectBrowserAISnapshot(page playwright.Page, maxChars int) (*browserAISnapshotResult, error) { - snapshot, err := captureBrowserPrivateAISnapshot(page, 5000) - if err != nil { - return nil, err - } - items := parseBrowserAriaSnapshot(snapshot, false, 2000, true, 0) - refs := make(map[string]browserSnapshotRef, len(items)) - refsJSON := make(map[string]map[string]any, len(items)) - interactiveCount := 0 - for index, item := range items { - ref := strings.TrimSpace(item.Ref) - if ref == "" { - ref = fmt.Sprintf("e%d", index+1) - } - entryMode := "aria" - if strings.TrimSpace(item.AriaRef) == "" { - entryMode = "role" - } - entry := browserSnapshotRef{ - Role: item.Role, - Name: item.Name, - Nth: item.Nth, - Mode: entryMode, - AriaRef: item.AriaRef, - } - refs[ref] = entry - refsJSON[ref] = map[string]any{ - "role": entry.Role, - "name": entry.Name, - } - if entry.Nth > 0 { - refsJSON[ref]["nth"] = entry.Nth - } - if isBrowserInteractiveRole(entry.Role) { - interactiveCount += 1 - } - } - - truncated := false - if maxChars > 0 && len(snapshot) > maxChars { - snapshot = trimToMaxChars(snapshot, maxChars) - truncated = true - } - - lines := 0 - for _, line := range strings.Split(snapshot, "\n") { - if strings.TrimSpace(line) == "" { - continue - } - lines += 1 - } - return &browserAISnapshotResult{ - Snapshot: snapshot, - Truncated: truncated, - Refs: refs, - RefsJSON: refsJSON, - Stats: map[string]any{ - "lines": lines, - "chars": len(snapshot), - "refs": len(refsJSON), - "interactive": interactiveCount, - }, - }, nil -} - -func captureBrowserPrivateAISnapshot(page playwright.Page, timeoutMs int) (snapshot string, err error) { - defer func() { - if recover() != nil { - snapshot = "" - err = errBrowserSnapshotForAIUnavailable - } - }() - if page == nil { - return "", errBrowserSnapshotForAIUnavailable - } - - value := reflect.ValueOf(page) - if !value.IsValid() { - return "", errBrowserSnapshotForAIUnavailable - } - if value.Kind() == reflect.Interface { - value = value.Elem() - } - if value.Kind() == reflect.Pointer { - value = value.Elem() - } - if !value.IsValid() || value.Kind() != reflect.Struct { - return "", errBrowserSnapshotForAIUnavailable - } - - channelOwnerField := value.FieldByName("channelOwner") - if !channelOwnerField.IsValid() { - return "", errBrowserSnapshotForAIUnavailable - } - channelField := channelOwnerField.FieldByName("channel") - if !channelField.IsValid() || (channelField.Kind() == reflect.Pointer && channelField.IsNil()) { - return "", errBrowserSnapshotForAIUnavailable - } - sendMethod := channelField.MethodByName("Send") - if !sendMethod.IsValid() { - return "", errBrowserSnapshotForAIUnavailable - } - - calls := sendMethod.Call([]reflect.Value{ - reflect.ValueOf("snapshotForAI"), - reflect.ValueOf(map[string]any{ - "timeout": normalizeBrowserTimeoutMs(timeoutMs, 5000), - "track": "response", - }), - }) - if len(calls) != 2 { - return "", errBrowserSnapshotForAIUnavailable - } - if errValue := calls[1]; errValue.IsValid() && !errValue.IsNil() { - if invokeErr, ok := errValue.Interface().(error); ok { - return "", invokeErr - } - return "", fmt.Errorf("snapshotForAI call failed") - } - - result := calls[0].Interface() - snapshot = extractBrowserAISnapshotText(result) - if strings.TrimSpace(snapshot) == "" { - return "", errBrowserSnapshotForAIUnavailable - } - return snapshot, nil -} - -func extractBrowserAISnapshotText(raw any) string { - switch value := raw.(type) { - case map[string]any: - if full, ok := value["full"]; ok { - return fmt.Sprint(full) - } - if snapshot, ok := value["snapshot"]; ok { - return fmt.Sprint(snapshot) - } - case string: - return value - } - return "" -} - -func isBrowserSnapshotForAIUnavailable(err error) bool { - if err == nil { - return false - } - if errors.Is(err, errBrowserSnapshotForAIUnavailable) { - return true - } - message := strings.ToLower(strings.TrimSpace(err.Error())) - return strings.Contains(message, "_snapshotforai") || - strings.Contains(message, "snapshotforai") || - strings.Contains(message, "playwright snapshotforai is unavailable") -} - -func collectBrowserSnapshotItems( - page playwright.Page, - selector string, - frameSelector string, - interactive bool, - limit int, - refsMode string, - maxDepth int, -) ([]browserSnapshotItem, error) { - if limit <= 0 { - limit = defaultBrowserSnapshotLimit - } - locator := resolveBrowserSnapshotLocator(page, selector, frameSelector) - snapshot, err := locator.AriaSnapshot(playwright.LocatorAriaSnapshotOptions{ - Ref: playwright.Bool(strings.EqualFold(refsMode, "aria")), - }) - if err != nil { - return nil, err - } - return parseBrowserAriaSnapshot(snapshot, interactive, limit, strings.EqualFold(refsMode, "aria"), maxDepth), nil -} - -func resolveBrowserSnapshotLocator(page playwright.Page, selector string, frameSelector string) playwright.Locator { - selector = strings.TrimSpace(selector) - frameSelector = strings.TrimSpace(frameSelector) - if frameSelector != "" { - frame := page.FrameLocator(frameSelector) - if selector != "" { - return frame.Locator(selector) - } - return frame.Locator(":root") - } - if selector != "" { - return page.Locator(selector) - } - return page.Locator(":root") -} - -func parseBrowserAriaSnapshot( - snapshot string, - interactive bool, - limit int, - useAriaRefs bool, - maxDepth int, -) []browserSnapshotItem { - if limit <= 0 { - limit = defaultBrowserSnapshotLimit - } - lines := strings.Split(snapshot, "\n") - items := make([]browserSnapshotItem, 0, minInt(len(lines), limit)) - countByRoleName := map[string]int{} - nextRef := 0 - - for _, rawLine := range lines { - if len(items) >= limit { - break - } - line := strings.TrimRight(rawLine, "\r") - if strings.TrimSpace(line) == "" { - continue - } - - match := browserAriaSnapshotLinePattern.FindStringSubmatch(line) - if len(match) < 5 { - continue - } - role := strings.ToLower(strings.TrimSpace(match[2])) - if role == "" || strings.HasPrefix(role, "/") { - continue - } - if interactive && !isBrowserInteractiveRole(role) { - continue - } - - name := strings.TrimSpace(match[3]) - suffix := strings.TrimSpace(match[4]) - text := browserAriaSnapshotTextFromSuffix(suffix) - depth := browserAriaSnapshotDepth(line) - if maxDepth > 0 && depth > maxDepth { - continue - } - - ariaRef := "" - if refMatch := browserAriaSnapshotRefPattern.FindStringSubmatch(suffix); len(refMatch) >= 2 { - ariaRef = strings.TrimSpace(refMatch[1]) - } - - key := role + "\n" + name - nth := countByRoleName[key] - countByRoleName[key] = nth + 1 - - nextRef += 1 - ref := fmt.Sprintf("e%d", nextRef) - if useAriaRefs && ariaRef != "" { - ref = ariaRef - } - - item := browserSnapshotItem{ - Ref: ref, - AriaRef: ariaRef, - Role: role, - Name: name, - Text: text, - Depth: depth, - Nth: nth, - } - if item.Role == "" { - item.Role = "element" - } - if item.Name == "" { - item.Name = item.Text - } - if item.Name == "" { - item.Name = "(empty)" - } - items = append(items, item) - } - return items -} - -func browserAriaSnapshotDepth(line string) int { - spaces := 0 - for _, r := range line { - if r == ' ' { - spaces += 1 - continue - } - if r == '\t' { - spaces += 2 - continue - } - break - } - return spaces / 2 -} - -func browserAriaSnapshotTextFromSuffix(suffix string) string { - suffix = strings.TrimSpace(suffix) - if suffix == "" { - return "" - } - if idx := strings.LastIndex(suffix, ":"); idx >= 0 && idx+1 < len(suffix) { - return strings.TrimSpace(suffix[idx+1:]) - } - return "" -} - -func isBrowserInteractiveRole(role string) bool { - switch strings.ToLower(strings.TrimSpace(role)) { - case "button", - "link", - "textbox", - "checkbox", - "radio", - "combobox", - "listbox", - "menuitem", - "menuitemcheckbox", - "menuitemradio", - "option", - "searchbox", - "slider", - "spinbutton", - "switch", - "tab", - "treeitem": - return true - default: - return false - } -} - -func resolveBrowserUploadLocator(tab *browserTabState, inputRef string, ref string, element string) (playwright.Locator, error) { - switch { - case strings.TrimSpace(inputRef) != "": - return tab.Page.Locator(strings.TrimSpace(inputRef)), nil - case strings.TrimSpace(ref) != "": - return resolveBrowserRefLocator(tab, strings.TrimSpace(ref)) - case strings.TrimSpace(element) != "": - return tab.Page.Locator(strings.TrimSpace(element)), nil - default: - return nil, errors.New("upload requires ref/inputRef/element") - } -} - -func parseBrowserUploadPaths(payload toolArgs) ([]string, error) { - raw := getStringSliceArg(payload, "paths") - if len(raw) == 0 { - return nil, errors.New("paths are required") - } - rootDir, err := resolveBrowserUploadRootDir() - if err != nil { - return nil, err - } - paths := make([]string, 0, len(raw)) - for _, item := range raw { - trimmed := strings.TrimSpace(item) - if trimmed == "" { - continue - } - absPath, err := resolvePathWithinRoot(rootDir, trimmed) - if err != nil { - return nil, err - } - info, err := os.Stat(absPath) - if err != nil { - return nil, fmt.Errorf("upload path not found: %s", absPath) - } - if info.IsDir() { - return nil, fmt.Errorf("upload path must be a file: %s", absPath) - } - paths = append(paths, absPath) - } - if len(paths) == 0 { - return nil, errors.New("paths are required") - } - return paths, nil -} - -func resolveBrowserUploadRootDir() (string, error) { - dir := filepath.Join(os.TempDir(), "dreamcreator", "browser", "uploads") - if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err - } - return filepath.Abs(dir) -} - -func resolvePathWithinRoot(rootDir string, requestedPath string) (string, error) { - root := strings.TrimSpace(rootDir) - if root == "" { - return "", errors.New("upload root is required") - } - rootAbs, err := filepath.Abs(root) - if err != nil { - return "", err - } - raw := strings.TrimSpace(requestedPath) - if raw == "" { - return "", errors.New("paths are required") - } - - var candidate string - if filepath.IsAbs(raw) { - candidate = filepath.Clean(raw) - } else { - candidate = filepath.Join(rootAbs, raw) - } - candidateAbs, err := filepath.Abs(candidate) - if err != nil { - return "", err - } - rel, err := filepath.Rel(rootAbs, candidateAbs) - if err != nil { - return "", err - } - rel = filepath.ToSlash(strings.TrimSpace(rel)) - if rel == "." || strings.HasPrefix(rel, "../") || rel == ".." { - return "", fmt.Errorf("Invalid path: must stay within uploads directory (%s)", rootAbs) - } - return candidateAbs, nil -} - -func resolveBrowserRefLocator(tab *browserTabState, ref string) (playwright.Locator, error) { - if tab == nil { - return nil, errors.New("tab unavailable") - } - ref = strings.TrimSpace(ref) - if ref == "" { - return nil, errors.New("ref is required") - } - if tab.refs == nil { - return nil, errors.New("no snapshot refs available; run action=snapshot first") - } - entry, ok := tab.refs[ref] - if !ok { - return nil, fmt.Errorf("ref not found: %s (run action=snapshot again)", ref) - } - - mode := strings.ToLower(strings.TrimSpace(entry.Mode)) - if mode == "aria" { - ariaRef := strings.TrimSpace(entry.AriaRef) - if ariaRef == "" { - ariaRef = ref - } - if ariaRef == "" { - return nil, fmt.Errorf("ref selector missing for %s", ref) - } - locator := resolveBrowserScopedLocator(tab.Page, entry.Frame, "aria-ref="+ariaRef) - return locator, nil - } - if strings.TrimSpace(entry.Role) != "" { - locator, err := resolveBrowserRoleLocator(tab.Page, entry.Role, entry.Name, entry.Nth, entry.Frame) - if err == nil { - return locator, nil - } - // Fall back to selector-based resolution below. - } - - selector := strings.TrimSpace(entry.Selector) - if selector == "" { - return nil, fmt.Errorf("ref selector missing for %s", ref) - } - return resolveBrowserScopedLocator(tab.Page, entry.Frame, selector), nil -} - -func resolveBrowserScopedLocator(page playwright.Page, frameSelector string, selector string) playwright.Locator { - frameSelector = strings.TrimSpace(frameSelector) - selector = strings.TrimSpace(selector) - if frameSelector == "" { - if strings.HasPrefix(selector, "//") || strings.HasPrefix(selector, "/") { - return page.Locator("xpath=" + selector) - } - return page.Locator(selector) - } - frame := page.FrameLocator(frameSelector) - if strings.HasPrefix(selector, "//") || strings.HasPrefix(selector, "/") { - return frame.Locator("xpath=" + selector) - } - return frame.Locator(selector) -} - -func resolveBrowserRoleLocator(page playwright.Page, role string, name string, nth int, frameSelector string) (playwright.Locator, error) { - role = strings.TrimSpace(strings.ToLower(role)) - if role == "" { - return nil, errors.New("role is required") - } - - frameSelector = strings.TrimSpace(frameSelector) - name = strings.TrimSpace(name) - - var locator playwright.Locator - if frameSelector != "" { - frame := page.FrameLocator(frameSelector) - if name != "" { - locator = frame.GetByRole(playwright.AriaRole(role), playwright.FrameLocatorGetByRoleOptions{ - Name: name, - Exact: playwright.Bool(true), - }) - } else { - locator = frame.GetByRole(playwright.AriaRole(role)) - } - } else { - if name != "" { - locator = page.GetByRole(playwright.AriaRole(role), playwright.PageGetByRoleOptions{ - Name: name, - Exact: playwright.Bool(true), - }) - } else { - locator = page.GetByRole(playwright.AriaRole(role)) - } - } - if nth > 0 { - locator = locator.Nth(nth) - } - return locator, nil -} - -func listBrowserTabs(state *browserProfileState) ([]map[string]any, error) { - state.mu.Lock() - defer state.mu.Unlock() - pruneClosedTabsLocked(state) - items := make([]map[string]any, 0, len(state.tabs)) - ids := make([]string, 0, len(state.tabs)) - for id := range state.tabs { - ids = append(ids, id) - } - sort.Strings(ids) - for _, id := range ids { - tab := state.tabs[id] - title, _ := tab.Page.Title() - items = append(items, map[string]any{ - "targetId": tab.TargetID, - "title": strings.TrimSpace(title), - "url": strings.TrimSpace(tab.Page.URL()), - "type": "page", - }) - } - return items, nil -} - -func resolveBrowserTab(state *browserProfileState, targetID string, autoCreate bool) (*browserTabState, error) { - state.mu.Lock() - pruneClosedTabsLocked(state) - - targetID = strings.TrimSpace(targetID) - if targetID != "" { - if tab, ok := state.tabs[targetID]; ok { - state.activeTarget = tab.TargetID - state.mu.Unlock() - return tab, nil - } - matches := make([]*browserTabState, 0, len(state.tabs)) - targetLower := strings.ToLower(targetID) - for id, tab := range state.tabs { - if strings.HasPrefix(strings.ToLower(strings.TrimSpace(id)), targetLower) { - matches = append(matches, tab) - } - } - if len(matches) == 1 { - state.activeTarget = matches[0].TargetID - state.mu.Unlock() - return matches[0], nil - } - if len(matches) > 1 { - state.mu.Unlock() - return nil, errors.New("ambiguous target id prefix") - } - if len(state.tabs) == 1 { - for _, only := range state.tabs { - state.activeTarget = only.TargetID - state.mu.Unlock() - return only, nil - } - } - state.mu.Unlock() - return nil, errors.New("tab not found") - } - - if state.activeTarget != "" { - if tab, ok := state.tabs[state.activeTarget]; ok { - state.mu.Unlock() - return tab, nil - } - } - if len(state.tabs) == 1 { - for _, only := range state.tabs { - state.activeTarget = only.TargetID - state.mu.Unlock() - return only, nil - } - } - if len(state.tabs) > 1 { - ids := make([]string, 0, len(state.tabs)) - for id := range state.tabs { - ids = append(ids, id) - } - sort.Strings(ids) - first := state.tabs[ids[0]] - state.activeTarget = first.TargetID - state.mu.Unlock() - return first, nil - } - - browserCtx := state.context - state.mu.Unlock() - if !autoCreate || browserCtx == nil { - return nil, errors.New("no browser tab available") - } - page, err := browserCtx.NewPage() - if err != nil { - return nil, err - } - tab := attachBrowserTab(state, page) - return tab, nil -} - -func attachBrowserTab(state *browserProfileState, page playwright.Page) *browserTabState { - state.mu.Lock() - defer state.mu.Unlock() - if targetID, ok := state.pageToTarget[page]; ok { - if existing, exists := state.tabs[targetID]; exists { - state.activeTarget = existing.TargetID - return existing - } - } - targetID := fmt.Sprintf("T%d", atomic.AddUint64(&browserGlobalTabCounter, 1)) - tab := &browserTabState{ - TargetID: targetID, - Page: page, - refs: map[string]browserSnapshotRef{}, - } - state.tabs[targetID] = tab - state.pageToTarget[page] = targetID - state.activeTarget = targetID - attachBrowserPageObservers(state, tab) - return tab -} - -func attachBrowserPageObservers(state *browserProfileState, tab *browserTabState) { - if tab == nil || tab.Page == nil { - return - } - targetID := tab.TargetID - - tab.Page.OnConsole(func(message playwright.ConsoleMessage) { - entry := browserConsoleMessage{ - TargetID: targetID, - Type: strings.ToLower(strings.TrimSpace(message.Type())), - Text: trimToMaxChars(strings.TrimSpace(message.Text()), 4000), - Timestamp: time.Now().UTC().Format(time.RFC3339), - } - state.mu.Lock() - state.consoleMessages = append(state.consoleMessages, entry) - if len(state.consoleMessages) > 400 { - state.consoleMessages = append([]browserConsoleMessage(nil), state.consoleMessages[len(state.consoleMessages)-400:]...) - } - state.mu.Unlock() - }) - - tab.Page.OnDialog(func(dialog playwright.Dialog) { - state.mu.Lock() - pending, ok := state.pendingDialogs[targetID] - if ok { - delete(state.pendingDialogs, targetID) - } - state.mu.Unlock() - if ok && time.Now().Before(pending.ExpiresAt) { - if pending.Accept { - if pending.PromptText != "" { - _ = dialog.Accept(pending.PromptText) - } else { - _ = dialog.Accept() - } - } else { - _ = dialog.Dismiss() - } - return - } - _ = dialog.Dismiss() - }) - - tab.Page.OnFileChooser(func(chooser playwright.FileChooser) { - state.mu.Lock() - pending, ok := state.pendingUploads[targetID] - if ok { - delete(state.pendingUploads, targetID) - } - state.mu.Unlock() - if ok && time.Now().Before(pending.ExpiresAt) { - _ = chooser.SetFiles(pending.Paths) - } - }) -} - -func ensureBrowserProfileStarted(state *browserProfileState) error { - state.mu.Lock() - defer state.mu.Unlock() - if state.browser != nil && state.context != nil { - pruneClosedTabsLocked(state) - return nil - } - - var err error - if state.pw == nil { - state.pw, err = playwright.Run(&playwright.RunOptions{Verbose: false, Stdout: io.Discard, Stderr: io.Discard}) - if err != nil { - return err - } - } - - launchOptions := playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(state.resolved.Headless), - } - args := append([]string(nil), state.resolved.ExtraArgs...) - if state.resolved.NoSandbox { - args = append(args, "--no-sandbox") - } - if state.resolved.Headless && !hasBrowserHeadlessArg(args) { - args = append(args, "--headless=new") - } - if len(args) > 0 { - launchOptions.Args = args - } - state.browser, err = state.pw.Chromium.Launch(launchOptions) - if err != nil { - return err - } - state.context, err = state.browser.NewContext(playwright.BrowserNewContextOptions{ - Viewport: &playwright.Size{Width: defaultBrowserViewportWidth, Height: defaultBrowserViewportHeight}, - }) - if err != nil { - _ = state.browser.Close() - state.browser = nil - return err - } - - if state.tabs == nil { - state.tabs = map[string]*browserTabState{} - } - if state.pageToTarget == nil { - state.pageToTarget = map[playwright.Page]string{} - } - if state.pendingUploads == nil { - state.pendingUploads = map[string]browserPendingUpload{} - } - if state.pendingDialogs == nil { - state.pendingDialogs = map[string]browserPendingDialog{} - } - for _, page := range state.context.Pages() { - if page == nil || page.IsClosed() { - continue - } - if targetID, ok := state.pageToTarget[page]; ok { - if _, exists := state.tabs[targetID]; exists { - continue - } - } - targetID := fmt.Sprintf("T%d", atomic.AddUint64(&browserGlobalTabCounter, 1)) - tab := &browserTabState{TargetID: targetID, Page: page, refs: map[string]browserSnapshotRef{}} - state.tabs[targetID] = tab - state.pageToTarget[page] = targetID - if state.activeTarget == "" { - state.activeTarget = targetID - } - attachBrowserPageObservers(state, tab) - } - pruneClosedTabsLocked(state) - return nil -} - -func stopBrowserProfile(state *browserProfileState) error { - state.mu.Lock() - defer state.mu.Unlock() - - if state.context != nil { - _ = state.context.Close() - state.context = nil - } - if state.browser != nil { - _ = state.browser.Close() - state.browser = nil - } - if state.pw != nil { - _ = state.pw.Stop() - state.pw = nil - } - - state.tabs = map[string]*browserTabState{} - state.pageToTarget = map[playwright.Page]string{} - state.activeTarget = "" - state.pendingUploads = map[string]browserPendingUpload{} - state.pendingDialogs = map[string]browserPendingDialog{} - state.consoleMessages = nil - return nil -} - -func pruneClosedTabsLocked(state *browserProfileState) { - if state == nil { - return - } - for targetID, tab := range state.tabs { - if tab == nil || tab.Page == nil || tab.Page.IsClosed() { - delete(state.tabs, targetID) - if tab != nil && tab.Page != nil { - delete(state.pageToTarget, tab.Page) - } - if state.activeTarget == targetID { - state.activeTarget = "" - } - } - } - if state.activeTarget != "" { - if _, ok := state.tabs[state.activeTarget]; ok { - return - } - } - if len(state.tabs) == 0 { - state.activeTarget = "" - return - } - ids := make([]string, 0, len(state.tabs)) - for id := range state.tabs { - ids = append(ids, id) - } - sort.Strings(ids) - state.activeTarget = ids[0] -} - -func getOrCreateBrowserProfileState(sessionKey string, profileName string, resolved browserResolvedConfig) *browserProfileState { - globalBrowserSessions.mu.Lock() - defer globalBrowserSessions.mu.Unlock() - - session, ok := globalBrowserSessions.sessions[sessionKey] - if !ok { - session = &browserSessionState{ - sessionKey: sessionKey, - profiles: map[string]*browserProfileState{}, - } - globalBrowserSessions.sessions[sessionKey] = session - } - - state, ok := session.profiles[profileName] - if !ok { - profile := resolved.Profiles[profileName] - state = &browserProfileState{ - profileName: profileName, - resolved: resolved, - profile: profile, - tabs: map[string]*browserTabState{}, - pageToTarget: map[playwright.Page]string{}, - pendingUploads: map[string]browserPendingUpload{}, - pendingDialogs: map[string]browserPendingDialog{}, - consoleMessages: nil, - } - session.profiles[profileName] = state - return state - } - - state.mu.Lock() - state.resolved = resolved - if profile, exists := resolved.Profiles[profileName]; exists { - state.profile = profile - } - state.profileName = profileName - state.mu.Unlock() - return state -} - -func resolveBrowserRuntimeConfig(config map[string]any) browserResolvedConfig { - browserConfig := resolveBrowserConfig(config) - resolved := browserResolvedConfig{ - Enabled: true, - EvaluateEnabled: true, - Color: defaultBrowserColor, - Headless: false, - NoSandbox: false, - DefaultProfile: defaultBrowserProfileDreamCreator, - Profiles: map[string]browserProfileConfig{ - defaultBrowserProfileDreamCreator: { - Name: defaultBrowserProfileDreamCreator, - Color: defaultBrowserColor, - Driver: browserTypePlaywright, - }, - }, - SSRFRules: browserSSRFPolicy{ - DangerouslyAllowPrivateNetwork: true, - AllowedHostnames: map[string]struct{}{}, - HostnameAllowlist: nil, - }, - ExtraArgs: nil, - } - if browserConfig == nil { - return resolved - } - - if value, ok := getBoolArg(toolArgs(browserConfig), "enabled"); ok { - resolved.Enabled = value - } - if value, ok := getBoolArg(toolArgs(browserConfig), "evaluateEnabled"); ok { - resolved.EvaluateEnabled = value - } - if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "color")); value != "" { - resolved.Color = value - } - if value, ok := getBoolArg(toolArgs(browserConfig), "headless"); ok { - resolved.Headless = value - } - if value, ok := getBoolArg(toolArgs(browserConfig), "noSandbox"); ok { - resolved.NoSandbox = value - } - if values := normalizeStringSlice(getStringSliceArg(toolArgs(browserConfig), "extraArgs")); len(values) > 0 { - resolved.ExtraArgs = values - } - - if snapshotDefaults := getMapArg(toolArgs(browserConfig), "snapshotDefaults"); snapshotDefaults != nil { - if value := strings.TrimSpace(getStringArg(toolArgs(snapshotDefaults), "mode")); value == defaultBrowserSnapshotModeEfficient { - resolved.SnapshotDefaultMode = value - } - } - - if ssrfRaw := getMapArg(toolArgs(browserConfig), "ssrfPolicy"); ssrfRaw != nil { - if value, ok := getBoolArg(toolArgs(ssrfRaw), "dangerouslyAllowPrivateNetwork"); ok { - resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value - } else if value, ok := getBoolArg(toolArgs(ssrfRaw), "allowPrivateNetwork"); ok { - resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value - } - for _, hostname := range getStringSliceArg(toolArgs(ssrfRaw), "allowedHostnames") { - resolved.SSRFRules.AllowedHostnames[strings.ToLower(strings.TrimSpace(hostname))] = struct{}{} - } - if allowlist := normalizeStringSlice(getStringSliceArg(toolArgs(ssrfRaw), "hostnameAllowlist")); len(allowlist) > 0 { - resolved.SSRFRules.HostnameAllowlist = allowlist - } - } - - if _, ok := resolved.Profiles[resolved.DefaultProfile]; !ok { - resolved.DefaultProfile = defaultBrowserProfileDreamCreator - } - for name, profile := range resolved.Profiles { - if profile.Name == "" { - profile.Name = name - } - if profile.Color == "" { - profile.Color = resolved.Color - } - if profile.Driver == "" { - profile.Driver = browserTypePlaywright - } - resolved.Profiles[name] = profile - } - - return resolved -} - -func resolveBrowserProfileName(payload toolArgs, resolved browserResolvedConfig) string { - profile := strings.TrimSpace(getStringArg(payload, "profile")) - if profile == "" { - profile = strings.TrimSpace(resolved.DefaultProfile) - } - if profile == "" { - profile = defaultBrowserProfileDreamCreator - } - if _, ok := resolved.Profiles[profile]; !ok { - profileCfg := browserProfileConfig{ - Name: profile, - Color: resolved.Color, - Driver: browserTypePlaywright, - } - resolved.Profiles[profile] = profileCfg - } - return profile -} - -func resolveBrowserConfig(config map[string]any) map[string]any { - return getNestedMap(config, "browser") -} - -func resolveBrowserConfigBool(config map[string]any, key string) (bool, bool) { - return getBoolArg(toolArgs(resolveBrowserConfig(config)), key) -} - -func resolveBrowserType(payload toolArgs, config map[string]any) string { - if raw := getStringArg(payload, "type", "mode", "engine"); raw != "" { - return normalizeBrowserType(raw) - } - if payloadBrowser := getMapArg(payload, "browser"); payloadBrowser != nil { - if raw := getStringArg(toolArgs(payloadBrowser), "type", "mode", "engine"); raw != "" { - return normalizeBrowserType(raw) - } - } - if raw := getStringArg(toolArgs(resolveBrowserConfig(config)), "type", "mode", "engine"); raw != "" { - return normalizeBrowserType(raw) - } - return defaultBrowserType -} - -func normalizeBrowserType(value string) string { - switch strings.ToLower(strings.TrimSpace(value)) { - case browserTypePlaywright, "browser", "headless", "chromium": - return browserTypePlaywright - default: - return "" - } -} - -func resolveBrowserWaitUntil(payload toolArgs, fallback string) string { - if value := normalizeBrowserWaitUntil(getStringArg(payload, "waitUntil")); value != "" { - return value - } - return normalizeBrowserWaitUntil(fallback) -} - -func resolveBrowserWaitUntilState(value string) *playwright.WaitUntilState { - switch normalizeBrowserWaitUntil(value) { - case "commit": - return playwright.WaitUntilStateCommit - case "load": - return playwright.WaitUntilStateLoad - case "networkidle": - return playwright.WaitUntilStateNetworkidle - default: - return playwright.WaitUntilStateDomcontentloaded - } -} - -func normalizeBrowserWaitUntil(value string) string { - switch strings.ToLower(strings.TrimSpace(value)) { - case "commit": - return "commit" - case "load": - return "load" - case "networkidle", "network_idle": - return "networkidle" - case "domcontentloaded", "dom_content_loaded", "dom-content-loaded": - return "domcontentloaded" - default: - return "" - } -} - -func resolveBrowserLoadState(value string) *playwright.LoadState { - switch strings.ToLower(strings.TrimSpace(value)) { - case "load": - return playwright.LoadStateLoad - case "domcontentloaded", "dom_content_loaded", "dom-content-loaded": - return playwright.LoadStateDomcontentloaded - case "networkidle", "network_idle": - return playwright.LoadStateNetworkidle - default: - return nil - } -} - -func resolveBrowserActionTimeoutMs(payload toolArgs, fallback int) int { - if timeoutMs, ok := getIntArg(payload, "timeoutMs"); ok && timeoutMs > 0 { - return normalizeBrowserTimeoutMs(timeoutMs, fallback) - } - if timeoutSeconds, ok := getIntArg(payload, "timeoutSeconds"); ok && timeoutSeconds > 0 { - return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback) - } - return normalizeBrowserTimeoutMs(fallback, fallback) -} - -func resolveBrowserActTimeoutMs(request toolArgs, payload toolArgs, fallback int) int { - if timeoutMs, ok := getIntArg(request, "timeoutMs"); ok && timeoutMs > 0 { - return normalizeBrowserTimeoutMs(timeoutMs, fallback) - } - if timeoutSeconds, ok := getIntArg(request, "timeoutSeconds"); ok && timeoutSeconds > 0 { - return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback) - } - return resolveBrowserActionTimeoutMs(payload, fallback) -} - -func normalizeBrowserTimeoutMs(value int, fallback int) int { - if value <= 0 { - value = fallback - } - if value < 500 { - return 500 - } - if value > 120000 { - return 120000 - } - return value -} - -func toBrowserFriendlyInteractionError(err error, selector string) error { - if err == nil { - return nil - } - message := strings.TrimSpace(err.Error()) - if message == "" { - return err - } - if strings.Contains(message, "strict mode violation") { - count := "multiple" - if match := browserStrictModeCountPattern.FindStringSubmatch(message); len(match) >= 2 { - count = strings.TrimSpace(match[1]) - } - return fmt.Errorf(`Selector "%s" matched %s elements. Run a new snapshot to get updated refs, or use a different ref.`, selector, count) - } - if (strings.Contains(message, "Timeout") || strings.Contains(message, "waiting for")) && - (strings.Contains(message, "to be visible") || strings.Contains(message, "not visible")) { - return fmt.Errorf(`Element "%s" not found or not visible. Run a new snapshot to see current page elements.`, selector) - } - if strings.Contains(message, "intercepts pointer events") || - strings.Contains(message, "not receive pointer events") || - strings.Contains(message, "not visible") { - return fmt.Errorf(`Element "%s" is not interactable (hidden or covered). Try scrolling it into view, closing overlays, or re-snapshotting.`, selector) - } - return err -} - -func addConnectorCookiesToContext(ctx context.Context, connectors ConnectorsReader, browserCtx playwright.BrowserContext, targetURL string) error { - if connectors == nil || browserCtx == nil { - return nil - } - cookies, err := resolveConnectorCookiesForURL(ctx, connectors, targetURL) - if err != nil { - return err - } - if len(cookies) == 0 { - return nil - } - return browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL)) -} - -func assertBrowserURLAllowed(rawURL string, policy browserSSRFPolicy) error { - trimmed := strings.TrimSpace(rawURL) - if trimmed == "" { - return errors.New("targetUrl is required") - } - parsed, err := url.Parse(trimmed) - if err != nil { - return fmt.Errorf("invalid url: %w", err) - } - if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { - return errors.New("only http(s) urls are supported") - } - hostname := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - if hostname == "" { - return errors.New("url hostname is required") - } - if isHostnameExplicitlyAllowed(hostname, policy) { - return nil - } - if policy.DangerouslyAllowPrivateNetwork { - return nil - } - if hostname == "localhost" || strings.HasSuffix(hostname, ".local") || strings.HasSuffix(hostname, ".internal") { - return fmt.Errorf("blocked private hostname: %s", hostname) - } - if ip := net.ParseIP(hostname); ip != nil { - if isPrivateOrLocalIP(ip) { - return fmt.Errorf("blocked private IP: %s", hostname) - } - } - return nil -} - -func isHostnameExplicitlyAllowed(hostname string, policy browserSSRFPolicy) bool { - hostname = strings.ToLower(strings.TrimSpace(hostname)) - if hostname == "" { - return false - } - if _, ok := policy.AllowedHostnames[hostname]; ok { - return true - } - for _, pattern := range policy.HostnameAllowlist { - pattern = strings.ToLower(strings.TrimSpace(pattern)) - if pattern == "" { - continue - } - if pattern == hostname { - return true - } - if strings.HasPrefix(pattern, "*.") { - suffix := strings.TrimPrefix(pattern, "*.") - if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix { - return true - } - continue - } - if matched, _ := filepath.Match(pattern, hostname); matched { - return true - } - } - return false -} - -func isPrivateOrLocalIP(ip net.IP) bool { - if ip == nil { - return false - } - if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return true - } - if ip.IsUnspecified() || ip.IsMulticast() { - return true - } - if ip4 := ip.To4(); ip4 != nil { - // Carrier-grade NAT: 100.64.0.0/10 - if ip4[0] == 100 && (ip4[1]&0xC0) == 64 { - return true - } - } - return false -} - -func waitBrowserEvaluateCondition(page playwright.Page, fn string, timeoutMs int) error { - if timeoutMs <= 0 { - timeoutMs = 15000 - } - deadline := time.Now().Add(time.Duration(timeoutMs) * time.Millisecond) - for { - result, err := page.Evaluate(fn) - if err == nil { - if value, ok := result.(bool); ok && value { - return nil - } - if result != nil { - switch typed := result.(type) { - case string: - if strings.TrimSpace(strings.ToLower(typed)) == "true" { - return nil - } - case float64: - if typed != 0 { - return nil - } - case int: - if typed != 0 { - return nil - } - } - } - } - if time.Now().After(deadline) { - if err != nil { - return err - } - return errors.New("wait fn timeout") - } - time.Sleep(200 * time.Millisecond) - } -} - -func saveBrowserArtifact(ext string, content []byte) (string, error) { - ext = strings.TrimSpace(strings.TrimPrefix(ext, ".")) - if ext == "" { - ext = "bin" - } - dir := filepath.Join(os.TempDir(), "dreamcreator", "browser") - if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err - } - filename := fmt.Sprintf("%d-%d.%s", time.Now().UnixNano(), atomic.AddUint64(&browserGlobalTabCounter, 1), ext) - path := filepath.Join(dir, filename) - if err := os.WriteFile(path, content, 0o644); err != nil { - return "", err - } - abs, err := filepath.Abs(path) - if err != nil { - return path, nil - } - return abs, nil -} - -func resolveBrowserPlaywrightRuntimeAvailability() (bool, string, string) { - now := time.Now() - browserPlaywrightRuntimeCache.mu.Lock() - if !browserPlaywrightRuntimeCache.checkedAt.IsZero() && now.Sub(browserPlaywrightRuntimeCache.checkedAt) < browserRuntimeCheckCacheTTL { - available := browserPlaywrightRuntimeCache.available - reason := browserPlaywrightRuntimeCache.reason - execPath := browserPlaywrightRuntimeCache.execPath - browserPlaywrightRuntimeCache.mu.Unlock() - return available, reason, execPath - } - browserPlaywrightRuntimeCache.mu.Unlock() - - available := true - reason := "" - execPath := "" - - pw, err := playwright.Run(&playwright.RunOptions{ - Verbose: false, - Stdout: io.Discard, - Stderr: io.Discard, - }) - if err != nil { - available = false - reason = trimToMaxChars(strings.TrimSpace(err.Error()), 220) - } else { - defer pw.Stop() - execPath = strings.TrimSpace(pw.Chromium.ExecutablePath()) - - browser, launchErr := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(true), - Args: []string{"--headless=new"}, - }) - if launchErr != nil { - available = false - reason = trimToMaxChars(strings.TrimSpace(launchErr.Error()), 220) - } else { - _ = browser.Close() - } - } - - browserPlaywrightRuntimeCache.mu.Lock() - browserPlaywrightRuntimeCache.checkedAt = now - browserPlaywrightRuntimeCache.available = available - browserPlaywrightRuntimeCache.reason = reason - browserPlaywrightRuntimeCache.execPath = execPath - browserPlaywrightRuntimeCache.mu.Unlock() - - return available, reason, execPath -} - -func hasBrowserHeadlessArg(args []string) bool { - for _, arg := range args { - if strings.Contains(strings.ToLower(strings.TrimSpace(arg)), "headless") { - return true - } - } - return false -} - -func toPlaywrightScreenshotType(value string) *playwright.ScreenshotType { - switch strings.ToLower(strings.TrimSpace(value)) { - case "jpeg", "jpg": - return playwright.ScreenshotTypeJpeg - default: - return playwright.ScreenshotTypePng - } -} - -func toPlaywrightMouseButton(value string) *playwright.MouseButton { - switch strings.ToLower(strings.TrimSpace(value)) { - case "left": - return playwright.MouseButtonLeft - case "right": - return playwright.MouseButtonRight - case "middle": - return playwright.MouseButtonMiddle - default: - return nil - } -} - -func toPlaywrightKeyboardModifiers(values []string) ([]playwright.KeyboardModifier, error) { - if len(values) == 0 { - return nil, nil - } - result := make([]playwright.KeyboardModifier, 0, len(values)) - for _, value := range values { - switch strings.ToLower(strings.TrimSpace(value)) { - case "alt": - if playwright.KeyboardModifierAlt != nil { - result = append(result, *playwright.KeyboardModifierAlt) - } - case "control", "ctrl": - if playwright.KeyboardModifierControl != nil { - result = append(result, *playwright.KeyboardModifierControl) - } - case "controlormeta", "control_or_meta": - if playwright.KeyboardModifierControlOrMeta != nil { - result = append(result, *playwright.KeyboardModifierControlOrMeta) - } - case "meta", "command", "cmd": - if playwright.KeyboardModifierMeta != nil { - result = append(result, *playwright.KeyboardModifierMeta) - } - case "shift": - if playwright.KeyboardModifierShift != nil { - result = append(result, *playwright.KeyboardModifierShift) - } - case "": - continue - default: - return nil, errors.New("modifiers must be Alt|Control|ControlOrMeta|Meta|Shift") - } - } - if len(result) == 0 { - return nil, nil - } - return result, nil -} - -func containsString(values []string, candidate string) bool { - for _, value := range values { - if value == candidate { - return true - } - } - return false -} - -func anyToString(value any) string { - switch typed := value.(type) { - case string: - return typed - case fmt.Stringer: - return typed.String() - case float64: - return strconv.FormatFloat(typed, 'f', -1, 64) - case int: - return strconv.Itoa(typed) - case int64: - return strconv.FormatInt(typed, 10) - default: - return "" - } -} - -func anyToInt(value any) int { - switch typed := value.(type) { - case int: - return typed - case int64: - return int(typed) - case float64: - return int(typed) - case string: - parsed, _ := strconv.Atoi(strings.TrimSpace(typed)) - return parsed - default: - return 0 - } -} - -func minInt(a int, b int) int { - if a <= b { - return a - } - return b + return "" } -func maxInt(a int, b int) int { - if a >= b { - return a +func browserResultItemCount(result browsercdp.ActionResult) int { + if result.State != nil && result.State.ItemCount > 0 { + return result.State.ItemCount } - return b + return len(result.Items) } diff --git a/internal/application/gateway/tools/browser_tools_helpers.go b/internal/application/gateway/tools/browser_tools_helpers.go new file mode 100644 index 0000000..9b9d363 --- /dev/null +++ b/internal/application/gateway/tools/browser_tools_helpers.go @@ -0,0 +1,779 @@ +package tools + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "regexp" + "strings" + "sync/atomic" + "time" + + "dreamcreator/internal/application/browsercdp" + gatewaynodes "dreamcreator/internal/application/gateway/nodes" + targetpkg "github.com/chromedp/cdproto/target" +) + +const ( + browserTypeCDP = "cdp" + defaultBrowserType = browserTypeCDP + + defaultBrowserWaitUntil = "domcontentloaded" + + defaultBrowserProfileDreamCreator = "dreamcreator" + defaultBrowserColor = "#FF4500" + + defaultBrowserSnapshotAIMaxChars = 80000 + defaultBrowserSnapshotAIEfficientMaxChars = 10000 + defaultBrowserSnapshotDepth = 6 + defaultBrowserSnapshotLimit = 200 + defaultBrowserViewportWidth = 1366 + defaultBrowserViewportHeight = 900 +) + +var browserToolActions = []string{ + "open", + "navigate", + "snapshot", + "act", + "wait", + "scroll", + "upload", + "dialog", + "reset", +} + +var browserSelectorUnsupportedMessage = strings.Join([]string{ + "Error: 'selector' is not supported. Use 'ref' from snapshot instead.", + "", + "Example workflow:", + "1. open or navigate to load the page", + "2. read returned items, or run snapshot to get fresh refs", + `3. act with ref: "e123" to interact with an element`, + "4. after the page changes, read returned items or run snapshot again", + "", + "This is more reliable for modern SPAs.", +}, "\n") + +var browserWaitRequiresConditionMessage = "wait requires at least one of: timeMs, text, textGone, selector, url, fn" +var errBrowserNoOpenTab = errors.New("no browser tab is open") +var errBrowserSnapshotForAIUnavailable = errors.New("browser snapshotForAI is unavailable") +var browserActKinds = []string{"click", "type", "press", "hover", "select", "fill", "resize", "wait", "evaluate", "close"} +var browserGlobalTabCounter uint64 + +type browserSnapshotItem struct { + Ref string `json:"ref,omitempty"` + AriaRef string `json:"ariaRef,omitempty"` + Role string `json:"role,omitempty"` + Name string `json:"name,omitempty"` + Text string `json:"text,omitempty"` + Depth int `json:"depth,omitempty"` + Nth int `json:"nth,omitempty"` +} + +var browserAriaSnapshotLinePattern = regexp.MustCompile(`^(\s*)-\s*([^\s":]+)(?:\s+"([^"]*)")?(.*)$`) +var browserAriaSnapshotRefPattern = regexp.MustCompile(`\[ref=([^\]]+)\]`) +var browserStrictModeCountPattern = regexp.MustCompile(`resolved to (\d+) elements`) + +type browserResolvedConfig struct { + Enabled bool + Color string + Headless bool + DefaultProfile string + PreferredBrowser string + Profiles map[string]browserProfileConfig + SSRFRules browserSSRFPolicy +} + +type browserProfileConfig struct { + Name string + Color string + Driver string +} + +type browserSSRFPolicy struct { + DangerouslyAllowPrivateNetwork bool + AllowedHostnames map[string]struct{} + HostnameAllowlist []string +} + +type browserNodeProxyEnvelope struct { + Result any `json:"result"` + Files []browserNodeProxyFile `json:"files"` +} + +type browserNodeProxyFile struct { + Path string `json:"path"` + Base64 string `json:"base64"` + MimeType string `json:"mimeType"` +} + +func resolveBrowserAction(payload toolArgs) (string, error) { + rawAction := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action"))) + if rawAction == "" { + return "", errors.New("browser action is required") + } + switch rawAction { + case "open", "navigate", "snapshot", "act", "wait", "scroll", "upload", "dialog", "reset": + return rawAction, nil + default: + return "", errors.New("browser action not supported: " + rawAction) + } +} + +func pickReusableBrowserTargetID(infos []*targetpkg.Info) string { + choose := func(requireUnattached bool, preferBlank bool) string { + for _, info := range infos { + if info == nil || info.Type != "page" { + continue + } + if requireUnattached && info.Attached { + continue + } + if preferBlank && !isReusableBrowserPageURL(info.URL) { + continue + } + return string(info.TargetID) + } + return "" + } + for _, candidate := range []string{ + choose(true, true), + choose(true, false), + choose(false, true), + choose(false, false), + } { + if strings.TrimSpace(candidate) != "" { + return candidate + } + } + return "" +} + +func isReusableBrowserPageURL(rawURL string) bool { + trimmed := strings.TrimSpace(strings.ToLower(rawURL)) + switch trimmed { + case "", "about:blank", "chrome://newtab/", "chrome-search://local-ntp/local-ntp.html": + return true + default: + return false + } +} + +func shouldTreatBrowserNavigationAsComplete(observedURL string, previousURL string, targetURL string) bool { + observed := strings.TrimSpace(observedURL) + if observed == "" || observed == "about:blank" { + return false + } + if urlsEqual(observed, targetURL) { + return true + } + previous := strings.TrimSpace(previousURL) + if previous == "" || previous == "about:blank" { + return true + } + return !urlsEqual(observed, previous) +} + +func urlsEqual(left string, right string) bool { + return strings.TrimSpace(left) == strings.TrimSpace(right) +} + +func shouldResetBrowserProfileAfterError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "browser runtime unavailable"), + strings.Contains(message, "context canceled"), + strings.Contains(message, "target closed"), + strings.Contains(message, "connection closed"), + strings.Contains(message, "websocket"), + strings.Contains(message, "session closed"), + strings.Contains(message, "browser session reset"): + return true + default: + return false + } +} + +func shouldDeferBrowserStateCaptureError(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(message, "context deadline exceeded"), + strings.Contains(message, "execution context was destroyed"), + strings.Contains(message, "cannot find context with specified id"), + strings.Contains(message, "unique context id not found"): + return true + default: + return false + } +} + +func isBrowserNodeTargetRequest(payload toolArgs) bool { + target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target"))) + nodeID := strings.TrimSpace(getStringArg(payload, "node", "nodeId")) + return target == "node" || nodeID != "" +} + +func resolveBrowserNodeID(ctx context.Context, payload toolArgs, nodes *gatewaynodes.Service) (string, error) { + requestedNode := strings.TrimSpace(getStringArg(payload, "node", "nodeId")) + if requestedNode != "" { + return requestedNode, nil + } + if nodes == nil { + return "", errors.New("nodes service unavailable") + } + list, err := nodes.ListNodes(ctx) + if err != nil { + return "", err + } + for _, descriptor := range list { + nodeID := strings.TrimSpace(descriptor.NodeID) + if nodeID == "" { + continue + } + for _, capability := range descriptor.Capabilities { + if strings.EqualFold(strings.TrimSpace(capability.Name), "browser.control") { + return nodeID, nil + } + } + } + for _, descriptor := range list { + if nodeID := strings.TrimSpace(descriptor.NodeID); nodeID != "" { + return nodeID, nil + } + } + return "", errors.New("nodeId is required") +} + +func runBrowserActionOnNode(ctx context.Context, payload toolArgs, action string, nodes *gatewaynodes.Service) (string, error) { + if nodes == nil { + return "", errors.New("nodes service unavailable") + } + target := strings.ToLower(strings.TrimSpace(getStringArg(payload, "target"))) + if target != "" && target != "node" { + return "", errors.New(`node is only supported with target="node"`) + } + nodeID, err := resolveBrowserNodeID(ctx, payload, nodes) + if err != nil { + return "", err + } + argsJSON, err := json.Marshal(payload) + if err != nil { + return "", err + } + request := gatewaynodes.NodeInvokeRequest{ + NodeID: nodeID, + Capability: "browser.control", + Action: action, + Args: string(argsJSON), + TimeoutMs: resolveBrowserActionTimeoutMs(payload, 30000), + } + result, invokeErr := nodes.Invoke(ctx, request) + if invokeErr != nil { + return marshalResult(result), invokeErr + } + if !result.Ok { + if strings.TrimSpace(result.Error) != "" { + return marshalResult(result), errors.New(strings.TrimSpace(result.Error)) + } + return marshalResult(result), errors.New("node browser invoke failed") + } + if parsed := resolveBrowserNodeOutput(result.Output); parsed != nil { + return marshalResult(parsed), nil + } + return marshalResult(result), nil +} + +func resolveBrowserNodeOutput(output string) any { + trimmedOutput := strings.TrimSpace(output) + if trimmedOutput == "" { + return nil + } + var parsed any + if err := json.Unmarshal([]byte(trimmedOutput), &parsed); err != nil { + return nil + } + envelope := browserNodeProxyEnvelope{} + if err := json.Unmarshal([]byte(trimmedOutput), &envelope); err == nil && envelope.Result != nil { + mapping := persistBrowserNodeProxyFiles(envelope.Files) + applyBrowserProxyPathMapping(envelope.Result, mapping) + return envelope.Result + } + return parsed +} + +func persistBrowserNodeProxyFiles(files []browserNodeProxyFile) map[string]string { + if len(files) == 0 { + return nil + } + mapping := map[string]string{} + for _, file := range files { + remotePath := strings.TrimSpace(file.Path) + encoded := strings.TrimSpace(file.Base64) + if remotePath == "" || encoded == "" { + continue + } + bytes, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + continue + } + localPath, err := saveBrowserArtifact(resolveBrowserProxyFileExt(file), bytes) + if err != nil { + continue + } + mapping[remotePath] = localPath + } + if len(mapping) == 0 { + return nil + } + return mapping +} + +func resolveBrowserProxyFileExt(file browserNodeProxyFile) string { + ext := strings.TrimSpace(strings.TrimPrefix(filepath.Ext(strings.TrimSpace(file.Path)), ".")) + if ext != "" { + return ext + } + mimeType := strings.ToLower(strings.TrimSpace(file.MimeType)) + switch { + case strings.Contains(mimeType, "png"): + return "png" + case strings.Contains(mimeType, "jpeg"), strings.Contains(mimeType, "jpg"): + return "jpg" + case strings.Contains(mimeType, "pdf"): + return "pdf" + case strings.Contains(mimeType, "json"): + return "json" + case strings.Contains(mimeType, "text"), strings.Contains(mimeType, "plain"): + return "txt" + default: + return "bin" + } +} + +func applyBrowserProxyPathMapping(result any, mapping map[string]string) { + if len(mapping) == 0 || result == nil { + return + } + obj, ok := result.(map[string]any) + if !ok { + return + } + if pathValue, ok := obj["path"].(string); ok { + if mapped, exists := mapping[pathValue]; exists { + obj["path"] = mapped + } + } + if imagePathValue, ok := obj["imagePath"].(string); ok { + if mapped, exists := mapping[imagePathValue]; exists { + obj["imagePath"] = mapped + } + } +} + +func resolveBrowserSessionKey(ctx context.Context, payload toolArgs) string { + sessionKey, _ := RuntimeContextFromContext(ctx) + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + sessionKey = strings.TrimSpace(getStringArg(payload, "sessionKey", "session_key")) + } + if sessionKey == "" { + sessionKey = "default" + } + return sessionKey +} + +func resolveBrowserRuntimeConfig(config map[string]any) browserResolvedConfig { + browserConfig := getNestedMap(config, "browser") + resolved := browserResolvedConfig{ + Enabled: true, + Color: defaultBrowserColor, + Headless: false, + DefaultProfile: defaultBrowserProfileDreamCreator, + PreferredBrowser: string(browsercdp.BrowserChrome), + Profiles: map[string]browserProfileConfig{ + defaultBrowserProfileDreamCreator: { + Name: defaultBrowserProfileDreamCreator, + Color: defaultBrowserColor, + Driver: browserTypeCDP, + }, + }, + SSRFRules: browserSSRFPolicy{ + DangerouslyAllowPrivateNetwork: false, + AllowedHostnames: map[string]struct{}{}, + }, + } + if browserConfig == nil { + return resolved + } + if value, ok := getBoolArg(toolArgs(browserConfig), "enabled"); ok { + resolved.Enabled = value + } + if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "color")); value != "" { + resolved.Color = value + } + if value, ok := getBoolArg(toolArgs(browserConfig), "headless"); ok { + resolved.Headless = value + } + if value := strings.TrimSpace(getStringArg(toolArgs(browserConfig), "preferredBrowser")); value != "" { + resolved.PreferredBrowser = strings.ToLower(value) + } + if ssrfRaw := getMapArg(toolArgs(browserConfig), "ssrfPolicy"); ssrfRaw != nil { + if value, ok := getBoolArg(toolArgs(ssrfRaw), "dangerouslyAllowPrivateNetwork"); ok { + resolved.SSRFRules.DangerouslyAllowPrivateNetwork = value + } + for _, hostname := range getStringSliceArg(toolArgs(ssrfRaw), "allowedHostnames") { + resolved.SSRFRules.AllowedHostnames[strings.ToLower(strings.TrimSpace(hostname))] = struct{}{} + } + if allowlist := normalizeStringSlice(getStringSliceArg(toolArgs(ssrfRaw), "hostnameAllowlist")); len(allowlist) > 0 { + resolved.SSRFRules.HostnameAllowlist = allowlist + } + } + return resolved +} + +func resolveBrowserProfileName(payload toolArgs, resolved browserResolvedConfig) string { + profile := strings.TrimSpace(getStringArg(payload, "profile")) + if profile == "" { + profile = strings.TrimSpace(resolved.DefaultProfile) + } + if profile == "" { + profile = defaultBrowserProfileDreamCreator + } + return profile +} + +func resolveBrowserConfig(config map[string]any) map[string]any { + return getNestedMap(config, "browser") +} + +func resolveBrowserConfigBool(config map[string]any, key string) (bool, bool) { + return getBoolArg(toolArgs(resolveBrowserConfig(config)), key) +} + +func resolveBrowserActionTimeoutMs(payload toolArgs, fallback int) int { + if timeoutMs, ok := getIntArg(payload, "timeoutMs"); ok && timeoutMs > 0 { + return normalizeBrowserTimeoutMs(timeoutMs, fallback) + } + if timeoutSeconds, ok := getIntArg(payload, "timeoutSeconds"); ok && timeoutSeconds > 0 { + return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback) + } + return normalizeBrowserTimeoutMs(fallback, fallback) +} + +func resolveBrowserActTimeoutMs(request toolArgs, payload toolArgs, fallback int) int { + if timeoutMs, ok := getIntArg(request, "timeoutMs"); ok && timeoutMs > 0 { + return normalizeBrowserTimeoutMs(timeoutMs, fallback) + } + if timeoutSeconds, ok := getIntArg(request, "timeoutSeconds"); ok && timeoutSeconds > 0 { + return normalizeBrowserTimeoutMs(timeoutSeconds*1000, fallback) + } + return resolveBrowserActionTimeoutMs(payload, fallback) +} + +func resolveBrowserSnapshotLimit(payload toolArgs) int { + if value, ok := getIntArg(payload, "limit"); ok && value > 0 { + return value + } + return defaultBrowserSnapshotLimit +} + +func resolveBrowserScrollDelta(payload toolArgs) (int, int) { + if x, ok := getIntArg(payload, "x"); ok { + if y, ok := getIntArg(payload, "y"); ok { + return x, y + } + return x, 0 + } + if y, ok := getIntArg(payload, "y"); ok { + return 0, y + } + amount, ok := getIntArg(payload, "amount") + if !ok || amount <= 0 { + amount = 700 + } + switch strings.ToLower(strings.TrimSpace(getStringArg(payload, "direction"))) { + case "up": + return 0, -amount + case "left": + return -amount, 0 + case "right": + return amount, 0 + default: + return 0, amount + } +} + +func normalizeBrowserTimeoutMs(value int, fallback int) int { + if value <= 0 { + value = fallback + } + if value < 500 { + return 500 + } + if value > 120000 { + return 120000 + } + return value +} + +func parseBrowserAriaSnapshot(snapshot string, interactive bool, limit int, useAriaRefs bool, maxDepth int) []browserSnapshotItem { + if limit <= 0 { + limit = defaultBrowserSnapshotLimit + } + lines := strings.Split(snapshot, "\n") + items := make([]browserSnapshotItem, 0, minInt(len(lines), limit)) + countByRoleName := map[string]int{} + nextRef := 0 + for _, rawLine := range lines { + if len(items) >= limit { + break + } + line := strings.TrimRight(rawLine, "\r") + if strings.TrimSpace(line) == "" { + continue + } + match := browserAriaSnapshotLinePattern.FindStringSubmatch(line) + if len(match) < 5 { + continue + } + role := strings.ToLower(strings.TrimSpace(match[2])) + if role == "" || strings.HasPrefix(role, "/") { + continue + } + if interactive && !isBrowserInteractiveRole(role) { + continue + } + name := strings.TrimSpace(match[3]) + suffix := strings.TrimSpace(match[4]) + text := browserAriaSnapshotTextFromSuffix(suffix) + depth := browserAriaSnapshotDepth(line) + if maxDepth > 0 && depth > maxDepth { + continue + } + ariaRef := "" + if refMatch := browserAriaSnapshotRefPattern.FindStringSubmatch(suffix); len(refMatch) >= 2 { + ariaRef = strings.TrimSpace(refMatch[1]) + } + key := role + "\n" + name + nth := countByRoleName[key] + countByRoleName[key] = nth + 1 + nextRef++ + ref := fmt.Sprintf("e%d", nextRef) + if useAriaRefs && ariaRef != "" { + ref = ariaRef + } + item := browserSnapshotItem{ + Ref: ref, + AriaRef: ariaRef, + Role: role, + Name: name, + Text: text, + Depth: depth, + Nth: nth, + } + if item.Name == "" { + item.Name = item.Text + } + if item.Name == "" { + item.Name = "(empty)" + } + items = append(items, item) + } + return items +} + +func browserAriaSnapshotDepth(line string) int { + spaces := 0 + for _, r := range line { + if r == ' ' { + spaces++ + continue + } + if r == '\t' { + spaces += 2 + continue + } + break + } + return spaces / 2 +} + +func browserAriaSnapshotTextFromSuffix(suffix string) string { + suffix = strings.TrimSpace(suffix) + if suffix == "" { + return "" + } + if idx := strings.LastIndex(suffix, ":"); idx >= 0 && idx+1 < len(suffix) { + return strings.TrimSpace(suffix[idx+1:]) + } + return "" +} + +func isBrowserInteractiveRole(role string) bool { + switch strings.ToLower(strings.TrimSpace(role)) { + case "button", "link", "textbox", "checkbox", "radio", "combobox", "listbox", "menuitem", "menuitemcheckbox", + "menuitemradio", "option", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem": + return true + default: + return false + } +} + +func resolveBrowserUploadRootDir() (string, error) { + dir := filepath.Join(os.TempDir(), "dreamcreator", "browser", "uploads") + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + return filepath.Abs(dir) +} + +func resolvePathWithinRoot(rootDir string, requestedPath string) (string, error) { + root := strings.TrimSpace(rootDir) + if root == "" { + return "", errors.New("upload root is required") + } + rootAbs, err := filepath.Abs(root) + if err != nil { + return "", err + } + raw := strings.TrimSpace(requestedPath) + if raw == "" { + return "", errors.New("paths are required") + } + var candidate string + if filepath.IsAbs(raw) { + candidate = filepath.Clean(raw) + } else { + candidate = filepath.Join(rootAbs, raw) + } + candidateAbs, err := filepath.Abs(candidate) + if err != nil { + return "", err + } + rel, err := filepath.Rel(rootAbs, candidateAbs) + if err != nil { + return "", err + } + rel = filepath.ToSlash(strings.TrimSpace(rel)) + if rel == "." || strings.HasPrefix(rel, "../") || rel == ".." { + return "", fmt.Errorf("Invalid path: must stay within uploads directory (%s)", rootAbs) + } + return candidateAbs, nil +} + +func isBrowserSnapshotForAIUnavailable(err error) bool { + if err == nil { + return false + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + return strings.Contains(message, "_snapshotforai is not available") || + strings.Contains(message, "snapshotforai is unavailable") +} + +func toBrowserFriendlyInteractionError(err error, selector string) error { + if err == nil { + return nil + } + message := strings.TrimSpace(err.Error()) + if strings.Contains(message, "strict mode violation") { + count := "multiple" + if match := browserStrictModeCountPattern.FindStringSubmatch(message); len(match) >= 2 { + count = strings.TrimSpace(match[1]) + } + return fmt.Errorf(`Selector "%s" matched %s elements. Run snapshot again to get updated refs, or use a different ref.`, selector, count) + } + return err +} + +func assertBrowserURLAllowed(rawURL string, policy browserSSRFPolicy) error { + return browsercdp.AssertURLAllowed(rawURL, browsercdp.SSRFPolicy{ + DangerouslyAllowPrivateNetwork: policy.DangerouslyAllowPrivateNetwork, + AllowedHostnames: cloneBrowserAllowedHostnames(policy.AllowedHostnames), + HostnameAllowlist: append([]string(nil), policy.HostnameAllowlist...), + }) +} + +func isHostnameExplicitlyAllowed(hostname string, policy browserSSRFPolicy) bool { + hostname = strings.ToLower(strings.TrimSpace(hostname)) + if hostname == "" { + return false + } + if _, ok := policy.AllowedHostnames[hostname]; ok { + return true + } + for _, pattern := range policy.HostnameAllowlist { + pattern = strings.ToLower(strings.TrimSpace(pattern)) + if pattern == "" { + continue + } + if pattern == hostname { + return true + } + if strings.HasPrefix(pattern, "*.") { + suffix := strings.TrimPrefix(pattern, "*.") + if strings.HasSuffix(hostname, "."+suffix) || hostname == suffix { + return true + } + continue + } + if matched, _ := filepath.Match(pattern, hostname); matched { + return true + } + } + return false +} + +func isPrivateOrLocalIP(ip net.IP) bool { + if ip == nil { + return false + } + if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsUnspecified() || ip.IsMulticast() { + return true + } + if ip4 := ip.To4(); ip4 != nil { + if ip4[0] == 100 && (ip4[1]&0xC0) == 64 { + return true + } + } + return false +} + +func saveBrowserArtifact(ext string, content []byte) (string, error) { + ext = strings.TrimSpace(strings.TrimPrefix(ext, ".")) + if ext == "" { + ext = "bin" + } + dir := filepath.Join(os.TempDir(), "dreamcreator", "browser") + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + filename := fmt.Sprintf("%d-%d.%s", time.Now().UnixNano(), atomic.AddUint64(&browserGlobalTabCounter, 1), ext) + path := filepath.Join(dir, filename) + if err := os.WriteFile(path, content, 0o644); err != nil { + return "", err + } + return filepath.Abs(path) +} + +func resolveBrowserRuntimeAvailability(preferred string, headless bool) browsercdp.Status { + return browsercdp.ResolveStatus(preferred, headless) +} + +func minInt(left int, right int) int { + if left < right { + return left + } + return right +} diff --git a/internal/application/gateway/tools/browser_tools_test.go b/internal/application/gateway/tools/browser_tools_test.go index f4c13a4..be8bd52 100644 --- a/internal/application/gateway/tools/browser_tools_test.go +++ b/internal/application/gateway/tools/browser_tools_test.go @@ -9,29 +9,27 @@ import ( "path/filepath" "strings" "testing" + + "dreamcreator/internal/application/agentruntime" + "dreamcreator/internal/application/browsercdp" + "github.com/cloudwego/eino/schema" ) -func TestResolveBrowserActionDefaultsStatus(t *testing.T) { +func TestResolveBrowserActionRequiresExplicitAction(t *testing.T) { t.Parallel() - action, err := resolveBrowserAction(toolArgs{}) - if err != nil { - t.Fatalf("resolve action: %v", err) - } - if action != "status" { - t.Fatalf("expected default action status, got %q", action) + _, err := resolveBrowserAction(toolArgs{}) + if err == nil { + t.Fatalf("expected missing action error") } } -func TestResolveBrowserActionDefaultsOpenWhenURLPresent(t *testing.T) { +func TestResolveBrowserActionRequiresExplicitActionEvenWhenURLPresent(t *testing.T) { t.Parallel() - action, err := resolveBrowserAction(toolArgs{"url": "https://example.com"}) - if err != nil { - t.Fatalf("resolve action: %v", err) - } - if action != "open" { - t.Fatalf("expected default action open with url, got %q", action) + _, err := resolveBrowserAction(toolArgs{"targetUrl": "https://example.com"}) + if err == nil { + t.Fatalf("expected missing action error") } } @@ -44,6 +42,98 @@ func TestResolveBrowserActionRejectsUnsupportedFetchAlias(t *testing.T) { } } +func TestSpecBrowserExposesWorkflowActionsOnly(t *testing.T) { + t.Parallel() + + var schema map[string]any + if err := json.Unmarshal([]byte(specBrowser().SchemaJSON), &schema); err != nil { + t.Fatalf("decode schema: %v", err) + } + properties, _ := schema["properties"].(map[string]any) + actionDef, _ := properties["action"].(map[string]any) + rawEnum, _ := actionDef["enum"].([]any) + got := make([]string, 0, len(rawEnum)) + for _, item := range rawEnum { + if value, ok := item.(string); ok && value != "" { + got = append(got, value) + } + } + want := []string{"open", "navigate", "snapshot", "act", "wait", "scroll", "upload", "dialog", "reset"} + if strings.Join(got, ",") != strings.Join(want, ",") { + t.Fatalf("unexpected browser actions: got %v want %v", got, want) + } +} + +func TestSpecBrowserDescriptionExplainsSnapshotWorkflow(t *testing.T) { + t.Parallel() + + description := specBrowser().Description + if !strings.Contains(description, "browser-use style loop") { + t.Fatalf("expected browser description to mention browser-use loop, got %q", description) + } + if !strings.Contains(description, "pass `url` or `targetUrl`") { + t.Fatalf("expected browser description to document url alias, got %q", description) + } + if !strings.Contains(description, "return `stateAvailable`, `itemCount`, and the current page `state`/`items`") { + t.Fatalf("expected browser description to explain open/navigate result state, got %q", description) + } + if !strings.Contains(description, "`stateAvailable=false`") { + t.Fatalf("expected browser description to explain stateAvailable fallback, got %q", description) + } + if !strings.Contains(description, "call `snapshot` to refresh them") { + t.Fatalf("expected browser description to explain open/navigate flow, got %q", description) + } + if !strings.Contains(description, "use `ref` from the latest state") { + t.Fatalf("expected browser description to prefer refs, got %q", description) + } +} + +func TestSpecBrowserSchemaAcceptsOpenURLAlias(t *testing.T) { + t.Parallel() + + validator := agentruntime.JSONToolValidator{ + Tools: map[string]agentruntime.ToolDefinition{ + "browser": { + Name: "browser", + SchemaJSON: specBrowser().SchemaJSON, + }, + }, + } + + err := validator.Validate(schema.ToolCall{ + Function: schema.FunctionCall{ + Name: "browser", + Arguments: `{"action":"open","url":"https://example.com"}`, + }, + }) + if err != nil { + t.Fatalf("expected open url alias to validate, got %v", err) + } +} + +func TestSpecBrowserSchemaAcceptsNavigateURLAlias(t *testing.T) { + t.Parallel() + + validator := agentruntime.JSONToolValidator{ + Tools: map[string]agentruntime.ToolDefinition{ + "browser": { + Name: "browser", + SchemaJSON: specBrowser().SchemaJSON, + }, + }, + } + + err := validator.Validate(schema.ToolCall{ + Function: schema.FunctionCall{ + Name: "browser", + Arguments: `{"action":"navigate","url":"https://example.com"}`, + }, + }) + if err != nil { + t.Fatalf("expected navigate url alias to validate, got %v", err) + } +} + func TestResolveBrowserRuntimeConfigDefaults(t *testing.T) { t.Parallel() @@ -51,17 +141,14 @@ func TestResolveBrowserRuntimeConfigDefaults(t *testing.T) { if !resolved.Enabled { t.Fatalf("expected browser enabled by default") } - if !resolved.EvaluateEnabled { - t.Fatalf("expected evaluateEnabled default true") - } if resolved.DefaultProfile != defaultBrowserProfileDreamCreator { t.Fatalf("expected default profile %q, got %q", defaultBrowserProfileDreamCreator, resolved.DefaultProfile) } if _, ok := resolved.Profiles[defaultBrowserProfileDreamCreator]; !ok { t.Fatalf("expected dreamcreator profile default") } - if !resolved.SSRFRules.DangerouslyAllowPrivateNetwork { - t.Fatalf("expected ssrf dangerous allow default true") + if resolved.SSRFRules.DangerouslyAllowPrivateNetwork { + t.Fatalf("expected ssrf dangerous allow default false") } } @@ -70,11 +157,8 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) { resolved := resolveBrowserRuntimeConfig(map[string]any{ "browser": map[string]any{ - "enabled": false, - "evaluateEnabled": false, - "headless": true, - "noSandbox": true, - "extraArgs": []any{"--window-size=1280,900"}, + "enabled": false, + "headless": true, "ssrfPolicy": map[string]any{ "dangerouslyAllowPrivateNetwork": false, "allowedHostnames": []any{"localhost"}, @@ -85,18 +169,9 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) { if resolved.Enabled { t.Fatalf("expected enabled false") } - if resolved.EvaluateEnabled { - t.Fatalf("expected evaluateEnabled false") - } if !resolved.Headless { t.Fatalf("expected headless true") } - if !resolved.NoSandbox { - t.Fatalf("expected noSandbox true") - } - if len(resolved.ExtraArgs) != 1 || resolved.ExtraArgs[0] != "--window-size=1280,900" { - t.Fatalf("expected extraArgs from config") - } if resolved.SSRFRules.DangerouslyAllowPrivateNetwork { t.Fatalf("expected ssrf dangerous allow false") } @@ -108,6 +183,108 @@ func TestResolveBrowserRuntimeConfigReadsBrowserSettings(t *testing.T) { } } +func TestBrowserToolRuntimeConfigChangedDetectsRuntimeChanges(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + previous map[string]any + current map[string]any + want bool + }{ + { + name: "same config", + previous: map[string]any{ + "browser": map[string]any{ + "enabled": true, + "headless": true, + "preferredBrowser": "chrome", + }, + }, + current: map[string]any{ + "browser": map[string]any{ + "enabled": true, + "headless": true, + "preferredBrowser": "chrome", + }, + }, + want: false, + }, + { + name: "headless changed", + previous: map[string]any{ + "browser": map[string]any{ + "headless": false, + }, + }, + current: map[string]any{ + "browser": map[string]any{ + "headless": true, + }, + }, + want: true, + }, + { + name: "preferred browser changed", + previous: map[string]any{ + "browser": map[string]any{ + "preferredBrowser": "chrome", + }, + }, + current: map[string]any{ + "browser": map[string]any{ + "preferredBrowser": "brave", + }, + }, + want: true, + }, + { + name: "enabled changed", + previous: map[string]any{ + "browser": map[string]any{ + "enabled": true, + }, + }, + current: map[string]any{ + "browser": map[string]any{ + "enabled": false, + }, + }, + want: true, + }, + { + name: "ssrf changed only", + previous: map[string]any{ + "browser": map[string]any{ + "headless": true, + "ssrfPolicy": map[string]any{ + "dangerouslyAllowPrivateNetwork": false, + }, + }, + }, + current: map[string]any{ + "browser": map[string]any{ + "headless": true, + "ssrfPolicy": map[string]any{ + "dangerouslyAllowPrivateNetwork": true, + }, + }, + }, + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := BrowserToolRuntimeConfigChanged(tt.previous, tt.current); got != tt.want { + t.Fatalf("BrowserToolRuntimeConfigChanged() = %v, want %v", got, tt.want) + } + }) + } +} + func TestAssertBrowserURLAllowedStrictPolicyBlocksPrivate(t *testing.T) { t.Parallel() @@ -227,7 +404,7 @@ func TestParseBrowserAriaSnapshotRespectsMaxDepth(t *testing.T) { func TestBrowserActionActRejectsSelectorForNonWaitBeforeRuntime(t *testing.T) { t.Parallel() - _, err := browserActionAct(toolArgs{ + _, err := browserActionAct(context.Background(), toolArgs{ "request": map[string]any{ "kind": "click", "selector": "button.save", @@ -241,6 +418,50 @@ func TestBrowserActionActRejectsSelectorForNonWaitBeforeRuntime(t *testing.T) { } } +func TestResolveBrowserActionAcceptsSnapshotAction(t *testing.T) { + t.Parallel() + + action, err := resolveBrowserAction(toolArgs{ + "action": "snapshot", + }) + if err != nil { + t.Fatalf("expected snapshot action to be accepted, got %v", err) + } + if action != "snapshot" { + t.Fatalf("unexpected action: got %q want snapshot", action) + } +} + +func TestResolveBrowserActionRejectsStateAction(t *testing.T) { + t.Parallel() + + _, err := resolveBrowserAction(toolArgs{ + "action": "state", + }) + if err == nil { + t.Fatalf("expected state action to be rejected") + } + if !strings.Contains(err.Error(), "not supported") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestBrowserActionActRejectsUnsupportedNavigateKind(t *testing.T) { + t.Parallel() + + _, err := browserActionAct(context.Background(), toolArgs{ + "request": map[string]any{ + "kind": "navigate", + }, + }, &browserProfileState{}) + if err == nil { + t.Fatalf("expected unsupported kind error") + } + if !strings.Contains(err.Error(), "act kind not supported") { + t.Fatalf("unexpected error: %v", err) + } +} + func TestNormalizeBrowserTimeoutMsClampRange(t *testing.T) { t.Parallel() @@ -255,28 +476,21 @@ func TestNormalizeBrowserTimeoutMsClampRange(t *testing.T) { } } -func TestTabResultFromEvaluateReturnsStoredValue(t *testing.T) { +func TestResolveBrowserScrollDeltaDefaultsDown(t *testing.T) { t.Parallel() - tab := &browserTabState{TargetID: "tab-1"} - tab.mu.Lock() - tab.evaluateResult = map[string]any{"ok": true, "value": "done"} - tab.mu.Unlock() - - result, ok := tabResultFromEvaluate(tab).(map[string]any) - if !ok { - t.Fatalf("expected evaluate result map") - } - if value, _ := result["value"].(string); value != "done" { - t.Fatalf("expected stored evaluate result, got %#v", result) + x, y := resolveBrowserScrollDelta(toolArgs{}) + if x != 0 || y != 700 { + t.Fatalf("expected default down scroll, got %d/%d", x, y) } } -func TestTabResultFromEvaluateNilTab(t *testing.T) { +func TestResolveBrowserScrollDeltaHonorsDirectionAndAmount(t *testing.T) { t.Parallel() - if result := tabResultFromEvaluate(nil); result != nil { - t.Fatalf("expected nil result for nil tab, got %#v", result) + x, y := resolveBrowserScrollDelta(toolArgs{"direction": "left", "amount": 320}) + if x != -320 || y != 0 { + t.Fatalf("expected left scroll -320/0, got %d/%d", x, y) } } @@ -373,14 +587,58 @@ func TestResolveBrowserNodeOutputReturnsNilOnInvalidJSON(t *testing.T) { func TestIsBrowserSnapshotForAIUnavailable(t *testing.T) { t.Parallel() - if !isBrowserSnapshotForAIUnavailable(errors.New("Playwright _snapshotForAI is not available")) { - t.Fatalf("expected _snapshotForAI error to be treated as unavailable") + if !isBrowserSnapshotForAIUnavailable(errors.New("browser snapshotForAI is unavailable")) { + t.Fatalf("expected snapshotForAI unavailable error to be treated as unavailable") } if isBrowserSnapshotForAIUnavailable(errors.New("network timeout")) { t.Fatalf("expected unrelated error to remain actionable") } } +func TestBrowserResultMapPreservesSnapshotPayload(t *testing.T) { + t.Parallel() + + result := browserResultMap(browsercdp.ActionResult{ + OK: true, + TargetID: "tab-1", + URL: "https://example.com", + StateVersion: 7, + Items: []browsercdp.SnapshotItem{ + {Ref: "e1", Role: "button", Name: "Save"}, + }, + State: &browsercdp.PageState{ + Version: 7, + URL: "https://example.com", + ItemCount: 1, + CapturedAt: "2026-04-18T00:00:00Z", + }, + StateAvailable: true, + StateError: "warning", + }) + + if _, ok := result["state"]; !ok { + t.Fatalf("expected state to be preserved") + } + if got := result["stateVersion"]; got != float64(7) { + t.Fatalf("expected stateVersion to remain, got %#v", got) + } + if _, ok := result["items"]; !ok { + t.Fatalf("expected items to be preserved") + } + if got := result["stateAvailable"]; got != true { + t.Fatalf("expected stateAvailable to remain, got %#v", got) + } + if got := result["itemCount"]; got != 1 { + t.Fatalf("expected itemCount to remain, got %#v", got) + } + if got := result["stateError"]; got != "warning" { + t.Fatalf("expected stateError to remain, got %#v", got) + } + if got := result["targetId"]; got != "tab-1" { + t.Fatalf("expected targetId to remain, got %#v", got) + } +} + func TestToBrowserFriendlyInteractionErrorStrictMode(t *testing.T) { t.Parallel() diff --git a/internal/application/gateway/tools/builtin_requirement_resolver.go b/internal/application/gateway/tools/builtin_requirement_resolver.go new file mode 100644 index 0000000..f6f5a3e --- /dev/null +++ b/internal/application/gateway/tools/builtin_requirement_resolver.go @@ -0,0 +1,203 @@ +package tools + +import ( + "context" + "strings" + + assistantservice "dreamcreator/internal/application/assistant/service" + gatewayvoice "dreamcreator/internal/application/gateway/voice" + tooldto "dreamcreator/internal/application/tools/dto" + "dreamcreator/internal/domain/providers" +) + +const ( + imageRequirementID = "image.model_runtime" + ttsServiceRequirementID = "tts.voice_service" + ttsVoiceEnabledRequirementID = "tts.voice_enabled" + ttsProviderRequirementID = "tts.provider_supported" + ttsProviderAPIKeyRequirementID = "tts.provider_api_key" + ttsVoiceIDRequirementID = "tts.voice_id" + imageModelNotConfiguredReason = "Image model is not configured" + providerRepositoriesReason = "Provider repositories are unavailable" + voiceDisabledReason = "Voice is disabled" + voiceServiceUnavailableReason = "Voice service unavailable" + ttsProviderAPIKeyMissingReason = "TTS provider API key is missing" + ttsVoiceIDMissingReason = "TTS voice ID is not configured" + ttsEdgeProviderUnavailableReason = "Edge-TTS provider is not implemented yet" + ttsUnsupportedProviderReason = "TTS provider is not supported" +) + +type voiceToolStatusProvider interface { + Status(ctx context.Context) (gatewayvoice.TTSStatusResponse, error) +} + +type BuiltinRequirementDeps struct { + Settings SettingsReader + Assistants *assistantservice.AssistantService + Providers providers.ProviderRepository + Models providers.ModelRepository + Secrets providers.SecretRepository + Voice voiceToolStatusProvider +} + +type builtinRequirementResolver struct { + deps BuiltinRequirementDeps +} + +func NewBuiltinRequirementResolver(deps BuiltinRequirementDeps) ToolRequirementResolver { + return builtinRequirementResolver{deps: deps} +} + +func (resolver builtinRequirementResolver) ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement { + switch resolveToolRequirementKey(spec) { + case "image": + return resolver.resolveImageRequirements(ctx) + case "tts": + return resolver.resolveTTSRequirements(ctx) + default: + return nil + } +} + +func (resolver builtinRequirementResolver) resolveImageRequirements(ctx context.Context) []tooldto.ToolRequirement { + requirement := tooldto.ToolRequirement{ + ID: imageRequirementID, + Name: "Image model", + Available: true, + } + if resolver.deps.Providers == nil || resolver.deps.Secrets == nil || resolver.deps.Models == nil { + requirement.Available = false + requirement.Reason = providerRepositoriesReason + return []tooldto.ToolRequirement{requirement} + } + configuredPrimaryRef := resolveImageToolConfiguredPrimaryRef(ctx, resolver.deps.Settings) + _, _, err := resolveImageToolCandidates( + ctx, + resolver.deps.Assistants, + "", + configuredPrimaryRef, + resolver.deps.Providers, + resolver.deps.Models, + resolver.deps.Secrets, + ) + if err == nil { + return []tooldto.ToolRequirement{requirement} + } + requirement.Available = false + requirement.Reason = normalizeBuiltinRequirementReason(err, imageModelNotConfiguredReason) + return []tooldto.ToolRequirement{requirement} +} + +func (resolver builtinRequirementResolver) resolveTTSRequirements(ctx context.Context) []tooldto.ToolRequirement { + if resolver.deps.Voice == nil { + return []tooldto.ToolRequirement{ + { + ID: ttsServiceRequirementID, + Name: "Voice service", + Available: false, + Reason: voiceServiceUnavailableReason, + }, + } + } + status, err := resolver.deps.Voice.Status(ctx) + if err != nil { + return []tooldto.ToolRequirement{ + { + ID: ttsServiceRequirementID, + Name: "Voice service", + Available: false, + Reason: normalizeBuiltinRequirementReason(err, voiceServiceUnavailableReason), + }, + } + } + providerID := strings.ToLower(strings.TrimSpace(status.Config.ProviderID)) + if providerID == "" { + providerID = "edge" + } + requirements := []tooldto.ToolRequirement{ + { + ID: ttsVoiceEnabledRequirementID, + Name: "Voice feature", + Available: status.Enabled, + }, + } + if !status.Enabled { + requirements[0].Reason = voiceDisabledReason + } + switch providerID { + case "edge": + requirements = append(requirements, tooldto.ToolRequirement{ + ID: ttsProviderRequirementID, + Name: "Provider", + Available: false, + Reason: ttsEdgeProviderUnavailableReason, + Data: map[string]any{ + "providerId": providerID, + }, + }) + return requirements + case "openai", "elevenlabs": + providerReady := false + for _, provider := range status.Providers { + if strings.EqualFold(strings.TrimSpace(provider.ProviderID), providerID) { + providerReady = provider.Available + break + } + } + requirements = append(requirements, + tooldto.ToolRequirement{ + ID: ttsProviderRequirementID, + Name: "Provider", + Available: true, + Data: map[string]any{ + "providerId": providerID, + }, + }, + tooldto.ToolRequirement{ + ID: ttsProviderAPIKeyRequirementID, + Name: "Provider API key", + Available: providerReady, + }, + ) + if !providerReady { + requirements[len(requirements)-1].Reason = ttsProviderAPIKeyMissingReason + } + if providerID == "elevenlabs" { + voiceIDConfigured := strings.TrimSpace(status.Config.VoiceID) != "" + requirements = append(requirements, tooldto.ToolRequirement{ + ID: ttsVoiceIDRequirementID, + Name: "Voice ID", + Available: voiceIDConfigured, + Data: map[string]any{ + "value": strings.TrimSpace(status.Config.VoiceID), + }, + }) + if !voiceIDConfigured { + requirements[len(requirements)-1].Reason = ttsVoiceIDMissingReason + } + } + return requirements + default: + requirements = append(requirements, tooldto.ToolRequirement{ + ID: ttsProviderRequirementID, + Name: "Provider", + Available: false, + Reason: ttsUnsupportedProviderReason, + Data: map[string]any{ + "providerId": providerID, + }, + }) + return requirements + } +} + +func normalizeBuiltinRequirementReason(err error, fallback string) string { + if err == nil { + return fallback + } + message := strings.TrimSpace(err.Error()) + if message == "" { + return fallback + } + return message +} diff --git a/internal/application/gateway/tools/builtin_specs.go b/internal/application/gateway/tools/builtin_specs.go index d0d5945..2c5b190 100644 --- a/internal/application/gateway/tools/builtin_specs.go +++ b/internal/application/gateway/tools/builtin_specs.go @@ -170,26 +170,16 @@ func specWebFetch() toolSpec { return toolSpec{ ID: "web_fetch", Name: "web_fetch", - Description: "Fetch a web page and return structured status fields (status, retryable, next_action, quality). Do not blind-retry the same call when status is not ok.", + Description: "Fetch a web page through a local CDP browser, extract token-efficient main content, and return structured status fields (status, retryable, next_action, quality). Do not blind-retry the same call when status is not ok.", Category: "web", SchemaJSON: schemaJSON(map[string]any{ "type": "object", "properties": map[string]any{ - "url": map[string]any{"type": "string"}, - "type": map[string]any{"type": "string"}, - "method": map[string]any{"type": "string"}, - "headers": map[string]any{"type": "object"}, - "maxChars": map[string]any{"type": "integer"}, - "maxBodyBytes": map[string]any{"type": "integer"}, - "maxRedirects": map[string]any{"type": "integer"}, - "retryMax": map[string]any{"type": "integer"}, - "timeoutSeconds": map[string]any{"type": "integer"}, - "acceptMarkdown": map[string]any{"type": "boolean"}, - "markdown": map[string]any{"type": "boolean"}, - "toMarkdown": map[string]any{"type": "boolean"}, - "enableUserAgent": map[string]any{"type": "boolean"}, - "userAgent": map[string]any{"type": "string"}, - "acceptLanguage": map[string]any{"type": "string"}, + "url": map[string]any{"type": "string"}, + "method": map[string]any{"type": "string"}, + "maxChars": map[string]any{"type": "integer"}, + "maxBodyBytes": map[string]any{"type": "integer"}, + "timeoutSeconds": map[string]any{"type": "integer"}, }, "required": []string{"url"}, }), @@ -227,207 +217,6 @@ func specWebSearch() toolSpec { } } -func specBrowser() toolSpec { - return toolSpec{ - ID: "browser", - Name: "browser", - Description: "Control the browser via Playwright runtime (status/start/stop/profiles/tabs/open/focus/close/snapshot/screenshot/navigate/console/pdf/upload/dialog/act).", - Category: "ui", - RiskLevel: "high", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "action": map[string]any{ - "type": "string", - "enum": []string{ - "status", - "start", - "stop", - "profiles", - "tabs", - "open", - "focus", - "close", - "snapshot", - "screenshot", - "navigate", - "console", - "pdf", - "upload", - "dialog", - "act", - }, - }, - "target": map[string]any{ - "type": "string", - "enum": []string{"sandbox", "host", "node"}, - }, - "node": map[string]any{"type": "string"}, - "profile": map[string]any{"type": "string"}, - "targetUrl": map[string]any{"type": "string"}, - "targetId": map[string]any{"type": "string"}, - "limit": map[string]any{"type": "integer"}, - "maxChars": map[string]any{"type": "integer"}, - "mode": map[string]any{"type": "string", "enum": []string{"efficient"}}, - "timeoutMs": map[string]any{"type": "integer"}, - "snapshotFormat": map[string]any{ - "type": "string", - "enum": []string{"aria", "ai"}, - }, - "refs": map[string]any{ - "type": "string", - "enum": []string{"role", "aria"}, - }, - "interactive": map[string]any{"type": "boolean"}, - "compact": map[string]any{"type": "boolean"}, - "depth": map[string]any{"type": "integer"}, - "selector": map[string]any{"type": "string"}, - "frame": map[string]any{"type": "string"}, - "labels": map[string]any{"type": "boolean"}, - "fullPage": map[string]any{"type": "boolean"}, - "ref": map[string]any{"type": "string"}, - "element": map[string]any{"type": "string"}, - "type": map[string]any{ - "type": "string", - "enum": []string{"png", "jpeg"}, - }, - "level": map[string]any{"type": "string"}, - "paths": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - }, - "inputRef": map[string]any{"type": "string"}, - "accept": map[string]any{"type": "boolean"}, - "promptText": map[string]any{"type": "string"}, - "request": map[string]any{ - "type": "object", - "properties": map[string]any{ - "kind": map[string]any{ - "type": "string", - "enum": []string{ - "click", - "type", - "press", - "hover", - "drag", - "select", - "fill", - "resize", - "wait", - "evaluate", - "close", - }, - }, - "targetId": map[string]any{"type": "string"}, - "ref": map[string]any{"type": "string"}, - "doubleClick": map[string]any{"type": "boolean"}, - "button": map[string]any{"type": "string"}, - "modifiers": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - }, - "text": map[string]any{"type": "string"}, - "submit": map[string]any{"type": "boolean"}, - "slowly": map[string]any{"type": "boolean"}, - "key": map[string]any{"type": "string"}, - "startRef": map[string]any{"type": "string"}, - "endRef": map[string]any{"type": "string"}, - "values": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}, - "fields": map[string]any{"type": "array", "items": map[string]any{"type": "object"}}, - "width": map[string]any{"type": "number"}, - "height": map[string]any{"type": "number"}, - "timeMs": map[string]any{"type": "number"}, - "textGone": map[string]any{"type": "string"}, - "selector": map[string]any{"type": "string"}, - "url": map[string]any{"type": "string"}, - "loadState": map[string]any{"type": "string"}, - "fn": map[string]any{"type": "string"}, - "timeoutMs": map[string]any{"type": "number"}, - }, - }, - }, - "required": []string{"action"}, - }), - RequiresSandbox: true, - RequiresApproval: true, - Enabled: true, - } -} - -func specCanvas() toolSpec { - return toolSpec{ - ID: "canvas", - Name: "canvas", - Description: "Control node canvases (present/hide/navigate/eval/snapshot/a2ui).", - Category: "ui", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "action": map[string]any{ - "type": "string", - "enum": []string{ - "present", - "hide", - "navigate", - "eval", - "snapshot", - "a2ui_push", - "a2ui_reset", - }, - }, - "gatewayUrl": map[string]any{"type": "string"}, - "gatewayToken": map[string]any{"type": "string"}, - "timeoutMs": map[string]any{"type": "number"}, - "node": map[string]any{"type": "string"}, - "target": map[string]any{"type": "string"}, - "x": map[string]any{"type": "number"}, - "y": map[string]any{"type": "number"}, - "width": map[string]any{"type": "number"}, - "height": map[string]any{"type": "number"}, - "url": map[string]any{"type": "string"}, - "javaScript": map[string]any{"type": "string"}, - "outputFormat": map[string]any{ - "type": "string", - "enum": []string{"png", "jpg", "jpeg"}, - }, - "maxWidth": map[string]any{"type": "number"}, - "quality": map[string]any{"type": "number"}, - "delayMs": map[string]any{"type": "number"}, - "jsonl": map[string]any{"type": "string"}, - "jsonlPath": map[string]any{ - "type": "string", - }, - }, - "required": []string{"action"}, - }), - Enabled: true, - } -} - -func specImage() toolSpec { - return toolSpec{ - ID: "image", - Name: "image", - Description: "Analyze one or more images with the configured image model.", - Category: "media", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "prompt": map[string]any{"type": "string"}, - "image": map[string]any{"type": "string"}, - "images": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - }, - "model": map[string]any{"type": "string"}, - "maxBytesMb": map[string]any{"type": "number"}, - "maxImages": map[string]any{"type": "number"}, - }, - }), - Enabled: true, - } -} - func specMessage(ctx context.Context, settings SettingsReader) toolSpec { spec := specMessageBase() profile := resolveMessageToolSchemaProfile(ctx, settings) @@ -1274,7 +1063,7 @@ func specAgentsList() toolSpec { return toolSpec{ ID: "agents_list", Name: "agents_list", - Description: "List agents.", + Description: "List available agent profiles for subagent spawning.", Category: "sessions", Enabled: true, } @@ -1284,7 +1073,7 @@ func specSessionsList() toolSpec { return toolSpec{ ID: "sessions_list", Name: "sessions_list", - Description: "List sessions.", + Description: "List current sessions.", Category: "sessions", Enabled: true, } @@ -1294,7 +1083,7 @@ func specSessionsHistory() toolSpec { return toolSpec{ ID: "sessions_history", Name: "sessions_history", - Description: "Fetch session history.", + Description: "Read message history for a session.", Category: "sessions", SchemaJSON: schemaJSON(map[string]any{ "type": "object", @@ -1312,7 +1101,7 @@ func specSessionsSend() toolSpec { return toolSpec{ ID: "sessions_send", Name: "sessions_send", - Description: "Send a message to a session.", + Description: "Append a message to an existing session.", Category: "sessions", RiskLevel: "medium", SchemaJSON: schemaJSON(map[string]any{ @@ -1357,7 +1146,7 @@ func specSessionStatus() toolSpec { return toolSpec{ ID: "session_status", Name: "session_status", - Description: "Get session status.", + Description: "Get session metadata and status.", Category: "sessions", SchemaJSON: schemaJSON(map[string]any{ "type": "object", @@ -1422,7 +1211,7 @@ func specSkills() toolSpec { return toolSpec{ ID: "skills", Name: "skills", - Description: "Skills runtime tool: status/bins + dependency install + per-skill runtime config update. Not for package search/install.", + Description: "Inspect skills runtime status and update per-skill runtime dependencies or configuration.", Category: "skills", RiskLevel: "medium", SchemaJSON: schemaJSON(map[string]any{ @@ -1506,12 +1295,12 @@ func specSkills() toolSpec { } } -func specSkillManage() toolSpec { +func specSkillsManage() toolSpec { actions := []string{"list", "search", "install", "update", "remove", "sync"} return toolSpec{ - ID: "skill_manage", - Name: "skill_manage", - Description: "Skills package management tool via ClawHub/SkillHub (list/search/install/update/remove/sync). Use for discovery and package lifecycle.", + ID: "skills_manage", + Name: "skills_manage", + Description: "Search, install, update, remove, and sync skill packages via ClawHub.", Category: "skills", RiskLevel: "medium", SchemaJSON: schemaJSON(map[string]any{ @@ -1574,7 +1363,7 @@ func specSubagents() toolSpec { return toolSpec{ ID: "subagents", Name: "subagents", - Description: "Manage subagent runs (list/info/log/kill/steer/send).", + Description: "Manage existing subagent runs (list/info/log/kill/steer/send).", Category: "sessions", SchemaJSON: schemaJSON(map[string]any{ "type": "object", @@ -1593,7 +1382,7 @@ func specNodes() toolSpec { return toolSpec{ ID: "nodes", Name: "nodes", - Description: "Invoke node capability.", + Description: "Experimental low-level RPC to a registered node. Temporarily unavailable until remote node runtime support is implemented; prefer specialized tools like canvas or browser when available.", Category: "nodes", RiskLevel: "medium", SchemaJSON: schemaJSON(map[string]any{ @@ -1601,10 +1390,14 @@ func specNodes() toolSpec { "properties": map[string]any{ "nodeId": map[string]any{"type": "string"}, "capability": map[string]any{"type": "string"}, + "action": map[string]any{"type": "string"}, + "args": map[string]any{"type": "string"}, "payload": map[string]any{"type": "object"}, + "timeoutMs": map[string]any{"type": "integer"}, }, + "required": []string{"nodeId", "capability"}, }), - Enabled: true, + Enabled: false, } } @@ -1612,7 +1405,7 @@ func specTTS() toolSpec { return toolSpec{ ID: "tts", Name: "tts", - Description: "Text-to-speech conversion.", + Description: "Synthesize speech audio from text with the configured voice provider.", Category: "voice", SchemaJSON: schemaJSON(map[string]any{ "type": "object", @@ -1628,15 +1421,20 @@ func specTTS() toolSpec { } } -func specMemoryRecall() toolSpec { +func specMemoryQuery() toolSpec { return toolSpec{ - ID: "memory_recall", - Name: "memory_recall", - Description: "Search long-term memory with hybrid retrieval.", + ID: "memory_query", + Name: "memory_query", + Description: "Query long-term memory with recall, list, and stats actions.", Category: "memory", SchemaJSON: schemaJSON(map[string]any{ - "type": "object", + "type": "object", + "additionalProperties": false, "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"recall", "list", "stats"}, + }, "query": map[string]any{"type": "string"}, "limit": map[string]any{"type": "integer"}, "topK": map[string]any{"type": "integer"}, @@ -1651,80 +1449,38 @@ func specMemoryRecall() toolSpec { "peerKind": map[string]any{"type": "string"}, "peerId": map[string]any{"type": "string"}, }, - "required": []string{"query"}, - }), - Enabled: true, - } -} - -func specMemoryStore() toolSpec { - return toolSpec{ - ID: "memory_store", - Name: "memory_store", - Description: "Store a long-term memory entry.", - Category: "memory", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "text": map[string]any{"type": "string"}, - "content": map[string]any{"type": "string"}, - "category": map[string]any{"type": "string"}, - "confidence": map[string]any{"type": "number"}, - "assistantId": map[string]any{"type": "string"}, - "threadId": map[string]any{"type": "string"}, - "scope": map[string]any{"type": "string"}, - "metadata": map[string]any{"type": "object"}, - "channel": map[string]any{"type": "string"}, - "accountId": map[string]any{"type": "string"}, - "userId": map[string]any{"type": "string"}, - "groupId": map[string]any{"type": "string"}, - "peerKind": map[string]any{"type": "string"}, - "peerId": map[string]any{"type": "string"}, - }, - "required": []string{"text"}, - }), - Enabled: true, - } -} - -func specMemoryForget() toolSpec { - return toolSpec{ - ID: "memory_forget", - Name: "memory_forget", - Description: "Forget memory entries by id or query.", - Category: "memory", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "memoryId": map[string]any{"type": "string"}, - "id": map[string]any{"type": "string"}, - "query": map[string]any{"type": "string"}, - "assistantId": map[string]any{"type": "string"}, - "threadId": map[string]any{"type": "string"}, - "category": map[string]any{"type": "string"}, - "scope": map[string]any{"type": "string"}, - "limit": map[string]any{"type": "integer"}, - "channel": map[string]any{"type": "string"}, - "accountId": map[string]any{"type": "string"}, - "userId": map[string]any{"type": "string"}, - "groupId": map[string]any{"type": "string"}, - "peerKind": map[string]any{"type": "string"}, - "peerId": map[string]any{"type": "string"}, + "required": []string{"action"}, + "allOf": []any{ + map[string]any{ + "if": map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "recall"}, + }, + }, + "then": map[string]any{ + "required": []string{"query"}, + }, + }, }, }), Enabled: true, } } -func specMemoryUpdate() toolSpec { +func specMemoryManage() toolSpec { return toolSpec{ - ID: "memory_update", - Name: "memory_update", - Description: "Update an existing memory entry.", + ID: "memory_manage", + Name: "memory_manage", + Description: "Create, update, or delete long-term memory entries.", Category: "memory", SchemaJSON: schemaJSON(map[string]any{ - "type": "object", + "type": "object", + "additionalProperties": false, "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"store", "update", "forget"}, + }, "memoryId": map[string]any{"type": "string"}, "id": map[string]any{"type": "string"}, "text": map[string]any{"type": "string"}, @@ -1735,57 +1491,8 @@ func specMemoryUpdate() toolSpec { "threadId": map[string]any{"type": "string"}, "scope": map[string]any{"type": "string"}, "metadata": map[string]any{"type": "object"}, - "channel": map[string]any{"type": "string"}, - "accountId": map[string]any{"type": "string"}, - "userId": map[string]any{"type": "string"}, - "groupId": map[string]any{"type": "string"}, - "peerKind": map[string]any{"type": "string"}, - "peerId": map[string]any{"type": "string"}, - }, - }), - Enabled: true, - } -} - -func specMemoryStats() toolSpec { - return toolSpec{ - ID: "memory_stats", - Name: "memory_stats", - Description: "Show memory statistics.", - Category: "memory", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "assistantId": map[string]any{"type": "string"}, - "threadId": map[string]any{"type": "string"}, - "scope": map[string]any{"type": "string"}, - "channel": map[string]any{"type": "string"}, - "accountId": map[string]any{"type": "string"}, - "userId": map[string]any{"type": "string"}, - "groupId": map[string]any{"type": "string"}, - "peerKind": map[string]any{"type": "string"}, - "peerId": map[string]any{"type": "string"}, - }, - }), - Enabled: true, - } -} - -func specMemoryList() toolSpec { - return toolSpec{ - ID: "memory_list", - Name: "memory_list", - Description: "List recent memory entries.", - Category: "memory", - SchemaJSON: schemaJSON(map[string]any{ - "type": "object", - "properties": map[string]any{ - "assistantId": map[string]any{"type": "string"}, - "threadId": map[string]any{"type": "string"}, - "category": map[string]any{"type": "string"}, - "scope": map[string]any{"type": "string"}, + "query": map[string]any{"type": "string"}, "limit": map[string]any{"type": "integer"}, - "offset": map[string]any{"type": "integer"}, "channel": map[string]any{"type": "string"}, "accountId": map[string]any{"type": "string"}, "userId": map[string]any{"type": "string"}, @@ -1793,6 +1500,49 @@ func specMemoryList() toolSpec { "peerKind": map[string]any{"type": "string"}, "peerId": map[string]any{"type": "string"}, }, + "required": []string{"action"}, + "allOf": []any{ + map[string]any{ + "if": map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "store"}, + }, + }, + "then": map[string]any{ + "anyOf": []any{ + map[string]any{"required": []string{"text"}}, + map[string]any{"required": []string{"content"}}, + }, + }, + }, + map[string]any{ + "if": map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "update"}, + }, + }, + "then": map[string]any{ + "anyOf": []any{ + map[string]any{"required": []string{"memoryId"}}, + map[string]any{"required": []string{"id"}}, + }, + }, + }, + map[string]any{ + "if": map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "forget"}, + }, + }, + "then": map[string]any{ + "anyOf": []any{ + map[string]any{"required": []string{"memoryId"}}, + map[string]any{"required": []string{"id"}}, + map[string]any{"required": []string{"query"}}, + }, + }, + }, + }, }), Enabled: true, } diff --git a/internal/application/gateway/tools/builtin_specs_ui.go b/internal/application/gateway/tools/builtin_specs_ui.go new file mode 100644 index 0000000..7367fe1 --- /dev/null +++ b/internal/application/gateway/tools/builtin_specs_ui.go @@ -0,0 +1,347 @@ +package tools + +func specBrowser() toolSpec { + waitConditionProperties := map[string]any{ + "timeMs": map[string]any{"type": "number"}, + "text": map[string]any{"type": "string"}, + "textGone": map[string]any{"type": "string"}, + "selector": map[string]any{"type": "string"}, + "url": map[string]any{"type": "string"}, + "fn": map[string]any{"type": "string"}, + "timeoutMs": map[string]any{"type": "number"}, + } + waitConditionSchema := map[string]any{ + "type": "object", + "properties": waitConditionProperties, + "additionalProperties": false, + } + return toolSpec{ + ID: "browser", + Name: "browser", + Description: "Control a local CDP browser (`open`/`navigate`/`snapshot`/`act`/`wait`/`scroll`/`upload`/`dialog`/`reset`) using a browser-use style loop. For `open` or `navigate`, pass `url` or `targetUrl`; these actions return `stateAvailable`, `itemCount`, and the current page `state`/`items` whenever capture succeeds, so inspect that result before deciding the next action. If `stateAvailable=false` or refs look stale after the page changes, call `snapshot` to refresh them, then continue with `act` using `ref` on the same `targetId`. Do not use raw CSS `selector` for normal interactions; use `ref` from the latest state. Matching connector cookies are injected automatically before navigation.", + Category: "ui", + RiskLevel: "high", + SchemaJSON: schemaJSON(map[string]any{ + "type": "object", + "additionalProperties": false, + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{ + "open", + "navigate", + "snapshot", + "act", + "wait", + "scroll", + "upload", + "dialog", + "reset", + }, + }, + "target": map[string]any{ + "type": "string", + "enum": []string{"sandbox", "host", "node"}, + }, + "node": map[string]any{"type": "string"}, + "profile": map[string]any{"type": "string"}, + "targetUrl": map[string]any{"type": "string"}, + "targetId": map[string]any{"type": "string"}, + "newTab": map[string]any{"type": "boolean"}, + "restart": map[string]any{"type": "boolean"}, + "limit": map[string]any{"type": "integer"}, + "timeoutMs": map[string]any{"type": "integer"}, + "selector": map[string]any{"type": "string"}, + "fullPage": map[string]any{"type": "boolean"}, + "ref": map[string]any{"type": "string"}, + "x": map[string]any{"type": "integer"}, + "y": map[string]any{"type": "integer"}, + "amount": map[string]any{"type": "integer"}, + "direction": map[string]any{ + "type": "string", + "enum": []string{"up", "down", "left", "right"}, + }, + "text": map[string]any{"type": "string"}, + "textGone": map[string]any{"type": "string"}, + "fn": map[string]any{"type": "string"}, + "timeMs": map[string]any{"type": "number"}, + "url": map[string]any{"type": "string"}, + "paths": map[string]any{"type": "array", "items": map[string]any{"type": "string"}}, + "accept": map[string]any{"type": "boolean"}, + "promptText": map[string]any{"type": "string"}, + "waitFor": waitConditionSchema, + "request": map[string]any{ + "type": "object", + "additionalProperties": false, + "properties": map[string]any{ + "kind": map[string]any{ + "type": "string", + "enum": []string{ + "click", + "type", + "press", + "hover", + "select", + "fill", + "resize", + "wait", + "evaluate", + "close", + }, + }, + "targetId": map[string]any{"type": "string"}, + "ref": map[string]any{"type": "string"}, + "text": map[string]any{"type": "string"}, + "key": map[string]any{"type": "string"}, + "value": map[string]any{"type": "string"}, + "width": map[string]any{"type": "number"}, + "height": map[string]any{"type": "number"}, + "timeMs": map[string]any{"type": "number"}, + "textGone": map[string]any{"type": "string"}, + "selector": map[string]any{"type": "string"}, + "url": map[string]any{"type": "string"}, + "fn": map[string]any{"type": "string"}, + "expression": map[string]any{"type": "string"}, + "timeoutMs": map[string]any{"type": "number"}, + "waitFor": waitConditionSchema, + }, + "required": []string{"kind"}, + }, + }, + "required": []string{"action"}, + "allOf": []any{ + map[string]any{ + "anyOf": []any{ + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "open"}, + }, + "required": []string{"action"}, + "anyOf": []any{ + map[string]any{"required": []string{"targetUrl"}}, + map[string]any{"required": []string{"url"}}, + }, + }, + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{ + "enum": []string{ + "snapshot", + "navigate", + "wait", + "scroll", + "upload", + "dialog", + "act", + "reset", + }, + }, + }, + "required": []string{"action"}, + }, + }, + }, + map[string]any{ + "anyOf": []any{ + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "navigate"}, + }, + "required": []string{"action"}, + "anyOf": []any{ + map[string]any{"required": []string{"targetUrl"}}, + map[string]any{"required": []string{"url"}}, + }, + }, + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{ + "enum": []string{ + "open", + "snapshot", + "wait", + "scroll", + "upload", + "dialog", + "act", + "reset", + }, + }, + }, + "required": []string{"action"}, + }, + }, + }, + map[string]any{ + "anyOf": []any{ + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "act"}, + }, + "required": []string{"action", "request"}, + }, + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{ + "enum": []string{ + "open", + "snapshot", + "navigate", + "wait", + "scroll", + "upload", + "dialog", + "reset", + }, + }, + }, + "required": []string{"action"}, + }, + }, + }, + map[string]any{ + "anyOf": []any{ + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "wait"}, + }, + "allOf": []any{ + map[string]any{ + "anyOf": []any{ + map[string]any{"required": []string{"timeMs"}}, + map[string]any{"required": []string{"text"}}, + map[string]any{"required": []string{"textGone"}}, + map[string]any{"required": []string{"selector"}}, + map[string]any{"required": []string{"url"}}, + map[string]any{"required": []string{"fn"}}, + }, + }, + }, + }, + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{ + "enum": []string{ + "open", + "snapshot", + "navigate", + "scroll", + "upload", + "dialog", + "act", + "reset", + }, + }, + }, + "required": []string{"action"}, + }, + }, + }, + map[string]any{ + "anyOf": []any{ + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{"const": "upload"}, + }, + "required": []string{"action", "ref", "paths"}, + }, + map[string]any{ + "properties": map[string]any{ + "action": map[string]any{ + "enum": []string{ + "open", + "snapshot", + "navigate", + "wait", + "scroll", + "dialog", + "act", + "reset", + }, + }, + }, + "required": []string{"action"}, + }, + }, + }, + }, + }), + RequiresSandbox: true, + RequiresApproval: true, + Enabled: true, + } +} + +func specCanvas() toolSpec { + return toolSpec{ + ID: "canvas", + Name: "canvas", + Description: "Control node canvases (present/hide/navigate/eval/snapshot/a2ui). Temporarily unavailable until remote node runtime support is implemented.", + Category: "ui", + SchemaJSON: schemaJSON(map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{ + "present", + "hide", + "navigate", + "eval", + "snapshot", + "a2ui_push", + "a2ui_reset", + }, + }, + "gatewayUrl": map[string]any{"type": "string"}, + "gatewayToken": map[string]any{"type": "string"}, + "timeoutMs": map[string]any{"type": "number"}, + "node": map[string]any{"type": "string"}, + "target": map[string]any{"type": "string"}, + "x": map[string]any{"type": "number"}, + "y": map[string]any{"type": "number"}, + "width": map[string]any{"type": "number"}, + "height": map[string]any{"type": "number"}, + "url": map[string]any{"type": "string"}, + "javaScript": map[string]any{"type": "string"}, + "outputFormat": map[string]any{ + "type": "string", + "enum": []string{"png", "jpg", "jpeg"}, + }, + "maxWidth": map[string]any{"type": "number"}, + "quality": map[string]any{"type": "number"}, + "delayMs": map[string]any{"type": "number"}, + "jsonl": map[string]any{"type": "string"}, + "jsonlPath": map[string]any{ + "type": "string", + }, + }, + "required": []string{"action"}, + }), + Enabled: false, + } +} + +func specImage() toolSpec { + return toolSpec{ + ID: "image", + Name: "image", + Description: "Analyze one or more images with the configured image model.", + Category: "media", + SchemaJSON: schemaJSON(map[string]any{ + "type": "object", + "properties": map[string]any{ + "prompt": map[string]any{"type": "string"}, + "image": map[string]any{"type": "string"}, + "images": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + }, + "model": map[string]any{"type": "string"}, + "maxBytesMb": map[string]any{"type": "number"}, + "maxImages": map[string]any{"type": "number"}, + }, + }), + Enabled: true, + } +} diff --git a/internal/application/gateway/tools/builtin_tools.go b/internal/application/gateway/tools/builtin_tools.go index 266e357..ffb2f79 100644 --- a/internal/application/gateway/tools/builtin_tools.go +++ b/internal/application/gateway/tools/builtin_tools.go @@ -65,19 +65,15 @@ func RegisterBuiltinTools(ctx context.Context, toolSvc *toolservice.ToolService, registerTool(ctx, toolSvc, executor, specSubagents(), runSubagentsTool(nil)) registerTool(ctx, toolSvc, executor, specNodes(), runNodesTool(deps.Nodes)) registerTool(ctx, toolSvc, executor, specTTS(), runTTSTool(deps.Voice)) - registerTool(ctx, toolSvc, executor, specMemoryRecall(), runMemoryRecallTool(deps.Memory)) - registerTool(ctx, toolSvc, executor, specMemoryStore(), runMemoryStoreTool(deps.Memory)) - registerTool(ctx, toolSvc, executor, specMemoryForget(), runMemoryForgetTool(deps.Memory)) - registerTool(ctx, toolSvc, executor, specMemoryUpdate(), runMemoryUpdateTool(deps.Memory)) - registerTool(ctx, toolSvc, executor, specMemoryStats(), runMemoryStatsTool(deps.Memory)) - registerTool(ctx, toolSvc, executor, specMemoryList(), runMemoryListTool(deps.Memory)) + registerTool(ctx, toolSvc, executor, specMemoryQuery(), runMemoryQueryTool(deps.Memory)) + registerTool(ctx, toolSvc, executor, specMemoryManage(), runMemoryManageTool(deps.Memory)) if deps.ExternalTools != nil { registerTool(ctx, toolSvc, executor, specExternalToolsQuery(), runExternalToolsQueryTool(deps.ExternalTools)) registerTool(ctx, toolSvc, executor, specExternalToolsManage(), runExternalToolsManageTool(deps.ExternalTools)) } if deps.Skills != nil { registerTool(ctx, toolSvc, executor, specSkills(), runSkillsTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools)) - registerTool(ctx, toolSvc, executor, specSkillManage(), runSkillManageTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools)) + registerTool(ctx, toolSvc, executor, specSkillsManage(), runSkillsManageTool(deps.Skills, deps.Assistant, deps.Settings, deps.ExternalTools)) } if deps.Library != nil { registerTool(ctx, toolSvc, executor, specLibrary(), runLibraryGroupTool(deps.Library, "library")) diff --git a/internal/application/gateway/tools/memory_tools.go b/internal/application/gateway/tools/memory_tools.go index cf5140b..ec51b37 100644 --- a/internal/application/gateway/tools/memory_tools.go +++ b/internal/application/gateway/tools/memory_tools.go @@ -11,6 +11,56 @@ import ( domainsession "dreamcreator/internal/domain/session" ) +func runMemoryQueryTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) { + recallHandler := runMemoryRecallTool(memory) + listHandler := runMemoryListTool(memory) + statsHandler := runMemoryStatsTool(memory) + return func(ctx context.Context, args string) (string, error) { + payload, err := parseToolArgs(args) + if err != nil { + return "", err + } + action := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "type"))) + switch action { + case "recall", "search": + return recallHandler(ctx, args) + case "list": + return listHandler(ctx, args) + case "stats", "status": + return statsHandler(ctx, args) + case "": + return "", errors.New("action is required") + default: + return "", errors.New("unsupported memory_query action") + } + } +} + +func runMemoryManageTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) { + storeHandler := runMemoryStoreTool(memory) + updateHandler := runMemoryUpdateTool(memory) + forgetHandler := runMemoryForgetTool(memory) + return func(ctx context.Context, args string) (string, error) { + payload, err := parseToolArgs(args) + if err != nil { + return "", err + } + action := strings.ToLower(strings.TrimSpace(getStringArg(payload, "action", "type"))) + switch action { + case "store", "create": + return storeHandler(ctx, args) + case "update": + return updateHandler(ctx, args) + case "forget", "delete", "remove": + return forgetHandler(ctx, args) + case "": + return "", errors.New("action is required") + default: + return "", errors.New("unsupported memory_manage action") + } + } +} + func runMemoryRecallTool(memory *memoryservice.MemoryService) func(ctx context.Context, args string) (string, error) { return func(ctx context.Context, args string) (string, error) { if memory == nil { diff --git a/internal/application/gateway/tools/memory_tools_group_test.go b/internal/application/gateway/tools/memory_tools_group_test.go new file mode 100644 index 0000000..af3b3d1 --- /dev/null +++ b/internal/application/gateway/tools/memory_tools_group_test.go @@ -0,0 +1,56 @@ +package tools + +import ( + "context" + "testing" +) + +func TestMemoryQueryToolRequiresAction(t *testing.T) { + t.Parallel() + + handler := runMemoryQueryTool(nil) + _, err := handler(context.Background(), `{}`) + if err == nil || err.Error() != "action is required" { + t.Fatalf("expected action is required, got %v", err) + } +} + +func TestMemoryManageToolRequiresAction(t *testing.T) { + t.Parallel() + + handler := runMemoryManageTool(nil) + _, err := handler(context.Background(), `{}`) + if err == nil || err.Error() != "action is required" { + t.Fatalf("expected action is required, got %v", err) + } +} + +func TestMemoryQueryToolRoutesRecallAction(t *testing.T) { + t.Parallel() + + handler := runMemoryQueryTool(nil) + _, err := handler(context.Background(), `{"action":"recall","query":"demo"}`) + if err == nil || err.Error() != "memory service unavailable" { + t.Fatalf("expected memory service unavailable, got %v", err) + } +} + +func TestMemoryManageToolRoutesCreateAlias(t *testing.T) { + t.Parallel() + + handler := runMemoryManageTool(nil) + _, err := handler(context.Background(), `{"action":"create","text":"demo"}`) + if err == nil || err.Error() != "memory service unavailable" { + t.Fatalf("expected memory service unavailable, got %v", err) + } +} + +func TestMemoryManageToolRejectsUnknownAction(t *testing.T) { + t.Parallel() + + handler := runMemoryManageTool(nil) + _, err := handler(context.Background(), `{"action":"archive"}`) + if err == nil || err.Error() != "unsupported memory_manage action" { + t.Fatalf("expected unsupported memory_manage action, got %v", err) + } +} diff --git a/internal/application/gateway/tools/nodes_tools.go b/internal/application/gateway/tools/nodes_tools.go index ccb5b95..c48b99e 100644 --- a/internal/application/gateway/tools/nodes_tools.go +++ b/internal/application/gateway/tools/nodes_tools.go @@ -19,10 +19,14 @@ func runNodesTool(nodes *gatewaynodes.Service) func(ctx context.Context, args st return "", err } nodeID := getStringArg(payload, "nodeId", "nodeID") - capability := getStringArg(payload, "capability", "action") + capability := getStringArg(payload, "capability") if nodeID == "" { return "", errors.New("nodeId is required") } + if capability == "" { + return "", errors.New("capability is required") + } + timeoutMs, _ := getIntArg(payload, "timeoutMs", "timeout_ms") argsJSON := strings.TrimSpace(getStringArg(payload, "args")) if argsJSON == "" { if payloadMap := getMapArg(payload, "payload"); payloadMap != nil { @@ -36,6 +40,7 @@ func runNodesTool(nodes *gatewaynodes.Service) func(ctx context.Context, args st Capability: capability, Action: getStringArg(payload, "action"), Args: argsJSON, + TimeoutMs: timeoutMs, } result, err := nodes.Invoke(ctx, request) if err != nil { diff --git a/internal/application/gateway/tools/nodes_tools_test.go b/internal/application/gateway/tools/nodes_tools_test.go new file mode 100644 index 0000000..306da0b --- /dev/null +++ b/internal/application/gateway/tools/nodes_tools_test.go @@ -0,0 +1,73 @@ +package tools + +import ( + "context" + "strings" + "testing" + + gatewaynodes "dreamcreator/internal/application/gateway/nodes" +) + +func TestRunNodesToolRequiresExplicitCapability(t *testing.T) { + t.Parallel() + + handler := runNodesTool(newNodesServiceForTest(t, gatewaynodes.InvokerFunc(func(_ context.Context, request gatewaynodes.NodeInvokeRequest) (gatewaynodes.NodeInvokeResult, error) { + t.Fatalf("unexpected invoke: %#v", request) + return gatewaynodes.NodeInvokeResult{}, nil + }))) + + _, err := handler(context.Background(), `{"nodeId":"node-1","action":"screen.capture"}`) + if err == nil { + t.Fatalf("expected capability required error") + } + if !strings.Contains(err.Error(), "capability is required") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRunNodesToolPassesActionPayloadAndTimeout(t *testing.T) { + t.Parallel() + + var captured gatewaynodes.NodeInvokeRequest + handler := runNodesTool(newNodesServiceForTest(t, gatewaynodes.InvokerFunc(func(_ context.Context, request gatewaynodes.NodeInvokeRequest) (gatewaynodes.NodeInvokeResult, error) { + captured = request + return gatewaynodes.NodeInvokeResult{ + InvokeID: request.InvokeID, + Ok: true, + Output: `{"ok":true}`, + }, nil + }))) + + _, err := handler(context.Background(), `{"nodeId":"node-1","capability":"screen","action":"capture","payload":{"format":"png"},"timeoutMs":4321}`) + if err != nil { + t.Fatalf("run nodes tool: %v", err) + } + if captured.NodeID != "node-1" { + t.Fatalf("expected node-1, got %q", captured.NodeID) + } + if captured.Capability != "screen" { + t.Fatalf("expected capability screen, got %q", captured.Capability) + } + if captured.Action != "capture" { + t.Fatalf("expected action capture, got %q", captured.Action) + } + if captured.Args != `{"format":"png"}` { + t.Fatalf("unexpected args payload: %q", captured.Args) + } + if captured.TimeoutMs != 4321 { + t.Fatalf("expected timeout 4321, got %d", captured.TimeoutMs) + } +} + +func newNodesServiceForTest(t *testing.T, invoker gatewaynodes.Invoker) *gatewaynodes.Service { + t.Helper() + registry := gatewaynodes.NewRegistry(nil, nil) + _, err := registry.Register(context.Background(), "", gatewaynodes.NodeDescriptor{ + NodeID: "node-1", + Status: "online", + }) + if err != nil { + t.Fatalf("register node: %v", err) + } + return gatewaynodes.NewService(registry, nil, invoker, nil, nil) +} diff --git a/internal/application/gateway/tools/policy.go b/internal/application/gateway/tools/policy.go index 4d5868a..ea293e6 100644 --- a/internal/application/gateway/tools/policy.go +++ b/internal/application/gateway/tools/policy.go @@ -8,13 +8,21 @@ import ( ) type PolicyPipeline struct { - settings SettingsReader + settings SettingsReader + requirementsResolver ToolRequirementResolver } func NewPolicyPipeline(settings SettingsReader) *PolicyPipeline { return &PolicyPipeline{settings: settings} } +func (pipeline *PolicyPipeline) SetRequirementsResolver(resolver ToolRequirementResolver) { + if pipeline == nil { + return + } + pipeline.requirementsResolver = resolver +} + func (pipeline *PolicyPipeline) Decide(ctx context.Context, spec tooldto.ToolSpec, policyCtx tooldto.ToolPolicyContext) (tooldto.ToolPolicyDecision, error) { if !spec.Enabled { return tooldto.ToolPolicyDecision{ @@ -24,7 +32,7 @@ func (pipeline *PolicyPipeline) Decide(ctx context.Context, spec tooldto.ToolSpe }, nil } snapshot := loadToolRequirementSnapshot(ctx, pipeline.settings) - effectiveSpec := resolveEffectiveToolSpec(spec, snapshot) + effectiveSpec := resolveEffectiveToolSpecWithResolver(ctx, spec, snapshot, pipeline.requirementsResolver) if !effectiveSpec.Enabled { reason := "tool requirements unavailable" if requirement, ok := firstUnavailableToolRequirement(effectiveSpec.Requirements); ok { diff --git a/internal/application/gateway/tools/requirement_resolver.go b/internal/application/gateway/tools/requirement_resolver.go new file mode 100644 index 0000000..25fed3c --- /dev/null +++ b/internal/application/gateway/tools/requirement_resolver.go @@ -0,0 +1,41 @@ +package tools + +import ( + "context" + + tooldto "dreamcreator/internal/application/tools/dto" +) + +type ToolRequirementResolver interface { + ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement +} + +type ToolRequirementResolverFunc func(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement + +func (fn ToolRequirementResolverFunc) ResolveToolRequirements(ctx context.Context, spec tooldto.ToolSpec) []tooldto.ToolRequirement { + if fn == nil { + return nil + } + return fn(ctx, spec) +} + +func resolveEffectiveToolSpecWithResolver( + ctx context.Context, + spec tooldto.ToolSpec, + snapshot toolRequirementSnapshot, + resolver ToolRequirementResolver, +) tooldto.ToolSpec { + effective := resolveEffectiveToolSpec(spec, snapshot) + if resolver == nil { + return effective + } + additional := resolver.ResolveToolRequirements(ctx, effective) + if len(additional) == 0 { + return effective + } + effective.Requirements = append(effective.Requirements, additional...) + if effective.Enabled && !toolRequirementsSatisfied(effective.Requirements) { + effective.Enabled = false + } + return effective +} diff --git a/internal/application/gateway/tools/service.go b/internal/application/gateway/tools/service.go index 6fe339f..34d0817 100644 --- a/internal/application/gateway/tools/service.go +++ b/internal/application/gateway/tools/service.go @@ -16,14 +16,15 @@ import ( ) type Service struct { - tools *toolservice.ToolService - approvals *gatewayapprovals.Service - sandbox *gatewaysandbox.Service - settings SettingsReader - audit PolicyAuditStore - events *gatewayevents.Broker - now func() time.Time - newID func() string + tools *toolservice.ToolService + approvals *gatewayapprovals.Service + sandbox *gatewaysandbox.Service + settings SettingsReader + audit PolicyAuditStore + events *gatewayevents.Broker + requirementsResolver ToolRequirementResolver + now func() time.Time + newID func() string } type PolicyAuditStore interface { @@ -43,6 +44,13 @@ func NewService(tools *toolservice.ToolService, approvals *gatewayapprovals.Serv } } +func (service *Service) SetRequirementsResolver(resolver ToolRequirementResolver) { + if service == nil { + return + } + service.requirementsResolver = resolver +} + func (service *Service) ListTools(ctx context.Context) []tooldto.ToolSpec { if service == nil || service.tools == nil { return nil @@ -54,7 +62,7 @@ func (service *Service) ListTools(ctx context.Context) []tooldto.ToolSpec { snapshot := loadToolRequirementSnapshot(ctx, service.settings) result := make([]tooldto.ToolSpec, 0, len(specs)) for _, spec := range specs { - resolved := resolveEffectiveToolSpec(spec, snapshot) + resolved := resolveEffectiveToolSpecWithResolver(ctx, spec, snapshot, service.requirementsResolver) resolved = resolveDynamicToolSpec(ctx, resolved, service.settings) result = append(result, resolved) } @@ -98,6 +106,7 @@ func (service *Service) InvokeWithPolicy(ctx context.Context, request tooldto.To if err != nil { return tooldto.ToolsInvokeResponse{}, err } + spec = resolveEffectiveToolSpecWithResolver(ctx, spec, loadToolRequirementSnapshot(ctx, service.settings), service.requirementsResolver) spec = resolveDynamicToolSpec(ctx, spec, service.settings) service.auditDecision(ctx, spec, decision, policyCtx) response := tooldto.ToolsInvokeResponse{ @@ -171,6 +180,13 @@ func (service *Service) InvokeWithPolicy(ctx context.Context, request tooldto.To return response, nil } +func (service *Service) CleanupRuntimeSession(_ context.Context, sessionKey string) { + if strings.TrimSpace(sessionKey) == "" { + return + } + cleanupBrowserToolSessions(sessionKey) +} + func (service *Service) auditDecision(ctx context.Context, spec tooldto.ToolSpec, decision tooldto.ToolPolicyDecision, policyCtx tooldto.ToolPolicyContext) { if service == nil || service.audit == nil { return diff --git a/internal/application/gateway/tools/skills_tools.go b/internal/application/gateway/tools/skills_tools.go index 43f9630..dc59118 100644 --- a/internal/application/gateway/tools/skills_tools.go +++ b/internal/application/gateway/tools/skills_tools.go @@ -152,8 +152,12 @@ func runSkillsTool(skills *skillsservice.SkillsService, assistants *assistantser return runSkillsToolWithName("skills", skills, assistants, settings, externalTools) } +func runSkillsManageTool(skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) { + return runSkillsToolWithName("skills_manage", skills, assistants, settings, externalTools) +} + func runSkillManageTool(skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) { - return runSkillsToolWithName("skill_manage", skills, assistants, settings, externalTools) + return runSkillsManageTool(skills, assistants, settings, externalTools) } func runSkillsToolWithName(toolName string, skills *skillsservice.SkillsService, assistants *assistantservice.AssistantService, settings SettingsReader, externalTools skillsExternalToolInstaller) func(ctx context.Context, args string) (string, error) { @@ -832,7 +836,7 @@ func canonicalSkillsAction(toolName string, action string) string { default: return "" } - case "skill_manage": + case "skill_manage", "skills_manage": switch normalized { case "list": return "catalog" @@ -873,20 +877,20 @@ func normalizeSkillsActionKey(toolName string, action string) string { default: return normalized } - case "skill_manage": + case "skill_manage", "skills_manage": switch normalized { case "catalog": - return "skill_manage.list" + return "skills_manage.list" case "search_packages": - return "skill_manage.search" + return "skills_manage.search" case "install_package": - return "skill_manage.install" + return "skills_manage.install" case "update_package": - return "skill_manage.update" + return "skills_manage.update" case "remove_package": - return "skill_manage.remove" + return "skills_manage.remove" case "sync_packages": - return "skill_manage.sync" + return "skills_manage.sync" default: return normalized } @@ -900,7 +904,7 @@ func validateSkillsToolArgs(toolName string, payload toolArgs) error { return nil } allowed := allowedSkillsToolArgs - if strings.EqualFold(strings.TrimSpace(toolName), "skill_manage") { + if strings.EqualFold(strings.TrimSpace(toolName), "skill_manage") || strings.EqualFold(strings.TrimSpace(toolName), "skills_manage") { allowed = allowedSkillManageToolArgs } unknown := make([]string, 0) @@ -1148,9 +1152,9 @@ func marshalSkillsResult(result skillsToolResult) string { func resolveSkillsActionGroup(action string) string { switch strings.ToLower(strings.TrimSpace(action)) { - case "skills.status", "skills.bins", "skill_manage.search", "skill_manage.list": + case "skills.status", "skills.bins", "skills_manage.search", "skills_manage.list": return "read" - case "skill_manage.install", "skill_manage.update", "skill_manage.remove", "skill_manage.sync": + case "skills_manage.install", "skills_manage.update", "skills_manage.remove", "skills_manage.sync": return "package_write" case "skills.install": return "deps_write" diff --git a/internal/application/gateway/tools/skills_tools_test.go b/internal/application/gateway/tools/skills_tools_test.go index c3e7447..c7e2e82 100644 --- a/internal/application/gateway/tools/skills_tools_test.go +++ b/internal/application/gateway/tools/skills_tools_test.go @@ -201,7 +201,7 @@ func (stub *skillsSettingsStub) UpdateSettings(_ context.Context, request settin return copied, nil } -func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) { +func TestSkillsManageToolSearchReturnsClawHubUnavailable(t *testing.T) { t.Parallel() repo := newSkillsRepoStub() @@ -209,10 +209,10 @@ func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) { svc.SetExternalTools(&skillsExternalToolsStubForGateway{ready: false}) svc.SetWorkspaceResolver(skillsWorkspaceResolverStubForGateway{}) - handler := runSkillManageTool(svc, nil, nil, nil) + handler := runSkillsManageTool(svc, nil, nil, nil) output, err := handler(context.Background(), `{"action":"search","query":"demo"}`) if err != nil { - t.Fatalf("skill_manage tool failed: %v", err) + t.Fatalf("skills_manage tool failed: %v", err) } var payload map[string]any if err := json.Unmarshal([]byte(output), &payload); err != nil { @@ -226,7 +226,7 @@ func TestSkillManageToolSearchReturnsClawHubUnavailable(t *testing.T) { } } -func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) { +func TestSkillsManageToolInstallRequireForceAndRetrySuccess(t *testing.T) { t.Parallel() repo := newSkillsRepoStub() @@ -248,11 +248,11 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) { } }, }) - handler := runSkillManageTool(svc, nil, nil, nil) + handler := runSkillsManageTool(svc, nil, nil, nil) output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro"}`) if err != nil { - t.Fatalf("skill_manage tool install failed: %v", err) + t.Fatalf("skills_manage tool install failed: %v", err) } var payload map[string]any if err := json.Unmarshal([]byte(output), &payload); err != nil { @@ -267,13 +267,13 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) { if payload["requiresForce"] != true { t.Fatalf("expected requiresForce=true, got %#v", payload["requiresForce"]) } - if payload["action"] != "skill_manage.install" { - t.Fatalf("expected action skill_manage.install, got %#v", payload["action"]) + if payload["action"] != "skills_manage.install" { + t.Fatalf("expected action skills_manage.install, got %#v", payload["action"]) } output, err = handler(context.Background(), `{"action":"install","skill":"web-search-pro","force":true}`) if err != nil { - t.Fatalf("skill_manage tool install(force) failed: %v", err) + t.Fatalf("skills_manage tool install(force) failed: %v", err) } payload = map[string]any{} if err := json.Unmarshal([]byte(output), &payload); err != nil { @@ -284,7 +284,7 @@ func TestSkillManageToolInstallRequireForceAndRetrySuccess(t *testing.T) { } } -func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) { +func TestSkillsManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) { t.Parallel() repo := newSkillsRepoStub() @@ -312,11 +312,11 @@ func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) }, }, }) - handler := runSkillManageTool(svc, nil, settings, nil) + handler := runSkillsManageTool(svc, nil, settings, nil) output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro"}`) if err != nil { - t.Fatalf("skill_manage tool install failed: %v", err) + t.Fatalf("skills_manage tool install failed: %v", err) } var payload map[string]any if err := json.Unmarshal([]byte(output), &payload); err != nil { @@ -330,7 +330,7 @@ func TestSkillManageToolInstallRequireForceBlockedByScannerPolicy(t *testing.T) } } -func TestSkillManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) { +func TestSkillsManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) { t.Parallel() repo := newSkillsRepoStub() @@ -348,11 +348,11 @@ func TestSkillManageToolForceInstallRequiresApprovalWhenEnabled(t *testing.T) { }, }, }) - handler := runSkillManageTool(svc, nil, settings, nil) + handler := runSkillsManageTool(svc, nil, settings, nil) output, err := handler(context.Background(), `{"action":"install","skill":"web-search-pro","force":true}`) if err != nil { - t.Fatalf("skill_manage tool install failed: %v", err) + t.Fatalf("skills_manage tool install failed: %v", err) } var payload map[string]any if err := json.Unmarshal([]byte(output), &payload); err != nil { @@ -483,7 +483,7 @@ func TestSkillsToolActionModeDeny(t *testing.T) { } } -func TestSkillManageToolActionModeAsk(t *testing.T) { +func TestSkillsManageToolActionModeAsk(t *testing.T) { t.Parallel() repo := newSkillsRepoStub() @@ -497,11 +497,11 @@ func TestSkillManageToolActionModeAsk(t *testing.T) { }, }, }) - handler := runSkillManageTool(svc, nil, settings, nil) + handler := runSkillsManageTool(svc, nil, settings, nil) output, err := handler(context.Background(), `{"action":"install","skill":"skill-a"}`) if err != nil { - t.Fatalf("skill_manage tool failed: %v", err) + t.Fatalf("skills_manage tool failed: %v", err) } var payload map[string]any if err := json.Unmarshal([]byte(output), &payload); err != nil { diff --git a/internal/application/gateway/tools/tool_helpers.go b/internal/application/gateway/tools/tool_helpers.go index 9b7949c..0aeacc1 100644 --- a/internal/application/gateway/tools/tool_helpers.go +++ b/internal/application/gateway/tools/tool_helpers.go @@ -350,3 +350,21 @@ func getNestedInt(root map[string]any, path ...string) (int, bool) { } return 0, false } + +func containsString(values []string, needle string) bool { + needle = strings.TrimSpace(needle) + for _, value := range values { + if strings.TrimSpace(value) == needle { + return true + } + } + return false +} + +func trimToMaxChars(value string, maxChars int) string { + value = strings.TrimSpace(value) + if maxChars <= 0 || len(value) <= maxChars { + return value + } + return value[:maxChars] +} diff --git a/internal/application/gateway/tools/tool_requirements.go b/internal/application/gateway/tools/tool_requirements.go index 1ca3e8e..01c50a2 100644 --- a/internal/application/gateway/tools/tool_requirements.go +++ b/internal/application/gateway/tools/tool_requirements.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "dreamcreator/internal/application/browsercdp" tooldto "dreamcreator/internal/application/tools/dto" ) @@ -41,9 +42,16 @@ func loadToolRequirementSnapshot(ctx context.Context, settings SettingsReader) t func resolveEffectiveToolSpec(spec tooldto.ToolSpec, snapshot toolRequirementSnapshot) tooldto.ToolSpec { key := resolveToolRequirementKey(spec) - if snapshot.loaded && key == "browser" { - if enabled, ok := resolveBrowserConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled { - spec.Enabled = false + if snapshot.loaded { + switch key { + case "browser": + if enabled, ok := resolveBrowserConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled { + spec.Enabled = false + } + case "web_fetch": + if enabled, ok := resolveWebFetchConfigBool(snapshot.toolsConfig, "enabled"); ok && !enabled { + spec.Enabled = false + } } } requirements := resolveToolRequirements(spec, snapshot) @@ -97,30 +105,56 @@ func resolveToolRequirements(spec tooldto.ToolSpec, snapshot toolRequirementSnap return resolveWebFetchRequirements(snapshot.toolsConfig) case "browser": return resolveBrowserRequirements(snapshot.toolsConfig) + case "canvas": + return resolveCanvasRequirements() + case "nodes": + return resolveNodeRequirements() default: return nil } } -func resolveBrowserRequirements(_ map[string]any) []tooldto.ToolRequirement { - runtimeRequirement := tooldto.ToolRequirement{ - ID: "browser.playwright_runtime", - Name: "Playwright runtime", - Available: true, +func resolveBrowserRequirements(config map[string]any) []tooldto.ToolRequirement { + resolved := resolveBrowserRuntimeConfig(config) + status := browsercdp.ResolveStatus(resolved.PreferredBrowser, resolved.Headless) + requirements := []tooldto.ToolRequirement{ + { + ID: "browser.cdp_runtime", + Name: "Local CDP browser", + Available: status.Ready, + Reason: strings.TrimSpace(status.DetectError), + Data: map[string]any{ + "candidates": status.Candidates, + "selectedBrowser": status.SelectedBrowser, + "chosenBrowser": status.ChosenBrowser, + "detectedExecutablePath": status.DetectedExecutablePath, + "headless": status.Headless, + }, + }, } - available, reason, execPath := resolveBrowserPlaywrightRuntimeAvailability() - runtimeRequirement.Available = available - if !available { - if strings.TrimSpace(reason) != "" { - runtimeRequirement.Reason = reason - } else { - runtimeRequirement.Reason = "Playwright runtime is unavailable" - } - } else if strings.TrimSpace(execPath) != "" { - // Surface resolved Chromium binary path for browser settings content. - runtimeRequirement.Reason = strings.TrimSpace(execPath) + return requirements +} + +func resolveNodeRequirements() []tooldto.ToolRequirement { + return []tooldto.ToolRequirement{ + { + ID: "nodes.remote_runtime", + Name: "Remote node runtime", + Available: false, + Reason: "Remote node runtime is not implemented yet", + }, + } +} + +func resolveCanvasRequirements() []tooldto.ToolRequirement { + return []tooldto.ToolRequirement{ + { + ID: "canvas.remote_runtime", + Name: "Remote node runtime", + Available: false, + Reason: "Remote node runtime is not implemented yet", + }, } - return []tooldto.ToolRequirement{runtimeRequirement} } func resolveToolRequirementKey(spec tooldto.ToolSpec) string { @@ -132,19 +166,21 @@ func resolveToolRequirementKey(spec tooldto.ToolSpec) string { } func resolveWebFetchRequirements(config map[string]any) []tooldto.ToolRequirement { - enabled := true - if value, ok := resolveWebFetchConfigBool(config, "enabled"); ok { - enabled = value + status := browsercdp.ResolveStatus(resolveWebFetchPreferredBrowser(config), resolveWebFetchHeadless(config)) + browserRequirement := tooldto.ToolRequirement{ + ID: "web_fetch.local_browser", + Name: "Local CDP browser", + Available: status.Ready, + Reason: strings.TrimSpace(status.DetectError), + Data: map[string]any{ + "candidates": status.Candidates, + "selectedBrowser": status.SelectedBrowser, + "chosenBrowser": status.ChosenBrowser, + "detectedExecutablePath": status.DetectedExecutablePath, + "headless": status.Headless, + }, } - requirement := tooldto.ToolRequirement{ - ID: "web_fetch.config_enabled", - Name: "Web fetch switch", - Available: enabled, - } - if !enabled { - requirement.Reason = "web_fetch is disabled in settings" - } - return []tooldto.ToolRequirement{requirement} + return []tooldto.ToolRequirement{browserRequirement} } func resolveWebSearchRequirements(config map[string]any) []tooldto.ToolRequirement { @@ -264,3 +300,12 @@ func resolveWebSearchProviderAPIKey(config map[string]any, provider string) stri } return strings.TrimSpace(apiKey) } + +func resolveWebSearchProviderString(config map[string]any, provider string, key string) string { + provider = strings.ToLower(strings.TrimSpace(provider)) + key = strings.TrimSpace(key) + if provider == "" || key == "" { + return "" + } + return getNestedString(config, "web", "search", "providers", provider, key) +} diff --git a/internal/application/gateway/tools/tool_requirements_test.go b/internal/application/gateway/tools/tool_requirements_test.go index 2882f0e..c11442b 100644 --- a/internal/application/gateway/tools/tool_requirements_test.go +++ b/internal/application/gateway/tools/tool_requirements_test.go @@ -5,11 +5,25 @@ import ( "strings" "testing" + gatewayvoice "dreamcreator/internal/application/gateway/voice" settingsdto "dreamcreator/internal/application/settings/dto" tooldto "dreamcreator/internal/application/tools/dto" toolservice "dreamcreator/internal/application/tools/service" + "dreamcreator/internal/domain/providers" ) +type voiceStatusStub struct { + status gatewayvoice.TTSStatusResponse + err error +} + +func (stub voiceStatusStub) Status(context.Context) (gatewayvoice.TTSStatusResponse, error) { + if stub.err != nil { + return gatewayvoice.TTSStatusResponse{}, stub.err + } + return stub.status, nil +} + func TestResolveEffectiveToolSpecWebSearchAPIMissingKeyDisablesTool(t *testing.T) { t.Parallel() @@ -111,16 +125,17 @@ func TestResolveEffectiveToolSpecWebFetchDisabledByConfig(t *testing.T) { if spec.Enabled { t.Fatalf("expected web_fetch disabled when config enabled flag is false") } - requirement, ok := findRequirement(spec.Requirements, "web_fetch.config_enabled") + requirement, ok := findRequirement(spec.Requirements, "web_fetch.local_browser") if !ok { - t.Fatalf("expected web_fetch config requirement") + t.Fatalf("expected web_fetch local browser requirement") } - if requirement.Available { - t.Fatalf("expected web_fetch config requirement to be unavailable") + if _, ok := findRequirement(spec.Requirements, "web_fetch.config_enabled"); ok { + t.Fatalf("did not expect web_fetch config enabled requirement") } + _ = requirement } -func TestResolveEffectiveToolSpecBrowserPlaywrightIncludesRuntimeRequirement(t *testing.T) { +func TestResolveEffectiveToolSpecBrowserIncludesCDPRuntimeRequirement(t *testing.T) { t.Parallel() snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ @@ -128,17 +143,23 @@ func TestResolveEffectiveToolSpecBrowserPlaywrightIncludesRuntimeRequirement(t * Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, Tools: map[string]any{ "browser": map[string]any{ - "type": "playwright", + "preferredBrowser": "brave", }, }, }, }) spec := resolveEffectiveToolSpec(specBrowser().toDTO(), snapshot) - requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime") + requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime") if !ok { - t.Fatalf("expected browser playwright runtime requirement") + t.Fatalf("expected browser cdp runtime requirement") + } + data, ok := requirement.Data.(map[string]any) + if !ok || data == nil { + t.Fatalf("expected browser runtime requirement data") + } + if data["selectedBrowser"] != "brave" { + t.Fatalf("expected selected browser brave, got %#v", data["selectedBrowser"]) } - _ = requirement } func TestResolveEffectiveToolSpecBrowserDisabledByConfig(t *testing.T) { @@ -158,9 +179,9 @@ func TestResolveEffectiveToolSpecBrowserDisabledByConfig(t *testing.T) { if spec.Enabled { t.Fatalf("expected browser disabled when browser.enabled is false") } - requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime") + requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime") if !ok { - t.Fatalf("expected browser playwright runtime requirement") + t.Fatalf("expected browser cdp runtime requirement") } if _, ok := findRequirement(spec.Requirements, "browser.config_enabled"); ok { t.Fatalf("did not expect browser config enabled requirement") @@ -182,9 +203,9 @@ func TestResolveEffectiveToolSpecBrowserLegacyTypeIsIgnored(t *testing.T) { }, }) spec := resolveEffectiveToolSpec(specBrowser().toDTO(), snapshot) - requirement, ok := findRequirement(spec.Requirements, "browser.playwright_runtime") + requirement, ok := findRequirement(spec.Requirements, "browser.cdp_runtime") if !ok { - t.Fatalf("expected browser playwright runtime requirement") + t.Fatalf("expected browser cdp runtime requirement") } if _, ok := findRequirement(spec.Requirements, "browser.type_supported"); ok { t.Fatalf("did not expect browser type requirement") @@ -213,6 +234,280 @@ func TestResolveEffectiveToolSpecGatewayRequiresControlPlane(t *testing.T) { } } +func TestResolveEffectiveToolSpecNodesRemainsDisabledUntilRemoteRuntimeExists(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + spec := resolveEffectiveToolSpec(specNodes().toDTO(), snapshot) + if spec.Enabled { + t.Fatalf("expected nodes tool disabled while remote node runtime is unavailable") + } + requirement, ok := findRequirement(spec.Requirements, "nodes.remote_runtime") + if !ok { + t.Fatalf("expected remote node runtime requirement") + } + if requirement.Available { + t.Fatalf("expected remote node runtime requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "not implemented") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + +func TestResolveEffectiveToolSpecCanvasRemainsDisabledUntilRemoteRuntimeExists(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + spec := resolveEffectiveToolSpec(specCanvas().toDTO(), snapshot) + if spec.Enabled { + t.Fatalf("expected canvas tool disabled while remote node runtime is unavailable") + } + requirement, ok := findRequirement(spec.Requirements, "canvas.remote_runtime") + if !ok { + t.Fatalf("expected canvas remote runtime requirement") + } + if requirement.Available { + t.Fatalf("expected canvas remote runtime requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "not implemented") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + +func TestResolveEffectiveToolSpecImageRequiresConfiguredModel(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Providers: imageProviderRepoStub{items: map[string]providers.Provider{}}, + Models: imageModelRepoStub{items: map[string][]providers.Model{}}, + Secrets: imageSecretRepoStub{items: map[string]providers.ProviderSecret{}}, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specImage().toDTO(), snapshot, resolver) + if spec.Enabled { + t.Fatalf("expected image tool disabled without a configured model") + } + requirement, ok := findRequirement(spec.Requirements, imageRequirementID) + if !ok { + t.Fatalf("expected image model requirement") + } + if requirement.Available { + t.Fatalf("expected image model requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "image model") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + +func TestResolveEffectiveToolSpecImageStaysEnabledWithConfiguredModel(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Providers: imageProviderRepoStub{items: map[string]providers.Provider{ + "openai": { + ID: "openai", + Enabled: true, + Type: providers.ProviderTypeOpenAI, + Endpoint: "https://api.openai.com/v1", + }, + }}, + Models: imageModelRepoStub{items: map[string][]providers.Model{ + "openai": { + { + ID: "gpt-4o", + Name: "gpt-4o", + Enabled: true, + SupportsVision: ptrBool(true), + }, + }, + }}, + Secrets: imageSecretRepoStub{items: map[string]providers.ProviderSecret{ + "openai": {ProviderID: "openai", APIKey: "test-key"}, + }}, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specImage().toDTO(), snapshot, resolver) + if !spec.Enabled { + t.Fatalf("expected image tool to stay enabled with a configured model") + } + requirement, ok := findRequirement(spec.Requirements, imageRequirementID) + if !ok { + t.Fatalf("expected image model requirement") + } + if !requirement.Available { + t.Fatalf("expected image model requirement to be available") + } +} + +func TestResolveEffectiveToolSpecTTSRequiresRunnableProvider(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Voice: voiceStatusStub{ + status: gatewayvoice.TTSStatusResponse{ + Enabled: true, + Config: gatewayvoice.TTSConfig{ + ProviderID: "edge", + }, + Providers: []gatewayvoice.TTSProviderCatalogItem{ + {ProviderID: "edge", DisplayName: "Edge-TTS", Available: true}, + }, + }, + }, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver) + if spec.Enabled { + t.Fatalf("expected tts tool disabled when only edge placeholder provider is selected") + } + requirement, ok := findRequirement(spec.Requirements, ttsProviderRequirementID) + if !ok { + t.Fatalf("expected tts provider requirement") + } + if requirement.Available { + t.Fatalf("expected tts provider requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "edge-tts") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + +func TestResolveEffectiveToolSpecTTSStaysEnabledWithConfiguredProvider(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Voice: voiceStatusStub{ + status: gatewayvoice.TTSStatusResponse{ + Enabled: true, + Config: gatewayvoice.TTSConfig{ + ProviderID: "openai", + }, + Providers: []gatewayvoice.TTSProviderCatalogItem{ + {ProviderID: "openai", DisplayName: "OpenAI", Available: true}, + }, + }, + }, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver) + if !spec.Enabled { + t.Fatalf("expected tts tool to stay enabled with a configured provider") + } + requirement, ok := findRequirement(spec.Requirements, ttsProviderAPIKeyRequirementID) + if !ok { + t.Fatalf("expected tts provider api key requirement") + } + if !requirement.Available { + t.Fatalf("expected tts provider api key requirement to be available") + } +} + +func TestResolveEffectiveToolSpecTTSVoiceFeatureDisabled(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Voice: voiceStatusStub{ + status: gatewayvoice.TTSStatusResponse{ + Enabled: false, + Config: gatewayvoice.TTSConfig{ + ProviderID: "openai", + }, + Providers: []gatewayvoice.TTSProviderCatalogItem{ + {ProviderID: "openai", DisplayName: "OpenAI", Available: true}, + }, + }, + }, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver) + if spec.Enabled { + t.Fatalf("expected tts tool disabled when voice feature is disabled") + } + requirement, ok := findRequirement(spec.Requirements, ttsVoiceEnabledRequirementID) + if !ok { + t.Fatalf("expected tts voice feature requirement") + } + if requirement.Available { + t.Fatalf("expected tts voice feature requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "voice is disabled") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + +func TestResolveEffectiveToolSpecTTSElevenLabsRequiresVoiceID(t *testing.T) { + t.Parallel() + + snapshot := loadToolRequirementSnapshot(context.Background(), gatewayToolSettingsStub{ + settings: settingsdto.Settings{ + Gateway: settingsdto.GatewaySettings{ControlPlaneEnabled: true}, + Tools: map[string]any{}, + }, + }) + resolver := NewBuiltinRequirementResolver(BuiltinRequirementDeps{ + Voice: voiceStatusStub{ + status: gatewayvoice.TTSStatusResponse{ + Enabled: true, + Config: gatewayvoice.TTSConfig{ + ProviderID: "elevenlabs", + }, + Providers: []gatewayvoice.TTSProviderCatalogItem{ + {ProviderID: "elevenlabs", DisplayName: "ElevenLabs", Available: true}, + }, + }, + }, + }) + spec := resolveEffectiveToolSpecWithResolver(context.Background(), specTTS().toDTO(), snapshot, resolver) + if spec.Enabled { + t.Fatalf("expected tts tool disabled when elevenlabs voice id is missing") + } + requirement, ok := findRequirement(spec.Requirements, ttsVoiceIDRequirementID) + if !ok { + t.Fatalf("expected tts voice id requirement") + } + if requirement.Available { + t.Fatalf("expected tts voice id requirement to be unavailable") + } + if !strings.Contains(strings.ToLower(requirement.Reason), "voice id") { + t.Fatalf("unexpected requirement reason: %q", requirement.Reason) + } +} + func TestPolicyPipelineDeniesUnavailableWebSearch(t *testing.T) { t.Parallel() diff --git a/internal/application/gateway/tools/web_tools.go b/internal/application/gateway/tools/web_tools.go index 5e0ebf4..e83b69f 100644 --- a/internal/application/gateway/tools/web_tools.go +++ b/internal/application/gateway/tools/web_tools.go @@ -10,7 +10,6 @@ import ( "io" "net/http" "net/url" - "os" "regexp" "strconv" "strings" @@ -18,32 +17,33 @@ import ( "time" md "github.com/JohannesKaufmann/html-to-markdown" + "github.com/PuerkitoBio/goquery" + "github.com/chromedp/cdproto/network" + "github.com/chromedp/chromedp" "github.com/hashicorp/go-retryablehttp" - "github.com/playwright-community/playwright-go" + "dreamcreator/internal/application/browsercdp" connectorsdto "dreamcreator/internal/application/connectors/dto" - domainweb "dreamcreator/internal/domain/web" + appcookies "dreamcreator/internal/application/cookies" + "dreamcreator/internal/application/sitepolicy" ) -const webFetchTypeBuiltin = "builtin" -const webFetchTypePlaywright = "playwright" - -const defaultWebFetchType = webFetchTypeBuiltin -const defaultWebFetchTimeoutSeconds = 20 -const defaultWebFetchMaxChars = 50000 -const defaultWebFetchMaxBodyBytes = 2 << 20 -const defaultWebFetchMaxRedirects = 3 -const defaultWebFetchRetryMax = 2 -const defaultWebFetchAcceptMarkdown = true -const defaultWebFetchPlaywrightMarkdown = true -const defaultWebFetchEnableUserAgent = true -const defaultWebFetchUserAgent = domainweb.DefaultBrowserRequestUserAgent -const defaultWebFetchAcceptLanguage = domainweb.DefaultBrowserRequestAcceptLanguage -const defaultWebSearchTimeoutSeconds = 30 -const defaultWebSearchCacheTtlMinutes = 15 -const defaultWebSearchCount = 5 -const maxWebSearchCount = 10 -const defaultWebSearchType = "api" +const webFetchTypeCDP = "cdp" + +const webFetchTypeBuiltin = webFetchTypeCDP + +const ( + defaultWebFetchType = webFetchTypeCDP + defaultWebFetchTimeoutSeconds = 20 + defaultWebFetchMaxChars = 50000 + defaultWebFetchMaxBodyBytes = 2 << 20 + defaultWebSearchTimeoutSeconds = 30 + defaultWebSearchCacheTtlMinutes = 15 + defaultWebSearchCount = 5 + maxWebSearchCount = 10 + defaultWebSearchType = "api" + defaultWebFetchContentSignalMain = "main_heuristic" +) const ( webStatusOK = "ok" @@ -74,6 +74,38 @@ var markdownListMarkerPattern = regexp.MustCompile(`(?m)^\s*[-*+]\s+`) var markdownBlockQuotePattern = regexp.MustCompile(`(?m)^\s*>\s*`) var markdownTableSepPattern = regexp.MustCompile(`(?m)^\s*\|?[-:\s|]+\|?\s*$`) +var defaultExtractorSelectors = []string{ + "article", + "main", + "[role=main]", + ".article", + ".post-content", + ".entry-content", + ".content", +} + +var defaultRemoveSelectors = []string{ + "script", + "style", + "noscript", + "svg", + "canvas", + "iframe", + "form", + "nav", + "aside", + "footer", + "header", + "[role=navigation]", + ".sidebar", + ".comments", + ".recommend", + ".recommendations", + ".related", + ".advertisement", + ".ads", +} + type webSearchResult struct { Title string `json:"title,omitempty"` URL string `json:"url,omitempty"` @@ -111,10 +143,9 @@ type tavilySearchResponse struct { Query string `json:"query"` Answer string `json:"answer"` Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Content string `json:"content"` - RawContent interface{} `json:"raw_content"` + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` } `json:"results"` } @@ -151,20 +182,9 @@ type webFetchResult struct { } type webFetchOptions struct { - TimeoutSeconds int - MaxChars int - MaxBodyBytes int - MaxRedirects int - RetryMax int - AcceptMarkdown bool - EnableUserAgent bool - UserAgent string - AcceptLanguage string - Headers map[string]any -} - -type webFetchPlaywrightOptions struct { - Markdown bool + TimeoutSeconds int + MaxChars int + MaxBodyBytes int } type webFetchResponse struct { @@ -184,40 +204,46 @@ type ConnectorsReader interface { ListConnectors(ctx context.Context) ([]connectorsdto.Connector, error) } +type extractorResult struct { + Content string + ContentSignal string +} + func runWebFetchTool(settings SettingsReader, connectors ConnectorsReader) func(ctx context.Context, args string) (string, error) { return func(ctx context.Context, args string) (string, error) { payload, err := parseToolArgs(args) if err != nil { return "", err } - url := getStringArg(payload, "url", "href") - if url == "" { + targetURL := getStringArg(payload, "url", "href") + if targetURL == "" { return "", errors.New("url is required") } + config := resolveToolsConfig(ctx, settings) - enabled := true - if value, ok := resolveWebFetchConfigBool(config, "enabled"); ok { - enabled = value - } - if !enabled { + if enabled, ok := resolveWebFetchConfigBool(config, "enabled"); ok && !enabled { return "", errors.New("web_fetch disabled") } - method := strings.ToUpper(getStringArg(payload, "method")) + + method := strings.ToUpper(strings.TrimSpace(getStringArg(payload, "method"))) if method == "" { method = http.MethodGet } - fetchType, err := resolveWebFetchType(payload, config) - if err != nil { - return "", err + if method != http.MethodGet { + result := buildWebFetchToolResult(targetURL, webFetchResponse{}, errors.New("web_fetch only supports GET")) + return marshalResult(result), nil } + options := resolveWebFetchOptions(payload, config, defaultWebFetchTimeoutSeconds) - playwrightOptions := resolveWebFetchPlaywrightOptions(payload, config) - cookies, err := resolveConnectorCookiesForURL(ctx, connectors, url) + cookies, err := browsercdp.ResolveConnectorCookiesForURL(ctx, connectors, targetURL) if err != nil { return "", err } - response, err := fetchByWebFetchType(ctx, fetchType, method, url, cookies, options, playwrightOptions) - result := buildWebFetchToolResult(url, response, err) + preferredBrowser := resolveWebFetchPreferredBrowser(config) + headless := resolveWebFetchHeadless(config) + + response, err := fetchWithCDP(ctx, targetURL, cookies, options, preferredBrowser, headless) + result := buildWebFetchToolResult(targetURL, response, err) return marshalResult(result), nil } } @@ -239,6 +265,99 @@ func runWebSearchTool(settings SettingsReader, connectors ConnectorsReader) func } } +func fetchWithCDP( + ctx context.Context, + targetURL string, + cookies []appcookies.Record, + options webFetchOptions, + preferredBrowser string, + headless bool, +) (webFetchResponse, error) { + runtimeCtx, cancel := context.WithTimeout(ctx, time.Duration(options.TimeoutSeconds)*time.Second) + defer cancel() + + runtime, err := browsercdp.Start(runtimeCtx, browsercdp.LaunchOptions{ + PreferredBrowser: preferredBrowser, + Headless: headless, + }) + if err != nil { + return webFetchResponse{}, err + } + defer runtime.Stop() + + tabCtx, tabCancel := chromedp.NewContext(runtime.BrowserContext()) + defer tabCancel() + + var navResponse *network.Response + var finalURL string + var contentType string + var htmlContent string + + tasks := chromedp.Tasks{ + network.Enable(), + } + if len(cookies) > 0 { + tasks = append(tasks, chromedp.ActionFunc(func(ctx context.Context) error { + return browsercdp.SetCookies(ctx, targetURL, cookies) + })) + } + if err := chromedp.Run(tabCtx, tasks); err != nil { + return webFetchResponse{}, err + } + + navResponse, err = chromedp.RunResponse(tabCtx, chromedp.Navigate(strings.TrimSpace(targetURL))) + if err != nil { + return webFetchResponse{}, err + } + if navResponse != nil { + contentType = strings.TrimSpace(navResponse.MimeType) + } + if selector := sitepolicy.ReadySelectorForURL(targetURL); selector != "" { + waitCtx, waitCancel := context.WithTimeout(tabCtx, time.Duration(options.TimeoutSeconds)*time.Second) + _ = chromedp.Run(waitCtx, chromedp.WaitVisible(selector, chromedp.ByQuery)) + waitCancel() + } + if err := chromedp.Run(tabCtx, + chromedp.Location(&finalURL), + chromedp.OuterHTML("html", &htmlContent, chromedp.ByQuery), + ); err != nil { + return webFetchResponse{}, err + } + if finalURL == "" { + finalURL = targetURL + } + extracted := extractMainContent(htmlContent, finalURL) + content := strings.TrimSpace(extracted.Content) + truncated := false + if options.MaxBodyBytes > 0 && len(content) > options.MaxBodyBytes { + content = content[:options.MaxBodyBytes] + truncated = true + } + if options.MaxChars > 0 && len(content) > options.MaxChars { + content = content[:options.MaxChars] + truncated = true + } + + headersMap := map[string]string{} + if navResponse != nil { + for key, value := range navResponse.Headers { + headersMap[key] = fmt.Sprint(value) + } + } + + return webFetchResponse{ + URL: targetURL, + FinalURL: finalURL, + Status: statusCodeFromResponse(navResponse), + Headers: headersMap, + ContentType: contentType, + Content: content, + MarkdownTokens: estimateMarkdownTokens(content), + ContentSignal: extracted.ContentSignal, + Truncated: truncated, + }, nil +} + func runWebSearchWithFallback( ctx context.Context, payload toolArgs, @@ -250,87 +369,224 @@ func runWebSearchWithFallback( _ = connectors searchType := resolveWebSearchType(config) switch searchType { - case "api": - provider := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "provider"))) - if provider == "" { - provider = "brave" - } - result, err := runWebSearchByAPI(ctx, payload, config, query, count) - if err != nil { - return webSearchResponse{ - Status: webStatusError, - Retryable: isTimeoutError(err), - NextAction: nextActionInspectErrorThenSwitch, - Message: trimToMaxChars(strings.TrimSpace(err.Error()), 260), - Provider: provider, - Quality: webQualityEmpty, - WebSearchAvailable: true, - Query: query, - } - } - quality := webQualitySufficient - if len(result.Results) == 0 { - quality = webQualityEmpty - } - return webSearchResponse{ - Status: webStatusOK, - Retryable: false, - NextAction: nextActionContinue, - Message: "search_completed", - Provider: provider, - Quality: quality, - WebSearchAvailable: true, - Query: query, - Results: result.Results, - Cached: result.Cached, - Data: map[string]any{ - "results_count": len(result.Results), - }, - } case "external_tools": return webSearchResponse{ Status: webStatusError, Retryable: false, NextAction: nextActionUseOtherToolsOrSkills, Message: "web_search_external_tools_not_configured", - Provider: "external_tools", + Provider: "web_search", Quality: webQualityEmpty, WebSearchAvailable: false, Query: query, } default: - return webSearchResponse{ - Status: webStatusError, - Retryable: false, - NextAction: nextActionUseOtherToolsOrSkills, - Message: "web_search_type_not_supported", - Provider: "web_search", - Quality: webQualityEmpty, - WebSearchAvailable: false, - Query: query, - Data: map[string]any{ - "type": searchType, - }, + result, err := runWebSearchByAPI(ctx, payload, config, query, count) + if err != nil { + return webSearchResponse{ + Status: webStatusError, + Retryable: false, + NextAction: nextActionUseOtherToolsOrSkills, + Message: err.Error(), + Provider: "web_search", + Quality: webQualityEmpty, + WebSearchAvailable: true, + Query: query, + } } + return result } } -func buildWebFetchToolResult(requestURL string, response webFetchResponse, err error) webFetchResult { - finalURL := strings.TrimSpace(response.FinalURL) - if finalURL == "" { - finalURL = strings.TrimSpace(requestURL) +func runWebSearchByAPI(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) (webSearchResponse, error) { + provider := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "provider"))) + if provider == "" { + provider = "brave" + } + cacheKey := fmt.Sprintf("%s:%s:%d", provider, strings.TrimSpace(query), count) + if cached, ok := loadWebSearchCache(cacheKey); ok { + cached.Cached = true + return cached, nil + } + + var ( + results []webSearchResult + err error + ) + switch provider { + case "tavily": + results, err = runTavilySearch(ctx, payload, config, query, count) + default: + results, err = runBraveSearch(ctx, payload, config, query, count) + } + if err != nil { + return webSearchResponse{}, err + } + response := webSearchResponse{ + Status: webStatusOK, + Retryable: false, + NextAction: nextActionContinue, + Message: "ok", + Provider: provider, + Quality: webQualitySufficient, + WebSearchAvailable: true, + Query: query, + Results: results, + } + storeWebSearchCache(cacheKey, response, resolveWebSearchCacheTTL(config)) + return response, nil +} + +func runBraveSearch(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) ([]webSearchResult, error) { + apiKey := strings.TrimSpace(resolveWebSearchProviderAPIKey(config, "brave")) + if apiKey == "" { + return nil, fmt.Errorf("Brave API key is missing") + } + reqURL, err := url.Parse(braveSearchEndpoint) + if err != nil { + return nil, err + } + values := reqURL.Query() + values.Set("q", query) + values.Set("count", strconv.Itoa(count)) + if value := getStringArg(payload, "country"); value != "" { + values.Set("country", value) } + if value := getStringArg(payload, "search_lang", "searchLang"); value != "" { + values.Set("search_lang", value) + } + if value := getStringArg(payload, "ui_lang", "uiLang"); value != "" { + values.Set("ui_lang", value) + } + if value := getStringArg(payload, "freshness"); value != "" { + values.Set("freshness", value) + } + reqURL.RawQuery = values.Encode() + + request, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return nil, err + } + request.Header.Set("Accept", "application/json") + request.Header.Set("X-Subscription-Token", apiKey) + + client := retryablehttp.NewClient() + client.RetryMax = 1 + client.HTTPClient.Timeout = time.Duration(resolveWebSearchTimeoutSeconds(payload, config)) * time.Second + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(response.Body, 2048)) + return nil, fmt.Errorf("Brave search failed: %s", strings.TrimSpace(string(body))) + } + var payloadJSON braveSearchResponse + if err := json.NewDecoder(response.Body).Decode(&payloadJSON); err != nil { + return nil, err + } + results := make([]webSearchResult, 0, len(payloadJSON.Web.Results)) + for _, item := range payloadJSON.Web.Results { + results = append(results, webSearchResult{ + Title: strings.TrimSpace(item.Title), + URL: strings.TrimSpace(item.URL), + Description: strings.TrimSpace(item.Description), + Age: strings.TrimSpace(item.Age), + }) + } + return results, nil +} + +func runTavilySearch(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) ([]webSearchResult, error) { + apiKey := strings.TrimSpace(resolveWebSearchProviderAPIKey(config, "tavily")) + if apiKey == "" { + return nil, fmt.Errorf("Tavily API key is missing") + } + requestBody := map[string]any{ + "api_key": apiKey, + "query": query, + "max_results": count, + } + if value := getStringArg(payload, "country"); value != "" { + requestBody["country"] = value + } + if value := getStringArg(payload, "search_depth", "searchDepth"); value != "" { + requestBody["search_depth"] = value + } + bodyBytes, err := json.Marshal(requestBody) + if err != nil { + return nil, err + } + request, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + request.Header.Set("Accept", "application/json") + request.Header.Set("Content-Type", "application/json") + + client := retryablehttp.NewClient() + client.RetryMax = 1 + client.HTTPClient.Timeout = time.Duration(resolveWebSearchTimeoutSeconds(payload, config)) * time.Second + response, err := client.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(response.Body, 2048)) + return nil, fmt.Errorf("Tavily search failed: %s", strings.TrimSpace(string(body))) + } + var payloadJSON tavilySearchResponse + if err := json.NewDecoder(response.Body).Decode(&payloadJSON); err != nil { + return nil, err + } + results := make([]webSearchResult, 0, len(payloadJSON.Results)) + for _, item := range payloadJSON.Results { + description := strings.TrimSpace(item.Content) + results = append(results, webSearchResult{ + Title: strings.TrimSpace(item.Title), + URL: strings.TrimSpace(item.URL), + Description: description, + }) + } + return results, nil +} + +func loadWebSearchCache(key string) (webSearchResponse, bool) { + webSearchCache.mu.RLock() + entry, ok := webSearchCache.entries[key] + webSearchCache.mu.RUnlock() + if !ok || time.Now().After(entry.expiresAt) { + return webSearchResponse{}, false + } + return entry.value, true +} + +func storeWebSearchCache(key string, value webSearchResponse, ttl time.Duration) { + if ttl <= 0 { + return + } + webSearchCache.mu.Lock() + webSearchCache.entries[key] = webSearchCacheEntry{ + value: value, + expiresAt: time.Now().Add(ttl), + } + webSearchCache.mu.Unlock() +} + +func buildWebFetchToolResult(rawURL string, response webFetchResponse, err error) webFetchResult { if err != nil { return webFetchResult{ Status: webStatusError, - Retryable: isTimeoutError(err), - NextAction: nextActionInspectErrorThenSwitch, - Message: trimToMaxChars(strings.TrimSpace(err.Error()), 260), + Retryable: false, + NextAction: nextActionUseOtherToolsOrSkills, + Message: err.Error(), Provider: "web_fetch", Quality: webQualityEmpty, WebSearchAvailable: true, - URL: strings.TrimSpace(requestURL), - FinalURL: finalURL, + URL: rawURL, + FinalURL: response.FinalURL, HTTPStatus: response.Status, ContentType: response.ContentType, Content: response.Content, @@ -338,14 +594,11 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e ContentSignal: response.ContentSignal, Truncated: response.Truncated, Data: map[string]any{ - "error": trimToMaxChars(strings.TrimSpace(err.Error()), 260), - "httpStatus": response.Status, - "finalURL": finalURL, "timeoutStage": response.TimeoutStage, - "truncated": response.Truncated, }, } } + quality := webQualitySufficient if strings.TrimSpace(response.Content) == "" { quality = webQualityEmpty @@ -358,8 +611,8 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e Provider: "web_fetch", Quality: quality, WebSearchAvailable: true, - URL: strings.TrimSpace(requestURL), - FinalURL: finalURL, + URL: rawURL, + FinalURL: response.FinalURL, HTTPStatus: response.Status, ContentType: response.ContentType, Content: response.Content, @@ -367,1326 +620,367 @@ func buildWebFetchToolResult(requestURL string, response webFetchResponse, err e ContentSignal: response.ContentSignal, Truncated: response.Truncated, Data: map[string]any{ - "httpStatus": response.Status, - "finalURL": finalURL, - "timeoutStage": response.TimeoutStage, - "truncated": response.Truncated, + "extractor": response.ContentSignal, + "browserSource": webFetchTypeCDP, + "timeoutStage": response.TimeoutStage, }, } } -func runWebSearchByAPI(ctx context.Context, payload toolArgs, config map[string]any, query string, count int) (webSearchResponse, error) { - provider := strings.ToLower(getNestedString(config, "web", "search", "provider")) - if provider == "" { - provider = "brave" +func resolveWebFetchType(payload toolArgs, config map[string]any) (string, error) { + if raw := strings.TrimSpace(getStringArg(payload, "type", "mode")); raw != "" { + normalized := normalizeWebFetchType(raw) + if normalized == "" { + return "", fmt.Errorf("unsupported web_fetch type: %s", raw) + } + return normalized, nil } - country := resolveWebSearchString(payload, config, "country", "US") - searchLang := resolveWebSearchString(payload, config, "search_lang", "") - uiLang := resolveWebSearchString(payload, config, "ui_lang", "") - freshness := resolveWebSearchString(payload, config, "freshness", "") - cacheTtlMinutes := resolveWebSearchInt(config, "cacheTtlMinutes", defaultWebSearchCacheTtlMinutes) - cacheKey := normalizeWebSearchCacheKey(provider, query, count, country, searchLang, uiLang, freshness) - if cacheTtlMinutes > 0 { - if cached, ok := readWebSearchCache(cacheKey); ok { - cached.Cached = true - return cached, nil + if fetchConfig := getNestedMap(config, "web_fetch"); fetchConfig != nil { + if raw := strings.TrimSpace(getStringArg(toolArgs(fetchConfig), "type", "mode")); raw != "" { + normalized := normalizeWebFetchType(raw) + if normalized == "" { + return "", fmt.Errorf("unsupported web_fetch type: %s", raw) + } + return normalized, nil } } - timeoutSeconds := resolveWebSearchInt(config, "timeoutSeconds", defaultWebSearchTimeoutSeconds) - timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) - defer cancel() + return defaultWebFetchType, nil +} - var ( - response webSearchResponse - err error - ) - switch provider { - case "brave": - response, err = runBraveSearch(timeoutCtx, config, query, count, country, searchLang, uiLang, freshness, timeoutSeconds) - case "tavily": - response, err = runTavilySearch(timeoutCtx, config, payload, query, count, timeoutSeconds) +func normalizeWebFetchType(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", "cdp", "chrome", "chromium", "browser", "builtin": + return webFetchTypeCDP default: - return webSearchResponse{}, errors.New("web_search provider not implemented: " + provider) - } - if err != nil { - return webSearchResponse{}, err - } - response.Query = query - response.Provider = provider - if cacheTtlMinutes > 0 { - writeWebSearchCache(cacheKey, response, time.Duration(cacheTtlMinutes)*time.Minute) + return "" } - return response, nil } -func resolveWebSearchCount(payload toolArgs, config map[string]any) int { - count, ok := getIntArg(payload, "count", "maxResults", "max_results") - if !ok || count <= 0 { - count = resolveWebSearchInt(config, "maxResults", defaultWebSearchCount) +func resolveWebFetchOptions(payload toolArgs, config map[string]any, fallbackTimeoutSeconds int) webFetchOptions { + fetchConfig := getNestedMap(config, "web_fetch") + timeoutSeconds := fallbackTimeoutSeconds + if value, ok := getIntArg(toolArgs(fetchConfig), "timeoutSeconds"); ok && value > 0 { + timeoutSeconds = value } - if count <= 0 { - count = defaultWebSearchCount + if value, ok := getIntArg(payload, "timeoutSeconds"); ok && value > 0 { + timeoutSeconds = value } - if count > maxWebSearchCount { - count = maxWebSearchCount + maxChars := defaultWebFetchMaxChars + if value, ok := getIntArg(toolArgs(fetchConfig), "maxChars"); ok && value > 0 { + maxChars = value } - if count < 1 { - count = 1 + if value, ok := getIntArg(payload, "maxChars"); ok && value > 0 { + maxChars = value } - return count -} - -func resolveWebSearchString(payload toolArgs, config map[string]any, key string, fallback string) string { - value := getStringArg(payload, key) - if value == "" { - value = getNestedString(config, "web", "search", key) + maxBodyBytes := defaultWebFetchMaxBodyBytes + if value, ok := getIntArg(toolArgs(fetchConfig), "maxBodyBytes"); ok && value > 0 { + maxBodyBytes = value } - if value == "" { - return fallback + if value, ok := getIntArg(payload, "maxBodyBytes"); ok && value > 0 { + maxBodyBytes = value } - return value -} -func normalizeWebSearchType(value string) string { - normalized := strings.ToLower(strings.TrimSpace(value)) - switch normalized { - case "api": - return "api" - case "external_tools", "external-tools", "external tools": - return "external_tools" - default: - return "" + return webFetchOptions{ + TimeoutSeconds: timeoutSeconds, + MaxChars: maxChars, + MaxBodyBytes: maxBodyBytes, } } -func resolveWebSearchType(config map[string]any) string { - if value := normalizeWebSearchType(getNestedString(config, "web", "search", "type")); value != "" { +func connectorTypeForURL(rawURL string) string { + return browsercdp.ConnectorTypeForURL(rawURL) +} + +func resolveWebFetchConfigBool(config map[string]any, key string) (bool, bool) { + fetchConfig := getNestedMap(config, "web_fetch") + return getBoolArg(toolArgs(fetchConfig), key) +} + +func resolveWebFetchPreferredBrowser(config map[string]any) string { + fetchConfig := getNestedMap(config, "web_fetch") + if value := strings.TrimSpace(getStringArg(toolArgs(fetchConfig), "preferredBrowser")); value != "" { return value } - if enabled, ok := getNestedBool(config, "web", "search", "enabled"); ok && enabled { - return "api" - } - return defaultWebSearchType + browserConfig := getNestedMap(config, "browser") + return strings.TrimSpace(getStringArg(toolArgs(browserConfig), "preferredBrowser")) } -func normalizeConnectorType(value string) string { - normalized := strings.ToLower(strings.TrimSpace(value)) - switch normalized { - case "google": - return "google" - case "xiaohongshu", "xhs": - return "xiaohongshu" - case "bilibili", "b23": - return "bilibili" - default: - return "" +func resolveWebFetchHeadless(config map[string]any) bool { + fetchConfig := getNestedMap(config, "web_fetch") + if value, ok := getBoolArg(toolArgs(fetchConfig), "headless"); ok { + return value } + return true } -func normalizeWebFetchType(value string) string { - switch strings.ToLower(strings.TrimSpace(value)) { - case webFetchTypePlaywright: - return webFetchTypePlaywright - case webFetchTypeBuiltin: - return webFetchTypeBuiltin - default: - return "" - } -} - -func resolveWebFetchType(payload toolArgs, config map[string]any) (string, error) { - if raw := strings.TrimSpace(getStringArg(payload, "type", "mode")); raw != "" { - if value := normalizeWebFetchType(raw); value != "" { - return value, nil - } - return "", fmt.Errorf("unsupported web_fetch type: %s", raw) - } - fetchConfig := toolArgs(resolveWebFetchConfig(config)) - if raw := strings.TrimSpace(getStringArg(fetchConfig, "type", "mode")); raw != "" { - if value := normalizeWebFetchType(raw); value != "" { - return value, nil - } - return "", fmt.Errorf("unsupported web_fetch type: %s", raw) - } - return defaultWebFetchType, nil -} - -func resolveWebFetchPlaywrightOptions(payload toolArgs, config map[string]any) webFetchPlaywrightOptions { - markdown := defaultWebFetchPlaywrightMarkdown - fetchConfig := toolArgs(resolveWebFetchConfig(config)) - if playwrightConfig := getMapArg(fetchConfig, "playwright"); playwrightConfig != nil { - if value, ok := getBoolArg(toolArgs(playwrightConfig), "markdown", "toMarkdown"); ok { - markdown = value - } - } - if payloadPlaywright := getMapArg(payload, "playwright"); payloadPlaywright != nil { - if value, ok := getBoolArg(toolArgs(payloadPlaywright), "markdown", "toMarkdown"); ok { - markdown = value - } - } - if value, ok := getBoolArg(payload, "markdown", "toMarkdown"); ok { - markdown = value - } - return webFetchPlaywrightOptions{ - Markdown: markdown, - } -} - -func fetchByWebFetchType( - ctx context.Context, - fetchType string, - method string, - targetURL string, - cookies []connectorsdto.ConnectorCookie, - options webFetchOptions, - playwrightOptions webFetchPlaywrightOptions, -) (webFetchResponse, error) { - switch normalizeWebFetchType(fetchType) { - case webFetchTypePlaywright: - if strings.ToUpper(strings.TrimSpace(method)) != http.MethodGet { - return webFetchResponse{}, errors.New("web_fetch playwright mode only supports GET") - } - return fetchWithPlaywrightOptions(ctx, targetURL, cookies, options, playwrightOptions) - case webFetchTypeBuiltin: - return fetchWithBuiltinOptions(ctx, method, targetURL, cookies, options) - default: - return fetchWithBuiltinOptions(ctx, method, targetURL, cookies, options) - } -} - -func resolveConnectorCookiesForURL( - ctx context.Context, - connectors ConnectorsReader, - targetURL string, -) ([]connectorsdto.ConnectorCookie, error) { - if connectors == nil { - return nil, nil - } - connectorType := connectorTypeForURL(targetURL) - if connectorType == "" { - return nil, nil - } - items, err := connectors.ListConnectors(ctx) +func extractMainContent(rawHTML string, pageURL string) extractorResult { + policy, _ := sitepolicy.ForURL(pageURL) + reader := strings.NewReader(rawHTML) + doc, err := goquery.NewDocumentFromReader(reader) if err != nil { - return nil, err - } - for _, item := range items { - if normalizeConnectorType(item.Type) != connectorType { - continue + return extractorResult{ + Content: compactText(rawHTML), + ContentSignal: "body_fallback", } - if len(item.Cookies) == 0 { - continue - } - return item.Cookies, nil - } - return nil, nil -} - -func connectorTypeForURL(targetURL string) string { - parsed, err := url.Parse(strings.TrimSpace(targetURL)) - if err != nil { - return "" - } - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - if host == "" { - return "" - } - switch { - case hostMatchesDomain(host, "google.com"), hostMatchesDomain(host, "youtube.com"), hostMatchesDomain(host, "youtu.be"): - return "google" - case hostMatchesDomain(host, "xiaohongshu.com"), hostMatchesDomain(host, "xhslink.com"), hostMatchesDomain(host, "redbook.com"): - return "xiaohongshu" - case hostMatchesDomain(host, "bilibili.com"), hostMatchesDomain(host, "b23.tv"): - return "bilibili" - default: - return "" - } -} - -func hostMatchesDomain(host string, domain string) bool { - normalizedHost := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(host)), ".") - normalizedDomain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(domain)), ".") - if normalizedHost == "" || normalizedDomain == "" { - return false } - return normalizedHost == normalizedDomain || strings.HasSuffix(normalizedHost, "."+normalizedDomain) -} - -func fetchWithPlaywrightOptions( - ctx context.Context, - targetURL string, - cookies []connectorsdto.ConnectorCookie, - options webFetchOptions, - playwrightOptions webFetchPlaywrightOptions, -) (webFetchResponse, error) { - if strings.TrimSpace(targetURL) == "" { - return webFetchResponse{}, errors.New("target url is required") - } - pw, err := playwright.Run() - if err != nil { - return webFetchResponse{}, err - } - defer pw.Stop() - - browser, err := pw.Chromium.Launch(playwright.BrowserTypeLaunchOptions{ - Headless: playwright.Bool(true), - Args: []string{ - "--headless=new", - }, - }) - if err != nil { - return webFetchResponse{}, err - } - defer browser.Close() - - contextOptions := playwright.BrowserNewContextOptions{ - Viewport: &playwright.Size{ - Width: 1366, - Height: 900, - }, - TimezoneId: playwright.String("UTC"), - } - if options.EnableUserAgent { - userAgent := strings.TrimSpace(options.UserAgent) - if userAgent == "" { - userAgent = defaultWebFetchUserAgent - } - contextOptions.UserAgent = playwright.String(userAgent) - } - if locale := localeFromAcceptLanguage(options.AcceptLanguage); locale != "" { - contextOptions.Locale = playwright.String(locale) - } - if headers := webFetchExtraHTTPHeaders(options); len(headers) > 0 { - contextOptions.ExtraHttpHeaders = headers - } - browserCtx, err := browser.NewContext(contextOptions) - if err != nil { - return webFetchResponse{}, err + for _, selector := range append(defaultRemoveSelectors, policy.RemoveSelectors...) { + doc.Find(selector).Each(func(_ int, selection *goquery.Selection) { + selection.Remove() + }) } - defer browserCtx.Close() - if len(cookies) > 0 { - if err := browserCtx.AddCookies(toPlaywrightCookies(cookies, targetURL)); err != nil { - return webFetchResponse{}, err + root := pickBestContentRoot(doc, policy) + if root == nil || root.Length() == 0 { + body := strings.TrimSpace(doc.Find("body").First().Text()) + return extractorResult{ + Content: compactMarkdown(body), + ContentSignal: "body_fallback", } } - page, err := browserCtx.NewPage() - if err != nil { - return webFetchResponse{}, err - } - - timeoutMs := float64(resolveWebFetchTimeoutSeconds(options.TimeoutSeconds) * 1000) - navTimeoutMs := timeoutMs * 0.6 - if navTimeoutMs < 1000 { - navTimeoutMs = timeoutMs - } - readyTimeoutMs := timeoutMs * 0.25 - if readyTimeoutMs < 600 { - readyTimeoutMs = 600 - } - extractTimeoutMs := timeoutMs * 0.15 - if extractTimeoutMs < 500 { - extractTimeoutMs = 500 - } - response, err := page.Goto(strings.TrimSpace(targetURL), playwright.PageGotoOptions{ - Timeout: playwright.Float(navTimeoutMs), - WaitUntil: playwright.WaitUntilStateDomcontentloaded, - }) + htmlFragment, err := root.Html() if err != nil { - return webFetchResponse{}, err - } - - timeoutStage := "" - finalURL := strings.TrimSpace(page.URL()) - if finalURL == "" { - finalURL = strings.TrimSpace(targetURL) - } - if selector := resolvePlaywrightReadySelector(finalURL, targetURL); selector != "" { - if readyErr := waitForPlaywrightReady(page, selector, readyTimeoutMs); readyErr != nil { - timeoutStage = "ready" + text := compactMarkdown(root.Text()) + if text == "" { + return extractorResult{ContentSignal: "body_fallback"} } - } - - status := http.StatusOK - contentType := "" - headers := map[string]string{} - if response != nil { - status = response.Status() - headers = normalizeHTTPHeaderMap(response.Headers()) - contentType = strings.TrimSpace(headers["content-type"]) - if contentType == "" { - allHeaders, err := response.AllHeaders() - if err == nil { - for key, value := range allHeaders { - if _, exists := headers[strings.ToLower(strings.TrimSpace(key))]; !exists { - headers[strings.ToLower(strings.TrimSpace(key))] = strings.TrimSpace(value) - } - } - contentType = strings.TrimSpace(headers["content-type"]) - } - } - } - - content, err := contentWithTimeout(ctx, page, time.Duration(extractTimeoutMs)*time.Millisecond) - if err != nil { - return webFetchResponse{}, err - } - if playwrightOptions.Markdown { - if markdown, err := convertHTMLToMarkdown(content, finalURL); err == nil && strings.TrimSpace(markdown) != "" { - content = markdown - contentType = "text/markdown; charset=utf-8" + return extractorResult{ + Content: text, + ContentSignal: defaultWebFetchContentSignalMain, } } - content, truncated := truncateWebFetchContent([]byte(content), options.MaxChars) - - return webFetchResponse{ - URL: strings.TrimSpace(targetURL), - FinalURL: finalURL, - Status: status, - Headers: headers, - ContentType: contentType, - Content: content, - TimeoutStage: timeoutStage, - Truncated: truncated, - }, nil -} -func fetchWithBuiltinOptions( - ctx context.Context, - method string, - targetURL string, - cookies []connectorsdto.ConnectorCookie, - options webFetchOptions, -) (webFetchResponse, error) { - if strings.TrimSpace(targetURL) == "" { - return webFetchResponse{}, errors.New("target url is required") - } - httpMethod := strings.ToUpper(strings.TrimSpace(method)) - if httpMethod == "" { - httpMethod = http.MethodGet - } - request, err := retryablehttp.NewRequestWithContext(ctx, httpMethod, targetURL, nil) + converter := md.NewConverter("", true, nil) + markdown, err := converter.ConvertString(htmlFragment) if err != nil { - return webFetchResponse{}, err - } - applyWebFetchHeaders(request.Request, options) - if cookieHeader := buildCookieHeader(targetURL, cookies); cookieHeader != "" { - if existing := strings.TrimSpace(request.Header.Get("Cookie")); existing != "" { - request.Header.Set("Cookie", existing+"; "+cookieHeader) - } else { - request.Header.Set("Cookie", cookieHeader) - } + markdown = compactMarkdown(root.Text()) } - timeoutSeconds := resolveWebFetchTimeoutSeconds(options.TimeoutSeconds) - maxRedirects := resolveWebFetchMaxRedirects(options.MaxRedirects) - retryMax := resolveWebFetchRetryMax(options.RetryMax) - client := retryablehttp.NewClient() - client.RetryMax = retryMax - client.RetryWaitMin = 200 * time.Millisecond - client.RetryWaitMax = 2 * time.Second - client.Logger = nil - client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { - if !isWebFetchRetryableMethod(httpMethod) { - return false, err - } - return retryablehttp.DefaultRetryPolicy(ctx, resp, err) + markdown = compactMarkdown(markdown) + if markdown == "" { + markdown = compactMarkdown(root.Text()) } - client.HTTPClient.Timeout = time.Duration(timeoutSeconds) * time.Second - client.HTTPClient.CheckRedirect = func(_ *http.Request, via []*http.Request) error { - if len(via) > maxRedirects { - return errors.New("too many redirects") + if markdown == "" { + return extractorResult{ + Content: compactMarkdown(doc.Text()), + ContentSignal: "body_fallback", } - return nil - } - resp, err := client.Do(request) - if err != nil { - return webFetchResponse{}, err } - defer resp.Body.Close() - body, bodyTruncated, err := readWebFetchBodyLimited(resp.Body, options.MaxBodyBytes) - if err != nil { - return webFetchResponse{}, err - } - content, charTruncated := truncateWebFetchContent(body, options.MaxChars) - markdownTokens, _ := strconv.Atoi(strings.TrimSpace(resp.Header.Get("x-markdown-tokens"))) - finalURL := strings.TrimSpace(targetURL) - if resp.Request != nil && resp.Request.URL != nil { - if resolved := strings.TrimSpace(resp.Request.URL.String()); resolved != "" { - finalURL = resolved - } - } - return webFetchResponse{ - URL: strings.TrimSpace(targetURL), - FinalURL: finalURL, - Status: resp.StatusCode, - Headers: normalizeHTTPHeaderValues(resp.Header), - ContentType: strings.TrimSpace(resp.Header.Get("Content-Type")), - Content: content, - MarkdownTokens: markdownTokens, - ContentSignal: strings.TrimSpace(resp.Header.Get("content-signal")), - Truncated: bodyTruncated || charTruncated, - }, nil -} - -func isWebFetchRetryableMethod(method string) bool { - switch strings.ToUpper(strings.TrimSpace(method)) { - case http.MethodGet, http.MethodHead, http.MethodOptions: - return true - default: - return false - } -} - -func resolveWebFetchTimeoutSeconds(value int) int { - timeoutSeconds := value - if timeoutSeconds <= 0 { - timeoutSeconds = defaultWebFetchTimeoutSeconds - } - return timeoutSeconds -} - -func resolveWebFetchMaxRedirects(value int) int { - maxRedirects := value - if maxRedirects < 0 { - maxRedirects = 0 - } - return maxRedirects -} - -func resolveWebFetchRetryMax(value int) int { - retryMax := value - if retryMax < 0 { - retryMax = 0 + return extractorResult{ + Content: markdown, + ContentSignal: resolveContentSignal(policy, root), } - return retryMax } -func webFetchExtraHTTPHeaders(options webFetchOptions) map[string]string { - result := make(map[string]string) - for key, value := range options.Headers { - trimmedKey := strings.TrimSpace(key) - if trimmedKey == "" { - continue - } - result[trimmedKey] = toString(value) - } - if _, exists := result["Accept"]; !exists { - if options.AcceptMarkdown { - result["Accept"] = "text/markdown, text/html;q=0.9, application/xhtml+xml;q=0.8" - } else { - result["Accept"] = "text/html,application/xhtml+xml" +func pickBestContentRoot(doc *goquery.Document, policy sitepolicy.Policy) *goquery.Selection { + for _, selector := range policy.ExtractorSelectors { + selection := doc.Find(selector).First() + if selection.Length() > 0 && len(compactText(selection.Text())) >= 160 { + return selection } } - if _, exists := result["Accept-Language"]; !exists && strings.TrimSpace(options.AcceptLanguage) != "" { - result["Accept-Language"] = strings.TrimSpace(options.AcceptLanguage) - } - return result -} - -func localeFromAcceptLanguage(value string) string { - segments := strings.Split(strings.TrimSpace(value), ",") - for _, segment := range segments { - base := strings.TrimSpace(segment) - if base == "" { - continue + for _, selector := range defaultExtractorSelectors { + selection := doc.Find(selector).First() + if selection.Length() > 0 && len(compactText(selection.Text())) >= 160 { + return selection } - if index := strings.Index(base, ";"); index >= 0 { - base = strings.TrimSpace(base[:index]) - } - if base != "" { - return base - } - } - return "" -} - -func convertHTMLToMarkdown(content string, targetURL string) (string, error) { - converter := md.NewConverter(md.DomainFromURL(strings.TrimSpace(targetURL)), true, nil) - markdown, err := converter.ConvertString(content) - if err != nil { - return "", err - } - return strings.TrimSpace(markdown), nil -} - -func resolveWebFetchOptions(payload toolArgs, config map[string]any, timeoutFallback int) webFetchOptions { - fetchConfig := toolArgs(resolveWebFetchConfig(config)) - timeoutSeconds, ok := getIntArg(fetchConfig, "timeoutSeconds") - if !ok || timeoutSeconds <= 0 { - timeoutSeconds, _ = getIntArg(payload, "timeoutSeconds") - } - if timeoutSeconds <= 0 { - timeoutSeconds = timeoutFallback - } - if timeoutSeconds <= 0 { - timeoutSeconds = defaultWebFetchTimeoutSeconds - } - - maxChars, ok := getIntArg(fetchConfig, "maxChars") - if !ok || maxChars <= 0 { - maxChars, _ = getIntArg(payload, "maxChars") - } - if maxChars <= 0 { - maxChars = defaultWebFetchMaxChars - } - - maxBodyBytes, ok := getIntArg(fetchConfig, "maxBodyBytes") - if !ok || maxBodyBytes <= 0 { - maxBodyBytes, _ = getIntArg(payload, "maxBodyBytes") - } - if maxBodyBytes <= 0 { - maxBodyBytes = defaultWebFetchMaxBodyBytes } - maxRedirects := defaultWebFetchMaxRedirects - if value, present := getIntArg(fetchConfig, "maxRedirects"); present { - maxRedirects = value - } - if value, present := getIntArg(payload, "maxRedirects"); present { - maxRedirects = value - } - if maxRedirects < 0 { - maxRedirects = 0 - } - - retryMax := defaultWebFetchRetryMax - if value, present := getIntArg(fetchConfig, "retryMax"); present { - retryMax = value - } - if value, present := getIntArg(payload, "retryMax"); present { - retryMax = value - } - if retryMax < 0 { - retryMax = 0 - } - - acceptMarkdown, ok := getBoolArg(fetchConfig, "acceptMarkdown") - if !ok { - acceptMarkdown = defaultWebFetchAcceptMarkdown - } - if value, present := getBoolArg(payload, "acceptMarkdown", "preferMarkdown", "markdown"); present { - acceptMarkdown = value - } - - enableUserAgent, ok := getBoolArg(fetchConfig, "enableUserAgent") - if !ok { - enableUserAgent = defaultWebFetchEnableUserAgent - } - if value, present := getBoolArg(payload, "enableUserAgent", "useUserAgent"); present { - enableUserAgent = value - } - - userAgent := getStringArg(payload, "userAgent") - if userAgent == "" { - userAgent = getStringArg(fetchConfig, "userAgent") - } - if userAgent == "" { - userAgent = defaultWebFetchUserAgent - } - - acceptLanguage := getStringArg(payload, "acceptLanguage") - if acceptLanguage == "" { - acceptLanguage = getStringArg(fetchConfig, "acceptLanguage") - } - if acceptLanguage == "" { - acceptLanguage = defaultWebFetchAcceptLanguage - } - headers := mergeAnyMap(getMapArg(fetchConfig, "headers"), getMapArg(payload, "headers")) - - return webFetchOptions{ - TimeoutSeconds: timeoutSeconds, - MaxChars: maxChars, - MaxBodyBytes: maxBodyBytes, - MaxRedirects: maxRedirects, - RetryMax: retryMax, - AcceptMarkdown: acceptMarkdown, - EnableUserAgent: enableUserAgent, - UserAgent: strings.TrimSpace(userAgent), - AcceptLanguage: strings.TrimSpace(acceptLanguage), - Headers: headers, - } -} - -func resolveWebFetchConfig(config map[string]any) map[string]any { - current := getNestedMap(config, "web_fetch") - if current == nil { - return nil - } - return current -} - -func resolveWebFetchConfigBool(config map[string]any, key string) (bool, bool) { - return getBoolArg(toolArgs(resolveWebFetchConfig(config)), key) -} - -func applyWebFetchHeaders(request *http.Request, options webFetchOptions) { - if request == nil { - return - } - for key, value := range options.Headers { - trimmedKey := strings.TrimSpace(key) - if trimmedKey == "" { - continue - } - request.Header.Set(trimmedKey, toString(value)) - } - if strings.TrimSpace(request.Header.Get("Accept")) == "" { - if options.AcceptMarkdown { - request.Header.Set("Accept", "text/markdown, text/html;q=0.9, application/xhtml+xml;q=0.8") - } else { - request.Header.Set("Accept", "text/html,application/xhtml+xml") - } - } - if strings.TrimSpace(request.Header.Get("User-Agent")) == "" && options.EnableUserAgent { - userAgent := strings.TrimSpace(options.UserAgent) - if userAgent == "" { - userAgent = defaultWebFetchUserAgent + var ( + best *goquery.Selection + bestScore float64 + ) + doc.Find("article,main,section,div").Each(func(_ int, selection *goquery.Selection) { + text := compactText(selection.Text()) + textLen := len(text) + if textLen < 160 { + return + } + linkTextLen := len(compactText(selection.Find("a").Text())) + linkDensity := 0.0 + if textLen > 0 { + linkDensity = float64(linkTextLen) / float64(textLen) + } + paragraphs := selection.Find("p").Length() + score := float64(textLen) + float64(paragraphs*80) - (linkDensity * 800) + if score > bestScore { + bestScore = score + best = selection } - request.Header.Set("User-Agent", userAgent) - } else if strings.TrimSpace(request.Header.Get("User-Agent")) == "" && !options.EnableUserAgent { - // Prevent net/http from injecting the default Go user-agent when disabled. - request.Header.Set("User-Agent", "") + }) + if best != nil { + return best } - if strings.TrimSpace(request.Header.Get("Accept-Language")) == "" && strings.TrimSpace(options.AcceptLanguage) != "" { - request.Header.Set("Accept-Language", strings.TrimSpace(options.AcceptLanguage)) + body := doc.Find("body").First() + if body.Length() > 0 { + return body } + return nil } -func truncateWebFetchContent(body []byte, maxChars int) (string, bool) { - content := string(body) - if maxChars <= 0 { - return content, false +func resolveContentSignal(policy sitepolicy.Policy, root *goquery.Selection) string { + if root == nil || root.Length() == 0 { + return "body_fallback" } - runeCount := 0 - for index := range content { - if runeCount >= maxChars { - return content[:index], true + nodeName := goquery.NodeName(root) + switch nodeName { + case "article": + return "article_readability" + case "main": + return "main_heuristic" + default: + if policy.Key != "" { + return defaultWebFetchContentSignalMain } - runeCount++ - } - if runeCount <= maxChars { - return content, false - } - return content, false -} - -func readWebFetchBodyLimited(reader io.Reader, maxBytes int) ([]byte, bool, error) { - if reader == nil { - return nil, false, nil + return defaultWebFetchContentSignalMain } - if maxBytes <= 0 { - maxBytes = defaultWebFetchMaxBodyBytes - } - limited := io.LimitReader(reader, int64(maxBytes)+1) - body, err := io.ReadAll(limited) - if err != nil { - return nil, false, err - } - if len(body) > maxBytes { - return body[:maxBytes], true, nil - } - return body, false, nil } -func buildCookieHeader(targetURL string, cookies []connectorsdto.ConnectorCookie) string { - if len(cookies) == 0 { - return "" - } - parsed, err := url.Parse(targetURL) - if err != nil { - return "" - } - host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) - path := strings.TrimSpace(parsed.Path) - if path == "" { - path = "/" - } - now := time.Now().Unix() - pairs := make([]string, 0, len(cookies)) - for _, cookie := range cookies { - name := strings.TrimSpace(cookie.Name) - if name == "" { +func compactMarkdown(input string) string { + text := strings.ReplaceAll(input, "\r\n", "\n") + text = markdownCodeFencePattern.ReplaceAllStringFunc(text, func(code string) string { + return strings.TrimSpace(code) + }) + lines := strings.Split(text, "\n") + result := make([]string, 0, len(lines)) + previousBlank := false + seen := map[string]struct{}{} + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + if previousBlank { + continue + } + previousBlank = true + result = append(result, "") continue } - if cookie.Expires > 0 && cookie.Expires < now { + previousBlank = false + normalized := htmlSpacePattern.ReplaceAllString(trimmed, " ") + if strings.HasPrefix(normalized, "Recommended") || strings.HasPrefix(normalized, "Related") { continue } - if !cookieDomainMatches(host, cookie.Domain) { - continue + if len(normalized) > 4000 { + normalized = normalized[:4000] } - if !cookiePathMatches(path, cookie.Path) { + key := strings.ToLower(normalized) + if _, exists := seen[key]; exists && len(normalized) > 48 { continue } - pairs = append(pairs, name+"="+cookie.Value) - } - return strings.Join(pairs, "; ") -} - -func toPlaywrightCookies(cookies []connectorsdto.ConnectorCookie, targetURL string) []playwright.OptionalCookie { - if len(cookies) == 0 { - return nil - } - result := make([]playwright.OptionalCookie, 0, len(cookies)) - for _, item := range cookies { - name := strings.TrimSpace(item.Name) - if name == "" { - continue - } - cookie := playwright.OptionalCookie{ - Name: name, - Value: item.Value, - } - domain := strings.TrimSpace(item.Domain) - path := strings.TrimSpace(item.Path) - if domain != "" { - cookie.Domain = playwright.String(domain) - } else if strings.TrimSpace(targetURL) != "" { - cookie.URL = playwright.String(strings.TrimSpace(targetURL)) - } - if path != "" { - cookie.Path = playwright.String(path) - } - if item.Expires > 0 { - cookie.Expires = playwright.Float(float64(item.Expires)) + if len(normalized) > 48 { + seen[key] = struct{}{} } - cookie.HttpOnly = playwright.Bool(item.HttpOnly) - cookie.Secure = playwright.Bool(item.Secure) - if sameSite := toPlaywrightSameSite(item.SameSite); sameSite != nil { - cookie.SameSite = sameSite - } - result = append(result, cookie) + result = append(result, normalized) } - return result + return strings.TrimSpace(strings.Join(result, "\n")) } -func toPlaywrightSameSite(value string) *playwright.SameSiteAttribute { - switch strings.ToLower(strings.TrimSpace(value)) { - case "lax": - return playwright.SameSiteAttributeLax - case "strict": - return playwright.SameSiteAttributeStrict - case "none": - return playwright.SameSiteAttributeNone - default: - return nil - } -} - -func cookieDomainMatches(host string, cookieDomain string) bool { - domain := strings.ToLower(strings.TrimSpace(cookieDomain)) - if domain == "" { - return true - } - domain = strings.TrimPrefix(domain, ".") - if domain == "" { - return true - } - return host == domain || strings.HasSuffix(host, "."+domain) +func compactText(input string) string { + text := html.UnescapeString(input) + text = htmlSpacePattern.ReplaceAllString(strings.TrimSpace(text), " ") + return text } -func cookiePathMatches(requestPath string, cookiePath string) bool { - path := strings.TrimSpace(cookiePath) - if path == "" || path == "/" { - return true +func estimateMarkdownTokens(content string) int { + trimmed := strings.TrimSpace(content) + if trimmed == "" { + return 0 } - return strings.HasPrefix(requestPath, path) + return (len(trimmed) + 3) / 4 } -func extractHTMLTitle(content string) string { - if strings.TrimSpace(content) == "" { - return "" +func statusCodeFromResponse(response *network.Response) int { + if response == nil { + return 0 } - matches := htmlTitlePattern.FindStringSubmatch(content) - if len(matches) < 2 { - return "" - } - return compactPlainText(matches[1], 120) + return int(response.Status) } -func extractHTMLSnippet(content string, limit int) string { - if strings.TrimSpace(content) == "" { +func extractWebPageTitle(content string, contentType string) string { + trimmed := strings.TrimSpace(content) + if trimmed == "" { return "" } - sanitized := htmlScriptStylePattern.ReplaceAllString(content, " ") - sanitized = htmlTagPattern.ReplaceAllString(sanitized, " ") - return compactPlainText(sanitized, limit) -} - -func extractWebPageTitle(content string, contentType string) string { if strings.Contains(strings.ToLower(contentType), "markdown") { - if title := extractMarkdownTitle(content); title != "" { - return title + if matches := markdownFrontMatterTitlePattern.FindStringSubmatch(trimmed); len(matches) > 1 { + return strings.TrimSpace(matches[1]) } - } - if title := extractHTMLTitle(content); title != "" { - return title - } - return extractMarkdownTitle(content) -} - -func extractWebPageSnippet(content string, contentType string, limit int) string { - if strings.Contains(strings.ToLower(contentType), "markdown") { - if snippet := extractMarkdownSnippet(content, limit); snippet != "" { - return snippet + if matches := markdownHeadingPattern.FindStringSubmatch(trimmed); len(matches) > 1 { + return strings.TrimSpace(matches[1]) } - } - if snippet := extractHTMLSnippet(content, limit); snippet != "" { - return snippet - } - return extractMarkdownSnippet(content, limit) -} - -func extractMarkdownTitle(content string) string { - if strings.TrimSpace(content) == "" { return "" } - if strings.HasPrefix(strings.TrimSpace(content), "---") { - matches := markdownFrontMatterTitlePattern.FindStringSubmatch(content) - if len(matches) >= 2 { - return compactPlainText(matches[1], 120) - } - } - matches := markdownHeadingPattern.FindStringSubmatch(content) - if len(matches) >= 2 { - return compactPlainText(matches[1], 120) + if matches := htmlTitlePattern.FindStringSubmatch(trimmed); len(matches) > 1 { + return compactText(matches[1]) } return "" } -func extractMarkdownSnippet(content string, limit int) string { - if strings.TrimSpace(content) == "" { - return "" - } - normalized := markdownCodeFencePattern.ReplaceAllString(content, " ") - normalized = markdownImagePattern.ReplaceAllString(normalized, "$1") - normalized = markdownLinkPattern.ReplaceAllString(normalized, "$1") - normalized = markdownHeadingMarkerPattern.ReplaceAllString(normalized, "") - normalized = markdownListMarkerPattern.ReplaceAllString(normalized, "") - normalized = markdownBlockQuotePattern.ReplaceAllString(normalized, "") - normalized = markdownTableSepPattern.ReplaceAllString(normalized, " ") - normalized = strings.NewReplacer("*", " ", "_", " ", "`", " ", "~", " ").Replace(normalized) - return compactPlainText(normalized, limit) -} - -func compactPlainText(value string, limit int) string { - if limit <= 0 { - limit = 320 - } - text := html.UnescapeString(value) - text = htmlSpacePattern.ReplaceAllString(strings.TrimSpace(text), " ") - if len(text) > limit { - return strings.TrimSpace(text[:limit]) + "..." - } - return text -} - -func resolveWebSearchProviderMap(config map[string]any, provider string) map[string]any { - if provider == "" { - return nil - } - return getNestedMap(config, "web", "search", "providers", provider) -} - -func resolveWebSearchProviderString(config map[string]any, provider string, key string) string { - if key == "" { +func extractWebPageSnippet(content string, contentType string, maxChars int) string { + trimmed := strings.TrimSpace(content) + if trimmed == "" { return "" } - providerConfig := resolveWebSearchProviderMap(config, provider) - if providerConfig == nil { - return "" - } - value, ok := providerConfig[key] - if !ok { - return "" - } - if str, ok := value.(string); ok { - return strings.TrimSpace(str) - } - return "" -} - -func resolveWebSearchInt(config map[string]any, key string, fallback int) int { - value, ok := getNestedInt(config, "web", "search", key) - if !ok || value <= 0 { - return fallback - } - return value -} - -func normalizeWebSearchCacheKey(provider string, query string, count int, country string, searchLang string, uiLang string, freshness string) string { - parts := []string{ - strings.ToLower(strings.TrimSpace(provider)), - strings.ToLower(strings.TrimSpace(query)), - strings.ToLower(strings.TrimSpace(country)), - strings.ToLower(strings.TrimSpace(searchLang)), - strings.ToLower(strings.TrimSpace(uiLang)), - strings.ToLower(strings.TrimSpace(freshness)), - strconv.Itoa(count), - } - return strings.Join(parts, "|") -} - -func readWebSearchCache(key string) (webSearchResponse, bool) { - webSearchCache.mu.RLock() - entry, ok := webSearchCache.entries[key] - webSearchCache.mu.RUnlock() - if !ok { - return webSearchResponse{}, false - } - if time.Now().After(entry.expiresAt) { - webSearchCache.mu.Lock() - delete(webSearchCache.entries, key) - webSearchCache.mu.Unlock() - return webSearchResponse{}, false - } - return entry.value, true -} - -func writeWebSearchCache(key string, value webSearchResponse, ttl time.Duration) { - if ttl <= 0 { - return - } - webSearchCache.mu.Lock() - webSearchCache.entries[key] = webSearchCacheEntry{ - value: value, - expiresAt: time.Now().Add(ttl), - } - webSearchCache.mu.Unlock() -} - -func runBraveSearch(ctx context.Context, config map[string]any, query string, count int, country string, searchLang string, uiLang string, freshness string, timeoutSeconds int) (webSearchResponse, error) { - apiKey := resolveWebSearchProviderString(config, "brave", "apiKey") - if apiKey == "" { - apiKey = getNestedString(config, "web", "search", "brave", "apiKey") - } - if apiKey == "" { - apiKey = getNestedString(config, "web", "search", "apiKey") - } - if apiKey == "" { - apiKey = strings.TrimSpace(os.Getenv("BRAVE_API_KEY")) - } - if apiKey == "" { - return webSearchResponse{}, errors.New("web_search needs a Brave API key") - } - endpoint := resolveWebSearchProviderString(config, "brave", "apiBaseUrl") - if endpoint == "" { - endpoint = getNestedString(config, "web", "search", "brave", "baseUrl") - } - if endpoint == "" { - endpoint = getNestedString(config, "web", "search", "baseUrl") - } - if endpoint == "" { - endpoint = braveSearchEndpoint - } - values := url.Values{} - values.Set("q", query) - if count > 0 { - values.Set("count", strconv.Itoa(count)) - } - if country != "" { - values.Set("country", country) - } - if searchLang != "" { - values.Set("search_lang", searchLang) - } - if uiLang != "" { - values.Set("ui_lang", uiLang) - } - if freshness != "" { - values.Set("freshness", freshness) - } - request, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint+"?"+values.Encode(), nil) - if err != nil { - return webSearchResponse{}, err - } - request.Header.Set("Accept", "application/json") - request.Header.Set("X-Subscription-Token", apiKey) - timeout := time.Duration(timeoutSeconds) * time.Second - if timeout <= 0 { - timeout = time.Duration(defaultWebSearchTimeoutSeconds) * time.Second - } - client := &http.Client{Timeout: timeout} - resp, err := client.Do(request) - if err != nil { - return webSearchResponse{}, err - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return webSearchResponse{}, err - } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { - message := strings.TrimSpace(string(body)) - if message == "" { - message = resp.Status - } - return webSearchResponse{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, message) - } - var parsed braveSearchResponse - if err := json.Unmarshal(body, &parsed); err != nil { - return webSearchResponse{}, err - } - results := make([]webSearchResult, 0, len(parsed.Web.Results)) - for _, item := range parsed.Web.Results { - results = append(results, webSearchResult{ - Title: strings.TrimSpace(item.Title), - URL: strings.TrimSpace(item.URL), - Description: strings.TrimSpace(item.Description), - Age: strings.TrimSpace(item.Age), - }) - } - if count > 0 && len(results) > count { - results = results[:count] - } - return webSearchResponse{ - Query: query, - Provider: "brave", - Results: results, - }, nil -} - -func runTavilySearch(ctx context.Context, config map[string]any, payload toolArgs, query string, count int, timeoutSeconds int) (webSearchResponse, error) { - apiKey := resolveWebSearchProviderString(config, "tavily", "apiKey") - if apiKey == "" { - apiKey = getNestedString(config, "web", "search", "tavily", "apiKey") - } - if apiKey == "" { - apiKey = getNestedString(config, "web", "search", "apiKey") - } - if apiKey == "" { - apiKey = strings.TrimSpace(os.Getenv("TAVILY_API_KEY")) - } - if apiKey == "" { - return webSearchResponse{}, errors.New("web_search needs a Tavily API key") - } - endpoint := resolveWebSearchProviderString(config, "tavily", "apiBaseUrl") - if endpoint == "" { - endpoint = getNestedString(config, "web", "search", "tavily", "baseUrl") - } - if endpoint == "" { - endpoint = getNestedString(config, "web", "search", "baseUrl") - } - if endpoint == "" { - endpoint = tavilySearchEndpoint - } - - requestBody := map[string]any{ - "query": query, - } - if count > 0 { - requestBody["max_results"] = count - } - if value := getStringArg(payload, "search_depth", "searchDepth"); value != "" { - requestBody["search_depth"] = value - } - if value := getStringArg(payload, "topic"); value != "" { - requestBody["topic"] = value - } - if value := getStringArg(payload, "time_range", "timeRange"); value != "" { - requestBody["time_range"] = value - } - if value := getStringArg(payload, "start_date", "startDate"); value != "" { - requestBody["start_date"] = value - } - if value := getStringArg(payload, "end_date", "endDate"); value != "" { - requestBody["end_date"] = value - } - if value := getStringArg(payload, "country"); value != "" { - requestBody["country"] = value - } - if value := getStringArg(payload, "include_answer", "includeAnswer"); value != "" { - requestBody["include_answer"] = value - } else if includeAnswer, ok := getBoolArg(payload, "include_answer", "includeAnswer"); ok { - requestBody["include_answer"] = includeAnswer - } - if value := getStringArg(payload, "include_raw_content", "includeRawContent"); value != "" { - requestBody["include_raw_content"] = value - } else if includeRaw, ok := getBoolArg(payload, "include_raw_content", "includeRawContent"); ok { - requestBody["include_raw_content"] = includeRaw - } - if includeImages, ok := getBoolArg(payload, "include_images", "includeImages"); ok { - requestBody["include_images"] = includeImages - } - if includeDescriptions, ok := getBoolArg(payload, "include_image_descriptions", "includeImageDescriptions"); ok { - requestBody["include_image_descriptions"] = includeDescriptions - } - if includeFavicon, ok := getBoolArg(payload, "include_favicon", "includeFavicon"); ok { - requestBody["include_favicon"] = includeFavicon - } - if autoParams, ok := getBoolArg(payload, "auto_parameters", "autoParameters"); ok { - requestBody["auto_parameters"] = autoParams - } - if chunks, ok := getIntArg(payload, "chunks_per_source", "chunksPerSource"); ok && chunks > 0 { - requestBody["chunks_per_source"] = chunks - } - if includeDomains := getStringSliceArg(payload, "include_domains", "includeDomains"); len(includeDomains) > 0 { - requestBody["include_domains"] = includeDomains - } - if excludeDomains := getStringSliceArg(payload, "exclude_domains", "excludeDomains"); len(excludeDomains) > 0 { - requestBody["exclude_domains"] = excludeDomains - } - - encoded, err := json.Marshal(requestBody) - if err != nil { - return webSearchResponse{}, err - } - request, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(encoded)) - if err != nil { - return webSearchResponse{}, err - } - request.Header.Set("Accept", "application/json") - request.Header.Set("Content-Type", "application/json") - request.Header.Set("Authorization", "Bearer "+apiKey) - - timeout := time.Duration(timeoutSeconds) * time.Second - if timeout <= 0 { - timeout = time.Duration(defaultWebSearchTimeoutSeconds) * time.Second - } - client := &http.Client{Timeout: timeout} - resp, err := client.Do(request) - if err != nil { - return webSearchResponse{}, err - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return webSearchResponse{}, err - } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { - message := strings.TrimSpace(string(body)) - if message == "" { - message = resp.Status - } - return webSearchResponse{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, message) - } - var parsed tavilySearchResponse - if err := json.Unmarshal(body, &parsed); err != nil { - return webSearchResponse{}, err - } - results := make([]webSearchResult, 0, len(parsed.Results)) - for _, item := range parsed.Results { - description := strings.TrimSpace(item.Content) - if description == "" { - if raw, ok := item.RawContent.(string); ok { - description = strings.TrimSpace(raw) - } - } - results = append(results, webSearchResult{ - Title: strings.TrimSpace(item.Title), - URL: strings.TrimSpace(item.URL), - Description: description, - }) - } - if count > 0 && len(results) > count { - results = results[:count] - } - return webSearchResponse{ - Query: query, - Provider: "tavily", - Results: results, - }, nil -} - -func isTimeoutError(err error) bool { - if err == nil { - return false + var snippet string + if strings.Contains(strings.ToLower(contentType), "markdown") { + snippet = markdownCodeFencePattern.ReplaceAllString(trimmed, "") + snippet = markdownImagePattern.ReplaceAllString(snippet, "$1") + snippet = markdownLinkPattern.ReplaceAllString(snippet, "$1") + snippet = markdownHeadingMarkerPattern.ReplaceAllString(snippet, "") + snippet = markdownListMarkerPattern.ReplaceAllString(snippet, "") + snippet = markdownBlockQuotePattern.ReplaceAllString(snippet, "") + snippet = markdownTableSepPattern.ReplaceAllString(snippet, "") + } else { + snippet = htmlScriptStylePattern.ReplaceAllString(trimmed, "") + snippet = htmlTagPattern.ReplaceAllString(snippet, " ") } - if errors.Is(err, context.DeadlineExceeded) { - return true + snippet = compactText(snippet) + if maxChars > 0 && len(snippet) > maxChars { + return snippet[:maxChars] } - lower := strings.ToLower(strings.TrimSpace(err.Error())) - return strings.Contains(lower, "timeout") || strings.Contains(lower, "timed out") + return snippet } -func trimToMaxChars(value string, max int) string { - trimmed := strings.TrimSpace(value) - if max <= 0 { - return trimmed - } - runes := []rune(trimmed) - if len(runes) <= max { - return trimmed +func resolveWebSearchType(config map[string]any) string { + value := strings.ToLower(strings.TrimSpace(getNestedString(config, "web", "search", "type"))) + switch value { + case "external_tools": + return "external_tools" + case "api": + return "api" + default: + return defaultWebSearchType } - return string(runes[:max]) } -func normalizeHTTPHeaderValues(header http.Header) map[string]string { - result := make(map[string]string, len(header)) - for key, values := range header { - trimmedKey := strings.ToLower(strings.TrimSpace(key)) - if trimmedKey == "" { - continue +func resolveWebSearchCount(payload toolArgs, config map[string]any) int { + if value, ok := getIntArg(payload, "count", "maxResults"); ok && value > 0 { + if value > maxWebSearchCount { + return maxWebSearchCount } - result[trimmedKey] = strings.TrimSpace(strings.Join(values, ", ")) + return value } - return result -} - -func normalizeHTTPHeaderMap(headers map[string]string) map[string]string { - result := make(map[string]string, len(headers)) - for key, value := range headers { - trimmedKey := strings.ToLower(strings.TrimSpace(key)) - if trimmedKey == "" { - continue + if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "maxResults"); ok && value > 0 { + if value > maxWebSearchCount { + return maxWebSearchCount } - result[trimmedKey] = strings.TrimSpace(value) - } - return result -} - -func resolvePlaywrightReadySelector(finalURL string, targetURL string) string { - host := extractHostname(finalURL) - if host == "" { - host = extractHostname(targetURL) - } - switch { - case hostMatchesDomain(host, "google.com"): - return "#search div.g, #search a h3" - case hostMatchesDomain(host, "xiaohongshu.com"): - return ".note-item, .search-result, section" - default: - return "main, article, body" - } -} - -func waitForPlaywrightReady(page playwright.Page, selector string, timeoutMs float64) error { - if strings.TrimSpace(selector) == "" { - return nil + return value } - _, err := page.WaitForSelector(selector, playwright.PageWaitForSelectorOptions{ - Timeout: playwright.Float(timeoutMs), - }) - return err + return defaultWebSearchCount } -func contentWithTimeout(ctx context.Context, page playwright.Page, timeout time.Duration) (string, error) { - type result struct { - content string - err error +func resolveWebSearchTimeoutSeconds(payload toolArgs, config map[string]any) int { + if value, ok := getIntArg(payload, "timeoutSeconds"); ok && value > 0 { + return value } - done := make(chan result, 1) - go func() { - content, err := page.Content() - done <- result{content: content, err: err} - }() - - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case <-ctx.Done(): - return "", ctx.Err() - case <-timer.C: - return "", errors.New("playwright extract timeout") - case output := <-done: - return output.content, output.err + if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "timeoutSeconds"); ok && value > 0 { + return value } + return defaultWebSearchTimeoutSeconds } -func extractHostname(rawURL string) string { - parsed, err := url.Parse(strings.TrimSpace(rawURL)) - if err != nil { - return "" +func resolveWebSearchCacheTTL(config map[string]any) time.Duration { + if value, ok := getIntArg(toolArgs(getNestedMap(config, "web", "search")), "cacheTtlMinutes"); ok && value > 0 { + return time.Duration(value) * time.Minute } - return strings.ToLower(strings.TrimSpace(parsed.Hostname())) + return time.Duration(defaultWebSearchCacheTtlMinutes) * time.Minute } diff --git a/internal/application/gateway/tools/web_tools_test.go b/internal/application/gateway/tools/web_tools_test.go index fee764e..83b66d8 100644 --- a/internal/application/gateway/tools/web_tools_test.go +++ b/internal/application/gateway/tools/web_tools_test.go @@ -3,8 +3,6 @@ package tools import ( "context" "encoding/json" - "net/http" - "net/http/httptest" "strings" "testing" @@ -28,156 +26,59 @@ func builtinWebFetchSettingsStub() *settingsReaderStub { settings: settingsdto.Settings{ Tools: map[string]any{ "web_fetch": map[string]any{ - "type": "builtin", + "preferredBrowser": "chrome", }, }, }, } } -func TestRunWebFetchTool_DefaultLLMHeaders(t *testing.T) { +func TestRunWebFetchTool_RejectsNonGETMethods(t *testing.T) { t.Parallel() - var acceptHeader string - var userAgentHeader string - var acceptLanguageHeader string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptHeader = r.Header.Get("Accept") - userAgentHeader = r.Header.Get("User-Agent") - acceptLanguageHeader = r.Header.Get("Accept-Language") - w.Header().Set("Content-Type", "text/markdown; charset=utf-8") - w.Header().Set("x-markdown-tokens", "256") - w.Header().Set("content-signal", "ai-input=yes") - _, _ = w.Write([]byte("# Hello\n\nworld")) - })) - defer server.Close() - handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil) - output, err := handler(context.Background(), `{"url":"`+server.URL+`"}`) + output, err := handler(context.Background(), `{"url":"https://example.com","method":"POST"}`) if err != nil { t.Fatalf("run web_fetch: %v", err) } - if !strings.Contains(acceptHeader, "text/markdown") { - t.Fatalf("expected markdown accept header, got %q", acceptHeader) - } - if !strings.Contains(userAgentHeader, "Version/17.0") || !strings.Contains(userAgentHeader, "Safari/605.1.15") { - t.Fatalf("expected default user-agent, got %q", userAgentHeader) - } - if acceptLanguageHeader != "en-US,en;q=0.9" { - t.Fatalf("expected default accept-language, got %q", acceptLanguageHeader) - } - var payload webFetchResult if err := json.Unmarshal([]byte(output), &payload); err != nil { t.Fatalf("decode result: %v", err) } - if payload.MarkdownTokens != 256 { - t.Fatalf("unexpected markdownTokens: %d", payload.MarkdownTokens) - } - if payload.ContentSignal != "ai-input=yes" { - t.Fatalf("unexpected content signal: %q", payload.ContentSignal) - } - if !strings.Contains(strings.ToLower(payload.ContentType), "markdown") { - t.Fatalf("expected markdown content type, got %q", payload.ContentType) - } -} - -func TestRunWebFetchTool_HeaderSwitches(t *testing.T) { - t.Parallel() - - var acceptHeader string - var userAgentHeader string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptHeader = r.Header.Get("Accept") - userAgentHeader = r.Header.Get("User-Agent") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer server.Close() - - handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil) - output, err := handler(context.Background(), `{"url":"`+server.URL+`","acceptMarkdown":false,"enableUserAgent":false}`) - if err != nil { - t.Fatalf("run web_fetch: %v", err) - } - if strings.Contains(strings.ToLower(acceptHeader), "text/markdown") { - t.Fatalf("expected html accept header, got %q", acceptHeader) - } - if userAgentHeader != "" { - t.Fatalf("expected empty user-agent when disabled, got %q", userAgentHeader) - } - - var payload webFetchResult - if err := json.Unmarshal([]byte(output), &payload); err != nil { - t.Fatalf("decode result: %v", err) - } - if payload.Status != webStatusOK { + if payload.Status != webStatusError { t.Fatalf("unexpected status: %q", payload.Status) } - if payload.HTTPStatus != http.StatusOK { - t.Fatalf("unexpected http status: %d", payload.HTTPStatus) + if payload.Message != "web_fetch only supports GET" { + t.Fatalf("unexpected message: %q", payload.Message) + } + if payload.NextAction != nextActionUseOtherToolsOrSkills { + t.Fatalf("unexpected next action: %q", payload.NextAction) } } -func TestRunWebFetchTool_ReadsTopLevelWebFetchSettings(t *testing.T) { +func TestResolveWebFetchPreferredBrowser_ReadsTopLevelConfig(t *testing.T) { t.Parallel() - var acceptHeader string - var userAgentHeader string - var acceptLanguageHeader string - var customHeader string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptHeader = r.Header.Get("Accept") - userAgentHeader = r.Header.Get("User-Agent") - acceptLanguageHeader = r.Header.Get("Accept-Language") - customHeader = r.Header.Get("X-Test") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - })) - defer server.Close() - - handler := runWebFetchTool(&settingsReaderStub{ - settings: settingsdto.Settings{ - Tools: map[string]any{ - "web_fetch": map[string]any{ - "type": "builtin", - "acceptMarkdown": false, - "enableUserAgent": false, - "acceptLanguage": "en-US,en;q=0.9", - "headers": map[string]any{ - "X-Test": "enabled", - }, - }, - }, + preferred := resolveWebFetchPreferredBrowser(map[string]any{ + "web_fetch": map[string]any{ + "preferredBrowser": "brave", }, - }, nil) - output, err := handler(context.Background(), `{"url":"`+server.URL+`"}`) - if err != nil { - t.Fatalf("run web_fetch: %v", err) - } - if strings.Contains(strings.ToLower(acceptHeader), "text/markdown") { - t.Fatalf("expected html accept header from settings, got %q", acceptHeader) - } - if userAgentHeader != "" { - t.Fatalf("expected empty user-agent from settings, got %q", userAgentHeader) - } - if acceptLanguageHeader != "en-US,en;q=0.9" { - t.Fatalf("expected accept-language from settings, got %q", acceptLanguageHeader) - } - if customHeader != "enabled" { - t.Fatalf("expected custom header from settings, got %q", customHeader) + "browser": map[string]any{ + "preferredBrowser": "chrome", + }, + }) + if preferred != "brave" { + t.Fatalf("expected preferred browser brave, got %q", preferred) } - var payload webFetchResult - if err := json.Unmarshal([]byte(output), &payload); err != nil { - t.Fatalf("decode result: %v", err) - } - if payload.Status != webStatusOK { - t.Fatalf("unexpected status: %q", payload.Status) - } - if payload.HTTPStatus != http.StatusOK { - t.Fatalf("unexpected http status: %d", payload.HTTPStatus) + fallback := resolveWebFetchPreferredBrowser(map[string]any{ + "browser": map[string]any{ + "preferredBrowser": "edge", + }, + }) + if fallback != "edge" { + t.Fatalf("expected fallback preferred browser edge, got %q", fallback) } } @@ -186,15 +87,9 @@ func TestResolveWebFetchOptions_ReadsTopLevelWebFetchConfig(t *testing.T) { options := resolveWebFetchOptions(toolArgs{}, map[string]any{ "web_fetch": map[string]any{ - "timeoutSeconds": 21, - "maxChars": 1234, - "maxBodyBytes": 4321, - "acceptMarkdown": false, - "enableUserAgent": false, - "userAgent": "top-level-agent", - "headers": map[string]any{ - "X-One": "1", - }, + "timeoutSeconds": 21, + "maxChars": 1234, + "maxBodyBytes": 4321, }, }, defaultWebFetchTimeoutSeconds) @@ -207,56 +102,30 @@ func TestResolveWebFetchOptions_ReadsTopLevelWebFetchConfig(t *testing.T) { if options.MaxBodyBytes != 4321 { t.Fatalf("expected maxBodyBytes from web_fetch, got %d", options.MaxBodyBytes) } - if options.MaxRedirects != defaultWebFetchMaxRedirects { - t.Fatalf("expected default maxRedirects, got %d", options.MaxRedirects) - } - if options.RetryMax != defaultWebFetchRetryMax { - t.Fatalf("expected default retryMax, got %d", options.RetryMax) - } - if options.AcceptMarkdown { - t.Fatalf("expected acceptMarkdown=false from web_fetch") - } - if options.EnableUserAgent { - t.Fatalf("expected enableUserAgent=false from web_fetch") - } - if options.UserAgent != "top-level-agent" { - t.Fatalf("expected userAgent from web_fetch, got %q", options.UserAgent) - } - if options.Headers["X-One"] != "1" { - t.Fatalf("expected headers from web_fetch, got %#v", options.Headers) - } } -func TestRunWebFetchTool_TruncatesLargeBodiesByMaxBodyBytes(t *testing.T) { +func TestBuildWebFetchToolResult_AnnotatesCDPSource(t *testing.T) { t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - _, _ = w.Write([]byte(strings.Repeat("a", 512))) - })) - defer server.Close() - - handler := runWebFetchTool(builtinWebFetchSettingsStub(), nil) - output, err := handler(context.Background(), `{"url":"`+server.URL+`","maxBodyBytes":64,"maxChars":200}`) - if err != nil { - t.Fatalf("run web_fetch: %v", err) - } - - var payload webFetchResult - if err := json.Unmarshal([]byte(output), &payload); err != nil { - t.Fatalf("decode result: %v", err) - } + payload := buildWebFetchToolResult("https://example.com", webFetchResponse{ + FinalURL: "https://example.com/article", + Status: 200, + ContentType: "text/markdown", + Content: "# Hello\n\nworld", + MarkdownTokens: 4, + ContentSignal: "article_readability", + }, nil) if payload.Status != webStatusOK { t.Fatalf("unexpected status: %q", payload.Status) } - if !payload.Truncated { - t.Fatalf("expected payload to be truncated") + if payload.Data["browserSource"] != webFetchTypeCDP { + t.Fatalf("expected browser source %q, got %#v", webFetchTypeCDP, payload.Data["browserSource"]) } - if len(payload.Content) != 64 { - t.Fatalf("expected content to be capped at 64 bytes, got %d", len(payload.Content)) + if payload.ContentSignal != "article_readability" { + t.Fatalf("unexpected content signal: %q", payload.ContentSignal) } - if payload.Content != strings.Repeat("a", 64) { - t.Fatalf("unexpected truncated content length=%d", len(payload.Content)) + if payload.MarkdownTokens != 4 { + t.Fatalf("unexpected markdown tokens: %d", payload.MarkdownTokens) } } @@ -281,33 +150,24 @@ Markdown has quickly become the lingua franca for agents.` } } -func TestResolveWebFetchTypeDefaultsToBuiltin(t *testing.T) { +func TestResolveWebFetchTypeDefaultsToCDP(t *testing.T) { t.Parallel() fetchType, err := resolveWebFetchType(toolArgs{}, map[string]any{}) if err != nil { t.Fatalf("resolve web_fetch type: %v", err) } - if fetchType != webFetchTypeBuiltin { - t.Fatalf("expected default fetch type builtin, got %q", fetchType) + if fetchType != webFetchTypeCDP { + t.Fatalf("expected default fetch type cdp, got %q", fetchType) } } -func TestResolveWebFetchPlaywrightOptionsDefaultsMarkdownEnabled(t *testing.T) { - t.Parallel() - - options := resolveWebFetchPlaywrightOptions(toolArgs{}, map[string]any{}) - if !options.Markdown { - t.Fatalf("expected markdown conversion enabled by default") - } -} - -func TestResolveWebFetchTypeOnlyAcceptsBuiltinAndPlaywright(t *testing.T) { +func TestResolveWebFetchTypeNormalizesLegacyLabelsToCDP(t *testing.T) { t.Parallel() fetchType := normalizeWebFetchType("builtin") - if fetchType != webFetchTypeBuiltin { - t.Fatalf("expected builtin fetch type, got %q", fetchType) + if fetchType != webFetchTypeCDP { + t.Fatalf("expected builtin label to normalize to cdp, got %q", fetchType) } if _, err := resolveWebFetchType(toolArgs{"type": "terminal"}, map[string]any{}); err == nil { @@ -327,6 +187,21 @@ func TestConnectorTypeForURL(t *testing.T) { if connectorType := connectorTypeForURL("https://www.google.com/search?q=test"); connectorType != "google" { t.Fatalf("expected google connector, got %q", connectorType) } + if connectorType := connectorTypeForURL("https://www.youtube.com/watch?v=test"); connectorType != "google" { + t.Fatalf("expected youtube URLs to use google connector cookies, got %q", connectorType) + } + if connectorType := connectorTypeForURL("https://github.com/owner/repo"); connectorType != "github" { + t.Fatalf("expected github connector, got %q", connectorType) + } + if connectorType := connectorTypeForURL("https://www.reddit.com/r/golang/comments/test"); connectorType != "reddit" { + t.Fatalf("expected reddit connector, got %q", connectorType) + } + if connectorType := connectorTypeForURL("https://www.zhihu.com/question/123456"); connectorType != "zhihu" { + t.Fatalf("expected zhihu connector, got %q", connectorType) + } + if connectorType := connectorTypeForURL("https://x.com/example/status/1"); connectorType != "x" { + t.Fatalf("expected x connector, got %q", connectorType) + } if connectorType := connectorTypeForURL("https://www.xiaohongshu.com/explore"); connectorType != "xiaohongshu" { t.Fatalf("expected xiaohongshu connector, got %q", connectorType) } diff --git a/internal/application/sitepolicy/policy.go b/internal/application/sitepolicy/policy.go new file mode 100644 index 0000000..636e9c9 --- /dev/null +++ b/internal/application/sitepolicy/policy.go @@ -0,0 +1,339 @@ +package sitepolicy + +import ( + "net/url" + "strings" +) + +type Policy struct { + Key string + ConnectorType string + Domains []string + ReadySelectors []string + ExtractorSelectors []string + RemoveSelectors []string + Capabilities []string +} + +var builtinPolicyOrder = []string{ + "youtube", + "google", + "github", + "reddit", + "zhihu", + "x", + "xiaohongshu", + "bilibili", +} + +var builtinPolicies = map[string]Policy{ + "youtube": { + Key: "youtube", + ConnectorType: "google", + Domains: []string{ + "youtube.com", + "youtu.be", + }, + ReadySelectors: []string{ + "ytd-watch-flexy", + "#content", + "main", + "body", + }, + ExtractorSelectors: []string{ + "#description", + "#description-inline-expander", + "ytd-watch-metadata", + "main", + }, + RemoveSelectors: []string{ + "#related", + "ytd-comments", + "ytd-merch-shelf-renderer", + "ytd-rich-grid-renderer", + }, + Capabilities: []string{"cookies", "web_fetch", "browser", "download"}, + }, + "google": { + Key: "google", + ConnectorType: "google", + Domains: []string{ + "google.com", + "youtube.com", + "youtu.be", + }, + ReadySelectors: []string{ + "#search", + "main", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + "#content", + }, + RemoveSelectors: []string{ + "#related", + "#secondary", + "ytd-comments", + }, + Capabilities: []string{"cookies", "web_fetch", "browser", "download"}, + }, + "github": { + Key: "github", + ConnectorType: "github", + Domains: []string{ + "github.com", + "raw.githubusercontent.com", + }, + ReadySelectors: []string{ + "main", + "#repo-content-pjax-container", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + ".markdown-body", + "[data-testid=\"issue-body\"]", + "[data-testid=\"pull-request-comment\"]", + }, + RemoveSelectors: []string{ + "header", + "footer", + ".Layout-sidebar", + ".js-header-wrapper", + "#repos-sticky-header", + }, + Capabilities: []string{"cookies", "web_fetch", "browser"}, + }, + "reddit": { + Key: "reddit", + ConnectorType: "reddit", + Domains: []string{ + "reddit.com", + "redd.it", + }, + ReadySelectors: []string{ + "main", + "[data-testid=\"post-container\"]", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + "[data-testid=\"post-container\"]", + ".md", + }, + RemoveSelectors: []string{ + "nav", + "[data-testid=\"frontpage-sidebar\"]", + "shreddit-comments-page-ad", + "shreddit-experience-tree", + }, + Capabilities: []string{"cookies", "web_fetch", "browser"}, + }, + "zhihu": { + Key: "zhihu", + ConnectorType: "zhihu", + Domains: []string{ + "zhihu.com", + }, + ReadySelectors: []string{ + "main", + ".Question-main", + ".Post-content", + "body", + }, + ExtractorSelectors: []string{ + "main", + ".Question-mainColumn", + ".Post-RichTextContainer", + "article", + }, + RemoveSelectors: []string{ + ".Question-sideColumn", + ".CornerButtons", + ".Recommendations-Main", + ".Comment-container", + }, + Capabilities: []string{"cookies", "web_fetch", "browser"}, + }, + "x": { + Key: "x", + ConnectorType: "x", + Domains: []string{ + "x.com", + "twitter.com", + }, + ReadySelectors: []string{ + "main", + "[data-testid=\"primaryColumn\"]", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + "[data-testid=\"tweet\"]", + }, + RemoveSelectors: []string{ + "nav", + "[data-testid=\"sidebarColumn\"]", + "[aria-label=\"Timeline: Trending now\"]", + "[aria-label=\"Who to follow\"]", + }, + Capabilities: []string{"cookies", "web_fetch", "browser"}, + }, + "xiaohongshu": { + Key: "xiaohongshu", + ConnectorType: "xiaohongshu", + Domains: []string{ + "xiaohongshu.com", + "xhslink.com", + "redbook.com", + }, + ReadySelectors: []string{ + "#app", + "main", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + "#noteContainer", + }, + RemoveSelectors: []string{ + ".note-side-bar", + ".recommend-container", + ".comment-container", + }, + Capabilities: []string{"cookies", "web_fetch", "browser"}, + }, + "bilibili": { + Key: "bilibili", + ConnectorType: "bilibili", + Domains: []string{ + "bilibili.com", + "b23.tv", + }, + ReadySelectors: []string{ + "#app", + "#arc_toolbar_report", + "main", + "body", + }, + ExtractorSelectors: []string{ + "main", + "article", + "#app", + }, + RemoveSelectors: []string{ + ".video-toolbar-v1", + ".right-container", + ".comment-container", + }, + Capabilities: []string{"cookies", "web_fetch", "browser", "download"}, + }, +} + +func List() []Policy { + result := make([]Policy, 0, len(builtinPolicyOrder)) + for _, key := range builtinPolicyOrder { + policy, ok := builtinPolicies[key] + if !ok { + continue + } + result = append(result, policy) + } + return result +} + +func ForConnectorType(connectorType string) (Policy, bool) { + policy, ok := builtinPolicies[strings.ToLower(strings.TrimSpace(connectorType))] + return policy, ok +} + +func ForURL(rawURL string) (Policy, bool) { + host := hostname(rawURL) + if host == "" { + return Policy{}, false + } + for _, key := range builtinPolicyOrder { + policy, ok := builtinPolicies[key] + if !ok { + continue + } + for _, domain := range policy.Domains { + if HostMatchesDomain(host, domain) { + return policy, true + } + } + } + return Policy{}, false +} + +func DomainsForConnector(connectorType string) []string { + policy, ok := ForConnectorType(connectorType) + if !ok { + return nil + } + return cloneStrings(policy.Domains) +} + +func ReadySelectorForURL(rawURL string) string { + policy, ok := ForURL(rawURL) + if !ok { + return "" + } + for _, selector := range policy.ReadySelectors { + if strings.TrimSpace(selector) != "" { + return strings.TrimSpace(selector) + } + } + return "" +} + +func HostMatchesDomain(host string, domain string) bool { + normalizedHost := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(host)), ".") + normalizedDomain := strings.TrimPrefix(strings.ToLower(strings.TrimSpace(domain)), ".") + if normalizedHost == "" || normalizedDomain == "" { + return false + } + return normalizedHost == normalizedDomain || strings.HasSuffix(normalizedHost, "."+normalizedDomain) +} + +func MatchDomains(rawURL string, domains []string) bool { + host := hostname(rawURL) + if host == "" { + return false + } + for _, domain := range domains { + if HostMatchesDomain(host, domain) { + return true + } + } + return false +} + +func cloneStrings(values []string) []string { + if len(values) == 0 { + return nil + } + result := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + result = append(result, trimmed) + } + return result +} + +func hostname(rawURL string) string { + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return "" + } + return strings.ToLower(strings.TrimSpace(parsed.Hostname())) +} diff --git a/internal/application/sitepolicy/policy_test.go b/internal/application/sitepolicy/policy_test.go new file mode 100644 index 0000000..c0d5578 --- /dev/null +++ b/internal/application/sitepolicy/policy_test.go @@ -0,0 +1,53 @@ +package sitepolicy + +import "testing" + +func TestForURLPrefersYouTubePolicyBeforeGoogle(t *testing.T) { + t.Parallel() + + policy, ok := ForURL("https://www.youtube.com/watch?v=dQw4w9WgXcQ") + if !ok { + t.Fatalf("expected youtube policy match") + } + if policy.Key != "youtube" { + t.Fatalf("expected youtube policy key, got %q", policy.Key) + } + if policy.ConnectorType != "google" { + t.Fatalf("expected youtube URLs to reuse google connector cookies, got %q", policy.ConnectorType) + } +} + +func TestForConnectorTypeGoogleDomainsIncludeYouTube(t *testing.T) { + t.Parallel() + + policy, ok := ForConnectorType("google") + if !ok { + t.Fatalf("expected google connector policy") + } + if !MatchDomains("https://www.youtube.com/watch?v=test", policy.Domains) { + t.Fatalf("expected google connector domains to cover youtube URLs") + } +} + +func TestForURLMatchesNewBuiltinSites(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "https://github.com/owner/repo": "github", + "https://www.reddit.com/r/golang/comments/test": "reddit", + "https://www.zhihu.com/question/123456": "zhihu", + "https://x.com/example/status/1": "x", + "https://www.xiaohongshu.com/explore/abc": "xiaohongshu", + "https://www.bilibili.com/video/BV1xx411c7mD/": "bilibili", + } + + for rawURL, expected := range cases { + policy, ok := ForURL(rawURL) + if !ok { + t.Fatalf("expected policy match for %s", rawURL) + } + if policy.Key != expected { + t.Fatalf("expected policy %q for %s, got %q", expected, rawURL, policy.Key) + } + } +} diff --git a/internal/application/skills/service/audit.go b/internal/application/skills/service/audit.go index a3ae365..37aa51f 100644 --- a/internal/application/skills/service/audit.go +++ b/internal/application/skills/service/audit.go @@ -48,9 +48,9 @@ func resolveSkillsAuditSourceFromContext(ctx context.Context) string { func resolveSkillsAuditGroup(action string) string { switch strings.ToLower(strings.TrimSpace(action)) { - case "skills.status", "skills.bins", "skill_manage.search", "skill_manage.list": + case "skills.status", "skills.bins", "skills_manage.search", "skills_manage.list": return "read" - case "skill_manage.install", "skill_manage.update", "skill_manage.remove", "skill_manage.sync": + case "skills_manage.install", "skills_manage.update", "skills_manage.remove", "skills_manage.sync": return "package_write" case "skills.install": return "deps_write" @@ -67,8 +67,8 @@ func resolveSkillsAuditTool(action string) string { switch { case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skills."): return "skills" - case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skill_manage."): - return "skill_manage" + case strings.HasPrefix(strings.ToLower(strings.TrimSpace(action)), "skills_manage."): + return "skills_manage" default: return "" } diff --git a/internal/application/skills/service/audit_test.go b/internal/application/skills/service/audit_test.go index c52dd54..c260483 100644 --- a/internal/application/skills/service/audit_test.go +++ b/internal/application/skills/service/audit_test.go @@ -47,7 +47,7 @@ func TestPruneSkillsAuditEntriesByRetention(t *testing.T) { "timestamp": now.AddDate(0, 0, -1).Format(time.RFC3339), }, map[string]any{ - "action": "skill_manage.search", + "action": "skills_manage.search", "timestamp": now.AddDate(0, 0, -30).Format(time.RFC3339), }, } diff --git a/internal/application/skills/service/search.go b/internal/application/skills/service/search.go index cc4d33d..4f67a80 100644 --- a/internal/application/skills/service/search.go +++ b/internal/application/skills/service/search.go @@ -106,20 +106,20 @@ func (service *SkillsService) SearchSkills(ctx context.Context, request dto.Sear providerID := "clawhub" workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, err) return nil, err } output, err := service.runClawHubCommand(ctx, workspaceRoot, skillsSearchTimeout, "search", "--limit", fmt.Sprintf("%d", limit), query) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, err) return nil, err } results := parseClawHubSearchOutput(output) if len(results) > limit { - service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, nil) return results[:limit], nil } - service.appendSkillsAuditRecord(ctx, "skill_manage.search", query, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.search", query, assistantID, providerID, nil) return results, nil } @@ -132,17 +132,17 @@ func (service *SkillsService) InspectSkill(ctx context.Context, request dto.Insp providerID := "clawhub" workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err) return dto.SkillDetail{}, err } output, err := service.runClawHubCommand(ctx, workspaceRoot, skillsSearchTimeout, "inspect", skill, "--json", "--files") if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err) return dto.SkillDetail{}, err } detail, err := parseClawHubInspectOutput(output) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, err) return dto.SkillDetail{}, err } if version := service.resolveInstalledSkillVersion(ctx, workspaceRoot, skill); version != "" { @@ -161,7 +161,7 @@ func (service *SkillsService) InspectSkill(ctx context.Context, request dto.Insp if detail.Name == "" { detail.Name = detail.ID } - service.appendSkillsAuditRecord(ctx, "skill_manage.inspect", skill, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.inspect", skill, assistantID, providerID, nil) return detail, nil } @@ -174,7 +174,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst providerID := "clawhub" workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, err) return err } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -206,7 +206,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst Force: request.Force, Error: err.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, err) return err } service.recordInstallAttempt(true) @@ -219,7 +219,7 @@ func (service *SkillsService) InstallSkill(ctx context.Context, request dto.Inst WorkspaceRoot: workspaceRoot, Force: request.Force, }) - service.appendSkillsAuditRecord(ctx, "skill_manage.install", skill, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.install", skill, assistantID, providerID, nil) return err } @@ -232,7 +232,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat providerID := "clawhub" workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, err) return err } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -264,7 +264,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat Force: request.Force, Error: err.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, err) return err } service.recordInstallAttempt(true) @@ -277,7 +277,7 @@ func (service *SkillsService) UpdateSkill(ctx context.Context, request dto.Updat WorkspaceRoot: workspaceRoot, Force: request.Force, }) - service.appendSkillsAuditRecord(ctx, "skill_manage.update", skill, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.update", skill, assistantID, providerID, nil) return err } @@ -290,7 +290,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov providerID := "clawhub" workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, err) return err } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -312,7 +312,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov WorkspaceRoot: workspaceRoot, Error: cleanupErr.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, cleanupErr) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, cleanupErr) return cleanupErr } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -322,7 +322,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov AssistantID: request.AssistantID, WorkspaceRoot: workspaceRoot, }) - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, nil) return nil } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -333,7 +333,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov WorkspaceRoot: workspaceRoot, Error: err.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, err) return err } if cleanupErr := service.removeWorkspaceSkillDirectories(ctx, workspaceRoot, skill); cleanupErr != nil { @@ -345,7 +345,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov WorkspaceRoot: workspaceRoot, Error: cleanupErr.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, cleanupErr) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, cleanupErr) return cleanupErr } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -355,7 +355,7 @@ func (service *SkillsService) RemoveSkill(ctx context.Context, request dto.Remov AssistantID: request.AssistantID, WorkspaceRoot: workspaceRoot, }) - service.appendSkillsAuditRecord(ctx, "skill_manage.remove", skill, assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.remove", skill, assistantID, providerID, nil) return nil } @@ -449,7 +449,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk } workspaceRoot, err := service.resolveWorkspaceRoot(ctx, request.AssistantID, request.WorkspaceRoot) if err != nil { - service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err) return nil, err } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -468,7 +468,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk WorkspaceRoot: workspaceRoot, Error: err.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err) return nil, err } result, err := service.ResolveSkillsForProviderInWorkspace(ctx, dto.ResolveSkillsRequest{ @@ -483,7 +483,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk WorkspaceRoot: workspaceRoot, Error: err.Error(), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, err) + service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, err) return nil, err } service.emitRealtimeEvent(ctx, SkillsRealtimeEvent{ @@ -494,7 +494,7 @@ func (service *SkillsService) SyncSkills(ctx context.Context, request dto.SyncSk WorkspaceRoot: workspaceRoot, CatalogCount: len(result), }) - service.appendSkillsAuditRecord(ctx, "skill_manage.sync", "", assistantID, providerID, nil) + service.appendSkillsAuditRecord(ctx, "skills_manage.sync", "", assistantID, providerID, nil) return result, nil } diff --git a/internal/application/skills/service/search_test.go b/internal/application/skills/service/search_test.go index eb3f707..38a20c2 100644 --- a/internal/application/skills/service/search_test.go +++ b/internal/application/skills/service/search_test.go @@ -495,14 +495,8 @@ func TestClassifyClawHubCommandError(t *testing.T) { func TestInstallSkillEmitsRealtimeStartedAndCompletedEvents(t *testing.T) { t.Parallel() - scriptPath := filepath.Join(t.TempDir(), "clawhub") - script := "#!/bin/sh\nexit 0\n" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { - t.Fatalf("write script failed: %v", err) - } - svc := NewSkillsService(newMemorySkillsRepo(), nil) - svc.SetExternalTools(&skillsExternalToolsStub{ready: true, execPath: scriptPath}) + svc.SetPackageAdapter(&skillsPackageAdapterStub{}) fixedNow := time.Date(2026, time.March, 17, 12, 0, 0, 0, time.UTC) svc.now = func() time.Time { return fixedNow } @@ -556,14 +550,12 @@ func TestInstallSkillEmitsRealtimeStartedAndCompletedEvents(t *testing.T) { func TestInstallSkillEmitsRealtimeFailedEvent(t *testing.T) { t.Parallel() - scriptPath := filepath.Join(t.TempDir(), "clawhub") - script := "#!/bin/sh\necho \"install failed\" 1>&2\nexit 1\n" - if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { - t.Fatalf("write script failed: %v", err) - } - svc := NewSkillsService(newMemorySkillsRepo(), nil) - svc.SetExternalTools(&skillsExternalToolsStub{ready: true, execPath: scriptPath}) + svc.SetPackageAdapter(&skillsPackageAdapterStub{ + run: func(_ context.Context, _ string, _ time.Duration, _ ...string) ([]byte, error) { + return nil, errors.New("install failed") + }, + }) fixedNow := time.Date(2026, time.March, 17, 12, 0, 1, 0, time.UTC) svc.now = func() time.Time { return fixedNow } diff --git a/internal/application/softwareupdate/service.go b/internal/application/softwareupdate/service.go index 49dabb7..f903433 100644 --- a/internal/application/softwareupdate/service.go +++ b/internal/application/softwareupdate/service.go @@ -2,6 +2,7 @@ package softwareupdate import ( "context" + "strings" "sync" "time" @@ -21,6 +22,10 @@ type AppFallbackProvider interface { FetchAppRelease(ctx context.Context, request AppRequest) (AppRelease, error) } +type appVersionFallbackProvider interface { + FetchAppReleaseByVersion(ctx context.Context, version string) (AppRelease, error) +} + type ToolFallbackProvider interface { FetchToolRelease(ctx context.Context, request ToolRequest) (ToolRelease, error) } @@ -164,6 +169,38 @@ func (service *Service) ResolveAppRelease(ctx context.Context, request AppReques return release, nil } +func (service *Service) ResolveAppReleaseByVersion(ctx context.Context, version string) (AppRelease, error) { + normalizedVersion := normalizeAppReleaseVersion(version) + if service == nil || normalizedVersion == "" { + return AppRelease{}, ErrReleaseNotFound + } + + snapshot, err := service.EnsureCatalog(ctx, time.Minute, Request{AppVersion: normalizedVersion}) + if err == nil && snapshot.Catalog.App != nil && sameAppReleaseVersion(snapshot.Catalog.App.Version, normalizedVersion) { + release := *snapshot.Catalog.App + release.ResolvedBy = SourceManifest + return release, nil + } + + resolver, ok := service.appFallbackProvider.(appVersionFallbackProvider) + if !ok || resolver == nil { + if err != nil { + return AppRelease{}, err + } + return AppRelease{}, ErrReleaseNotFound + } + + release, fallbackErr := resolver.FetchAppReleaseByVersion(ctx, normalizedVersion) + if fallbackErr != nil { + if err != nil { + return AppRelease{}, err + } + return AppRelease{}, fallbackErr + } + release.ResolvedBy = SourceFallback + return release, nil +} + func (service *Service) ResolveToolRelease(ctx context.Context, request ToolRequest) (ToolRelease, error) { if service == nil { return ToolRelease{}, ErrReleaseNotFound @@ -240,6 +277,18 @@ func (service *Service) StartSchedule(ctx context.Context, initialDelay time.Dur }() } +func normalizeAppReleaseVersion(version string) string { + trimmed := strings.TrimSpace(version) + trimmed = strings.TrimPrefix(strings.TrimPrefix(trimmed, "v"), "V") + return trimmed +} + +func sameAppReleaseVersion(left string, right string) bool { + normalizedLeft := normalizeAppReleaseVersion(left) + normalizedRight := normalizeAppReleaseVersion(right) + return normalizedLeft != "" && normalizedLeft == normalizedRight +} + func (service *Service) StopSchedule() { service.mu.Lock() defer service.mu.Unlock() diff --git a/internal/application/tools/dto/models.go b/internal/application/tools/dto/models.go index b6f382a..4af55da 100644 --- a/internal/application/tools/dto/models.go +++ b/internal/application/tools/dto/models.go @@ -21,6 +21,7 @@ type ToolRequirement struct { Name string `json:"name,omitempty"` Available bool `json:"available"` Reason string `json:"reason,omitempty"` + Data any `json:"data,omitempty"` } type ToolMethodSpec struct { diff --git a/internal/application/update/service.go b/internal/application/update/service.go index 46215a1..7ce1471 100644 --- a/internal/application/update/service.go +++ b/internal/application/update/service.go @@ -8,6 +8,7 @@ import ( "io" "os" "slices" + "strconv" "strings" "sync" "time" @@ -23,7 +24,7 @@ type Downloader interface { } type Installer interface { - Install(ctx context.Context, artifactPath string) error + Install(ctx context.Context, artifactPath string, prepared update.Info) error RestartToApply(ctx context.Context) error } @@ -31,24 +32,36 @@ type downloadURLSelector interface { SelectDownloadURLs(ctx context.Context, urls []string) []string } +type preparedUpdateInspector interface { + PreparedUpdate(ctx context.Context) (update.Info, bool, error) + ClearPreparedUpdate(ctx context.Context) error +} + +type whatsNewStore interface { + PendingWhatsNew(ctx context.Context) (update.WhatsNew, bool, error) + SeenWhatsNewVersion(ctx context.Context) (string, error) + MarkWhatsNewSeen(ctx context.Context, version string) error +} + type Notifier interface { SetUpdateAvailable(available bool) NotifyUpdateState(info update.Info) } type Service struct { - mu sync.Mutex - state update.Info - catalog *softwareupdate.Service - downloader Downloader - installer Installer - bus events.Bus - notifier Notifier - now func() time.Time - scheduleTicker *time.Ticker - cancelSchedule context.CancelFunc - downloadURLs []string - downloadSHA256 string + mu sync.Mutex + state update.Info + catalog *softwareupdate.Service + downloader Downloader + installer Installer + bus events.Bus + notifier Notifier + now func() time.Time + scheduleTicker *time.Ticker + cancelSchedule context.CancelFunc + downloadURLs []string + downloadSHA256 string + autoPrepareInFlight bool } type ServiceParams struct { @@ -89,9 +102,15 @@ func (service *Service) PublishCurrentState() { func (service *Service) SetCurrentVersion(version string) { service.mu.Lock() service.state.CurrentVersion = update.NormalizeVersion(version) + if service.state.CurrentVersion != "" && + service.state.PreparedVersion != "" && + update.CompareVersion(service.state.CurrentVersion, service.state.PreparedVersion) >= 0 { + service.clearPreparedStateLocked() + } if service.state.CurrentVersion != "" && service.state.LatestVersion != "" && - update.CompareVersion(service.state.CurrentVersion, service.state.LatestVersion) >= 0 { + update.CompareVersion(service.state.CurrentVersion, service.state.LatestVersion) >= 0 && + !service.state.HasPreparedUpdate() { service.state.Status = update.StatusIdle service.state.Progress = 0 service.state.DownloadURL = "" @@ -102,11 +121,117 @@ func (service *Service) SetCurrentVersion(version string) { service.mu.Unlock() } +func (service *Service) RestorePreparedUpdate(ctx context.Context) (update.Info, error) { + inspector, ok := service.installer.(preparedUpdateInspector) + if !ok || inspector == nil { + return service.State(), nil + } + + prepared, found, err := inspector.PreparedUpdate(ctx) + if err != nil { + return service.State(), err + } + if !found { + return service.State(), nil + } + + preparedVersion := update.NormalizeVersion(prepared.PreparedVersion) + currentVersion := update.NormalizeVersion(service.State().CurrentVersion) + if preparedVersion != "" && currentVersion != "" && update.CompareVersion(currentVersion, preparedVersion) >= 0 { + if clearErr := inspector.ClearPreparedUpdate(ctx); clearErr != nil { + zap.L().Warn("update: clear stale prepared update failed", zap.Error(clearErr)) + } + service.mu.Lock() + service.clearPreparedStateLocked() + state := service.state + service.mu.Unlock() + return state, nil + } + + service.mu.Lock() + service.state.Kind = update.KindApp + if strings.TrimSpace(service.state.LatestVersion) == "" || + update.CompareVersion(preparedVersion, service.state.LatestVersion) >= 0 { + service.state.LatestVersion = preparedVersion + service.state.Changelog = prepared.PreparedChangelog + } + service.state.PreparedVersion = preparedVersion + service.state.PreparedChangelog = prepared.PreparedChangelog + service.setPreparedReadyLocked() + state := service.state + service.mu.Unlock() + service.notifyAvailability(true) + return state, nil +} + +func (service *Service) GetWhatsNew(ctx context.Context) (update.WhatsNew, error) { + store, ok := service.installer.(whatsNewStore) + if !ok || store == nil { + return update.WhatsNew{}, nil + } + + currentVersion := update.NormalizeVersion(service.State().CurrentVersion) + if !isReleaseVersion(currentVersion) { + return update.WhatsNew{}, nil + } + + seenVersion, err := store.SeenWhatsNewVersion(ctx) + if err != nil { + return update.WhatsNew{}, err + } + seenVersion = update.NormalizeVersion(seenVersion) + + pending, found, err := store.PendingWhatsNew(ctx) + if err != nil { + return update.WhatsNew{}, err + } + if found { + pendingVersion := update.NormalizeVersion(pending.Version) + switch { + case pendingVersion != "" && + update.CompareVersion(currentVersion, pendingVersion) == 0 && + (seenVersion == "" || update.CompareVersion(pendingVersion, seenVersion) > 0): + pending.CurrentVersion = currentVersion + if strings.TrimSpace(pending.Changelog) == "" { + pending.Changelog = service.resolveReleaseNotes(ctx, pendingVersion) + } + return pending, nil + case pendingVersion != "" && update.CompareVersion(currentVersion, pendingVersion) > 0: + // A newer version is already running; ignore the older pending notice. + case pendingVersion != "" && update.CompareVersion(currentVersion, pendingVersion) < 0: + return update.WhatsNew{}, nil + } + } + + if seenVersion != "" && update.CompareVersion(currentVersion, seenVersion) <= 0 { + return update.WhatsNew{}, nil + } + + return update.WhatsNew{ + Version: currentVersion, + CurrentVersion: currentVersion, + Changelog: service.resolveReleaseNotes(ctx, currentVersion), + }, nil +} + +func (service *Service) DismissWhatsNew(ctx context.Context, version string) error { + store, ok := service.installer.(whatsNewStore) + if !ok || store == nil { + return nil + } + return store.MarkWhatsNewSeen(ctx, update.NormalizeVersion(version)) +} + func (service *Service) CheckForUpdate(ctx context.Context, currentVersion string) (update.Info, error) { service.mu.Lock() if currentVersion != "" { service.state.CurrentVersion = update.NormalizeVersion(currentVersion) } + if service.state.Status == update.StatusDownloading || service.state.Status == update.StatusInstalling { + state := service.state + service.mu.Unlock() + return state, nil + } service.setStatusLocked(update.StatusChecking, 0, "") state := service.state service.mu.Unlock() @@ -125,18 +250,12 @@ func (service *Service) CheckForUpdate(ctx context.Context, currentVersion strin }) if err != nil { - service.mu.Lock() - service.setStatusLocked(update.StatusError, service.state.Progress, err.Error()) - state := service.state - service.mu.Unlock() - go service.publishSnapshot(state) - return state, err + return service.publishCheckError(err) } downloadURLs := service.selectDownloadURLs(ctx, release.Asset.DownloadURLs()) service.mu.Lock() - defer service.mu.Unlock() latest := update.NormalizeVersion(release.Version) current := update.NormalizeVersion(service.state.CurrentVersion) service.state.LatestVersion = latest @@ -155,43 +274,71 @@ func (service *Service) CheckForUpdate(ctx context.Context, currentVersion strin ) if current != "" && latest != "" && update.CompareVersion(current, latest) >= 0 { + if service.state.HasPreparedUpdate() { + service.setPreparedReadyLocked() + state := service.state + service.mu.Unlock() + service.notifyAvailability(true) + service.publishSnapshot(state) + return state, nil + } service.setStatusLocked(update.StatusNoUpdate, 0, "") state := service.state service.downloadURLs = nil service.downloadSHA256 = "" + service.mu.Unlock() service.notifyAvailability(false) - go service.publishSnapshot(state) + service.publishSnapshot(state) return state, nil } if service.state.DownloadURL == "" { - service.setStatusLocked(update.StatusError, service.state.Progress, "no downloadable asset for update") + service.mu.Unlock() + return service.publishPrepareError(fmt.Errorf("no downloadable asset for update"), service.capturePreparedFallback()) + } + + if service.state.HasPreparedUpdate() && + update.CompareVersion(service.state.PreparedVersion, latest) == 0 { + service.setPreparedReadyLocked() state := service.state - go service.publishSnapshot(state) - return state, fmt.Errorf("no downloadable asset for update") + service.mu.Unlock() + service.notifyAvailability(true) + service.publishSnapshot(state) + return state, nil } service.setStatusLocked(update.StatusAvailable, 0, "") state = service.state + shouldAutoPrepare := service.shouldAutoPrepareLocked() + service.mu.Unlock() service.notifyAvailability(true) - go service.publishSnapshot(state) + service.publishSnapshot(state) + if shouldAutoPrepare { + service.scheduleAutoPrepare() + } return state, nil } func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error) { service.mu.Lock() + if service.state.Status == update.StatusDownloading || service.state.Status == update.StatusInstalling { + state := service.state + service.mu.Unlock() + return state, nil + } downloadURLs := service.resolveDownloadURLsLocked() expectedSHA256 := service.downloadSHA256 + fallback := service.capturePreparedFallbackLocked() service.setStatusLocked(update.StatusDownloading, 0, "") state := service.state service.mu.Unlock() service.publishSnapshot(state) if len(downloadURLs) == 0 { - return service.publishError(fmt.Errorf("missing download url")) + return service.publishPrepareError(fmt.Errorf("missing download url"), fallback) } if service.downloader == nil { - return service.publishError(fmt.Errorf("downloader not configured")) + return service.publishPrepareError(fmt.Errorf("downloader not configured"), fallback) } var path string @@ -214,7 +361,7 @@ func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error) zap.L().Warn("update: download source failed", zap.String("url", downloadURL), zap.Error(err)) } if err != nil { - return service.publishError(err) + return service.publishPrepareError(err, fallback) } service.mu.Lock() @@ -226,12 +373,14 @@ func (service *Service) DownloadUpdate(ctx context.Context) (update.Info, error) service.publishSnapshot(installingState) if service.installer != nil { - if err := service.installer.Install(ctx, path); err != nil { - return service.publishError(err) + if err := service.installer.Install(ctx, path, installingState); err != nil { + return service.publishPrepareError(err, fallback) } } service.mu.Lock() + service.state.PreparedVersion = update.NormalizeVersion(service.state.LatestVersion) + service.state.PreparedChangelog = service.state.Changelog service.setStatusLocked(update.StatusReadyToRestart, 100, "") finalState := service.state service.mu.Unlock() @@ -249,6 +398,7 @@ func (service *Service) RestartToApply(ctx context.Context) (update.Info, error) return service.publishError(err) } service.mu.Lock() + service.clearPreparedStateLocked() service.setStatusLocked(update.StatusIdle, 0, "") service.downloadURLs = nil service.downloadSHA256 = "" @@ -314,6 +464,67 @@ func (service *Service) safeCheck(ctx context.Context, currentVersion string) { _, _ = service.CheckForUpdate(ctx, currentVersion) // errors already published } +func (service *Service) scheduleAutoPrepare() { + service.mu.Lock() + if !service.shouldAutoPrepareLocked() { + service.mu.Unlock() + return + } + latestVersion := service.state.LatestVersion + service.autoPrepareInFlight = true + service.mu.Unlock() + + go func() { + defer service.finishAutoPrepare() + if _, err := service.DownloadUpdate(context.Background()); err != nil { + zap.L().Warn("update: auto-prepare failed", zap.String("latestVersion", latestVersion), zap.Error(err)) + } + }() +} + +func (service *Service) finishAutoPrepare() { + service.mu.Lock() + service.autoPrepareInFlight = false + service.mu.Unlock() +} + +func (service *Service) publishCheckError(err error) (update.Info, error) { + service.mu.Lock() + service.state.CheckedAt = service.now() + if service.state.HasPreparedUpdate() { + service.setPreparedReadyLocked() + state := service.state + service.mu.Unlock() + service.notifyAvailability(true) + service.publishSnapshot(state) + return state, err + } + service.setStatusLocked(update.StatusError, service.state.Progress, err.Error()) + state := service.state + service.mu.Unlock() + service.publishSnapshot(state) + return state, err +} + +func (service *Service) publishPrepareError(err error, fallback preparedFallback) (update.Info, error) { + service.mu.Lock() + if fallback.HasPreparedUpdate(service.state.CurrentVersion) { + service.state.PreparedVersion = fallback.Version + service.state.PreparedChangelog = fallback.Changelog + service.setPreparedReadyLocked() + state := service.state + service.mu.Unlock() + service.notifyAvailability(true) + service.publishSnapshot(state) + return state, err + } + service.setStatusLocked(update.StatusError, service.state.Progress, err.Error()) + state := service.state + service.mu.Unlock() + service.publishSnapshot(state) + return state, err +} + func (service *Service) publishError(err error) (update.Info, error) { service.mu.Lock() service.setStatusLocked(update.StatusError, service.state.Progress, err.Error()) @@ -331,6 +542,35 @@ func (service *Service) setStatusLocked(status update.Status, progress int, mess service.state.Message = message } +func (service *Service) setPreparedReadyLocked() { + service.state.Status = update.StatusReadyToRestart + service.state.Progress = 100 + service.state.Message = "" +} + +func (service *Service) clearPreparedStateLocked() { + service.state.PreparedVersion = "" + service.state.PreparedChangelog = "" +} + +func (service *Service) shouldAutoPrepareLocked() bool { + if service.autoPrepareInFlight { + return false + } + if service.state.Status != update.StatusAvailable { + return false + } + if strings.TrimSpace(service.state.DownloadURL) == "" { + return false + } + latestVersion := update.NormalizeVersion(service.state.LatestVersion) + if latestVersion == "" { + return false + } + preparedVersion := update.NormalizeVersion(service.state.PreparedVersion) + return preparedVersion == "" || update.CompareVersion(latestVersion, preparedVersion) > 0 +} + func (service *Service) publishState() { service.mu.Lock() state := service.state @@ -368,6 +608,62 @@ func (service *Service) resolveDownloadURLsLocked() []string { return []string{service.state.DownloadURL} } +type preparedFallback struct { + Version string + Changelog string +} + +func (fallback preparedFallback) HasPreparedUpdate(currentVersion string) bool { + preparedVersion := update.NormalizeVersion(fallback.Version) + current := update.NormalizeVersion(currentVersion) + return preparedVersion != "" && update.CompareVersion(preparedVersion, current) > 0 +} + +func (service *Service) capturePreparedFallback() preparedFallback { + service.mu.Lock() + defer service.mu.Unlock() + return service.capturePreparedFallbackLocked() +} + +func (service *Service) capturePreparedFallbackLocked() preparedFallback { + return preparedFallback{ + Version: service.state.PreparedVersion, + Changelog: service.state.PreparedChangelog, + } +} + +func (service *Service) resolveReleaseNotes(ctx context.Context, version string) string { + if service.catalog == nil || !isReleaseVersion(version) { + return "" + } + release, err := service.catalog.ResolveAppReleaseByVersion(ctx, version) + if err != nil { + zap.L().Warn("update: resolve release notes failed", + zap.String("version", version), + zap.Error(err), + ) + return "" + } + return strings.TrimSpace(release.Notes) +} + +func isReleaseVersion(version string) bool { + normalized := update.NormalizeVersion(version) + if normalized == "" { + return false + } + parts := strings.Split(normalized, ".") + for _, part := range parts { + if part == "" { + return false + } + if _, err := strconv.Atoi(part); err != nil { + return false + } + } + return true +} + func normalizeSHA256(raw string) string { value := strings.ToLower(strings.TrimSpace(raw)) value = strings.TrimPrefix(value, "sha256:") diff --git a/internal/application/update/service_test.go b/internal/application/update/service_test.go index 9d96103..0cfd4dd 100644 --- a/internal/application/update/service_test.go +++ b/internal/application/update/service_test.go @@ -38,9 +38,16 @@ type installerStub struct { restarted bool selectedDownloadURLs []string selectDownloadInvoked bool + preparedInfo domainupdate.Info + hasPreparedInfo bool + clearPreparedInvoked bool + pendingWhatsNew domainupdate.WhatsNew + hasPendingWhatsNew bool + seenWhatsNewVersion string + markSeenVersion string } -func (stub installerStub) Install(_ context.Context, _ string) error { +func (stub installerStub) Install(_ context.Context, _ string, _ domainupdate.Info) error { return stub.installErr } @@ -57,6 +64,28 @@ func (stub *installerStub) SelectDownloadURLs(_ context.Context, urls []string) return urls } +func (stub *installerStub) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) { + return stub.preparedInfo, stub.hasPreparedInfo, nil +} + +func (stub *installerStub) ClearPreparedUpdate(_ context.Context) error { + stub.clearPreparedInvoked = true + return nil +} + +func (stub *installerStub) PendingWhatsNew(_ context.Context) (domainupdate.WhatsNew, bool, error) { + return stub.pendingWhatsNew, stub.hasPendingWhatsNew, nil +} + +func (stub *installerStub) SeenWhatsNewVersion(_ context.Context) (string, error) { + return stub.seenWhatsNewVersion, nil +} + +func (stub *installerStub) MarkWhatsNewSeen(_ context.Context, version string) error { + stub.markSeenVersion = version + return nil +} + func (stub *catalogProviderStub) FetchCatalog(_ context.Context, _ softwareupdate.Request) (softwareupdate.Catalog, error) { stub.fetchCount++ if stub.err != nil { @@ -71,6 +100,25 @@ func newCatalogService(provider *catalogProviderStub) *softwareupdate.Service { }) } +type appVersionProviderStub struct { + release softwareupdate.AppRelease + err error +} + +func (stub appVersionProviderStub) FetchAppRelease(_ context.Context, _ softwareupdate.AppRequest) (softwareupdate.AppRelease, error) { + if stub.err != nil { + return softwareupdate.AppRelease{}, stub.err + } + return stub.release, nil +} + +func (stub appVersionProviderStub) FetchAppReleaseByVersion(_ context.Context, _ string) (softwareupdate.AppRelease, error) { + if stub.err != nil { + return softwareupdate.AppRelease{}, stub.err + } + return stub.release, nil +} + func buildCatalog(version string, downloadURL string) softwareupdate.Catalog { return softwareupdate.Catalog{ App: &softwareupdate.AppRelease{ @@ -309,3 +357,224 @@ func TestRestartToApplyPublishesErrorWhenInstallerFails(t *testing.T) { t.Fatalf("expected error message %q, got %q", restartErr.Error(), info.Message) } } + +func TestRestorePreparedUpdateRestoresReadyState(t *testing.T) { + t.Parallel() + + installer := &installerStub{ + hasPreparedInfo: true, + preparedInfo: domainupdate.Info{ + PreparedVersion: "1.2.4", + PreparedChangelog: "Bug fixes", + }, + } + service := NewService(ServiceParams{Installer: installer}) + service.SetCurrentVersion("1.2.3") + + info, err := service.RestorePreparedUpdate(context.Background()) + if err != nil { + t.Fatalf("restore prepared update failed: %v", err) + } + if info.Status != domainupdate.StatusReadyToRestart { + t.Fatalf("expected ready_to_restart status, got %q", info.Status) + } + if info.PreparedVersion != "1.2.4" { + t.Fatalf("expected prepared version 1.2.4, got %q", info.PreparedVersion) + } + if info.PreparedChangelog != "Bug fixes" { + t.Fatalf("expected prepared changelog to be restored, got %q", info.PreparedChangelog) + } +} + +func TestRestorePreparedUpdateClearsStalePreparedPlan(t *testing.T) { + t.Parallel() + + installer := &installerStub{ + hasPreparedInfo: true, + preparedInfo: domainupdate.Info{ + PreparedVersion: "1.2.3", + }, + } + service := NewService(ServiceParams{Installer: installer}) + service.SetCurrentVersion("1.2.3") + + info, err := service.RestorePreparedUpdate(context.Background()) + if err != nil { + t.Fatalf("restore prepared update failed: %v", err) + } + if !installer.clearPreparedInvoked { + t.Fatal("expected stale prepared update to be cleared") + } + if info.PreparedVersion != "" { + t.Fatalf("expected prepared version to be cleared, got %q", info.PreparedVersion) + } +} + +func TestCheckForUpdateAutoPreparesLatestVersion(t *testing.T) { + t.Parallel() + + file, err := os.CreateTemp(t.TempDir(), "update-*.zip") + if err != nil { + t.Fatalf("create temp file failed: %v", err) + } + if err := file.Close(); err != nil { + t.Fatalf("close temp file failed: %v", err) + } + + provider := &catalogProviderStub{ + catalog: buildCatalog("1.2.4", "https://example.com/download.zip"), + } + service := NewService(ServiceParams{ + Catalog: newCatalogService(provider), + Downloader: &downloaderStub{path: file.Name()}, + Installer: &installerStub{}, + }) + + if _, err := service.CheckForUpdate(context.Background(), "1.2.3"); err != nil { + t.Fatalf("check for update failed: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + info := service.State() + if info.Status == domainupdate.StatusReadyToRestart { + if info.PreparedVersion != "1.2.4" { + t.Fatalf("expected prepared version 1.2.4, got %q", info.PreparedVersion) + } + return + } + time.Sleep(20 * time.Millisecond) + } + + t.Fatalf("expected auto prepare to reach ready_to_restart, got %q", service.State().Status) +} + +func TestCheckForUpdatePreservesPreparedStateWhenRefreshFails(t *testing.T) { + t.Parallel() + + provider := &catalogProviderStub{err: errors.New("manifest unavailable")} + service := NewService(ServiceParams{Catalog: newCatalogService(provider)}) + service.state = domainupdate.Info{ + Kind: domainupdate.KindApp, + CurrentVersion: "1.2.3", + LatestVersion: "1.2.4", + PreparedVersion: "1.2.4", + PreparedChangelog: "Bug fixes", + Status: domainupdate.StatusReadyToRestart, + Progress: 100, + } + + info, err := service.CheckForUpdate(context.Background(), "1.2.3") + if err == nil { + t.Fatal("expected check to fail") + } + if info.Status != domainupdate.StatusReadyToRestart { + t.Fatalf("expected ready_to_restart status, got %q", info.Status) + } + if info.PreparedVersion != "1.2.4" { + t.Fatalf("expected prepared version to be preserved, got %q", info.PreparedVersion) + } +} + +func TestDownloadUpdateRestoresPreviousPreparedVersionWhenNewerPrepareFails(t *testing.T) { + t.Parallel() + + installerErr := errors.New("prepare latest failed") + service := NewService(ServiceParams{ + Downloader: &downloaderStub{path: "/tmp/dreamcreator-update.zip"}, + Installer: &installerStub{installErr: installerErr}, + }) + service.state = domainupdate.Info{ + Kind: domainupdate.KindApp, + CurrentVersion: "1.2.3", + LatestVersion: "1.2.5", + PreparedVersion: "1.2.4", + PreparedChangelog: "Prepared 1.2.4", + DownloadURL: "https://example.com/dreamcreator-update.zip", + Status: domainupdate.StatusAvailable, + } + + info, err := service.DownloadUpdate(context.Background()) + if !errors.Is(err, installerErr) { + t.Fatalf("expected installer error, got %v", err) + } + if info.Status != domainupdate.StatusReadyToRestart { + t.Fatalf("expected ready_to_restart status, got %q", info.Status) + } + if info.PreparedVersion != "1.2.4" { + t.Fatalf("expected prepared version 1.2.4 to be preserved, got %q", info.PreparedVersion) + } + if info.LatestVersion != "1.2.5" { + t.Fatalf("expected latest version 1.2.5 to stay visible, got %q", info.LatestVersion) + } +} + +func TestGetWhatsNewReturnsPendingPreparedNoticeForCurrentVersion(t *testing.T) { + t.Parallel() + + installer := &installerStub{ + hasPendingWhatsNew: true, + pendingWhatsNew: domainupdate.WhatsNew{ + Version: "2.0.7", + Changelog: "## Prepared update", + }, + seenWhatsNewVersion: "2.0.6", + } + service := NewService(ServiceParams{Installer: installer}) + service.SetCurrentVersion("2.0.7") + + notice, err := service.GetWhatsNew(context.Background()) + if err != nil { + t.Fatalf("GetWhatsNew failed: %v", err) + } + if notice.Version != "2.0.7" { + t.Fatalf("expected version 2.0.7, got %q", notice.Version) + } + if notice.Changelog != "## Prepared update" { + t.Fatalf("expected prepared changelog, got %q", notice.Changelog) + } +} + +func TestGetWhatsNewFallsBackToCurrentVersionReleaseNotes(t *testing.T) { + t.Parallel() + + installer := &installerStub{seenWhatsNewVersion: "2.0.6"} + catalog := softwareupdate.NewService(softwareupdate.ServiceParams{ + AppFallbackProvider: appVersionProviderStub{ + release: softwareupdate.AppRelease{ + Version: "2.0.7", + Notes: "## Current release notes", + }, + }, + }) + service := NewService(ServiceParams{ + Catalog: catalog, + Installer: installer, + }) + service.SetCurrentVersion("2.0.7") + + notice, err := service.GetWhatsNew(context.Background()) + if err != nil { + t.Fatalf("GetWhatsNew failed: %v", err) + } + if notice.Version != "2.0.7" { + t.Fatalf("expected version 2.0.7, got %q", notice.Version) + } + if notice.Changelog != "## Current release notes" { + t.Fatalf("expected current release notes, got %q", notice.Changelog) + } +} + +func TestDismissWhatsNewMarksSeenVersion(t *testing.T) { + t.Parallel() + + installer := &installerStub{} + service := NewService(ServiceParams{Installer: installer}) + + if err := service.DismissWhatsNew(context.Background(), "2.0.7"); err != nil { + t.Fatalf("DismissWhatsNew failed: %v", err) + } + if installer.markSeenVersion != "2.0.7" { + t.Fatalf("expected seen version 2.0.7, got %q", installer.markSeenVersion) + } +} diff --git a/internal/domain/connectors/connector.go b/internal/domain/connectors/connector.go index fe60756..9cc8907 100644 --- a/internal/domain/connectors/connector.go +++ b/internal/domain/connectors/connector.go @@ -9,6 +9,10 @@ type ConnectorType string const ( ConnectorGoogle ConnectorType = "google" + ConnectorGitHub ConnectorType = "github" + ConnectorReddit ConnectorType = "reddit" + ConnectorZhihu ConnectorType = "zhihu" + ConnectorX ConnectorType = "x" ConnectorXiaohongshu ConnectorType = "xiaohongshu" ConnectorBilibili ConnectorType = "bilibili" ) diff --git a/internal/domain/connectors/errors.go b/internal/domain/connectors/errors.go index 96a7122..6f214cc 100644 --- a/internal/domain/connectors/errors.go +++ b/internal/domain/connectors/errors.go @@ -3,7 +3,9 @@ package connectors import "errors" var ( - ErrConnectorNotFound = errors.New("connector not found") - ErrInvalidConnector = errors.New("invalid connector") - ErrNoCookies = errors.New("no cookies stored") + ErrConnectorNotFound = errors.New("connector not found") + ErrInvalidConnector = errors.New("invalid connector") + ErrNoCookies = errors.New("no cookies stored") + ErrConnectorSessionDead = errors.New("connector browser session ended") + ErrConnectorSessionGone = errors.New("connector session not found") ) diff --git a/internal/domain/externaltools/tool.go b/internal/domain/externaltools/tool.go index a54bd89..e77bfc0 100644 --- a/internal/domain/externaltools/tool.go +++ b/internal/domain/externaltools/tool.go @@ -8,11 +8,10 @@ import ( type ToolName string const ( - ToolYTDLP ToolName = "yt-dlp" - ToolFFmpeg ToolName = "ffmpeg" - ToolBun ToolName = "bun" - ToolClawHub ToolName = "clawhub" - ToolPlaywright ToolName = "playwright" + ToolYTDLP ToolName = "yt-dlp" + ToolFFmpeg ToolName = "ffmpeg" + ToolBun ToolName = "bun" + ToolClawHub ToolName = "clawhub" ) type ToolKind string diff --git a/internal/domain/settings/calls.go b/internal/domain/settings/calls.go index 0033ebe..1bcef84 100644 --- a/internal/domain/settings/calls.go +++ b/internal/domain/settings/calls.go @@ -1,27 +1,18 @@ package settings -import domainweb "dreamcreator/internal/domain/web" - const DefaultBrowserColor = "#FF4500" func DefaultCallsToolsConfig() map[string]any { defaultWebFetch := map[string]any{ - "type": "builtin", - "acceptMarkdown": true, - "enableUserAgent": true, - "userAgent": domainweb.DefaultBrowserRequestUserAgent, - "acceptLanguage": domainweb.DefaultBrowserRequestAcceptLanguage, - "playwright": map[string]any{ - "markdown": true, - }, + "headless": true, + "preferredBrowser": "chrome", } defaultBrowser := map[string]any{ - "enabled": true, - "evaluateEnabled": true, - "headless": false, - "noSandbox": false, + "enabled": true, + "headless": true, + "preferredBrowser": "chrome", "ssrfPolicy": map[string]any{ - "dangerouslyAllowPrivateNetwork": true, + "dangerouslyAllowPrivateNetwork": false, }, } return map[string]any{ diff --git a/internal/domain/settings/calls_test.go b/internal/domain/settings/calls_test.go index 6c388fe..7454105 100644 --- a/internal/domain/settings/calls_test.go +++ b/internal/domain/settings/calls_test.go @@ -10,21 +10,14 @@ func TestDefaultCallsToolsConfigIncludesWebFetchDefaults(t *testing.T) { if !ok || fetchRaw == nil { t.Fatalf("expected web_fetch defaults") } - if fetchRaw["acceptMarkdown"] != true { - t.Fatalf("expected acceptMarkdown default true, got %#v", fetchRaw["acceptMarkdown"]) + if fetchRaw["preferredBrowser"] != "chrome" { + t.Fatalf("expected preferredBrowser default chrome, got %#v", fetchRaw["preferredBrowser"]) } - if fetchRaw["acceptLanguage"] != "en-US,en;q=0.9" { - t.Fatalf("expected acceptLanguage default en-US,en;q=0.9, got %#v", fetchRaw["acceptLanguage"]) + if fetchRaw["headless"] != true { + t.Fatalf("expected web_fetch headless default true, got %#v", fetchRaw["headless"]) } - if fetchRaw["type"] != "builtin" { - t.Fatalf("expected type default builtin, got %#v", fetchRaw["type"]) - } - playwrightRaw, ok := fetchRaw["playwright"].(map[string]any) - if !ok || playwrightRaw == nil { - t.Fatalf("expected playwright defaults") - } - if playwrightRaw["markdown"] != true { - t.Fatalf("expected playwright markdown default true, got %#v", playwrightRaw["markdown"]) + if _, exists := fetchRaw["type"]; exists { + t.Fatalf("expected legacy web_fetch type to be removed") } webRaw, ok := defaults["web"].(map[string]any) if !ok || webRaw == nil { @@ -46,19 +39,19 @@ func TestDefaultCallsToolsConfigIncludesBrowserDefaults(t *testing.T) { if browserRaw["enabled"] != true { t.Fatalf("expected browser enabled default true, got %#v", browserRaw["enabled"]) } - if browserRaw["evaluateEnabled"] != true { - t.Fatalf("expected browser evaluateEnabled default true, got %#v", browserRaw["evaluateEnabled"]) + if browserRaw["headless"] != true { + t.Fatalf("expected browser headless default true, got %#v", browserRaw["headless"]) } - if browserRaw["headless"] != false { - t.Fatalf("expected browser headless default false, got %#v", browserRaw["headless"]) + if browserRaw["preferredBrowser"] != "chrome" { + t.Fatalf("expected browser preferredBrowser default chrome, got %#v", browserRaw["preferredBrowser"]) } ssrfRaw, ok := browserRaw["ssrfPolicy"].(map[string]any) if !ok || ssrfRaw == nil { t.Fatalf("expected browser ssrfPolicy defaults") } - if ssrfRaw["dangerouslyAllowPrivateNetwork"] != true { + if ssrfRaw["dangerouslyAllowPrivateNetwork"] != false { t.Fatalf( - "expected browser ssrfPolicy.dangerouslyAllowPrivateNetwork true, got %#v", + "expected browser ssrfPolicy.dangerouslyAllowPrivateNetwork false, got %#v", ssrfRaw["dangerouslyAllowPrivateNetwork"], ) } @@ -72,11 +65,11 @@ func TestNormalizeToolsConfigAddsWebFetchDefaultsWhenMissing(t *testing.T) { if !ok || fetchRaw == nil { t.Fatalf("expected web_fetch config") } - if fetchRaw["acceptMarkdown"] != true { - t.Fatalf("expected acceptMarkdown default true, got %#v", fetchRaw["acceptMarkdown"]) + if fetchRaw["preferredBrowser"] != "chrome" { + t.Fatalf("expected preferredBrowser default, got %#v", fetchRaw["preferredBrowser"]) } - if fetchRaw["userAgent"] != "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15" { - t.Fatalf("expected userAgent default, got %#v", fetchRaw["userAgent"]) + if fetchRaw["headless"] != true { + t.Fatalf("expected headless default true, got %#v", fetchRaw["headless"]) } } diff --git a/internal/domain/update/update.go b/internal/domain/update/update.go index 7aebdae..c24125b 100644 --- a/internal/domain/update/update.go +++ b/internal/domain/update/update.go @@ -28,15 +28,23 @@ const ( ) type Info struct { - Kind Kind + Kind Kind + CurrentVersion string + LatestVersion string + Changelog string + DownloadURL string + CheckedAt time.Time + Status Status + Progress int // 0-100 + Message string + PreparedVersion string + PreparedChangelog string +} + +type WhatsNew struct { + Version string CurrentVersion string - LatestVersion string Changelog string - DownloadURL string - CheckedAt time.Time - Status Status - Progress int // 0-100 - Message string } func (info Info) IsUpdateAvailable() bool { @@ -51,6 +59,18 @@ func (info Info) IsError() bool { return info.Status == StatusError } +func (info Info) HasPreparedUpdate() bool { + current := NormalizeVersion(info.CurrentVersion) + prepared := NormalizeVersion(info.PreparedVersion) + return prepared != "" && CompareVersion(prepared, current) > 0 +} + +func (info Info) HasRemoteUpdate() bool { + current := NormalizeVersion(info.CurrentVersion) + latest := NormalizeVersion(info.LatestVersion) + return latest != "" && CompareVersion(latest, current) > 0 +} + func NormalizeVersion(version string) string { trimmed := strings.TrimSpace(version) if trimmed == "" { diff --git a/internal/infrastructure/update/app_fallback_provider.go b/internal/infrastructure/update/app_fallback_provider.go index 713f056..1849472 100644 --- a/internal/infrastructure/update/app_fallback_provider.go +++ b/internal/infrastructure/update/app_fallback_provider.go @@ -12,6 +12,18 @@ func (client *GithubReleaseClient) FetchAppRelease(ctx context.Context, request if err != nil { return softwareupdate.AppRelease{}, err } + return client.toAppRelease(release), nil +} + +func (client *GithubReleaseClient) FetchAppReleaseByVersion(ctx context.Context, version string) (softwareupdate.AppRelease, error) { + release, err := client.FetchReleaseByVersion(ctx, version) + if err != nil { + return softwareupdate.AppRelease{}, err + } + return client.toAppRelease(release), nil +} + +func (client *GithubReleaseClient) toAppRelease(release githubRelease) softwareupdate.AppRelease { assetURL := selectAsset(release.Assets) sources := make([]softwareupdate.DownloadSource, 0, 2) if strings.TrimSpace(assetURL) != "" { @@ -42,5 +54,5 @@ func (client *GithubReleaseClient) FetchAppRelease(ctx context.Context, request Asset: softwareupdate.Asset{ Sources: sources, }, - }, nil + } } diff --git a/internal/infrastructure/update/github_client.go b/internal/infrastructure/update/github_client.go index 20229b7..75bc8d1 100644 --- a/internal/infrastructure/update/github_client.go +++ b/internal/infrastructure/update/github_client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "runtime" "strings" "time" @@ -16,11 +17,12 @@ const ( ) type GithubReleaseClient struct { - client *http.Client + client *http.Client + releasesURL string } func NewGithubReleaseClient(client *http.Client) *GithubReleaseClient { - return &GithubReleaseClient{client: client} + return &GithubReleaseClient{client: client, releasesURL: defaultReleasesURL} } type githubRelease struct { @@ -44,28 +46,82 @@ func (client *GithubReleaseClient) FetchLatestRelease(ctx context.Context) (gith if client == nil || client.client == nil { return githubRelease{}, fmt.Errorf("http client not configured") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, defaultReleasesURL+"/latest", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, client.baseReleasesURL()+"/latest", nil) if err != nil { return githubRelease{}, err } req.Header.Set("Accept", "application/vnd.github+json") + release, err := client.fetchRelease(req, "github latest release") + if err != nil { + return githubRelease{}, err + } + if strings.TrimSpace(release.TagName) == "" { + return githubRelease{}, fmt.Errorf("no latest release found") + } + return release, nil +} + +func (client *GithubReleaseClient) FetchReleaseByVersion(ctx context.Context, version string) (githubRelease, error) { + if client == nil || client.client == nil { + return githubRelease{}, fmt.Errorf("http client not configured") + } + normalized := strings.TrimSpace(version) + if normalized == "" { + return githubRelease{}, fmt.Errorf("release version is empty") + } + + candidates := []string{normalized} + if !strings.HasPrefix(strings.ToLower(normalized), "v") { + candidates = append(candidates, "v"+normalized) + } + + var lastErr error + for _, candidate := range candidates { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + client.baseReleasesURL()+"/tags/"+url.PathEscape(candidate), + nil, + ) + if err != nil { + return githubRelease{}, err + } + req.Header.Set("Accept", "application/vnd.github+json") + + release, err := client.fetchRelease(req, fmt.Sprintf("github release %q", candidate)) + if err == nil { + return release, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = fmt.Errorf("release %q not found", normalized) + } + return githubRelease{}, lastErr +} + +func (client *GithubReleaseClient) baseReleasesURL() string { + if client == nil || strings.TrimSpace(client.releasesURL) == "" { + return defaultReleasesURL + } + return strings.TrimSpace(client.releasesURL) +} + +func (client *GithubReleaseClient) fetchRelease(req *http.Request, label string) (githubRelease, error) { resp, err := client.client.Do(req) if err != nil { return githubRelease{}, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return githubRelease{}, fmt.Errorf("github latest release http %d", resp.StatusCode) + return githubRelease{}, fmt.Errorf("%s http %d", label, resp.StatusCode) } var release githubRelease if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { return githubRelease{}, err } - if strings.TrimSpace(release.TagName) == "" { - return githubRelease{}, fmt.Errorf("no latest release found") - } return release, nil } diff --git a/internal/infrastructure/update/installer.go b/internal/infrastructure/update/installer.go index 84755d9..bd58811 100644 --- a/internal/infrastructure/update/installer.go +++ b/internal/infrastructure/update/installer.go @@ -15,6 +15,9 @@ import ( "slices" "strconv" "strings" + "time" + + domainupdate "dreamcreator/internal/domain/update" ) var ErrPreparedUpdateNotFound = fmt.Errorf("prepared update not found") @@ -28,12 +31,14 @@ const ( ) type PlatformInstaller struct { - stateDir string - planPath string - goos string - goarch string - executablePath func() (string, error) - startDetached func(name string, args []string) error + stateDir string + planPath string + whatsNewPendingPath string + whatsNewSeenPath string + goos string + goarch string + executablePath func() (string, error) + startDetached func(name string, args []string) error } type stagedPlan struct { @@ -45,6 +50,13 @@ type stagedPlan struct { RelaunchPath string `json:"relaunchPath"` FallbackPath string `json:"fallbackPath,omitempty"` InstallDir string `json:"installDir,omitempty"` + Version string `json:"version,omitempty"` + Changelog string `json:"changelog,omitempty"` +} + +type whatsNewSeenState struct { + Version string `json:"version"` + SeenAt string `json:"seenAt,omitempty"` } func NewInstaller(statePath string) (*PlatformInstaller, error) { @@ -61,12 +73,14 @@ func NewInstaller(statePath string) (*PlatformInstaller, error) { return nil, err } return &PlatformInstaller{ - stateDir: stateDir, - planPath: filepath.Join(stateDir, "update_install_plan.json"), - goos: runtime.GOOS, - goarch: runtime.GOARCH, - executablePath: os.Executable, - startDetached: startDetachedCommand, + stateDir: stateDir, + planPath: filepath.Join(stateDir, "update_install_plan.json"), + whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"), + whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"), + goos: runtime.GOOS, + goarch: runtime.GOARCH, + executablePath: os.Executable, + startDetached: startDetachedCommand, }, nil } @@ -84,7 +98,7 @@ func (installer *PlatformInstaller) SelectDownloadURLs(_ context.Context, urls [ return preferWindowsPortableDownloadURLs(urls) } -func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath string) error { +func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath string, prepared domainupdate.Info) error { if installer == nil { return fmt.Errorf("installer not configured") } @@ -92,18 +106,29 @@ func (installer *PlatformInstaller) Install(ctx context.Context, artifactPath st if normalizedArtifact == "" { return fmt.Errorf("artifact path is empty") } - if err := installer.cleanupStagedUpdate(); err != nil { + + previousPlan, err := installer.loadPlan() + hasPreviousPlan := err == nil + if err != nil && !errors.Is(err, ErrPreparedUpdateNotFound) { return err } + var installErr error switch installer.goos { case "windows": - return installer.prepareWindowsUpdate(normalizedArtifact) + installErr = installer.prepareWindowsUpdate(normalizedArtifact, prepared) case "darwin": - return installer.prepareMacUpdate(ctx, normalizedArtifact) + installErr = installer.prepareMacUpdate(ctx, normalizedArtifact, prepared) default: - return fmt.Errorf("update install is not supported on %s", installer.goos) + installErr = fmt.Errorf("update install is not supported on %s", installer.goos) + } + if installErr != nil { + return installErr } + if hasPreviousPlan && strings.TrimSpace(previousPlan.StageDir) != "" { + _ = os.RemoveAll(previousPlan.StageDir) + } + return nil } func (installer *PlatformInstaller) RestartToApply(_ context.Context) error { @@ -125,7 +150,7 @@ func (installer *PlatformInstaller) RestartToApply(_ context.Context) error { } } -func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) error { +func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string, prepared domainupdate.Info) error { currentExe, err := installer.currentExecutable() if err != nil { return err @@ -143,6 +168,8 @@ func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) er TargetPath: currentExe, RelaunchPath: currentExe, InstallDir: filepath.Dir(currentExe), + Version: strings.TrimSpace(prepared.LatestVersion), + Changelog: prepared.Changelog, } switch strings.ToLower(filepath.Ext(artifactName)) { @@ -168,7 +195,7 @@ func (installer *PlatformInstaller) prepareWindowsUpdate(artifactPath string) er return installer.savePlan(plan) } -func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifactPath string) error { +func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifactPath string, prepared domainupdate.Info) error { currentExe, err := installer.currentExecutable() if err != nil { return err @@ -204,9 +231,112 @@ func (installer *PlatformInstaller) prepareMacUpdate(ctx context.Context, artifa TargetPath: targetBundle, RelaunchPath: targetBundle, FallbackPath: currentBundle, + Version: strings.TrimSpace(prepared.LatestVersion), + Changelog: prepared.Changelog, }) } +func (installer *PlatformInstaller) PreparedUpdate(_ context.Context) (domainupdate.Info, bool, error) { + if installer == nil { + return domainupdate.Info{}, false, fmt.Errorf("installer not configured") + } + plan, err := installer.loadPlan() + if err != nil { + if errors.Is(err, ErrPreparedUpdateNotFound) { + return domainupdate.Info{}, false, nil + } + return domainupdate.Info{}, false, err + } + return domainupdate.Info{ + Kind: domainupdate.KindApp, + Status: domainupdate.StatusReadyToRestart, + LatestVersion: strings.TrimSpace(plan.Version), + PreparedVersion: strings.TrimSpace(plan.Version), + Changelog: plan.Changelog, + PreparedChangelog: plan.Changelog, + Progress: 100, + }, true, nil +} + +func (installer *PlatformInstaller) ClearPreparedUpdate(_ context.Context) error { + if installer == nil { + return fmt.Errorf("installer not configured") + } + return installer.cleanupStagedUpdate() +} + +func (installer *PlatformInstaller) PendingWhatsNew(_ context.Context) (domainupdate.WhatsNew, bool, error) { + if installer == nil { + return domainupdate.WhatsNew{}, false, fmt.Errorf("installer not configured") + } + data, err := os.ReadFile(installer.whatsNewPendingPath) + if err != nil { + if os.IsNotExist(err) { + return domainupdate.WhatsNew{}, false, nil + } + return domainupdate.WhatsNew{}, false, err + } + + var plan stagedPlan + if err := json.Unmarshal(data, &plan); err != nil { + return domainupdate.WhatsNew{}, false, err + } + version := strings.TrimSpace(plan.Version) + if version == "" { + return domainupdate.WhatsNew{}, false, nil + } + return domainupdate.WhatsNew{ + Version: version, + Changelog: plan.Changelog, + }, true, nil +} + +func (installer *PlatformInstaller) SeenWhatsNewVersion(_ context.Context) (string, error) { + if installer == nil { + return "", fmt.Errorf("installer not configured") + } + data, err := os.ReadFile(installer.whatsNewSeenPath) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } + var seen whatsNewSeenState + if err := json.Unmarshal(data, &seen); err != nil { + return "", err + } + return strings.TrimSpace(seen.Version), nil +} + +func (installer *PlatformInstaller) MarkWhatsNewSeen(_ context.Context, version string) error { + if installer == nil { + return fmt.Errorf("installer not configured") + } + normalized := strings.TrimSpace(version) + if normalized == "" { + return nil + } + data, err := json.MarshalIndent(whatsNewSeenState{ + Version: normalized, + SeenAt: time.Now().UTC().Format(time.RFC3339), + }, "", " ") + if err != nil { + return err + } + if err := os.WriteFile(installer.whatsNewSeenPath, data, 0o600); err != nil { + return err + } + pending, found, err := installer.PendingWhatsNew(context.Background()) + if err != nil { + return err + } + if found && domainupdate.CompareVersion(pending.Version, normalized) <= 0 { + _ = os.Remove(installer.whatsNewPendingPath) + } + return nil +} + func (installer *PlatformInstaller) restartWindows(plan stagedPlan) error { scriptPath := filepath.Join(installer.stateDir, "apply_update.ps1") if err := os.WriteFile(scriptPath, []byte(windowsApplyScript), 0o600); err != nil { @@ -224,6 +354,7 @@ func (installer *PlatformInstaller) restartWindows(plan stagedPlan) error { plan.InstallDir, plan.StageDir, installer.planPath, + installer.whatsNewPendingPath, } return installer.startDetached("powershell.exe", args) } @@ -252,6 +383,7 @@ func (installer *PlatformInstaller) restartDarwin(plan stagedPlan) error { fallbackPath, plan.StageDir, installer.planPath, + installer.whatsNewPendingPath, } return installer.startDetached("/bin/sh", args) } @@ -563,7 +695,8 @@ const windowsApplyScript = `param( [Parameter(Mandatory = $true)][string]$TargetPath, [Parameter(Mandatory = $true)][string]$InstallDir, [Parameter(Mandatory = $true)][string]$StageDir, - [Parameter(Mandatory = $true)][string]$PlanPath + [Parameter(Mandatory = $true)][string]$PlanPath, + [Parameter(Mandatory = $true)][string]$PendingWhatsNewPath ) $ErrorActionPreference = "Stop" @@ -667,13 +800,18 @@ try { } } + try { + Copy-Item -LiteralPath $PlanPath -Destination $PendingWhatsNewPath -Force -ErrorAction Stop + } catch { + } + Start-Process -FilePath $TargetPath -WorkingDirectory $InstallDir | Out-Null Remove-Item -LiteralPath $PlanPath -Force -ErrorAction SilentlyContinue Remove-Item -LiteralPath $StageDir -Recurse -Force -ErrorAction SilentlyContinue } catch { try { if (Test-Path -LiteralPath $TargetPath) { - Start-Process -FilePath $TargetPath -WorkingDirectory $InstallDir | Out-Null + Start-Process -FilePath $TargetPath -ArgumentList @("--skip-prepared-update-once") -WorkingDirectory $InstallDir | Out-Null } } catch { } @@ -691,6 +829,7 @@ RELAUNCH_APP="$4" FALLBACK_APP="$5" STAGE_DIR="$6" PLAN_PATH="$7" +PENDING_WHATS_NEW_PATH="$8" BACKUP_APP="${TARGET_APP}.old" while kill -0 "$PARENT_PID" 2>/dev/null; do @@ -699,8 +838,13 @@ done relaunch_app() { APP_PATH="$1" + shift || true if [ -n "$APP_PATH" ] && [ -d "$APP_PATH" ]; then - open "$APP_PATH" >/dev/null 2>&1 || true + if [ "$#" -gt 0 ]; then + open -a "$APP_PATH" --args "$@" >/dev/null 2>&1 || true + else + open "$APP_PATH" >/dev/null 2>&1 || true + fi fi } @@ -712,9 +856,9 @@ restore_backup() { } relaunch_fallback() { - relaunch_app "$FALLBACK_APP" + relaunch_app "$FALLBACK_APP" "--skip-prepared-update-once" if [ "$FALLBACK_APP" != "$TARGET_APP" ]; then - relaunch_app "$TARGET_APP" + relaunch_app "$TARGET_APP" "--skip-prepared-update-once" fi } @@ -760,6 +904,7 @@ if ! install_direct; then fi /usr/bin/xattr -dr com.apple.quarantine "$TARGET_APP" >/dev/null 2>&1 || true +cp "$PLAN_PATH" "$PENDING_WHATS_NEW_PATH" >/dev/null 2>&1 || true if ! open "$RELAUNCH_APP"; then relaunch_fallback exit 1 @@ -770,7 +915,12 @@ rm -rf "$STAGE_DIR" ` var _ interface { - Install(context.Context, string) error + Install(context.Context, string, domainupdate.Info) error RestartToApply(context.Context) error SelectDownloadURLs(context.Context, []string) []string + PreparedUpdate(context.Context) (domainupdate.Info, bool, error) + ClearPreparedUpdate(context.Context) error + PendingWhatsNew(context.Context) (domainupdate.WhatsNew, bool, error) + SeenWhatsNewVersion(context.Context) (string, error) + MarkWhatsNewSeen(context.Context, string) error } = (*PlatformInstaller)(nil) diff --git a/internal/infrastructure/update/installer_test.go b/internal/infrastructure/update/installer_test.go index d60232d..172486f 100644 --- a/internal/infrastructure/update/installer_test.go +++ b/internal/infrastructure/update/installer_test.go @@ -1,6 +1,8 @@ package update import ( + "context" + "os" "path/filepath" "testing" ) @@ -34,8 +36,10 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) { ) stateDir := t.TempDir() installer := &PlatformInstaller{ - stateDir: stateDir, - planPath: filepath.Join(stateDir, "update_install_plan.json"), + stateDir: stateDir, + planPath: filepath.Join(stateDir, "update_install_plan.json"), + whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"), + whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"), startDetached: func(name string, args []string) error { capturedName = name capturedArgs = append([]string(nil), args...) @@ -57,7 +61,7 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) { if capturedName != "/bin/sh" { t.Fatalf("unexpected restart helper: %q", capturedName) } - if len(capturedArgs) != 8 { + if len(capturedArgs) != 9 { t.Fatalf("unexpected helper args: %#v", capturedArgs) } if capturedArgs[4] != plan.RelaunchPath { @@ -66,6 +70,9 @@ func TestRestartDarwinUsesExplicitRelaunchAndFallbackPaths(t *testing.T) { if capturedArgs[5] != plan.FallbackPath { t.Fatalf("expected fallback path %q, got %q", plan.FallbackPath, capturedArgs[5]) } + if capturedArgs[8] != installer.whatsNewPendingPath { + t.Fatalf("expected pending what's new path %q, got %q", installer.whatsNewPendingPath, capturedArgs[8]) + } } func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) { @@ -74,8 +81,10 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) { var capturedArgs []string stateDir := t.TempDir() installer := &PlatformInstaller{ - stateDir: stateDir, - planPath: filepath.Join(stateDir, "update_install_plan.json"), + stateDir: stateDir, + planPath: filepath.Join(stateDir, "update_install_plan.json"), + whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"), + whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"), startDetached: func(_ string, args []string) error { capturedArgs = append([]string(nil), args...) return nil @@ -91,7 +100,7 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) { t.Fatalf("restartDarwin failed: %v", err) } - if len(capturedArgs) != 8 { + if len(capturedArgs) != 9 { t.Fatalf("unexpected helper args: %#v", capturedArgs) } if capturedArgs[4] != plan.TargetPath { @@ -100,4 +109,76 @@ func TestRestartDarwinDefaultsRelaunchAndFallbackToTarget(t *testing.T) { if capturedArgs[5] != plan.TargetPath { t.Fatalf("expected default fallback path %q, got %q", plan.TargetPath, capturedArgs[5]) } + if capturedArgs[8] != installer.whatsNewPendingPath { + t.Fatalf("expected pending what's new path %q, got %q", installer.whatsNewPendingPath, capturedArgs[8]) + } +} + +func TestPendingWhatsNewReadsCopiedPlan(t *testing.T) { + t.Parallel() + + stateDir := t.TempDir() + installer := &PlatformInstaller{ + stateDir: stateDir, + planPath: filepath.Join(stateDir, "update_install_plan.json"), + whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"), + whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"), + } + if err := installer.savePlan(stagedPlan{ + Version: "2.0.7", + Changelog: "## Updated", + }); err != nil { + t.Fatalf("savePlan failed: %v", err) + } + data, err := os.ReadFile(installer.planPath) + if err != nil { + t.Fatalf("read plan failed: %v", err) + } + if err := os.WriteFile(installer.whatsNewPendingPath, data, 0o600); err != nil { + t.Fatalf("write pending file failed: %v", err) + } + + notice, found, err := installer.PendingWhatsNew(context.Background()) + if err != nil { + t.Fatalf("PendingWhatsNew failed: %v", err) + } + if !found { + t.Fatal("expected pending what's new notice") + } + if notice.Version != "2.0.7" { + t.Fatalf("expected version 2.0.7, got %q", notice.Version) + } + if notice.Changelog != "## Updated" { + t.Fatalf("expected changelog to be preserved, got %q", notice.Changelog) + } +} + +func TestMarkWhatsNewSeenPersistsVersionAndClearsCoveredPendingNotice(t *testing.T) { + t.Parallel() + + stateDir := t.TempDir() + installer := &PlatformInstaller{ + stateDir: stateDir, + planPath: filepath.Join(stateDir, "update_install_plan.json"), + whatsNewPendingPath: filepath.Join(stateDir, "pending_whats_new.json"), + whatsNewSeenPath: filepath.Join(stateDir, "whats_new_seen.json"), + } + if err := os.WriteFile(installer.whatsNewPendingPath, []byte(`{"version":"2.0.7","changelog":"hi"}`), 0o600); err != nil { + t.Fatalf("seed pending file failed: %v", err) + } + + if err := installer.MarkWhatsNewSeen(context.Background(), "2.0.7"); err != nil { + t.Fatalf("MarkWhatsNewSeen failed: %v", err) + } + + seenVersion, err := installer.SeenWhatsNewVersion(context.Background()) + if err != nil { + t.Fatalf("SeenWhatsNewVersion failed: %v", err) + } + if seenVersion != "2.0.7" { + t.Fatalf("expected seen version 2.0.7, got %q", seenVersion) + } + if _, err := os.Stat(installer.whatsNewPendingPath); !os.IsNotExist(err) { + t.Fatalf("expected pending file to be removed, got err=%v", err) + } } diff --git a/internal/infrastructure/update/manifest_provider_test.go b/internal/infrastructure/update/manifest_provider_test.go index b00905f..51f1750 100644 --- a/internal/infrastructure/update/manifest_provider_test.go +++ b/internal/infrastructure/update/manifest_provider_test.go @@ -22,7 +22,7 @@ func TestManifestCatalogProviderSelectsCurrentPlatformAssets(t *testing.T) { "channels":{ "stable":{ "app":{ - "source":{"provider":"github-release","owner":"arnoldhao","repo":"dreamcreator"}, + "source":{"provider":"github-release","owner":"example-owner","repo":"dreamcreator"}, "version":"1.3.0", "publishedAt":"2026-04-06T00:00:00Z", "platforms":{ diff --git a/internal/infrastructure/update/tool_fallback_provider.go b/internal/infrastructure/update/tool_fallback_provider.go index e4f291c..831ed80 100644 --- a/internal/infrastructure/update/tool_fallback_provider.go +++ b/internal/infrastructure/update/tool_fallback_provider.go @@ -61,8 +61,6 @@ func (provider *ToolFallbackProvider) FetchToolRelease(ctx context.Context, requ return provider.fetchGitHubToolRelease(ctx, request.Name, "oven-sh", "bun") case externaltools.ToolClawHub: return provider.fetchNPMPackageRelease(ctx, request.Name, "clawhub") - case externaltools.ToolPlaywright: - return softwareupdate.ToolRelease{}, softwareupdate.ErrReleaseNotFound default: return softwareupdate.ToolRelease{}, externaltools.ErrInvalidTool } diff --git a/internal/presentation/wails/connectors_handler.go b/internal/presentation/wails/connectors_handler.go index 164f1d4..5451c57 100644 --- a/internal/presentation/wails/connectors_handler.go +++ b/internal/presentation/wails/connectors_handler.go @@ -36,21 +36,29 @@ func (handler *ConnectorsHandler) ClearConnector(ctx context.Context, request dt return handler.service.ClearConnector(ctx, request) } -func (handler *ConnectorsHandler) ConnectConnector(ctx context.Context, request dto.ConnectConnectorRequest) (dto.Connector, error) { - result, err := handler.service.ConnectConnector(ctx, request) +func (handler *ConnectorsHandler) StartConnectorConnect(ctx context.Context, request dto.StartConnectorConnectRequest) (dto.StartConnectorConnectResult, error) { + return handler.service.StartConnectorConnect(ctx, request) +} + +func (handler *ConnectorsHandler) FinishConnectorConnect(ctx context.Context, request dto.FinishConnectorConnectRequest) (dto.FinishConnectorConnectResult, error) { + result, err := handler.service.FinishConnectorConnect(ctx, request) if err != nil { - return dto.Connector{}, err + return dto.FinishConnectorConnectResult{}, err } - if handler.telemetry != nil && result.Status == "connected" { - handler.telemetry.TrackConnectorConnected(ctx, result.Type) + if handler.telemetry != nil && result.Saved && result.Connector.Status == "connected" { + handler.telemetry.TrackConnectorConnected(ctx, result.Connector.Type) } return result, nil } -func (handler *ConnectorsHandler) OpenConnectorSite(ctx context.Context, request dto.OpenConnectorSiteRequest) error { - return handler.service.OpenConnectorSite(ctx, request) +func (handler *ConnectorsHandler) CancelConnectorConnect(ctx context.Context, request dto.CancelConnectorConnectRequest) error { + return handler.service.CancelConnectorConnect(ctx, request) } -func (handler *ConnectorsHandler) InstallPlaywright(ctx context.Context) error { - return handler.service.InstallPlaywright(ctx) +func (handler *ConnectorsHandler) GetConnectorConnectSession(ctx context.Context, request dto.GetConnectorConnectSessionRequest) (dto.ConnectorConnectSession, error) { + return handler.service.GetConnectorConnectSession(ctx, request) +} + +func (handler *ConnectorsHandler) OpenConnectorSite(ctx context.Context, request dto.OpenConnectorSiteRequest) error { + return handler.service.OpenConnectorSite(ctx, request) } diff --git a/internal/presentation/wails/settings_handler.go b/internal/presentation/wails/settings_handler.go index 3c15de9..7e08a92 100644 --- a/internal/presentation/wails/settings_handler.go +++ b/internal/presentation/wails/settings_handler.go @@ -6,6 +6,7 @@ import ( "strings" "time" + gatewaytools "dreamcreator/internal/application/gateway/tools" "dreamcreator/internal/application/settings/dto" "dreamcreator/internal/application/settings/service" "dreamcreator/internal/domain/settings" @@ -84,6 +85,19 @@ func (handler *SettingsHandler) UpdateSettings(ctx context.Context, request dto. zap.L().Info("proxy applied", proxyFields(updated.Proxy)...) } + if hasPrevious && gatewaytools.BrowserToolRuntimeConfigChanged(previousSettings.Tools, updated.Tools) { + gatewaytools.CleanupAllBrowserToolSessions() + zap.L().Info( + "browser tool sessions reset after browser runtime settings change", + zap.Bool("previousEnabled", resolveSettingsBrowserBool(previousSettings.Tools, "enabled")), + zap.Bool("currentEnabled", resolveSettingsBrowserBool(updated.Tools, "enabled")), + zap.Bool("previousHeadless", resolveSettingsBrowserBool(previousSettings.Tools, "headless")), + zap.Bool("currentHeadless", resolveSettingsBrowserBool(updated.Tools, "headless")), + zap.String("previousPreferredBrowser", resolveSettingsBrowserString(previousSettings.Tools, "preferredBrowser")), + zap.String("currentPreferredBrowser", resolveSettingsBrowserString(updated.Tools, "preferredBrowser")), + ) + } + handler.windows.ApplySettings(updated) return updated, nil } @@ -250,6 +264,24 @@ func proxyFields(proxyDTO dto.Proxy) []zap.Field { } } +func resolveSettingsBrowserBool(config map[string]any, key string) bool { + browser, ok := config["browser"].(map[string]any) + if !ok || browser == nil { + return false + } + value, _ := browser[key].(bool) + return value +} + +func resolveSettingsBrowserString(config map[string]any, key string) string { + browser, ok := config["browser"].(map[string]any) + if !ok || browser == nil { + return "" + } + value, _ := browser[key].(string) + return strings.TrimSpace(value) +} + func (handler *SettingsHandler) rollbackSettings(ctx context.Context, previous dto.Settings) { _, err := handler.service.UpdateSettings(ctx, dto.UpdateSettingsRequest{ Appearance: &previous.Appearance, diff --git a/internal/presentation/wails/update_handler.go b/internal/presentation/wails/update_handler.go index 6470e40..01ed03b 100644 --- a/internal/presentation/wails/update_handler.go +++ b/internal/presentation/wails/update_handler.go @@ -34,6 +34,10 @@ func (handler *UpdateHandler) GetState(_ context.Context) update.Info { return handler.service.State() } +func (handler *UpdateHandler) GetWhatsNew(ctx context.Context) (update.WhatsNew, error) { + return handler.service.GetWhatsNew(ctx) +} + func (handler *UpdateHandler) CheckForUpdate(ctx context.Context, currentVersion string) (update.Info, error) { return handler.service.CheckForUpdate(ctx, currentVersion) } @@ -41,7 +45,11 @@ func (handler *UpdateHandler) CheckForUpdate(ctx context.Context, currentVersion func (handler *UpdateHandler) DownloadUpdate(ctx context.Context) (update.Info, error) { info, err := handler.service.DownloadUpdate(ctx) if err == nil && handler.telemetry != nil && info.Status == update.StatusReadyToRestart { - handler.telemetry.TrackUpdateReadyToRestart(ctx, info.LatestVersion) + latestVersion := info.PreparedVersion + if latestVersion == "" { + latestVersion = info.LatestVersion + } + handler.telemetry.TrackUpdateReadyToRestart(ctx, latestVersion) } return info, err } @@ -56,3 +64,7 @@ func (handler *UpdateHandler) RestartToApply(ctx context.Context) (update.Info, } return info, err } + +func (handler *UpdateHandler) DismissWhatsNew(ctx context.Context, version string) error { + return handler.service.DismissWhatsNew(ctx, version) +} diff --git a/main.go b/main.go index 860dcac..b2a18bf 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "embed" "log" "os" @@ -21,6 +22,14 @@ func main() { } }() + appliedPreparedUpdate, err := app.TryApplyPreparedUpdateOnLaunch(context.Background(), os.Args[1:]) + if err != nil { + log.Fatal(err) + } + if appliedPreparedUpdate { + return + } + application, err := app.CreateApplication(assets) if err != nil { log.Fatal(err)