-
Notifications
You must be signed in to change notification settings - Fork 478
feat: add Tavily search engine option to deep_research configuration #892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # flake8: noqa | ||
| from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult | ||
| from ms_agent.tools.search.tavily.search import TavilySearch |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,93 @@ | ||||||
| # flake8: noqa | ||||||
| from dataclasses import dataclass, field | ||||||
| from typing import Any, Dict, List, Optional | ||||||
|
|
||||||
| import json | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class TavilySearchRequest: | ||||||
|
|
||||||
| # The search query string | ||||||
| query: str | ||||||
|
|
||||||
| # Number of results to return, default is 5 | ||||||
| num_results: Optional[int] = 5 | ||||||
|
|
||||||
| # Search depth: 'basic' or 'advanced' | ||||||
| search_depth: Optional[str] = 'advanced' | ||||||
|
|
||||||
| # Topic category: 'general', 'news', or 'finance' | ||||||
| topic: Optional[str] = 'general' | ||||||
|
|
||||||
| # Domains to include in search | ||||||
| include_domains: Optional[List[str]] = None | ||||||
|
|
||||||
| # Domains to exclude from search | ||||||
| exclude_domains: Optional[List[str]] = None | ||||||
|
|
||||||
| def to_dict(self) -> Dict[str, Any]: | ||||||
| """Convert the request parameters to a dictionary.""" | ||||||
| d = { | ||||||
| 'query': self.query, | ||||||
| 'max_results': self.num_results, | ||||||
| 'search_depth': self.search_depth, | ||||||
| 'topic': self.topic, | ||||||
| } | ||||||
| if self.include_domains: | ||||||
| d['include_domains'] = self.include_domains | ||||||
| if self.exclude_domains: | ||||||
| d['exclude_domains'] = self.exclude_domains | ||||||
| return d | ||||||
|
|
||||||
| def to_json(self) -> str: | ||||||
| """Convert the request parameters to a JSON string.""" | ||||||
| return json.dumps(self.to_dict(), ensure_ascii=False) | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class TavilySearchResult: | ||||||
|
|
||||||
| # The original search query string | ||||||
| query: str | ||||||
|
|
||||||
| # Optional arguments for the search request | ||||||
| arguments: Dict[str, Any] = field(default_factory=dict) | ||||||
|
|
||||||
| # The raw response from the Tavily search API | ||||||
| response: Any = None | ||||||
|
|
||||||
| def to_list(self) -> List[Dict[str, Any]]: | ||||||
| """Convert the search results to a list of dictionaries.""" | ||||||
| if not self.response or not self.response.get('results'): | ||||||
| print('***Warning: No search results found.') | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using You can add
Suggested change
|
||||||
| return [] | ||||||
|
|
||||||
| if not self.query: | ||||||
| print('***Warning: No query provided for search results.') | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| return [] | ||||||
|
|
||||||
| res_list: List[Dict[str, Any]] = [] | ||||||
| for res in self.response['results']: | ||||||
| res_list.append({ | ||||||
| 'url': res.get('url', ''), | ||||||
| 'id': res.get('url', ''), | ||||||
| 'title': res.get('title', ''), | ||||||
| 'summary': res.get('content', ''), | ||||||
| 'published_date': res.get('published_date', ''), | ||||||
| }) | ||||||
|
|
||||||
| return res_list | ||||||
|
|
||||||
| @staticmethod | ||||||
| def load_from_disk(file_path: str) -> List[Dict[str, Any]]: | ||||||
| """Load search results from a local file.""" | ||||||
| import os | ||||||
| if not os.path.exists(file_path): | ||||||
| return [] | ||||||
|
|
||||||
| with open(file_path, 'r', encoding='utf-8') as f: | ||||||
| data = json.load(f) | ||||||
| print(f'Search results loaded from {file_path}') | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| return data | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # flake8: noqa | ||
| import os | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from tavily import TavilyClient | ||
| from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult | ||
| from ms_agent.tools.search.search_base import SearchEngine, SearchEngineType | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ms_agent.llm.utils import Tool | ||
|
|
||
|
|
||
| class TavilySearch(SearchEngine): | ||
| """ | ||
| Search engine using Tavily API. | ||
|
|
||
| Best for: AI-optimized web search with high relevance, | ||
| real-time data retrieval for LLM applications. | ||
| """ | ||
|
|
||
| engine_type = SearchEngineType.TAVILY | ||
|
|
||
| def __init__(self, api_key: str = None): | ||
| api_key = api_key or os.getenv('TAVILY_API_KEY') | ||
| assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable' | ||
|
|
||
| self.client = TavilyClient(api_key=api_key) | ||
|
|
||
| def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: | ||
| """ | ||
| Perform a search using the Tavily API with the provided search request parameters. | ||
|
|
||
| :param search_request: An instance of TavilySearchRequest containing search parameters. | ||
| :return: An instance of TavilySearchResult containing the search results. | ||
| """ | ||
| search_args: dict = search_request.to_dict() | ||
| search_result = TavilySearchResult( | ||
| query=search_request.query, | ||
| arguments=search_args, | ||
| ) | ||
| try: | ||
| search_result.response = self.client.search(**search_args) | ||
| except Exception as e: | ||
| raise RuntimeError(f'Failed to perform Tavily search: {e}') from e | ||
|
|
||
| return search_result | ||
|
|
||
| @classmethod | ||
| def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': | ||
| """Return the tool definition for Tavily search engine.""" | ||
| from ms_agent.llm.utils import Tool | ||
| return Tool( | ||
| tool_name=cls.get_tool_name(), | ||
| server_name=server_name, | ||
| description=( | ||
| 'Search the web using Tavily AI search engine. ' | ||
| 'Best for: AI-optimized web search with high relevance, ' | ||
| 'real-time data retrieval for LLM applications.'), | ||
| parameters={ | ||
| 'type': 'object', | ||
| 'properties': { | ||
| 'query': { | ||
| 'type': | ||
| 'string', | ||
| 'description': | ||
| 'The search query. Keep under 400 characters for best results.', | ||
| }, | ||
| 'num_results': { | ||
| 'type': | ||
| 'integer', | ||
| 'minimum': | ||
| 1, | ||
| 'maximum': | ||
| 10, | ||
| 'description': | ||
| 'Number of results to return. Default is 5.', | ||
| }, | ||
| 'search_depth': { | ||
| 'type': | ||
| 'string', | ||
| 'enum': ['basic', 'advanced'], | ||
| 'description': | ||
| ('Search depth. "basic" for quick results, ' | ||
| '"advanced" for higher relevance. ' | ||
| 'Default is "advanced".'), | ||
| }, | ||
| 'topic': { | ||
| 'type': | ||
| 'string', | ||
| 'enum': ['general', 'news', 'finance'], | ||
| 'description': | ||
| ('Topic category for the search. ' | ||
| 'Default is "general".'), | ||
| }, | ||
| }, | ||
| 'required': ['query'], | ||
| }, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def build_request_from_args(cls, **kwargs) -> TavilySearchRequest: | ||
| """Build TavilySearchRequest from tool call arguments.""" | ||
| return TavilySearchRequest( | ||
| query=kwargs['query'], | ||
| num_results=kwargs.get('num_results', 5), | ||
| search_depth=kwargs.get('search_depth', 'advanced'), | ||
| topic=kwargs.get('topic', 'general'), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation filters out the
research_goalfield, which is defined as a required field inget_json_schema. This is inconsistent with other request generators and likely causes a bug, asresearch_goalis probably used by downstream components.To fix this, you can remove this filtering logic. This will also require adding
research_goal: Optional[str] = Noneto theTavilySearchRequestdataclass inms_agent/tools/search/tavily/schema.pyto avoid an error when instantiating the dataclass.