diff --git a/docs/output.md b/docs/output.md index 182a753944..0b61334e1c 100644 --- a/docs/output.md +++ b/docs/output.md @@ -385,6 +385,62 @@ print(repr(result.output)) _(This example is complete, it can be run "as is")_ +### Validation context {#validation-context} + +Some validation relies on an extra Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) object. You can pass such an object to an `Agent` at definition-time via its [`validation_context`][pydantic_ai.Agent.__init__] parameter. + +This validation context is used for the validation of _all_ structured outputs. It can be either: + +- the context object itself (`Any`), used as-is to validate outputs, or +- a function that takes the [`RunContext`][pydantic_ai.tools.RunContext] and returns a context object (`Any`). This function will be called automatically before each validation, allowing you to build a dynamic validation context. + +!!! warning "Don't confuse this _validation_ context with the _LLM_ context" + This Pydantic [context](https://docs.pydantic.dev/latest/concepts/validators/#validation-data) object is only used internally by Pydantic AI for output validation. In particular, it is **not** included in the prompts or messages sent to the language model. + +```python {title="validation_context.py"} +from dataclasses import dataclass + +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import Agent + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +agent = Agent( + 'google-gla:gemini-2.5-flash', + output_type=Value, + validation_context=10, +) +result = agent.run_sync('Give me a value of 5.') +print(repr(result.output)) # 5 from the model + 10 from the validation context +#> Value(x=15) + + +@dataclass +class Deps: + increment: int + + +agent = Agent( + 'google-gla:gemini-2.5-flash', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, +) +result = agent.run_sync('Give me a value of 5.', deps=Deps(increment=10)) +print(repr(result.output)) # 5 from the model + 10 from the validation context +#> Value(x=15) +``` + +_(This example is complete, it can be run "as is")_ + ### Custom JSON schema {#structured-dict} If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model. diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index c973befc70..a6d493679b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -144,6 +144,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] + validation_context: Any | Callable[[RunContext[DepsT]], Any] history_processors: Sequence[HistoryProcessor[DepsT]] @@ -736,7 +737,7 @@ async def _handle_text_response( ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: run_context = build_run_context(ctx) - result_data = await text_processor.process(text, run_context) + result_data = await text_processor.process(text, run_context=run_context) for validator in ctx.deps.output_validators: result_data = await validator.validate(result_data, run_context) @@ -781,12 +782,13 @@ async def run( def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: """Build a `RunContext` object from the current agent graph run context.""" - return RunContext[DepsT]( + run_context = RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, + validation_context=None, tracer=ctx.deps.tracer, trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, @@ -796,6 +798,21 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT run_step=ctx.state.run_step, run_id=ctx.state.run_id, ) + validation_context = build_validation_context(ctx.deps.validation_context, run_context) + run_context = replace(run_context, validation_context=validation_context) + return run_context + + +def build_validation_context( + validation_ctx: Any | Callable[[RunContext[DepsT]], Any], + run_context: RunContext[DepsT], +) -> Any: + """Build a Pydantic validation context, potentially from the current agent run context.""" + if callable(validation_ctx): + fn = cast(Callable[[RunContext[DepsT]], Any], validation_ctx) + return fn(run_context) + else: + return validation_ctx async def process_tool_calls( # noqa: C901 diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 053d3a71a8..24df1025bd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -522,6 +522,7 @@ class BaseOutputProcessor(ABC, Generic[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, @@ -609,6 +610,7 @@ def __init__( async def process( self, data: str | dict[str, Any] | None, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, @@ -628,7 +630,7 @@ async def process( data = _utils.strip_markdown_fences(data) try: - output = self.validate(data, allow_partial) + output = self.validate(data, allow_partial, run_context.validation_context) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -646,12 +648,17 @@ def validate( self, data: str | dict[str, Any] | None, allow_partial: bool = False, + validation_context: Any | None = None, ) -> dict[str, Any]: pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): - return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + return self.validator.validate_json( + data or '{}', allow_partial=pyd_allow_partial, context=validation_context + ) else: - return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + return self.validator.validate_python( + data or {}, allow_partial=pyd_allow_partial, context=validation_context + ) async def call( self, @@ -770,12 +777,16 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, + run_context=run_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) result = union_object.result @@ -791,7 +802,10 @@ async def process( raise return await processor.process( - inner_data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + inner_data, + run_context=run_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, ) @@ -799,7 +813,9 @@ class TextOutputProcessor(BaseOutputProcessor[OutputDataT]): async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -830,14 +846,22 @@ def __init__( async def process( self, data: str, + *, run_context: RunContext[AgentDepsT], + validation_context: Any | None = None, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} data = await execute_traced_output_function(self._function_schema, run_context, args, wrap_validation_errors) - return await super().process(data, run_context, allow_partial, wrap_validation_errors) + return await super().process( + data, + run_context=run_context, + validation_context=validation_context, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + ) @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 4f9b253767..6acf05ba18 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -3,7 +3,7 @@ import dataclasses from collections.abc import Sequence from dataclasses import field -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Any, Generic from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar @@ -38,6 +38,8 @@ class RunContext(Generic[RunContextAgentDepsT]): """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + validation_context: Any = None + """Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) for the run outputs.""" tracer: Tracer = field(default_factory=NoOpTracer) """The tracer to use for tracing the run.""" trace_include_content: bool = False diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index fb7039e2cc..9a9f93e1ff 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -164,9 +164,13 @@ async def _call_tool( pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = tool.args_validator if isinstance(call.args, str): - args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + args_dict = validator.validate_json( + call.args or '{}', allow_partial=pyd_allow_partial, context=ctx.validation_context + ) else: - args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + args_dict = validator.validate_python( + call.args or {}, allow_partial=pyd_allow_partial, context=ctx.validation_context + ) result = await self.toolset.call_tool(name, args_dict, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..b1ef5a71d0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -147,6 +147,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _max_tool_retries: int = dataclasses.field(repr=False) + _validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = dataclasses.field(repr=False) _event_stream_handler: EventStreamHandler[AgentDepsT] | None = dataclasses.field(repr=False) @@ -166,6 +167,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -192,6 +194,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -216,6 +219,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, + validation_context: Any | Callable[[RunContext[AgentDepsT]], Any] = None, output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), builtin_tools: Sequence[AbstractBuiltinTool] = (), @@ -249,6 +253,7 @@ def __init__( model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [HTTP Request Retries](../retries.md) documentation. + validation_context: Pydantic [validation context](https://docs.pydantic.dev/latest/concepts/validators/#validation-context) used to validate all outputs. output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. @@ -314,6 +319,8 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._max_tool_retries = retries + self._validation_context = validation_context + self._builtin_tools = builtin_tools self._prepare_tools = prepare_tools @@ -612,6 +619,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, + validation_context=self._validation_context, history_processors=self.history_processors, builtin_tools=[*self._builtin_tools, *(builtin_tools or [])], tool_manager=tool_manager, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c6b59ec796..88bfe407fa 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -198,7 +198,10 @@ async def validate_response_output( text = '' result_data = await text_processor.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, + run_context=self._run_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) for validator in self._output_validators: result_data = await validator.validate( diff --git a/tests/test_examples.py b/tests/test_examples.py index 407816b60a..b4208a045c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -506,6 +506,7 @@ async def call_tool( 'What is a banana?': ToolCallPart(tool_name='return_fruit', args={'name': 'banana', 'color': 'yellow'}), 'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}', 'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}', + 'Give me a value of 5.': ToolCallPart(tool_name='final_result', args={'x': 5}), 'Write a creative story about space exploration': 'In the year 2157, Captain Maya Chen piloted her spacecraft through the vast expanse of the Andromeda Galaxy. As she discovered a planet with crystalline mountains that sang in harmony with the cosmic winds, she realized that space exploration was not just about finding new worlds, but about finding new ways to understand the universe and our place within it.', 'Create a person': ToolCallPart( tool_name='final_result', diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py new file mode 100644 index 0000000000..ae475859ee --- /dev/null +++ b/tests/test_validation_context.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, ValidationInfo, field_validator + +from pydantic_ai import ( + Agent, + ModelMessage, + ModelResponse, + NativeOutput, + PromptedOutput, + RunContext, + TextPart, + ToolCallPart, + ToolOutput, +) +from pydantic_ai._output import OutputSpec +from pydantic_ai.models.function import AgentInfo, FunctionModel + + +class Value(BaseModel): + x: int + + @field_validator('x') + def increment_value(cls, value: int, info: ValidationInfo): + return value + (info.context or 0) + + +@dataclass +class Deps: + increment: int + + +@pytest.mark.parametrize( + 'output_type', + [ + Value, + ToolOutput(Value), + NativeOutput(Value), + PromptedOutput(Value), + ], + ids=[ + 'Value', + 'ToolOutput(Value)', + 'NativeOutput(Value)', + 'PromptedOutput(Value)', + ], +) +def test_agent_output_with_validation_context(output_type: OutputSpec[Value]): + """Test that the output is validated using the validation context""" + + def mock_llm(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if isinstance(output_type, ToolOutput): + return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args={'x': 0})]) + else: + text = Value(x=0).model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + agent = Agent( + FunctionModel(mock_llm), + output_type=output_type, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_tool_call_with_validation_context(): + """Test that the argument passed to the tool call is validated using the validation context.""" + + agent = Agent( + 'test', + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def get_value(ctx: RunContext[Deps], v: Value) -> int: + # NOTE: The test agent calls this tool with Value(x=0) which should then have been influenced by the validation context through the `increment_value` field validator + assert v.x == ctx.deps.increment + return v.x + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot('{"get_value":10}') + + +def test_agent_output_function_with_validation_context(): + """Test that the argument passed to the output function is validated using the validation context.""" + + def get_value(v: Value) -> int: + return v.x + + agent = Agent( + 'test', + output_type=get_value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output == snapshot(10) + + +def test_agent_output_validator_with_validation_context(): + """Test that the argument passed to the output validator is validated using the validation context.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(10) + + +def test_agent_output_validator_with_intermediary_deps_change_and_validation_context(): + """Test that the validation context is updated as run dependencies are mutated.""" + + agent = Agent( + 'test', + output_type=Value, + deps_type=Deps, + validation_context=lambda ctx: ctx.deps.increment, + ) + + @agent.tool + def bump_increment(ctx: RunContext[Deps]): + assert ctx.validation_context == snapshot(10) # validation ctx was first computed using the original deps + ctx.deps.increment += 5 # update the deps + + @agent.output_validator + def identity(ctx: RunContext[Deps], v: Value) -> Value: + assert ctx.validation_context == snapshot(15) # validation ctx was re-computed after deps update from tool call + + return v + + result = agent.run_sync('', deps=Deps(increment=10)) + assert result.output.x == snapshot(15)