diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java index 4b6747fb1..462af468f 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -21,18 +21,36 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Data Transfer Objects for Chat Completion API requests. * + *

Can be used to translate from a {@link LlmRequest} into a {@link ChatCompletionsRequest} using + * {@link #fromLlmRequest(LlmRequest, boolean)}. + * *

See * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create */ @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) -final class ChatCompletionsRequest { +public final class ChatCompletionsRequest { /** * See @@ -249,6 +267,314 @@ final class ChatCompletionsRequest { @JsonProperty("extra_body") public Map extraBody; + private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class); + private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + + /** + * Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for + * /chat/completions compatible endpoints. + * + * @param llmRequest The internal source request containing contents, configuration, and tool + * definitions. + * @param responseStreaming True if the request asks for a streaming response. + * @return A populated ChatCompletionsRequest ready for JSON serialization. + */ + public static ChatCompletionsRequest fromLlmRequest( + LlmRequest llmRequest, boolean responseStreaming) { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = llmRequest.model().orElse(""); + request.stream = responseStreaming; + if (responseStreaming) { + StreamOptions options = new StreamOptions(); + options.includeUsage = true; + request.streamOptions = options; + } + + boolean isOSeries = request.model.matches("^o\\d+(?:-.*)?$"); + + List messages = new ArrayList<>(); + + llmRequest.config().ifPresent(config -> handleSystemInstruction(config, isOSeries, messages)); + + for (Content content : llmRequest.contents()) { + handleContent(content, messages); + } + + request.messages = ImmutableList.copyOf(messages); + + llmRequest + .config() + .ifPresent( + config -> { + handleConfigOptions(config, request); + handleTools(config, request); + }); + + return request; + } + + /** + * Updates the messages list based on the provided system instruction configuration. + * + * @param config The content generation configuration that may contain a system instruction. + * @param isOSeries True if the target model belongs to the OpenAI o-series (e.g., o1, o3), which + * requires the "developer" role instead of the standard "system" role. + * @param messages The list of messages to append the mapped instruction to. + */ + private static void handleSystemInstruction( + GenerateContentConfig config, boolean isOSeries, List messages) { + if (config.systemInstruction().isPresent()) { + Message systemMsg = new Message(); + systemMsg.role = isOSeries ? "developer" : "system"; + systemMsg.content = new MessageContent(config.systemInstruction().get().text()); + messages.add(systemMsg); + } + } + + /** + * Updates the messages list based on the provided content. + * + * @param content The incoming content containing parts to map. + * @param messages The list of messages to append the mapped content to. + */ + private static void handleContent(Content content, List messages) { + Message msg = new Message(); + String role = content.role().orElse("user"); + msg.role = role.equals("model") ? "assistant" : role; + + List contentParts = new ArrayList<>(); + List toolCalls = new ArrayList<>(); + List toolResponses = new ArrayList<>(); + + content + .parts() + .ifPresent( + parts -> { + for (Part part : parts) { + if (part.text().isPresent()) { + handleTextPart(part, contentParts); + } else if (part.inlineData().isPresent()) { + handleInlineDataPart(part, contentParts); + } else if (part.fileData().isPresent()) { + handleFileDataPart(part, contentParts); + } else if (part.functionCall().isPresent()) { + handleFunctionCallPart(part, toolCalls); + } else if (part.functionResponse().isPresent()) { + handleFunctionResponsePart(part, toolResponses); + } else if (part.executableCode().isPresent()) { + logger.warn("Executable code is not supported in Chat Completion conversion"); + } else if (part.codeExecutionResult().isPresent()) { + logger.warn( + "Code execution result is not supported in Chat Completion conversion"); + } + } + }); + + if (!toolResponses.isEmpty()) { + messages.addAll(toolResponses); + } else { + if (!toolCalls.isEmpty()) { + msg.toolCalls = ImmutableList.copyOf(toolCalls); + } + if (!contentParts.isEmpty()) { + if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) { + msg.content = new MessageContent(contentParts.get(0).text); + } else { + msg.content = new MessageContent(ImmutableList.copyOf(contentParts)); + } + } + messages.add(msg); + } + } + + /** + * Updates the contentParts list based on the provided text part. + * + * @param part The input part containing simple text. + * @param contentParts The list of content parts to append the mapped text to. + */ + private static void handleTextPart(Part part, List contentParts) { + ContentPart textPart = new ContentPart(); + textPart.type = "text"; + textPart.text = part.text().get(); + contentParts.add(textPart); + } + + /** + * Updates the contentParts list based on the provided inline data part. + * + * @param part The input part containing base64 inline data. + * @param contentParts The list of content parts to append the mapped image URL to. + */ + private static void handleInlineDataPart(Part part, List contentParts) { + ContentPart imgPart = new ContentPart(); + imgPart.type = "image_url"; + ImageUrl imageUrl = new ImageUrl(); + imageUrl.url = + "data:" + + part.inlineData().get().mimeType().orElse("image/jpeg") + + ";base64," + + Base64.getEncoder().encodeToString(part.inlineData().get().data().get()); + imgPart.imageUrl = imageUrl; + contentParts.add(imgPart); + } + + /** + * Updates the contentParts list based on the provided file data part. + * + * @param part The input part referencing a stored file via URI. + * @param contentParts The list of content parts to append the mapped image URL to. + */ + private static void handleFileDataPart(Part part, List contentParts) { + ContentPart imgPart = new ContentPart(); + imgPart.type = "image_url"; + ImageUrl imageUrl = new ImageUrl(); + imageUrl.url = part.fileData().get().fileUri().orElse(""); + imgPart.imageUrl = imageUrl; + contentParts.add(imgPart); + } + + /** + * Updates the toolCalls list based on the provided function call part. + * + * @param part The input part containing a requested function call or invocation. + * @param toolCalls The list of tool calls to append the mapped function call to. + */ + private static void handleFunctionCallPart( + Part part, List toolCalls) { + com.google.genai.types.FunctionCall fc = part.functionCall().get(); + ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall(); + toolCall.id = fc.id().orElse("call_" + fc.name().orElse("unknown")); + toolCall.type = "function"; + ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function(); + function.name = fc.name().orElse(""); + if (fc.args().isPresent()) { + try { + function.arguments = objectMapper.writeValueAsString(fc.args().get()); + } catch (Exception e) { + logger.warn("Failed to serialize function arguments", e); + } + } + toolCall.function = function; + toolCalls.add(toolCall); + } + + /** + * Updates the toolResponses list based on the provided function response part. + * + * @param part The input part containing the execution results of a function. + * @param toolResponses The list of tool responses to append the mapped output to. + */ + private static void handleFunctionResponsePart(Part part, List toolResponses) { + FunctionResponse fr = part.functionResponse().get(); + Message toolResp = new Message(); + toolResp.role = "tool"; + toolResp.toolCallId = fr.id().orElse(""); + if (fr.response().isPresent()) { + try { + toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get())); + } catch (Exception e) { + logger.warn("Failed to serialize tool response", e); + } + } + toolResponses.add(toolResp); + } + + /** + * Updates the request based on the provided configuration options. + * + * @param config The content generation configuration containing parameters such as temperature. + * @param request The chat completions request to populate with matching options. + */ + private static void handleConfigOptions( + GenerateContentConfig config, ChatCompletionsRequest request) { + config.temperature().ifPresent(v -> request.temperature = v.doubleValue()); + config.topP().ifPresent(v -> request.topP = v.doubleValue()); + config + .maxOutputTokens() + .ifPresent( + v -> { + request.maxCompletionTokens = Math.toIntExact(v); + }); + config.stopSequences().ifPresent(v -> request.stop = new StopCondition(v)); + config.candidateCount().ifPresent(v -> request.n = Math.toIntExact(v)); + config.presencePenalty().ifPresent(v -> request.presencePenalty = v.doubleValue()); + config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v.doubleValue()); + config.seed().ifPresent(v -> request.seed = v.longValue()); + + if (config.responseJsonSchema().isPresent()) { + ResponseFormatJsonSchema format = new ResponseFormatJsonSchema(); + ResponseFormatJsonSchema.JsonSchema schema = new ResponseFormatJsonSchema.JsonSchema(); + schema.name = "response_schema"; + schema.schema = + objectMapper.convertValue( + config.responseJsonSchema().get(), new TypeReference>() {}); + schema.strict = true; + format.jsonSchema = schema; + request.responseFormat = format; + } else if (config.responseMimeType().isPresent() + && config.responseMimeType().get().equals("application/json")) { + request.responseFormat = new ResponseFormatJsonObject(); + } + + if (config.responseLogprobs().isPresent() && config.responseLogprobs().get()) { + request.logprobs = true; + config.logprobs().ifPresent(v -> request.topLogprobs = Math.toIntExact(v)); + } + } + + /** + * Updates the request tools list based on the provided tools configuration. + * + * @param config The content generation configuration defining available tools. + * @param request The chat completions request to populate with mapped tool definitions. + */ + private static void handleTools(GenerateContentConfig config, ChatCompletionsRequest request) { + if (config.tools().isPresent()) { + List tools = new ArrayList<>(); + for (com.google.genai.types.Tool t : config.tools().get()) { + if (t.functionDeclarations().isPresent()) { + for (FunctionDeclaration fd : t.functionDeclarations().get()) { + Tool tool = new Tool(); + tool.type = "function"; + FunctionDefinition def = new FunctionDefinition(); + def.name = fd.name().orElse(""); + def.description = fd.description().orElse(""); + fd.parameters() + .ifPresent( + params -> + def.parameters = + objectMapper.convertValue( + params, new TypeReference>() {})); + tool.function = def; + tools.add(tool); + } + } + } + if (!tools.isEmpty()) { + request.tools = ImmutableList.copyOf(tools); + if (config.toolConfig().isPresent() + && config.toolConfig().get().functionCallingConfig().isPresent()) { + config + .toolConfig() + .get() + .functionCallingConfig() + .get() + .mode() + .ifPresent( + mode -> { + switch (mode.knownEnum()) { + case ANY -> request.toolChoice = new ToolChoiceMode("required"); + case NONE -> request.toolChoice = new ToolChoiceMode("none"); + case AUTO -> request.toolChoice = new ToolChoiceMode("auto"); + default -> {} + } + }); + } + } + } + } + /** * A catch-all class for message parameters. See * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index 9645016a9..a718f9a43 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -50,7 +50,7 @@ public final class ChatCompletionsResponse { private ChatCompletionsResponse() {} - static @Nullable FinishReason mapFinishReason(String reason) { + static @Nullable FinishReason mapFinishReason(@Nullable String reason) { if (reason == null) { return null; } @@ -62,7 +62,7 @@ private ChatCompletionsResponse() {} }; } - static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) { + static @Nullable GenerateContentResponseUsageMetadata mapUsage(@Nullable Usage usage) { if (usage == null) { return null; } @@ -188,8 +188,15 @@ private ImmutableList mapMessageToParts(Message message) { return parts.build(); } + /** + * Maps a list of tool calls to a list of {@link Part} objects. + * + * @param toolCalls the list of tool calls to map (non-null). + * @return a list of parts containing converted tool calls. + */ private ImmutableList mapToolCallsToParts( List toolCalls) { + ImmutableList.Builder parts = ImmutableList.builder(); for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { Part part = toolCall.toPart(); diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index 9dc63c5d6..8fe9c7b0b 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -17,11 +17,22 @@ package com.google.adk.models.chat; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; -import java.util.HashMap; -import java.util.Map; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import com.google.genai.types.Tool; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -34,17 +45,17 @@ public final class ChatCompletionsRequestTest { @Before public void setUp() { - objectMapper = new ObjectMapper(); + objectMapper = JsonBaseModel.getMapper(); } @Test public void testSerializeChatCompletionRequest_standard() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); message.role = "user"; message.content = new ChatCompletionsRequest.MessageContent("Hello"); + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; request.messages = ImmutableList.of(message); String json = objectMapper.writeValueAsString(request); @@ -56,24 +67,20 @@ public void testSerializeChatCompletionRequest_standard() throws Exception { @Test public void testSerializeChatCompletionRequest_withExtraBody() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); message.role = "user"; message.content = new ChatCompletionsRequest.MessageContent("Explain to me how AI works"); - request.messages = ImmutableList.of(message); - - Map thinkingConfig = new HashMap<>(); - thinkingConfig.put("thinking_level", "low"); - thinkingConfig.put("include_thoughts", true); - - Map google = new HashMap<>(); - google.put("thinking_config", thinkingConfig); - Map extraBody = new HashMap<>(); - extraBody.put("google", google); + ImmutableMap extraBody = + ImmutableMap.of( + "google", + ImmutableMap.of( + "thinking_config", + ImmutableMap.of("thinking_level", "low", "include_thoughts", true))); + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(message); request.extraBody = extraBody; String json = objectMapper.writeValueAsString(request); @@ -85,9 +92,6 @@ public void testSerializeChatCompletionRequest_withExtraBody() throws Exception @Test public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message userMessage = new ChatCompletionsRequest.Message(); userMessage.role = "user"; userMessage.content = new ChatCompletionsRequest.MessageContent("Check flight status"); @@ -104,11 +108,8 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th function.arguments = "{\"flight\":\"AA100\"}"; toolCall.function = function; - Map google = new HashMap<>(); - google.put("thought_signature", ""); - - Map extraContent = new HashMap<>(); - extraContent.put("google", google); + ImmutableMap extraContent = + ImmutableMap.of("google", ImmutableMap.of("thought_signature", "")); toolCall.extraContent = extraContent; @@ -120,6 +121,8 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th toolMessage.toolCallId = "function-call-1"; toolMessage.content = new ChatCompletionsRequest.MessageContent("{\"status\":\"delayed\"}"); + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; request.messages = ImmutableList.of(userMessage, modelMessage, toolMessage); String json = objectMapper.writeValueAsString(request); @@ -134,45 +137,38 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th @Test public void testSerializeChatCompletionRequest_comprehensive() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - - // Developer message with name ChatCompletionsRequest.Message devMsg = new ChatCompletionsRequest.Message(); devMsg.role = "developer"; devMsg.content = new ChatCompletionsRequest.MessageContent("System instruction"); devMsg.name = "system-bot"; - request.messages = ImmutableList.of(devMsg); - - // Response Format JSON Schema ChatCompletionsRequest.ResponseFormatJsonSchema format = new ChatCompletionsRequest.ResponseFormatJsonSchema(); format.jsonSchema = new ChatCompletionsRequest.ResponseFormatJsonSchema.JsonSchema(); format.jsonSchema.name = "MySchema"; format.jsonSchema.strict = true; - request.responseFormat = format; - // Tool Choice Named ChatCompletionsRequest.NamedToolChoice choice = new ChatCompletionsRequest.NamedToolChoice(); choice.function = new ChatCompletionsRequest.NamedToolChoice.FunctionName(); choice.function.name = "my_function"; + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(devMsg); + request.responseFormat = format; request.toolChoice = choice; String json = objectMapper.writeValueAsString(request); - // Assert Developer Message assertThat(json).contains("\"role\":\"developer\""); assertThat(json).contains("\"name\":\"system-bot\""); assertThat(json).contains("\"content\":\"System instruction\""); - // Assert Response Format assertThat(json).contains("\"response_format\":{"); assertThat(json).contains("\"type\":\"json_schema\""); assertThat(json).contains("\"name\":\"MySchema\""); assertThat(json).contains("\"strict\":true"); - // Assert Tool Choice assertThat(json).contains("\"tool_choice\":{"); assertThat(json).contains("\"type\":\"function\""); assertThat(json).contains("\"name\":\"my_function\""); @@ -182,7 +178,7 @@ public void testSerializeChatCompletionRequest_comprehensive() throws Exception public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Exception { ChatCompletionsRequest request = new ChatCompletionsRequest(); request.model = "gemini-3-flash-preview"; - + request.messages = ImmutableList.of(); request.toolChoice = new ChatCompletionsRequest.ToolChoiceMode("none"); String json = objectMapper.writeValueAsString(request); @@ -192,13 +188,15 @@ public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Excep @Test public void testSerializeChatCompletionRequest_withStopAndVoice() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - - request.stop = new ChatCompletionsRequest.StopCondition("STOP"); + ChatCompletionsRequest.StopCondition stop = new ChatCompletionsRequest.StopCondition("STOP"); ChatCompletionsRequest.AudioParam audio = new ChatCompletionsRequest.AudioParam(); audio.voice = new ChatCompletionsRequest.VoiceConfig("alloy"); + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(); + request.stop = stop; request.audio = audio; String json = objectMapper.writeValueAsString(request); @@ -211,11 +209,199 @@ public void testSerializeChatCompletionRequest_withStopAndVoice() throws Excepti public void testSerializeChatCompletionRequest_withStopList() throws Exception { ChatCompletionsRequest request = new ChatCompletionsRequest(); request.model = "gemini-3-flash-preview"; - + request.messages = ImmutableList.of(); request.stop = new ChatCompletionsRequest.StopCondition(ImmutableList.of("STOP1", "STOP2")); String json = objectMapper.writeValueAsString(request); assertThat(json).contains("\"stop\":[\"STOP1\",\"STOP2\"]"); } + + @Test + public void testFromLlmRequest_basic() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.fromText("Hello"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.model).isEqualTo("gemini-1.5-pro"); + assertThat(request.stream).isFalse(); + assertThat(request.messages).hasSize(1); + assertThat(request.messages.get(0).role).isEqualTo("user"); + assertThat(request.messages.get(0).content.getValue()).isEqualTo("Hello"); + } + + @Test + public void testFromLlmRequest_withSystemInstruction() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gpt-4") + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(ImmutableList.of(Part.fromText("Be helpful"))) + .build()) + .temperature(0.7f) + .topP(0.9f) + .maxOutputTokens(100) + .stopSequences(ImmutableList.of("END")) + .tools( + ImmutableList.of( + Tool.builder() + .functionDeclarations( + ImmutableList.of( + FunctionDeclaration.builder() + .name("get_weather") + .description("Get current weather") + .build())) + .build())) + .build()) + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.fromText("Hello"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(2); + assertThat(request.messages.get(0).role).isEqualTo("system"); + assertThat(request.messages.get(0).content.getValue()).isEqualTo("Be helpful"); + assertThat(request.temperature).isWithin(0.001).of(0.7); + assertThat(request.topP).isWithin(0.001).of(0.9); + assertThat(request.maxCompletionTokens).isEqualTo(100); + assertThat((List) request.stop.getValue()).containsExactly("END"); + assertThat(request.tools).hasSize(1); + assertThat(request.tools.get(0).function.name).isEqualTo("get_weather"); + assertThat(request.tools.get(0).function.description).isEqualTo("Get current weather"); + } + + @Test + public void testFromLlmRequest_withInlineData() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .inlineData( + Blob.builder() + .mimeType("image/jpeg") + .data("base64data".getBytes(UTF_8)) + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + + @SuppressWarnings( + "unchecked") // Safe in unit tests and this is the expected type from msg.content + List parts = + (List) msg.content.getValue(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).type).isEqualTo("image_url"); + assertThat(parts.get(0).imageUrl.url).contains("base64,"); + } + + @Test + public void testFromLlmRequest_withFileData() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://bucket/file.jpg") + .mimeType("image/jpeg") + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + + @SuppressWarnings( + "unchecked") // Safe in unit tests and this is the expected type from msg.content + List parts = + (List) msg.content.getValue(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).type).isEqualTo("image_url"); + assertThat(parts.get(0).imageUrl.url).isEqualTo("gs://bucket/file.jpg"); + } + + @Test + public void testFromLlmRequest_withFunctionCall() throws Exception { + ImmutableMap args = ImmutableMap.of("location", "Paris"); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_123") + .name("get_weather") + .args(args) + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThat(msg.role).isEqualTo("assistant"); + assertThat(msg.toolCalls).hasSize(1); + assertThat(msg.toolCalls.get(0).id).isEqualTo("call_123"); + assertThat(msg.toolCalls.get(0).type).isEqualTo("function"); + assertThat(msg.toolCalls.get(0).function.name).isEqualTo("get_weather"); + assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{\"location\":\"Paris\"}"); + } + + @Test + public void testFromLlmRequest_withStreamOptions() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder().model("gemini-1.5-pro").contents(ImmutableList.of()).build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, true); + + assertThat(request.stream).isTrue(); + assertThat(request.streamOptions).isNotNull(); + assertThat(request.streamOptions.includeUsage).isTrue(); + } }