Skip to content

Commit 951d5d9

Browse files
Merge pull request #55 from amazon-science/feature/prompt_tuner_fixes
Bug Fixes for Prompt Tuner
2 parents a1ae458 + a8b98c7 commit 951d5d9

File tree

6 files changed

+69
-27
lines changed

6 files changed

+69
-27
lines changed

src/fmcore/llm/mixins/llm_mixins.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional, Union
22

3+
from pydantic import Field
4+
35
from fmcore.llm.types.llm_types import LLMConfig, DistributedLLMConfig
46
from fmcore.types.mixins_types import Mixin
57
from fmcore.types.typed import MutableTyped
@@ -13,4 +15,4 @@ class LLMConfigMixin(MutableTyped, Mixin):
1315
llm_config (Optional[LLMConfig]): The LLM configuration object.
1416
"""
1517

16-
llm_config: Union[LLMConfig, DistributedLLMConfig]
18+
llm_config: Union[LLMConfig, DistributedLLMConfig] = Field(union_mode="left_to_right")

src/fmcore/llm/mixins/provider_mixins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class AWSAccountMixin(MutableTyped, Mixin):
1616
region (str): The AWS region where the account operates. Defaults to 'us-east-1'.
1717
"""
1818

19-
role_arn: str
20-
region: str = Field(default=AWSRegion.US_EAST_1.value)
19+
role_arn: Optional[str] = Field(default=None)
20+
region: Optional[str] = Field(default="us-east-1")
2121

2222

2323
class APIKeyServiceMixin(MutableTyped, Mixin):

src/fmcore/prompt_tuner/dspy/optimizer_wrapper/miprov2/miprov2_optimizer_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional
22

3+
from pydantic import Field
4+
35
from fmcore.prompt_tuner.types.enums.optimizer_enums import OptimizerMetricType, DSPyOptimizerType
46
from fmcore.prompt_tuner.types.mixins.optimizer_mixins import (
57
StudentConfigMixin,
@@ -22,8 +24,10 @@ class MIPROv2OptimizerParams(BaseOptimizerParams):
2224
"""
2325

2426
optimizer_metric: str = OptimizerMetricType.ACCURACY
25-
auto: Optional[str] = "light"
26-
num_candidates: int = 7
27+
num_candidates: Optional[int] = Field(default=7)
28+
max_errors: Optional[int] = Field(default=10)
29+
minibatch: Optional[bool] = Field(default=False)
30+
auto: Optional[str] = None
2731

2832

2933
class MIPROv2OptimizerConfig(

src/fmcore/prompt_tuner/dspy/utils/dspy_utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,8 @@ def evaluate_func(example: dspy.Example, prediction: dspy.Prediction, trace=None
168168
"output": prediction.toDict(),
169169
}
170170

171-
try:
172-
# We are using this hack because dspy doesn't support async
173-
decision = AsyncUtils.execute(evaluator.aevaluate(data=row))
174-
except Exception as e:
175-
# Defaulting to false incase of failures
176-
Log.info(f"Error {e} during evaluating {row}")
177-
decision = False
171+
# We are using this hack because dspy doesn't support async
172+
decision = AsyncUtils.execute(evaluator.aevaluate(data=row))
178173

179174
return decision
180175

src/fmcore/prompt_tuner/evaluator/llm_as_a_judge_boolean/llm_as_a_judge_boolean_evaluator.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fmcore.mapper.llm_response_json_mapper import LLMResponseJsonMapper
1212
from fmcore.mapper.criteria_checker_mapper import CriteriaCheckerMapper
1313
from fmcore.mapper.llm_inference_mapper import LLMInferenceMapper
14+
from fmcore.utils.logging_utils import Log
1415

1516

1617
class LLMAsJudgeBooleanEvaluator(BaseEvaluator[Dict, bool]):
@@ -72,19 +73,38 @@ def _get_instance(cls, *, evaluator_config: EvaluatorConfig) -> "LLMAsJudgeBoole
7273

7374
def evaluate(self, data: Dict) -> bool:
7475
"""
75-
Processes the input data by using the llm_as_a_judge_boolean_mapper to evaluate the context.
76+
Processes the input data using the llm_as_a_judge_boolean_mapper to evaluate the context.
7677
7778
Args:
7879
data (BooleanLLMJudgeInput): Input data containing context for evaluation.
7980
8081
Returns:
81-
BooleanLLMJudgeOutput: Evaluation result as a boolean decision.
82+
bool: Evaluation result as a boolean decision.
8283
"""
83-
# Format the context into messages using the template
84-
formatted_message: BaseMessage = self.text_prompt_mapper.map(data)
85-
llm_response: BaseMessage = self.llm_inference_mapper.map([formatted_message])
86-
json_response: Dict = self.json_mapper.map(llm_response.content)
87-
decision: bool = self.criteria_checker.map(json_response)
84+
formatted_message = llm_response = json_response = decision = None
85+
86+
try:
87+
formatted_message = self.text_prompt_mapper.map(data)
88+
llm_response = self.llm_inference_mapper.map([formatted_message])
89+
json_response = self.json_mapper.map(llm_response.content)
90+
decision = self.criteria_checker.map(json_response)
91+
92+
if not isinstance(decision, bool):
93+
raise ValueError("Decision is not a boolean value")
94+
95+
except Exception as e:
96+
Log.error(
97+
"[SYNC EVALUATION ERROR]\t\t ->"
98+
f"[INPUT DATA]: {data}\t\t ->"
99+
f"[PROMPT]: {self.evaluator_config.evaluator_params.prompt}\t\t ->"
100+
f"[FORMATTED MESSAGE]: {formatted_message}\t\t ->"
101+
f"[LLM RESPONSE]: {llm_response}\t\t ->"
102+
f"[JSON RESPONSE]: {json_response}\t\t ->"
103+
f"[DECISION]: {decision}\t\t ->"
104+
f"[ERROR]: {e}"
105+
)
106+
raise
107+
88108
return decision
89109

90110
async def aevaluate(self, data: Dict) -> bool:
@@ -95,11 +115,30 @@ async def aevaluate(self, data: Dict) -> bool:
95115
data (BooleanLLMJudgeInput): Input data containing context for evaluation.
96116
97117
Returns:
98-
BooleanLLMJudgeOutput: Evaluation result as a boolean decision.
118+
bool: Evaluation result as a boolean decision.
99119
"""
100-
# Format the context into messages using the template
101-
formatted_message: BaseMessage = await self.text_prompt_mapper.amap(data)
102-
llm_response: BaseMessage = await self.llm_inference_mapper.amap([formatted_message])
103-
json_response: Dict = await self.json_mapper.amap(llm_response.content)
104-
decision: bool = await self.criteria_checker.amap(json_response)
120+
formatted_message = llm_response = json_response = decision = None
121+
122+
try:
123+
formatted_message = await self.text_prompt_mapper.amap(data)
124+
llm_response = await self.llm_inference_mapper.amap([formatted_message])
125+
json_response = await self.json_mapper.amap(llm_response.content)
126+
decision = await self.criteria_checker.amap(json_response)
127+
128+
if not isinstance(decision, bool):
129+
raise ValueError("Decision is not a boolean value")
130+
131+
except Exception as e:
132+
Log.error(
133+
"[ASYNC EVALUATION ERROR]\t\t->"
134+
f"[INPUT DATA]: {data}\t\t ->"
135+
f"[PROMPT]: {self.evaluator_config.evaluator_params.prompt}\t\t ->"
136+
f"[FORMATTED MESSAGE]: {formatted_message}\t\t ->"
137+
f"[LLM RESPONSE]: {llm_response}\t\t ->"
138+
f"[JSON RESPONSE]: {json_response}\t\t ->"
139+
f"[DECISION]: {decision}\t\t ->"
140+
f"[ERROR]: {e}"
141+
)
142+
raise
143+
105144
return decision

src/fmcore/prompt_tuner/types/mixins/optimizer_mixins.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional, Union
22

3+
from pydantic import Field
4+
35
from fmcore.llm.types.llm_types import LLMConfig, DistributedLLMConfig
46
from fmcore.prompt_tuner.evaluator.types.evaluator_types import EvaluatorConfig
57
from fmcore.types.mixins_types import Mixin
@@ -14,7 +16,7 @@ class StudentConfigMixin(MutableTyped, Mixin):
1416
student_config (Optional[LLMConfig]): The LLM configuration object for student model
1517
"""
1618

17-
student_config: Union[LLMConfig, DistributedLLMConfig]
19+
student_config: Union[LLMConfig, DistributedLLMConfig] = Field(union_mode="left_to_right")
1820

1921

2022
class TeacherConfigMixin(MutableTyped, Mixin):
@@ -25,7 +27,7 @@ class TeacherConfigMixin(MutableTyped, Mixin):
2527
teacher_config (Optional[LLMConfig]): The LLM configuration object for teacher model
2628
"""
2729

28-
teacher_config: Union[LLMConfig, DistributedLLMConfig]
30+
teacher_config: Union[LLMConfig, DistributedLLMConfig] = Field(union_mode="left_to_right")
2931

3032

3133
class EvaluatorConfigMixin(MutableTyped, Mixin):

0 commit comments

Comments
 (0)