diff --git a/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/__init__.py b/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/__init__.py new file mode 100644 index 00000000..036cf0e6 --- /dev/null +++ b/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/__init__.py @@ -0,0 +1,17 @@ +"""bind_tools cap + shortlist machinery for cuga_lite. + +Keeps cuga_lite_graph.py focused on orchestration. See :mod:`.cap` for the +provider-safe cap and shortlister flow. +""" + +from cuga.backend.cuga_graph.nodes.cuga_lite.bind_tools.cap import ( + apply_bind_tools_cap_and_merge, + bind_tools_max_count_from_settings, + bind_tools_pad_to_cap_from_settings, +) + +__all__ = [ + "apply_bind_tools_cap_and_merge", + "bind_tools_max_count_from_settings", + "bind_tools_pad_to_cap_from_settings", +] diff --git a/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/cap.py b/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/cap.py new file mode 100644 index 00000000..0365dc6b --- /dev/null +++ b/src/cuga/backend/cuga_graph/nodes/cuga_lite/bind_tools/cap.py @@ -0,0 +1,350 @@ +"""Provider-safe cap and LLM shortlisting for ``bind_tools`` candidate lists. + +Strict providers (Groq, OpenAI) reject ``bind_tools`` calls with more than ~128 +tools per request. This module reads the cap from settings, and — when the +candidate list exceeds it — defers to the same LLM shortlister that runtime +tool-discovery uses (:meth:`PromptUtils.shortlist_tool_names`) to pick the +top-K most relevant tools for the user query. + +Note: when the cap is exceeded, applying it incurs a single shortlister LLM +round-trip per ``call_model`` invocation. This is intentional — silent +truncation would corrupt benchmark results comparing native tool-calling vs +text-mode. Permissive backends (WatsonX, Anthropic via LiteLLM) can disable +the cap with ``cuga_lite_bind_tools_max_count=0``. +""" + +from typing import Any, Dict, List, Optional, Set, Tuple + +from loguru import logger +from langchain_core.language_models import BaseChatModel +from langchain_core.tools import StructuredTool + +from cuga.backend.cuga_graph.nodes.cuga_lite.prompt_utils import PromptUtils +from cuga.backend.cuga_graph.nodes.cuga_lite.tool_provider_interface import ToolProviderInterface +from cuga.config import settings + + +__all__ = [ + "apply_bind_tools_cap_and_merge", + "bind_tools_max_count_from_settings", + "bind_tools_pad_to_cap_from_settings", +] + + +def bind_tools_max_count_from_settings() -> int: + """Provider-safe cap on the number of tools passed to ``LLM.bind_tools``. + + Default 128 matches the strictest common provider limit (Groq, OpenAI). Set + ``DYNACONF_ADVANCED_FEATURES__CUGA_LITE_BIND_TOOLS_MAX_COUNT=0`` (or negative) + to disable the cap entirely — useful for permissive backends like WatsonX or + LiteLLM routing to Anthropic. + """ + try: + raw = getattr(settings.advanced_features, "cuga_lite_bind_tools_max_count", 128) + except Exception: + return 128 + try: + return int(raw) + except (TypeError, ValueError): + return 128 + + +def bind_tools_pad_to_cap_from_settings() -> bool: + """Whether to pad the shortlister output with the remaining tools to fill the cap. + + Default ``False`` — bind only the tools the shortlister deemed relevant (often 1-4 + on the existing system prompt). cuga_lite is a code-execution agent and exhibits + measurable regressions in code-emission when many tools are bound natively (the + model tends to switch to native ``tool_calls`` mode, which the code-mode flow + doesn't fully exercise). + + Set ``True`` for research scenarios where the user explicitly wants ``mode=all`` + to bind as many tools as the provider will accept. + """ + try: + raw = getattr(settings.advanced_features, "cuga_lite_bind_tools_pad_to_cap", False) + except Exception: + return False + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + return raw.strip().lower() in ("true", "1", "yes", "on") + return bool(raw) + + +def _resolve_find_tools_overlay( + bound: List[StructuredTool], + *, + include_find_tools: bool, + tools_context_ref: Optional[Dict[str, Any]], +) -> Tuple[Optional[StructuredTool], str, bool, List[StructuredTool]]: + """Resolve the ``find_tools`` overlay candidate and reconcile against ``bound``. + + Returns ``(find_tools_tool, find_tools_name, find_tools_already_in_bound, bound)``. + + The overlay path (``_indexed_tools_for_native_bind``) can inject ``find_tools`` + into ``bound`` independently of ``include_find_tools``, so we detect it either + way to honor an explicit opt-out. If the user disabled it but the overlay + injected it anyway, strip it from ``bound`` so it can't consume a capped slot + or sneak into the shortlister's input. (Coderabbit on #203.) + """ + find_tools_tool: Optional[StructuredTool] = None + if tools_context_ref: + candidate = tools_context_ref.get("_lc_bind_tools_find_tools") + if candidate is not None: + find_tools_tool = candidate + + find_tools_name = getattr(find_tools_tool, "name", "") or "" + find_tools_already_in_bound = bool(find_tools_name) and any( + getattr(t, "name", "") == find_tools_name for t in bound + ) + if not include_find_tools and find_tools_already_in_bound: + bound = [t for t in bound if getattr(t, "name", "") != find_tools_name] + find_tools_already_in_bound = False + return find_tools_tool, find_tools_name, find_tools_already_in_bound, bound + + +def _build_ranking_pool( + bound: List[StructuredTool], + *, + keep_find_tools: bool, + find_tools_name: str, + find_tools_already_in_bound: bool, +) -> List[StructuredTool]: + """Strip ``find_tools`` from the ranking pool when we must guarantee it survives. + + When ``include_find_tools=True`` the LLM ranker is free to drop any tool from + the ranking pool — pulling find_tools out and appending it back is the only + safe way to guarantee it. + """ + if keep_find_tools and find_tools_already_in_bound: + return [t for t in bound if getattr(t, "name", "") != find_tools_name] + return bound + + +async def _run_shortlister( + query_text: str, + *, + ranking_pool: List[StructuredTool], + tool_provider: Optional[ToolProviderInterface], + llm: Optional[BaseChatModel], + top_k: int, + mode: str, + max_count: int, +) -> List[str]: + """Run :meth:`PromptUtils.shortlist_tool_names` and validate the result. + + Raises ``RuntimeError`` on shortlister failure or empty ranking — silent + truncation would corrupt benchmark results comparing native vs text mode. + """ + all_apps: List[Any] = [] + if tool_provider is not None: + try: + all_apps = await tool_provider.get_apps() + except Exception as e: + logger.warning("bind_tools cap: tool_provider.get_apps() failed: {}", e) + + logger.info( + "bind_tools cap exceeded: mode={} candidates={} cap={} → LLM shortlister to top {}", + mode, + len(ranking_pool), + max_count, + top_k, + ) + try: + ranked_names = await PromptUtils.shortlist_tool_names( + query=query_text, + all_tools=ranking_pool, + all_apps=all_apps, + llm=llm, + top_k=top_k, + ) + except Exception as e: + raise RuntimeError( + f"cuga_lite_bind_tools shortlister failed reducing {len(ranking_pool)} tools to " + f"top {top_k} (cap={max_count}): {e!r}. Raise the cap or fix the shortlister LLM." + ) from e + + if not ranked_names: + raise RuntimeError( + f"cuga_lite_bind_tools shortlister returned 0 tools for {len(ranking_pool)} " + f"candidates (cap={max_count}, query={query_text!r}). Cannot proceed safely; " + f"raise the cap or refine the query." + ) + return ranked_names + + +def _materialize_shortlist( + ranked_names: List[str], + *, + ranking_pool: List[StructuredTool], + target_k: int, + query_text: str, + max_count: int, +) -> Tuple[List[StructuredTool], Set[str]]: + """Map ranker output back to ``StructuredTool`` objects, clamped to ``target_k``. + + Defense-in-depth clamp: enforce ``target_k`` at the call site too, in case the + shortlister returns more names than ``top_k`` (custom shortlister, future + refactor, or a mocked path). Without this clamp the bound list could exceed + ``max_count`` and re-trigger the provider 400 the cap exists to prevent. + Raises ``RuntimeError`` if the ranker hallucinated names that don't match any + candidate. (Coderabbit on #203.) + """ + by_name = {getattr(t, "name", ""): t for t in ranking_pool} + shortlisted: List[StructuredTool] = [] + seen_short: Set[str] = set() + for n in ranked_names: + t = by_name.get(n) + if t is not None and n not in seen_short: + seen_short.add(n) + shortlisted.append(t) + if len(shortlisted) >= target_k: + break + + if not shortlisted: + raise RuntimeError( + f"cuga_lite_bind_tools shortlister returned {len(ranked_names)} names but none " + f"matched the {len(ranking_pool)} candidates (cap={max_count}, " + f"query={query_text!r}, sample_ranked={ranked_names[:5]}). Shortlister LLM " + f"hallucinated tool names — raise the cap, fix the shortlister prompt, or " + f"refine the query." + ) + return shortlisted, seen_short + + +def _maybe_pad_to_cap( + shortlisted: List[StructuredTool], + *, + ranking_pool: List[StructuredTool], + seen_short: Set[str], + target_k: int, +) -> int: + """Opt-in padding (off by default) — measured regressions on m3 hockey otherwise. + + Padding pushes the model toward native ``tool_calls`` mode, which the code-mode + flow doesn't fully exercise (measured: 0 tool calls vs 5-7 without padding). + Users explicitly chasing "true mode=all" can opt in. + """ + if not bind_tools_pad_to_cap_from_settings() or len(shortlisted) >= target_k: + return 0 + padded_count = 0 + for t in ranking_pool: + name = getattr(t, "name", "") or "" + if not name or name in seen_short: + continue + seen_short.add(name) + shortlisted.append(t) + padded_count += 1 + if len(shortlisted) >= target_k: + break + return padded_count + + +async def apply_bind_tools_cap_and_merge( + bound: List[StructuredTool], + *, + query: Optional[str], + tool_provider: Optional[ToolProviderInterface], + llm: Optional[BaseChatModel], + max_count: int, + include_find_tools: bool, + tools_context_ref: Optional[Dict[str, Any]], + mode: str, +) -> List[StructuredTool]: + """Enforce the provider-safe ``max_count`` and optionally merge ``find_tools``. + + Under cap → merge ``find_tools`` (when ``include_find_tools``) and return. Over cap → + run the existing LLM shortlister (see :meth:`PromptUtils.shortlist_tool_names`) against + ``query``, take top-K (reserving 1 slot for ``find_tools`` when applicable), and return + the ranked subset. + + Raises ``RuntimeError`` with an actionable message when the cap is exceeded but + shortlisting is impossible — no user query, shortlister failure, or empty ranking. + Failing loudly is intentional: silent truncation would corrupt research/benchmark + results that compare native tool-calling against text-mode. + """ + bound_in_len = len(bound) + ( + find_tools_tool, + find_tools_name, + find_tools_already_in_bound, + bound, + ) = _resolve_find_tools_overlay( + bound, + include_find_tools=include_find_tools, + tools_context_ref=tools_context_ref, + ) + + keep_find_tools = include_find_tools and find_tools_tool is not None + ranking_pool = _build_ranking_pool( + bound, + keep_find_tools=keep_find_tools, + find_tools_name=find_tools_name, + find_tools_already_in_bound=find_tools_already_in_bound, + ) + + def _append_find_tools(tools: List[StructuredTool]) -> List[StructuredTool]: + if not keep_find_tools or find_tools_tool is None: + return tools + if find_tools_name in {getattr(t, "name", "") for t in tools}: + return tools + return [*tools, find_tools_tool] + + cap_disabled = max_count <= 0 + effective_count = len(ranking_pool) + (1 if keep_find_tools else 0) + if cap_disabled or effective_count <= max_count: + return _append_find_tools(ranking_pool) + + query_text = (query or "").strip() + if not query_text: + raise RuntimeError( + f"cuga_lite_bind_tools_mode={mode!r} produced {bound_in_len} tools but the " + f"provider-safe cap (cuga_lite_bind_tools_max_count) is {max_count}. " + f"Shortlisting requires a non-empty user query, but none was provided. Options: " + f"(a) ensure the first user message is non-empty so the shortlister can run, " + f"(b) raise the cap via DYNACONF_ADVANCED_FEATURES__CUGA_LITE_BIND_TOOLS_MAX_COUNT " + f"for permissive backends (WatsonX, Anthropic via LiteLLM), or " + f"(c) set the cap to 0 to disable (Groq/OpenAI will reject)." + ) + + reserve = 1 if keep_find_tools else 0 + target_k = max_count - reserve + if target_k <= 0: + return _append_find_tools([]) + + ranked_names = await _run_shortlister( + query_text, + ranking_pool=ranking_pool, + tool_provider=tool_provider, + llm=llm, + top_k=target_k, + mode=mode, + max_count=max_count, + ) + shortlisted, seen_short = _materialize_shortlist( + ranked_names, + ranking_pool=ranking_pool, + target_k=target_k, + query_text=query_text, + max_count=max_count, + ) + padded_count = _maybe_pad_to_cap( + shortlisted, + ranking_pool=ranking_pool, + seen_short=seen_short, + target_k=target_k, + ) + shortlisted = _append_find_tools(shortlisted) + logger.info( + "bind_tools cap: shortlisted to {} tools (mode={}, cap={}, ranked={}, padded={}, " + "include_find_tools={}, top_ranked={})", + len(shortlisted), + mode, + max_count, + len(ranked_names), + padded_count, + find_tools_tool is not None, + ranked_names[:5], + ) + return shortlisted diff --git a/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py b/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py index 41f36395..14735d45 100644 --- a/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py +++ b/src/cuga/backend/cuga_graph/nodes/cuga_lite/cuga_lite_graph.py @@ -91,6 +91,10 @@ resolved_runtime_model_name, resolve_bind_tools_fields, ) +from cuga.backend.cuga_graph.nodes.cuga_lite.bind_tools import ( + apply_bind_tools_cap_and_merge, + bind_tools_max_count_from_settings, +) from cuga.backend.cuga_graph.nodes.cuga_lite.nl_auto_continue_classifier import ( classify_nl_auto_continue, normalize_assistant_text, @@ -290,6 +294,7 @@ async def resolve_model_with_bind_tools( tools_context_ref: Optional[Dict[str, Any]], tool_provider: Optional[ToolProviderInterface], model_name: Optional[str] = None, + query: Optional[str] = None, ) -> BaseChatModel: """Optionally wrap ``active_model`` with ``bind_tools`` for native tool-calling tests. @@ -299,6 +304,23 @@ async def resolve_model_with_bind_tools( - ``cuga_lite_bind_tools_apps``: list of app names (``mode=apps`` or ``apps_and_tools``) - ``cuga_lite_bind_tools_tool_names``: StructuredTool ``name`` values (``mode=tools`` or ``apps_and_tools``) - ``cuga_lite_bind_tools_include_find_tools``: merge ``find_tools`` into ``all`` / ``apps`` / ``tools`` / ``apps_and_tools`` + - ``cuga_lite_bind_tools_max_count``: provider-safe cap on the number of tools sent to + ``bind_tools``. Default 128 (matches Groq/OpenAI). Set 0 to disable. When the + candidate list exceeds the cap, the LLM shortlister picks the top-K most relevant + tools for ``query`` (typically the first user message). + - ``cuga_lite_bind_tools_pad_to_cap``: opt-in padding (default ``False``). When the + shortlister returns fewer than the cap allows, pad with remaining candidates to fill + the cap. Off by default because padding pushes the model toward native ``tool_calls`` + mode, which the code-mode flow doesn't fully exercise (measured: 0 tool calls vs 5-7 + without padding on the m3 hockey benchmark). + + Operational cost: when the cap is exceeded, applying it incurs **one extra LLM + round-trip** (the shortlister) per ``call_model`` invocation. Permissive backends + (WatsonX, Anthropic via LiteLLM) can avoid this round-trip entirely by setting + ``cuga_lite_bind_tools_max_count=0``. Silent truncation is **not** an option — when + shortlisting cannot run safely (no user query, shortlister failure, or hallucinated + names that don't match any candidate), a ``RuntimeError`` is raised so research/ + benchmark runs comparing native tool-calling vs text-mode don't silently degrade. Profile ``gpt-oss-20b``: see ``model_runtime_profile.GPT_OSS_20B_RUNTIME_DEFAULTS``. """ @@ -317,6 +339,21 @@ async def resolve_model_with_bind_tools( settings_tool_names_fn=_bind_tools_tool_names_from_settings, settings_include_fn=lambda: _bind_include_find_tools_from_config({}), ) + max_count = bind_tools_max_count_from_settings() + + async def _cap_merge_bound(bound: List[StructuredTool]) -> List[StructuredTool]: + # Closes over query, tool_provider, llm, max_count, include_find_tools, + # tools_context_ref, mode — the four mode branches all pass the same kwargs. + return await apply_bind_tools_cap_and_merge( + bound, + query=query, + tool_provider=tool_provider, + llm=active_model, + max_count=max_count, + include_find_tools=include_find_tools, + tools_context_ref=tools_context_ref, + mode=mode, + ) if mode in ("", "none", "false", "0", "off"): if include_find_tools: @@ -341,11 +378,7 @@ async def resolve_model_with_bind_tools( logger.warning("cuga_lite_bind_tools_mode=all but tool_provider is missing") return active_model by_name = await _indexed_tools_for_native_bind(tool_provider, tools_context_ref) - bound = list(by_name.values()) - seen: Set[str] = {n for n in by_name} - _merge_find_tools_into_bound( - bound, seen, include_find_tools=include_find_tools, tools_context_ref=tools_context_ref - ) + bound = await _cap_merge_bound(list(by_name.values())) if not bound: return active_model return active_model.bind_tools(bound) @@ -399,9 +432,7 @@ async def resolve_model_with_bind_tools( missing, ) - _merge_find_tools_into_bound( - bound, seen_names, include_find_tools=include_find_tools, tools_context_ref=tools_context_ref - ) + bound = await _cap_merge_bound(bound) if not bound: return active_model return active_model.bind_tools(bound) @@ -432,9 +463,7 @@ async def resolve_model_with_bind_tools( bound.append(t) except Exception as e: logger.warning("bind_tools apps: get_tools(%s) failed: %s", app_name, e) - _merge_find_tools_into_bound( - bound, seen, include_find_tools=include_find_tools, tools_context_ref=tools_context_ref - ) + bound = await _cap_merge_bound(bound) if not bound: return active_model return active_model.bind_tools(bound) @@ -472,9 +501,7 @@ async def resolve_model_with_bind_tools( "cuga_lite_bind_tools_tool_names not found among provider tools (skipped): %s", missing, ) - _merge_find_tools_into_bound( - bound, seen, include_find_tools=include_find_tools, tools_context_ref=tools_context_ref - ) + bound = await _cap_merge_bound(bound) if not bound: return active_model return active_model.bind_tools(bound) @@ -483,6 +510,10 @@ async def resolve_model_with_bind_tools( "Unknown cuga_lite_bind_tools_mode: %s (use none|find_tools|all|apps|tools|apps_and_tools)", mode, ) + except RuntimeError: + # Actionable cap/shortlist errors from apply_bind_tools_cap_and_merge are intentional — + # surfacing them is required so research/benchmark runs don't silently degrade. + raise except Exception as e: logger.warning("resolve_model_with_bind_tools failed: %s", e) return active_model @@ -1955,6 +1986,7 @@ async def call_model(state: CugaLiteState, config: Optional[RunnableConfig] = No tools_context_ref=tools_context_ref, tool_provider=base_tool_provider, model_name=_runtime_model_name, + query=_first_user_message_text(state.chat_messages), ) response = await invoke_model.ainvoke( diff --git a/src/cuga/backend/cuga_graph/nodes/cuga_lite/prompt_utils.py b/src/cuga/backend/cuga_graph/nodes/cuga_lite/prompt_utils.py index 71563c54..fceec538 100644 --- a/src/cuga/backend/cuga_graph/nodes/cuga_lite/prompt_utils.py +++ b/src/cuga/backend/cuga_graph/nodes/cuga_lite/prompt_utils.py @@ -205,6 +205,49 @@ def get_tool_docs(tool: StructuredTool) -> tuple[str, str]: return params_doc, response_doc + @staticmethod + def _build_shortlister_payload( + all_tools: List[StructuredTool], + all_apps: List[AppDefinition], + ) -> tuple[Dict[str, Any], Dict[str, Any]]: + """Serialize ``all_tools`` and ``all_apps`` for the shortlister LLM prompt. + + Shared by :meth:`find_tools` (runtime tool discovery) and + :meth:`shortlist_tool_names` (bind-time cap reduction). Per coderabbit on + cuga-agent#203, keeping a single payload builder prevents the two callers + from drifting — both must include ``args_schema``, ``_response_schemas``, + and ``_param_constraints`` for the LLM to rank tools consistently. + """ + tools_as_dict: Dict[str, Any] = {} + for tool in all_tools: + tool_dict = tool.model_dump() + if hasattr(tool, 'args_schema') and tool.args_schema: + try: + if hasattr(tool.args_schema, 'schema'): + tool_dict['args_schema'] = tool.args_schema.schema() + elif hasattr(tool.args_schema, 'model_json_schema'): + tool_dict['args_schema'] = tool.args_schema.model_json_schema() + else: + tool_dict['args_schema'] = {} + except (AttributeError, TypeError, ValueError) as e: + # Narrow to expected serialization failures so unexpected bugs propagate + # instead of silently stripping schema (coderabbit on #203). + logger.debug(f"Failed to serialize args_schema for tool {tool.name}: {e}") + tool_dict['args_schema'] = {} + else: + tool_dict['args_schema'] = {} + + if hasattr(tool, 'func'): + if hasattr(tool.func, '_response_schemas'): + tool_dict['_response_schemas'] = tool.func._response_schemas + if hasattr(tool.func, '_param_constraints'): + tool_dict['_param_constraints'] = tool.func._param_constraints + + tools_as_dict[tool.name] = tool_dict + + apps_as_dict = {app.name: app.model_dump() for app in all_apps} + return tools_as_dict, apps_as_dict + @staticmethod async def find_tools( query: str, @@ -246,37 +289,7 @@ async def find_tools( ('human', '{input}'), ], ) - # Serialize tools properly, converting args_schema class to dict - tools_as_dict = {} - for tool in all_tools: - tool_dict = tool.model_dump() - # Extract and convert args_schema from the tool object (it's an attribute, not in model_dump) - if hasattr(tool, 'args_schema') and tool.args_schema: - try: - # Try schema() method (Pydantic v1) - if hasattr(tool.args_schema, 'schema'): - tool_dict['args_schema'] = tool.args_schema.schema() - # Try model_json_schema() method (Pydantic v2) - elif hasattr(tool.args_schema, 'model_json_schema'): - tool_dict['args_schema'] = tool.args_schema.model_json_schema() - else: - tool_dict['args_schema'] = {} - except Exception as e: - logger.debug(f"Failed to serialize args_schema for tool {tool.name}: {e}") - tool_dict['args_schema'] = {} - else: - tool_dict['args_schema'] = {} - - # Also ensure response_schemas and param_constraints are included if they exist - if hasattr(tool, 'func'): - if hasattr(tool.func, '_response_schemas'): - tool_dict['_response_schemas'] = tool.func._response_schemas - if hasattr(tool.func, '_param_constraints'): - tool_dict['_param_constraints'] = tool.func._param_constraints - - tools_as_dict[tool.name] = tool_dict - - apps_as_dict = {app.name: app.model_dump() for app in all_apps} + tools_as_dict, apps_as_dict = PromptUtils._build_shortlister_payload(all_tools, all_apps) from cuga.backend.llm.models import LLMManager from cuga.backend.cuga_graph.nodes.api.shortlister_agent.prompts.load_prompt import ( ShortListerOutputLite, @@ -397,6 +410,85 @@ async def find_tools( return "\n".join(markdown_lines) + @staticmethod + async def shortlist_tool_names( + query: str, + all_tools: List[StructuredTool], + all_apps: List[AppDefinition], + llm: Optional[Any] = None, + top_k: int = 4, + instructions: Optional[str] = None, + ) -> List[str]: + """Rank tools by relevance to ``query`` and return up to ``top_k`` names (best-first). + + Wraps the same shortlister LLM chain as :meth:`find_tools` but exposes the + ranked ``APIDetails.name`` list directly. Used by bind-time shortlisting in + ``resolve_model_with_bind_tools`` when the candidate tool count exceeds the + configured provider cap. + """ + if top_k <= 0 or not all_tools: + return [] + # A whitespace-only query would otherwise invoke the LLM and produce arbitrary + # rankings, defeating the "no query" failure path in the caller (coderabbit on #203). + if not query or not query.strip(): + return [] + + from cuga.backend.llm.models import LLMManager + from cuga.backend.cuga_graph.nodes.api.shortlister_agent.prompts.load_prompt import ( + ShortListerOutputLite, + ) + from cuga.backend.cuga_graph.nodes.shared.base_agent import BaseAgent + + effective_instructions = ( + instructions + if instructions is not None + else ( + f"Return the {top_k} most relevant tools (or fewer if not enough are relevant), " + "ordered best-first by relevance. Do not exceed this count." + ) + ) + + prompt = create_chat_prompt_from_templates( + system_path='./prompts/shortlister/system.jinja2', + message_templates=[ + ( + 'human', + """ + Current Apps: {all_apps} + Current Available Tools: {all_tools} + """, + ), + ('ai', 'Sure, now give me the intent'), + ('human', '{input}'), + ], + ) + tools_as_dict, apps_as_dict = PromptUtils._build_shortlister_payload(all_tools, all_apps) + + llm_manager = LLMManager() + model = llm or llm_manager.get_model(settings.agent.code.model) + chain = BaseAgent.get_chain(prompt, model, ShortListerOutputLite) + response = await chain.ainvoke( + { + "input": query, + "all_apps": apps_as_dict, + "all_tools": tools_as_dict, + "instructions": effective_instructions, + } + ) + + valid_names = {t.name for t in all_tools} + ranked: List[str] = [] + seen: set = set() + for api_detail in getattr(response, "result", None) or []: + name = getattr(api_detail, "name", None) + if not name or name in seen or name not in valid_names: + continue + seen.add(name) + ranked.append(name) + if len(ranked) >= top_k: + break + return ranked + @staticmethod def create_find_tools_bound(all_tools: List[StructuredTool], all_apps: List[AppDefinition]): """Create a bound version of find_tools with all_tools and all_apps pre-bound. diff --git a/src/cuga/config.py b/src/cuga/config.py index 83f2877b..c15e512d 100644 --- a/src/cuga/config.py +++ b/src/cuga/config.py @@ -181,6 +181,8 @@ def get_all_paths(config, parent_key=""): Validator("skills.enabled", default=False), Validator("advanced_features.builtin_tools", default=["knowledge"]), Validator("advanced_features.cuga_lite_bind_tools_tool_names", default=[]), + Validator("advanced_features.cuga_lite_bind_tools_max_count", default=128), + Validator("advanced_features.cuga_lite_bind_tools_pad_to_cap", default=False), # Evolve integration Validator("evolve.enabled", default=False), Validator("evolve.url", default="http://127.0.0.1:8201/sse"), diff --git a/src/cuga/settings.toml b/src/cuga/settings.toml index 15d330a9..a9676f9a 100644 --- a/src/cuga/settings.toml +++ b/src/cuga/settings.toml @@ -54,6 +54,10 @@ cuga_lite_bind_tools_mode = "none" cuga_lite_bind_tools_apps = [] # mode=apps or apps_and_tools cuga_lite_bind_tools_tool_names = [] # mode=tools or apps_and_tools (StructuredTool.name) cuga_lite_bind_tools_include_find_tools = false # Also bind find_tools alongside all/apps/tools/apps_and_tools +# Cap on tools sent to LLM.bind_tools (0 disables). Over-cap triggers an LLM shortlister; failure raises. +cuga_lite_bind_tools_max_count = 128 +# Pad shortlister output to fill the cap. Off by default for code-execution mode. +cuga_lite_bind_tools_pad_to_cap = false cuga_lite_nl_auto_continue = false # When model returns NL with no code, LLM-classify interim vs final; if interim, simulate user "continue" and re-call model path_segment_index = 1 # Which path segment to use for operation naming (1 = first, 2 = second, 3 = third) force_autonomous_mode = false diff --git a/src/scripts/run_tests.sh b/src/scripts/run_tests.sh index 78997716..32ff9f9a 100755 --- a/src/scripts/run_tests.sh +++ b/src/scripts/run_tests.sh @@ -80,6 +80,8 @@ run_pytest \ tests/unit/test_chat_knowledge_mode.py \ tests/unit/test_chat_agent_knowledge_toggle.py \ tests/integration/test_knowledge_integration.py +echo "Running cuga_lite bind_tools tests..." +run_pytest tests/unit/test_cuga_lite_bind_tools.py echo "✅ All unit tests passed!" # Check for test type flag diff --git a/tests/unit/test_cuga_lite_bind_tools.py b/tests/unit/test_cuga_lite_bind_tools.py index a85c2183..271dc035 100644 --- a/tests/unit/test_cuga_lite_bind_tools.py +++ b/tests/unit/test_cuga_lite_bind_tools.py @@ -1,6 +1,6 @@ """CugaLite native bind_tools resolution (mode=tools by tool name).""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_core.tools import StructuredTool @@ -109,3 +109,552 @@ async def test_bind_tools_overlay_includes_shell_tools_not_on_registry(): model.bind_tools.assert_called_once() (bound,), _kwargs = model.bind_tools.call_args assert [t.name for t in bound] == ["run_command"] + + +@pytest.mark.asyncio +async def test_bind_tools_mode_all_shortlists_when_over_cap(): + """When mode=all candidate count > max_count, run the LLM shortlister and bind top-K.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def fake_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + return [t.name for t in all_tools[: min(top_k, 3)]] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=3, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=fake_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="find me a hockey scorer", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert [t.name for t in bound] == ["tool_000", "tool_001", "tool_002"] + + +@pytest.mark.asyncio +async def test_bind_tools_mode_all_raises_when_over_cap_without_query(): + """Failing loudly is required to avoid silently corrupting benchmark results.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + with patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=3, + ): + with pytest.raises(RuntimeError, match="provider-safe cap"): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query=None, + ) + model.bind_tools.assert_not_called() + + +@pytest.mark.asyncio +async def test_bind_tools_mode_all_no_cap_when_under_threshold(): + """Under the cap, no shortlister is invoked and all candidates are bound.""" + tools = [_stub_tool(f"tool_{i}") for i in range(3)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + shortlist_calls = MagicMock() + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=128, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=shortlist_calls, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="anything", + ) + + shortlist_calls.assert_not_called() + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert sorted(t.name for t in bound) == ["tool_0", "tool_1", "tool_2"] + + +@pytest.mark.asyncio +async def test_bind_tools_mode_all_disabled_cap_binds_everything(): + """max_count <= 0 disables the cap entirely (WatsonX/permissive backends).""" + tools = [_stub_tool(f"tool_{i}") for i in range(50)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + with patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=0, + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query=None, + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert len(bound) == 50 + + +@pytest.mark.asyncio +async def test_bind_tools_cap_does_not_pad_by_default(): + """By default, only the shortlister's ranked tools are bound — cuga_lite is a code-agent + and binding many native tools regresses code-emission (see commit context).""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def stingy_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + return ["tool_007"] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=5, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.bind_tools.cap.bind_tools_pad_to_cap_from_settings", + return_value=False, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=stingy_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="anything", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert [t.name for t in bound] == ["tool_007"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_pads_when_opt_in(): + """When pad_to_cap=True, the shortlist is filled with the remaining tools up to target_k.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def stingy_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + return ["tool_007"] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=5, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.bind_tools.cap.bind_tools_pad_to_cap_from_settings", + return_value=True, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=stingy_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="anything", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + names = [t.name for t in bound] + assert len(names) == 5 + assert names[0] == "tool_007" + assert names[1:] == ["tool_000", "tool_001", "tool_002", "tool_003"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_not_violated_when_at_boundary_with_find_tools(): + """Boundary case from coderabbit: len(bound) == max_count and include_find_tools=True. + + Without the effective-count check, the under-cap fast path would append find_tools and + return max_count+1 tools — provider rejects. + """ + tools = [_stub_tool(f"tool_{i:03d}") for i in range(5)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + captured_top_k = {} + + async def fake_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + captured_top_k["value"] = top_k + return [t.name for t in all_tools[:top_k]] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=5, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=fake_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + "cuga_lite_bind_tools_include_find_tools": True, + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert len(bound) == 5, f"cap violated: bound has {len(bound)} tools (max=5)" + assert captured_top_k["value"] == 4, "expected 1 slot reserved for find_tools" + assert bound[-1].name == "find_tools" + + +@pytest.mark.asyncio +async def test_bind_tools_cap_guarantees_find_tools_when_in_overlay_bound(): + """If `find_tools` is in `bound` via the overlay and `include_find_tools=True`, it must + survive shortlisting — the LLM ranker is allowed to drop any tool from its ranking input, + so we pull find_tools out of the ranking pool and reserve a cap slot for it. + + Regression test for coderabbit comment on #203 (`include_find_tools` contract). + """ + tools = [_stub_tool(f"tool_{i:03d}") for i in range(5)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools + [find_tools_tool]) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + captured = {} + + async def fake_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + captured["top_k"] = top_k + captured["pool_names"] = [t.name for t in all_tools] + # Adversarial: rank find_tools out if it appears in the input. With the fix it shouldn't. + return [t.name for t in all_tools[:top_k]] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=4, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=fake_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + "cuga_lite_bind_tools_include_find_tools": True, + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert len(bound) == 4, f"cap violated: bound has {len(bound)} (cap=4)" + # Reserve 1 slot for find_tools, so the shortlister sees top_k = cap - 1. + assert captured["top_k"] == 3 + # find_tools must not be exposed to the ranker (it can't be evicted that way). + assert "find_tools" not in captured["pool_names"] + # find_tools must end up bound regardless of how the ranker votes. + assert "find_tools" in {t.name for t in bound} + assert bound[-1].name == "find_tools" + + +@pytest.mark.asyncio +async def test_bind_tools_include_find_tools_false_strips_overlay_find_tools(): + """Overlay can inject `find_tools` into `bound` regardless of `include_find_tools`. + + When the user sets `include_find_tools=False`, the overlay-injected tool must be stripped + so it can't consume a cap slot or be ranked. Regression test for coderabbit on #203. + """ + tools = [_stub_tool(f"tool_{i:03d}") for i in range(5)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + # Overlay path puts find_tools into bound directly. + provider.get_all_tools = AsyncMock(return_value=tools + [find_tools_tool]) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + captured = {} + + async def fake_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + captured["pool_names"] = [t.name for t in all_tools] + return [t.name for t in all_tools[:top_k]] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=3, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=fake_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + # NB: include_find_tools is False (default in configurable here) + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + bound_names = {t.name for t in bound} + assert "find_tools" not in bound_names, ( + f"find_tools leaked through overlay despite include_find_tools=False: {bound_names}" + ) + # The shortlister must not have seen find_tools either. + assert "find_tools" not in captured["pool_names"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_raises_when_shortlist_names_dont_match_pool(): + """LLM-hallucinated shortlist names (non-empty list, zero matches in pool) must fail + loudly. Without the guard we'd silently pad or bind just find_tools, recreating the + silent degradation the cap path exists to prevent. Coderabbit on #203.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def hallucinating_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + return ["nonexistent_tool_a", "nonexistent_tool_b"] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=3, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=hallucinating_shortlist, + ), + ): + with pytest.raises(RuntimeError, match="hallucinated"): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="hockey", + ) + model.bind_tools.assert_not_called() + + +@pytest.mark.asyncio +async def test_bind_tools_cap_clamps_shortlist_when_llm_returns_too_many(): + """If the shortlister (e.g. a non-compliant custom impl or future refactor) returns more + valid names than ``top_k``, the call site must still enforce ``target_k`` so the bound + list never exceeds the provider-safe cap. Regression test for coderabbit on #203.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(20)] + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def overlong_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + # Deliberately ignore top_k — return every pool name. + return [t.name for t in all_tools] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=4, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=overlong_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={"cuga_lite_bind_tools_mode": "all"}, + tools_context_ref={}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert len(bound) == 4, f"cap violated: bound has {len(bound)} (max=4)" + # First four ranked names — earlier names win via the in-order break. + assert [t.name for t in bound] == ["tool_000", "tool_001", "tool_002", "tool_003"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_clamps_shortlist_with_find_tools_slot(): + """Same clamp must hold when ``include_find_tools=True`` reserves a slot: + ``target_k = max_count - 1`` is the upper bound on shortlisted entries.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(20)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + async def overlong_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + return [t.name for t in all_tools] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=4, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=overlong_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + "cuga_lite_bind_tools_include_find_tools": True, + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert len(bound) == 4, f"cap violated: bound has {len(bound)} (max=4)" + # 3 shortlisted + 1 find_tools at the end. + assert [t.name for t in bound] == ["tool_000", "tool_001", "tool_002", "find_tools"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_binds_only_find_tools_when_max_count_is_one(): + """`max_count=1` + `include_find_tools=True` should still succeed by binding only + find_tools, instead of raising as "cap too small to fit even find_tools".""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(5)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + with patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=1, + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + "cuga_lite_bind_tools_include_find_tools": True, + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert [t.name for t in bound] == ["find_tools"] + + +@pytest.mark.asyncio +async def test_bind_tools_cap_reserves_slot_for_find_tools(): + """When include_find_tools is on, the cap reserves 1 slot for find_tools.""" + tools = [_stub_tool(f"tool_{i:03d}") for i in range(10)] + find_tools_tool = _stub_tool("find_tools") + provider = AsyncMock() + provider.get_all_tools = AsyncMock(return_value=tools) + provider.get_apps = AsyncMock(return_value=[]) + model = MagicMock() + + captured_top_k = {} + + async def fake_shortlist(query, all_tools, all_apps, llm=None, top_k=4, instructions=None): + captured_top_k["value"] = top_k + return [t.name for t in all_tools[:top_k]] + + with ( + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.bind_tools_max_count_from_settings", + return_value=4, + ), + patch( + "cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_graph.PromptUtils.shortlist_tool_names", + side_effect=fake_shortlist, + ), + ): + await resolve_model_with_bind_tools( + model, + configurable={ + "cuga_lite_bind_tools_mode": "all", + "cuga_lite_bind_tools_include_find_tools": True, + }, + tools_context_ref={"_lc_bind_tools_find_tools": find_tools_tool}, + tool_provider=provider, + query="hockey", + ) + + assert captured_top_k["value"] == 3 + model.bind_tools.assert_called_once() + (bound,), _kwargs = model.bind_tools.call_args + assert [t.name for t in bound] == ["tool_000", "tool_001", "tool_002", "find_tools"]