From 806c08ed5b5ce7605b7da832691217793ecef7dc Mon Sep 17 00:00:00 2001 From: hcadioli Date: Mon, 27 Oct 2025 09:55:25 -0300 Subject: [PATCH 1/3] fix: cache canonical tools to avoid multiple calls when streaming --- src/google/adk/agents/invocation_context.py | 4 ++++ src/google/adk/flows/llm_flows/base_llm_flow.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index e972a65eda..24fdce9d59 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -32,6 +32,7 @@ from ..plugins.plugin_manager import PluginManager from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session +from ..tools.base_tool import BaseTool from .active_streaming_tool import ActiveStreamingTool from .base_agent import BaseAgent from .base_agent import BaseAgentState @@ -202,6 +203,9 @@ class InvocationContext(BaseModel): plugin_manager: PluginManager = Field(default_factory=PluginManager) """The manager for keeping track of plugins in this invocation.""" + canonical_tools_cache: Optional[list[BaseTool]] = None + """The cache of canonical tools for this invocation.""" + _invocation_cost_manager: _InvocationCostManager = PrivateAttr( default_factory=_InvocationCostManager ) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 531a5034c8..246ba2974e 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -842,7 +842,10 @@ async def _maybe_add_grounding_metadata( response: Optional[LlmResponse] = None, ) -> Optional[LlmResponse]: readonly_context = ReadonlyContext(invocation_context) - tools = await agent.canonical_tools(readonly_context) + if (tools := invocation_context.canonical_tools_cache) is None: + tools = await agent.canonical_tools(readonly_context) + invocation_context.canonical_tools_cache = tools + if not any(tool.name == 'google_search_agent' for tool in tools): return response ground_metadata = invocation_context.session.state.get( From 12012c0221702cfc98ba6085f05620de9236ede1 Mon Sep 17 00:00:00 2001 From: hcadioli Date: Mon, 27 Oct 2025 10:54:21 -0300 Subject: [PATCH 2/3] chore: add tests to verify canonical tools cache --- .../flows/llm_flows/test_base_llm_flow.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 81ef925a39..348d0ce932 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -413,3 +413,67 @@ def __init__(self): assert result == plugin_response plugin.after_model_callback.assert_called_once() + +@pytest.mark.asyncio +async def test_handle_after_model_callback_caches_canonical_tools(): + """Test that canonical_tools is only called once per invocation_context.""" + canonical_tools_call_count = 0 + + async def mock_canonical_tools(self, readonly_context=None): + nonlocal canonical_tools_call_count + canonical_tools_call_count += 1 + from google.adk.tools.base_tool import BaseTool + + class MockGoogleSearchTool(BaseTool): + def __init__(self): + super().__init__(name="google_search_agent", description="Mock search") + + async def call(self, **kwargs): + return "mock result" + + return [MockGoogleSearchTool()] + + agent = Agent(name="test_agent", tools=[google_search, dummy_tool]) + + with mock.patch.object(type(agent), "canonical_tools", new=mock_canonical_tools): + invocation_context = await testing_utils.create_invocation_context(agent=agent) + + assert invocation_context.canonical_tools_cache is None + + invocation_context.session.state["temp:_adk_grounding_metadata"] = { + "foo": "bar" + } + + llm_response = LlmResponse( + content=types.Content(parts=[types.Part.from_text(text="response")]) + ) + event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + flow = BaseLlmFlowForTesting() + + # Call _handle_after_model_callback multiple times with the same context + result1 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + result2 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + result3 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + + assert canonical_tools_call_count == 1, ( + f"canonical_tools should be called once, but was called " + f"{canonical_tools_call_count} times" + ) + + assert invocation_context.canonical_tools_cache is not None + assert len(invocation_context.canonical_tools_cache) == 1 + assert invocation_context.canonical_tools_cache[0].name == "google_search_agent" + + assert result1.grounding_metadata == {"foo": "bar"} + assert result2.grounding_metadata == {"foo": "bar"} + assert result3.grounding_metadata == {"foo": "bar"} \ No newline at end of file From de02bd3e4533c3741edf05788a5e8b2d3d38bae4 Mon Sep 17 00:00:00 2001 From: hcadioli Date: Thu, 6 Nov 2025 16:38:30 -0300 Subject: [PATCH 3/3] chore: format tests for canonical tools cache --- .../flows/llm_flows/test_base_llm_flow.py | 123 ++++++++++-------- 1 file changed, 66 insertions(+), 57 deletions(-) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 348d0ce932..d3cc210e2b 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -414,66 +414,75 @@ def __init__(self): assert result == plugin_response plugin.after_model_callback.assert_called_once() + @pytest.mark.asyncio async def test_handle_after_model_callback_caches_canonical_tools(): """Test that canonical_tools is only called once per invocation_context.""" canonical_tools_call_count = 0 async def mock_canonical_tools(self, readonly_context=None): - nonlocal canonical_tools_call_count - canonical_tools_call_count += 1 - from google.adk.tools.base_tool import BaseTool - - class MockGoogleSearchTool(BaseTool): - def __init__(self): - super().__init__(name="google_search_agent", description="Mock search") - - async def call(self, **kwargs): - return "mock result" - - return [MockGoogleSearchTool()] - - agent = Agent(name="test_agent", tools=[google_search, dummy_tool]) - - with mock.patch.object(type(agent), "canonical_tools", new=mock_canonical_tools): - invocation_context = await testing_utils.create_invocation_context(agent=agent) - - assert invocation_context.canonical_tools_cache is None - - invocation_context.session.state["temp:_adk_grounding_metadata"] = { - "foo": "bar" - } - - llm_response = LlmResponse( - content=types.Content(parts=[types.Part.from_text(text="response")]) - ) - event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=agent.name, - ) - flow = BaseLlmFlowForTesting() - - # Call _handle_after_model_callback multiple times with the same context - result1 = await flow._handle_after_model_callback( - invocation_context, llm_response, event - ) - result2 = await flow._handle_after_model_callback( - invocation_context, llm_response, event - ) - result3 = await flow._handle_after_model_callback( - invocation_context, llm_response, event - ) - - assert canonical_tools_call_count == 1, ( - f"canonical_tools should be called once, but was called " - f"{canonical_tools_call_count} times" - ) - - assert invocation_context.canonical_tools_cache is not None - assert len(invocation_context.canonical_tools_cache) == 1 - assert invocation_context.canonical_tools_cache[0].name == "google_search_agent" - - assert result1.grounding_metadata == {"foo": "bar"} - assert result2.grounding_metadata == {"foo": "bar"} - assert result3.grounding_metadata == {"foo": "bar"} \ No newline at end of file + nonlocal canonical_tools_call_count + canonical_tools_call_count += 1 + from google.adk.tools.base_tool import BaseTool + + class MockGoogleSearchTool(BaseTool): + + def __init__(self): + super().__init__(name='google_search_agent', description='Mock search') + + async def call(self, **kwargs): + return 'mock result' + + return [MockGoogleSearchTool()] + + agent = Agent(name='test_agent', tools=[google_search, dummy_tool]) + + with mock.patch.object( + type(agent), 'canonical_tools', new=mock_canonical_tools + ): + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + + assert invocation_context.canonical_tools_cache is None + + invocation_context.session.state['temp:_adk_grounding_metadata'] = { + 'foo': 'bar' + } + + llm_response = LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='response')]) + ) + event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=agent.name, + ) + flow = BaseLlmFlowForTesting() + + # Call _handle_after_model_callback multiple times with the same context + result1 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + result2 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + result3 = await flow._handle_after_model_callback( + invocation_context, llm_response, event + ) + + assert canonical_tools_call_count == 1, ( + 'canonical_tools should be called once, but was called ' + f'{canonical_tools_call_count} times' + ) + + assert invocation_context.canonical_tools_cache is not None + assert len(invocation_context.canonical_tools_cache) == 1 + assert ( + invocation_context.canonical_tools_cache[0].name + == 'google_search_agent' + ) + + assert result1.grounding_metadata == {'foo': 'bar'} + assert result2.grounding_metadata == {'foo': 'bar'} + assert result3.grounding_metadata == {'foo': 'bar'}