From 837c8ea05479ae43847d2e0f9e7d80385f43ba0e Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 21 Nov 2025 09:31:01 -0800 Subject: [PATCH] feat: GenAI SDK client(sessions): Add label to Sessions PiperOrigin-RevId: 835253647 --- .../test_create_agent_engine_session.py | 102 +++++++++--------- vertexai/_genai/sessions.py | 6 ++ vertexai/_genai/types/common.py | 67 ++++++++---- 3 files changed, 104 insertions(+), 71 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py b/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py index 5658aaa1fd..b055a9b67c 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py +++ b/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py @@ -22,63 +22,69 @@ def test_create_session_with_ttl(client): agent_engine = client.agent_engines.create() - assert isinstance(agent_engine, types.AgentEngine) - assert isinstance(agent_engine.api_resource, types.ReasoningEngine) + try: + assert isinstance(agent_engine, types.AgentEngine) + assert isinstance(agent_engine.api_resource, types.ReasoningEngine) - operation = client.agent_engines.create_session( - name=agent_engine.api_resource.name, - user_id="test-user-123", - config=types.CreateAgentEngineSessionConfig( - display_name="my_session", - session_state={"foo": "bar"}, - ttl="120s", - ), - ) - assert isinstance(operation, types.AgentEngineSessionOperation) - assert operation.response.display_name == "my_session" - assert operation.response.session_state == {"foo": "bar"} - assert operation.response.user_id == "test-user-123" - assert operation.response.name.startswith(agent_engine.api_resource.name) - # Expire time is calculated by the server, so we only check that it is - # within a reasonable range to avoid flakiness. - assert ( - operation.response.create_time + datetime.timedelta(seconds=119.5) - <= operation.response.expire_time - <= operation.response.create_time + datetime.timedelta(seconds=120.5) - ) - # Clean up resources. - client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) + operation = client.agent_engines.create_session( + name=agent_engine.api_resource.name, + user_id="test-user-123", + config=types.CreateAgentEngineSessionConfig( + display_name="my_session", + session_state={"foo": "bar"}, + ttl="120s", + labels={"label_key": "label_value"}, + ), + ) + assert isinstance(operation, types.AgentEngineSessionOperation) + assert operation.response.display_name == "my_session" + assert operation.response.session_state == {"foo": "bar"} + assert operation.response.user_id == "test-user-123" + assert operation.response.labels == {"label_key": "label_value"} + assert operation.response.name.startswith(agent_engine.api_resource.name) + # Expire time is calculated by the server, so we only check that it is + # within a reasonable range to avoid flakiness. + assert ( + operation.response.create_time + datetime.timedelta(seconds=119.5) + <= operation.response.expire_time + <= operation.response.create_time + datetime.timedelta(seconds=120.5) + ) + finally: + # Clean up resources. + client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) def test_create_session_with_expire_time(client): agent_engine = client.agent_engines.create() - assert isinstance(agent_engine, types.AgentEngine) - assert isinstance(agent_engine.api_resource, types.ReasoningEngine) - expire_time = datetime.datetime( - 2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc - ) + try: + assert isinstance(agent_engine, types.AgentEngine) + assert isinstance(agent_engine.api_resource, types.ReasoningEngine) + expire_time = datetime.datetime( + 2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc + ) - operation = client.agent_engines.sessions.create( - name=agent_engine.api_resource.name, - user_id="test-user-123", - config=types.CreateAgentEngineSessionConfig( - display_name="my_session", - session_state={"foo": "bar"}, - expire_time=expire_time, - ), - ) - assert isinstance(operation, types.AgentEngineSessionOperation) - assert operation.response.display_name == "my_session" - assert operation.response.session_state == {"foo": "bar"} - assert operation.response.user_id == "test-user-123" - assert operation.response.name.startswith(agent_engine.api_resource.name) - assert operation.response.expire_time == expire_time - # Clean up resources. - client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) + operation = client.agent_engines.sessions.create( + name=agent_engine.api_resource.name, + user_id="test-user-123", + config=types.CreateAgentEngineSessionConfig( + display_name="my_session", + session_state={"foo": "bar"}, + expire_time=expire_time, + ), + ) + assert isinstance(operation, types.AgentEngineSessionOperation) + assert operation.response.display_name == "my_session" + assert operation.response.session_state == {"foo": "bar"} + assert operation.response.user_id == "test-user-123" + assert operation.response.name.startswith(agent_engine.api_resource.name) + assert operation.response.expire_time == expire_time + finally: + # Clean up resources. + client.agent_engines.delete(name=agent_engine.api_resource.name, force=True) pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), - test_method="agent_engines.create_session", + test_method="agent_engines.sessions.create", ) diff --git a/vertexai/_genai/sessions.py b/vertexai/_genai/sessions.py index c82550e59e..31b83ec6dc 100644 --- a/vertexai/_genai/sessions.py +++ b/vertexai/_genai/sessions.py @@ -55,6 +55,9 @@ def _CreateAgentEngineSessionConfig_to_vertex( if getv(from_object, ["expire_time"]) is not None: setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + return to_object @@ -181,6 +184,9 @@ def _UpdateAgentEngineSessionConfig_to_vertex( if getv(from_object, ["expire_time"]) is not None: setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + if getv(from_object, ["update_mask"]) is not None: setv( parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 1fefe3652c..2b6e9f0fc9 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -8778,6 +8778,10 @@ class CreateAgentEngineSessionConfig(_common.BaseModel): default=None, description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) class CreateAgentEngineSessionConfigDict(TypedDict, total=False): @@ -8803,6 +8807,9 @@ class CreateAgentEngineSessionConfigDict(TypedDict, total=False): expire_time: Optional[datetime.datetime] """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + labels: Optional[dict[str, str]] + """Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + CreateAgentEngineSessionConfigOrDict = Union[ CreateAgentEngineSessionConfig, CreateAgentEngineSessionConfigDict @@ -8846,32 +8853,36 @@ class _CreateAgentEngineSessionRequestParametersDict(TypedDict, total=False): class Session(_common.BaseModel): """A session.""" - create_time: Optional[datetime.datetime] = Field( - default=None, - description="""Output only. Timestamp when the session was created.""", - ) - display_name: Optional[str] = Field( - default=None, description="""Optional. The display name of the session.""" - ) expire_time: Optional[datetime.datetime] = Field( default=None, description="""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input.""", ) + ttl: Optional[str] = Field( + default=None, description="""Optional. Input only. The TTL for this session.""" + ) name: Optional[str] = Field( default=None, description="""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""", ) - session_state: Optional[dict[str, Any]] = Field( + create_time: Optional[datetime.datetime] = Field( default=None, - description="""Optional. Session specific memory which stores key conversation points.""", - ) - ttl: Optional[str] = Field( - default=None, description="""Optional. Input only. The TTL for this session.""" + description="""Output only. Timestamp when the session was created.""", ) update_time: Optional[datetime.datetime] = Field( default=None, description="""Output only. Timestamp when the session was updated.""", ) + display_name: Optional[str] = Field( + default=None, description="""Optional. The display name of the session.""" + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + session_state: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Session specific memory which stores key conversation points.""", + ) user_id: Optional[str] = Field( default=None, description="""Required. Immutable. String id provided by the user""", @@ -8881,27 +8892,30 @@ class Session(_common.BaseModel): class SessionDict(TypedDict, total=False): """A session.""" - create_time: Optional[datetime.datetime] - """Output only. Timestamp when the session was created.""" - - display_name: Optional[str] - """Optional. The display name of the session.""" - expire_time: Optional[datetime.datetime] """Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input.""" + ttl: Optional[str] + """Optional. Input only. The TTL for this session.""" + name: Optional[str] """Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""" - session_state: Optional[dict[str, Any]] - """Optional. Session specific memory which stores key conversation points.""" - - ttl: Optional[str] - """Optional. Input only. The TTL for this session.""" + create_time: Optional[datetime.datetime] + """Output only. Timestamp when the session was created.""" update_time: Optional[datetime.datetime] """Output only. Timestamp when the session was updated.""" + display_name: Optional[str] + """Optional. The display name of the session.""" + + labels: Optional[dict[str, str]] + """The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + session_state: Optional[dict[str, Any]] + """Optional. Session specific memory which stores key conversation points.""" + user_id: Optional[str] """Required. Immutable. String id provided by the user""" @@ -9240,6 +9254,10 @@ class UpdateAgentEngineSessionConfig(_common.BaseModel): default=None, description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) update_mask: Optional[str] = Field( default=None, description="""The update mask to apply. For the `FieldMask` definition, see @@ -9273,6 +9291,9 @@ class UpdateAgentEngineSessionConfigDict(TypedDict, total=False): expire_time: Optional[datetime.datetime] """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + labels: Optional[dict[str, str]] + """Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + update_mask: Optional[str] """The update mask to apply. For the `FieldMask` definition, see https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""