Skip to content

Commit ee1e106

Browse files
committed
adding exponential backoff retry decorator to utils.py and using the decorator to safely call the user-provided LLM function and the AIMon Detect function. Re-raises last encountered exception upon failure.
1 parent a7f56e8 commit ee1e106

File tree

3 files changed

+101
-47
lines changed

3 files changed

+101
-47
lines changed

aimon/reprompting_api/pipeline.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from aimon.reprompting_api.config import RepromptingConfig, StopReasons
22
from aimon.reprompting_api.telemetry import TelemetryLogger
33
from aimon.reprompting_api.reprompter import Reprompter
4-
from aimon.reprompting_api.utils import toxicity_check, get_failed_instructions_count, get_failed_instructions, get_residual_error_score, get_failed_toxicity_instructions
4+
from aimon.reprompting_api.utils import retry, toxicity_check, get_failed_instructions_count, get_failed_instructions, get_residual_error_score, get_failed_toxicity_instructions
55
from aimon import Detect
66
import time
77
import random
@@ -243,57 +243,50 @@ def _build_aimon_payload(self, context, user_query, user_instructions, generated
243243
}
244244
return payload
245245

246-
def _call_llm(self, prompt_template: Template, max_attempts, system_prompt=None, context=None, user_query=None, base_delay=1):
246+
def _call_llm(self, prompt_template: Template, max_attempts, system_prompt=None, context=None, user_query=None):
247247
"""
248248
Calls the LLM with exponential backoff. Retries if the LLM call fails
249-
OR returns a non-string value. Raises an exception if all retries fail.
249+
OR returns a non-string value. If all retries fail, the last encountered
250+
exception from the LLM call is re-raised.
250251
251252
Args:
252253
prompt_template (Template): Prompt template for the LLM.
253254
max_attempts (int): Max retry attempts.
254-
base_delay (float): Initial delay in seconds before backoff.
255-
255+
256256
Returns:
257257
str: LLM response text.
258258
259259
Raises:
260-
RuntimeError: If the LLM call fails or returns an invalid type after all retries.
260+
RuntimeError: If the LLM call repeatedly fails, re-raises the last encountered error.
261+
TypeError: If the LLM call fails to return a string.
261262
"""
262-
last_exception = None
263-
for attempt in range(max_attempts):
264-
try:
265-
logger.debug(f"LLM call attempt {attempt+1} with prompt template.")
266-
result = self.llm_fn(prompt_template, system_prompt, context, user_query)
267-
# Validate type
268-
if not isinstance(result, str):
269-
raise TypeError(f"LLM returned invalid type {type(result).__name__}, expected str.")
270-
return result
271-
except Exception as e:
272-
last_exception = e
273-
logger.warning(f"LLM call failed on attempt {attempt+1}: {e}")
274-
wait_time = base_delay * (2 ** attempt) + random.uniform(0, 0.1)
275-
time.sleep(wait_time)
276-
raise RuntimeError(f"LLM call failed or returned invalid type after maximum retries. Last error: {last_exception}")
263+
@retry(exception_to_check=Exception, tries=max_attempts, delay=1, backoff=2, logger=logger)
264+
def backoff_call():
265+
result = self.llm_fn(prompt_template, system_prompt, context, user_query)
266+
if not isinstance(result, str):
267+
raise TypeError(f"LLM returned invalid type {type(result).__name__}, expected str.")
268+
return result
269+
return backoff_call()
277270

278-
def _detect_aimon_response(self, payload, max_attempts, base_delay=1):
271+
def _detect_aimon_response(self, payload, max_attempts):
279272
"""
280273
Calls AIMon Detect with exponential backoff and returns the detection result.
281274
282275
This method wraps the AIMon evaluation call, retrying if it fails due to transient
283276
errors (e.g., network issues or temporary service unavailability). It retries up to
284-
`max_attempts` times with exponential backoff before raising a RuntimeError.
277+
`max_attempts` times with exponential backoff before raising the last encountered
278+
exception from the AIMon Detect call.
285279
286280
Args:
287281
payload (dict): A dictionary containing 'context', 'user_query',
288282
'instructions', and 'generated_text' for evaluation.
289283
max_attempts (int): Maximum number of retry attempts.
290-
base_delay (float): Initial delay in seconds before exponential backoff.
291284
292285
Returns:
293286
object: The AIMon detection result containing evaluation scores and feedback.
294287
295288
Raises:
296-
RuntimeError: If AIMon Detect fails after all retry attempts.
289+
RuntimeError: If AIMon Detect fails after all retry attempts, re-raises the last encountered error.
297290
"""
298291
aimon_context = f"{payload['context']}\n\nUser Query:\n{payload['user_query']}"
299292
aimon_query = f"{payload['user_query']}\n\nInstructions:\n{payload['instructions']}"
@@ -302,21 +295,23 @@ def _detect_aimon_response(self, payload, max_attempts, base_delay=1):
302295
def run_detection(query, instructions, generated_text, context):
303296
return query, instructions, generated_text, context
304297

305-
for attempt in range(max_attempts):
306-
try:
307-
logger.debug(f"AIMon detect attempt {attempt+1} with payload: {payload}")
308-
_, _, _, _, result = run_detection(
309-
aimon_query,
310-
payload['instructions'],
311-
payload['generated_text'],
312-
aimon_context
313-
)
314-
return result
315-
except Exception as e:
316-
logger.debug(f"AIMon detect failed on attempt {attempt+1}: {e}")
317-
wait_time = base_delay * (2 ** attempt) + random.uniform(0, 0.1)
318-
time.sleep(wait_time)
319-
raise RuntimeError("AIMon detect call failed after maximum retries.")
298+
@retry(
299+
exception_to_check=Exception,
300+
tries=max_attempts,
301+
delay=1,
302+
backoff=2,
303+
logger=logger
304+
)
305+
def inner_detection():
306+
logger.debug(f"AIMon detect call with payload: {payload}")
307+
_, _, _, _, result = run_detection(
308+
aimon_query,
309+
payload['instructions'],
310+
payload['generated_text'],
311+
aimon_context
312+
)
313+
return result
314+
return inner_detection()
320315

321316
def get_response_feedback(self, result):
322317
"""

aimon/reprompting_api/utils.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,62 @@
1414
- Toxicity failures are flagged when follow_probability > TOXICITY_THRESHOLD (default 0.25).
1515
- Residual error scoring penalizes low follow probabilities more heavily and adds a flat penalty for any toxicity failures.
1616
"""
17-
from typing import List
17+
from typing import Callable, Type, Union, Tuple, Optional, List
18+
from functools import wraps
19+
import logging
20+
import random
21+
import time
22+
23+
def retry(
24+
exception_to_check: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
25+
tries: int = 5,
26+
delay: int = 3,
27+
backoff: int = 2,
28+
logger: Optional[logging.Logger] = None,
29+
log_level: int = logging.WARNING,
30+
re_raise: bool = True,
31+
jitter: float = 0.1
32+
) -> Callable:
33+
"""
34+
Retry calling the decorated function using an exponential backoff.
35+
:param exception_to_check: Exception or a tuple of exceptions to check.
36+
:param tries: Number of times to try (not retry) before giving up.
37+
:param delay: Initial delay between retries in seconds.
38+
:param backoff: Backoff multiplier e.g., a value of 2 will double the delay each retry.
39+
:param logger: Logger to use. If None, print.
40+
:param log_level: Logging level.
41+
:param re_raise: Whether to re-raise the exception after the last retry.
42+
:param jitter: The maximum jitter to apply to the delay as a fraction of the delay.
43+
"""
44+
45+
def deco_retry(func: Callable) -> Callable:
46+
@wraps(func)
47+
def f_retry(*args, **kwargs):
48+
remaining_tries, current_delay = tries, delay
49+
while remaining_tries > 1:
50+
try:
51+
return func(*args, **kwargs)
52+
except exception_to_check as e:
53+
msg = f"{e}, Retrying in {current_delay} seconds..."
54+
if logger:
55+
logger.log(log_level, msg)
56+
else:
57+
print(msg)
58+
time.sleep(current_delay * (1 + jitter * (2 * random.random() - 1)))
59+
remaining_tries -= 1
60+
current_delay *= backoff
61+
try:
62+
return func(*args, **kwargs)
63+
except exception_to_check as e:
64+
msg = f"Failed after {tries} tries. {e}"
65+
if logger:
66+
logger.log(log_level, msg)
67+
else:
68+
print(msg)
69+
if re_raise:
70+
raise
71+
return f_retry
72+
return deco_retry
1873

1974
# toxicity threshold for AIMon detection; Follow probabilities above this are considered failures
2075
TOXICITY_THRESHOLD = 0.25
@@ -168,11 +223,14 @@ def penalized_average(probs: List[float]) -> float:
168223
Probabilities > 0.5 (passed instructions) recieve no penalty
169224
170225
Args:
171-
probs (List[float]): A list of follow probabilities.
226+
probs (List[float]): A list of follow probabilities. Must be non-empty.
172227
173228
Returns:
174-
float: Penalized average.
229+
float: Penalized average. Return -1 if probs is empty.
175230
"""
231+
if not probs: # handle division by zero for empty list
232+
return -1
233+
176234
penalties = []
177235
for p in probs:
178236
if p >= 0.5:

tests/test_reprompting_failures.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from string import Template
44
from together import Together
5+
import aimon
56
from aimon.reprompting_api.config import RepromptingConfig
67
from aimon.reprompting_api.runner import run_reprompting_pipeline
78

@@ -61,7 +62,7 @@ def get_config_with_invalid_aimon_api_key():
6162
def test_llm_failure():
6263
"""Should raise RuntimeError when the LLM function always fails."""
6364
config = get_config()
64-
with pytest.raises(RuntimeError, match="LLM call failed or returned invalid type after maximum retries."):
65+
with pytest.raises(RuntimeError, match="LLM call failed intentionally for testing"):
6566
run_reprompting_pipeline(
6667
user_query="Test LLM failure handling",
6768
context="Context for failure test",
@@ -85,9 +86,9 @@ def test_invalid_llm_fn():
8586

8687
@pytest.mark.integration
8788
def test_invalid_return_value():
88-
"""Should raise RuntimeError when the LLM returns a non-string value."""
89+
"""Should raise TypeError when the LLM returns a non-string value."""
8990
config = get_config()
90-
with pytest.raises(RuntimeError, match="LLM call failed or returned invalid type"):
91+
with pytest.raises(TypeError, match="LLM returned invalid type int, expected str."):
9192
run_reprompting_pipeline(
9293
user_query="Test invalid return type",
9394
context="Context for type error",
@@ -113,7 +114,7 @@ def test_empty_query():
113114
def test_invalid_api_key():
114115
"""Should fail due to invalid AIMon API key."""
115116
config = get_config_with_invalid_aimon_api_key()
116-
with pytest.raises(RuntimeError):
117+
with pytest.raises(aimon.AuthenticationError):
117118
run_reprompting_pipeline(
118119
user_query="Testing with invalid AIMon API key",
119120
context="Context for invalid key test",

0 commit comments

Comments
 (0)