Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

TRTLLM_MIN_MEMORY_REQUEST_GI = 10
HF_MODELS_API_URL = "https://huggingface.co/api/models"
EMPTY_HF_REPO = "michaelfeil/empty-model"
HF_MAIN_BRANCH = "main"
HF_ACCESS_TOKEN_KEY = "hf_access_token"
TRUSSLESS_MAX_PAYLOAD_SIZE = "64M"
# Alias for TEMPLATES_DIR
Expand Down
5 changes: 5 additions & 0 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from truss.remote.baseten.service import BasetenService
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
from truss.trt_llm.config_checks import (
empty_checkpoint_repo_trt_llm_builder,
has_no_tags_trt_llm_builder,
memory_updated_for_trt_llm_builder,
uses_trt_llm_builder,
Expand Down Expand Up @@ -597,6 +598,10 @@ def push(
console.print(message_oai, style="red")
sys.exit(1)

message_empty_repo = empty_checkpoint_repo_trt_llm_builder(tr)
if message_empty_repo:
console.print(message_empty_repo, style="yellow")

trt_llm_build_config = tr.spec.config.trt_llm.build
if (
trt_llm_build_config.quantization_type
Expand Down
38 changes: 37 additions & 1 deletion truss/trt_llm/config_checks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import requests

from truss.base.constants import (
EMPTY_HF_REPO,
HF_MAIN_BRANCH,
HF_MODELS_API_URL,
OPENAI_COMPATIBLE_TAG,
OPENAI_NON_COMPATIBLE_TAG,
TRTLLM_MIN_MEMORY_REQUEST_GI,
)
from truss.base.trt_llm_config import TrussTRTLLMModel
from truss.base.trt_llm_config import CheckpointSource, TrussTRTLLMModel
from truss.truss_handle.truss_handle import TrussHandle


Expand Down Expand Up @@ -93,6 +95,32 @@ def add_openai_tag(tr: TrussHandle) -> str:
return ("", False)


def empty_checkpoint_repo_trt_llm_builder(tr: TrussHandle) -> str:
"""
If cache_internal is used to deploy large models with inference stack v2,
then we want our checkpoint repository to point to an empty HF repo because the download workflow still runs.
"""
if uses_trt_llm_builder(tr):
assert tr.spec.config.trt_llm is not None
if (
uses_cache_internal(tr)
and tr.spec.config.trt_llm.root.inference_stack == "v2"
):
trt_llm_config = tr.spec.config.trt_llm.root
if (
trt_llm_config.build is None
or trt_llm_config.build.checkpoint_repository is None
):
return ""
else:
checkpoint_repository = trt_llm_config.build.checkpoint_repository
checkpoint_repository.source = CheckpointSource.HF
checkpoint_repository.repo = EMPTY_HF_REPO
checkpoint_repository.revision = HF_MAIN_BRANCH
return f"Set checkpoint repository to download empty HF repo ({EMPTY_HF_REPO}) because cache_internal was specified"
return ""


def memory_updated_for_trt_llm_builder(tr: TrussHandle) -> bool:
if uses_trt_llm_builder(tr):
if tr.spec.memory_in_bytes < TRTLLM_MIN_MEMORY_REQUEST_GI * 1024**3:
Expand All @@ -113,3 +141,11 @@ def _is_model_public(model_id: str) -> bool:

def uses_trt_llm_builder(tr: TrussHandle) -> bool:
return tr.spec.config.trt_llm is not None


def uses_cache_internal(tr: TrussHandle) -> bool:
return (
tr.spec.config.cache_internal is not None
and tr.spec.config.cache_internal.models is not None
and len(tr.spec.config.cache_internal.models) > 0
)
Loading