Skip to content

Commit 279b93c

Browse files
committed
Update test mocks to be json serializable
1 parent df786d7 commit 279b93c

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

tests/completions/services/test_completion_service.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
from completions import CompletionService
44

55

6+
class FakeChunk:
7+
def __init__(self, id, content):
8+
self.id = id
9+
self.content = content
10+
11+
612
@pytest.fixture
713
def fake_repository():
814
repo = MagicMock()
915
repo.get_top_k_chunks_by_similarity.return_value = [
10-
MagicMock(content="Chunk 1"),
11-
MagicMock(content="Chunk 2"),
12-
MagicMock(content="Chunk 3"),
16+
FakeChunk(id=1, content="Chunk 1"),
17+
FakeChunk(id=2, content="Chunk 2"),
18+
FakeChunk(id=3, content="Chunk 3"),
1319
]
1420
return repo
1521

@@ -30,7 +36,7 @@ def service(fake_repository):
3036
def test_create(mock_create_prompt, mock_create_completion, service, capsys):
3137
"""Should call the LLM with the correct prompt and print the answer."""
3238

33-
service.create("What is AI?", k=3)
39+
result = service.create("What is AI?", k=3)
3440

3541
service.repository.get_top_k_chunks_by_similarity.assert_called_once_with(
3642
"What is AI?", 3
@@ -40,5 +46,4 @@ def test_create(mock_create_prompt, mock_create_completion, service, capsys):
4046
)
4147
mock_create_completion.assert_called_once_with("Generated Prompt")
4248

43-
captured = capsys.readouterr()
44-
assert "LLM Answer" in captured.out
49+
assert result == "LLM Answer"
Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
11
import pytest
2-
from unittest.mock import MagicMock
2+
from dataclasses import dataclass
33
from completions.services.context_generation_service import ContextGenerationService
44

55

6+
@dataclass
7+
class FakeChunk:
8+
id: int
9+
content: str
10+
11+
612
@pytest.fixture
713
def fake_repository():
8-
repo = MagicMock()
9-
repo.get_top_k_chunks_by_similarity.return_value = [
10-
MagicMock(content="Chunk 1"),
11-
MagicMock(content="Chunk 2"),
12-
MagicMock(content="Chunk 3"),
13-
]
14-
return repo
14+
class FakeRepository:
15+
def __init__(self):
16+
self._return_value = [
17+
FakeChunk(id=1, content="Chunk 1"),
18+
FakeChunk(id=2, content="Chunk 2"),
19+
FakeChunk(id=3, content="Chunk 3"),
20+
]
21+
self.called_with = None
22+
23+
def get_top_k_chunks_by_similarity(self, query, k):
24+
self.called_with = (query, k)
25+
return self._return_value
26+
27+
def set_return_value(self, chunks):
28+
self._return_value = chunks
29+
30+
def assert_called_once_with(self, query, k):
31+
assert self.called_with == (query, k)
32+
33+
return FakeRepository()
1534

1635

1736
def test_generate_context_returns_joined_content(fake_repository):
@@ -20,35 +39,31 @@ def test_generate_context_returns_joined_content(fake_repository):
2039
)
2140
context = service.process(k=3)
2241

23-
fake_repository.get_top_k_chunks_by_similarity.assert_called_once_with(
24-
"example query", 3
25-
)
42+
fake_repository.assert_called_once_with("example query", 3)
2643
assert context == "Chunk 1\n\nChunk 2\n\nChunk 3"
2744

2845

2946
def test_generate_context_with_different_k(fake_repository):
30-
# Change fake_repository to return fewer chunks
31-
fake_repository.get_top_k_chunks_by_similarity.return_value = [
32-
MagicMock(content="Only Chunk 1"),
33-
]
47+
fake_repository.set_return_value(
48+
[
49+
FakeChunk(id=1, content="Only Chunk 1"),
50+
]
51+
)
52+
3453
service = ContextGenerationService(
3554
query="different query", repository=fake_repository
3655
)
3756
context = service.process(k=1)
3857

39-
fake_repository.get_top_k_chunks_by_similarity.assert_called_once_with(
40-
"different query", 1
41-
)
58+
fake_repository.assert_called_once_with("different query", 1)
4259
assert context == "Only Chunk 1"
4360

4461

4562
def test_generate_context_when_no_chunks(fake_repository):
46-
# Make it return an empty list
47-
fake_repository.get_top_k_chunks_by_similarity.return_value = []
63+
fake_repository.set_return_value([])
64+
4865
service = ContextGenerationService(query="empty query", repository=fake_repository)
4966
context = service.process(k=3)
5067

51-
fake_repository.get_top_k_chunks_by_similarity.assert_called_once_with(
52-
"empty query", 3
53-
)
68+
fake_repository.assert_called_once_with("empty query", 3)
5469
assert context == ""

0 commit comments

Comments
 (0)