Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,21 @@ async def handle_call(
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
# Output tool calls are not traced and not counted
return await self._call_tool(
call,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
)
output_tool_flag = True
else:
return await self._call_function_tool(
call,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
tracer=self.ctx.tracer,
include_content=self.ctx.trace_include_content,
instrumentation_version=self.ctx.instrumentation_version,
usage=self.ctx.usage,
)
output_tool_flag = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be simplified as is_output_tool = ..., without the if/else, but I think we don't need this check here anymore since we can also do it inside the call method, now that we always use that. So please move this check into the method.


return await self._call_function_tool(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is no longer accurate, because we also use it for output tools now. Can we either merge the 2 methods (as there's no need for a separate _call_tool anymore), or call this method _call_tool_traced? The latter may be better to keep the concerns separated.

call,
allow_partial=allow_partial,
wrap_validation_errors=wrap_validation_errors,
approved=approved,
tracer=self.ctx.tracer,
include_content=self.ctx.trace_include_content,
instrumentation_version=self.ctx.instrumentation_version,
usage=self.ctx.usage,
output_tool_flag=output_tool_flag,
)

async def _call_tool(
self,
Expand Down Expand Up @@ -213,16 +210,22 @@ async def _call_function_tool(
include_content: bool,
instrumentation_version: int,
usage: RunUsage,
output_tool_flag: bool = False,
) -> Any:
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
instrumentation_names = InstrumentationNames.for_version(instrumentation_version)

if output_tool_flag:
tool_name = 'output tool'
else:
tool_name = call.tool_name

span_attributes = {
'gen_ai.tool.name': call.tool_name,
'gen_ai.tool.name': tool_name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg should always have the actual tool name, so please change this back

# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
'gen_ai.tool.call.id': call.tool_call_id,
**({instrumentation_names.tool_arguments_attr: call.args_as_json_str()} if include_content else {}),
'logfire.msg': f'running tool: {call.tool_name}',
'logfire.msg': f'running tool: {tool_name}',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where I think we should say validating output: {tool_name}

# add the JSON schema so these attributes are formatted nicely in Logfire
'logfire.json_schema': json.dumps(
{
Expand All @@ -243,7 +246,7 @@ async def _call_function_tool(
),
}
with tracer.start_as_current_span(
instrumentation_names.get_tool_span_name(call.tool_name),
instrumentation_names.get_tool_span_name(tool_name),
attributes=span_attributes,
) as span:
try:
Expand All @@ -253,7 +256,9 @@ async def _call_function_tool(
wrap_validation_errors=wrap_validation_errors,
approved=approved,
)
usage.tool_calls += 1
if not output_tool_flag:
# Output tool calls are not counted
usage.tool_calls += 1

except ToolRetryError as e:
part = e.tool_retry
Expand Down
1 change: 1 addition & 0 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
)
],
),
BasicSpan(content='running tool: output tool'),
],
)
],
Expand Down
64 changes: 52 additions & 12 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,10 @@ class MyOutput:
'id': 0,
'name': 'agent run',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand All @@ -703,7 +706,10 @@ class MyOutput:
'id': 0,
'name': 'invoke_agent my_agent',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand Down Expand Up @@ -900,7 +906,10 @@ class MyOutput:
'id': 0,
'name': 'agent run',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand All @@ -911,7 +920,10 @@ class MyOutput:
'id': 0,
'name': 'invoke_agent my_agent',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand Down Expand Up @@ -1381,8 +1393,15 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
{'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'},
{
'id': 2,
'name': 'running output function',
'message': 'running output function: final_result',
'name': 'running tool',
'message': 'running tool: output tool',
'children': [
{
'id': 3,
'name': 'running output function',
'message': 'running output function: final_result',
}
],
},
],
}
Expand Down Expand Up @@ -1428,8 +1447,15 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
{'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'},
{
'id': 2,
'name': 'execute_tool final_result',
'message': 'running output function: final_result',
'name': 'execute_tool output tool',
'message': 'running tool: output tool',
'children': [
{
'id': 3,
'name': 'execute_tool final_result',
'message': 'running output function: final_result',
}
],
},
],
}
Expand Down Expand Up @@ -2336,7 +2362,10 @@ def instructions():
'id': 0,
'name': 'agent run',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand All @@ -2347,7 +2376,10 @@ def instructions():
'id': 0,
'name': 'invoke_agent my_agent',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand Down Expand Up @@ -2589,6 +2621,7 @@ def my_tool() -> str:
'children': [{'id': 3, 'name': 'running tool', 'message': 'running tool: my_tool'}],
},
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
{'id': 5, 'name': 'running tool', 'message': 'running tool: output tool'},
],
}
]
Expand All @@ -2611,6 +2644,7 @@ def my_tool() -> str:
],
},
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
{'id': 5, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'},
],
}
]
Expand Down Expand Up @@ -2877,7 +2911,10 @@ def instructions(ctx: RunContext[None]):
'id': 0,
'name': 'agent run',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand All @@ -2888,7 +2925,10 @@ def instructions(ctx: RunContext[None]):
'id': 0,
'name': 'invoke_agent my_agent',
'message': 'my_agent run',
'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}],
'children': [
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
{'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'},
],
}
]
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_prefect.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ async def run_complex_agent() -> Response:
)
],
),
BasicSpan(content='running tool: output tool'),
],
)
],
Expand Down
1 change: 1 addition & 0 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ async def test_complex_agent_run_in_workflow(
)
],
),
BasicSpan(content='running tool: output tool'),
],
),
BasicSpan(content='CompleteWorkflow:ComplexAgentWorkflow'),
Expand Down