Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,7 @@ void Models::PullModel(const HttpRequestPtr& req,
desired_model_name = name_value;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
if (model_and_branch.size() == 3) {
auto mh = url_parser::Url{
.protocol = "https",
.host = kHuggingFaceHost,
.pathParams = {
model_and_branch[0],
model_and_branch[1],
"resolve",
"main",
model_and_branch[2],
}}.ToFullPath();
return model_service_->HandleDownloadUrlAsync(mh, desired_model_id,
desired_model_name);
}
return model_service_->DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1], desired_model_id);
}

return cpp::fail("Invalid model handle or not supported!");
};

auto result = handle_model_input();
auto result = model_service_->PullModel(model_handle, desired_model_id, desired_model_name);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down
24 changes: 24 additions & 0 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,30 @@ cpp::result<bool, std::string> ModelService::GetModelStatus(
}
}

cpp::result<DownloadTask, std::string> ModelService::PullModel(
const std::string& model_handle,
const std::optional<std::string>& desired_model_id,
const std::optional<std::string>& desired_model_name) {
CTL_INF("Handle model input, model handle: " + model_handle);

if (string_utils::StartsWith(model_handle, "https"))
return HandleDownloadUrlAsync(model_handle, desired_model_id,
desired_model_name);

if (model_handle.find(":") == std::string::npos)
return cpp::fail("Invalid model handle or not supported!");

auto model_and_branch = string_utils::SplitBy(model_handle, ":");

// cortexso format - model:branch
// NOTE: desired_model_name is not used by cortexso downloader
if (model_and_branch.size() == 2)
return DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1], desired_model_id);

return cpp::fail("Invalid model handle or not supported!");
}

cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
const std::string& input) {
if (input.empty()) {
Expand Down
21 changes: 13 additions & 8 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ class ModelService {
std::shared_ptr<EngineServiceI> engine_svc,
cortex::TaskQueue& task_queue);

cpp::result<DownloadTask, std::string> PullModel(
const std::string& model_handle,
const std::optional<std::string>& desired_model_id,
const std::optional<std::string>& desired_model_name);

cpp::result<std::string, std::string> AbortDownloadModel(
const std::string& task_id);

cpp::result<DownloadTask, std::string> DownloadModelFromCortexsoAsync(
const std::string& name, const std::string& branch = "main",
std::optional<std::string> temp_model_id = std::nullopt);

std::optional<config::ModelConfig> GetDownloadedModel(
const std::string& modelId) const;

Expand All @@ -67,10 +68,6 @@ class ModelService {
cpp::result<ModelPullInfo, std::string> GetModelPullInfo(
const std::string& model_handle);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

bool HasModel(const std::string& id) const;

std::optional<hardware::Estimation> GetEstimation(
Expand All @@ -89,6 +86,14 @@ class ModelService {
std::string GetEngineByModelId(const std::string& model_id) const;

private:
cpp::result<DownloadTask, std::string> DownloadModelFromCortexsoAsync(
const std::string& name, const std::string& branch = "main",
std::optional<std::string> temp_model_id = std::nullopt);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

cpp::result<std::optional<std::string>, std::string> MayFallbackToCpu(
const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048,
int n_ubatch = 2048, const std::string& kv_cache_type = "f16");
Expand Down
Loading