From 140269602124e816e03058aa9260e89cbe5abd6b Mon Sep 17 00:00:00 2001 From: Link Dupont Date: Wed, 25 Mar 2026 21:46:48 -0400 Subject: [PATCH] feat: Add Vertex AI support for Google and Anthropic models Building on the GeminiClient, add support for models hosted through GCP Vertex, both Google models and Anthropic models. --- .../AgentRunKit/LLM/GoogleAuthService.swift | 169 ++++++++++++++ .../LLM/VertexAnthropicClient.swift | 214 ++++++++++++++++++ .../AgentRunKit/LLM/VertexGoogleClient.swift | 190 ++++++++++++++++ .../GoogleAuthServiceTests.swift | 125 ++++++++++ .../VertexAnthropicClientTests.swift | 211 +++++++++++++++++ .../VertexGoogleClientTests.swift | 149 ++++++++++++ 6 files changed, 1058 insertions(+) create mode 100644 Sources/AgentRunKit/LLM/GoogleAuthService.swift create mode 100644 Sources/AgentRunKit/LLM/VertexAnthropicClient.swift create mode 100644 Sources/AgentRunKit/LLM/VertexGoogleClient.swift create mode 100644 Tests/AgentRunKitTests/GoogleAuthServiceTests.swift create mode 100644 Tests/AgentRunKitTests/VertexAnthropicClientTests.swift create mode 100644 Tests/AgentRunKitTests/VertexGoogleClientTests.swift diff --git a/Sources/AgentRunKit/LLM/GoogleAuthService.swift b/Sources/AgentRunKit/LLM/GoogleAuthService.swift new file mode 100644 index 0000000..1e36af5 --- /dev/null +++ b/Sources/AgentRunKit/LLM/GoogleAuthService.swift @@ -0,0 +1,169 @@ +import Foundation + +// MARK: - ADC Credential File + +private struct ADCCredentials: Decodable { + let type: String + let clientId: String + let clientSecret: String + let refreshToken: String + + enum CodingKeys: String, CodingKey { + case type + case clientId = "client_id" + case clientSecret = "client_secret" + case refreshToken = "refresh_token" + } +} + +// MARK: - Token Response + +private struct TokenResponse: Decodable { + let accessToken: String + let expiresIn: Int + let tokenType: String + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case expiresIn = "expires_in" + case tokenType = "token_type" + } +} + +/// Manages Google OAuth2 tokens from Application Default Credentials (ADC). +/// +/// Reads `~/.config/gcloud/application_default_credentials.json` (created by +/// `gcloud auth application-default login`) and transparently refreshes access +/// tokens as needed. +/// +/// Thread-safe via `actor` isolation — only one refresh request can be in +/// flight at a time. +public actor GoogleAuthService { + // MARK: - Errors + + public enum GoogleAuthError: Error, LocalizedError, Sendable { + case credentialsFileNotFound(path: String) + case unsupportedCredentialType(String) + case refreshFailed(statusCode: Int, body: String) + case decodingFailed(String) + + public var errorDescription: String? { + switch self { + case let .credentialsFileNotFound(path): + "Google ADC credentials not found at \(path). Run `gcloud auth application-default login`." + case let .unsupportedCredentialType(type): + "Unsupported ADC credential type: \(type). Only 'authorized_user' is supported." + case let .refreshFailed(code, body): + "Token refresh failed (HTTP \(code)): \(body)" + case let .decodingFailed(message): + "Failed to decode ADC credentials: \(message)" + } + } + } + + // MARK: - State + + private let clientID: String + private let clientSecret: String + private let refreshToken: String + private let session: URLSession + + private var cachedAccessToken: String? + private var tokenExpiry: Date? + + /// Refresh the token when it has fewer than this many seconds remaining. + private let refreshMargin: TimeInterval = 300 // 5 minutes + + private static let tokenEndpoint = URL(string: "https://oauth2.googleapis.com/token")! + + // MARK: - Init + + /// Creates an auth service by reading the ADC file at the default path. + public init(session: URLSession = .shared) throws { + try self.init(credentialsPath: Self.defaultCredentialsPath(), session: session) + } + + /// Creates an auth service by reading the ADC file at a custom path. + public init(credentialsPath: String, session: URLSession = .shared) throws { + guard FileManager.default.fileExists(atPath: credentialsPath) else { + throw GoogleAuthError.credentialsFileNotFound(path: credentialsPath) + } + let data: Data + do { + data = try Data(contentsOf: URL(fileURLWithPath: credentialsPath)) + } catch { + throw GoogleAuthError.decodingFailed("Failed to read file: \(error.localizedDescription)") + } + let credentials: ADCCredentials + do { + credentials = try JSONDecoder().decode(ADCCredentials.self, from: data) + } catch { + throw GoogleAuthError.decodingFailed(error.localizedDescription) + } + guard credentials.type == "authorized_user" else { + throw GoogleAuthError.unsupportedCredentialType(credentials.type) + } + clientID = credentials.clientId + clientSecret = credentials.clientSecret + refreshToken = credentials.refreshToken + self.session = session + } + + // MARK: - Public API + + /// Returns a valid access token, refreshing if necessary. + public func accessToken() async throws -> String { + if let token = cachedAccessToken, + let expiry = tokenExpiry, + Date() < expiry.addingTimeInterval(-refreshMargin) { + return token + } + return try await refreshAccessToken() + } + + // MARK: - Private + + private func refreshAccessToken() async throws -> String { + var request = URLRequest(url: Self.tokenEndpoint) + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + let body = [ + "client_id=\(urlEncode(clientID))", + "client_secret=\(urlEncode(clientSecret))", + "refresh_token=\(urlEncode(refreshToken))", + "grant_type=refresh_token", + ].joined(separator: "&") + request.httpBody = Data(body.utf8) + + let (data, response) = try await session.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse else { + throw GoogleAuthError.refreshFailed(statusCode: 0, body: "Invalid response") + } + guard httpResponse.statusCode == 200 else { + let responseBody = String(data: data, encoding: .utf8) ?? "" + throw GoogleAuthError.refreshFailed(statusCode: httpResponse.statusCode, body: responseBody) + } + + let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data) + cachedAccessToken = tokenResponse.accessToken + tokenExpiry = Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) + return tokenResponse.accessToken + } + + private func urlEncode(_ string: String) -> String { + string.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? string + } + + /// The default path to the ADC credentials file. + public static func defaultCredentialsPath() -> String { + let home = FileManager.default.homeDirectoryForCurrentUser.path + return "\(home)/.config/gcloud/application_default_credentials.json" + } + + /// Whether an ADC credentials file exists at the default path. + public static func credentialsAvailable() -> Bool { + FileManager.default.fileExists(atPath: defaultCredentialsPath()) + } +} diff --git a/Sources/AgentRunKit/LLM/VertexAnthropicClient.swift b/Sources/AgentRunKit/LLM/VertexAnthropicClient.swift new file mode 100644 index 0000000..1849104 --- /dev/null +++ b/Sources/AgentRunKit/LLM/VertexAnthropicClient.swift @@ -0,0 +1,214 @@ +import Foundation + +/// An LLM client for Anthropic Claude models served via Vertex AI. +/// +/// Uses OAuth2 Bearer token authentication (via ``GoogleAuthService`` or a custom +/// token provider closure) instead of Anthropic API key authentication. +/// +/// The wire format is the standard Anthropic Messages API with a +/// `"anthropic_version": "vertex-2023-10-16"` field injected into the request body. +/// Response parsing and SSE streaming are delegated to an internal ``AnthropicClient``. +/// +/// ```swift +/// let auth = try GoogleAuthService() +/// let client = VertexAnthropicClient( +/// projectID: "my-project", +/// location: "us-east5", +/// model: "claude-sonnet-4-6", +/// authService: auth +/// ) +/// ``` +public struct VertexAnthropicClient: LLMClient, Sendable { + public let contextWindowSize: Int? + + let anthropic: AnthropicClient + private let projectID: String + private let location: String + private let model: String + private let tokenProvider: @Sendable () async throws -> String + private let session: URLSession + private let retryPolicy: RetryPolicy + + public init( + projectID: String, + location: String, + model: String, + tokenProvider: @Sendable @escaping () async throws -> String, + maxTokens: Int = 8192, + contextWindowSize: Int? = nil, + session: URLSession = .shared, + retryPolicy: RetryPolicy = .default, + reasoningConfig: ReasoningConfig? = nil, + interleavedThinking: Bool = true, + cachingEnabled: Bool = false + ) { + self.projectID = projectID + self.location = location + self.model = model + self.tokenProvider = tokenProvider + self.session = session + self.retryPolicy = retryPolicy + self.contextWindowSize = contextWindowSize + anthropic = AnthropicClient( + apiKey: "", + model: model, + maxTokens: maxTokens, + contextWindowSize: contextWindowSize, + session: session, + retryPolicy: retryPolicy, + reasoningConfig: reasoningConfig, + interleavedThinking: interleavedThinking, + cachingEnabled: cachingEnabled + ) + } + + /// Convenience initializer that uses a ``GoogleAuthService`` for authentication. + public init( + projectID: String, + location: String, + model: String, + authService: GoogleAuthService, + maxTokens: Int = 8192, + contextWindowSize: Int? = nil, + session: URLSession = .shared, + retryPolicy: RetryPolicy = .default, + reasoningConfig: ReasoningConfig? = nil, + interleavedThinking: Bool = true, + cachingEnabled: Bool = false + ) { + self.init( + projectID: projectID, + location: location, + model: model, + tokenProvider: { try await authService.accessToken() }, + maxTokens: maxTokens, + contextWindowSize: contextWindowSize, + session: session, + retryPolicy: retryPolicy, + reasoningConfig: reasoningConfig, + interleavedThinking: interleavedThinking, + cachingEnabled: cachingEnabled + ) + } + + // MARK: - LLMClient + + public func generate( + messages: [ChatMessage], + tools: [ToolDefinition], + responseFormat: ResponseFormat?, + requestContext: RequestContext? + ) async throws -> AssistantMessage { + if responseFormat != nil { + throw AgentError.llmError(.other("VertexAnthropicClient does not support responseFormat")) + } + let request = try anthropic.buildRequest( + messages: messages, + tools: tools, + extraFields: requestContext?.extraFields ?? [:] + ) + let token = try await tokenProvider() + let urlRequest = try buildVertexURLRequest( + VertexAnthropicRequest(inner: request), stream: false, token: token + ) + let (data, httpResponse) = try await HTTPRetry.performData( + urlRequest: urlRequest, session: session, retryPolicy: retryPolicy + ) + requestContext?.onResponse?(httpResponse) + return try anthropic.parseResponse(data) + } + + public func stream( + messages: [ChatMessage], + tools: [ToolDefinition], + requestContext: RequestContext? + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + try await performStreamRequest( + messages: messages, + tools: tools, + extraFields: requestContext?.extraFields ?? [:], + onResponse: requestContext?.onResponse, + continuation: continuation + ) + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + + // MARK: - Streaming + + private func performStreamRequest( + messages: [ChatMessage], + tools: [ToolDefinition], + extraFields: [String: JSONValue], + onResponse: (@Sendable (HTTPURLResponse) -> Void)?, + continuation: AsyncThrowingStream.Continuation + ) async throws { + let request = try anthropic.buildRequest( + messages: messages, tools: tools, + stream: true, extraFields: extraFields + ) + let token = try await tokenProvider() + let urlRequest = try buildVertexURLRequest( + VertexAnthropicRequest(inner: request), stream: true, token: token + ) + let (bytes, httpResponse) = try await HTTPRetry.performStream( + urlRequest: urlRequest, session: session, retryPolicy: retryPolicy + ) + onResponse?(httpResponse) + + let state = AnthropicStreamState() + + try await processSSEStream( + bytes: bytes, + stallTimeout: retryPolicy.streamStallTimeout + ) { line in + try await anthropic.handleSSELine( + line, state: state, continuation: continuation + ) + } + continuation.finish() + } + + // MARK: - URL Construction + + func buildVertexURLRequest( + _ request: VertexAnthropicRequest, + stream: Bool, + token: String + ) throws -> URLRequest { + let action = stream ? "streamRawPredict" : "rawPredict" + let basePath = "v1/projects/\(projectID)/locations/\(location)" + + "/publishers/anthropic/models/\(model):\(action)" + let baseURL = URL(string: "https://\(location)-aiplatform.googleapis.com")! + let url = baseURL.appendingPathComponent(basePath) + + let headers = ["Authorization": "Bearer \(token)"] + return try buildJSONPostRequest(url: url, body: request, headers: headers) + } +} + +// MARK: - Vertex Anthropic Request Wrapper + +/// Wraps an ``AnthropicRequest`` and injects `"anthropic_version": "vertex-2023-10-16"` +/// into the encoded JSON body for Vertex AI compatibility. +struct VertexAnthropicRequest: Encodable { + static let vertexAnthropicVersion = "vertex-2023-10-16" + + let inner: AnthropicRequest + + func encode(to encoder: any Encoder) throws { + try inner.encode(to: encoder) + var container = encoder.container(keyedBy: DynamicCodingKey.self) + try container.encode( + Self.vertexAnthropicVersion, + forKey: DynamicCodingKey("anthropic_version") + ) + } +} diff --git a/Sources/AgentRunKit/LLM/VertexGoogleClient.swift b/Sources/AgentRunKit/LLM/VertexGoogleClient.swift new file mode 100644 index 0000000..d57b2dc --- /dev/null +++ b/Sources/AgentRunKit/LLM/VertexGoogleClient.swift @@ -0,0 +1,190 @@ +import Foundation + +/// An LLM client for Google Gemini models served via Vertex AI. +/// +/// Uses OAuth2 Bearer token authentication (via ``GoogleAuthService`` or a custom +/// token provider closure) instead of API key authentication. +/// +/// The wire format is identical to the Gemini API — this client delegates request +/// building, response parsing, and SSE handling to an internal ``GeminiClient``. +/// +/// ```swift +/// let auth = try GoogleAuthService() +/// let client = VertexGoogleClient( +/// projectID: "my-project", +/// location: "us-central1", +/// model: "gemini-2.5-pro", +/// authService: auth +/// ) +/// ``` +public struct VertexGoogleClient: LLMClient, Sendable { + public let contextWindowSize: Int? + + let gemini: GeminiClient + private let projectID: String + private let location: String + private let model: String + private let apiVersion: String + private let tokenProvider: @Sendable () async throws -> String + private let session: URLSession + private let retryPolicy: RetryPolicy + + public init( + projectID: String, + location: String, + model: String, + tokenProvider: @Sendable @escaping () async throws -> String, + maxOutputTokens: Int = 8192, + contextWindowSize: Int? = nil, + apiVersion: String = "v1beta1", + session: URLSession = .shared, + retryPolicy: RetryPolicy = .default, + reasoningConfig: ReasoningConfig? = nil + ) { + self.projectID = projectID + self.location = location + self.model = model + self.apiVersion = apiVersion + self.tokenProvider = tokenProvider + self.session = session + self.retryPolicy = retryPolicy + self.contextWindowSize = contextWindowSize + gemini = GeminiClient( + apiKey: "", + model: model, + maxOutputTokens: maxOutputTokens, + contextWindowSize: contextWindowSize, + session: session, + retryPolicy: retryPolicy, + reasoningConfig: reasoningConfig + ) + } + + /// Convenience initializer that uses a ``GoogleAuthService`` for authentication. + public init( + projectID: String, + location: String, + model: String, + authService: GoogleAuthService, + maxOutputTokens: Int = 8192, + contextWindowSize: Int? = nil, + apiVersion: String = "v1beta1", + session: URLSession = .shared, + retryPolicy: RetryPolicy = .default, + reasoningConfig: ReasoningConfig? = nil + ) { + self.init( + projectID: projectID, + location: location, + model: model, + tokenProvider: { try await authService.accessToken() }, + maxOutputTokens: maxOutputTokens, + contextWindowSize: contextWindowSize, + apiVersion: apiVersion, + session: session, + retryPolicy: retryPolicy, + reasoningConfig: reasoningConfig + ) + } + + // MARK: - LLMClient + + public func generate( + messages: [ChatMessage], + tools: [ToolDefinition], + responseFormat: ResponseFormat?, + requestContext: RequestContext? + ) async throws -> AssistantMessage { + let request = try gemini.buildRequest( + messages: messages, + tools: tools, + responseFormat: responseFormat, + extraFields: requestContext?.extraFields ?? [:] + ) + let token = try await tokenProvider() + let urlRequest = try buildVertexURLRequest(request, stream: false, token: token) + let (data, httpResponse) = try await HTTPRetry.performData( + urlRequest: urlRequest, session: session, retryPolicy: retryPolicy + ) + requestContext?.onResponse?(httpResponse) + return try gemini.parseResponse(data) + } + + public func stream( + messages: [ChatMessage], + tools: [ToolDefinition], + requestContext: RequestContext? + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + try await performStreamRequest( + messages: messages, + tools: tools, + extraFields: requestContext?.extraFields ?? [:], + onResponse: requestContext?.onResponse, + continuation: continuation + ) + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + } + + // MARK: - Streaming + + private func performStreamRequest( + messages: [ChatMessage], + tools: [ToolDefinition], + extraFields: [String: JSONValue], + onResponse: (@Sendable (HTTPURLResponse) -> Void)?, + continuation: AsyncThrowingStream.Continuation + ) async throws { + let request = try gemini.buildRequest( + messages: messages, tools: tools, extraFields: extraFields + ) + let token = try await tokenProvider() + let urlRequest = try buildVertexURLRequest(request, stream: true, token: token) + let (bytes, httpResponse) = try await HTTPRetry.performStream( + urlRequest: urlRequest, session: session, retryPolicy: retryPolicy + ) + onResponse?(httpResponse) + + let state = GeminiStreamState() + + try await processSSEStream( + bytes: bytes, + stallTimeout: retryPolicy.streamStallTimeout + ) { line in + try await gemini.handleSSELine( + line, state: state, continuation: continuation + ) + } + continuation.finish() + } + + // MARK: - URL Construction + + func buildVertexURLRequest( + _ request: GeminiRequest, + stream: Bool, + token: String + ) throws -> URLRequest { + let action = stream ? "streamGenerateContent" : "generateContent" + let basePath = "\(apiVersion)/projects/\(projectID)/locations/\(location)" + + "/publishers/google/models/\(model):\(action)" + let baseURL = URL(string: "https://\(location)-aiplatform.googleapis.com")! + var url = baseURL.appendingPathComponent(basePath) + + if stream { + var components = URLComponents(url: url, resolvingAgainstBaseURL: false)! + components.queryItems = [URLQueryItem(name: "alt", value: "sse")] + url = components.url! + } + + let headers = ["Authorization": "Bearer \(token)"] + return try buildJSONPostRequest(url: url, body: request, headers: headers) + } +} diff --git a/Tests/AgentRunKitTests/GoogleAuthServiceTests.swift b/Tests/AgentRunKitTests/GoogleAuthServiceTests.swift new file mode 100644 index 0000000..4fd6758 --- /dev/null +++ b/Tests/AgentRunKitTests/GoogleAuthServiceTests.swift @@ -0,0 +1,125 @@ +@testable import AgentRunKit +import Foundation +import Testing + +struct GoogleAuthServiceTests { + @Test + func defaultCredentialsPathFormat() { + let path = GoogleAuthService.defaultCredentialsPath() + #expect(path.hasSuffix("/.config/gcloud/application_default_credentials.json")) + #expect(path.hasPrefix("/")) + } + + @Test + func initWithMissingFileThrows() { + #expect(throws: GoogleAuthService.GoogleAuthError.self) { + _ = try GoogleAuthService(credentialsPath: "/nonexistent/path/credentials.json") + } + } + + @Test + func initWithMissingFileThrowsCorrectError() { + do { + _ = try GoogleAuthService(credentialsPath: "/tmp/does_not_exist_adc.json") + Issue.record("Expected error") + } catch let error as GoogleAuthService.GoogleAuthError { + if case let .credentialsFileNotFound(path) = error { + #expect(path == "/tmp/does_not_exist_adc.json") + } else { + Issue.record("Expected credentialsFileNotFound, got \(error)") + } + } catch { + Issue.record("Expected GoogleAuthError, got \(error)") + } + } + + @Test + func initWithInvalidJSONThrows() throws { + let tempDir = FileManager.default.temporaryDirectory + let tempFile = tempDir.appendingPathComponent("invalid_adc_\(UUID().uuidString).json") + try Data("not json".utf8).write(to: tempFile) + defer { try? FileManager.default.removeItem(at: tempFile) } + + do { + _ = try GoogleAuthService(credentialsPath: tempFile.path) + Issue.record("Expected error") + } catch let error as GoogleAuthService.GoogleAuthError { + guard case .decodingFailed = error else { + Issue.record("Expected decodingFailed, got \(error)") + return + } + } catch { + Issue.record("Expected GoogleAuthError, got \(error)") + } + } + + @Test + func initWithUnsupportedTypeThrows() throws { + let tempDir = FileManager.default.temporaryDirectory + let tempFile = tempDir.appendingPathComponent("sa_adc_\(UUID().uuidString).json") + let json = """ + { + "type": "service_account", + "client_id": "123", + "client_secret": "secret", + "refresh_token": "token" + } + """ + try Data(json.utf8).write(to: tempFile) + defer { try? FileManager.default.removeItem(at: tempFile) } + + do { + _ = try GoogleAuthService(credentialsPath: tempFile.path) + Issue.record("Expected error") + } catch let error as GoogleAuthService.GoogleAuthError { + guard case let .unsupportedCredentialType(type) = error else { + Issue.record("Expected unsupportedCredentialType, got \(error)") + return + } + #expect(type == "service_account") + } catch { + Issue.record("Expected GoogleAuthError, got \(error)") + } + } + + @Test + func initWithValidCredentialsSucceeds() throws { + let tempDir = FileManager.default.temporaryDirectory + let tempFile = tempDir.appendingPathComponent("valid_adc_\(UUID().uuidString).json") + let json = """ + { + "type": "authorized_user", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "refresh_token": "test-refresh-token" + } + """ + try Data(json.utf8).write(to: tempFile) + defer { try? FileManager.default.removeItem(at: tempFile) } + + // Should not throw + _ = try GoogleAuthService(credentialsPath: tempFile.path) + } + + @Test + func credentialsAvailableReturnsConsistentResult() { + let available = GoogleAuthService.credentialsAvailable() + let path = GoogleAuthService.defaultCredentialsPath() + let fileExists = FileManager.default.fileExists(atPath: path) + #expect(available == fileExists) + } + + @Test + func errorDescriptionsAreNonEmpty() throws { + let errors: [GoogleAuthService.GoogleAuthError] = [ + .credentialsFileNotFound(path: "/test"), + .unsupportedCredentialType("service_account"), + .refreshFailed(statusCode: 401, body: "Unauthorized"), + .decodingFailed("test error") + ] + for error in errors { + let description = try #require(error.errorDescription) + #expect(!description.isEmpty) + } + } +} diff --git a/Tests/AgentRunKitTests/VertexAnthropicClientTests.swift b/Tests/AgentRunKitTests/VertexAnthropicClientTests.swift new file mode 100644 index 0000000..2091d49 --- /dev/null +++ b/Tests/AgentRunKitTests/VertexAnthropicClientTests.swift @@ -0,0 +1,211 @@ +@testable import AgentRunKit +import Foundation +import Testing + +struct VertexAnthropicURLTests { + private func makeClient( + projectID: String = "test-project", + location: String = "us-east5", + model: String = "claude-sonnet-4-6" + ) -> VertexAnthropicClient { + VertexAnthropicClient( + projectID: projectID, + location: location, + model: model, + tokenProvider: { "test-token-123" } + ) + } + + @Test + func vertexURLHasCorrectPath() throws { + let client = makeClient() + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: false, token: "tok") + + let url = try #require(urlRequest.url) + #expect(url.absoluteString.contains("/projects/test-project/")) + #expect(url.absoluteString.contains("/locations/us-east5/")) + #expect(url.absoluteString.contains("/publishers/anthropic/models/claude-sonnet-4-6:rawPredict")) + #expect(url.host == "us-east5-aiplatform.googleapis.com") + } + + @Test + func vertexStreamURLUsesStreamRawPredict() throws { + let client = makeClient() + let request = try client.anthropic.buildRequest( + messages: [.user("Hi")], tools: [], stream: true + ) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: true, token: "tok") + + #expect(urlRequest.url?.absoluteString.contains(":streamRawPredict") == true) + } + + @Test + func bearerTokenInAuthHeader() throws { + let client = makeClient() + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: false, token: "my-oauth-token") + + #expect(urlRequest.value(forHTTPHeaderField: "Authorization") == "Bearer my-oauth-token") + } + + @Test + func noApiKeyHeader() throws { + let client = makeClient() + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: false, token: "tok") + + #expect(urlRequest.value(forHTTPHeaderField: "x-api-key") == nil) + #expect(urlRequest.value(forHTTPHeaderField: "anthropic-version") == nil) + } + + @Test + func httpMethodIsPost() throws { + let client = makeClient() + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: false, token: "tok") + + #expect(urlRequest.httpMethod == "POST") + #expect(urlRequest.value(forHTTPHeaderField: "Content-Type") == "application/json") + } + + @Test + func differentLocationsChangeHost() throws { + let client = makeClient(location: "europe-west4") + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let urlRequest = try client.buildVertexURLRequest(wrapped, stream: false, token: "tok") + + #expect(urlRequest.url?.host == "europe-west4-aiplatform.googleapis.com") + #expect(urlRequest.url?.absoluteString.contains("/locations/europe-west4/") == true) + } +} + +struct VertexAnthropicRequestTests { + @Test + func requestBodyContainsAnthropicVersion() throws { + let client = VertexAnthropicClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" } + ) + let request = try client.anthropic.buildRequest(messages: [.user("Hi")], tools: []) + let wrapped = VertexAnthropicRequest(inner: request) + let data = try JSONEncoder().encode(wrapped) + let json = try #require(JSONSerialization.jsonObject(with: data) as? [String: Any]) + + #expect(json["anthropic_version"] as? String == "vertex-2023-10-16") + } + + @Test + func requestBodyPreservesAnthropicFields() throws { + let client = VertexAnthropicClient( + projectID: "p", location: "l", model: "claude-sonnet-4-6", + tokenProvider: { "tok" }, + maxTokens: 4096 + ) + let tools = [ + ToolDefinition( + name: "search", description: "Search", + parametersSchema: .object(properties: ["q": .string()], required: ["q"]) + ) + ] + let request = try client.anthropic.buildRequest( + messages: [.system("Be helpful"), .user("Hello")], tools: tools + ) + let wrapped = VertexAnthropicRequest(inner: request) + let data = try JSONEncoder().encode(wrapped) + let json = try #require(JSONSerialization.jsonObject(with: data) as? [String: Any]) + + #expect(json["max_tokens"] as? Int == 4096) + #expect(json["model"] as? String == "claude-sonnet-4-6") + + let messages = json["messages"] as? [[String: Any]] + #expect(messages?.count == 1) + #expect(messages?[0]["role"] as? String == "user") + + let system = json["system"] as? [[String: Any]] + #expect(system?.count == 1) + #expect(system?[0]["text"] as? String == "Be helpful") + + let jsonTools = json["tools"] as? [[String: Any]] + #expect(jsonTools?.count == 1) + #expect(jsonTools?[0]["name"] as? String == "search") + + #expect(json["anthropic_version"] as? String == "vertex-2023-10-16") + } + + @Test + func streamFieldEncodesInBody() throws { + let client = VertexAnthropicClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" } + ) + let request = try client.anthropic.buildRequest( + messages: [.user("Hi")], tools: [], stream: true + ) + let wrapped = VertexAnthropicRequest(inner: request) + let data = try JSONEncoder().encode(wrapped) + let json = try #require(JSONSerialization.jsonObject(with: data) as? [String: Any]) + + #expect(json["stream"] as? Bool == true) + #expect(json["anthropic_version"] as? String == "vertex-2023-10-16") + } +} + +struct VertexAnthropicResponseTests { + @Test + func responseParsingDelegatedToAnthropic() throws { + let client = VertexAnthropicClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" } + ) + let json = """ + { + "id": "msg_001", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello from Vertex!"}], + "stop_reason": "end_turn", + "usage": {"input_tokens": 100, "output_tokens": 50} + } + """ + let msg = try client.anthropic.parseResponse(Data(json.utf8)) + #expect(msg.content == "Hello from Vertex!") + #expect(msg.tokenUsage?.input == 100) + #expect(msg.tokenUsage?.output == 50) + } + + @Test + func responseFormatThrows() async { + let client = VertexAnthropicClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" } + ) + let format = ResponseFormat.jsonSchema(TestVertexAnthropicOutput.self) + await #expect(throws: AgentError.self) { + _ = try await client.generate( + messages: [.user("Hi")], + tools: [], + responseFormat: format + ) + } + } +} + +struct VertexAnthropicRequestWrapperTests { + @Test + func vertexAnthropicVersionConstant() { + #expect(VertexAnthropicRequest.vertexAnthropicVersion == "vertex-2023-10-16") + } +} + +private enum TestVertexAnthropicOutput: SchemaProviding { + static var jsonSchema: JSONSchema { + .object(properties: ["value": .string()], required: ["value"]) + } +} diff --git a/Tests/AgentRunKitTests/VertexGoogleClientTests.swift b/Tests/AgentRunKitTests/VertexGoogleClientTests.swift new file mode 100644 index 0000000..fc1abff --- /dev/null +++ b/Tests/AgentRunKitTests/VertexGoogleClientTests.swift @@ -0,0 +1,149 @@ +@testable import AgentRunKit +import Foundation +import Testing + +struct VertexGoogleURLTests { + private func makeClient( + projectID: String = "test-project", + location: String = "us-central1", + model: String = "gemini-2.5-pro", + apiVersion: String = "v1beta1", + reasoningConfig: ReasoningConfig? = nil + ) -> VertexGoogleClient { + VertexGoogleClient( + projectID: projectID, + location: location, + model: model, + tokenProvider: { "test-token-123" }, + apiVersion: apiVersion, + reasoningConfig: reasoningConfig + ) + } + + @Test + func vertexURLHasCorrectPath() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + let url = try #require(urlRequest.url) + #expect(url.absoluteString.contains("/projects/test-project/")) + #expect(url.absoluteString.contains("/locations/us-central1/")) + #expect(url.absoluteString.contains("/publishers/google/models/gemini-2.5-pro:generateContent")) + #expect(url.host == "us-central1-aiplatform.googleapis.com") + } + + @Test + func vertexStreamURLHasStreamAction() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: true, token: "tok") + + let url = try #require(urlRequest.url) + #expect(url.absoluteString.contains(":streamGenerateContent")) + #expect(url.query?.contains("alt=sse") == true) + } + + @Test + func vertexURLUsesCorrectApiVersion() throws { + let client = makeClient(apiVersion: "v1") + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + #expect(urlRequest.url?.absoluteString.contains("/v1/projects/") == true) + } + + @Test + func noApiKeyInQueryParams() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + #expect(urlRequest.url?.query?.contains("key=") != true) + } + + @Test + func bearerTokenInAuthHeader() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "my-oauth-token") + + #expect(urlRequest.value(forHTTPHeaderField: "Authorization") == "Bearer my-oauth-token") + } + + @Test + func requestBodyHasContents() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hello")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + let body = try #require(urlRequest.httpBody) + let json = try #require(JSONSerialization.jsonObject(with: body) as? [String: Any]) + + let contents = json["contents"] as? [[String: Any]] + #expect(contents?.count == 1) + let parts = contents?[0]["parts"] as? [[String: Any]] + #expect(parts?[0]["text"] as? String == "Hello") + } + + @Test + func differentLocationsChangeHost() throws { + let client = makeClient(location: "europe-west1") + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + #expect(urlRequest.url?.host == "europe-west1-aiplatform.googleapis.com") + #expect(urlRequest.url?.absoluteString.contains("/locations/europe-west1/") == true) + } + + @Test + func httpMethodIsPost() throws { + let client = makeClient() + let request = try client.gemini.buildRequest(messages: [.user("Hi")], tools: []) + let urlRequest = try client.buildVertexURLRequest(request, stream: false, token: "tok") + + #expect(urlRequest.httpMethod == "POST") + #expect(urlRequest.value(forHTTPHeaderField: "Content-Type") == "application/json") + } +} + +struct VertexGoogleResponseTests { + @Test + func responseParsingDelegatedToGemini() throws { + let client = VertexGoogleClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" } + ) + let json = """ + { + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Hello from Vertex!"}] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 5 + } + } + """ + let msg = try client.gemini.parseResponse(Data(json.utf8)) + #expect(msg.content == "Hello from Vertex!") + #expect(msg.tokenUsage?.input == 10) + #expect(msg.tokenUsage?.output == 5) + } + + @Test + func thinkingConfigPassedThrough() { + let client = VertexGoogleClient( + projectID: "p", location: "l", model: "m", + tokenProvider: { "tok" }, + reasoningConfig: .high + ) + let config = client.gemini.buildThinkingConfig() + #expect(config?.thinkingLevel == "HIGH") + #expect(config?.includeThoughts == true) + } +}