diff --git a/ms_agent/tools/search/search_base.py b/ms_agent/tools/search/search_base.py index bb6952729..8a10c9d20 100644 --- a/ms_agent/tools/search/search_base.py +++ b/ms_agent/tools/search/search_base.py @@ -17,6 +17,7 @@ class SearchEngineType(enum.Enum): EXA = 'exa' SERPAPI = 'serpapi' ARXIV = 'arxiv' + TAVILY = 'tavily' # Mapping from engine type to tool name @@ -24,6 +25,7 @@ class SearchEngineType(enum.Enum): 'exa': 'exa_search', 'serpapi': 'serpapi_search', 'arxiv': 'arxiv_search', + 'tavily': 'tavily_search', } diff --git a/ms_agent/tools/search/search_request.py b/ms_agent/tools/search/search_request.py index e2aed6666..065afa2b1 100644 --- a/ms_agent/tools/search/search_request.py +++ b/ms_agent/tools/search/search_request.py @@ -7,6 +7,7 @@ from ms_agent.tools.search.exa import ExaSearchRequest from ms_agent.tools.search.search_base import SearchEngineType, SearchRequest from ms_agent.tools.search.serpapi.schema import SerpApiSearchRequest +from ms_agent.tools.search.tavily.schema import TavilySearchRequest class SearchRequestGenerator: @@ -256,6 +257,82 @@ def create_request(self, return ArxivSearchRequest(**search_request_d) +class TavilySearchRequestGenerator(SearchRequestGenerator): + + def get_args_template(self) -> str: + return '{"query": "xxx", "num_results": 5, "search_depth": "advanced", "topic": "general"}' + + def get_json_schema(self, + num_queries: int, + is_strict: bool = True) -> Dict[str, Any]: + return { + 'name': 'search_requests', + 'strict': is_strict, + 'schema': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': ( + 'Write a concise, keyword-rich search query optimized for web search. ' + 'Keep under 400 characters for best results. Follow these rules:\n' + '1) Use exact-match quotes for key phrases (e.g., "contrastive learning").\n' + '2) Combine terms naturally; avoid overly complex Boolean syntax.\n' + '3) Keep it concise and focused on the core research intent.\n' + '4) Chinese queries are supported directly.\n\n' + 'Examples:\n' + '- "retrieval augmented generation" evaluation benchmarks\n' + '- large language model medical applications 2024\n' + '- 大模型 函数调用 工具调用\n') + }, + 'num_results': { + 'type': 'integer', + 'description': 'The number of results to return (1-10). ' + 'Choose a value appropriate to the query complexity (e.g., 5)', + }, + 'search_depth': { + 'type': 'string', + 'enum': ['basic', 'advanced'], + 'description': 'Search depth. "basic" for quick results, ' + '"advanced" for higher relevance. Default is "advanced".', + }, + 'topic': { + 'type': 'string', + 'enum': ['general', 'news', 'finance'], + 'description': 'Topic category for the search. Default is "general".', + }, + 'research_goal': { + 'type': 'string', + 'description': 'The goal of the research and additional research directions' + } + }, + 'required': ['query', 'num_results', 'research_goal'] + }, + 'description': f'List of Tavily search queries, max of {num_queries}' + } + } + + def get_rewrite_prompt(self) -> str: + return ( + f'生成search request,具体要求为: ' + f'\n1. 必须符合以下arguments格式:{self.get_args_template()}' + f'\n2. 其中,query参数的值通过分析用户原始输入中的有效问题部分生成,即{self.user_prompt},要求为精简的关键词查询,' + f'例如,用户输入"请帮我查找2023年发表的关于大语言模型在医疗领域应用的最新研究",则query参数的值应为"large language model medical applications 2023";' + f'\n3. 参数需要符合搜索引擎的要求,num_results需要根据实际问题的复杂程度来估算,最大10,最小1;' + f'\n4. search_depth参数默认为"advanced",如需快速结果可设为"basic";' + f'\n5. topic参数默认为"general",如搜索新闻可设为"news",搜索金融相关可设为"finance"') + + def create_request(self, + search_request_d: Dict[str, Any]) -> TavilySearchRequest: + # Filter out keys not in TavilySearchRequest fields + valid_keys = {'query', 'num_results', 'search_depth', 'topic', + 'include_domains', 'exclude_domains'} + filtered = {k: v for k, v in search_request_d.items() if k in valid_keys} + return TavilySearchRequest(**filtered) + + def get_search_request_generator(engine_type: SearchEngineType, user_prompt: str) -> SearchRequestGenerator: """ @@ -277,5 +354,7 @@ def get_search_request_generator(engine_type: SearchEngineType, return SerpApiSearchRequestGenerator(user_prompt) elif engine_type == SearchEngineType.ARXIV: return ArxivSearchRequestGenerator(user_prompt) + elif engine_type == SearchEngineType.TAVILY: + return TavilySearchRequestGenerator(user_prompt) else: raise ValueError(f'Unsupported search engine type: {engine_type}') diff --git a/ms_agent/tools/search/tavily/__init__.py b/ms_agent/tools/search/tavily/__init__.py new file mode 100644 index 000000000..974b4a914 --- /dev/null +++ b/ms_agent/tools/search/tavily/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa +from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.tools.search.tavily.search import TavilySearch diff --git a/ms_agent/tools/search/tavily/schema.py b/ms_agent/tools/search/tavily/schema.py new file mode 100644 index 000000000..f6ba76f1b --- /dev/null +++ b/ms_agent/tools/search/tavily/schema.py @@ -0,0 +1,93 @@ +# flake8: noqa +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import json + + +@dataclass +class TavilySearchRequest: + + # The search query string + query: str + + # Number of results to return, default is 5 + num_results: Optional[int] = 5 + + # Search depth: 'basic' or 'advanced' + search_depth: Optional[str] = 'advanced' + + # Topic category: 'general', 'news', or 'finance' + topic: Optional[str] = 'general' + + # Domains to include in search + include_domains: Optional[List[str]] = None + + # Domains to exclude from search + exclude_domains: Optional[List[str]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the request parameters to a dictionary.""" + d = { + 'query': self.query, + 'max_results': self.num_results, + 'search_depth': self.search_depth, + 'topic': self.topic, + } + if self.include_domains: + d['include_domains'] = self.include_domains + if self.exclude_domains: + d['exclude_domains'] = self.exclude_domains + return d + + def to_json(self) -> str: + """Convert the request parameters to a JSON string.""" + return json.dumps(self.to_dict(), ensure_ascii=False) + + +@dataclass +class TavilySearchResult: + + # The original search query string + query: str + + # Optional arguments for the search request + arguments: Dict[str, Any] = field(default_factory=dict) + + # The raw response from the Tavily search API + response: Any = None + + def to_list(self) -> List[Dict[str, Any]]: + """Convert the search results to a list of dictionaries.""" + if not self.response or not self.response.get('results'): + print('***Warning: No search results found.') + return [] + + if not self.query: + print('***Warning: No query provided for search results.') + return [] + + res_list: List[Dict[str, Any]] = [] + for res in self.response['results']: + res_list.append({ + 'url': res.get('url', ''), + 'id': res.get('url', ''), + 'title': res.get('title', ''), + 'summary': res.get('content', ''), + 'published_date': res.get('published_date', ''), + }) + + return res_list + + @staticmethod + def load_from_disk(file_path: str) -> List[Dict[str, Any]]: + """Load search results from a local file.""" + import os + if not os.path.exists(file_path): + return [] + + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + print(f'Search results loaded from {file_path}') + + return data diff --git a/ms_agent/tools/search/tavily/search.py b/ms_agent/tools/search/tavily/search.py new file mode 100644 index 000000000..9c8249e83 --- /dev/null +++ b/ms_agent/tools/search/tavily/search.py @@ -0,0 +1,108 @@ +# flake8: noqa +import os +from typing import TYPE_CHECKING + +from tavily import TavilyClient +from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult +from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType + +if TYPE_CHECKING: + from ms_agent.llm.utils import Tool + + +class TavilySearch(SearchEngine): + """ + Search engine using Tavily API. + + Best for: AI-optimized web search with high relevance, + real-time data retrieval for LLM applications. + """ + + engine_type = SearchEngineType.TAVILY + + def __init__(self, api_key: str = None): + api_key = api_key or os.getenv('TAVILY_API_KEY') + assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable' + + self.client = TavilyClient(api_key=api_key) + + def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: + """ + Perform a search using the Tavily API with the provided search request parameters. + + :param search_request: An instance of TavilySearchRequest containing search parameters. + :return: An instance of TavilySearchResult containing the search results. + """ + search_args: dict = search_request.to_dict() + search_result = TavilySearchResult( + query=search_request.query, + arguments=search_args, + ) + try: + search_result.response = self.client.search(**search_args) + except Exception as e: + raise RuntimeError(f'Failed to perform Tavily search: {e}') from e + + return search_result + + @classmethod + def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': + """Return the tool definition for Tavily search engine.""" + from ms_agent.llm.utils import Tool + return Tool( + tool_name=cls.get_tool_name(), + server_name=server_name, + description=( + 'Search the web using Tavily AI search engine. ' + 'Best for: AI-optimized web search with high relevance, ' + 'real-time data retrieval for LLM applications.'), + parameters={ + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'The search query. Keep under 400 characters for best results.', + }, + 'num_results': { + 'type': + 'integer', + 'minimum': + 1, + 'maximum': + 10, + 'description': + 'Number of results to return. Default is 5.', + }, + 'search_depth': { + 'type': + 'string', + 'enum': ['basic', 'advanced'], + 'description': + ('Search depth. "basic" for quick results, ' + '"advanced" for higher relevance. ' + 'Default is "advanced".'), + }, + 'topic': { + 'type': + 'string', + 'enum': ['general', 'news', 'finance'], + 'description': + ('Topic category for the search. ' + 'Default is "general".'), + }, + }, + 'required': ['query'], + }, + ) + + @classmethod + def build_request_from_args(cls, **kwargs) -> TavilySearchRequest: + """Build TavilySearchRequest from tool call arguments.""" + return TavilySearchRequest( + query=kwargs['query'], + num_results=kwargs.get('num_results', 5), + search_depth=kwargs.get('search_depth', 'advanced'), + topic=kwargs.get('topic', 'general'), + ) diff --git a/ms_agent/tools/search/websearch_tool.py b/ms_agent/tools/search/websearch_tool.py index 7bbecfe89..13e38d123 100644 --- a/ms_agent/tools/search/websearch_tool.py +++ b/ms_agent/tools/search/websearch_tool.py @@ -190,7 +190,7 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: Get search engine class by type. Args: - engine_type: One of 'exa', 'serpapi', 'arxiv' + engine_type: One of 'exa', 'serpapi', 'arxiv', 'tavily' Returns: SearchEngine class (not instance) @@ -206,6 +206,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]: elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -220,7 +223,7 @@ def get_search_engine(engine_type: str, Get search engine instance by type. Args: - engine_type: One of 'exa', 'serpapi', 'arxiv' + engine_type: One of 'exa', 'serpapi', 'arxiv', 'tavily' api_key: API key for the search engine (if required) **kwargs: Additional arguments passed to engine constructor """ @@ -241,6 +244,9 @@ def get_search_engine(engine_type: str, elif engine_type == 'arxiv': from ms_agent.tools.search.arxiv import ArxivSearch return ArxivSearch() + elif engine_type == 'tavily': + from ms_agent.tools.search.tavily import TavilySearch + return TavilySearch(api_key=api_key or os.getenv('TAVILY_API_KEY')) else: logger.warning( f"Unknown search engine '{engine_type}', falling back to arxiv") @@ -265,7 +271,7 @@ def build_search_request(engine_type: str, class WebSearchTool(ToolBase): """ Unified web search tool for agents. It can search the web and fetch page content. - - Search via multiple engines (Exa, SerpAPI, Arxiv) + - Search via multiple engines (Exa, SerpAPI, Arxiv, Tavily) - Dynamic tool definitions based on configured engines - Auto-fetch and parse page content - Configurable content fetcher (jina_reader, docling, etc.) @@ -296,7 +302,7 @@ class WebSearchTool(ToolBase): SERVER_NAME = 'web_search' # Registry of supported search engines - SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv') + SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily') # Process-wide (class-level) usage tracking for summarization calls. # This is intentionally separate from LLMAgent usage totals. @@ -404,6 +410,9 @@ def __init__(self, config, **kwargs): 'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None) or os.getenv('SERPAPI_API_KEY')) if tool_cfg else os.getenv('SERPAPI_API_KEY'), + 'tavily': (getattr(tool_cfg, 'tavily_api_key', None) + or os.getenv('TAVILY_API_KEY')) + if tool_cfg else os.getenv('TAVILY_API_KEY'), } # SerpApi provider (google, bing, baidu) @@ -508,6 +517,9 @@ async def connect(self) -> None: api_key=self._api_keys.get('serpapi'), provider=self._serpapi_provider, ) + elif engine_type == 'tavily': + self._engines[engine_type] = engine_cls( + api_key=self._api_keys.get('tavily')) else: # arxiv self._engines[engine_type] = engine_cls() diff --git a/ms_agent/tools/search_engine.py b/ms_agent/tools/search_engine.py index 6bb0be4d6..127d9f181 100644 --- a/ms_agent/tools/search_engine.py +++ b/ms_agent/tools/search_engine.py @@ -8,6 +8,7 @@ from ms_agent.tools.search.exa import ExaSearch from ms_agent.tools.search.search_base import SearchEngineType from ms_agent.tools.search.serpapi import SerpApiSearch +from ms_agent.tools.search.tavily import TavilySearch from ms_agent.utils.logger import get_logger logger = get_logger() @@ -23,7 +24,8 @@ def set_search_env_overrides(env_overrides: Optional[Dict[str, str]]) -> None: Expected keys (all optional): - 'EXA_API_KEY' - 'SERPAPI_API_KEY' - - SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv') + - 'TAVILY_API_KEY' + - SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv' / 'tavily') """ if not env_overrides: if hasattr(_search_env_local, 'overrides'): @@ -135,7 +137,8 @@ def get_web_search_tool(config_file: str): or '')).strip().lower() if engine_override and engine_override in (SearchEngineType.EXA.value, SearchEngineType.SERPAPI.value, - SearchEngineType.ARXIV.value): + SearchEngineType.ARXIV.value, + SearchEngineType.TAVILY.value): search_config['engine'] = engine_override engine_name = (search_config.get('engine', '') or '').lower() @@ -143,6 +146,7 @@ def get_web_search_tool(config_file: str): # Per-request API key overrides (thread-local) take precedence override_exa_key = local_env.get('EXA_API_KEY') override_serp_key = local_env.get('SERPAPI_API_KEY') + override_tavily_key = local_env.get('TAVILY_API_KEY') if engine_name == SearchEngineType.EXA.value: return ExaSearch( @@ -153,7 +157,14 @@ def get_web_search_tool(config_file: str): api_key=override_serp_key or search_config.get( 'serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)), provider=search_config.get('provider', 'google').lower()) + elif engine_name == SearchEngineType.TAVILY.value: + return TavilySearch( + api_key=override_tavily_key or search_config.get( + 'tavily_api_key', os.getenv('TAVILY_API_KEY', None))) elif engine_name == SearchEngineType.ARXIV.value: return ArxivSearch() else: + logger.warning( + f'Unknown search engine "{engine_name}", falling back to ArxivSearch. ' + f'Valid engines: {[e.value for e in SearchEngineType]}') return ArxivSearch() diff --git a/projects/deep_research/.env.example b/projects/deep_research/.env.example index 59ec0d03b..2fbd9d820 100644 --- a/projects/deep_research/.env.example +++ b/projects/deep_research/.env.example @@ -1,5 +1,6 @@ EXA_API_KEY=xxx SERPAPI_API_KEY=xxx +TAVILY_API_KEY=xxx OPENAI_API_KEY=xxx OPENAI_BASE_URL=https://your-openai-compatible-endpoint/v1 diff --git a/projects/deep_research/conf.yaml b/projects/deep_research/conf.yaml index 9f2236cf9..18d2207d4 100644 --- a/projects/deep_research/conf.yaml +++ b/projects/deep_research/conf.yaml @@ -1,4 +1,4 @@ -## Search Engine, Supported values: EXA, SERPAPI, ARXIV ## +## Search Engine, Supported values: EXA, SERPAPI, ARXIV, TAVILY ## #SEARCH_ENGINE: # engine: exa @@ -9,5 +9,11 @@ # serpapi_api_key: $SERPAPI_API_KEY # provider: google +# NOTE: Tavily requires tavily-python package (pip install tavily-python) +# and a valid TAVILY_API_KEY in your .env file. +#SEARCH_ENGINE: +# engine: tavily +# tavily_api_key: $TAVILY_API_KEY + SEARCH_ENGINE: engine: arxiv diff --git a/requirements/research.txt b/requirements/research.txt index 67ce2d3fd..5863522d5 100644 --- a/requirements/research.txt +++ b/requirements/research.txt @@ -3,6 +3,7 @@ docling<=2.38.1 docling-core<=2.38.2 exa-py google-search-results +tavily-python gradio>=5.0.0 json5 markdown