Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ private void logEvent(
if (state.isProcessed(invocationContext.invocationId())) {
return;
}
if (config.contentFormatter() != null && content != null) {
try {
content = config.contentFormatter().apply(content, eventType);
} catch (RuntimeException e) {

logger.log(
Level.WARNING,
"Failed to format content for invocation ID: " + invocationContext.invocationId(),
e);
content = null; // Fail-closed to avoid leaking unmasked sensitive data
}
}
String invocationId = invocationContext.invocationId();
BatchProcessor processor = state.getBatchProcessor(invocationId);
// Ensure table exists before logging.
Expand All @@ -223,10 +235,12 @@ private void logEvent(
row.put("invocation_id", invocationContext.invocationId());
row.put("user_id", invocationContext.userId());
// Parse and log content
ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength());
row.put("content_parts", parsedContent.parts());
row.put("content", parsedContent.content());
row.put("is_truncated", isContentTruncated || parsedContent.isTruncated());
if (content != null) {
ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength());
row.put("content_parts", parsedContent.parts());
row.put("content", parsedContent.content());
row.put("is_truncated", isContentTruncated || parsedContent.isTruncated());
}

EventData data = eventData.orElse(EventData.builder().build());
row.put("status", data.status());
Expand Down Expand Up @@ -301,7 +315,10 @@ private Map<String, Object> getAttributes(
}
attributes.put("session_metadata", sessionMeta);
} catch (RuntimeException e) {
// Ignore session enrichment errors as in Python.
logger.log(
Level.WARNING,
"Failed to log session metadata for invocation ID: " + invocationContext.invocationId(),
e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,32 @@ public abstract class BigQueryLoggerConfig {
// Max size of the batch processor queue.
public abstract int queueMaxSize();

// Optional custom formatter for content.
// TODO(b/491852782): Implement content formatter.
@Nullable
public abstract BiFunction<Object, String, Object> contentFormatter();
/**
* Optional custom formatter for content.
*
* <p>Allow plugins to modify the content before logging. This is useful for masking sensitive
* data, formatting content, etc.
*
* <p>The contentFormatter must be <b>thread-safe</b> as it may be called concurrently across
* different agent invocations and <b>fast/non-blocking</b> to avoid adding latency to the agent's
* event processing pipeline.
*
* <p><b>Important:</b> To avoid corruption of the logs, the incoming content object should
* <b>not</b> be mutated. Modifying code should return a <b>new copy</b> of the object with
* desired changes.
*/
public abstract @Nullable BiFunction<Object, String, Object> contentFormatter();

// GCS bucket name to store multi-modal content.
public abstract String gcsBucketName();

// TODO(b/491852782): Implement connection id.
public abstract Optional<String> connectionId();

// Toggle for session metadata (e.g. gchat thread-id).
// TODO(b/491852782): Implement logging of session metadata.
public abstract boolean logSessionMetadata();

// Static custom tags (e.g. {"agent_role": "sales"}).
// TODO(b/491852782): Implement custom tags.
public abstract ImmutableMap<String, Object> customTags();

// Automatically add new columns to existing tables when the plugin
Expand Down Expand Up @@ -120,6 +132,7 @@ public static Builder builder() {
.tableName("events")
.clusteringFields(ImmutableList.of("event_type", "agent", "user_id"))
.logMultiModalContent(true)
.gcsBucketName("")
.retryConfig(RetryConfig.builder().build())
.batchSize(1)
.batchFlushInterval(Duration.ofSeconds(1))
Expand Down Expand Up @@ -205,6 +218,9 @@ public abstract Builder contentFormatter(
@CanIgnoreReturnValue
public abstract Builder viewPrefix(String viewPrefix);

@CanIgnoreReturnValue
public abstract Builder gcsBucketName(String gcsBucketName);

@CanIgnoreReturnValue
public abstract Builder credentials(Credentials credentials);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ public void runAgent_logsAgentStartingAndCompleted() throws Exception {
assertEquals("user", agentStartingRow.get("user_id"));
assertNotNull("invocation_id should be populated", agentStartingRow.get("invocation_id"));
assertTrue("timestamp should be positive", (Long) agentStartingRow.get("timestamp") > 0);
assertEquals(false, agentStartingRow.get("is_truncated"));
// AGENT_STARTING is not a content-bearing event, so is_truncated is not set and should be null.
assertEquals(null, agentStartingRow.get("is_truncated"));

// Verify content for USER_MESSAGE_RECEIVED
Map<String, Object> userMessageRow =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.LogRecord;
Expand Down Expand Up @@ -493,7 +494,12 @@ public void onModelErrorCallback_populatesCorrectFields() throws Exception {
assertEquals("ERROR", row.get("status"));
assertEquals("model error message", row.get("error_message"));
assertNotNull(row.get("latency_ms"));
assertEquals(false, row.get("is_truncated"));
assertFalse("Row should not contain content when it is null", row.containsKey("content"));
assertFalse(
"Row should not contain content_parts when it is null", row.containsKey("content_parts"));
assertFalse(
"Row should not contain is_truncated when content is null",
row.containsKey("is_truncated"));
}

@Test
Expand Down Expand Up @@ -649,6 +655,108 @@ protected StreamWriter createWriter() {
"attributes should not contain session_metadata", attributes.has("session_metadata"));
}

@Test
public void logEvent_usesContentFormatter_whenConfigured() throws Exception {
BiFunction<Object, String, Object> formatter =
(content, eventType) -> {
if (Objects.equals(eventType, "USER_MESSAGE_RECEIVED") && content instanceof Content) {
return "Formatted: " + content;
}
return content;
};

BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build();
PluginState formattedState =
new PluginState(formattedConfig) {
@Override
protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) {
return mockWriteClient;
}

@Override
protected StreamWriter createWriter() {
return mockWriter;
}
};
BigQueryAgentAnalyticsPlugin formattedPlugin =
new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState);

Content content = Content.fromParts(Part.fromText("test message"));
formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe();

Map<String, Object> row = formattedState.getBatchProcessor("invocation_id").queue.poll();
assertNotNull(row);
assertTrue(row.get("content").toString().contains("Formatted: "));
}

@Test
public void logEvent_handlesNullContentFromFormatter() throws Exception {
BiFunction<Object, String, Object> formatter = (content, eventType) -> null;

BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build();
PluginState formattedState =
new PluginState(formattedConfig) {
@Override
protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) {
return mockWriteClient;
}

@Override
protected StreamWriter createWriter() {
return mockWriter;
}
};
BigQueryAgentAnalyticsPlugin formattedPlugin =
new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState);

Content content = Content.fromParts(Part.fromText("test message"));
formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe();

Map<String, Object> row = formattedState.getBatchProcessor("invocation_id").queue.poll();
assertNotNull(row);
assertFalse(
"Row should not contain content when formatter returns null", row.containsKey("content"));
assertFalse(
"Row should not contain content_parts when formatter returns null",
row.containsKey("content_parts"));
}

@Test
public void logEvent_handlesExceptionFromFormatter() throws Exception {
BiFunction<Object, String, Object> formatter =
(content, eventType) -> {
throw new RuntimeException("Formatter error");
};

BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build();
PluginState formattedState =
new PluginState(formattedConfig) {
@Override
protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) {
return mockWriteClient;
}

@Override
protected StreamWriter createWriter() {
return mockWriter;
}
};
BigQueryAgentAnalyticsPlugin formattedPlugin =
new BigQueryAgentAnalyticsPlugin(formattedConfig, mockBigQuery, formattedState);

Content content = Content.fromParts(Part.fromText("test message"));
formattedPlugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe();

Map<String, Object> row = formattedState.getBatchProcessor("invocation_id").queue.poll();
assertNotNull(row);
assertFalse(
"Row should not contain content when formatter throws exception",
row.containsKey("content"));
assertFalse(
"Row should not contain content_parts when formatter throws exception",
row.containsKey("content_parts"));
}

@Test
public void maybeUpgradeSchema_addsNewTopLevelField() throws Exception {
Table mockTable = mock(Table.class);
Expand Down