diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index 83d2fdf87..2959e636e 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -32,6 +32,7 @@ import json from pathlib import Path from typing import Optional, Tuple +from enum import Enum import typer import httpx @@ -62,6 +63,27 @@ def _github_auth_headers(cli_token: str | None = None) -> dict: token = _github_token(cli_token) return {"Authorization": f"Bearer {token}"} if token else {} +def _confirm_overwrite(item: Path) -> bool: + prompt = f"[bright_blue]{str(item)}[/bright_blue] exists. Overwrite it with a template?" + keep = OverwriteOptionsEnum.KEEP.value + OVERWRITE_OPTIONS[keep] = OVERWRITE_OPTIONS[keep].format(item.name) + + result = select_with_arrows( + OVERWRITE_OPTIONS, + prompt_text=prompt, + default_key=OverwriteOptionsEnum.OVERWRITE.value, + ) + + if result is OverwriteOptionsEnum.OVERWRITE.value: + return True + return False + +def _overwrite_is_allowed(item: Path, force: bool = False) -> bool: + if not force and str(item) in FILES_TO_CONFIRM_OVERWRITING: + if not _confirm_overwrite(item): + return False + return True + # Constants AI_CHOICES = { "copilot": "GitHub Copilot", @@ -82,6 +104,17 @@ def _github_auth_headers(cli_token: str | None = None) -> dict: # Claude CLI local installation path after migrate-installer CLAUDE_LOCAL_PATH = Path.home() / ".claude" / "local" / "claude" +class OverwriteOptionsEnum(Enum): + OVERWRITE = "Overwrite" + KEEP = "Keep" + +OVERWRITE_OPTIONS = { + OverwriteOptionsEnum.OVERWRITE.value: "Proceed with overwriting", + OverwriteOptionsEnum.KEEP.value: "Keep [blue]{}[/blue] and proceed", +} + +FILES_TO_CONFIRM_OVERWRITING = ["memory/constitution.md"] + # ASCII Art Banner BANNER = """ ███████╗██████╗ ███████╗ ██████╗██╗███████╗██╗ ██╗ @@ -543,7 +576,7 @@ def download_template_from_github(ai_assistant: str, download_dir: Path, *, scri return zip_path, metadata -def download_and_extract_template(project_path: Path, ai_assistant: str, script_type: str, is_current_dir: bool = False, *, verbose: bool = True, tracker: StepTracker | None = None, client: httpx.Client = None, debug: bool = False, github_token: str = None) -> Path: +def download_and_extract_template(project_path: Path, ai_assistant: str, script_type: str, is_current_dir: bool = False, *, verbose: bool = True, tracker: StepTracker | None = None, client: httpx.Client = None, debug: bool = False, github_token: str = None, force: bool = False) -> Path: """Download the latest release and extract it to create a new project. Returns project_path. Uses tracker if provided (with keys: fetch, download, extract, cleanup) """ @@ -631,6 +664,7 @@ def download_and_extract_template(project_path: Path, ai_assistant: str, script_ if sub_item.is_file(): rel_path = sub_item.relative_to(item) dest_file = dest_path / rel_path + if not _overwrite_is_allowed(rel_path, force): continue dest_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(sub_item, dest_file) else: @@ -965,7 +999,7 @@ def init( local_ssl_context = ssl_context if verify else False local_client = httpx.Client(verify=local_ssl_context) - download_and_extract_template(project_path, selected_ai, selected_script, here, verbose=False, tracker=tracker, client=local_client, debug=debug, github_token=github_token) + download_and_extract_template(project_path, selected_ai, selected_script, here, verbose=False, tracker=tracker, client=local_client, debug=debug, github_token=github_token, force=force) # Ensure scripts are executable (POSIX) ensure_executable_scripts(project_path, tracker=tracker)