Skip to content

Commit 7a55c14

Browse files
committed
feat(core): Add timeout to OllamaClient
Adds a timeout to the `OllamaClient` to prevent indefinite waits during API calls. This improves the robustness of the application by ensuring operations complete within a reasonable timeframe. The timeout value is configurable and helps avoid blocking the main thread. This change enhances the overall reliability of interacting with the Ollama API. Affected files: - M setup.py - M smart_git_commit.py - M tests/test_smart_git_commit.py - + run_tests.py
1 parent 229079a commit 7a55c14

File tree

4 files changed

+388
-25
lines changed

4 files changed

+388
-25
lines changed

run_tests.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Helper script to run tests with coverage reporting for the Smart Git Commit tool.
4+
"""
5+
6+
import os
7+
import sys
8+
import subprocess
9+
import argparse
10+
11+
12+
def run_tests(verbose=False, coverage=False):
13+
"""Run the test suite with optional coverage reporting."""
14+
print("Running Smart Git Commit tests...")
15+
16+
cmd = ["python", "-m", "unittest", "discover", "-s", "tests"]
17+
18+
if verbose:
19+
cmd.append("-v")
20+
21+
if coverage:
22+
try:
23+
# Check if coverage is installed
24+
import coverage
25+
print("Running tests with coverage reporting...")
26+
27+
# Create a coverage object
28+
cov = coverage.Coverage()
29+
cov.start()
30+
31+
# Run the tests
32+
subprocess.run(cmd, check=True)
33+
34+
# Stop coverage and generate report
35+
cov.stop()
36+
cov.save()
37+
38+
print("\nCoverage report:")
39+
cov.report()
40+
41+
# Generate HTML report
42+
report_dir = os.path.join(os.path.dirname(__file__), "coverage_html")
43+
cov.html_report(directory=report_dir)
44+
print(f"\nDetailed HTML coverage report generated in: {report_dir}")
45+
46+
return True
47+
48+
except ImportError:
49+
print("Warning: coverage package not installed. Running tests without coverage.")
50+
coverage = False
51+
52+
if not coverage:
53+
# Run tests without coverage
54+
result = subprocess.run(cmd)
55+
return result.returncode == 0
56+
57+
58+
def parse_args():
59+
"""Parse command line arguments."""
60+
parser = argparse.ArgumentParser(description="Run tests for Smart Git Commit")
61+
parser.add_argument("-v", "--verbose", action="store_true", help="Run tests in verbose mode")
62+
parser.add_argument("-c", "--coverage", action="store_true", help="Generate coverage report")
63+
64+
return parser.parse_args()
65+
66+
67+
if __name__ == "__main__":
68+
args = parse_args()
69+
success = run_tests(verbose=args.verbose, coverage=args.coverage)
70+
sys.exit(0 if success else 1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
setup(
1515
name="smart-git-commit",
16-
version="0.1.4",
16+
version="0.1.5",
1717
description="AI-powered Git commit workflow tool",
1818
long_description=long_description,
1919
long_description_content_type="text/markdown",

smart_git_commit.py

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,39 @@ def generate_commit_message(self) -> str:
197197
class OllamaClient:
198198
"""Client for interacting with Ollama API with GPU acceleration."""
199199

200-
def __init__(self, host: str = "http://localhost:11434", model: Optional[str] = None):
200+
def __init__(self, host: str = "http://localhost:11434", model: Optional[str] = None, timeout: int = 10):
201201
"""
202202
Initialize the Ollama client.
203203
204204
Args:
205205
host: Host for Ollama API
206206
model: Model to use for Ollama, if None will prompt user to select one
207+
timeout: Timeout in seconds for HTTP requests
207208
"""
208209
self.host = host
209210
self.headers = {"Content-Type": "application/json"}
210-
self.available_models = self._get_available_models()
211+
self.timeout = timeout
211212

212-
if not self.available_models:
213-
logger.warning("No models found in Ollama. Make sure Ollama is running.")
214-
raise RuntimeError("No Ollama models available")
213+
try:
214+
self.available_models = self._get_available_models()
215215

216-
if model is None:
217-
self.model = self._select_model()
218-
else:
219-
if model not in self.available_models:
220-
logger.warning(f"Model {model} not found. Available models: {', '.join(self.available_models)}")
216+
if not self.available_models:
217+
logger.warning("No models found in Ollama. Make sure Ollama is running.")
218+
raise RuntimeError("No Ollama models available")
219+
220+
if model is None:
221221
self.model = self._select_model()
222222
else:
223-
self.model = model
224-
225-
logger.info(f"Using Ollama model: {self.model}")
223+
if model not in self.available_models:
224+
logger.warning(f"Model {model} not found. Available models: {', '.join(self.available_models)}")
225+
self.model = self._select_model()
226+
else:
227+
self.model = model
228+
229+
logger.info(f"Using Ollama model: {self.model}")
230+
except Exception as e:
231+
logger.error(f"Error initializing Ollama client: {str(e)}")
232+
raise
226233

227234
def _get_host_connection(self) -> Tuple[str, int]:
228235
"""Parse host string and return connection parameters."""
@@ -239,9 +246,25 @@ def _get_host_connection(self) -> Tuple[str, int]:
239246
host = self.host.split(':')[0] # Handle case if port is included
240247
port = 11434
241248

242-
# Test connection before returning
249+
# Test connection before returning with a short timeout
250+
socket.setdefaulttimeout(self.timeout)
243251
socket.getaddrinfo(host, port)
244252
return host, port
253+
except socket.gaierror as e:
254+
logger.warning(f"DNS resolution error for {self.host}: {str(e)}")
255+
# Fall back to localhost if specified host fails
256+
if self.host != "localhost" and self.host != "http://localhost:11434":
257+
logger.info("Trying localhost as fallback")
258+
self.host = "http://localhost:11434"
259+
return "localhost", 11434
260+
raise
261+
except socket.timeout:
262+
logger.warning(f"Connection timeout to {self.host}")
263+
if self.host != "localhost" and self.host != "http://localhost:11434":
264+
logger.info("Trying localhost as fallback")
265+
self.host = "http://localhost:11434"
266+
return "localhost", 11434
267+
raise RuntimeError(f"Connection timeout to {self.host}")
245268
except Exception as e:
246269
logger.warning(f"Connection error to {self.host}: {str(e)}")
247270
# Fall back to localhost if specified host fails
@@ -255,9 +278,14 @@ def _get_available_models(self) -> List[str]:
255278
"""Get a list of available models from Ollama."""
256279
try:
257280
host, port = self._get_host_connection()
258-
conn = http.client.HTTPConnection(host, port)
281+
conn = http.client.HTTPConnection(host, port, timeout=self.timeout)
259282
conn.request("GET", "/api/tags")
260283
response = conn.getresponse()
284+
285+
if response.status != 200:
286+
logger.warning(f"Failed to get models: HTTP {response.status} {response.reason}")
287+
return self._get_models_from_cli()
288+
261289
data = json.loads(response.read().decode())
262290

263291
# Different Ollama API versions might return models differently
@@ -271,6 +299,15 @@ def _get_available_models(self) -> List[str]:
271299
# Try to run ollama list directly if API doesn't work
272300
return self._get_models_from_cli()
273301

302+
except json.JSONDecodeError:
303+
logger.warning("Invalid JSON response from Ollama API")
304+
return self._get_models_from_cli()
305+
except http.client.HTTPException as e:
306+
logger.warning(f"HTTP error when connecting to Ollama: {str(e)}")
307+
return self._get_models_from_cli()
308+
except socket.timeout:
309+
logger.warning("Connection timeout when retrieving models from Ollama API")
310+
return self._get_models_from_cli()
274311
except Exception as e:
275312
logger.warning(f"Failed to get models from Ollama API: {str(e)}")
276313
# Try command-line fallback
@@ -285,8 +322,9 @@ def _get_models_from_cli(self) -> List[str]:
285322
stderr=subprocess.PIPE,
286323
text=True
287324
)
288-
stdout, stderr = process.communicate()
325+
stdout, stderr = process.communicate(timeout=self.timeout)
289326
if process.returncode != 0:
327+
logger.warning(f"Ollama CLI failed with error: {stderr}")
290328
return []
291329

292330
models = []
@@ -297,7 +335,14 @@ def _get_models_from_cli(self) -> List[str]:
297335
if parts:
298336
models.append(parts[0])
299337
return models
300-
except Exception:
338+
except subprocess.TimeoutExpired:
339+
logger.warning("Timeout running 'ollama list' command")
340+
return []
341+
except FileNotFoundError:
342+
logger.warning("Ollama command not found in PATH")
343+
return []
344+
except Exception as e:
345+
logger.warning(f"Error getting models from CLI: {str(e)}")
301346
return []
302347

303348
def _select_model(self) -> str:
@@ -321,12 +366,16 @@ def _select_model(self) -> str:
321366
if selection in self.available_models:
322367
return selection
323368
print("Please enter a valid model number or name")
369+
except KeyboardInterrupt:
370+
# If user interrupts, use first model as default
371+
print("\nInterrupted, using first available model")
372+
return self.available_models[0]
324373

325374
def generate(self, prompt: str, system_prompt: str = "", max_tokens: int = 2000) -> str:
326375
"""Generate text using Ollama."""
327376
try:
328377
host, port = self._get_host_connection()
329-
conn = http.client.HTTPConnection(host, port)
378+
conn = http.client.HTTPConnection(host, port, timeout=self.timeout)
330379

331380
data = {
332381
"model": self.model,
@@ -338,9 +387,23 @@ def generate(self, prompt: str, system_prompt: str = "", max_tokens: int = 2000)
338387

339388
conn.request("POST", "/api/generate", json.dumps(data), self.headers)
340389
response = conn.getresponse()
390+
391+
if response.status != 200:
392+
logger.warning(f"Failed to generate text: HTTP {response.status} {response.reason}")
393+
return ""
394+
341395
result = json.loads(response.read().decode())
342396

343397
return result.get("response", "")
398+
except json.JSONDecodeError:
399+
logger.warning("Invalid JSON response from Ollama API during generation")
400+
return ""
401+
except http.client.HTTPException as e:
402+
logger.warning(f"HTTP error when generating text: {str(e)}")
403+
return ""
404+
except socket.timeout:
405+
logger.warning("Timeout when generating text with Ollama")
406+
return ""
344407
except Exception as e:
345408
logger.warning(f"Failed to generate text with Ollama: {str(e)}")
346409
return ""
@@ -350,7 +413,7 @@ class SmartGitCommitWorkflow:
350413
"""Manages the workflow for analyzing, grouping, and committing changes with AI assistance."""
351414

352415
def __init__(self, repo_path: str = ".", ollama_host: str = "http://localhost:11434",
353-
ollama_model: Optional[str] = None, use_ai: bool = True):
416+
ollama_model: Optional[str] = None, use_ai: bool = True, timeout: int = 10):
354417
"""
355418
Initialize the workflow.
356419
@@ -359,16 +422,18 @@ def __init__(self, repo_path: str = ".", ollama_host: str = "http://localhost:11
359422
ollama_host: Host for Ollama API
360423
ollama_model: Model to use for Ollama, if None will prompt user to select
361424
use_ai: Whether to use AI-powered analysis
425+
timeout: Timeout in seconds for HTTP requests to Ollama
362426
"""
363427
self.repo_path = repo_path
364428
self.changes: List[GitChange] = []
365429
self.commit_groups: List[CommitGroup] = []
366430
self.use_ai = use_ai
367431
self.ollama = None
432+
self.timeout = timeout
368433

369434
if use_ai:
370435
try:
371-
self.ollama = OllamaClient(host=ollama_host, model=ollama_model)
436+
self.ollama = OllamaClient(host=ollama_host, model=ollama_model, timeout=timeout)
372437
except Exception as e:
373438
logger.warning(f"Failed to initialize Ollama client: {str(e)}")
374439
logger.info("Falling back to rule-based analysis")
@@ -448,6 +513,10 @@ def load_changes(self) -> None:
448513
status = line[:2].strip()
449514
filename = line[3:].strip()
450515

516+
# Remove any leading "backend/" or similar prefix that might come from running in a subdirectory
517+
if " -> " in filename: # Handle renamed files
518+
old_path, filename = filename.split(" -> ")
519+
451520
# Get diff content for modified files
452521
content_diff = None
453522
if status != "??": # Not for untracked files
@@ -884,15 +953,30 @@ def execute_commits(self, interactive: bool = True) -> None:
884953

885954
# Execute the commit
886955
# Write commit message with UTF-8 encoding explicitly
887-
commit_msg_path = os.path.join(self.repo_path, ".git", "COMMIT_EDITMSG")
888956
try:
957+
# First make sure .git directory exists
958+
git_dir = os.path.join(self.repo_path, ".git")
959+
if not os.path.isdir(git_dir):
960+
# Try to find the git directory
961+
stdout, _ = self._run_git_command(["rev-parse", "--git-dir"])
962+
git_dir = stdout.strip()
963+
if not os.path.isdir(git_dir):
964+
git_dir = os.path.join(self.repo_path, git_dir)
965+
966+
# Now create the commit message file
967+
commit_msg_path = os.path.join(git_dir, "COMMIT_EDITMSG")
968+
889969
with open(commit_msg_path, "w", encoding='utf-8') as f:
890970
f.write(commit_message)
891971

892-
stdout, code = self._run_git_command(["commit", "-F", os.path.join(".git", "COMMIT_EDITMSG")])
972+
stdout, code = self._run_git_command(["commit", "-F", commit_msg_path])
973+
except Exception as e:
974+
logger.error(f"Failed to create or use commit message file: {str(e)}")
975+
# Try direct commit as fallback
976+
stdout, code = self._run_git_command(["commit", "-m", commit_message])
893977
finally:
894978
# Clean up the temporary commit message file
895-
if os.path.exists(commit_msg_path):
979+
if 'commit_msg_path' in locals() and os.path.exists(commit_msg_path):
896980
try:
897981
os.remove(commit_msg_path)
898982
except OSError as e:
@@ -962,14 +1046,16 @@ def main() -> int:
9621046
parser.add_argument("--ollama-host", help="Host for Ollama API", default="http://localhost:11434")
9631047
parser.add_argument("--ollama-model", help="Model to use for Ollama (will prompt if not specified)")
9641048
parser.add_argument("--no-ai", action="store_true", help="Disable AI-powered analysis")
1049+
parser.add_argument("--timeout", type=int, help="Timeout in seconds for HTTP requests", default=10)
9651050
args = parser.parse_args()
9661051

9671052
try:
9681053
workflow = SmartGitCommitWorkflow(
9691054
repo_path=args.repo_path,
9701055
ollama_host=args.ollama_host,
9711056
ollama_model=args.ollama_model,
972-
use_ai=not args.no_ai
1057+
use_ai=not args.no_ai,
1058+
timeout=args.timeout
9731059
)
9741060

9751061
workflow.load_changes()

0 commit comments

Comments
 (0)