|
15 | 15 |
|
16 | 16 | import pytest |
17 | 17 | from langchain_core.tools import tool as langchain_tool |
| 18 | +from pydantic import BaseModel |
18 | 19 |
|
19 | 20 | from litai import LitTool, tool |
20 | 21 |
|
21 | 22 |
|
| 23 | +@pytest.fixture |
| 24 | +def weather_tool_model(): |
| 25 | + class WeatherRequest(BaseModel): |
| 26 | + """Get weather for location.""" |
| 27 | + |
| 28 | + location: str |
| 29 | + |
| 30 | + return WeatherRequest |
| 31 | + |
| 32 | + |
22 | 33 | @pytest.fixture |
23 | 34 | def basic_tool_class(): |
24 | 35 | class TestTool(LitTool): |
@@ -226,3 +237,45 @@ def get_weather(city: str) -> str: |
226 | 237 |
|
227 | 238 | with pytest.raises(TypeError, match="Unsupported tool type: <class 'function'>"): |
228 | 239 | LitTool.convert_tools([get_weather]) |
| 240 | + |
| 241 | + |
| 242 | +def test_tool_from_model_with_no_description(weather_tool_model): |
| 243 | + weather_tool_model.__doc__ = None |
| 244 | + |
| 245 | + lit_tool = LitTool.from_model(weather_tool_model) |
| 246 | + |
| 247 | + assert isinstance(lit_tool, LitTool) |
| 248 | + assert lit_tool.name == "WeatherRequest" |
| 249 | + assert lit_tool.description == "" |
| 250 | + |
| 251 | + assert lit_tool.as_tool() == { |
| 252 | + "type": "function", |
| 253 | + "function": { |
| 254 | + "name": "WeatherRequest", |
| 255 | + "description": "", |
| 256 | + "parameters": weather_tool_model.model_json_schema(), |
| 257 | + }, |
| 258 | + } |
| 259 | + |
| 260 | + |
| 261 | +def test_tool_run_from_model(weather_tool_model): |
| 262 | + lit_tool = LitTool.from_model(weather_tool_model) |
| 263 | + |
| 264 | + assert lit_tool.run(location="NYC") == weather_tool_model(location="NYC").model_dump() |
| 265 | + |
| 266 | + |
| 267 | +def test_tool_from_model_with_description(weather_tool_model): |
| 268 | + lit_tool = LitTool.from_model(weather_tool_model) |
| 269 | + |
| 270 | + assert isinstance(lit_tool, LitTool) |
| 271 | + assert lit_tool.name == "WeatherRequest" |
| 272 | + assert lit_tool.description == "Get weather for location." |
| 273 | + |
| 274 | + assert lit_tool.as_tool() == { |
| 275 | + "type": "function", |
| 276 | + "function": { |
| 277 | + "name": "WeatherRequest", |
| 278 | + "description": "Get weather for location.", |
| 279 | + "parameters": weather_tool_model.model_json_schema(), |
| 280 | + }, |
| 281 | + } |
0 commit comments