Skip to content

Commit 157483f

Browse files
pwwpchecopybara-github
authored andcommitted
feat: Implement exponential backoff to GeminiLLM
Output: ```bash Retry attempt: 4, after 32.799424 [logging_plugin] 🧠 LLM ERROR [logging_plugin] Agent: check_input [logging_plugin] Error: 429 None. {} <bound method AsyncModels.generate_content of <google.genai.models.AsyncModels object at 0x7f5c905c38d0>> True Retrying... Retry attempt: 1, after 1.6e-05 Retry attempt: 2, after 3.999309 Retry attempt: 3, after 5.987343 Retry attempt: 4, after 14.075334 ``` PiperOrigin-RevId: 784029405
1 parent 31fa5d9 commit 157483f

File tree

14 files changed

+976
-8
lines changed

14 files changed

+976
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"PyYAML>=6.0.2", # For APIHubToolset.
2929
"anyio>=4.9.0;python_version>='3.10'", # For MCP Session Manager
3030
"authlib>=1.5.1", # For RestAPI Tool
31+
"backoff>=2.2.1", # For LLM Retries
3132
"click>=8.1.8", # For CLI tools
3233
"fastapi>=0.115.0", # FastAPI framework
3334
"google-api-python-client>=2.157.0", # Google API client discovery

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from abc import ABC
1818
import asyncio
1919
import datetime
20+
from enum import Enum
2021
import inspect
2122
import logging
2223
from typing import AsyncGenerator
24+
from typing import Callable
2325
from typing import cast
2426
from typing import Optional
2527
from typing import TYPE_CHECKING
@@ -36,6 +38,7 @@
3638
from ...agents.run_config import StreamingMode
3739
from ...agents.transcription_entry import TranscriptionEntry
3840
from ...events.event import Event
41+
from ...models.base_llm import ModelErrorStrategy
3942
from ...models.base_llm_connection import BaseLlmConnection
4043
from ...models.llm_request import LlmRequest
4144
from ...models.llm_response import LlmResponse
@@ -521,7 +524,13 @@ async def _call_llm_async(
521524
with tracer.start_as_current_span('call_llm'):
522525
if invocation_context.run_config.support_cfc:
523526
invocation_context.live_request_queue = LiveRequestQueue()
524-
async for llm_response in self.run_live(invocation_context):
527+
responses_generator = lambda: self.run_live(invocation_context)
528+
async for llm_response in self._run_and_handle_error(
529+
responses_generator,
530+
invocation_context,
531+
llm_request,
532+
model_response_event,
533+
):
525534
# Runs after_model_callback if it exists.
526535
if altered_llm_response := await self._handle_after_model_callback(
527536
invocation_context, llm_response, model_response_event
@@ -540,10 +549,16 @@ async def _call_llm_async(
540549
# the counter beyond the max set value, then the execution is stopped
541550
# right here, and exception is thrown.
542551
invocation_context.increment_llm_call_count()
543-
async for llm_response in llm.generate_content_async(
552+
responses_generator = lambda: llm.generate_content_async(
544553
llm_request,
545554
stream=invocation_context.run_config.streaming_mode
546555
== StreamingMode.SSE,
556+
)
557+
async for llm_response in self._run_and_handle_error(
558+
responses_generator,
559+
invocation_context,
560+
llm_request,
561+
model_response_event,
547562
):
548563
trace_call_llm(
549564
invocation_context,
@@ -660,6 +675,54 @@ def _finalize_model_response_event(
660675

661676
return model_response_event
662677

678+
async def _run_and_handle_error(
679+
self,
680+
response_generator: Callable[..., AsyncGenerator[LlmResponse, None]],
681+
invocation_context: InvocationContext,
682+
llm_request: LlmRequest,
683+
model_response_event: Event,
684+
) -> AsyncGenerator[LlmResponse, None]:
685+
"""Runs the response generator and processes the error with plugins.
686+
687+
Args:
688+
response_generator: The response generator to run.
689+
invocation_context: The invocation context.
690+
llm_request: The LLM request.
691+
model_response_event: The model response event.
692+
693+
Yields:
694+
A generator of LlmResponse.
695+
"""
696+
while True:
697+
try:
698+
responses_generator_instance = response_generator()
699+
async for response in responses_generator_instance:
700+
yield response
701+
break
702+
except Exception as model_error:
703+
callback_context = CallbackContext(
704+
invocation_context, event_actions=model_response_event.actions
705+
)
706+
outcome = (
707+
await invocation_context.plugin_manager.run_on_model_error_callback(
708+
callback_context=callback_context,
709+
llm_request=llm_request,
710+
error=model_error,
711+
)
712+
)
713+
# Retry the LLM call if the plugin outcome is RETRY.
714+
if outcome == ModelErrorStrategy.RETRY:
715+
continue
716+
717+
# If the plugin outcome is PASS, we can break the loop.
718+
if outcome == ModelErrorStrategy.PASS:
719+
break
720+
if outcome is not None:
721+
yield outcome
722+
break
723+
else:
724+
raise model_error
725+
663726
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
664727
from ...agents.llm_agent import LlmAgent
665728

src/google/adk/flows/llm_flows/functions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,21 @@ async def handle_function_calls_async(
176176

177177
# Step 3: Otherwise, proceed calling the tool normally.
178178
if function_response is None:
179-
function_response = await __call_tool_async(
180-
tool, args=function_args, tool_context=tool_context
181-
)
179+
try:
180+
function_response = await __call_tool_async(
181+
tool, args=function_args, tool_context=tool_context
182+
)
183+
except Exception as tool_error:
184+
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
185+
tool=tool,
186+
tool_args=function_args,
187+
tool_context=tool_context,
188+
error=tool_error,
189+
)
190+
if error_response is not None:
191+
function_response = error_response
192+
else:
193+
raise tool_error
182194

183195
# Step 4: Check if plugin after_tool_callback overrides the function
184196
# response.

src/google/adk/models/base_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from .llm_response import LlmResponse
2929

3030

31+
class ModelErrorStrategy:
32+
RETRY = 'RETRY'
33+
PASS = 'PASS'
34+
35+
3136
class BaseLlm(BaseModel):
3237
"""The BaseLLM class.
3338

src/google/adk/models/google_llm.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import contextlib
1919
from functools import cached_property
20+
from functools import partial
2021
import logging
2122
import os
2223
import sys
@@ -25,8 +26,11 @@
2526
from typing import TYPE_CHECKING
2627
from typing import Union
2728

29+
import backoff
2830
from google.genai import Client
2931
from google.genai import types
32+
from google.genai.errors import ClientError
33+
from google.genai.errors import ServerError
3034
from google.genai.types import FinishReason
3135
from typing_extensions import override
3236

@@ -46,6 +50,19 @@
4650
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
4751
_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine'
4852
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
53+
DEFAULT_NETWORK_RETRIES = 4
54+
55+
56+
def give_up_non_retryable_error(error: Exception) -> bool:
57+
"""Returns True if the error is non-retryable."""
58+
# Retry on Resource exhausted error
59+
if isinstance(error, ClientError) and error.code == 429:
60+
return False
61+
62+
# Retry on Service unavailable error
63+
if isinstance(error, ServerError) and error.code == 503:
64+
return False
65+
return True
4966

5067

5168
class Gemini(BaseLlm):
@@ -57,6 +74,22 @@ class Gemini(BaseLlm):
5774

5875
model: str = 'gemini-1.5-flash'
5976

77+
# Implementation for exponential backoff.
78+
# Note: backoff.on_exception is a decorator, but it cannot properly decorate \
79+
# the generate_content_async function.
80+
# This is because backoff generate_content_async is an AsyncGenerator.
81+
# Backoff library uses asyncio.iscorotinefunction to detect sync/async
82+
# functions, which cannot recognize AsyncGenerator. Consequently, backoff
83+
# fails to catch the exception ingenerate_content_async.
84+
# Therefore, it requires applying the decorator manually to inner calls
85+
# to genai.client.
86+
exponential_backoff = backoff.on_exception(
87+
wait_gen=partial(backoff.expo, factor=5), # 5 seconds * 2 ^ retries
88+
exception=(ClientError, ServerError),
89+
max_tries=DEFAULT_NETWORK_RETRIES,
90+
giveup=give_up_non_retryable_error,
91+
)
92+
6093
@staticmethod
6194
@override
6295
def supported_models() -> list[str]:
@@ -104,7 +137,13 @@ async def generate_content_async(
104137
llm_request.config.http_options.headers.update(self._tracking_headers)
105138

106139
if stream:
107-
responses = await self.api_client.aio.models.generate_content_stream(
140+
generate_content_stream_func = (
141+
self.api_client.aio.models.generate_content_stream
142+
)
143+
response_stream_gen = Gemini.exponential_backoff(
144+
generate_content_stream_func
145+
)
146+
responses = await response_stream_gen(
108147
model=llm_request.model,
109148
contents=llm_request.contents,
110149
config=llm_request.config,
@@ -172,7 +211,9 @@ async def generate_content_async(
172211
)
173212

174213
else:
175-
response = await self.api_client.aio.models.generate_content(
214+
generate_content_func = self.api_client.aio.models.generate_content
215+
responses = Gemini.exponential_backoff(generate_content_func)
216+
response = await responses(
176217
model=llm_request.model,
177218
contents=llm_request.contents,
178219
config=llm_request.config,

src/google/adk/plugins/base_plugin.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from ..agents.base_agent import BaseAgent
2626
from ..agents.callback_context import CallbackContext
2727
from ..events.event import Event
28+
from ..models.base_llm import ModelErrorStrategy
2829
from ..models.llm_request import LlmRequest
2930
from ..models.llm_response import LlmResponse
3031
from ..tools.base_tool import BaseTool
31-
from ..utils.feature_decorator import working_in_progress
3232

3333
if TYPE_CHECKING:
3434
from ..agents.invocation_context import InvocationContext
@@ -265,6 +265,34 @@ async def after_model_callback(
265265
"""
266266
pass
267267

268+
async def on_model_error_callback(
269+
self,
270+
*,
271+
callback_context: CallbackContext,
272+
llm_request: LlmRequest,
273+
error: Exception,
274+
) -> Optional[LlmResponse | ModelErrorStrategy]:
275+
"""Callback executed when a model call encounters an error.
276+
277+
This callback provides an opportunity to handle model errors gracefully,
278+
potentially providing alternative responses or recovery mechanisms.
279+
280+
Args:
281+
callback_context: The context for the current agent call.
282+
llm_request: The request that was sent to the model when the error
283+
occurred.
284+
error: The exception that was raised during model execution.
285+
286+
Returns:
287+
An optional LlmResponse. If an LlmResponse is returned, it will be used
288+
instead of propagating the error.
289+
Returning `ModelErrorStrategy.RETRY` will retry the LLM call.
290+
Returning `ModelErrorStrategy.PASS` will allow the LLM call to
291+
proceed normally and ignore the error.
292+
Returning `None` allows the original error to be raised.
293+
"""
294+
pass
295+
268296
async def before_tool_callback(
269297
self,
270298
*,
@@ -315,3 +343,29 @@ async def after_tool_callback(
315343
result.
316344
"""
317345
pass
346+
347+
async def on_tool_error_callback(
348+
self,
349+
*,
350+
tool: BaseTool,
351+
tool_args: dict[str, Any],
352+
tool_context: ToolContext,
353+
error: Exception,
354+
) -> Optional[dict]:
355+
"""Callback executed when a tool call encounters an error.
356+
357+
This callback provides an opportunity to handle tool errors gracefully,
358+
potentially providing alternative responses or recovery mechanisms.
359+
360+
Args:
361+
tool: The tool instance that encountered an error.
362+
tool_args: The arguments that were passed to the tool.
363+
tool_context: The context specific to the tool execution.
364+
error: The exception that was raised during tool execution.
365+
366+
Returns:
367+
An optional dictionary. If a dictionary is returned, it will be used as
368+
the tool response instead of propagating the error. Returning `None`
369+
allows the original error to be raised.
370+
"""
371+
pass

0 commit comments

Comments
 (0)