Skip to content

Commit b9c91a9

Browse files
committed
chore(typing): Add Pydantic models for tools return : remediate and list
1 parent 7e64584 commit b9c91a9

File tree

5 files changed

+195
-142
lines changed

5 files changed

+195
-142
lines changed

packages/gg_api_core/src/gg_api_core/tools/list_repo_incidents.py

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,22 @@ class ListRepoIncidentsParams(BaseModel):
128128
validity: list[str] | None = Field(default=DEFAULT_VALIDITIES, description="Filter by validity (list of validity names)")
129129

130130

131-
async def list_repo_incidents(params: ListRepoIncidentsParams) -> dict[str, Any]:
131+
class ListRepoIncidentsResult(BaseModel):
132+
"""Result from listing repository incidents."""
133+
source_id: str | None = Field(default=None, description="Source ID of the repository")
134+
incidents: list[dict[str, Any]] = Field(default_factory=list, description="List of incident objects")
135+
total_count: int = Field(description="Total number of incidents")
136+
next_cursor: str | None = Field(default=None, description="Pagination cursor for next page")
137+
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters that were applied to the query")
138+
suggestion: str = Field(default="", description="Suggestions for interpreting or modifying the results")
139+
140+
141+
class ListRepoIncidentsError(BaseModel):
142+
"""Error result from listing repository incidents."""
143+
error: str = Field(description="Error message")
144+
145+
146+
async def list_repo_incidents(params: ListRepoIncidentsParams) -> ListRepoIncidentsResult | ListRepoIncidentsError:
132147
"""
133148
List secret incidents or occurrences related to a specific repository.
134149
@@ -145,7 +160,7 @@ async def list_repo_incidents(params: ListRepoIncidentsParams) -> dict[str, Any]
145160

146161
# Validate that at least one of repository_name or source_id is provided
147162
if not params.repository_name and not params.source_id:
148-
return {"error": "Either repository_name or source_id must be provided"}
163+
return ListRepoIncidentsError(error="Either repository_name or source_id must be provided")
149164

150165
logger.debug(f"Listing incidents with repository_name={params.repository_name}, source_id={params.source_id}")
151166

@@ -185,64 +200,54 @@ async def list_repo_incidents(params: ListRepoIncidentsParams) -> dict[str, Any]
185200
incidents_result = await client.paginate_all(f"/sources/{params.source_id}/incidents/secrets", api_params)
186201
if isinstance(incidents_result, list):
187202
count = len(incidents_result)
188-
return {
189-
"source_id": params.source_id,
190-
"incidents": incidents_result,
191-
"total_count": count,
192-
"applied_filters": _build_filter_info(params),
193-
"suggestion": _build_suggestion(params, count),
194-
}
203+
return ListRepoIncidentsResult(
204+
source_id=params.source_id,
205+
incidents=incidents_result,
206+
total_count=count,
207+
applied_filters=_build_filter_info(params),
208+
suggestion=_build_suggestion(params, count),
209+
)
195210
elif isinstance(incidents_result, dict):
196211
count = incidents_result.get("total_count", len(incidents_result.get("data", [])))
197-
return {
198-
"source_id": params.source_id,
199-
"incidents": incidents_result.get("data", []),
200-
"total_count": count,
201-
"applied_filters": _build_filter_info(params),
202-
"suggestion": _build_suggestion(params, count),
203-
}
212+
return ListRepoIncidentsResult(
213+
source_id=params.source_id,
214+
incidents=incidents_result.get("data", []),
215+
total_count=count,
216+
applied_filters=_build_filter_info(params),
217+
suggestion=_build_suggestion(params, count),
218+
)
204219
else:
205220
# Fallback for unexpected types
206-
return {
207-
"source_id": params.source_id,
208-
"incidents": [],
209-
"total_count": 0,
210-
"error": f"Unexpected response type: {type(incidents_result).__name__}",
211-
"applied_filters": _build_filter_info(params),
212-
"suggestion": _build_suggestion(params, 0),
213-
}
221+
return ListRepoIncidentsError(
222+
error=f"Unexpected response type: {type(incidents_result).__name__}",
223+
)
214224
else:
215225
incidents_result = await client.list_source_incidents(params.source_id, **api_params)
216226
if isinstance(incidents_result, dict):
217227
count = incidents_result.get("total_count", 0)
218-
return {
219-
"source_id": params.source_id,
220-
"incidents": incidents_result.get("data", []),
221-
"next_cursor": incidents_result.get("next_cursor"),
222-
"total_count": count,
223-
"applied_filters": _build_filter_info(params),
224-
"suggestion": _build_suggestion(params, count),
225-
}
228+
return ListRepoIncidentsResult(
229+
source_id=params.source_id,
230+
incidents=incidents_result.get("data", []),
231+
next_cursor=incidents_result.get("next_cursor"),
232+
total_count=count,
233+
applied_filters=_build_filter_info(params),
234+
suggestion=_build_suggestion(params, count),
235+
)
226236
elif isinstance(incidents_result, list):
227237
# Handle case where API returns a list directly
228238
count = len(incidents_result)
229-
return {
230-
"source_id": params.source_id,
231-
"incidents": incidents_result,
232-
"total_count": count,
233-
"applied_filters": _build_filter_info(params),
234-
"suggestion": _build_suggestion(params, count),
235-
}
239+
return ListRepoIncidentsResult(
240+
source_id=params.source_id,
241+
incidents=incidents_result,
242+
total_count=count,
243+
applied_filters=_build_filter_info(params),
244+
suggestion=_build_suggestion(params, count),
245+
)
236246
else:
237247
# Fallback for unexpected types
238-
return {
239-
"source_id": params.source_id,
240-
"incidents": [],
241-
"total_count": 0,
242-
"error": f"Unexpected response type: {type(incidents_result).__name__}",
243-
"applied_filters": _build_filter_info(params),
244-
"suggestion": _build_suggestion(params, 0),
245-
}
248+
return ListRepoIncidentsError(
249+
error=f"Unexpected response type: {type(incidents_result).__name__}",
250+
)
246251
else:
247252
# Use repository_name lookup (legacy path)
248253
result = await client.list_repo_incidents_directly(
@@ -259,14 +264,20 @@ async def list_repo_incidents(params: ListRepoIncidentsParams) -> dict[str, Any]
259264
mine=params.mine,
260265
)
261266

262-
# Enrich result with filter info
267+
# Enrich result with filter info and convert to Pydantic model
263268
if isinstance(result, dict):
264269
count = result.get("total_count", len(result.get("incidents", [])))
265-
result["applied_filters"] = _build_filter_info(params)
266-
result["suggestion"] = _build_suggestion(params, count)
267-
268-
return result
270+
return ListRepoIncidentsResult(
271+
source_id=result.get("source_id"),
272+
incidents=result.get("incidents", []),
273+
total_count=count,
274+
next_cursor=result.get("next_cursor"),
275+
applied_filters=_build_filter_info(params),
276+
suggestion=_build_suggestion(params, count),
277+
)
278+
else:
279+
return ListRepoIncidentsError(error="Unexpected result format from legacy path")
269280

270281
except Exception as e:
271282
logger.error(f"Error listing repository incidents: {str(e)}")
272-
return {"error": f"Failed to list repository incidents: {str(e)}"}
283+
return ListRepoIncidentsError(error=f"Failed to list repository incidents: {str(e)}")

packages/gg_api_core/src/gg_api_core/tools/list_repo_occurrences.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ class ListRepoOccurrencesParams(ListRepoOccurrencesFilters, ListRepoOccurrencesB
7474
pass
7575

7676

77+
class ListRepoOccurrencesResult(BaseModel):
78+
"""Result from listing repository occurrences."""
79+
repository: str | None = Field(default=None, description="Repository name")
80+
occurrences_count: int = Field(description="Number of occurrences returned")
81+
occurrences: list[dict[str, Any]] = Field(default_factory=list, description="List of occurrence objects")
82+
cursor: str | None = Field(default=None, description="Pagination cursor for next page")
83+
has_more: bool = Field(default=False, description="Whether more results are available")
84+
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters that were applied to the query")
85+
suggestion: str = Field(default="", description="Suggestions for interpreting or modifying the results")
86+
87+
88+
class ListRepoOccurrencesError(BaseModel):
89+
"""Error result from listing repository occurrences."""
90+
error: str = Field(description="Error message")
91+
92+
7793
def _build_filter_info(params: ListRepoOccurrencesParams) -> dict[str, Any]:
7894
"""Build a dictionary describing the filters applied to the query."""
7995
filters = {}
@@ -127,7 +143,7 @@ def _build_suggestion(params: ListRepoOccurrencesParams, occurrences_count: int)
127143
return "\n".join(suggestions) if suggestions else ""
128144

129145

130-
async def list_repo_occurrences(params: ListRepoOccurrencesParams) -> dict[str, Any]:
146+
async def list_repo_occurrences(params: ListRepoOccurrencesParams) -> ListRepoOccurrencesResult | ListRepoOccurrencesError:
131147
"""
132148
List secret occurrences for a specific repository using the GitGuardian v1/occurrences/secrets API.
133149
@@ -159,7 +175,7 @@ async def list_repo_occurrences(params: ListRepoOccurrencesParams) -> dict[str,
159175

160176
# Validate that at least one of repository_name or source_id is provided
161177
if not params.repository_name and not params.source_id:
162-
return {"error": "Either repository_name or source_id must be provided"}
178+
return ListRepoOccurrencesError(error="Either repository_name or source_id must be provided")
163179

164180
logger.debug(f"Listing occurrences with repository_name={params.repository_name}, source_id={params.source_id}")
165181

@@ -208,34 +224,34 @@ async def list_repo_occurrences(params: ListRepoOccurrencesParams) -> dict[str,
208224
if isinstance(result, dict):
209225
occurrences = result.get("occurrences", [])
210226
count = len(occurrences)
211-
return {
212-
"repository": params.repository_name,
213-
"occurrences_count": count,
214-
"occurrences": occurrences,
215-
"cursor": result.get("cursor"),
216-
"has_more": result.get("has_more", False),
217-
"applied_filters": _build_filter_info(params),
218-
"suggestion": _build_suggestion(params, count),
219-
}
227+
return ListRepoOccurrencesResult(
228+
repository=params.repository_name,
229+
occurrences_count=count,
230+
occurrences=occurrences,
231+
cursor=result.get("cursor"),
232+
has_more=result.get("has_more", False),
233+
applied_filters=_build_filter_info(params),
234+
suggestion=_build_suggestion(params, count),
235+
)
220236
elif isinstance(result, list):
221237
# If get_all=True, we get a list directly
222238
count = len(result)
223-
return {
224-
"repository": params.repository_name,
225-
"occurrences_count": count,
226-
"occurrences": result,
227-
"applied_filters": _build_filter_info(params),
228-
"suggestion": _build_suggestion(params, count),
229-
}
239+
return ListRepoOccurrencesResult(
240+
repository=params.repository_name,
241+
occurrences_count=count,
242+
occurrences=result,
243+
applied_filters=_build_filter_info(params),
244+
suggestion=_build_suggestion(params, count),
245+
)
230246
else:
231-
return {
232-
"repository": params.repository_name,
233-
"occurrences_count": 0,
234-
"occurrences": [],
235-
"applied_filters": _build_filter_info(params),
236-
"suggestion": _build_suggestion(params, 0),
237-
}
247+
return ListRepoOccurrencesResult(
248+
repository=params.repository_name,
249+
occurrences_count=0,
250+
occurrences=[],
251+
applied_filters=_build_filter_info(params),
252+
suggestion=_build_suggestion(params, 0),
253+
)
238254

239255
except Exception as e:
240256
logger.error(f"Error listing repository occurrences: {str(e)}")
241-
return {"error": f"Failed to list repository occurrences: {str(e)}"}
257+
return ListRepoOccurrencesError(error=f"Failed to list repository occurrences: {str(e)}")

packages/gg_api_core/src/gg_api_core/tools/remediate_secret_incidents.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,25 @@ def validate_source_or_repository(self) -> "RemediateSecretIncidentsParams":
5656
return self
5757

5858

59-
async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) -> dict[str, Any]:
59+
class RemediateSecretIncidentsResult(BaseModel):
60+
"""Result from remediating secret incidents."""
61+
repository_info: dict[str, Any] = Field(description="Information about the repository")
62+
summary: dict[str, Any] | None = Field(default=None, description="Summary of occurrences, files, and secret types")
63+
remediation_steps: list[dict[str, Any]] = Field(default_factory=list, description="Steps for remediating each file")
64+
message: str | None = Field(default=None, description="Message when no occurrences found")
65+
env_example_content: str | None = Field(default=None, description="Suggested .env.example content")
66+
env_example_instructions: list[str] | None = Field(default=None, description="Instructions for .env.example")
67+
git_commands: dict[str, Any] | None = Field(default=None, description="Git commands to fix history")
68+
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters applied when querying occurrences")
69+
suggestion: str = Field(default="", description="Suggestions for interpreting results")
70+
71+
72+
class RemediateSecretIncidentsError(BaseModel):
73+
"""Error result from remediating secret incidents."""
74+
error: str = Field(description="Error message")
75+
76+
77+
async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) -> RemediateSecretIncidentsResult | RemediateSecretIncidentsError:
6078
"""
6179
Find and remediate secret incidents in the current repository using EXACT match locations.
6280
@@ -108,10 +126,17 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
108126
)
109127
occurrences_result = await list_repo_occurrences(occurrences_params)
110128

111-
if "error" in occurrences_result:
112-
return {"error": occurrences_result["error"]}
129+
# Check if list_repo_occurrences returned an error
130+
if isinstance(occurrences_result, dict) and "error" in occurrences_result:
131+
return RemediateSecretIncidentsError(error=occurrences_result["error"])
132+
133+
# Since list_repo_occurrences now returns Pydantic models, handle both old dict and new model formats
134+
if hasattr(occurrences_result, 'model_dump'):
135+
occurrences_dict = occurrences_result.model_dump()
136+
else:
137+
occurrences_dict = occurrences_result
113138

114-
occurrences = occurrences_result.get("occurrences", [])
139+
occurrences = occurrences_dict.get("occurrences", [])
115140

116141
# Filter by assignee if mine=True
117142
if params.mine:
@@ -132,13 +157,13 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
132157
logger.warning(f"Could not filter by assignee: {str(e)}")
133158

134159
if not occurrences:
135-
return {
136-
"repository_info": {"name": params.repository_name},
137-
"message": "No secret occurrences found for this repository that match the criteria.",
138-
"remediation_steps": [],
139-
"applied_filters": occurrences_result.get("applied_filters", {}),
140-
"suggestion": occurrences_result.get("suggestion", ""),
141-
}
160+
return RemediateSecretIncidentsResult(
161+
repository_info={"name": params.repository_name},
162+
message="No secret occurrences found for this repository that match the criteria.",
163+
remediation_steps=[],
164+
applied_filters=occurrences_dict.get("applied_filters", {}),
165+
suggestion=occurrences_dict.get("suggestion", ""),
166+
)
142167

143168
# Process occurrences for remediation with exact location data
144169
logger.debug(f"Processing {len(occurrences)} occurrences with exact locations for remediation")
@@ -151,11 +176,12 @@ async def remediate_secret_incidents(params: RemediateSecretIncidentsParams) ->
151176
logger.debug(
152177
f"Remediation processing complete, returning result with {len(result.get('remediation_steps', []))} steps"
153178
)
154-
return result
179+
# Convert dict result to Pydantic model
180+
return RemediateSecretIncidentsResult(**result)
155181

156182
except Exception as e:
157183
logger.error(f"Error remediating incidents: {str(e)}")
158-
return {"error": f"Failed to remediate incidents: {str(e)}"}
184+
return RemediateSecretIncidentsError(error=f"Failed to remediate incidents: {str(e)}")
159185

160186

161187
async def _process_occurrences_for_remediation(

0 commit comments

Comments
 (0)