Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion packages/toolbox-adk/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# limitations under the License.

steps:
- id: Build toolbox
name: 'golang:1.26'
entrypoint: /bin/bash
args:
- '-c'
- |
cd packages/toolbox-adk/genai-toolbox
go build -o ../toolbox main.go
- id: Install requirements
name: 'python:${_VERSION}'
dir: 'packages/toolbox-adk'
Expand All @@ -39,16 +47,20 @@ steps:
- TOOLBOX_VERSION=$_TOOLBOX_VERSION
- GOOGLE_CLOUD_PROJECT=$PROJECT_ID
- TOOLBOX_MANIFEST_VERSION=${_TOOLBOX_MANIFEST_VERSION}
- TEST_MOCK_GCP=false
args:
- '-c'
- |
chmod +x toolbox
source /workspace/venv/bin/activate
python -m pytest --cov=src/toolbox_adk --cov-report=term --cov-fail-under=90 tests/
python -m pytest -s --cov=src/toolbox_adk --cov-report=term --cov-fail-under=90 tests/
entrypoint: /bin/bash
options:
machineType: 'E2_HIGHCPU_8'
logging: CLOUD_LOGGING_ONLY
substitutions:
_VERSION: '3.13'
# Default values (can be overridden by triggers)
_TOOLBOX_VERSION: '1.1.0'
_TOOLBOX_MANIFEST_VERSION: '34'
_TOOLBOX_URL: 'http://localhost:5000'
66 changes: 22 additions & 44 deletions packages/toolbox-adk/src/toolbox_adk/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
import logging
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
from typing import Any, Dict, Mapping, Optional

import toolbox_core
from fastapi.openapi.models import (
Expand All @@ -31,6 +31,7 @@
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration, Schema, Type
from pydantic import ValidationError
from toolbox_core.protocol import AdditionalPropertiesSchema, ParameterSchema
from toolbox_core.tool import ToolboxTool as CoreToolboxTool
from typing_extensions import override
Expand All @@ -40,25 +41,19 @@


class ToolboxTool(BaseTool):
"""
A tool that delegates to a remote Toolbox tool, integrated with ADK.
"""
"""A tool that delegates to a remote Toolbox tool, integrated with ADK."""

def __init__(
self,
core_tool: CoreToolboxTool,
auth_config: Optional[CredentialConfig] = None,
adk_token_getters: Optional[Mapping[str, Any]] = None,
):
"""Args:
core_tool: The underlying toolbox_core.py tool instance.
auth_config: Credential configuration to handle interactive flows.
adk_token_getters: Tool-specific auth token getters.
"""
Args:
core_tool: The underlying toolbox_core.py tool instance.
auth_config: Credential configuration to handle interactive flows.
adk_token_getters: Tool-specific auth token getters.
"""
# We act as a proxy.
# We need to extract metadata from the core tool to satisfy BaseTool's contract.

name = getattr(core_tool, "__name__", None)
if not name:
raise ValueError(f"Core tool {core_tool} must have a valid __name__")
Expand All @@ -72,7 +67,6 @@ def __init__(
super().__init__(
name=name,
description=description,
# Pass empty custom_metadata as it is not currently used
custom_metadata={},
)
self._core_tool = core_tool
Expand All @@ -95,11 +89,9 @@ def _build_schema(self, param: Any) -> Schema:
"""Builds a Schema from a parameter."""
param_type = getattr(param, "type", "string")
schema_type = self._param_type_to_schema_type(param_type)

properties = {}
required = []
schema_items = None
schema_additional_properties = None

if schema_type == Type.ARRAY:
if hasattr(param, "items") and param.items:
Expand All @@ -111,6 +103,7 @@ def _build_schema(self, param: Any) -> Schema:
properties[k] = self._build_schema(v)
if getattr(v, "required", False):
required.append(k)

return Schema(
type=schema_type,
description=getattr(param, "description", "") or "",
Expand All @@ -124,10 +117,6 @@ def _get_declaration(self) -> Optional[FunctionDeclaration]:
"""Gets the function declaration for the tool."""
properties = {}
required = []

# We do not use `google.genai.types.FunctionDeclaration.from_callable`
# here because it explicitly drops argument descriptions from the schema
# properties, lumping them all into the root description instead.
if hasattr(self._core_tool, "_params") and self._core_tool._params:
for param in self._core_tool._params:
properties[param.name] = self._build_schema(param)
Expand All @@ -143,7 +132,6 @@ def _get_declaration(self) -> Optional[FunctionDeclaration]:
if properties
else None
)

return FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
Expand All @@ -154,7 +142,6 @@ async def run_async(
args: Dict[str, Any],
tool_context: ToolContext,
) -> Any:
# Check if USER_IDENTITY is configured
reset_token = None

if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
Expand All @@ -172,7 +159,6 @@ async def run_async(
"USER_IDENTITY requires client_id and client_secret"
)

# Construct ADK AuthConfig
scopes = self._auth_config.scopes or ["openid", "profile", "email"]
scope_dict = {s: "" for s in scopes}

Expand All @@ -195,9 +181,7 @@ async def run_async(
),
)

# Check if we already have credentials from a previous exchange
try:
# Try to load credential from credential service first (persists across sessions)
creds = None
try:
if tool_context._invocation_context.credential_service:
Expand All @@ -206,19 +190,16 @@ async def run_async(
callback_context=tool_context,
)
except ValueError:
# Credential service might not be initialized
pass

if not creds:
# Fallback to session state (get_auth_response returns AuthCredential if found)
creds = tool_context.get_auth_response(auth_config_adk)

if creds and creds.oauth2 and creds.oauth2.access_token:
reset_token = USER_TOKEN_CONTEXT_VAR.set(
creds.oauth2.access_token
)

# Bind the token to the underlying core_tool so it constructs headers properly
needed_services = set()
for requested_service in list(
self._core_tool._required_authn_params.values()
Expand All @@ -229,7 +210,6 @@ async def run_async(
needed_services.add(requested_service)

for s in needed_services:
# Only add if not already registered (prevents ValueError on duplicate params or subsequent runs)
if (
not hasattr(self._core_tool, "_auth_token_getters")
or s not in self._core_tool._auth_token_getters
Expand All @@ -238,7 +218,6 @@ async def run_async(
s,
lambda t=creds.oauth2.id_token or creds.oauth2.access_token: t,
)
# Once we use it from get_auth_response, save it to the auth service for future use
Comment thread
Deeven-Seru marked this conversation as resolved.
try:
if tool_context._invocation_context.credential_service:
auth_config_adk.exchanged_auth_credential = creds
Expand All @@ -256,7 +235,6 @@ async def run_async(
except Exception as e:
if "credential" in str(e).lower() or isinstance(e, ValueError):
raise e

logging.warning(
f"Unexpected error in get_auth_response during User Identity (OAuth2) retrieval: {e}. "
"Falling back to request_credential.",
Expand All @@ -272,41 +250,41 @@ async def run_async(
# This deferred loop also enables dynamic 1-arity `tool_context` injection.
needed_services = set()
for reqs in self._core_tool._required_authn_params.values():
needed_services.update(reqs)
if isinstance(reqs, list):
needed_services.update(reqs)
else:
needed_services.add(reqs)
needed_services.update(self._core_tool._required_authz_tokens)

for service, getter in self._adk_token_getters.items():
if service in needed_services:
sig = inspect.signature(getter)

if len(sig.parameters) == 1:
bound_getter = lambda t=getter, ctx=tool_context: t(ctx)
else:
bound_getter = getter

self._core_tool = self._core_tool.add_auth_token_getter(
service, bound_getter
)

result: Optional[Any] = None
error: Optional[Exception] = None

try:
# Execute the core tool
result = await self._core_tool(**args)
return result

except Exception as e:
error = e
return await self._core_tool(**args)
except (TypeError, PermissionError) as e:
# Propagate system-level errors
raise e
except (ValueError, ValidationError, Exception) as e:
# Catch tool execution and validation errors and return as a structured dictionary
# This handles cases like invalid parameters, tool crashes, or server errors
logging.warning(
"Toolbox tool '%s' execution failed: %s", self.name, e, exc_info=True
)
return {"error": f"{type(e).__name__}: {e}", "is_error": True}
finally:
if reset_token:
USER_TOKEN_CONTEXT_VAR.reset(reset_token)

def bind_params(self, bounded_params: Dict[str, Any]) -> "ToolboxTool":
"""Allows runtime binding of parameters, delegating to core tool."""
new_core_tool = self._core_tool.bind_params(bounded_params)
# Return a new wrapper
return ToolboxTool(
core_tool=new_core_tool,
auth_config=self._auth_config,
Expand Down
22 changes: 14 additions & 8 deletions packages/toolbox-adk/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def toolbox_version() -> str:

@pytest_asyncio.fixture(scope="session")
def tools_file_path(project_id: str) -> Generator[str]:
"""Provides a temporary file path containing the tools manifest."""
if os.path.exists("tools.yaml"):
print("Using local tools.yaml at root")
yield os.path.abspath("tools.yaml")
return

if os.environ.get("TEST_MOCK_GCP"):
content = "tools: []" # Dummy manifest
path = create_tmpfile(content)
Expand Down Expand Up @@ -144,16 +148,18 @@ def auth_token2(project_id: str) -> str:

@pytest_asyncio.fixture(scope="session")
def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None]:
"""Starts the toolbox server as a subprocess."""
if os.environ.get("TEST_MOCK_GCP"):
# Still allow mocked runs if no binary is found, but if it exists, let's use it
if os.environ.get("TEST_MOCK_GCP") and not os.path.exists("./toolbox"):
yield
return

print("Downloading toolbox binary from gcs bucket...")
source_blob_name = get_toolbox_binary_url(toolbox_version)
download_blob("mcp-toolbox-for-databases", source_blob_name, "toolbox")

print("Toolbox binary downloaded successfully.")
if os.path.exists("./toolbox"):
print("Using existing toolbox binary.")
else:
print("Downloading toolbox binary from gcs bucket...")
source_blob_name = get_toolbox_binary_url(toolbox_version)
download_blob("mcp-toolbox-for-databases", source_blob_name, "toolbox")
print("Toolbox binary downloaded successfully.")
try:
print("Opening toolbox server process...")
# Make toolbox executable
Expand Down
46 changes: 46 additions & 0 deletions packages/toolbox-adk/tests/unit/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
from google.genai.types import Type
from pydantic import ValidationError

from toolbox_adk.credentials import CredentialConfig, CredentialType
from toolbox_adk.tool import ToolboxTool
Expand Down Expand Up @@ -56,6 +57,51 @@ async def test_auth_check_no_token(self):
# Should proceed to execute (auth not forced)
mock_core.assert_awaited()

@pytest.mark.asyncio
async def test_run_async_returns_error_on_exception(self):
mock_core = AsyncMock(side_effect=RuntimeError("boom"))
mock_core.__name__ = "my_tool"
mock_core.__doc__ = "my description"

tool = ToolboxTool(mock_core)
ctx = MagicMock()

result = await tool.run_async({"arg": 1}, ctx)

assert isinstance(result, dict) and "error" in result
assert result.get("is_error") is True
assert "RuntimeError" in result["error"]

@pytest.mark.asyncio
async def test_run_async_returns_error_on_validation_error(self):
# Setup mock to raise a ValidationError
mock_core = AsyncMock()
mock_core.__name__ = "my_tool"
mock_core.__doc__ = "my description"

# We simulate a ValidationError by raising it from the mock
mock_core.side_effect = ValidationError.from_exception_data(
"MyModel",
[
{
"type": "missing",
"loc": ("param",),
"msg": "field required",
"input": {},
}
],
)

tool = ToolboxTool(mock_core)
ctx = MagicMock()

result = await tool.run_async({"arg": 1}, ctx)

assert isinstance(result, dict) and "error" in result
assert result.get("is_error") is True
assert "ValidationError" in result["error"]
assert "Field required" in result["error"]

@pytest.mark.asyncio
async def test_bind_params(self):
mock_core = MagicMock()
Expand Down
Loading