Skip to content

Commit 785ca6f

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Implement exponential backoff to GeminiLLM, and enable it by default
Exponential backoff is enabled by default. It start with 5 seconds and ^2 backoff for subsequent retries with max 60s delay. As the result, it fires at 5s, 10s, 20s, 40s, 60s. Usage: It is enabled by default, but is configurable during agent declaration: ```python root_agent = Agent( model=Gemini( model='gemini-2.0-flash', retry_config=RetryConfig(initial_delay_sec=60, max_retries=3) ), ... ) ``` Note: This config cannot be added to RunConfig. Although there are similar configurations, RunConfig is only available in invocation_context, which is not available to BaseLLM and any derived LLM classes. Tested locally: ```bash The description about you is "Checks if input is valid using predefined tools"' [logging_plugin] Available Tools: ['check_valid_input', 'check_valid_input2'] 2025-07-21 18:00:07.767904 2025-07-21 18:00:12.776910 2025-07-21 18:00:22.792078 2025-07-21 18:00:42.817873 2025-07-21 18:01:22.856147 [logging_plugin] 🧠 LLM ERROR [logging_plugin] Agent: check_input [logging_plugin] Error: 503 None. {} ``` PiperOrigin-RevId: 784029405
1 parent 18f5bea commit 785ca6f

File tree

12 files changed

+906
-7
lines changed

12 files changed

+906
-7
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,13 @@ async def _call_llm_async(
534534
with tracer.start_as_current_span('call_llm'):
535535
if invocation_context.run_config.support_cfc:
536536
invocation_context.live_request_queue = LiveRequestQueue()
537-
async for llm_response in self.run_live(invocation_context):
537+
responses_generator = self.run_live(invocation_context)
538+
async for llm_response in self._run_and_handle_error(
539+
responses_generator,
540+
invocation_context,
541+
llm_request,
542+
model_response_event,
543+
):
538544
# Runs after_model_callback if it exists.
539545
if altered_llm_response := await self._handle_after_model_callback(
540546
invocation_context, llm_response, model_response_event
@@ -553,10 +559,16 @@ async def _call_llm_async(
553559
# the counter beyond the max set value, then the execution is stopped
554560
# right here, and exception is thrown.
555561
invocation_context.increment_llm_call_count()
556-
async for llm_response in llm.generate_content_async(
562+
responses_generator = llm.generate_content_async(
557563
llm_request,
558564
stream=invocation_context.run_config.streaming_mode
559565
== StreamingMode.SSE,
566+
)
567+
async for llm_response in self._run_and_handle_error(
568+
responses_generator,
569+
invocation_context,
570+
llm_request,
571+
model_response_event,
560572
):
561573
trace_call_llm(
562574
invocation_context,
@@ -673,6 +685,43 @@ def _finalize_model_response_event(
673685

674686
return model_response_event
675687

688+
async def _run_and_handle_error(
689+
self,
690+
response_generator: AsyncGenerator[LlmResponse, None],
691+
invocation_context: InvocationContext,
692+
llm_request: LlmRequest,
693+
model_response_event: Event,
694+
) -> AsyncGenerator[LlmResponse, None]:
695+
"""Runs the response generator and processes the error with plugins.
696+
697+
Args:
698+
response_generator: The response generator to run.
699+
invocation_context: The invocation context.
700+
llm_request: The LLM request.
701+
model_response_event: The model response event.
702+
703+
Yields:
704+
A generator of LlmResponse.
705+
"""
706+
try:
707+
async for response in response_generator:
708+
yield response
709+
except Exception as model_error:
710+
callback_context = CallbackContext(
711+
invocation_context, event_actions=model_response_event.actions
712+
)
713+
error_response = (
714+
await invocation_context.plugin_manager.run_on_model_error_callback(
715+
callback_context=callback_context,
716+
llm_request=llm_request,
717+
error=model_error,
718+
)
719+
)
720+
if error_response is not None:
721+
yield error_response
722+
else:
723+
raise model_error
724+
676725
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
677726
from ...agents.llm_agent import LlmAgent
678727

src/google/adk/flows/llm_flows/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,21 @@ async def handle_function_calls_async(
176176

177177
# Step 3: Otherwise, proceed calling the tool normally.
178178
if function_response is None:
179-
function_response = await __call_tool_async(
180-
tool, args=function_args, tool_context=tool_context
181-
)
179+
try:
180+
function_response = await __call_tool_async(
181+
tool, args=function_args, tool_context=tool_context
182+
)
183+
except Exception as tool_error:
184+
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
185+
tool=tool,
186+
tool_args=function_args,
187+
tool_context=tool_context,
188+
error=tool_error,
189+
)
190+
if error_response is not None:
191+
function_response = error_response
192+
else:
193+
raise tool_error
182194

183195
# Step 4: Check if plugin after_tool_callback overrides the function
184196
# response.

src/google/adk/models/google_llm.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,22 @@
2121
import os
2222
import sys
2323
from typing import AsyncGenerator
24+
from typing import Callable
2425
from typing import cast
26+
from typing import Optional
2527
from typing import TYPE_CHECKING
2628
from typing import Union
2729

2830
from google.genai import Client
2931
from google.genai import types
32+
from google.genai.errors import ClientError
33+
from google.genai.errors import ServerError
3034
from google.genai.types import FinishReason
35+
from pydantic import BaseModel
36+
from tenacity import retry
37+
from tenacity import retry_if_exception
38+
from tenacity import stop_after_attempt
39+
from tenacity import wait_exponential
3140
from typing_extensions import override
3241

3342
from .. import version
@@ -48,6 +57,48 @@
4857
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
4958

5059

60+
class RetryConfig(BaseModel):
61+
"""Config for controlling retry behavior during model execution.
62+
63+
Use this config in agent.model. Example:
64+
```
65+
agent = Agent(
66+
model=Gemini(
67+
retry_config=RetryConfig(initial_delay_sec=10, max_retries=3)
68+
),
69+
...
70+
)
71+
```
72+
"""
73+
74+
initial_delay_sec: int = 5
75+
"""The initial delay before the first retry, in seconds."""
76+
77+
expo_base: int = 2
78+
"""The exponential base to add to the retry delay."""
79+
80+
max_delay_sec: int = 60
81+
"""The maximum delay before the next retry, in seconds."""
82+
83+
max_retries: int = 5
84+
"""The maximum number of retries."""
85+
86+
retry_predicate: Callable[[Exception], bool] = None
87+
"""The predicate function to determine if the error is retryable."""
88+
89+
90+
def retry_on_resumable_error(error: Exception) -> bool:
91+
"""Returns True if the error is non-retryable."""
92+
# Retry on Resource exhausted error
93+
if isinstance(error, ClientError) and error.code == 429:
94+
return True
95+
96+
# Retry on Service unavailable error
97+
if isinstance(error, ServerError) and error.code == 503:
98+
return True
99+
return False
100+
101+
51102
class Gemini(BaseLlm):
52103
"""Integration for Gemini models.
53104
@@ -57,6 +108,9 @@ class Gemini(BaseLlm):
57108

58109
model: str = 'gemini-1.5-flash'
59110

111+
retry_config: Optional[RetryConfig] = RetryConfig()
112+
"""Use default retry config to retry on resumable model errors."""
113+
60114
@staticmethod
61115
@override
62116
def supported_models() -> list[str]:
@@ -106,7 +160,11 @@ async def generate_content_async(
106160
llm_request.config.http_options.headers.update(self._tracking_headers)
107161

108162
if stream:
109-
responses = await self.api_client.aio.models.generate_content_stream(
163+
retry_annotation = self._build_retry_wrapper()
164+
retryable_generate = retry_annotation(
165+
self.api_client.aio.models.generate_content_stream
166+
)
167+
responses = await retryable_generate(
110168
model=llm_request.model,
111169
contents=llm_request.contents,
112170
config=llm_request.config,
@@ -174,7 +232,11 @@ async def generate_content_async(
174232
)
175233

176234
else:
177-
response = await self.api_client.aio.models.generate_content(
235+
retry_annotation = self._build_retry_wrapper()
236+
retryable_generate = retry_annotation(
237+
self.api_client.aio.models.generate_content
238+
)
239+
response = await retryable_generate(
178240
model=llm_request.model,
179241
contents=llm_request.contents,
180242
config=llm_request.config,
@@ -284,6 +346,31 @@ def _preprocess_request(self, llm_request: LlmRequest) -> None:
284346
_remove_display_name_if_present(part.inline_data)
285347
_remove_display_name_if_present(part.file_data)
286348

349+
def _build_retry_wrapper(self) -> retry:
350+
"""Apply retry logic to the Gemini API client.
351+
352+
Underlyingly this returns a tenacity.retry annotation that can be applied
353+
to any function. Works for async functions as well.
354+
355+
Returns:
356+
A tenacity.retry annotation that can be applied to any function.
357+
"""
358+
# Use default retry config if not specified.
359+
config = self.retry_config or RetryConfig()
360+
retry_predicate = config.retry_predicate
361+
if not retry_predicate:
362+
retry_predicate = retry_if_exception(retry_on_resumable_error)
363+
return retry(
364+
stop=stop_after_attempt(config.max_retries),
365+
wait=wait_exponential(
366+
multiplier=config.initial_delay_sec,
367+
min=config.initial_delay_sec,
368+
max=config.max_delay_sec,
369+
),
370+
retry=retry_predicate,
371+
reraise=True,
372+
)
373+
287374

288375
def _build_function_declaration_log(
289376
func_decl: types.FunctionDeclaration,

src/google/adk/plugins/base_plugin.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,31 @@ async def after_model_callback(
265265
"""
266266
pass
267267

268+
async def on_model_error_callback(
269+
self,
270+
*,
271+
callback_context: CallbackContext,
272+
llm_request: LlmRequest,
273+
error: Exception,
274+
) -> Optional[LlmResponse]:
275+
"""Callback executed when a model call encounters an error.
276+
277+
This callback provides an opportunity to handle model errors gracefully,
278+
potentially providing alternative responses or recovery mechanisms.
279+
280+
Args:
281+
callback_context: The context for the current agent call.
282+
llm_request: The request that was sent to the model when the error
283+
occurred.
284+
error: The exception that was raised during model execution.
285+
286+
Returns:
287+
An optional LlmResponse. If an LlmResponse is returned, it will be used
288+
instead of propagating the error. Returning `None` allows the original
289+
error to be raised.
290+
"""
291+
pass
292+
268293
async def before_tool_callback(
269294
self,
270295
*,
@@ -315,3 +340,29 @@ async def after_tool_callback(
315340
result.
316341
"""
317342
pass
343+
344+
async def on_tool_error_callback(
345+
self,
346+
*,
347+
tool: BaseTool,
348+
tool_args: dict[str, Any],
349+
tool_context: ToolContext,
350+
error: Exception,
351+
) -> Optional[dict]:
352+
"""Callback executed when a tool call encounters an error.
353+
354+
This callback provides an opportunity to handle tool errors gracefully,
355+
potentially providing alternative responses or recovery mechanisms.
356+
357+
Args:
358+
tool: The tool instance that encountered an error.
359+
tool_args: The arguments that were passed to the tool.
360+
tool_context: The context specific to the tool execution.
361+
error: The exception that was raised during tool execution.
362+
363+
Returns:
364+
An optional dictionary. If a dictionary is returned, it will be used as
365+
the tool response instead of propagating the error. Returning `None`
366+
allows the original error to be raised.
367+
"""
368+
pass

0 commit comments

Comments
 (0)