11import pytest
2- from unittest . mock import MagicMock
2+ from dataclasses import dataclass
33from completions .services .context_generation_service import ContextGenerationService
44
55
6+ @dataclass
7+ class FakeChunk :
8+ id : int
9+ content : str
10+
11+
612@pytest .fixture
713def fake_repository ():
8- repo = MagicMock ()
9- repo .get_top_k_chunks_by_similarity .return_value = [
10- MagicMock (content = "Chunk 1" ),
11- MagicMock (content = "Chunk 2" ),
12- MagicMock (content = "Chunk 3" ),
13- ]
14- return repo
14+ class FakeRepository :
15+ def __init__ (self ):
16+ self ._return_value = [
17+ FakeChunk (id = 1 , content = "Chunk 1" ),
18+ FakeChunk (id = 2 , content = "Chunk 2" ),
19+ FakeChunk (id = 3 , content = "Chunk 3" ),
20+ ]
21+ self .called_with = None
22+
23+ def get_top_k_chunks_by_similarity (self , query , k ):
24+ self .called_with = (query , k )
25+ return self ._return_value
26+
27+ def set_return_value (self , chunks ):
28+ self ._return_value = chunks
29+
30+ def assert_called_once_with (self , query , k ):
31+ assert self .called_with == (query , k )
32+
33+ return FakeRepository ()
1534
1635
1736def test_generate_context_returns_joined_content (fake_repository ):
@@ -20,35 +39,31 @@ def test_generate_context_returns_joined_content(fake_repository):
2039 )
2140 context = service .process (k = 3 )
2241
23- fake_repository .get_top_k_chunks_by_similarity .assert_called_once_with (
24- "example query" , 3
25- )
42+ fake_repository .assert_called_once_with ("example query" , 3 )
2643 assert context == "Chunk 1\n \n Chunk 2\n \n Chunk 3"
2744
2845
2946def test_generate_context_with_different_k (fake_repository ):
30- # Change fake_repository to return fewer chunks
31- fake_repository .get_top_k_chunks_by_similarity .return_value = [
32- MagicMock (content = "Only Chunk 1" ),
33- ]
47+ fake_repository .set_return_value (
48+ [
49+ FakeChunk (id = 1 , content = "Only Chunk 1" ),
50+ ]
51+ )
52+
3453 service = ContextGenerationService (
3554 query = "different query" , repository = fake_repository
3655 )
3756 context = service .process (k = 1 )
3857
39- fake_repository .get_top_k_chunks_by_similarity .assert_called_once_with (
40- "different query" , 1
41- )
58+ fake_repository .assert_called_once_with ("different query" , 1 )
4259 assert context == "Only Chunk 1"
4360
4461
4562def test_generate_context_when_no_chunks (fake_repository ):
46- # Make it return an empty list
47- fake_repository . get_top_k_chunks_by_similarity . return_value = []
63+ fake_repository . set_return_value ([])
64+
4865 service = ContextGenerationService (query = "empty query" , repository = fake_repository )
4966 context = service .process (k = 3 )
5067
51- fake_repository .get_top_k_chunks_by_similarity .assert_called_once_with (
52- "empty query" , 3
53- )
68+ fake_repository .assert_called_once_with ("empty query" , 3 )
5469 assert context == ""
0 commit comments