|
2 | 2 | import json |
3 | 3 | import logging |
4 | 4 | import time |
| 5 | +import re |
5 | 6 | from typing import List, Optional |
6 | 7 | from pydantic import BaseModel, Field |
7 | 8 | import google.generativeai as genai |
| 9 | +from google.api_core.exceptions import ResourceExhausted |
8 | 10 | import ollama |
9 | 11 |
|
| 12 | + |
| 13 | +# --- Gemini API Wrapper with Dynamic Retry Logic --- |
| 14 | +def call_gemini_with_retry(model, prompt: str, max_retries: int = 2): |
| 15 | + """ |
| 16 | + Calls the Gemini API with a dynamic retry mechanism based on the API's feedback. |
| 17 | + """ |
| 18 | + attempt = 0 |
| 19 | + while attempt < max_retries: |
| 20 | + try: |
| 21 | + response = model.generate_content(prompt) |
| 22 | + return response |
| 23 | + except ResourceExhausted as e: |
| 24 | + attempt += 1 |
| 25 | + error_message = str(e) |
| 26 | + |
| 27 | + # Use regex to find the retry delay in the error message |
| 28 | + match = re.search(r"retry_delay {\s*seconds: (\d+)\s*}", error_message) |
| 29 | + |
| 30 | + if match: |
| 31 | + wait_time = int(match.group(1)) + 1 # Add a 1-second buffer |
| 32 | + logging.warning( |
| 33 | + f"Gemini API quota exceeded. Retrying after {wait_time} seconds (attempt {attempt}/{max_retries})." |
| 34 | + ) |
| 35 | + time.sleep(wait_time) |
| 36 | + else: |
| 37 | + # If no specific delay is found, wait a default time or re-raise |
| 38 | + logging.warning( |
| 39 | + f"Gemini API quota exceeded, but no retry_delay found. " |
| 40 | + f"Waiting 60 seconds before attempt {attempt}/{max_retries}." |
| 41 | + ) |
| 42 | + time.sleep(60) # Fallback wait time |
| 43 | + |
| 44 | + except Exception as e: |
| 45 | + logging.error(f"An unexpected error occurred calling Gemini API: {e}") |
| 46 | + raise e # Re-raise other exceptions immediately |
| 47 | + |
| 48 | + logging.error(f"Gemini API call failed after {max_retries} attempts.") |
| 49 | + raise Exception("Gemini API call failed after multiple retries.") |
| 50 | + |
| 51 | + |
10 | 52 | # --- Pydantic Models for Agent Communication --- |
11 | 53 | # These models define the "contracts" for data passed between agents. |
12 | 54 |
|
@@ -92,7 +134,7 @@ def agent_step_planner(query: str, model_provider: str, model_name: str) -> List |
92 | 134 | cleaned_response = response['message']['content'] |
93 | 135 | else: # Default to online |
94 | 136 | model = genai.GenerativeModel(model_name) |
95 | | - response = model.generate_content(prompt) |
| 137 | + response = call_gemini_with_retry(model, prompt) |
96 | 138 | cleaned_response = response.text.strip().lstrip("```json").rstrip("```").strip() |
97 | 139 |
|
98 | 140 | planned_steps_data = json.loads(cleaned_response) |
@@ -157,7 +199,7 @@ def agent_element_identifier(steps: List[PlannedStep], model_provider: str, mode |
157 | 199 | cleaned_response = response['message']['content'] |
158 | 200 | else: # Default to online |
159 | 201 | model = genai.GenerativeModel(model_name) |
160 | | - response = model.generate_content(prompt) |
| 202 | + response = call_gemini_with_retry(model, prompt) |
161 | 203 | cleaned_response = response.text.strip().lstrip("```json").rstrip("```").strip() |
162 | 204 |
|
163 | 205 | locator_data = json.loads(cleaned_response) |
@@ -271,7 +313,7 @@ def agent_code_validator(code: str, model_provider: str, model_name: str) -> Val |
271 | 313 | cleaned_response = response['message']['content'] |
272 | 314 | else: # Default to online |
273 | 315 | model = genai.GenerativeModel(model_name) |
274 | | - response = model.generate_content(prompt) |
| 316 | + response = call_gemini_with_retry(model, prompt) |
275 | 317 | cleaned_response = response.text.strip().lstrip("```json").rstrip("```").strip() |
276 | 318 |
|
277 | 319 | validation_data = json.loads(cleaned_response) |
|
0 commit comments