|
40 | 40 | AgentMessageChunk, |
41 | 41 | AllowedOutcome, |
42 | 42 | DeniedOutcome, |
| 43 | + PermissionOption, |
43 | 44 | TextContentBlock, |
| 45 | + ToolCallLocation, |
| 46 | + ToolCallProgress, |
| 47 | + ToolCallStart, |
| 48 | + ToolCallUpdate, |
44 | 49 | UserMessageChunk, |
45 | 50 | ) |
46 | 51 |
|
@@ -416,6 +421,169 @@ async def test_ignore_invalid_messages(): |
416 | 421 | await asyncio.wait_for(s.client_reader.readline(), timeout=0.1) |
417 | 422 |
|
418 | 423 |
|
| 424 | +class _ExampleAgent(Agent): |
| 425 | + __test__ = False |
| 426 | + |
| 427 | + def __init__(self) -> None: |
| 428 | + self._conn: AgentSideConnection | None = None |
| 429 | + self.permission_response: RequestPermissionResponse | None = None |
| 430 | + self.prompt_requests: list[PromptRequest] = [] |
| 431 | + |
| 432 | + def bind(self, conn: AgentSideConnection) -> "_ExampleAgent": |
| 433 | + self._conn = conn |
| 434 | + return self |
| 435 | + |
| 436 | + async def initialize(self, params: InitializeRequest) -> InitializeResponse: |
| 437 | + return InitializeResponse(protocolVersion=params.protocolVersion) |
| 438 | + |
| 439 | + async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: |
| 440 | + return NewSessionResponse(sessionId="sess_demo") |
| 441 | + |
| 442 | + async def prompt(self, params: PromptRequest) -> PromptResponse: |
| 443 | + assert self._conn is not None |
| 444 | + self.prompt_requests.append(params) |
| 445 | + |
| 446 | + await self._conn.sessionUpdate( |
| 447 | + SessionNotification( |
| 448 | + sessionId=params.sessionId, |
| 449 | + update=AgentMessageChunk( |
| 450 | + sessionUpdate="agent_message_chunk", |
| 451 | + content=TextContentBlock(type="text", text="I'll help you with that."), |
| 452 | + ), |
| 453 | + ) |
| 454 | + ) |
| 455 | + |
| 456 | + await self._conn.sessionUpdate( |
| 457 | + SessionNotification( |
| 458 | + sessionId=params.sessionId, |
| 459 | + update=ToolCallStart( |
| 460 | + sessionUpdate="tool_call", |
| 461 | + toolCallId="call_1", |
| 462 | + title="Modifying configuration", |
| 463 | + kind="edit", |
| 464 | + status="pending", |
| 465 | + locations=[ToolCallLocation(path="/project/config.json")], |
| 466 | + rawInput={"path": "/project/config.json"}, |
| 467 | + ), |
| 468 | + ) |
| 469 | + ) |
| 470 | + |
| 471 | + permission_request = RequestPermissionRequest( |
| 472 | + sessionId=params.sessionId, |
| 473 | + toolCall=ToolCallUpdate( |
| 474 | + toolCallId="call_1", |
| 475 | + title="Modifying configuration", |
| 476 | + kind="edit", |
| 477 | + status="pending", |
| 478 | + locations=[ToolCallLocation(path="/project/config.json")], |
| 479 | + rawInput={"path": "/project/config.json"}, |
| 480 | + ), |
| 481 | + options=[ |
| 482 | + PermissionOption(kind="allow_once", name="Allow", optionId="allow"), |
| 483 | + PermissionOption(kind="reject_once", name="Reject", optionId="reject"), |
| 484 | + ], |
| 485 | + ) |
| 486 | + response = await self._conn.requestPermission(permission_request) |
| 487 | + self.permission_response = response |
| 488 | + |
| 489 | + if isinstance(response.outcome, AllowedOutcome) and response.outcome.optionId == "allow": |
| 490 | + await self._conn.sessionUpdate( |
| 491 | + SessionNotification( |
| 492 | + sessionId=params.sessionId, |
| 493 | + update=ToolCallProgress( |
| 494 | + sessionUpdate="tool_call_update", |
| 495 | + toolCallId="call_1", |
| 496 | + status="completed", |
| 497 | + rawOutput={"success": True}, |
| 498 | + ), |
| 499 | + ) |
| 500 | + ) |
| 501 | + await self._conn.sessionUpdate( |
| 502 | + SessionNotification( |
| 503 | + sessionId=params.sessionId, |
| 504 | + update=AgentMessageChunk( |
| 505 | + sessionUpdate="agent_message_chunk", |
| 506 | + content=TextContentBlock(type="text", text="Done."), |
| 507 | + ), |
| 508 | + ) |
| 509 | + ) |
| 510 | + |
| 511 | + return PromptResponse(stopReason="end_turn") |
| 512 | + |
| 513 | + |
| 514 | +class _ExampleClient(TestClient): |
| 515 | + __test__ = False |
| 516 | + |
| 517 | + def __init__(self) -> None: |
| 518 | + super().__init__() |
| 519 | + self.permission_requests: list[RequestPermissionRequest] = [] |
| 520 | + |
| 521 | + async def requestPermission(self, params: RequestPermissionRequest) -> RequestPermissionResponse: |
| 522 | + self.permission_requests.append(params) |
| 523 | + if not params.options: |
| 524 | + return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled")) |
| 525 | + option = params.options[0] |
| 526 | + return RequestPermissionResponse(outcome=AllowedOutcome(optionId=option.optionId, outcome="selected")) |
| 527 | + |
| 528 | + |
| 529 | +@pytest.mark.asyncio |
| 530 | +async def test_example_agent_permission_flow(): |
| 531 | + async with _Server() as s: |
| 532 | + agent = _ExampleAgent() |
| 533 | + client = _ExampleClient() |
| 534 | + |
| 535 | + agent_conn = ClientSideConnection(lambda _conn: client, s.client_writer, s.client_reader) |
| 536 | + AgentSideConnection(lambda conn: agent.bind(conn), s.server_writer, s.server_reader) |
| 537 | + |
| 538 | + init = await agent_conn.initialize(InitializeRequest(protocolVersion=1)) |
| 539 | + assert init.protocolVersion == 1 |
| 540 | + |
| 541 | + session = await agent_conn.newSession(NewSessionRequest(mcpServers=[], cwd="/workspace")) |
| 542 | + assert session.sessionId == "sess_demo" |
| 543 | + |
| 544 | + prompt = PromptRequest( |
| 545 | + sessionId=session.sessionId, |
| 546 | + prompt=[TextContentBlock(type="text", text="Please edit config")], |
| 547 | + ) |
| 548 | + resp = await agent_conn.prompt(prompt) |
| 549 | + assert resp.stopReason == "end_turn" |
| 550 | + |
| 551 | + for _ in range(50): |
| 552 | + if len(client.notifications) >= 4: |
| 553 | + break |
| 554 | + await asyncio.sleep(0.02) |
| 555 | + |
| 556 | + assert len(client.notifications) >= 4 |
| 557 | + session_updates = [getattr(note.update, "sessionUpdate", None) for note in client.notifications] |
| 558 | + assert session_updates[:4] == ["agent_message_chunk", "tool_call", "tool_call_update", "agent_message_chunk"] |
| 559 | + |
| 560 | + first_message = client.notifications[0].update |
| 561 | + assert isinstance(first_message, AgentMessageChunk) |
| 562 | + assert first_message.content.text == "I'll help you with that." |
| 563 | + |
| 564 | + tool_call = client.notifications[1].update |
| 565 | + assert isinstance(tool_call, ToolCallStart) |
| 566 | + assert tool_call.title == "Modifying configuration" |
| 567 | + assert tool_call.status == "pending" |
| 568 | + |
| 569 | + tool_update = client.notifications[2].update |
| 570 | + assert isinstance(tool_update, ToolCallProgress) |
| 571 | + assert tool_update.status == "completed" |
| 572 | + assert tool_update.rawOutput == {"success": True} |
| 573 | + |
| 574 | + final_message = client.notifications[3].update |
| 575 | + assert isinstance(final_message, AgentMessageChunk) |
| 576 | + assert final_message.content.text == "Done." |
| 577 | + |
| 578 | + assert len(client.permission_requests) == 1 |
| 579 | + options = client.permission_requests[0].options |
| 580 | + assert [opt.optionId for opt in options] == ["allow", "reject"] |
| 581 | + |
| 582 | + assert agent.permission_response is not None |
| 583 | + assert isinstance(agent.permission_response.outcome, AllowedOutcome) |
| 584 | + assert agent.permission_response.outcome.optionId == "allow" |
| 585 | + |
| 586 | + |
419 | 587 | @pytest.mark.asyncio |
420 | 588 | async def test_spawn_agent_process_roundtrip(tmp_path): |
421 | 589 | script = Path(__file__).parents[1] / "examples" / "echo_agent.py" |
|
0 commit comments