Skip to content

Commit 541eb3f

Browse files
committed
fix(tests): Fix tests
1 parent da821fc commit 541eb3f

File tree

4 files changed

+155
-109
lines changed

4 files changed

+155
-109
lines changed

packages/gg_api_core/src/gg_api_core/tools/remediate_secret_incidents.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,20 @@ class RemediateSecretIncidentsResult(BaseModel):
6868
env_example_content: str | None = Field(default=None, description="Suggested .env.example content")
6969
env_example_instructions: list[str] | None = Field(default=None, description="Instructions for .env.example")
7070
git_commands: dict[str, Any] | None = Field(default=None, description="Git commands to fix history")
71-
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters applied when querying occurrences")
71+
applied_filters: dict[str, Any] = Field(default_factory=dict,
72+
description="Filters applied when querying occurrences")
7273
suggestion: str = Field(default="", description="Suggestions for interpreting results")
74+
sub_tools_results: dict[str, Any] = Field(default_factory=dict, description="Results from sub tools")
7375

7476

7577
class RemediateSecretIncidentsError(BaseModel):
7678
"""Error result from remediating secret incidents."""
7779
error: str = Field(description="Error message")
80+
sub_tools_results: dict[str, Any] = Field(default_factory=dict, description="Results from sub tools")
7881

7982

80-
async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) -> RemediateSecretIncidentsResult | RemediateSecretIncidentsError:
83+
async def remediate_secret_incidents(
84+
params: RemediateSecretIncidentsParams) -> RemediateSecretIncidentsResult | RemediateSecretIncidentsError:
8185
"""
8286
Find and remediate secret incidents in the current repository using EXACT match locations.
8387
@@ -124,12 +128,31 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
124128

125129
try:
126130
# Get detailed occurrences with exact match locations
127-
# Extract filter parameters from list_repo_occurrences_params
128-
occurrences_result = await list_repo_occurrences(params.list_repo_occurrences_params)
131+
# Build ListRepoOccurrencesParams by combining repository info with filters
132+
from .list_repo_occurrences import ListRepoOccurrencesParams
133+
134+
occurrences_params = ListRepoOccurrencesParams(
135+
repository_name=params.repository_name,
136+
source_id=params.source_id,
137+
from_date=params.list_repo_occurrences_params.from_date,
138+
to_date=params.list_repo_occurrences_params.to_date,
139+
presence=params.list_repo_occurrences_params.presence,
140+
tags=params.list_repo_occurrences_params.tags,
141+
exclude_tags=params.list_repo_occurrences_params.exclude_tags,
142+
status=params.list_repo_occurrences_params.status,
143+
severity=params.list_repo_occurrences_params.severity,
144+
validity=params.list_repo_occurrences_params.validity,
145+
ordering=None,
146+
per_page=20,
147+
cursor=None,
148+
get_all=params.get_all,
149+
)
150+
occurrences_result = await list_repo_occurrences(occurrences_params)
129151

130152
# Check if list_repo_occurrences returned an error
131-
if getattr(occurrences_result, "error"):
132-
return RemediateSecretIncidentsError(error=occurrences_result.error)
153+
if hasattr(occurrences_result, "error") and occurrences_result.error:
154+
return RemediateSecretIncidentsError(error=occurrences_result.error,
155+
sub_tools_results={"list_repo_occurrences": occurrences_result})
133156
occurrences = occurrences_result.occurrences
134157

135158
# Filter by assignee if mine=True
@@ -157,6 +180,7 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
157180
remediation_steps=[],
158181
applied_filters=occurrences_result.applied_filters or {},
159182
suggestion=occurrences_result.suggestion or "",
183+
sub_tools_results={"list_repo_occurrences": occurrences_result}
160184
)
161185

162186
# Process occurrences for remediation with exact location data
@@ -168,22 +192,34 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
168192
create_env_example=params.create_env_example,
169193
)
170194
logger.debug(
171-
f"Remediation processing complete, returning result with {len(result.get('remediation_steps', []))} steps"
195+
f"Remediation processing complete, returning result with {len(result.remediation_steps)} steps"
172196
)
173-
# Convert dict result to Pydantic model
174-
return RemediateSecretIncidentsResult(**result)
197+
198+
# Add sub_tools_results and applied_filters/suggestion from occurrences_result
199+
result_dict = result.model_dump()
200+
result_dict["sub_tools_results"] = {
201+
"list_repo_occurrences": {
202+
"total_occurrences": result.summary.get("total_occurrences",
203+
len(occurrences)) if result.summary else len(occurrences),
204+
"affected_files": result.summary.get("affected_files", 0) if result.summary else 0,
205+
}
206+
}
207+
result_dict["applied_filters"] = occurrences_result.applied_filters or {}
208+
result_dict["suggestion"] = occurrences_result.suggestion or ""
209+
210+
return RemediateSecretIncidentsResult(**result_dict)
175211

176212
except Exception as e:
177213
logger.error(f"Error remediating incidents: {str(e)}")
178214
return RemediateSecretIncidentsError(error=f"Failed to remediate incidents: {str(e)}")
179215

180216

181217
async def _process_occurrences_for_remediation(
182-
occurrences: list[dict[str, Any]],
183-
repository_name: str,
184-
include_git_commands: bool = True,
185-
create_env_example: bool = True,
186-
) -> dict[str, Any]:
218+
occurrences: list[dict[str, Any]],
219+
repository_name: str,
220+
include_git_commands: bool = True,
221+
create_env_example: bool = True,
222+
) -> RemediateSecretIncidentsResult:
187223
"""
188224
Process occurrences for remediation using exact match locations.
189225
@@ -319,4 +355,4 @@ async def _process_occurrences_for_remediation(
319355
if git_commands:
320356
result["git_commands"] = git_commands
321357

322-
return result
358+
return RemediateSecretIncidentsResult(**result)

tests/tools/test_list_repo_incidents.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def test_list_repo_incidents_with_repository_name(
1818
"""
1919
# Mock the client response
2020
mock_response = {
21-
"data": [
21+
"incidents": [
2222
{
2323
"id": "incident_1",
2424
"detector": {"name": "AWS Access Key"},
@@ -56,7 +56,8 @@ async def test_list_repo_incidents_with_repository_name(
5656
assert call_kwargs["mine"] is True
5757

5858
# Verify response
59-
assert result == mock_response
59+
assert result.total_count == mock_response["total_count"]
60+
assert len(result.incidents) == len(mock_response["incidents"])
6061

6162
@pytest.mark.asyncio
6263
async def test_list_repo_incidents_with_source_id(self, mock_gitguardian_client):
@@ -106,9 +107,9 @@ async def test_list_repo_incidents_with_source_id(self, mock_gitguardian_client)
106107
assert call_args[1]["with_sources"] == "false"
107108

108109
# Verify response
109-
assert "source_id" in result
110-
assert result["source_id"] == "source_123"
111-
assert len(result["incidents"]) == 1
110+
assert hasattr(result, "source_id")
111+
assert result.source_id == "source_123"
112+
assert len(result.incidents) == 1
112113

113114
@pytest.mark.asyncio
114115
async def test_list_repo_incidents_with_filters(self, mock_gitguardian_client):
@@ -182,8 +183,8 @@ async def test_list_repo_incidents_get_all(self, mock_gitguardian_client):
182183
mock_gitguardian_client.paginate_all.assert_called_once()
183184

184185
# Verify response
185-
assert result["total_count"] == 3
186-
assert len(result["incidents"]) == 3
186+
assert result.total_count == 3
187+
assert len(result.incidents) == 3
187188

188189
@pytest.mark.asyncio
189190
async def test_list_repo_incidents_no_repository_or_source(
@@ -212,8 +213,8 @@ async def test_list_repo_incidents_no_repository_or_source(
212213
)
213214

214215
# Verify error response
215-
assert "error" in result
216-
assert "Either repository_name or source_id must be provided" in result["error"]
216+
assert hasattr(result, "error")
217+
assert "Either repository_name or source_id must be provided" in result.error
217218

218219
@pytest.mark.asyncio
219220
async def test_list_repo_incidents_client_error(self, mock_gitguardian_client):
@@ -246,8 +247,8 @@ async def test_list_repo_incidents_client_error(self, mock_gitguardian_client):
246247
)
247248

248249
# Verify error response
249-
assert "error" in result
250-
assert "Failed to list repository incidents" in result["error"]
250+
assert hasattr(result, "error")
251+
assert "Failed to list repository incidents" in result.error
251252

252253
@pytest.mark.asyncio
253254
async def test_list_repo_incidents_with_cursor(self, mock_gitguardian_client):
@@ -322,9 +323,9 @@ async def test_list_repo_incidents_source_id_list_response(
322323
)
323324

324325
# Verify response format
325-
assert result["source_id"] == "source_123"
326-
assert result["total_count"] == 2
327-
assert len(result["incidents"]) == 2
326+
assert result.source_id == "source_123"
327+
assert result.total_count == 2
328+
assert len(result.incidents) == 2
328329

329330
@pytest.mark.asyncio
330331
async def test_list_repo_incidents_get_all_dict_response(
@@ -348,6 +349,6 @@ async def test_list_repo_incidents_get_all_dict_response(
348349
)
349350

350351
# Verify response
351-
assert result["source_id"] == "source_123"
352-
assert result["total_count"] == 2
353-
assert len(result["incidents"]) == 2
352+
assert result.source_id == "source_123"
353+
assert result.total_count == 2
354+
assert len(result.incidents) == 2

tests/tools/test_list_repo_occurrences.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import AsyncMock
22

33
import pytest
4+
from pydantic import ValidationError
45
from gg_api_core.tools.list_repo_occurrences import list_repo_occurrences, ListRepoOccurrencesParams
56

67

@@ -184,11 +185,11 @@ async def test_list_repo_occurrences_no_repository_or_source(
184185
):
185186
"""
186187
GIVEN: Neither repository_name nor source_id provided
187-
WHEN: Attempting to list occurrences
188-
THEN: An error is returned
188+
WHEN: Attempting to create params
189+
THEN: A ValidationError is raised
189190
"""
190-
# Call the function without repository_name or source_id
191-
result = await list_repo_occurrences(
191+
# Try to create params without repository_name or source_id
192+
with pytest.raises(ValidationError) as exc_info:
192193
ListRepoOccurrencesParams(
193194
repository_name=None,
194195
source_id=None,
@@ -201,11 +202,11 @@ async def test_list_repo_occurrences_no_repository_or_source(
201202
cursor=None,
202203
get_all=False,
203204
)
204-
)
205205

206-
# Verify error response
207-
assert hasattr(result, "error")
208-
assert "Either repository_name or source_id must be provided" in result.error
206+
# Verify error message
207+
errors = exc_info.value.errors()
208+
assert len(errors) == 1
209+
assert "Either 'source_id' or 'repository_name' must be provided" in str(errors[0])
209210

210211
@pytest.mark.asyncio
211212
async def test_list_repo_occurrences_client_error(self, mock_gitguardian_client):

0 commit comments

Comments
 (0)