From 311395df85c2895b8e52a038a10f1cb1124029db Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 11 Nov 2025 17:03:36 -0800 Subject: [PATCH 01/16] Initial commit for xAI grok model --- docs/api/models/grok.md | 7 + docs/models/grok.md | 77 ++++ docs/models/overview.md | 1 + pydantic_ai_slim/pydantic_ai/models/grok.py | 340 ++++++++++++++++++ pydantic_ai_slim/pydantic_ai/profiles/grok.py | 18 +- .../pydantic_ai/providers/grok.py | 2 + pydantic_ai_slim/pyproject.toml | 1 + pyproject.toml | 2 +- uv.lock | 232 ++++++++---- 9 files changed, 600 insertions(+), 80 deletions(-) create mode 100644 docs/api/models/grok.md create mode 100644 docs/models/grok.md create mode 100644 pydantic_ai_slim/pydantic_ai/models/grok.py diff --git a/docs/api/models/grok.md b/docs/api/models/grok.md new file mode 100644 index 0000000000..699c1e95f9 --- /dev/null +++ b/docs/api/models/grok.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.grok` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for Grokq](../../models/grokq.md). + +::: pydantic_ai.models.grok diff --git a/docs/models/grok.md b/docs/models/grok.md new file mode 100644 index 0000000000..db04423493 --- /dev/null +++ b/docs/models/grok.md @@ -0,0 +1,77 @@ +# Groq + +## Install + +To use `GroqModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `groq` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[groq]" +``` + +## Configuration + +To use [Groq](https://groq.com/) through their API, go to [console.groq.com/keys](https://console.groq.com/keys) and follow your nose until you find the place to generate an API key. + +`GroqModelName` contains a list of available Groq models. + +## Environment variable + +Once you have the API key, you can set it as an environment variable: + +```bash +export GROQ_API_KEY='your-api-key' +``` + +You can then use `GroqModel` by name: + +```python +from pydantic_ai import Agent + +agent = Agent('groq:llama-3.3-70b-versatile') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.groq import GroqModel + +model = GroqModel('llama-3.3-70b-versatile') +agent = Agent(model) +... +``` + +## `provider` argument + +You can provide a custom `Provider` via the `provider` argument: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +model = GroqModel( + 'llama-3.3-70b-versatile', provider=GroqProvider(api_key='your-api-key') +) +agent = Agent(model) +... +``` + +You can also customize the `GroqProvider` with a custom `httpx.AsyncHTTPClient`: + +```python +from httpx import AsyncClient + +from pydantic_ai import Agent +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +custom_http_client = AsyncClient(timeout=30) +model = GroqModel( + 'llama-3.3-70b-versatile', + provider=GroqProvider(api_key='your-api-key', http_client=custom_http_client), +) +agent = Agent(model) +... +``` diff --git a/docs/models/overview.md b/docs/models/overview.md index 75cf954b11..8e7b4aadc9 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -5,6 +5,7 @@ Pydantic AI is model-agnostic and has built-in support for multiple model provid * [OpenAI](openai.md) * [Anthropic](anthropic.md) * [Gemini](google.md) (via two different APIs: Generative Language API and VertexAI API) +* [Grok](grok.md) * [Groq](groq.md) * [Mistral](mistral.md) * [Cohere](cohere.md) diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py new file mode 100644 index 0000000000..5ec8c0ebf6 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -0,0 +1,340 @@ +"""Grok model implementation using xAI SDK.""" + +import os +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator + +from .._run_context import RunContext +from ..messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + SystemPromptPart, + UserPromptPart, + ToolReturnPart, + TextPart, + ToolCallPart, + FinishReason, +) +from ..models import ( + Model, + ModelRequestParameters, + ModelSettings, + StreamedResponse, +) +from ..usage import RequestUsage +from .._utils import now_utc + +# Import xai_sdk components +from xai_sdk import AsyncClient +from xai_sdk.chat import system, user, assistant, tool, tool_result +import xai_sdk.chat as chat_types + + +class GrokModel(Model): + """A model that uses the xAI SDK to interact with Grok.""" + + _model_name: str + _api_key: str + + def __init__( + self, + model_name: str, + *, + api_key: str | None = None, + settings: ModelSettings | None = None, + ): + """Initialize the Grok model. + + Args: + model_name: The name of the Grok model to use (e.g., "grok-3", "grok-4-fast-non-reasoning") + api_key: The xAI API key. If not provided, uses XAI_API_KEY environment variable. + settings: Optional model settings. + """ + super().__init__(settings=settings) + self._model_name = model_name + self._api_key = api_key or os.getenv("XAI_API_KEY") or "" + if not self._api_key: + raise ValueError("XAI API key is required") + + @property + def model_name(self) -> str: + """The model name.""" + return self._model_name + + @property + def system(self) -> str: + """The model provider.""" + return "xai" + + def _map_messages(self, messages: list[ModelMessage]) -> list[chat_types.chat_pb2.Message]: + """Convert pydantic_ai messages to xAI SDK messages.""" + xai_messages = [] + + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + if isinstance(part, SystemPromptPart): + xai_messages.append(system(part.content)) + elif isinstance(part, UserPromptPart): + # Handle user prompt content + if isinstance(part.content, str): + xai_messages.append(user(part.content)) + else: + # Handle complex content (images, etc.) + # For now, just concatenate text content + text_parts = [] + for item in part.content: + if isinstance(item, str): + text_parts.append(item) + if text_parts: + xai_messages.append(user(" ".join(text_parts))) + elif isinstance(part, ToolReturnPart): + xai_messages.append(tool_result(part.model_response_str())) + elif isinstance(message, ModelResponse): + content_parts = [] + for part in message.parts: + if isinstance(part, TextPart): + content_parts.append(part.content) + elif isinstance(part, ToolCallPart): + # Tool calls will be handled separately in the response processing + pass + + if content_parts: + xai_messages.append(assistant(" ".join(content_parts))) + + return xai_messages + + def _map_tools( + self, model_request_parameters: ModelRequestParameters + ) -> list[chat_types.chat_pb2.Tool]: + """Convert pydantic_ai tool definitions to xAI SDK tools.""" + tools = [] + for tool_def in model_request_parameters.tool_defs.values(): + xai_tool = tool( + name=tool_def.name, + description=tool_def.description or "", + parameters=tool_def.parameters_json_schema, + ) + tools.append(xai_tool) + return tools + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + """Make a request to the Grok model.""" + # Create client in the current async context to avoid event loop issues + client = AsyncClient(api_key=self._api_key) + + # Convert messages to xAI format + xai_messages = self._map_messages(messages) + + # Convert tools if any + tools = ( + self._map_tools(model_request_parameters) + if model_request_parameters.tool_defs + else None + ) + + # Filter model settings to only include xAI SDK compatible parameters + xai_settings = {} + if model_settings: + # Map pydantic_ai settings to xAI SDK parameters + if "temperature" in model_settings: + xai_settings["temperature"] = model_settings["temperature"] + if "top_p" in model_settings: + xai_settings["top_p"] = model_settings["top_p"] + if "max_tokens" in model_settings: + xai_settings["max_tokens"] = model_settings["max_tokens"] + if "stop_sequences" in model_settings: + xai_settings["stop"] = model_settings["stop_sequences"] + if "seed" in model_settings: + xai_settings["seed"] = model_settings["seed"] + + # Create chat instance + chat = client.chat.create( + model=self._model_name, messages=xai_messages, tools=tools, **xai_settings + ) + + # Sample the response + response = await chat.sample() + + # Convert response to pydantic_ai format + return self._process_response(response) + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + """Make a streaming request to the Grok model.""" + # Create client in the current async context to avoid event loop issues + client = AsyncClient(api_key=self._api_key) + + # Convert messages to xAI format + xai_messages = self._map_messages(messages) + + # Convert tools if any + tools = ( + self._map_tools(model_request_parameters) + if model_request_parameters.tool_defs + else None + ) + + # Filter model settings to only include xAI SDK compatible parameters + xai_settings = {} + if model_settings: + # Map pydantic_ai settings to xAI SDK parameters + if "temperature" in model_settings: + xai_settings["temperature"] = model_settings["temperature"] + if "top_p" in model_settings: + xai_settings["top_p"] = model_settings["top_p"] + if "max_tokens" in model_settings: + xai_settings["max_tokens"] = model_settings["max_tokens"] + if "stop_sequences" in model_settings: + xai_settings["stop"] = model_settings["stop_sequences"] + if "seed" in model_settings: + xai_settings["seed"] = model_settings["seed"] + + # Create chat instance + chat = client.chat.create( + model=self._model_name, messages=xai_messages, tools=tools, **xai_settings + ) + + # Stream the response + response_stream = chat.stream() + streamed_response = GrokStreamedResponse(model_request_parameters) + streamed_response._model_name = self._model_name + streamed_response._response = response_stream + streamed_response._timestamp = now_utc() + streamed_response._provider_name = "xai" + yield streamed_response + + def _process_response(self, response: chat_types.Response) -> ModelResponse: + """Convert xAI SDK response to pydantic_ai ModelResponse.""" + from typing import cast + + parts = [] + + # Add text content + if response.content: + parts.append(TextPart(content=response.content)) + + # Add tool calls + for tool_call in response.tool_calls: + parts.append( + ToolCallPart( + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + ) + + # Convert usage - try to access attributes, default to 0 if not available + input_tokens = getattr(response.usage, "input_tokens", 0) + output_tokens = getattr(response.usage, "output_tokens", 0) + usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) + + # Map finish reason + finish_reason_map = { + "stop": "stop", + "length": "length", + "content_filter": "content_filter", + "max_output_tokens": "length", + "cancelled": "error", + "failed": "error", + } + raw_finish_reason = response.finish_reason + mapped_reason = ( + finish_reason_map.get(raw_finish_reason, "stop") + if isinstance(raw_finish_reason, str) + else "stop" + ) + finish_reason = cast(FinishReason, mapped_reason) + + return ModelResponse( + parts=parts, + usage=usage, + model_name=self._model_name, + timestamp=now_utc(), + provider_name="xai", + finish_reason=finish_reason, + ) + + +class GrokStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for xAI SDK.""" + + _model_name: str + _response: Any # xai_sdk chat stream + _timestamp: Any + _provider_name: str + _usage: RequestUsage + provider_response_id: str | None + finish_reason: Any + + async def _get_event_iterator(self): + """Iterate over streaming events from xAI SDK.""" + from typing import cast + + async for response, chunk in self._response: + # Update usage if available + if hasattr(response, "usage"): + input_tokens = getattr(response.usage, "input_tokens", 0) + output_tokens = getattr(response.usage, "output_tokens", 0) + self._usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) + + # Set provider response ID + if hasattr(response, "id") and self.provider_response_id is None: + self.provider_response_id = response.id + + # Handle finish reason + if hasattr(response, "finish_reason") and response.finish_reason: + finish_reason_map = { + "stop": "stop", + "length": "length", + "content_filter": "content_filter", + "max_output_tokens": "length", + "cancelled": "error", + "failed": "error", + } + mapped_reason = finish_reason_map.get(response.finish_reason, "stop") + self.finish_reason = cast(FinishReason, mapped_reason) + + # Handle text content + if hasattr(chunk, "content") and chunk.content: + yield self._parts_manager.handle_text_delta( + vendor_part_id="content", + content=chunk.content, + ) + + # Handle tool calls + if hasattr(chunk, "tool_calls"): + for tool_call in chunk.tool_calls: + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=tool_call.id, + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self._model_name + + @property + def provider_name(self) -> str: + """Get the provider name.""" + return self._provider_name + + @property + def timestamp(self): + """Get the timestamp of the response.""" + return self._timestamp diff --git a/pydantic_ai_slim/pydantic_ai/profiles/grok.py b/pydantic_ai_slim/pydantic_ai/profiles/grok.py index 3b7c4a3746..9a1a9317c4 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/grok.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/grok.py @@ -1,8 +1,24 @@ from __future__ import annotations as _annotations +from dataclasses import dataclass + from . import ModelProfile +@dataclass(kw_only=True) +class GrokModelProfile(ModelProfile): + """Profile for models used with GroqModel. + + ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + """ + + builtin_tool: bool = False + """Whether the model always has the web search built-in tool available.""" + + def grok_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Grok model.""" - return None + return GrokModelProfile( + # Support tool calling for building tools + builtin_tool=model_name.startswith('grok-4'), + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/grok.py b/pydantic_ai_slim/pydantic_ai/providers/grok.py index 604a38abbf..65ae8946d7 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/grok.py +++ b/pydantic_ai_slim/pydantic_ai/providers/grok.py @@ -25,6 +25,8 @@ GrokModelName = Literal[ 'grok-4', 'grok-4-0709', + 'grok-4-fast-reasoning', + 'grok-4-fast-non-reasoning', 'grok-3', 'grok-3-mini', 'grok-3-fast', diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 79a3982c02..85fb2a7e0c 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -72,6 +72,7 @@ cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.46.0"] anthropic = ["anthropic>=0.70.0"] +grok = ["xai-sdk>=1.4.0"] groq = ["groq>=0.25.0"] mistral = ["mistralai>=1.9.10"] bedrock = ["boto3>=1.39.0"] diff --git a/pyproject.toml b/pyproject.toml index 3c13afdece..b8aa47e7f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.10" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,fastmcp,evals,ag-ui,retries,temporal,logfire,ui]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,grok,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,fastmcp,evals,ag-ui,retries,temporal,logfire,ui]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/uv.lock b/uv.lock index 7d94519296..6ae5258d6f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -1410,18 +1410,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] -[[package]] -name = "deprecated" -version = "1.2.18" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wrapt" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/98/97/06afe62762c9a8a86af0cfb7bfdab22a43ad17138b07af5b1a58442690a2/deprecated-1.2.18.tar.gz", hash = "sha256:422b6f6d859da6f2ef57857761bfb392480502a64c3028ca9bbe86085d72115d", size = 2928744, upload-time = "2025-01-27T10:46:25.7Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6e/c6/ac0b6c1e2d138f1002bcf799d330bd6d85084fece321e662a14223794041/Deprecated-1.2.18-py2.py3-none-any.whl", hash = "sha256:bd5011788200372a32418f888e326a09ff80d0214bd961147cfed01b5c018eec", size = 9998, upload-time = "2025-01-27T10:46:09.186Z" }, -] - [[package]] name = "depyf" version = "0.19.0" @@ -2015,14 +2003,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.68.0" +version = "1.72.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/d2/c08f0d9f94b45faca68e355771329cba2411c777c8713924dd1baee0e09c/googleapis_common_protos-1.68.0.tar.gz", hash = "sha256:95d38161f4f9af0d9423eed8fb7b64ffd2568c3464eb542ff02c5bfa1953ab3c", size = 57367, upload-time = "2025-02-20T19:08:28.426Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/7b/adfd75544c415c487b33061fe7ae526165241c1ea133f9a9125a56b39fd8/googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5", size = 147433, upload-time = "2025-11-06T18:29:24.087Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/85/c99a157ee99d67cc6c9ad123abb8b1bfb476fab32d2f3511c59314548e4f/googleapis_common_protos-1.68.0-py2.py3-none-any.whl", hash = "sha256:aaf179b2f81df26dfadac95def3b16a95064c76a5f45f07e4c68a21bb371c4ac", size = 164985, upload-time = "2025-02-20T19:08:26.964Z" }, + { url = "https://files.pythonhosted.org/packages/c4/ab/09169d5a4612a5f92490806649ac8d41e3ec9129c636754575b3553f4ea4/googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038", size = 297515, upload-time = "2025-11-06T18:29:13.14Z" }, ] [[package]] @@ -2193,6 +2181,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4d/11/1019a6cfdb2e520cb461cf70d859216be8ca122ddf5ad301fc3b0ee45fd4/groq-0.25.0-py3-none-any.whl", hash = "sha256:aadc78b40b1809cdb196b1aa8c7f7293108767df1508cafa3e0d5045d9328e7a", size = 129371, upload-time = "2025-05-16T19:57:41.786Z" }, ] +[[package]] +name = "grpcio" +version = "1.76.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/e0/318c1ce3ae5a17894d5791e87aea147587c9e702f24122cc7a5c8bbaeeb1/grpcio-1.76.0.tar.gz", hash = "sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73", size = 12785182, upload-time = "2025-10-21T16:23:12.106Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/17/ff4795dc9a34b6aee6ec379f1b66438a3789cd1315aac0cbab60d92f74b3/grpcio-1.76.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc", size = 5840037, upload-time = "2025-10-21T16:20:25.069Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ff/35f9b96e3fa2f12e1dcd58a4513a2e2294a001d64dec81677361b7040c9a/grpcio-1.76.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde", size = 11836482, upload-time = "2025-10-21T16:20:30.113Z" }, + { url = "https://files.pythonhosted.org/packages/3e/1c/8374990f9545e99462caacea5413ed783014b3b66ace49e35c533f07507b/grpcio-1.76.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3", size = 6407178, upload-time = "2025-10-21T16:20:32.733Z" }, + { url = "https://files.pythonhosted.org/packages/1e/77/36fd7d7c75a6c12542c90a6d647a27935a1ecaad03e0ffdb7c42db6b04d2/grpcio-1.76.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990", size = 7075684, upload-time = "2025-10-21T16:20:35.435Z" }, + { url = "https://files.pythonhosted.org/packages/38/f7/e3cdb252492278e004722306c5a8935eae91e64ea11f0af3437a7de2e2b7/grpcio-1.76.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af", size = 6611133, upload-time = "2025-10-21T16:20:37.541Z" }, + { url = "https://files.pythonhosted.org/packages/7e/20/340db7af162ccd20a0893b5f3c4a5d676af7b71105517e62279b5b61d95a/grpcio-1.76.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2", size = 7195507, upload-time = "2025-10-21T16:20:39.643Z" }, + { url = "https://files.pythonhosted.org/packages/10/f0/b2160addc1487bd8fa4810857a27132fb4ce35c1b330c2f3ac45d697b106/grpcio-1.76.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6", size = 8160651, upload-time = "2025-10-21T16:20:42.492Z" }, + { url = "https://files.pythonhosted.org/packages/2c/2c/ac6f98aa113c6ef111b3f347854e99ebb7fb9d8f7bb3af1491d438f62af4/grpcio-1.76.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3", size = 7620568, upload-time = "2025-10-21T16:20:45.995Z" }, + { url = "https://files.pythonhosted.org/packages/90/84/7852f7e087285e3ac17a2703bc4129fafee52d77c6c82af97d905566857e/grpcio-1.76.0-cp310-cp310-win32.whl", hash = "sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b", size = 3998879, upload-time = "2025-10-21T16:20:48.592Z" }, + { url = "https://files.pythonhosted.org/packages/10/30/d3d2adcbb6dd3ff59d6ac3df6ef830e02b437fb5c90990429fd180e52f30/grpcio-1.76.0-cp310-cp310-win_amd64.whl", hash = "sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b", size = 4706892, upload-time = "2025-10-21T16:20:50.697Z" }, + { url = "https://files.pythonhosted.org/packages/a0/00/8163a1beeb6971f66b4bbe6ac9457b97948beba8dd2fc8e1281dce7f79ec/grpcio-1.76.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a", size = 5843567, upload-time = "2025-10-21T16:20:52.829Z" }, + { url = "https://files.pythonhosted.org/packages/10/c1/934202f5cf335e6d852530ce14ddb0fef21be612ba9ecbbcbd4d748ca32d/grpcio-1.76.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c", size = 11848017, upload-time = "2025-10-21T16:20:56.705Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/8dec16b1863d74af6eb3543928600ec2195af49ca58b16334972f6775663/grpcio-1.76.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465", size = 6412027, upload-time = "2025-10-21T16:20:59.3Z" }, + { url = "https://files.pythonhosted.org/packages/d7/64/7b9e6e7ab910bea9d46f2c090380bab274a0b91fb0a2fe9b0cd399fffa12/grpcio-1.76.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48", size = 7075913, upload-time = "2025-10-21T16:21:01.645Z" }, + { url = "https://files.pythonhosted.org/packages/68/86/093c46e9546073cefa789bd76d44c5cb2abc824ca62af0c18be590ff13ba/grpcio-1.76.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da", size = 6615417, upload-time = "2025-10-21T16:21:03.844Z" }, + { url = "https://files.pythonhosted.org/packages/f7/b6/5709a3a68500a9c03da6fb71740dcdd5ef245e39266461a03f31a57036d8/grpcio-1.76.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397", size = 7199683, upload-time = "2025-10-21T16:21:06.195Z" }, + { url = "https://files.pythonhosted.org/packages/91/d3/4b1f2bf16ed52ce0b508161df3a2d186e4935379a159a834cb4a7d687429/grpcio-1.76.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749", size = 8163109, upload-time = "2025-10-21T16:21:08.498Z" }, + { url = "https://files.pythonhosted.org/packages/5c/61/d9043f95f5f4cf085ac5dd6137b469d41befb04bd80280952ffa2a4c3f12/grpcio-1.76.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00", size = 7626676, upload-time = "2025-10-21T16:21:10.693Z" }, + { url = "https://files.pythonhosted.org/packages/36/95/fd9a5152ca02d8881e4dd419cdd790e11805979f499a2e5b96488b85cf27/grpcio-1.76.0-cp311-cp311-win32.whl", hash = "sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054", size = 3997688, upload-time = "2025-10-21T16:21:12.746Z" }, + { url = "https://files.pythonhosted.org/packages/60/9c/5c359c8d4c9176cfa3c61ecd4efe5affe1f38d9bae81e81ac7186b4c9cc8/grpcio-1.76.0-cp311-cp311-win_amd64.whl", hash = "sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d", size = 4709315, upload-time = "2025-10-21T16:21:15.26Z" }, + { url = "https://files.pythonhosted.org/packages/bf/05/8e29121994b8d959ffa0afd28996d452f291b48cfc0875619de0bde2c50c/grpcio-1.76.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8", size = 5799718, upload-time = "2025-10-21T16:21:17.939Z" }, + { url = "https://files.pythonhosted.org/packages/d9/75/11d0e66b3cdf998c996489581bdad8900db79ebd83513e45c19548f1cba4/grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280", size = 11825627, upload-time = "2025-10-21T16:21:20.466Z" }, + { url = "https://files.pythonhosted.org/packages/28/50/2f0aa0498bc188048f5d9504dcc5c2c24f2eb1a9337cd0fa09a61a2e75f0/grpcio-1.76.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4", size = 6359167, upload-time = "2025-10-21T16:21:23.122Z" }, + { url = "https://files.pythonhosted.org/packages/66/e5/bbf0bb97d29ede1d59d6588af40018cfc345b17ce979b7b45424628dc8bb/grpcio-1.76.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11", size = 7044267, upload-time = "2025-10-21T16:21:25.995Z" }, + { url = "https://files.pythonhosted.org/packages/f5/86/f6ec2164f743d9609691115ae8ece098c76b894ebe4f7c94a655c6b03e98/grpcio-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6", size = 6573963, upload-time = "2025-10-21T16:21:28.631Z" }, + { url = "https://files.pythonhosted.org/packages/60/bc/8d9d0d8505feccfdf38a766d262c71e73639c165b311c9457208b56d92ae/grpcio-1.76.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8", size = 7164484, upload-time = "2025-10-21T16:21:30.837Z" }, + { url = "https://files.pythonhosted.org/packages/67/e6/5d6c2fc10b95edf6df9b8f19cf10a34263b7fd48493936fffd5085521292/grpcio-1.76.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980", size = 8127777, upload-time = "2025-10-21T16:21:33.577Z" }, + { url = "https://files.pythonhosted.org/packages/3f/c8/dce8ff21c86abe025efe304d9e31fdb0deaaa3b502b6a78141080f206da0/grpcio-1.76.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882", size = 7594014, upload-time = "2025-10-21T16:21:41.882Z" }, + { url = "https://files.pythonhosted.org/packages/e0/42/ad28191ebf983a5d0ecef90bab66baa5a6b18f2bfdef9d0a63b1973d9f75/grpcio-1.76.0-cp312-cp312-win32.whl", hash = "sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958", size = 3984750, upload-time = "2025-10-21T16:21:44.006Z" }, + { url = "https://files.pythonhosted.org/packages/9e/00/7bd478cbb851c04a48baccaa49b75abaa8e4122f7d86da797500cccdd771/grpcio-1.76.0-cp312-cp312-win_amd64.whl", hash = "sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347", size = 4704003, upload-time = "2025-10-21T16:21:46.244Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ed/71467ab770effc9e8cef5f2e7388beb2be26ed642d567697bb103a790c72/grpcio-1.76.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2", size = 5807716, upload-time = "2025-10-21T16:21:48.475Z" }, + { url = "https://files.pythonhosted.org/packages/2c/85/c6ed56f9817fab03fa8a111ca91469941fb514e3e3ce6d793cb8f1e1347b/grpcio-1.76.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468", size = 11821522, upload-time = "2025-10-21T16:21:51.142Z" }, + { url = "https://files.pythonhosted.org/packages/ac/31/2b8a235ab40c39cbc141ef647f8a6eb7b0028f023015a4842933bc0d6831/grpcio-1.76.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3", size = 6362558, upload-time = "2025-10-21T16:21:54.213Z" }, + { url = "https://files.pythonhosted.org/packages/bd/64/9784eab483358e08847498ee56faf8ff6ea8e0a4592568d9f68edc97e9e9/grpcio-1.76.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb", size = 7049990, upload-time = "2025-10-21T16:21:56.476Z" }, + { url = "https://files.pythonhosted.org/packages/2b/94/8c12319a6369434e7a184b987e8e9f3b49a114c489b8315f029e24de4837/grpcio-1.76.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae", size = 6575387, upload-time = "2025-10-21T16:21:59.051Z" }, + { url = "https://files.pythonhosted.org/packages/15/0f/f12c32b03f731f4a6242f771f63039df182c8b8e2cf8075b245b409259d4/grpcio-1.76.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77", size = 7166668, upload-time = "2025-10-21T16:22:02.049Z" }, + { url = "https://files.pythonhosted.org/packages/ff/2d/3ec9ce0c2b1d92dd59d1c3264aaec9f0f7c817d6e8ac683b97198a36ed5a/grpcio-1.76.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03", size = 8124928, upload-time = "2025-10-21T16:22:04.984Z" }, + { url = "https://files.pythonhosted.org/packages/1a/74/fd3317be5672f4856bcdd1a9e7b5e17554692d3db9a3b273879dc02d657d/grpcio-1.76.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42", size = 7589983, upload-time = "2025-10-21T16:22:07.881Z" }, + { url = "https://files.pythonhosted.org/packages/45/bb/ca038cf420f405971f19821c8c15bcbc875505f6ffadafe9ffd77871dc4c/grpcio-1.76.0-cp313-cp313-win32.whl", hash = "sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f", size = 3984727, upload-time = "2025-10-21T16:22:10.032Z" }, + { url = "https://files.pythonhosted.org/packages/41/80/84087dc56437ced7cdd4b13d7875e7439a52a261e3ab4e06488ba6173b0a/grpcio-1.76.0-cp313-cp313-win_amd64.whl", hash = "sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8", size = 4702799, upload-time = "2025-10-21T16:22:12.709Z" }, + { url = "https://files.pythonhosted.org/packages/b4/46/39adac80de49d678e6e073b70204091e76631e03e94928b9ea4ecf0f6e0e/grpcio-1.76.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62", size = 5808417, upload-time = "2025-10-21T16:22:15.02Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f5/a4531f7fb8b4e2a60b94e39d5d924469b7a6988176b3422487be61fe2998/grpcio-1.76.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd", size = 11828219, upload-time = "2025-10-21T16:22:17.954Z" }, + { url = "https://files.pythonhosted.org/packages/4b/1c/de55d868ed7a8bd6acc6b1d6ddc4aa36d07a9f31d33c912c804adb1b971b/grpcio-1.76.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc", size = 6367826, upload-time = "2025-10-21T16:22:20.721Z" }, + { url = "https://files.pythonhosted.org/packages/59/64/99e44c02b5adb0ad13ab3adc89cb33cb54bfa90c74770f2607eea629b86f/grpcio-1.76.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a", size = 7049550, upload-time = "2025-10-21T16:22:23.637Z" }, + { url = "https://files.pythonhosted.org/packages/43/28/40a5be3f9a86949b83e7d6a2ad6011d993cbe9b6bd27bea881f61c7788b6/grpcio-1.76.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba", size = 6575564, upload-time = "2025-10-21T16:22:26.016Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a9/1be18e6055b64467440208a8559afac243c66a8b904213af6f392dc2212f/grpcio-1.76.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09", size = 7176236, upload-time = "2025-10-21T16:22:28.362Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/dba05d3fcc151ce6e81327541d2cc8394f442f6b350fead67401661bf041/grpcio-1.76.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc", size = 8125795, upload-time = "2025-10-21T16:22:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/4a/45/122df922d05655f63930cf42c9e3f72ba20aadb26c100ee105cad4ce4257/grpcio-1.76.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc", size = 7592214, upload-time = "2025-10-21T16:22:33.831Z" }, + { url = "https://files.pythonhosted.org/packages/4a/6e/0b899b7f6b66e5af39e377055fb4a6675c9ee28431df5708139df2e93233/grpcio-1.76.0-cp314-cp314-win32.whl", hash = "sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e", size = 4062961, upload-time = "2025-10-21T16:22:36.468Z" }, + { url = "https://files.pythonhosted.org/packages/19/41/0b430b01a2eb38ee887f88c1f07644a1df8e289353b78e82b37ef988fb64/grpcio-1.76.0-cp314-cp314-win_amd64.whl", hash = "sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e", size = 4834462, upload-time = "2025-10-21T16:22:39.772Z" }, +] + [[package]] name = "grpclib" version = "0.4.7" @@ -2725,6 +2774,7 @@ version = "0.7.30" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/bf/38/d1ef3ae08d8d857e5e0690c5b1e07bf7eb4a1cae5881d87215826dc6cadb/llguidance-0.7.30.tar.gz", hash = "sha256:e93bf75f2b6e48afb86a5cee23038746975e1654672bf5ba0ae75f7d4d4a2248", size = 1055528, upload-time = "2025-06-23T00:23:49.247Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/e1/694c89986fcae7777184fc8b22baa0976eba15a6847221763f6ad211fc1f/llguidance-0.7.30-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c80af02c118d2b0526bcecaab389af2ed094537a069b0fc724cd2a2f2ba3990f", size = 3327974, upload-time = "2025-06-23T00:23:47.556Z" }, { url = "https://files.pythonhosted.org/packages/fd/77/ab7a548ae189dc23900fdd37803c115c2339b1223af9e8eb1f4329b5935a/llguidance-0.7.30-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:00a256d532911d2cf5ba4ef63e182944e767dd2402f38d63002016bc37755958", size = 3210709, upload-time = "2025-06-23T00:23:45.872Z" }, { url = "https://files.pythonhosted.org/packages/9c/5b/6a166564b14f9f805f0ea01ec233a84f55789cb7eeffe1d6224ccd0e6cdd/llguidance-0.7.30-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af8741c867e4bc7e42f7cdc68350c076b4edd0ca10ecefbde75f15a9f6bc25d0", size = 14867038, upload-time = "2025-06-23T00:23:39.571Z" }, { url = "https://files.pythonhosted.org/packages/af/80/5a40b9689f17612434b820854cba9b8cabd5142072c491b5280fe5f7a35e/llguidance-0.7.30-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9edc409b9decd6cffba5f5bf3b4fbd7541f95daa8cbc9510cbf96c6ab1ffc153", size = 15004926, upload-time = "2025-06-23T00:23:43.965Z" }, @@ -2816,7 +2866,7 @@ wheels = [ [[package]] name = "logfire" -version = "4.0.0" +version = "4.14.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "executing" }, @@ -2828,9 +2878,9 @@ dependencies = [ { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d8/85/4ee1ced49f2c378fd7df9f507d6426da3c3520957bfe56e6c049ccacd4e4/logfire-4.0.0.tar.gz", hash = "sha256:64d95fbf0f05c99a8b4c99a35b5b2971f11adbfbe9a73726df11d01c12f9959c", size = 512056, upload-time = "2025-07-22T15:12:05.951Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/89/d26951b6b21790641720c12cfd6dca0cf7ead0f5ddd7de4299837b90b8b1/logfire-4.14.2.tar.gz", hash = "sha256:8dcedbd59c3d06a8794a93bbf09add788de3b74c45afa821750992f0c822c628", size = 548291, upload-time = "2025-10-24T20:14:39.115Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/06/377ff0eb5d78ba893025eafed6104088eccefb0e538a9bed24e1f5d4fe53/logfire-4.0.0-py3-none-any.whl", hash = "sha256:4e50887d61954f849ec05343ca71b29fec5c0b6e4e945cabbceed664e37966e7", size = 211515, upload-time = "2025-07-22T15:12:02.113Z" }, + { url = "https://files.pythonhosted.org/packages/a7/92/4fba7b8f4f56f721ad279cb0c08164bffa14e93cfd184d1a4cc7151c52a2/logfire-4.14.2-py3-none-any.whl", hash = "sha256:caa8111b20f263f4ebb0ae380a62f2a214aeb07d5e2f03c9300fa096d0a8e692", size = 228364, upload-time = "2025-10-24T20:14:34.495Z" }, ] [package.optional-dependencies] @@ -4139,50 +4189,50 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.30.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "importlib-metadata" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2b/6d/bbbf879826b7f3c89a45252010b5796fb1f1a0d45d9dc4709db0ef9a06c8/opentelemetry_api-1.30.0.tar.gz", hash = "sha256:375893400c1435bf623f7dfb3bcd44825fe6b56c34d0667c542ea8257b1a1240", size = 63703, upload-time = "2025-02-04T18:17:13.789Z" } +sdist = { url = "https://files.pythonhosted.org/packages/08/d8/0f354c375628e048bd0570645b310797299754730079853095bf000fba69/opentelemetry_api-1.38.0.tar.gz", hash = "sha256:f4c193b5e8acb0912b06ac5b16321908dd0843d75049c091487322284a3eea12", size = 65242, upload-time = "2025-10-16T08:35:50.25Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/0a/eea862fae6413d8181b23acf8e13489c90a45f17986ee9cf4eab8a0b9ad9/opentelemetry_api-1.30.0-py3-none-any.whl", hash = "sha256:d5f5284890d73fdf47f843dda3210edf37a38d66f44f2b5aedc1e89ed455dc09", size = 64955, upload-time = "2025-02-04T18:16:46.167Z" }, + { url = "https://files.pythonhosted.org/packages/ae/a2/d86e01c28300bd41bab8f18afd613676e2bd63515417b77636fc1add426f/opentelemetry_api-1.38.0-py3-none-any.whl", hash = "sha256:2891b0197f47124454ab9f0cf58f3be33faca394457ac3e09daba13ff50aa582", size = 65947, upload-time = "2025-10-16T08:35:30.23Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.30.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a2/d7/44098bf1ef89fc5810cdbda05faa2ae9322a0dbda4921cdc965dc68a9856/opentelemetry_exporter_otlp_proto_common-1.30.0.tar.gz", hash = "sha256:ddbfbf797e518411857d0ca062c957080279320d6235a279f7b64ced73c13897", size = 19640, upload-time = "2025-02-04T18:17:16.234Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/83/dd4660f2956ff88ed071e9e0e36e830df14b8c5dc06722dbde1841accbe8/opentelemetry_exporter_otlp_proto_common-1.38.0.tar.gz", hash = "sha256:e333278afab4695aa8114eeb7bf4e44e65c6607d54968271a249c180b2cb605c", size = 20431, upload-time = "2025-10-16T08:35:53.285Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/54/f4b3de49f8d7d3a78fd6e6e1a6fd27dd342eb4d82c088b9078c6a32c3808/opentelemetry_exporter_otlp_proto_common-1.30.0-py3-none-any.whl", hash = "sha256:5468007c81aa9c44dc961ab2cf368a29d3475977df83b4e30aeed42aa7bc3b38", size = 18747, upload-time = "2025-02-04T18:16:51.512Z" }, + { url = "https://files.pythonhosted.org/packages/a7/9e/55a41c9601191e8cd8eb626b54ee6827b9c9d4a46d736f32abc80d8039fc/opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl", hash = "sha256:03cb76ab213300fe4f4c62b7d8f17d97fcfd21b89f0b5ce38ea156327ddda74a", size = 18359, upload-time = "2025-10-16T08:35:34.099Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.30.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "googleapis-common-protos" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp-proto-common" }, { name = "opentelemetry-proto" }, { name = "opentelemetry-sdk" }, { name = "requests" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/04/f9/abb9191d536e6a2e2b7903f8053bf859a76bf784e3ca19a5749550ef19e4/opentelemetry_exporter_otlp_proto_http-1.30.0.tar.gz", hash = "sha256:c3ae75d4181b1e34a60662a6814d0b94dd33b628bee5588a878bed92cee6abdc", size = 15073, upload-time = "2025-02-04T18:17:18.446Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/0a/debcdfb029fbd1ccd1563f7c287b89a6f7bef3b2902ade56797bfd020854/opentelemetry_exporter_otlp_proto_http-1.38.0.tar.gz", hash = "sha256:f16bd44baf15cbe07633c5112ffc68229d0edbeac7b37610be0b2def4e21e90b", size = 17282, upload-time = "2025-10-16T08:35:54.422Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/3c/cdf34bc459613f2275aff9b258f35acdc4c4938dad161d17437de5d4c034/opentelemetry_exporter_otlp_proto_http-1.30.0-py3-none-any.whl", hash = "sha256:9578e790e579931c5ffd50f1e6975cbdefb6a0a0a5dea127a6ae87df10e0a589", size = 17245, upload-time = "2025-02-04T18:16:53.514Z" }, + { url = "https://files.pythonhosted.org/packages/e5/77/154004c99fb9f291f74aa0822a2f5bbf565a72d8126b3a1b63ed8e5f83c7/opentelemetry_exporter_otlp_proto_http-1.38.0-py3-none-any.whl", hash = "sha256:84b937305edfc563f08ec69b9cb2298be8188371217e867c1854d77198d0825b", size = 19579, upload-time = "2025-10-16T08:35:36.269Z" }, ] [[package]] name = "opentelemetry-instrumentation" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4190,14 +4240,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/5a/4c7f02235ac1269b48f3855f6be1afc641f31d4888d28b90b732fbce7141/opentelemetry_instrumentation-0.51b0.tar.gz", hash = "sha256:4ca266875e02f3988536982467f7ef8c32a38b8895490ddce9ad9604649424fa", size = 27760, upload-time = "2025-02-04T18:21:09.279Z" } +sdist = { url = "https://files.pythonhosted.org/packages/04/ed/9c65cd209407fd807fa05be03ee30f159bdac8d59e7ea16a8fe5a1601222/opentelemetry_instrumentation-0.59b0.tar.gz", hash = "sha256:6010f0faaacdaf7c4dff8aac84e226d23437b331dcda7e70367f6d73a7db1adc", size = 31544, upload-time = "2025-10-16T08:39:31.959Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/2c/48fa93f1acca9f79a06da0df7bfe916632ecc7fce1971067b3e46bcae55b/opentelemetry_instrumentation-0.51b0-py3-none-any.whl", hash = "sha256:c6de8bd26b75ec8b0e54dff59e198946e29de6a10ec65488c357d4b34aa5bdcf", size = 30923, upload-time = "2025-02-04T18:19:37.829Z" }, + { url = "https://files.pythonhosted.org/packages/10/f5/7a40ff3f62bfe715dad2f633d7f1174ba1a7dd74254c15b2558b3401262a/opentelemetry_instrumentation-0.59b0-py3-none-any.whl", hash = "sha256:44082cc8fe56b0186e87ee8f7c17c327c4c2ce93bdbe86496e600985d74368ee", size = 33020, upload-time = "2025-10-16T08:38:31.463Z" }, ] [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asgiref" }, @@ -4206,28 +4256,28 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/67/8aa6e1129f641f0f3f8786e6c5d18c1f2bbe490bd4b0e91a6879e85154d2/opentelemetry_instrumentation_asgi-0.51b0.tar.gz", hash = "sha256:b3fe97c00f0bfa934371a69674981d76591c68d937b6422a5716ca21081b4148", size = 24201, upload-time = "2025-02-04T18:21:14.321Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/a4/cfbb6fc1ec0aa9bf5a93f548e6a11ab3ac1956272f17e0d399aa2c1f85bc/opentelemetry_instrumentation_asgi-0.59b0.tar.gz", hash = "sha256:2509d6fe9fd829399ce3536e3a00426c7e3aa359fc1ed9ceee1628b56da40e7a", size = 25116, upload-time = "2025-10-16T08:39:36.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/7e/0a95ab37302729543631a789ba8e71dea75c520495739dbbbdfdc580b401/opentelemetry_instrumentation_asgi-0.51b0-py3-none-any.whl", hash = "sha256:e8072993db47303b633c6ec1bc74726ba4d32bd0c46c28dfadf99f79521a324c", size = 16340, upload-time = "2025-02-04T18:19:49.924Z" }, + { url = "https://files.pythonhosted.org/packages/f3/88/fe02d809963b182aafbf5588685d7a05af8861379b0ec203d48e360d4502/opentelemetry_instrumentation_asgi-0.59b0-py3-none-any.whl", hash = "sha256:ba9703e09d2c33c52fa798171f344c8123488fcd45017887981df088452d3c53", size = 16797, upload-time = "2025-10-16T08:38:37.214Z" }, ] [[package]] name = "opentelemetry-instrumentation-asyncpg" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/fe/95eb7747a37d980787440db8001ab991f54ba4f47ea8635b43644eb2df5f/opentelemetry_instrumentation_asyncpg-0.51b0.tar.gz", hash = "sha256:366fb7f7e2c3a66de28b3770e7e795fd2612eace346dd842b77bbe61a97b7ff1", size = 8656, upload-time = "2025-02-04T18:21:16.107Z" } +sdist = { url = "https://files.pythonhosted.org/packages/64/2b/9bad3483380513b1c4c232dffbc8e54d1f38bad275f86462883b355f0d8e/opentelemetry_instrumentation_asyncpg-0.59b0.tar.gz", hash = "sha256:fada2fa14c8ee77b25c1f4ed37aa21a581449b456a78d814b54c6e5b051d3618", size = 8725, upload-time = "2025-10-16T08:39:38Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/35/ec8638338a1b4623172f86fa7c01a58f30fd5f39c053bbb3fabc9514d7fd/opentelemetry_instrumentation_asyncpg-0.51b0-py3-none-any.whl", hash = "sha256:6180c57c497cee1c787aeb5b090f92b1bb9ee90cb606932adfaf6bf3fdb494a5", size = 9992, upload-time = "2025-02-04T18:19:53.239Z" }, + { url = "https://files.pythonhosted.org/packages/c2/4e/27430be77f066b8c457e2e85d68009a7ff28d635298bce2486b7429da3dd/opentelemetry_instrumentation_asyncpg-0.59b0-py3-none-any.whl", hash = "sha256:538af20d9423bd05f2bdf4c1cab063539cb4db0835340c0b7f45836725e31cb0", size = 10087, upload-time = "2025-10-16T08:38:39.727Z" }, ] [[package]] name = "opentelemetry-instrumentation-dbapi" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4235,14 +4285,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/b7/fdc107617b9f626632f5fbe444a6a91efa4a9d1e38447500802b8a12010c/opentelemetry_instrumentation_dbapi-0.51b0.tar.gz", hash = "sha256:740b5e17eef02a91a8d3966f06e5605817a7d875ae4d9dec8318ef652ccfc1fe", size = 13860, upload-time = "2025-02-04T18:21:23.948Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/aa/36a09652c98c65b42408d40f222fba031a3a281f1b6682e1b141b20b508d/opentelemetry_instrumentation_dbapi-0.59b0.tar.gz", hash = "sha256:c50112ae1cdb7f55bddcf57eca96aaa0f2dd78732be2b00953183439a4740493", size = 16308, upload-time = "2025-10-16T08:39:43.192Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/13/d3cd0292680ebd54ed6d55d7a81434bc2c6f7327d971c6690c98114d6abc/opentelemetry_instrumentation_dbapi-0.51b0-py3-none-any.whl", hash = "sha256:1b4dfb4f25b4ef509b70fb24c637436a40fe5fc8204933b956f1d0ccaa61735f", size = 12373, upload-time = "2025-02-04T18:20:09.771Z" }, + { url = "https://files.pythonhosted.org/packages/e5/9b/1739b5b7926cbae342880d7a56d59a847313e6568a96ba7d4873ce0c0996/opentelemetry_instrumentation_dbapi-0.59b0-py3-none-any.whl", hash = "sha256:672d59caa06754b42d4e722644d9fcd00a1f9f862e9ea5cef6d4da454515ac67", size = 13970, upload-time = "2025-10-16T08:38:48.342Z" }, ] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4251,14 +4301,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2d/dc/8db4422b5084177d1ef6c7855c69bf2e9e689f595a4a9b59e60588e0d427/opentelemetry_instrumentation_fastapi-0.51b0.tar.gz", hash = "sha256:1624e70f2f4d12ceb792d8a0c331244cd6723190ccee01336273b4559bc13abc", size = 19249, upload-time = "2025-02-04T18:21:28.379Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/a7/7a6ce5009584ce97dbfd5ce77d4f9d9570147507363349d2cb705c402bcf/opentelemetry_instrumentation_fastapi-0.59b0.tar.gz", hash = "sha256:e8fe620cfcca96a7d634003df1bc36a42369dedcdd6893e13fb5903aeeb89b2b", size = 24967, upload-time = "2025-10-16T08:39:46.056Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/55/1c/ec2d816b78edf2404d7b3df6d09eefb690b70bfd191b7da06f76634f1bdc/opentelemetry_instrumentation_fastapi-0.51b0-py3-none-any.whl", hash = "sha256:10513bbc11a1188adb9c1d2c520695f7a8f2b5f4de14e8162098035901cd6493", size = 12117, upload-time = "2025-02-04T18:20:15.267Z" }, + { url = "https://files.pythonhosted.org/packages/35/27/5914c8bf140ffc70eff153077e225997c7b054f0bf28e11b9ab91b63b18f/opentelemetry_instrumentation_fastapi-0.59b0-py3-none-any.whl", hash = "sha256:0d8d00ff7d25cca40a4b2356d1d40a8f001e0668f60c102f5aa6bb721d660c4f", size = 13492, upload-time = "2025-10-16T08:38:52.312Z" }, ] [[package]] name = "opentelemetry-instrumentation-httpx" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4267,71 +4317,71 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b7/d5/4a3990c461ae7e55212115e0f8f3aa412b5ce6493579e85c292245ac69ea/opentelemetry_instrumentation_httpx-0.51b0.tar.gz", hash = "sha256:061d426a04bf5215a859fea46662e5074f920e5cbde7e6ad6825a0a1b595802c", size = 17700, upload-time = "2025-02-04T18:21:31.685Z" } +sdist = { url = "https://files.pythonhosted.org/packages/18/6b/1bdf36b68cace9b4eae3cbbade4150c71c90aa392b127dda5bb5c2a49307/opentelemetry_instrumentation_httpx-0.59b0.tar.gz", hash = "sha256:a1cb9b89d9f05a82701cc9ab9cfa3db54fd76932489449778b350bc1b9f0e872", size = 19886, upload-time = "2025-10-16T08:39:48.428Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/ba/23d4ab6402408c01f1c3f32e0c04ea6dae575bf19bcb9a0049c9e768c983/opentelemetry_instrumentation_httpx-0.51b0-py3-none-any.whl", hash = "sha256:2e3fdf755ba6ead6ab43031497c3d55d4c796d0368eccc0ce48d304b7ec6486a", size = 14109, upload-time = "2025-02-04T18:20:19.947Z" }, + { url = "https://files.pythonhosted.org/packages/58/16/c1e0745d20af392ec9060693531d7f01239deb2d81e460d0c379719691b8/opentelemetry_instrumentation_httpx-0.59b0-py3-none-any.whl", hash = "sha256:7dc9f66aef4ca3904d877f459a70c78eafd06131dc64d713b9b1b5a7d0a48f05", size = 15197, upload-time = "2025-10-16T08:38:55.507Z" }, ] [[package]] name = "opentelemetry-instrumentation-sqlite3" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-dbapi" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e8/2a/1755f34fd1d58858272970ce9f8386a488ce2aa16c2673373ed31cc60d33/opentelemetry_instrumentation_sqlite3-0.51b0.tar.gz", hash = "sha256:3bd5dbe2292a68b27b79c44a13a03b1443341404e02351d3886ee6526792ead1", size = 7930, upload-time = "2025-02-04T18:21:47.709Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/c9/316d9800fbb64ac2b5474d17d13f96a37df86e5c06e348a7d143b3eb377f/opentelemetry_instrumentation_sqlite3-0.59b0.tar.gz", hash = "sha256:7b9989d805336a1e78a907b3863376cf4ff1dc96dd8a9e0d385f6bb3686c27ac", size = 7923, upload-time = "2025-10-16T08:40:01.625Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d8/d0/6288eb2b6065b7766eee545729e6e68ac241ce82ec60a8452742414536c7/opentelemetry_instrumentation_sqlite3-0.51b0-py3-none-any.whl", hash = "sha256:77418bfec1b45f4d44a9a316c355aab33d36eb7cc1cd5d871f40acae36ae5c96", size = 9339, upload-time = "2025-02-04T18:20:51.607Z" }, + { url = "https://files.pythonhosted.org/packages/e5/ef/daf9075b22f59f45c8839dcde8d1c4fd3061b6a6692a61150fad6ca7a1a5/opentelemetry_instrumentation_sqlite3-0.59b0-py3-none-any.whl", hash = "sha256:ec13867102687426b835f6c499a287ee2f4195abfba85d372e011a795661914c", size = 9338, upload-time = "2025-10-16T08:39:11.545Z" }, ] [[package]] name = "opentelemetry-proto" -version = "1.30.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/31/6e/c1ff2e3b0cd3a189a6be03fd4d63441d73d7addd9117ab5454e667b9b6c7/opentelemetry_proto-1.30.0.tar.gz", hash = "sha256:afe5c9c15e8b68d7c469596e5b32e8fc085eb9febdd6fb4e20924a93a0389179", size = 34362, upload-time = "2025-02-04T18:17:28.099Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/14/f0c4f0f6371b9cb7f9fa9ee8918bfd59ac7040c7791f1e6da32a1839780d/opentelemetry_proto-1.38.0.tar.gz", hash = "sha256:88b161e89d9d372ce723da289b7da74c3a8354a8e5359992be813942969ed468", size = 46152, upload-time = "2025-10-16T08:36:01.612Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/d7/85de6501f7216995295f7ec11e470142e6a6e080baacec1753bbf272e007/opentelemetry_proto-1.30.0-py3-none-any.whl", hash = "sha256:c6290958ff3ddacc826ca5abbeb377a31c2334387352a259ba0df37c243adc11", size = 55854, upload-time = "2025-02-04T18:17:08.024Z" }, + { url = "https://files.pythonhosted.org/packages/b6/6a/82b68b14efca5150b2632f3692d627afa76b77378c4999f2648979409528/opentelemetry_proto-1.38.0-py3-none-any.whl", hash = "sha256:b6ebe54d3217c42e45462e2a1ae28c3e2bf2ec5a5645236a490f55f45f1a0a18", size = 72535, upload-time = "2025-10-16T08:35:45.749Z" }, ] [[package]] name = "opentelemetry-sdk" -version = "1.30.0" +version = "1.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/93/ee/d710062e8a862433d1be0b85920d0c653abe318878fef2d14dfe2c62ff7b/opentelemetry_sdk-1.30.0.tar.gz", hash = "sha256:c9287a9e4a7614b9946e933a67168450b9ab35f08797eb9bc77d998fa480fa18", size = 158633, upload-time = "2025-02-04T18:17:28.908Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/cb/f0eee1445161faf4c9af3ba7b848cc22a50a3d3e2515051ad8628c35ff80/opentelemetry_sdk-1.38.0.tar.gz", hash = "sha256:93df5d4d871ed09cb4272305be4d996236eedb232253e3ab864c8620f051cebe", size = 171942, upload-time = "2025-10-16T08:36:02.257Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/97/28/64d781d6adc6bda2260067ce2902bd030cf45aec657e02e28c5b4480b976/opentelemetry_sdk-1.30.0-py3-none-any.whl", hash = "sha256:14fe7afc090caad881addb6926cec967129bd9260c4d33ae6a217359f6b61091", size = 118717, upload-time = "2025-02-04T18:17:09.353Z" }, + { url = "https://files.pythonhosted.org/packages/2f/2e/e93777a95d7d9c40d270a371392b6d6f1ff170c2a3cb32d6176741b5b723/opentelemetry_sdk-1.38.0-py3-none-any.whl", hash = "sha256:1c66af6564ecc1553d72d811a01df063ff097cdc82ce188da9951f93b8d10f6b", size = 132349, upload-time = "2025-10-16T08:35:46.995Z" }, ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "opentelemetry-api" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/c0/0f9ef4605fea7f2b83d55dd0b0d7aebe8feead247cd6facd232b30907b4f/opentelemetry_semantic_conventions-0.51b0.tar.gz", hash = "sha256:3fabf47f35d1fd9aebcdca7e6802d86bd5ebc3bc3408b7e3248dde6e87a18c47", size = 107191, upload-time = "2025-02-04T18:17:29.903Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/bc/8b9ad3802cd8ac6583a4eb7de7e5d7db004e89cb7efe7008f9c8a537ee75/opentelemetry_semantic_conventions-0.59b0.tar.gz", hash = "sha256:7a6db3f30d70202d5bf9fa4b69bc866ca6a30437287de6c510fb594878aed6b0", size = 129861, upload-time = "2025-10-16T08:36:03.346Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/75/d7bdbb6fd8630b4cafb883482b75c4fc276b6426619539d266e32ac53266/opentelemetry_semantic_conventions-0.51b0-py3-none-any.whl", hash = "sha256:fdc777359418e8d06c86012c3dc92c88a6453ba662e941593adb062e48c2eeae", size = 177416, upload-time = "2025-02-04T18:17:11.305Z" }, + { url = "https://files.pythonhosted.org/packages/24/7d/c88d7b15ba8fe5c6b8f93be50fc11795e9fc05386c44afaf6b76fe191f9b/opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl", hash = "sha256:35d3b8833ef97d614136e253c1da9342b4c3c083bbaf29ce31d572a1c3825eed", size = 207954, upload-time = "2025-10-16T08:35:48.054Z" }, ] [[package]] name = "opentelemetry-util-http" -version = "0.51b0" +version = "0.59b0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/58/64/32510c0a803465eb6ef1f5bd514d0f5627f8abc9444ed94f7240faf6fcaa/opentelemetry_util_http-0.51b0.tar.gz", hash = "sha256:05edd19ca1cc3be3968b1e502fd94816901a365adbeaab6b6ddb974384d3a0b9", size = 8043, upload-time = "2025-02-04T18:21:59.811Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f7/13cd081e7851c42520ab0e96efb17ffbd901111a50b8252ec1e240664020/opentelemetry_util_http-0.59b0.tar.gz", hash = "sha256:ae66ee91be31938d832f3b4bc4eb8a911f6eddd38969c4a871b1230db2a0a560", size = 9412, upload-time = "2025-10-16T08:40:11.335Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/dd/c371eeb9cc78abbdad231a27ce1a196a37ef96328d876ccbb381dea4c8ee/opentelemetry_util_http-0.51b0-py3-none-any.whl", hash = "sha256:0561d7a6e9c422b9ef9ae6e77eafcfcd32a2ab689f5e801475cbb67f189efa20", size = 7304, upload-time = "2025-02-04T18:21:05.483Z" }, + { url = "https://files.pythonhosted.org/packages/20/56/62282d1d4482061360449dacc990c89cad0fc810a2ed937b636300f55023/opentelemetry_util_http-0.59b0-py3-none-any.whl", hash = "sha256:6d036a07563bce87bf521839c0671b507a02a0d39d7ea61b88efa14c6e25355d", size = 7648, upload-time = "2025-10-16T08:39:25.706Z" }, ] [[package]] @@ -4481,11 +4531,11 @@ wheels = [ [[package]] name = "packaging" -version = "24.2" +version = "25.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" }, + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, ] [[package]] @@ -4942,16 +4992,17 @@ wheels = [ [[package]] name = "protobuf" -version = "5.29.3" +version = "6.33.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f7/d1/e0a911544ca9993e0f17ce6d3cc0932752356c1b0a834397f28e63479344/protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620", size = 424945, upload-time = "2025-01-08T21:38:51.572Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/ff/64a6c8f420818bb873713988ca5492cba3a7946be57e027ac63495157d97/protobuf-6.33.0.tar.gz", hash = "sha256:140303d5c8d2037730c548f8c7b93b20bb1dc301be280c378b82b8894589c954", size = 443463, upload-time = "2025-10-15T20:39:52.159Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/7a/1e38f3cafa022f477ca0f57a1f49962f21ad25850c3ca0acd3b9d0091518/protobuf-5.29.3-cp310-abi3-win32.whl", hash = "sha256:3ea51771449e1035f26069c4c7fd51fba990d07bc55ba80701c78f886bf9c888", size = 422708, upload-time = "2025-01-08T21:38:31.799Z" }, - { url = "https://files.pythonhosted.org/packages/61/fa/aae8e10512b83de633f2646506a6d835b151edf4b30d18d73afd01447253/protobuf-5.29.3-cp310-abi3-win_amd64.whl", hash = "sha256:a4fa6f80816a9a0678429e84973f2f98cbc218cca434abe8db2ad0bffc98503a", size = 434508, upload-time = "2025-01-08T21:38:35.489Z" }, - { url = "https://files.pythonhosted.org/packages/dd/04/3eaedc2ba17a088961d0e3bd396eac764450f431621b58a04ce898acd126/protobuf-5.29.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a8434404bbf139aa9e1300dbf989667a83d42ddda9153d8ab76e0d5dcaca484e", size = 417825, upload-time = "2025-01-08T21:38:36.642Z" }, - { url = "https://files.pythonhosted.org/packages/4f/06/7c467744d23c3979ce250397e26d8ad8eeb2bea7b18ca12ad58313c1b8d5/protobuf-5.29.3-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:daaf63f70f25e8689c072cfad4334ca0ac1d1e05a92fc15c54eb9cf23c3efd84", size = 319573, upload-time = "2025-01-08T21:38:37.896Z" }, - { url = "https://files.pythonhosted.org/packages/a8/45/2ebbde52ad2be18d3675b6bee50e68cd73c9e0654de77d595540b5129df8/protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:c027e08a08be10b67c06bf2370b99c811c466398c357e615ca88c91c07f0910f", size = 319672, upload-time = "2025-01-08T21:38:40.204Z" }, - { url = "https://files.pythonhosted.org/packages/fd/b2/ab07b09e0f6d143dfb839693aa05765257bceaa13d03bf1a696b78323e7a/protobuf-5.29.3-py3-none-any.whl", hash = "sha256:0a18ed4a24198528f2333802eb075e59dea9d679ab7a6c5efb017a59004d849f", size = 172550, upload-time = "2025-01-08T21:38:50.439Z" }, + { url = "https://files.pythonhosted.org/packages/7e/ee/52b3fa8feb6db4a833dfea4943e175ce645144532e8a90f72571ad85df4e/protobuf-6.33.0-cp310-abi3-win32.whl", hash = "sha256:d6101ded078042a8f17959eccd9236fb7a9ca20d3b0098bbcb91533a5680d035", size = 425593, upload-time = "2025-10-15T20:39:40.29Z" }, + { url = "https://files.pythonhosted.org/packages/7b/c6/7a465f1825872c55e0341ff4a80198743f73b69ce5d43ab18043699d1d81/protobuf-6.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:9a031d10f703f03768f2743a1c403af050b6ae1f3480e9c140f39c45f81b13ee", size = 436882, upload-time = "2025-10-15T20:39:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/e1/a9/b6eee662a6951b9c3640e8e452ab3e09f117d99fc10baa32d1581a0d4099/protobuf-6.33.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:905b07a65f1a4b72412314082c7dbfae91a9e8b68a0cc1577515f8df58ecf455", size = 427521, upload-time = "2025-10-15T20:39:43.803Z" }, + { url = "https://files.pythonhosted.org/packages/10/35/16d31e0f92c6d2f0e77c2a3ba93185130ea13053dd16200a57434c882f2b/protobuf-6.33.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e0697ece353e6239b90ee43a9231318302ad8353c70e6e45499fa52396debf90", size = 324445, upload-time = "2025-10-15T20:39:44.932Z" }, + { url = "https://files.pythonhosted.org/packages/e6/eb/2a981a13e35cda8b75b5585aaffae2eb904f8f351bdd3870769692acbd8a/protobuf-6.33.0-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:e0a1715e4f27355afd9570f3ea369735afc853a6c3951a6afe1f80d8569ad298", size = 339159, upload-time = "2025-10-15T20:39:46.186Z" }, + { url = "https://files.pythonhosted.org/packages/21/51/0b1cbad62074439b867b4e04cc09b93f6699d78fd191bed2bbb44562e077/protobuf-6.33.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:35be49fd3f4fefa4e6e2aacc35e8b837d6703c37a2168a55ac21e9b1bc7559ef", size = 323172, upload-time = "2025-10-15T20:39:47.465Z" }, + { url = "https://files.pythonhosted.org/packages/07/d1/0a28c21707807c6aacd5dc9c3704b2aa1effbf37adebd8caeaf68b17a636/protobuf-6.33.0-py3-none-any.whl", hash = "sha256:25c9e1963c6734448ea2d308cfa610e692b801304ba0908d7bfa564ac5132995", size = 170477, upload-time = "2025-10-15T20:39:51.311Z" }, ] [[package]] @@ -5330,7 +5381,7 @@ email = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "temporal", "ui", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "fastmcp", "google", "grok", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "temporal", "ui", "vertexai"] }, ] [package.optional-dependencies] @@ -5409,7 +5460,7 @@ lint = [ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "temporal", "ui", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "fastmcp", "google", "grok", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "temporal", "ui", "vertexai"], editable = "pydantic_ai_slim" }, { name = "pydantic-ai-slim", extras = ["dbos"], marker = "extra == 'dbos'", editable = "pydantic_ai_slim" }, { name = "pydantic-ai-slim", extras = ["outlines-llamacpp"], marker = "extra == 'outlines-llamacpp'", editable = "pydantic_ai_slim" }, { name = "pydantic-ai-slim", extras = ["outlines-mlxlm"], marker = "extra == 'outlines-mlxlm'", editable = "pydantic_ai_slim" }, @@ -5555,6 +5606,9 @@ fastmcp = [ google = [ { name = "google-genai" }, ] +grok = [ + { name = "xai-sdk" }, +] groq = [ { name = "groq" }, ] @@ -5667,8 +5721,9 @@ requires-dist = [ { name = "transformers", marker = "extra == 'outlines-transformers'", specifier = ">=4.0.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, { name = "vllm", marker = "(python_full_version < '3.12' and platform_machine != 'x86_64' and extra == 'outlines-vllm-offline') or (python_full_version < '3.12' and sys_platform != 'darwin' and extra == 'outlines-vllm-offline')" }, + { name = "xai-sdk", marker = "extra == 'grok'", specifier = ">=1.4.0" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "dbos", "duckduckgo", "evals", "fastmcp", "google", "grok", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "outlines-llamacpp", "outlines-mlxlm", "outlines-sglang", "outlines-transformers", "outlines-vllm-offline", "prefect", "retries", "tavily", "temporal", "ui", "vertexai"] [[package]] name = "pydantic-core" @@ -8523,6 +8578,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/82/f56956041adef78f849db6b289b282e72b55ab8045a75abad81898c28d19/wrapt-1.17.2-py3-none-any.whl", hash = "sha256:b18f2d1533a71f069c7f82d524a52599053d4c7166e9dd374ae2136b7f40f7c8", size = 23594, upload-time = "2025-01-14T10:35:44.018Z" }, ] +[[package]] +name = "xai-sdk" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "grpcio" }, + { name = "opentelemetry-sdk" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/cf/c9ccc20bd419f4fce088cd3e1778fb6b3420526ff4599c2bf6caf1427e99/xai_sdk-1.4.0.tar.gz", hash = "sha256:90e6e0b929395816a8474a332e6d996fbd7c56c3e9922b3894d14ef90b4adc37", size = 314502, upload-time = "2025-11-07T23:55:07.722Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/e5/8cbdd56008e8194880151f151db62b1b2331d51de8f8e788b91524279611/xai_sdk-1.4.0-py3-none-any.whl", hash = "sha256:2635d661995ef1424fd5b5de6a9b7d6a11bad49a34afb19b04a330c40d90e0d1", size = 185691, upload-time = "2025-11-07T23:55:06.168Z" }, +] + [[package]] name = "xformers" version = "0.0.32.post1" @@ -8552,14 +8625,17 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/f2/a9/dc3c63cf7f082d183711e46ef34d10d8a135c2319dc581905d79449f52ea/xgrammar-0.1.25.tar.gz", hash = "sha256:70ce16b27e8082f20808ed759b0733304316facc421656f0f30cfce514b5b77a", size = 2297187, upload-time = "2025-09-21T05:58:58.942Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/b4/8f78b56ebf64f161258f339cc5898bf761b4fb6c6805d0bca1bcaaaef4a1/xgrammar-0.1.25-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d12d1078ee2b5c1531610489b433b77694a7786210ceb2c0c1c1eb058e9053c7", size = 679074, upload-time = "2025-09-21T05:58:20.344Z" }, { url = "https://files.pythonhosted.org/packages/52/38/b57120b73adcd342ef974bff14b2b584e7c47edf28d91419cb9325fd5ef2/xgrammar-0.1.25-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c2e940541b7cddf3ef55a70f20d4c872af7f0d900bc0ed36f434bf7212e2e729", size = 622668, upload-time = "2025-09-21T05:58:22.269Z" }, { url = "https://files.pythonhosted.org/packages/19/8d/64430d01c21ca2b1d8c5a1ed47c90f8ac43717beafc9440d01d81acd5cfc/xgrammar-0.1.25-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2063e1c72f0c00f47ac8ce7ce0fcbff6fa77f79012e063369683844e2570c266", size = 8517569, upload-time = "2025-09-21T05:58:23.77Z" }, { url = "https://files.pythonhosted.org/packages/b1/c4/137d0e9cd038ff4141752c509dbeea0ec5093eb80815620c01b1f1c26d0a/xgrammar-0.1.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9785eafa251c996ebaa441f3b8a6c037538930104e265a64a013da0e6fd2ad86", size = 8709188, upload-time = "2025-09-21T05:58:26.246Z" }, { url = "https://files.pythonhosted.org/packages/6c/3d/c228c470d50865c9db3fb1e75a95449d0183a8248519b89e86dc481d6078/xgrammar-0.1.25-cp310-cp310-win_amd64.whl", hash = "sha256:42ecefd020038b3919a473fe5b9bb9d8d809717b8689a736b81617dec4acc59b", size = 698919, upload-time = "2025-09-21T05:58:28.368Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b7/ca0ff7c91f24b2302e94b0e6c2a234cc5752b10da51eb937e7f2aa257fde/xgrammar-0.1.25-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:27d7ac4be05cf9aa258c109a8647092ae47cb1e28df7d27caced6ab44b72b799", size = 678801, upload-time = "2025-09-21T05:58:29.936Z" }, { url = "https://files.pythonhosted.org/packages/43/cd/fdf4fb1b5f9c301d381656a600ad95255a76fa68132978af6f06e50a46e1/xgrammar-0.1.25-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:151c1636188bc8c5cdf318cefc5ba23221c9c8cc07cb392317fb3f7635428150", size = 622565, upload-time = "2025-09-21T05:58:31.185Z" }, { url = "https://files.pythonhosted.org/packages/55/04/55a87e814bcab771d3e4159281fa382b3d5f14a36114f2f9e572728da831/xgrammar-0.1.25-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35fc135650aa204bf84db7fe9c0c0f480b6b11419fe47d89f4bd21602ac33be9", size = 8517238, upload-time = "2025-09-21T05:58:32.835Z" }, { url = "https://files.pythonhosted.org/packages/31/f6/3c5210bc41b61fb32b66bf5c9fd8ec5edacfeddf9860e95baa9caa9a2c82/xgrammar-0.1.25-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc19d6d7e8e51b6c9a266e949ac7fb3d2992447efeec7df32cca109149afac18", size = 8709514, upload-time = "2025-09-21T05:58:34.727Z" }, { url = "https://files.pythonhosted.org/packages/21/de/85714f307536b328cc16cc6755151865e8875378c8557c15447ca07dff98/xgrammar-0.1.25-cp311-cp311-win_amd64.whl", hash = "sha256:8fcb24f5a7acd5876165c50bd51ce4bf8e6ff897344a5086be92d1fe6695f7fe", size = 698722, upload-time = "2025-09-21T05:58:36.411Z" }, + { url = "https://files.pythonhosted.org/packages/bf/d7/a7bdb158afa88af7e6e0d312e9677ba5fb5e423932008c9aa2c45af75d5d/xgrammar-0.1.25-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:96500d7578c46e8551253b9211b02e02f54e147bc290479a64717d80dcf4f7e3", size = 678250, upload-time = "2025-09-21T05:58:37.936Z" }, { url = "https://files.pythonhosted.org/packages/10/9d/b20588a3209d544a3432ebfcf2e3b1a455833ee658149b08c18eef0c6f59/xgrammar-0.1.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ba9031e359447af53ce89dfb0775e7b9f4b358d513bcc28a6b4deace661dd5", size = 621550, upload-time = "2025-09-21T05:58:39.464Z" }, { url = "https://files.pythonhosted.org/packages/99/9c/39bb38680be3b6d6aa11b8a46a69fb43e2537d6728710b299fa9fc231ff0/xgrammar-0.1.25-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c519518ebc65f75053123baaf23776a21bda58f64101a64c2fc4aa467c9cd480", size = 8519097, upload-time = "2025-09-21T05:58:40.831Z" }, { url = "https://files.pythonhosted.org/packages/c6/c2/695797afa9922c30c45aa94e087ad33a9d87843f269461b622a65a39022a/xgrammar-0.1.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47fdbfc6007df47de2142613220292023e88e4a570546b39591f053e4d9ec33f", size = 8712184, upload-time = "2025-09-21T05:58:43.142Z" }, From afd8a4ee6fb091f0073c285f14e3d48dfe72a1e9 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 11 Nov 2025 17:11:20 -0800 Subject: [PATCH 02/16] Adding flight booking example for live test --- .../flight_booking_grok.py | 248 ++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 examples/pydantic_ai_examples/flight_booking_grok.py diff --git a/examples/pydantic_ai_examples/flight_booking_grok.py b/examples/pydantic_ai_examples/flight_booking_grok.py new file mode 100644 index 0000000000..1acb497927 --- /dev/null +++ b/examples/pydantic_ai_examples/flight_booking_grok.py @@ -0,0 +1,248 @@ +"""Example of a multi-agent flow where one agent delegates work to another. + +In this scenario, a group of agents work together to find flights for a user. +""" + +import datetime +from dataclasses import dataclass +from typing import Literal +import os +import logfire +from pydantic import BaseModel, Field +from rich.prompt import Prompt + +from pydantic_ai import Agent, ModelRetry, RunContext, RunUsage, UsageLimits +from pydantic_ai.messages import ModelMessage + +# Import local GrokModel +from pydantic_ai.models.grok import GrokModel + +logfire.configure() +logfire.instrument_pydantic_ai() +logfire.instrument_httpx() + +# Configure for xAI API +xai_api_key = os.getenv("XAI_API_KEY") +if not xai_api_key: + raise ValueError("XAI_API_KEY environment variable is required") + + +# Create the model using the new GrokModelpwd +model = GrokModel("grok-4-fast-non-reasoning", api_key=xai_api_key) + + +class FlightDetails(BaseModel): + """Details of the most suitable flight.""" + + flight_number: str + price: int + origin: str = Field(description="Three-letter airport code") + destination: str = Field(description="Three-letter airport code") + date: datetime.date + + +class NoFlightFound(BaseModel): + """When no valid flight is found.""" + + +@dataclass +class Deps: + web_page_text: str + req_origin: str + req_destination: str + req_date: datetime.date + + +# This agent is responsible for controlling the flow of the conversation. +search_agent = Agent[Deps, FlightDetails | NoFlightFound]( + model=model, + output_type=FlightDetails | NoFlightFound, # type: ignore + retries=4, + system_prompt=("Your job is to find the cheapest flight for the user on the given date. "), +) + + +# This agent is responsible for extracting flight details from web page text. +extraction_agent = Agent( + model=model, + output_type=list[FlightDetails], + system_prompt="Extract all the flight details from the given text.", +) + + +@search_agent.tool +async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]: + """Get details of all flights.""" + # we pass the usage to the search agent so requests within this agent are counted + result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage) + logfire.info("found {flight_count} flights", flight_count=len(result.output)) + return result.output + + +@search_agent.output_validator +async def validate_output( + ctx: RunContext[Deps], output: FlightDetails | NoFlightFound +) -> FlightDetails | NoFlightFound: + """Procedural validation that the flight meets the constraints.""" + if isinstance(output, NoFlightFound): + return output + + errors: list[str] = [] + if output.origin != ctx.deps.req_origin: + errors.append(f"Flight should have origin {ctx.deps.req_origin}, not {output.origin}") + if output.destination != ctx.deps.req_destination: + errors.append( + f"Flight should have destination {ctx.deps.req_destination}, not {output.destination}" + ) + if output.date != ctx.deps.req_date: + errors.append(f"Flight should be on {ctx.deps.req_date}, not {output.date}") + + if errors: + raise ModelRetry("\n".join(errors)) + else: + return output + + +class SeatPreference(BaseModel): + row: int = Field(ge=1, le=30) + seat: Literal["A", "B", "C", "D", "E", "F"] + + +class Failed(BaseModel): + """Unable to extract a seat selection.""" + + +# This agent is responsible for extracting the user's seat selection +seat_preference_agent = Agent[None, SeatPreference | Failed]( + model=model, + output_type=SeatPreference | Failed, + system_prompt=( + "Extract the user's seat preference. " + "Seats A and F are window seats. " + "Row 1 is the front row and has extra leg room. " + "Rows 14, and 20 also have extra leg room. " + ), +) + + +# in reality this would be downloaded from a booking site, +# potentially using another agent to navigate the site +flights_web_page = """ +1. Flight SFO-AK123 +- Price: $350 +- Origin: San Francisco International Airport (SFO) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 10, 2025 + +2. Flight SFO-AK456 +- Price: $370 +- Origin: San Francisco International Airport (SFO) +- Destination: Fairbanks International Airport (FAI) +- Date: January 10, 2025 + +3. Flight SFO-AK789 +- Price: $400 +- Origin: San Francisco International Airport (SFO) +- Destination: Juneau International Airport (JNU) +- Date: January 20, 2025 + +4. Flight NYC-LA101 +- Price: $250 +- Origin: San Francisco International Airport (SFO) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 10, 2025 + +5. Flight CHI-MIA202 +- Price: $200 +- Origin: Chicago O'Hare International Airport (ORD) +- Destination: Miami International Airport (MIA) +- Date: January 12, 2025 + +6. Flight BOS-SEA303 +- Price: $120 +- Origin: Boston Logan International Airport (BOS) +- Destination: Ted Stevens Anchorage International Airport (ANC) +- Date: January 12, 2025 + +7. Flight DFW-DEN404 +- Price: $150 +- Origin: Dallas/Fort Worth International Airport (DFW) +- Destination: Denver International Airport (DEN) +- Date: January 10, 2025 + +8. Flight ATL-HOU505 +- Price: $180 +- Origin: Hartsfield-Jackson Atlanta International Airport (ATL) +- Destination: George Bush Intercontinental Airport (IAH) +- Date: January 10, 2025 +""" + +# restrict how many requests this app can make to the LLM +usage_limits = UsageLimits(request_limit=15) + + +async def main(): + deps = Deps( + web_page_text=flights_web_page, + req_origin="SFO", + req_destination="ANC", + req_date=datetime.date(2025, 1, 10), + ) + message_history: list[ModelMessage] | None = None + usage: RunUsage = RunUsage() + # run the agent until a satisfactory flight is found + while True: + result = await search_agent.run( + f"Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}", + deps=deps, + usage=usage, + message_history=message_history, + usage_limits=usage_limits, + ) + if isinstance(result.output, NoFlightFound): + print("No flight found") + break + else: + flight = result.output + print(f"Flight found: {flight}") + answer = Prompt.ask( + "Do you want to buy this flight, or keep searching? (buy/*search)", + choices=["buy", "search", ""], + show_choices=False, + ) + if answer == "buy": + seat = await find_seat(usage) + await buy_tickets(flight, seat) + break + else: + message_history = result.all_messages( + output_tool_return_content="Please suggest another flight" + ) + + +async def find_seat(usage: RunUsage) -> SeatPreference: + message_history: list[ModelMessage] | None = None + while True: + answer = Prompt.ask("What seat would you like?") + + result = await seat_preference_agent.run( + answer, + message_history=message_history, + usage=usage, + usage_limits=usage_limits, + ) + if isinstance(result.output, SeatPreference): + return result.output + else: + print("Could not understand seat preference. Please try again.") + message_history = result.all_messages() + + +async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference): + print(f"Purchasing flight {flight_details=!r} {seat=!r}...") + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) From 4513c8fb245cc67a14602c9f7d58ae86c2a541c5 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Mon, 17 Nov 2025 18:59:14 -0800 Subject: [PATCH 03/16] Fix pre-commit issues --- .../flight_booking_grok.py | 65 ++--- pydantic_ai_slim/pydantic_ai/models/grok.py | 249 +++++++++--------- pydantic_ai_slim/pydantic_ai/profiles/grok.py | 2 +- 3 files changed, 162 insertions(+), 154 deletions(-) diff --git a/examples/pydantic_ai_examples/flight_booking_grok.py b/examples/pydantic_ai_examples/flight_booking_grok.py index 1acb497927..ec8673fb03 100644 --- a/examples/pydantic_ai_examples/flight_booking_grok.py +++ b/examples/pydantic_ai_examples/flight_booking_grok.py @@ -4,9 +4,10 @@ """ import datetime +import os from dataclasses import dataclass from typing import Literal -import os + import logfire from pydantic import BaseModel, Field from rich.prompt import Prompt @@ -22,13 +23,13 @@ logfire.instrument_httpx() # Configure for xAI API -xai_api_key = os.getenv("XAI_API_KEY") +xai_api_key = os.getenv('XAI_API_KEY') if not xai_api_key: - raise ValueError("XAI_API_KEY environment variable is required") + raise ValueError('XAI_API_KEY environment variable is required') # Create the model using the new GrokModelpwd -model = GrokModel("grok-4-fast-non-reasoning", api_key=xai_api_key) +model = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) class FlightDetails(BaseModel): @@ -36,8 +37,8 @@ class FlightDetails(BaseModel): flight_number: str price: int - origin: str = Field(description="Three-letter airport code") - destination: str = Field(description="Three-letter airport code") + origin: str = Field(description='Three-letter airport code') + destination: str = Field(description='Three-letter airport code') date: datetime.date @@ -58,7 +59,9 @@ class Deps: model=model, output_type=FlightDetails | NoFlightFound, # type: ignore retries=4, - system_prompt=("Your job is to find the cheapest flight for the user on the given date. "), + system_prompt=( + 'Your job is to find the cheapest flight for the user on the given date. ' + ), ) @@ -66,7 +69,7 @@ class Deps: extraction_agent = Agent( model=model, output_type=list[FlightDetails], - system_prompt="Extract all the flight details from the given text.", + system_prompt='Extract all the flight details from the given text.', ) @@ -75,7 +78,7 @@ async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]: """Get details of all flights.""" # we pass the usage to the search agent so requests within this agent are counted result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage) - logfire.info("found {flight_count} flights", flight_count=len(result.output)) + logfire.info('found {flight_count} flights', flight_count=len(result.output)) return result.output @@ -89,23 +92,25 @@ async def validate_output( errors: list[str] = [] if output.origin != ctx.deps.req_origin: - errors.append(f"Flight should have origin {ctx.deps.req_origin}, not {output.origin}") + errors.append( + f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}' + ) if output.destination != ctx.deps.req_destination: errors.append( - f"Flight should have destination {ctx.deps.req_destination}, not {output.destination}" + f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}' ) if output.date != ctx.deps.req_date: - errors.append(f"Flight should be on {ctx.deps.req_date}, not {output.date}") + errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}') if errors: - raise ModelRetry("\n".join(errors)) + raise ModelRetry('\n'.join(errors)) else: return output class SeatPreference(BaseModel): row: int = Field(ge=1, le=30) - seat: Literal["A", "B", "C", "D", "E", "F"] + seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] class Failed(BaseModel): @@ -118,9 +123,9 @@ class Failed(BaseModel): output_type=SeatPreference | Failed, system_prompt=( "Extract the user's seat preference. " - "Seats A and F are window seats. " - "Row 1 is the front row and has extra leg room. " - "Rows 14, and 20 also have extra leg room. " + 'Seats A and F are window seats. ' + 'Row 1 is the front row and has extra leg room. ' + 'Rows 14, and 20 also have extra leg room. ' ), ) @@ -184,8 +189,8 @@ class Failed(BaseModel): async def main(): deps = Deps( web_page_text=flights_web_page, - req_origin="SFO", - req_destination="ANC", + req_origin='SFO', + req_destination='ANC', req_date=datetime.date(2025, 1, 10), ) message_history: list[ModelMessage] | None = None @@ -193,37 +198,37 @@ async def main(): # run the agent until a satisfactory flight is found while True: result = await search_agent.run( - f"Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}", + f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}', deps=deps, usage=usage, message_history=message_history, usage_limits=usage_limits, ) if isinstance(result.output, NoFlightFound): - print("No flight found") + print('No flight found') break else: flight = result.output - print(f"Flight found: {flight}") + print(f'Flight found: {flight}') answer = Prompt.ask( - "Do you want to buy this flight, or keep searching? (buy/*search)", - choices=["buy", "search", ""], + 'Do you want to buy this flight, or keep searching? (buy/*search)', + choices=['buy', 'search', ''], show_choices=False, ) - if answer == "buy": + if answer == 'buy': seat = await find_seat(usage) await buy_tickets(flight, seat) break else: message_history = result.all_messages( - output_tool_return_content="Please suggest another flight" + output_tool_return_content='Please suggest another flight' ) async def find_seat(usage: RunUsage) -> SeatPreference: message_history: list[ModelMessage] | None = None while True: - answer = Prompt.ask("What seat would you like?") + answer = Prompt.ask('What seat would you like?') result = await seat_preference_agent.run( answer, @@ -234,15 +239,15 @@ async def find_seat(usage: RunUsage) -> SeatPreference: if isinstance(result.output, SeatPreference): return result.output else: - print("Could not understand seat preference. Please try again.") + print('Could not understand seat preference. Please try again.') message_history = result.all_messages() async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference): - print(f"Purchasing flight {flight_details=!r} {seat=!r}...") + print(f'Purchasing flight {flight_details=!r} {seat=!r}...') -if __name__ == "__main__": +if __name__ == '__main__': import asyncio asyncio.run(main()) diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index 5ec8c0ebf6..7b073b37d1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -1,34 +1,40 @@ """Grok model implementation using xAI SDK.""" import os +from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from typing import Any, AsyncIterator +from dataclasses import dataclass +from typing import Any + +import xai_sdk.chat as chat_types + +# Import xai_sdk components +from xai_sdk import AsyncClient +from xai_sdk.chat import assistant, system, tool, tool_result, user from .._run_context import RunContext +from .._utils import now_utc from ..messages import ( + FinishReason, ModelMessage, ModelRequest, + ModelRequestPart, ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, SystemPromptPart, - UserPromptPart, - ToolReturnPart, TextPart, ToolCallPart, - FinishReason, + ToolReturnPart, + UserPromptPart, ) from ..models import ( Model, ModelRequestParameters, - ModelSettings, StreamedResponse, ) +from ..settings import ModelSettings from ..usage import RequestUsage -from .._utils import now_utc - -# Import xai_sdk components -from xai_sdk import AsyncClient -from xai_sdk.chat import system, user, assistant, tool, tool_result -import xai_sdk.chat as chat_types class GrokModel(Model): @@ -53,9 +59,9 @@ def __init__( """ super().__init__(settings=settings) self._model_name = model_name - self._api_key = api_key or os.getenv("XAI_API_KEY") or "" + self._api_key = api_key or os.getenv('XAI_API_KEY') or '' if not self._api_key: - raise ValueError("XAI API key is required") + raise ValueError('XAI API key is required') @property def model_name(self) -> str: @@ -65,55 +71,64 @@ def model_name(self) -> str: @property def system(self) -> str: """The model provider.""" - return "xai" + return 'xai' def _map_messages(self, messages: list[ModelMessage]) -> list[chat_types.chat_pb2.Message]: """Convert pydantic_ai messages to xAI SDK messages.""" - xai_messages = [] + xai_messages: list[chat_types.chat_pb2.Message] = [] for message in messages: if isinstance(message, ModelRequest): - for part in message.parts: - if isinstance(part, SystemPromptPart): - xai_messages.append(system(part.content)) - elif isinstance(part, UserPromptPart): - # Handle user prompt content - if isinstance(part.content, str): - xai_messages.append(user(part.content)) - else: - # Handle complex content (images, etc.) - # For now, just concatenate text content - text_parts = [] - for item in part.content: - if isinstance(item, str): - text_parts.append(item) - if text_parts: - xai_messages.append(user(" ".join(text_parts))) - elif isinstance(part, ToolReturnPart): - xai_messages.append(tool_result(part.model_response_str())) + xai_messages.extend(self._map_request_parts(message.parts)) elif isinstance(message, ModelResponse): - content_parts = [] - for part in message.parts: - if isinstance(part, TextPart): - content_parts.append(part.content) - elif isinstance(part, ToolCallPart): - # Tool calls will be handled separately in the response processing - pass + if response_msg := self._map_response_parts(message.parts): + xai_messages.append(response_msg) + + return xai_messages + + def _map_request_parts(self, parts: Sequence[ModelRequestPart]) -> list[chat_types.chat_pb2.Message]: + """Map ModelRequest parts to xAI messages.""" + xai_messages: list[chat_types.chat_pb2.Message] = [] - if content_parts: - xai_messages.append(assistant(" ".join(content_parts))) + for part in parts: + if isinstance(part, SystemPromptPart): + xai_messages.append(system(part.content)) + elif isinstance(part, UserPromptPart): + if user_msg := self._map_user_prompt(part): + xai_messages.append(user_msg) + elif isinstance(part, ToolReturnPart): + xai_messages.append(tool_result(part.model_response_str())) return xai_messages - def _map_tools( - self, model_request_parameters: ModelRequestParameters - ) -> list[chat_types.chat_pb2.Tool]: + def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message | None: + """Map a UserPromptPart to an xAI user message.""" + if isinstance(part.content, str): + return user(part.content) + + # Handle complex content (images, etc.) + text_parts: list[str] = [item for item in part.content if isinstance(item, str)] + if text_parts: + return user(' '.join(text_parts)) + + return None + + def _map_response_parts(self, parts: Sequence[ModelResponsePart]) -> chat_types.chat_pb2.Message | None: + """Map ModelResponse parts to an xAI assistant message.""" + content_parts: list[str] = [part.content for part in parts if isinstance(part, TextPart)] + + if content_parts: + return assistant(' '.join(content_parts)) + + return None + + def _map_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat_types.chat_pb2.Tool]: """Convert pydantic_ai tool definitions to xAI SDK tools.""" - tools = [] + tools: list[chat_types.chat_pb2.Tool] = [] for tool_def in model_request_parameters.tool_defs.values(): xai_tool = tool( name=tool_def.name, - description=tool_def.description or "", + description=tool_def.description or '', parameters=tool_def.parameters_json_schema, ) tools.append(xai_tool) @@ -133,31 +148,25 @@ async def request( xai_messages = self._map_messages(messages) # Convert tools if any - tools = ( - self._map_tools(model_request_parameters) - if model_request_parameters.tool_defs - else None - ) + tools = self._map_tools(model_request_parameters) if model_request_parameters.tool_defs else None # Filter model settings to only include xAI SDK compatible parameters - xai_settings = {} + xai_settings: dict[str, Any] = {} if model_settings: # Map pydantic_ai settings to xAI SDK parameters - if "temperature" in model_settings: - xai_settings["temperature"] = model_settings["temperature"] - if "top_p" in model_settings: - xai_settings["top_p"] = model_settings["top_p"] - if "max_tokens" in model_settings: - xai_settings["max_tokens"] = model_settings["max_tokens"] - if "stop_sequences" in model_settings: - xai_settings["stop"] = model_settings["stop_sequences"] - if "seed" in model_settings: - xai_settings["seed"] = model_settings["seed"] + if 'temperature' in model_settings: + xai_settings['temperature'] = model_settings['temperature'] + if 'top_p' in model_settings: + xai_settings['top_p'] = model_settings['top_p'] + if 'max_tokens' in model_settings: + xai_settings['max_tokens'] = model_settings['max_tokens'] + if 'stop_sequences' in model_settings: + xai_settings['stop'] = model_settings['stop_sequences'] + if 'seed' in model_settings: + xai_settings['seed'] = model_settings['seed'] # Create chat instance - chat = client.chat.create( - model=self._model_name, messages=xai_messages, tools=tools, **xai_settings - ) + chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) # Sample the response response = await chat.sample() @@ -181,46 +190,42 @@ async def request_stream( xai_messages = self._map_messages(messages) # Convert tools if any - tools = ( - self._map_tools(model_request_parameters) - if model_request_parameters.tool_defs - else None - ) + tools = self._map_tools(model_request_parameters) if model_request_parameters.tool_defs else None # Filter model settings to only include xAI SDK compatible parameters - xai_settings = {} + xai_settings: dict[str, Any] = {} if model_settings: # Map pydantic_ai settings to xAI SDK parameters - if "temperature" in model_settings: - xai_settings["temperature"] = model_settings["temperature"] - if "top_p" in model_settings: - xai_settings["top_p"] = model_settings["top_p"] - if "max_tokens" in model_settings: - xai_settings["max_tokens"] = model_settings["max_tokens"] - if "stop_sequences" in model_settings: - xai_settings["stop"] = model_settings["stop_sequences"] - if "seed" in model_settings: - xai_settings["seed"] = model_settings["seed"] + if 'temperature' in model_settings: + xai_settings['temperature'] = model_settings['temperature'] + if 'top_p' in model_settings: + xai_settings['top_p'] = model_settings['top_p'] + if 'max_tokens' in model_settings: + xai_settings['max_tokens'] = model_settings['max_tokens'] + if 'stop_sequences' in model_settings: + xai_settings['stop'] = model_settings['stop_sequences'] + if 'seed' in model_settings: + xai_settings['seed'] = model_settings['seed'] # Create chat instance - chat = client.chat.create( - model=self._model_name, messages=xai_messages, tools=tools, **xai_settings - ) + chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) # Stream the response response_stream = chat.stream() - streamed_response = GrokStreamedResponse(model_request_parameters) - streamed_response._model_name = self._model_name - streamed_response._response = response_stream - streamed_response._timestamp = now_utc() - streamed_response._provider_name = "xai" + streamed_response = GrokStreamedResponse( + model_request_parameters=model_request_parameters, + _model_name=self._model_name, + _response=response_stream, + _timestamp=now_utc(), + _provider_name='xai', + ) yield streamed_response def _process_response(self, response: chat_types.Response) -> ModelResponse: """Convert xAI SDK response to pydantic_ai ModelResponse.""" from typing import cast - parts = [] + parts: list[ModelResponsePart] = [] # Add text content if response.content: @@ -237,24 +242,22 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: ) # Convert usage - try to access attributes, default to 0 if not available - input_tokens = getattr(response.usage, "input_tokens", 0) - output_tokens = getattr(response.usage, "output_tokens", 0) + input_tokens = getattr(response.usage, 'input_tokens', 0) + output_tokens = getattr(response.usage, 'output_tokens', 0) usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) # Map finish reason finish_reason_map = { - "stop": "stop", - "length": "length", - "content_filter": "content_filter", - "max_output_tokens": "length", - "cancelled": "error", - "failed": "error", + 'stop': 'stop', + 'length': 'length', + 'content_filter': 'content_filter', + 'max_output_tokens': 'length', + 'cancelled': 'error', + 'failed': 'error', } raw_finish_reason = response.finish_reason mapped_reason = ( - finish_reason_map.get(raw_finish_reason, "stop") - if isinstance(raw_finish_reason, str) - else "stop" + finish_reason_map.get(raw_finish_reason, 'stop') if isinstance(raw_finish_reason, str) else 'stop' ) finish_reason = cast(FinishReason, mapped_reason) @@ -263,11 +266,12 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: usage=usage, model_name=self._model_name, timestamp=now_utc(), - provider_name="xai", + provider_name='xai', finish_reason=finish_reason, ) +@dataclass class GrokStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for xAI SDK.""" @@ -275,47 +279,46 @@ class GrokStreamedResponse(StreamedResponse): _response: Any # xai_sdk chat stream _timestamp: Any _provider_name: str - _usage: RequestUsage - provider_response_id: str | None - finish_reason: Any - async def _get_event_iterator(self): + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: """Iterate over streaming events from xAI SDK.""" from typing import cast async for response, chunk in self._response: # Update usage if available - if hasattr(response, "usage"): - input_tokens = getattr(response.usage, "input_tokens", 0) - output_tokens = getattr(response.usage, "output_tokens", 0) + if hasattr(response, 'usage'): + input_tokens = getattr(response.usage, 'input_tokens', 0) + output_tokens = getattr(response.usage, 'output_tokens', 0) self._usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) # Set provider response ID - if hasattr(response, "id") and self.provider_response_id is None: + if hasattr(response, 'id') and self.provider_response_id is None: self.provider_response_id = response.id # Handle finish reason - if hasattr(response, "finish_reason") and response.finish_reason: + if hasattr(response, 'finish_reason') and response.finish_reason: finish_reason_map = { - "stop": "stop", - "length": "length", - "content_filter": "content_filter", - "max_output_tokens": "length", - "cancelled": "error", - "failed": "error", + 'stop': 'stop', + 'length': 'length', + 'content_filter': 'content_filter', + 'max_output_tokens': 'length', + 'cancelled': 'error', + 'failed': 'error', } - mapped_reason = finish_reason_map.get(response.finish_reason, "stop") + mapped_reason = finish_reason_map.get(response.finish_reason, 'stop') self.finish_reason = cast(FinishReason, mapped_reason) # Handle text content - if hasattr(chunk, "content") and chunk.content: - yield self._parts_manager.handle_text_delta( - vendor_part_id="content", + if hasattr(chunk, 'content') and chunk.content: + event = self._parts_manager.handle_text_delta( + vendor_part_id='content', content=chunk.content, ) + if event is not None: + yield event # Handle tool calls - if hasattr(chunk, "tool_calls"): + if hasattr(chunk, 'tool_calls'): for tool_call in chunk.tool_calls: yield self._parts_manager.handle_tool_call_part( vendor_part_id=tool_call.id, diff --git a/pydantic_ai_slim/pydantic_ai/profiles/grok.py b/pydantic_ai_slim/pydantic_ai/profiles/grok.py index 9a1a9317c4..3a205c4143 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/grok.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/grok.py @@ -4,6 +4,7 @@ from . import ModelProfile + @dataclass(kw_only=True) class GrokModelProfile(ModelProfile): """Profile for models used with GroqModel. @@ -15,7 +16,6 @@ class GrokModelProfile(ModelProfile): """Whether the model always has the web search built-in tool available.""" - def grok_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Grok model.""" return GrokModelProfile( From e60d3ead0806b4eaddb8a6bf442ad534654eeba5 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Mon, 17 Nov 2025 20:36:06 -0800 Subject: [PATCH 04/16] Updated grok to support passing in AsyncClient, added initial tests --- pydantic_ai_slim/pydantic_ai/models/grok.py | 40 +- tests/models/mock_grok.py | 198 ++++++ tests/models/test_grok.py | 716 ++++++++++++++++++++ 3 files changed, 939 insertions(+), 15 deletions(-) create mode 100644 tests/models/mock_grok.py create mode 100644 tests/models/test_grok.py diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index 7b073b37d1..a8c5a919e3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -42,12 +42,14 @@ class GrokModel(Model): _model_name: str _api_key: str + _client: AsyncClient | None def __init__( self, model_name: str, *, api_key: str | None = None, + client: AsyncClient | None = None, settings: ModelSettings | None = None, ): """Initialize the Grok model. @@ -55,13 +57,18 @@ def __init__( Args: model_name: The name of the Grok model to use (e.g., "grok-3", "grok-4-fast-non-reasoning") api_key: The xAI API key. If not provided, uses XAI_API_KEY environment variable. + client: Optional AsyncClient instance for testing. If provided, api_key is ignored. settings: Optional model settings. """ super().__init__(settings=settings) self._model_name = model_name - self._api_key = api_key or os.getenv('XAI_API_KEY') or '' - if not self._api_key: - raise ValueError('XAI API key is required') + self._client = client + if client is None: + self._api_key = api_key or os.getenv('XAI_API_KEY') or '' + if not self._api_key: + raise ValueError('XAI API key is required') + else: + self._api_key = api_key or '' @property def model_name(self) -> str: @@ -141,8 +148,8 @@ async def request( model_request_parameters: ModelRequestParameters, ) -> ModelResponse: """Make a request to the Grok model.""" - # Create client in the current async context to avoid event loop issues - client = AsyncClient(api_key=self._api_key) + # Use injected client or create one in the current async context + client = self._client or AsyncClient(api_key=self._api_key) # Convert messages to xAI format xai_messages = self._map_messages(messages) @@ -183,8 +190,8 @@ async def request_stream( run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: """Make a streaming request to the Grok model.""" - # Create client in the current async context to avoid event loop issues - client = AsyncClient(api_key=self._api_key) + # Use injected client or create one in the current async context + client = self._client or AsyncClient(api_key=self._api_key) # Convert messages to xAI format xai_messages = self._map_messages(messages) @@ -318,14 +325,17 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: yield event # Handle tool calls - if hasattr(chunk, 'tool_calls'): - for tool_call in chunk.tool_calls: - yield self._parts_manager.handle_tool_call_part( - vendor_part_id=tool_call.id, - tool_name=tool_call.function.name, - args=tool_call.function.arguments, - tool_call_id=tool_call.id, - ) + # Note: We use the accumulated Response tool calls, not the Chunk deltas, + # because pydantic validation needs complete JSON, not partial deltas + if hasattr(response, 'tool_calls'): + for tool_call in response.tool_calls: + if hasattr(tool_call.function, 'name') and tool_call.function.name: + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=tool_call.id, + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) @property def model_name(self) -> str: diff --git a/tests/models/mock_grok.py b/tests/models/mock_grok.py new file mode 100644 index 0000000000..4157a01315 --- /dev/null +++ b/tests/models/mock_grok.py @@ -0,0 +1,198 @@ +from __future__ import annotations as _annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from functools import cached_property +from typing import Any, cast + +from ..conftest import raise_if_exception, try_import +from .mock_async_stream import MockAsyncStream + +with try_import() as imports_successful: + import xai_sdk.chat as chat_types + from xai_sdk import AsyncClient + + MockResponse = chat_types.Response | Exception + # xai_sdk streaming returns tuples of (Response, chunk) where chunk type is not explicitly defined + MockResponseChunk = tuple[chat_types.Response, Any] | Exception + + +@dataclass +class MockGrok: + """Mock for xAI SDK AsyncClient to simulate Grok API responses.""" + + responses: MockResponse | Sequence[MockResponse] | None = None + stream_data: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]] | None = None + index: int = 0 + chat_create_kwargs: list[dict[str, Any]] = field(default_factory=list) + api_key: str = 'test-api-key' + + @cached_property + def chat(self) -> Any: + """Create mock chat interface.""" + return type('Chat', (), {'create': self.chat_create}) + + @classmethod + def create_mock( + cls, responses: MockResponse | Sequence[MockResponse], api_key: str = 'test-api-key' + ) -> AsyncClient: + """Create a mock AsyncClient for non-streaming responses.""" + return cast(AsyncClient, cls(responses=responses, api_key=api_key)) + + @classmethod + def create_mock_stream( + cls, + stream: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]], + api_key: str = 'test-api-key', + ) -> AsyncClient: + """Create a mock AsyncClient for streaming responses.""" + return cast(AsyncClient, cls(stream_data=stream, api_key=api_key)) + + def chat_create(self, *_args: Any, **kwargs: Any) -> MockChatInstance: + """Mock the chat.create method.""" + self.chat_create_kwargs.append(kwargs) + return MockChatInstance( + responses=self.responses, + stream_data=self.stream_data, + index=self.index, + parent=self, + ) + + +@dataclass +class MockChatInstance: + """Mock for the chat instance returned by client.chat.create().""" + + responses: MockResponse | Sequence[MockResponse] | None = None + stream_data: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]] | None = None + index: int = 0 + parent: MockGrok | None = None + + async def sample(self) -> chat_types.Response: + """Mock the sample() method for non-streaming responses.""" + assert self.responses is not None, 'you can only use sample() if responses are provided' + + if isinstance(self.responses, Sequence): + raise_if_exception(self.responses[self.index]) + response = cast(chat_types.Response, self.responses[self.index]) + else: + raise_if_exception(self.responses) + response = cast(chat_types.Response, self.responses) + + if self.parent: + self.parent.index += 1 + + return response + + def stream(self) -> MockAsyncStream[MockResponseChunk]: + """Mock the stream() method for streaming responses.""" + assert self.stream_data is not None, 'you can only use stream() if stream_data is provided' + + # Check if we have nested sequences (multiple streams) vs single stream + # We need to check if it's a list of tuples (single stream) vs list of lists (multiple streams) + if isinstance(self.stream_data, list) and len(self.stream_data) > 0: + first_item = self.stream_data[0] + # If first item is a list (not a tuple), we have multiple streams + if isinstance(first_item, list): + data = cast(list[MockResponseChunk], self.stream_data[self.index]) + else: + # Single stream - use the data as is + data = cast(list[MockResponseChunk], self.stream_data) + else: + data = cast(list[MockResponseChunk], self.stream_data) + + if self.parent: + self.parent.index += 1 + + return MockAsyncStream(iter(data)) + + +def get_mock_chat_create_kwargs(async_client: AsyncClient) -> list[dict[str, Any]]: + """Extract the kwargs passed to chat.create from a mock client.""" + if isinstance(async_client, MockGrok): + return async_client.chat_create_kwargs + else: # pragma: no cover + raise RuntimeError('Not a MockGrok instance') + + +@dataclass +class MockGrokResponse: + """Mock Response object that mimics xai_sdk.chat.Response interface.""" + + id: str = 'grok-123' + content: str = '' + tool_calls: list[Any] = field(default_factory=list) + finish_reason: str = 'stop' + usage: Any | None = None # Would be usage_pb2.SamplingUsage in real xai_sdk + + +@dataclass +class MockGrokToolCall: + """Mock ToolCall object that mimics chat_pb2.ToolCall interface.""" + + id: str + function: Any # Would be chat_pb2.Function with name and arguments + + +@dataclass +class MockGrokFunction: + """Mock Function object for tool calls.""" + + name: str + arguments: dict[str, Any] + + +def create_response( + content: str = '', + tool_calls: list[Any] | None = None, + finish_reason: str = 'stop', + usage: Any | None = None, +) -> MockGrokResponse: + """Create a mock Response object for testing. + + Returns a MockGrokResponse that mimics the xai_sdk.chat.Response interface. + """ + return MockGrokResponse( + id='grok-123', + content=content, + tool_calls=tool_calls or [], + finish_reason=finish_reason, + usage=usage, + ) + + +def create_tool_call( + id: str, + name: str, + arguments: dict[str, Any], +) -> MockGrokToolCall: + """Create a mock ToolCall object for testing. + + Returns a MockGrokToolCall that mimics the chat_pb2.ToolCall interface. + """ + return MockGrokToolCall( + id=id, + function=MockGrokFunction(name=name, arguments=arguments), + ) + + +@dataclass +class MockGrokResponseChunk: + """Mock response chunk for streaming.""" + + content: str = '' + tool_calls: list[Any] = field(default_factory=list) + + +def create_response_chunk( + content: str = '', + tool_calls: list[Any] | None = None, +) -> MockGrokResponseChunk: + """Create a mock response chunk object for testing. + + Returns a MockGrokResponseChunk for streaming responses. + """ + return MockGrokResponseChunk( + content=content, + tool_calls=tool_calls or [], + ) diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py new file mode 100644 index 0000000000..11b0854d9c --- /dev/null +++ b/tests/models/test_grok.py @@ -0,0 +1,716 @@ +from __future__ import annotations as _annotations + +import json +from datetime import timezone +from typing import Any + +import pytest +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import ( + Agent, + BinaryContent, + ImageUrl, + ModelRequest, + ModelResponse, + ModelRetry, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.output import NativeOutput +from pydantic_ai.result import RunUsage +from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import RequestUsage + +from ..conftest import IsDatetime, IsNow, IsStr, try_import +from .mock_grok import ( + MockGrok, + MockGrokResponse, + MockGrokResponseChunk, + create_response, + create_tool_call, + get_mock_chat_create_kwargs, +) + +with try_import() as imports_successful: + import xai_sdk.chat as chat_types + + from pydantic_ai.models.grok import GrokModel + + MockResponse = chat_types.Response | Exception + # xai_sdk streaming returns tuples of (Response, chunk) where chunk type is not explicitly defined + MockResponseChunk = tuple[chat_types.Response, Any] | Exception + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='xai_sdk not installed'), + pytest.mark.anyio, + pytest.mark.vcr, +] + + +def test_init(): + m = GrokModel('grok-3', api_key='foobar') + assert m._api_key == 'foobar' + assert m.model_name == 'grok-3' + assert m.system == 'xai' + + +async def test_request_simple_success(allow_model_requests: None): + response = create_response(content='world') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('hello') + assert result.output == 'world' + assert result.usage() == snapshot(RunUsage(requests=1)) + + # reset the index so we get the same response again + mock_client.index = 0 # type: ignore + + result = await agent.run('hello', message_history=result.new_messages()) + assert result.output == 'world' + assert result.usage() == snapshot(RunUsage(requests=1)) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='world')], + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='world')], + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_request_simple_usage(allow_model_requests: None): + from types import SimpleNamespace + + response = create_response( + content='world', + usage=SimpleNamespace(input_tokens=2, output_tokens=1), + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('Hello') + assert result.output == 'world' + assert result.usage() == snapshot( + RunUsage( + requests=1, + input_tokens=2, + output_tokens=1, + ) + ) + + +async def test_grok_image_input(allow_model_requests: None): + """Test that Grok model handles image inputs (text is extracted from content).""" + response = create_response(content='done') + mock_client = MockGrok.create_mock(response) + model = GrokModel('grok-3', client=mock_client) + agent = Agent(model) + + image_url = ImageUrl('https://example.com/image.png') + binary_image = BinaryContent(b'\x89PNG', media_type='image/png') + + result = await agent.run(['Describe these inputs.', image_url, binary_image]) + assert result.output == 'done' + + +async def test_request_structured_response(allow_model_requests: None): + tool_call = create_tool_call( + id='123', + name='final_result', + arguments={'response': [1, 2, 123]}, + ) + response = create_response(tool_calls=[tool_call]) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, output_type=list[int]) + + result = await agent.run('Hello') + assert result.output == [1, 2, 123] + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'response': [1, 2, 123]}, + tool_call_id='123', + ) + ], + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='123', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ] + ) + + +async def test_request_tool_call(allow_model_requests: None): + from types import SimpleNamespace + + responses = [ + create_response( + tool_calls=[create_tool_call(id='1', name='get_location', arguments={'loc_name': 'San Fransisco'})], + usage=SimpleNamespace(input_tokens=2, output_tokens=1), + ), + create_response( + tool_calls=[create_tool_call(id='2', name='get_location', arguments={'loc_name': 'London'})], + usage=SimpleNamespace(input_tokens=3, output_tokens=2), + ), + create_response(content='final response'), + ] + mock_client = MockGrok.create_mock(responses) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, system_prompt='this is the system prompt') + + @agent.tool_plain + async def get_location(loc_name: str) -> str: + if loc_name == 'London': + return json.dumps({'lat': 51, 'lng': 0}) + else: + raise ModelRetry('Wrong location, please try again') + + result = await agent.run('Hello') + assert result.output == 'final response' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args={'loc_name': 'San Fransisco'}, + tool_call_id='1', + ) + ], + usage=RequestUsage( + input_tokens=2, + output_tokens=1, + ), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrong location, please try again', + tool_name='get_location', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args={'loc_name': 'London'}, + tool_call_id='2', + ) + ], + usage=RequestUsage( + input_tokens=3, + output_tokens=2, + ), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='{"lat": 51, "lng": 0}', + tool_call_id='2', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='final response')], + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + assert result.usage() == snapshot(RunUsage(requests=3, input_tokens=5, output_tokens=3, tool_calls=1)) + + +# Helpers for creating Grok streaming chunks +def grok_chunk(response: chat_types.Response, chunk: Any) -> tuple[chat_types.Response, Any]: + """Create a Grok streaming chunk (response, chunk) tuple.""" + return (response, chunk) + + +def grok_text_chunk(text: str, finish_reason: str = 'stop') -> tuple[chat_types.Response, Any]: + """Create a text streaming chunk for Grok. + + Note: For streaming, Response accumulates content, Chunk is the delta. + Since we can't easily track state across calls, we pass full accumulated text as response.content + and the delta as chunk.content. + """ + from types import SimpleNamespace + + # Create chunk (delta) - just this piece of text + chunk = MockGrokResponseChunk(content=text, tool_calls=[]) + + # Create response (accumulated) - for simplicity in mocks, we'll just use the same text + # In real usage, the Response object would accumulate over multiple chunks + response = MockGrokResponse( + id='grok-123', + content=text, # This will be accumulated by the streaming handler + tool_calls=[], + finish_reason=finish_reason if finish_reason else '', + usage=SimpleNamespace(input_tokens=2, output_tokens=1) if finish_reason else None, + ) + + return (response, chunk) + + +async def test_stream_text(allow_model_requests: None): + stream = [grok_text_chunk('hello '), grok_text_chunk('world')] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) + + +async def test_stream_text_finish_reason(allow_model_requests: None): + # Create streaming chunks with finish reasons + stream = [ + grok_text_chunk('hello ', ''), + grok_text_chunk('world', ''), + grok_text_chunk('.', 'stop'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) + assert result.is_complete + async for response, is_last in result.stream_responses(debounce_by=None): + if is_last: + assert response == snapshot( + ModelResponse( + parts=[TextPart(content='hello world.')], + usage=RequestUsage(input_tokens=2, output_tokens=1), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + ) + ) + + +def grok_tool_chunk( + tool_name: str | None, tool_arguments: str | None, finish_reason: str = '', accumulated_args: str = '' +) -> tuple[chat_types.Response, Any]: + """Create a tool call streaming chunk for Grok. + + Args: + tool_name: The tool name (should be provided in all chunks for proper tracking) + tool_arguments: The delta of arguments for this chunk + finish_reason: The finish reason (only in last chunk) + accumulated_args: The accumulated arguments string up to and including this chunk + + Note: Unlike the real xAI SDK which only sends the tool name in the first chunk, + our mock includes it in every chunk to ensure proper tool call tracking. + """ + from types import SimpleNamespace + + # Infer tool name from accumulated state if not provided + effective_tool_name = tool_name or ('final_result' if accumulated_args else None) + + # Create the chunk (delta) - includes tool name for proper tracking + chunk_tool_call = None + if effective_tool_name is not None or tool_arguments is not None: + chunk_tool_call = SimpleNamespace( + id='tool-123', + function=SimpleNamespace( + name=effective_tool_name, + # arguments should be a string (delta JSON), default to empty string + arguments=tool_arguments if tool_arguments is not None else '', + ), + ) + + # Chunk (delta) + chunk = MockGrokResponseChunk( + content='', + tool_calls=[chunk_tool_call] if chunk_tool_call else [], + ) + + # Response (accumulated) - contains the full accumulated tool call + response_tool_call = SimpleNamespace( + id='tool-123', + function=SimpleNamespace( + name=effective_tool_name, + arguments=accumulated_args, # Full accumulated arguments + ), + ) + + response = MockGrokResponse( + id='grok-123', + content='', + tool_calls=[response_tool_call] if (effective_tool_name is not None or accumulated_args) else [], + finish_reason=finish_reason, + usage=SimpleNamespace(input_tokens=20, output_tokens=1) if finish_reason else None, + ) + + return (response, chunk) + + +class MyTypedDict(TypedDict, total=False): + first: str + second: str + + +async def test_stream_structured(allow_model_requests: None): + stream = [ + grok_tool_chunk('final_result', None, accumulated_args=''), + grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), + grok_tool_chunk(None, '", "second": "Two"', accumulated_args='{"first": "One", "second": "Two"'), + grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream_output(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert result.is_complete + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=1)) + + +async def test_stream_structured_finish_reason(allow_model_requests: None): + stream = [ + grok_tool_chunk('final_result', None, accumulated_args=''), + grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), + grok_tool_chunk(None, '", "second": "Two"', accumulated_args='{"first": "One", "second": "Two"'), + grok_tool_chunk(None, '}', accumulated_args='{"first": "One", "second": "Two"}'), + grok_tool_chunk(None, None, finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream_output(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert result.is_complete + + +async def test_stream_native_output(allow_model_requests: None): + stream = [ + grok_text_chunk('{"first": "One'), + grok_text_chunk('", "second": "Two"'), + grok_text_chunk('}'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, output_type=NativeOutput(MyTypedDict)) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream_output(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert result.is_complete + + +async def test_stream_tool_call_with_empty_text(allow_model_requests: None): + stream = [ + grok_tool_chunk('final_result', None, accumulated_args=''), + grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), + grok_tool_chunk(None, '", "second": "Two"', accumulated_args='{"first": "One", "second": "Two"'), + grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m, output_type=[str, MyTypedDict]) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_output(debounce_by=None)] == snapshot( + [{'first': 'One'}, {'first': 'One', 'second': 'Two'}, {'first': 'One', 'second': 'Two'}] + ) + assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'}) + + +async def test_no_delta(allow_model_requests: None): + stream = [ + grok_text_chunk('hello '), + grok_text_chunk('world'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) + + +async def test_none_delta(allow_model_requests: None): + # Test handling of chunks without deltas + stream = [ + grok_text_chunk('hello '), + grok_text_chunk('world'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) + + +# Skip OpenAI-specific tests that don't apply to Grok +# test_system_prompt_role - OpenAI specific +# test_system_prompt_role_o1_mini - OpenAI specific +# test_openai_pass_custom_system_prompt_role - OpenAI specific +# test_openai_o1_mini_system_role - OpenAI specific + + +# Skip tests that are not applicable to Grok model +# The following tests were removed as they are OpenAI-specific: +# - test_system_prompt_role (OpenAI-specific system prompt roles) +# - test_system_prompt_role_o1_mini (OpenAI o1 specific) +# - test_openai_pass_custom_system_prompt_role (OpenAI-specific) +# - test_openai_o1_mini_system_role (OpenAI-specific) +# - test_parallel_tool_calls (OpenAI-specific parameter) +# - test_image_url_input (OpenAI-specific image handling - would need VCR cassettes for Grok) +# - test_image_url_input_force_download (OpenAI-specific) +# - test_image_url_input_force_download_response_api (OpenAI-specific) +# - test_openai_audio_url_input (OpenAI-specific audio) +# - test_document_url_input (OpenAI-specific documents) +# - test_image_url_tool_response (OpenAI-specific) +# - test_image_as_binary_content_tool_response (OpenAI-specific) +# - test_image_as_binary_content_input (OpenAI-specific) +# - test_audio_as_binary_content_input (OpenAI-specific) +# - test_binary_content_input_unknown_media_type (OpenAI-specific) + + +# Continue with model request/response tests +# Grok-specific tests for built-in tools + + +async def test_grok_web_search_tool(allow_model_requests: None): + """Test Grok model with web_search built-in tool.""" + # First response: tool call to web_search + tool_call = create_tool_call( + id='web-1', + name='web_search', + arguments={'query': 'latest news about AI'}, + ) + response1 = create_response(tool_calls=[tool_call]) + + # Second response: final answer + response2 = create_response(content='Based on web search: AI is advancing rapidly.') + + mock_client = MockGrok.create_mock([response1, response2]) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + # Add a mock web search tool + @agent.tool_plain + async def web_search(query: str) -> str: + return f'Search results for: {query}' + + result = await agent.run('What is the latest news about AI?') + assert 'AI is advancing rapidly' in result.output + assert result.usage().requests == 2 + + +async def test_grok_model_retries(allow_model_requests: None): + """Test Grok model with retries.""" + # Create error response then success + success_response = create_response(content='Success after retry') + + mock_client = MockGrok.create_mock(success_response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + result = await agent.run('hello') + assert result.output == 'Success after retry' + + +async def test_grok_model_settings(allow_model_requests: None): + """Test Grok model with various settings.""" + response = create_response(content='response with settings') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent( + m, + model_settings=ModelSettings( + temperature=0.5, + max_tokens=100, + top_p=0.9, + ), + ) + + result = await agent.run('hello') + assert result.output == 'response with settings' + + # Verify settings were passed to the mock + kwargs = get_mock_chat_create_kwargs(mock_client) + assert len(kwargs) > 0 + + +async def test_grok_model_multiple_tool_calls(allow_model_requests: None): + """Test Grok model with multiple tool calls in sequence.""" + # Three responses: two tool calls, then final answer + responses = [ + create_response( + tool_calls=[create_tool_call(id='1', name='get_data', arguments={'key': 'value1'})], + ), + create_response( + tool_calls=[create_tool_call(id='2', name='process_data', arguments={'data': 'result1'})], + ), + create_response(content='Final processed result'), + ] + + mock_client = MockGrok.create_mock(responses) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + @agent.tool_plain + async def get_data(key: str) -> str: + return f'data for {key}' + + @agent.tool_plain + async def process_data(data: str) -> str: + return f'processed {data}' + + result = await agent.run('Get and process data') + assert result.output == 'Final processed result' + assert result.usage().requests == 3 + assert result.usage().tool_calls == 2 + + +async def test_grok_stream_with_tool_calls(allow_model_requests: None): + """Test Grok streaming with tool calls.""" + # First stream: tool call + stream1 = [ + grok_tool_chunk('get_info', None, accumulated_args=''), + grok_tool_chunk(None, '{"query": "test"}', finish_reason='tool_calls', accumulated_args='{"query": "test"}'), + ] + # Second stream: final response after tool execution + stream2 = [ + grok_text_chunk('Info retrieved: Info about test', finish_reason='stop'), + ] + + mock_client = MockGrok.create_mock_stream([stream1, stream2]) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + @agent.tool_plain + async def get_info(query: str) -> str: + return f'Info about {query}' + + async with agent.run_stream('Get information') as result: + # Consume the stream + [c async for c in result.stream_text(debounce_by=None)] + + # Verify the final output includes the tool result + assert result.is_complete + output = await result.get_output() + assert 'Info about test' in output + + +# Test for error handling +async def test_grok_model_invalid_api_key(): + """Test Grok model with invalid API key.""" + with pytest.raises(ValueError, match='XAI API key is required'): + GrokModel('grok-3', api_key='') + + +async def test_grok_model_properties(): + """Test Grok model properties.""" + m = GrokModel('grok-3', api_key='test-key') + + assert m.model_name == 'grok-3' + assert m.system == 'xai' + assert m._api_key == 'test-key' + + +# End of tests From 43e68904e1756f6af1ce7f0a1d25eab4628a0781 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Mon, 17 Nov 2025 20:53:43 -0800 Subject: [PATCH 05/16] Fix pyright --- tests/models/mock_grok.py | 17 ++++++++------- tests/models/test_grok.py | 44 ++++++++++++++++----------------------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/tests/models/mock_grok.py b/tests/models/mock_grok.py index 4157a01315..4295bd070d 100644 --- a/tests/models/mock_grok.py +++ b/tests/models/mock_grok.py @@ -147,17 +147,20 @@ def create_response( tool_calls: list[Any] | None = None, finish_reason: str = 'stop', usage: Any | None = None, -) -> MockGrokResponse: +) -> chat_types.Response: """Create a mock Response object for testing. Returns a MockGrokResponse that mimics the xai_sdk.chat.Response interface. """ - return MockGrokResponse( - id='grok-123', - content=content, - tool_calls=tool_calls or [], - finish_reason=finish_reason, - usage=usage, + return cast( + chat_types.Response, + MockGrokResponse( + id='grok-123', + content=content, + tool_calls=tool_calls or [], + finish_reason=finish_reason, + usage=usage, + ), ) diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index 11b0854d9c..a486d87562 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -2,7 +2,8 @@ import json from datetime import timezone -from typing import Any +from types import SimpleNamespace +from typing import Any, cast import pytest from inline_snapshot import snapshot @@ -53,14 +54,14 @@ ] -def test_init(): +def test_grok_init(): m = GrokModel('grok-3', api_key='foobar') - assert m._api_key == 'foobar' + # Check model properties without accessing private attributes assert m.model_name == 'grok-3' assert m.system == 'xai' -async def test_request_simple_success(allow_model_requests: None): +async def test_grok_request_simple_success(allow_model_requests: None): response = create_response(content='world') mock_client = MockGrok.create_mock(response) m = GrokModel('grok-3', client=mock_client) @@ -106,9 +107,7 @@ async def test_request_simple_success(allow_model_requests: None): ) -async def test_request_simple_usage(allow_model_requests: None): - from types import SimpleNamespace - +async def test_grok_request_simple_usage(allow_model_requests: None): response = create_response( content='world', usage=SimpleNamespace(input_tokens=2, output_tokens=1), @@ -142,7 +141,7 @@ async def test_grok_image_input(allow_model_requests: None): assert result.output == 'done' -async def test_request_structured_response(allow_model_requests: None): +async def test_grok_request_structured_response(allow_model_requests: None): tool_call = create_tool_call( id='123', name='final_result', @@ -190,9 +189,7 @@ async def test_request_structured_response(allow_model_requests: None): ) -async def test_request_tool_call(allow_model_requests: None): - from types import SimpleNamespace - +async def test_grok_request_tool_call(allow_model_requests: None): responses = [ create_response( tool_calls=[create_tool_call(id='1', name='get_location', arguments={'loc_name': 'San Fransisco'})], @@ -310,8 +307,6 @@ def grok_text_chunk(text: str, finish_reason: str = 'stop') -> tuple[chat_types. Since we can't easily track state across calls, we pass full accumulated text as response.content and the delta as chunk.content. """ - from types import SimpleNamespace - # Create chunk (delta) - just this piece of text chunk = MockGrokResponseChunk(content=text, tool_calls=[]) @@ -325,10 +320,10 @@ def grok_text_chunk(text: str, finish_reason: str = 'stop') -> tuple[chat_types. usage=SimpleNamespace(input_tokens=2, output_tokens=1) if finish_reason else None, ) - return (response, chunk) + return (cast(chat_types.Response, response), chunk) -async def test_stream_text(allow_model_requests: None): +async def test_grok_stream_text(allow_model_requests: None): stream = [grok_text_chunk('hello '), grok_text_chunk('world')] mock_client = MockGrok.create_mock_stream(stream) m = GrokModel('grok-3', client=mock_client) @@ -341,7 +336,7 @@ async def test_stream_text(allow_model_requests: None): assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) -async def test_stream_text_finish_reason(allow_model_requests: None): +async def test_grok_stream_text_finish_reason(allow_model_requests: None): # Create streaming chunks with finish reasons stream = [ grok_text_chunk('hello ', ''), @@ -387,8 +382,6 @@ def grok_tool_chunk( Note: Unlike the real xAI SDK which only sends the tool name in the first chunk, our mock includes it in every chunk to ensure proper tool call tracking. """ - from types import SimpleNamespace - # Infer tool name from accumulated state if not provided effective_tool_name = tool_name or ('final_result' if accumulated_args else None) @@ -427,7 +420,7 @@ def grok_tool_chunk( usage=SimpleNamespace(input_tokens=20, output_tokens=1) if finish_reason else None, ) - return (response, chunk) + return (cast(chat_types.Response, response), chunk) class MyTypedDict(TypedDict, total=False): @@ -435,7 +428,7 @@ class MyTypedDict(TypedDict, total=False): second: str -async def test_stream_structured(allow_model_requests: None): +async def test_grok_stream_structured(allow_model_requests: None): stream = [ grok_tool_chunk('final_result', None, accumulated_args=''), grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), @@ -455,7 +448,7 @@ async def test_stream_structured(allow_model_requests: None): assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=20, output_tokens=1)) -async def test_stream_structured_finish_reason(allow_model_requests: None): +async def test_grok_stream_structured_finish_reason(allow_model_requests: None): stream = [ grok_tool_chunk('final_result', None, accumulated_args=''), grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), @@ -475,7 +468,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): assert result.is_complete -async def test_stream_native_output(allow_model_requests: None): +async def test_grok_stream_native_output(allow_model_requests: None): stream = [ grok_text_chunk('{"first": "One'), grok_text_chunk('", "second": "Two"'), @@ -493,7 +486,7 @@ async def test_stream_native_output(allow_model_requests: None): assert result.is_complete -async def test_stream_tool_call_with_empty_text(allow_model_requests: None): +async def test_grok_stream_tool_call_with_empty_text(allow_model_requests: None): stream = [ grok_tool_chunk('final_result', None, accumulated_args=''), grok_tool_chunk(None, '{"first": "One', accumulated_args='{"first": "One'), @@ -512,7 +505,7 @@ async def test_stream_tool_call_with_empty_text(allow_model_requests: None): assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'}) -async def test_no_delta(allow_model_requests: None): +async def test_grok_no_delta(allow_model_requests: None): stream = [ grok_text_chunk('hello '), grok_text_chunk('world'), @@ -528,7 +521,7 @@ async def test_no_delta(allow_model_requests: None): assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=2, output_tokens=1)) -async def test_none_delta(allow_model_requests: None): +async def test_grok_none_delta(allow_model_requests: None): # Test handling of chunks without deltas stream = [ grok_text_chunk('hello '), @@ -710,7 +703,6 @@ async def test_grok_model_properties(): assert m.model_name == 'grok-3' assert m.system == 'xai' - assert m._api_key == 'test-key' # End of tests From 4f15966c9050b46ae6f3058c55ac31b9bb75fc19 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 13:40:09 -0800 Subject: [PATCH 06/16] Update grok all available grok models --- .../pydantic_ai/models/__init__.py | 3 + .../pydantic_ai/providers/grok.py | 1 + tests/models/test_fallback.py | 1 + tests/models/test_grok.py | 60 +++++++++---------- 4 files changed, 35 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 98214910bd..59ff5c6169 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -171,6 +171,9 @@ 'grok:grok-3-mini-fast', 'grok:grok-4', 'grok:grok-4-0709', + 'grok:grok-4-fast-non-reasoning', + 'grok:grok-4-fast-reasoning', + 'grok:grok-code-fast-1', 'groq:deepseek-r1-distill-llama-70b', 'groq:deepseek-r1-distill-qwen-32b', 'groq:distil-whisper-large-v3-en', diff --git a/pydantic_ai_slim/pydantic_ai/providers/grok.py b/pydantic_ai_slim/pydantic_ai/providers/grok.py index 65ae8946d7..1970c9696d 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/grok.py +++ b/pydantic_ai_slim/pydantic_ai/providers/grok.py @@ -27,6 +27,7 @@ 'grok-4-0709', 'grok-4-fast-reasoning', 'grok-4-fast-non-reasoning', + 'grok-code-fast-1', 'grok-3', 'grok-3-mini', 'grok-3-fast', diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index d03726330a..62e3454bd3 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -415,6 +415,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.agent.name': 'agent', 'logfire.msg': 'agent run', 'logfire.span_type': 'span', + 'logfire.exception.fingerprint': '0000000000000000000000000000000000000000000000000000000000000000', 'pydantic_ai.all_messages': [{'role': 'user', 'parts': [{'type': 'text', 'content': 'hello'}]}], 'logfire.json_schema': { 'type': 'object', diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index a486d87562..0470921ce1 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -55,16 +55,16 @@ def test_grok_init(): - m = GrokModel('grok-3', api_key='foobar') + m = GrokModel('grok-4-fast-non-reasoning', api_key='foobar') # Check model properties without accessing private attributes - assert m.model_name == 'grok-3' + assert m.model_name == 'grok-4-fast-non-reasoning' assert m.system == 'xai' async def test_grok_request_simple_success(allow_model_requests: None): response = create_response(content='world') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('hello') @@ -85,7 +85,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -97,7 +97,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -113,7 +113,7 @@ async def test_grok_request_simple_usage(allow_model_requests: None): usage=SimpleNamespace(input_tokens=2, output_tokens=1), ) mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('Hello') @@ -131,7 +131,7 @@ async def test_grok_image_input(allow_model_requests: None): """Test that Grok model handles image inputs (text is extracted from content).""" response = create_response(content='done') mock_client = MockGrok.create_mock(response) - model = GrokModel('grok-3', client=mock_client) + model = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(model) image_url = ImageUrl('https://example.com/image.png') @@ -149,7 +149,7 @@ async def test_grok_request_structured_response(allow_model_requests: None): ) response = create_response(tool_calls=[tool_call]) mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=list[int]) result = await agent.run('Hello') @@ -168,7 +168,7 @@ async def test_grok_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -202,7 +202,7 @@ async def test_grok_request_tool_call(allow_model_requests: None): create_response(content='final response'), ] mock_client = MockGrok.create_mock(responses) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, system_prompt='this is the system prompt') @agent.tool_plain @@ -235,7 +235,7 @@ async def get_location(loc_name: str) -> str: input_tokens=2, output_tokens=1, ), - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -264,7 +264,7 @@ async def get_location(loc_name: str) -> str: input_tokens=3, output_tokens=2, ), - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -283,7 +283,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', finish_reason='stop', @@ -326,7 +326,7 @@ def grok_text_chunk(text: str, finish_reason: str = 'stop') -> tuple[chat_types. async def test_grok_stream_text(allow_model_requests: None): stream = [grok_text_chunk('hello '), grok_text_chunk('world')] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -344,7 +344,7 @@ async def test_grok_stream_text_finish_reason(allow_model_requests: None): grok_text_chunk('.', 'stop'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -359,7 +359,7 @@ async def test_grok_stream_text_finish_reason(allow_model_requests: None): ModelResponse( parts=[TextPart(content='hello world.')], usage=RequestUsage(input_tokens=2, output_tokens=1), - model_name='grok-3', + model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -436,7 +436,7 @@ async def test_grok_stream_structured(allow_model_requests: None): grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -457,7 +457,7 @@ async def test_grok_stream_structured_finish_reason(allow_model_requests: None): grok_tool_chunk(None, None, finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -475,7 +475,7 @@ async def test_grok_stream_native_output(allow_model_requests: None): grok_text_chunk('}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=NativeOutput(MyTypedDict)) async with agent.run_stream('') as result: @@ -494,7 +494,7 @@ async def test_grok_stream_tool_call_with_empty_text(allow_model_requests: None) grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=[str, MyTypedDict]) async with agent.run_stream('') as result: @@ -511,7 +511,7 @@ async def test_grok_no_delta(allow_model_requests: None): grok_text_chunk('world'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -528,7 +528,7 @@ async def test_grok_none_delta(allow_model_requests: None): grok_text_chunk('world'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -582,7 +582,7 @@ async def test_grok_web_search_tool(allow_model_requests: None): response2 = create_response(content='Based on web search: AI is advancing rapidly.') mock_client = MockGrok.create_mock([response1, response2]) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) # Add a mock web search tool @@ -601,7 +601,7 @@ async def test_grok_model_retries(allow_model_requests: None): success_response = create_response(content='Success after retry') mock_client = MockGrok.create_mock(success_response) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('hello') assert result.output == 'Success after retry' @@ -611,7 +611,7 @@ async def test_grok_model_settings(allow_model_requests: None): """Test Grok model with various settings.""" response = create_response(content='response with settings') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent( m, model_settings=ModelSettings( @@ -643,7 +643,7 @@ async def test_grok_model_multiple_tool_calls(allow_model_requests: None): ] mock_client = MockGrok.create_mock(responses) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) @agent.tool_plain @@ -673,7 +673,7 @@ async def test_grok_stream_with_tool_calls(allow_model_requests: None): ] mock_client = MockGrok.create_mock_stream([stream1, stream2]) - m = GrokModel('grok-3', client=mock_client) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) agent = Agent(m) @agent.tool_plain @@ -694,14 +694,14 @@ async def get_info(query: str) -> str: async def test_grok_model_invalid_api_key(): """Test Grok model with invalid API key.""" with pytest.raises(ValueError, match='XAI API key is required'): - GrokModel('grok-3', api_key='') + GrokModel('grok-4-fast-non-reasoning', api_key='') async def test_grok_model_properties(): """Test Grok model properties.""" - m = GrokModel('grok-3', api_key='test-key') + m = GrokModel('grok-4-fast-non-reasoning', api_key='test-key') - assert m.model_name == 'grok-3' + assert m.model_name == 'grok-4-fast-non-reasoning' assert m.system == 'xai' From 185a6be2522f2a932a4f66c8c9082c06c2c7feae Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 14:24:29 -0800 Subject: [PATCH 07/16] Adding image tests that require an active API_KEY --- pydantic_ai_slim/pydantic_ai/models/grok.py | 42 ++++- pydantic_ai_slim/pydantic_ai/settings.py | 5 + tests/conftest.py | 5 + tests/models/test_grok.py | 164 ++++++++++++++++++-- 4 files changed, 196 insertions(+), 20 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index a8c5a919e3..cedde52ae3 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -10,12 +10,14 @@ # Import xai_sdk components from xai_sdk import AsyncClient -from xai_sdk.chat import assistant, system, tool, tool_result, user +from xai_sdk.chat import assistant, image, system, tool, tool_result, user from .._run_context import RunContext from .._utils import now_utc from ..messages import ( + BinaryContent, FinishReason, + ImageUrl, ModelMessage, ModelRequest, ModelRequestPart, @@ -113,10 +115,28 @@ def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message if isinstance(part.content, str): return user(part.content) - # Handle complex content (images, etc.) - text_parts: list[str] = [item for item in part.content if isinstance(item, str)] - if text_parts: - return user(' '.join(text_parts)) + # Handle complex content (images, text, etc.) + content_items: list[chat_types.Content] = [] + + for item in part.content: + if isinstance(item, str): + content_items.append(item) + elif isinstance(item, ImageUrl): + # Get detail from vendor_metadata if available + detail: chat_types.ImageDetail = 'auto' + if item.vendor_metadata and 'detail' in item.vendor_metadata: + detail = item.vendor_metadata['detail'] + content_items.append(image(item.url, detail=detail)) + elif isinstance(item, BinaryContent): + if item.is_image: + # Convert binary content to data URI and use image() + content_items.append(image(item.data_uri, detail='auto')) + else: + # xAI SDK doesn't support non-image binary content yet + pass + + if content_items: + return user(*content_items) return None @@ -171,6 +191,12 @@ async def request( xai_settings['stop'] = model_settings['stop_sequences'] if 'seed' in model_settings: xai_settings['seed'] = model_settings['seed'] + if 'parallel_tool_calls' in model_settings: + xai_settings['parallel_tool_calls'] = model_settings['parallel_tool_calls'] + if 'presence_penalty' in model_settings: + xai_settings['presence_penalty'] = model_settings['presence_penalty'] + if 'frequency_penalty' in model_settings: + xai_settings['frequency_penalty'] = model_settings['frequency_penalty'] # Create chat instance chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) @@ -213,6 +239,12 @@ async def request_stream( xai_settings['stop'] = model_settings['stop_sequences'] if 'seed' in model_settings: xai_settings['seed'] = model_settings['seed'] + if 'parallel_tool_calls' in model_settings: + xai_settings['parallel_tool_calls'] = model_settings['parallel_tool_calls'] + if 'presence_penalty' in model_settings: + xai_settings['presence_penalty'] = model_settings['presence_penalty'] + if 'frequency_penalty' in model_settings: + xai_settings['frequency_penalty'] = model_settings['frequency_penalty'] # Create chat instance chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index 6941eb1ab3..ea4fb2ff09 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -86,6 +86,7 @@ class ModelSettings(TypedDict, total=False): * OpenAI (some models, not o1) * Groq * Anthropic + * Grok """ seed: int @@ -112,6 +113,7 @@ class ModelSettings(TypedDict, total=False): * Gemini * Mistral * Outlines (LlamaCpp, SgLang, VLLMOffline) + * Grok """ frequency_penalty: float @@ -125,6 +127,7 @@ class ModelSettings(TypedDict, total=False): * Gemini * Mistral * Outlines (LlamaCpp, SgLang, VLLMOffline) + * Grok """ logit_bias: dict[str, int] @@ -149,6 +152,7 @@ class ModelSettings(TypedDict, total=False): * Groq * Cohere * Google + * Grok """ extra_headers: dict[str, str] @@ -159,6 +163,7 @@ class ModelSettings(TypedDict, total=False): * OpenAI * Anthropic * Groq + * Grok """ extra_body: object diff --git a/tests/conftest.py b/tests/conftest.py index 6b90ecfb28..b7b8d28895 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -412,6 +412,11 @@ def cerebras_api_key() -> str: return os.getenv('CEREBRAS_API_KEY', 'mock-api-key') +@pytest.fixture(scope='session') +def xai_api_key() -> str: + return os.getenv('XAI_API_KEY', 'mock-api-key') + + @pytest.fixture(scope='session') def bedrock_provider(): try: diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index 0470921ce1..ad27f077c8 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import json +import os from datetime import timezone from types import SimpleNamespace from typing import Any, cast @@ -545,23 +546,155 @@ async def test_grok_none_delta(allow_model_requests: None): # test_openai_o1_mini_system_role - OpenAI specific +@pytest.mark.parametrize('parallel_tool_calls', [True, False]) +async def test_grok_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None: + tool_call = create_tool_call( + id='123', + name='final_result', + arguments={'response': [1, 2, 3]}, + ) + response = create_response(content='', tool_calls=[tool_call], finish_reason='tool_calls') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + agent = Agent(m, output_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls)) + + await agent.run('Hello') + assert get_mock_chat_create_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls + + +async def test_grok_penalty_parameters(allow_model_requests: None) -> None: + response = create_response(content='test response') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + + settings = ModelSettings( + temperature=0.7, + presence_penalty=0.5, + frequency_penalty=0.3, + parallel_tool_calls=False, + ) + + agent = Agent(m, model_settings=settings) + result = await agent.run('Hello') + + # Check that all settings were passed to the xAI SDK + kwargs = get_mock_chat_create_kwargs(mock_client)[0] + assert kwargs['temperature'] == 0.7 + assert kwargs['presence_penalty'] == 0.5 + assert kwargs['frequency_penalty'] == 0.3 + assert kwargs['parallel_tool_calls'] is False + assert result.output == 'test response' + + +async def test_grok_image_url_input(allow_model_requests: None): + response = create_response(content='world') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + agent = Agent(m) + + result = await agent.run( + [ + 'hello', + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), + ] + ) + assert result.output == 'world' + # Verify that the image URL was included in the messages + assert len(get_mock_chat_create_kwargs(mock_client)) == 1 + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_image_url_tool_response(allow_model_requests: None, xai_api_key: str): + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + agent = Agent(m) + + @agent.tool_plain + async def get_image() -> ImageUrl: + return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg') + + result = await agent.run(['What food is in the image you can get from the get_image tool?']) + + # Verify structure with matchers for dynamic values + messages = result.all_messages() + assert len(messages) == 4 + + # Verify message types and key content + assert isinstance(messages[0], ModelRequest) + assert isinstance(messages[1], ModelResponse) + assert isinstance(messages[2], ModelRequest) + assert isinstance(messages[3], ModelResponse) + + # Verify tool was called + assert isinstance(messages[1].parts[0], ToolCallPart) + assert messages[1].parts[0].tool_name == 'get_image' + + # Verify image was passed back to model + assert isinstance(messages[2].parts[1], UserPromptPart) + assert isinstance(messages[2].parts[1].content, list) + assert any(isinstance(item, ImageUrl) for item in messages[2].parts[1].content) + + # Verify model responded about the image + assert isinstance(messages[3].parts[0], TextPart) + assert 'potato' in messages[3].parts[0].content.lower() + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_image_as_binary_content_tool_response( + allow_model_requests: None, image_content: BinaryContent, xai_api_key: str +): + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + agent = Agent(m) + + @agent.tool_plain + async def get_image() -> BinaryContent: + return image_content + + result = await agent.run(['What fruit is in the image you can get from the get_image tool?']) + + # Verify structure with matchers for dynamic values + messages = result.all_messages() + assert len(messages) == 4 + + # Verify message types and key content + assert isinstance(messages[0], ModelRequest) + assert isinstance(messages[1], ModelResponse) + assert isinstance(messages[2], ModelRequest) + assert isinstance(messages[3], ModelResponse) + + # Verify tool was called + assert isinstance(messages[1].parts[0], ToolCallPart) + assert messages[1].parts[0].tool_name == 'get_image' + + # Verify binary image content was passed back to model + assert isinstance(messages[2].parts[1], UserPromptPart) + assert isinstance(messages[2].parts[1].content, list) + has_binary_image = any(isinstance(item, BinaryContent) and item.is_image for item in messages[2].parts[1].content) + assert has_binary_image, 'Expected BinaryContent image in tool response' + + # Verify model responded about the image + assert isinstance(messages[3].parts[0], TextPart) + response_text = messages[3].parts[0].content.lower() + assert 'kiwi' in response_text or 'fruit' in response_text + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_image_as_binary_content_input( + allow_model_requests: None, image_content: BinaryContent, xai_api_key: str +): + """Test passing binary image content directly as input (not from a tool).""" + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + agent = Agent(m) + + result = await agent.run(['What fruit is in the image?', image_content]) + + # Verify the model received and processed the image + assert result.output + response_text = result.output.lower() + assert 'kiwi' in response_text or 'fruit' in response_text + + # Skip tests that are not applicable to Grok model # The following tests were removed as they are OpenAI-specific: -# - test_system_prompt_role (OpenAI-specific system prompt roles) -# - test_system_prompt_role_o1_mini (OpenAI o1 specific) -# - test_openai_pass_custom_system_prompt_role (OpenAI-specific) -# - test_openai_o1_mini_system_role (OpenAI-specific) -# - test_parallel_tool_calls (OpenAI-specific parameter) -# - test_image_url_input (OpenAI-specific image handling - would need VCR cassettes for Grok) -# - test_image_url_input_force_download (OpenAI-specific) -# - test_image_url_input_force_download_response_api (OpenAI-specific) -# - test_openai_audio_url_input (OpenAI-specific audio) -# - test_document_url_input (OpenAI-specific documents) -# - test_image_url_tool_response (OpenAI-specific) -# - test_image_as_binary_content_tool_response (OpenAI-specific) -# - test_image_as_binary_content_input (OpenAI-specific) -# - test_audio_as_binary_content_input (OpenAI-specific) -# - test_binary_content_input_unknown_media_type (OpenAI-specific) # Continue with model request/response tests @@ -691,6 +824,7 @@ async def get_info(query: str) -> str: # Test for error handling +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is not None, reason='Skipped when XAI_API_KEY is set') async def test_grok_model_invalid_api_key(): """Test Grok model with invalid API key.""" with pytest.raises(ValueError, match='XAI API key is required'): From a632cbf3d3e3fb07655aeb3b6b8183ad1c9ce2dc Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 16:25:07 -0800 Subject: [PATCH 08/16] Adding MCP added stock analysis agent --- .../flight_booking_grok.py | 253 ------------------ .../stock_analysis_agent.py | 161 +++++++++++ pydantic_ai_slim/pydantic_ai/builtin_tools.py | 6 + pydantic_ai_slim/pydantic_ai/models/grok.py | 160 +++++++++-- tests/models/test_grok.py | 203 ++++++++++++-- 5 files changed, 483 insertions(+), 300 deletions(-) delete mode 100644 examples/pydantic_ai_examples/flight_booking_grok.py create mode 100644 examples/pydantic_ai_examples/stock_analysis_agent.py diff --git a/examples/pydantic_ai_examples/flight_booking_grok.py b/examples/pydantic_ai_examples/flight_booking_grok.py deleted file mode 100644 index ec8673fb03..0000000000 --- a/examples/pydantic_ai_examples/flight_booking_grok.py +++ /dev/null @@ -1,253 +0,0 @@ -"""Example of a multi-agent flow where one agent delegates work to another. - -In this scenario, a group of agents work together to find flights for a user. -""" - -import datetime -import os -from dataclasses import dataclass -from typing import Literal - -import logfire -from pydantic import BaseModel, Field -from rich.prompt import Prompt - -from pydantic_ai import Agent, ModelRetry, RunContext, RunUsage, UsageLimits -from pydantic_ai.messages import ModelMessage - -# Import local GrokModel -from pydantic_ai.models.grok import GrokModel - -logfire.configure() -logfire.instrument_pydantic_ai() -logfire.instrument_httpx() - -# Configure for xAI API -xai_api_key = os.getenv('XAI_API_KEY') -if not xai_api_key: - raise ValueError('XAI_API_KEY environment variable is required') - - -# Create the model using the new GrokModelpwd -model = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) - - -class FlightDetails(BaseModel): - """Details of the most suitable flight.""" - - flight_number: str - price: int - origin: str = Field(description='Three-letter airport code') - destination: str = Field(description='Three-letter airport code') - date: datetime.date - - -class NoFlightFound(BaseModel): - """When no valid flight is found.""" - - -@dataclass -class Deps: - web_page_text: str - req_origin: str - req_destination: str - req_date: datetime.date - - -# This agent is responsible for controlling the flow of the conversation. -search_agent = Agent[Deps, FlightDetails | NoFlightFound]( - model=model, - output_type=FlightDetails | NoFlightFound, # type: ignore - retries=4, - system_prompt=( - 'Your job is to find the cheapest flight for the user on the given date. ' - ), -) - - -# This agent is responsible for extracting flight details from web page text. -extraction_agent = Agent( - model=model, - output_type=list[FlightDetails], - system_prompt='Extract all the flight details from the given text.', -) - - -@search_agent.tool -async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]: - """Get details of all flights.""" - # we pass the usage to the search agent so requests within this agent are counted - result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage) - logfire.info('found {flight_count} flights', flight_count=len(result.output)) - return result.output - - -@search_agent.output_validator -async def validate_output( - ctx: RunContext[Deps], output: FlightDetails | NoFlightFound -) -> FlightDetails | NoFlightFound: - """Procedural validation that the flight meets the constraints.""" - if isinstance(output, NoFlightFound): - return output - - errors: list[str] = [] - if output.origin != ctx.deps.req_origin: - errors.append( - f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}' - ) - if output.destination != ctx.deps.req_destination: - errors.append( - f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}' - ) - if output.date != ctx.deps.req_date: - errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}') - - if errors: - raise ModelRetry('\n'.join(errors)) - else: - return output - - -class SeatPreference(BaseModel): - row: int = Field(ge=1, le=30) - seat: Literal['A', 'B', 'C', 'D', 'E', 'F'] - - -class Failed(BaseModel): - """Unable to extract a seat selection.""" - - -# This agent is responsible for extracting the user's seat selection -seat_preference_agent = Agent[None, SeatPreference | Failed]( - model=model, - output_type=SeatPreference | Failed, - system_prompt=( - "Extract the user's seat preference. " - 'Seats A and F are window seats. ' - 'Row 1 is the front row and has extra leg room. ' - 'Rows 14, and 20 also have extra leg room. ' - ), -) - - -# in reality this would be downloaded from a booking site, -# potentially using another agent to navigate the site -flights_web_page = """ -1. Flight SFO-AK123 -- Price: $350 -- Origin: San Francisco International Airport (SFO) -- Destination: Ted Stevens Anchorage International Airport (ANC) -- Date: January 10, 2025 - -2. Flight SFO-AK456 -- Price: $370 -- Origin: San Francisco International Airport (SFO) -- Destination: Fairbanks International Airport (FAI) -- Date: January 10, 2025 - -3. Flight SFO-AK789 -- Price: $400 -- Origin: San Francisco International Airport (SFO) -- Destination: Juneau International Airport (JNU) -- Date: January 20, 2025 - -4. Flight NYC-LA101 -- Price: $250 -- Origin: San Francisco International Airport (SFO) -- Destination: Ted Stevens Anchorage International Airport (ANC) -- Date: January 10, 2025 - -5. Flight CHI-MIA202 -- Price: $200 -- Origin: Chicago O'Hare International Airport (ORD) -- Destination: Miami International Airport (MIA) -- Date: January 12, 2025 - -6. Flight BOS-SEA303 -- Price: $120 -- Origin: Boston Logan International Airport (BOS) -- Destination: Ted Stevens Anchorage International Airport (ANC) -- Date: January 12, 2025 - -7. Flight DFW-DEN404 -- Price: $150 -- Origin: Dallas/Fort Worth International Airport (DFW) -- Destination: Denver International Airport (DEN) -- Date: January 10, 2025 - -8. Flight ATL-HOU505 -- Price: $180 -- Origin: Hartsfield-Jackson Atlanta International Airport (ATL) -- Destination: George Bush Intercontinental Airport (IAH) -- Date: January 10, 2025 -""" - -# restrict how many requests this app can make to the LLM -usage_limits = UsageLimits(request_limit=15) - - -async def main(): - deps = Deps( - web_page_text=flights_web_page, - req_origin='SFO', - req_destination='ANC', - req_date=datetime.date(2025, 1, 10), - ) - message_history: list[ModelMessage] | None = None - usage: RunUsage = RunUsage() - # run the agent until a satisfactory flight is found - while True: - result = await search_agent.run( - f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}', - deps=deps, - usage=usage, - message_history=message_history, - usage_limits=usage_limits, - ) - if isinstance(result.output, NoFlightFound): - print('No flight found') - break - else: - flight = result.output - print(f'Flight found: {flight}') - answer = Prompt.ask( - 'Do you want to buy this flight, or keep searching? (buy/*search)', - choices=['buy', 'search', ''], - show_choices=False, - ) - if answer == 'buy': - seat = await find_seat(usage) - await buy_tickets(flight, seat) - break - else: - message_history = result.all_messages( - output_tool_return_content='Please suggest another flight' - ) - - -async def find_seat(usage: RunUsage) -> SeatPreference: - message_history: list[ModelMessage] | None = None - while True: - answer = Prompt.ask('What seat would you like?') - - result = await seat_preference_agent.run( - answer, - message_history=message_history, - usage=usage, - usage_limits=usage_limits, - ) - if isinstance(result.output, SeatPreference): - return result.output - else: - print('Could not understand seat preference. Please try again.') - message_history = result.all_messages() - - -async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference): - print(f'Purchasing flight {flight_details=!r} {seat=!r}...') - - -if __name__ == '__main__': - import asyncio - - asyncio.run(main()) diff --git a/examples/pydantic_ai_examples/stock_analysis_agent.py b/examples/pydantic_ai_examples/stock_analysis_agent.py new file mode 100644 index 0000000000..13ea704492 --- /dev/null +++ b/examples/pydantic_ai_examples/stock_analysis_agent.py @@ -0,0 +1,161 @@ +"""Example of using Grok's server-side tools (web_search, code_execution) with a local function. + +This agent: +1. Uses web_search to find the best performing NASDAQ stock over the last week +2. Uses code_execution to project the price using linear regression +3. Calls a local function project_price with the results +""" + +import os +from datetime import datetime + +import logfire +from pydantic import BaseModel, Field + +from pydantic_ai import ( + Agent, + BuiltinToolCallPart, + CodeExecutionTool, + ModelResponse, + RunContext, + WebSearchTool, +) +from pydantic_ai.models.grok import GrokModel + +logfire.configure() +logfire.instrument_pydantic_ai() + +# Configure for xAI API +xai_api_key = os.getenv('XAI_API_KEY') +if not xai_api_key: + raise ValueError('XAI_API_KEY environment variable is required') + + +# Create the model using GrokModel with server-side tools +model = GrokModel('grok-4-fast', api_key=xai_api_key) + + +class StockProjection(BaseModel): + """Projection of stock price at year end.""" + + stock_symbol: str = Field(description='Stock ticker symbol') + current_price: float = Field(description='Current stock price') + projected_price: float = Field(description='Projected price at end of year') + analysis: str = Field(description='Brief analysis of the projection') + + +# This agent uses server-side tools to research and analyze stocks +stock_analysis_agent = Agent[None, StockProjection]( + model=model, + output_type=StockProjection, + builtin_tools=[ + WebSearchTool(), # Server-side web search + CodeExecutionTool(), # Server-side code execution + ], + system_prompt=( + 'You are a stock analysis assistant. ' + 'Use web_search to find recent stock performance data on NASDAQ. ' + 'Use code_execution to perform linear regression for price projection. ' + 'After analysis, call project_price with your findings.' + ), +) + + +@stock_analysis_agent.tool +def project_price(ctx: RunContext[None], stock: str, price: float) -> str: + """Record the projected stock price. + + This is a local/client-side function that gets called with the analysis results. + + Args: + ctx: The run context (not used in this function) + stock: Stock ticker symbol + price: Projected price at end of year + """ + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + logfire.info( + 'Stock projection recorded', + stock=stock, + projected_price=price, + timestamp=timestamp, + ) + print('\nšŸ“Š PROJECTION RECORDED:') + print(f' Stock: {stock}') + print(f' Projected End-of-Year Price: ${price:.2f}') + print(f' Timestamp: {timestamp}\n') + + return f'Projection for {stock} at ${price:.2f} has been recorded successfully.' + + +async def main(): + """Run the stock analysis agent.""" + query = ( + 'Can you find me the best performing stock on the NASDAQ over the last week, ' + 'and return the price project for the end of the year using a simple linear regression. ' + ) + + print('šŸ” Starting stock analysis...\n') + print(f'Query: {query}\n') + + result = await stock_analysis_agent.run(query) + + # Track which builtin tools were used + web_search_count = 0 + code_execution_count = 0 + + for message in result.all_messages(): + if isinstance(message, ModelResponse): + for part in message.parts: + if isinstance(part, BuiltinToolCallPart): + if 'web_search' in part.tool_name or 'browse' in part.tool_name: + web_search_count += 1 + logfire.info( + 'Server-side web_search tool called', + tool_name=part.tool_name, + tool_call_id=part.tool_call_id, + ) + elif 'code_execution' in part.tool_name: + code_execution_count += 1 + logfire.info( + 'Server-side code_execution tool called', + tool_name=part.tool_name, + tool_call_id=part.tool_call_id, + code=part.args_as_dict().get('code', 'N/A') + if part.args + else 'N/A', + ) + + print('\nāœ… Analysis complete!') + print('\nšŸ”§ Server-Side Tools Used:') + print(f' Web Search calls: {web_search_count}') + print(f' Code Execution calls: {code_execution_count}') + + print(f'\nStock: {result.output.stock_symbol}') + print(f'Current Price: ${result.output.current_price:.2f}') + print(f'Projected Year-End Price: ${result.output.projected_price:.2f}') + print(f'\nAnalysis: {result.output.analysis}') + + # Get the final response message for metadata + final_message = result.all_messages()[-1] + if isinstance(final_message, ModelResponse): + print('\nšŸ†” Response Metadata:') + if final_message.provider_response_id: + print(f' Response ID: {final_message.provider_response_id}') + if final_message.model_name: + print(f' Model: {final_message.model_name}') + if final_message.timestamp: + print(f' Timestamp: {final_message.timestamp}') + + # Show usage statistics + usage = result.usage() + print('\nšŸ“ˆ Usage Statistics:') + print(f' Requests: {usage.requests}') + print(f' Input Tokens: {usage.input_tokens}') + print(f' Output Tokens: {usage.output_tokens}') + print(f' Total Tokens: {usage.total_tokens}') + + +if __name__ == '__main__': + import asyncio + + asyncio.run(main()) diff --git a/pydantic_ai_slim/pydantic_ai/builtin_tools.py b/pydantic_ai_slim/pydantic_ai/builtin_tools.py index 5559b3124a..c53b0bd61c 100644 --- a/pydantic_ai_slim/pydantic_ai/builtin_tools.py +++ b/pydantic_ai_slim/pydantic_ai/builtin_tools.py @@ -75,6 +75,7 @@ class WebSearchTool(AbstractBuiltinTool): * OpenAI Responses * Groq * Google + * Grok """ search_context_size: Literal['low', 'medium', 'high'] = 'medium' @@ -159,6 +160,7 @@ class CodeExecutionTool(AbstractBuiltinTool): * Anthropic * OpenAI Responses * Google + * Grok """ kind: str = 'code_execution' @@ -280,6 +282,7 @@ class MCPServerTool(AbstractBuiltinTool): * OpenAI Responses * Anthropic + * Grok """ id: str @@ -298,6 +301,7 @@ class MCPServerTool(AbstractBuiltinTool): * OpenAI Responses * Anthropic + * Grok """ description: str | None = None @@ -315,6 +319,7 @@ class MCPServerTool(AbstractBuiltinTool): * OpenAI Responses * Anthropic + * Grok """ headers: dict[str, str] | None = None @@ -325,6 +330,7 @@ class MCPServerTool(AbstractBuiltinTool): Supported by: * OpenAI Responses + * Grok """ kind: str = 'mcp_server' diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index cedde52ae3..9d0778ca93 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -11,11 +11,16 @@ # Import xai_sdk components from xai_sdk import AsyncClient from xai_sdk.chat import assistant, image, system, tool, tool_result, user +from xai_sdk.tools import code_execution, get_tool_call_type, mcp, web_search # x_search not yet supported from .._run_context import RunContext from .._utils import now_utc +from ..builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool +from ..exceptions import UserError from ..messages import ( BinaryContent, + BuiltinToolCallPart, + BuiltinToolReturnPart, FinishReason, ImageUrl, ModelMessage, @@ -161,6 +166,33 @@ def _map_tools(self, model_request_parameters: ModelRequestParameters) -> list[c tools.append(xai_tool) return tools + def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat_types.chat_pb2.Tool]: + """Convert pydantic_ai built-in tools to xAI SDK server-side tools.""" + tools: list[chat_types.chat_pb2.Tool] = [] + for builtin_tool in model_request_parameters.builtin_tools: + if isinstance(builtin_tool, WebSearchTool): + tools.append(web_search()) + elif isinstance(builtin_tool, CodeExecutionTool): + tools.append(code_execution()) + elif isinstance(builtin_tool, MCPServerTool): + tools.append( + mcp( + server_url=builtin_tool.url, + server_label=builtin_tool.id, + server_description=builtin_tool.description, + allowed_tool_names=builtin_tool.allowed_tools, + authorization=builtin_tool.authorization_token, + extra_headers=builtin_tool.headers, + ) + ) + else: + raise UserError( + f'`{builtin_tool.__class__.__name__}` is not supported by `GrokModel`. ' + f'Supported built-in tools: WebSearchTool, CodeExecutionTool, MCPServerTool. ' + f'If XSearchTool should be supported, please file an issue.' + ) + return tools + async def request( self, messages: list[ModelMessage], @@ -174,8 +206,13 @@ async def request( # Convert messages to xAI format xai_messages = self._map_messages(messages) - # Convert tools if any - tools = self._map_tools(model_request_parameters) if model_request_parameters.tool_defs else None + # Convert tools: combine built-in (server-side) tools and custom (client-side) tools + tools: list[chat_types.chat_pb2.Tool] = [] + if model_request_parameters.builtin_tools: + tools.extend(self._get_builtin_tools(model_request_parameters)) + if model_request_parameters.tool_defs: + tools.extend(self._map_tools(model_request_parameters)) + tools_param = tools if tools else None # Filter model settings to only include xAI SDK compatible parameters xai_settings: dict[str, Any] = {} @@ -199,7 +236,7 @@ async def request( xai_settings['frequency_penalty'] = model_settings['frequency_penalty'] # Create chat instance - chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) + chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools_param, **xai_settings) # Sample the response response = await chat.sample() @@ -222,8 +259,13 @@ async def request_stream( # Convert messages to xAI format xai_messages = self._map_messages(messages) - # Convert tools if any - tools = self._map_tools(model_request_parameters) if model_request_parameters.tool_defs else None + # Convert tools: combine built-in (server-side) tools and custom (client-side) tools + tools: list[chat_types.chat_pb2.Tool] = [] + if model_request_parameters.builtin_tools: + tools.extend(self._get_builtin_tools(model_request_parameters)) + if model_request_parameters.tool_defs: + tools.extend(self._map_tools(model_request_parameters)) + tools_param = tools if tools else None # Filter model settings to only include xAI SDK compatible parameters xai_settings: dict[str, Any] = {} @@ -247,7 +289,7 @@ async def request_stream( xai_settings['frequency_penalty'] = model_settings['frequency_penalty'] # Create chat instance - chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings) + chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools_param, **xai_settings) # Stream the response response_stream = chat.stream() @@ -266,19 +308,55 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: parts: list[ModelResponsePart] = [] - # Add text content - if response.content: - parts.append(TextPart(content=response.content)) - - # Add tool calls + # Add tool calls (both client-side and server-side) first + # For server-side tools, these were executed before generating the final content for tool_call in response.tool_calls: - parts.append( - ToolCallPart( - tool_name=tool_call.function.name, - args=tool_call.function.arguments, - tool_call_id=tool_call.id, + # Try to determine if this is a server-side tool + # In real responses, we can use get_tool_call_type() + # In mock responses, we default to client-side tools + is_server_side_tool = False + if hasattr(tool_call, 'type'): + try: + tool_type = get_tool_call_type(tool_call) + # If it's not a client-side tool, it's a server-side tool + is_server_side_tool = tool_type != 'client_side_tool' + except Exception: + # If we can't determine the type, treat as client-side + pass + + if is_server_side_tool: + # Server-side tools are executed by xAI, so we add both call and return parts + # The final result is in response.content + parts.append( + BuiltinToolCallPart( + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + provider_name='xai', + ) ) - ) + # Always add the return part for server-side tools since they're already executed + parts.append( + BuiltinToolReturnPart( + tool_name=tool_call.function.name, + content={'status': 'completed'}, + tool_call_id=tool_call.id, + provider_name='xai', + ) + ) + else: + # Client-side tool call (or mock) + parts.append( + ToolCallPart( + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + ) + + # Add text content after tool calls (for server-side tools, this is the final result) + if response.content: + parts.append(TextPart(content=response.content)) # Convert usage - try to access attributes, default to 0 if not available input_tokens = getattr(response.usage, 'input_tokens', 0) @@ -356,18 +434,52 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if event is not None: yield event - # Handle tool calls + # Handle tool calls (both client-side and server-side) # Note: We use the accumulated Response tool calls, not the Chunk deltas, # because pydantic validation needs complete JSON, not partial deltas if hasattr(response, 'tool_calls'): for tool_call in response.tool_calls: if hasattr(tool_call.function, 'name') and tool_call.function.name: - yield self._parts_manager.handle_tool_call_part( - vendor_part_id=tool_call.id, - tool_name=tool_call.function.name, - args=tool_call.function.arguments, - tool_call_id=tool_call.id, - ) + # Check if this is a server-side (built-in) tool + is_server_side_tool = False + if hasattr(tool_call, 'type'): + try: + tool_type = get_tool_call_type(tool_call) + # If it's not a client-side tool, it's a server-side tool + is_server_side_tool = tool_type != 'client_side_tool' + except Exception: + # If we can't determine the type, treat as client-side + pass + + if is_server_side_tool: + # Server-side tools - create BuiltinToolCallPart and BuiltinToolReturnPart + # These tools are already executed by xAI's infrastructure + call_part = BuiltinToolCallPart( + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + provider_name='xai', + ) + yield self._parts_manager.handle_part(vendor_part_id=tool_call.id, part=call_part) + + # Immediately yield the return part since the tool was already executed + return_part = BuiltinToolReturnPart( + tool_name=tool_call.function.name, + content={'status': 'completed'}, + tool_call_id=tool_call.id, + provider_name='xai', + ) + yield self._parts_manager.handle_part( + vendor_part_id=f'{tool_call.id}_return', part=return_part + ) + else: + # Client-side tools - use standard handler + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=tool_call.id, + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) @property def model_name(self) -> str: diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index ad27f077c8..c9c3c843e1 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -13,6 +13,7 @@ from pydantic_ai import ( Agent, BinaryContent, + BuiltinToolCallPart, ImageUrl, ModelRequest, ModelResponse, @@ -52,6 +53,12 @@ pytest.mark.skipif(not imports_successful(), reason='xai_sdk not installed'), pytest.mark.anyio, pytest.mark.vcr, + pytest.mark.filterwarnings( + 'ignore:`BuiltinToolCallEvent` is deprecated, look for `PartStartEvent` and `PartDeltaEvent` with `BuiltinToolCallPart` instead.:DeprecationWarning' + ), + pytest.mark.filterwarnings( + 'ignore:`BuiltinToolResultEvent` is deprecated, look for `PartStartEvent` and `PartDeltaEvent` with `BuiltinToolReturnPart` instead.:DeprecationWarning' + ), ] @@ -693,39 +700,189 @@ async def test_grok_image_as_binary_content_input( assert 'kiwi' in response_text or 'fruit' in response_text -# Skip tests that are not applicable to Grok model -# The following tests were removed as they are OpenAI-specific: +# Grok built-in tools tests +# Built-in tools are executed server-side by xAI's infrastructure +# Based on: https://github.com/xai-org/xai-sdk-python/blob/main/examples/aio/server_side_tools.py -# Continue with model request/response tests -# Grok-specific tests for built-in tools +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_builtin_web_search_tool(allow_model_requests: None, xai_api_key: str): + """Test Grok's built-in web_search tool.""" + from pydantic_ai import WebSearchTool + m = GrokModel('grok-4-fast', api_key=xai_api_key) + agent = Agent(m, builtin_tools=[WebSearchTool()]) -async def test_grok_web_search_tool(allow_model_requests: None): - """Test Grok model with web_search built-in tool.""" - # First response: tool call to web_search - tool_call = create_tool_call( - id='web-1', - name='web_search', - arguments={'query': 'latest news about AI'}, + result = await agent.run('What is the weather in San Francisco today?') + + # Verify the response + assert result.output + messages = result.all_messages() + assert len(messages) >= 2 + + # TODO: Add validation for built-in tool call parts once response parsing is fully tested + # Server-side tools are executed by xAI's infrastructure + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_builtin_x_search_tool(allow_model_requests: None, xai_api_key: str): + """Test Grok's built-in x_search tool (X/Twitter search).""" + # Note: This test is skipped until XSearchTool is properly implemented + # from pydantic_ai.builtin_tools import AbstractBuiltinTool + # + # class XSearchTool(AbstractBuiltinTool): + # """X (Twitter) search tool - specific to Grok.""" + # kind: str = 'x_search' + # + # m = GrokModel('grok-4-fast', api_key=xai_api_key) + # agent = Agent(m, builtin_tools=[XSearchTool()]) + # result = await agent.run('What is the latest post from @elonmusk?') + # assert result.output + pytest.skip('XSearchTool not yet implemented in pydantic-ai') + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_builtin_code_execution_tool(allow_model_requests: None, xai_api_key: str): + """Test Grok's built-in code_execution tool.""" + from pydantic_ai import CodeExecutionTool + + m = GrokModel('grok-4-fast', api_key=xai_api_key) + agent = Agent(m, builtin_tools=[CodeExecutionTool()]) + + # Use a simpler calculation similar to OpenAI tests + result = await agent.run('What is 65465 - 6544 * 65464 - 6 + 1.02255? Use code to calculate this.') + + # Verify the response + assert result.output + # Expected: 65465 - 6544*65464 - 6 + 1.02255 = -428050955.97745 + assert '-428' in result.output or 'million' in result.output.lower() + + messages = result.all_messages() + assert len(messages) >= 2 + + # TODO: Add validation for built-in tool call parts once response parsing is fully tested + # Server-side tools are executed by xAI's infrastructure + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_builtin_multiple_tools(allow_model_requests: None, xai_api_key: str): + """Test using multiple built-in tools together.""" + from pydantic_ai import CodeExecutionTool, WebSearchTool + + m = GrokModel('grok-4-fast', api_key=xai_api_key) + agent = Agent( + m, + instructions='You are a helpful assistant.', + builtin_tools=[WebSearchTool(), CodeExecutionTool()], ) - response1 = create_response(tool_calls=[tool_call]) - # Second response: final answer - response2 = create_response(content='Based on web search: AI is advancing rapidly.') + result = await agent.run( + 'Search for the current price of Bitcoin and calculate its percentage change if it was $50000 last week.' + ) - mock_client = MockGrok.create_mock([response1, response2]) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) - agent = Agent(m) + # Verify the response + assert result.output + messages = result.all_messages() + assert len(messages) >= 2 + + # The model should use both tools (basic validation that registration works) + # TODO: Add validation for built-in tool usage once response parsing is fully tested + + +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') +async def test_grok_builtin_tools_with_custom_tools(allow_model_requests: None, xai_api_key: str): + """Test mixing Grok's built-in tools with custom (client-side) tools.""" + from pydantic_ai import WebSearchTool + + m = GrokModel('grok-4-fast', api_key=xai_api_key) + agent = Agent(m, builtin_tools=[WebSearchTool()]) - # Add a mock web search tool @agent.tool_plain - async def web_search(query: str) -> str: - return f'Search results for: {query}' + def get_local_temperature(city: str) -> str: + """Get the local temperature for a city (mock).""" + return f'The local temperature in {city} is 72°F' + + result = await agent.run('What is the weather in Tokyo? Use web search and then get the local temperature.') + + # Verify the response + assert result.output + messages = result.all_messages() - result = await agent.run('What is the latest news about AI?') - assert 'AI is advancing rapidly' in result.output - assert result.usage().requests == 2 + # Should have both built-in tool calls and custom tool calls + assert len(messages) >= 4 # Request, builtin response, request, custom tool response + + +async def test_grok_builtin_tools_wiring(allow_model_requests: None): + """Test that built-in tools are correctly wired to xAI SDK.""" + from pydantic_ai import CodeExecutionTool, MCPServerTool, WebSearchTool + + response = create_response(content='Built-in tools are registered') + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-4-fast', client=mock_client) + agent = Agent( + m, + builtin_tools=[ + WebSearchTool(), + CodeExecutionTool(), + MCPServerTool( + id='test-mcp', + url='https://example.com/mcp', + description='Test MCP server', + authorization_token='test-token', + ), + ], + ) + + # If this runs without error, the built-in tools are correctly wired + result = await agent.run('Test built-in tools') + assert result.output == 'Built-in tools are registered' + + +@pytest.mark.skipif( + os.getenv('XAI_API_KEY') is None or os.getenv('LINEAR_ACCESS_TOKEN') is None, + reason='Requires XAI_API_KEY and LINEAR_ACCESS_TOKEN (gRPC, no cassettes)', +) +async def test_grok_builtin_mcp_server_tool(allow_model_requests: None, xai_api_key: str): + """Test Grok's MCP server tool with Linear.""" + from pydantic_ai import MCPServerTool + + linear_token = os.getenv('LINEAR_ACCESS_TOKEN') + m = GrokModel('grok-4-fast', api_key=xai_api_key) + agent = Agent( + m, + instructions='You are a helpful assistant.', + builtin_tools=[ + MCPServerTool( + id='linear', + url='https://mcp.linear.app/mcp', + description='MCP server for Linear the project management tool.', + authorization_token=linear_token, + ), + ], + ) + + result = await agent.run('Can you list my Linear issues? Keep your answer brief.') + + # Verify the response + assert result.output + messages = result.all_messages() + assert len(messages) >= 2 + + # Check that we have builtin tool call parts for MCP (server-side tool with server_label prefix) + response_message = messages[-1] + assert isinstance(response_message, ModelResponse) + + # Should have at least one BuiltinToolCallPart for MCP tools (prefixed with server_label, e.g. "linear.list_issues") + mcp_tool_calls = [ + part + for msg in messages + if isinstance(msg, ModelResponse) + for part in msg.parts + if isinstance(part, BuiltinToolCallPart) and part.tool_name.startswith('linear.') + ] + assert len(mcp_tool_calls) > 0, ( + f'Expected MCP tool calls with "linear." prefix, got parts: {[part for msg in messages if isinstance(msg, ModelResponse) for part in msg.parts]}' + ) async def test_grok_model_retries(allow_model_requests: None): From 18b21a57b736a7f47ba769473206746f00225416 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 16:46:29 -0800 Subject: [PATCH 09/16] Updating docs --- docs/api/models/grok.md | 2 +- docs/models/grok.md | 47 ++++++++++++++++------------------------- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/docs/api/models/grok.md b/docs/api/models/grok.md index 699c1e95f9..c37429873f 100644 --- a/docs/api/models/grok.md +++ b/docs/api/models/grok.md @@ -2,6 +2,6 @@ ## Setup -For details on how to set up authentication with this model, see [model configuration for Grokq](../../models/grokq.md). +For details on how to set up authentication with this model, see [model configuration for Grok](../../models/grok.md). ::: pydantic_ai.models.grok diff --git a/docs/models/grok.md b/docs/models/grok.md index db04423493..629d2a88e6 100644 --- a/docs/models/grok.md +++ b/docs/models/grok.md @@ -1,33 +1,31 @@ -# Groq +# Grok ## Install -To use `GroqModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `groq` optional group: +To use `GrokModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `grok` optional group: ```bash -pip/uv-add "pydantic-ai-slim[groq]" +pip/uv-add "pydantic-ai-slim[grok]" ``` ## Configuration -To use [Groq](https://groq.com/) through their API, go to [console.groq.com/keys](https://console.groq.com/keys) and follow your nose until you find the place to generate an API key. - -`GroqModelName` contains a list of available Groq models. +To use Grok from [xAI](https://x.ai/api) through their API, go to [console.x.ai]https://console.x.ai) and follow your nose until you find the place to create an API key. ## Environment variable Once you have the API key, you can set it as an environment variable: ```bash -export GROQ_API_KEY='your-api-key' +export XAI_API_KEY='your-api-key' ``` -You can then use `GroqModel` by name: +You can then use `GrokModel` by name: ```python from pydantic_ai import Agent -agent = Agent('groq:llama-3.3-70b-versatile') +agent = Agent('grok:grok-4-fast-non-reasoning') ... ``` @@ -35,43 +33,34 @@ Or initialise the model directly with just the model name: ```python from pydantic_ai import Agent -from pydantic_ai.models.groq import GroqModel +from pydantic_ai.models.grok import GrokModel -model = GroqModel('llama-3.3-70b-versatile') +model = GrokModel('grok-4-fast-non-reasoning') agent = Agent(model) ... ``` -## `provider` argument - -You can provide a custom `Provider` via the `provider` argument: +You can provide your own `api_key` inline like so: ```python from pydantic_ai import Agent -from pydantic_ai.models.groq import GroqModel -from pydantic_ai.providers.groq import GroqProvider +from pydantic_ai.models.grok import GrokModel -model = GroqModel( - 'llama-3.3-70b-versatile', provider=GroqProvider(api_key='your-api-key') -) +model = GrokModel('grok-4-fast-non-reasoning', api_key='your-api-key') agent = Agent(model) ... ``` -You can also customize the `GroqProvider` with a custom `httpx.AsyncHTTPClient`: +You can also customize the `GrokModel` with a custom `xai_sdk.AsyncClient`: ```python -from httpx import AsyncClient +from xai_sdk import AsyncClient +async_client = AsyncClient(api_key='your-api-key') from pydantic_ai import Agent -from pydantic_ai.models.groq import GroqModel -from pydantic_ai.providers.groq import GroqProvider - -custom_http_client = AsyncClient(timeout=30) -model = GroqModel( - 'llama-3.3-70b-versatile', - provider=GroqProvider(api_key='your-api-key', http_client=custom_http_client), -) +from pydantic_ai.models.grok import GrokModel + +model = GrokModel('grok-4-fast-non-reasoning', client=async_client) agent = Agent(model) ... ``` From b41748544497495a392de454ff062052b891da41 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 17:01:57 -0800 Subject: [PATCH 10/16] Set provider response id --- pydantic_ai_slim/pydantic_ai/models/grok.py | 1 + tests/models/test_grok.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index 9d0778ca93..8819ec85b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -384,6 +384,7 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: model_name=self._model_name, timestamp=now_utc(), provider_name='xai', + provider_response_id=response.id if hasattr(response, 'id') else None, finish_reason=finish_reason, ) diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index c9c3c843e1..4ca65b616c 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -96,6 +96,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), @@ -108,6 +109,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), @@ -179,6 +181,7 @@ async def test_grok_request_structured_response(allow_model_requests: None): model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), @@ -246,6 +249,7 @@ async def get_location(loc_name: str) -> str: model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), @@ -275,6 +279,7 @@ async def get_location(loc_name: str) -> str: model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), @@ -294,6 +299,7 @@ async def get_location(loc_name: str) -> str: model_name='grok-4-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', + provider_response_id='grok-123', finish_reason='stop', run_id=IsStr(), ), From 0f1c11341ca0c8b7999fac40ab20e0b8080176ce Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 17:55:41 -0800 Subject: [PATCH 11/16] Adding reasoning and usage --- .../stock_analysis_agent.py | 1 + pydantic_ai_slim/pydantic_ai/models/grok.py | 291 +++++++---- tests/models/mock_grok.py | 12 + tests/models/test_grok.py | 451 +++++++++++++++++- 4 files changed, 661 insertions(+), 94 deletions(-) diff --git a/examples/pydantic_ai_examples/stock_analysis_agent.py b/examples/pydantic_ai_examples/stock_analysis_agent.py index 13ea704492..65d93d065b 100644 --- a/examples/pydantic_ai_examples/stock_analysis_agent.py +++ b/examples/pydantic_ai_examples/stock_analysis_agent.py @@ -136,6 +136,7 @@ async def main(): print(f'\nAnalysis: {result.output.analysis}') # Get the final response message for metadata + print(result.all_messages()) final_message = result.all_messages()[-1] if isinstance(final_message, ModelResponse): print('\nšŸ†” Response Metadata:') diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index 8819ec85b8..893ae6a02b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -1,7 +1,7 @@ """Grok model implementation using xAI SDK.""" import os -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any @@ -31,6 +31,7 @@ ModelResponseStreamEvent, SystemPromptPart, TextPart, + ThinkingPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -308,6 +309,26 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: parts: list[ModelResponsePart] = [] + # Add reasoning/thinking content first if present + if hasattr(response, 'reasoning_content') and response.reasoning_content: + # reasoning_content is the human-readable summary + parts.append( + ThinkingPart( + content=response.reasoning_content, + signature=None, + provider_name='xai', + ) + ) + elif hasattr(response, 'encrypted_content') and response.encrypted_content: + # encrypted_content is a signature that can be sent back for reasoning continuity + parts.append( + ThinkingPart( + content='', # No readable content for encrypted-only reasoning + signature=response.encrypted_content, + provider_name='xai', + ) + ) + # Add tool calls (both client-side and server-side) first # For server-side tools, these were executed before generating the final content for tool_call in response.tool_calls: @@ -358,10 +379,8 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: if response.content: parts.append(TextPart(content=response.content)) - # Convert usage - try to access attributes, default to 0 if not available - input_tokens = getattr(response.usage, 'input_tokens', 0) - output_tokens = getattr(response.usage, 'output_tokens', 0) - usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) + # Convert usage with detailed token information + usage = self._map_usage(response) # Map finish reason finish_reason_map = { @@ -388,6 +407,69 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse: finish_reason=finish_reason, ) + def _map_usage(self, response: chat_types.Response) -> RequestUsage: + """Extract usage information from xAI SDK response, including reasoning tokens and cache tokens.""" + return GrokModel.extract_usage(response) + + @staticmethod + def extract_usage(response: chat_types.Response) -> RequestUsage: + """Extract usage information from xAI SDK response. + + Extracts token counts and additional usage details including: + - reasoning_tokens: Tokens used for model reasoning/thinking + - cache_read_tokens: Tokens read from prompt cache + - server_side_tools_used: Count of server-side (built-in) tools executed + """ + if not hasattr(response, 'usage'): + return RequestUsage() + + usage_obj = getattr(response, 'usage', None) + if not usage_obj: + return RequestUsage() + + prompt_tokens = getattr(usage_obj, 'prompt_tokens', 0) + completion_tokens = getattr(usage_obj, 'completion_tokens', 0) + + # Build details dict for additional usage metrics + details: dict[str, int] = {} + + # Add reasoning tokens if available + if hasattr(usage_obj, 'reasoning_tokens'): + reasoning_tokens = getattr(usage_obj, 'reasoning_tokens', 0) + if reasoning_tokens: + details['reasoning_tokens'] = reasoning_tokens + + # Add cached prompt tokens if available + if hasattr(usage_obj, 'cached_prompt_text_tokens'): + cached_tokens = getattr(usage_obj, 'cached_prompt_text_tokens', 0) + if cached_tokens: + details['cache_read_tokens'] = cached_tokens + + # Add server-side tools used count if available + if hasattr(usage_obj, 'server_side_tools_used'): + server_side_tools = getattr(usage_obj, 'server_side_tools_used', None) + # server_side_tools_used is a repeated field (list-like) in the real SDK + # but may be an int in mocks for simplicity + if server_side_tools: + if isinstance(server_side_tools, int): + tools_count = server_side_tools + else: + tools_count = len(server_side_tools) + if tools_count: + details['server_side_tools_used'] = tools_count + + if details: + return RequestUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + details=details, + ) + else: + return RequestUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + ) + @dataclass class GrokStreamedResponse(StreamedResponse): @@ -398,89 +480,134 @@ class GrokStreamedResponse(StreamedResponse): _timestamp: Any _provider_name: str + def _update_response_state(self, response: Any) -> None: + """Update response state including usage, response ID, and finish reason.""" + from typing import cast + + # Update usage + if hasattr(response, 'usage'): + self._usage = GrokModel.extract_usage(response) + + # Set provider response ID + if hasattr(response, 'id') and self.provider_response_id is None: + self.provider_response_id = response.id + + # Handle finish reason + if hasattr(response, 'finish_reason') and response.finish_reason: + finish_reason_map = { + 'stop': 'stop', + 'length': 'length', + 'content_filter': 'content_filter', + 'max_output_tokens': 'length', + 'cancelled': 'error', + 'failed': 'error', + } + mapped_reason = finish_reason_map.get(response.finish_reason, 'stop') + self.finish_reason = cast(FinishReason, mapped_reason) + + def _handle_reasoning_content(self, response: Any, reasoning_handled: bool) -> Iterator[ModelResponseStreamEvent]: + """Handle reasoning content (both readable and encrypted).""" + if reasoning_handled: + return + + if hasattr(response, 'reasoning_content') and response.reasoning_content: + # reasoning_content is the human-readable summary + thinking_part = ThinkingPart( + content=response.reasoning_content, + signature=None, + provider_name='xai', + ) + yield self._parts_manager.handle_part(vendor_part_id='reasoning', part=thinking_part) + elif hasattr(response, 'encrypted_content') and response.encrypted_content: + # encrypted_content is a signature that can be sent back for reasoning continuity + thinking_part = ThinkingPart( + content='', # No readable content for encrypted-only reasoning + signature=response.encrypted_content, + provider_name='xai', + ) + yield self._parts_manager.handle_part(vendor_part_id='encrypted_reasoning', part=thinking_part) + + def _handle_text_delta(self, chunk: Any) -> Iterator[ModelResponseStreamEvent]: + """Handle text content delta from chunk.""" + if hasattr(chunk, 'content') and chunk.content: + event = self._parts_manager.handle_text_delta( + vendor_part_id='content', + content=chunk.content, + ) + if event is not None: + yield event + + def _handle_single_tool_call(self, tool_call: Any) -> Iterator[ModelResponseStreamEvent]: + """Handle a single tool call, routing to server-side or client-side handler.""" + if not (hasattr(tool_call.function, 'name') and tool_call.function.name): + return + + # Determine if this is a server-side (built-in) tool + is_server_side_tool = False + if hasattr(tool_call, 'type'): + try: + tool_type = get_tool_call_type(tool_call) + is_server_side_tool = tool_type != 'client_side_tool' + except Exception: + pass # Treat as client-side if we can't determine + + if is_server_side_tool: + # Server-side tools - create BuiltinToolCallPart and BuiltinToolReturnPart + # These tools are already executed by xAI's infrastructure + call_part = BuiltinToolCallPart( + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + provider_name='xai', + ) + yield self._parts_manager.handle_part(vendor_part_id=tool_call.id, part=call_part) + + # Immediately yield the return part since the tool was already executed + return_part = BuiltinToolReturnPart( + tool_name=tool_call.function.name, + content={'status': 'completed'}, + tool_call_id=tool_call.id, + provider_name='xai', + ) + yield self._parts_manager.handle_part(vendor_part_id=f'{tool_call.id}_return', part=return_part) + else: + # Client-side tools - use standard handler + yield self._parts_manager.handle_tool_call_part( + vendor_part_id=tool_call.id, + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + + def _handle_tool_calls(self, response: Any) -> Iterator[ModelResponseStreamEvent]: + """Handle tool calls (both client-side and server-side).""" + if not hasattr(response, 'tool_calls'): + return + + for tool_call in response.tool_calls: + yield from self._handle_single_tool_call(tool_call) + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: """Iterate over streaming events from xAI SDK.""" - from typing import cast + reasoning_handled = False # Track if we've already handled reasoning content async for response, chunk in self._response: - # Update usage if available - if hasattr(response, 'usage'): - input_tokens = getattr(response.usage, 'input_tokens', 0) - output_tokens = getattr(response.usage, 'output_tokens', 0) - self._usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens) - - # Set provider response ID - if hasattr(response, 'id') and self.provider_response_id is None: - self.provider_response_id = response.id - - # Handle finish reason - if hasattr(response, 'finish_reason') and response.finish_reason: - finish_reason_map = { - 'stop': 'stop', - 'length': 'length', - 'content_filter': 'content_filter', - 'max_output_tokens': 'length', - 'cancelled': 'error', - 'failed': 'error', - } - mapped_reason = finish_reason_map.get(response.finish_reason, 'stop') - self.finish_reason = cast(FinishReason, mapped_reason) - - # Handle text content - if hasattr(chunk, 'content') and chunk.content: - event = self._parts_manager.handle_text_delta( - vendor_part_id='content', - content=chunk.content, - ) - if event is not None: + self._update_response_state(response) + + # Handle reasoning content (only emit once) + reasoning_events = list(self._handle_reasoning_content(response, reasoning_handled)) + if reasoning_events: + reasoning_handled = True + for event in reasoning_events: yield event - # Handle tool calls (both client-side and server-side) - # Note: We use the accumulated Response tool calls, not the Chunk deltas, - # because pydantic validation needs complete JSON, not partial deltas - if hasattr(response, 'tool_calls'): - for tool_call in response.tool_calls: - if hasattr(tool_call.function, 'name') and tool_call.function.name: - # Check if this is a server-side (built-in) tool - is_server_side_tool = False - if hasattr(tool_call, 'type'): - try: - tool_type = get_tool_call_type(tool_call) - # If it's not a client-side tool, it's a server-side tool - is_server_side_tool = tool_type != 'client_side_tool' - except Exception: - # If we can't determine the type, treat as client-side - pass - - if is_server_side_tool: - # Server-side tools - create BuiltinToolCallPart and BuiltinToolReturnPart - # These tools are already executed by xAI's infrastructure - call_part = BuiltinToolCallPart( - tool_name=tool_call.function.name, - args=tool_call.function.arguments, - tool_call_id=tool_call.id, - provider_name='xai', - ) - yield self._parts_manager.handle_part(vendor_part_id=tool_call.id, part=call_part) - - # Immediately yield the return part since the tool was already executed - return_part = BuiltinToolReturnPart( - tool_name=tool_call.function.name, - content={'status': 'completed'}, - tool_call_id=tool_call.id, - provider_name='xai', - ) - yield self._parts_manager.handle_part( - vendor_part_id=f'{tool_call.id}_return', part=return_part - ) - else: - # Client-side tools - use standard handler - yield self._parts_manager.handle_tool_call_part( - vendor_part_id=tool_call.id, - tool_name=tool_call.function.name, - args=tool_call.function.arguments, - tool_call_id=tool_call.id, - ) + # Handle text content delta + for event in self._handle_text_delta(chunk): + yield event + + # Handle tool calls + for event in self._handle_tool_calls(response): + yield event @property def model_name(self) -> str: diff --git a/tests/models/mock_grok.py b/tests/models/mock_grok.py index 4295bd070d..aac8c605a5 100644 --- a/tests/models/mock_grok.py +++ b/tests/models/mock_grok.py @@ -124,6 +124,14 @@ class MockGrokResponse: tool_calls: list[Any] = field(default_factory=list) finish_reason: str = 'stop' usage: Any | None = None # Would be usage_pb2.SamplingUsage in real xai_sdk + reasoning_content: str = '' # Human-readable reasoning trace + encrypted_content: str = '' # Encrypted reasoning signature + + # Note: The real xAI SDK usage object uses protobuf fields: + # - prompt_tokens (not input_tokens) + # - completion_tokens (not output_tokens) + # - reasoning_tokens + # - cached_prompt_text_tokens @dataclass @@ -147,6 +155,8 @@ def create_response( tool_calls: list[Any] | None = None, finish_reason: str = 'stop', usage: Any | None = None, + reasoning_content: str = '', + encrypted_content: str = '', ) -> chat_types.Response: """Create a mock Response object for testing. @@ -160,6 +170,8 @@ def create_response( tool_calls=tool_calls or [], finish_reason=finish_reason, usage=usage, + reasoning_content=reasoning_content, + encrypted_content=encrypted_content, ), ) diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index 4ca65b616c..6629689710 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -21,6 +21,7 @@ RetryPromptPart, SystemPromptPart, TextPart, + ThinkingPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -120,7 +121,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): async def test_grok_request_simple_usage(allow_model_requests: None): response = create_response( content='world', - usage=SimpleNamespace(input_tokens=2, output_tokens=1), + usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1), ) mock_client = MockGrok.create_mock(response) m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) @@ -204,11 +205,11 @@ async def test_grok_request_tool_call(allow_model_requests: None): responses = [ create_response( tool_calls=[create_tool_call(id='1', name='get_location', arguments={'loc_name': 'San Fransisco'})], - usage=SimpleNamespace(input_tokens=2, output_tokens=1), + usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1), ), create_response( tool_calls=[create_tool_call(id='2', name='get_location', arguments={'loc_name': 'London'})], - usage=SimpleNamespace(input_tokens=3, output_tokens=2), + usage=SimpleNamespace(prompt_tokens=3, completion_tokens=2), ), create_response(content='final response'), ] @@ -331,7 +332,35 @@ def grok_text_chunk(text: str, finish_reason: str = 'stop') -> tuple[chat_types. content=text, # This will be accumulated by the streaming handler tool_calls=[], finish_reason=finish_reason if finish_reason else '', - usage=SimpleNamespace(input_tokens=2, output_tokens=1) if finish_reason else None, + usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1) if finish_reason else None, + ) + + return (cast(chat_types.Response, response), chunk) + + +def grok_reasoning_text_chunk( + text: str, reasoning_content: str = '', encrypted_content: str = '', finish_reason: str = 'stop' +) -> tuple[chat_types.Response, Any]: + """Create a text streaming chunk for Grok with reasoning content. + + Args: + text: The text content delta + reasoning_content: The reasoning trace (accumulated, not a delta) + encrypted_content: The encrypted reasoning signature (accumulated, not a delta) + finish_reason: The finish reason + """ + # Create chunk (delta) - just this piece of text + chunk = MockGrokResponseChunk(content=text, tool_calls=[]) + + # Create response (accumulated) - includes reasoning content + response = MockGrokResponse( + id='grok-123', + content=text, + tool_calls=[], + finish_reason=finish_reason if finish_reason else '', + usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1) if finish_reason else None, + reasoning_content=reasoning_content, + encrypted_content=encrypted_content, ) return (cast(chat_types.Response, response), chunk) @@ -431,7 +460,7 @@ def grok_tool_chunk( content='', tool_calls=[response_tool_call] if (effective_tool_name is not None or accumulated_args) else [], finish_reason=finish_reason, - usage=SimpleNamespace(input_tokens=20, output_tokens=1) if finish_reason else None, + usage=SimpleNamespace(prompt_tokens=20, completion_tokens=1) if finish_reason else None, ) return (cast(chat_types.Response, response), chunk) @@ -719,15 +748,15 @@ async def test_grok_builtin_web_search_tool(allow_model_requests: None, xai_api_ m = GrokModel('grok-4-fast', api_key=xai_api_key) agent = Agent(m, builtin_tools=[WebSearchTool()]) - result = await agent.run('What is the weather in San Francisco today?') - - # Verify the response + result = await agent.run('Return just the day of week for the date of Jan 1 in 2026?') assert result.output - messages = result.all_messages() - assert len(messages) >= 2 + assert result.output.lower() == 'thursday' - # TODO: Add validation for built-in tool call parts once response parsing is fully tested - # Server-side tools are executed by xAI's infrastructure + # Verify that server-side tools were used + usage = result.usage() + assert usage.details is not None + assert 'server_side_tools_used' in usage.details + assert usage.details['server_side_tools_used'] > 0 @pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') @@ -1002,4 +1031,402 @@ async def test_grok_model_properties(): assert m.system == 'xai' +# Tests for reasoning/thinking content (similar to OpenAI Responses tests) + + +async def test_grok_reasoning_simple(allow_model_requests: None): + """Test Grok model with simple reasoning content.""" + response = create_response( + content='The answer is 4', + reasoning_content='Let me think: 2+2 equals 4', + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=20), + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('What is 2+2?') + assert result.output == 'The answer is 4' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is 2+2?', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ThinkingPart(content='Let me think: 2+2 equals 4', signature=None, provider_name='xai'), + TextPart(content='The answer is 4'), + ], + usage=RequestUsage(input_tokens=10, output_tokens=20), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_grok_encrypted_content_only(allow_model_requests: None): + """Test Grok model with encrypted content (signature) only.""" + response = create_response( + content='4', + encrypted_content='abc123signature', + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5), + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('What is 2+2?') + assert result.output == '4' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is 2+2?', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ThinkingPart(content='', signature='abc123signature', provider_name='xai'), + TextPart(content='4'), + ], + usage=RequestUsage(input_tokens=10, output_tokens=5), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_grok_reasoning_without_summary(allow_model_requests: None): + """Test Grok model with encrypted content but no reasoning summary.""" + response = create_response( + content='4', + encrypted_content='encrypted123', + ) + mock_client = MockGrok.create_mock(response) + model = GrokModel('grok-3', client=mock_client) + + agent = Agent(model=model) + result = await agent.run('What is 2+2?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is 2+2?', + timestamp=IsDatetime(), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ThinkingPart(content='', signature='encrypted123', provider_name='xai'), + TextPart(content='4'), + ], + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_grok_reasoning_with_tool_calls(allow_model_requests: None): + """Test Grok model with reasoning content and tool calls.""" + responses = [ + create_response( + tool_calls=[create_tool_call(id='1', name='calculate', arguments={'expression': '2+2'})], + reasoning_content='I need to use the calculate tool to solve this', + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=30), + ), + create_response( + content='The calculation shows that 2+2 equals 4', + usage=SimpleNamespace(prompt_tokens=15, completion_tokens=10), + ), + ] + mock_client = MockGrok.create_mock(responses) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + @agent.tool_plain + async def calculate(expression: str) -> str: + return '4' + + result = await agent.run('What is 2+2?') + assert result.output == 'The calculation shows that 2+2 equals 4' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is 2+2?', timestamp=IsNow(tz=timezone.utc))], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ThinkingPart( + content='I need to use the calculate tool to solve this', signature=None, provider_name='xai' + ), + ToolCallPart( + tool_name='calculate', + args={'expression': '2+2'}, + tool_call_id='1', + ), + ], + usage=RequestUsage(input_tokens=10, output_tokens=30), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='calculate', + content='4', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='The calculation shows that 2+2 equals 4')], + usage=RequestUsage(input_tokens=15, output_tokens=10), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_grok_reasoning_with_encrypted_and_tool_calls(allow_model_requests: None): + """Test Grok model with encrypted reasoning content and tool calls.""" + responses = [ + create_response( + tool_calls=[create_tool_call(id='1', name='get_weather', arguments={'city': 'San Francisco'})], + encrypted_content='encrypted_reasoning_abc123', + usage=SimpleNamespace(prompt_tokens=20, completion_tokens=40), + ), + create_response( + content='The weather in San Francisco is sunny', + usage=SimpleNamespace(prompt_tokens=25, completion_tokens=12), + ), + ] + mock_client = MockGrok.create_mock(responses) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + @agent.tool_plain + async def get_weather(city: str) -> str: + return 'sunny' + + result = await agent.run('What is the weather in San Francisco?') + assert result.output == 'The weather in San Francisco is sunny' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart(content='What is the weather in San Francisco?', timestamp=IsNow(tz=timezone.utc)) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ThinkingPart(content='', signature='encrypted_reasoning_abc123', provider_name='xai'), + ToolCallPart( + tool_name='get_weather', + args={'city': 'San Francisco'}, + tool_call_id='1', + ), + ], + usage=RequestUsage(input_tokens=20, output_tokens=40), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_weather', + content='sunny', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='The weather in San Francisco is sunny')], + usage=RequestUsage(input_tokens=25, output_tokens=12), + model_name='grok-3', + timestamp=IsDatetime(), + provider_name='xai', + provider_response_id='grok-123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_grok_stream_with_reasoning(allow_model_requests: None): + """Test Grok streaming with reasoning content.""" + stream = [ + grok_reasoning_text_chunk('The answer', reasoning_content='Let me think about this...', finish_reason=''), + grok_reasoning_text_chunk(' is 4', reasoning_content='Let me think about this...', finish_reason='stop'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('What is 2+2?') as result: + assert not result.is_complete + text_chunks = [c async for c in result.stream_text(debounce_by=None)] + assert text_chunks == snapshot(['The answer', 'The answer is 4']) + assert result.is_complete + + # Verify the final response includes both reasoning and text + messages = result.all_messages() + assert len(messages) == 2 + assert isinstance(messages[1], ModelResponse) + assert len(messages[1].parts) == 2 + assert isinstance(messages[1].parts[0], ThinkingPart) + assert messages[1].parts[0].content == 'Let me think about this...' + assert isinstance(messages[1].parts[1], TextPart) + assert messages[1].parts[1].content == 'The answer is 4' + + +async def test_grok_stream_with_encrypted_reasoning(allow_model_requests: None): + """Test Grok streaming with encrypted reasoning content.""" + stream = [ + grok_reasoning_text_chunk('The weather', encrypted_content='encrypted_abc123', finish_reason=''), + grok_reasoning_text_chunk(' is sunny', encrypted_content='encrypted_abc123', finish_reason='stop'), + ] + mock_client = MockGrok.create_mock_stream(stream) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + async with agent.run_stream('What is the weather?') as result: + assert not result.is_complete + text_chunks = [c async for c in result.stream_text(debounce_by=None)] + assert text_chunks == snapshot(['The weather', 'The weather is sunny']) + assert result.is_complete + + # Verify the final response includes both encrypted reasoning and text + messages = result.all_messages() + assert len(messages) == 2 + assert isinstance(messages[1], ModelResponse) + assert len(messages[1].parts) == 2 + assert isinstance(messages[1].parts[0], ThinkingPart) + assert messages[1].parts[0].content == '' # No readable content for encrypted-only + assert messages[1].parts[0].signature == 'encrypted_abc123' + assert isinstance(messages[1].parts[1], TextPart) + assert messages[1].parts[1].content == 'The weather is sunny' + + +async def test_grok_usage_with_reasoning_tokens(allow_model_requests: None): + """Test that Grok model properly extracts reasoning_tokens and cache_read_tokens from usage.""" + # Create a mock usage object with reasoning_tokens and cached_prompt_text_tokens + mock_usage = SimpleNamespace( + prompt_tokens=100, + completion_tokens=50, + reasoning_tokens=25, + cached_prompt_text_tokens=30, + ) + response = create_response( + content='The answer is 42', + reasoning_content='Let me think deeply about this...', + usage=mock_usage, + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('What is the meaning of life?') + assert result.output == 'The answer is 42' + + # Verify usage includes details + usage = result.usage() + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + assert usage.total_tokens == 150 + assert usage.details == snapshot({'reasoning_tokens': 25, 'cache_read_tokens': 30}) + + +async def test_grok_usage_without_details(allow_model_requests: None): + """Test that Grok model handles usage without reasoning_tokens or cached tokens.""" + mock_usage = SimpleNamespace( + prompt_tokens=20, + completion_tokens=10, + ) + response = create_response( + content='Simple answer', + usage=mock_usage, + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-3', client=mock_client) + agent = Agent(m) + + result = await agent.run('Simple question') + assert result.output == 'Simple answer' + + # Verify usage without details + usage = result.usage() + assert usage.input_tokens == 20 + assert usage.output_tokens == 10 + assert usage.total_tokens == 30 + # details should be empty dict when no additional usage info is provided + assert usage.details == snapshot({}) + + +async def test_grok_usage_with_server_side_tools(allow_model_requests: None): + """Test that Grok model properly extracts server_side_tools_used from usage.""" + # Create a mock usage object with server_side_tools_used + # Note: In the real SDK, server_side_tools_used is a repeated field (list-like), + # but we use an int in mocks for simplicity + mock_usage = SimpleNamespace( + prompt_tokens=50, + completion_tokens=30, + server_side_tools_used=2, + ) + response = create_response( + content='The answer based on web search', + usage=mock_usage, + ) + mock_client = MockGrok.create_mock(response) + m = GrokModel('grok-4-fast', client=mock_client) + agent = Agent(m) + + result = await agent.run('Search for something') + assert result.output == 'The answer based on web search' + + # Verify usage includes server_side_tools_used in details + usage = result.usage() + assert usage.input_tokens == 50 + assert usage.output_tokens == 30 + assert usage.total_tokens == 80 + assert usage.details == snapshot({'server_side_tools_used': 2}) + + # End of tests From 2b58f59914145c7a3ced79718cd0a158fedc6d02 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 18:33:44 -0800 Subject: [PATCH 12/16] Update stock analytics agent to stream responses --- .../stock_analysis_agent.py | 131 +++++------------- tests/models/test_grok.py | 5 +- 2 files changed, 32 insertions(+), 104 deletions(-) diff --git a/examples/pydantic_ai_examples/stock_analysis_agent.py b/examples/pydantic_ai_examples/stock_analysis_agent.py index 65d93d065b..db894967fd 100644 --- a/examples/pydantic_ai_examples/stock_analysis_agent.py +++ b/examples/pydantic_ai_examples/stock_analysis_agent.py @@ -1,13 +1,11 @@ -"""Example of using Grok's server-side tools (web_search, code_execution) with a local function. +"""Example of using Grok's server-side web_search tool. This agent: -1. Uses web_search to find the best performing NASDAQ stock over the last week -2. Uses code_execution to project the price using linear regression -3. Calls a local function project_price with the results +1. Uses web_search to find the hottest performing stock yesterday +2. Provides buy analysis for the user """ import os -from datetime import datetime import logfire from pydantic import BaseModel, Field @@ -15,9 +13,6 @@ from pydantic_ai import ( Agent, BuiltinToolCallPart, - CodeExecutionTool, - ModelResponse, - RunContext, WebSearchTool, ) from pydantic_ai.models.grok import GrokModel @@ -35,126 +30,62 @@ model = GrokModel('grok-4-fast', api_key=xai_api_key) -class StockProjection(BaseModel): - """Projection of stock price at year end.""" +class StockAnalysis(BaseModel): + """Analysis of top performing stock.""" stock_symbol: str = Field(description='Stock ticker symbol') current_price: float = Field(description='Current stock price') - projected_price: float = Field(description='Projected price at end of year') - analysis: str = Field(description='Brief analysis of the projection') + buy_analysis: str = Field(description='Brief analysis for whether to buy the stock') -# This agent uses server-side tools to research and analyze stocks -stock_analysis_agent = Agent[None, StockProjection]( +# This agent uses server-side web search to research stocks +stock_analysis_agent = Agent[None, StockAnalysis]( model=model, - output_type=StockProjection, - builtin_tools=[ - WebSearchTool(), # Server-side web search - CodeExecutionTool(), # Server-side code execution - ], + output_type=StockAnalysis, + builtin_tools=[WebSearchTool()], system_prompt=( 'You are a stock analysis assistant. ' - 'Use web_search to find recent stock performance data on NASDAQ. ' - 'Use code_execution to perform linear regression for price projection. ' - 'After analysis, call project_price with your findings.' + 'Use web_search to find the hottest performing stock from yesterday on NASDAQ. ' + 'Provide the current price and a brief buy analysis explaining whether this is a good buy.' ), ) -@stock_analysis_agent.tool -def project_price(ctx: RunContext[None], stock: str, price: float) -> str: - """Record the projected stock price. - - This is a local/client-side function that gets called with the analysis results. - - Args: - ctx: The run context (not used in this function) - stock: Stock ticker symbol - price: Projected price at end of year - """ - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - logfire.info( - 'Stock projection recorded', - stock=stock, - projected_price=price, - timestamp=timestamp, - ) - print('\nšŸ“Š PROJECTION RECORDED:') - print(f' Stock: {stock}') - print(f' Projected End-of-Year Price: ${price:.2f}') - print(f' Timestamp: {timestamp}\n') - - return f'Projection for {stock} at ${price:.2f} has been recorded successfully.' - - async def main(): """Run the stock analysis agent.""" - query = ( - 'Can you find me the best performing stock on the NASDAQ over the last week, ' - 'and return the price project for the end of the year using a simple linear regression. ' - ) + query = 'What was the hottest performing stock on NASDAQ yesterday?' print('šŸ” Starting stock analysis...\n') print(f'Query: {query}\n') - result = await stock_analysis_agent.run(query) - - # Track which builtin tools were used - web_search_count = 0 - code_execution_count = 0 - - for message in result.all_messages(): - if isinstance(message, ModelResponse): + async with stock_analysis_agent.run_stream(query) as result: + # Stream responses as they happen + async for message, _is_last in result.stream_responses(): for part in message.parts: if isinstance(part, BuiltinToolCallPart): - if 'web_search' in part.tool_name or 'browse' in part.tool_name: - web_search_count += 1 - logfire.info( - 'Server-side web_search tool called', - tool_name=part.tool_name, - tool_call_id=part.tool_call_id, - ) - elif 'code_execution' in part.tool_name: - code_execution_count += 1 - logfire.info( - 'Server-side code_execution tool called', - tool_name=part.tool_name, - tool_call_id=part.tool_call_id, - code=part.args_as_dict().get('code', 'N/A') - if part.args - else 'N/A', - ) - - print('\nāœ… Analysis complete!') - print('\nšŸ”§ Server-Side Tools Used:') - print(f' Web Search calls: {web_search_count}') - print(f' Code Execution calls: {code_execution_count}') - - print(f'\nStock: {result.output.stock_symbol}') - print(f'Current Price: ${result.output.current_price:.2f}') - print(f'Projected Year-End Price: ${result.output.projected_price:.2f}') - print(f'\nAnalysis: {result.output.analysis}') - - # Get the final response message for metadata - print(result.all_messages()) - final_message = result.all_messages()[-1] - if isinstance(final_message, ModelResponse): - print('\nšŸ†” Response Metadata:') - if final_message.provider_response_id: - print(f' Response ID: {final_message.provider_response_id}') - if final_message.model_name: - print(f' Model: {final_message.model_name}') - if final_message.timestamp: - print(f' Timestamp: {final_message.timestamp}') + print(f'šŸ”§ Server-side tool: {part.tool_name}\n') + + # Access output after streaming is complete + output = await result.get_output() + + print('\nāœ… Analysis complete!\n') + + print(f'šŸ“Š Top Stock: {output.stock_symbol}') + print(f'šŸ’° Current Price: ${output.current_price:.2f}') + print(f'\nšŸ“ˆ Buy Analysis:\n{output.buy_analysis}') # Show usage statistics usage = result.usage() - print('\nšŸ“ˆ Usage Statistics:') + print('\nšŸ“Š Usage Statistics:') print(f' Requests: {usage.requests}') print(f' Input Tokens: {usage.input_tokens}') print(f' Output Tokens: {usage.output_tokens}') print(f' Total Tokens: {usage.total_tokens}') + # Show server-side tools usage if available + if usage.details and 'server_side_tools_used' in usage.details: + print(f' Server-Side Tools: {usage.details["server_side_tools_used"]}') + if __name__ == '__main__': import asyncio diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index 6629689710..2314704482 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -750,7 +750,7 @@ async def test_grok_builtin_web_search_tool(allow_model_requests: None, xai_api_ result = await agent.run('Return just the day of week for the date of Jan 1 in 2026?') assert result.output - assert result.output.lower() == 'thursday' + assert 'thursday' in result.output.lower() # Verify that server-side tools were used usage = result.usage() @@ -820,9 +820,6 @@ async def test_grok_builtin_multiple_tools(allow_model_requests: None, xai_api_k messages = result.all_messages() assert len(messages) >= 2 - # The model should use both tools (basic validation that registration works) - # TODO: Add validation for built-in tool usage once response parsing is fully tested - @pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') async def test_grok_builtin_tools_with_custom_tools(allow_model_requests: None, xai_api_key: str): From 65bb6468bd50019d6aad4256d8fe6e4e3b82781e Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Tue, 18 Nov 2025 18:39:20 -0800 Subject: [PATCH 13/16] Remove new line for tool calls --- examples/pydantic_ai_examples/stock_analysis_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pydantic_ai_examples/stock_analysis_agent.py b/examples/pydantic_ai_examples/stock_analysis_agent.py index db894967fd..4eacd665a2 100644 --- a/examples/pydantic_ai_examples/stock_analysis_agent.py +++ b/examples/pydantic_ai_examples/stock_analysis_agent.py @@ -63,7 +63,7 @@ async def main(): async for message, _is_last in result.stream_responses(): for part in message.parts: if isinstance(part, BuiltinToolCallPart): - print(f'šŸ”§ Server-side tool: {part.tool_name}\n') + print(f'šŸ”§ Server-side tool: {part.tool_name}') # Access output after streaming is complete output = await result.get_output() From b19479e76e3571a12a8e7d613ccd6e3848605f7c Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Wed, 19 Nov 2025 12:41:47 -0800 Subject: [PATCH 14/16] Adding XAI_API_KEY to test_examples.py --- tests/test_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_examples.py b/tests/test_examples.py index 85bae688d0..b9237dc551 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -180,6 +180,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('DEEPSEEK_API_KEY', 'testing') env.set('OVHCLOUD_API_KEY', 'testing') env.set('PYDANTIC_AI_GATEWAY_API_KEY', 'testing') + env.set('XAI_API_KEY', 'testing') prefix_settings = example.prefix_settings() opt_test = prefix_settings.get('test', '') From 82682679aab2e53073076ab6ec399858155e0c4d Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Wed, 19 Nov 2025 15:19:59 -0800 Subject: [PATCH 15/16] Adding grok-4-1-fast-non-reasoning as default --- docs/models/grok.md | 10 +-- .../pydantic_ai/models/__init__.py | 2 + pydantic_ai_slim/pydantic_ai/models/grok.py | 2 +- .../pydantic_ai/providers/grok.py | 2 + tests/models/test_grok.py | 70 +++++++++---------- 5 files changed, 45 insertions(+), 41 deletions(-) diff --git a/docs/models/grok.md b/docs/models/grok.md index 629d2a88e6..af75cc9d50 100644 --- a/docs/models/grok.md +++ b/docs/models/grok.md @@ -10,7 +10,7 @@ pip/uv-add "pydantic-ai-slim[grok]" ## Configuration -To use Grok from [xAI](https://x.ai/api) through their API, go to [console.x.ai]https://console.x.ai) and follow your nose until you find the place to create an API key. +To use Grok from [xAI](https://x.ai/api) through their API, go your [console.x.ai](https://console.x.ai/team/default/api-keys) and follow your nose until you find the place to create an API key. ## Environment variable @@ -25,7 +25,7 @@ You can then use `GrokModel` by name: ```python from pydantic_ai import Agent -agent = Agent('grok:grok-4-fast-non-reasoning') +agent = Agent('grok:grok-4-1-fast-non-reasoning') ... ``` @@ -35,7 +35,7 @@ Or initialise the model directly with just the model name: from pydantic_ai import Agent from pydantic_ai.models.grok import GrokModel -model = GrokModel('grok-4-fast-non-reasoning') +model = GrokModel('grok-4-1-fast-non-reasoning') agent = Agent(model) ... ``` @@ -46,7 +46,7 @@ You can provide your own `api_key` inline like so: from pydantic_ai import Agent from pydantic_ai.models.grok import GrokModel -model = GrokModel('grok-4-fast-non-reasoning', api_key='your-api-key') +model = GrokModel('grok-4-1-fast-non-reasoning', api_key='your-api-key') agent = Agent(model) ... ``` @@ -60,7 +60,7 @@ async_client = AsyncClient(api_key='your-api-key') from pydantic_ai import Agent from pydantic_ai.models.grok import GrokModel -model = GrokModel('grok-4-fast-non-reasoning', client=async_client) +model = GrokModel('grok-4-1-fast-non-reasoning', client=async_client) agent = Agent(model) ... ``` diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 222621db96..77776d6f24 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -175,6 +175,8 @@ 'grok:grok-3-mini-fast', 'grok:grok-4', 'grok:grok-4-0709', + 'grok:grok-4-1-fast-non-reasoning', + 'grok:grok-4-1-fast-reasoning', 'grok:grok-4-fast-non-reasoning', 'grok:grok-4-fast-reasoning', 'grok:grok-code-fast-1', diff --git a/pydantic_ai_slim/pydantic_ai/models/grok.py b/pydantic_ai_slim/pydantic_ai/models/grok.py index 893ae6a02b..8feff113ee 100644 --- a/pydantic_ai_slim/pydantic_ai/models/grok.py +++ b/pydantic_ai_slim/pydantic_ai/models/grok.py @@ -63,7 +63,7 @@ def __init__( """Initialize the Grok model. Args: - model_name: The name of the Grok model to use (e.g., "grok-3", "grok-4-fast-non-reasoning") + model_name: The name of the Grok model to use (e.g., "grok-4-1-fast-non-reasoning") api_key: The xAI API key. If not provided, uses XAI_API_KEY environment variable. client: Optional AsyncClient instance for testing. If provided, api_key is ignored. settings: Optional model settings. diff --git a/pydantic_ai_slim/pydantic_ai/providers/grok.py b/pydantic_ai_slim/pydantic_ai/providers/grok.py index 1970c9696d..5973b41c1b 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/grok.py +++ b/pydantic_ai_slim/pydantic_ai/providers/grok.py @@ -25,6 +25,8 @@ GrokModelName = Literal[ 'grok-4', 'grok-4-0709', + 'grok-4-1-fast-reasoning', + 'grok-4-1-fast-non-reasoning', 'grok-4-fast-reasoning', 'grok-4-fast-non-reasoning', 'grok-code-fast-1', diff --git a/tests/models/test_grok.py b/tests/models/test_grok.py index 2314704482..c62eef0cd2 100644 --- a/tests/models/test_grok.py +++ b/tests/models/test_grok.py @@ -64,16 +64,16 @@ def test_grok_init(): - m = GrokModel('grok-4-fast-non-reasoning', api_key='foobar') + m = GrokModel('grok-4-1-fast-non-reasoning', api_key='foobar') # Check model properties without accessing private attributes - assert m.model_name == 'grok-4-fast-non-reasoning' + assert m.model_name == 'grok-4-1-fast-non-reasoning' assert m.system == 'xai' async def test_grok_request_simple_success(allow_model_requests: None): response = create_response(content='world') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('hello') @@ -94,7 +94,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -107,7 +107,7 @@ async def test_grok_request_simple_success(allow_model_requests: None): ), ModelResponse( parts=[TextPart(content='world')], - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -124,7 +124,7 @@ async def test_grok_request_simple_usage(allow_model_requests: None): usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1), ) mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('Hello') @@ -142,7 +142,7 @@ async def test_grok_image_input(allow_model_requests: None): """Test that Grok model handles image inputs (text is extracted from content).""" response = create_response(content='done') mock_client = MockGrok.create_mock(response) - model = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + model = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(model) image_url = ImageUrl('https://example.com/image.png') @@ -160,7 +160,7 @@ async def test_grok_request_structured_response(allow_model_requests: None): ) response = create_response(tool_calls=[tool_call]) mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=list[int]) result = await agent.run('Hello') @@ -179,7 +179,7 @@ async def test_grok_request_structured_response(allow_model_requests: None): tool_call_id='123', ) ], - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -214,7 +214,7 @@ async def test_grok_request_tool_call(allow_model_requests: None): create_response(content='final response'), ] mock_client = MockGrok.create_mock(responses) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, system_prompt='this is the system prompt') @agent.tool_plain @@ -247,7 +247,7 @@ async def get_location(loc_name: str) -> str: input_tokens=2, output_tokens=1, ), - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -277,7 +277,7 @@ async def get_location(loc_name: str) -> str: input_tokens=3, output_tokens=2, ), - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -297,7 +297,7 @@ async def get_location(loc_name: str) -> str: ), ModelResponse( parts=[TextPart(content='final response')], - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -369,7 +369,7 @@ def grok_reasoning_text_chunk( async def test_grok_stream_text(allow_model_requests: None): stream = [grok_text_chunk('hello '), grok_text_chunk('world')] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -387,7 +387,7 @@ async def test_grok_stream_text_finish_reason(allow_model_requests: None): grok_text_chunk('.', 'stop'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -402,7 +402,7 @@ async def test_grok_stream_text_finish_reason(allow_model_requests: None): ModelResponse( parts=[TextPart(content='hello world.')], usage=RequestUsage(input_tokens=2, output_tokens=1), - model_name='grok-4-fast-non-reasoning', + model_name='grok-4-1-fast-non-reasoning', timestamp=IsDatetime(), provider_name='xai', provider_response_id='grok-123', @@ -479,7 +479,7 @@ async def test_grok_stream_structured(allow_model_requests: None): grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -500,7 +500,7 @@ async def test_grok_stream_structured_finish_reason(allow_model_requests: None): grok_tool_chunk(None, None, finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -518,7 +518,7 @@ async def test_grok_stream_native_output(allow_model_requests: None): grok_text_chunk('}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=NativeOutput(MyTypedDict)) async with agent.run_stream('') as result: @@ -537,7 +537,7 @@ async def test_grok_stream_tool_call_with_empty_text(allow_model_requests: None) grok_tool_chunk(None, '}', finish_reason='stop', accumulated_args='{"first": "One", "second": "Two"}'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=[str, MyTypedDict]) async with agent.run_stream('') as result: @@ -554,7 +554,7 @@ async def test_grok_no_delta(allow_model_requests: None): grok_text_chunk('world'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -571,7 +571,7 @@ async def test_grok_none_delta(allow_model_requests: None): grok_text_chunk('world'), ] mock_client = MockGrok.create_mock_stream(stream) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) async with agent.run_stream('') as result: @@ -597,7 +597,7 @@ async def test_grok_parallel_tool_calls(allow_model_requests: None, parallel_too ) response = create_response(content='', tool_calls=[tool_call], finish_reason='tool_calls') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m, output_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls)) await agent.run('Hello') @@ -607,7 +607,7 @@ async def test_grok_parallel_tool_calls(allow_model_requests: None, parallel_too async def test_grok_penalty_parameters(allow_model_requests: None) -> None: response = create_response(content='test response') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) settings = ModelSettings( temperature=0.7, @@ -631,7 +631,7 @@ async def test_grok_penalty_parameters(allow_model_requests: None) -> None: async def test_grok_image_url_input(allow_model_requests: None): response = create_response(content='world') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run( @@ -647,7 +647,7 @@ async def test_grok_image_url_input(allow_model_requests: None): @pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') async def test_grok_image_url_tool_response(allow_model_requests: None, xai_api_key: str): - m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + m = GrokModel('grok-4-1-fast-non-reasoning', api_key=xai_api_key) agent = Agent(m) @agent.tool_plain @@ -684,7 +684,7 @@ async def get_image() -> ImageUrl: async def test_grok_image_as_binary_content_tool_response( allow_model_requests: None, image_content: BinaryContent, xai_api_key: str ): - m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + m = GrokModel('grok-4-1-fast-non-reasoning', api_key=xai_api_key) agent = Agent(m) @agent.tool_plain @@ -724,7 +724,7 @@ async def test_grok_image_as_binary_content_input( allow_model_requests: None, image_content: BinaryContent, xai_api_key: str ): """Test passing binary image content directly as input (not from a tool).""" - m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) + m = GrokModel('grok-4-1-fast-non-reasoning', api_key=xai_api_key) agent = Agent(m) result = await agent.run(['What fruit is in the image?', image_content]) @@ -923,7 +923,7 @@ async def test_grok_model_retries(allow_model_requests: None): success_response = create_response(content='Success after retry') mock_client = MockGrok.create_mock(success_response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) result = await agent.run('hello') assert result.output == 'Success after retry' @@ -933,7 +933,7 @@ async def test_grok_model_settings(allow_model_requests: None): """Test Grok model with various settings.""" response = create_response(content='response with settings') mock_client = MockGrok.create_mock(response) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent( m, model_settings=ModelSettings( @@ -965,7 +965,7 @@ async def test_grok_model_multiple_tool_calls(allow_model_requests: None): ] mock_client = MockGrok.create_mock(responses) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) @agent.tool_plain @@ -995,7 +995,7 @@ async def test_grok_stream_with_tool_calls(allow_model_requests: None): ] mock_client = MockGrok.create_mock_stream([stream1, stream2]) - m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) + m = GrokModel('grok-4-1-fast-non-reasoning', client=mock_client) agent = Agent(m) @agent.tool_plain @@ -1017,14 +1017,14 @@ async def get_info(query: str) -> str: async def test_grok_model_invalid_api_key(): """Test Grok model with invalid API key.""" with pytest.raises(ValueError, match='XAI API key is required'): - GrokModel('grok-4-fast-non-reasoning', api_key='') + GrokModel('grok-4-1-fast-non-reasoning', api_key='') async def test_grok_model_properties(): """Test Grok model properties.""" - m = GrokModel('grok-4-fast-non-reasoning', api_key='test-key') + m = GrokModel('grok-4-1-fast-non-reasoning', api_key='test-key') - assert m.model_name == 'grok-4-fast-non-reasoning' + assert m.model_name == 'grok-4-1-fast-non-reasoning' assert m.system == 'xai' From 843f588cbb0c1d120b93f0650c6cf8eb74d89100 Mon Sep 17 00:00:00 2001 From: Julian Bright Date: Wed, 19 Nov 2025 15:23:27 -0800 Subject: [PATCH 16/16] Update to mention GrokModelName in docs --- docs/models/grok.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/models/grok.md b/docs/models/grok.md index af75cc9d50..24edcb0540 100644 --- a/docs/models/grok.md +++ b/docs/models/grok.md @@ -12,6 +12,8 @@ pip/uv-add "pydantic-ai-slim[grok]" To use Grok from [xAI](https://x.ai/api) through their API, go your [console.x.ai](https://console.x.ai/team/default/api-keys) and follow your nose until you find the place to create an API key. +`GrokModelName` contains a list of available Grok models. + ## Environment variable Once you have the API key, you can set it as an environment variable: