|
| 1 | +"""Implements a tool to fetch Google Search grounded responses from Gemini.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import os |
| 5 | +from typing import Any, Literal |
| 6 | +from urllib.parse import urlparse |
| 7 | + |
| 8 | +import backoff |
| 9 | +import httpx |
| 10 | +from pydantic import BaseModel |
| 11 | +from pydantic.fields import Field |
| 12 | + |
| 13 | + |
| 14 | +RETRYABLE_STATUS = {429, 500, 502, 503, 504} |
| 15 | + |
| 16 | + |
| 17 | +class ModelSettings(BaseModel): |
| 18 | + """Configuration for the Gemini model used for web search.""" |
| 19 | + |
| 20 | + model: Literal["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"] = ( |
| 21 | + "gemini-2.5-flash" |
| 22 | + ) |
| 23 | + temperature: float | None = Field(default=0.2, ge=0, le=2) |
| 24 | + max_output_tokens: int | None = Field(default=None, ge=1) |
| 25 | + seed: int | None = None |
| 26 | + thinking_budget: int | None = Field(default=-1, ge=-1) |
| 27 | + |
| 28 | + |
| 29 | +class GroundedResponse(BaseModel): |
| 30 | + """Response returned by Gemini.""" |
| 31 | + |
| 32 | + text_with_citations: str |
| 33 | + web_search_queries: list[str] |
| 34 | + citations: dict[int, str] |
| 35 | + |
| 36 | + |
| 37 | +class GeminiGroundingWithGoogleSearch: |
| 38 | + """Tool for fetching Google Search grounded responses from Gemini via a proxy. |
| 39 | +
|
| 40 | + Parameters |
| 41 | + ---------- |
| 42 | + base_url : str, optional, default=None |
| 43 | + Base URL for the Gemini proxy. Defaults to the value of the |
| 44 | + ``WEB_SEARCH_BASE_URL`` environment variable. |
| 45 | + api_key : str, optional, default=None |
| 46 | + API key for the Gemini proxy. Defaults to the value of the |
| 47 | + ``WEB_SEARCH_API_KEY`` environment variable. |
| 48 | + model_settings : ModelSettings, optional, default=None |
| 49 | + Settings for the Gemini model used for web search. |
| 50 | + max_concurrency : int, optional, default=5 |
| 51 | + Maximum number of concurrent Gemini requests. |
| 52 | + timeout : int, optional, default=300 |
| 53 | + Timeout for requests to the server. |
| 54 | +
|
| 55 | + Raises |
| 56 | + ------ |
| 57 | + ValueError |
| 58 | + If the ``WEB_SEARCH_API_KEY`` environment variable is not set or the |
| 59 | + ``WEB_SEARCH_BASE_URL`` environment variable is not set. |
| 60 | + """ |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + base_url: str | None = None, |
| 65 | + api_key: str | None = None, |
| 66 | + *, |
| 67 | + model_settings: ModelSettings | None = None, |
| 68 | + max_concurrency: int = 5, |
| 69 | + timeout: int = 300, |
| 70 | + ) -> None: |
| 71 | + self.base_url = base_url or os.getenv("WEB_SEARCH_BASE_URL") |
| 72 | + self.api_key = api_key or os.getenv("WEB_SEARCH_API_KEY") |
| 73 | + self.model_settings = model_settings or ModelSettings() |
| 74 | + |
| 75 | + if self.api_key is None: |
| 76 | + raise ValueError("WEB_SEARCH_API_KEY environment variable is not set.") |
| 77 | + if self.base_url is None: |
| 78 | + raise ValueError("WEB_SEARCH_BASE_URL environment variable is not set.") |
| 79 | + |
| 80 | + self._semaphore = asyncio.Semaphore(max_concurrency) |
| 81 | + |
| 82 | + self._client = httpx.AsyncClient( |
| 83 | + timeout=timeout, headers={"X-API-Key": self.api_key} |
| 84 | + ) |
| 85 | + self._endpoint = f"{self.base_url.strip('/')}/api/v1/grounding_with_search" |
| 86 | + |
| 87 | + async def get_web_search_grounded_response(self, query: str) -> GroundedResponse: |
| 88 | + """Get Google Search grounded response to query from Gemini model. |
| 89 | +
|
| 90 | + This function calls a Gemini model with Google Search tool enabled. How |
| 91 | + it works [1]_: |
| 92 | + - The model analyzes the input query and determines if a Google Search |
| 93 | + can improve the answer. |
| 94 | + - If needed, the model automatically generates one or multiple search |
| 95 | + queries and executes them. |
| 96 | + - The model processes the search results, synthesizes the information, |
| 97 | + and formulates a response. |
| 98 | + - The API returns a final, user-friendly response that is grounded in |
| 99 | + the search results. |
| 100 | +
|
| 101 | + Parameters |
| 102 | + ---------- |
| 103 | + query : str |
| 104 | + Query to pass to Gemini. |
| 105 | +
|
| 106 | + Returns |
| 107 | + ------- |
| 108 | + GroundedResponse |
| 109 | + Response returned by Gemini. This includes the text with citations added, |
| 110 | + the web search queries executed (expanded from the input query), and a |
| 111 | + mapping of the citation ids to the website where the citation is from. |
| 112 | +
|
| 113 | + References |
| 114 | + ---------- |
| 115 | + .. [1] https://ai.google.dev/gemini-api/docs/google-search#how_grounding_with_google_search_works |
| 116 | + """ |
| 117 | + # Payload |
| 118 | + payload = self.model_settings.model_dump(exclude_unset=True) |
| 119 | + payload["query"] = query |
| 120 | + |
| 121 | + # Call Gemini |
| 122 | + response = await self._post_payload(payload) |
| 123 | + |
| 124 | + try: |
| 125 | + response.raise_for_status() |
| 126 | + except httpx.HTTPStatusError as exc: |
| 127 | + raise exc from exc |
| 128 | + |
| 129 | + response_json = response.json() |
| 130 | + |
| 131 | + candidates: list[dict[str, Any]] | None = response_json.get("candidates") |
| 132 | + grounding_metadata: dict[str, Any] | None = ( |
| 133 | + candidates[0].get("grounding_metadata") if candidates else None |
| 134 | + ) |
| 135 | + web_search_queries: list[str] = ( |
| 136 | + grounding_metadata["web_search_queries"] if grounding_metadata else [] |
| 137 | + ) |
| 138 | + |
| 139 | + text_with_citations, citations = add_citations(response_json) |
| 140 | + |
| 141 | + return GroundedResponse( |
| 142 | + text_with_citations=text_with_citations, |
| 143 | + web_search_queries=web_search_queries, |
| 144 | + citations=citations, |
| 145 | + ) |
| 146 | + |
| 147 | + @backoff.on_exception( |
| 148 | + backoff.expo, |
| 149 | + ( |
| 150 | + httpx.TimeoutException, |
| 151 | + httpx.NetworkError, |
| 152 | + httpx.HTTPStatusError, # only retry codes in RETRYABLE_STATUS |
| 153 | + ), |
| 154 | + giveup=lambda exc: ( |
| 155 | + isinstance(exc, httpx.HTTPStatusError) |
| 156 | + and exc.response.status_code not in RETRYABLE_STATUS |
| 157 | + ), |
| 158 | + jitter=backoff.full_jitter, |
| 159 | + max_tries=5, |
| 160 | + ) |
| 161 | + async def _post_payload(self, payload: dict[str, object]) -> httpx.Response: |
| 162 | + """Send a POST request to the endpoint with the given payload.""" |
| 163 | + async with self._semaphore: |
| 164 | + return await self._client.post(self._endpoint, json=payload) |
| 165 | + |
| 166 | + |
| 167 | +def add_citations(response: dict[str, object]) -> tuple[str, dict[int, str]]: |
| 168 | + """Add citations to the Gemini response. |
| 169 | +
|
| 170 | + Code based on example in [1]_. |
| 171 | +
|
| 172 | + Parameters |
| 173 | + ---------- |
| 174 | + response : dict of str to object |
| 175 | + JSON response returned by Gemini. |
| 176 | +
|
| 177 | + Returns |
| 178 | + ------- |
| 179 | + tuple[str, dict[int, str]] |
| 180 | + The synthesized text and a mapping of citation ids to source labels. |
| 181 | +
|
| 182 | + References |
| 183 | + ---------- |
| 184 | + .. [1] https://ai.google.dev/gemini-api/docs/google-search#attributing_sources_with_inline_citations |
| 185 | + """ |
| 186 | + candidates = response.get("candidates") if isinstance(response, dict) else None |
| 187 | + if not candidates: |
| 188 | + return "", {} |
| 189 | + |
| 190 | + candidate = candidates[0] or {} |
| 191 | + content = candidate.get("content") if isinstance(candidate, dict) else {} |
| 192 | + parts = content.get("parts") if isinstance(content, dict) else [] |
| 193 | + |
| 194 | + text = "" |
| 195 | + for part in parts if isinstance(parts, list) else []: |
| 196 | + if isinstance(part, dict) and isinstance(part.get("text"), str): |
| 197 | + text = part["text"] |
| 198 | + break |
| 199 | + if not text: |
| 200 | + return "", {} |
| 201 | + |
| 202 | + meta = candidate.get("grounding_metadata") if isinstance(candidate, dict) else {} |
| 203 | + raw_supports = meta.get("grounding_supports") if isinstance(meta, dict) else [] |
| 204 | + supports = raw_supports if isinstance(raw_supports, list) else [] |
| 205 | + raw_chunks = meta.get("grounding_chunks") if isinstance(meta, dict) else [] |
| 206 | + chunks = raw_chunks if isinstance(raw_chunks, list) else [] |
| 207 | + |
| 208 | + citations: dict[int, str] = {} |
| 209 | + chunk_to_id: dict[int, int] = {} |
| 210 | + |
| 211 | + if supports and chunks: |
| 212 | + citations, chunk_to_id = _collect_citations(candidate) |
| 213 | + |
| 214 | + # Sort supports by end_index in descending order to avoid shifting issues |
| 215 | + # when inserting. |
| 216 | + sorted_supports = sorted( |
| 217 | + (s for s in supports if isinstance(s, dict) and s.get("segment")), |
| 218 | + key=lambda s: s["segment"].get("end_index", 0), |
| 219 | + reverse=True, |
| 220 | + ) |
| 221 | + |
| 222 | + for support in sorted_supports: |
| 223 | + segment = support.get("segment") or {} |
| 224 | + end_index = segment.get("end_index") |
| 225 | + if not isinstance(end_index, int) or end_index < 0 or end_index > len(text): |
| 226 | + continue |
| 227 | + indices = support.get("grounding_chunk_indices") or [] |
| 228 | + citation_links: list[str] = [] |
| 229 | + for idx in indices: |
| 230 | + if not isinstance(idx, int): |
| 231 | + continue |
| 232 | + citation_id = chunk_to_id.get(idx) |
| 233 | + if citation_id is None or idx >= len(chunks): |
| 234 | + continue |
| 235 | + web = chunks[idx].get("web") if isinstance(chunks[idx], dict) else {} |
| 236 | + uri = web.get("uri") if isinstance(web, dict) else None |
| 237 | + if uri: |
| 238 | + citation_links.append(f"[{citation_id}]({uri})") |
| 239 | + |
| 240 | + if citation_links: |
| 241 | + citation_string = ", ".join(citation_links) |
| 242 | + text = text[:end_index] + citation_string + text[end_index:] |
| 243 | + |
| 244 | + return text, citations |
| 245 | + |
| 246 | + |
| 247 | +def _collect_citations(candidate: dict) -> tuple[dict[int, str], dict[int, int]]: |
| 248 | + """Collect citation ids from a candidate dict.""" |
| 249 | + supports = candidate["grounding_metadata"]["grounding_supports"] |
| 250 | + chunks = candidate["grounding_metadata"]["grounding_chunks"] |
| 251 | + |
| 252 | + citations: dict[int, str] = {} |
| 253 | + chunk_to_id: dict[int, int] = {} |
| 254 | + next_id = 1 |
| 255 | + |
| 256 | + def label_for(chunk: dict) -> str: |
| 257 | + web = chunk.get("web") or {} |
| 258 | + title = web.get("title") |
| 259 | + uri = web.get("uri") |
| 260 | + if title: |
| 261 | + return title |
| 262 | + if uri: |
| 263 | + parsed = urlparse(uri) |
| 264 | + return parsed.hostname or parsed.netloc or uri |
| 265 | + return "unknown source" |
| 266 | + |
| 267 | + for support in supports: |
| 268 | + if not isinstance(support, dict): |
| 269 | + continue |
| 270 | + for chunk_idx in support.get("grounding_chunk_indices", []): |
| 271 | + if chunk_idx not in chunk_to_id and 0 <= chunk_idx < len(chunks): |
| 272 | + label = label_for(chunks[chunk_idx]) |
| 273 | + chunk_to_id[chunk_idx] = next_id |
| 274 | + citations[next_id] = label |
| 275 | + next_id += 1 |
| 276 | + |
| 277 | + return citations, chunk_to_id |
0 commit comments