Skip to content

Commit 3c3310f

Browse files
fcogidiamrit110
andauthored
Add web search tool (#32)
* Add web search functionality with Gemini integration - Updated .env.example to include WEB_SEARCH_BASE_URL and WEB_SEARCH_API_KEY. - Introduced web_search.py to implement a tool for fetching Google Search grounded responses from the Gemini model. * Update docstring in GeminiGroundingWithGoogleSearch to include reference link * Refactor web search error handling, update POST method, fix web search query extraction - Introduced RETRYABLE_STATUS for managing retryable HTTP status codes. - Updated the _post method to _post_payload for clarity and added a docstring. - Enhanced error handling in the backoff decorator to include specific exceptions and retry logic based on RETRYABLE_STATUS. * Add integration tests for web search functionality with Gemini grounding * Enhance error handling in GeminiGroundingWithGoogleSearch by raising ValueError for missing environment variables. * Rename tool module and test file * Rename Response class to GroundedResponse for clarity * Refactor citation handling in Gemini grounding - Updated the add_citations function to return a tuple of synthesized text and a mapping of citation IDs to source labels. - Enhanced the GroundedResponse class to include a citations field. - Improved error handling and type checks throughout the citation processing logic. - Introduced a new _collect_citations function to streamline citation ID collection from candidates. --------- Co-authored-by: Amrit Krishnan <amrit110@gmail.com>
1 parent ea2ee6a commit 3c3310f

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ WEAVIATE_GRPC_SECURE="true" # set to false for localhost
2323

2424
# Optionally, specify E2B.dev API key for Python Code Interpreter
2525
E2B_API_KEY="e2b_..."
26+
27+
WEB_SEARCH_BASE_URL="..."
28+
WEB_SEARCH_API_KEY="..."
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
"""Implements a tool to fetch Google Search grounded responses from Gemini."""
2+
3+
import asyncio
4+
import os
5+
from typing import Any, Literal
6+
from urllib.parse import urlparse
7+
8+
import backoff
9+
import httpx
10+
from pydantic import BaseModel
11+
from pydantic.fields import Field
12+
13+
14+
RETRYABLE_STATUS = {429, 500, 502, 503, 504}
15+
16+
17+
class ModelSettings(BaseModel):
18+
"""Configuration for the Gemini model used for web search."""
19+
20+
model: Literal["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"] = (
21+
"gemini-2.5-flash"
22+
)
23+
temperature: float | None = Field(default=0.2, ge=0, le=2)
24+
max_output_tokens: int | None = Field(default=None, ge=1)
25+
seed: int | None = None
26+
thinking_budget: int | None = Field(default=-1, ge=-1)
27+
28+
29+
class GroundedResponse(BaseModel):
30+
"""Response returned by Gemini."""
31+
32+
text_with_citations: str
33+
web_search_queries: list[str]
34+
citations: dict[int, str]
35+
36+
37+
class GeminiGroundingWithGoogleSearch:
38+
"""Tool for fetching Google Search grounded responses from Gemini via a proxy.
39+
40+
Parameters
41+
----------
42+
base_url : str, optional, default=None
43+
Base URL for the Gemini proxy. Defaults to the value of the
44+
``WEB_SEARCH_BASE_URL`` environment variable.
45+
api_key : str, optional, default=None
46+
API key for the Gemini proxy. Defaults to the value of the
47+
``WEB_SEARCH_API_KEY`` environment variable.
48+
model_settings : ModelSettings, optional, default=None
49+
Settings for the Gemini model used for web search.
50+
max_concurrency : int, optional, default=5
51+
Maximum number of concurrent Gemini requests.
52+
timeout : int, optional, default=300
53+
Timeout for requests to the server.
54+
55+
Raises
56+
------
57+
ValueError
58+
If the ``WEB_SEARCH_API_KEY`` environment variable is not set or the
59+
``WEB_SEARCH_BASE_URL`` environment variable is not set.
60+
"""
61+
62+
def __init__(
63+
self,
64+
base_url: str | None = None,
65+
api_key: str | None = None,
66+
*,
67+
model_settings: ModelSettings | None = None,
68+
max_concurrency: int = 5,
69+
timeout: int = 300,
70+
) -> None:
71+
self.base_url = base_url or os.getenv("WEB_SEARCH_BASE_URL")
72+
self.api_key = api_key or os.getenv("WEB_SEARCH_API_KEY")
73+
self.model_settings = model_settings or ModelSettings()
74+
75+
if self.api_key is None:
76+
raise ValueError("WEB_SEARCH_API_KEY environment variable is not set.")
77+
if self.base_url is None:
78+
raise ValueError("WEB_SEARCH_BASE_URL environment variable is not set.")
79+
80+
self._semaphore = asyncio.Semaphore(max_concurrency)
81+
82+
self._client = httpx.AsyncClient(
83+
timeout=timeout, headers={"X-API-Key": self.api_key}
84+
)
85+
self._endpoint = f"{self.base_url.strip('/')}/api/v1/grounding_with_search"
86+
87+
async def get_web_search_grounded_response(self, query: str) -> GroundedResponse:
88+
"""Get Google Search grounded response to query from Gemini model.
89+
90+
This function calls a Gemini model with Google Search tool enabled. How
91+
it works [1]_:
92+
- The model analyzes the input query and determines if a Google Search
93+
can improve the answer.
94+
- If needed, the model automatically generates one or multiple search
95+
queries and executes them.
96+
- The model processes the search results, synthesizes the information,
97+
and formulates a response.
98+
- The API returns a final, user-friendly response that is grounded in
99+
the search results.
100+
101+
Parameters
102+
----------
103+
query : str
104+
Query to pass to Gemini.
105+
106+
Returns
107+
-------
108+
GroundedResponse
109+
Response returned by Gemini. This includes the text with citations added,
110+
the web search queries executed (expanded from the input query), and a
111+
mapping of the citation ids to the website where the citation is from.
112+
113+
References
114+
----------
115+
.. [1] https://ai.google.dev/gemini-api/docs/google-search#how_grounding_with_google_search_works
116+
"""
117+
# Payload
118+
payload = self.model_settings.model_dump(exclude_unset=True)
119+
payload["query"] = query
120+
121+
# Call Gemini
122+
response = await self._post_payload(payload)
123+
124+
try:
125+
response.raise_for_status()
126+
except httpx.HTTPStatusError as exc:
127+
raise exc from exc
128+
129+
response_json = response.json()
130+
131+
candidates: list[dict[str, Any]] | None = response_json.get("candidates")
132+
grounding_metadata: dict[str, Any] | None = (
133+
candidates[0].get("grounding_metadata") if candidates else None
134+
)
135+
web_search_queries: list[str] = (
136+
grounding_metadata["web_search_queries"] if grounding_metadata else []
137+
)
138+
139+
text_with_citations, citations = add_citations(response_json)
140+
141+
return GroundedResponse(
142+
text_with_citations=text_with_citations,
143+
web_search_queries=web_search_queries,
144+
citations=citations,
145+
)
146+
147+
@backoff.on_exception(
148+
backoff.expo,
149+
(
150+
httpx.TimeoutException,
151+
httpx.NetworkError,
152+
httpx.HTTPStatusError, # only retry codes in RETRYABLE_STATUS
153+
),
154+
giveup=lambda exc: (
155+
isinstance(exc, httpx.HTTPStatusError)
156+
and exc.response.status_code not in RETRYABLE_STATUS
157+
),
158+
jitter=backoff.full_jitter,
159+
max_tries=5,
160+
)
161+
async def _post_payload(self, payload: dict[str, object]) -> httpx.Response:
162+
"""Send a POST request to the endpoint with the given payload."""
163+
async with self._semaphore:
164+
return await self._client.post(self._endpoint, json=payload)
165+
166+
167+
def add_citations(response: dict[str, object]) -> tuple[str, dict[int, str]]:
168+
"""Add citations to the Gemini response.
169+
170+
Code based on example in [1]_.
171+
172+
Parameters
173+
----------
174+
response : dict of str to object
175+
JSON response returned by Gemini.
176+
177+
Returns
178+
-------
179+
tuple[str, dict[int, str]]
180+
The synthesized text and a mapping of citation ids to source labels.
181+
182+
References
183+
----------
184+
.. [1] https://ai.google.dev/gemini-api/docs/google-search#attributing_sources_with_inline_citations
185+
"""
186+
candidates = response.get("candidates") if isinstance(response, dict) else None
187+
if not candidates:
188+
return "", {}
189+
190+
candidate = candidates[0] or {}
191+
content = candidate.get("content") if isinstance(candidate, dict) else {}
192+
parts = content.get("parts") if isinstance(content, dict) else []
193+
194+
text = ""
195+
for part in parts if isinstance(parts, list) else []:
196+
if isinstance(part, dict) and isinstance(part.get("text"), str):
197+
text = part["text"]
198+
break
199+
if not text:
200+
return "", {}
201+
202+
meta = candidate.get("grounding_metadata") if isinstance(candidate, dict) else {}
203+
raw_supports = meta.get("grounding_supports") if isinstance(meta, dict) else []
204+
supports = raw_supports if isinstance(raw_supports, list) else []
205+
raw_chunks = meta.get("grounding_chunks") if isinstance(meta, dict) else []
206+
chunks = raw_chunks if isinstance(raw_chunks, list) else []
207+
208+
citations: dict[int, str] = {}
209+
chunk_to_id: dict[int, int] = {}
210+
211+
if supports and chunks:
212+
citations, chunk_to_id = _collect_citations(candidate)
213+
214+
# Sort supports by end_index in descending order to avoid shifting issues
215+
# when inserting.
216+
sorted_supports = sorted(
217+
(s for s in supports if isinstance(s, dict) and s.get("segment")),
218+
key=lambda s: s["segment"].get("end_index", 0),
219+
reverse=True,
220+
)
221+
222+
for support in sorted_supports:
223+
segment = support.get("segment") or {}
224+
end_index = segment.get("end_index")
225+
if not isinstance(end_index, int) or end_index < 0 or end_index > len(text):
226+
continue
227+
indices = support.get("grounding_chunk_indices") or []
228+
citation_links: list[str] = []
229+
for idx in indices:
230+
if not isinstance(idx, int):
231+
continue
232+
citation_id = chunk_to_id.get(idx)
233+
if citation_id is None or idx >= len(chunks):
234+
continue
235+
web = chunks[idx].get("web") if isinstance(chunks[idx], dict) else {}
236+
uri = web.get("uri") if isinstance(web, dict) else None
237+
if uri:
238+
citation_links.append(f"[{citation_id}]({uri})")
239+
240+
if citation_links:
241+
citation_string = ", ".join(citation_links)
242+
text = text[:end_index] + citation_string + text[end_index:]
243+
244+
return text, citations
245+
246+
247+
def _collect_citations(candidate: dict) -> tuple[dict[int, str], dict[int, int]]:
248+
"""Collect citation ids from a candidate dict."""
249+
supports = candidate["grounding_metadata"]["grounding_supports"]
250+
chunks = candidate["grounding_metadata"]["grounding_chunks"]
251+
252+
citations: dict[int, str] = {}
253+
chunk_to_id: dict[int, int] = {}
254+
next_id = 1
255+
256+
def label_for(chunk: dict) -> str:
257+
web = chunk.get("web") or {}
258+
title = web.get("title")
259+
uri = web.get("uri")
260+
if title:
261+
return title
262+
if uri:
263+
parsed = urlparse(uri)
264+
return parsed.hostname or parsed.netloc or uri
265+
return "unknown source"
266+
267+
for support in supports:
268+
if not isinstance(support, dict):
269+
continue
270+
for chunk_idx in support.get("grounding_chunk_indices", []):
271+
if chunk_idx not in chunk_to_id and 0 <= chunk_idx < len(chunks):
272+
label = label_for(chunks[chunk_idx])
273+
chunk_to_id[chunk_idx] = next_id
274+
citations[next_id] = label
275+
next_id += 1
276+
277+
return citations, chunk_to_id
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Test web search integration."""
2+
3+
import os
4+
5+
import pytest
6+
7+
from src.utils import pretty_print
8+
from src.utils.tools.gemini_grounding import GeminiGroundingWithGoogleSearch
9+
10+
11+
@pytest.mark.asyncio
12+
async def test_web_search_with_gemini_grounding():
13+
"""Test Gemini grounding with Google Search integration."""
14+
# Check if the environment variable is set
15+
assert os.getenv("WEB_SEARCH_BASE_URL")
16+
assert os.getenv("WEB_SEARCH_API_KEY")
17+
18+
tool_cls = GeminiGroundingWithGoogleSearch()
19+
response = await tool_cls.get_web_search_grounded_response(
20+
"How does the annual growth in the 50th-percentile income "
21+
"in the US compare with that in Canada?"
22+
)
23+
24+
pretty_print(response.text_with_citations)
25+
assert response.text_with_citations

0 commit comments

Comments
 (0)