|
4 | 4 |
|
5 | 5 | import logging |
6 | 6 | import os |
7 | | -from typing import Any, Optional |
| 7 | +from typing import Any, Dict, Mapping, Optional, Sequence |
8 | 8 |
|
9 | 9 | import anthropic |
10 | 10 | import openai |
11 | | -from autogen_core.models import ModelInfo |
| 11 | +from autogen_core.models import ( |
| 12 | + ChatCompletionClient, |
| 13 | + ModelInfo, |
| 14 | + SystemMessage, |
| 15 | + UserMessage, |
| 16 | +) |
12 | 17 | from autogen_ext.models.anthropic import AnthropicChatCompletionClient |
13 | 18 | from autogen_ext.models.openai import OpenAIChatCompletionClient |
14 | 19 | from tenacity import ( |
|
28 | 33 |
|
29 | 34 |
|
30 | 35 | class RetryableModelClient: |
31 | | - """Wrapper that adds retry logic to any model client.""" |
| 36 | + """Wrap a client and retry `create` on transient API errors.""" |
32 | 37 |
|
33 | 38 | def __init__(self, client: Any, max_retries: int = 3): |
34 | 39 | self.client = client |
@@ -60,7 +65,7 @@ def __getattr__(self, name: str) -> Any: |
60 | 65 |
|
61 | 66 |
|
62 | 67 | def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any) -> Any: |
63 | | - """Get a model client for the given model name.""" |
| 68 | + """Legacy factory: return a retry-wrapped client for `model_name`.""" |
64 | 69 | n = model_name.lower() |
65 | 70 |
|
66 | 71 | if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")): |
@@ -104,3 +109,267 @@ def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any) |
104 | 109 | return RetryableModelClient(client) |
105 | 110 |
|
106 | 111 | raise ValueError(f"Unsupported model '{model_name}'.") |
| 112 | + |
| 113 | + |
| 114 | +def get_standard_model_client( |
| 115 | + model_name: str, |
| 116 | + *, |
| 117 | + seed: Optional[int] = None, |
| 118 | + **kwargs: Any, |
| 119 | +) -> ChatCompletionClient: |
| 120 | + """Build a plain client for use with `async_call_model`.""" |
| 121 | + n = model_name.lower() |
| 122 | + |
| 123 | + # OpenAI GPT / o-series models |
| 124 | + if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")): |
| 125 | + return OpenAIChatCompletionClient(model=model_name, seed=seed, **kwargs) |
| 126 | + |
| 127 | + # Anthropic Claude models |
| 128 | + if "claude" in n: |
| 129 | + kwargs.setdefault("timeout", None) |
| 130 | + return AnthropicChatCompletionClient(model=model_name, **kwargs) |
| 131 | + |
| 132 | + # Gemini via OpenAI-compatible AI Studio endpoint |
| 133 | + if "gemini" in n: |
| 134 | + api_key = kwargs.pop("api_key", os.getenv("GOOGLE_API_KEY")) |
| 135 | + if not api_key: |
| 136 | + raise ValueError("Set GOOGLE_API_KEY for Gemini (AI Studio).") |
| 137 | + |
| 138 | + model_info = kwargs.pop( |
| 139 | + "model_info", |
| 140 | + ModelInfo( |
| 141 | + vision=True, |
| 142 | + function_calling=True, |
| 143 | + json_output=True, |
| 144 | + structured_output=True, |
| 145 | + family="unknown", |
| 146 | + ), |
| 147 | + ) |
| 148 | + |
| 149 | + return OpenAIChatCompletionClient( |
| 150 | + model=model_name, |
| 151 | + base_url=GEMINI_STUDIO_BASE, |
| 152 | + api_key=api_key, |
| 153 | + model_info=model_info, |
| 154 | + **kwargs, |
| 155 | + ) |
| 156 | + |
| 157 | + raise ValueError(f"Unsupported model '{model_name}'.") |
| 158 | + |
| 159 | + |
| 160 | +class ModelCallError(RuntimeError): |
| 161 | + """Error raised when a standardized model call fails.""" |
| 162 | + |
| 163 | + |
| 164 | +class ModelCallMode: |
| 165 | + """Output modes for `async_call_model`.""" |
| 166 | + |
| 167 | + TEXT = "text" |
| 168 | + JSON_PARSE = "json_parse" |
| 169 | + STRUCTURED = "structured" |
| 170 | + |
| 171 | + |
| 172 | +async def async_call_model( |
| 173 | + model_client: ChatCompletionClient, |
| 174 | + *, |
| 175 | + model_name: Optional[str] = None, |
| 176 | + system_prompt: Optional[str] = None, |
| 177 | + user_prompt: Optional[str] = None, |
| 178 | + messages: Optional[Sequence[Any]] = None, |
| 179 | + mode: str = ModelCallMode.TEXT, |
| 180 | + temperature: Optional[float] = None, |
| 181 | + max_tokens: Optional[int] = None, |
| 182 | + top_p: Optional[float] = None, |
| 183 | + seed: Optional[int] = None, |
| 184 | + max_attempts: int = 3, |
| 185 | + extra_kwargs: Optional[Mapping[str, Any]] = None, |
| 186 | +) -> Any: |
| 187 | + """Perform a standard async model call with provider-aware args and output modes. |
| 188 | +
|
| 189 | + - Builds messages from prompts if `messages` is None. |
| 190 | + - Maps `temperature`, `max_tokens`, `top_p`, `seed` to the right provider kwargs. |
| 191 | + - `mode`: |
| 192 | + - TEXT: return `str` content. |
| 193 | + - JSON_PARSE: parse JSON and return `dict`. |
| 194 | + - STRUCTURED: return the raw provider response. |
| 195 | + - Retries only for empty content / JSON parse failures; other errors raise |
| 196 | + `ModelCallError` immediately. |
| 197 | + """ |
| 198 | + # Try to infer model name if not provided explicitly. |
| 199 | + resolved_model_name: Optional[str] = model_name |
| 200 | + if resolved_model_name is None: |
| 201 | + underlying = getattr(model_client, "client", model_client) |
| 202 | + resolved_model_name = getattr(underlying, "model", None) |
| 203 | + |
| 204 | + # Identify provider family from the model name. |
| 205 | + provider: Optional[str] = None |
| 206 | + lowered_name = ( |
| 207 | + resolved_model_name.lower() if isinstance(resolved_model_name, str) else "" |
| 208 | + ) |
| 209 | + if lowered_name.startswith(("gpt-", "o1-", "o3-", "gpt-5")): |
| 210 | + provider = "openai" |
| 211 | + elif "claude" in lowered_name: |
| 212 | + provider = "anthropic" |
| 213 | + elif "gemini" in lowered_name: |
| 214 | + provider = "gemini" |
| 215 | + |
| 216 | + if messages is None: |
| 217 | + if user_prompt is None and system_prompt is None: |
| 218 | + raise ValueError( |
| 219 | + "Either 'messages' or at least one of 'system_prompt' / 'user_prompt' " |
| 220 | + "must be provided." |
| 221 | + ) |
| 222 | + |
| 223 | + built_messages: list[Any] = [] |
| 224 | + if system_prompt: |
| 225 | + built_messages.append(SystemMessage(content=system_prompt)) |
| 226 | + if user_prompt: |
| 227 | + built_messages.append(UserMessage(content=user_prompt, source="user")) |
| 228 | + messages = built_messages |
| 229 | + |
| 230 | + if max_attempts < 1: |
| 231 | + raise ValueError("max_attempts must be at least 1") |
| 232 | + |
| 233 | + last_error: Exception | None = None |
| 234 | + drop_temperature_for_model = False |
| 235 | + |
| 236 | + for attempt in range(1, max_attempts + 1): |
| 237 | + request_kwargs: Dict[str, Any] = {} |
| 238 | + |
| 239 | + if temperature is not None and not drop_temperature_for_model: |
| 240 | + if provider == "openai" and lowered_name: |
| 241 | + # "o1" models: special handling, often ignore temperature. |
| 242 | + # "o3-mini", "o3", "o4-mini": temperature is not always supported. |
| 243 | + if any( |
| 244 | + key in lowered_name for key in ("o1", "o3-mini", "o3", "o4-mini") |
| 245 | + ): |
| 246 | + logger.debug( |
| 247 | + "Not sending 'temperature' for model '%s' due to known " |
| 248 | + "limitations.", |
| 249 | + resolved_model_name, |
| 250 | + ) |
| 251 | + else: |
| 252 | + request_kwargs["temperature"] = temperature |
| 253 | + elif provider in {"anthropic", "gemini", None}: |
| 254 | + # Anthropic Claude and Gemini generally support temperature; |
| 255 | + # for unknown providers we optimistically pass it through. |
| 256 | + request_kwargs["temperature"] = temperature |
| 257 | + |
| 258 | + # Map unified `max_tokens` to provider-specific kwarg. |
| 259 | + if max_tokens is not None: |
| 260 | + if provider in {"openai", "gemini"}: |
| 261 | + request_kwargs["max_completion_tokens"] = max_tokens |
| 262 | + elif provider == "anthropic": |
| 263 | + request_kwargs["max_tokens"] = max_tokens |
| 264 | + else: |
| 265 | + request_kwargs["max_tokens"] = max_tokens |
| 266 | + |
| 267 | + # `top_p` only for OpenAI-style providers. |
| 268 | + if top_p is not None and provider in {"openai", "gemini", None}: |
| 269 | + request_kwargs["top_p"] = top_p |
| 270 | + if seed is not None: |
| 271 | + request_kwargs["seed"] = seed |
| 272 | + |
| 273 | + # Output / structured config |
| 274 | + if mode in (ModelCallMode.JSON_PARSE, ModelCallMode.STRUCTURED): |
| 275 | + # Many clients support json_output / structured_output flags. |
| 276 | + # Some may ignore these silently; others might raise if unsupported. |
| 277 | + request_kwargs.setdefault("json_output", True) |
| 278 | + if mode == ModelCallMode.STRUCTURED: |
| 279 | + request_kwargs.setdefault("structured_output", True) |
| 280 | + |
| 281 | + # Extra kwargs always win |
| 282 | + if extra_kwargs: |
| 283 | + request_kwargs.update(extra_kwargs) |
| 284 | + |
| 285 | + try: |
| 286 | + response = await model_client.create( |
| 287 | + messages=list(messages), # type: ignore[arg-type] |
| 288 | + **request_kwargs, |
| 289 | + ) |
| 290 | + except TypeError as exc: |
| 291 | + # Some models (e.g., certain reasoning or o-series models) do not |
| 292 | + # support temperature or other generation parameters. If the error |
| 293 | + # message clearly points to 'temperature', drop it and retry once. |
| 294 | + if ( |
| 295 | + "temperature" in str(exc) |
| 296 | + and "temperature" in request_kwargs |
| 297 | + and not drop_temperature_for_model |
| 298 | + ): |
| 299 | + logger.warning( |
| 300 | + "Model rejected 'temperature' parameter; retrying without it. " |
| 301 | + "Error was: %s", |
| 302 | + exc, |
| 303 | + ) |
| 304 | + drop_temperature_for_model = True |
| 305 | + last_error = exc |
| 306 | + continue |
| 307 | + last_error = exc |
| 308 | + logger.error("Model call failed with TypeError: %s", exc) |
| 309 | + break |
| 310 | + except Exception as exc: # pragma: no cover - network/SDK errors |
| 311 | + # Let lower-level client / infrastructure handle any network or |
| 312 | + # transient retries. At this layer we convert to ModelCallError |
| 313 | + # without additional retry loops to avoid duplicating behaviour. |
| 314 | + logger.error("Model call failed with unexpected error: %s", exc) |
| 315 | + last_error = exc |
| 316 | + break |
| 317 | + |
| 318 | + # Extract content in a provider-agnostic way. |
| 319 | + content = getattr(response, "content", None) |
| 320 | + if content is None: |
| 321 | + last_error = ModelCallError("Model returned empty response content") |
| 322 | + logger.warning( |
| 323 | + "Empty response content on attempt %d/%d", attempt, max_attempts |
| 324 | + ) |
| 325 | + if attempt < max_attempts: |
| 326 | + continue |
| 327 | + break |
| 328 | + |
| 329 | + # Normalize to string for text / JSON modes. |
| 330 | + if isinstance(content, (list, tuple)): |
| 331 | + content_str = "\n".join(str(part) for part in content) |
| 332 | + else: |
| 333 | + content_str = str(content) |
| 334 | + |
| 335 | + content_str = content_str.strip() |
| 336 | + if not content_str: |
| 337 | + last_error = ModelCallError("Model returned empty response content") |
| 338 | + logger.warning( |
| 339 | + "Blank response content on attempt %d/%d", attempt, max_attempts |
| 340 | + ) |
| 341 | + if attempt < max_attempts: |
| 342 | + continue |
| 343 | + break |
| 344 | + |
| 345 | + if mode == ModelCallMode.TEXT: |
| 346 | + return content_str |
| 347 | + |
| 348 | + if mode == ModelCallMode.JSON_PARSE: |
| 349 | + import json |
| 350 | + |
| 351 | + try: |
| 352 | + return json.loads(content_str) |
| 353 | + except Exception as exc: # pragma: no cover - JSON edge cases |
| 354 | + last_error = ModelCallError( |
| 355 | + f"Failed to parse JSON from model response: {exc}" |
| 356 | + ) |
| 357 | + logger.warning( |
| 358 | + "JSON parsing failed on attempt %d/%d: %s", |
| 359 | + attempt, |
| 360 | + max_attempts, |
| 361 | + exc, |
| 362 | + ) |
| 363 | + if attempt < max_attempts: |
| 364 | + continue |
| 365 | + break |
| 366 | + |
| 367 | + # STRUCTURED mode: return provider object as-is to the caller. |
| 368 | + return response |
| 369 | + |
| 370 | + # If we get here, all attempts failed. |
| 371 | + if last_error is None: |
| 372 | + raise ModelCallError("Model call failed for unknown reasons") |
| 373 | + if isinstance(last_error, ModelCallError): |
| 374 | + raise last_error |
| 375 | + raise ModelCallError(f"Model call failed: {last_error}") from last_error |
0 commit comments