Skip to content

Commit aaee6e4

Browse files
author
pavan
committed
feat: enforce message history starts with user message
Add validation in UserPromptNode to raise UserError if message history starts with ModelResponse, ensuring conversations begin with a user message (ModelRequest) Include comprehensive tests for invalid history, valid history, empty history, multiple messages, and validation after message cleaning to prevent issues with malformed conversation logs
1 parent dec2611 commit aaee6e4

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,12 @@ async def run( # noqa: C901
216216
ctx.state.message_history = messages
217217
ctx.deps.new_message_index = len(messages)
218218

219+
# Validate that message history starts with a user message
220+
if messages and isinstance(messages[0], _messages.ModelResponse):
221+
raise exceptions.UserError(
222+
'Message history cannot start with a `ModelResponse`. Conversations must begin with a user message.'
223+
)
224+
219225
if self.deferred_tool_results is not None:
220226
return await self._handle_deferred_tool_results(self.deferred_tool_results, messages, ctx)
221227

tests/test_agent.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6125,3 +6125,108 @@ def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
61256125
]
61266126
)
61276127
assert run.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')
6128+
6129+
6130+
def test_message_history_cannot_start_with_model_response():
6131+
"""Test that message history starting with ModelResponse raises UserError."""
6132+
6133+
def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
6134+
return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover
6135+
6136+
agent = Agent(FunctionModel(simple_response))
6137+
6138+
invalid_history = [
6139+
ModelResponse(parts=[TextPart(content='ai response')]),
6140+
]
6141+
6142+
with pytest.raises(
6143+
UserError,
6144+
match='Message history cannot start with a `ModelResponse`.',
6145+
):
6146+
agent.run_sync('hello', message_history=invalid_history)
6147+
6148+
6149+
async def test_message_history_starts_with_model_request():
6150+
"""Test that valid history starting with ModelRequest works correctly."""
6151+
6152+
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
6153+
return ModelResponse(parts=[TextPart('ok here is text')])
6154+
6155+
agent = Agent(FunctionModel(llm))
6156+
6157+
valid_history = [
6158+
ModelRequest(parts=[UserPromptPart(content='Hello')]),
6159+
ModelResponse(parts=[TextPart(content='Hi there!')]),
6160+
]
6161+
6162+
# Should not raise error - valid history starting with ModelRequest
6163+
async with agent.iter('How are you?', message_history=valid_history) as run:
6164+
async for _ in run:
6165+
pass
6166+
# Verify messages are processed correctly
6167+
all_messages = run.all_messages()
6168+
assert len(all_messages) >= 3 # History + new request + response
6169+
assert isinstance(all_messages[0], ModelRequest) # First message is ModelRequest
6170+
6171+
6172+
async def test_empty_message_history_is_valid():
6173+
"""Test that empty message history works fine."""
6174+
6175+
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
6176+
return ModelResponse(parts=[TextPart('response text')])
6177+
6178+
agent = Agent(FunctionModel(llm))
6179+
6180+
# Empty history should work - should not raise error
6181+
async with agent.iter('hello', message_history=[]) as run:
6182+
async for _ in run:
6183+
pass
6184+
all_messages = run.all_messages()
6185+
assert len(all_messages) >= 2 # Request + response
6186+
assert isinstance(all_messages[0], ModelRequest)
6187+
6188+
6189+
async def test_message_history_with_multiple_messages():
6190+
"""Test that history with multiple messages starting correctly works."""
6191+
6192+
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
6193+
return ModelResponse(parts=[TextPart('final response')])
6194+
6195+
agent = Agent(FunctionModel(llm))
6196+
6197+
valid_history = [
6198+
ModelRequest(parts=[UserPromptPart(content='First')]),
6199+
ModelResponse(parts=[TextPart(content='Response 1')]),
6200+
ModelRequest(parts=[UserPromptPart(content='Second')]),
6201+
ModelResponse(parts=[TextPart(content='Response 2')]),
6202+
]
6203+
6204+
async with agent.iter('Third message', message_history=valid_history) as run:
6205+
async for _ in run:
6206+
pass
6207+
# Verify the history is preserved and new message is added
6208+
all_messages = run.all_messages()
6209+
assert len(all_messages) >= 5 # 4 from history + at least 1 new
6210+
assert isinstance(all_messages[0], ModelRequest)
6211+
assert isinstance(all_messages[-1], ModelResponse)
6212+
6213+
6214+
def test_validation_happens_after_cleaning():
6215+
"""Test that validation catches issues even after message cleaning."""
6216+
6217+
def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
6218+
return ModelResponse(parts=[TextPart(content='Final response')]) # pragma: no cover
6219+
6220+
agent = Agent(FunctionModel(simple_response))
6221+
6222+
# Even if cleaning merges messages, first should still be checked
6223+
invalid_history = [
6224+
ModelResponse(parts=[TextPart(content='response 1')]),
6225+
ModelResponse(parts=[TextPart(content='response 2')]), # Would be merged
6226+
]
6227+
6228+
with pytest.raises(
6229+
UserError,
6230+
match='Message history cannot start with a `ModelResponse`.',
6231+
):
6232+
agent.run_sync('hello', message_history=invalid_history)

0 commit comments

Comments
 (0)