Skip to content

feat: Implement exponential backoff to GeminiLLM, and enable it by default #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,13 @@ async def _call_llm_async(
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context):
responses_generator = self.run_live(invocation_context)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
Expand All @@ -553,10 +559,16 @@ async def _call_llm_async(
# the counter beyond the max set value, then the execution is stopped
# right here, and exception is thrown.
invocation_context.increment_llm_call_count()
async for llm_response in llm.generate_content_async(
responses_generator = llm.generate_content_async(
llm_request,
stream=invocation_context.run_config.streaming_mode
== StreamingMode.SSE,
)
async for llm_response in self._run_and_handle_error(
responses_generator,
invocation_context,
llm_request,
model_response_event,
):
trace_call_llm(
invocation_context,
Expand Down Expand Up @@ -673,6 +685,43 @@ def _finalize_model_response_event(

return model_response_event

async def _run_and_handle_error(
self,
response_generator: AsyncGenerator[LlmResponse, None],
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
"""Runs the response generator and processes the error with plugins.
Args:
response_generator: The response generator to run.
invocation_context: The invocation context.
llm_request: The LLM request.
model_response_event: The model response event.
Yields:
A generator of LlmResponse.
"""
try:
async for response in response_generator:
yield response
except Exception as model_error:
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
error_response = (
await invocation_context.plugin_manager.run_on_model_error_callback(
callback_context=callback_context,
llm_request=llm_request,
error=model_error,
)
)
if error_response is not None:
yield error_response
else:
raise model_error

def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent

Expand Down
18 changes: 15 additions & 3 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,21 @@ async def handle_function_calls_async(

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
try:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
except Exception as tool_error:
error_response = await invocation_context.plugin_manager.run_on_tool_error_callback(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_error

# Step 4: Check if plugin after_tool_callback overrides the function
# response.
Expand Down
91 changes: 89 additions & 2 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,22 @@
import os
import sys
from typing import AsyncGenerator
from typing import Callable
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from google.genai import Client
from google.genai import types
from google.genai.errors import ClientError
from google.genai.errors import ServerError
from google.genai.types import FinishReason
from pydantic import BaseModel
from tenacity import retry
from tenacity import retry_if_exception
from tenacity import stop_after_attempt
from tenacity import wait_exponential
from typing_extensions import override

from .. import version
Expand All @@ -48,6 +57,48 @@
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'


class RetryConfig(BaseModel):
"""Config for controlling retry behavior during model execution.

Use this config in agent.model. Example:
```
agent = Agent(
model=Gemini(
retry_config=RetryConfig(initial_delay_sec=10, max_retries=3)
),
...
)
```
"""

initial_delay_sec: int = 5
"""The initial delay before the first retry, in seconds."""

expo_base: int = 2
"""The exponential base to add to the retry delay."""

max_delay_sec: int = 60
"""The maximum delay before the next retry, in seconds."""

max_retries: int = 5
"""The maximum number of retries."""

retry_predicate: Callable[[Exception], bool] = None
"""The predicate function to determine if the error is retryable."""


def retry_on_resumable_error(error: Exception) -> bool:
"""Returns True if the error is non-retryable."""
# Retry on Resource exhausted error
if isinstance(error, ClientError) and error.code == 429:
return True

# Retry on Service unavailable error
if isinstance(error, ServerError) and error.code == 503:
return True
return False


class Gemini(BaseLlm):
"""Integration for Gemini models.

Expand All @@ -57,6 +108,9 @@ class Gemini(BaseLlm):

model: str = 'gemini-1.5-flash'

retry_config: Optional[RetryConfig] = RetryConfig()
"""Use default retry config to retry on resumable model errors."""

@staticmethod
@override
def supported_models() -> list[str]:
Expand Down Expand Up @@ -106,7 +160,11 @@ async def generate_content_async(
llm_request.config.http_options.headers.update(self._tracking_headers)

if stream:
responses = await self.api_client.aio.models.generate_content_stream(
retry_annotation = self._build_retry_wrapper()
retryable_generate = retry_annotation(
self.api_client.aio.models.generate_content_stream
)
responses = await retryable_generate(
model=llm_request.model,
contents=llm_request.contents,
config=llm_request.config,
Expand Down Expand Up @@ -174,7 +232,11 @@ async def generate_content_async(
)

else:
response = await self.api_client.aio.models.generate_content(
retry_annotation = self._build_retry_wrapper()
retryable_generate = retry_annotation(
self.api_client.aio.models.generate_content
)
response = await retryable_generate(
model=llm_request.model,
contents=llm_request.contents,
config=llm_request.config,
Expand Down Expand Up @@ -284,6 +346,31 @@ def _preprocess_request(self, llm_request: LlmRequest) -> None:
_remove_display_name_if_present(part.inline_data)
_remove_display_name_if_present(part.file_data)

def _build_retry_wrapper(self) -> retry:
"""Apply retry logic to the Gemini API client.

Underlyingly this returns a tenacity.retry annotation that can be applied
to any function. Works for async functions as well.

Returns:
A tenacity.retry annotation that can be applied to any function.
"""
# Use default retry config if not specified.
config = self.retry_config or RetryConfig()
retry_predicate = config.retry_predicate
if not retry_predicate:
retry_predicate = retry_if_exception(retry_on_resumable_error)
return retry(
stop=stop_after_attempt(config.max_retries),
wait=wait_exponential(
multiplier=config.initial_delay_sec,
min=config.initial_delay_sec,
max=config.max_delay_sec,
),
retry=retry_predicate,
reraise=True,
)


def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
Expand Down
51 changes: 51 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,31 @@ async def after_model_callback(
"""
pass

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
"""Callback executed when a model call encounters an error.
This callback provides an opportunity to handle model errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
callback_context: The context for the current agent call.
llm_request: The request that was sent to the model when the error
occurred.
error: The exception that was raised during model execution.
Returns:
An optional LlmResponse. If an LlmResponse is returned, it will be used
instead of propagating the error. Returning `None` allows the original
error to be raised.
"""
pass

async def before_tool_callback(
self,
*,
Expand Down Expand Up @@ -315,3 +340,29 @@ async def after_tool_callback(
result.
"""
pass

async def on_tool_error_callback(
self,
*,
tool: BaseTool,
tool_args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
) -> Optional[dict]:
"""Callback executed when a tool call encounters an error.
This callback provides an opportunity to handle tool errors gracefully,
potentially providing alternative responses or recovery mechanisms.
Args:
tool: The tool instance that encountered an error.
tool_args: The arguments that were passed to the tool.
tool_context: The context specific to the tool execution.
error: The exception that was raised during tool execution.
Returns:
An optional dictionary. If a dictionary is returned, it will be used as
the tool response instead of propagating the error. Returning `None`
allows the original error to be raised.
"""
pass
Loading