Skip to content

Commit ad1dc0a

Browse files
authored
Merge pull request #46 from VectorInstitute/std_model_call
Add standardized async model call helper and separate new client factory.
2 parents bd9fc33 + 2bb81e0 commit ad1dc0a

File tree

1 file changed

+273
-4
lines changed

1 file changed

+273
-4
lines changed

src/utils/model_client_utils.py

Lines changed: 273 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44

55
import logging
66
import os
7-
from typing import Any, Optional
7+
from typing import Any, Dict, Mapping, Optional, Sequence
88

99
import anthropic
1010
import openai
11-
from autogen_core.models import ModelInfo
11+
from autogen_core.models import (
12+
ChatCompletionClient,
13+
ModelInfo,
14+
SystemMessage,
15+
UserMessage,
16+
)
1217
from autogen_ext.models.anthropic import AnthropicChatCompletionClient
1318
from autogen_ext.models.openai import OpenAIChatCompletionClient
1419
from tenacity import (
@@ -28,7 +33,7 @@
2833

2934

3035
class RetryableModelClient:
31-
"""Wrapper that adds retry logic to any model client."""
36+
"""Wrap a client and retry `create` on transient API errors."""
3237

3338
def __init__(self, client: Any, max_retries: int = 3):
3439
self.client = client
@@ -60,7 +65,7 @@ def __getattr__(self, name: str) -> Any:
6065

6166

6267
def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any) -> Any:
63-
"""Get a model client for the given model name."""
68+
"""Legacy factory: return a retry-wrapped client for `model_name`."""
6469
n = model_name.lower()
6570

6671
if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")):
@@ -104,3 +109,267 @@ def get_model_client(model_name: str, seed: Optional[int] = None, **kwargs: Any)
104109
return RetryableModelClient(client)
105110

106111
raise ValueError(f"Unsupported model '{model_name}'.")
112+
113+
114+
def get_standard_model_client(
115+
model_name: str,
116+
*,
117+
seed: Optional[int] = None,
118+
**kwargs: Any,
119+
) -> ChatCompletionClient:
120+
"""Build a plain client for use with `async_call_model`."""
121+
n = model_name.lower()
122+
123+
# OpenAI GPT / o-series models
124+
if n.startswith(("gpt-", "o1-", "o3-", "gpt-5")):
125+
return OpenAIChatCompletionClient(model=model_name, seed=seed, **kwargs)
126+
127+
# Anthropic Claude models
128+
if "claude" in n:
129+
kwargs.setdefault("timeout", None)
130+
return AnthropicChatCompletionClient(model=model_name, **kwargs)
131+
132+
# Gemini via OpenAI-compatible AI Studio endpoint
133+
if "gemini" in n:
134+
api_key = kwargs.pop("api_key", os.getenv("GOOGLE_API_KEY"))
135+
if not api_key:
136+
raise ValueError("Set GOOGLE_API_KEY for Gemini (AI Studio).")
137+
138+
model_info = kwargs.pop(
139+
"model_info",
140+
ModelInfo(
141+
vision=True,
142+
function_calling=True,
143+
json_output=True,
144+
structured_output=True,
145+
family="unknown",
146+
),
147+
)
148+
149+
return OpenAIChatCompletionClient(
150+
model=model_name,
151+
base_url=GEMINI_STUDIO_BASE,
152+
api_key=api_key,
153+
model_info=model_info,
154+
**kwargs,
155+
)
156+
157+
raise ValueError(f"Unsupported model '{model_name}'.")
158+
159+
160+
class ModelCallError(RuntimeError):
161+
"""Error raised when a standardized model call fails."""
162+
163+
164+
class ModelCallMode:
165+
"""Output modes for `async_call_model`."""
166+
167+
TEXT = "text"
168+
JSON_PARSE = "json_parse"
169+
STRUCTURED = "structured"
170+
171+
172+
async def async_call_model(
173+
model_client: ChatCompletionClient,
174+
*,
175+
model_name: Optional[str] = None,
176+
system_prompt: Optional[str] = None,
177+
user_prompt: Optional[str] = None,
178+
messages: Optional[Sequence[Any]] = None,
179+
mode: str = ModelCallMode.TEXT,
180+
temperature: Optional[float] = None,
181+
max_tokens: Optional[int] = None,
182+
top_p: Optional[float] = None,
183+
seed: Optional[int] = None,
184+
max_attempts: int = 3,
185+
extra_kwargs: Optional[Mapping[str, Any]] = None,
186+
) -> Any:
187+
"""Perform a standard async model call with provider-aware args and output modes.
188+
189+
- Builds messages from prompts if `messages` is None.
190+
- Maps `temperature`, `max_tokens`, `top_p`, `seed` to the right provider kwargs.
191+
- `mode`:
192+
- TEXT: return `str` content.
193+
- JSON_PARSE: parse JSON and return `dict`.
194+
- STRUCTURED: return the raw provider response.
195+
- Retries only for empty content / JSON parse failures; other errors raise
196+
`ModelCallError` immediately.
197+
"""
198+
# Try to infer model name if not provided explicitly.
199+
resolved_model_name: Optional[str] = model_name
200+
if resolved_model_name is None:
201+
underlying = getattr(model_client, "client", model_client)
202+
resolved_model_name = getattr(underlying, "model", None)
203+
204+
# Identify provider family from the model name.
205+
provider: Optional[str] = None
206+
lowered_name = (
207+
resolved_model_name.lower() if isinstance(resolved_model_name, str) else ""
208+
)
209+
if lowered_name.startswith(("gpt-", "o1-", "o3-", "gpt-5")):
210+
provider = "openai"
211+
elif "claude" in lowered_name:
212+
provider = "anthropic"
213+
elif "gemini" in lowered_name:
214+
provider = "gemini"
215+
216+
if messages is None:
217+
if user_prompt is None and system_prompt is None:
218+
raise ValueError(
219+
"Either 'messages' or at least one of 'system_prompt' / 'user_prompt' "
220+
"must be provided."
221+
)
222+
223+
built_messages: list[Any] = []
224+
if system_prompt:
225+
built_messages.append(SystemMessage(content=system_prompt))
226+
if user_prompt:
227+
built_messages.append(UserMessage(content=user_prompt, source="user"))
228+
messages = built_messages
229+
230+
if max_attempts < 1:
231+
raise ValueError("max_attempts must be at least 1")
232+
233+
last_error: Exception | None = None
234+
drop_temperature_for_model = False
235+
236+
for attempt in range(1, max_attempts + 1):
237+
request_kwargs: Dict[str, Any] = {}
238+
239+
if temperature is not None and not drop_temperature_for_model:
240+
if provider == "openai" and lowered_name:
241+
# "o1" models: special handling, often ignore temperature.
242+
# "o3-mini", "o3", "o4-mini": temperature is not always supported.
243+
if any(
244+
key in lowered_name for key in ("o1", "o3-mini", "o3", "o4-mini")
245+
):
246+
logger.debug(
247+
"Not sending 'temperature' for model '%s' due to known "
248+
"limitations.",
249+
resolved_model_name,
250+
)
251+
else:
252+
request_kwargs["temperature"] = temperature
253+
elif provider in {"anthropic", "gemini", None}:
254+
# Anthropic Claude and Gemini generally support temperature;
255+
# for unknown providers we optimistically pass it through.
256+
request_kwargs["temperature"] = temperature
257+
258+
# Map unified `max_tokens` to provider-specific kwarg.
259+
if max_tokens is not None:
260+
if provider in {"openai", "gemini"}:
261+
request_kwargs["max_completion_tokens"] = max_tokens
262+
elif provider == "anthropic":
263+
request_kwargs["max_tokens"] = max_tokens
264+
else:
265+
request_kwargs["max_tokens"] = max_tokens
266+
267+
# `top_p` only for OpenAI-style providers.
268+
if top_p is not None and provider in {"openai", "gemini", None}:
269+
request_kwargs["top_p"] = top_p
270+
if seed is not None:
271+
request_kwargs["seed"] = seed
272+
273+
# Output / structured config
274+
if mode in (ModelCallMode.JSON_PARSE, ModelCallMode.STRUCTURED):
275+
# Many clients support json_output / structured_output flags.
276+
# Some may ignore these silently; others might raise if unsupported.
277+
request_kwargs.setdefault("json_output", True)
278+
if mode == ModelCallMode.STRUCTURED:
279+
request_kwargs.setdefault("structured_output", True)
280+
281+
# Extra kwargs always win
282+
if extra_kwargs:
283+
request_kwargs.update(extra_kwargs)
284+
285+
try:
286+
response = await model_client.create(
287+
messages=list(messages), # type: ignore[arg-type]
288+
**request_kwargs,
289+
)
290+
except TypeError as exc:
291+
# Some models (e.g., certain reasoning or o-series models) do not
292+
# support temperature or other generation parameters. If the error
293+
# message clearly points to 'temperature', drop it and retry once.
294+
if (
295+
"temperature" in str(exc)
296+
and "temperature" in request_kwargs
297+
and not drop_temperature_for_model
298+
):
299+
logger.warning(
300+
"Model rejected 'temperature' parameter; retrying without it. "
301+
"Error was: %s",
302+
exc,
303+
)
304+
drop_temperature_for_model = True
305+
last_error = exc
306+
continue
307+
last_error = exc
308+
logger.error("Model call failed with TypeError: %s", exc)
309+
break
310+
except Exception as exc: # pragma: no cover - network/SDK errors
311+
# Let lower-level client / infrastructure handle any network or
312+
# transient retries. At this layer we convert to ModelCallError
313+
# without additional retry loops to avoid duplicating behaviour.
314+
logger.error("Model call failed with unexpected error: %s", exc)
315+
last_error = exc
316+
break
317+
318+
# Extract content in a provider-agnostic way.
319+
content = getattr(response, "content", None)
320+
if content is None:
321+
last_error = ModelCallError("Model returned empty response content")
322+
logger.warning(
323+
"Empty response content on attempt %d/%d", attempt, max_attempts
324+
)
325+
if attempt < max_attempts:
326+
continue
327+
break
328+
329+
# Normalize to string for text / JSON modes.
330+
if isinstance(content, (list, tuple)):
331+
content_str = "\n".join(str(part) for part in content)
332+
else:
333+
content_str = str(content)
334+
335+
content_str = content_str.strip()
336+
if not content_str:
337+
last_error = ModelCallError("Model returned empty response content")
338+
logger.warning(
339+
"Blank response content on attempt %d/%d", attempt, max_attempts
340+
)
341+
if attempt < max_attempts:
342+
continue
343+
break
344+
345+
if mode == ModelCallMode.TEXT:
346+
return content_str
347+
348+
if mode == ModelCallMode.JSON_PARSE:
349+
import json
350+
351+
try:
352+
return json.loads(content_str)
353+
except Exception as exc: # pragma: no cover - JSON edge cases
354+
last_error = ModelCallError(
355+
f"Failed to parse JSON from model response: {exc}"
356+
)
357+
logger.warning(
358+
"JSON parsing failed on attempt %d/%d: %s",
359+
attempt,
360+
max_attempts,
361+
exc,
362+
)
363+
if attempt < max_attempts:
364+
continue
365+
break
366+
367+
# STRUCTURED mode: return provider object as-is to the caller.
368+
return response
369+
370+
# If we get here, all attempts failed.
371+
if last_error is None:
372+
raise ModelCallError("Model call failed for unknown reasons")
373+
if isinstance(last_error, ModelCallError):
374+
raise last_error
375+
raise ModelCallError(f"Model call failed: {last_error}") from last_error

0 commit comments

Comments
 (0)