Skip to content

Commit 123c2f4

Browse files
Jacksunweicopybara-github
authored andcommitted
feat: Adds CustomAgentConfig to support user-defined agents in config
PiperOrigin-RevId: 786456046
1 parent 884c201 commit 123c2f4

File tree

2 files changed

+188
-1
lines changed

2 files changed

+188
-1
lines changed

src/google/adk/agents/agent_config.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,88 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
18+
from typing import Literal
19+
from typing import Type
20+
from typing import TypeVar
1721
from typing import Union
1822

23+
from pydantic import ConfigDict
24+
from pydantic import Discriminator
1925
from pydantic import RootModel
2026

2127
from ..utils.feature_decorator import working_in_progress
28+
from .base_agent import BaseAgentConfig
2229
from .llm_agent import LlmAgentConfig
2330
from .loop_agent import LoopAgentConfig
2431
from .parallel_agent import ParallelAgentConfig
2532
from .sequential_agent import SequentialAgentConfig
2633

34+
TBaseAgentConfig = TypeVar("TBaseAgentConfig", bound=BaseAgentConfig)
35+
36+
37+
@working_in_progress("AgentConfig is not ready for use.")
38+
class CustomAgentConfig(BaseAgentConfig):
39+
"""Used for configs for user-defined custom agents."""
40+
41+
model_config = ConfigDict(
42+
extra="allow",
43+
)
44+
agent_class: Union[Literal["CustomAgent"], str] = "CustomAgent"
45+
46+
def to_agent_config(
47+
self, custom_agent_config_cls: Type[TBaseAgentConfig]
48+
) -> TBaseAgentConfig:
49+
"""Converts the this config to the concrete agent config type.
50+
51+
```
52+
# In my_custom_agent.py
53+
class MyCustomAgentConfig(BaseAgentConfig):
54+
agent_class: Literal["mylib.agents.MyCustomAgent"] = "mylib.agents.MyCustomAgent"
55+
other_field: str
56+
57+
class MyCustomAgent(BaseAgent):
58+
...
59+
60+
def from_config(
61+
cls: Type[MyCustomAgent],
62+
config: CustomAgentConfig,
63+
config_abs_path: str,
64+
) -> MyCustomAgent:
65+
my_custom_agent_config = config.to_agent_config(MyCustomAgentConfig)
66+
# use my_custom_agent_config for the remaining logic...
67+
68+
```
69+
"""
70+
return custom_agent_config_cls.model_validate(self.model_dump())
71+
72+
2773
# A discriminated union of all possible agent configurations.
2874
ConfigsUnion = Union[
2975
LlmAgentConfig,
3076
LoopAgentConfig,
3177
ParallelAgentConfig,
3278
SequentialAgentConfig,
79+
CustomAgentConfig,
3380
]
3481

3582

83+
def agent_config_discriminator(v: Any):
84+
if isinstance(v, dict):
85+
agent_class = v.get("agent_class", "LlmAgentConfig")
86+
if agent_class in [
87+
"LlmAgent",
88+
"LoopAgent",
89+
"ParallelAgent",
90+
"SequentialAgent",
91+
]:
92+
return agent_class
93+
else:
94+
return "CustomAgent"
95+
96+
raise ValueError(f"Invalid agent config: {v}")
97+
98+
3699
# Use a RootModel to represent the agent directly at the top level.
37100
# The `discriminator` is applied to the union within the RootModel.
38101
@working_in_progress("AgentConfig is not ready for use.")
@@ -43,4 +106,4 @@ class Config:
43106
# Pydantic v2 requires this for discriminated unions on RootModel
44107
# This tells the model to look at the 'agent_class' field of the input
45108
# data to decide which model from the `ConfigsUnion` to use.
46-
discriminator = "agent_class"
109+
discriminator = Discriminator(agent_config_discriminator)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from typing import Literal
2+
3+
from google.adk.agents.agent_config import AgentConfig
4+
from google.adk.agents.agent_config import CustomAgentConfig
5+
from google.adk.agents.agent_config import LlmAgentConfig
6+
from google.adk.agents.agent_config import LoopAgentConfig
7+
from google.adk.agents.agent_config import ParallelAgentConfig
8+
from google.adk.agents.agent_config import SequentialAgentConfig
9+
from google.adk.agents.base_agent import BaseAgentConfig
10+
import yaml
11+
12+
13+
def test_agent_config_discriminator_default_is_llm_agent():
14+
yaml_content = """\
15+
name: search_agent
16+
model: gemini-2.0-flash
17+
description: a sample description
18+
instruction: a fake instruction
19+
tools:
20+
- name: google_search
21+
"""
22+
config_data = yaml.safe_load(yaml_content)
23+
24+
config = AgentConfig.model_validate(config_data)
25+
26+
assert isinstance(config.root, LlmAgentConfig)
27+
assert config.root.agent_class == "LlmAgent"
28+
29+
30+
def test_agent_config_discriminator_llm_agent():
31+
yaml_content = """\
32+
agent_class: LlmAgent
33+
name: search_agent
34+
model: gemini-2.0-flash
35+
description: a sample description
36+
instruction: a fake instruction
37+
tools:
38+
- name: google_search
39+
"""
40+
config_data = yaml.safe_load(yaml_content)
41+
42+
config = AgentConfig.model_validate(config_data)
43+
44+
assert isinstance(config.root, LlmAgentConfig)
45+
assert config.root.agent_class == "LlmAgent"
46+
47+
48+
def test_agent_config_discriminator_loop_agent():
49+
yaml_content = """\
50+
agent_class: LoopAgent
51+
name: CodePipelineAgent
52+
description: Executes a sequence of code writing, reviewing, and refactoring.
53+
sub_agents:
54+
- config: sub_agents/code_writer_agent.yaml
55+
- config: sub_agents/code_reviewer_agent.yaml
56+
- config: sub_agents/code_refactorer_agent.yaml
57+
"""
58+
config_data = yaml.safe_load(yaml_content)
59+
60+
config = AgentConfig.model_validate(config_data)
61+
62+
assert isinstance(config.root, LoopAgentConfig)
63+
assert config.root.agent_class == "LoopAgent"
64+
65+
66+
def test_agent_config_discriminator_parallel_agent():
67+
yaml_content = """\
68+
agent_class: ParallelAgent
69+
name: CodePipelineAgent
70+
description: Executes a sequence of code writing, reviewing, and refactoring.
71+
sub_agents:
72+
- config: sub_agents/code_writer_agent.yaml
73+
- config: sub_agents/code_reviewer_agent.yaml
74+
- config: sub_agents/code_refactorer_agent.yaml
75+
"""
76+
config_data = yaml.safe_load(yaml_content)
77+
78+
config = AgentConfig.model_validate(config_data)
79+
80+
assert isinstance(config.root, ParallelAgentConfig)
81+
assert config.root.agent_class == "ParallelAgent"
82+
83+
84+
def test_agent_config_discriminator_sequential_agent():
85+
yaml_content = """\
86+
agent_class: SequentialAgent
87+
name: CodePipelineAgent
88+
description: Executes a sequence of code writing, reviewing, and refactoring.
89+
sub_agents:
90+
- config: sub_agents/code_writer_agent.yaml
91+
- config: sub_agents/code_reviewer_agent.yaml
92+
- config: sub_agents/code_refactorer_agent.yaml
93+
"""
94+
config_data = yaml.safe_load(yaml_content)
95+
96+
config = AgentConfig.model_validate(config_data)
97+
98+
assert isinstance(config.root, SequentialAgentConfig)
99+
assert config.root.agent_class == "SequentialAgent"
100+
101+
102+
def test_agent_config_discriminator_custom_agent():
103+
class MyCustomAgentConfig(BaseAgentConfig):
104+
agent_class: Literal["mylib.agents.MyCustomAgent"] = (
105+
"mylib.agents.MyCustomAgent"
106+
)
107+
other_field: str
108+
109+
yaml_content = """\
110+
agent_class: mylib.agents.MyCustomAgent
111+
name: CodePipelineAgent
112+
description: Executes a sequence of code writing, reviewing, and refactoring.
113+
other_field: other valud
114+
"""
115+
config_data = yaml.safe_load(yaml_content)
116+
117+
config = AgentConfig.model_validate(config_data)
118+
119+
assert isinstance(config.root, CustomAgentConfig)
120+
assert config.root.agent_class == "mylib.agents.MyCustomAgent"
121+
assert config.root.model_extra == {"other_field": "other valud"}
122+
123+
my_custom_config = config.root.to_agent_config(MyCustomAgentConfig)
124+
assert my_custom_config.other_field == "other valud"

0 commit comments

Comments
 (0)