diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 22c772e31..5f8222e70 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -210,6 +210,18 @@ private void logEvent( if (state.isProcessed(invocationContext.invocationId())) { return; } + if (config.contentFormatter() != null && content != null) { + try { + content = config.contentFormatter().apply(content, eventType); + } catch (RuntimeException e) { + + logger.log( + Level.WARNING, + "Failed to format content for invocation ID: " + invocationContext.invocationId(), + e); + content = null; // Fail-closed to avoid leaking unmasked sensitive data + } + } String invocationId = invocationContext.invocationId(); BatchProcessor processor = state.getBatchProcessor(invocationId); // Ensure table exists before logging. @@ -223,10 +235,12 @@ private void logEvent( row.put("invocation_id", invocationContext.invocationId()); row.put("user_id", invocationContext.userId()); // Parse and log content - ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); - row.put("content_parts", parsedContent.parts()); - row.put("content", parsedContent.content()); - row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + if (content != null) { + ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); + row.put("content_parts", parsedContent.parts()); + row.put("content", parsedContent.content()); + row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + } EventData data = eventData.orElse(EventData.builder().build()); row.put("status", data.status()); @@ -301,7 +315,10 @@ private Map getAttributes( } attributes.put("session_metadata", sessionMeta); } catch (RuntimeException e) { - // Ignore session enrichment errors as in Python. + logger.log( + Level.WARNING, + "Failed to log session metadata for invocation ID: " + invocationContext.invocationId(), + e); } } diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java index ccce8c3bc..b35e7c51d 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -77,20 +77,32 @@ public abstract class BigQueryLoggerConfig { // Max size of the batch processor queue. public abstract int queueMaxSize(); - // Optional custom formatter for content. - // TODO(b/491852782): Implement content formatter. - @Nullable - public abstract BiFunction contentFormatter(); + /** + * Optional custom formatter for content. + * + *

Allow plugins to modify the content before logging. This is useful for masking sensitive + * data, formatting content, etc. + * + *

The contentFormatter must be thread-safe as it may be called concurrently across + * different agent invocations and fast/non-blocking to avoid adding latency to the agent's + * event processing pipeline. + * + *

Important: To avoid corruption of the logs, the incoming content object should + * not be mutated. Modifying code should return a new copy of the object with + * desired changes. + */ + public abstract @Nullable BiFunction contentFormatter(); + + // GCS bucket name to store multi-modal content. + public abstract String gcsBucketName(); // TODO(b/491852782): Implement connection id. public abstract Optional connectionId(); // Toggle for session metadata (e.g. gchat thread-id). - // TODO(b/491852782): Implement logging of session metadata. public abstract boolean logSessionMetadata(); // Static custom tags (e.g. {"agent_role": "sales"}). - // TODO(b/491852782): Implement custom tags. public abstract ImmutableMap customTags(); // Automatically add new columns to existing tables when the plugin @@ -120,6 +132,7 @@ public static Builder builder() { .tableName("events") .clusteringFields(ImmutableList.of("event_type", "agent", "user_id")) .logMultiModalContent(true) + .gcsBucketName("") .retryConfig(RetryConfig.builder().build()) .batchSize(1) .batchFlushInterval(Duration.ofSeconds(1)) @@ -205,6 +218,9 @@ public abstract Builder contentFormatter( @CanIgnoreReturnValue public abstract Builder viewPrefix(String viewPrefix); + @CanIgnoreReturnValue + public abstract Builder gcsBucketName(String gcsBucketName); + @CanIgnoreReturnValue public abstract Builder credentials(Credentials credentials); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java index ef721e432..7e6107e9d 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginE2ETest.java @@ -200,7 +200,8 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception { assertEquals("user", agentStartingRow.get("user_id")); assertNotNull("invocation_id should be populated", agentStartingRow.get("invocation_id")); assertTrue("timestamp should be positive", (Long) agentStartingRow.get("timestamp") > 0); - assertEquals(false, agentStartingRow.get("is_truncated")); + // AGENT_STARTING is not a content-bearing event, so is_truncated is not set and should be null. + assertEquals(null, agentStartingRow.get("is_truncated")); // Verify content for USER_MESSAGE_RECEIVED Map userMessageRow = diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index fed1d81f1..5a149d3e2 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -79,6 +79,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.BiFunction; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -493,7 +494,12 @@ public void onModelErrorCallback_populatesCorrectFields() throws Exception { assertEquals("ERROR", row.get("status")); assertEquals("model error message", row.get("error_message")); assertNotNull(row.get("latency_ms")); - assertEquals(false, row.get("is_truncated")); + assertFalse("Row should not contain content when it is null", row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when it is null", row.containsKey("content_parts")); + assertFalse( + "Row should not contain is_truncated when content is null", + row.containsKey("is_truncated")); } @Test @@ -649,6 +655,108 @@ protected StreamWriter createWriter() { "attributes should not contain session_metadata", attributes.has("session_metadata")); } + @Test + public void logEvent_usesContentFormatter_whenConfigured() throws Exception { + BiFunction formatter = + (content, eventType) -> { + if (Objects.equals(eventType, "USER_MESSAGE_RECEIVED") && content instanceof Content) { + return "Formatted: " + content; + } + return content; + }; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); + + Content content = Content.fromParts(Part.fromText("test message")); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + assertTrue(row.get("content").toString().contains("Formatted: ")); + } + + @Test + public void logEvent_handlesNullContentFromFormatter() throws Exception { + BiFunction formatter = (content, eventType) -> null; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); + + Content content = Content.fromParts(Part.fromText("test message")); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + assertFalse( + "Row should not contain content when formatter returns null", row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when formatter returns null", + row.containsKey("content_parts")); + } + + @Test + public void logEvent_handlesExceptionFromFormatter() throws Exception { + BiFunction formatter = + (content, eventType) -> { + throw new RuntimeException("Formatter error"); + }; + + BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); + PluginState formattedState = + new PluginState(formattedConfig) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + }; + BigQueryAgentAnalyticsPlugin formattedPlugin = + new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState); + + Content content = Content.fromParts(Part.fromText("test message")); + formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + Map row = formattedState.getBatchProcessor("invocation_id").queue.poll(); + assertNotNull(row); + assertFalse( + "Row should not contain content when formatter throws exception", + row.containsKey("content")); + assertFalse( + "Row should not contain content_parts when formatter throws exception", + row.containsKey("content_parts")); + } + @Test public void maybeUpgradeSchema_addsNewTopLevelField() throws Exception { Table mockTable = mock(Table.class);