From 6ef2b8c78eee4d2fdf0efa203cddeed302c32f6c Mon Sep 17 00:00:00 2001 From: pavan Date: Sat, 15 Nov 2025 22:39:13 +0530 Subject: [PATCH 1/3] feat: enforce message history starts with user message Add validation in UserPromptNode to raise UserError if message history starts with ModelResponse, ensuring conversations begin with a user message (ModelRequest) Include comprehensive tests for invalid history, valid history, empty history, multiple messages, and validation after message cleaning to prevent issues with malformed conversation logs --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 6 ++ tests/models/test_outlines.py | 4 +- tests/test_agent.py | 105 +++++++++++++++++++ 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 91cda373a5..d96f182085 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -216,6 +216,12 @@ async def run( # noqa: C901 ctx.state.message_history = messages ctx.deps.new_message_index = len(messages) + # Validate that message history starts with a user message + if messages and isinstance(messages[0], _messages.ModelResponse): + raise exceptions.UserError( + 'Message history cannot start with a `ModelResponse`. Conversations must begin with a user message.' + ) + if self.deferred_tool_results is not None: return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx) diff --git a/tests/models/test_outlines.py b/tests/models/test_outlines.py index 73adc28853..3758e3c404 100644 --- a/tests/models/test_outlines.py +++ b/tests/models/test_outlines.py @@ -573,6 +573,7 @@ def test_input_format(transformers_multimodal_model: OutlinesModel, binary_image # unsupported: tool calls tool_call_message_history: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='some user prompt')]), ModelResponse(parts=[ToolCallPart(tool_call_id='1', tool_name='get_location')]), ModelRequest(parts=[ToolReturnPart(tool_name='get_location', content='London', tool_call_id='1')]), ] @@ -588,7 +589,8 @@ def test_input_format(transformers_multimodal_model: OutlinesModel, binary_image # unsupported: non-image file parts file_part_message_history: list[ModelMessage] = [ - ModelResponse(parts=[FilePart(content=BinaryContent(data=b'test', media_type='text/plain'))]) + ModelRequest(parts=[UserPromptPart(content='some user prompt')]), + ModelResponse(parts=[FilePart(content=BinaryContent(data=b'test', media_type='text/plain'))]), ] with pytest.raises( UserError, match='File parts other than `BinaryImage` are not supported for Outlines models yet.' diff --git a/tests/test_agent.py b/tests/test_agent.py index d8323f3c98..fbe91d0efa 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6125,3 +6125,108 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ] ) assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",') + + +def test_message_history_cannot_start_with_model_response(): + """Test that message history starting with ModelResponse raises UserError.""" + + def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover + + agent = Agent(FunctionModel(simple_response)) + + invalid_history = [ + ModelResponse(parts=[TextPart(content='ai response')]), + ] + + with pytest.raises( + UserError, + match='Message history cannot start with a `ModelResponse`.', + ): + agent.run_sync('hello', message_history=invalid_history) + + +async def test_message_history_starts_with_model_request(): + """Test that valid history starting with ModelRequest works correctly.""" + + def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('ok here is text')]) + + agent = Agent(FunctionModel(llm)) + + valid_history = [ + ModelRequest(parts=[UserPromptPart(content='Hello')]), + ModelResponse(parts=[TextPart(content='Hi there!')]), + ] + + # Should not raise error - valid history starting with ModelRequest + async with agent.iter('How are you?', message_history=valid_history) as run: + async for _ in run: + pass + # Verify messages are processed correctly + all_messages = run.all_messages() + assert len(all_messages) >= 3 # History + new request + response + assert isinstance(all_messages[0], ModelRequest) # First message is ModelRequest + + +async def test_empty_message_history_is_valid(): + """Test that empty message history works fine.""" + + def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('response text')]) + + agent = Agent(FunctionModel(llm)) + + # Empty history should work - should not raise error + async with agent.iter('hello', message_history=[]) as run: + async for _ in run: + pass + all_messages = run.all_messages() + assert len(all_messages) >= 2 # Request + response + assert isinstance(all_messages[0], ModelRequest) + + +async def test_message_history_with_multiple_messages(): + """Test that history with multiple messages starting correctly works.""" + + def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('final response')]) + + agent = Agent(FunctionModel(llm)) + + valid_history = [ + ModelRequest(parts=[UserPromptPart(content='First')]), + ModelResponse(parts=[TextPart(content='Response 1')]), + ModelRequest(parts=[UserPromptPart(content='Second')]), + ModelResponse(parts=[TextPart(content='Response 2')]), + ] + + async with agent.iter('Third message', message_history=valid_history) as run: + async for _ in run: + pass + # Verify the history is preserved and new message is added + all_messages = run.all_messages() + assert len(all_messages) >= 5 # 4 from history + at least 1 new + assert isinstance(all_messages[0], ModelRequest) + assert isinstance(all_messages[-1], ModelResponse) + + +def test_validation_happens_after_cleaning(): + """Test that validation catches issues even after message cleaning.""" + + def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover + + agent = Agent(FunctionModel(simple_response)) + + # Even if cleaning merges messages, first should still be checked + invalid_history = [ + ModelResponse(parts=[TextPart(content='response 1')]), + ModelResponse(parts=[TextPart(content='response 2')]), # Would be merged + ] + + with pytest.raises( + UserError, + match='Message history cannot start with a `ModelResponse`.', + ): + agent.run_sync('hello', message_history=invalid_history) From 685ac073951e5a0410e0d64bcbe57fde6de87918 Mon Sep 17 00:00:00 2001 From: pavan Date: Tue, 18 Nov 2025 21:23:41 +0530 Subject: [PATCH 2/3] test: remove redundant tests per maintainer feedback Only keep the test that verifies the exception is raised. The non-exception behavior is already tested by existing tests. Addresses review feedback from @DouweM on PR #3440 --- tests/test_agent.py | 86 --------------------------------------------- 1 file changed, 86 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index fbe91d0efa..8a974dd165 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6144,89 +6144,3 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes match='Message history cannot start with a `ModelResponse`.', ): agent.run_sync('hello', message_history=invalid_history) - - -async def test_message_history_starts_with_model_request(): - """Test that valid history starting with ModelRequest works correctly.""" - - def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart('ok here is text')]) - - agent = Agent(FunctionModel(llm)) - - valid_history = [ - ModelRequest(parts=[UserPromptPart(content='Hello')]), - ModelResponse(parts=[TextPart(content='Hi there!')]), - ] - - # Should not raise error - valid history starting with ModelRequest - async with agent.iter('How are you?', message_history=valid_history) as run: - async for _ in run: - pass - # Verify messages are processed correctly - all_messages = run.all_messages() - assert len(all_messages) >= 3 # History + new request + response - assert isinstance(all_messages[0], ModelRequest) # First message is ModelRequest - - -async def test_empty_message_history_is_valid(): - """Test that empty message history works fine.""" - - def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart('response text')]) - - agent = Agent(FunctionModel(llm)) - - # Empty history should work - should not raise error - async with agent.iter('hello', message_history=[]) as run: - async for _ in run: - pass - all_messages = run.all_messages() - assert len(all_messages) >= 2 # Request + response - assert isinstance(all_messages[0], ModelRequest) - - -async def test_message_history_with_multiple_messages(): - """Test that history with multiple messages starting correctly works.""" - - def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart('final response')]) - - agent = Agent(FunctionModel(llm)) - - valid_history = [ - ModelRequest(parts=[UserPromptPart(content='First')]), - ModelResponse(parts=[TextPart(content='Response 1')]), - ModelRequest(parts=[UserPromptPart(content='Second')]), - ModelResponse(parts=[TextPart(content='Response 2')]), - ] - - async with agent.iter('Third message', message_history=valid_history) as run: - async for _ in run: - pass - # Verify the history is preserved and new message is added - all_messages = run.all_messages() - assert len(all_messages) >= 5 # 4 from history + at least 1 new - assert isinstance(all_messages[0], ModelRequest) - assert isinstance(all_messages[-1], ModelResponse) - - -def test_validation_happens_after_cleaning(): - """Test that validation catches issues even after message cleaning.""" - - def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover - - agent = Agent(FunctionModel(simple_response)) - - # Even if cleaning merges messages, first should still be checked - invalid_history = [ - ModelResponse(parts=[TextPart(content='response 1')]), - ModelResponse(parts=[TextPart(content='response 2')]), # Would be merged - ] - - with pytest.raises( - UserError, - match='Message history cannot start with a `ModelResponse`.', - ): - agent.run_sync('hello', message_history=invalid_history) From 40d8862e2d971b8cfd7af53bea8f92ce8e6c4ca9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 18 Nov 2025 17:29:52 -0600 Subject: [PATCH 3/3] Update tests/test_agent.py --- tests/test_agent.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 289a2ef8c0..c2a513af47 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -6137,10 +6137,7 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: def test_message_history_cannot_start_with_model_response(): """Test that message history starting with ModelResponse raises UserError.""" - def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover - - agent = Agent(FunctionModel(simple_response)) + agent = Agent('test') invalid_history = [ ModelResponse(parts=[TextPart(content='ai response')]),