From b8ba9177558f4fd47be218e477eff483764a6687 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:33:27 -0400 Subject: [PATCH 1/9] Add web search functionality with Gemini integration - Updated .env.example to include WEB_SEARCH_BASE_URL and WEB_SEARCH_API_KEY. - Introduced web_search.py to implement a tool for fetching Google Search grounded responses from the Gemini model. --- .env.example | 3 + src/utils/tools/web_search.py | 173 ++++++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 src/utils/tools/web_search.py diff --git a/.env.example b/.env.example index 88bd87b..449e696 100644 --- a/.env.example +++ b/.env.example @@ -23,3 +23,6 @@ WEAVIATE_GRPC_SECURE="true" # set to false for localhost # Optionally, specify E2B.dev API key for Python Code Interpreter E2B_API_KEY="e2b_..." + +WEB_SEARCH_BASE_URL="..." +WEB_SEARCH_API_KEY="..." diff --git a/src/utils/tools/web_search.py b/src/utils/tools/web_search.py new file mode 100644 index 0000000..c6c04a3 --- /dev/null +++ b/src/utils/tools/web_search.py @@ -0,0 +1,173 @@ +"""Implements a tool for fetch Google Search grounded responses from Gemini.""" + +import asyncio +import os +from collections.abc import Mapping +from typing import Literal + +import backoff +import httpx +from pydantic import BaseModel +from pydantic.fields import Field + + +class ModelSettings(BaseModel): + """Configuration for the Gemini model used for web search.""" + + model: Literal["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"] = ( + "gemini-2.5-flash" + ) + temperature: float | None = Field(default=0.2, ge=0, le=2) + max_output_tokens: int | None = Field(default=None, ge=1) + seed: int | None = None + thinking_budget: int | None = Field(default=-1, ge=-1) + + +class Response(BaseModel): + """Response returned by Gemini.""" + + text_with_citations: str + web_search_queries: list[str] + raw_response: Mapping[str, object] + + +class GeminiGroundingWithGoogleSearch: + """Tool for fetching Google Search grounded responses from Gemini via a proxy. + + Parameters + ---------- + base_url : str, optional, default=None + Base URL for the Gemini proxy. Defaults to the value of the + ``WEB_SEARCH_BASE_URL`` environment variable. + api_key : str, optional, default=None + API key for the Gemini proxy. Defaults to the value of the + ``WEB_SEARCH_API_KEY`` environment variable. + model_settings : ModelSettings, optional, default=None + Settings for the Gemini model used for web search. + max_concurrency : int, optional, default=5 + Maximum number of concurrent Gemini requests. + timeout : int, optional, default=300 + Timeout for requests to the server. + """ + + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + *, + model_settings: ModelSettings | None = None, + max_concurrency: int = 5, + timeout: int = 300, + ) -> None: + self.base_url = base_url or os.getenv("WEB_SEARCH_BASE_URL") + self.api_key = api_key or os.getenv("WEB_SEARCH_API_KEY") + self.model_settings = model_settings or ModelSettings() + + self._semaphore = asyncio.Semaphore(max_concurrency) + + self._client = httpx.AsyncClient( + timeout=timeout, headers={"X-API-Key": self.api_key} + ) + self._endpoint = f"{self.base_url}/api/v1/grounding_with_search" + + async def get_web_search_grounded_response(self, query: str) -> Response: + """Get Google Search grounded response to query from Gemini model. + + This function calls a Gemini model with Google Search tool enabled. How + it works: + - The model analyzes the input query and determines if a Google Search + can improve the answer + - If needed, the model automatically generates one or multiple search + queries and executes them + - The model processes the search results, synthesizes the information, + and formulates a response. + - The API returns a final, user-friendly response that is grounded in + the search results + + Parameters + ---------- + query : str + Query to pass to Gemini. + + Returns + ------- + Response + Response returned by Gemini. This includes the text with citations added, + the web search queries executed (expanded from the input query), and the + raw response object from the API. + """ + # Payload + payload = self.model_settings.model_dump(exclude_unset=True) + payload["query"] = query + + # Call Gemini + response = await self._post(payload) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise RuntimeError( + f"Gemini call failed with status {exc.response.status_code}" + ) from exc + + response_json = response.json() + text_with_citations = add_citations(response_json) + return Response( + text_with_citations=text_with_citations, + web_search_queries=response_json["web_search_queries"], + raw_response=response_json, + ) + + @backoff.on_exception( + backoff.expo, + (httpx.HTTPError, httpx.Timeout, httpx.RequestError), + jitter=backoff.full_jitter, + ) + async def _post(self, payload: dict[str, object]) -> httpx.Response: + async with self._semaphore: + return await self._client.post(self._endpoint, json=payload) + + +def add_citations(response: dict[str, object]) -> str: + """Add citations to the Gemini response. + + Code based on example in [1]_. + + Parameters + ---------- + response : dict of str to object + JSON response returned by Gemini. + + Returns + ------- + str + Text with citations added. + + References + ---------- + .. [1] https://ai.google.dev/gemini-api/docs/google-search#attributing_sources_with_inline_citations + """ + text = response["candidates"][0]["content"]["parts"][0]["text"] + supports = response["candidates"][0]["grounding_metadata"]["grounding_supports"] + chunks = response["candidates"][0]["grounding_metadata"]["grounding_chunks"] + + # Sort supports by end_index in descending order to avoid shifting issues + # when inserting. + sorted_supports = sorted( + supports, key=lambda s: s["segment"]["end_index"], reverse=True + ) + + for support in sorted_supports: + end_index = support["segment"]["end_index"] + if support["grounding_chunk_indices"]: + # Create citation string like [1](link1)[2](link2) + citation_links = [] + for i in support["grounding_chunk_indices"]: + if i < len(chunks): + uri = chunks[i]["web"]["uri"] + citation_links.append(f"[{i + 1}]({uri})") + + citation_string = ", ".join(citation_links) + text = text[:end_index] + citation_string + text[end_index:] + + return text From 40df1665b9a169250103e4b367de26df03a41ea6 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 13:08:04 -0400 Subject: [PATCH 2/9] Update docstring in GeminiGroundingWithGoogleSearch to include reference link --- src/utils/tools/web_search.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/utils/tools/web_search.py b/src/utils/tools/web_search.py index c6c04a3..6b3951d 100644 --- a/src/utils/tools/web_search.py +++ b/src/utils/tools/web_search.py @@ -74,15 +74,15 @@ async def get_web_search_grounded_response(self, query: str) -> Response: """Get Google Search grounded response to query from Gemini model. This function calls a Gemini model with Google Search tool enabled. How - it works: + it works [1]_: - The model analyzes the input query and determines if a Google Search - can improve the answer + can improve the answer. - If needed, the model automatically generates one or multiple search - queries and executes them + queries and executes them. - The model processes the search results, synthesizes the information, and formulates a response. - The API returns a final, user-friendly response that is grounded in - the search results + the search results. Parameters ---------- @@ -95,6 +95,10 @@ async def get_web_search_grounded_response(self, query: str) -> Response: Response returned by Gemini. This includes the text with citations added, the web search queries executed (expanded from the input query), and the raw response object from the API. + + References + ---------- + .. [1] https://ai.google.dev/gemini-api/docs/google-search#how_grounding_with_google_search_works """ # Payload payload = self.model_settings.model_dump(exclude_unset=True) From 80127c846c25ee7f20cb250029a6f58d216cbf11 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:00:20 -0400 Subject: [PATCH 3/9] Refactor web search error handling, update POST method, fix web search query extraction - Introduced RETRYABLE_STATUS for managing retryable HTTP status codes. - Updated the _post method to _post_payload for clarity and added a docstring. - Enhanced error handling in the backoff decorator to include specific exceptions and retry logic based on RETRYABLE_STATUS. --- src/utils/tools/web_search.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/utils/tools/web_search.py b/src/utils/tools/web_search.py index 6b3951d..e95c908 100644 --- a/src/utils/tools/web_search.py +++ b/src/utils/tools/web_search.py @@ -11,6 +11,9 @@ from pydantic.fields import Field +RETRYABLE_STATUS = {429, 500, 502, 503, 504} + + class ModelSettings(BaseModel): """Configuration for the Gemini model used for web search.""" @@ -105,7 +108,7 @@ async def get_web_search_grounded_response(self, query: str) -> Response: payload["query"] = query # Call Gemini - response = await self._post(payload) + response = await self._post_payload(payload) try: response.raise_for_status() @@ -116,18 +119,31 @@ async def get_web_search_grounded_response(self, query: str) -> Response: response_json = response.json() text_with_citations = add_citations(response_json) + return Response( text_with_citations=text_with_citations, - web_search_queries=response_json["web_search_queries"], + web_search_queries=response_json["candidates"][0]["grounding_metadata"][ + "web_search_queries" + ], raw_response=response_json, ) @backoff.on_exception( backoff.expo, - (httpx.HTTPError, httpx.Timeout, httpx.RequestError), + ( + httpx.TimeoutException, + httpx.NetworkError, + httpx.HTTPStatusError, # only retry codes in RETRYABLE_STATUS + ), + giveup=lambda exc: ( + isinstance(exc, httpx.HTTPStatusError) + and exc.response.status_code not in RETRYABLE_STATUS + ), jitter=backoff.full_jitter, + max_tries=5, ) - async def _post(self, payload: dict[str, object]) -> httpx.Response: + async def _post_payload(self, payload: dict[str, object]) -> httpx.Response: + """Send a POST request to the endpoint with the given payload.""" async with self._semaphore: return await self._client.post(self._endpoint, json=payload) From db64d24722e4a8d94e444fd9ee88030d6f1c3435 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:00:47 -0400 Subject: [PATCH 4/9] Add integration tests for web search functionality with Gemini grounding --- tests/tool_tests/test_web_search.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/tool_tests/test_web_search.py diff --git a/tests/tool_tests/test_web_search.py b/tests/tool_tests/test_web_search.py new file mode 100644 index 0000000..dbd9918 --- /dev/null +++ b/tests/tool_tests/test_web_search.py @@ -0,0 +1,25 @@ +"""Test web search integration.""" + +import os + +import pytest + +from src.utils import pretty_print +from src.utils.tools.web_search import GeminiGroundingWithGoogleSearch + + +@pytest.mark.asyncio +async def test_web_search_with_gemini_grounding(): + """Test Gemini grounding with Google Search integration.""" + # Check if the environment variable is set + assert os.getenv("WEB_SEARCH_BASE_URL") + assert os.getenv("WEB_SEARCH_API_KEY") + + tool_cls = GeminiGroundingWithGoogleSearch() + response = await tool_cls.get_web_search_grounded_response( + "How does the annual growth in the 50th-percentile income " + "in the US compare with that in Canada?" + ) + + pretty_print(response.text_with_citations) + assert response.text_with_citations From 0e3f94c4c9de39f33c3719ebd2086751f4179ada Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:06:49 -0400 Subject: [PATCH 5/9] Enhance error handling in GeminiGroundingWithGoogleSearch by raising ValueError for missing environment variables. --- src/utils/tools/web_search.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/utils/tools/web_search.py b/src/utils/tools/web_search.py index e95c908..a522cc3 100644 --- a/src/utils/tools/web_search.py +++ b/src/utils/tools/web_search.py @@ -1,4 +1,4 @@ -"""Implements a tool for fetch Google Search grounded responses from Gemini.""" +"""Implements a tool to fetch Google Search grounded responses from Gemini.""" import asyncio import os @@ -51,6 +51,12 @@ class GeminiGroundingWithGoogleSearch: Maximum number of concurrent Gemini requests. timeout : int, optional, default=300 Timeout for requests to the server. + + Raises + ------ + ValueError + If the ``WEB_SEARCH_API_KEY`` environment variable is not set or the + ``WEB_SEARCH_BASE_URL`` environment variable is not set. """ def __init__( @@ -66,6 +72,11 @@ def __init__( self.api_key = api_key or os.getenv("WEB_SEARCH_API_KEY") self.model_settings = model_settings or ModelSettings() + if self.api_key is None: + raise ValueError("WEB_SEARCH_API_KEY environment variable is not set.") + if self.base_url is None: + raise ValueError("WEB_SEARCH_BASE_URL environment variable is not set.") + self._semaphore = asyncio.Semaphore(max_concurrency) self._client = httpx.AsyncClient( From e54899c84eb5d8208819cc9da5074db694f58517 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:12:50 -0400 Subject: [PATCH 6/9] Rename tool module and test file --- src/utils/tools/{web_search.py => gemini_grounding.py} | 0 .../tool_tests/{test_web_search.py => test_gemini_grounding.py} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/utils/tools/{web_search.py => gemini_grounding.py} (100%) rename tests/tool_tests/{test_web_search.py => test_gemini_grounding.py} (90%) diff --git a/src/utils/tools/web_search.py b/src/utils/tools/gemini_grounding.py similarity index 100% rename from src/utils/tools/web_search.py rename to src/utils/tools/gemini_grounding.py diff --git a/tests/tool_tests/test_web_search.py b/tests/tool_tests/test_gemini_grounding.py similarity index 90% rename from tests/tool_tests/test_web_search.py rename to tests/tool_tests/test_gemini_grounding.py index dbd9918..2f9a713 100644 --- a/tests/tool_tests/test_web_search.py +++ b/tests/tool_tests/test_gemini_grounding.py @@ -5,7 +5,7 @@ import pytest from src.utils import pretty_print -from src.utils.tools.web_search import GeminiGroundingWithGoogleSearch +from src.utils.tools.gemini_grounding import GeminiGroundingWithGoogleSearch @pytest.mark.asyncio From fff4a52ef1d4b5414d10e6b147904f8b1304abad Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Mon, 3 Nov 2025 10:17:53 -0500 Subject: [PATCH 7/9] Rename Response class to GroundedResponse for clarity --- src/utils/tools/gemini_grounding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/utils/tools/gemini_grounding.py b/src/utils/tools/gemini_grounding.py index a522cc3..62e083f 100644 --- a/src/utils/tools/gemini_grounding.py +++ b/src/utils/tools/gemini_grounding.py @@ -26,7 +26,7 @@ class ModelSettings(BaseModel): thinking_budget: int | None = Field(default=-1, ge=-1) -class Response(BaseModel): +class GroundedResponse(BaseModel): """Response returned by Gemini.""" text_with_citations: str @@ -82,9 +82,9 @@ def __init__( self._client = httpx.AsyncClient( timeout=timeout, headers={"X-API-Key": self.api_key} ) - self._endpoint = f"{self.base_url}/api/v1/grounding_with_search" + self._endpoint = f"{self.base_url.strip('/')}/api/v1/grounding_with_search" - async def get_web_search_grounded_response(self, query: str) -> Response: + async def get_web_search_grounded_response(self, query: str) -> GroundedResponse: """Get Google Search grounded response to query from Gemini model. This function calls a Gemini model with Google Search tool enabled. How @@ -105,7 +105,7 @@ async def get_web_search_grounded_response(self, query: str) -> Response: Returns ------- - Response + GroundedResponse Response returned by Gemini. This includes the text with citations added, the web search queries executed (expanded from the input query), and the raw response object from the API. @@ -131,7 +131,7 @@ async def get_web_search_grounded_response(self, query: str) -> Response: response_json = response.json() text_with_citations = add_citations(response_json) - return Response( + return GroundedResponse( text_with_citations=text_with_citations, web_search_queries=response_json["candidates"][0]["grounding_metadata"][ "web_search_queries" From ff7157cdd6ea540afe60c18b2cedb661a81a7f60 Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:03:05 -0500 Subject: [PATCH 8/9] Refactor citation handling in Gemini grounding - Updated the add_citations function to return a tuple of synthesized text and a mapping of citation IDs to source labels. - Enhanced the GroundedResponse class to include a citations field. - Improved error handling and type checks throughout the citation processing logic. - Introduced a new _collect_citations function to streamline citation ID collection from candidates. --- src/utils/tools/gemini_grounding.py | 133 +++++++++++++++++++++------- 1 file changed, 103 insertions(+), 30 deletions(-) diff --git a/src/utils/tools/gemini_grounding.py b/src/utils/tools/gemini_grounding.py index 62e083f..5bc95be 100644 --- a/src/utils/tools/gemini_grounding.py +++ b/src/utils/tools/gemini_grounding.py @@ -2,8 +2,8 @@ import asyncio import os -from collections.abc import Mapping -from typing import Literal +from typing import Any, Literal +from urllib.parse import urlparse import backoff import httpx @@ -31,7 +31,7 @@ class GroundedResponse(BaseModel): text_with_citations: str web_search_queries: list[str] - raw_response: Mapping[str, object] + citations: dict[int, str] class GeminiGroundingWithGoogleSearch: @@ -107,8 +107,8 @@ async def get_web_search_grounded_response(self, query: str) -> GroundedResponse ------- GroundedResponse Response returned by Gemini. This includes the text with citations added, - the web search queries executed (expanded from the input query), and the - raw response object from the API. + the web search queries executed (expanded from the input query), and a + mapping of the citation ids to the website where the citation is from. References ---------- @@ -124,19 +124,24 @@ async def get_web_search_grounded_response(self, query: str) -> GroundedResponse try: response.raise_for_status() except httpx.HTTPStatusError as exc: - raise RuntimeError( - f"Gemini call failed with status {exc.response.status_code}" - ) from exc + raise exc from exc response_json = response.json() - text_with_citations = add_citations(response_json) + + candidates: list[dict[str, Any]] | None = response_json.get("candidates") + grounding_metadata: dict[str, Any] | None = ( + candidates[0].get("grounding_metadata") if candidates else None + ) + web_search_queries: list[str] = ( + grounding_metadata["web_search_queries"] if grounding_metadata else [] + ) + + text_with_citations, citations = add_citations(response_json) return GroundedResponse( text_with_citations=text_with_citations, - web_search_queries=response_json["candidates"][0]["grounding_metadata"][ - "web_search_queries" - ], - raw_response=response_json, + web_search_queries=web_search_queries, + citations=citations, ) @backoff.on_exception( @@ -159,7 +164,7 @@ async def _post_payload(self, payload: dict[str, object]) -> httpx.Response: return await self._client.post(self._endpoint, json=payload) -def add_citations(response: dict[str, object]) -> str: +def add_citations(response: dict[str, object]) -> tuple[str, dict[int, str]]: """Add citations to the Gemini response. Code based on example in [1]_. @@ -171,34 +176,102 @@ def add_citations(response: dict[str, object]) -> str: Returns ------- - str - Text with citations added. + tuple[str, dict[int, str]] + The synthesized text and a mapping of citation ids to source labels. References ---------- .. [1] https://ai.google.dev/gemini-api/docs/google-search#attributing_sources_with_inline_citations """ - text = response["candidates"][0]["content"]["parts"][0]["text"] - supports = response["candidates"][0]["grounding_metadata"]["grounding_supports"] - chunks = response["candidates"][0]["grounding_metadata"]["grounding_chunks"] + candidates = response.get("candidates") if isinstance(response, dict) else None + if not candidates: + return "", {} + + candidate = candidates[0] or {} + content = candidate.get("content") if isinstance(candidate, dict) else {} + parts = content.get("parts") if isinstance(content, dict) else [] + + text = "" + for part in parts if isinstance(parts, list) else []: + if isinstance(part, dict) and isinstance(part.get("text"), str): + text = part["text"] + break + if not text: + return "", {} + + meta = candidate.get("grounding_metadata") if isinstance(candidate, dict) else {} + raw_supports = meta.get("grounding_supports") if isinstance(meta, dict) else [] + supports = raw_supports if isinstance(raw_supports, list) else [] + raw_chunks = meta.get("grounding_chunks") if isinstance(meta, dict) else [] + chunks = raw_chunks if isinstance(raw_chunks, list) else [] + + citations: dict[int, str] = {} + chunk_to_id: dict[int, int] = {} + + if supports and chunks: + citations, chunk_to_id = _collect_citations(candidate) # Sort supports by end_index in descending order to avoid shifting issues # when inserting. sorted_supports = sorted( - supports, key=lambda s: s["segment"]["end_index"], reverse=True + (s for s in supports if isinstance(s, dict) and s.get("segment")), + key=lambda s: s["segment"].get("end_index", 0), + reverse=True, ) for support in sorted_supports: - end_index = support["segment"]["end_index"] - if support["grounding_chunk_indices"]: - # Create citation string like [1](link1)[2](link2) - citation_links = [] - for i in support["grounding_chunk_indices"]: - if i < len(chunks): - uri = chunks[i]["web"]["uri"] - citation_links.append(f"[{i + 1}]({uri})") - + segment = support.get("segment") or {} + end_index = segment.get("end_index") + if not isinstance(end_index, int) or end_index < 0 or end_index > len(text): + continue + indices = support.get("grounding_chunk_indices") or [] + citation_links: list[str] = [] + for idx in indices: + if not isinstance(idx, int): + continue + citation_id = chunk_to_id.get(idx) + if citation_id is None or idx >= len(chunks): + continue + web = chunks[idx].get("web") if isinstance(chunks[idx], dict) else {} + uri = web.get("uri") if isinstance(web, dict) else None + if uri: + citation_links.append(f"[{citation_id}]({uri})") + + if citation_links: citation_string = ", ".join(citation_links) text = text[:end_index] + citation_string + text[end_index:] - return text + return text, citations + + +def _collect_citations(candidate: dict) -> tuple[dict[int, str], dict[int, int]]: + """Collect citation ids from a candidate dict.""" + supports = candidate["grounding_metadata"]["grounding_supports"] + chunks = candidate["grounding_metadata"]["grounding_chunks"] + + citations: dict[int, str] = {} + chunk_to_id: dict[int, int] = {} + next_id = 1 + + def label_for(chunk: dict) -> str: + web = chunk.get("web") or {} + title = web.get("title") + uri = web.get("uri") + if title: + return title + if uri: + parsed = urlparse(uri) + return parsed.hostname or parsed.netloc or uri + return "unknown source" + + for support in supports: + if not isinstance(support, dict): + continue + for chunk_idx in support.get("grounding_chunk_indices", []): + if chunk_idx not in chunk_to_id and 0 <= chunk_idx < len(chunks): + label = label_for(chunks[chunk_idx]) + chunk_to_id[chunk_idx] = next_id + citations[next_id] = label + next_id += 1 + + return citations, chunk_to_id From 19b8b435a1db11132caf92fc5823fc37f4cb6afa Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:41:52 -0500 Subject: [PATCH 9/9] Add new multi-agent implementation with knowledge base and web search tools --- .../2_multi_agent/efficient_multiple_kbs.py | 202 ++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 src/2_frameworks/2_multi_agent/efficient_multiple_kbs.py diff --git a/src/2_frameworks/2_multi_agent/efficient_multiple_kbs.py b/src/2_frameworks/2_multi_agent/efficient_multiple_kbs.py new file mode 100644 index 0000000..42d9247 --- /dev/null +++ b/src/2_frameworks/2_multi_agent/efficient_multiple_kbs.py @@ -0,0 +1,202 @@ +"""Example code for planner-worker agent collaboration with multiple tools.""" + +import asyncio +import contextlib +import signal +import sys + +import agents +import gradio as gr +from dotenv import load_dotenv +from gradio.components.chatbot import ChatMessage +from openai import AsyncOpenAI + +from src.utils import ( + AsyncWeaviateKnowledgeBase, + Configs, + get_weaviate_async_client, + oai_agent_stream_to_gradio_messages, + set_up_logging, + setup_langfuse_tracer, +) +from src.utils.langfuse.shared_client import langfuse_client +from src.utils.tools.gemini_grounding import ( + GeminiGroundingWithGoogleSearch, + ModelSettings, +) + + +load_dotenv(verbose=True) + +set_up_logging() + +AGENT_LLM_NAMES = { + "worker": "gemini-2.5-flash", # less expensive, + "planner": "gemini-2.5-pro", # more expensive, better at reasoning and planning +} + +configs = Configs.from_env_var() +async_weaviate_client = get_weaviate_async_client( + http_host=configs.weaviate_http_host, + http_port=configs.weaviate_http_port, + http_secure=configs.weaviate_http_secure, + grpc_host=configs.weaviate_grpc_host, + grpc_port=configs.weaviate_grpc_port, + grpc_secure=configs.weaviate_grpc_secure, + api_key=configs.weaviate_api_key, +) +async_openai_client = AsyncOpenAI() +async_knowledgebase = AsyncWeaviateKnowledgeBase( + async_weaviate_client, + collection_name="enwiki_20250520", +) + +gemini_grounding_tool = GeminiGroundingWithGoogleSearch( + model_settings=ModelSettings(model=AGENT_LLM_NAMES["worker"]) +) + + +async def _cleanup_clients() -> None: + """Close async clients.""" + await async_weaviate_client.close() + await async_openai_client.close() + + +def _handle_sigint(signum: int, frame: object) -> None: + """Handle SIGINT signal to gracefully shutdown.""" + with contextlib.suppress(Exception): + asyncio.get_event_loop().run_until_complete(_cleanup_clients()) + sys.exit(0) + + +# Worker Agent: handles long context efficiently +kb_agent = agents.Agent( + name="KnowledgeBaseAgent", + instructions=""" + You are an agent specialized in searching a knowledge base. + You will receive a single search query as input. + Use the 'search_knowledgebase' tool to perform a search, then return a + JSON object with: + - 'summary': a concise synthesis of the retrieved information in your own words + - 'sources': a list of citations with {type: "kb", title: "...", section: "..."} + - 'no_results': true/false + + If the tool returns no matches, set "no_results": true and keep "sources" empty. + Do NOT make up information. Do NOT return raw search results or long quotes. + """, + tools=[ + agents.function_tool(async_knowledgebase.search_knowledgebase), + ], + # a faster, smaller model for quick searches + model=agents.OpenAIChatCompletionsModel( + model=AGENT_LLM_NAMES["worker"], openai_client=async_openai_client + ), +) + +# Main Agent: more expensive and slower, but better at complex planning +main_agent = agents.Agent( + name="MainAgent", + instructions=""" + You are a deep research agent and your goal is to conduct in-depth, multi-turn + research by breaking down complex queries, using the provided tools, and + synthesizing the information into a comprehensive report. + + You have access to the following tools: + 1. 'search_knowledgebase' - use this tool to search for information in a + knowledge base. The knowledge base reflects a subset of Wikipedia as + of May 2025. + 2. 'get_web_search_grounded_response' - use this tool for current events, + news, fact-checking or when the information in the knowledge base is + not sufficient to answer the question. + + Both tools will not return raw search results or the sources themselves. + Instead, they will return a concise summary of the key findings, along + with the sources used to generate the summary. + + For best performance, divide complex queries into simpler sub-queries + Before calling either tool, always explain your reasoning for doing so. + + Note that the 'get_web_search_grounded_response' tool will expand the query + into multiple search queries and execute them. It will also return the + queries it executed. Do not repeat them. + + **Routing Guidelines:** + - When answering a question, you should first try to use the 'search_knowledgebase' + tool, unless the question requires recent information after May 2025 or + has explicit recency cues. + - If either tool returns insufficient information for a given query, try + reformulating or using the other tool. You can call either tool multiple + times to get the information you need to answer the user's question. + + **Guidelines for synthesis** + - After collecting results, write the final answer from your own synthesis. + - Add a "Sources" section listing unique sources, formatted as: + [1] Publisher - URL + [2] Wikipedia: (Section:
) + Order by first mention in your text. Every factual sentence in your final + response must map to at least one source. + - If web and knowledge base disagree, surface the disagreement and prefer sources + with newer publication dates. + - Do not invent URLs or sources. + - If both tools fail, say so and suggest 2–3 refined queries. + + Be sure to mention the sources in your response, including the URL if available, + and do not make up information. + """, + # Allow the planner agent to invoke the worker agent. + # The long context provided to the worker agent is hidden from the main agent. + tools=[ + kb_agent.as_tool( + tool_name="search_knowledgebase", + tool_description=( + "Search the knowledge base for a query and return a concise summary " + "of the key findings, along with the sources used to generate " + "the summary" + ), + ), + agents.function_tool(gemini_grounding_tool.get_web_search_grounded_response), + ], + # a larger, more capable model for planning and reasoning over summaries + model=agents.OpenAIChatCompletionsModel( + model=AGENT_LLM_NAMES["planner"], openai_client=async_openai_client + ), +) + + +async def _main(question: str, gr_messages: list[ChatMessage]): + setup_langfuse_tracer() + + # Use the main agent as the entry point- not the worker agent. + with langfuse_client.start_as_current_span(name="Agents-SDK-Trace") as span: + span.update(input=question) + + result_stream = agents.Runner.run_streamed(main_agent, input=question) + async for _item in result_stream.stream_events(): + gr_messages += oai_agent_stream_to_gradio_messages(_item) + if len(gr_messages) > 0: + yield gr_messages + + span.update(output=result_stream.final_output) + + +demo = gr.ChatInterface( + _main, + title="2.3 Multi-Agent with Multiple Search Tools", + type="messages", + examples=[ + "At which university did the SVP Software Engineering" + " at Apple (as of June 2025) earn their engineering degree?", + "How does the annual growth in the 50th-percentile income " + "in the US compare with that in Canada?", + ], +) + +if __name__ == "__main__": + async_openai_client = AsyncOpenAI() + + signal.signal(signal.SIGINT, _handle_sigint) + + try: + demo.launch(share=True) + finally: + asyncio.run(_cleanup_clients())