Skip to content

Commit 25ad3b1

Browse files
LitTool.from_model method to create LitTool from Pydantic (#57)
* from_model method to create LitTool from Pydantic Added a class method 'from_model' to create a LitTool from a Pydantic model, including setup and run methods for validation. * add test for run method * mypy silence * ruff format * lint * typehint change * revert to Optional[str] * Modify run method to return serialized instance Updated the run method to return a serialized instance instead of a model instance. * Update assertion in test_tool_run_from_model
1 parent fba9e18 commit 25ad3b1

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/litai/tools.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ def _extract_parameters(self) -> Dict[str, Any]:
117117

118118
return LangchainTool()
119119

120+
@classmethod
121+
def from_model(cls, model: type[BaseModel]) -> "LitTool":
122+
"""Create a LitTool that exposes a Pydantic model as a structured schema."""
123+
124+
class ModelTool(LitTool):
125+
def setup(self) -> None:
126+
super().setup()
127+
self.name = model.__name__
128+
self.description = model.__doc__ or ""
129+
130+
def run(self, *args, **kwargs) -> Any: # type: ignore
131+
# Default implementation: validate & return a serialized instance
132+
return model(*args, **kwargs).model_dump()
133+
134+
def _extract_parameters(self) -> Dict[str, Any]:
135+
return model.model_json_schema()
136+
137+
return ModelTool()
138+
120139
@classmethod
121140
def convert_tools(cls, tools: Optional[Sequence[Union["LitTool", "StructuredTool"]]]) -> List["LitTool"]:
122141
"""Convert a list of tools into LitTool instances.

tests/test_tools.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,21 @@
1515

1616
import pytest
1717
from langchain_core.tools import tool as langchain_tool
18+
from pydantic import BaseModel
1819

1920
from litai import LitTool, tool
2021

2122

23+
@pytest.fixture
24+
def weather_tool_model():
25+
class WeatherRequest(BaseModel):
26+
"""Get weather for location."""
27+
28+
location: str
29+
30+
return WeatherRequest
31+
32+
2233
@pytest.fixture
2334
def basic_tool_class():
2435
class TestTool(LitTool):
@@ -226,3 +237,45 @@ def get_weather(city: str) -> str:
226237

227238
with pytest.raises(TypeError, match="Unsupported tool type: <class 'function'>"):
228239
LitTool.convert_tools([get_weather])
240+
241+
242+
def test_tool_from_model_with_no_description(weather_tool_model):
243+
weather_tool_model.__doc__ = None
244+
245+
lit_tool = LitTool.from_model(weather_tool_model)
246+
247+
assert isinstance(lit_tool, LitTool)
248+
assert lit_tool.name == "WeatherRequest"
249+
assert lit_tool.description == ""
250+
251+
assert lit_tool.as_tool() == {
252+
"type": "function",
253+
"function": {
254+
"name": "WeatherRequest",
255+
"description": "",
256+
"parameters": weather_tool_model.model_json_schema(),
257+
},
258+
}
259+
260+
261+
def test_tool_run_from_model(weather_tool_model):
262+
lit_tool = LitTool.from_model(weather_tool_model)
263+
264+
assert lit_tool.run(location="NYC") == weather_tool_model(location="NYC").model_dump()
265+
266+
267+
def test_tool_from_model_with_description(weather_tool_model):
268+
lit_tool = LitTool.from_model(weather_tool_model)
269+
270+
assert isinstance(lit_tool, LitTool)
271+
assert lit_tool.name == "WeatherRequest"
272+
assert lit_tool.description == "Get weather for location."
273+
274+
assert lit_tool.as_tool() == {
275+
"type": "function",
276+
"function": {
277+
"name": "WeatherRequest",
278+
"description": "Get weather for location.",
279+
"parameters": weather_tool_model.model_json_schema(),
280+
},
281+
}

0 commit comments

Comments
 (0)