Skip to content

Commit bc811e8

Browse files
committed
base_class as ImportString + validator for base_class in AgentDefinition
1 parent fd51bac commit bc811e8

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sgr_deep_research/core/agent_definition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
import os
34
from functools import cached_property
@@ -6,7 +7,7 @@
67

78
import yaml
89
from fastmcp.mcp_config import MCPConfig
9-
from pydantic import BaseModel, Field, FilePath, computed_field, model_validator
10+
from pydantic import BaseModel, Field, FilePath, ImportString, computed_field, field_validator, model_validator
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -128,12 +129,11 @@ class AgentDefinition(AgentConfig):
128129

129130
name: str = Field(description="Unique agent name/ID")
130131
# ToDo: not sure how to type this properly and avoid circular imports
131-
base_class: type[Any] | str = Field(description="Agent class name")
132+
base_class: type[Any] | ImportString | str = Field(description="Agent class name")
132133
tools: list[type[Any] | str] = Field(default_factory=list, description="List of tool names to include")
133134

134135
@model_validator(mode="before")
135136
def default_config_override_validator(cls, data):
136-
print(data)
137137
from sgr_deep_research.core.agent_config import GlobalConfig
138138

139139
data["llm"] = GlobalConfig().llm.model_copy(update=data.get("llm", {})).model_dump()
@@ -157,6 +157,14 @@ def necessary_fields_validator(self) -> Self:
157157
raise ValueError(f"Tools are not provided for agent '{self.name}'")
158158
return self
159159

160+
@field_validator("base_class", mode="after")
161+
def base_class_is_agent(cls, v: Any) -> type[Any]:
162+
from sgr_deep_research.core.base_agent import BaseAgent
163+
164+
if inspect.isclass(v) and not issubclass(v, BaseAgent):
165+
raise TypeError("Imported base_class must be a subclass of BaseAgent")
166+
return v
167+
160168
def __str__(self) -> str:
161169
base_class_name = self.base_class.__name__ if isinstance(self.base_class, type) else self.base_class
162170
tool_names = [t.__name__ if isinstance(t, type) else t for t in self.tools]

0 commit comments

Comments
 (0)