Skip to content

Commit a5a2d85

Browse files
committed
Add results to completions
1 parent 6cec5fe commit a5a2d85

File tree

6 files changed

+51
-23
lines changed

6 files changed

+51
-23
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ prompt_toolkit==3.0.51
3232
psycopg2-binary==2.9.10
3333
PyMuPDF==1.25.5
3434
pytest==8.3.5
35+
pytest-describe==2.2.0
3536
python-dateutil==2.9.0.post0
3637
PyYAML==6.0.2
3738
redis==5.2.1
3839
regex==2024.11.6
3940
requests==2.32.3
41+
result==0.17.0
4042
safetensors==0.5.3
4143
scikit-learn==1.6.1
4244
scipy==1.15.2

src/apps/dev_cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from completions import CompletionService
66
from infrastructure.db.source_repository import SourceRepository
77
from config.logger import logger
8+
from result import Ok
89

910
if __name__ == "__main__":
1011
with SourceRepository() as repository:
@@ -14,5 +15,12 @@
1415
q = input("\nAsk a question (or 'exit'): ")
1516
if q.lower() in {"exit", "quit"}:
1617
break
17-
answer = service.create(q)
18-
logger.info("dev_cli.answer_displayed", answer=answer)
18+
19+
result = service.create(q)
20+
21+
if isinstance(result, Ok):
22+
logger.info("dev_cli.answer_displayed", answer=result.ok_value)
23+
print(f"\nAnswer: {result.ok_value}")
24+
else:
25+
logger.error("dev_cli.answer_failed", error=repr(result.err_value))
26+
print(f"\nError: {result.err_value}")
Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from result import Ok, Result
12
from completions.prompts import create_prompt
23
from infrastructure.ollama import create_completion
34
from infrastructure.db.source_repository import SourceRepository
@@ -9,16 +10,26 @@ class CompletionService:
910
def __init__(self, repository: SourceRepository):
1011
self.repository = repository
1112

12-
def create(self, query, k=3):
13+
def create(self, query: str, k: int = 3) -> Result[str, Exception]:
1314
logger.info("completion_service.create", query=query)
1415

1516
prompt_context = ContextGenerationService(query, self.repository).process(k)
1617
logger.info("completion_service.context_generated", query=query)
1718

18-
answer = create_completion(create_prompt(prompt_context, query))
19-
logger.info(
20-
"completion_service.answer_generated",
21-
query=query,
22-
answer=answer[0:10] + "...",
23-
)
24-
return answer
19+
prompt = create_prompt(query, prompt_context)
20+
result = create_completion(prompt)
21+
22+
if isinstance(result, Ok):
23+
logger.info(
24+
"completion_service.answer_generated",
25+
query=query,
26+
answer=result.ok_value[:10] + "...",
27+
)
28+
else:
29+
logger.error(
30+
"completion_service.answer_generation_error",
31+
query=query,
32+
error=result.err_value,
33+
)
34+
35+
return result

src/config/logger.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(self, name: Optional[str] = None):
2525
self.logger = logging.getLogger(name or __name__)
2626
self.logger.propagate = False # 🔒 Prevent double logging
2727

28-
# Only add handler if none exist
2928
if not self.logger.handlers:
3029
handler = logging.StreamHandler()
3130
formatter = ColorFormatter("%(asctime)s [%(levelname)s] %(message)s")
@@ -48,7 +47,7 @@ def debug(self, event: str, **data: Any):
4847

4948
def _format(self, event: str, data: Dict[str, Any]) -> str:
5049
structured = {"event": event, **data}
51-
return json.dumps(structured)
50+
return json.dumps(structured, default=str)
5251

5352

5453
logger = SemanticLogger(__name__)

src/infrastructure/ollama.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import requests
2+
from result import Ok, Err, Result
23
from config import OLLAMA_MODEL, OLLAMA_URL
34

45

5-
def create_completion(prompt):
6-
res = requests.post(
7-
f"{OLLAMA_URL}/api/generate",
8-
json={"model": OLLAMA_MODEL, "prompt": prompt, "stream": False},
9-
)
10-
return res.json()["response"]
6+
def create_completion(prompt: str) -> Result[str, Exception]:
7+
try:
8+
res = requests.post(
9+
f"{OLLAMA_URL}/api/generate",
10+
json={"model": OLLAMA_MODEL, "prompt": prompt, "stream": False},
11+
timeout=30,
12+
)
13+
res.raise_for_status()
14+
return Ok(res.json()["response"])
15+
except (requests.RequestException, ValueError, KeyError) as e:
16+
return Err(e)

tests/completions/services/test_completion_service.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from unittest.mock import MagicMock, patch
3+
from result import Ok
34
from completions import CompletionService
45

56

@@ -27,23 +28,24 @@ def service(fake_repository):
2728

2829
@patch(
2930
"completions.services.completion_service.create_completion",
30-
return_value="LLM Answer",
31+
return_value=Ok("LLM Answer"),
3132
)
3233
@patch(
3334
"completions.services.completion_service.create_prompt",
3435
return_value="Generated Prompt",
3536
)
36-
def test_create(mock_create_prompt, mock_create_completion, service, capsys):
37-
"""Should call the LLM with the correct prompt and print the answer."""
37+
def test_create(mock_create_prompt, mock_create_completion, service):
38+
"""Should call the LLM with the correct prompt and return Ok result."""
3839

3940
result = service.create("What is AI?", k=3)
4041

4142
service.repository.get_top_k_chunks_by_similarity.assert_called_once_with(
4243
"What is AI?", 3
4344
)
4445
mock_create_prompt.assert_called_once_with(
45-
"Chunk 1\n\nChunk 2\n\nChunk 3", "What is AI?"
46+
"What is AI?", "Chunk 1\n\nChunk 2\n\nChunk 3"
4647
)
4748
mock_create_completion.assert_called_once_with("Generated Prompt")
4849

49-
assert result == "LLM Answer"
50+
assert isinstance(result, Ok)
51+
assert result.ok_value == "LLM Answer"

0 commit comments

Comments
 (0)