diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fffeab698..6c1dd52bd 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -38,6 +38,8 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.telemetry.Tracing; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -99,10 +101,24 @@ private Flowable preprocess( RequestProcessor toolsProcessor = (ctx, req) -> { LlmRequest.Builder builder = req.toBuilder(); - return agent - .canonicalTools(new ReadonlyContext(ctx)) + ReadonlyContext readonlyContext = new ReadonlyContext(ctx); + return Flowable.fromIterable(agent.toolsUnion()) .concatMapCompletable( - tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) + toolOrToolset -> { + ToolContext toolContext = ToolContext.builder(ctx).build(); + if (toolOrToolset instanceof BaseToolset toolset) { + return toolset + .processLlmRequest(builder, toolContext) + .andThen( + toolset + .getTools(readonlyContext) + .concatMapCompletable( + tool -> tool.processLlmRequest(builder, toolContext))); + } else if (toolOrToolset instanceof BaseTool tool) { + return tool.processLlmRequest(builder, toolContext); + } + return Completable.complete(); + }) .andThen( Single.fromCallable( () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); diff --git a/core/src/main/java/com/google/adk/tools/BaseToolset.java b/core/src/main/java/com/google/adk/tools/BaseToolset.java index 76369e5b9..21df14ae1 100644 --- a/core/src/main/java/com/google/adk/tools/BaseToolset.java +++ b/core/src/main/java/com/google/adk/tools/BaseToolset.java @@ -17,6 +17,8 @@ package com.google.adk.tools; import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import java.util.List; import org.jspecify.annotations.Nullable; @@ -32,6 +34,21 @@ public interface BaseToolset extends AutoCloseable { */ Flowable getTools(ReadonlyContext readonlyContext); + /** + * Called during LLM request preprocessing, before the toolset's individual tools are processed. + * Allows the toolset to modify the LLM request, e.g. to append system instructions. + * + *

The default implementation is a no-op. + * + * @param llmRequestBuilder The builder for the LLM request to be modified. + * @param toolContext The tool context providing access to session state and agent information. + * @return A {@link Completable} that completes when processing is done. + */ + default Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.complete(); + } + /** * Performs cleanup and releases resources held by the toolset. * diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 2a06c1f0a..fd50df8cc 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -16,28 +16,29 @@ package com.google.adk.flows.llmflows; -import static com.google.adk.testing.TestUtils.assertEqualIgnoringFunctionIds; -import static com.google.adk.testing.TestUtils.createGenerateContentResponseUsageMetadata; -import static com.google.adk.testing.TestUtils.createInvocationContext; -import static com.google.adk.testing.TestUtils.createLlmResponse; -import static com.google.adk.testing.TestUtils.createTestAgent; -import static com.google.adk.testing.TestUtils.createTestAgentBuilder; -import static com.google.adk.testing.TestUtils.createTestLlm; -import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.truth.Truth.assertThat; - import com.google.adk.agents.Callbacks; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.ReadonlyContext; import com.google.adk.events.Event; import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult; import com.google.adk.flows.llmflows.ResponseProcessor.ResponseProcessingResult; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.testing.TestLlm; +import static com.google.adk.testing.TestUtils.assertEqualIgnoringFunctionIds; +import static com.google.adk.testing.TestUtils.createGenerateContentResponseUsageMetadata; +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.adk.testing.TestUtils.createLlmResponse; +import static com.google.adk.testing.TestUtils.createTestAgent; +import static com.google.adk.testing.TestUtils.createTestAgentBuilder; +import static com.google.adk.testing.TestUtils.createTestLlm; import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; import com.google.adk.tools.ToolContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.truth.Truth.assertThat; import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionDeclaration; @@ -47,18 +48,22 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; /** Unit tests for {@link BaseLlmFlow}. */ @RunWith(JUnit4.class) @@ -577,6 +582,350 @@ public Single> runAsync(Map args, ToolContex } } + @Test + public void run_withToolset_processLlmRequestIsCalled() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + AtomicInteger processLlmRequestCallCount = new AtomicInteger(); + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction( + () -> { + processLlmRequestCallCount.incrementAndGet(); + llmRequestBuilder.appendInstructions(List.of("instruction from toolset")); + }); + } + + @Override + public void close() {} + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm).tools(ImmutableList.of(toolset)).build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(processLlmRequestCallCount.get()).isEqualTo(1); + assertThat(testLlm.getLastRequest().config().orElseThrow().systemInstruction().orElseThrow()) + .isEqualTo(Content.fromParts(Part.fromText("instruction from toolset"))); + } + + @Test + public void run_withToolsetAndTools_processLlmRequestIsCalledBeforeTools() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just(new TestTool("my_function", ImmutableMap.of())); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction( + () -> llmRequestBuilder.appendInstructions(List.of("toolset instruction"))); + } + + @Override + public void close() {} + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm).tools(ImmutableList.of(toolset)).build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + // Both the tool and the toolset instruction should be present + assertThat(testLlm.getLastRequest().tools()).containsKey("my_function"); + assertThat(testLlm.getLastRequest().config().orElseThrow().systemInstruction().orElseThrow()) + .isEqualTo(Content.fromParts(Part.fromText("toolset instruction"))); + } + + @Test + public void run_withToolset_toolsetProcessedBeforeItsTools() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + List callOrder = new ArrayList<>(); + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just( + new BaseTool("ordering_tool", "test tool") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("ordering_tool").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("tool")); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("toolset")); + } + + @Override + public void close() {} + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm).tools(ImmutableList.of(toolset)).build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(callOrder).containsExactly("toolset", "tool").inOrder(); + } + + @Test + public void run_withToolset_toolsetAndToolsShareSameToolContext() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + + AtomicReference toolsetContext = new AtomicReference<>(); + AtomicReference toolToolContext = new AtomicReference<>(); + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just( + new BaseTool("ctx_tool", "test tool") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("ctx_tool").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> toolToolContext.set(toolContext)); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> toolsetContext.set(toolContext)); + } + + @Override + public void close() {} + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm).tools(ImmutableList.of(toolset)).build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(toolsetContext.get()).isNotNull(); + assertThat(toolToolContext.get()).isNotNull(); + assertThat(toolsetContext.get()).isSameInstanceAs(toolToolContext.get()); + } + + @Test + public void run_multipleToolsets_processedInDeclarationOrder() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + List callOrder = new ArrayList<>(); + + BaseToolset toolsetA = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just( + new BaseTool("tool_a", "tool a") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("tool_a").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("tool_a")); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("toolset_a")); + } + + @Override + public void close() {} + }; + + BaseToolset toolsetB = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just( + new BaseTool("tool_b", "tool b") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("tool_b").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("tool_b")); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("toolset_b")); + } + + @Override + public void close() {} + }; + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm).tools(ImmutableList.of(toolsetA, toolsetB)).build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + // Toolset A and its tools should be processed before toolset B and its tools + assertThat(callOrder).containsExactly("toolset_a", "tool_a", "toolset_b", "tool_b").inOrder(); + } + + @Test + public void run_mixedToolsAndToolsets_processedInDeclarationOrder() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + List callOrder = new ArrayList<>(); + + BaseTool standaloneTool = + new BaseTool("standalone", "standalone tool") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("standalone").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("standalone_tool")); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }; + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.just( + new BaseTool("toolset_tool", "from toolset") { + @Override + public Optional declaration() { + return Optional.of(FunctionDeclaration.builder().name("toolset_tool").build()); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("toolset_tool")); + } + + @Override + public Single> runAsync( + Map args, ToolContext toolContext) { + return Single.just(ImmutableMap.of()); + } + }); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction(() -> callOrder.add("toolset")); + } + + @Override + public void close() {} + }; + + // Standalone tool declared BEFORE toolset + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(standaloneTool, toolset)) + .build()); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + // Standalone tool processes first, then toolset + its tool + assertThat(callOrder).containsExactly("standalone_tool", "toolset", "toolset_tool").inOrder(); + } + @Test public void run_contextPropagation() { ContextKey testKey = ContextKey.named("test-key"); diff --git a/core/src/test/java/com/google/adk/tools/BaseToolsetTest.java b/core/src/test/java/com/google/adk/tools/BaseToolsetTest.java index bbdf9dd94..fdde5c7f3 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolsetTest.java @@ -4,8 +4,14 @@ import static org.mockito.Mockito.mock; import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.genai.types.Content; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,4 +39,73 @@ public void close() throws Exception {} List tools = toolset.getTools(mockContext).toList().blockingGet(); assertThat(tools).containsExactly(mockTool1, mockTool2); } + + @Test + public void testProcessLlmRequest_defaultIsNoOp() { + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public void close() throws Exception {} + }; + + LlmRequest.Builder builder = + LlmRequest.builder().model("test-model").config(GenerateContentConfig.builder().build()); + ToolContext toolContext = mock(ToolContext.class); + + // Default implementation should complete without error + toolset.processLlmRequest(builder, toolContext).blockingAwait(); + + // Request should be unchanged + LlmRequest request = builder.build(); + assertThat(request.model()).hasValue("test-model"); + } + + @Test + public void testProcessLlmRequest_canBeOverridden() { + AtomicBoolean called = new AtomicBoolean(false); + + BaseToolset toolset = + new BaseToolset() { + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Flowable.empty(); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return Completable.fromAction( + () -> { + called.set(true); + llmRequestBuilder.appendInstructions(List.of("Custom toolset instruction")); + }); + } + + @Override + public void close() throws Exception {} + }; + + LlmRequest.Builder builder = + LlmRequest.builder() + .model("test-model") + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder().parts(List.of(Part.fromText("original"))).build()) + .build()); + ToolContext toolContext = mock(ToolContext.class); + + toolset.processLlmRequest(builder, toolContext).blockingAwait(); + + assertThat(called.get()).isTrue(); + LlmRequest request = builder.build(); + List instructions = request.getSystemInstructions(); + assertThat(instructions.stream().anyMatch(i -> i.contains("Custom toolset instruction"))) + .isTrue(); + } }