Skip to content

Commit 3432b22

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Copy the original function call args before passing it to callback or tools to avoid being modified
PiperOrigin-RevId: 788462897
1 parent af35e26 commit 3432b22

File tree

2 files changed

+297
-6
lines changed

2 files changed

+297
-6
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
import copy
2021
import inspect
2122
import logging
2223
from typing import Any
@@ -150,9 +151,12 @@ async def handle_function_calls_async(
150151
)
151152

152153
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
153-
# do not use "args" as the variable name, because it is a reserved keyword
154+
# Do not use "args" as the variable name, because it is a reserved keyword
154155
# in python debugger.
155-
function_args = function_call.args or {}
156+
# Make a deep copy to avoid being modified.
157+
function_args = (
158+
copy.deepcopy(function_call.args) if function_call.args else {}
159+
)
156160

157161
# Step 1: Check if plugin before_tool_callback overrides the function
158162
# response.
@@ -275,9 +279,12 @@ async def handle_function_calls_live(
275279
invocation_context, function_call_event, function_call, tools_dict
276280
)
277281
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
278-
# do not use "args" as the variable name, because it is a reserved keyword
282+
# Do not use "args" as the variable name, because it is a reserved keyword
279283
# in python debugger.
280-
function_args = function_call.args or {}
284+
# Make a deep copy to avoid being modified.
285+
function_args = (
286+
copy.deepcopy(function_call.args) if function_call.args else {}
287+
)
281288
function_response = None
282289

283290
# Handle before_tool_callbacks - iterate through the canonical callback

tests/unittests/flows/llm_flows/test_functions_simple.py

Lines changed: 286 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
# limitations under the License.
1414

1515
from typing import Any
16-
from typing import AsyncGenerator
1716
from typing import Callable
1817

1918
from google.adk.agents.llm_agent import Agent
2019
from google.adk.events.event import Event
2120
from google.adk.flows.llm_flows.functions import find_matching_function_call
22-
from google.adk.sessions.session import Session
2321
from google.adk.tools.function_tool import FunctionTool
2422
from google.adk.tools.tool_context import ToolContext
2523
from google.genai import types
@@ -392,3 +390,289 @@ def test_find_function_call_event_multiple_function_responses():
392390
# Should return the first matching function call event found
393391
result = find_matching_function_call(events)
394392
assert result == call_event1 # First match (func_123)
393+
394+
395+
@pytest.mark.asyncio
396+
async def test_function_call_args_not_modified():
397+
"""Test that function_call.args is not modified when making a copy."""
398+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
399+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
400+
401+
def simple_fn(**kwargs) -> dict:
402+
return {'result': 'test'}
403+
404+
tool = FunctionTool(simple_fn)
405+
model = testing_utils.MockModel.create(responses=[])
406+
agent = Agent(
407+
name='test_agent',
408+
model=model,
409+
tools=[tool],
410+
)
411+
invocation_context = await testing_utils.create_invocation_context(
412+
agent=agent, user_content=''
413+
)
414+
415+
# Create original args that we want to ensure are not modified
416+
original_args = {'param1': 'value1', 'param2': 42}
417+
function_call = types.FunctionCall(name=tool.name, args=original_args)
418+
content = types.Content(parts=[types.Part(function_call=function_call)])
419+
event = Event(
420+
invocation_id=invocation_context.invocation_id,
421+
author=agent.name,
422+
content=content,
423+
)
424+
tools_dict = {tool.name: tool}
425+
426+
# Test handle_function_calls_async
427+
result_async = await handle_function_calls_async(
428+
invocation_context,
429+
event,
430+
tools_dict,
431+
)
432+
433+
# Verify original args are not modified
434+
assert function_call.args == original_args
435+
assert function_call.args is not original_args # Should be a copy
436+
437+
# Test handle_function_calls_live
438+
result_live = await handle_function_calls_live(
439+
invocation_context,
440+
event,
441+
tools_dict,
442+
)
443+
444+
# Verify original args are still not modified
445+
assert function_call.args == original_args
446+
assert function_call.args is not original_args # Should be a copy
447+
448+
# Both should return valid results
449+
assert result_async is not None
450+
assert result_live is not None
451+
452+
453+
@pytest.mark.asyncio
454+
async def test_function_call_args_none_handling():
455+
"""Test that function_call.args=None is handled correctly."""
456+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
457+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
458+
459+
def simple_fn(**kwargs) -> dict:
460+
return {'result': 'test'}
461+
462+
tool = FunctionTool(simple_fn)
463+
model = testing_utils.MockModel.create(responses=[])
464+
agent = Agent(
465+
name='test_agent',
466+
model=model,
467+
tools=[tool],
468+
)
469+
invocation_context = await testing_utils.create_invocation_context(
470+
agent=agent, user_content=''
471+
)
472+
473+
# Create function call with None args
474+
function_call = types.FunctionCall(name=tool.name, args=None)
475+
content = types.Content(parts=[types.Part(function_call=function_call)])
476+
event = Event(
477+
invocation_id=invocation_context.invocation_id,
478+
author=agent.name,
479+
content=content,
480+
)
481+
tools_dict = {tool.name: tool}
482+
483+
# Test handle_function_calls_async
484+
result_async = await handle_function_calls_async(
485+
invocation_context,
486+
event,
487+
tools_dict,
488+
)
489+
490+
# Test handle_function_calls_live
491+
result_live = await handle_function_calls_live(
492+
invocation_context,
493+
event,
494+
tools_dict,
495+
)
496+
497+
# Both should return valid results even with None args
498+
assert result_async is not None
499+
assert result_live is not None
500+
501+
502+
@pytest.mark.asyncio
503+
async def test_function_call_args_copy_behavior():
504+
"""Test that modifying the copied args doesn't affect the original."""
505+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
506+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
507+
508+
def simple_fn(test_param: str, other_param: int) -> dict:
509+
# Modify the args to test that the copy prevents affecting the original
510+
return {
511+
'result': 'test',
512+
'received_args': {'test_param': test_param, 'other_param': other_param},
513+
}
514+
515+
tool = FunctionTool(simple_fn)
516+
model = testing_utils.MockModel.create(responses=[])
517+
agent = Agent(
518+
name='test_agent',
519+
model=model,
520+
tools=[tool],
521+
)
522+
invocation_context = await testing_utils.create_invocation_context(
523+
agent=agent, user_content=''
524+
)
525+
526+
# Create original args
527+
original_args = {'test_param': 'original_value', 'other_param': 123}
528+
function_call = types.FunctionCall(name=tool.name, args=original_args)
529+
content = types.Content(parts=[types.Part(function_call=function_call)])
530+
event = Event(
531+
invocation_id=invocation_context.invocation_id,
532+
author=agent.name,
533+
content=content,
534+
)
535+
tools_dict = {tool.name: tool}
536+
537+
# Test handle_function_calls_async
538+
result_async = await handle_function_calls_async(
539+
invocation_context,
540+
event,
541+
tools_dict,
542+
)
543+
544+
# Verify original args are unchanged
545+
assert function_call.args == original_args
546+
assert function_call.args['test_param'] == 'original_value'
547+
548+
# Verify the tool received the args correctly
549+
assert result_async is not None
550+
response = result_async.content.parts[0].function_response.response
551+
552+
# Check if the response has the expected structure
553+
assert 'received_args' in response
554+
received_args = response['received_args']
555+
assert 'test_param' in received_args
556+
assert received_args['test_param'] == 'original_value'
557+
assert received_args['other_param'] == 123
558+
assert (
559+
function_call.args['test_param'] == 'original_value'
560+
) # Original unchanged
561+
562+
563+
@pytest.mark.asyncio
564+
async def test_function_call_args_deep_copy_behavior():
565+
"""Test that deep copy behavior works correctly with nested structures."""
566+
from google.adk.flows.llm_flows.functions import handle_function_calls_async
567+
from google.adk.flows.llm_flows.functions import handle_function_calls_live
568+
569+
def simple_fn(nested_dict: dict, list_param: list) -> dict:
570+
# Modify the nested structures to test deep copy
571+
nested_dict['inner']['value'] = 'modified'
572+
list_param.append('new_item')
573+
return {
574+
'result': 'test',
575+
'received_nested': nested_dict,
576+
'received_list': list_param,
577+
}
578+
579+
tool = FunctionTool(simple_fn)
580+
model = testing_utils.MockModel.create(responses=[])
581+
agent = Agent(
582+
name='test_agent',
583+
model=model,
584+
tools=[tool],
585+
)
586+
invocation_context = await testing_utils.create_invocation_context(
587+
agent=agent, user_content=''
588+
)
589+
590+
# Create original args with nested structures
591+
original_nested_dict = {'inner': {'value': 'original'}}
592+
original_list = ['item1', 'item2']
593+
original_args = {
594+
'nested_dict': original_nested_dict,
595+
'list_param': original_list,
596+
}
597+
598+
function_call = types.FunctionCall(name=tool.name, args=original_args)
599+
content = types.Content(parts=[types.Part(function_call=function_call)])
600+
event = Event(
601+
invocation_id=invocation_context.invocation_id,
602+
author=agent.name,
603+
content=content,
604+
)
605+
tools_dict = {tool.name: tool}
606+
607+
# Test handle_function_calls_async
608+
result_async = await handle_function_calls_async(
609+
invocation_context,
610+
event,
611+
tools_dict,
612+
)
613+
614+
# Verify original args are completely unchanged
615+
assert function_call.args == original_args
616+
assert function_call.args['nested_dict']['inner']['value'] == 'original'
617+
assert function_call.args['list_param'] == ['item1', 'item2']
618+
619+
# Verify the tool received the modified nested structures
620+
assert result_async is not None
621+
response = result_async.content.parts[0].function_response.response
622+
623+
# Check that the tool received modified versions
624+
assert 'received_nested' in response
625+
assert 'received_list' in response
626+
assert response['received_nested']['inner']['value'] == 'modified'
627+
assert 'new_item' in response['received_list']
628+
629+
# Verify original is still unchanged
630+
assert function_call.args['nested_dict']['inner']['value'] == 'original'
631+
assert function_call.args['list_param'] == ['item1', 'item2']
632+
633+
634+
def test_shallow_vs_deep_copy_demonstration():
635+
"""Demonstrate why deep copy is necessary vs shallow copy."""
636+
import copy
637+
638+
# Original nested structure
639+
original = {
640+
'nested_dict': {'inner': {'value': 'original'}},
641+
'list_param': ['item1', 'item2'],
642+
}
643+
644+
# Shallow copy (what dict() does)
645+
shallow_copy = dict(original)
646+
647+
# Deep copy (what copy.deepcopy() does)
648+
deep_copy = copy.deepcopy(original)
649+
650+
# Modify the shallow copy
651+
shallow_copy['nested_dict']['inner']['value'] = 'modified'
652+
shallow_copy['list_param'].append('new_item')
653+
654+
# Check that shallow copy affects the original
655+
assert (
656+
original['nested_dict']['inner']['value'] == 'modified'
657+
) # Original is affected!
658+
assert 'new_item' in original['list_param'] # Original is affected!
659+
660+
# Reset original for deep copy test
661+
original = {
662+
'nested_dict': {'inner': {'value': 'original'}},
663+
'list_param': ['item1', 'item2'],
664+
}
665+
666+
# Modify the deep copy
667+
deep_copy['nested_dict']['inner']['value'] = 'modified'
668+
deep_copy['list_param'].append('new_item')
669+
670+
# Check that deep copy does NOT affect the original
671+
assert (
672+
original['nested_dict']['inner']['value'] == 'original'
673+
) # Original unchanged
674+
assert 'new_item' not in original['list_param'] # Original unchanged
675+
assert (
676+
deep_copy['nested_dict']['inner']['value'] == 'modified'
677+
) # Copy is modified
678+
assert 'new_item' in deep_copy['list_param'] # Copy is modified

0 commit comments

Comments
 (0)