-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add FileSearchTool with support for OpenAI and Google
#3396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4376b96
6cec96f
4c3fe56
3c8decf
2343679
666a1bb
7365e20
2ee21c9
deef1ec
18b4b86
11654ed
1542f5c
7d683b7
d8ef07d
6acbd76
380e25c
c83f125
8eba82d
b3a8930
19f32f9
00ea1ed
c6ed56c
9b5bb54
c2765ac
8286cd7
3011e05
bc278e8
5f694c9
8dc7c17
8216f31
ffcb21f
bc3ac7a
977ab53
68bafb6
29f8da0
eef4526
8cc3d60
b77d857
db475c6
fd62e29
13abf31
129dacd
8d3f359
065c711
50ad873
ff22a6d
1f249aa
c7798d1
9696c01
2fe8791
1c0589b
122be13
925a909
5cc9fb0
8d6a6e7
68738fd
99af405
0c7c582
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| 'ImageGenerationTool', | ||
| 'MemoryTool', | ||
| 'MCPServerTool', | ||
| 'FileSearchTool', | ||
| ) | ||
|
|
||
| _BUILTIN_TOOL_TYPES: dict[str, type[AbstractBuiltinTool]] = {} | ||
|
|
@@ -334,6 +335,30 @@ def unique_id(self) -> str: | |
| return ':'.join([self.kind, self.id]) | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class FileSearchTool(AbstractBuiltinTool): | ||
| """A builtin tool that allows your agent to search through uploaded files using vector search. | ||
|
|
||
| This tool provides a fully managed Retrieval-Augmented Generation (RAG) system that handles | ||
| file storage, chunking, embedding generation, and context injection into prompts. | ||
|
|
||
| Supported by: | ||
|
|
||
| * OpenAI Responses | ||
| * Google (Gemini) | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| file_store_ids: list[str] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make this a set |
||
| """List of file store IDs to search through. | ||
|
|
||
| For OpenAI, these are the IDs of vector stores created via the OpenAI API. | ||
| For Google, these are file search store names that have been uploaded and processed via the Gemini Files API. | ||
| """ | ||
|
|
||
| kind: str = 'file_search' | ||
| """The kind of tool.""" | ||
|
|
||
|
|
||
| def _tool_discriminator(tool_data: dict[str, Any] | AbstractBuiltinTool) -> str: | ||
| if isinstance(tool_data, dict): | ||
| return tool_data.get('kind', AbstractBuiltinTool.kind) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
| from .. import UnexpectedModelBehavior, _utils, usage | ||
| from .._output import OutputObjectDefinition | ||
| from .._run_context import RunContext | ||
| from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool | ||
| from ..builtin_tools import CodeExecutionTool, FileSearchTool, ImageGenerationTool, UrlContextTool, WebSearchTool | ||
| from ..exceptions import ModelAPIError, ModelHTTPError, UserError | ||
| from ..messages import ( | ||
| BinaryContent, | ||
|
|
@@ -63,6 +63,7 @@ | |
| ExecutableCode, | ||
| ExecutableCodeDict, | ||
| FileDataDict, | ||
| FileSearchDict, | ||
| FinishReason as GoogleFinishReason, | ||
| FunctionCallDict, | ||
| FunctionCallingConfigDict, | ||
|
|
@@ -93,6 +94,7 @@ | |
| 'you can use the `google` optional group — `pip install "pydantic-ai-slim[google]"`' | ||
| ) from _import_error | ||
|
|
||
|
|
||
| LatestGoogleModelNames = Literal[ | ||
| 'gemini-flash-latest', | ||
| 'gemini-flash-lite-latest', | ||
|
|
@@ -350,6 +352,9 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T | |
| tools.append(ToolDict(url_context=UrlContextDict())) | ||
| elif isinstance(tool, CodeExecutionTool): | ||
| tools.append(ToolDict(code_execution=ToolCodeExecutionDict())) | ||
| elif isinstance(tool, FileSearchTool): | ||
| file_search_config = FileSearchDict(file_search_store_names=tool.file_store_ids) | ||
| tools.append(ToolDict(file_search=file_search_config)) | ||
| elif isinstance(tool, ImageGenerationTool): # pragma: no branch | ||
| if not self.profile.supports_image_output: | ||
| raise UserError( | ||
|
|
@@ -652,6 +657,7 @@ class GeminiStreamedResponse(StreamedResponse): | |
| _timestamp: datetime | ||
| _provider_name: str | ||
| _provider_url: str | ||
| _file_search_tool_call_id: str | None = field(default=None, init=False) | ||
|
|
||
| async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 | ||
| code_execution_tool_call_id: str | None = None | ||
|
|
@@ -697,6 +703,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| continue # pragma: no cover | ||
|
|
||
| for part in parts: | ||
| if self._file_search_tool_call_id and candidate.grounding_metadata: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| grounding_chunks = candidate.grounding_metadata.grounding_chunks | ||
| if grounding_chunks: | ||
| retrieved_contexts = [ | ||
| chunk.retrieved_context.model_dump(mode='json') | ||
| for chunk in grounding_chunks | ||
| if chunk.retrieved_context | ||
| ] | ||
| if retrieved_contexts: | ||
| yield self._parts_manager.handle_part( | ||
| vendor_part_id=uuid4(), | ||
| part=BuiltinToolReturnPart( | ||
| provider_name=self.provider_name, | ||
| tool_name=FileSearchTool.kind, | ||
| tool_call_id=self._file_search_tool_call_id, | ||
| content={'retrieved_contexts': retrieved_contexts}, | ||
| ), | ||
| ) | ||
| self._file_search_tool_call_id = None | ||
|
|
||
| provider_details: dict[str, Any] | None = None | ||
| if part.thought_signature: | ||
| # Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#thought-signatures: | ||
|
|
@@ -739,10 +765,27 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| part=FilePart(content=BinaryContent.narrow_type(content), provider_details=provider_details), | ||
| ) | ||
| elif part.executable_code is not None: | ||
| code_execution_tool_call_id = _utils.generate_tool_call_id() | ||
| part = _map_executable_code(part.executable_code, self.provider_name, code_execution_tool_call_id) | ||
| part.provider_details = provider_details | ||
| yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=part) | ||
| code = part.executable_code.code | ||
| if code and (file_search_query := _extract_file_search_query(code)): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's check if the file search builtin tool was included before we do this
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And pleaase move this to a method |
||
| self._file_search_tool_call_id = _utils.generate_tool_call_id() | ||
| part_obj = BuiltinToolCallPart( | ||
| provider_name=self.provider_name, | ||
| tool_name=FileSearchTool.kind, | ||
| tool_call_id=self._file_search_tool_call_id, | ||
| args={'query': file_search_query}, | ||
| ) | ||
| part_obj.provider_details = provider_details | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line and the next one can stay out of the new method |
||
| yield self._parts_manager.handle_part( | ||
| vendor_part_id=uuid4(), | ||
| part=part_obj, | ||
| ) | ||
| else: | ||
| code_execution_tool_call_id = _utils.generate_tool_call_id() | ||
| part_obj = _map_executable_code( | ||
| part.executable_code, self.provider_name, code_execution_tool_call_id | ||
| ) | ||
| part_obj.provider_details = provider_details | ||
| yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=part_obj) | ||
| elif part.code_execution_result is not None: | ||
| assert code_execution_tool_call_id is not None | ||
| part = _map_code_execution_result( | ||
|
|
@@ -856,6 +899,11 @@ def _process_response_from_parts( | |
| items.append(web_search_call) | ||
| items.append(web_search_return) | ||
|
|
||
| file_search_call, file_search_return = _map_file_search_grounding_metadata(grounding_metadata, provider_name) | ||
| if file_search_call and file_search_return: | ||
| items.append(file_search_call) | ||
| items.append(file_search_return) | ||
|
|
||
| item: ModelResponsePart | None = None | ||
| code_execution_tool_call_id: str | None = None | ||
| for part in parts: | ||
|
|
@@ -1007,3 +1055,47 @@ def _map_grounding_metadata( | |
| ) | ||
| else: | ||
| return None, None | ||
|
|
||
|
|
||
| def _map_file_search_grounding_metadata( | ||
| grounding_metadata: GroundingMetadata | None, provider_name: str | ||
| ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart] | tuple[None, None]: | ||
| if not grounding_metadata or not (grounding_chunks := grounding_metadata.grounding_chunks): | ||
| return None, None | ||
|
|
||
| retrieved_contexts = [ | ||
| chunk.retrieved_context.model_dump(mode='json') for chunk in grounding_chunks if chunk.retrieved_context | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems duplicated with the stuff above |
||
| ] | ||
|
|
||
| if not retrieved_contexts: | ||
| return None, None | ||
|
|
||
| tool_call_id = _utils.generate_tool_call_id() | ||
| return ( | ||
| BuiltinToolCallPart( | ||
| provider_name=provider_name, | ||
| tool_name=FileSearchTool.kind, | ||
| tool_call_id=tool_call_id, | ||
| args={}, | ||
| ), | ||
| BuiltinToolReturnPart( | ||
| provider_name=provider_name, | ||
| tool_name=FileSearchTool.kind, | ||
| tool_call_id=tool_call_id, | ||
| content={'retrieved_contexts': retrieved_contexts}, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def _extract_file_search_query(code: str) -> str | None: | ||
| """Extract the query from file_search.query() executable code. | ||
|
|
||
| Example: 'print(file_search.query(query="what is the capital of France?"))' | ||
| Returns: 'what is the capital of France?' | ||
| """ | ||
| import re | ||
|
|
||
| match = re.search(r'file_search\.query\(query=(["\'])(.+?)\1\)', code) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will break on slash-escaped quotes
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should compile this regex outside of the method |
||
| if match: | ||
| return match.group(2) | ||
| return None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was actually an incorrect find/replace, because here we're passing OpenAI's own types via
OpenAIResponsesModelSettings.openai_builtin_toolsThat's a good reminder that line 137 and 139 in this file should also be updated now that File search is natively supported. That means this example should be changed to the
ComputerToolParam