Skip to content

Commit b76340c

Browse files
pjoshi30Preetam Joshi
andauthored
Corrected toxicity intepretation based on the new toxicity model beha… (#72)
* Corrected toxicity intepretation based on the new toxicity model behavior * Bumping version of the package * Corrected doc string for ttoxicity * Added unit tests for reprompting utils * Updated tests and fixed a few bugs * Fixing tests --------- Co-authored-by: Preetam Joshi <info@aimon.ai>
1 parent e0e856f commit b76340c

File tree

6 files changed

+526
-62
lines changed

6 files changed

+526
-62
lines changed

aimon/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

33
__title__ = "aimon"
4-
__version__ = "0.12.1"
4+
__version__ = "0.12.2"

aimon/reprompting_api/reprompter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def get_toxicity_reprompt(self, result) -> str:
8484
8585
Returns:
8686
str: Toxicity-specific feedback, or None if no toxicity detected.
87+
88+
Note:
89+
For toxicity, lower scores indicate higher toxicity. We invert the score to show confidence.
8790
"""
8891
try:
8992
failed_instructions = get_failed_toxicity_instructions(result)
@@ -92,7 +95,8 @@ def get_toxicity_reprompt(self, result) -> str:
9295
logger.info(f"Toxicity violations detected: {len(failed_instructions)}")
9396
lines = ["Your reply contained toxic content. Remove any harmful, abusive, or unsafe language."]
9497
for i, failed_instruction in enumerate(failed_instructions, start=1):
95-
confidence = failed_instruction.get("score", 0.0) * 100
98+
# For toxicity, lower score = more toxic, so invert to show confidence
99+
confidence = (1.0 - failed_instruction.get("score", 0.0)) * 100
96100
confidence_str = f"{confidence:.2f}%"
97101
lines.append(
98102
f"{i}. We are {confidence_str} confident that your response had the following issue:\n"

aimon/reprompting_api/utils.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
- Guide corrective re-prompting logic.
1212
1313
Key conventions:
14-
- Toxicity failures are flagged when follow_probability > TOXICITY_THRESHOLD (default 0.25).
14+
- Toxicity failures are flagged when follow_probability < TOXICITY_THRESHOLD (default 0.5). Lower scores indicate higher toxicity.
1515
- Residual error scoring penalizes low follow probabilities more heavily and adds a flat penalty for any toxicity failures.
1616
"""
1717
from typing import Callable, Type, Union, Tuple, Optional, List
@@ -71,12 +71,13 @@ def f_retry(*args, **kwargs):
7171
return f_retry
7272
return deco_retry
7373

74-
# toxicity threshold for AIMon detection; Follow probabilities above this are considered failures
75-
TOXICITY_THRESHOLD = 0.25
74+
# toxicity threshold for AIMon detection; Follow probabilities below this are considered failures (lower score = more toxic)
75+
TOXICITY_THRESHOLD = 0.5
7676

7777
def _count_toxicity_failures(result) -> int:
7878
"""
79-
Count the number of toxicity instructions whose follow probability exceeds the threshold.
79+
Count the number of toxicity instructions whose follow probability is below the threshold.
80+
Lower scores indicate higher toxicity.
8081
8182
Args:
8283
result: AIMon detection result containing a `toxicity` section.
@@ -87,25 +88,27 @@ def _count_toxicity_failures(result) -> int:
8788
return sum(
8889
1
8990
for inst in result.detect_response.toxicity.get("instructions_list", [])
90-
if inst.get("follow_probability", 0.0) > TOXICITY_THRESHOLD
91+
if inst.get("follow_probability", 0.0) < TOXICITY_THRESHOLD
9192
)
9293

9394
def toxicity_check(result) -> bool:
9495
"""
95-
Check whether any toxicity instructions exceed the threshold.
96+
Check whether any toxicity instructions fall below the threshold.
97+
Lower scores indicate higher toxicity.
9698
9799
Args:
98100
result: AIMon detection result containing a `toxicity` section.
99101
100102
Returns:
101-
bool: True if at least one toxicity instruction exceeds the threshold, False otherwise.
103+
bool: True if at least one toxicity instruction is below the threshold, False otherwise.
102104
"""
103105
return _count_toxicity_failures(result) > 0
104106

105107

106108
def get_failed_toxicity_instructions(result) -> List[dict]:
107109
"""
108-
Extract failed toxicity instructions exceeding the threshold.
110+
Extract failed toxicity instructions below the threshold.
111+
Lower scores indicate higher toxicity.
109112
110113
Args:
111114
result: AIMon detection result containing a `toxicity` section.
@@ -120,7 +123,7 @@ def get_failed_toxicity_instructions(result) -> List[dict]:
120123
"""
121124
failed = []
122125
for inst in result.detect_response.toxicity.get("instructions_list", []):
123-
if inst.get("follow_probability", 0.0) > TOXICITY_THRESHOLD:
126+
if inst.get("follow_probability", 0.0) < TOXICITY_THRESHOLD:
124127
failed.append({
125128
"type": "toxicity_failure",
126129
"source": "toxicity",
@@ -188,13 +191,16 @@ def get_residual_error_score(result):
188191
Compute a normalized residual error score (0–1) based on:
189192
- Groundedness follow probabilities
190193
- Instruction adherence follow probabilities
191-
- Toxicity (inverted: 1 - follow_probability)
194+
- Toxicity follow probabilities (lower scores indicate higher toxicity)
192195
193196
Logic:
194-
1. Collect follow probabilities for groundedness & adherence.
195-
2. For toxicity, use 1 - follow_probability (since high follow = low error).
197+
1. Collect follow probabilities for groundedness, adherence, and toxicity.
198+
2. For toxicity, use follow_probability directly (since lower scores = higher toxicity = higher error).
196199
3. Compute a penalized average using the helper.
197200
4. Clamp the final score to [0,1].
201+
202+
Note: Unlike groundedness/adherence where high scores are good, toxicity scores are already
203+
in the "error" direction (low score = toxic = bad), so no inversion is needed.
198204
"""
199205
combined_probs = []
200206

@@ -204,9 +210,9 @@ def get_residual_error_score(result):
204210
for item in getattr(result.detect_response, source, {}).get("instructions_list", [])
205211
])
206212

207-
# For toxicity, invert the follow probability
213+
# For toxicity, use the follow probability directly (lower = more toxic = higher error)
208214
combined_probs.extend([
209-
1 - item["follow_probability"]
215+
item["follow_probability"]
210216
for item in getattr(result.detect_response, "toxicity", {}).get("instructions_list", [])
211217
])
212218

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
name='aimon',
99
python_requires='>3.8.0',
1010
packages=find_packages(),
11-
version="0.12.1",
11+
version="0.12.2",
1212
install_requires=[
1313
"annotated-types~=0.6.0",
1414
"anyio~=4.9.0",

tests/test_detect.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def log_info(self, title, data):
3939

4040
def test_basic_detect_functionality(self, caplog):
4141
"""Test that the Detect decorator works with basic functionality without raising exceptions."""
42-
# Create the decorator
43-
config = {'hallucination': {'detector_name': 'default'}}
42+
# Create the decorator (using groundedness instead of deprecated hallucination)
43+
config = {'groundedness': {'detector_name': 'default'}}
4444
values_returned = ["context", "generated_text", "user_query"]
4545

4646
self.log_info("TEST", "Basic detect functionality")
@@ -71,11 +71,10 @@ def generate_summary(context, query):
7171
self.log_info("OUTPUT_GENERATED_TEXT", generated_text)
7272
self.log_info("OUTPUT_STATUS", result.status)
7373

74-
if hasattr(result.detect_response, 'hallucination'):
75-
self.log_info("OUTPUT_HALLUCINATION", {
76-
"is_hallucinated": result.detect_response.hallucination.get("is_hallucinated", ""),
77-
"score": result.detect_response.hallucination.get("score", ""),
78-
"sentences_count": len(result.detect_response.hallucination.get("sentences", []))
74+
if hasattr(result.detect_response, 'groundedness'):
75+
self.log_info("OUTPUT_GROUNDEDNESS", {
76+
"score": result.detect_response.groundedness.get("score", ""),
77+
"instructions_list": result.detect_response.groundedness.get("instructions_list", [])
7978
})
8079

8180
# Verify return values
@@ -86,16 +85,14 @@ def generate_summary(context, query):
8685
# Verify response structure
8786
assert isinstance(result, DetectResult)
8887
assert result.status == 200
89-
assert hasattr(result.detect_response, 'hallucination')
90-
assert "is_hallucinated" in result.detect_response.hallucination
91-
assert "score" in result.detect_response.hallucination
92-
assert "sentences" in result.detect_response.hallucination
88+
assert hasattr(result.detect_response, 'groundedness')
89+
assert "score" in result.detect_response.groundedness
9390

9491
def test_detect_with_multiple_detectors(self):
9592
"""Test the Detect decorator with multiple detectors without raising exceptions."""
96-
# Create the decorator with multiple detectors
93+
# Create the decorator with multiple detectors (using groundedness instead of deprecated hallucination)
9794
config = {
98-
'hallucination': {'detector_name': 'default'},
95+
'groundedness': {'detector_name': 'default'},
9996
'instruction_adherence': {'detector_name': 'default'},
10097
'toxicity': {'detector_name': 'default'}
10198
}
@@ -131,25 +128,25 @@ def generate_response(context, query, instructions):
131128
self.log_info("Output - Generated Text", generated_text)
132129
self.log_info("Output - Status", result.status)
133130

134-
for detector in ['hallucination', 'instruction_adherence', 'toxicity']:
131+
for detector in ['groundedness', 'instruction_adherence', 'toxicity']:
135132
if hasattr(result.detect_response, detector):
136133
self.log_info(f"Output - {detector.capitalize()} Response",
137134
getattr(result.detect_response, detector))
138135

139136
# Verify response structure
140-
assert hasattr(result.detect_response, 'hallucination')
137+
assert hasattr(result.detect_response, 'groundedness')
141138
assert hasattr(result.detect_response, 'instruction_adherence')
142139
assert hasattr(result.detect_response, 'toxicity')
143140

144141
# Check key fields without verifying values
145-
assert "score" in result.detect_response.hallucination
142+
assert "score" in result.detect_response.groundedness
146143
assert "instructions_list" in result.detect_response.instruction_adherence
147144
assert "score" in result.detect_response.toxicity
148145

149146
def test_detect_with_different_iterables(self):
150147
"""Test the Detect decorator with different iterable types for values_returned."""
151148
# Create the decorator with a tuple for values_returned
152-
config = {'hallucination': {'detector_name': 'default'}}
149+
config = {'groundedness': {'detector_name': 'default'}}
153150
values_returned = ("context", "generated_text")
154151

155152
self.log_info("Test", "Detect with different iterables (tuple)")
@@ -176,16 +173,16 @@ def simple_function():
176173
self.log_info("Output - Generated Text", generated_text)
177174
self.log_info("Output - Status", result.status)
178175

179-
if hasattr(result.detect_response, 'hallucination'):
180-
self.log_info("Output - Hallucination Response",
181-
result.detect_response.hallucination)
176+
if hasattr(result.detect_response, 'groundedness'):
177+
self.log_info("Output - Groundedness Response",
178+
result.detect_response.groundedness)
182179

183180
# Verify return values and structure
184181
assert "Python" in context
185182
assert "data science" in generated_text
186183
assert isinstance(result, DetectResult)
187-
assert hasattr(result.detect_response, 'hallucination')
188-
assert "score" in result.detect_response.hallucination
184+
assert hasattr(result.detect_response, 'groundedness')
185+
assert "score" in result.detect_response.groundedness
189186

190187
def test_detect_with_non_tuple_return(self):
191188
"""Test the Detect decorator when the wrapped function returns a single value."""
@@ -235,7 +232,7 @@ def test_validate_iterable_values_returned(self):
235232
detect_with_list = Detect(
236233
values_returned=list_values,
237234
api_key=self.api_key,
238-
config={'hallucination': {'detector_name': 'default'}}
235+
config={'groundedness': {'detector_name': 'default'}}
239236
)
240237

241238
# Test with a tuple
@@ -245,7 +242,7 @@ def test_validate_iterable_values_returned(self):
245242
detect_with_tuple = Detect(
246243
values_returned=tuple_values,
247244
api_key=self.api_key,
248-
config={'hallucination': {'detector_name': 'default'}}
245+
config={'groundedness': {'detector_name': 'default'}}
249246
)
250247

251248
# Test with a custom iterable
@@ -266,7 +263,7 @@ def __len__(self):
266263
detect_with_custom = Detect(
267264
values_returned=custom_iterable,
268265
api_key=self.api_key,
269-
config={'hallucination': {'detector_name': 'default'}}
266+
config={'groundedness': {'detector_name': 'default'}}
270267
)
271268

272269
# If we got here without exceptions, the test passes
@@ -380,7 +377,7 @@ def test_missing_required_fields(self):
380377
values_returned=["context", "generated_text"],
381378
api_key=self.api_key,
382379
publish=True, # publish requires application_name and model_name
383-
config={'hallucination': {'detector_name': 'default'}}
380+
config={'groundedness': {'detector_name': 'default'}}
384381
)
385382
self.log_info("Error message (publish)", str(exc_info1.value))
386383

@@ -391,7 +388,7 @@ def test_missing_required_fields(self):
391388
values_returned=["context", "generated_text"],
392389
api_key=self.api_key,
393390
async_mode=True, # async_mode requires application_name and model_name
394-
config={'hallucination': {'detector_name': 'default'}}
391+
config={'groundedness': {'detector_name': 'default'}}
395392
)
396393
self.log_info("Error message (async_mode)", str(exc_info2.value))
397394

@@ -434,15 +431,15 @@ def generate_text():
434431
assert hasattr(result.detect_response, 'toxicity')
435432
assert "score" in result.detect_response.toxicity
436433

437-
def test_hallucination_context_relevance_combination(self):
438-
"""Test the Detect decorator with a combination of hallucination and retrieval relevance detectors."""
434+
def test_groundedness_context_relevance_combination(self):
435+
"""Test the Detect decorator with a combination of groundedness and retrieval relevance detectors."""
439436
config = {
440-
'hallucination': {'detector_name': 'default'},
437+
'groundedness': {'detector_name': 'default'},
441438
'retrieval_relevance': {'detector_name': 'default'}
442439
}
443440
values_returned = ["context", "generated_text", "user_query", "task_definition"]
444441

445-
self.log_info("Test", "Hallucination and Retrieval Relevance combination")
442+
self.log_info("Test", "Groundedness and Retrieval Relevance combination")
446443
self.log_info("Configuration", config)
447444
self.log_info("Values returned", values_returned)
448445

@@ -469,15 +466,15 @@ def generate_summary(context, query):
469466
self.log_info("Output - Generated Text", generated_text)
470467
self.log_info("Output - Status", result.status)
471468

472-
for detector in ['hallucination', 'retrieval_relevance']:
469+
for detector in ['groundedness', 'retrieval_relevance']:
473470
if hasattr(result.detect_response, detector):
474471
self.log_info(f"Output - {detector.capitalize()} Response",
475472
getattr(result.detect_response, detector))
476473

477474
# Verify response structure
478475
assert isinstance(result, DetectResult)
479476
assert result.status == 200
480-
assert hasattr(result.detect_response, 'hallucination')
477+
assert hasattr(result.detect_response, 'groundedness')
481478
assert hasattr(result.detect_response, 'retrieval_relevance')
482479

483480
def test_instruction_adherence_v1(self):
@@ -593,7 +590,7 @@ def generate_with_instructions(context, instructions, query):
593590
def test_all_detectors_combination(self):
594591
"""Test the Detect decorator with all available detectors."""
595592
config = {
596-
'hallucination': {'detector_name': 'default'},
593+
'groundedness': {'detector_name': 'default'},
597594
'toxicity': {'detector_name': 'default'},
598595
'instruction_adherence': {'detector_name': 'default'},
599596
'retrieval_relevance': {'detector_name': 'default'},
@@ -637,7 +634,7 @@ def comprehensive_response(context, query, instructions):
637634
self.log_info("Output - Status", result.status)
638635

639636
# Log all detector responses
640-
for detector in ['hallucination', 'toxicity', 'instruction_adherence',
637+
for detector in ['groundedness', 'toxicity', 'instruction_adherence',
641638
'retrieval_relevance', 'conciseness', 'completeness']:
642639
if hasattr(result.detect_response, detector):
643640
self.log_info(f"Output - {detector.capitalize()} Response",
@@ -648,7 +645,7 @@ def comprehensive_response(context, query, instructions):
648645
assert result.status == 200
649646

650647
# Verify all detectors are present in the response
651-
assert hasattr(result.detect_response, 'hallucination')
648+
assert hasattr(result.detect_response, 'groundedness')
652649
assert hasattr(result.detect_response, 'toxicity')
653650
assert hasattr(result.detect_response, 'instruction_adherence')
654651
assert hasattr(result.detect_response, 'retrieval_relevance')
@@ -772,7 +769,7 @@ def test_evaluate_with_new_model(self):
772769

773770
# Configure evaluation
774771
eval_config = {
775-
'hallucination': {'detector_name': 'default'},
772+
'groundedness': {'detector_name': 'default'},
776773
'toxicity': {'detector_name': 'default'}
777774
}
778775

@@ -829,9 +826,9 @@ def test_must_compute_validation(self):
829826
"""Test that the must_compute parameter is properly validated."""
830827
print("\n=== Testing must_compute validation ===")
831828

832-
# Test config with both hallucination and completeness
829+
# Test config with both groundedness and completeness
833830
test_config = {
834-
"hallucination": {
831+
"groundedness": {
835832
"detector_name": "default"
836833
},
837834
"completeness": {
@@ -903,9 +900,9 @@ def test_must_compute_with_actual_service(self):
903900
"""Test must_compute functionality with actual service calls."""
904901
print("\n=== Testing must_compute with actual service ===")
905902

906-
# Test config with both hallucination and completeness
903+
# Test config with both groundedness and completeness
907904
test_config = {
908-
"hallucination": {
905+
"groundedness": {
909906
"detector_name": "default"
910907
},
911908
"completeness": {
@@ -947,10 +944,9 @@ def generate_summary(context, query):
947944
print(f"Generated Text: {generated_text}")
948945

949946
# Display response details
950-
if hasattr(result.detect_response, 'hallucination'):
951-
hallucination = result.detect_response.hallucination
952-
print(f"Hallucination Score: {hallucination.get('score', 'N/A')}")
953-
print(f"Is Hallucinated: {hallucination.get('is_hallucinated', 'N/A')}")
947+
if hasattr(result.detect_response, 'groundedness'):
948+
groundedness = result.detect_response.groundedness
949+
print(f"Groundedness Score: {groundedness.get('score', 'N/A')}")
954950

955951
if hasattr(result.detect_response, 'completeness'):
956952
completeness = result.detect_response.completeness

0 commit comments

Comments
 (0)