Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ def test_inference_from_local_jsonl_file(self, mock_models):
assert inference_result.candidate_name == "gemini-pro"
assert inference_result.gcs_source is None

@pytest.mark.skip(reason="currently flakey")
@mock.patch.object(_evals_common, "Models")
def test_inference_from_local_csv_file(self, mock_models):
local_src_path = "/tmp/input.csv"
Expand Down
96 changes: 79 additions & 17 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
except (ImportError, AttributeError):
Event = Any

try:
from google.adk.apps import App

App = App
except (ImportError, AttributeError):
App = Any

try:
from google.adk.agents import BaseAgent

Expand Down Expand Up @@ -449,7 +456,8 @@ class AdkApp:
def __init__(
self,
*,
agent: "BaseAgent",
app: "App" = None,
agent: "BaseAgent" = None,
app_name: Optional[str] = None,
plugins: Optional[List["BasePlugin"]] = None,
enable_tracing: Optional[bool] = None,
Expand Down Expand Up @@ -505,10 +513,26 @@ def __init__(
)
raise ValueError(msg)

if not agent and not app:
raise ValueError("One of `agent` or `app` must be provided.")
if app:
if app_name:
raise ValueError(
"When app is provided, app_name should not be provided."
)
if agent:
raise ValueError("When app is provided, agent should not be provided.")
if plugins:
raise ValueError(
"When app is provided, plugins should not be provided and"
" should be provided in the app instead."
)

self._tmpl_attrs: Dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"agent": agent,
"app": app,
"app_name": app_name,
"plugins": plugins,
"enable_tracing": enable_tracing,
Expand Down Expand Up @@ -624,10 +648,23 @@ def clone(self):
import copy

return self.__class__(
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
app=copy.deepcopy(self._tmpl_attrs.get("app")),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
app_name=self._tmpl_attrs.get("app_name"),
plugins=self._tmpl_attrs.get("plugins"),
agent=(
None
if self._tmpl_attrs.get("app")
else copy.deepcopy(self._tmpl_attrs.get("agent"))
),
app_name=(
None
if self._tmpl_attrs.get("app")
else self._tmpl_attrs.get("app_name")
),
plugins=(
None
if self._tmpl_attrs.get("app")
else copy.deepcopy(self._tmpl_attrs.get("plugins"))
),
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
memory_service_builder=self._tmpl_attrs.get("memory_service_builder"),
Expand Down Expand Up @@ -774,20 +811,38 @@ def tracing_enabled() -> bool:
self._tmpl_attrs["memory_service"] = InMemoryMemoryService()

self._tmpl_attrs["runner"] = Runner(
agent=self._tmpl_attrs.get("agent"),
plugins=self._tmpl_attrs.get("plugins"),
app=self._tmpl_attrs.get("app"),
agent=(
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent")
),
app_name=(
None
if self._tmpl_attrs.get("app")
else self._tmpl_attrs.get("app_name")
),
plugins=(
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins")
),
session_service=self._tmpl_attrs.get("session_service"),
artifact_service=self._tmpl_attrs.get("artifact_service"),
memory_service=self._tmpl_attrs.get("memory_service"),
app_name=self._tmpl_attrs.get("app_name"),
)
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService()
self._tmpl_attrs["in_memory_runner"] = Runner(
app_name=self._tmpl_attrs.get("app_name"),
agent=self._tmpl_attrs.get("agent"),
plugins=self._tmpl_attrs.get("plugins"),
app=self._tmpl_attrs.get("app"),
app_name=(
None
if self._tmpl_attrs.get("app")
else self._tmpl_attrs.get("app_name")
),
agent=(
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent")
),
plugins=(
None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins")
),
session_service=self._tmpl_attrs.get("in_memory_session_service"),
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
Expand Down Expand Up @@ -968,12 +1023,13 @@ async def streaming_agent_run_with_events(self, request_json: str):
self.set_up()
session_service = self._tmpl_attrs.get("in_memory_session_service")
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
app = self._tmpl_attrs.get("app")
# Try to get the session, if it doesn't exist, create a new one.
session = None
if request.session_id:
try:
session = await session_service.get_session(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=request.user_id,
session_id=request.session_id,
)
Expand Down Expand Up @@ -1006,8 +1062,9 @@ async def streaming_agent_run_with_events(self, request_json: str):
yield converted_event
finally:
if session and not request.session_id:
app = self._tmpl_attrs.get("app")
await session_service.delete_session(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=request.user_id,
session_id=session.id,
)
Expand Down Expand Up @@ -1039,8 +1096,9 @@ async def async_get_session(
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
app = self._tmpl_attrs.get("app")
session = await self._tmpl_attrs.get("session_service").get_session(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
**kwargs,
Expand Down Expand Up @@ -1116,8 +1174,9 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
app = self._tmpl_attrs.get("app")
return await self._tmpl_attrs.get("session_service").list_sessions(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=user_id,
**kwargs,
)
Expand Down Expand Up @@ -1188,8 +1247,9 @@ async def async_create_session(
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
app = self._tmpl_attrs.get("app")
session = await self._tmpl_attrs.get("session_service").create_session(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
state=state,
Expand Down Expand Up @@ -1269,8 +1329,9 @@ async def async_delete_session(
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
app = self._tmpl_attrs.get("app")
await self._tmpl_attrs.get("session_service").delete_session(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
**kwargs,
Expand Down Expand Up @@ -1359,8 +1420,9 @@ async def async_search_memory(self, *, user_id: str, query: str):
"""
if not self._tmpl_attrs.get("memory_service"):
self.set_up()
app = self._tmpl_attrs.get("app")
return await self._tmpl_attrs.get("memory_service").search_memory(
app_name=self._tmpl_attrs.get("app_name"),
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
user_id=user_id,
query=query,
)
Expand Down