Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,36 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Data Transfer Objects for Chat Completion API requests.
*
* <p>Can be used to translate from a {@link LlmRequest} into a {@link ChatCompletionsRequest} using
* {@link #fromLlmRequest(LlmRequest, boolean)}.
*
* <p>See
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create
*/
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
final class ChatCompletionsRequest {
public final class ChatCompletionsRequest {

/**
* See
Expand Down Expand Up @@ -249,6 +267,314 @@ final class ChatCompletionsRequest {
@JsonProperty("extra_body")
public Map<String, Object> extraBody;

private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();

/**
* Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for
* /chat/completions compatible endpoints.
*
* @param llmRequest The internal source request containing contents, configuration, and tool
* definitions.
* @param responseStreaming True if the request asks for a streaming response.
* @return A populated ChatCompletionsRequest ready for JSON serialization.
*/
public static ChatCompletionsRequest fromLlmRequest(
LlmRequest llmRequest, boolean responseStreaming) {
ChatCompletionsRequest request = new ChatCompletionsRequest();
request.model = llmRequest.model().orElse("");
request.stream = responseStreaming;
if (responseStreaming) {
StreamOptions options = new StreamOptions();
options.includeUsage = true;
request.streamOptions = options;
}

boolean isOSeries = request.model.matches("^o\\d+(?:-.*)?$");

List<Message> messages = new ArrayList<>();

llmRequest.config().ifPresent(config -> handleSystemInstruction(config, isOSeries, messages));

for (Content content : llmRequest.contents()) {
handleContent(content, messages);
}

request.messages = ImmutableList.copyOf(messages);

llmRequest
.config()
.ifPresent(
config -> {
handleConfigOptions(config, request);
handleTools(config, request);
});

return request;
}

/**
* Updates the messages list based on the provided system instruction configuration.
*
* @param config The content generation configuration that may contain a system instruction.
* @param isOSeries True if the target model belongs to the OpenAI o-series (e.g., o1, o3), which
* requires the "developer" role instead of the standard "system" role.
* @param messages The list of messages to append the mapped instruction to.
*/
private static void handleSystemInstruction(
GenerateContentConfig config, boolean isOSeries, List<Message> messages) {
if (config.systemInstruction().isPresent()) {
Message systemMsg = new Message();
systemMsg.role = isOSeries ? "developer" : "system";
systemMsg.content = new MessageContent(config.systemInstruction().get().text());
messages.add(systemMsg);
}
}

/**
* Updates the messages list based on the provided content.
*
* @param content The incoming content containing parts to map.
* @param messages The list of messages to append the mapped content to.
*/
private static void handleContent(Content content, List<Message> messages) {
Message msg = new Message();
String role = content.role().orElse("user");
msg.role = role.equals("model") ? "assistant" : role;

List<ContentPart> contentParts = new ArrayList<>();
List<ChatCompletionsCommon.ToolCall> toolCalls = new ArrayList<>();
List<Message> toolResponses = new ArrayList<>();

content
.parts()
.ifPresent(
parts -> {
for (Part part : parts) {
if (part.text().isPresent()) {
handleTextPart(part, contentParts);
} else if (part.inlineData().isPresent()) {
handleInlineDataPart(part, contentParts);
} else if (part.fileData().isPresent()) {
handleFileDataPart(part, contentParts);
} else if (part.functionCall().isPresent()) {
handleFunctionCallPart(part, toolCalls);
} else if (part.functionResponse().isPresent()) {
handleFunctionResponsePart(part, toolResponses);
} else if (part.executableCode().isPresent()) {
logger.warn("Executable code is not supported in Chat Completion conversion");
} else if (part.codeExecutionResult().isPresent()) {
logger.warn(
"Code execution result is not supported in Chat Completion conversion");
}
}
});

if (!toolResponses.isEmpty()) {
messages.addAll(toolResponses);
} else {
if (!toolCalls.isEmpty()) {
msg.toolCalls = ImmutableList.copyOf(toolCalls);
}
if (!contentParts.isEmpty()) {
if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) {
msg.content = new MessageContent(contentParts.get(0).text);
} else {
msg.content = new MessageContent(ImmutableList.copyOf(contentParts));
}
}
messages.add(msg);
}
}

/**
* Updates the contentParts list based on the provided text part.
*
* @param part The input part containing simple text.
* @param contentParts The list of content parts to append the mapped text to.
*/
private static void handleTextPart(Part part, List<ContentPart> contentParts) {
ContentPart textPart = new ContentPart();
textPart.type = "text";
textPart.text = part.text().get();
contentParts.add(textPart);
}

/**
* Updates the contentParts list based on the provided inline data part.
*
* @param part The input part containing base64 inline data.
* @param contentParts The list of content parts to append the mapped image URL to.
*/
private static void handleInlineDataPart(Part part, List<ContentPart> contentParts) {
ContentPart imgPart = new ContentPart();
imgPart.type = "image_url";
ImageUrl imageUrl = new ImageUrl();
imageUrl.url =
"data:"
+ part.inlineData().get().mimeType().orElse("image/jpeg")
+ ";base64,"
+ Base64.getEncoder().encodeToString(part.inlineData().get().data().get());
imgPart.imageUrl = imageUrl;
contentParts.add(imgPart);
}

/**
* Updates the contentParts list based on the provided file data part.
*
* @param part The input part referencing a stored file via URI.
* @param contentParts The list of content parts to append the mapped image URL to.
*/
private static void handleFileDataPart(Part part, List<ContentPart> contentParts) {
ContentPart imgPart = new ContentPart();
imgPart.type = "image_url";
ImageUrl imageUrl = new ImageUrl();
imageUrl.url = part.fileData().get().fileUri().orElse("");
imgPart.imageUrl = imageUrl;
contentParts.add(imgPart);
}

/**
* Updates the toolCalls list based on the provided function call part.
*
* @param part The input part containing a requested function call or invocation.
* @param toolCalls The list of tool calls to append the mapped function call to.
*/
private static void handleFunctionCallPart(
Part part, List<ChatCompletionsCommon.ToolCall> toolCalls) {
com.google.genai.types.FunctionCall fc = part.functionCall().get();
ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall();
toolCall.id = fc.id().orElse("call_" + fc.name().orElse("unknown"));
toolCall.type = "function";
ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function();
function.name = fc.name().orElse("");
if (fc.args().isPresent()) {
try {
function.arguments = objectMapper.writeValueAsString(fc.args().get());
} catch (Exception e) {
logger.warn("Failed to serialize function arguments", e);
}
}
toolCall.function = function;
toolCalls.add(toolCall);
}

/**
* Updates the toolResponses list based on the provided function response part.
*
* @param part The input part containing the execution results of a function.
* @param toolResponses The list of tool responses to append the mapped output to.
*/
private static void handleFunctionResponsePart(Part part, List<Message> toolResponses) {
FunctionResponse fr = part.functionResponse().get();
Message toolResp = new Message();
toolResp.role = "tool";
toolResp.toolCallId = fr.id().orElse("");
if (fr.response().isPresent()) {
try {
toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get()));
} catch (Exception e) {
logger.warn("Failed to serialize tool response", e);
}
}
toolResponses.add(toolResp);
}

/**
* Updates the request based on the provided configuration options.
*
* @param config The content generation configuration containing parameters such as temperature.
* @param request The chat completions request to populate with matching options.
*/
private static void handleConfigOptions(
GenerateContentConfig config, ChatCompletionsRequest request) {
config.temperature().ifPresent(v -> request.temperature = v.doubleValue());
config.topP().ifPresent(v -> request.topP = v.doubleValue());
config
.maxOutputTokens()
.ifPresent(
v -> {
request.maxCompletionTokens = Math.toIntExact(v);
});
config.stopSequences().ifPresent(v -> request.stop = new StopCondition(v));
config.candidateCount().ifPresent(v -> request.n = Math.toIntExact(v));
config.presencePenalty().ifPresent(v -> request.presencePenalty = v.doubleValue());
config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v.doubleValue());
config.seed().ifPresent(v -> request.seed = v.longValue());

if (config.responseJsonSchema().isPresent()) {
ResponseFormatJsonSchema format = new ResponseFormatJsonSchema();
ResponseFormatJsonSchema.JsonSchema schema = new ResponseFormatJsonSchema.JsonSchema();
schema.name = "response_schema";
schema.schema =
objectMapper.convertValue(
config.responseJsonSchema().get(), new TypeReference<Map<String, Object>>() {});
schema.strict = true;
format.jsonSchema = schema;
request.responseFormat = format;
} else if (config.responseMimeType().isPresent()
&& config.responseMimeType().get().equals("application/json")) {
request.responseFormat = new ResponseFormatJsonObject();
}

if (config.responseLogprobs().isPresent() && config.responseLogprobs().get()) {
request.logprobs = true;
config.logprobs().ifPresent(v -> request.topLogprobs = Math.toIntExact(v));
}
}

/**
* Updates the request tools list based on the provided tools configuration.
*
* @param config The content generation configuration defining available tools.
* @param request The chat completions request to populate with mapped tool definitions.
*/
private static void handleTools(GenerateContentConfig config, ChatCompletionsRequest request) {
if (config.tools().isPresent()) {
List<Tool> tools = new ArrayList<>();
for (com.google.genai.types.Tool t : config.tools().get()) {
if (t.functionDeclarations().isPresent()) {
for (FunctionDeclaration fd : t.functionDeclarations().get()) {
Tool tool = new Tool();
tool.type = "function";
FunctionDefinition def = new FunctionDefinition();
def.name = fd.name().orElse("");
def.description = fd.description().orElse("");
fd.parameters()
.ifPresent(
params ->
def.parameters =
objectMapper.convertValue(
params, new TypeReference<Map<String, Object>>() {}));
tool.function = def;
tools.add(tool);
}
}
}
if (!tools.isEmpty()) {
request.tools = ImmutableList.copyOf(tools);
if (config.toolConfig().isPresent()
&& config.toolConfig().get().functionCallingConfig().isPresent()) {
config
.toolConfig()
.get()
.functionCallingConfig()
.get()
.mode()
.ifPresent(
mode -> {
switch (mode.knownEnum()) {
case ANY -> request.toolChoice = new ToolChoiceMode("required");
case NONE -> request.toolChoice = new ToolChoiceMode("none");
case AUTO -> request.toolChoice = new ToolChoiceMode("auto");
default -> {}
}
});
}
}
}
}

/**
* A catch-all class for message parameters. See
* https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public final class ChatCompletionsResponse {

private ChatCompletionsResponse() {}

static @Nullable FinishReason mapFinishReason(String reason) {
static @Nullable FinishReason mapFinishReason(@Nullable String reason) {
if (reason == null) {
return null;
}
Expand All @@ -62,7 +62,7 @@ private ChatCompletionsResponse() {}
};
}

static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) {
static @Nullable GenerateContentResponseUsageMetadata mapUsage(@Nullable Usage usage) {
if (usage == null) {
return null;
}
Expand Down Expand Up @@ -188,8 +188,15 @@ private ImmutableList<Part> mapMessageToParts(Message message) {
return parts.build();
}

/**
* Maps a list of tool calls to a list of {@link Part} objects.
*
* @param toolCalls the list of tool calls to map (non-null).
* @return a list of parts containing converted tool calls.
*/
private ImmutableList<Part> mapToolCallsToParts(
List<ChatCompletionsCommon.ToolCall> toolCalls) {

ImmutableList.Builder<Part> parts = ImmutableList.builder();
for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) {
Part part = toolCall.toPart();
Expand Down
Loading