Skip to content

Commit a64dcc2

Browse files
committed
Add span_attributes to InstrumentationSettings.
This allows users to add extra attributes to the agent's span, either as a direct dict, or a callable that takes the RunContext. The attributes are added/computed after the agent finishes.
1 parent 5768447 commit a64dcc2

File tree

3 files changed

+65
-8
lines changed

3 files changed

+65
-8
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing_extensions import Self, TypeVar, deprecated
1616

1717
from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION, InstrumentationNames
18+
from pydantic_graph import GraphRunContext
1819

1920
from .. import (
2021
_agent_graph,
@@ -667,21 +668,22 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
667668
finally:
668669
try:
669670
if instrumentation_settings and run_span.is_recording():
670-
run_span.set_attributes(
671-
self._run_span_end_attributes(
672-
instrumentation_settings, usage, state.message_history, graph_deps.new_message_index
673-
)
674-
)
671+
graph_ctx = GraphRunContext(state=state, deps=graph_deps)
672+
span_attributes = self._run_span_end_attributes(instrumentation_settings, graph_ctx)
673+
run_span.set_attributes(span_attributes)
675674
finally:
676675
run_span.end()
677676

678677
def _run_span_end_attributes(
679678
self,
680679
settings: InstrumentationSettings,
681-
usage: _usage.RunUsage,
682-
message_history: list[_messages.ModelMessage],
683-
new_message_index: int,
680+
graph_ctx: GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT]],
684681
):
682+
run_ctx = _agent_graph.build_run_context(graph_ctx)
683+
usage = run_ctx.usage
684+
message_history = run_ctx.messages
685+
new_message_index = graph_ctx.deps.new_message_index
686+
685687
if settings.version == 1:
686688
attrs = {
687689
'all_messages_events': json.dumps(
@@ -714,9 +716,20 @@ def _run_span_end_attributes(
714716
):
715717
attrs['pydantic_ai.variable_instructions'] = True
716718

719+
extra_attributes: dict[str, str] = {}
720+
span_attrs_setting = settings.span_attributes
721+
if span_attrs_setting:
722+
if callable(span_attrs_setting):
723+
resolved_attrs = span_attrs_setting(run_ctx)
724+
if resolved_attrs:
725+
extra_attributes = dict(resolved_attrs)
726+
else:
727+
extra_attributes = dict(span_attrs_setting)
728+
717729
return {
718730
**usage.opentelemetry_attributes(),
719731
**attrs,
732+
**extra_attributes,
720733
'logfire.json_schema': json.dumps(
721734
{
722735
'type': 'object',

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class InstrumentationSettings:
9393
include_binary_content: bool = True
9494
include_content: bool = True
9595
version: Literal[1, 2, 3] = DEFAULT_INSTRUMENTATION_VERSION
96+
span_attributes: dict[str, str] | Callable[[RunContext[Any]], dict[str, str]] = field(default_factory=dict)
9697

9798
def __init__(
9899
self,
@@ -104,6 +105,7 @@ def __init__(
104105
version: Literal[1, 2, 3] = DEFAULT_INSTRUMENTATION_VERSION,
105106
event_mode: Literal['attributes', 'logs'] = 'attributes',
106107
event_logger_provider: EventLoggerProvider | None = None,
108+
span_attributes: dict[str, str] | Callable[[RunContext[Any]], dict[str, str]] | None = None,
107109
):
108110
"""Create instrumentation options.
109111
@@ -132,6 +134,9 @@ def __init__(
132134
If not provided, the global event logger provider is used.
133135
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
134136
This is only used if `event_mode='logs'` and `version=1`.
137+
span_attributes: Additional attributes to add to the agent's span.
138+
Can either be a dict of attributes, or a callable that accepts the RunContext to compute them.
139+
Attributes are added or computed after the agent finishes running.
135140
"""
136141
from pydantic_ai import __version__
137142

@@ -145,6 +150,7 @@ def __init__(
145150
self.event_mode = event_mode
146151
self.include_binary_content = include_binary_content
147152
self.include_content = include_content
153+
self.span_attributes = span_attributes or dict()
148154

149155
if event_mode == 'logs' and version != 1:
150156
warnings.warn(

tests/test_logfire.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,44 @@ async def add_numbers(x: int, y: int) -> int:
13311331
)
13321332

13331333

1334+
def _callable_span_attributes(ctx: RunContext[Any]) -> dict[str, str]:
1335+
"""Helper callable that reflects the prompt in a span attribute."""
1336+
return {'callable_attr': str(ctx.prompt)}
1337+
1338+
1339+
@pytest.mark.skipif(not logfire_installed, reason='logfire not installed')
1340+
@pytest.mark.parametrize(
1341+
('span_attrs', 'expected'),
1342+
[
1343+
pytest.param({'custom_attr': 'value'}, {'custom_attr': 'value'}, id='dict'),
1344+
pytest.param(_callable_span_attributes, {'callable_attr': 'span attributes prompt'}, id='callable'),
1345+
],
1346+
)
1347+
def test_agent_span_attributes_extra(
1348+
get_logfire_summary: Callable[[], LogfireSummary],
1349+
span_attrs: dict[str, str] | Callable[[RunContext[Any]], dict[str, str]],
1350+
expected: dict[str, str],
1351+
) -> None:
1352+
instrumentation_settings = InstrumentationSettings(span_attributes=span_attrs)
1353+
agent = Agent(
1354+
model=TestModel(),
1355+
instrument=instrumentation_settings,
1356+
name='span_attributes_agent',
1357+
)
1358+
1359+
agent.run_sync('span attributes prompt')
1360+
1361+
summary = get_logfire_summary()
1362+
agent_span_attributes = next(
1363+
attributes
1364+
for attributes in summary.attributes.values()
1365+
if attributes.get('gen_ai.agent.name') == 'span_attributes_agent'
1366+
)
1367+
1368+
for key, value in expected.items():
1369+
assert agent_span_attributes[key] == value
1370+
1371+
13341372
class WeatherInfo(BaseModel):
13351373
temperature: float
13361374
description: str

0 commit comments

Comments
 (0)