Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/specify_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import json
from pathlib import Path
from typing import Optional, Tuple
from enum import Enum

import typer
import httpx
Expand Down Expand Up @@ -62,6 +63,27 @@ def _github_auth_headers(cli_token: str | None = None) -> dict:
token = _github_token(cli_token)
return {"Authorization": f"Bearer {token}"} if token else {}

def _confirm_overwrite(item: Path) -> bool:
prompt = f"[bright_blue]{str(item)}[/bright_blue] exists. Overwrite it with a template?"
keep = OverwriteOptionsEnum.KEEP.value
OVERWRITE_OPTIONS[keep] = OVERWRITE_OPTIONS[keep].format(item.name)

result = select_with_arrows(
OVERWRITE_OPTIONS,
prompt_text=prompt,
default_key=OverwriteOptionsEnum.OVERWRITE.value,
)

if result is OverwriteOptionsEnum.OVERWRITE.value:
return True
return False

def _overwrite_is_allowed(item: Path, force: bool = False) -> bool:
if not force and str(item) in FILES_TO_CONFIRM_OVERWRITING:
if not _confirm_overwrite(item):
return False
return True

# Constants
AI_CHOICES = {
"copilot": "GitHub Copilot",
Expand All @@ -82,6 +104,17 @@ def _github_auth_headers(cli_token: str | None = None) -> dict:
# Claude CLI local installation path after migrate-installer
CLAUDE_LOCAL_PATH = Path.home() / ".claude" / "local" / "claude"

class OverwriteOptionsEnum(Enum):
OVERWRITE = "Overwrite"
KEEP = "Keep"

OVERWRITE_OPTIONS = {
OverwriteOptionsEnum.OVERWRITE.value: "Proceed with overwriting",
OverwriteOptionsEnum.KEEP.value: "Keep [blue]{}[/blue] and proceed",
}

FILES_TO_CONFIRM_OVERWRITING = ["memory/constitution.md"]

# ASCII Art Banner
BANNER = """
███████╗██████╗ ███████╗ ██████╗██╗███████╗██╗ ██╗
Expand Down Expand Up @@ -543,7 +576,7 @@ def download_template_from_github(ai_assistant: str, download_dir: Path, *, scri
return zip_path, metadata


def download_and_extract_template(project_path: Path, ai_assistant: str, script_type: str, is_current_dir: bool = False, *, verbose: bool = True, tracker: StepTracker | None = None, client: httpx.Client = None, debug: bool = False, github_token: str = None) -> Path:
def download_and_extract_template(project_path: Path, ai_assistant: str, script_type: str, is_current_dir: bool = False, *, verbose: bool = True, tracker: StepTracker | None = None, client: httpx.Client = None, debug: bool = False, github_token: str = None, force: bool = False) -> Path:
"""Download the latest release and extract it to create a new project.
Returns project_path. Uses tracker if provided (with keys: fetch, download, extract, cleanup)
"""
Expand Down Expand Up @@ -631,6 +664,7 @@ def download_and_extract_template(project_path: Path, ai_assistant: str, script_
if sub_item.is_file():
rel_path = sub_item.relative_to(item)
dest_file = dest_path / rel_path
if not _overwrite_is_allowed(rel_path, force): continue
dest_file.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(sub_item, dest_file)
else:
Expand Down Expand Up @@ -965,7 +999,7 @@ def init(
local_ssl_context = ssl_context if verify else False
local_client = httpx.Client(verify=local_ssl_context)

download_and_extract_template(project_path, selected_ai, selected_script, here, verbose=False, tracker=tracker, client=local_client, debug=debug, github_token=github_token)
download_and_extract_template(project_path, selected_ai, selected_script, here, verbose=False, tracker=tracker, client=local_client, debug=debug, github_token=github_token, force=force)

# Ensure scripts are executable (POSIX)
ensure_executable_scripts(project_path, tracker=tracker)
Expand Down