Skip to content

Commit f6c9677

Browse files
committed
Add source repository tests
1 parent 9975633 commit f6c9677

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock
3+
from infrastructure.db.models import Source
4+
from infrastructure.db.source_repository import (
5+
SourceRepository,
6+
)
7+
8+
9+
@pytest.fixture
10+
def fake_session():
11+
return MagicMock()
12+
13+
14+
@pytest.fixture
15+
def repo(fake_session):
16+
with patch(
17+
"infrastructure.db.source_repository.SessionLocal", return_value=fake_session
18+
):
19+
yield SourceRepository()
20+
21+
22+
def test_has_chunks_true(fake_session, repo):
23+
"""Should return True if a chunk exists."""
24+
fake_session.query.return_value.first.return_value = True
25+
assert repo.has_chunks() is True
26+
27+
28+
def test_has_chunks_false(fake_session, repo):
29+
"""Should return False if no chunks exist."""
30+
fake_session.query.return_value.first.return_value = None
31+
assert repo.has_chunks() is False
32+
33+
34+
def test_get_or_create_source_existing(fake_session, repo):
35+
"""Should return existing source if found."""
36+
fake_source = MagicMock()
37+
fake_session.query.return_value.filter_by.return_value.first.return_value = (
38+
fake_source
39+
)
40+
41+
source = repo.get_or_create_source("Test Source", "http://example.com")
42+
43+
assert source == fake_source
44+
fake_session.add.assert_not_called()
45+
fake_session.commit.assert_not_called()
46+
47+
48+
def test_get_or_create_source_new(fake_session, repo):
49+
"""Should create and return a new source if not found."""
50+
fake_session.query.return_value.filter_by.return_value.first.return_value = None
51+
52+
source = repo.get_or_create_source("New Source", "http://newsite.com")
53+
54+
assert isinstance(source, Source)
55+
fake_session.add.assert_called_once()
56+
fake_session.commit.assert_called_once()
57+
fake_session.refresh.assert_called_once_with(source)
58+
59+
60+
def test_save_chunks(fake_session, repo):
61+
"""Should add chunks and commit them."""
62+
fake_source = MagicMock()
63+
chunks = ["chunk one", "chunk two"]
64+
embeddings = [[0.1] * 384, [0.2] * 384]
65+
66+
repo.save_chunks(fake_source, chunks, embeddings)
67+
68+
# Should add two chunks
69+
assert fake_session.add.call_count == 2
70+
fake_session.commit.assert_called_once()
71+
72+
73+
@patch("infrastructure.db.source_repository.SentenceTransformer")
74+
def test_get_top_k_chunks_by_similarity(mock_transformer, fake_session, repo):
75+
"""Should query top K chunks by similarity."""
76+
# Mock SentenceTransformer.encode
77+
mock_model_instance = MagicMock()
78+
mock_model_instance.encode.return_value = [[0.5] * 384]
79+
mock_transformer.return_value = mock_model_instance
80+
81+
# Mock session.execute
82+
fake_execute_result = MagicMock()
83+
fake_execute_result.scalars.return_value.all.return_value = ["chunk1", "chunk2"]
84+
fake_session.execute.return_value = fake_execute_result
85+
86+
chunks = repo.get_top_k_chunks_by_similarity("What is AI?", k=2)
87+
88+
mock_model_instance.encode.assert_called_once_with(["What is AI?"])
89+
fake_session.execute.assert_called_once()
90+
assert chunks == ["chunk1", "chunk2"]

0 commit comments

Comments
 (0)