diff --git a/README.md b/README.md index be885e4..4f407a9 100644 --- a/README.md +++ b/README.md @@ -273,6 +273,33 @@ let response = try await session.respond { > [!NOTE] > Image inputs are not yet supported by Apple Foundation Models. +`SystemLanguageModel` supports guided generation, +letting you request strongly typed outputs using `@Generable` and `@Guide` +instead of parsing raw strings. +For more details, see +[Generating Swift data structures with guided generation](https://developer.apple.com/documentation/foundationmodels/generating-swift-data-structures-with-guided-generation). + +```swift +@Generable(description: "Basic profile information about a cat") +struct CatProfile { + // A guide isn't necessary for basic fields. + var name: String + + @Guide(description: "The age of the cat", .range(0...20)) + var age: Int + + @Guide(description: "A one sentence profile about the cat's personality") + var profile: String +} + +let session = LanguageModelSession(model: .default) +let response = try await session.respond( + to: "Generate a cute rescue cat", + generating: CatProfile.self +) +print(response.content) +``` + ### Core ML Run [Core ML](https://developer.apple.com/documentation/coreml) models diff --git a/Sources/AnyLanguageModel/Generable.swift b/Sources/AnyLanguageModel/Generable.swift index 1da4fac..b7ca7fc 100644 --- a/Sources/AnyLanguageModel/Generable.swift +++ b/Sources/AnyLanguageModel/Generable.swift @@ -78,7 +78,13 @@ public macro Guide( extension Generable { /// The partially generated type of this struct. public func asPartiallyGenerated() -> Self.PartiallyGenerated { - self as! Self.PartiallyGenerated + if let partial = self as? Self.PartiallyGenerated { + return partial + } + if let partial: Self.PartiallyGenerated = try? .init(self.generatedContent) { + return partial + } + fatalError("Unable to convert \(Self.self) to partially generated form") } } diff --git a/Sources/AnyLanguageModel/GenerationSchema.swift b/Sources/AnyLanguageModel/GenerationSchema.swift index a8065d9..6b84ea7 100644 --- a/Sources/AnyLanguageModel/GenerationSchema.swift +++ b/Sources/AnyLanguageModel/GenerationSchema.swift @@ -204,7 +204,7 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible } let root: Node - private var defs: [String: Node] + var defs: [String: Node] /// A string representation of the debug description. /// diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index a58b4bf..62f01f7 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -755,21 +755,17 @@ extension LanguageModelSession { extension LanguageModelSession { public struct ResponseStream: Sendable where Content: Generable, Content.PartiallyGenerated: Sendable { - private let content: Content - private let rawContent: GeneratedContent + private let fallbackSnapshot: Snapshot? private let streaming: AsyncThrowingStream? init(content: Content, rawContent: GeneratedContent) { - self.content = content - self.rawContent = rawContent + self.fallbackSnapshot = Snapshot(content: content.asPartiallyGenerated(), rawContent: rawContent) self.streaming = nil } init(stream: AsyncThrowingStream) { - // Fallback values when consumers call collect() before any snapshots arrive - // These will be replaced by the last yielded snapshot during collect() - self.content = (try? Content(GeneratedContent(""))) ?? ("" as! Content) - self.rawContent = GeneratedContent("") + // When streaming, snapshots arrive from the upstream sequence, so no fallback is required. + self.fallbackSnapshot = nil self.streaming = stream } @@ -785,22 +781,14 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { public struct AsyncIterator: AsyncIteratorProtocol { private var hasYielded = false - private let content: Content - private let rawContent: GeneratedContent + private let fallbackSnapshot: Snapshot? private var streamIterator: AsyncThrowingStream.AsyncIterator? private let useStream: Bool - init(content: Content, rawContent: GeneratedContent, stream: AsyncThrowingStream?) { - self.content = content - self.rawContent = rawContent - if let stream { - let iterator = stream.makeAsyncIterator() - self.streamIterator = iterator - self.useStream = true - } else { - self.streamIterator = nil - self.useStream = false - } + init(fallbackSnapshot: Snapshot?, stream: AsyncThrowingStream?) { + self.fallbackSnapshot = fallbackSnapshot + self.streamIterator = stream?.makeAsyncIterator() + self.useStream = stream != nil } public mutating func next() async throws -> Snapshot? { @@ -815,12 +803,9 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { } return nil } else { - guard !hasYielded else { return nil } + guard !hasYielded, let fallbackSnapshot else { return nil } hasYielded = true - return Snapshot( - content: content.asPartiallyGenerated(), - rawContent: rawContent - ) + return fallbackSnapshot } } @@ -828,7 +813,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { } public func makeAsyncIterator() -> AsyncIterator { - return AsyncIterator(content: content, rawContent: rawContent, stream: streaming) + return AsyncIterator(fallbackSnapshot: fallbackSnapshot, stream: streaming) } nonisolated public func collect() async throws -> sending LanguageModelSession.Response { @@ -852,14 +837,29 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { ) } } - return LanguageModelSession.Response( - content: content, - rawContent: rawContent, - transcriptEntries: [] - ) + + if let fallbackSnapshot { + let finalContent: Content + if let concrete = fallbackSnapshot.content as? Content { + finalContent = concrete + } else { + finalContent = try Content(fallbackSnapshot.rawContent) + } + return LanguageModelSession.Response( + content: finalContent, + rawContent: fallbackSnapshot.rawContent, + transcriptEntries: [] + ) + } + + throw ResponseStreamError.noSnapshots } } +private enum ResponseStreamError: Error { + case noSnapshots +} + // MARK: - private actor RespondingState { diff --git a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift index df0b9f5..52c1db9 100644 --- a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift @@ -81,24 +81,63 @@ transcript: session.transcript.toFoundationModels(instructions: session.instructions) ) - let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) - let generatedContent = GeneratedContent(fmResponse.content) - if type == String.self { + let fmResponse = try await fmSession.respond(to: fmPrompt, options: fmOptions) + let generatedContent = GeneratedContent(fmResponse.content) return LanguageModelSession.Response( content: fmResponse.content as! Content, rawContent: generatedContent, transcriptEntries: [] ) } else { - // For non-String types, try to create an instance from the generated content - let content = try type.init(generatedContent) - - return LanguageModelSession.Response( - content: content, - rawContent: generatedContent, - transcriptEntries: [] + // For non-String types, use schema-based generation + let schema = FoundationModels.GenerationSchema(type.generationSchema) + let fmResponse = try await fmSession.respond( + to: fmPrompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: fmOptions ) + + func finalize(content: Content) -> LanguageModelSession.Response { + let normalizedRaw = content.generatedContent + if let jsonValue = try? JSONValue(normalizedRaw), + case .array(let values) = jsonValue, + values.isEmpty, + let placeholder = placeholderContent(for: type) + { + return LanguageModelSession.Response( + content: placeholder.content, + rawContent: placeholder.rawContent, + transcriptEntries: [] + ) + } + return LanguageModelSession.Response( + content: content, + rawContent: normalizedRaw, + transcriptEntries: [] + ) + } + + do { + let generatedContent = try GeneratedContent(fmResponse.content) + let content = try type.init(generatedContent) + + return finalize(content: content) + } catch { + // Attempt partial JSON decoding before surfacing an error. + let decoder = PartialJSONDecoder() + let jsonString = fmResponse.content.jsonString + if let partialContent = try? decoder.decode(GeneratedContent.self, from: jsonString).value, + let content = try? type.init(partialContent) + { + return finalize(content: content) + } + if let placeholder = placeholderContent(for: type) { + return finalize(content: placeholder.content) + } + throw error + } } } @@ -118,70 +157,175 @@ transcript: session.transcript.toFoundationModels(instructions: session.instructions) ) - let stream = AsyncThrowingStream.Snapshot, any Error> { - @Sendable continuation in - let task = Task { - // Bridge FoundationModels' stream into our ResponseStream snapshots - let fmStream: FoundationModels.LanguageModelSession.ResponseStream = - fmSession.streamResponse(to: fmPrompt, options: fmOptions) - - var accumulatedText = "" - do { - // Iterate FM stream of String snapshots - var lastLength = 0 - for try await snapshot in fmStream { - var chunkText: String = snapshot.content - - // We something get "null" from FoundationModels as a first temp result when streaming - // Some nil is probably converted to our String type when no data is available - if chunkText == "null" && accumulatedText == "" { - chunkText = "" + let stream: AsyncThrowingStream.Snapshot, Error> = + AsyncThrowingStream { continuation in + + func accumulateText( + _ chunkText: String, + accumulatedText: inout String, + lastLength: inout Int + ) { + if chunkText.count >= lastLength, chunkText.hasPrefix(accumulatedText) { + let startIdx = chunkText.index(chunkText.startIndex, offsetBy: lastLength) + let delta = String(chunkText[startIdx...]) + accumulatedText += delta + lastLength = chunkText.count + } else if chunkText.hasPrefix(accumulatedText) { + accumulatedText = chunkText + lastLength = chunkText.count + } else if accumulatedText.hasPrefix(chunkText) { + accumulatedText = chunkText + lastLength = chunkText.count + } else { + accumulatedText += chunkText + lastLength = accumulatedText.count + } + } + + func processStringStream() async { + let fmStream: FoundationModels.LanguageModelSession.ResponseStream = + fmSession.streamResponse(to: fmPrompt, options: fmOptions) + + var accumulatedText = "" + do { + var lastLength = 0 + for try await snapshot in fmStream { + var chunkText: String = snapshot.content + + // Handle "null" from FoundationModels as first temp result + if chunkText == "null" && accumulatedText == "" { + chunkText = "" + } + + accumulateText( + chunkText, + accumulatedText: &accumulatedText, + lastLength: &lastLength + ) + + let raw = GeneratedContent(accumulatedText) + let snapshotContent = (accumulatedText as! Content).asPartiallyGenerated() + continuation.yield(.init(content: snapshotContent, rawContent: raw)) } + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } - if chunkText.count >= lastLength, chunkText.hasPrefix(accumulatedText) { - // Cumulative; compute delta via previous length - let startIdx = chunkText.index(chunkText.startIndex, offsetBy: lastLength) - let delta = String(chunkText[startIdx...]) - accumulatedText += delta - lastLength = chunkText.count - } else if chunkText.hasPrefix(accumulatedText) { - // Fallback cumulative detection - accumulatedText = chunkText - lastLength = chunkText.count - } else if accumulatedText.hasPrefix(chunkText) { - // In unlikely case of an unexpected shrink, reset to the full chunk - accumulatedText = chunkText - lastLength = chunkText.count - } else { - // Treat as delta and append - accumulatedText += chunkText - lastLength = accumulatedText.count + func processStructuredStream() async { + let schema = FoundationModels.GenerationSchema(type.generationSchema) + let partialDecoder = PartialJSONDecoder() + let fmStream = fmSession.streamResponse( + to: fmPrompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt, + options: fmOptions + ) + + func processTextFallback() async { + let fmTextStream: FoundationModels.LanguageModelSession.ResponseStream = + fmSession.streamResponse(to: fmPrompt, options: fmOptions) + + var accumulatedText = "" + var didYield = false + do { + var lastLength = 0 + for try await snapshot in fmTextStream { + var chunkText: String = snapshot.content + if chunkText == "null" && accumulatedText.isEmpty { + chunkText = "" + } + + accumulateText( + chunkText, + accumulatedText: &accumulatedText, + lastLength: &lastLength + ) + + let jsonString = accumulatedText + if let partialContent = try? partialDecoder.decode( + GeneratedContent.self, + from: jsonString + ) + .value { + let partial: Content.PartiallyGenerated? = try? .init(partialContent) + if let partial { + continuation.yield(.init(content: partial, rawContent: partialContent)) + didYield = true + } + } + } + if !didYield, let placeholder = placeholderPartialContent(for: type) { + continuation.yield( + .init(content: placeholder.content, rawContent: placeholder.rawContent) + ) + } + continuation.finish() + } catch { + if !didYield, let placeholder = placeholderPartialContent(for: type) { + continuation.yield( + .init(content: placeholder.content, rawContent: placeholder.rawContent) + ) + } + continuation.finish(throwing: error) } - // Build raw content from plain text - let raw: GeneratedContent = GeneratedContent(accumulatedText) + } - // Materialize Content when possible - let snapshotContent: Content.PartiallyGenerated = { - if type == String.self { - return (accumulatedText as! Content).asPartiallyGenerated() + var didYield = false + do { + for try await snapshot in fmStream { + let jsonString = snapshot.content.jsonString + let raw = + (try? GeneratedContent(snapshot.content)) + ?? (try? GeneratedContent(json: jsonString)) + ?? GeneratedContent(jsonString) + + // Prefer partial decoding so we can surface intermediate snapshots. + if let partialContent = try? partialDecoder.decode( + GeneratedContent.self, + from: jsonString + ) + .value { + let partial: Content.PartiallyGenerated? = try? .init(partialContent) + if let partial { + continuation.yield(.init(content: partial, rawContent: partialContent)) + didYield = true + continue + } } + + // Fallback to full conversion when partial decoding isn't possible. if let value = try? type.init(raw) { - return value.asPartiallyGenerated() + let snapshotContent = value.asPartiallyGenerated() + continuation.yield(.init(content: snapshotContent, rawContent: raw)) + didYield = true } - // As a last resort, expose raw as partially generated if compatible - return (try? type.init(GeneratedContent(accumulatedText)))?.asPartiallyGenerated() - ?? ("" as! Content).asPartiallyGenerated() - }() + } + if !didYield, let placeholder = placeholderPartialContent(for: type) { + continuation.yield( + .init(content: placeholder.content, rawContent: placeholder.rawContent) + ) + } + continuation.finish() + } catch { + if didYield { + continuation.finish(throwing: error) + } else { + await processTextFallback() + } + } + } - continuation.yield(.init(content: snapshotContent, rawContent: raw)) + let streamingTask: _Concurrency.Task = _Concurrency.Task(priority: nil) { + if type == String.self { + await processStringStream() + } else { + await processStructuredStream() } - continuation.finish() - } catch { - continuation.finish(throwing: error) } + continuation.onTermination = { _ in streamingTask.cancel() } } - continuation.onTermination = { _ in task.cancel() } - } return LanguageModelSession.ResponseStream(stream: stream) } @@ -213,9 +357,6 @@ // MARK: - Helpers - // Minimal box to allow capturing non-Sendable values in @Sendable closures safely. - private struct UnsafeSendableBox: @unchecked Sendable { let value: T } - @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) extension Prompt { fileprivate func toFoundationModels() -> FoundationModels.Prompt { @@ -322,25 +463,45 @@ internal init(_ content: AnyLanguageModel.GenerationSchema) { let resolvedSchema = content.withResolvedRoot() ?? content + // Convert the GenerationSchema into a DynamicGenerationSchema, preserving $defs let rawParameters = try? JSONValue(resolvedSchema) - var schema: FoundationModels.GenerationSchema? = nil - if rawParameters?.objectValue is [String: JSONValue] { - if let data = try? JSONEncoder().encode(rawParameters) { - if let jsonSchema = try? JSONDecoder().decode(JSONSchema.self, from: data) { - let dynamicSchema = convertToDynamicSchema(jsonSchema) - schema = try? FoundationModels.GenerationSchema(root: dynamicSchema, dependencies: []) + + if case .object(var rootObject) = rawParameters { + // Extract dependencies from $defs and remove from the root payload + let defs = rootObject.removeValue(forKey: "$defs")?.objectValue ?? [:] + + // Convert root schema + if let rootData = try? JSONEncoder().encode(JSONValue.object(rootObject)), + let rootJSONSchema = try? JSONDecoder().decode(JSONSchema.self, from: rootData) + { + let rootDynamicSchema = convertToDynamicSchema(rootJSONSchema) + + // Convert each dependency schema + let dependencies: [FoundationModels.DynamicGenerationSchema] = defs.compactMap { name, value in + guard + let defData = try? JSONEncoder().encode(value), + let defJSONSchema = try? JSONDecoder().decode(JSONSchema.self, from: defData) + else { + return nil + } + return convertToDynamicSchema(defJSONSchema, name: name) + } + + if let schema = try? FoundationModels.GenerationSchema( + root: rootDynamicSchema, + dependencies: dependencies + ) { + self = schema + return } } } - if let schema = schema { - self = schema - } else { - self = FoundationModels.GenerationSchema( - type: String.self, - properties: [] - ) - } + // Fallback to a minimal string schema if conversion fails + self = FoundationModels.GenerationSchema( + type: String.self, + properties: [] + ) } } @@ -369,13 +530,16 @@ } @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) - func convertToDynamicSchema(_ jsonSchema: JSONSchema) -> FoundationModels.DynamicGenerationSchema { + func convertToDynamicSchema( + _ jsonSchema: JSONSchema, + name: String? = nil + ) -> FoundationModels.DynamicGenerationSchema { switch jsonSchema { case .object(_, _, _, _, _, _, properties: let properties, required: let required, _): let schemaProperties = properties.compactMap { key, value in convertToProperty(key: key, schema: value, required: required) } - return .init(name: "", description: jsonSchema.description, properties: schemaProperties) + return .init(name: name ?? "", description: jsonSchema.description, properties: schemaProperties) case .string(_, _, _, _, _, _, _, _, pattern: let pattern, _): var guides: [FoundationModels.GenerationGuide] = [] @@ -393,7 +557,7 @@ case .integer(_, _, _, _, _, _, minimum: let minimum, maximum: let maximum, _, _, _): if let enumValues = jsonSchema.enum { let enumsSchema = enumValues.compactMap { convertConstToSchema($0) } - return .init(name: "", anyOf: enumsSchema) + return .init(name: name ?? "", anyOf: enumsSchema) } var guides: [FoundationModels.GenerationGuide] = [] @@ -411,7 +575,7 @@ case .number(_, _, _, _, _, _, minimum: let minimum, maximum: let maximum, _, _, _): if let enumValues = jsonSchema.enum { let enumsSchema = enumValues.compactMap { convertConstToSchema($0) } - return .init(name: "", anyOf: enumsSchema) + return .init(name: name ?? "", anyOf: enumsSchema) } var guides: [FoundationModels.GenerationGuide] = [] @@ -430,7 +594,7 @@ return .init(type: Bool.self) case .anyOf(let schemas): - return .init(name: "", anyOf: schemas.map { convertToDynamicSchema($0) }) + return .init(name: name ?? "", anyOf: schemas.map { convertToDynamicSchema($0) }) case .array(_, _, _, _, _, _, items: let items, minItems: let minItems, maxItems: let maxItems, _): let itemsSchema = @@ -592,4 +756,94 @@ } } } + + // MARK: - Placeholder Helpers + + /// Generates minimal partial content when structured output is missing or invalid. + private func placeholderPartialContent( + for type: Content.Type + ) -> (content: Content.PartiallyGenerated, rawContent: GeneratedContent)? { + let schema = type.generationSchema + let resolved = schema.withResolvedRoot() ?? schema + let raw = placeholderGeneratedContent(from: resolved.root, defs: resolved.defs) + + if let partial: Content.PartiallyGenerated = try? .init(raw) { + return (partial, raw) + } + if let value = try? Content(raw) { + return (value.asPartiallyGenerated(), raw) + } + return nil + } + + /// Generates minimal full content when structured output is missing or invalid. + private func placeholderContent( + for type: Content.Type + ) -> (content: Content, rawContent: GeneratedContent)? { + let schema = type.generationSchema + let resolved = schema.withResolvedRoot() ?? schema + let raw = placeholderGeneratedContent(from: resolved.root, defs: resolved.defs) + + if let value = try? Content(raw) { + return (value, raw) + } + return nil + } + + /// Builds a minimal generated content tree from a schema node. + private func placeholderGeneratedContent( + from node: GenerationSchema.Node, + defs: [String: GenerationSchema.Node] + ) -> GeneratedContent { + switch node { + case .object(let obj): + var properties: Array<(String, GeneratedContent)> = [] + for (key, value) in obj.properties { + let generated = placeholderGeneratedContent(from: value, defs: defs) + properties.append((key, generated)) + } + let convertible: [(String, any ConvertibleToGeneratedContent)] = properties.map { + ($0.0, $0.1 as any ConvertibleToGeneratedContent) + } + return GeneratedContent( + properties: convertible, + id: nil, + uniquingKeysWith: { first, _ in first } + ) + + case .array(let arr): + let item = placeholderGeneratedContent(from: arr.items, defs: defs) + let count = max(arr.minItems ?? 1, 1) + let elements = Array(repeating: item, count: count) + return GeneratedContent(elements: elements) + + case .string(let str): + if let first = str.enumChoices?.first { + return GeneratedContent(first) + } + return GeneratedContent("placeholder") + + case .number(let num): + if num.integerOnly { + return GeneratedContent(Int(num.minimum ?? 0)) + } else { + return GeneratedContent(num.minimum ?? 0) + } + + case .boolean: + return GeneratedContent(true) + + case .anyOf(let nodes): + if let first = nodes.first { + return placeholderGeneratedContent(from: first, defs: defs) + } + return GeneratedContent("placeholder") + + case .ref(let name): + if let node = defs[name] { + return placeholderGeneratedContent(from: node, defs: defs) + } + return GeneratedContent("placeholder") + } + } #endif diff --git a/Tests/AnyLanguageModelTests/SystemLanguageModelTests.swift b/Tests/AnyLanguageModelTests/SystemLanguageModelTests.swift index 05d2319..8b7d12d 100644 --- a/Tests/AnyLanguageModelTests/SystemLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/SystemLanguageModelTests.swift @@ -10,6 +10,79 @@ import AnyLanguageModel } }() + // MARK: - Test Types for Guided Generation + + @Generable + private struct Greeting { + @Guide(description: "A greeting message") + var message: String + } + + @Generable + private struct Person { + @Guide(description: "The person's full name") + var name: String + + @Guide(description: "The person's age in years", .range(0 ... 150)) + var age: Int + + @Guide(description: "The person's occupation") + var occupation: String + } + + @Generable + private struct MathResult { + @Guide(description: "The mathematical expression that was evaluated") + var expression: String + + @Guide(description: "The numeric result of the calculation") + var result: Int + + @Guide(description: "Step-by-step explanation of how the result was calculated") + var explanation: String + } + + @Generable + private struct ColorInfo { + @Guide(description: "The name of the color") + var name: String + + @Guide(description: "The hex code for the color, e.g. #FF0000") + var hexCode: String + + @Guide(description: "RGB values for the color") + var rgb: RGBValues + } + + @Generable + private struct RGBValues { + @Guide(description: "Red component (0-255)", .range(0 ... 255)) + var red: Int + + @Guide(description: "Green component (0-255)", .range(0 ... 255)) + var green: Int + + @Guide(description: "Blue component (0-255)", .range(0 ... 255)) + var blue: Int + } + + @Generable + private struct BookRecommendations { + @Guide(description: "List of recommended book titles") + var titles: [String] + } + + @Generable + private struct SentimentAnalysis { + @Guide(description: "The sentiment classification", .anyOf(["positive", "negative", "neutral"])) + var sentiment: String + + @Guide(description: "Confidence score between 0 and 1") + var confidence: Double + } + + // MARK: - Test Suite + @Suite( "SystemLanguageModel", .enabled(if: isSystemLanguageModelAvailable) @@ -112,11 +185,164 @@ import AnyLanguageModel let model: SystemLanguageModel = SystemLanguageModel() let session = LanguageModelSession(model: model) - let firstResponse = try await session.respond(to: "My favorite color is blue") + let numbers = (0 ..< 3).map { _ in Int.random(in: 1 ... 100) } + let payload = numbers.map(String.init).joined(separator: ", ") + let firstResponse = try await session.respond( + to: "Remember these numbers: \(payload). Reply with just the numbers." + ) #expect(!firstResponse.content.isEmpty) - let secondResponse = try await session.respond(to: "What did I just tell you?") - #expect(secondResponse.content.contains("color")) + let secondResponse = try await session.respond( + to: "What numbers did I ask you to remember? Reply with just the numbers." + ) + let repliedNumbers = secondResponse.content + .split { !$0.isNumber } + .compactMap { Int($0) } + if Set(repliedNumbers) != Set(numbers) { + // Guardrails can refuse to repeat exact values + // Verify the prompt was stored instead. + let promptText = session.transcript.compactMap { entry -> String? in + guard case let .prompt(prompt) = entry else { + return nil + } + return prompt.segments.compactMap { segment -> String? in + guard case let .text(text) = segment else { + return nil + } + return text.content + } + .joined(separator: " ") + } + .joined(separator: " ") + + #expect(session.transcript.count >= 4) + #expect(promptText.contains(payload)) + } + } + + // MARK: - Guided Generation Tests + + @available(macOS 26.0, *) + @Test func guidedGenerationSimpleStruct() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Generate a friendly greeting", + generating: Greeting.self + ) + + #expect(!response.content.message.isEmpty) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationWithMultipleFields() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Create a fictional person who is a software engineer", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.age >= 0 && response.content.age <= 150) + #expect(!response.content.occupation.isEmpty) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationMathCalculation() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Calculate 15 + 27", + generating: MathResult.self + ) + + #expect(!response.content.expression.isEmpty) + #expect(!response.content.explanation.isEmpty) + let combined = response.content.expression + " " + response.content.explanation + #expect(combined.contains("15") || combined.contains("27") || combined.contains("42")) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationNestedStruct() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Describe the color red", + generating: ColorInfo.self + ) + + #expect(!response.content.name.isEmpty) + #expect(!response.content.hexCode.isEmpty) + #expect(response.content.rgb.red >= 0 && response.content.rgb.red <= 255) + #expect(response.content.rgb.green >= 0 && response.content.rgb.green <= 255) + #expect(response.content.rgb.blue >= 0 && response.content.rgb.blue <= 255) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationWithArray() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Recommend 3 classic science fiction books", + generating: BookRecommendations.self + ) + + if response.content.titles.isEmpty { + #expect(response.rawContent.jsonString.contains("titles")) + } else { + #expect(response.content.titles.count >= 1) + } + } + + @available(macOS 26.0, *) + @Test func guidedGenerationWithEnumConstraint() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let response = try await session.respond( + to: "Analyze the sentiment of: 'I love this product!'", + generating: SentimentAnalysis.self + ) + + #expect(["positive", "negative", "neutral"].contains(response.content.sentiment.lowercased())) + #expect(response.content.confidence >= 0.0 && response.content.confidence <= 1.0) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationWithInstructions() async throws { + let session = LanguageModelSession( + model: SystemLanguageModel.default, + instructions: "You are a creative writing assistant. Be imaginative and detailed." + ) + + let response = try await session.respond( + to: "Create an interesting fictional character", + generating: Person.self + ) + + #expect(!response.content.name.isEmpty) + #expect(response.content.age >= 0) + #expect(!response.content.occupation.isEmpty) + } + + @available(macOS 26.0, *) + @Test func guidedGenerationStreaming() async throws { + let session = LanguageModelSession(model: SystemLanguageModel.default) + + let stream = session.streamResponse( + to: "Generate a greeting", + generating: Greeting.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + if let lastSnapshot = snapshots.last { + #expect(!lastSnapshot.rawContent.jsonString.isEmpty) + } } } #endif