diff --git a/README.md b/README.md index b8f39bd..374b1cc 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,15 @@ Suggestions generated: 3. chore: refresh package metadata ``` +### Stream suggestions as they are generated + +Use `--stream` to print LLM output incrementally instead of waiting behind a spinner. Supported for OpenAI-compatible and Anthropic providers; use non-streaming mode for Cohere. Pair with `--yes` for a non-interactive workflow that streams output and auto-commits the first suggestion. + +```bash +commit-echo suggest --stream +commit-echo suggest --stream --yes +``` + ### Inspect suggestion diagnostics with `--verbose` Use verbose mode when you want to confirm which model handled the request, how much commit history was folded into the style profile, or whether the diff had to be truncated before sending it to the provider. diff --git a/src/commands/batch.ts b/src/commands/batch.ts new file mode 100644 index 0000000..a94f134 --- /dev/null +++ b/src/commands/batch.ts @@ -0,0 +1,481 @@ +import { existsSync, readdirSync, statSync, writeFileSync, unlinkSync } from 'node:fs'; +import { basename, join } from 'node:path'; +import { execSync, spawnSync } from 'node:child_process'; +import { tmpdir } from 'node:os'; +import { intro, outro, confirm, select, text, isCancel } from '@clack/prompts'; +import pc from 'picocolors'; +import { loadOrPromptConfig } from '../config/store.js'; +import { assertApiKeyAvailable, generateSuggestions } from '../llm/client.js'; +import { buildProfile, appendEntry } from '../history/store.js'; +import type { Config, Suggestion } from '../types.js'; + +export interface BatchResult { + repo: string; + repoName: string; + status: 'success' | 'skipped' | 'failed'; + message?: string; +} + +/** + * Scan a directory for git repositories (directories containing a `.git` folder). + * When `recursive` is true, descends into subdirectories to find nested repos. + */ +export function findGitRepositories(rootDir: string, recursive: boolean): string[] { + const repos: string[] = []; + + if (!existsSync(rootDir)) return repos; + + // If rootDir itself is a git repo, return it directly + if (existsSync(join(rootDir, '.git'))) { + repos.push(rootDir); + return repos.sort(); + } + + let entries; + try { + entries = readdirSync(rootDir, { withFileTypes: true }); + } catch { + return repos; // skip unreadable directories + } + + for (const entry of entries) { + if (!entry.isDirectory()) continue; + if (entry.name.startsWith('.')) continue; + + const fullPath = join(rootDir, entry.name); + + if (existsSync(join(fullPath, '.git'))) { + repos.push(fullPath); + } else if (recursive) { + repos.push(...findGitRepositories(fullPath, true)); + } + } + + return repos.sort(); +} + +/** + * Check whether a git repository at `cwd` has staged or unstaged changes. + */ +export function gitHasChanges(cwd: string): { staged: boolean; unstaged: boolean } { + let staged = false; + let unstaged = false; + + try { + execSync('git diff --cached --quiet', { cwd, stdio: 'pipe' }); + } catch { + staged = true; + } + + try { + execSync('git diff --quiet', { cwd, stdio: 'pipe' }); + } catch { + unstaged = true; + } + + return { staged, unstaged }; +} + +/** + * Get the git diff for a repository at `cwd`. + */ +export function getGitDiff(cwd: string, staged: boolean): string { + const cmd = staged ? 'git diff --cached' : 'git diff'; + try { + return execSync(cmd, { cwd, encoding: 'utf-8', maxBuffer: 100 * 1024 * 1024 }).trim(); + } catch (err) { + throw new Error( + `Failed to get diff: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +/** + * Run `git commit` inside a specific repository directory. + */ +export function gitCommit( + cwd: string, + message: string, + body?: string, +): { hash: string; summary: string } { + const fullMessage = body ? `${message}\n\n${body}` : message; + const tmpFile = join( + tmpdir(), + `commit-echo-batch-${process.pid}-${Date.now()}.txt`, + ); + + try { + writeFileSync(tmpFile, fullMessage, 'utf-8'); + const result = spawnSync('git', ['commit', '-F', tmpFile], { + cwd, + encoding: 'utf-8', + shell: false, + }); + + if (result.error) throw result.error; + if (result.status !== 0) { + const detail = [result.stderr, result.stdout] + .filter(Boolean) + .join('\n') + .trim(); + throw new Error(detail || `git commit exited with code ${result.status}`); + } + + const summary = result.stdout.trim().split('\n').find(Boolean) ?? ''; + const match = summary.match( + /\[.*?([a-f0-9]{7,})\]\s+(.+)$/i, + ); + + return { + hash: match?.[1] ?? '', + summary: match?.[2] ?? summary, + }; + } finally { + try { + unlinkSync(tmpFile); + } catch { + /* ignore */ + } + } +} + +/** + * Display a set of suggestions to the user. + */ +function displaySuggestions(suggestions: Suggestion[]): void { + for (const s of suggestions) { + const full = s.body ? `${s.message}\n ${pc.dim(s.body)}` : s.message; + console.log(` ${pc.cyan(`${s.index}.`)} ${full}`); + } +} + +export async function batchCommand( + options: { + directory?: string; + recursive?: boolean; + yes?: boolean; + } = {}, +): Promise { + intro(pc.bold(pc.cyan('commit-echo batch'))); + + const dir = options.directory ?? process.cwd(); + + if (!existsSync(dir) || !statSync(dir).isDirectory()) { + outro(pc.red(`Directory not found: ${dir}`)); + return; + } + + // Discover git repositories in the target directory + const repos = findGitRepositories(dir, options.recursive ?? false); + + if (repos.length === 0) { + outro(pc.yellow(`No git repositories found in ${dir}`)); + return; + } + + console.log( + `\n Found ${pc.bold(String(repos.length))} repo(s) — checking for changes...\n`, + ); + + // Load configuration once (shared across all repos) + let config: Config; + try { + config = await loadOrPromptConfig(); + } catch (err) { + outro(pc.red(err instanceof Error ? err.message : 'Configuration error')); + return; + } + + // Verify API key once + let apiKey: string; + try { + apiKey = assertApiKeyAvailable(config); + } catch (err) { + outro(pc.red(err instanceof Error ? err.message : 'Missing API key')); + return; + } + + // Build style profile once (shared across all repos) + const profile = await buildProfile(config.historySize); + + const results: BatchResult[] = []; + + for (const repoPath of repos) { + const repoName = basename(repoPath); + console.log(` ${pc.bold(pc.cyan(`▶ ${repoName}`))} ${pc.dim(repoPath)}`); + + // Check what kind of changes exist + const { staged, unstaged } = gitHasChanges(repoPath); + + if (!staged) { + if (!unstaged) { + console.log(` ${pc.yellow('↻ No changes found, skipping')}\n`); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'No changes', + }); + continue; + } + console.log( + ` ${pc.yellow('ℹ Unstaged changes only (stage with `git add` first), skipping')}\n`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'Unstaged only', + }); + continue; + } + + // Get the staged diff + let diff: string; + try { + diff = getGitDiff(repoPath, true); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + console.log(` ${pc.red(`✖ ${msg}`)}\n`); + results.push({ repo: repoPath, repoName, status: 'failed', message: msg }); + continue; + } + + if (!diff) { + console.log(` ${pc.yellow('↻ Empty diff, skipping')}\n`); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'Empty diff', + }); + continue; + } + + // Generate suggestions using the shared profile + let suggestions: Suggestion[]; + try { + const result = await generateSuggestions(config, diff, profile, apiKey); + suggestions = result.suggestions; + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + console.log( + ` ${pc.red(`✖ Failed to generate suggestions: ${msg}`)}\n`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'failed', + message: msg, + }); + continue; + } + + // Display suggestions + console.log(''); + displaySuggestions(suggestions); + + if (options.yes) { + // Unattended mode: auto-select first suggestion and commit + const first = suggestions[0]; + if (!first) { + console.log( + ` ${pc.yellow('↻ No suggestions generated, skipping')}`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'No suggestions', + }); + console.log(''); + continue; + } + + try { + const commitResult = gitCommit(repoPath, first.message, first.body); + await appendEntry({ + timestamp: new Date().toISOString(), + message: first.body + ? `${first.message}\n\n${first.body}` + : first.message, + diff, + model: config.model, + provider: config.provider, + }); + console.log( + ` ${pc.green(`✓ ${pc.bold(commitResult.hash)} ${commitResult.summary}`)}`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'success', + message: first.message, + }); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + console.log(` ${pc.red(`✖ Commit failed: ${msg}`)}`); + results.push({ + repo: repoPath, + repoName, + status: 'failed', + message: msg, + }); + } + } else if (suggestions.length > 0) { + // Interactive mode: prompt per repo + const proceed = await confirm({ + message: `Commit changes in ${repoName}?`, + initialValue: true, + }); + + if (isCancel(proceed)) { + console.log(` ${pc.dim('– Cancelled, skipping')}`); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'Cancelled', + }); + console.log(''); + continue; + } + + if (!proceed) { + console.log(` ${pc.dim('– Skipped')}`); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'User skipped', + }); + console.log(''); + continue; + } + + // Let user select which suggestion to use + const suggestionOptions = suggestions.map((s) => ({ + value: s.index, + label: + s.message.length > 60 + ? s.message.slice(0, 57) + '...' + : s.message, + })); + + const selectedIndex = await select({ + message: `Select message for ${repoName}:`, + options: suggestionOptions, + }); + + if (isCancel(selectedIndex)) { + console.log(` ${pc.dim('– Cancelled, skipping')}`); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'Cancelled', + }); + console.log(''); + continue; + } + + const selected = suggestions.find( + (s) => s.index === selectedIndex, + ); + if (!selected) { + console.log(` ${pc.red('✖ Invalid selection')}`); + results.push({ + repo: repoPath, + repoName, + status: 'failed', + message: 'Invalid selection', + }); + console.log(''); + continue; + } + + // Prompt for an optional commit body (consistent with `suggest` UX) + const customBody = await text({ + message: `Optional body for ${repoName}:`, + initialValue: selected.body ?? '', + }); + const finalBody = + isCancel(customBody) || !customBody + ? selected.body + : customBody; + + try { + const commitResult = gitCommit( + repoPath, + selected.message, + finalBody, + ); + await appendEntry({ + timestamp: new Date().toISOString(), + message: finalBody + ? `${selected.message}\n\n${finalBody}` + : selected.message, + diff, + model: config.model, + provider: config.provider, + }); + console.log( + ` ${pc.green(`✓ ${pc.bold(commitResult.hash)} ${commitResult.summary}`)}`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'success', + message: selected.message, + }); + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + console.log(` ${pc.red(`✖ Commit failed: ${msg}`)}`); + results.push({ + repo: repoPath, + repoName, + status: 'failed', + message: msg, + }); + } + } else { + console.log( + ` ${pc.yellow('↻ No suggestions generated, skipping')}`, + ); + results.push({ + repo: repoPath, + repoName, + status: 'skipped', + message: 'No suggestions', + }); + } + + console.log(''); + } + + // Print summary report + const succeeded = results.filter((r) => r.status === 'success'); + const failed = results.filter((r) => r.status === 'failed'); + const skipped = results.filter((r) => r.status === 'skipped'); + + console.log(pc.bold('━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━')); + console.log(pc.bold('📋 Batch Summary\n')); + + for (const r of results) { + const icon = + r.status === 'success' + ? pc.green('✓') + : r.status === 'failed' + ? pc.red('✖') + : pc.yellow('–'); + const msg = r.message + ? ` — ${r.message.length > 60 ? r.message.slice(0, 57) + '...' : r.message}` + : ''; + console.log(` ${icon} ${r.repoName}${pc.dim(msg)}`); + } + + console.log( + `\n ${pc.green(String(succeeded.length))} succeeded, ${pc.yellow(String(skipped.length))} skipped, ${pc.red(String(failed.length))} failed`, + ); + outro('Batch processing complete.'); +} diff --git a/src/commands/suggest.ts b/src/commands/suggest.ts index bdc9530..f6306e9 100644 --- a/src/commands/suggest.ts +++ b/src/commands/suggest.ts @@ -10,6 +10,7 @@ import { import pc from "picocolors"; import type { Config, + Provider, StyleProfile, Suggestion, TruncationInfo, @@ -21,8 +22,14 @@ import { getUnstagedDiff, commit, } from "../git/diff.js"; -import { assertApiKeyAvailable, generateSuggestions } from "../llm/client.js"; +import { + assertApiKeyAvailable, + generateSuggestions, + generateSuggestionsStream, +} from "../llm/client.js"; import { appendEntry, buildProfile } from "../history/store.js"; +import { parseSuggestions } from "../llm/prompt.js"; +import { getStreamingProvider } from "../providers/index.js"; function showTruncationWarning(info: TruncationInfo): void { const pct = ((info.truncatedSize / info.originalSize) * 100).toFixed(1); @@ -74,6 +81,7 @@ export async function suggestCommand( autoCommit?: boolean; verbose?: boolean; model?: string; + stream?: boolean; } = {}, ): Promise { intro(pc.bold(pc.cyan("commit-echo"))); @@ -121,49 +129,126 @@ export async function suggestCommand( const profile = await buildProfile(config.historySize); - const genSpinner = spinner(); - genSpinner.start("Generating commit suggestions..."); + let suggestions: Suggestion[]; + let truncation: TruncationInfo | undefined; + let model: string; - try { - const { suggestions, truncation, model } = await generateSuggestions( - config, - diffResult.diff, - profile, - apiKey, - ); - genSpinner.stop(pc.green("Suggestions generated:")); + if (options.stream) { + let streamProvider: Provider; + try { + streamProvider = getStreamingProvider(config.provider); + } catch (err) { + outro(pc.red(err instanceof Error ? err.message : "Streaming not supported")); + return; + } + + // Streaming mode: show text as it arrives + console.log(pc.dim("Streaming suggestions...\n")); + + model = config.model; + let accumulated = ""; + try { + for await (const event of generateSuggestionsStream( + config, + diffResult.diff, + profile, + apiKey, + streamProvider, + )) { + if (event.kind === "meta") { + truncation = event.truncation; + continue; + } - if (options.verbose) { - showVerboseInfo(model, profile, truncation); + if (event.kind === "model") { + model = event.model; + continue; + } + + accumulated += event.text; + process.stdout.write(event.text); + } + } catch (err) { + process.stdout.write("\n"); + const message = err instanceof Error ? err.message : "Unknown error"; + outro(pc.red(`Streaming failed: ${message}`)); + return; + } + process.stdout.write("\n\n"); + + const parsed = parseSuggestions(accumulated); + suggestions = parsed.map((p, i) => ({ + index: i + 1, + message: p.message, + body: p.body, + })); + + if (suggestions.length === 0) { + outro( + pc.red( + "Could not parse any suggestions from LLM response. The model may need a different prompt format.", + ), + ); + return; } + } else { + // Non-streaming mode: use spinner and wait for full response + const genSpinner = spinner(); + genSpinner.start("Generating commit suggestions..."); - if (truncation) { - showTruncationWarning(truncation); + try { + const result = await generateSuggestions( + config, + diffResult.diff, + profile, + apiKey, + ); + suggestions = result.suggestions; + truncation = result.truncation; + model = result.model; + genSpinner.stop(pc.green("Suggestions generated:")); + } catch (err) { + genSpinner.stop(pc.red("Failed to generate suggestions.")); + const message = err instanceof Error ? err.message : "Unknown error"; + outro(pc.red(message)); + return; } + } + if (options.verbose) { + showVerboseInfo(model, profile, truncation); + } + + if (truncation) { + showTruncationWarning(truncation); + } + + if (!options.stream) { await displaySuggestions(suggestions); + } - if (options.autoCommit && suggestions.length > 0) { - const first = suggestions[0]!; - if (options.commit !== false) { - if (!diffResult.staged) { - outro( - pc.red( - "Auto-commit requires staged changes. Stage your changes with `git add` and try again.", - ), - ); - process.exit(1); - } - await acceptAndCommit(first, config, diffResult.diff, true); - } else { - console.log(`\n ${pc.green("Selected:")} ${pc.bold(first.message)}`); - if (first.body) { - console.log(` ${pc.dim(first.body)}`); - } + if (options.autoCommit && suggestions.length > 0) { + const first = suggestions[0]!; + if (options.commit !== false) { + if (!diffResult.staged) { + outro( + pc.red( + "Auto-commit requires staged changes. Stage your changes with `git add` and try again.", + ), + ); + process.exit(1); + } + await acceptAndCommit(first, config, diffResult.diff, true); + } else { + console.log(`\n ${pc.green("Selected:")} ${pc.bold(first.message)}`); + if (first.body) { + console.log(` ${pc.dim(first.body)}`); } - return; } + return; + } + try { const action = await select({ message: "Choose an action:", options: [ @@ -208,7 +293,6 @@ export async function suggestCommand( await acceptAndCommit(selected, config, diffResult.diff); } } catch (err) { - genSpinner.stop(pc.red("Failed to generate suggestions.")); const message = err instanceof Error ? err.message : "Unknown error"; outro(pc.red(message)); } diff --git a/src/config/store.ts b/src/config/store.ts index c1891da..983b4a7 100644 --- a/src/config/store.ts +++ b/src/config/store.ts @@ -35,7 +35,19 @@ export function getHistoryPath(): string { export async function loadConfig(): Promise { const configPath = getConfigPath(); const raw = await readFile(configPath, 'utf-8'); - const parsed = JSON.parse(raw) as Partial; + let parsed: Partial; + + try { + parsed = JSON.parse(raw) as Partial; + } catch (error) { + if (error instanceof SyntaxError) { + throw new Error( + `Invalid JSON in config file: ${configPath}. Fix the JSON syntax or run \`commit-echo init\` to recreate it.`, + { cause: error }, + ); + } + throw error; + } return { provider: parsed.provider ?? '', diff --git a/src/index.ts b/src/index.ts index 8fe4330..76b94f8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,6 +8,7 @@ import { fileURLToPath } from 'node:url'; import { initCommand } from './commands/init.js'; import { suggestCommand } from './commands/suggest.js'; import { historyCommand } from './commands/history.js'; +import { batchCommand } from './commands/batch.js'; import { getAvailableTemplateVars } from './llm/prompt.js'; import { runPostCommitHook, runPrepareCommitMsgHook } from './git/hook.js'; @@ -39,6 +40,9 @@ ${pc.dim('Examples:')} ${pc.cyan('commit-echo suggest')} Generate suggestions without committing ${pc.cyan('commit-echo suggest --yes')} Auto-select first suggestion (no commit) ${pc.cyan('commit-echo history')} View learned style profile and history + ${pc.cyan('commit-echo batch')} Process all git repos in current directory + ${pc.cyan('commit-echo batch --recursive')} Search subdirectories for repos + ${pc.cyan('commit-echo batch --yes')} Auto-commit repos with staged changes ${pc.dim('Custom prompt template variables:')} ${getAvailableTemplateVars() @@ -64,18 +68,39 @@ program .option('-y, --yes', 'Automatically select the first suggestion and skip prompts') .option('-v, --verbose', 'Print diagnostic information about the suggestion request') .option('-m, --model ', 'Override the configured LLM model for this invocation') + .option('--stream', 'Stream suggestions as they are generated (progressive output)') .option('--auto', 'Alias for --yes') .action(async (options) => { + const globalOpts = program.opts<{ yes?: boolean; auto?: boolean }>(); await suggestCommand({ commit: options.commit, - autoCommit: Boolean(options.yes || options.auto), + autoCommit: Boolean( + options.yes || options.auto || globalOpts.yes || globalOpts.auto, + ), verbose: Boolean(options.verbose), model: options.model, + stream: Boolean(options.stream), }); }); program.command('history').description('View learned style profile and recent commit history').action(historyCommand); +program + .command('batch') + .description('Process multiple git repositories in batch mode') + .argument('[directory]', 'Directory to scan for git repositories') + .option('-r, --recursive', 'Recursively search subdirectories for git repos') + .option('-y, --yes', 'Automatically accept the first suggestion and commit without prompts') + .option('--auto', 'Alias for --yes') + .action(async (directory, options) => { + const globalOpts = program.opts<{ yes?: boolean; auto?: boolean }>(); + await batchCommand({ + directory: directory || undefined, + recursive: Boolean(options.recursive), + yes: Boolean(options.yes || options.auto || globalOpts.yes || globalOpts.auto), + }); + }); + const hookCommand = new Command('hook') .description('Internal Git hook entry point') .argument('', 'Git hook name') diff --git a/src/llm/client.ts b/src/llm/client.ts index 5ae4d60..fda4f2c 100644 --- a/src/llm/client.ts +++ b/src/llm/client.ts @@ -1,11 +1,12 @@ import type { Config, + Provider, Suggestion, StyleProfile, TruncationInfo, } from "../types.js"; import { getProviderInfo } from "../providers/index.js"; -import { complete } from "../providers/index.js"; +import { complete, completeStream } from "../providers/index.js"; import { resolveSystemPrompt, resolveUserPrompt, @@ -111,6 +112,75 @@ export async function generateSuggestions( }; } +export type SuggestionStreamEvent = + | { kind: "meta"; truncation?: TruncationInfo } + | { kind: "model"; model: string } + | { kind: "text"; text: string }; + +/** + * Stream commit suggestions from the LLM provider. + * Yields a meta event first (including truncation info), then text chunks. + * After iteration completes, the caller can parse accumulated text with + * `parseSuggestions()`. + */ +export async function* generateSuggestionsStream( + config: Config, + diff: string, + profileParam?: StyleProfile, + apiKeyParam?: string, + provider?: Provider, +): AsyncGenerator { + const profile = profileParam ?? (await buildProfile(config.historySize)); + + const { diff: truncatedDiff, info: truncation } = truncateDiff( + diff, + config.maxDiffSize, + ); + + const branch = getBranchName(); + const profileStr = formatProfile(profile); + + const vars = { + diff: truncatedDiff, + profile: profileStr, + branch, + }; + + const systemPrompt = resolveSystemPrompt(profile, vars, config); + const userPrompt = resolveUserPrompt(vars, config); + + const apiKey = apiKeyParam ?? assertApiKeyAvailable(config); + + yield { + kind: "meta", + truncation: truncation.wasTruncated ? truncation : undefined, + }; + + const stream = completeStream( + config.provider, + config.baseUrl, + { + model: config.model, + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt }, + ], + temperature: 0.7, + maxTokens: 1024, + apiKey, + }, + provider, + ); + + for await (const chunk of stream) { + if (chunk.kind === "model") { + yield { kind: "model", model: chunk.model }; + continue; + } + yield { kind: "text", text: chunk.text }; + } +} + export async function testConnection(config: Config): Promise { const apiKey = resolveApiKey(config); diff --git a/src/providers/anthropic.ts b/src/providers/anthropic.ts index 19a7c10..f623527 100644 --- a/src/providers/anthropic.ts +++ b/src/providers/anthropic.ts @@ -1,28 +1,43 @@ -import type { ChatParams, ChatResult, Provider } from '../types.js'; +import type { ChatParams, ChatResult, Provider, ProviderStreamChunk } from '../types.js'; import { fetchWithTimeout } from './request.js'; +import { parseAnthropicSseLine, streamSseResponse } from './sse.js'; + +function buildAnthropicRequestBody( + params: ChatParams, + options: { stream?: boolean } = {}, +): Record { + const { model, messages, temperature = 0.7, maxTokens = 1024 } = params; + + const systemMessages = messages.filter((m) => m.role === 'system'); + const nonSystemMessages = messages.filter((m) => m.role !== 'system'); + + const body: Record = { + model, + messages: nonSystemMessages.map((m) => ({ + role: m.role, + content: m.content, + })), + max_tokens: maxTokens, + temperature, + }; + + if (options.stream) { + body.stream = true; + } + + if (systemMessages.length > 0) { + body.system = systemMessages.map((m) => m.content).join('\n'); + } + + return body; +} export class AnthropicProvider implements Provider { async complete(params: ChatParams): Promise { - const { model, messages, temperature = 0.7, maxTokens = 1024, apiKey, baseUrl } = params; + const { model, apiKey, baseUrl } = params; const url = `${baseUrl.replace(/\/+$/, '')}/messages`; - - const systemMessages = messages.filter((m) => m.role === 'system'); - const nonSystemMessages = messages.filter((m) => m.role !== 'system'); - - const body: Record = { - model, - messages: nonSystemMessages.map((m) => ({ - role: m.role, - content: m.content, - })), - max_tokens: maxTokens, - temperature, - }; - - if (systemMessages.length > 0) { - body['system'] = systemMessages.map((m) => m.content).join('\n'); - } + const body = buildAnthropicRequestBody(params); const response = await fetchWithTimeout( url, @@ -59,6 +74,37 @@ export class AnthropicProvider implements Provider { }; } + async *completeStream(params: ChatParams): AsyncIterable { + const { apiKey, baseUrl } = params; + + const url = `${baseUrl.replace(/\/+$/, '')}/messages`; + const body = buildAnthropicRequestBody(params, { stream: true }); + + const response = await fetchWithTimeout( + url, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': apiKey, + 'anthropic-version': '2023-06-01', + }, + body: JSON.stringify(body), + }, + 'Anthropic streaming request', + ); + + if (!response.ok) { + const errorBody = await response.text().catch(() => ''); + throw new Error(`Anthropic API error (${response.status}): ${errorBody || response.statusText}`); + } + + const sseState = { currentEvent: '' }; + yield* streamSseResponse(response, (line) => + parseAnthropicSseLine(line, sseState), + ); + } + async fetchModels(_baseUrl: string, _apiKey: string): Promise { return [ 'claude-sonnet-4-20250514', diff --git a/src/providers/index.ts b/src/providers/index.ts index 15d976d..606726d 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,4 +1,4 @@ -import type { Provider, ChatParams, ChatResult } from '../types.js'; +import type { Provider, ChatParams, ChatResult, ProviderStreamChunk } from '../types.js'; import { OpenAICompatibleProvider } from './openai-compatible.js'; import { AnthropicProvider } from './anthropic.js'; import { CohereProvider } from './cohere.js'; @@ -39,6 +39,16 @@ export function createProvider(configProvider: string): Provider { return new OpenAICompatibleProvider(); } +export function getStreamingProvider(configProvider: string): Provider { + const provider = createProvider(configProvider); + if (!provider.completeStream) { + throw new Error( + `Streaming is not supported for the '${configProvider}' provider. Use non-streaming mode.`, + ); + } + return provider; +} + export async function complete( configProvider: string, baseUrlOverride: string | undefined, @@ -49,6 +59,17 @@ export async function complete( return provider.complete({ ...params, baseUrl }); } +export async function* completeStream( + configProvider: string, + baseUrlOverride: string | undefined, + params: Omit, + provider?: Provider, +): AsyncIterable { + const resolvedProvider = provider ?? getStreamingProvider(configProvider); + const baseUrl = getBaseUrl(configProvider, baseUrlOverride); + yield* resolvedProvider.completeStream!({ ...params, baseUrl }); +} + export async function fetchModels( configProvider: string, baseUrlOverride: string | undefined, diff --git a/src/providers/openai-compatible.ts b/src/providers/openai-compatible.ts index 841d464..d1bc6f8 100644 --- a/src/providers/openai-compatible.ts +++ b/src/providers/openai-compatible.ts @@ -1,9 +1,30 @@ -import type { ChatParams, ChatResult, Provider } from '../types.js'; +import type { ChatParams, ChatResult, Provider, ProviderStreamChunk } from '../types.js'; import { fetchWithTimeout } from './request.js'; +import { parseOpenAiSseLine, streamSseResponse, SSE_STREAM_END } from './sse.js'; + +function buildOpenAiRequestBody( + params: ChatParams, + options: { stream?: boolean } = {}, +): Record { + const { model, messages, temperature = 0.7, maxTokens = 1024 } = params; + + const body: Record = { + model, + messages, + temperature, + max_tokens: maxTokens, + }; + + if (options.stream) { + body.stream = true; + } + + return body; +} export class OpenAICompatibleProvider implements Provider { async complete(params: ChatParams): Promise { - const { model, messages, temperature = 0.7, maxTokens = 1024, apiKey, baseUrl } = params; + const { model, apiKey, baseUrl } = params; const url = `${baseUrl.replace(/\/+$/, '')}/chat/completions`; @@ -15,12 +36,7 @@ export class OpenAICompatibleProvider implements Provider { 'Content-Type': 'application/json', Authorization: `Bearer ${apiKey}`, }, - body: JSON.stringify({ - model, - messages, - temperature, - max_tokens: maxTokens, - }), + body: JSON.stringify(buildOpenAiRequestBody(params)), }, 'OpenAI-compatible API request', ); @@ -46,6 +62,39 @@ export class OpenAICompatibleProvider implements Provider { }; } + async *completeStream(params: ChatParams): AsyncIterable { + const { apiKey, baseUrl } = params; + + const url = `${baseUrl.replace(/\/+$/, '')}/chat/completions`; + + const response = await fetchWithTimeout( + url, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + body: JSON.stringify(buildOpenAiRequestBody(params, { stream: true })), + }, + 'OpenAI-compatible streaming request', + ); + + if (!response.ok) { + const errorBody = await response.text().catch(() => ''); + throw new Error(`OpenAI-compatible API error (${response.status}): ${errorBody || response.statusText}`); + } + + yield* streamSseResponse(response, (line) => { + const parsed = parseOpenAiSseLine(line); + if (parsed.error) throw new Error(`OpenAI-compatible streaming error: ${parsed.error}`); + if (parsed.done) return SSE_STREAM_END; + if (parsed.model) return { kind: 'model', model: parsed.model }; + if (parsed.text) return { kind: 'text', text: parsed.text }; + return null; + }); + } + async fetchModels(baseUrl: string, apiKey: string): Promise { const url = `${baseUrl.replace(/\/+$/, '')}/models`; diff --git a/src/providers/sse.ts b/src/providers/sse.ts new file mode 100644 index 0000000..1b82b05 --- /dev/null +++ b/src/providers/sse.ts @@ -0,0 +1,156 @@ +import type { ProviderStreamChunk } from '../types.js'; + +export type AnthropicSseState = { + currentEvent: string; +}; + +export const SSE_STREAM_END = Symbol('SSE_STREAM_END'); + +export type SseLineParser = (line: string) => ProviderStreamChunk | typeof SSE_STREAM_END | null; + +/** + * Read an SSE response body, split into lines, and yield parsed chunks. + * Handles buffering for partial lines and ensures the reader is released + * safely even after `reader.cancel()` has been called. + */ +export async function* streamSseResponse( + response: Response, + parseLine: SseLineParser, +): AsyncIterable { + const reader = response.body?.getReader(); + if (!reader) throw new Error('No response body'); + + const decoder = new TextDecoder(); + let buffer = ''; + let cancelled = false; + + try { + while (true) { + const { done, value } = await reader.read(); + + if (value) { + buffer += decoder.decode(value, { stream: !done }); + } + + const lines = buffer.split('\n'); + // When not done, the last element is an incomplete line — put it back. + // When done, keep all elements so the final line is processed below. + buffer = done ? '' : (lines.pop() ?? ''); + + for (const line of lines) { + const result = parseLine(line); + if (result === SSE_STREAM_END) { + await reader.cancel(); + cancelled = true; + return; + } + if (result) yield result; + } + + if (done) break; + } + } finally { + if (!cancelled) reader.releaseLock(); + } +} + +export function parseOpenAiSseLine(line: string): { + text?: string; + model?: string; + done?: boolean; + error?: string; +} { + const trimmed = line.trim(); + if (!trimmed || !trimmed.startsWith('data:')) return {}; + + const payload = trimmed.slice(5).trim(); + if (payload === '[DONE]') return { done: true }; + + try { + const parsed = JSON.parse(payload) as { + error?: { message?: string }; + model?: string; + choices?: { delta?: { content?: string } }[]; + }; + + if (parsed.error?.message) { + return { error: parsed.error.message }; + } + + const result: { text?: string; model?: string } = {}; + if (parsed.model) result.model = parsed.model; + + const content = parsed.choices?.[0]?.delta?.content; + if (content) result.text = content; + + return result; + } catch { + // Skip malformed JSON chunks + } + + return {}; +} + +/** + * Parse a single Anthropic SSE line. Call repeatedly for each line in a batch, + * passing shared `state` to track event types across event/data line pairs. + */ +export function parseAnthropicSseLine( + line: string, + state: AnthropicSseState, +): ProviderStreamChunk | typeof SSE_STREAM_END | null { + const trimmed = line.trim(); + if (!trimmed) return null; + + if (trimmed.startsWith('event:')) { + state.currentEvent = trimmed.slice(6).trim(); + return null; + } + + if (!trimmed.startsWith('data:')) return null; + + const payload = trimmed.slice(5).trim(); + + if (state.currentEvent === 'message_start') { + try { + const parsed = JSON.parse(payload) as { + message?: { model?: string }; + }; + if (parsed.message?.model) { + return { kind: 'model', model: parsed.message.model }; + } + } catch { + // Skip malformed JSON + } + return null; + } + + if (state.currentEvent === 'content_block_delta') { + try { + const parsed = JSON.parse(payload) as { delta?: { text?: string } }; + if (parsed.delta?.text) { + return { kind: 'text', text: parsed.delta.text }; + } + } catch { + // Skip malformed JSON + } + return null; + } + + if (state.currentEvent === 'error') { + let message = 'Anthropic streaming error'; + try { + const parsed = JSON.parse(payload) as { error?: { message?: string } }; + if (parsed.error?.message) message = parsed.error.message; + } catch { + // Use default message + } + throw new Error(message); + } + + if (state.currentEvent === 'message_stop') { + return SSE_STREAM_END; + } + + return null; +} diff --git a/src/types.ts b/src/types.ts index 14d459a..eaa3b22 100644 --- a/src/types.ts +++ b/src/types.ts @@ -70,7 +70,12 @@ export interface ChatResult { model: string; } +export type ProviderStreamChunk = + | { kind: 'text'; text: string } + | { kind: 'model'; model: string }; + export interface Provider { complete(params: ChatParams): Promise; + completeStream?(params: ChatParams): AsyncIterable; fetchModels(baseUrl: string, apiKey: string): Promise; } diff --git a/tests/batch.test.mjs b/tests/batch.test.mjs new file mode 100644 index 0000000..78d3884 --- /dev/null +++ b/tests/batch.test.mjs @@ -0,0 +1,323 @@ +import assert from 'node:assert/strict'; +import test from 'node:test'; +import { execFileSync } from 'node:child_process'; +import { + existsSync, + mkdirSync, + mkdtempSync, + realpathSync, + rmSync, + writeFileSync, +} from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; + +import { + findGitRepositories, + gitHasChanges, + getGitDiff, + gitCommit, +} from '../dist/commands/batch.js'; + +function createTempDir() { + return realpathSync.native( + mkdtempSync(join(tmpdir(), 'commit-echo-batch-test-')), + ); +} + +function git(args, cwd) { + return execFileSync('git', args, { + cwd, + encoding: 'utf-8', + stdio: 'pipe', + }); +} + +function initRepo(root, name) { + const repoDir = join(root, name); + mkdirSync(repoDir, { recursive: true }); + git(['init'], repoDir); + git(['config', 'core.fsmonitor', 'false'], repoDir); + git(['config', 'user.name', 'Test User'], repoDir); + git(['config', 'user.email', 'test@example.com'], repoDir); + + return repoDir; +} + +// ─── findGitRepositories ──────────────────────────────────────────────────── + +test('findGitRepositories returns repos in a flat directory', () => { + const root = createTempDir(); + try { + const repoA = initRepo(root, 'repo-a'); + const repoB = initRepo(root, 'repo-b'); + const repos = findGitRepositories(root, false); + + assert.equal(repos.length, 2); + assert.ok(repos.includes(repoA)); + assert.ok(repos.includes(repoB)); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories ignores hidden directories', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'visible-repo'); + mkdirSync(join(root, '.hidden'), { recursive: true }); + const repos = findGitRepositories(root, false); + + assert.equal(repos.length, 1); + assert.equal(repos[0], repo); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories non-recursive does not descend into subdirectories', () => { + const root = createTempDir(); + try { + const topRepo = initRepo(root, 'top-repo'); + const nestedDir = join(root, 'nested'); + mkdirSync(nestedDir, { recursive: true }); + const nestedRepo = initRepo(nestedDir, 'inner-repo'); + + const flat = findGitRepositories(root, false); + assert.equal(flat.length, 1); + assert.equal(flat[0], topRepo); + + const recursive = findGitRepositories(root, true); + assert.equal(recursive.length, 2); + assert.ok(recursive.includes(topRepo)); + assert.ok(recursive.includes(nestedRepo)); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories returns empty array for non-existent directory', () => { + const repos = findGitRepositories('/path/does/not/exist', false); + assert.deepEqual(repos, []); +}); + +test('findGitRepositories returns empty array for directory with no repos', () => { + const root = createTempDir(); + try { + mkdirSync(join(root, 'plain-dir'), { recursive: true }); + const repos = findGitRepositories(root, false); + assert.deepEqual(repos, []); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories sorts results alphabetically', () => { + const root = createTempDir(); + try { + const repoB = initRepo(root, 'b-repo'); + const repoA = initRepo(root, 'a-repo'); + const repoC = initRepo(root, 'c-repo'); + + const repos = findGitRepositories(root, false); + assert.equal(repos.length, 3); + assert.equal(repos[0], repoA); + assert.equal(repos[1], repoB); + assert.equal(repos[2], repoC); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories returns rootDir when it is itself a git repo', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'inner'); + // Point directly at the repo itself, not its parent + const repos = findGitRepositories(repo, false); + + assert.equal(repos.length, 1); + assert.equal(repos[0], repo); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('findGitRepositories returns rootDir even with recursive flag', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'inner'); + const repos = findGitRepositories(repo, true); + + assert.equal(repos.length, 1); + assert.equal(repos[0], repo); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +// ─── gitHasChanges ────────────────────────────────────────────────────────── + +test('gitHasChanges detects staged changes', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'content\n', 'utf-8'); + git(['add', 'file.txt'], repo); + + const { staged, unstaged } = gitHasChanges(repo); + assert.equal(staged, true); + assert.equal(unstaged, false); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('gitHasChanges detects unstaged changes', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + // Create a tracked file first, then modify it + writeFileSync(join(repo, 'file.txt'), 'initial\n', 'utf-8'); + git(['add', 'file.txt'], repo); + git(['commit', '-m', 'feat: initial'], repo); + writeFileSync(join(repo, 'file.txt'), 'modified\n', 'utf-8'); + + const { staged, unstaged } = gitHasChanges(repo); + assert.equal(staged, false); + assert.equal(unstaged, true); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('gitHasChanges returns false for clean repo', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'content\n', 'utf-8'); + git(['add', 'file.txt'], repo); + git(['commit', '-m', 'feat: initial'], repo); + + const { staged, unstaged } = gitHasChanges(repo); + assert.equal(staged, false); + assert.equal(unstaged, false); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('gitHasChanges detects both staged and unstaged', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + // Create a tracked file + writeFileSync(join(repo, 'tracked.txt'), 'base\n', 'utf-8'); + git(['add', 'tracked.txt'], repo); + git(['commit', '-m', 'feat: initial'], repo); + // Modify and stage it + writeFileSync(join(repo, 'tracked.txt'), 'staged change\n', 'utf-8'); + git(['add', 'tracked.txt'], repo); + // Modify again (unstaged) + writeFileSync(join(repo, 'tracked.txt'), 'staged + unstaged\n', 'utf-8'); + + const { staged, unstaged } = gitHasChanges(repo); + assert.equal(staged, true); + assert.equal(unstaged, true); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +// ─── getGitDiff ───────────────────────────────────────────────────────────── + +test('getGitDiff returns the staged diff', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'hello\n', 'utf-8'); + git(['add', 'file.txt'], repo); + + const diff = getGitDiff(repo, true); + assert.match(diff, /diff --git/); + assert.match(diff, /\+hello/); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('getGitDiff returns the unstaged diff', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'initial\n', 'utf-8'); + git(['add', 'file.txt'], repo); + git(['commit', '-m', 'feat: initial'], repo); + writeFileSync(join(repo, 'file.txt'), 'modified\n', 'utf-8'); + + const diff = getGitDiff(repo, false); + assert.match(diff, /diff --git/); + assert.match(diff, /-initial/); + assert.match(diff, /\+modified/); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('getGitDiff throws when not in a git repo', () => { + const root = createTempDir(); + try { + assert.throws(() => getGitDiff(root, true), /Failed to get diff/); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +// ─── gitCommit ──────────────────────────────────────────────────────────── + +test('gitCommit creates a commit and returns hash and summary', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'content\n', 'utf-8'); + git(['add', 'file.txt'], repo); + + const result = gitCommit(repo, 'feat: initial commit'); + + assert.ok(result.hash, 'expected a commit hash'); + assert.ok(/^[0-9a-f]+$/.test(result.hash), 'hash should be hex'); + assert.ok(result.hash.length >= 7, 'hash should be at least 7 chars'); + assert.ok(result.summary.includes('feat: initial commit')); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('gitCommit includes body in the commit message', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + writeFileSync(join(repo, 'file.txt'), 'content\n', 'utf-8'); + git(['add', 'file.txt'], repo); + + const result = gitCommit(repo, 'feat: with body', 'Optional body text here'); + + assert.ok(result.hash, 'expected a commit hash'); + // Verify body is in the full commit message + const log = git(['log', '--format=%B', '-1'], repo); + assert.match(log, /feat: with body/); + assert.match(log, /Optional body text here/); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); + +test('gitCommit throws on empty commit (nothing to commit)', () => { + const root = createTempDir(); + try { + const repo = initRepo(root, 'repo'); + assert.throws(() => gitCommit(repo, 'message'), /nothing to commit/i); + } finally { + rmSync(root, { recursive: true, force: true }); + } +}); diff --git a/tests/client-stream.test.mjs b/tests/client-stream.test.mjs new file mode 100644 index 0000000..3508659 --- /dev/null +++ b/tests/client-stream.test.mjs @@ -0,0 +1,131 @@ +import assert from 'node:assert/strict'; +import test from 'node:test'; + +import { generateSuggestionsStream } from '../dist/llm/client.js'; +import { streamFromChunks } from './helpers/stream-from-chunks.mjs'; + +const emptyProfile = { + avgLength: 0, + commonPrefixes: [], + prefixRates: {}, + imperativeRate: 0, + sentenceCaseRate: 0, + usesScopeRate: 0, + usesBodyRate: 0, + totalCommits: 0, +}; + +test('generateSuggestionsStream yields meta then text chunks', async () => { + const originalFetch = globalThis.fetch; + + globalThis.fetch = async () => + new Response( + streamFromChunks([ + 'data: {"choices":[{"delta":{"content":"1. feat: stream test"}}]}\n', + 'data: [DONE]\n', + ]), + { status: 200 }, + ); + + try { + const events = []; + for await (const event of generateSuggestionsStream( + { + provider: '__custom__', + model: 'test-model', + baseUrl: 'http://127.0.0.1/v1', + apiKey: 'test-key', + historySize: 5, + maxDiffSize: 100_000, + }, + 'diff --git a/file.txt b/file.txt\n', + emptyProfile, + 'test-key', + )) { + events.push(event); + } + + assert.equal(events[0]?.kind, 'meta'); + assert.equal(events[0]?.truncation, undefined); + + const chunks = events + .filter((event) => event.kind === 'text') + .map((event) => event.text); + assert.equal(chunks.join(''), '1. feat: stream test'); + } finally { + globalThis.fetch = originalFetch; + } +}); + +test('generateSuggestionsStream yields model from provider stream', async () => { + const originalFetch = globalThis.fetch; + + globalThis.fetch = async () => + new Response( + streamFromChunks([ + 'data: {"model":"gpt-4o-mini","choices":[{"delta":{"content":"1. feat: stream test"}}]}\n', + 'data: [DONE]\n', + ]), + { status: 200 }, + ); + + try { + const events = []; + for await (const event of generateSuggestionsStream( + { + provider: '__custom__', + model: 'test-model', + baseUrl: 'http://127.0.0.1/v1', + apiKey: 'test-key', + historySize: 5, + maxDiffSize: 100_000, + }, + 'diff --git a/file.txt b/file.txt\n', + emptyProfile, + 'test-key', + )) { + events.push(event); + } + + const modelEvent = events.find((event) => event.kind === 'model'); + assert.equal(modelEvent?.model, 'gpt-4o-mini'); + } finally { + globalThis.fetch = originalFetch; + } +}); + +test('generateSuggestionsStream meta includes truncation info', async () => { + const originalFetch = globalThis.fetch; + const largeDiff = `diff --git a/big.txt b/big.txt\n${'x'.repeat(200)}`; + + globalThis.fetch = async () => + new Response( + streamFromChunks(['data: [DONE]\n']), + { status: 200 }, + ); + + try { + const events = []; + for await (const event of generateSuggestionsStream( + { + provider: '__custom__', + model: 'test-model', + baseUrl: 'http://127.0.0.1/v1', + apiKey: 'test-key', + historySize: 5, + maxDiffSize: 50, + }, + largeDiff, + emptyProfile, + 'test-key', + )) { + events.push(event); + } + + assert.equal(events[0]?.kind, 'meta'); + assert.equal(events[0]?.truncation?.wasTruncated, true); + assert.ok((events[0]?.truncation?.originalSize ?? 0) > 50); + } finally { + globalThis.fetch = originalFetch; + } +}); diff --git a/tests/config-store.test.mjs b/tests/config-store.test.mjs new file mode 100644 index 0000000..ce335c8 --- /dev/null +++ b/tests/config-store.test.mjs @@ -0,0 +1,54 @@ +import assert from "node:assert/strict"; +import { mkdir, mkdtemp, writeFile } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { dirname } from "node:path"; +import test from "node:test"; + +import { getConfigPath, loadConfig } from "../dist/config/store.js"; + +function escapeRegExp(value) { + return value.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); +} + +test("loadConfig reports invalid JSON with the config path and fix hint", async () => { + const originalHome = process.env.HOME; + const originalXdgConfigHome = process.env.XDG_CONFIG_HOME; + const originalAppData = process.env.APPDATA; + const home = await mkdtemp(`${tmpdir()}/commit-echo-config-`); + + process.env.HOME = home; + process.env.APPDATA = home; + delete process.env.XDG_CONFIG_HOME; + + try { + const configPath = getConfigPath(); + await mkdir(dirname(configPath), { recursive: true }); + await writeFile(configPath, "{ invalid json", "utf-8"); + + await assert.rejects(loadConfig(), (error) => { + assert.equal(error instanceof Error, true); + assert.match(error.message, /Invalid JSON in config file:/); + assert.match(error.message, new RegExp(escapeRegExp(configPath))); + assert.match(error.message, /Fix the JSON syntax or run `commit-echo init` to recreate it\./); + return true; + }); + } finally { + if (originalHome === undefined) { + delete process.env.HOME; + } else { + process.env.HOME = originalHome; + } + + if (originalAppData === undefined) { + delete process.env.APPDATA; + } else { + process.env.APPDATA = originalAppData; + } + + if (originalXdgConfigHome === undefined) { + delete process.env.XDG_CONFIG_HOME; + } else { + process.env.XDG_CONFIG_HOME = originalXdgConfigHome; + } + } +}); diff --git a/tests/e2e/suggest-smoke.test.mjs b/tests/e2e/suggest-smoke.test.mjs index 629bc02..b0b8054 100644 --- a/tests/e2e/suggest-smoke.test.mjs +++ b/tests/e2e/suggest-smoke.test.mjs @@ -19,6 +19,10 @@ function onceExit(child) { }); } +function stripAnsi(text) { + return text.replace(/\x1B\[[0-?]*[ -/]*[@-~]/g, ''); +} + function runSuggestUntil(args, { cwd, env, text }) { return new Promise((resolve, reject) => { const child = spawn(process.execPath, [join(process.cwd(), 'dist/index.js'), ...args], { @@ -64,6 +68,36 @@ function runSuggestUntil(args, { cwd, env, text }) { }); } +function runCli(args, { cwd, env }) { + return new Promise((resolve, reject) => { + const child = spawn(process.execPath, [join(process.cwd(), 'dist/index.js'), ...args], { + cwd, + env, + stdio: ['ignore', 'pipe', 'pipe'], + }); + let stdout = ''; + let stderr = ''; + const timeout = setTimeout(() => { + child.kill('SIGINT'); + reject(new Error(`Timed out running ${args.join(' ')}. stdout: ${stdout} stderr: ${stderr}`)); + }, 8000); + child.stdout.on('data', (chunk) => { + stdout += chunk.toString(); + }); + child.stderr.on('data', (chunk) => { + stderr += chunk.toString(); + }); + child.on('error', (err) => { + clearTimeout(timeout); + reject(err); + }); + child.on('exit', (code, signal) => { + clearTimeout(timeout); + resolve({ code, signal, stdout, stderr }); + }); + }); +} + function configDirFor(home) { return platform() === 'darwin' ? join(home, 'Library', 'Application Support', 'commit-echo') @@ -79,6 +113,7 @@ async function setupRepo(root) { await mkdir(repo, { recursive: true }); execFileSync('git', ['init'], { cwd: repo }); + execFileSync('git', ['config', 'core.fsmonitor', 'false'], { cwd: repo }); execFileSync('git', ['config', 'user.name', 'E2E Tester'], { cwd: repo }); execFileSync('git', ['config', 'user.email', 'e2e@example.com'], { cwd: repo }); await writeFile(join(repo, 'README.md'), '# fixture\n', 'utf8'); @@ -93,6 +128,24 @@ async function setupRepo(root) { return { home, repo, configDir }; } +async function writeCustomProviderConfig(configDir, port) { + await writeFile( + join(configDir, 'config.json'), + JSON.stringify( + { + provider: '__custom__', + model: 'fixture-model', + baseUrl: `http://127.0.0.1:${port}`, + apiKey: 'test-key', + historySize: 5, + }, + null, + 2, + ), + 'utf8', + ); +} + test('suggest smoke test boots the CLI, loads config, and prints suggestions', async (t) => { const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-')); const { home, repo, configDir } = await setupRepo(root); @@ -177,6 +230,116 @@ test('suggest smoke test boots the CLI, loads config, and prints suggestions', a assert.ok(result.code === 0 || result.signal === 'SIGINT'); }); +test('suggest --auto selects the first suggestion like --yes without committing', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-auto-suggest-')); + const { home, repo, configDir } = await setupRepo(root); + + const server = createServer(async (req, res) => { + if (req.url === '/chat/completions' && req.method === 'POST') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + model: 'fixture-model', + choices: [{ message: { content: '1. feat: choose first alias\n2. docs: should not select' } }], + }), + ); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeCustomProviderConfig(configDir, port); + + const env = { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }; + + const yes = await runCli(['suggest', '--yes'], { cwd: repo, env }); + const auto = await runCli(['suggest', '--auto'], { cwd: repo, env }); + + for (const result of [yes, auto]) { + const stdout = stripAnsi(result.stdout); + assert.equal(result.code, 0); + assert.equal(result.signal, null); + assert.equal(result.stderr, ''); + assert.match(stdout, /Suggestions generated:/); + assert.match(stdout, /Selected:\s+feat: choose first alias/); + assert.doesNotMatch(stdout, /Choose an action/); + } + + assert.equal( + execFileSync('git', ['log', '-1', '--pretty=%s'], { cwd: repo, encoding: 'utf8' }).trim(), + 'feat: initial fixture', + ); +}); + +test('top-level --auto commits the first suggestion like --yes', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-auto-commit-')); + const yesFixture = await setupRepo(join(root, 'yes')); + const autoFixture = await setupRepo(join(root, 'auto')); + + const server = createServer(async (req, res) => { + if (req.url === '/chat/completions' && req.method === 'POST') { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end( + JSON.stringify({ + model: 'fixture-model', + choices: [{ message: { content: '1. feat: auto alias parity\n2. docs: should not commit' } }], + }), + ); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeCustomProviderConfig(yesFixture.configDir, port); + await writeCustomProviderConfig(autoFixture.configDir, port); + + const envFor = (home) => ({ + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }); + + const yes = await runCli(['--yes'], { cwd: yesFixture.repo, env: envFor(yesFixture.home) }); + const auto = await runCli(['--auto'], { cwd: autoFixture.repo, env: envFor(autoFixture.home) }); + + for (const [result, repo] of [ + [yes, yesFixture.repo], + [auto, autoFixture.repo], + ]) { + const stdout = stripAnsi(result.stdout); + assert.equal(result.code, 0); + assert.equal(result.signal, null); + assert.equal(result.stderr, ''); + assert.match(stdout, /Selected:\s+feat: auto alias parity/); + assert.equal( + execFileSync('git', ['log', '-1', '--pretty=%s'], { cwd: repo, encoding: 'utf8' }).trim(), + 'feat: auto alias parity', + ); + } +}); + test('suggest reports no changes before checking for an API key', async (t) => { const root = await mkdtemp(join(tmpdir(), 'commit-echo-no-changes-')); const { home, repo, configDir } = await setupRepo(root); @@ -281,3 +444,329 @@ test('suggest --model overrides configured model for one invocation and -m is an await runSuggestUntil(['suggest', '-m', 'claude-3-5-sonnet'], { cwd: repo, env, text: 'Suggestions generated:' }); assert.equal(requests.at(-1).model, 'claude-3-5-sonnet'); }); + +test('suggest --stream prints incremental SSE output', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-stream-')); + const { home, repo, configDir } = await setupRepo(root); + + const requests = []; + const server = createServer(async (req, res) => { + if (req.url === '/chat/completions' && req.method === 'POST') { + let body = ''; + req.setEncoding('utf8'); + for await (const chunk of req) body += chunk; + const parsed = JSON.parse(body); + requests.push(parsed); + + if (parsed.stream) { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"choices":[{"delta":{"content":"1. feat: streamed suggestion"}}]}\n\n'); + res.write('data: [DONE]\n\n'); + res.end(); + return; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + model: parsed.model, + choices: [{ message: { content: '1. feat: fallback suggestion' } }], + })); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeFile( + join(configDir, 'config.json'), + JSON.stringify({ + provider: '__custom__', + model: 'fixture-model', + baseUrl: `http://127.0.0.1:${port}`, + apiKey: 'test-key', + historySize: 5, + }, null, 2), + 'utf8' + ); + + const { stdout } = await runSuggestUntil( + ['suggest', '--stream'], + { + cwd: repo, + env: { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }, + text: 'feat: streamed suggestion', + }, + ); + + assert.match(stdout, /Streaming suggestions/); + assert.match(stdout, /feat: streamed suggestion/); + assert.equal(requests.at(-1)?.stream, true); + assert.equal( + (stdout.match(/feat: streamed suggestion/g) ?? []).length, + 1, + 'streamed suggestion text should not be printed twice', + ); +}); + +test('suggest --stream prints incremental Anthropic SSE output', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-stream-anthropic-')); + const { home, repo, configDir } = await setupRepo(root); + + const requests = []; + const server = createServer(async (req, res) => { + if (req.url === '/v1/messages' && req.method === 'POST') { + let body = ''; + req.setEncoding('utf8'); + for await (const chunk of req) body += chunk; + const parsed = JSON.parse(body); + requests.push(parsed); + + if (parsed.stream) { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('event: content_block_delta\n'); + res.write('data: {"delta":{"text":"1. feat: anthropic streamed suggestion"}}\n\n'); + res.write('event: message_stop\n'); + res.write('data: {}\n\n'); + res.end(); + return; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + model: parsed.model, + content: [{ type: 'text', text: '1. feat: anthropic fallback suggestion' }], + })); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeFile( + join(configDir, 'config.json'), + JSON.stringify({ + provider: 'anthropic', + model: 'claude-sonnet-4-20250514', + baseUrl: `http://127.0.0.1:${port}/v1`, + apiKey: 'test-key', + historySize: 5, + }, null, 2), + 'utf8' + ); + + const { stdout } = await runSuggestUntil( + ['suggest', '--stream'], + { + cwd: repo, + env: { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }, + text: 'feat: anthropic streamed suggestion', + }, + ); + + assert.match(stdout, /Streaming suggestions/); + assert.match(stdout, /feat: anthropic streamed suggestion/); + assert.equal(requests.at(-1)?.stream, true); +}); + +test('suggest --stream --yes streams output and auto-commits the first suggestion', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-stream-yes-')); + const { home, repo, configDir } = await setupRepo(root); + + const requests = []; + const server = createServer(async (req, res) => { + if (req.url === '/chat/completions' && req.method === 'POST') { + let body = ''; + req.setEncoding('utf8'); + for await (const chunk of req) body += chunk; + const parsed = JSON.parse(body); + requests.push(parsed); + + if (parsed.stream) { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"choices":[{"delta":{"content":"1. feat: stream auto commit"}}]}\n\n'); + res.write('data: [DONE]\n\n'); + res.end(); + return; + } + + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + model: parsed.model, + choices: [{ message: { content: '1. feat: fallback suggestion' } }], + })); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeFile( + join(configDir, 'config.json'), + JSON.stringify({ + provider: '__custom__', + model: 'fixture-model', + baseUrl: `http://127.0.0.1:${port}`, + apiKey: 'test-key', + historySize: 5, + }, null, 2), + 'utf8' + ); + + const child = spawn(process.execPath, [join(process.cwd(), 'dist/index.js'), 'suggest', '--stream', '--yes'], { + cwd: repo, + env: { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }, + stdio: ['ignore', 'pipe', 'pipe'], + }); + + let stdout = ''; + child.stdout.on('data', (chunk) => { + stdout += chunk.toString(); + }); + + const result = await onceExit(child); + assert.equal(result.code, 0); + assert.match(stdout, /Streaming suggestions/); + assert.match(stdout, /feat: stream auto commit/); + assert.match(stdout, /Selected:/); + assert.doesNotMatch(stdout, /Choose an action/); + assert.equal(requests.at(-1)?.stream, true); +}); + +test('suggest --stream fails fast for unsupported providers', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-stream-cohere-')); + const { home, repo, configDir } = await setupRepo(root); + + await writeFile( + join(configDir, 'config.json'), + JSON.stringify({ + provider: 'cohere', + model: 'command-r', + apiKey: 'test-key', + historySize: 5, + }, null, 2), + 'utf8' + ); + + t.after(async () => { + await rm(root, { recursive: true, force: true }); + }); + + const child = spawn(process.execPath, [join(process.cwd(), 'dist/index.js'), 'suggest', '--stream'], { + cwd: repo, + env: { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }, + stdio: ['ignore', 'pipe', 'pipe'], + }); + + let stdout = ''; + child.stdout.on('data', (chunk) => { + stdout += chunk.toString(); + }); + + const result = await onceExit(child); + assert.equal(result.code, 0); + assert.match(stdout, /Streaming is not supported for the 'cohere' provider/); + assert.doesNotMatch(stdout, /Streaming suggestions/); +}); + +test('suggest --stream reports parse failure for unparseable streamed output', async (t) => { + const root = await mkdtemp(join(tmpdir(), 'commit-echo-e2e-stream-parse-')); + const { home, repo, configDir } = await setupRepo(root); + + const server = createServer(async (req, res) => { + if (req.url === '/chat/completions' && req.method === 'POST') { + let body = ''; + req.setEncoding('utf8'); + for await (const chunk of req) body += chunk; + + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"choices":[{"delta":{"content":"not a numbered suggestion list"}}]}\n\n'); + res.write('data: [DONE]\n\n'); + res.end(); + return; + } + + res.writeHead(404, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ error: 'not found' })); + }); + const port = await listen(server); + t.after(async () => { + server.close(); + await rm(root, { recursive: true, force: true }); + }); + + await writeFile( + join(configDir, 'config.json'), + JSON.stringify({ + provider: '__custom__', + model: 'fixture-model', + baseUrl: `http://127.0.0.1:${port}`, + apiKey: 'test-key', + historySize: 5, + }, null, 2), + 'utf8' + ); + + const child = spawn(process.execPath, [join(process.cwd(), 'dist/index.js'), 'suggest', '--stream'], { + cwd: repo, + env: { + ...process.env, + HOME: home, + XDG_CONFIG_HOME: join(home, '.config'), + APPDATA: join(home, 'AppData', 'Roaming'), + FORCE_COLOR: '0', + }, + stdio: ['ignore', 'pipe', 'pipe'], + }); + + let stdout = ''; + child.stdout.on('data', (chunk) => { + stdout += chunk.toString(); + }); + + const result = await onceExit(child); + assert.equal(result.code, 0); + assert.match(stdout, /not a numbered suggestion list/); + assert.match(stdout, /Could not parse any suggestions from LLM response/); +}); diff --git a/tests/git-diff.test.mjs b/tests/git-diff.test.mjs index 9c9c346..dd5ad8a 100644 --- a/tests/git-diff.test.mjs +++ b/tests/git-diff.test.mjs @@ -4,6 +4,7 @@ import { existsSync, mkdirSync, mkdtempSync, + realpathSync, rmSync, writeFileSync, } from "node:fs"; @@ -20,7 +21,7 @@ import { } from "../dist/git/diff.js"; function createTempDir() { - return mkdtempSync(join(tmpdir(), "commit-echo-git-diff-test-")); + return realpathSync.native(mkdtempSync(join(tmpdir(), "commit-echo-git-diff-test-"))); } function git(args, cwd) { @@ -35,6 +36,7 @@ function initRepo() { const repoDir = createTempDir(); git(["init"], repoDir); + git(["config", "core.fsmonitor", "false"], repoDir); git(["config", "user.name", "Test User"], repoDir); git(["config", "user.email", "test@example.com"], repoDir); diff --git a/tests/helpers/stream-from-chunks.mjs b/tests/helpers/stream-from-chunks.mjs new file mode 100644 index 0000000..a60dc5b --- /dev/null +++ b/tests/helpers/stream-from-chunks.mjs @@ -0,0 +1,13 @@ +export function streamFromChunks(chunks) { + let index = 0; + return new ReadableStream({ + pull(controller) { + if (index >= chunks.length) { + controller.close(); + return; + } + controller.enqueue(new TextEncoder().encode(chunks[index])); + index += 1; + }, + }); +} diff --git a/tests/history-profile.test.mjs b/tests/history-profile.test.mjs index 19f660d..55f09d2 100644 --- a/tests/history-profile.test.mjs +++ b/tests/history-profile.test.mjs @@ -2,12 +2,20 @@ import assert from 'node:assert/strict'; import test from 'node:test'; import { mkdtempSync, writeFileSync, rmSync, mkdirSync } from 'node:fs'; import { join } from 'node:path'; -import { tmpdir } from 'node:os'; +import { platform, tmpdir } from 'node:os'; import { buildProfile } from '../dist/history/store.js'; +function configDirFor(homeDir) { + return platform() === 'darwin' + ? join(homeDir, 'Library', 'Application Support', 'commit-echo') + : platform() === 'win32' + ? join(homeDir, 'AppData', 'Roaming', 'commit-echo') + : join(homeDir, '.config', 'commit-echo'); +} + function writeHistory(homeDir, messages) { - const configDir = join(homeDir, 'Library', 'Application Support', 'commit-echo'); + const configDir = configDirFor(homeDir); const historyPath = join(configDir, 'history.jsonl'); mkdirSync(configDir, { recursive: true }); writeFileSync( @@ -19,16 +27,28 @@ function writeHistory(homeDir, messages) { model: 'test-model', provider: 'openai', })).join('\n') + '\n', - 'utf-8' + 'utf-8', ); } +function restoreEnv(name, value) { + if (value === undefined) { + delete process.env[name]; + } else { + process.env[name] = value; + } +} + test('buildProfile excludes descriptive verb forms from the imperative-rate denominator', async () => { const originalHome = process.env.HOME; + const originalAppData = process.env.APPDATA; + const originalXdgConfigHome = process.env.XDG_CONFIG_HOME; const tempHome = mkdtempSync(join(tmpdir(), 'commit-echo-home-')); try { process.env.HOME = tempHome; + process.env.APPDATA = join(tempHome, 'AppData', 'Roaming'); + process.env.XDG_CONFIG_HOME = join(tempHome, '.config'); writeHistory(tempHome, [ 'fix: add retries', 'fix: added retries', @@ -40,7 +60,9 @@ test('buildProfile excludes descriptive verb forms from the imperative-rate deno assert.equal(profile.totalCommits, 3); assert.equal(profile.imperativeRate, 1); } finally { - process.env.HOME = originalHome; + restoreEnv('HOME', originalHome); + restoreEnv('APPDATA', originalAppData); + restoreEnv('XDG_CONFIG_HOME', originalXdgConfigHome); rmSync(tempHome, { recursive: true, force: true }); } }); diff --git a/tests/providers-index.test.mjs b/tests/providers-index.test.mjs index 544be22..bc47273 100644 --- a/tests/providers-index.test.mjs +++ b/tests/providers-index.test.mjs @@ -1,7 +1,12 @@ import assert from 'node:assert/strict'; import test from 'node:test'; -import { createProvider, fetchModels } from '../dist/providers/index.js'; +import { + getStreamingProvider, + completeStream, + createProvider, + fetchModels, +} from '../dist/providers/index.js'; test('createProvider returns the Anthropic adapter shape', () => { const provider = createProvider('anthropic'); @@ -16,3 +21,30 @@ test('fetchModels returns Anthropic model ids', async () => { assert.ok(models.length > 0); assert.ok(models.includes('claude-sonnet-4')); }); + +test('getStreamingProvider rejects providers without streaming', () => { + assert.throws( + () => getStreamingProvider('cohere'), + /Streaming is not supported for the 'cohere' provider/, + ); +}); + +test('getStreamingProvider accepts streaming providers', () => { + assert.doesNotThrow(() => getStreamingProvider('anthropic')); + assert.doesNotThrow(() => getStreamingProvider('openai')); +}); + +test('completeStream rejects Cohere before making a request', async () => { + await assert.rejects( + async () => { + for await (const _chunk of completeStream('cohere', undefined, { + model: 'command', + messages: [{ role: 'user', content: 'test' }], + apiKey: 'test-key', + })) { + // no-op + } + }, + /Streaming is not supported for the 'cohere' provider/, + ); +}); diff --git a/tests/stream-sse.test.mjs b/tests/stream-sse.test.mjs new file mode 100644 index 0000000..f8d296c --- /dev/null +++ b/tests/stream-sse.test.mjs @@ -0,0 +1,176 @@ +import assert from 'node:assert/strict'; +import test from 'node:test'; + +import { + parseAnthropicSseLine, + parseOpenAiSseLine, + SSE_STREAM_END, +} from '../dist/providers/sse.js'; +import { AnthropicProvider } from '../dist/providers/anthropic.js'; +import { OpenAICompatibleProvider } from '../dist/providers/openai-compatible.js'; +import { streamFromChunks } from './helpers/stream-from-chunks.mjs'; + +test('parseOpenAiSseLine extracts delta content', () => { + const result = parseOpenAiSseLine( + 'data: {"choices":[{"delta":{"content":"hello"}}]}', + ); + + assert.equal(result.text, 'hello'); +}); + +test('parseOpenAiSseLine extracts model from stream chunk', () => { + const result = parseOpenAiSseLine( + 'data: {"model":"gpt-4o","choices":[{"delta":{"content":"hello"}}]}', + ); + + assert.equal(result.model, 'gpt-4o'); + assert.equal(result.text, 'hello'); +}); + +test('parseOpenAiSseLine detects stream completion', () => { + assert.deepEqual(parseOpenAiSseLine('data: [DONE]'), { done: true }); +}); + +test('parseOpenAiSseLine surfaces API errors', () => { + const result = parseOpenAiSseLine( + 'data: {"error":{"message":"rate limited"}}', + ); + + assert.equal(result.error, 'rate limited'); +}); + +test('parseAnthropicSseLine handles event and data split across batches', () => { + const state = { currentEvent: '' }; + + const eventResult = parseAnthropicSseLine('event: content_block_delta', state); + assert.equal(eventResult, null); + + assert.equal(state.currentEvent, 'content_block_delta'); + + const dataResult = parseAnthropicSseLine( + 'data: {"delta":{"text":"hello"}}', + state, + ); + assert.deepEqual(dataResult, { kind: 'text', text: 'hello' }); +}); + +test('parseAnthropicSseLine extracts model from message_start', () => { + const state = { currentEvent: '' }; + parseAnthropicSseLine('event: message_start', state); + const result = parseAnthropicSseLine( + 'data: {"type":"message_start","message":{"model":"claude-sonnet-4"}}', + state, + ); + assert.deepEqual(result, { kind: 'model', model: 'claude-sonnet-4' }); +}); + +test('parseAnthropicSseLine returns SSE_STREAM_END on message_stop', () => { + const state = { currentEvent: '' }; + parseAnthropicSseLine('event: message_stop', state); + const result = parseAnthropicSseLine('data: {}', state); + assert.equal(result, SSE_STREAM_END); +}); + +test('parseAnthropicSseLine throws on error events', () => { + const state = { currentEvent: '' }; + parseAnthropicSseLine('event: error', state); + assert.throws( + () => parseAnthropicSseLine('data: {"error":{"message":"overloaded"}}', state), + /overloaded/, + ); +}); + +test('Anthropic completeStream reassembles event/data split across network chunks', async () => { + const originalFetch = globalThis.fetch; + const provider = new AnthropicProvider(); + + globalThis.fetch = async () => + new Response( + streamFromChunks([ + 'event: content_block_delta\n', + 'data: {"delta":{"text":"hi"}}\n', + 'event: message_stop\n', + 'data: {}\n', + ]), + { status: 200 }, + ); + + try { + const chunks = []; + for await (const chunk of provider.completeStream({ + model: 'claude-sonnet-4', + messages: [{ role: 'user', content: 'test' }], + apiKey: 'test-key', + baseUrl: 'https://api.anthropic.com/v1', + })) { + chunks.push(chunk); + } + + assert.deepEqual(chunks, [{ kind: 'text', text: 'hi' }]); + } finally { + globalThis.fetch = originalFetch; + } +}); + +test('OpenAI completeStream processes final line without trailing newline', async () => { + const originalFetch = globalThis.fetch; + const provider = new OpenAICompatibleProvider(); + + globalThis.fetch = async () => + new Response( + streamFromChunks([ + 'data: {"choices":[{"delta":{"content":"hel"}}]}\n', + 'data: {"choices":[{"delta":{"content":"lo"}}]}', + ]), + { status: 200 }, + ); + + try { + const chunks = []; + for await (const chunk of provider.completeStream({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'test' }], + apiKey: 'test-key', + baseUrl: 'https://api.openai.com/v1', + })) { + chunks.push(chunk); + } + + assert.deepEqual(chunks, [ + { kind: 'text', text: 'hel' }, + { kind: 'text', text: 'lo' }, + ]); + } finally { + globalThis.fetch = originalFetch; + } +}); + +test('OpenAI completeStream handles [DONE] in final buffer without trailing newline', async () => { + const originalFetch = globalThis.fetch; + const provider = new OpenAICompatibleProvider(); + + globalThis.fetch = async () => + new Response( + streamFromChunks([ + 'data: {"choices":[{"delta":{"content":"done"}}]}\n', + 'data: [DONE]', + ]), + { status: 200 }, + ); + + try { + const chunks = []; + for await (const chunk of provider.completeStream({ + model: 'gpt-4o', + messages: [{ role: 'user', content: 'test' }], + apiKey: 'test-key', + baseUrl: 'https://api.openai.com/v1', + })) { + chunks.push(chunk); + } + + assert.deepEqual(chunks, [{ kind: 'text', text: 'done' }]); + } finally { + globalThis.fetch = originalFetch; + } +});