@@ -197,32 +197,39 @@ def generate_commit_message(self) -> str:
197
197
class OllamaClient :
198
198
"""Client for interacting with Ollama API with GPU acceleration."""
199
199
200
- def __init__ (self , host : str = "http://localhost:11434" , model : Optional [str ] = None ):
200
+ def __init__ (self , host : str = "http://localhost:11434" , model : Optional [str ] = None , timeout : int = 10 ):
201
201
"""
202
202
Initialize the Ollama client.
203
203
204
204
Args:
205
205
host: Host for Ollama API
206
206
model: Model to use for Ollama, if None will prompt user to select one
207
+ timeout: Timeout in seconds for HTTP requests
207
208
"""
208
209
self .host = host
209
210
self .headers = {"Content-Type" : "application/json" }
210
- self .available_models = self . _get_available_models ()
211
+ self .timeout = timeout
211
212
212
- if not self .available_models :
213
- logger .warning ("No models found in Ollama. Make sure Ollama is running." )
214
- raise RuntimeError ("No Ollama models available" )
213
+ try :
214
+ self .available_models = self ._get_available_models ()
215
215
216
- if model is None :
217
- self . model = self . _select_model ( )
218
- else :
219
- if model not in self . available_models :
220
- logger . warning ( f"Model { model } not found. Available models: { ', ' . join ( self . available_models ) } " )
216
+ if not self . available_models :
217
+ logger . warning ( "No models found in Ollama. Make sure Ollama is running." )
218
+ raise RuntimeError ( "No Ollama models available" )
219
+
220
+ if model is None :
221
221
self .model = self ._select_model ()
222
222
else :
223
- self .model = model
224
-
225
- logger .info (f"Using Ollama model: { self .model } " )
223
+ if model not in self .available_models :
224
+ logger .warning (f"Model { model } not found. Available models: { ', ' .join (self .available_models )} " )
225
+ self .model = self ._select_model ()
226
+ else :
227
+ self .model = model
228
+
229
+ logger .info (f"Using Ollama model: { self .model } " )
230
+ except Exception as e :
231
+ logger .error (f"Error initializing Ollama client: { str (e )} " )
232
+ raise
226
233
227
234
def _get_host_connection (self ) -> Tuple [str , int ]:
228
235
"""Parse host string and return connection parameters."""
@@ -239,9 +246,25 @@ def _get_host_connection(self) -> Tuple[str, int]:
239
246
host = self .host .split (':' )[0 ] # Handle case if port is included
240
247
port = 11434
241
248
242
- # Test connection before returning
249
+ # Test connection before returning with a short timeout
250
+ socket .setdefaulttimeout (self .timeout )
243
251
socket .getaddrinfo (host , port )
244
252
return host , port
253
+ except socket .gaierror as e :
254
+ logger .warning (f"DNS resolution error for { self .host } : { str (e )} " )
255
+ # Fall back to localhost if specified host fails
256
+ if self .host != "localhost" and self .host != "http://localhost:11434" :
257
+ logger .info ("Trying localhost as fallback" )
258
+ self .host = "http://localhost:11434"
259
+ return "localhost" , 11434
260
+ raise
261
+ except socket .timeout :
262
+ logger .warning (f"Connection timeout to { self .host } " )
263
+ if self .host != "localhost" and self .host != "http://localhost:11434" :
264
+ logger .info ("Trying localhost as fallback" )
265
+ self .host = "http://localhost:11434"
266
+ return "localhost" , 11434
267
+ raise RuntimeError (f"Connection timeout to { self .host } " )
245
268
except Exception as e :
246
269
logger .warning (f"Connection error to { self .host } : { str (e )} " )
247
270
# Fall back to localhost if specified host fails
@@ -255,9 +278,14 @@ def _get_available_models(self) -> List[str]:
255
278
"""Get a list of available models from Ollama."""
256
279
try :
257
280
host , port = self ._get_host_connection ()
258
- conn = http .client .HTTPConnection (host , port )
281
+ conn = http .client .HTTPConnection (host , port , timeout = self . timeout )
259
282
conn .request ("GET" , "/api/tags" )
260
283
response = conn .getresponse ()
284
+
285
+ if response .status != 200 :
286
+ logger .warning (f"Failed to get models: HTTP { response .status } { response .reason } " )
287
+ return self ._get_models_from_cli ()
288
+
261
289
data = json .loads (response .read ().decode ())
262
290
263
291
# Different Ollama API versions might return models differently
@@ -271,6 +299,15 @@ def _get_available_models(self) -> List[str]:
271
299
# Try to run ollama list directly if API doesn't work
272
300
return self ._get_models_from_cli ()
273
301
302
+ except json .JSONDecodeError :
303
+ logger .warning ("Invalid JSON response from Ollama API" )
304
+ return self ._get_models_from_cli ()
305
+ except http .client .HTTPException as e :
306
+ logger .warning (f"HTTP error when connecting to Ollama: { str (e )} " )
307
+ return self ._get_models_from_cli ()
308
+ except socket .timeout :
309
+ logger .warning ("Connection timeout when retrieving models from Ollama API" )
310
+ return self ._get_models_from_cli ()
274
311
except Exception as e :
275
312
logger .warning (f"Failed to get models from Ollama API: { str (e )} " )
276
313
# Try command-line fallback
@@ -285,8 +322,9 @@ def _get_models_from_cli(self) -> List[str]:
285
322
stderr = subprocess .PIPE ,
286
323
text = True
287
324
)
288
- stdout , stderr = process .communicate ()
325
+ stdout , stderr = process .communicate (timeout = self . timeout )
289
326
if process .returncode != 0 :
327
+ logger .warning (f"Ollama CLI failed with error: { stderr } " )
290
328
return []
291
329
292
330
models = []
@@ -297,7 +335,14 @@ def _get_models_from_cli(self) -> List[str]:
297
335
if parts :
298
336
models .append (parts [0 ])
299
337
return models
300
- except Exception :
338
+ except subprocess .TimeoutExpired :
339
+ logger .warning ("Timeout running 'ollama list' command" )
340
+ return []
341
+ except FileNotFoundError :
342
+ logger .warning ("Ollama command not found in PATH" )
343
+ return []
344
+ except Exception as e :
345
+ logger .warning (f"Error getting models from CLI: { str (e )} " )
301
346
return []
302
347
303
348
def _select_model (self ) -> str :
@@ -321,12 +366,16 @@ def _select_model(self) -> str:
321
366
if selection in self .available_models :
322
367
return selection
323
368
print ("Please enter a valid model number or name" )
369
+ except KeyboardInterrupt :
370
+ # If user interrupts, use first model as default
371
+ print ("\n Interrupted, using first available model" )
372
+ return self .available_models [0 ]
324
373
325
374
def generate (self , prompt : str , system_prompt : str = "" , max_tokens : int = 2000 ) -> str :
326
375
"""Generate text using Ollama."""
327
376
try :
328
377
host , port = self ._get_host_connection ()
329
- conn = http .client .HTTPConnection (host , port )
378
+ conn = http .client .HTTPConnection (host , port , timeout = self . timeout )
330
379
331
380
data = {
332
381
"model" : self .model ,
@@ -338,9 +387,23 @@ def generate(self, prompt: str, system_prompt: str = "", max_tokens: int = 2000)
338
387
339
388
conn .request ("POST" , "/api/generate" , json .dumps (data ), self .headers )
340
389
response = conn .getresponse ()
390
+
391
+ if response .status != 200 :
392
+ logger .warning (f"Failed to generate text: HTTP { response .status } { response .reason } " )
393
+ return ""
394
+
341
395
result = json .loads (response .read ().decode ())
342
396
343
397
return result .get ("response" , "" )
398
+ except json .JSONDecodeError :
399
+ logger .warning ("Invalid JSON response from Ollama API during generation" )
400
+ return ""
401
+ except http .client .HTTPException as e :
402
+ logger .warning (f"HTTP error when generating text: { str (e )} " )
403
+ return ""
404
+ except socket .timeout :
405
+ logger .warning ("Timeout when generating text with Ollama" )
406
+ return ""
344
407
except Exception as e :
345
408
logger .warning (f"Failed to generate text with Ollama: { str (e )} " )
346
409
return ""
@@ -350,7 +413,7 @@ class SmartGitCommitWorkflow:
350
413
"""Manages the workflow for analyzing, grouping, and committing changes with AI assistance."""
351
414
352
415
def __init__ (self , repo_path : str = "." , ollama_host : str = "http://localhost:11434" ,
353
- ollama_model : Optional [str ] = None , use_ai : bool = True ):
416
+ ollama_model : Optional [str ] = None , use_ai : bool = True , timeout : int = 10 ):
354
417
"""
355
418
Initialize the workflow.
356
419
@@ -359,16 +422,18 @@ def __init__(self, repo_path: str = ".", ollama_host: str = "http://localhost:11
359
422
ollama_host: Host for Ollama API
360
423
ollama_model: Model to use for Ollama, if None will prompt user to select
361
424
use_ai: Whether to use AI-powered analysis
425
+ timeout: Timeout in seconds for HTTP requests to Ollama
362
426
"""
363
427
self .repo_path = repo_path
364
428
self .changes : List [GitChange ] = []
365
429
self .commit_groups : List [CommitGroup ] = []
366
430
self .use_ai = use_ai
367
431
self .ollama = None
432
+ self .timeout = timeout
368
433
369
434
if use_ai :
370
435
try :
371
- self .ollama = OllamaClient (host = ollama_host , model = ollama_model )
436
+ self .ollama = OllamaClient (host = ollama_host , model = ollama_model , timeout = timeout )
372
437
except Exception as e :
373
438
logger .warning (f"Failed to initialize Ollama client: { str (e )} " )
374
439
logger .info ("Falling back to rule-based analysis" )
@@ -448,6 +513,10 @@ def load_changes(self) -> None:
448
513
status = line [:2 ].strip ()
449
514
filename = line [3 :].strip ()
450
515
516
+ # Remove any leading "backend/" or similar prefix that might come from running in a subdirectory
517
+ if " -> " in filename : # Handle renamed files
518
+ old_path , filename = filename .split (" -> " )
519
+
451
520
# Get diff content for modified files
452
521
content_diff = None
453
522
if status != "??" : # Not for untracked files
@@ -884,15 +953,30 @@ def execute_commits(self, interactive: bool = True) -> None:
884
953
885
954
# Execute the commit
886
955
# Write commit message with UTF-8 encoding explicitly
887
- commit_msg_path = os .path .join (self .repo_path , ".git" , "COMMIT_EDITMSG" )
888
956
try :
957
+ # First make sure .git directory exists
958
+ git_dir = os .path .join (self .repo_path , ".git" )
959
+ if not os .path .isdir (git_dir ):
960
+ # Try to find the git directory
961
+ stdout , _ = self ._run_git_command (["rev-parse" , "--git-dir" ])
962
+ git_dir = stdout .strip ()
963
+ if not os .path .isdir (git_dir ):
964
+ git_dir = os .path .join (self .repo_path , git_dir )
965
+
966
+ # Now create the commit message file
967
+ commit_msg_path = os .path .join (git_dir , "COMMIT_EDITMSG" )
968
+
889
969
with open (commit_msg_path , "w" , encoding = 'utf-8' ) as f :
890
970
f .write (commit_message )
891
971
892
- stdout , code = self ._run_git_command (["commit" , "-F" , os .path .join (".git" , "COMMIT_EDITMSG" )])
972
+ stdout , code = self ._run_git_command (["commit" , "-F" , commit_msg_path ])
973
+ except Exception as e :
974
+ logger .error (f"Failed to create or use commit message file: { str (e )} " )
975
+ # Try direct commit as fallback
976
+ stdout , code = self ._run_git_command (["commit" , "-m" , commit_message ])
893
977
finally :
894
978
# Clean up the temporary commit message file
895
- if os .path .exists (commit_msg_path ):
979
+ if 'commit_msg_path' in locals () and os .path .exists (commit_msg_path ):
896
980
try :
897
981
os .remove (commit_msg_path )
898
982
except OSError as e :
@@ -962,14 +1046,16 @@ def main() -> int:
962
1046
parser .add_argument ("--ollama-host" , help = "Host for Ollama API" , default = "http://localhost:11434" )
963
1047
parser .add_argument ("--ollama-model" , help = "Model to use for Ollama (will prompt if not specified)" )
964
1048
parser .add_argument ("--no-ai" , action = "store_true" , help = "Disable AI-powered analysis" )
1049
+ parser .add_argument ("--timeout" , type = int , help = "Timeout in seconds for HTTP requests" , default = 10 )
965
1050
args = parser .parse_args ()
966
1051
967
1052
try :
968
1053
workflow = SmartGitCommitWorkflow (
969
1054
repo_path = args .repo_path ,
970
1055
ollama_host = args .ollama_host ,
971
1056
ollama_model = args .ollama_model ,
972
- use_ai = not args .no_ai
1057
+ use_ai = not args .no_ai ,
1058
+ timeout = args .timeout
973
1059
)
974
1060
975
1061
workflow .load_changes ()
0 commit comments