From 01b421ec41560df5bf10005f09fb49b23ec76ced Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 23 Oct 2025 14:36:16 +0000 Subject: [PATCH 01/26] fix: change "client/test_client.py" to "client/test_client_factory.py" in Running the tests instructions. "client/test_client_factory.py" no longer exists. --- tests/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/README.md b/tests/README.md index bab99450c..d89f3bec7 100644 --- a/tests/README.md +++ b/tests/README.md @@ -2,7 +2,7 @@ 1. Run the tests ```bash - uv run pytest -v -s client/test_client.py + uv run pytest -v -s client/test_client_factory.py ``` In case of failures, you can cleanup the cache: From 697438f76dba044db5c9f14fb678127572e930ba Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 28 Oct 2025 16:16:02 +0000 Subject: [PATCH 02/26] feat: Add client-side extension support This commit introduces support for clients to declare the extensions they support. - Adds an `extensions` list to `ClientConfig`. - Updates `ClientFactory` to pass `client_extensions` to `JsonRpcTransport` and `RestTransport`. - Adds `_update_extension_header` method to both transports to update the `X-A2A-Extensions` header. - Modifies `send_message` and `send_message_streaming` in `JsonRpcTransport` to include the extension headers. - Modifies `_prepare_send_message` in `RestTransport` to include the extension headers. - Adds tests for the extension header logic in both JSON-RPC and REST transports, including a new test file `test_rest_client.py`. --- src/a2a/client/client.py | 3 + src/a2a/client/client_factory.py | 2 + src/a2a/client/transports/jsonrpc.py | 19 +++ src/a2a/client/transports/rest.py | 18 +++ tests/client/test_jsonrpc_client.py | 179 ++++++++++++++++++++++++++ tests/client/test_rest_client.py | 185 +++++++++++++++++++++++++++ 6 files changed, 406 insertions(+) create mode 100644 tests/client/test_rest_client.py diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 7cc10423d..c65e8204b 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -67,6 +67,9 @@ class ClientConfig: ) """Push notification callbacks to use for every request.""" + extensions: list[str] = dataclasses.field(default_factory=list) + """A list of extension URIs the client supports.""" + UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None # Alias for emitted events from client diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 65b3fb5f0..7d2f40f4c 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -77,6 +77,7 @@ def _register_defaults( TransportProtocol.jsonrpc, lambda card, url, config, interceptors: JsonRpcTransport( config.httpx_client or httpx.AsyncClient(), + config.extensions or None, card, url, interceptors, @@ -87,6 +88,7 @@ def _register_defaults( TransportProtocol.http_json, lambda card, url, config, interceptors: RestTransport( config.httpx_client or httpx.AsyncClient(), + config.extensions or None, card, url, interceptors, diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index bfba09d71..ec360bf12 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -18,6 +18,7 @@ ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCard, CancelTaskRequest, @@ -59,6 +60,7 @@ class JsonRpcTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, + client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, @@ -72,6 +74,7 @@ def __init__( raise ValueError('Must provide either agent_card or url') self.httpx_client = httpx_client + self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -80,6 +83,20 @@ def __init__( else True ) + def _update_extension_header( + self, http_kwargs: dict[str, Any] + ) -> dict[str, Any]: + if self.client_extensions: + headers = http_kwargs.get('headers', {}) + existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') + split = ( + existing_extensions.split(', ') if existing_extensions else [] + ) + updated_extensions = list(set(self.client_extensions + split)) + headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) + http_kwargs['headers'] = headers + return http_kwargs + async def _apply_interceptors( self, method_name: str, @@ -122,6 +139,7 @@ async def send_message( self._get_http_args(context), context, ) + modified_kwargs = self._update_extension_header(modified_kwargs) response_data = await self._send_request(payload, modified_kwargs) response = SendMessageResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): @@ -147,6 +165,7 @@ async def send_message_streaming( context, ) + modified_kwargs = self._update_extension_header(modified_kwargs) modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) ) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index eef7b0f2e..45bbd4c48 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -13,6 +13,7 @@ from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, @@ -40,6 +41,7 @@ class RestTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, + client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, @@ -54,6 +56,7 @@ def __init__( if self.url.endswith('/'): self.url = self.url[:-1] self.httpx_client = httpx_client + self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -62,6 +65,20 @@ def __init__( else True ) + def _update_extension_header( + self, http_kwargs: dict[str, Any] + ) -> dict[str, Any]: + if self.client_extensions: + headers = http_kwargs.get('headers', {}) + existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') + split = ( + existing_extensions.split(', ') if existing_extensions else [] + ) + updated_extensions = list(set(self.client_extensions + split)) + headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) + http_kwargs['headers'] = headers + return http_kwargs + async def _apply_interceptors( self, request_payload: dict[str, Any], @@ -98,6 +115,7 @@ async def _prepare_send_message( self._get_http_args(context), context, ) + modified_kwargs = self._update_extension_header(modified_kwargs) return payload, modified_kwargs async def send_message( diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 58feec25d..94561ab92 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -17,6 +17,7 @@ create_text_message_object, ) from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCapabilities, AgentCard, @@ -785,3 +786,181 @@ async def test_close(self, mock_httpx_client: AsyncMock): ) await client.close() mock_httpx_client.aclose.assert_called_once() + + +class TestJsonRpcTransportExtensions: + def test_update_extension_header_no_initial_headers( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_1', 'test_extension_2'] + client = JsonRpcTransport( + mock_httpx_client, extensions, mock_agent_card + ) + http_kwargs = {} + result_kwargs = client._update_extension_header(http_kwargs) + actual_extensions = set( + result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') + ) + expected_extensions = {'test_extension_1', 'test_extension_2'} + assert actual_extensions == expected_extensions + + def test_update_extension_header_with_existing_other_headers( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_1'] + client = JsonRpcTransport( + mock_httpx_client, extensions, mock_agent_card + ) + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = client._update_extension_header(http_kwargs) + assert ( + result_kwargs['headers'][HTTP_EXTENSION_HEADER] + == 'test_extension_1' + ) + assert result_kwargs['headers']['X_Other'] == 'Test' + + def test_update_extension_header_merge_with_existing_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_1', 'test_extension_2'] + client = JsonRpcTransport( + mock_httpx_client, extensions, mock_agent_card + ) + http_kwargs = { + 'headers': { + HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3' + } + } + result_kwargs = client._update_extension_header(http_kwargs) + actual_extensions_list = result_kwargs['headers'][ + HTTP_EXTENSION_HEADER + ].split(', ') + actual_extensions = set(actual_extensions_list) + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + } + assert len(actual_extensions_list) == 3 + assert actual_extensions == expected_extensions + + def test_update_extension_header_no_client_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card) + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = client._update_extension_header(http_kwargs) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' + + def test_update_extension_header_empty_client_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card) + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = client._update_extension_header(http_kwargs) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' + + @pytest.mark.asyncio + async def test_send_message_with_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds extension headers when client_extensions are provided.""" + extensions = ['test_extension_1', 'test_extension_2'] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + client_extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + success_response = create_text_message_object( + role=Role.agent, content='Hi there!' + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + # Mock the response from httpx_client.post + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = rpc_response.model_dump(mode='json') + mock_httpx_client.post.return_value = mock_response + + await client.send_message(request=params) + + mock_httpx_client.post.assert_called_once() + _, mock_kwargs = mock_httpx_client.post.call_args + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', ')) + expected_extensions = {'test_extension_1', 'test_extension_2'} + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + async def test_send_message_no_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message does not add extension headers when client_extensions is None.""" + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + client_extensions=None, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + success_response = create_text_message_object( + role=Role.agent, content='Hi there!' + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + # Mock the response from httpx_client.post + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = rpc_response.model_dump(mode='json') + mock_httpx_client.post.return_value = mock_response + + await client.send_message(request=params) + + mock_httpx_client.post.assert_called_once() + _, mock_kwargs = mock_httpx_client.post.call_args + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER not in headers + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_with_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + extensions = ['test_extension'] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + client_extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming(request=params): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py new file mode 100644 index 000000000..7fef166fd --- /dev/null +++ b/tests/client/test_rest_client.py @@ -0,0 +1,185 @@ +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from httpx_sse import EventSource, ServerSentEvent + +from a2a.client import create_text_message_object +from a2a.client.transports.rest import RestTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.types import AgentCard, MessageSendParams, Role + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_agent_card() -> MagicMock: + mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') + mock.supports_authenticated_extended_card = False + return mock + + +async def async_iterable_from_list( + items: list[ServerSentEvent], +) -> AsyncGenerator[ServerSentEvent, None]: + """Helper to create an async iterable from a list.""" + for item in items: + yield item + + +class TestRestTransportExtensions: + def test_update_extension_header_no_initial_headers( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_1', 'test_extension_2'] + client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + http_kwargs = {} + result_kwargs = client._update_extension_header(http_kwargs) + actual_extensions = set( + result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') + ) + expected_extensions = {'test_extension_1', 'test_extension_2'} + assert actual_extensions == expected_extensions + + def test_update_extension_header_merge_with_existing_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_2', 'test_extension_3'] + client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + http_kwargs = { + 'headers': { + HTTP_EXTENSION_HEADER: 'test_extension_1, test_extension_2' + } + } + result_kwargs = client._update_extension_header(http_kwargs) + actual_extensions = set( + result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') + ) + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + } + assert actual_extensions == expected_extensions + + def test_update_extension_header_with_other_headers( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + extensions = ['test_extension_1'] + client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = client._update_extension_header(http_kwargs) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension_1' + assert headers['X_Other'] == 'Test' + + @pytest.mark.asyncio + async def test_send_message_with_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds client_extensions to headers.""" + extensions = ['test_extension_1', 'test_extension_2'] + client = RestTransport( + httpx_client=mock_httpx_client, + client_extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + # Mock the build_request method to capture its inputs + mock_build_request = MagicMock( + return_value=AsyncMock(spec=httpx.Request) + ) + mock_httpx_client.build_request = mock_build_request + + # Mock the send method + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_httpx_client.send.return_value = mock_response + + await client.send_message(request=params) + + mock_build_request.assert_called_once() + _, kwargs = mock_build_request.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', ')) + expected_extensions = {'test_extension_1', 'test_extension_2'} + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + async def test_send_message_no_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message does not add extension headers when client_extensions is None.""" + client = RestTransport( + httpx_client=mock_httpx_client, + client_extensions=None, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + # Mock the build_request method to capture its inputs + mock_build_request = MagicMock( + return_value=AsyncMock(spec=httpx.Request) + ) + mock_httpx_client.build_request = mock_build_request + + # Mock the send method + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_httpx_client.send.return_value = mock_response + + await client.send_message(request=params) + + mock_build_request.assert_called_once() + _, kwargs = mock_build_request.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER not in headers + + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_with_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + extensions = ['test_extension'] + client = RestTransport( + httpx_client=mock_httpx_client, + client_extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming(request=params): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' From 860f2d51f64859282adeb9eca0ceae4b488ad843 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 09:44:16 +0000 Subject: [PATCH 03/26] refactor: remove redundant tests for send_message without extensions in JsonRpc and Rest transports --- tests/client/test_jsonrpc_client.py | 32 ---------------------------- tests/client/test_rest_client.py | 33 ----------------------------- 2 files changed, 65 deletions(-) diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 94561ab92..be6dc16cf 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -898,38 +898,6 @@ async def test_send_message_with_extensions( expected_extensions = {'test_extension_1', 'test_extension_2'} assert actual_extensions == expected_extensions - @pytest.mark.asyncio - async def test_send_message_no_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - """Test that send_message does not add extension headers when client_extensions is None.""" - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - client_extensions=None, - agent_card=mock_agent_card, - ) - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - success_response = create_text_message_object( - role=Role.agent, content='Hi there!' - ) - rpc_response = SendMessageSuccessResponse( - id='123', jsonrpc='2.0', result=success_response - ) - # Mock the response from httpx_client.post - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - mock_response.json.return_value = rpc_response.model_dump(mode='json') - mock_httpx_client.post.return_value = mock_response - - await client.send_message(request=params) - - mock_httpx_client.post.assert_called_once() - _, mock_kwargs = mock_httpx_client.post.call_args - headers = mock_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER not in headers - @pytest.mark.asyncio @patch('a2a.client.transports.jsonrpc.aconnect_sse') async def test_send_message_streaming_with_extensions( diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index 7fef166fd..835c2b4cb 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -116,39 +116,6 @@ async def test_send_message_with_extensions( expected_extensions = {'test_extension_1', 'test_extension_2'} assert actual_extensions == expected_extensions - @pytest.mark.asyncio - async def test_send_message_no_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - """Test that send_message does not add extension headers when client_extensions is None.""" - client = RestTransport( - httpx_client=mock_httpx_client, - client_extensions=None, - agent_card=mock_agent_card, - ) - params = MessageSendParams( - message=create_text_message_object(content='Hello') - ) - - # Mock the build_request method to capture its inputs - mock_build_request = MagicMock( - return_value=AsyncMock(spec=httpx.Request) - ) - mock_httpx_client.build_request = mock_build_request - - # Mock the send method - mock_response = AsyncMock(spec=httpx.Response) - mock_response.status_code = 200 - mock_httpx_client.send.return_value = mock_response - - await client.send_message(request=params) - - mock_build_request.assert_called_once() - _, kwargs = mock_build_request.call_args - - headers = kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER not in headers - @pytest.mark.asyncio @patch('a2a.client.transports.rest.aconnect_sse') async def test_send_message_streaming_with_extensions( From 511de3887da8fff7dbc58ff3b1c9b846818076fd Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 10:12:25 +0000 Subject: [PATCH 04/26] refactor: reorder parameters in JsonRpcTransport and RestTransport constructors for compatability with legacy.py --- src/a2a/client/transports/jsonrpc.py | 4 ++-- src/a2a/client/transports/rest.py | 4 ++-- tests/client/test_jsonrpc_client.py | 12 +++++++++--- tests/client/test_rest_client.py | 18 +++++++++++++++--- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index ec360bf12..f26ce24e9 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -60,10 +60,10 @@ class JsonRpcTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, - client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + client_extensions: list[str] | None = None, ): """Initializes the JsonRpcTransport.""" if url: @@ -74,7 +74,6 @@ def __init__( raise ValueError('Must provide either agent_card or url') self.httpx_client = httpx_client - self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -82,6 +81,7 @@ def __init__( if agent_card else True ) + self.client_extensions = client_extensions def _update_extension_header( self, http_kwargs: dict[str, Any] diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 45bbd4c48..769dfbcf8 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -41,10 +41,10 @@ class RestTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, - client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + client_extensions: list[str] | None = None, ): """Initializes the RestTransport.""" if url: @@ -56,7 +56,6 @@ def __init__( if self.url.endswith('/'): self.url = self.url[:-1] self.httpx_client = httpx_client - self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -64,6 +63,7 @@ def __init__( if agent_card else True ) + self.client_extensions = client_extensions def _update_extension_header( self, http_kwargs: dict[str, Any] diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index be6dc16cf..8b6743dec 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -794,7 +794,9 @@ def test_update_extension_header_no_initial_headers( ): extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) @@ -809,7 +811,9 @@ def test_update_extension_header_with_existing_other_headers( ): extensions = ['test_extension_1'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) @@ -824,7 +828,9 @@ def test_update_extension_header_merge_with_existing_extensions( ): extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = { 'headers': { diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index 835c2b4cb..20815dde7 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -37,7 +37,11 @@ def test_update_extension_header_no_initial_headers( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_1', 'test_extension_2'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) actual_extensions = set( @@ -50,7 +54,11 @@ def test_update_extension_header_merge_with_existing_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_2', 'test_extension_3'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = { 'headers': { HTTP_EXTENSION_HEADER: 'test_extension_1, test_extension_2' @@ -71,7 +79,11 @@ def test_update_extension_header_with_other_headers( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_1'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) headers = result_kwargs.get('headers', {}) From 6e80123b7776ded30ed36a00dbb31bf0c28d40e7 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 10:12:25 +0000 Subject: [PATCH 05/26] refactor: reorder parameters in JsonRpcTransport and RestTransport constructors for compatability with legacy.py --- src/a2a/client/client_factory.py | 12 ++++++------ src/a2a/client/transports/jsonrpc.py | 4 ++-- src/a2a/client/transports/rest.py | 4 ++-- tests/client/test_jsonrpc_client.py | 12 +++++++++--- tests/client/test_rest_client.py | 18 +++++++++++++++--- 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 7d2f40f4c..f03008b59 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -76,11 +76,11 @@ def _register_defaults( self.register( TransportProtocol.jsonrpc, lambda card, url, config, interceptors: JsonRpcTransport( - config.httpx_client or httpx.AsyncClient(), - config.extensions or None, - card, - url, - interceptors, + httpx_client=config.httpx_client or httpx.AsyncClient(), + agent_card=card, + url=url, + interceptors=interceptors, + client_extensions=config.extensions or None, ), ) if TransportProtocol.http_json in supported: @@ -88,10 +88,10 @@ def _register_defaults( TransportProtocol.http_json, lambda card, url, config, interceptors: RestTransport( config.httpx_client or httpx.AsyncClient(), - config.extensions or None, card, url, interceptors, + config.extensions or None, ), ) if TransportProtocol.grpc in supported: diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index ec360bf12..f26ce24e9 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -60,10 +60,10 @@ class JsonRpcTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, - client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + client_extensions: list[str] | None = None, ): """Initializes the JsonRpcTransport.""" if url: @@ -74,7 +74,6 @@ def __init__( raise ValueError('Must provide either agent_card or url') self.httpx_client = httpx_client - self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -82,6 +81,7 @@ def __init__( if agent_card else True ) + self.client_extensions = client_extensions def _update_extension_header( self, http_kwargs: dict[str, Any] diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 45bbd4c48..769dfbcf8 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -41,10 +41,10 @@ class RestTransport(ClientTransport): def __init__( self, httpx_client: httpx.AsyncClient, - client_extensions: list[str] | None = None, agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + client_extensions: list[str] | None = None, ): """Initializes the RestTransport.""" if url: @@ -56,7 +56,6 @@ def __init__( if self.url.endswith('/'): self.url = self.url[:-1] self.httpx_client = httpx_client - self.client_extensions = client_extensions self.agent_card = agent_card self.interceptors = interceptors or [] self._needs_extended_card = ( @@ -64,6 +63,7 @@ def __init__( if agent_card else True ) + self.client_extensions = client_extensions def _update_extension_header( self, http_kwargs: dict[str, Any] diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index be6dc16cf..8b6743dec 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -794,7 +794,9 @@ def test_update_extension_header_no_initial_headers( ): extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) @@ -809,7 +811,9 @@ def test_update_extension_header_with_existing_other_headers( ): extensions = ['test_extension_1'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) @@ -824,7 +828,9 @@ def test_update_extension_header_merge_with_existing_extensions( ): extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( - mock_httpx_client, extensions, mock_agent_card + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, ) http_kwargs = { 'headers': { diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index 835c2b4cb..20815dde7 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -37,7 +37,11 @@ def test_update_extension_header_no_initial_headers( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_1', 'test_extension_2'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) actual_extensions = set( @@ -50,7 +54,11 @@ def test_update_extension_header_merge_with_existing_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_2', 'test_extension_3'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = { 'headers': { HTTP_EXTENSION_HEADER: 'test_extension_1, test_extension_2' @@ -71,7 +79,11 @@ def test_update_extension_header_with_other_headers( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): extensions = ['test_extension_1'] - client = RestTransport(mock_httpx_client, extensions, mock_agent_card) + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=extensions, + ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) headers = result_kwargs.get('headers', {}) From 31a4581bb0eb84a154e0cf6995bf0e012c40f6c0 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 12:23:55 +0000 Subject: [PATCH 06/26] Fix Parsing Bug in _update_extension_header method --- src/a2a/client/transports/jsonrpc.py | 23 ++++++++++++++--------- src/a2a/client/transports/rest.py | 23 ++++++++++++++--------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index f26ce24e9..523be0fbc 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -86,15 +86,20 @@ def __init__( def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if self.client_extensions: - headers = http_kwargs.get('headers', {}) - existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') - split = ( - existing_extensions.split(', ') if existing_extensions else [] - ) - updated_extensions = list(set(self.client_extensions + split)) - headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) - http_kwargs['headers'] = headers + if not self.client_extensions: + return http_kwargs + + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + existing_extensions = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + + all_extensions = set(self.client_extensions) + all_extensions.update(existing_extensions) + + headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 769dfbcf8..696e0a6d9 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -68,15 +68,20 @@ def __init__( def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if self.client_extensions: - headers = http_kwargs.get('headers', {}) - existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') - split = ( - existing_extensions.split(', ') if existing_extensions else [] - ) - updated_extensions = list(set(self.client_extensions + split)) - headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) - http_kwargs['headers'] = headers + if not self.client_extensions: + return http_kwargs + + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + existing_extensions = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + + all_extensions = set(self.client_extensions) + all_extensions.update(existing_extensions) + + headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( From 5fc530e7639a99d3b92a715e08b9d41bd235026a Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 12:23:55 +0000 Subject: [PATCH 07/26] Fix Parsing Bug in _update_extension_header method --- src/a2a/client/transports/jsonrpc.py | 23 ++++++++++++++--------- src/a2a/client/transports/rest.py | 23 ++++++++++++++--------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index f26ce24e9..523be0fbc 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -86,15 +86,20 @@ def __init__( def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if self.client_extensions: - headers = http_kwargs.get('headers', {}) - existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') - split = ( - existing_extensions.split(', ') if existing_extensions else [] - ) - updated_extensions = list(set(self.client_extensions + split)) - headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) - http_kwargs['headers'] = headers + if not self.client_extensions: + return http_kwargs + + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + existing_extensions = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + + all_extensions = set(self.client_extensions) + all_extensions.update(existing_extensions) + + headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 769dfbcf8..696e0a6d9 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -68,15 +68,20 @@ def __init__( def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if self.client_extensions: - headers = http_kwargs.get('headers', {}) - existing_extensions = headers.get(HTTP_EXTENSION_HEADER, '') - split = ( - existing_extensions.split(', ') if existing_extensions else [] - ) - updated_extensions = list(set(self.client_extensions + split)) - headers[HTTP_EXTENSION_HEADER] = ', '.join(updated_extensions) - http_kwargs['headers'] = headers + if not self.client_extensions: + return http_kwargs + + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + existing_extensions = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + + all_extensions = set(self.client_extensions) + all_extensions.update(existing_extensions) + + headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( From 97eec52b7ab9531bf65340de6a876a76ba559171 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 29 Oct 2025 13:29:13 +0000 Subject: [PATCH 08/26] refactor: streamline extension header handling in JsonRpcTransport and RestTransport tests. Remove redundant code from client_factory.py --- src/a2a/client/client_factory.py | 10 ++--- tests/client/test_jsonrpc_client.py | 68 +++++++++++++++++++++-------- tests/client/test_rest_client.py | 55 ++++++++++++++++------- 3 files changed, 93 insertions(+), 40 deletions(-) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index f03008b59..1507f23a4 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -76,11 +76,11 @@ def _register_defaults( self.register( TransportProtocol.jsonrpc, lambda card, url, config, interceptors: JsonRpcTransport( - httpx_client=config.httpx_client or httpx.AsyncClient(), - agent_card=card, - url=url, - interceptors=interceptors, - client_extensions=config.extensions or None, + config.httpx_client or httpx.AsyncClient(), + card, + url, + interceptors, + config.extensions or None, ), ) if TransportProtocol.http_json in supported: diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 8b6743dec..8081fb589 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -800,10 +800,15 @@ def test_update_extension_header_no_initial_headers( ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) - actual_extensions = set( - result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') - ) - expected_extensions = {'test_extension_1', 'test_extension_2'} + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + } + assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions def test_update_extension_header_with_existing_other_headers( @@ -823,8 +828,20 @@ def test_update_extension_header_with_existing_other_headers( ) assert result_kwargs['headers']['X_Other'] == 'Test' + @pytest.mark.parametrize( + 'existing_header, expected_count', + [ + ('test_extension_2, test_extension_3', 3), + ('test_extension_2,test_extension_3', 3), + ('test_extension_3', 3), + ], + ) def test_update_extension_header_merge_with_existing_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + self, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + existing_header: str, + expected_count: int, ): extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( @@ -832,16 +849,13 @@ def test_update_extension_header_merge_with_existing_extensions( agent_card=mock_agent_card, client_extensions=extensions, ) - http_kwargs = { - 'headers': { - HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3' - } - } + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} result_kwargs = client._update_extension_header(http_kwargs) - actual_extensions_list = result_kwargs['headers'][ - HTTP_EXTENSION_HEADER - ].split(', ') + + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] actual_extensions = set(actual_extensions_list) + expected_extensions = { 'test_extension_1', 'test_extension_2', @@ -853,7 +867,11 @@ def test_update_extension_header_merge_with_existing_extensions( def test_update_extension_header_no_client_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = JsonRpcTransport(mock_httpx_client, None, mock_agent_card) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=None, + ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] @@ -862,7 +880,11 @@ def test_update_extension_header_no_client_extensions( def test_update_extension_header_empty_client_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - client = JsonRpcTransport(mock_httpx_client, [], mock_agent_card) + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + client_extensions=[], + ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] @@ -876,8 +898,8 @@ async def test_send_message_with_extensions( extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( httpx_client=mock_httpx_client, - client_extensions=extensions, agent_card=mock_agent_card, + client_extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello') @@ -898,10 +920,18 @@ async def test_send_message_with_extensions( mock_httpx_client.post.assert_called_once() _, mock_kwargs = mock_httpx_client.post.call_args + headers = mock_kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', ')) - expected_extensions = {'test_extension_1', 'test_extension_2'} + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + } + assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions @pytest.mark.asyncio @@ -916,8 +946,8 @@ async def test_send_message_streaming_with_extensions( extensions = ['test_extension'] client = JsonRpcTransport( httpx_client=mock_httpx_client, - client_extensions=extensions, agent_card=mock_agent_card, + client_extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index 20815dde7..02d7cbf0c 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -44,14 +44,31 @@ def test_update_extension_header_no_initial_headers( ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) - actual_extensions = set( - result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') - ) - expected_extensions = {'test_extension_1', 'test_extension_2'} + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + } + assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions + @pytest.mark.parametrize( + 'existing_header, expected_count', + [ + ('test_extension_1, test_extension_2', 3), + ('test_extension_1,test_extension_2', 3), + ('test_extension_1', 3), + ], + ) def test_update_extension_header_merge_with_existing_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + self, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + existing_header: str, + expected_count: int, ): extensions = ['test_extension_2', 'test_extension_3'] client = RestTransport( @@ -59,20 +76,19 @@ def test_update_extension_header_merge_with_existing_extensions( agent_card=mock_agent_card, client_extensions=extensions, ) - http_kwargs = { - 'headers': { - HTTP_EXTENSION_HEADER: 'test_extension_1, test_extension_2' - } - } + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} result_kwargs = client._update_extension_header(http_kwargs) - actual_extensions = set( - result_kwargs['headers'][HTTP_EXTENSION_HEADER].split(', ') - ) + + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + expected_extensions = { 'test_extension_1', 'test_extension_2', 'test_extension_3', } + assert len(actual_extensions_list) == expected_count assert actual_extensions == expected_extensions def test_update_extension_header_with_other_headers( @@ -124,8 +140,15 @@ async def test_send_message_with_extensions( headers = kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - actual_extensions = set(headers[HTTP_EXTENSION_HEADER].split(', ')) - expected_extensions = {'test_extension_1', 'test_extension_2'} + header_value = kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + } + assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions @pytest.mark.asyncio @@ -140,8 +163,8 @@ async def test_send_message_streaming_with_extensions( extensions = ['test_extension'] client = RestTransport( httpx_client=mock_httpx_client, - client_extensions=extensions, agent_card=mock_agent_card, + client_extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') From caba0a237fcc648ed04506bb7f36a8ad7565fdee Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 30 Oct 2025 15:01:13 +0000 Subject: [PATCH 09/26] refactor: rename client_extensions to extensions in JsonRpcTransport and RestTransport --- src/a2a/client/transports/jsonrpc.py | 10 +++++----- src/a2a/client/transports/rest.py | 10 +++++----- tests/client/test_jsonrpc_client.py | 20 ++++++++++---------- tests/client/test_rest_client.py | 12 ++++++------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 523be0fbc..08fea2c02 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -63,7 +63,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, - client_extensions: list[str] | None = None, + extensions: list[str] | None = None, ): """Initializes the JsonRpcTransport.""" if url: @@ -81,12 +81,12 @@ def __init__( if agent_card else True ) - self.client_extensions = client_extensions + self.extensions = extensions def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if not self.client_extensions: + if not self.extensions: return http_kwargs headers = http_kwargs.setdefault('headers', {}) @@ -96,10 +96,10 @@ def _update_extension_header( e.strip() for e in existing_extensions_str.split(',') if e.strip() ] - all_extensions = set(self.client_extensions) + all_extensions = set(self.extensions) all_extensions.update(existing_extensions) - headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) + headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 696e0a6d9..94e50ce86 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -44,7 +44,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, - client_extensions: list[str] | None = None, + extensions: list[str] | None = None, ): """Initializes the RestTransport.""" if url: @@ -63,12 +63,12 @@ def __init__( if agent_card else True ) - self.client_extensions = client_extensions + self.extensions = extensions def _update_extension_header( self, http_kwargs: dict[str, Any] ) -> dict[str, Any]: - if not self.client_extensions: + if not self.extensions: return http_kwargs headers = http_kwargs.setdefault('headers', {}) @@ -78,10 +78,10 @@ def _update_extension_header( e.strip() for e in existing_extensions_str.split(',') if e.strip() ] - all_extensions = set(self.client_extensions) + all_extensions = set(self.extensions) all_extensions.update(existing_extensions) - headers[HTTP_EXTENSION_HEADER] = ', '.join(list(all_extensions)) + headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) return http_kwargs async def _apply_interceptors( diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 8081fb589..6ea288293 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -796,7 +796,7 @@ def test_update_extension_header_no_initial_headers( client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) @@ -818,7 +818,7 @@ def test_update_extension_header_with_existing_other_headers( client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) @@ -847,7 +847,7 @@ def test_update_extension_header_merge_with_existing_extensions( client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} result_kwargs = client._update_extension_header(http_kwargs) @@ -864,26 +864,26 @@ def test_update_extension_header_merge_with_existing_extensions( assert len(actual_extensions_list) == 3 assert actual_extensions == expected_extensions - def test_update_extension_header_no_client_extensions( + def test_update_extension_header_no_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=None, + extensions=None, ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' - def test_update_extension_header_empty_client_extensions( + def test_update_extension_header_empty_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=[], + extensions=[], ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) @@ -894,12 +894,12 @@ def test_update_extension_header_empty_client_extensions( async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - """Test that send_message adds extension headers when client_extensions are provided.""" + """Test that send_message adds extension headers when extensions are provided.""" extensions = ['test_extension_1', 'test_extension_2'] client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello') @@ -947,7 +947,7 @@ async def test_send_message_streaming_with_extensions( client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index 02d7cbf0c..ea534e87f 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -40,7 +40,7 @@ def test_update_extension_header_no_initial_headers( client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {} result_kwargs = client._update_extension_header(http_kwargs) @@ -74,7 +74,7 @@ def test_update_extension_header_merge_with_existing_extensions( client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} result_kwargs = client._update_extension_header(http_kwargs) @@ -98,7 +98,7 @@ def test_update_extension_header_with_other_headers( client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = client._update_extension_header(http_kwargs) @@ -111,11 +111,11 @@ def test_update_extension_header_with_other_headers( async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): - """Test that send_message adds client_extensions to headers.""" + """Test that send_message adds extensions to headers.""" extensions = ['test_extension_1', 'test_extension_2'] client = RestTransport( httpx_client=mock_httpx_client, - client_extensions=extensions, + extensions=extensions, agent_card=mock_agent_card, ) params = MessageSendParams( @@ -164,7 +164,7 @@ async def test_send_message_streaming_with_extensions( client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, - client_extensions=extensions, + extensions=extensions, ) params = MessageSendParams( message=create_text_message_object(content='Hello stream') From 28b1d5353f7875a03d051c5670c688efc81511ee Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 3 Nov 2025 10:19:57 +0000 Subject: [PATCH 10/26] feat: move common functions for managing HTTP extension headers to utility.py Add extensions feature to grpc. --- src/a2a/client/client.py | 3 + src/a2a/client/transports/grpc.py | 49 +++- src/a2a/client/transports/jsonrpc.py | 52 ++--- src/a2a/client/transports/rest.py | 46 +--- src/a2a/client/transports/utils.py | 24 ++ tests/client/test_grpc_client.py | 320 +++++++++++++++++++++++++- tests/client/test_jsonrpc_client.py | 101 -------- tests/client/test_rest_client.py | 74 ------ tests/client/transports/test_utils.py | 82 +++++++ 9 files changed, 489 insertions(+), 262 deletions(-) create mode 100644 src/a2a/client/transports/utils.py create mode 100644 tests/client/transports/test_utils.py diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index c65e8204b..eca59be39 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -93,6 +93,7 @@ def __init__( self, consumers: list[Consumer] | None = None, middleware: list[ClientCallInterceptor] | None = None, + # iva todo add optional extensions- it can override value from the config, if it is provided ): """Initializes the client with consumers and middleware. @@ -113,6 +114,8 @@ async def send_message( request: Message, *, context: ClientCallContext | None = None, + # iva todo add optional extensions- it can override value from the config, if it is provided + # and to the other ones as well ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e50b0ea81..1a5532264 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -1,6 +1,7 @@ import logging from collections.abc import AsyncGenerator +from typing import Any try: @@ -16,6 +17,7 @@ from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport +from a2a.client.transports.utils import update_extension_header from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -44,6 +46,7 @@ def __init__( self, channel: Channel, agent_card: AgentCard | None, + extensions: list[str] | None = None, ): """Initializes the GrpcTransport.""" self.agent_card = agent_card @@ -54,6 +57,25 @@ def __init__( if agent_card else True ) + self.extensions = extensions + + def _get_metadata( + self, context: ClientCallContext | None + ) -> list[tuple[str, str]]: + http_kwargs: dict[str, Any] = {} + if context and context.state.get('grpc_metadata'): + # Convert existing metadata to headers format for update_extension_header + http_kwargs['headers'] = { + k: v for k, v in context.state['grpc_metadata'] + } + + updated_kwargs = update_extension_header(http_kwargs, self.extensions) + + metadata = [] + if 'headers' in updated_kwargs: + metadata.extend(updated_kwargs['headers'].items()) + + return metadata @classmethod def create( @@ -66,10 +88,7 @@ def create( """Creates a gRPC transport for the A2A client.""" if config.grpc_channel_factory is None: raise ValueError('grpc_channel_factory is required when using gRPC') - return cls( - config.grpc_channel_factory(url), - card, - ) + return cls(config.grpc_channel_factory(url), card, config.extensions) async def send_message( self, @@ -85,7 +104,8 @@ async def send_message( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_metadata(context), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -107,7 +127,8 @@ async def send_message_streaming( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_metadata(context), ) while True: response = await stream.read() @@ -122,7 +143,8 @@ async def resubscribe( ]: """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), + metadata=self._get_metadata(context), ) while True: response = await stream.read() @@ -141,7 +163,8 @@ async def get_task( a2a_pb2.GetTaskRequest( name=f'tasks/{request.id}', history_length=request.history_length, - ) + ), + metadata=self._get_metadata(context), ) return proto_utils.FromProto.task(task) @@ -153,7 +176,8 @@ async def cancel_task( ) -> Task: """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), + metadata=self._get_metadata(context), ) return proto_utils.FromProto.task(task) @@ -171,7 +195,8 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config( request ), - ) + ), + metadata=self._get_metadata(context), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -185,7 +210,8 @@ async def get_task_callback( config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - ) + ), + metadata=self._get_metadata(context), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -203,6 +229,7 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), + metadata=self._get_metadata(context), # probaby not needed ) card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 08fea2c02..b407e959b 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -18,7 +18,7 @@ ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.client.transports.utils import get_http_args, update_extension_header from a2a.types import ( AgentCard, CancelTaskRequest, @@ -83,25 +83,6 @@ def __init__( ) self.extensions = extensions - def _update_extension_header( - self, http_kwargs: dict[str, Any] - ) -> dict[str, Any]: - if not self.extensions: - return http_kwargs - - headers = http_kwargs.setdefault('headers', {}) - existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - - existing_extensions = [ - e.strip() for e in existing_extensions_str.split(',') if e.strip() - ] - - all_extensions = set(self.extensions) - all_extensions.update(existing_extensions) - - headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) - return http_kwargs - async def _apply_interceptors( self, method_name: str, @@ -125,11 +106,6 @@ async def _apply_interceptors( ) return final_request_payload, final_http_kwargs - def _get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - return context.state.get('http_kwargs') if context else None - async def send_message( self, request: MessageSendParams, @@ -141,10 +117,12 @@ async def send_message( payload, modified_kwargs = await self._apply_interceptors( 'message/send', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) - modified_kwargs = self._update_extension_header(modified_kwargs) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request(payload, modified_kwargs) response = SendMessageResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): @@ -166,11 +144,13 @@ async def send_message_streaming( payload, modified_kwargs = await self._apply_interceptors( 'message/stream', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) - modified_kwargs = self._update_extension_header(modified_kwargs) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) ) @@ -237,7 +217,7 @@ async def get_task( payload, modified_kwargs = await self._apply_interceptors( 'tasks/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -257,7 +237,7 @@ async def cancel_task( payload, modified_kwargs = await self._apply_interceptors( 'tasks/cancel', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -279,7 +259,7 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/set', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -303,7 +283,7 @@ async def get_task_callback( payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -327,7 +307,7 @@ async def resubscribe( payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) @@ -369,7 +349,7 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) + http_kwargs=get_http_args(context) ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -383,7 +363,7 @@ async def get_card( payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 94e50ce86..6b838e455 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -13,7 +13,7 @@ from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.client.transports.utils import get_http_args, update_extension_header from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, @@ -65,25 +65,6 @@ def __init__( ) self.extensions = extensions - def _update_extension_header( - self, http_kwargs: dict[str, Any] - ) -> dict[str, Any]: - if not self.extensions: - return http_kwargs - - headers = http_kwargs.setdefault('headers', {}) - existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - - existing_extensions = [ - e.strip() for e in existing_extensions_str.split(',') if e.strip() - ] - - all_extensions = set(self.extensions) - all_extensions.update(existing_extensions) - - headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) - return http_kwargs - async def _apply_interceptors( self, request_payload: dict[str, Any], @@ -95,11 +76,6 @@ async def _apply_interceptors( # TODO: Implement interceptors for other transports return final_request_payload, final_http_kwargs - def _get_http_args( - self, context: ClientCallContext | None - ) -> dict[str, Any] | None: - return context.state.get('http_kwargs') if context else None - async def _prepare_send_message( self, request: MessageSendParams, context: ClientCallContext | None ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -117,10 +93,12 @@ async def _prepare_send_message( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + get_http_args(context), context, ) - modified_kwargs = self._update_extension_header(modified_kwargs) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) return payload, modified_kwargs async def send_message( @@ -231,7 +209,7 @@ async def get_task( """Retrieves the current state and history of a specific task.""" _payload, modified_kwargs = await self._apply_interceptors( request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_get_request( @@ -256,7 +234,7 @@ async def cancel_task( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_post_request( @@ -280,7 +258,7 @@ async def set_task_callback( ) payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( - payload, self._get_http_args(context), context + payload, get_http_args(context), context ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', @@ -304,7 +282,7 @@ async def get_task_callback( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_get_request( @@ -325,7 +303,7 @@ async def resubscribe( Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - http_kwargs = self._get_http_args(context) or {} + http_kwargs = get_http_args(context) or {} http_kwargs.setdefault('timeout', None) async with aconnect_sse( @@ -360,7 +338,7 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=self._get_http_args(context) + http_kwargs=get_http_args(context) ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -372,7 +350,7 @@ async def get_card( _, modified_kwargs = await self._apply_interceptors( {}, - self._get_http_args(context), + get_http_args(context), context, ) response_data = await self._send_get_request( diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py new file mode 100644 index 000000000..27f45d7af --- /dev/null +++ b/src/a2a/client/transports/utils.py @@ -0,0 +1,24 @@ +from typing import Any + +from a2a.client.middleware import ClientCallContext +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + + +def update_extension_header( + http_kwargs: dict[str, Any], extensions: list[str] | None +) -> dict[str, Any]: + if not extensions: + return http_kwargs + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + existing_extensions = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + all_extensions = set(extensions) + all_extensions.update(existing_extensions) + headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) + return http_kwargs diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index 19f5abc16..c3e2bbb6f 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -1,9 +1,11 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import grpc import pytest +from a2a.client.middleware import ClientCallContext from a2a.client.transports.grpc import GrpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -40,6 +42,8 @@ def mock_grpc_stub() -> AsyncMock: stub.CancelTask = AsyncMock() stub.CreateTaskPushNotificationConfig = AsyncMock() stub.GetTaskPushNotificationConfig = AsyncMock() + stub.TaskSubscription = MagicMock() + stub.GetAgentCard = AsyncMock() return stub @@ -278,7 +282,8 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=None - ) + ), + metadata=[], ) assert response.id == sample_task.id @@ -297,7 +302,8 @@ async def test_get_task_with_history( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=history_len - ) + ), + metadata=[], ) @@ -316,7 +322,8 @@ async def test_cancel_task( response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( - a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), + metadata=[], ) assert response.status.state == TaskState.canceled @@ -345,7 +352,8 @@ async def test_set_task_callback_with_valid_task( config=proto_utils.ToProto.task_push_notification_config( sample_task_push_notification_config ), - ) + ), + metadata=[], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -402,7 +410,8 @@ async def test_get_task_callback_with_valid_task( f'tasks/{params.id}/' f'pushNotificationConfigs/{params.push_notification_config_id}' ), - ) + ), + metadata=[], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -434,3 +443,302 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) + + +class TestGrpcTransportExtensions: + def test_get_metadata_no_initial(self, sample_agent_card: AgentCard): + extensions = ['test_extension_1', 'test_extension_2'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + metadata = transport._get_metadata(None) + metadata_dict = dict(metadata) + assert HTTP_EXTENSION_HEADER in metadata_dict + actual_extensions = set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) + assert actual_extensions == set(extensions) + + def test_get_metadata_with_existing(self, sample_agent_card: AgentCard): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + context = ClientCallContext( + state={'grpc_metadata': [('x-other', 'Test')]} + ) + metadata = transport._get_metadata(context) + metadata_dict = dict(metadata) + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + assert metadata_dict['x-other'] == 'Test' + + @pytest.mark.parametrize( + 'existing_header, expected_extensions', + [ + ( + 'test_extension_2, test_extension_3', + {'test_extension_1', 'test_extension_2', 'test_extension_3'}, + ), + ( + 'test_extension_3', + {'test_extension_1', 'test_extension_2', 'test_extension_3'}, + ), + ], + ) + def test_get_metadata_merge_with_existing( + self, + sample_agent_card: AgentCard, + existing_header: str, + expected_extensions: set, + ): + extensions = ['test_extension_1', 'test_extension_2'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + context = ClientCallContext( + state={'grpc_metadata': [(HTTP_EXTENSION_HEADER, existing_header)]} + ) + metadata = transport._get_metadata(context) + metadata_dict = dict(metadata) + assert HTTP_EXTENSION_HEADER in metadata_dict + actual_extensions = set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) + assert actual_extensions == expected_extensions + + def test_get_metadata_no_extensions(self, sample_agent_card: AgentCard): + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=None, + ) + context = ClientCallContext( + state={'grpc_metadata': [('x-other', 'Test')]} + ) + metadata = transport._get_metadata(context) + metadata_dict = dict(metadata) + assert HTTP_EXTENSION_HEADER not in metadata_dict + assert metadata_dict['x-other'] == 'Test' + + def test_get_metadata_empty_extensions(self, sample_agent_card: AgentCard): + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=[], + ) + context = ClientCallContext( + state={'grpc_metadata': [('x-other', 'Test')]} + ) + metadata = transport._get_metadata(context) + metadata_dict = dict(metadata) + assert HTTP_EXTENSION_HEADER not in metadata_dict + assert metadata_dict['x-other'] == 'Test' + + @pytest.mark.asyncio + async def test_send_message_with_extensions( + self, + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_message_send_params: MessageSendParams, + ): + extensions = ['test_extension_1', 'test_extension_2'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( + msg=proto_utils.ToProto.message(sample_message_send_params.message) + ) + + await transport.send_message(sample_message_send_params) + + mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) == set( + extensions + ) + + @pytest.mark.asyncio + async def test_send_message_streaming_with_extensions( + self, + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_message_send_params: MessageSendParams, + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + stream = MagicMock() + stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) + mock_grpc_stub.SendStreamingMessage.return_value = stream + + async for _ in transport.send_message_streaming( + sample_message_send_params + ): + pass + + mock_grpc_stub.SendStreamingMessage.assert_called_once() + _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_resubscribe_with_extensions( + self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + stream = MagicMock() + stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) + mock_grpc_stub.TaskSubscription.return_value = stream + + async for _ in transport.resubscribe(TaskIdParams(id='task-1')): + pass + + mock_grpc_stub.TaskSubscription.assert_called_once() + _, kwargs = mock_grpc_stub.TaskSubscription.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_get_task_with_extensions( + self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.GetTask.return_value = a2a_pb2.Task() + + await transport.get_task(TaskQueryParams(id='task-1')) + + mock_grpc_stub.GetTask.assert_awaited_once() + _, kwargs = mock_grpc_stub.GetTask.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_cancel_task_with_extensions( + self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.CancelTask.return_value = a2a_pb2.Task() + + await transport.cancel_task(TaskIdParams(id='task-1')) + + mock_grpc_stub.CancelTask.assert_awaited_once() + _, kwargs = mock_grpc_stub.CancelTask.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_set_task_callback_with_extensions( + self, + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_task_push_notification_config: TaskPushNotificationConfig, + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( + proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ) + ) + + await transport.set_task_callback(sample_task_push_notification_config) + + mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once() + _, kwargs = mock_grpc_stub.CreateTaskPushNotificationConfig.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_get_task_callback_with_extensions( + self, + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_task_push_notification_config: TaskPushNotificationConfig, + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( + proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ) + ) + + await transport.get_task_callback( + GetTaskPushNotificationConfigParams( + id=sample_task_push_notification_config.task_id, + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, + ) + ) + + mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once() + _, kwargs = mock_grpc_stub.GetTaskPushNotificationConfig.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' + + @pytest.mark.asyncio + async def test_get_card_with_extensions( + self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard + ): + extensions = ['test_extension'] + transport = GrpcTransport( + channel=AsyncMock(), + agent_card=sample_agent_card, + extensions=extensions, + ) + transport.stub = mock_grpc_stub + mock_grpc_stub.GetAgentCard.return_value = ( + proto_utils.ToProto.agent_card(sample_agent_card) + ) + + await transport.get_card() + + mock_grpc_stub.GetAgentCard.assert_awaited_once() + _, kwargs = mock_grpc_stub.GetAgentCard.call_args + metadata_dict = dict(kwargs['metadata']) + assert HTTP_EXTENSION_HEADER in metadata_dict + assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/test_jsonrpc_client.py index 6ea288293..3ebd50173 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/test_jsonrpc_client.py @@ -789,107 +789,6 @@ async def test_close(self, mock_httpx_client: AsyncMock): class TestJsonRpcTransportExtensions: - def test_update_extension_header_no_initial_headers( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - extensions = ['test_extension_1', 'test_extension_2'] - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {} - result_kwargs = client._update_extension_header(http_kwargs) - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions - - def test_update_extension_header_with_existing_other_headers( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - extensions = ['test_extension_1'] - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = client._update_extension_header(http_kwargs) - assert ( - result_kwargs['headers'][HTTP_EXTENSION_HEADER] - == 'test_extension_1' - ) - assert result_kwargs['headers']['X_Other'] == 'Test' - - @pytest.mark.parametrize( - 'existing_header, expected_count', - [ - ('test_extension_2, test_extension_3', 3), - ('test_extension_2,test_extension_3', 3), - ('test_extension_3', 3), - ], - ) - def test_update_extension_header_merge_with_existing_extensions( - self, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - existing_header: str, - expected_count: int, - ): - extensions = ['test_extension_1', 'test_extension_2'] - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} - result_kwargs = client._update_extension_header(http_kwargs) - - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - } - assert len(actual_extensions_list) == 3 - assert actual_extensions == expected_extensions - - def test_update_extension_header_no_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=None, - ) - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = client._update_extension_header(http_kwargs) - assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] - assert result_kwargs['headers']['X_Other'] == 'Test' - - def test_update_extension_header_empty_extensions( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - client = JsonRpcTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=[], - ) - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = client._update_extension_header(http_kwargs) - assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] - assert result_kwargs['headers']['X_Other'] == 'Test' - @pytest.mark.asyncio async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock diff --git a/tests/client/test_rest_client.py b/tests/client/test_rest_client.py index ea534e87f..d236fad24 100644 --- a/tests/client/test_rest_client.py +++ b/tests/client/test_rest_client.py @@ -33,80 +33,6 @@ async def async_iterable_from_list( class TestRestTransportExtensions: - def test_update_extension_header_no_initial_headers( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - extensions = ['test_extension_1', 'test_extension_2'] - client = RestTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {} - result_kwargs = client._update_extension_header(http_kwargs) - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions - - @pytest.mark.parametrize( - 'existing_header, expected_count', - [ - ('test_extension_1, test_extension_2', 3), - ('test_extension_1,test_extension_2', 3), - ('test_extension_1', 3), - ], - ) - def test_update_extension_header_merge_with_existing_extensions( - self, - mock_httpx_client: AsyncMock, - mock_agent_card: MagicMock, - existing_header: str, - expected_count: int, - ): - extensions = ['test_extension_2', 'test_extension_3'] - client = RestTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} - result_kwargs = client._update_extension_header(http_kwargs) - - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - } - assert len(actual_extensions_list) == expected_count - assert actual_extensions == expected_extensions - - def test_update_extension_header_with_other_headers( - self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock - ): - extensions = ['test_extension_1'] - client = RestTransport( - httpx_client=mock_httpx_client, - agent_card=mock_agent_card, - extensions=extensions, - ) - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = client._update_extension_header(http_kwargs) - headers = result_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension_1' - assert headers['X_Other'] == 'Test' - @pytest.mark.asyncio async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock diff --git a/tests/client/transports/test_utils.py b/tests/client/transports/test_utils.py new file mode 100644 index 000000000..b7f2ff62d --- /dev/null +++ b/tests/client/transports/test_utils.py @@ -0,0 +1,82 @@ +import pytest + +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.client.transports.utils import update_extension_header + + +class TestUtils: + def test_update_extension_header_no_initial_headers(self): + extensions = ['test_extension_1', 'test_extension_2'] + + http_kwargs = {} + result_kwargs = update_extension_header(http_kwargs, extensions) + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.parametrize( + 'existing_header, expected_count', + [ + ('test_extension_1, test_extension_2', 3), + ('test_extension_1,test_extension_2', 3), + ('test_extension_1', 3), + ], + ) + def test_update_extension_header_merge_with_existing_extensions( + self, + existing_header: str, + expected_count: int, + ): + extensions = ['test_extension_2', 'test_extension_3'] + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} + result_kwargs = update_extension_header(http_kwargs, extensions) + + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + } + assert len(actual_extensions_list) == expected_count + assert actual_extensions == expected_extensions + + def test_update_extension_header_with_other_headers(self): + extensions = ['test_extension_1'] + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension_1' + assert headers['X_Other'] == 'Test' + + def test_update_extension_header_with_existing_other_headers(self): + extensions = ['test_extension_1'] + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + assert ( + result_kwargs['headers'][HTTP_EXTENSION_HEADER] + == 'test_extension_1' + ) + assert result_kwargs['headers']['X_Other'] == 'Test' + + def test_update_extension_header_no_extensions(self): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, None) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' + + def test_update_extension_header_empty_extensions(self): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, []) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' From 270d6e7dc00cb566538110dc25d558304b7ea2e9 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 3 Nov 2025 10:58:20 +0000 Subject: [PATCH 11/26] Remove extensions from grpc methog get_card --- src/a2a/client/transports/grpc.py | 1 - tests/client/test_grpc_client.py | 23 ----------------------- 2 files changed, 24 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 1a5532264..2ba7ac4ae 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -229,7 +229,6 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), - metadata=self._get_metadata(context), # probaby not needed ) card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index c3e2bbb6f..d111b219a 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -719,26 +719,3 @@ async def test_get_task_callback_with_extensions( metadata_dict = dict(kwargs['metadata']) assert HTTP_EXTENSION_HEADER in metadata_dict assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_get_card_with_extensions( - self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.GetAgentCard.return_value = ( - proto_utils.ToProto.agent_card(sample_agent_card) - ) - - await transport.get_card() - - mock_grpc_stub.GetAgentCard.assert_awaited_once() - _, kwargs = mock_grpc_stub.GetAgentCard.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' From a9aa9ee846e6119f286af784a498a8d2bbe61d22 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 3 Nov 2025 15:42:15 +0000 Subject: [PATCH 12/26] feat: add support for extensions in Client and BaseClient, update transport methods to handle extensions --- src/a2a/client/base_client.py | 5 +- src/a2a/client/client.py | 6 ++- src/a2a/client/client_factory.py | 16 +++++- src/a2a/client/transports/jsonrpc.py | 19 ++++++- src/a2a/client/transports/rest.py | 18 ++++++- tests/client/test_base_client.py | 1 + tests/client/transports/test_utils.py | 72 +++++++++++++-------------- 7 files changed, 93 insertions(+), 44 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index f4a8d03de..8f79dea97 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -36,11 +36,14 @@ def __init__( transport: ClientTransport, consumers: list[Consumer], middleware: list[ClientCallInterceptor], + extensions: list[str], ): - super().__init__(consumers, middleware) + super().__init__(consumers, middleware, extensions) self._card = card self._config = config self._transport = transport + if self._extensions: + self._config.extensions = self._extensions async def send_message( self, diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index eca59be39..d9ba6580b 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -93,7 +93,8 @@ def __init__( self, consumers: list[Consumer] | None = None, middleware: list[ClientCallInterceptor] | None = None, - # iva todo add optional extensions- it can override value from the config, if it is provided + # iva todo - it can override value from the config, if it is provided + extensions: list[str] | None = None, ): """Initializes the client with consumers and middleware. @@ -105,8 +106,11 @@ def __init__( middleware = [] if consumers is None: consumers = [] + if extensions is None: + extensions = [] self._consumers = consumers self._middleware = middleware + self._extensions = extensions @abstractmethod async def send_message( diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 1507f23a4..b6f534d6f 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -115,6 +115,7 @@ async def connect( # noqa: PLR0913 relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, extra_transports: dict[str, TransportProducer] | None = None, + extensions: list[str] | None = None, ) -> Client: """Convenience method for constructing a client. @@ -168,7 +169,7 @@ async def connect( # noqa: PLR0913 factory = cls(client_config) for label, generator in (extra_transports or {}).items(): factory.register(label, generator) - return factory.create(card, consumers, interceptors) + return factory.create(card, consumers, interceptors, extensions) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -179,6 +180,7 @@ def create( card: AgentCard, consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ) -> Client: """Create a new `Client` for the provided `AgentCard`. @@ -228,12 +230,22 @@ def create( if consumers: all_consumers.extend(consumers) + all_extensions = self._config.extensions.copy() + if extensions: + all_extensions.extend(extensions) + self._config.extensions = all_extensions + transport = self._registry[transport_protocol]( card, transport_url, self._config, interceptors or [] ) return BaseClient( - card, self._config, transport, all_consumers, interceptors or [] + card, + self._config, + transport, + all_consumers, + interceptors or [], + all_extensions, ) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index b407e959b..f8396b7c4 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -220,6 +220,9 @@ async def get_task( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): @@ -240,6 +243,9 @@ async def cancel_task( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request(payload, modified_kwargs) response = CancelTaskResponse.model_validate(response_data) if isinstance(response.root, JSONRPCErrorResponse): @@ -262,6 +268,9 @@ async def set_task_callback( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request(payload, modified_kwargs) response = SetTaskPushNotificationConfigResponse.model_validate( response_data @@ -286,6 +295,9 @@ async def get_task_callback( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskPushNotificationConfigResponse.model_validate( response_data @@ -310,7 +322,9 @@ async def resubscribe( get_http_args(context), context, ) - + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) modified_kwargs.setdefault('timeout', None) async with aconnect_sse( @@ -366,6 +380,9 @@ async def get_card( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_request( payload, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 6b838e455..2e9373646 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -212,6 +212,9 @@ async def get_task( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}', {'historyLength': str(request.history_length)} @@ -237,6 +240,9 @@ async def cancel_task( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_post_request( f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs ) @@ -260,6 +266,9 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( payload, get_http_args(context), context ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', payload, @@ -285,6 +294,9 @@ async def get_task_callback( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', {}, @@ -305,12 +317,13 @@ async def resubscribe( """Reconnects to get task updates.""" http_kwargs = get_http_args(context) or {} http_kwargs.setdefault('timeout', None) + modified_kwargs = update_extension_header(http_kwargs, self.extensions) async with aconnect_sse( self.httpx_client, 'GET', f'{self.url}/v1/tasks/{request.id}:subscribe', - **http_kwargs, + **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): @@ -353,6 +366,9 @@ async def get_card( get_http_args(context), context, ) + modified_kwargs = update_extension_header( + modified_kwargs, self.extensions + ) response_data = await self._send_get_request( '/v1/card', {}, modified_kwargs ) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index c1251f1c4..a0688a0b3 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -55,6 +55,7 @@ def base_client(sample_agent_card, mock_transport): transport=mock_transport, consumers=[], middleware=[], + extensions=[], ) diff --git a/tests/client/transports/test_utils.py b/tests/client/transports/test_utils.py index b7f2ff62d..f47c0465f 100644 --- a/tests/client/transports/test_utils.py +++ b/tests/client/transports/test_utils.py @@ -5,36 +5,47 @@ class TestUtils: - def test_update_extension_header_no_initial_headers(self): - extensions = ['test_extension_1', 'test_extension_2'] - - http_kwargs = {} - result_kwargs = update_extension_header(http_kwargs, extensions) - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - } - assert len(actual_extensions_list) == 2 - assert actual_extensions == expected_extensions - @pytest.mark.parametrize( - 'existing_header, expected_count', + 'extensions, existing_header, expected_extensions, expected_count', [ - ('test_extension_1, test_extension_2', 3), - ('test_extension_1,test_extension_2', 3), - ('test_extension_1', 3), + ( + ['test_extension_1', 'test_extension_2'], + '', + { + 'test_extension_1', + 'test_extension_2', + }, + 2, + ), + ( + ['test_extension_1', 'test_extension_2'], + 'test_extension_2, test_extension_3', + { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + }, + 3, + ), + ( + ['test_extension_1', 'test_extension_2'], + 'test_extension_3', + { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + }, + 3, + ), ], ) def test_update_extension_header_merge_with_existing_extensions( self, + extensions: list[str], existing_header: str, + expected_extensions: set[str], expected_count: int, ): - extensions = ['test_extension_2', 'test_extension_3'] http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} result_kwargs = update_extension_header(http_kwargs, extensions) @@ -42,33 +53,18 @@ def test_update_extension_header_merge_with_existing_extensions( actual_extensions_list = [e.strip() for e in header_value.split(',')] actual_extensions = set(actual_extensions_list) - expected_extensions = { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - } assert len(actual_extensions_list) == expected_count assert actual_extensions == expected_extensions def test_update_extension_header_with_other_headers(self): - extensions = ['test_extension_1'] + extensions = ['test_extension'] http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = update_extension_header(http_kwargs, extensions) headers = result_kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension_1' + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' assert headers['X_Other'] == 'Test' - def test_update_extension_header_with_existing_other_headers(self): - extensions = ['test_extension_1'] - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, extensions) - assert ( - result_kwargs['headers'][HTTP_EXTENSION_HEADER] - == 'test_extension_1' - ) - assert result_kwargs['headers']['X_Other'] == 'Test' - def test_update_extension_header_no_extensions(self): http_kwargs = {'headers': {'X_Other': 'Test'}} result_kwargs = update_extension_header(http_kwargs, None) From 948d3f3d390309d72dc648c3b0239b72e3f8ab0c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 3 Nov 2025 16:10:41 +0000 Subject: [PATCH 13/26] fix: correct order of extension header updates in update_extension_header function and remove set to list transformation. --- src/a2a/client/transports/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py index 27f45d7af..9fd3bf7e9 100644 --- a/src/a2a/client/transports/utils.py +++ b/src/a2a/client/transports/utils.py @@ -18,7 +18,7 @@ def update_extension_header( existing_extensions = [ e.strip() for e in existing_extensions_str.split(',') if e.strip() ] - all_extensions = set(extensions) - all_extensions.update(existing_extensions) - headers[HTTP_EXTENSION_HEADER] = ','.join(list(all_extensions)) + all_extensions = set(existing_extensions) + all_extensions.update(extensions) + headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) return http_kwargs From 4073c0bcb6b52a52eb995b3287ac6d5fe0315c54 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 4 Nov 2025 12:56:43 +0000 Subject: [PATCH 14/26] refactor: streamline extension handling in BaseClient and GrpcTransport. Add helper __merge_extensions method in utils.py --- src/a2a/client/base_client.py | 3 +- src/a2a/client/transports/grpc.py | 66 ++++--- src/a2a/client/transports/utils.py | 22 ++- tests/client/test_grpc_client.py | 307 ----------------------------- 4 files changed, 60 insertions(+), 338 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 8f79dea97..8a62f9392 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -42,8 +42,7 @@ def __init__( self._card = card self._config = config self._transport = transport - if self._extensions: - self._config.extensions = self._extensions + self._config.extensions = self._extensions async def send_message( self, diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 2ba7ac4ae..f84fe3117 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -13,11 +13,13 @@ "'pip install a2a-sdk[grpc]'" ) from e +from google.protobuf import struct_pb2 + from a2a.client.client import ClientConfig from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport -from a2a.client.transports.utils import update_extension_header +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -59,24 +61,6 @@ def __init__( ) self.extensions = extensions - def _get_metadata( - self, context: ClientCallContext | None - ) -> list[tuple[str, str]]: - http_kwargs: dict[str, Any] = {} - if context and context.state.get('grpc_metadata'): - # Convert existing metadata to headers format for update_extension_header - http_kwargs['headers'] = { - k: v for k, v in context.state['grpc_metadata'] - } - - updated_kwargs = update_extension_header(http_kwargs, self.extensions) - - metadata = [] - if 'headers' in updated_kwargs: - metadata.extend(updated_kwargs['headers'].items()) - - return metadata - @classmethod def create( cls, @@ -105,7 +89,7 @@ async def send_message( ), metadata=proto_utils.ToProto.metadata(request.metadata), ), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -128,7 +112,7 @@ async def send_message_streaming( ), metadata=proto_utils.ToProto.metadata(request.metadata), ), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(request.metadata), ) while True: response = await stream.read() @@ -136,6 +120,7 @@ async def send_message_streaming( break yield proto_utils.FromProto.stream_response(response) + # iva todo TaskIdParams has metadata async def resubscribe( self, request: TaskIdParams, *, context: ClientCallContext | None = None ) -> AsyncGenerator[ @@ -144,7 +129,7 @@ async def resubscribe( """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) while True: response = await stream.read() @@ -152,6 +137,7 @@ async def resubscribe( break yield proto_utils.FromProto.stream_response(response) + # iva todo TaskQueryParams has metadata async def get_task( self, request: TaskQueryParams, @@ -164,7 +150,7 @@ async def get_task( name=f'tasks/{request.id}', history_length=request.history_length, ), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) return proto_utils.FromProto.task(task) @@ -177,7 +163,7 @@ async def cancel_task( """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) return proto_utils.FromProto.task(task) @@ -196,10 +182,11 @@ async def set_task_callback( request ), ), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) return proto_utils.FromProto.task_push_notification_config(config) + # iva todo GetTaskPushNotificationConfigParams has metadata async def get_task_callback( self, request: GetTaskPushNotificationConfigParams, @@ -211,7 +198,7 @@ async def get_task_callback( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ), - metadata=self._get_metadata(context), + metadata=self._update_extension_metadata(), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -235,6 +222,33 @@ async def get_card( self._needs_extended_card = False return card + def _update_extension_metadata( + self, metadata: dict[str, Any] | None = None + ) -> struct_pb2.Struct | None: + """Gets the metadata for the gRPC call.""" + if metadata is None: + metadata = {} + + if self.extensions: + existing_extensions_str = str( + metadata.get(HTTP_EXTENSION_HEADER, '') + ) + existing_extensions = { + e.strip() + for e in existing_extensions_str.split(',') + if e.strip() + } + + all_extensions = set(existing_extensions) + all_extensions.update(self.extensions) + + if all_extensions: + metadata[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) + elif HTTP_EXTENSION_HEADER in metadata: + del metadata[HTTP_EXTENSION_HEADER] + + return proto_utils.ToProto.metadata(metadata if metadata else None) + async def close(self) -> None: """Closes the gRPC channel.""" await self.channel.close() diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py index 9fd3bf7e9..fd409001d 100644 --- a/src/a2a/client/transports/utils.py +++ b/src/a2a/client/transports/utils.py @@ -8,6 +8,20 @@ def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: return context.state.get('http_kwargs') if context else None +def __merge_extensions( + existing_extensions: str, new_extensions: list[str] +) -> str: + existing_extensions_list = [ + e.strip() for e in existing_extensions.split(',') if e.strip() + ] + existing_extensions_set = set(existing_extensions_list) + new_extensions = [ + ext for ext in new_extensions if ext not in existing_extensions_set + ] + + return ','.join(existing_extensions_list + new_extensions) + + def update_extension_header( http_kwargs: dict[str, Any], extensions: list[str] | None ) -> dict[str, Any]: @@ -15,10 +29,12 @@ def update_extension_header( return http_kwargs headers = http_kwargs.setdefault('headers', {}) existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - existing_extensions = [ + """existing_extensions = [ e.strip() for e in existing_extensions_str.split(',') if e.strip() ] all_extensions = set(existing_extensions) - all_extensions.update(extensions) - headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) + all_extensions.update(extensions)""" + headers[HTTP_EXTENSION_HEADER] = __merge_extensions( + existing_extensions_str, extensions + ) return http_kwargs diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index d111b219a..3ed672285 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -385,37 +385,6 @@ async def test_set_task_callback_with_invalid_task( ) -@pytest.mark.asyncio -async def test_get_task_callback_with_valid_task( - grpc_transport: GrpcTransport, - mock_grpc_stub: AsyncMock, - sample_task_push_notification_config: TaskPushNotificationConfig, -): - """Test retrieving a task push notification config with a valid task id.""" - mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( - proto_utils.ToProto.task_push_notification_config( - sample_task_push_notification_config - ) - ) - params = GetTaskPushNotificationConfigParams( - id=sample_task_push_notification_config.task_id, - push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, - ) - - response = await grpc_transport.get_task_callback(params) - - mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with( - a2a_pb2.GetTaskPushNotificationConfigRequest( - name=( - f'tasks/{params.id}/' - f'pushNotificationConfigs/{params.push_notification_config_id}' - ), - ), - metadata=[], - ) - assert response.task_id == sample_task_push_notification_config.task_id - - @pytest.mark.asyncio async def test_get_task_callback_with_invalid_task( grpc_transport: GrpcTransport, @@ -443,279 +412,3 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) - - -class TestGrpcTransportExtensions: - def test_get_metadata_no_initial(self, sample_agent_card: AgentCard): - extensions = ['test_extension_1', 'test_extension_2'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - metadata = transport._get_metadata(None) - metadata_dict = dict(metadata) - assert HTTP_EXTENSION_HEADER in metadata_dict - actual_extensions = set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) - assert actual_extensions == set(extensions) - - def test_get_metadata_with_existing(self, sample_agent_card: AgentCard): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - context = ClientCallContext( - state={'grpc_metadata': [('x-other', 'Test')]} - ) - metadata = transport._get_metadata(context) - metadata_dict = dict(metadata) - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - assert metadata_dict['x-other'] == 'Test' - - @pytest.mark.parametrize( - 'existing_header, expected_extensions', - [ - ( - 'test_extension_2, test_extension_3', - {'test_extension_1', 'test_extension_2', 'test_extension_3'}, - ), - ( - 'test_extension_3', - {'test_extension_1', 'test_extension_2', 'test_extension_3'}, - ), - ], - ) - def test_get_metadata_merge_with_existing( - self, - sample_agent_card: AgentCard, - existing_header: str, - expected_extensions: set, - ): - extensions = ['test_extension_1', 'test_extension_2'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - context = ClientCallContext( - state={'grpc_metadata': [(HTTP_EXTENSION_HEADER, existing_header)]} - ) - metadata = transport._get_metadata(context) - metadata_dict = dict(metadata) - assert HTTP_EXTENSION_HEADER in metadata_dict - actual_extensions = set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) - assert actual_extensions == expected_extensions - - def test_get_metadata_no_extensions(self, sample_agent_card: AgentCard): - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=None, - ) - context = ClientCallContext( - state={'grpc_metadata': [('x-other', 'Test')]} - ) - metadata = transport._get_metadata(context) - metadata_dict = dict(metadata) - assert HTTP_EXTENSION_HEADER not in metadata_dict - assert metadata_dict['x-other'] == 'Test' - - def test_get_metadata_empty_extensions(self, sample_agent_card: AgentCard): - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=[], - ) - context = ClientCallContext( - state={'grpc_metadata': [('x-other', 'Test')]} - ) - metadata = transport._get_metadata(context) - metadata_dict = dict(metadata) - assert HTTP_EXTENSION_HEADER not in metadata_dict - assert metadata_dict['x-other'] == 'Test' - - @pytest.mark.asyncio - async def test_send_message_with_extensions( - self, - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_message_send_params: MessageSendParams, - ): - extensions = ['test_extension_1', 'test_extension_2'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( - msg=proto_utils.ToProto.message(sample_message_send_params.message) - ) - - await transport.send_message(sample_message_send_params) - - mock_grpc_stub.SendMessage.assert_awaited_once() - _, kwargs = mock_grpc_stub.SendMessage.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert set(metadata_dict[HTTP_EXTENSION_HEADER].split(',')) == set( - extensions - ) - - @pytest.mark.asyncio - async def test_send_message_streaming_with_extensions( - self, - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_message_send_params: MessageSendParams, - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - stream = MagicMock() - stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) - mock_grpc_stub.SendStreamingMessage.return_value = stream - - async for _ in transport.send_message_streaming( - sample_message_send_params - ): - pass - - mock_grpc_stub.SendStreamingMessage.assert_called_once() - _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_resubscribe_with_extensions( - self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - stream = MagicMock() - stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) - mock_grpc_stub.TaskSubscription.return_value = stream - - async for _ in transport.resubscribe(TaskIdParams(id='task-1')): - pass - - mock_grpc_stub.TaskSubscription.assert_called_once() - _, kwargs = mock_grpc_stub.TaskSubscription.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_get_task_with_extensions( - self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.GetTask.return_value = a2a_pb2.Task() - - await transport.get_task(TaskQueryParams(id='task-1')) - - mock_grpc_stub.GetTask.assert_awaited_once() - _, kwargs = mock_grpc_stub.GetTask.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_cancel_task_with_extensions( - self, mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.CancelTask.return_value = a2a_pb2.Task() - - await transport.cancel_task(TaskIdParams(id='task-1')) - - mock_grpc_stub.CancelTask.assert_awaited_once() - _, kwargs = mock_grpc_stub.CancelTask.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_set_task_callback_with_extensions( - self, - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_task_push_notification_config: TaskPushNotificationConfig, - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( - proto_utils.ToProto.task_push_notification_config( - sample_task_push_notification_config - ) - ) - - await transport.set_task_callback(sample_task_push_notification_config) - - mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once() - _, kwargs = mock_grpc_stub.CreateTaskPushNotificationConfig.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' - - @pytest.mark.asyncio - async def test_get_task_callback_with_extensions( - self, - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_task_push_notification_config: TaskPushNotificationConfig, - ): - extensions = ['test_extension'] - transport = GrpcTransport( - channel=AsyncMock(), - agent_card=sample_agent_card, - extensions=extensions, - ) - transport.stub = mock_grpc_stub - mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( - proto_utils.ToProto.task_push_notification_config( - sample_task_push_notification_config - ) - ) - - await transport.get_task_callback( - GetTaskPushNotificationConfigParams( - id=sample_task_push_notification_config.task_id, - push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, - ) - ) - - mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once() - _, kwargs = mock_grpc_stub.GetTaskPushNotificationConfig.call_args - metadata_dict = dict(kwargs['metadata']) - assert HTTP_EXTENSION_HEADER in metadata_dict - assert metadata_dict[HTTP_EXTENSION_HEADER] == 'test_extension' From c5cea2cd7ee3b0d1aa5637f2cd0fe966da04975c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 5 Nov 2025 14:03:29 +0000 Subject: [PATCH 15/26] Move transport tests from tests/client to tests/client/transport. Add method "update_extension_metadata" to tansports/utils.py. --- src/a2a/client/base_client.py | 2 +- src/a2a/client/client.py | 3 - src/a2a/client/transports/grpc.py | 61 +++------ src/a2a/client/transports/utils.py | 39 +++--- .../{ => transports}/test_grpc_client.py | 125 +++++++++++++++--- .../{ => transports}/test_jsonrpc_client.py | 0 .../{ => transports}/test_rest_client.py | 0 tests/client/transports/test_utils.py | 76 +++++++++-- 8 files changed, 215 insertions(+), 91 deletions(-) rename tests/client/{ => transports}/test_grpc_client.py (79%) rename tests/client/{ => transports}/test_jsonrpc_client.py (100%) rename tests/client/{ => transports}/test_rest_client.py (100%) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 8a62f9392..18e85d50c 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -40,9 +40,9 @@ def __init__( ): super().__init__(consumers, middleware, extensions) self._card = card + config.extensions = extensions self._config = config self._transport = transport - self._config.extensions = self._extensions async def send_message( self, diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index d9ba6580b..c19f508ca 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -93,7 +93,6 @@ def __init__( self, consumers: list[Consumer] | None = None, middleware: list[ClientCallInterceptor] | None = None, - # iva todo - it can override value from the config, if it is provided extensions: list[str] | None = None, ): """Initializes the client with consumers and middleware. @@ -118,8 +117,6 @@ async def send_message( request: Message, *, context: ClientCallContext | None = None, - # iva todo add optional extensions- it can override value from the config, if it is provided - # and to the other ones as well ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index f84fe3117..eb37b5e68 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -1,7 +1,6 @@ import logging from collections.abc import AsyncGenerator -from typing import Any try: @@ -13,13 +12,12 @@ "'pip install a2a-sdk[grpc]'" ) from e -from google.protobuf import struct_pb2 from a2a.client.client import ClientConfig from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.client.transports.utils import update_extension_metadata from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -87,9 +85,10 @@ async def send_message( configuration=proto_utils.ToProto.message_send_configuration( request.configuration ), - metadata=proto_utils.ToProto.metadata(request.metadata), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ), - metadata=self._update_extension_metadata(), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -110,9 +109,10 @@ async def send_message_streaming( configuration=proto_utils.ToProto.message_send_configuration( request.configuration ), - metadata=proto_utils.ToProto.metadata(request.metadata), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ), - metadata=self._update_extension_metadata(request.metadata), ) while True: response = await stream.read() @@ -120,7 +120,6 @@ async def send_message_streaming( break yield proto_utils.FromProto.stream_response(response) - # iva todo TaskIdParams has metadata async def resubscribe( self, request: TaskIdParams, *, context: ClientCallContext | None = None ) -> AsyncGenerator[ @@ -129,7 +128,9 @@ async def resubscribe( """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), - metadata=self._update_extension_metadata(), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ) while True: response = await stream.read() @@ -137,7 +138,6 @@ async def resubscribe( break yield proto_utils.FromProto.stream_response(response) - # iva todo TaskQueryParams has metadata async def get_task( self, request: TaskQueryParams, @@ -150,7 +150,9 @@ async def get_task( name=f'tasks/{request.id}', history_length=request.history_length, ), - metadata=self._update_extension_metadata(), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ) return proto_utils.FromProto.task(task) @@ -163,7 +165,6 @@ async def cancel_task( """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), - metadata=self._update_extension_metadata(), ) return proto_utils.FromProto.task(task) @@ -182,11 +183,12 @@ async def set_task_callback( request ), ), - metadata=self._update_extension_metadata(), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ) return proto_utils.FromProto.task_push_notification_config(config) - # iva todo GetTaskPushNotificationConfigParams has metadata async def get_task_callback( self, request: GetTaskPushNotificationConfigParams, @@ -198,7 +200,9 @@ async def get_task_callback( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ), - metadata=self._update_extension_metadata(), + metadata=update_extension_metadata( + request.metadata, self.extensions + ), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -222,33 +226,6 @@ async def get_card( self._needs_extended_card = False return card - def _update_extension_metadata( - self, metadata: dict[str, Any] | None = None - ) -> struct_pb2.Struct | None: - """Gets the metadata for the gRPC call.""" - if metadata is None: - metadata = {} - - if self.extensions: - existing_extensions_str = str( - metadata.get(HTTP_EXTENSION_HEADER, '') - ) - existing_extensions = { - e.strip() - for e in existing_extensions_str.split(',') - if e.strip() - } - - all_extensions = set(existing_extensions) - all_extensions.update(self.extensions) - - if all_extensions: - metadata[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) - elif HTTP_EXTENSION_HEADER in metadata: - del metadata[HTTP_EXTENSION_HEADER] - - return proto_utils.ToProto.metadata(metadata if metadata else None) - async def close(self) -> None: """Closes the gRPC channel.""" await self.channel.close() diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py index fd409001d..b86dacdcc 100644 --- a/src/a2a/client/transports/utils.py +++ b/src/a2a/client/transports/utils.py @@ -1,7 +1,10 @@ from typing import Any +from google.protobuf import struct_pb2 + from a2a.client.middleware import ClientCallContext from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.utils import proto_utils def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: @@ -14,27 +17,33 @@ def __merge_extensions( existing_extensions_list = [ e.strip() for e in existing_extensions.split(',') if e.strip() ] - existing_extensions_set = set(existing_extensions_list) new_extensions = [ - ext for ext in new_extensions if ext not in existing_extensions_set + ext for ext in new_extensions if ext not in existing_extensions_list ] - return ','.join(existing_extensions_list + new_extensions) def update_extension_header( http_kwargs: dict[str, Any], extensions: list[str] | None ) -> dict[str, Any]: - if not extensions: - return http_kwargs - headers = http_kwargs.setdefault('headers', {}) - existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - """existing_extensions = [ - e.strip() for e in existing_extensions_str.split(',') if e.strip() - ] - all_extensions = set(existing_extensions) - all_extensions.update(extensions)""" - headers[HTTP_EXTENSION_HEADER] = __merge_extensions( - existing_extensions_str, extensions - ) + if extensions: + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + headers[HTTP_EXTENSION_HEADER] = __merge_extensions( + existing_extensions_str, extensions + ) return http_kwargs + + +def update_extension_metadata( + metadata: dict[str, Any] | None, extensions: list[str] | None +) -> struct_pb2.Struct | None: + if metadata is None: + metadata = {} + if extensions: + existing_extensions_str = str(metadata.get(HTTP_EXTENSION_HEADER, '')) + metadata[HTTP_EXTENSION_HEADER] = __merge_extensions( + existing_extensions_str, extensions + ) + return proto_utils.ToProto.metadata(metadata if metadata else None) diff --git a/tests/client/test_grpc_client.py b/tests/client/transports/test_grpc_client.py similarity index 79% rename from tests/client/test_grpc_client.py rename to tests/client/transports/test_grpc_client.py index 3ed672285..b12430f68 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -1,9 +1,8 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import grpc import pytest -from a2a.client.middleware import ClientCallContext from a2a.client.transports.grpc import GrpcTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc @@ -42,8 +41,6 @@ def mock_grpc_stub() -> AsyncMock: stub.CancelTask = AsyncMock() stub.CreateTaskPushNotificationConfig = AsyncMock() stub.GetTaskPushNotificationConfig = AsyncMock() - stub.TaskSubscription = MagicMock() - stub.GetAgentCard = AsyncMock() return stub @@ -132,7 +129,7 @@ def sample_task_status_update_event() -> TaskStatusUpdateEvent: @pytest.fixture def sample_task_artifact_update_event( - sample_artifact, + sample_artifact: Artifact, ) -> TaskArtifactUpdateEvent: """Provides a sample TaskArtifactUpdateEvent.""" return TaskArtifactUpdateEvent( @@ -183,7 +180,7 @@ async def test_send_message_task_response( mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_task: Task, -): +) -> None: """Test send_message that returns a Task.""" mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( task=proto_utils.ToProto.task(sample_task) @@ -202,7 +199,7 @@ async def test_send_message_message_response( mock_grpc_stub: AsyncMock, sample_message_send_params: MessageSendParams, sample_message: Message, -): +) -> None: """Test send_message that returns a Message.""" mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( msg=proto_utils.ToProto.message(sample_message) @@ -227,7 +224,7 @@ async def test_send_message_streaming( # noqa: PLR0913 sample_task: Task, sample_task_status_update_event: TaskStatusUpdateEvent, sample_task_artifact_update_event: TaskArtifactUpdateEvent, -): +) -> None: """Test send_message_streaming that yields responses.""" stream = MagicMock() stream.read = AsyncMock( @@ -272,7 +269,7 @@ async def test_send_message_streaming( # noqa: PLR0913 @pytest.mark.asyncio async def test_get_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test retrieving a task.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) params = TaskQueryParams(id=sample_task.id) @@ -282,8 +279,7 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=None - ), - metadata=[], + ) ) assert response.id == sample_task.id @@ -291,7 +287,7 @@ async def test_get_task( @pytest.mark.asyncio async def test_get_task_with_history( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test retrieving a task with history.""" mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task) history_len = 10 @@ -302,15 +298,14 @@ async def test_get_task_with_history( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=history_len - ), - metadata=[], + ) ) @pytest.mark.asyncio async def test_cancel_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task -): +) -> None: """Test cancelling a task.""" cancelled_task = sample_task.model_copy() cancelled_task.status.state = TaskState.canceled @@ -322,8 +317,7 @@ async def test_cancel_task( response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( - a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), - metadata=[], + a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') ) assert response.status.state == TaskState.canceled @@ -333,7 +327,7 @@ async def test_set_task_callback_with_valid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test setting a task push notification config with a valid task id.""" mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = ( proto_utils.ToProto.task_push_notification_config( @@ -352,8 +346,7 @@ async def test_set_task_callback_with_valid_task( config=proto_utils.ToProto.task_push_notification_config( sample_task_push_notification_config ), - ), - metadata=[], + ) ) assert response.task_id == sample_task_push_notification_config.task_id @@ -363,7 +356,7 @@ async def test_set_task_callback_with_invalid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test setting a task push notification config with an invalid task id.""" mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( name=( @@ -385,12 +378,42 @@ async def test_set_task_callback_with_invalid_task( ) +@pytest.mark.asyncio +async def test_get_task_callback_with_valid_task( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_task_push_notification_config: TaskPushNotificationConfig, +) -> None: + """Test retrieving a task push notification config with a valid task id.""" + mock_grpc_stub.GetTaskPushNotificationConfig.return_value = ( + proto_utils.ToProto.task_push_notification_config( + sample_task_push_notification_config + ) + ) + params = GetTaskPushNotificationConfigParams( + id=sample_task_push_notification_config.task_id, + push_notification_config_id=sample_task_push_notification_config.push_notification_config.id, + ) + + response = await grpc_transport.get_task_callback(params) + + mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with( + a2a_pb2.GetTaskPushNotificationConfigRequest( + name=( + f'tasks/{params.id}/' + f'pushNotificationConfigs/{params.push_notification_config_id}' + ), + ) + ) + assert response.task_id == sample_task_push_notification_config.task_id + + @pytest.mark.asyncio async def test_get_task_callback_with_invalid_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task_push_notification_config: TaskPushNotificationConfig, -): +) -> None: """Test retrieving a task push notification config with an invalid task id.""" mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig( name=( @@ -412,3 +435,61 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) + + +@pytest.mark.asyncio +async def test_send_message_with_extensions( + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_message_send_params: MessageSendParams, + sample_task: Task, +) -> None: + """Test send_message with extensions.""" + extensions = ['test_extension_1', 'test_extension_2'] + channel = AsyncMock() + transport = GrpcTransport( + channel=channel, agent_card=sample_agent_card, extensions=extensions + ) + transport.stub = mock_grpc_stub + + mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( + task=proto_utils.ToProto.task(sample_task) + ) + + await transport.send_message(sample_message_send_params) + + mock_grpc_stub.SendMessage.assert_awaited_once() + args, _ = mock_grpc_stub.SendMessage.call_args + request = args[0] + metadata = proto_utils.FromProto.metadata(request.metadata) + assert HTTP_EXTENSION_HEADER in metadata + assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2' + + +@pytest.mark.asyncio +async def test_send_message_streaming_with_extensions( + mock_grpc_stub: AsyncMock, + sample_agent_card: AgentCard, + sample_message_send_params: MessageSendParams, +) -> None: + """Test send_message_streaming with extensions.""" + extensions = ['test_extension_1', 'test_extension_2'] + channel = AsyncMock() + transport = GrpcTransport( + channel=channel, agent_card=sample_agent_card, extensions=extensions + ) + transport.stub = mock_grpc_stub + + stream = MagicMock() + stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) + mock_grpc_stub.SendStreamingMessage.return_value = stream + + async for _ in transport.send_message_streaming(sample_message_send_params): + pass + + mock_grpc_stub.SendStreamingMessage.assert_called_once() + args, _ = mock_grpc_stub.SendStreamingMessage.call_args + request = args[0] + metadata = proto_utils.FromProto.metadata(request.metadata) + assert HTTP_EXTENSION_HEADER in metadata + assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2' diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py similarity index 100% rename from tests/client/test_jsonrpc_client.py rename to tests/client/transports/test_jsonrpc_client.py diff --git a/tests/client/test_rest_client.py b/tests/client/transports/test_rest_client.py similarity index 100% rename from tests/client/test_rest_client.py rename to tests/client/transports/test_rest_client.py diff --git a/tests/client/transports/test_utils.py b/tests/client/transports/test_utils.py index f47c0465f..d691befb3 100644 --- a/tests/client/transports/test_utils.py +++ b/tests/client/transports/test_utils.py @@ -1,7 +1,11 @@ import pytest from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.client.transports.utils import update_extension_header +from a2a.client.transports.utils import ( + update_extension_header, + update_extension_metadata, +) +from a2a.utils import proto_utils class TestUtils: @@ -65,14 +69,70 @@ def test_update_extension_header_with_other_headers(self): assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' assert headers['X_Other'] == 'Test' - def test_update_extension_header_no_extensions(self): + @pytest.mark.parametrize('extensions', [(None), ([])]) + def test_update_extension_header_no_or_empty_extensions(self, extensions): http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, None) + result_kwargs = update_extension_header(http_kwargs, extensions) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' - def test_update_extension_header_empty_extensions(self): - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, []) - assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] - assert result_kwargs['headers']['X_Other'] == 'Test' + @pytest.mark.parametrize( + 'extensions, existing_metadata, expected_extensions, expected_count', + [ + ( + ['test_extension_1', 'test_extension_2'], + None, + {'test_extension_1', 'test_extension_2'}, + 2, + ), + ( + ['test_extension_1', 'test_extension_2'], + {HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3'}, + {'test_extension_1', 'test_extension_2', 'test_extension_3'}, + 3, + ), + ( + ['test_extension_1', 'test_extension_2'], + {HTTP_EXTENSION_HEADER: 'test_extension_3'}, + {'test_extension_1', 'test_extension_2', 'test_extension_3'}, + 3, + ), + ( + ['test_extension_1'], + {'X_Other': 'Test'}, + {'test_extension_1'}, + 1, + ), + ], + ) + def test_update_extension_metadata( + self, + extensions: list[str], + existing_metadata: dict[str, str], + expected_extensions: set[str], + expected_count: int, + ): + result_metadata = update_extension_metadata( + existing_metadata, extensions + ) + assert result_metadata is not None + metadata_dict = proto_utils.FromProto.metadata(result_metadata) + header_value = metadata_dict.get(HTTP_EXTENSION_HEADER, '') + actual_extensions_list = [ + e.strip() for e in header_value.split(',') if e.strip() + ] + actual_extensions = set(actual_extensions_list) + + assert len(actual_extensions_list) == expected_count + assert actual_extensions == expected_extensions + if existing_metadata and 'X_Other' in existing_metadata: + assert metadata_dict['X_Other'] == existing_metadata['X_Other'] + + @pytest.mark.parametrize('extensions', [(None), ([])]) + def test_update_extension_metadata_no_or_empty_extensions(self, extensions): + metadata = {'X_Other': 'Test'} + result_metadata = update_extension_metadata(metadata, extensions) + assert result_metadata is not None + metadata_dict = proto_utils.FromProto.metadata(result_metadata) + assert HTTP_EXTENSION_HEADER not in metadata_dict + assert metadata_dict['X_Other'] == 'Test' From 6e856d5842822c78f73934c32aa2514358d89dfc Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 6 Nov 2025 12:15:26 +0000 Subject: [PATCH 16/26] feat: enhance GrpcTransport to manage extensions in metadata and update related tests --- src/a2a/client/transports/grpc.py | 35 +++--- src/a2a/client/transports/utils.py | 27 ++--- tests/client/transports/test_grpc_client.py | 128 ++++++++++---------- tests/client/transports/test_utils.py | 66 +--------- 4 files changed, 90 insertions(+), 166 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index eb37b5e68..ee81da99b 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -17,7 +17,7 @@ from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport -from a2a.client.transports.utils import update_extension_metadata +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -59,6 +59,12 @@ def __init__( ) self.extensions = extensions + def _get_grpc_metadata(self) -> list[tuple[str, str]] | None: + """Creates gRPC metadata for extensions.""" + if not self.extensions: + return None + return [(HTTP_EXTENSION_HEADER, ', '.join(self.extensions))] + @classmethod def create( cls, @@ -85,10 +91,9 @@ async def send_message( configuration=proto_utils.ToProto.message_send_configuration( request.configuration ), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=proto_utils.ToProto.metadata(request.metadata), ), + metadata=self._get_grpc_metadata(), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -109,10 +114,9 @@ async def send_message_streaming( configuration=proto_utils.ToProto.message_send_configuration( request.configuration ), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=proto_utils.ToProto.metadata(request.metadata), ), + metadata=self._get_grpc_metadata(), ) while True: response = await stream.read() @@ -128,9 +132,7 @@ async def resubscribe( """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=self._get_grpc_metadata(), ) while True: response = await stream.read() @@ -150,9 +152,7 @@ async def get_task( name=f'tasks/{request.id}', history_length=request.history_length, ), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=self._get_grpc_metadata(), ) return proto_utils.FromProto.task(task) @@ -165,6 +165,7 @@ async def cancel_task( """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), + metadata=self._get_grpc_metadata(), ) return proto_utils.FromProto.task(task) @@ -183,9 +184,7 @@ async def set_task_callback( request ), ), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=self._get_grpc_metadata(), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -200,9 +199,7 @@ async def get_task_callback( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ), - metadata=update_extension_metadata( - request.metadata, self.extensions - ), + metadata=self._get_grpc_metadata(), ) return proto_utils.FromProto.task_push_notification_config(config) diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py index b86dacdcc..467c3ba3a 100644 --- a/src/a2a/client/transports/utils.py +++ b/src/a2a/client/transports/utils.py @@ -1,10 +1,7 @@ from typing import Any -from google.protobuf import struct_pb2 - from a2a.client.middleware import ClientCallContext from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.utils import proto_utils def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: @@ -30,20 +27,14 @@ def update_extension_header( headers = http_kwargs.setdefault('headers', {}) existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - headers[HTTP_EXTENSION_HEADER] = __merge_extensions( - existing_extensions_str, extensions - ) - return http_kwargs - + existing_extensions_list = [ + e.strip() for e in existing_extensions_str.split(',') if e.strip() + ] + new_extensions = [ + ext for ext in extensions if ext not in existing_extensions_list + ] -def update_extension_metadata( - metadata: dict[str, Any] | None, extensions: list[str] | None -) -> struct_pb2.Struct | None: - if metadata is None: - metadata = {} - if extensions: - existing_extensions_str = str(metadata.get(HTTP_EXTENSION_HEADER, '')) - metadata[HTTP_EXTENSION_HEADER] = __merge_extensions( - existing_extensions_str, extensions + headers[HTTP_EXTENSION_HEADER] = ','.join( + existing_extensions_list + new_extensions ) - return proto_utils.ToProto.metadata(metadata if metadata else None) + return http_kwargs diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index b12430f68..dce554f69 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -65,7 +65,14 @@ def grpc_transport( ) -> GrpcTransport: """Provides a GrpcTransport instance.""" channel = AsyncMock() - transport = GrpcTransport(channel=channel, agent_card=sample_agent_card) + transport = GrpcTransport( + channel=channel, + agent_card=sample_agent_card, + extensions=[ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ], + ) transport.stub = mock_grpc_stub return transport @@ -189,6 +196,13 @@ async def test_send_message_task_response( response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ] assert isinstance(response, Task) assert response.id == sample_task.id @@ -208,6 +222,13 @@ async def test_send_message_message_response( response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ] assert isinstance(response, Message) assert response.message_id == sample_message.message_id assert get_text_parts(response.parts) == get_text_parts( @@ -256,6 +277,13 @@ async def test_send_message_streaming( # noqa: PLR0913 ] mock_grpc_stub.SendStreamingMessage.assert_called_once() + _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ] assert isinstance(responses[0], Message) assert responses[0].message_id == sample_message.message_id assert isinstance(responses[1], Task) @@ -279,7 +307,13 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=None - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ], ) assert response.id == sample_task.id @@ -298,7 +332,13 @@ async def test_get_task_with_history( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=history_len - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ], ) @@ -317,7 +357,13 @@ async def test_cancel_task( response = await grpc_transport.cancel_task(params) mock_grpc_stub.CancelTask.assert_awaited_once_with( - a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ], ) assert response.status.state == TaskState.canceled @@ -346,7 +392,13 @@ async def test_set_task_callback_with_valid_task( config=proto_utils.ToProto.task_push_notification_config( sample_task_push_notification_config ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -403,7 +455,13 @@ async def test_get_task_callback_with_valid_task( f'tasks/{params.id}/' f'pushNotificationConfigs/{params.push_notification_config_id}' ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -435,61 +493,3 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) - - -@pytest.mark.asyncio -async def test_send_message_with_extensions( - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_message_send_params: MessageSendParams, - sample_task: Task, -) -> None: - """Test send_message with extensions.""" - extensions = ['test_extension_1', 'test_extension_2'] - channel = AsyncMock() - transport = GrpcTransport( - channel=channel, agent_card=sample_agent_card, extensions=extensions - ) - transport.stub = mock_grpc_stub - - mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( - task=proto_utils.ToProto.task(sample_task) - ) - - await transport.send_message(sample_message_send_params) - - mock_grpc_stub.SendMessage.assert_awaited_once() - args, _ = mock_grpc_stub.SendMessage.call_args - request = args[0] - metadata = proto_utils.FromProto.metadata(request.metadata) - assert HTTP_EXTENSION_HEADER in metadata - assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2' - - -@pytest.mark.asyncio -async def test_send_message_streaming_with_extensions( - mock_grpc_stub: AsyncMock, - sample_agent_card: AgentCard, - sample_message_send_params: MessageSendParams, -) -> None: - """Test send_message_streaming with extensions.""" - extensions = ['test_extension_1', 'test_extension_2'] - channel = AsyncMock() - transport = GrpcTransport( - channel=channel, agent_card=sample_agent_card, extensions=extensions - ) - transport.stub = mock_grpc_stub - - stream = MagicMock() - stream.read = AsyncMock(side_effect=[grpc.aio.EOF]) - mock_grpc_stub.SendStreamingMessage.return_value = stream - - async for _ in transport.send_message_streaming(sample_message_send_params): - pass - - mock_grpc_stub.SendStreamingMessage.assert_called_once() - args, _ = mock_grpc_stub.SendStreamingMessage.call_args - request = args[0] - metadata = proto_utils.FromProto.metadata(request.metadata) - assert HTTP_EXTENSION_HEADER in metadata - assert metadata[HTTP_EXTENSION_HEADER] == 'test_extension_1,test_extension_2' diff --git a/tests/client/transports/test_utils.py b/tests/client/transports/test_utils.py index d691befb3..e08cff97c 100644 --- a/tests/client/transports/test_utils.py +++ b/tests/client/transports/test_utils.py @@ -1,10 +1,7 @@ import pytest from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.client.transports.utils import ( - update_extension_header, - update_extension_metadata, -) +from a2a.client.transports.utils import update_extension_header from a2a.utils import proto_utils @@ -75,64 +72,3 @@ def test_update_extension_header_no_or_empty_extensions(self, extensions): result_kwargs = update_extension_header(http_kwargs, extensions) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' - - @pytest.mark.parametrize( - 'extensions, existing_metadata, expected_extensions, expected_count', - [ - ( - ['test_extension_1', 'test_extension_2'], - None, - {'test_extension_1', 'test_extension_2'}, - 2, - ), - ( - ['test_extension_1', 'test_extension_2'], - {HTTP_EXTENSION_HEADER: 'test_extension_2, test_extension_3'}, - {'test_extension_1', 'test_extension_2', 'test_extension_3'}, - 3, - ), - ( - ['test_extension_1', 'test_extension_2'], - {HTTP_EXTENSION_HEADER: 'test_extension_3'}, - {'test_extension_1', 'test_extension_2', 'test_extension_3'}, - 3, - ), - ( - ['test_extension_1'], - {'X_Other': 'Test'}, - {'test_extension_1'}, - 1, - ), - ], - ) - def test_update_extension_metadata( - self, - extensions: list[str], - existing_metadata: dict[str, str], - expected_extensions: set[str], - expected_count: int, - ): - result_metadata = update_extension_metadata( - existing_metadata, extensions - ) - assert result_metadata is not None - metadata_dict = proto_utils.FromProto.metadata(result_metadata) - header_value = metadata_dict.get(HTTP_EXTENSION_HEADER, '') - actual_extensions_list = [ - e.strip() for e in header_value.split(',') if e.strip() - ] - actual_extensions = set(actual_extensions_list) - - assert len(actual_extensions_list) == expected_count - assert actual_extensions == expected_extensions - if existing_metadata and 'X_Other' in existing_metadata: - assert metadata_dict['X_Other'] == existing_metadata['X_Other'] - - @pytest.mark.parametrize('extensions', [(None), ([])]) - def test_update_extension_metadata_no_or_empty_extensions(self, extensions): - metadata = {'X_Other': 'Test'} - result_metadata = update_extension_metadata(metadata, extensions) - assert result_metadata is not None - metadata_dict = proto_utils.FromProto.metadata(result_metadata) - assert HTTP_EXTENSION_HEADER not in metadata_dict - assert metadata_dict['X_Other'] == 'Test' From edd7982339beec22f3aa451c6c5d3186441edd12 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 6 Nov 2025 14:49:51 +0000 Subject: [PATCH 17/26] refactor: remove unused __merge_extensions function from utils.py --- src/a2a/client/transports/utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py index 467c3ba3a..540c5d3ae 100644 --- a/src/a2a/client/transports/utils.py +++ b/src/a2a/client/transports/utils.py @@ -7,19 +7,6 @@ def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: return context.state.get('http_kwargs') if context else None - -def __merge_extensions( - existing_extensions: str, new_extensions: list[str] -) -> str: - existing_extensions_list = [ - e.strip() for e in existing_extensions.split(',') if e.strip() - ] - new_extensions = [ - ext for ext in new_extensions if ext not in existing_extensions_list - ] - return ','.join(existing_extensions_list + new_extensions) - - def update_extension_header( http_kwargs: dict[str, Any], extensions: list[str] | None ) -> dict[str, Any]: From 48ea2ae277fde1ee4a9de4cc133d76bc24aff0e7 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 12 Nov 2025 14:50:49 +0000 Subject: [PATCH 18/26] feat: update extension handling in transports and tests, migrate utility functions to common module --- src/a2a/client/base_client.py | 1 - src/a2a/client/transports/jsonrpc.py | 25 ++++--- src/a2a/client/transports/rest.py | 23 +++--- src/a2a/client/transports/utils.py | 27 ------- src/a2a/extensions/common.py | 17 +++++ .../client/transports/test_jsonrpc_client.py | 15 ++-- tests/client/transports/test_rest_client.py | 15 ++-- tests/client/transports/test_utils.py | 74 ------------------- tests/extensions/test_common.py | 70 ++++++++++++++++++ 9 files changed, 136 insertions(+), 131 deletions(-) delete mode 100644 src/a2a/client/transports/utils.py delete mode 100644 tests/client/transports/test_utils.py diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 18e85d50c..c2af4bf2a 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -40,7 +40,6 @@ def __init__( ): super().__init__(consumers, middleware, extensions) self._card = card - config.extensions = extensions self._config = config self._transport = transport diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index f8396b7c4..fcd41b7cc 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -18,7 +18,7 @@ ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport -from a2a.client.transports.utils import get_http_args, update_extension_header +from a2a.extensions.common import update_extension_header from a2a.types import ( AgentCard, CancelTaskRequest, @@ -106,6 +106,11 @@ async def _apply_interceptors( ) return final_request_payload, final_http_kwargs + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + async def send_message( self, request: MessageSendParams, @@ -117,7 +122,7 @@ async def send_message( payload, modified_kwargs = await self._apply_interceptors( 'message/send', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -144,7 +149,7 @@ async def send_message_streaming( payload, modified_kwargs = await self._apply_interceptors( 'message/stream', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) @@ -217,7 +222,7 @@ async def get_task( payload, modified_kwargs = await self._apply_interceptors( 'tasks/get', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -240,7 +245,7 @@ async def cancel_task( payload, modified_kwargs = await self._apply_interceptors( 'tasks/cancel', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -265,7 +270,7 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/set', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -292,7 +297,7 @@ async def get_task_callback( payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/get', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -319,7 +324,7 @@ async def resubscribe( payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', rpc_request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -363,7 +368,7 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=get_http_args(context) + http_kwargs=self._get_http_args(context) ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -377,7 +382,7 @@ async def get_card( payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 2e9373646..e36afffef 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -13,7 +13,7 @@ from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport -from a2a.client.transports.utils import get_http_args, update_extension_header +from a2a.extensions.common import update_extension_header from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, @@ -76,6 +76,11 @@ async def _apply_interceptors( # TODO: Implement interceptors for other transports return final_request_payload, final_http_kwargs + def _get_http_args( + self, context: ClientCallContext | None + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs') if context else None + async def _prepare_send_message( self, request: MessageSendParams, context: ClientCallContext | None ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -93,7 +98,7 @@ async def _prepare_send_message( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -209,7 +214,7 @@ async def get_task( """Retrieves the current state and history of a specific task.""" _payload, modified_kwargs = await self._apply_interceptors( request.model_dump(mode='json', exclude_none=True), - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -237,7 +242,7 @@ async def cancel_task( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -264,7 +269,7 @@ async def set_task_callback( ) payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( - payload, get_http_args(context), context + payload, self._get_http_args(context), context ) modified_kwargs = update_extension_header( modified_kwargs, self.extensions @@ -291,7 +296,7 @@ async def get_task_callback( payload = MessageToDict(pb) payload, modified_kwargs = await self._apply_interceptors( payload, - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( @@ -315,7 +320,7 @@ async def resubscribe( Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - http_kwargs = get_http_args(context) or {} + http_kwargs = self._get_http_args(context) or {} http_kwargs.setdefault('timeout', None) modified_kwargs = update_extension_header(http_kwargs, self.extensions) @@ -351,7 +356,7 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card( - http_kwargs=get_http_args(context) + http_kwargs=self._get_http_args(context) ) self._needs_extended_card = ( card.supports_authenticated_extended_card @@ -363,7 +368,7 @@ async def get_card( _, modified_kwargs = await self._apply_interceptors( {}, - get_http_args(context), + self._get_http_args(context), context, ) modified_kwargs = update_extension_header( diff --git a/src/a2a/client/transports/utils.py b/src/a2a/client/transports/utils.py deleted file mode 100644 index 540c5d3ae..000000000 --- a/src/a2a/client/transports/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any - -from a2a.client.middleware import ClientCallContext -from a2a.extensions.common import HTTP_EXTENSION_HEADER - - -def get_http_args(context: ClientCallContext | None) -> dict[str, Any] | None: - return context.state.get('http_kwargs') if context else None - -def update_extension_header( - http_kwargs: dict[str, Any], extensions: list[str] | None -) -> dict[str, Any]: - if extensions: - headers = http_kwargs.setdefault('headers', {}) - existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - - existing_extensions_list = [ - e.strip() for e in existing_extensions_str.split(',') if e.strip() - ] - new_extensions = [ - ext for ext in extensions if ext not in existing_extensions_list - ] - - headers[HTTP_EXTENSION_HEADER] = ','.join( - existing_extensions_list + new_extensions - ) - return http_kwargs diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 2f752caa7..8f8749235 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,3 +1,5 @@ +from typing import Any + from a2a.types import AgentCard, AgentExtension @@ -25,3 +27,18 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: return ext return None + + +def update_extension_header( + http_kwargs: dict[str, Any], extensions: list[str] | None +) -> dict[str, Any]: + if extensions: + headers = http_kwargs.setdefault('headers', {}) + existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') + + existing_extensions = get_requested_extensions( + [existing_extensions_str] + ) + all_extensions = existing_extensions.union(extensions) + headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) + return http_kwargs diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 3ebd50173..f1cec6eb5 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -794,7 +794,10 @@ async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): """Test that send_message adds extension headers when extensions are provided.""" - extensions = ['test_extension_1', 'test_extension_2'] + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, @@ -827,8 +830,8 @@ async def test_send_message_with_extensions( actual_extensions = set(actual_extensions_list) expected_extensions = { - 'test_extension_1', - 'test_extension_2', + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', } assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions @@ -842,7 +845,7 @@ async def test_send_message_streaming_with_extensions( mock_agent_card: MagicMock, ): """Test X-A2A-Extensions header in send_message_streaming.""" - extensions = ['test_extension'] + extensions = ['https://example.com/test-ext/v1'] client = JsonRpcTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, @@ -866,4 +869,6 @@ async def test_send_message_streaming_with_extensions( headers = kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1' + ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index d236fad24..ed8a25e20 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -38,7 +38,10 @@ async def test_send_message_with_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): """Test that send_message adds extensions to headers.""" - extensions = ['test_extension_1', 'test_extension_2'] + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] client = RestTransport( httpx_client=mock_httpx_client, extensions=extensions, @@ -71,8 +74,8 @@ async def test_send_message_with_extensions( actual_extensions = set(actual_extensions_list) expected_extensions = { - 'test_extension_1', - 'test_extension_2', + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', } assert len(actual_extensions_list) == 2 assert actual_extensions == expected_extensions @@ -86,7 +89,7 @@ async def test_send_message_streaming_with_extensions( mock_agent_card: MagicMock, ): """Test X-A2A-Extensions header in send_message_streaming.""" - extensions = ['test_extension'] + extensions = ['https://example.com/test-ext/v1'] client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, @@ -110,4 +113,6 @@ async def test_send_message_streaming_with_extensions( headers = kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1' + ) diff --git a/tests/client/transports/test_utils.py b/tests/client/transports/test_utils.py deleted file mode 100644 index e08cff97c..000000000 --- a/tests/client/transports/test_utils.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest - -from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.client.transports.utils import update_extension_header -from a2a.utils import proto_utils - - -class TestUtils: - @pytest.mark.parametrize( - 'extensions, existing_header, expected_extensions, expected_count', - [ - ( - ['test_extension_1', 'test_extension_2'], - '', - { - 'test_extension_1', - 'test_extension_2', - }, - 2, - ), - ( - ['test_extension_1', 'test_extension_2'], - 'test_extension_2, test_extension_3', - { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - }, - 3, - ), - ( - ['test_extension_1', 'test_extension_2'], - 'test_extension_3', - { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - }, - 3, - ), - ], - ) - def test_update_extension_header_merge_with_existing_extensions( - self, - extensions: list[str], - existing_header: str, - expected_extensions: set[str], - expected_count: int, - ): - http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} - result_kwargs = update_extension_header(http_kwargs, extensions) - - header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - - assert len(actual_extensions_list) == expected_count - assert actual_extensions == expected_extensions - - def test_update_extension_header_with_other_headers(self): - extensions = ['test_extension'] - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, extensions) - headers = result_kwargs.get('headers', {}) - assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' - assert headers['X_Other'] == 'Test' - - @pytest.mark.parametrize('extensions', [(None), ([])]) - def test_update_extension_header_no_or_empty_extensions(self, extensions): - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, extensions) - assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] - assert result_kwargs['headers']['X_Other'] == 'Test' diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 137e64c9a..3b3489f97 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -1,6 +1,9 @@ +import pytest from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, find_extension_by_uri, get_requested_extensions, + update_extension_header, ) from a2a.types import AgentCapabilities, AgentCard, AgentExtension @@ -56,3 +59,70 @@ def test_find_extension_by_uri_no_extensions(): ) assert find_extension_by_uri(card, 'foo') is None + + +@pytest.mark.parametrize( + 'extensions, existing_header, expected_extensions, expected_count', + [ + ( + ['test_extension_1', 'test_extension_2'], + '', + { + 'test_extension_1', + 'test_extension_2', + }, + 2, + ), + ( + ['test_extension_1', 'test_extension_2'], + 'test_extension_2, test_extension_3', + { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + }, + 3, + ), + ( + ['test_extension_1', 'test_extension_2'], + 'test_extension_3', + { + 'test_extension_1', + 'test_extension_2', + 'test_extension_3', + }, + 3, + ), + ], +) +def test_update_extension_header_merge_with_existing_extensions( + extensions: list[str], + existing_header: str, + expected_extensions: set[str], + expected_count: int, +): + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} + result_kwargs = update_extension_header(http_kwargs, extensions) + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + assert len(actual_extensions_list) == expected_count + assert actual_extensions == expected_extensions + + +def test_update_extension_header_with_other_headers(): + extensions = ['test_extension'] + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' + assert headers['X_Other'] == 'Test' + + +@pytest.mark.parametrize('extensions', [(None), ([])]) +def test_update_extension_header_no_or_empty_extensions(extensions): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' From 5b475622369be36e3b53e2729d288c4d0d2f6766 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 12 Nov 2025 16:27:27 +0000 Subject: [PATCH 19/26] fix(client): clarify the purpose of the extensions parameter in Client constructor --- src/a2a/client/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 90a620396..0ecdfea5c 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -100,6 +100,7 @@ def __init__( Args: consumers: A list of callables to process events from the agent. middleware: A list of interceptors to process requests and responses. + extensions: A list of extension URIs the client supports. """ if middleware is None: middleware = [] From a2eeb7b8b42b631f4bf96b6176968053b7427681 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 13 Nov 2025 11:00:27 +0000 Subject: [PATCH 20/26] feat: enhance extension handling across client and transport layers --- src/a2a/client/base_client.py | 49 +++++++++--- src/a2a/client/client.py | 16 ++-- src/a2a/client/client_factory.py | 1 - src/a2a/client/transports/base.py | 8 ++ src/a2a/client/transports/grpc.py | 34 ++++++--- src/a2a/client/transports/jsonrpc.py | 40 ++++++---- src/a2a/client/transports/rest.py | 45 +++++++---- src/a2a/extensions/common.py | 14 ++-- tests/client/test_base_client.py | 1 - tests/client/test_client_factory.py | 4 + tests/client/transports/test_grpc_client.py | 70 ++++++++++++++++-- tests/extensions/test_common.py | 82 +++++++++++++-------- 12 files changed, 261 insertions(+), 103 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 77feea889..5719bc1b0 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -37,9 +37,8 @@ def __init__( transport: ClientTransport, consumers: list[Consumer], middleware: list[ClientCallInterceptor], - extensions: list[str], ): - super().__init__(consumers, middleware, extensions) + super().__init__(consumers, middleware) self._card = card self._config = config self._transport = transport @@ -50,6 +49,7 @@ async def send_message( *, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the agent. @@ -61,6 +61,7 @@ async def send_message( request: The message to send to the agent. context: The client call context. request_metadata: Extensions Metadata attached to the request. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` or a final `Message` response. @@ -80,7 +81,7 @@ async def send_message( if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport.send_message( - params, context=context + params, context=context, extensions=extensions ) result = ( (response, None) if isinstance(response, Task) else response @@ -90,7 +91,9 @@ async def send_message( return tracker = ClientTaskManager() - stream = self._transport.send_message_streaming(params, context=context) + stream = self._transport.send_message_streaming( + params, context=context, extensions=extensions + ) first_event = await anext(stream) # The response from a server may be either exactly one Message or a @@ -127,74 +130,91 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task. Args: request: The `TaskQueryParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object representing the current state of the task. """ - return await self._transport.get_task(request, context=context) + return await self._transport.get_task( + request, context=context, extensions=extensions + ) async def cancel_task( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task. Args: request: The `TaskIdParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object containing the updated task status. """ - return await self._transport.cancel_task(request, context=context) + return await self._transport.cancel_task( + request, context=context, extensions=extensions + ) async def set_task_callback( self, request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task. Args: request: The `TaskPushNotificationConfig` object with the new configuration. context: The client call context. + extensions: List of extensions to be activated. Returns: The created or updated `TaskPushNotificationConfig` object. """ - return await self._transport.set_task_callback(request, context=context) + return await self._transport.set_task_callback( + request, context=context, extensions=extensions + ) async def get_task_callback( self, request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. Args: request: The `GetTaskPushNotificationConfigParams` object specifying the task. context: The client call context. + extensions: List of extensions to be activated. Returns: A `TaskPushNotificationConfig` object containing the configuration. """ - return await self._transport.get_task_callback(request, context=context) + return await self._transport.get_task_callback( + request, context=context, extensions=extensions + ) async def resubscribe( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream. @@ -203,6 +223,7 @@ async def resubscribe( Args: request: Parameters to identify the task to resubscribe to. context: The client call context. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` objects. @@ -220,12 +241,15 @@ async def resubscribe( # we should never see Message updates, despite the typing of the service # definition indicating it may be possible. async for event in self._transport.resubscribe( - request, context=context + request, context=context, extensions=extensions ): yield await self._process_response(tracker, event) async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card. @@ -234,11 +258,14 @@ async def get_card( Args: context: The client call context. + extensions: List of extensions to be activated. Returns: The `AgentCard` for the agent. """ - card = await self._transport.get_card(context=context) + card = await self._transport.get_card( + context=context, extensions=extensions + ) self._card = card return card diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 0ecdfea5c..fd97b4d14 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -93,24 +93,19 @@ def __init__( self, consumers: list[Consumer] | None = None, middleware: list[ClientCallInterceptor] | None = None, - extensions: list[str] | None = None, ): """Initializes the client with consumers and middleware. Args: consumers: A list of callables to process events from the agent. middleware: A list of interceptors to process requests and responses. - extensions: A list of extension URIs the client supports. """ if middleware is None: middleware = [] if consumers is None: consumers = [] - if extensions is None: - extensions = [] self._consumers = consumers self._middleware = middleware - self._extensions = extensions @abstractmethod async def send_message( @@ -119,6 +114,7 @@ async def send_message( *, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. @@ -137,6 +133,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -146,6 +143,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -155,6 +153,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -164,6 +163,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -173,6 +173,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream.""" return @@ -180,7 +181,10 @@ async def resubscribe( @abstractmethod async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index b6f534d6f..ead31b5ec 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -245,7 +245,6 @@ def create( transport, all_consumers, interceptors or [], - all_extensions, ) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 3573cb7ca..8f114d95d 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -25,6 +25,7 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" @@ -34,6 +35,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -47,6 +49,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -56,6 +59,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -65,6 +69,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -74,6 +79,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -83,6 +89,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -95,6 +102,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the AgentCard.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index ee81da99b..b23e379cd 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -59,8 +59,13 @@ def __init__( ) self.extensions = extensions - def _get_grpc_metadata(self) -> list[tuple[str, str]] | None: + def _get_grpc_metadata( + self, + extensions: list[str] | None = None, + ) -> list[tuple[str, str]] | None: """Creates gRPC metadata for extensions.""" + if extensions: + self.extensions = extensions if not self.extensions: return None return [(HTTP_EXTENSION_HEADER, ', '.join(self.extensions))] @@ -83,6 +88,7 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" response = await self.stub.SendMessage( @@ -93,7 +99,7 @@ async def send_message( ), metadata=proto_utils.ToProto.metadata(request.metadata), ), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -104,6 +110,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -116,7 +123,7 @@ async def send_message_streaming( ), metadata=proto_utils.ToProto.metadata(request.metadata), ), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -125,14 +132,18 @@ async def send_message_streaming( yield proto_utils.FromProto.stream_response(response) async def resubscribe( - self, request: TaskIdParams, *, context: ClientCallContext | None = None + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -145,6 +156,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" task = await self.stub.GetTask( @@ -152,7 +164,7 @@ async def get_task( name=f'tasks/{request.id}', history_length=request.history_length, ), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -161,11 +173,12 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -174,6 +187,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( @@ -184,7 +198,7 @@ async def set_task_callback( request ), ), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -193,13 +207,14 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ), - metadata=self._get_grpc_metadata(), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -207,6 +222,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index fcd41b7cc..f135eb441 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -116,6 +116,7 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" rpc_request = SendMessageRequest(params=request, id=str(uuid4())) @@ -125,8 +126,8 @@ async def send_message( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request(payload, modified_kwargs) response = SendMessageResponse.model_validate(response_data) @@ -139,6 +140,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -153,8 +155,8 @@ async def send_message_streaming( context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) @@ -216,6 +218,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" rpc_request = GetTaskRequest(params=request, id=str(uuid4())) @@ -225,8 +228,8 @@ async def get_task( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskResponse.model_validate(response_data) @@ -239,6 +242,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) @@ -248,8 +252,8 @@ async def cancel_task( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request(payload, modified_kwargs) response = CancelTaskResponse.model_validate(response_data) @@ -262,6 +266,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" rpc_request = SetTaskPushNotificationConfigRequest( @@ -273,8 +278,8 @@ async def set_task_callback( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request(payload, modified_kwargs) response = SetTaskPushNotificationConfigResponse.model_validate( @@ -289,6 +294,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" rpc_request = GetTaskPushNotificationConfigRequest( @@ -300,8 +306,8 @@ async def get_task_callback( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskPushNotificationConfigResponse.model_validate( @@ -316,6 +322,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -327,8 +334,8 @@ async def resubscribe( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) modified_kwargs.setdefault('timeout', None) @@ -362,6 +369,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -385,8 +393,8 @@ async def get_card( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_request( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index e36afffef..7fc41e459 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -82,7 +82,10 @@ def _get_http_args( return context.state.get('http_kwargs') if context else None async def _prepare_send_message( - self, request: MessageSendParams, context: ClientCallContext | None + self, + request: MessageSendParams, + context: ClientCallContext | None, + extensions: list[str] | None = None, ) -> tuple[dict[str, Any], dict[str, Any]]: pb = a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), @@ -101,8 +104,8 @@ async def _prepare_send_message( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) return payload, modified_kwargs @@ -111,10 +114,11 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) response_data = await self._send_post_request( '/v1/message:send', payload, modified_kwargs @@ -128,12 +132,13 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Sends a streaming message request to the agent and yields responses as they arrive.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) modified_kwargs.setdefault('timeout', None) @@ -210,6 +215,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" _payload, modified_kwargs = await self._apply_interceptors( @@ -217,8 +223,8 @@ async def get_task( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}', @@ -236,6 +242,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') @@ -245,8 +252,8 @@ async def cancel_task( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_post_request( f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs @@ -260,6 +267,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( @@ -271,8 +279,8 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( payload, self._get_http_args(context), context ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', @@ -288,6 +296,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" pb = a2a_pb2.GetTaskPushNotificationConfigRequest( @@ -299,8 +308,8 @@ async def get_task_callback( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', @@ -316,13 +325,16 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" http_kwargs = self._get_http_args(context) or {} http_kwargs.setdefault('timeout', None) - modified_kwargs = update_extension_header(http_kwargs, self.extensions) + modified_kwargs, self.extensions = update_extension_header( + http_kwargs, self.extensions, extensions + ) async with aconnect_sse( self.httpx_client, @@ -350,6 +362,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -371,8 +384,8 @@ async def get_card( self._get_http_args(context), context, ) - modified_kwargs = update_extension_header( - modified_kwargs, self.extensions + modified_kwargs, self.extensions = update_extension_header( + modified_kwargs, self.extensions, extensions ) response_data = await self._send_get_request( '/v1/card', {}, modified_kwargs diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 8f8749235..d665c66fa 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -30,15 +30,19 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: def update_extension_header( - http_kwargs: dict[str, Any], extensions: list[str] | None -) -> dict[str, Any]: - if extensions: + http_kwargs: dict[str, Any], + active_extensions: list[str] | None, + new_extensions: list[str] | None, +) -> tuple[dict[str, Any], list[str] | None]: + if new_extensions: + active_extensions = new_extensions + if active_extensions: headers = http_kwargs.setdefault('headers', {}) existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') existing_extensions = get_requested_extensions( [existing_extensions_str] ) - all_extensions = existing_extensions.union(extensions) + all_extensions = existing_extensions.union(active_extensions) headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) - return http_kwargs + return http_kwargs, active_extensions diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index 19ef049e4..f5ab25432 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -57,7 +57,6 @@ def base_client( transport=mock_transport, consumers=[], middleware=[], - extensions=[], ) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 847b256fa..16a1433fb 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -39,12 +39,14 @@ def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): TransportProtocol.jsonrpc, TransportProtocol.http_json, ], + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) assert client._transport.url == 'http://primary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_selects_secondary_transport_url( @@ -65,12 +67,14 @@ def test_client_factory_selects_secondary_transport_url( TransportProtocol.jsonrpc, ], use_client_preference=True, + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, RestTransport) assert client._transport.url == 'http://secondary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_server_preference(base_agent_card: AgentCard): diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index dce554f69..0e4fba94c 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -353,17 +353,14 @@ async def test_cancel_task( cancelled_task ) params = TaskIdParams(id=sample_task.id) - - response = await grpc_transport.cancel_task(params) + extensions = [ + 'https://example.com/test-ext/v3', + ] + response = await grpc_transport.cancel_task(params, extensions=extensions) mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), - metadata=[ - ( - HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', - ) - ], + metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')], ) assert response.status.state == TaskState.canceled @@ -493,3 +490,60 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) + + +@pytest.mark.parametrize( + 'initial_extensions, input_extensions, expected_metadata, expected_extensions', + [ + ( + None, + None, + None, + None, + ), # Case 1: No initial, No input + ( + ['ext1'], + None, + [(HTTP_EXTENSION_HEADER, 'ext1')], + ['ext1'], + ), # Case 2: Initial, No input + ( + None, + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ['ext2'], + ), # Case 3: No initial, Input + ( + ['ext1'], + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ['ext2'], + ), # Case 4: Initial, Input (override) + ( + ['ext1'], + ['ext2', 'ext3'], + [(HTTP_EXTENSION_HEADER, 'ext2, ext3')], + ['ext2', 'ext3'], + ), # Case 5: Initial, Multiple inputs (override) + ( + ['ext1', 'ext2'], + ['ext3'], + [(HTTP_EXTENSION_HEADER, 'ext3')], + ['ext3'], + ), # Case 6: Multiple initial, Single input (override) + ], +) +def test_get_grpc_metadata( + grpc_transport: GrpcTransport, + initial_extensions: list[str] | None, + input_extensions: list[str] | None, + expected_metadata: list[tuple[str, str]] | None, + expected_extensions: list[str] | None, +) -> None: + """Tests _get_grpc_metadata for correct metadata generation and self.extensions update.""" + grpc_transport.extensions = initial_extensions + + metadata = grpc_transport._get_grpc_metadata(extensions=input_extensions) + + assert metadata == expected_metadata + assert grpc_transport.extensions == expected_extensions diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 3b3489f97..753386bcb 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -62,67 +62,89 @@ def test_find_extension_by_uri_no_extensions(): @pytest.mark.parametrize( - 'extensions, existing_header, expected_extensions, expected_count', + 'active_extensions, new_extensions, existing_header, expected_extensions, expected_count, expected_returned_extensions', [ ( - ['test_extension_1', 'test_extension_2'], - '', + ['ext1', 'ext2'], # active_extensions + None, # new_extensions + '', # existing_header { - 'test_extension_1', - 'test_extension_2', - }, - 2, - ), + 'ext1', + 'ext2', + }, # expected_extensions + 2, # expected_count + ['ext1', 'ext2'], # expected_returned_extensions + ), # Case 1: Active extensions, no new extensions, and no existing header. ( - ['test_extension_1', 'test_extension_2'], - 'test_extension_2, test_extension_3', + ['ext1', 'ext2'], # active_extensions + None, # new_extensions + 'ext2, ext3', # existing_header { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - }, - 3, - ), + 'ext1', + 'ext2', + 'ext3', + }, # expected_extensions + 3, # expected_count + ['ext1', 'ext2'], # expected_returned_extensions + ), # Case 2: Active extensions, no new extensions, with an existing header containing overlapping and new extensions. ( - ['test_extension_1', 'test_extension_2'], - 'test_extension_3', + ['ext1', 'ext2'], # active_extensions + None, # new_extensions + 'ext3', # existing_header { - 'test_extension_1', - 'test_extension_2', - 'test_extension_3', - }, - 3, - ), + 'ext1', + 'ext2', + 'ext3', + }, # expected_extensions + 3, # expected_count + ['ext1', 'ext2'], # expected_returned_extensions + ), # Case 3: Active extensions, no new extensions, with an existing header containing different extensions. + ( + ['ext1', 'ext2'], # active_extensions + ['ext3'], # new_extensions + 'ext4', # existing_header + { + 'ext3', + 'ext4', + }, # expected_extensions + 2, # expected_count + ['ext3'], # expected_returned_extensions + ), # Case 4: Active extensions, new extensions provided, and an existing header. New extensions should override active and merge with existing. ], ) def test_update_extension_header_merge_with_existing_extensions( - extensions: list[str], + active_extensions: list[str], + new_extensions: list[str], existing_header: str, expected_extensions: set[str], expected_count: int, + expected_returned_extensions: list[str], ): http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} - result_kwargs = update_extension_header(http_kwargs, extensions) + result_kwargs, actual_returned_extensions = update_extension_header( + http_kwargs, active_extensions, new_extensions + ) header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] actual_extensions_list = [e.strip() for e in header_value.split(',')] actual_extensions = set(actual_extensions_list) assert len(actual_extensions_list) == expected_count assert actual_extensions == expected_extensions + assert actual_returned_extensions == expected_returned_extensions def test_update_extension_header_with_other_headers(): - extensions = ['test_extension'] + extensions = ['ext'] http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, extensions) + result_kwargs, _ = update_extension_header(http_kwargs, extensions, None) headers = result_kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers - assert headers[HTTP_EXTENSION_HEADER] == 'test_extension' + assert headers[HTTP_EXTENSION_HEADER] == 'ext' assert headers['X_Other'] == 'Test' @pytest.mark.parametrize('extensions', [(None), ([])]) def test_update_extension_header_no_or_empty_extensions(extensions): http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, extensions) + result_kwargs, _ = update_extension_header(http_kwargs, extensions, None) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' From 07465416aed9a166aa4a774c547399b41f9132ac Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 13 Nov 2025 11:08:27 +0000 Subject: [PATCH 21/26] feat: add extensions parameter documentation in ClientFactory and update header function docstring --- src/a2a/client/client_factory.py | 2 ++ src/a2a/extensions/common.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index ead31b5ec..fabd7270f 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -145,6 +145,7 @@ async def connect( # noqa: PLR0913 A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. extra_transports: Additional transport protocols to enable when constructing the client. + extensions: List of extensions to be activated. Returns: A `Client` object. @@ -190,6 +191,7 @@ def create( interceptors: A list of interceptors to use for each request. These are used for things like attaching credentials or http headers to all outbound requests. + extensions: List of extensions to be activated. Returns: A `Client` object. diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index d665c66fa..4471e8654 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -34,6 +34,7 @@ def update_extension_header( active_extensions: list[str] | None, new_extensions: list[str] | None, ) -> tuple[dict[str, Any], list[str] | None]: + """Update the X-A2A-Extensions header and update active extensions.""" if new_extensions: active_extensions = new_extensions if active_extensions: From 1337dcf38137e9aa1f3c8581605a4056499bedef Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 14 Nov 2025 11:01:48 +0000 Subject: [PATCH 22/26] refactor: streamline extension handling in transport classes and update related tests --- src/a2a/client/transports/grpc.py | 11 +-- src/a2a/client/transports/jsonrpc.py | 40 ++++++---- src/a2a/client/transports/rest.py | 35 +++++---- src/a2a/extensions/common.py | 19 ++--- tests/client/transports/test_grpc_client.py | 35 ++++----- .../client/transports/test_jsonrpc_client.py | 11 ++- tests/client/transports/test_rest_client.py | 11 ++- tests/extensions/test_common.py | 74 +++++++------------ 8 files changed, 110 insertions(+), 126 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index b23e379cd..4e27953af 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -64,11 +64,11 @@ def _get_grpc_metadata( extensions: list[str] | None = None, ) -> list[tuple[str, str]] | None: """Creates gRPC metadata for extensions.""" - if extensions: - self.extensions = extensions - if not self.extensions: - return None - return [(HTTP_EXTENSION_HEADER, ', '.join(self.extensions))] + if extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(extensions))] + if self.extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))] + return None @classmethod def create( @@ -233,6 +233,7 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), + metadata=self._get_grpc_metadata(extensions), ) card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index f135eb441..71a8fd8c7 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -126,8 +126,9 @@ async def send_message( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request(payload, modified_kwargs) response = SendMessageResponse.model_validate(response_data) @@ -155,8 +156,9 @@ async def send_message_streaming( context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) @@ -228,8 +230,9 @@ async def get_task( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskResponse.model_validate(response_data) @@ -252,8 +255,9 @@ async def cancel_task( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request(payload, modified_kwargs) response = CancelTaskResponse.model_validate(response_data) @@ -278,8 +282,9 @@ async def set_task_callback( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request(payload, modified_kwargs) response = SetTaskPushNotificationConfigResponse.model_validate( @@ -306,8 +311,9 @@ async def get_task_callback( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskPushNotificationConfigResponse.model_validate( @@ -334,8 +340,9 @@ async def resubscribe( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) modified_kwargs.setdefault('timeout', None) @@ -393,8 +400,9 @@ async def get_card( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_request( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 7fc41e459..2c859763a 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -104,8 +104,9 @@ async def _prepare_send_message( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) return payload, modified_kwargs @@ -223,8 +224,9 @@ async def get_task( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}', @@ -252,8 +254,9 @@ async def cancel_task( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_post_request( f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs @@ -279,8 +282,9 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( payload, self._get_http_args(context), context ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', @@ -308,8 +312,9 @@ async def get_task_callback( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', @@ -332,8 +337,9 @@ async def resubscribe( """Reconnects to get task updates.""" http_kwargs = self._get_http_args(context) or {} http_kwargs.setdefault('timeout', None) - modified_kwargs, self.extensions = update_extension_header( - http_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + http_kwargs, + extensions if extensions is not None else self.extensions, ) async with aconnect_sse( @@ -384,8 +390,9 @@ async def get_card( self._get_http_args(context), context, ) - modified_kwargs, self.extensions = update_extension_header( - modified_kwargs, self.extensions, extensions + modified_kwargs = update_extension_header( + modified_kwargs, + extensions if extensions is not None else self.extensions, ) response_data = await self._send_get_request( '/v1/card', {}, modified_kwargs diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 4471e8654..2b3a5cfaa 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -31,19 +31,10 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: def update_extension_header( http_kwargs: dict[str, Any], - active_extensions: list[str] | None, - new_extensions: list[str] | None, -) -> tuple[dict[str, Any], list[str] | None]: + extensions: list[str] | None, +) -> dict[str, Any]: """Update the X-A2A-Extensions header and update active extensions.""" - if new_extensions: - active_extensions = new_extensions - if active_extensions: + if extensions is not None: headers = http_kwargs.setdefault('headers', {}) - existing_extensions_str = headers.get(HTTP_EXTENSION_HEADER, '') - - existing_extensions = get_requested_extensions( - [existing_extensions_str] - ) - all_extensions = existing_extensions.union(active_extensions) - headers[HTTP_EXTENSION_HEADER] = ','.join(all_extensions) - return http_kwargs, active_extensions + headers[HTTP_EXTENSION_HEADER] = ','.join(extensions) + return http_kwargs diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 0e4fba94c..111e44ba6 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -193,14 +193,17 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_transport.send_message(sample_message_send_params) + response = await grpc_transport.send_message( + sample_message_send_params, + extensions=['https://example.com/test-ext/v3'], + ) mock_grpc_stub.SendMessage.assert_awaited_once() _, kwargs = mock_grpc_stub.SendMessage.call_args assert kwargs['metadata'] == [ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v3', ) ] assert isinstance(response, Task) @@ -226,7 +229,7 @@ async def test_send_message_message_response( assert kwargs['metadata'] == [ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ] assert isinstance(response, Message) @@ -281,7 +284,7 @@ async def test_send_message_streaming( # noqa: PLR0913 assert kwargs['metadata'] == [ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ] assert isinstance(responses[0], Message) @@ -311,7 +314,7 @@ async def test_get_task( metadata=[ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], ) @@ -336,7 +339,7 @@ async def test_get_task_with_history( metadata=[ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], ) @@ -393,7 +396,7 @@ async def test_set_task_callback_with_valid_task( metadata=[ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], ) @@ -456,7 +459,7 @@ async def test_get_task_callback_with_valid_task( metadata=[ ( HTTP_EXTENSION_HEADER, - 'https://example.com/test-ext/v1, https://example.com/test-ext/v2', + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', ) ], ) @@ -493,43 +496,37 @@ async def test_get_task_callback_with_invalid_task( @pytest.mark.parametrize( - 'initial_extensions, input_extensions, expected_metadata, expected_extensions', + 'initial_extensions, input_extensions, expected_metadata', [ ( None, None, None, - None, ), # Case 1: No initial, No input ( ['ext1'], None, [(HTTP_EXTENSION_HEADER, 'ext1')], - ['ext1'], ), # Case 2: Initial, No input ( None, ['ext2'], [(HTTP_EXTENSION_HEADER, 'ext2')], - ['ext2'], ), # Case 3: No initial, Input ( ['ext1'], ['ext2'], [(HTTP_EXTENSION_HEADER, 'ext2')], - ['ext2'], ), # Case 4: Initial, Input (override) ( ['ext1'], ['ext2', 'ext3'], - [(HTTP_EXTENSION_HEADER, 'ext2, ext3')], - ['ext2', 'ext3'], + [(HTTP_EXTENSION_HEADER, 'ext2,ext3')], ), # Case 5: Initial, Multiple inputs (override) ( ['ext1', 'ext2'], ['ext3'], [(HTTP_EXTENSION_HEADER, 'ext3')], - ['ext3'], ), # Case 6: Multiple initial, Single input (override) ], ) @@ -538,12 +535,8 @@ def test_get_grpc_metadata( initial_extensions: list[str] | None, input_extensions: list[str] | None, expected_metadata: list[tuple[str, str]] | None, - expected_extensions: list[str] | None, ) -> None: """Tests _get_grpc_metadata for correct metadata generation and self.extensions update.""" grpc_transport.extensions = initial_extensions - - metadata = grpc_transport._get_grpc_metadata(extensions=input_extensions) - + metadata = grpc_transport._get_grpc_metadata(input_extensions) assert metadata == expected_metadata - assert grpc_transport.extensions == expected_extensions diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index f1cec6eb5..bd705d93c 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -790,7 +790,7 @@ async def test_close(self, mock_httpx_client: AsyncMock): class TestJsonRpcTransportExtensions: @pytest.mark.asyncio - async def test_send_message_with_extensions( + async def test_send_message_with_default_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): """Test that send_message adds extension headers when extensions are provided.""" @@ -838,13 +838,14 @@ async def test_send_message_with_extensions( @pytest.mark.asyncio @patch('a2a.client.transports.jsonrpc.aconnect_sse') - async def test_send_message_streaming_with_extensions( + async def test_send_message_streaming_with_new_extensions( self, mock_aconnect_sse: AsyncMock, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = JsonRpcTransport( httpx_client=mock_httpx_client, @@ -861,7 +862,9 @@ async def test_send_message_streaming_with_extensions( mock_event_source ) - async for _ in client.send_message_streaming(request=params): + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): pass mock_aconnect_sse.assert_called_once() @@ -870,5 +873,5 @@ async def test_send_message_streaming_with_extensions( headers = kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1' + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index ed8a25e20..04bd10361 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -34,7 +34,7 @@ async def async_iterable_from_list( class TestRestTransportExtensions: @pytest.mark.asyncio - async def test_send_message_with_extensions( + async def test_send_message_with_default_extensions( self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock ): """Test that send_message adds extensions to headers.""" @@ -82,13 +82,14 @@ async def test_send_message_with_extensions( @pytest.mark.asyncio @patch('a2a.client.transports.rest.aconnect_sse') - async def test_send_message_streaming_with_extensions( + async def test_send_message_streaming_with_new_extensions( self, mock_aconnect_sse: AsyncMock, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = RestTransport( httpx_client=mock_httpx_client, @@ -105,7 +106,9 @@ async def test_send_message_streaming_with_extensions( mock_event_source ) - async for _ in client.send_message_streaming(request=params): + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): pass mock_aconnect_sse.assert_called_once() @@ -114,5 +117,5 @@ async def test_send_message_streaming_with_extensions( headers = kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers assert ( - headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v1' + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' ) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 753386bcb..5e425059e 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -62,89 +62,67 @@ def test_find_extension_by_uri_no_extensions(): @pytest.mark.parametrize( - 'active_extensions, new_extensions, existing_header, expected_extensions, expected_count, expected_returned_extensions', + 'extensions, header, expected_extensions', [ ( - ['ext1', 'ext2'], # active_extensions - None, # new_extensions - '', # existing_header + ['ext1', 'ext2'], # extensions + '', # header { 'ext1', 'ext2', }, # expected_extensions - 2, # expected_count - ['ext1', 'ext2'], # expected_returned_extensions - ), # Case 1: Active extensions, no new extensions, and no existing header. + ), # Case 1: New extensions provided, empty header. ( - ['ext1', 'ext2'], # active_extensions - None, # new_extensions - 'ext2, ext3', # existing_header + None, # extensions + 'ext1, ext2', # existing_header { 'ext1', 'ext2', - 'ext3', }, # expected_extensions - 3, # expected_count - ['ext1', 'ext2'], # expected_returned_extensions - ), # Case 2: Active extensions, no new extensions, with an existing header containing overlapping and new extensions. + ), # Case 2: Extensions is None, existing header extensions. ( - ['ext1', 'ext2'], # active_extensions - None, # new_extensions + [], # extensions + 'ext1', # existing_header + {}, # expected_extensions + ), # Case 3: New extensions is empty list, existing header extensions. + ( + ['ext1', 'ext2'], # extensions 'ext3', # existing_header { 'ext1', 'ext2', - 'ext3', }, # expected_extensions - 3, # expected_count - ['ext1', 'ext2'], # expected_returned_extensions - ), # Case 3: Active extensions, no new extensions, with an existing header containing different extensions. - ( - ['ext1', 'ext2'], # active_extensions - ['ext3'], # new_extensions - 'ext4', # existing_header - { - 'ext3', - 'ext4', - }, # expected_extensions - 2, # expected_count - ['ext3'], # expected_returned_extensions - ), # Case 4: Active extensions, new extensions provided, and an existing header. New extensions should override active and merge with existing. + ), # Case 4: New extensions provided, and an existing header. New extensions should override active extensions. ], ) def test_update_extension_header_merge_with_existing_extensions( - active_extensions: list[str], - new_extensions: list[str], - existing_header: str, + extensions: list[str], + header: str, expected_extensions: set[str], - expected_count: int, - expected_returned_extensions: list[str], ): - http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: existing_header}} - result_kwargs, actual_returned_extensions = update_extension_header( - http_kwargs, active_extensions, new_extensions - ) + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: header}} + result_kwargs = update_extension_header(http_kwargs, extensions) header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] - actual_extensions_list = [e.strip() for e in header_value.split(',')] - actual_extensions = set(actual_extensions_list) - assert len(actual_extensions_list) == expected_count + if not header_value: + actual_extensions = {} + else: + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) assert actual_extensions == expected_extensions - assert actual_returned_extensions == expected_returned_extensions def test_update_extension_header_with_other_headers(): extensions = ['ext'] http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs, _ = update_extension_header(http_kwargs, extensions, None) + result_kwargs = update_extension_header(http_kwargs, extensions) headers = result_kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers assert headers[HTTP_EXTENSION_HEADER] == 'ext' assert headers['X_Other'] == 'Test' -@pytest.mark.parametrize('extensions', [(None), ([])]) -def test_update_extension_header_no_or_empty_extensions(extensions): +def test_update_extension_header_with_other_headers_extensions_none(): http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs, _ = update_extension_header(http_kwargs, extensions, None) + result_kwargs = update_extension_header(http_kwargs, None) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' From 16ee453eb2c14344364a5df165614b876a956277 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 17 Nov 2025 09:52:55 +0000 Subject: [PATCH 23/26] add integration test for extensions. Add a test case to test_common.py. Change desription of update_extension_header --- src/a2a/extensions/common.py | 2 +- tests/extensions/test_common.py | 9 +++ .../test_client_server_integration.py | 59 ++++++++++++++++++- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 2b3a5cfaa..8c2c65c7c 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -33,7 +33,7 @@ def update_extension_header( http_kwargs: dict[str, Any], extensions: list[str] | None, ) -> dict[str, Any]: - """Update the X-A2A-Extensions header and update active extensions.""" + """Update the X-A2A-Extensions header with active extensions.""" if extensions is not None: headers = http_kwargs.setdefault('headers', {}) headers[HTTP_EXTENSION_HEADER] = ','.join(extensions) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 5e425059e..9a71b3c04 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -126,3 +126,12 @@ def test_update_extension_header_with_other_headers_extensions_none(): result_kwargs = update_extension_header(http_kwargs, None) assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] assert result_kwargs['headers']['X_Other'] == 'Test' + + +def test_update_extension_header_empty_header(): + extensions = ['ext'] + http_kwargs = {} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'ext' diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 88d4d3d11..7dfbee8e2 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import AsyncGenerator from typing import NamedTuple -from unittest.mock import ANY, AsyncMock +from unittest.mock import ANY, AsyncMock, patch import grpc import httpx @@ -9,6 +9,8 @@ import pytest_asyncio from grpc.aio import Channel +from a2a.client import ClientConfig +from a2a.client.base_client import BaseClient from a2a.client.transports import JsonRpcTransport, RestTransport from a2a.client.transports.base import ClientTransport from a2a.client.transports.grpc import GrpcTransport @@ -767,3 +769,58 @@ def channel_factory(address: str) -> Channel: assert transport._needs_extended_card is False await transport.close() + + +@pytest.mark.asyncio +async def test_base_client_sends_message_with_extensions( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """ + Integration test for BaseClient with JSON-RPC transport to ensure extensions are included in headers. + """ + transport = jsonrpc_setup.transport + agent_card.capabilities.streaming = False + + # Create a BaseClient instance + client = BaseClient( + card=agent_card, + config=ClientConfig(streaming=False), + transport=transport, + consumers=[], + middleware=[], + ) + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-extensions', + parts=[Part(root=TextPart(text='Hello, extensions test!'))], + ) + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + + with patch.object( + transport, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': TASK_FROM_BLOCKING.model_dump(mode='json'), + } + + # Call send_message on the BaseClient + async for _ in client.send_message( + request=message_to_send, extensions=extensions + ): + pass + + mock_send_request.assert_called_once() + call_args, _ = mock_send_request.call_args + kwargs = call_args[1] + headers = kwargs.get('headers', {}) + assert 'X-A2A-Extensions' in headers + assert headers['X-A2A-Extensions'] == ','.join(extensions) + + if hasattr(transport, 'close'): + await transport.close() From 80be4bf8385961d2a9f690e204842daad13d2f8b Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 17 Nov 2025 12:02:59 +0000 Subject: [PATCH 24/26] change test case name in tests/extensions/test_common.py --- tests/extensions/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 9a71b3c04..a8cd4304c 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -128,7 +128,7 @@ def test_update_extension_header_with_other_headers_extensions_none(): assert result_kwargs['headers']['X_Other'] == 'Test' -def test_update_extension_header_empty_header(): +def test_update_extension_header_headers_not_in_kwargs(): extensions = ['ext'] http_kwargs = {} result_kwargs = update_extension_header(http_kwargs, extensions) From 4a423ef9a21fc2593a16206fa85840fc9f68a91d Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 18 Nov 2025 09:46:49 +0000 Subject: [PATCH 25/26] Change the order of update_extension_header and _apply_interceptors function calls inside rest and jsonrpc methods --- src/a2a/client/transports/jsonrpc.py | 82 ++++++++++++++-------------- src/a2a/client/transports/rest.py | 63 +++++++++++---------- src/a2a/extensions/common.py | 3 +- tests/extensions/test_common.py | 31 +++++++---- 4 files changed, 93 insertions(+), 86 deletions(-) diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 71a8fd8c7..d8011cf4d 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -120,15 +120,15 @@ async def send_message( ) -> Task | Message: """Sends a non-streaming message request to the agent.""" rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/send', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_request(payload, modified_kwargs) response = SendMessageResponse.model_validate(response_data) @@ -149,16 +149,15 @@ async def send_message_streaming( rpc_request = SendStreamingMessageRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/stream', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) @@ -224,15 +223,15 @@ async def get_task( ) -> Task: """Retrieves the current state and history of a specific task.""" rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskResponse.model_validate(response_data) @@ -249,15 +248,15 @@ async def cancel_task( ) -> Task: """Requests the agent to cancel a specific task.""" rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/cancel', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_request(payload, modified_kwargs) response = CancelTaskResponse.model_validate(response_data) @@ -276,15 +275,15 @@ async def set_task_callback( rpc_request = SetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/set', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_request(payload, modified_kwargs) response = SetTaskPushNotificationConfigResponse.model_validate( @@ -305,15 +304,15 @@ async def get_task_callback( rpc_request = GetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_request(payload, modified_kwargs) response = GetTaskPushNotificationConfigResponse.model_validate( @@ -334,15 +333,15 @@ async def resubscribe( ]: """Reconnects to get task updates.""" rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) modified_kwargs.setdefault('timeout', None) @@ -394,17 +393,16 @@ async def get_card( return card request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), - context, - ) - modified_kwargs = update_extension_header( modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) - response_data = await self._send_request( payload, modified_kwargs, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 2c859763a..83c267873 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -99,14 +99,14 @@ async def _prepare_send_message( ), ) payload = MessageToDict(pb) - payload, modified_kwargs = await self._apply_interceptors( - payload, + modified_kwargs = update_extension_header( self._get_http_args(context), - context, + extensions if extensions is not None else self.extensions, ) - modified_kwargs = update_extension_header( + payload, modified_kwargs = await self._apply_interceptors( + payload, modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) return payload, modified_kwargs @@ -219,14 +219,14 @@ async def get_task( extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" - _payload, modified_kwargs = await self._apply_interceptors( - request.model_dump(mode='json', exclude_none=True), + modified_kwargs = update_extension_header( self._get_http_args(context), - context, + extensions if extensions is not None else self.extensions, ) - modified_kwargs = update_extension_header( + _payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}', @@ -249,14 +249,14 @@ async def cancel_task( """Requests the agent to cancel a specific task.""" pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') payload = MessageToDict(pb) - payload, modified_kwargs = await self._apply_interceptors( - payload, + modified_kwargs = update_extension_header( self._get_http_args(context), - context, + extensions if extensions is not None else self.extensions, ) - modified_kwargs = update_extension_header( + payload, modified_kwargs = await self._apply_interceptors( + payload, modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_post_request( f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs @@ -279,13 +279,13 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config(request), ) payload = MessageToDict(pb) - payload, modified_kwargs = await self._apply_interceptors( - payload, self._get_http_args(context), context - ) modified_kwargs = update_extension_header( - modified_kwargs, + self._get_http_args(context), extensions if extensions is not None else self.extensions, ) + payload, modified_kwargs = await self._apply_interceptors( + payload, modified_kwargs, context + ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', payload, @@ -307,14 +307,14 @@ async def get_task_callback( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ) payload = MessageToDict(pb) - payload, modified_kwargs = await self._apply_interceptors( - payload, + modified_kwargs = update_extension_header( self._get_http_args(context), - context, + extensions if extensions is not None else self.extensions, ) - modified_kwargs = update_extension_header( + payload, modified_kwargs = await self._apply_interceptors( + payload, modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_get_request( f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', @@ -335,12 +335,11 @@ async def resubscribe( Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - http_kwargs = self._get_http_args(context) or {} - http_kwargs.setdefault('timeout', None) modified_kwargs = update_extension_header( - http_kwargs, + self._get_http_args(context), extensions if extensions is not None else self.extensions, ) + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, @@ -385,14 +384,14 @@ async def get_card( if not self._needs_extended_card: return card - _, modified_kwargs = await self._apply_interceptors( - {}, + modified_kwargs = update_extension_header( self._get_http_args(context), - context, + extensions if extensions is not None else self.extensions, ) - modified_kwargs = update_extension_header( + _, modified_kwargs = await self._apply_interceptors( + {}, modified_kwargs, - extensions if extensions is not None else self.extensions, + context, ) response_data = await self._send_get_request( '/v1/card', {}, modified_kwargs diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 8c2c65c7c..cba3517e4 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -30,10 +30,11 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: def update_extension_header( - http_kwargs: dict[str, Any], + http_kwargs: dict[str, Any] | None, extensions: list[str] | None, ) -> dict[str, Any]: """Update the X-A2A-Extensions header with active extensions.""" + http_kwargs = http_kwargs or {} if extensions is not None: headers = http_kwargs.setdefault('headers', {}) headers[HTTP_EXTENSION_HEADER] = ','.join(extensions) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index a8cd4304c..b3123028a 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -74,7 +74,7 @@ def test_find_extension_by_uri_no_extensions(): ), # Case 1: New extensions provided, empty header. ( None, # extensions - 'ext1, ext2', # existing_header + 'ext1, ext2', # header { 'ext1', 'ext2', @@ -82,12 +82,12 @@ def test_find_extension_by_uri_no_extensions(): ), # Case 2: Extensions is None, existing header extensions. ( [], # extensions - 'ext1', # existing_header + 'ext1', # header {}, # expected_extensions ), # Case 3: New extensions is empty list, existing header extensions. ( ['ext1', 'ext2'], # extensions - 'ext3', # existing_header + 'ext3', # header { 'ext1', 'ext2', @@ -121,17 +121,26 @@ def test_update_extension_header_with_other_headers(): assert headers['X_Other'] == 'Test' -def test_update_extension_header_with_other_headers_extensions_none(): - http_kwargs = {'headers': {'X_Other': 'Test'}} - result_kwargs = update_extension_header(http_kwargs, None) - assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] - assert result_kwargs['headers']['X_Other'] == 'Test' - - -def test_update_extension_header_headers_not_in_kwargs(): +@pytest.mark.parametrize( + 'http_kwargs', + [ + None, + {}, + ], +) +def test_update_extension_header_headers_not_in_kwargs( + http_kwargs: dict[str, str] | None, +): extensions = ['ext'] http_kwargs = {} result_kwargs = update_extension_header(http_kwargs, extensions) headers = result_kwargs.get('headers', {}) assert HTTP_EXTENSION_HEADER in headers assert headers[HTTP_EXTENSION_HEADER] == 'ext' + + +def test_update_extension_header_with_other_headers_extensions_none(): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, None) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' From 125406dc71d2e66760db98851c1675f4c7e42c48 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 18 Nov 2025 16:35:07 +0000 Subject: [PATCH 26/26] Change assertion in test_client_server_integration --- tests/integration/test_client_server_integration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 7dfbee8e2..e0a564eee 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -820,7 +820,10 @@ async def test_base_client_sends_message_with_extensions( kwargs = call_args[1] headers = kwargs.get('headers', {}) assert 'X-A2A-Extensions' in headers - assert headers['X-A2A-Extensions'] == ','.join(extensions) + assert ( + headers['X-A2A-Extensions'] + == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + ) if hasattr(transport, 'close'): await transport.close()