Skip to content

Commit a1ae458

Browse files
SibaRajendranSiba Rajendran
andauthored
Dspy Bug Fixes (#54)
Co-authored-by: Siba Rajendran <rajsiba@amazon.com>
1 parent 342a056 commit a1ae458

File tree

5 files changed

+41
-3
lines changed

5 files changed

+41
-3
lines changed

src/fmcore/prompt_tuner/dspy/dspy_prompt_tuner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import os
2+
import tempfile
3+
4+
# Set the DSP_CACHEDIR environment variable to the system's default temporary directory
5+
os.environ["DSP_CACHEDIR"] = tempfile.gettempdir()
6+
os.environ["DSPY_CACHEDIR"] = tempfile.gettempdir()
7+
18
import dspy
29
import pandas as pd
310
from typing import Dict, List

src/fmcore/prompt_tuner/dspy/lm_adapters/dspy_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from fmcore.llm.types.llm_types import LLMConfig
55
from langchain_core.messages import BaseMessage
66

7+
from fmcore.utils.async_utils import AsyncUtils
8+
79

810
class DSPyLLMAdapter(dspy.LM):
911
"""
@@ -65,7 +67,8 @@ def __call__(
6567
if prompt:
6668
messages = [{"role": "user", "content": prompt}]
6769

68-
response = self.llm.invoke(messages)
70+
# We are using this hack because dspy doesn't support async
71+
response = AsyncUtils.execute(self.llm.ainvoke(messages))
6972
result = [response.content]
7073

7174
# Update history with DSPy constructs, which currently support only dictionaries

src/fmcore/prompt_tuner/dspy/utils/dspy_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from fmcore.prompt_tuner.evaluator.types.evaluator_types import EvaluatorConfig
1010
from fmcore.prompt_tuner.types.prompt_tuner_types import PromptConfig, PromptEvaluationResult
1111
from fmcore.types.enums.dataset_enums import DatasetType
12+
from fmcore.utils.async_utils import AsyncUtils
13+
from fmcore.utils.logging_utils import Log
1214

1315

1416
class DSPyUtils:
@@ -166,7 +168,15 @@ def evaluate_func(example: dspy.Example, prediction: dspy.Prediction, trace=None
166168
"output": prediction.toDict(),
167169
}
168170

169-
return evaluator.evaluate(data=row)
171+
try:
172+
# We are using this hack because dspy doesn't support async
173+
decision = AsyncUtils.execute(evaluator.aevaluate(data=row))
174+
except Exception as e:
175+
# Defaulting to false incase of failures
176+
Log.info(f"Error {e} during evaluating {row}")
177+
decision = False
178+
179+
return decision
170180

171181
return evaluate_func
172182

src/fmcore/prompt_tuner/evaluator/llm_as_a_judge_boolean/llm_as_a_judge_boolean_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def aevaluate(self, data: Dict) -> bool:
9898
BooleanLLMJudgeOutput: Evaluation result as a boolean decision.
9999
"""
100100
# Format the context into messages using the template
101-
formatted_message: BaseMessage = await self.text_prompt_mapper.amap(data.context)
101+
formatted_message: BaseMessage = await self.text_prompt_mapper.amap(data)
102102
llm_response: BaseMessage = await self.llm_inference_mapper.amap([formatted_message])
103103
json_response: Dict = await self.json_mapper.amap(llm_response.content)
104104
decision: bool = await self.criteria_checker.amap(json_response)

src/fmcore/utils/async_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import asyncio
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
import asyncio
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
8+
class AsyncUtils:
9+
_executor = ThreadPoolExecutor()
10+
11+
@staticmethod
12+
def execute(coro):
13+
"""
14+
Executes an async coroutine in a thread pool executor.
15+
- No event loop interaction: Just submit coroutines to the thread pool.
16+
- Always uses ThreadPoolExecutor to run async functions.
17+
"""
18+
return AsyncUtils._executor.submit(lambda: asyncio.run(coro)).result()

0 commit comments

Comments
 (0)