Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ms_agent/tools/search/search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ class SearchEngineType(enum.Enum):
EXA = 'exa'
SERPAPI = 'serpapi'
ARXIV = 'arxiv'
TAVILY = 'tavily'


# Mapping from engine type to tool name
ENGINE_TOOL_NAMES: Dict[str, str] = {
'exa': 'exa_search',
'serpapi': 'serpapi_search',
'arxiv': 'arxiv_search',
'tavily': 'tavily_search',
}


Expand Down
79 changes: 79 additions & 0 deletions ms_agent/tools/search/search_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ms_agent.tools.search.exa import ExaSearchRequest
from ms_agent.tools.search.search_base import SearchEngineType, SearchRequest
from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest
from ms_agent.tools.search.tavily.schema import TavilySearchRequest


class SearchRequestGenerator:
Expand Down Expand Up @@ -256,6 +257,82 @@ def create_request(self,
return ArxivSearchRequest(**search_request_d)


class TavilySearchRequestGenerator(SearchRequestGenerator):

def get_args_template(self) -> str:
return '{"query": "xxx", "num_results": 5, "search_depth": "advanced", "topic": "general"}'

def get_json_schema(self,
num_queries: int,
is_strict: bool = True) -> Dict[str, Any]:
return {
'name': 'search_requests',
'strict': is_strict,
'schema': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'query': {
'type': 'string',
'description': (
'Write a concise, keyword-rich search query optimized for web search. '
'Keep under 400 characters for best results. Follow these rules:\n'
'1) Use exact-match quotes for key phrases (e.g., "contrastive learning").\n'
'2) Combine terms naturally; avoid overly complex Boolean syntax.\n'
'3) Keep it concise and focused on the core research intent.\n'
'4) Chinese queries are supported directly.\n\n'
'Examples:\n'
'- "retrieval augmented generation" evaluation benchmarks\n'
'- large language model medical applications 2024\n'
'- 大模型 函数调用 工具调用\n')
},
'num_results': {
'type': 'integer',
'description': 'The number of results to return (1-10). '
'Choose a value appropriate to the query complexity (e.g., 5)',
},
'search_depth': {
'type': 'string',
'enum': ['basic', 'advanced'],
'description': 'Search depth. "basic" for quick results, '
'"advanced" for higher relevance. Default is "advanced".',
},
'topic': {
'type': 'string',
'enum': ['general', 'news', 'finance'],
'description': 'Topic category for the search. Default is "general".',
},
'research_goal': {
'type': 'string',
'description': 'The goal of the research and additional research directions'
}
},
'required': ['query', 'num_results', 'research_goal']
},
'description': f'List of Tavily search queries, max of {num_queries}'
}
}

def get_rewrite_prompt(self) -> str:
return (
f'生成search request,具体要求为: '
f'\n1. 必须符合以下arguments格式:{self.get_args_template()}'
f'\n2. 其中,query参数的值通过分析用户原始输入中的有效问题部分生成,即{self.user_prompt},要求为精简的关键词查询,'
f'例如,用户输入"请帮我查找2023年发表的关于大语言模型在医疗领域应用的最新研究",则query参数的值应为"large language model medical applications 2023";'
f'\n3. 参数需要符合搜索引擎的要求,num_results需要根据实际问题的复杂程度来估算,最大10,最小1;'
f'\n4. search_depth参数默认为"advanced",如需快速结果可设为"basic";'
f'\n5. topic参数默认为"general",如搜索新闻可设为"news",搜索金融相关可设为"finance"')

def create_request(self,
search_request_d: Dict[str, Any]) -> TavilySearchRequest:
# Filter out keys not in TavilySearchRequest fields
valid_keys = {'query', 'num_results', 'search_depth', 'topic',
'include_domains', 'exclude_domains'}
filtered = {k: v for k, v in search_request_d.items() if k in valid_keys}
return TavilySearchRequest(**filtered)
Comment on lines +329 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation filters out the research_goal field, which is defined as a required field in get_json_schema. This is inconsistent with other request generators and likely causes a bug, as research_goal is probably used by downstream components.

To fix this, you can remove this filtering logic. This will also require adding research_goal: Optional[str] = None to the TavilySearchRequest dataclass in ms_agent/tools/search/tavily/schema.py to avoid an error when instantiating the dataclass.

Suggested change
# Filter out keys not in TavilySearchRequest fields
valid_keys = {'query', 'num_results', 'search_depth', 'topic',
'include_domains', 'exclude_domains'}
filtered = {k: v for k, v in search_request_d.items() if k in valid_keys}
return TavilySearchRequest(**filtered)
return TavilySearchRequest(**search_request_d)



def get_search_request_generator(engine_type: SearchEngineType,
user_prompt: str) -> SearchRequestGenerator:
"""
Expand All @@ -277,5 +354,7 @@ def get_search_request_generator(engine_type: SearchEngineType,
return SerpApiSearchRequestGenerator(user_prompt)
elif engine_type == SearchEngineType.ARXIV:
return ArxivSearchRequestGenerator(user_prompt)
elif engine_type == SearchEngineType.TAVILY:
return TavilySearchRequestGenerator(user_prompt)
else:
raise ValueError(f'Unsupported search engine type: {engine_type}')
3 changes: 3 additions & 0 deletions ms_agent/tools/search/tavily/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa
from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult
from ms_agent.tools.search.tavily.search import TavilySearch
93 changes: 93 additions & 0 deletions ms_agent/tools/search/tavily/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# flake8: noqa
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import json


@dataclass
class TavilySearchRequest:

# The search query string
query: str

# Number of results to return, default is 5
num_results: Optional[int] = 5

# Search depth: 'basic' or 'advanced'
search_depth: Optional[str] = 'advanced'

# Topic category: 'general', 'news', or 'finance'
topic: Optional[str] = 'general'

# Domains to include in search
include_domains: Optional[List[str]] = None

# Domains to exclude from search
exclude_domains: Optional[List[str]] = None

def to_dict(self) -> Dict[str, Any]:
"""Convert the request parameters to a dictionary."""
d = {
'query': self.query,
'max_results': self.num_results,
'search_depth': self.search_depth,
'topic': self.topic,
}
if self.include_domains:
d['include_domains'] = self.include_domains
if self.exclude_domains:
d['exclude_domains'] = self.exclude_domains
return d

def to_json(self) -> str:
"""Convert the request parameters to a JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False)


@dataclass
class TavilySearchResult:

# The original search query string
query: str

# Optional arguments for the search request
arguments: Dict[str, Any] = field(default_factory=dict)

# The raw response from the Tavily search API
response: Any = None

def to_list(self) -> List[Dict[str, Any]]:
"""Convert the search results to a list of dictionaries."""
if not self.response or not self.response.get('results'):
print('***Warning: No search results found.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print() for warnings in library code is not ideal as it writes to standard output and can't be easily controlled (e.g., silenced, redirected to a file, or formatted). It's better to use the application's logger.

You can add from ms_agent.utils.logger import get_logger and logger = get_logger() at the top of the file, then replace this print call with logger.warning().

Suggested change
print('***Warning: No search results found.')
logger.warning('No search results found.')

return []

if not self.query:
print('***Warning: No query provided for search results.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the comment above, it's better to use a logger for warnings instead of print().

Suggested change
print('***Warning: No query provided for search results.')
logger.warning('No query provided for search results.')

return []

res_list: List[Dict[str, Any]] = []
for res in self.response['results']:
res_list.append({
'url': res.get('url', ''),
'id': res.get('url', ''),
'title': res.get('title', ''),
'summary': res.get('content', ''),
'published_date': res.get('published_date', ''),
})

return res_list

@staticmethod
def load_from_disk(file_path: str) -> List[Dict[str, Any]]:
"""Load search results from a local file."""
import os
if not os.path.exists(file_path):
return []

with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f'Search results loaded from {file_path}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and better log management, please use the application logger (logger.info) instead of print() for this informational message.

Suggested change
print(f'Search results loaded from {file_path}')
logger.info(f'Search results loaded from {file_path}')


return data
108 changes: 108 additions & 0 deletions ms_agent/tools/search/tavily/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# flake8: noqa
import os
from typing import TYPE_CHECKING

from tavily import TavilyClient
from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult
from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType

if TYPE_CHECKING:
from ms_agent.llm.utils import Tool


class TavilySearch(SearchEngine):
"""
Search engine using Tavily API.

Best for: AI-optimized web search with high relevance,
real-time data retrieval for LLM applications.
"""

engine_type = SearchEngineType.TAVILY

def __init__(self, api_key: str = None):
api_key = api_key or os.getenv('TAVILY_API_KEY')
assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable'

self.client = TavilyClient(api_key=api_key)

def search(self, search_request: TavilySearchRequest) -> TavilySearchResult:
"""
Perform a search using the Tavily API with the provided search request parameters.

:param search_request: An instance of TavilySearchRequest containing search parameters.
:return: An instance of TavilySearchResult containing the search results.
"""
search_args: dict = search_request.to_dict()
search_result = TavilySearchResult(
query=search_request.query,
arguments=search_args,
)
try:
search_result.response = self.client.search(**search_args)
except Exception as e:
raise RuntimeError(f'Failed to perform Tavily search: {e}') from e

return search_result

@classmethod
def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool':
"""Return the tool definition for Tavily search engine."""
from ms_agent.llm.utils import Tool
return Tool(
tool_name=cls.get_tool_name(),
server_name=server_name,
description=(
'Search the web using Tavily AI search engine. '
'Best for: AI-optimized web search with high relevance, '
'real-time data retrieval for LLM applications.'),
parameters={
'type': 'object',
'properties': {
'query': {
'type':
'string',
'description':
'The search query. Keep under 400 characters for best results.',
},
'num_results': {
'type':
'integer',
'minimum':
1,
'maximum':
10,
'description':
'Number of results to return. Default is 5.',
},
'search_depth': {
'type':
'string',
'enum': ['basic', 'advanced'],
'description':
('Search depth. "basic" for quick results, '
'"advanced" for higher relevance. '
'Default is "advanced".'),
},
'topic': {
'type':
'string',
'enum': ['general', 'news', 'finance'],
'description':
('Topic category for the search. '
'Default is "general".'),
},
},
'required': ['query'],
},
)

@classmethod
def build_request_from_args(cls, **kwargs) -> TavilySearchRequest:
"""Build TavilySearchRequest from tool call arguments."""
return TavilySearchRequest(
query=kwargs['query'],
num_results=kwargs.get('num_results', 5),
search_depth=kwargs.get('search_depth', 'advanced'),
topic=kwargs.get('topic', 'general'),
)
20 changes: 16 additions & 4 deletions ms_agent/tools/search/websearch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]:
Get search engine class by type.

Args:
engine_type: One of 'exa', 'serpapi', 'arxiv'
engine_type: One of 'exa', 'serpapi', 'arxiv', 'tavily'

Returns:
SearchEngine class (not instance)
Expand All @@ -206,6 +206,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]:
elif engine_type == 'arxiv':
from ms_agent.tools.search.arxiv import ArxivSearch
return ArxivSearch
elif engine_type == 'tavily':
from ms_agent.tools.search.tavily import TavilySearch
return TavilySearch
else:
logger.warning(
f"Unknown search engine '{engine_type}', falling back to arxiv")
Expand All @@ -220,7 +223,7 @@ def get_search_engine(engine_type: str,
Get search engine instance by type.

Args:
engine_type: One of 'exa', 'serpapi', 'arxiv'
engine_type: One of 'exa', 'serpapi', 'arxiv', 'tavily'
api_key: API key for the search engine (if required)
**kwargs: Additional arguments passed to engine constructor
"""
Expand All @@ -241,6 +244,9 @@ def get_search_engine(engine_type: str,
elif engine_type == 'arxiv':
from ms_agent.tools.search.arxiv import ArxivSearch
return ArxivSearch()
elif engine_type == 'tavily':
from ms_agent.tools.search.tavily import TavilySearch
return TavilySearch(api_key=api_key or os.getenv('TAVILY_API_KEY'))
else:
logger.warning(
f"Unknown search engine '{engine_type}', falling back to arxiv")
Expand All @@ -265,7 +271,7 @@ def build_search_request(engine_type: str,
class WebSearchTool(ToolBase):
"""
Unified web search tool for agents. It can search the web and fetch page content.
- Search via multiple engines (Exa, SerpAPI, Arxiv)
- Search via multiple engines (Exa, SerpAPI, Arxiv, Tavily)
- Dynamic tool definitions based on configured engines
- Auto-fetch and parse page content
- Configurable content fetcher (jina_reader, docling, etc.)
Expand Down Expand Up @@ -296,7 +302,7 @@ class WebSearchTool(ToolBase):
SERVER_NAME = 'web_search'

# Registry of supported search engines
SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv')
SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily')

# Process-wide (class-level) usage tracking for summarization calls.
# This is intentionally separate from LLMAgent usage totals.
Expand Down Expand Up @@ -404,6 +410,9 @@ def __init__(self, config, **kwargs):
'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None)
or os.getenv('SERPAPI_API_KEY'))
if tool_cfg else os.getenv('SERPAPI_API_KEY'),
'tavily': (getattr(tool_cfg, 'tavily_api_key', None)
or os.getenv('TAVILY_API_KEY'))
if tool_cfg else os.getenv('TAVILY_API_KEY'),
}

# SerpApi provider (google, bing, baidu)
Expand Down Expand Up @@ -508,6 +517,9 @@ async def connect(self) -> None:
api_key=self._api_keys.get('serpapi'),
provider=self._serpapi_provider,
)
elif engine_type == 'tavily':
self._engines[engine_type] = engine_cls(
api_key=self._api_keys.get('tavily'))
else: # arxiv
self._engines[engine_type] = engine_cls()

Expand Down
Loading