Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 54 additions & 48 deletions tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,63 +22,69 @@

def test_create_session_with_ttl(client):
agent_engine = client.agent_engines.create()
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
try:
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)

operation = client.agent_engines.create_session(
name=agent_engine.api_resource.name,
user_id="test-user-123",
config=types.CreateAgentEngineSessionConfig(
display_name="my_session",
session_state={"foo": "bar"},
ttl="120s",
),
)
assert isinstance(operation, types.AgentEngineSessionOperation)
assert operation.response.display_name == "my_session"
assert operation.response.session_state == {"foo": "bar"}
assert operation.response.user_id == "test-user-123"
assert operation.response.name.startswith(agent_engine.api_resource.name)
# Expire time is calculated by the server, so we only check that it is
# within a reasonable range to avoid flakiness.
assert (
operation.response.create_time + datetime.timedelta(seconds=119.5)
<= operation.response.expire_time
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
)
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
operation = client.agent_engines.create_session(
name=agent_engine.api_resource.name,
user_id="test-user-123",
config=types.CreateAgentEngineSessionConfig(
display_name="my_session",
session_state={"foo": "bar"},
ttl="120s",
labels={"label_key": "label_value"},
),
)
assert isinstance(operation, types.AgentEngineSessionOperation)
assert operation.response.display_name == "my_session"
assert operation.response.session_state == {"foo": "bar"}
assert operation.response.user_id == "test-user-123"
assert operation.response.labels == {"label_key": "label_value"}
assert operation.response.name.startswith(agent_engine.api_resource.name)
# Expire time is calculated by the server, so we only check that it is
# within a reasonable range to avoid flakiness.
assert (
operation.response.create_time + datetime.timedelta(seconds=119.5)
<= operation.response.expire_time
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
)
finally:
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


def test_create_session_with_expire_time(client):
agent_engine = client.agent_engines.create()
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
expire_time = datetime.datetime(
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)
try:
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
expire_time = datetime.datetime(
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)

operation = client.agent_engines.sessions.create(
name=agent_engine.api_resource.name,
user_id="test-user-123",
config=types.CreateAgentEngineSessionConfig(
display_name="my_session",
session_state={"foo": "bar"},
expire_time=expire_time,
),
)
assert isinstance(operation, types.AgentEngineSessionOperation)
assert operation.response.display_name == "my_session"
assert operation.response.session_state == {"foo": "bar"}
assert operation.response.user_id == "test-user-123"
assert operation.response.name.startswith(agent_engine.api_resource.name)
assert operation.response.expire_time == expire_time
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
operation = client.agent_engines.sessions.create(
name=agent_engine.api_resource.name,
user_id="test-user-123",
config=types.CreateAgentEngineSessionConfig(
display_name="my_session",
session_state={"foo": "bar"},
expire_time=expire_time,
),
)
assert isinstance(operation, types.AgentEngineSessionOperation)
assert operation.response.display_name == "my_session"
assert operation.response.session_state == {"foo": "bar"}
assert operation.response.user_id == "test-user-123"
assert operation.response.name.startswith(agent_engine.api_resource.name)
assert operation.response.expire_time == expire_time
finally:
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="agent_engines.create_session",
test_method="agent_engines.sessions.create",
)
6 changes: 6 additions & 0 deletions vertexai/_genai/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def _CreateAgentEngineSessionConfig_to_vertex(
if getv(from_object, ["expire_time"]) is not None:
setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"]))

if getv(from_object, ["labels"]) is not None:
setv(parent_object, ["labels"], getv(from_object, ["labels"]))

return to_object


Expand Down Expand Up @@ -181,6 +184,9 @@ def _UpdateAgentEngineSessionConfig_to_vertex(
if getv(from_object, ["expire_time"]) is not None:
setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"]))

if getv(from_object, ["labels"]) is not None:
setv(parent_object, ["labels"], getv(from_object, ["labels"]))

if getv(from_object, ["update_mask"]) is not None:
setv(
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
Expand Down
67 changes: 44 additions & 23 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8778,6 +8778,10 @@ class CreateAgentEngineSessionConfig(_common.BaseModel):
default=None,
description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""",
)
labels: Optional[dict[str, str]] = Field(
default=None,
description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
)


class CreateAgentEngineSessionConfigDict(TypedDict, total=False):
Expand All @@ -8803,6 +8807,9 @@ class CreateAgentEngineSessionConfigDict(TypedDict, total=False):
expire_time: Optional[datetime.datetime]
"""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input."""

labels: Optional[dict[str, str]]
"""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""


CreateAgentEngineSessionConfigOrDict = Union[
CreateAgentEngineSessionConfig, CreateAgentEngineSessionConfigDict
Expand Down Expand Up @@ -8846,32 +8853,36 @@ class _CreateAgentEngineSessionRequestParametersDict(TypedDict, total=False):
class Session(_common.BaseModel):
"""A session."""

create_time: Optional[datetime.datetime] = Field(
default=None,
description="""Output only. Timestamp when the session was created.""",
)
display_name: Optional[str] = Field(
default=None, description="""Optional. The display name of the session."""
)
expire_time: Optional[datetime.datetime] = Field(
default=None,
description="""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input.""",
)
ttl: Optional[str] = Field(
default=None, description="""Optional. Input only. The TTL for this session."""
)
name: Optional[str] = Field(
default=None,
description="""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""",
)
session_state: Optional[dict[str, Any]] = Field(
create_time: Optional[datetime.datetime] = Field(
default=None,
description="""Optional. Session specific memory which stores key conversation points.""",
)
ttl: Optional[str] = Field(
default=None, description="""Optional. Input only. The TTL for this session."""
description="""Output only. Timestamp when the session was created.""",
)
update_time: Optional[datetime.datetime] = Field(
default=None,
description="""Output only. Timestamp when the session was updated.""",
)
display_name: Optional[str] = Field(
default=None, description="""Optional. The display name of the session."""
)
labels: Optional[dict[str, str]] = Field(
default=None,
description="""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
)
session_state: Optional[dict[str, Any]] = Field(
default=None,
description="""Optional. Session specific memory which stores key conversation points.""",
)
user_id: Optional[str] = Field(
default=None,
description="""Required. Immutable. String id provided by the user""",
Expand All @@ -8881,27 +8892,30 @@ class Session(_common.BaseModel):
class SessionDict(TypedDict, total=False):
"""A session."""

create_time: Optional[datetime.datetime]
"""Output only. Timestamp when the session was created."""

display_name: Optional[str]
"""Optional. The display name of the session."""

expire_time: Optional[datetime.datetime]
"""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input."""

ttl: Optional[str]
"""Optional. Input only. The TTL for this session."""

name: Optional[str]
"""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'."""

session_state: Optional[dict[str, Any]]
"""Optional. Session specific memory which stores key conversation points."""

ttl: Optional[str]
"""Optional. Input only. The TTL for this session."""
create_time: Optional[datetime.datetime]
"""Output only. Timestamp when the session was created."""

update_time: Optional[datetime.datetime]
"""Output only. Timestamp when the session was updated."""

display_name: Optional[str]
"""Optional. The display name of the session."""

labels: Optional[dict[str, str]]
"""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""

session_state: Optional[dict[str, Any]]
"""Optional. Session specific memory which stores key conversation points."""

user_id: Optional[str]
"""Required. Immutable. String id provided by the user"""

Expand Down Expand Up @@ -9240,6 +9254,10 @@ class UpdateAgentEngineSessionConfig(_common.BaseModel):
default=None,
description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""",
)
labels: Optional[dict[str, str]] = Field(
default=None,
description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""",
)
update_mask: Optional[str] = Field(
default=None,
description="""The update mask to apply. For the `FieldMask` definition, see
Expand Down Expand Up @@ -9273,6 +9291,9 @@ class UpdateAgentEngineSessionConfigDict(TypedDict, total=False):
expire_time: Optional[datetime.datetime]
"""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input."""

labels: Optional[dict[str, str]]
"""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels."""

update_mask: Optional[str]
"""The update mask to apply. For the `FieldMask` definition, see
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
Expand Down
Loading