diff --git a/doc/source/_static/model_update.png b/doc/source/_static/model_update.png new file mode 100644 index 0000000000..24b5104303 Binary files /dev/null and b/doc/source/_static/model_update.png differ diff --git a/doc/source/locale/zh_CN/LC_MESSAGES/models/model_update.po b/doc/source/locale/zh_CN/LC_MESSAGES/models/model_update.po new file mode 100644 index 0000000000..c3976fca83 --- /dev/null +++ b/doc/source/locale/zh_CN/LC_MESSAGES/models/model_update.po @@ -0,0 +1,413 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2025, Xorbits Inc. +# This file is distributed under the same license as the Xinference package. +# FIRST AUTHOR , 2025. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: Xinference \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2025-10-31 18:37+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.17.0\n" + +#: ../../source/models/model_update.rst:5 +msgid "Model Update" +msgstr "模型更新" + +#: ../../source/models/model_update.rst:8 +msgid "" +"This section briefly introduces two common operations on the \"Launch " +"Model\" page: updating model lists and adding models. They correspond to " +"the \"Type Selection + Update\" and \"Add Model\" buttons at the top of " +"the page, facilitating quick refresh of models of a certain type or " +"adding new models locally." +msgstr "" +"本节简要介绍了“启动模型”页面上的两个常见操作:更新模型列表和添加模型。" +"它们分别对应页面顶部的“类型选择 + 更新”和“添加模型”按钮,用于快速刷新" +"某一类型的模型或在本地添加新模型。" + +#: ../../source/models/model_update.rst:15 +msgid "Update Models (Launch Model Page)" +msgstr "更新模型(启动模型页面)" + +#: ../../source/models/model_update.rst:17 +msgid "" +"Operation Location: \"Type Selection\" dropdown and \"Update\" button at " +"the top right of the page." +msgstr "操作位置:页面右上角的“类型选择”下拉框和“更新”按钮。" + +#: ../../source/models/model_update.rst:18 +#: ../../source/models/model_update.rst:26 +msgid "Usage:" +msgstr "使用方法:" + +#: ../../source/models/model_update.rst:19 +msgid "" +"Select a model type from the dropdown (such as llm, embedding, rerank, " +"image, audio, video)." +msgstr "" +"从下拉框中选择模型类型(例如 llm、embedding、rerank、image、audio、video" +")。" + +#: ../../source/models/model_update.rst:20 +msgid "" +"Click the \"Update\" button, the page will send an update request to the " +"backend, then automatically jump to the corresponding Tab and refresh the" +" model list of that type." +msgstr "" +"点击“更新”按钮,页面将向后端发送更新请求,然后自动跳转到对应的标签页并" +"刷新该类型的模型列表。" + +#: ../../source/models/model_update.rst:23 +msgid "Add Model (Launch Model Page)" +msgstr "添加模型(启动模型页面)" + +#: ../../source/models/model_update.rst:25 +msgid "Operation Location: \"Add Model\" button at the top right of the page." +msgstr "操作位置:页面右上角的“添加模型”按钮。" + +#: ../../source/models/model_update.rst:27 +msgid "Click \"Add Model\" to open the add dialog." +msgstr "点击“添加模型”以打开添加对话框。" + +#: ../../source/models/model_update.rst:28 +msgid "Complete the model addition process in the dialog." +msgstr "在对话框中完成模型添加流程。" + +#: ../../source/models/model_update.rst:29 +msgid "" +"After successful addition, the page will jump to the corresponding type " +"Tab to immediately view the latest model list." +msgstr "添加成功后,页面将跳转到对应类型的标签页,并立即显示最新的模型列表。" + +#: ../../source/models/model_update.rst:32 +msgid "Xinference Models Hub User Guide" +msgstr "Xinference 模型中心用户指南" + +#: ../../source/models/model_update.rst:35 +msgid "Overview" +msgstr "概述" + +#: ../../source/models/model_update.rst:37 +msgid "" +"Xinference Models Hub is a full-stack platform for managing and sharing " +"models. It provides a comprehensive solution for model registration, " +"browsing, review workflows, and collaborative model management." +msgstr "" +"Xinference 模型中心是一个用于管理和共享模型的全栈平台,提供了模型注册、" +"浏览、审核流程和协作管理的完整解决方案。" + +#: ../../source/models/model_update.rst:40 +msgid "You can visit the Models Hub at: https://model.xinference.io" +msgstr "您可以通过以下网址访问模型中心:https://model.xinference.io" + +#: ../../source/models/model_update.rst:43 +msgid "Quick Start" +msgstr "快速开始" + +#: ../../source/models/model_update.rst:46 +msgid "User Registration and Login" +msgstr "用户注册与登录" + +#: ../../source/models/model_update.rst:48 +msgid "**Registration**" +msgstr "**注册**" + +#: ../../source/models/model_update.rst:50 +msgid "Open the website registration page" +msgstr "打开网站注册页面" + +#: ../../source/models/model_update.rst:51 +msgid "Fill in the necessary information and submit" +msgstr "填写必要的信息并提交" + +#: ../../source/models/model_update.rst:53 +msgid "**Login**" +msgstr "**登录**" + +#: ../../source/models/model_update.rst:55 +msgid "Open the website login page" +msgstr "打开网站登录页面" + +#: ../../source/models/model_update.rst:56 +msgid "After successful login, you will be redirected to the model list page" +msgstr "登录成功后,将跳转到模型列表页面" + +#: ../../source/models/model_update.rst:58 +msgid "**Password Reset**" +msgstr "**重置密码**" + +#: ../../source/models/model_update.rst:60 +msgid "Click the \"Forgot Password\" link on the login page" +msgstr "点击登录页面的“忘记密码”链接" + +#: ../../source/models/model_update.rst:61 +msgid "Follow the instructions in the email to reset your password" +msgstr "按照邮件中的说明重置密码" + +#: ../../source/models/model_update.rst:63 +msgid "**Logout**" +msgstr "**退出登录**" + +#: ../../source/models/model_update.rst:65 +msgid "Click the avatar in the top right corner of the page" +msgstr "点击页面右上角的头像" + +#: ../../source/models/model_update.rst:66 +msgid "Select \"Logout\" from the dropdown menu" +msgstr "在下拉菜单中选择“退出登录”" + +#: ../../source/models/model_update.rst:69 +msgid "Core Features" +msgstr "核心功能" + +#: ../../source/models/model_update.rst:72 +msgid "Browse Models" +msgstr "浏览模型" + +#: ../../source/models/model_update.rst:74 +msgid "**Model List (Homepage)**" +msgstr "**模型列表(首页)**" + +#: ../../source/models/model_update.rst:76 +msgid "**Function:** Browse available models, click any model to view details" +msgstr "**功能:** 浏览可用模型,点击任意模型查看详情" + +#: ../../source/models/model_update.rst:77 +msgid "**Location:** \"Models\" menu in the website navigation bar" +msgstr "**位置:** 网站导航栏中的“模型”菜单" + +#: ../../source/models/model_update.rst:80 +msgid "Some advanced models are only visible to authorized users." +msgstr "部分高级模型仅对授权用户可见。" + +#: ../../source/models/model_update.rst:82 +msgid "**Model Details and Documentation**" +msgstr "**模型详情与文档**" + +#: ../../source/models/model_update.rst:84 +msgid "**Function:** View detailed information about models" +msgstr "**功能:** 查看模型的详细信息" + +#: ../../source/models/model_update.rst:85 +msgid "" +"**Default Display:** \"README\" tab - view model description, usage " +"instructions, and notes" +msgstr "**默认显示:** “README” 标签页 - 查看模型描述、使用说明和注意事项" + +#: ../../source/models/model_update.rst:86 +msgid "**Other Tabs:** Settings (authorized users), review status" +msgstr "**其他标签页:** 设置(仅授权用户)、审核状态" + +#: ../../source/models/model_update.rst:89 +msgid "User Center" +msgstr "用户中心" + +#: ../../source/models/model_update.rst:91 +msgid "**Function:** View and manage personal information" +msgstr "**功能:** 查看并管理个人信息" + +#: ../../source/models/model_update.rst:92 +msgid "" +"**Location:** Click the avatar in the top right corner, select \"User " +"Center\"" +msgstr "**位置:** 点击右上角头像,选择“用户中心”" + +#: ../../source/models/model_update.rst:93 +msgid "**Content:** Personal profile settings" +msgstr "**内容:** 个人资料设置" + +#: ../../source/models/model_update.rst:96 +msgid "Model Management (Authorized Users)" +msgstr "模型管理(授权用户)" + +#: ../../source/models/model_update.rst:99 +msgid "Model Registration" +msgstr "模型注册" + +#: ../../source/models/model_update.rst:101 +msgid "**Function:** Submit new models to the platform" +msgstr "**功能:** 向平台提交新模型" + +#: ../../source/models/model_update.rst:102 +msgid "" +"**Location:** Click the avatar in the top right corner, select \"Model " +"Registration\"" +msgstr "**位置:** 点击右上角头像,选择“模型注册”" + +#: ../../source/models/model_update.rst:103 +#: ../../source/models/model_update.rst:126 +msgid "**Required Permissions:**" +msgstr "**所需权限:**" + +#: ../../source/models/model_update.rst:105 +#: ../../source/models/model_update.rst:128 +msgid "**Private Models:** Model registration permission" +msgstr "**私有模型:** 需要模型注册权限" + +#: ../../source/models/model_update.rst:106 +msgid "**Public Models:** Public model registration permission" +msgstr "**公共模型:** 需要公共模型注册权限" + +#: ../../source/models/model_update.rst:107 +msgid "**Enterprise Models:** Enterprise model registration permission" +msgstr "**企业模型:** 需要企业模型注册权限" + +#: ../../source/models/model_update.rst:109 +#: ../../source/models/model_update.rst:161 +msgid "**Operation Process:**" +msgstr "**操作流程:**" + +#: ../../source/models/model_update.rst:111 +msgid "Fill in basic model information" +msgstr "填写模型的基本信息" + +#: ../../source/models/model_update.rst:112 +msgid "" +"Fill in Readme (can be automatically obtained by clicking the Get Readme " +"button)" +msgstr "填写 Readme(可通过点击“获取 Readme”按钮自动获取)" + +#: ../../source/models/model_update.rst:113 +msgid "Submit (to register public models, enable the Public Model parameter)" +msgstr "提交(如需注册公共模型,请启用“公共模型”参数)" + +#: ../../source/models/model_update.rst:115 +#: ../../source/models/model_update.rst:144 +msgid "**Notes:**" +msgstr "**注意事项:**" + +#: ../../source/models/model_update.rst:117 +msgid "Regular users can only register private models" +msgstr "普通用户仅可注册私有模型" + +#: ../../source/models/model_update.rst:118 +msgid "" +"Public model registration requires review, and can be used publicly after" +" approval (no review needed if you have public model registration " +"permission)" +msgstr "" +"公共模型注册需要审核,通过后才能公开使用(若具有公共模型注册权限则无需" +"审核)。" + +#: ../../source/models/model_update.rst:119 +msgid "" +"Enterprise model registration requires enabling the Public Model " +"parameter first" +msgstr "企业模型注册需先启用“公共模型”参数。" + +#: ../../source/models/model_update.rst:122 +msgid "My Models" +msgstr "我的模型" + +#: ../../source/models/model_update.rst:124 +msgid "" +"**Function:** View models associated with your account (models you " +"registered)" +msgstr "**功能:** 查看与你账户相关的模型(即你注册的模型)" + +#: ../../source/models/model_update.rst:125 +msgid "" +"**Location:** Click the avatar in the top right corner, select \"My " +"Models\"" +msgstr "**位置:** 点击右上角头像,选择“我的模型”" + +#: ../../source/models/model_update.rst:129 +msgid "**Public Models:** Model registration permission" +msgstr "**公共模型:** 需要模型注册权限" + +#: ../../source/models/model_update.rst:130 +msgid "**Enterprise Models:** Model registration permission" +msgstr "**企业模型:** 需要模型注册权限" + +#: ../../source/models/model_update.rst:133 +msgid "Model Maintenance" +msgstr "模型维护" + +#: ../../source/models/model_update.rst:135 +msgid "**Function:** Modify and manage existing models" +msgstr "**功能:** 修改和管理已有模型" + +#: ../../source/models/model_update.rst:136 +msgid "**Location:** Click the \"Settings\" icon on the model details page" +msgstr "**位置:** 在模型详情页点击“设置”图标" + +#: ../../source/models/model_update.rst:138 +msgid "**Permission Requirements:**" +msgstr "**权限要求:**" + +#: ../../source/models/model_update.rst:140 +msgid "" +"**Private Models:** Model ownership or any public model management " +"permission" +msgstr "**私有模型:** 需拥有模型所有权或公共模型管理权限" + +#: ../../source/models/model_update.rst:141 +msgid "" +"**Advanced Models:** Advanced model update, delete, or expiration " +"permission" +msgstr "**高级模型:** 需具备高级模型更新、删除或过期管理权限" + +#: ../../source/models/model_update.rst:142 +msgid "**Public Models:** Public model update, delete, or expiration permission" +msgstr "**公共模型:** 需具备公共模型更新、删除或过期管理权限" + +#: ../../source/models/model_update.rst:146 +msgid "" +"Updating JSON or modifying expiration attributes of public models will " +"automatically create a PR to the xorbitsai/inference repository" +msgstr "" +"更新公共模型的 JSON 或修改过期属性时,将自动创建一个 PR 提交到 xorbitsai/" +"inference 仓库。" + +#: ../../source/models/model_update.rst:149 +msgid "Review Workflow" +msgstr "审核流程" + +#: ../../source/models/model_update.rst:151 +msgid "**For Model Submitters:**" +msgstr "**针对模型提交者:**" + +#: ../../source/models/model_update.rst:153 +msgid "Submit models for review" +msgstr "提交模型以进行审核" + +#: ../../source/models/model_update.rst:154 +msgid "Check review status on the model details page" +msgstr "在模型详情页查看审核状态" + +#: ../../source/models/model_update.rst:155 +msgid "Make modifications based on reviewer feedback if needed" +msgstr "如有需要,根据审核者的反馈进行修改" + +#: ../../source/models/model_update.rst:157 +msgid "**For Reviewers:**" +msgstr "**针对审核者:**" + +#: ../../source/models/model_update.rst:159 +msgid "" +"**Required Permissions:** Model review list permission, model review " +"permission" +msgstr "**所需权限:** 模型审核列表权限、模型审核权限" + +#: ../../source/models/model_update.rst:163 +msgid "Enter the review queue page" +msgstr "进入审核队列页面" + +#: ../../source/models/model_update.rst:164 +msgid "Evaluate model quality and compliance" +msgstr "评估模型质量与合规性" + +#: ../../source/models/model_update.rst:165 +msgid "Approve or reject and provide feedback" +msgstr "批准或拒绝模型,并提供反馈意见" + diff --git a/doc/source/models/index.rst b/doc/source/models/index.rst index 03aaaf70dc..dd780d0e83 100644 --- a/doc/source/models/index.rst +++ b/doc/source/models/index.rst @@ -250,6 +250,7 @@ Model Usage model_abilities/index builtin/index custom + model_update sources/sources virtualenv lora diff --git a/doc/source/models/model_update.rst b/doc/source/models/model_update.rst new file mode 100644 index 0000000000..953ba4c5aa --- /dev/null +++ b/doc/source/models/model_update.rst @@ -0,0 +1,165 @@ +.. _model_update: + +============ +Model Update +============ +.. versionadded:: v1.13.0 + +This section briefly introduces two common operations on the "Launch Model" page: updating model lists and adding models. They correspond to the "Type Selection + Update" and "Add Model" buttons at the top of the page, facilitating quick refresh of models of a certain type or adding new models locally. + +.. raw:: html + + model update interface + +Update Models (Launch Model Page) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Operation Location: "Type Selection" dropdown and "Update" button at the top right of the page. +- Usage: +1. Select a model type from the dropdown (such as llm, embedding, rerank, image, audio, video). +2. Click the "Update" button, the page will send an update request to the backend, then automatically jump to the corresponding Tab and refresh the model list of that type. + +Add Model (Launch Model Page) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Operation Location: "Add Model" button at the top right of the page. +- Usage: +1. Click "Add Model" to open the add dialog. +2. Complete the model addition process in the dialog. +3. After successful addition, the page will jump to the corresponding type Tab to immediately view the latest model list. + +Xinference Models Hub User Guide +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Overview +-------- + +Xinference Models Hub is a full-stack platform for managing and sharing models. +It provides a comprehensive solution for model registration, browsing, review workflows, and collaborative model management. + +You can visit the Models Hub at: https://model.xinference.io + +Quick Start +----------- + +User Registration and Login +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Registration** + +1. Open the website registration page +2. Fill in the necessary information and submit + +**Login** + +1. Open the website login page +2. After successful login, you will be redirected to the model list page + +**Password Reset** + +1. Click the "Forgot Password" link on the login page +2. Follow the instructions in the email to reset your password + +**Logout** + +1. Click the avatar in the top right corner of the page +2. Select "Logout" from the dropdown menu + +Core Features +------------- + +Browse Models +^^^^^^^^^^^^^ + +**Model List (Homepage)** + +* **Function:** Browse available models, click any model to view details +* **Location:** "Models" menu in the website navigation bar + +.. note:: + Some advanced models are only visible to authorized users. + +**Model Details and Documentation** + +* **Function:** View detailed information about models +* **Default Display:** "README" tab - view model description, usage instructions, and notes +* **Other Tabs:** Settings (authorized users), review status + +User Center +^^^^^^^^^^^ + +* **Function:** View and manage personal information +* **Location:** Click the avatar in the top right corner, select "User Center" +* **Content:** Personal profile settings + +Model Management (Authorized Users) +----------------------------------- + +Model Registration +^^^^^^^^^^^^^^^^^^ + +* **Function:** Submit new models to the platform +* **Location:** Click the avatar in the top right corner, select "Model Registration" +* **Required Permissions:** + + * **Private Models:** Model registration permission + * **Public Models:** Public model registration permission + * **Enterprise Models:** Enterprise model registration permission + +**Operation Process:** + +1. Fill in basic model information +2. Fill in Readme (can be automatically obtained by clicking the Get Readme button) +3. Submit (to register public models, enable the Public Model parameter) + +**Notes:** + + * Regular users can only register private models + * Public model registration requires review, and can be used publicly after approval (no review needed if you have public model registration permission) + * Enterprise model registration requires enabling the Public Model parameter first + +My Models +^^^^^^^^^ + +* **Function:** View models associated with your account (models you registered) +* **Location:** Click the avatar in the top right corner, select "My Models" +* **Required Permissions:** + + * **Private Models:** Model registration permission + * **Public Models:** Model registration permission + * **Enterprise Models:** Model registration permission + +Model Maintenance +^^^^^^^^^^^^^^^^^ + +* **Function:** Modify and manage existing models +* **Location:** Click the "Settings" icon on the model details page + +* **Permission Requirements:** + + * **Private Models:** Model ownership or any public model management permission + * **Advanced Models:** Advanced model update, delete, or expiration permission + * **Public Models:** Public model update, delete, or expiration permission + +**Notes:** + + * Updating JSON or modifying expiration attributes of public models will automatically create a PR to the xorbitsai/inference repository + +Review Workflow +^^^^^^^^^^^^^^^ + +**For Model Submitters:** + +1. Submit models for review +2. Check review status on the model details page +3. Make modifications based on reviewer feedback if needed + +**For Reviewers:** + +* **Required Permissions:** Model review list permission, model review permission + +**Operation Process:** + +1. Enter the review queue page +2. Evaluate model quality and compliance +3. Approve or reject and provide feedback \ No newline at end of file diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 84c7b18d80..3da2e7bd1a 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -198,6 +198,14 @@ class RegisterModelRequest(BaseModel): persist: bool +class AddModelRequest(BaseModel): + model_name: str + + +class UpdateModelRequest(BaseModel): + model_type: str + + class BuildGradioInterfaceRequest(BaseModel): model_type: str model_name: str @@ -900,6 +908,26 @@ async def internal_exception_handler(request: Request, exc: Exception): else None ), ) + self._router.add_api_route( + "/v1/models/add", + self.add_model, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) + self._router.add_api_route( + "/v1/models/update_type", + self.update_model_type, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:add"])] + if self.is_authenticated() + else None + ), + ) self._router.add_api_route( "/v1/cache/models", self.list_cached_models, @@ -3123,13 +3151,90 @@ async def unregister_model(self, model_type: str, model_name: str) -> JSONRespon raise HTTPException(status_code=500, detail=str(e)) return JSONResponse(content=None) + async def add_model(self, request: Request) -> JSONResponse: + try: + # Parse request + raw_json = await request.json() + body = AddModelRequest.parse_obj(raw_json) + model_name = body.model_name + + supervisor_ref = await self._get_supervisor_ref() + + # Call supervisor with model_name only + await supervisor_ref.add_model(model_name) + + # Get model type information to return to frontend + import requests + + try: + info_url = f"https://model.xinference.io/api/models/{model_name}" + info_response = requests.get(info_url, timeout=30) + if info_response.status_code == 200: + model_info = info_response.json() + model_type = model_info.get("data", {}).get("model_type") + else: + model_type = "unknown" + except Exception: + model_type = "unknown" + + except ValueError as re: + logger.error(f"ValueError in add_model API: {re}", exc_info=True) + logger.error(f"ValueError details: {type(re).__name__}: {re}") + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(f"Unexpected error in add_model API: {e}", exc_info=True) + logger.error(f"Error details: {type(e).__name__}: {e}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(e)) + + response_data = { + "data": { + "model_name": model_name, + "model_type": model_type, + }, + } + + return JSONResponse(content=response_data) + + async def update_model_type(self, request: Request) -> JSONResponse: + try: + # Parse request + raw_json = await request.json() + + body = UpdateModelRequest.parse_obj(raw_json) + model_type = body.model_type + + # Get supervisor reference + supervisor_ref = await self._get_supervisor_ref() + + await supervisor_ref.update_model_type(model_type) + + except ValueError as re: + logger.error(f"ValueError in update_model_type API: {re}", exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error( + f"Unexpected error in update_model_type API: {e}", exc_info=True + ) + raise HTTPException(status_code=500, detail=str(e)) + + return JSONResponse( + content={ + "message": f"Model configurations updated successfully for type: {model_type}" + } + ) + async def list_model_registrations( self, model_type: str, detailed: bool = Query(False) ) -> JSONResponse: try: + data = await (await self._get_supervisor_ref()).list_model_registrations( model_type, detailed=detailed ) + # Remove duplicate model names. model_names = set() final_data = [] @@ -3137,11 +3242,20 @@ async def list_model_registrations( if item["model_name"] not in model_names: model_names.add(item["model_name"]) final_data.append(item) + return JSONResponse(content=final_data) except ValueError as re: + logger.error( + f"ValueError in list_model_registrations: {re}", + exc_info=True, + ) logger.error(re, exc_info=True) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: + logger.error( + f"Unexpected error in list_model_registrations: {e}", + exc_info=True, + ) logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 1ed96cd703..43bbe4ad97 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -46,6 +46,7 @@ ) from ..core.model import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus +from ..model.cache_manager import CacheManager from ..model.utils import get_engine_params_by_name from ..types import PeftModelConfig from .metrics import record_metrics @@ -116,6 +117,13 @@ def __init__(self): self._uptime = None self._lock = asyncio.Lock() + def _log_debug(self, message: str): + """Helper method to log debug information.""" + import logging + + logger = logging.getLogger(__name__) + logger.debug(f"[SupervisorActor] {message}") + @classmethod def default_uid(cls) -> str: return "supervisor" @@ -134,6 +142,19 @@ async def __post_create__(self): if not XINFERENCE_DISABLE_HEALTH_CHECK: # Run _check_dead_nodes() in a dedicated thread. from ..isolation import Isolation + # Load persisted models on startup + try: + from ..model.utils import load_persisted_models_to_registry + + loaded_count = load_persisted_models_to_registry() + if loaded_count > 0: + logger.info( + f"Supervisor loaded {loaded_count} persisted models on startup" + ) + except Exception as e: + logger.warning( + f"Supervisor failed to load persisted models on startup: {e}" + ) self._isolation = Isolation(asyncio.new_event_loop(), threaded=True) self._isolation.start() @@ -217,6 +238,13 @@ async def __post_create__(self): register_rerank, unregister_rerank, ) + from ..model.video import ( + CustomVideoModelFamilyV2, + generate_video_description, + get_video_model_descriptions, + register_video, + unregister_video, + ) self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore "LLM": ( @@ -249,6 +277,12 @@ async def __post_create__(self): unregister_audio, generate_audio_description, ), + "video": ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, + generate_video_description, + ), "flexible": ( FlexibleModelSpec, register_flexible_model, @@ -264,6 +298,7 @@ async def __post_create__(self): model_version_infos.update(get_rerank_model_descriptions()) model_version_infos.update(get_image_model_descriptions()) model_version_infos.update(get_audio_model_descriptions()) + model_version_infos.update(get_video_model_descriptions()) model_version_infos.update(get_flexible_model_descriptions()) await self._cache_tracker_ref.record_model_version( model_version_infos, self.address @@ -427,179 +462,307 @@ def _get_spec_dicts( ) return specs, list(download_hubs) - async def _to_llm_reg( - self, llm_family: "LLMFamilyV2", is_builtin: bool + async def _to_model_reg( + self, + model_family, + is_builtin: bool, + cache_manager_class=None, + use_spec_dicts: bool = False, + cache_status: Optional[bool] = None, ) -> Dict[str, Any]: - from ..model.llm.cache_manager import LLMCacheManager - - instance_cnt = await self.get_instance_count(llm_family.model_name) - version_cnt = await self.get_model_version_count(llm_family.model_name) + instance_cnt = await self.get_instance_count(model_family.model_name) + version_cnt = await self.get_model_version_count(model_family.model_name) if self.is_local_deployment(): - # TODO: does not work when the supervisor and worker are running on separate nodes. - _llm_family = llm_family.copy() - specs, download_hubs = self._get_spec_dicts(_llm_family, LLMCacheManager) + # Local deployment - include additional fields + if use_spec_dicts: + # For LLM/Embedding/Rerank - use spec_dicts + _family = model_family.copy() + specs, download_hubs = self._get_spec_dicts( + _family, cache_manager_class + ) + res = { + **model_family.dict(), + "is_builtin": is_builtin, + "model_specs": specs, + "download_hubs": download_hubs, + } + else: + # For Image/Audio/Video - use cache status + if cache_manager_class: + cache_manager = cache_manager_class(model_family) + actual_cache_status = cache_manager.get_cache_status() + else: + actual_cache_status = ( + cache_status if cache_status is not None else True + ) + + res = { + **model_family.dict(), + "cache_status": actual_cache_status, + "is_builtin": is_builtin, + } + else: + # Remote deployment - basic info only res = { - **llm_family.dict(), + **model_family.dict(), "is_builtin": is_builtin, - "model_specs": specs, - "download_hubs": download_hubs, } - else: - res = {**llm_family.dict(), "is_builtin": is_builtin} + res["model_version_count"] = version_cnt res["model_instance_count"] = instance_cnt return res + async def _collect_model_registrations( + self, + builtin_models, + register_builtin_func, + get_registered_func, + cache_manager_class, + to_model_reg_func, + detailed: bool = False, + is_dict_builtins: bool = False, + builtin_transform_func=None, + model_type: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Generic helper function to collect model registrations for any model type. + + Args: + builtin_models: Collection of builtin models (can be list, dict, etc.) + register_builtin_func: Function to register builtin models + get_registered_func: Function to get registered models + cache_manager_class: CacheManager class for checking model source + to_model_reg_func: Function to convert model to registration format + detailed: Whether to include detailed information + is_dict_builtins: Whether builtin_models is a dict (for embedding/rerank style) + builtin_transform_func: Optional function to transform builtin model data (for image/audio with download_hubs) + model_type: The model type string (e.g., "llm", "embedding") + + Returns: + List of model registration dictionaries + """ + ret = [] + + # Register builtin models + register_builtin_func() + + # Get names of hardcoded builtin models + if is_dict_builtins: + builtin_names = set(builtin_models.keys()) + else: + builtin_names = {model_obj.model_name for model_obj in builtin_models} + + # Process all models (builtin + registered) + all_models = [] + + # Add builtin models + if is_dict_builtins: + for model_name, model_obj in builtin_models.items(): + all_models.append( + (model_obj, True, "builtin") + ) # (model_obj, is_builtin, source) + else: + for model_obj in builtin_models: + all_models.append((model_obj, True, "builtin")) + + # Add registered models + registered_models = get_registered_func() + for model_obj in registered_models: + model_name = getattr(model_obj, "model_name", None) + if model_name is None: + continue + + # Use unified source resolution logic with debug info + source = cache_manager_class.resolve_model_source( + model_name, model_type, builtin_names + ) + is_builtin = source != "user" + + # Debug: Log model source determination for testing + self._log_debug( + f"[MODEL_SOURCE_DEBUG] Model: {model_name}, Type: {model_type}, Source: {source}, Is_builtin: {is_builtin}" + ) + + all_models.append((model_obj, is_builtin, source)) + + # Process all models + for model_obj, is_builtin, source in all_models: + # Debug: Check what we have + self._log_debug( + f"[DEBUG] Processing model_obj: type={type(model_obj)}, is_builtin={is_builtin}, source={source}" + ) + + # Initialize actual_model_obj + actual_model_obj = model_obj + + # Apply transform functions for builtin models (needed for both detailed and non-detailed modes) + if source == "builtin" and builtin_transform_func: + if is_dict_builtins: + model_name = getattr(model_obj, "model_name", None) + if model_name: + # Apply transformation for dict-style builtins + families = builtin_models.get(model_name) + self._log_debug( + f"[DEBUG] Transform: model_name={model_name}, families type={type(families)}" + ) + actual_model_obj = builtin_transform_func(model_name, families) + self._log_debug( + f"[DEBUG] Transform result: type={type(actual_model_obj)}, model_name={getattr(actual_model_obj, 'model_name', 'NO_NAME')}" + ) + else: + actual_model_obj = model_obj + else: + # For list-style builtins, we don't have model_name and builtin_models + # In this case, model_obj should already be the correct format + # Debug: Check the type of model_obj + self._log_debug( + f"[DEBUG] model_obj type: {type(model_obj)}, value: {model_obj}" + ) + actual_model_obj = model_obj + + # Critical check: ensure actual_model_obj is not a list + if isinstance(actual_model_obj, list): + # This should not happen, but if it does, we need to handle it + print(f"ERROR: actual_model_obj is a list: {actual_model_obj}") + self._log_debug( + f"[ERROR] actual_model_obj is a list: {actual_model_obj}" + ) + # Try to get the first item if it's a list + if len(actual_model_obj) > 0: + actual_model_obj = actual_model_obj[0] + else: + continue + + if detailed: + reg_data = await to_model_reg_func( + actual_model_obj, is_builtin=is_builtin + ) + + # Apply post-transform if needed (for download_hubs) + if hasattr(builtin_transform_func, "post_transform"): + if is_dict_builtins and model_name: + reg_data = builtin_transform_func.post_transform( + model_name, builtin_models.get(model_name), reg_data + ) + else: + # For list-style builtins, we may not have the right data for post-transform + # Skip post-transform in this case to avoid errors + pass + + ret.append(reg_data) + else: + # Non-detailed mode - use actual_model_obj which should have model_name after transformation + model_name = getattr(actual_model_obj, "model_name", None) + self._log_debug( + f"[NON-DETAILED] First attempt: model_name={model_name}, actual_model_obj type={type(actual_model_obj)}" + ) + + if not model_name: + # Fallback to original model_obj if actual_model_obj doesn't have model_name + model_name = getattr(model_obj, "model_name", None) + self._log_debug( + f"[NON-DETAILED] Fallback: model_name={model_name}, model_obj type={type(model_obj)}" + ) + + if not model_name: + # Last resort: try to extract from the first element if model_obj is a list (shouldn't happen now) + if isinstance(model_obj, list) and len(model_obj) > 0: + model_name = getattr(model_obj[0], "model_name", "unknown") + self._log_debug( + f"[NON-DETAILED] Last resort list: model_name={model_name}" + ) + else: + model_name = "unknown" + self._log_debug(f"[NON-DETAILED] Final unknown") + + self._log_debug(f"[NON-DETAILED] Final model_name: {model_name}") + ret.append({"model_name": model_name, "is_builtin": is_builtin}) + + return ret + + async def _to_llm_reg( + self, llm_family: "LLMFamilyV2", is_builtin: bool + ) -> Dict[str, Any]: + from ..model.llm.cache_manager import LLMCacheManager + + return await self._to_model_reg( + model_family=llm_family, + is_builtin=is_builtin, + cache_manager_class=LLMCacheManager, + use_spec_dicts=True, + ) + async def _to_embedding_model_reg( self, model_family: "EmbeddingModelFamilyV2", is_builtin: bool ) -> Dict[str, Any]: from ..model.embedding.cache_manager import EmbeddingCacheManager - instance_cnt = await self.get_instance_count(model_family.model_name) - version_cnt = await self.get_model_version_count(model_family.model_name) - - if self.is_local_deployment(): - _family = model_family.copy() - # TODO: does not work when the supervisor and worker are running on separate nodes. - specs, download_hubs = self._get_spec_dicts(_family, EmbeddingCacheManager) - res = { - **model_family.dict(), - "is_builtin": is_builtin, - "model_specs": specs, - "download_hubs": download_hubs, - } - else: - res = { - **model_family.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_family, + is_builtin=is_builtin, + cache_manager_class=EmbeddingCacheManager, + use_spec_dicts=True, + ) async def _to_rerank_model_reg( self, model_family: "RerankModelFamilyV2", is_builtin: bool ) -> Dict[str, Any]: from ..model.rerank.cache_manager import RerankCacheManager - instance_cnt = await self.get_instance_count(model_family.model_name) - version_cnt = await self.get_model_version_count(model_family.model_name) - - if self.is_local_deployment(): - _family = model_family.copy() - # TODO: does not work when the supervisor and worker are running on separate nodes. - specs, download_hubs = self._get_spec_dicts(_family, RerankCacheManager) - res = { - **model_family.dict(), - "is_builtin": is_builtin, - "model_specs": specs, - "download_hubs": download_hubs, - } - else: - res = { - **model_family.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_family, + is_builtin=is_builtin, + cache_manager_class=RerankCacheManager, + use_spec_dicts=True, + ) async def _to_image_model_reg( self, model_family: "ImageModelFamilyV2", is_builtin: bool ) -> Dict[str, Any]: from ..model.image.cache_manager import ImageCacheManager - instance_cnt = await self.get_instance_count(model_family.model_name) - version_cnt = await self.get_model_version_count(model_family.model_name) - - if self.is_local_deployment(): - # TODO: does not work when the supervisor and worker are running on separate nodes. - cache_manager = ImageCacheManager(model_family) - res = { - **model_family.dict(), - "cache_status": cache_manager.get_cache_status(), - "is_builtin": is_builtin, - } - else: - res = { - **model_family.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_family, + is_builtin=is_builtin, + cache_manager_class=ImageCacheManager, + use_spec_dicts=False, + ) async def _to_audio_model_reg( self, model_family: "AudioModelFamilyV2", is_builtin: bool ) -> Dict[str, Any]: - from ..model.cache_manager import CacheManager - - instance_cnt = await self.get_instance_count(model_family.model_name) - version_cnt = await self.get_model_version_count(model_family.model_name) - cache_manager = CacheManager(model_family) + from ..model.audio.cache_manager import AudioCacheManager - if self.is_local_deployment(): - # TODO: does not work when the supervisor and worker are running on separate nodes. - res = { - **model_family.dict(), - "cache_status": cache_manager.get_cache_status(), - "is_builtin": is_builtin, - } - else: - res = { - **model_family.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_family, + is_builtin=is_builtin, + cache_manager_class=AudioCacheManager, + use_spec_dicts=False, + ) async def _to_video_model_reg( self, model_family: "VideoModelFamilyV2", is_builtin: bool ) -> Dict[str, Any]: - from ..model.cache_manager import CacheManager + from ..model.video.cache_manager import VideoCacheManager - instance_cnt = await self.get_instance_count(model_family.model_name) - version_cnt = await self.get_model_version_count(model_family.model_name) - cache_manager = CacheManager(model_family) - - if self.is_local_deployment(): - # TODO: does not work when the supervisor and worker are running on separate nodes. - res = { - **model_family.dict(), - "cache_status": cache_manager.get_cache_status(), - "is_builtin": is_builtin, - } - else: - res = { - **model_family.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_family, + is_builtin=is_builtin, + cache_manager_class=VideoCacheManager, + use_spec_dicts=False, + ) async def _to_flexible_model_reg( self, model_spec: "FlexibleModelSpec", is_builtin: bool ) -> Dict[str, Any]: - instance_cnt = await self.get_instance_count(model_spec.model_name) - version_cnt = await self.get_model_version_count(model_spec.model_name) - - if self.is_local_deployment(): - res = { - **model_spec.dict(), - "cache_status": True, - "is_builtin": is_builtin, - } - else: - res = { - **model_spec.dict(), - "is_builtin": is_builtin, - } - res["model_version_count"] = version_cnt - res["model_instance_count"] = instance_cnt - return res + return await self._to_model_reg( + model_family=model_spec, + is_builtin=is_builtin, + cache_manager_class=None, + use_spec_dicts=False, + cache_status=True, + ) @log_async(logger=logger) async def list_model_registrations( @@ -613,133 +776,184 @@ def sort_helper(item): if not self.is_local_deployment(): workers = list(self._worker_address_to_worker.values()) for worker in workers: - ret.extend(await worker.list_model_registrations(model_type, detailed)) - - if model_type == "LLM": - from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families - - for family in BUILTIN_LLM_FAMILIES: - if detailed: - ret.append(await self._to_llm_reg(family, True)) - else: - ret.append({"model_name": family.model_name, "is_builtin": True}) - - for family in get_user_defined_llm_families(): - if detailed: - ret.append(await self._to_llm_reg(family, False)) - else: - ret.append({"model_name": family.model_name, "is_builtin": False}) + worker_data = await worker.list_model_registrations( + model_type, detailed + ) + ret.extend(worker_data) + if model_type.upper() == "LLM": + from ..model.llm import ( + BUILTIN_LLM_FAMILIES, + get_registered_llm_families, + register_builtin_model, + ) + from ..model.llm.cache_manager import LLMCacheManager + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_LLM_FAMILIES, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_llm_families, + cache_manager_class=LLMCacheManager, + to_model_reg_func=self._to_llm_reg, + detailed=detailed, + is_dict_builtins=False, + model_type="llm", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "embedding": - from ..model.embedding import BUILTIN_EMBEDDING_MODELS - from ..model.embedding.custom import get_user_defined_embeddings - - for model_name, family in BUILTIN_EMBEDDING_MODELS.items(): - if detailed: - ret.append( - await self._to_embedding_model_reg(family, is_builtin=True) - ) - else: - ret.append({"model_name": model_name, "is_builtin": True}) - - for model_spec in get_user_defined_embeddings(): - if detailed: - ret.append( - await self._to_embedding_model_reg(model_spec, is_builtin=False) - ) - else: - ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} - ) - + from ..model.embedding import ( + BUILTIN_EMBEDDING_MODELS, + register_builtin_model, + ) + from ..model.embedding.cache_manager import EmbeddingCacheManager + from ..model.embedding.custom import get_registered_embeddings + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_EMBEDDING_MODELS, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_embeddings, + cache_manager_class=EmbeddingCacheManager, + to_model_reg_func=self._to_embedding_model_reg, + detailed=detailed, + is_dict_builtins=True, + model_type="embedding", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "image": - from ..model.image import BUILTIN_IMAGE_MODELS - from ..model.image.custom import get_user_defined_images - - for model_name, families in BUILTIN_IMAGE_MODELS.items(): - if detailed: - family = [x for x in families if x.model_hub == "huggingface"][0] - info = await self._to_image_model_reg(family, is_builtin=True) - info["download_hubs"] = [x.model_hub for x in families] - ret.append(info) - else: - ret.append({"model_name": model_name, "is_builtin": True}) - - for model_spec in get_user_defined_images(): - if detailed: - ret.append( - await self._to_image_model_reg(model_spec, is_builtin=False) - ) - else: - ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} - ) - + from ..model.image import BUILTIN_IMAGE_MODELS, register_builtin_model + from ..model.image.cache_manager import ImageCacheManager + from ..model.image.custom import get_registered_images + + class ImageTransformFunc: + def __init__(self): + self.post_transform = self.image_post_transform + + def __call__(self, model_name, families): + """Transform function for image models to select huggingface family and add download_hubs""" + if isinstance(families, list): + return [x for x in families if x.model_hub == "huggingface"][0] + else: + # For single model objects + return families + + def image_post_transform(self, model_name, families, reg_data): + """Post-transform function for image models to add download_hubs""" + if isinstance(families, list): + reg_data["download_hubs"] = [x.model_hub for x in families] + return reg_data + + image_transform_func = ImageTransformFunc() + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_IMAGE_MODELS, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_images, + cache_manager_class=ImageCacheManager, + to_model_reg_func=self._to_image_model_reg, + detailed=detailed, + is_dict_builtins=True, + builtin_transform_func=image_transform_func, + model_type="image", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "audio": - from ..model.audio import BUILTIN_AUDIO_MODELS - from ..model.audio.custom import get_user_defined_audios - - for model_name, families in BUILTIN_AUDIO_MODELS.items(): - if detailed: - family = [x for x in families if x.model_hub == "huggingface"][0] - info = await self._to_audio_model_reg(family, is_builtin=True) - info["download_hubs"] = [x.model_hub for x in families] - ret.append(info) - else: - ret.append({"model_name": model_name, "is_builtin": True}) - - for model_spec in get_user_defined_audios(): - if detailed: - ret.append( - await self._to_audio_model_reg(model_spec, is_builtin=False) - ) - else: - ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} - ) - + from ..model.audio import BUILTIN_AUDIO_MODELS, register_builtin_model + from ..model.audio.cache_manager import AudioCacheManager + from ..model.audio.custom import get_registered_audios + + class AudioTransformFunc: + def __init__(self): + self.post_transform = self.audio_post_transform + + def __call__(self, model_name, families): + """Transform function for audio models to select huggingface family""" + if isinstance(families, list): + return [x for x in families if x.model_hub == "huggingface"][0] + else: + return families + + def audio_post_transform(self, model_name, families, reg_data): + """Post-transform function for audio models to add download_hubs""" + if isinstance(families, list): + reg_data["download_hubs"] = [x.model_hub for x in families] + return reg_data + + audio_transform_func = AudioTransformFunc() + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_AUDIO_MODELS, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_audios, + cache_manager_class=AudioCacheManager, + to_model_reg_func=self._to_audio_model_reg, + detailed=detailed, + is_dict_builtins=True, + builtin_transform_func=audio_transform_func, + model_type="audio", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "video": - from ..model.video import BUILTIN_VIDEO_MODELS - - for model_name, families in BUILTIN_VIDEO_MODELS.items(): - if detailed: - family = [x for x in families if x.model_hub == "huggingface"][0] - info = await self._to_video_model_reg(family, is_builtin=True) - info["download_hubs"] = [x.model_hub for x in families] - ret.append(info) - else: - ret.append({"model_name": model_name, "is_builtin": True}) - + from ..model.video import BUILTIN_VIDEO_MODELS, register_builtin_model + from ..model.video.cache_manager import VideoCacheManager + from ..model.video.custom import get_registered_videos + + class VideoTransformFunc: + def __init__(self): + self.post_transform = self.video_post_transform + + def __call__(self, model_name, families): + """Transform function for video models to select huggingface family""" + if isinstance(families, list): + return [x for x in families if x.model_hub == "huggingface"][0] + else: + return families + + def video_post_transform(self, model_name, families, reg_data): + """Post-transform function for video models to add download_hubs""" + if isinstance(families, list): + reg_data["download_hubs"] = [x.model_hub for x in families] + return reg_data + + video_transform_func = VideoTransformFunc() + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_VIDEO_MODELS, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_videos, + cache_manager_class=VideoCacheManager, + to_model_reg_func=self._to_video_model_reg, + detailed=detailed, + is_dict_builtins=True, + builtin_transform_func=video_transform_func, + model_type="video", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "rerank": - from ..model.rerank import BUILTIN_RERANK_MODELS - from ..model.rerank.custom import get_user_defined_reranks - - for model_name, family in BUILTIN_RERANK_MODELS.items(): - if detailed: - ret.append(await self._to_rerank_model_reg(family, is_builtin=True)) - else: - ret.append({"model_name": model_name, "is_builtin": True}) - - for model_spec in get_user_defined_reranks(): - if detailed: - ret.append( - await self._to_rerank_model_reg(model_spec, is_builtin=False) - ) - else: - ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} - ) - + from ..model.rerank import BUILTIN_RERANK_MODELS, register_builtin_model + from ..model.rerank.cache_manager import RerankCacheManager + from ..model.rerank.custom import get_registered_reranks + + model_regs = await self._collect_model_registrations( + builtin_models=BUILTIN_RERANK_MODELS, + register_builtin_func=register_builtin_model, + get_registered_func=get_registered_reranks, + cache_manager_class=RerankCacheManager, + to_model_reg_func=self._to_rerank_model_reg, + detailed=detailed, + is_dict_builtins=True, + model_type="rerank", + ) + ret.extend(model_regs) ret.sort(key=sort_helper) return ret elif model_type == "flexible": @@ -748,13 +962,29 @@ def sort_helper(item): ret = [] for model_spec in get_flexible_models(): + cache_manager = CacheManager(model_spec) + is_persisted_model = False + if hasattr(cache_manager, "_v2_custom_dir_prefix"): + import os + + potential_persist_path = os.path.join( + cache_manager._v2_custom_dir_prefix, + "flexible", + f"{model_spec.model_name}.json", + ) + is_persisted_model = os.path.exists(potential_persist_path) + + is_builtin = is_persisted_model # Treat persisted models as built-in + if detailed: ret.append( - await self._to_flexible_model_reg(model_spec, is_builtin=False) + await self._to_flexible_model_reg( + model_spec, is_builtin=is_builtin + ) ) else: ret.append( - {"model_name": model_spec.model_name, "is_builtin": False} + {"model_name": model_spec.model_name, "is_builtin": is_builtin} ) ret.sort(key=sort_helper) @@ -772,27 +1002,27 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if f is not None: return f - if model_type == "LLM": - from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families + if model_type.upper() == "LLM": + from ..model.llm import BUILTIN_LLM_FAMILIES, get_registered_llm_families - for f in BUILTIN_LLM_FAMILIES + get_user_defined_llm_families(): + for f in BUILTIN_LLM_FAMILIES + get_registered_llm_families(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "embedding": from ..model.embedding import BUILTIN_EMBEDDING_MODELS - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings for f in ( - list(BUILTIN_EMBEDDING_MODELS.values()) + get_user_defined_embeddings() + list(BUILTIN_EMBEDDING_MODELS.values()) + get_registered_embeddings() ): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "image": from ..model.image import BUILTIN_IMAGE_MODELS - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images if model_name in BUILTIN_IMAGE_MODELS: return [ @@ -801,13 +1031,13 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if x.model_hub == "huggingface" ][0] else: - for f in get_user_defined_images(): + for f in get_registered_images(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "audio": from ..model.audio import BUILTIN_AUDIO_MODELS - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios if model_name in BUILTIN_AUDIO_MODELS: return [ @@ -816,15 +1046,15 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: if x.model_hub == "huggingface" ][0] else: - for f in get_user_defined_audios(): + for f in get_registered_audios(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") elif model_type == "rerank": from ..model.rerank import BUILTIN_RERANK_MODELS - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks - for f in list(BUILTIN_RERANK_MODELS.values()) + get_user_defined_reranks(): + for f in list(BUILTIN_RERANK_MODELS.values()) + get_registered_reranks(): if f.model_name == model_name: return f raise ValueError(f"Model {model_name} not found") @@ -837,6 +1067,7 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: raise ValueError(f"Model {model_name} not found") elif model_type == "video": from ..model.video import BUILTIN_VIDEO_MODELS + from ..model.video.custom import get_registered_videos if model_name in BUILTIN_VIDEO_MODELS: return [ @@ -844,6 +1075,10 @@ async def get_model_registration(self, model_type: str, model_name: str) -> Any: for x in BUILTIN_VIDEO_MODELS[model_name] if x.model_hub == "huggingface" ][0] + else: + for f in get_registered_videos(): + if f.model_name == model_name: + return f raise ValueError(f"Model {model_name} not found") else: raise ValueError(f"Unsupported model type: {model_type}") @@ -911,50 +1146,151 @@ async def register_model( f"Worker ip address {worker_ip} is not in the cluster." ) + # Select worker to handle registration if target_ip_worker_ref: - await target_ip_worker_ref.register_model(model_type, model, persist) - return + # Specific worker requested + chosen_worker = target_ip_worker_ref + else: + # Choose first available worker + if self._worker_address_to_worker: + chosen_worker = list(self._worker_address_to_worker.values())[0] + else: + raise RuntimeError("No workers available for model registration") try: - register_fn(model_spec, persist) + # Forward registration to chosen worker - worker handles everything + await chosen_worker.register_model(model_type, model, persist) + + # Record model version for cache tracking await self._cache_tracker_ref.record_model_version( generate_fn(model_spec), self.address ) - await self._sync_register_model( - model_type, model, persist, model_spec.model_name - ) - except ValueError as e: - raise e except Exception as e: - unregister_fn(model_spec.model_name, raise_error=False) + # Registration failed, cleanup if needed + logger.error(f"Model registration failed on worker: {e}") raise e else: raise ValueError(f"Unsupported model type: {model_type}") - async def _sync_register_model( - self, model_type: str, model: str, persist: bool, model_name: str - ): - logger.info(f"begin sync model: {model_name} to worker") + @log_async(logger=logger) + async def add_model(self, model_name: str): + """ + Add a new model by forwarding the request to all workers. + + Args: + model_name: Name of the model to add + """ + + self._log_debug(f"[ADD_MODEL_DEBUG] Adding model: {model_name}") + try: - # Sync model to all workers. - for name, worker in self._worker_address_to_worker.items(): - logger.info(f"sync model: {model_name} to {name}") - if name == self.address: - # Ignore: when worker and supervisor at the same node. - logger.info( - f"ignore sync model: {model_name} to {name} for same node" - ) - else: - await worker.register_model(model_type, model, persist) - logger.info(f"success sync model: {model_name} to {name}") + # Forward the add_model request to all workers + tasks = [] + for worker_address, worker_ref in self._worker_address_to_worker.items(): + self._log_debug( + f"[ADD_MODEL_DEBUG] Forwarding to worker: {worker_address}" + ) + tasks.append(worker_ref.add_model(model_name)) + + # Wait for all workers to complete the operation + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + successful_workers = 0 + failed_workers = 0 + + for i, result in enumerate(results): + if isinstance(result, Exception): + failed_workers += 1 + error_msg = str(result) + worker_addr = list(self._worker_address_to_worker.keys())[i] + + self._log_debug( + f"[ADD_MODEL_DEBUG] Worker {worker_addr} failed: {error_msg}" + ) + else: + successful_workers += 1 + worker_addr = list(self._worker_address_to_worker.keys())[i] + self._log_debug( + f"[ADD_MODEL_DEBUG] Worker {worker_addr} succeeded" + ) + + # Determine overall result + if successful_workers == 0: + # All workers failed + raise RuntimeError(f"All workers failed to add model {model_name}") + else: + logger.warning(f"No workers available to forward add_model request") + + self._log_debug( + f"[ADD_MODEL_DEBUG] Successfully completed add_model for: {model_name}" + ) + + except Exception as e: + self._log_debug( + f"[ADD_MODEL_DEBUG] Failed to add model {model_name}: {str(e)}" + ) + logger.error( + f"Error during add_model forwarding: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to add model: {str(e)}") + + async def update_model_type(self, model_type: str): + """ + Update model configurations for a specific model type by forwarding + the request to all workers. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + """ + + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Starting update_model_type for: {model_type}" + ) + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Available workers: {list(self._worker_address_to_worker.keys())}" + ) + + try: + # Forward the update_model_type request to all workers + tasks = [] + for worker_address, worker_ref in self._worker_address_to_worker.items(): + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Forwarding update_model_type to worker: {worker_address}" + ) + tasks.append(worker_ref.update_model_type(model_type)) + + # Wait for all workers to complete the operation + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + for i, result in enumerate(results): + if isinstance(result, Exception): + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Worker {list(self._worker_address_to_worker.keys())[i]} failed: {result}" + ) + else: + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Worker {list(self._worker_address_to_worker.keys())[i]} succeeded" + ) + else: + logger.warning( + f"No workers available to forward update_model_type request" + ) + + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Successfully completed update_model_type for: {model_type}" + ) + except Exception as e: - # If sync fails, unregister the model in all workers. - for name, worker in self._worker_address_to_worker.items(): - logger.warning(f"ready to unregister model for {name}") - await worker.unregister_model(model_type, model_name) - logger.warning(f"finish unregister model: {model} for {name}") - raise e + self._log_debug( + f"[UPDATE_MODEL_TYPE_DEBUG] Failed to update model type {model_type}: {str(e)}" + ) + logger.error( + f"Error during update_model_type forwarding: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to update model type: {str(e)}") @log_async(logger=logger) async def unregister_model(self, model_type: str, model_name: str): diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 3a211b19e3..4458d6db25 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -272,6 +272,22 @@ async def __post_create__(self): register_rerank, unregister_rerank, ) + from ..model.video import ( + CustomVideoModelFamilyV2, + generate_video_description, + register_video, + unregister_video, + ) + + # Load persisted models on startup + try: + from ..model.utils import load_persisted_models_to_registry + + loaded_count = load_persisted_models_to_registry() + if loaded_count > 0: + logger.info(f"Loaded {loaded_count} persisted models on startup") + except Exception as e: + logger.warning(f"Failed to load persisted models on startup: {e}") self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore "LLM": ( @@ -310,6 +326,12 @@ async def __post_create__(self): unregister_flexible_model, generate_flexible_model_description, ), + "video": ( + CustomVideoModelFamilyV2, + register_video, + unregister_video, + generate_video_description, + ), } logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR) @@ -652,6 +674,604 @@ async def unregister_model(self, model_type: str, model_name: str): else: raise ValueError(f"Unsupported model type: {model_type}") + def _check_model_file_exists(self, model_type: str, model_name: str) -> bool: + """ + Check if a model file already exists in the filesystem. + + Args: + model_type: Type of the model (llm, embedding, audio, etc.) + model_name: Name of the model + + Returns: + True if model file exists, False otherwise + """ + import json + import os + + from ..constants import XINFERENCE_MODEL_DIR + + try: + model_type_lower = model_type.lower() + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + + # Check individual model file + individual_file = os.path.join(builtin_dir, f"{model_name}.json") + if os.path.exists(individual_file): + return True + + # Check unified model file + unified_file = os.path.join(builtin_dir, f"{model_type_lower}_models.json") + if os.path.exists(unified_file): + try: + with open(unified_file, "r", encoding="utf-8") as f: + models = json.load(f) + if isinstance(models, list): + return any( + m.get("model_name") == model_name for m in models + ) + elif isinstance(models, dict): + return model_name in models + except (json.JSONDecodeError, IOError): + pass + + return False + except Exception: + return False + + async def _check_model_already_registered( + self, model_type: str, model_name: str + ) -> bool: + """ + Check if a model is already registered in the system. + + Args: + model_type: Type of the model (llm, embedding, audio, etc.) + model_name: Name of the model + + Returns: + True if model is already registered, False otherwise + """ + try: + # Check if model is already registered using existing method + existing_model = await self.get_model_registration(model_type, model_name) + return existing_model is not None + except ValueError as e: + # "not found" error means model is not registered + if "not found" in str(e): + return False + # Other ValueError means there's a real issue + raise + except Exception: + # For any other exception, assume model might exist and be cautious + logger.warning( + f"Error checking if model {model_name} is registered, assuming it exists" + ) + return True + + @log_async(logger=logger) + async def add_model(self, model_name: str): + """ + Add a new model by first getting its type information, then downloading + its JSON configuration and storing it as an individual file. + + Args: + model_name: Name of the model to add + """ + import json + + import requests + + try: + # Step 1: Get model details first to determine the model type + logger.info(f"Getting model type information for: {model_name}") + info_url = f"https://model.xinference.io/api/models/{model_name}" + info_response = requests.get(info_url, timeout=30) + info_response.raise_for_status() + + model_info = info_response.json() + + # Extract model_type from the response - handle nested data structure + model_type = None + if "data" in model_info and "model_type" in model_info["data"]: + model_type = model_info["data"]["model_type"] + elif "model_type" in model_info: + model_type = model_info["model_type"] + + if not model_type: + logger.error(f"No model_type found in model info for: {model_name}") + logger.error(f"Response structure: {list(model_info.keys())}") + if "data" in model_info: + logger.error(f"Data structure: {list(model_info['data'].keys())}") + raise ValueError(f"Model type not found for model: {model_name}") + + logger.info(f"Retrieved model type: {model_type} for model: {model_name}") + + # Step 1.5: Log model existence but don't block the download + logger.info(f"Checking if model {model_name} already exists") + file_exists = self._check_model_file_exists(model_type, model_name) + registered = await self._check_model_already_registered( + model_type, model_name + ) + + if file_exists or registered: + logger.info( + f"Model {model_name} (type: {model_type}) already exists, proceeding with download anyway" + ) + if file_exists: + logger.info( + " - Model file exists in filesystem, will be overwritten" + ) + if registered: + logger.info(" - Model is registered in the system") + + # Step 2: Download model JSON configuration using the original download API + logger.info(f"Downloading model configuration for: {model_name}") + download_url = f"https://model.xinference.io/api/models/download?model_name={model_name}" + download_response = requests.get(download_url, timeout=30) + download_response.raise_for_status() + + model_data = download_response.json() + logger.info(f"Downloaded model configuration for: {model_name}") + + # Validate model type is supported + supported_types = list(self._custom_register_type_to_cls.keys()) + normalized_for_validation = model_type + + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_for_validation = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_for_validation = "llm" + + if normalized_for_validation not in supported_types: + logger.error(f"Unsupported model type: {normalized_for_validation}") + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + # Step 3: Save the model to ensure it's properly registered + await self._save_model_as_individual_file(model_type, model_data) + + # Dynamically reload built-in models to make the new model immediately available + try: + if model_type.lower() == "llm": + from ..model.llm import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "embedding": + from ..model.embedding import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "audio": + from ..model.audio import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "image": + from ..model.image import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "rerank": + from ..model.rerank import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "video": + from ..model.video import register_builtin_model + + register_builtin_model() + else: + logger.warning( + f"No dynamic loading available for model type: {model_type}" + ) + except Exception as reload_error: + logger.error( + f"Error reloading built-in models: {reload_error}", + exc_info=True, + ) + # Don't fail the add if reload fails, just log the error + + logger.info(f"Successfully added model: {model_name} (type: {model_type})") + + except requests.exceptions.RequestException as e: + logger.error(f"Network error downloading model configuration: {e}") + raise ValueError(f"Failed to download model configuration: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {e}") + raise ValueError(f"Invalid JSON response from remote API: {str(e)}") + except Exception as e: + logger.error( + f"Unexpected error during model addition: {e}", + exc_info=True, + ) + raise ValueError(f"Failed to add model: {str(e)}") + + async def _add_model_as_individual_file( + self, model_type: str, model_name: str, model_data + ): + """ + Add a single model as an individual JSON file for the model type. + + Args: + model_type: Type of the model (llm, embedding, audio, etc.) + model_name: Name of the model + model_data: Model configuration data + """ + import json + import os + + from ..constants import XINFERENCE_MODEL_DIR + + try: + model_type_lower = model_type.lower() + + # Use the model-specific directory path + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Save as individual JSON file named after the model + json_file_path = os.path.join(builtin_dir, f"{model_name}.json") + + # Check if file already exists (extra safety check) + if os.path.exists(json_file_path): + logger.warning( + f"Model file {json_file_path} already exists, overwriting..." + ) + # Create backup of existing file + backup_path = f"{json_file_path}.backup" + try: + import shutil + + shutil.copy2(json_file_path, backup_path) + logger.info(f"Created backup of existing model file: {backup_path}") + except Exception as backup_error: + logger.warning(f"Failed to create backup: {backup_error}") + + # Save the model configuration + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + logger.info( + f"Added model {model_name} as individual file: {json_file_path}" + ) + + except Exception as e: + logger.error( + f"Error adding model as individual file: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store model configuration: {str(e)}") + + def _infer_model_type(self, model_data): + """ + Infer model type from the structure of the downloaded JSON. + """ + # Method 1: Check if model_type is explicitly present + if "model_type" in model_data: + return model_data["model_type"] + + # Method 2: Infer from model_ability field + model_ability = model_data.get("model_ability", []) + if model_ability: + if "embed" in model_ability: + return "embedding" + elif "rerank" in model_ability: + return "rerank" + elif "chat" in model_ability or "generate" in model_ability: + return "llm" + elif "image-to-text" in model_ability or "text-to-image" in model_ability: + return "image" + elif "audio-to-text" in model_ability or "text-to-audio" in model_ability: + return "audio" + elif "text-to-video" in model_ability or "video-to-text" in model_ability: + return "video" + + # Method 3: Infer from specific fields + if "dimensions" in model_data and "max_tokens" in model_data: + return "embedding" + + if "context_length" in model_data and "chat_template" in model_data: + return "llm" + + # Method 4: Infer from model_specs structure + model_specs = model_data.get("model_specs", []) + if model_specs: + # Check if any spec has embedding-like characteristics + for spec in model_specs: + if "dimensions" in spec: + return "embedding" + + # Check if any spec has LLM-like characteristics + if "chat_template" in model_data or "context_length" in model_data: + return "llm" + + # Method 5: Default to LLM as the most common type + logger.warning( + f"Could not definitively determine model type for {model_data.get('model_name', 'unknown')}, defaulting to 'llm'" + ) + return "llm" + + async def _add_single_model_to_unified_json(self, model_type: str, model_data): + """ + Add a single model to the unified JSON file for the model type. + This follows the same storage pattern as update_model_type. + """ + import json + + from ..constants import XINFERENCE_MODEL_DIR + + try: + model_type_lower = model_type.lower() + + # Use the unified JSON file path (same as update_model_type logic) + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + json_file_path = os.path.join( + builtin_dir, f"{model_type_lower}_models.json" + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Read existing unified JSON file if it exists + existing_models = [] + if os.path.exists(json_file_path): + try: + with open(json_file_path, "r", encoding="utf-8") as f: + existing_data = json.load(f) + # Handle both array format and object format + if isinstance(existing_data, list): + existing_models = existing_data + else: + # If it's an object, try to extract models + existing_models = ( + list(existing_data.values()) if existing_data else [] + ) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Error reading existing models file: {e}") + existing_models = [] + + # Check if model already exists in the unified file + model_name = model_data.get("model_name") + if model_name: + for i, existing_model in enumerate(existing_models): + if existing_model.get("model_name") == model_name: + logger.warning( + f"Model {model_name} already exists, updating..." + ) + existing_models[i] = model_data + break + else: + # Model doesn't exist, add it + existing_models.append(model_data) + + # Save the updated unified JSON file + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(existing_models, f, indent=2, ensure_ascii=False) + + logger.info( + f"Added model {model_name} to unified JSON file: {json_file_path}" + ) + + except Exception as e: + logger.error( + f"Error adding model to unified JSON: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store model configuration: {str(e)}") + + @log_async(logger=logger) + async def update_model_type(self, model_type: str): + """ + Update model configurations for a specific model type by downloading + the latest JSON from the remote API and storing it locally. + + Args: + model_type: Type of model (LLM, embedding, image, etc.) + """ + import json + + import requests + + supported_types = list(self._custom_register_type_to_cls.keys()) + + normalized_for_validation = model_type + if model_type.lower() == "llm" and "LLM" in supported_types: + normalized_for_validation = "LLM" + elif model_type.lower() == "llm" and "llm" in supported_types: + normalized_for_validation = "llm" + + if normalized_for_validation not in supported_types: + logger.error(f"Unsupported model type: {normalized_for_validation}") + raise ValueError( + f"Unsupported model type '{model_type}'. " + f"Supported types are: {', '.join(supported_types)}" + ) + + # Construct the URL to download JSON + url = f"https://model.xinference.io/api/models/download?model_type={model_type.lower()}" + + try: + # Download JSON from remote API + response = requests.get(url, timeout=30) + response.raise_for_status() + + # Parse JSON response + model_data = response.json() + + # Store the JSON data using CacheManager as built-in models + await self._store_complete_model_configurations(model_type, model_data) + + # Dynamically reload built-in models to make them immediately available + try: + if model_type.lower() == "llm": + from ..model.llm import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "embedding": + from ..model.embedding import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "audio": + from ..model.audio import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "image": + from ..model.image import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "rerank": + from ..model.rerank import register_builtin_model + + register_builtin_model() + elif model_type.lower() == "video": + from ..model.video import register_builtin_model + + register_builtin_model() + else: + logger.warning( + f"No dynamic loading available for model type: {model_type}" + ) + except Exception as reload_error: + logger.error( + f"Error reloading built-in models: {reload_error}", + exc_info=True, + ) + # Don't fail the update if reload fails, just log the error + + except requests.exceptions.RequestException as e: + logger.error(f"Network error downloading model configurations: {e}") + raise ValueError(f"Failed to download model configurations: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {e}") + raise ValueError(f"Invalid JSON response from remote API: {str(e)}") + except Exception as e: + logger.error( + f"Unexpected error during model update: {e}", + exc_info=True, + ) + raise ValueError(f"Failed to update model configurations: {str(e)}") + + async def _store_model_configurations(self, model_type: str, model_data): + """ + Store model configurations as separate JSON files (one per model). + This follows the same pattern as CacheManager.register_builtin_model. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing model configurations (can be single dict or list) + """ + import json + + from ..constants import XINFERENCE_MODEL_DIR + + try: + # Ensure model_data is a list for consistent processing + if isinstance(model_data, dict): + models_to_store = [model_data] + elif isinstance(model_data, list): + models_to_store = model_data + else: + raise ValueError(f"Invalid model_data type: {type(model_data)}") + + model_type_lower = model_type.lower() + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Store each model as a separate JSON file + for model_dict in models_to_store: + if not isinstance(model_dict, dict): + logger.warning(f"Skipping invalid model data: {model_dict}") + continue + + model_name = model_dict.get("model_name") + if not model_name: + logger.warning(f"Skipping model without model_name: {model_dict}") + continue + + # Create file path using model name (same as CacheManager pattern) + json_file_path = os.path.join(builtin_dir, f"{model_name}.json") + + # Store the model as a separate JSON file + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(model_dict, f, indent=2, ensure_ascii=False) + + except Exception as e: + logger.error( + f"Error storing model configurations: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store model configurations: {str(e)}") + + async def _store_complete_model_configurations(self, model_type: str, model_data): + """ + Store complete model configurations as a unified JSON file. + This is used by update_model_type to preserve the original JSON structure. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing model configurations (complete array) + """ + import json + + from ..constants import XINFERENCE_MODEL_DIR + + try: + model_type_lower = model_type.lower() + + # Use the unified JSON file path (same as original update_model_type logic) + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type_lower + ) + json_file_path = os.path.join( + builtin_dir, f"{model_type_lower}_models.json" + ) + + # Ensure directory exists + os.makedirs(builtin_dir, exist_ok=True) + + # Store the complete JSON file (preserving original structure) + with open(json_file_path, "w", encoding="utf-8") as f: + json.dump(model_data, f, indent=2, ensure_ascii=False) + + except Exception as e: + logger.error( + f"Error storing complete model configurations: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to store complete model configurations: {str(e)}") + + async def _save_model_as_individual_file(self, model_type: str, model_data): + """ + Save a model as an individual JSON file named after the model. + + Args: + model_type: Type of model (as provided by user, e.g., "llm") + model_data: JSON data containing a single model configuration + """ + try: + model_name = model_data.get("model_name") + if not model_name: + raise ValueError("Model name not found in model data") + + await self._add_model_as_individual_file(model_type, model_name, model_data) + + except Exception as e: + logger.error( + f"Error saving model as individual file: {str(e)}", + exc_info=True, + ) + raise ValueError(f"Failed to save model as individual file: {str(e)}") + @log_async(logger=logger) async def list_model_registrations( self, model_type: str, detailed: bool = False @@ -661,41 +1281,41 @@ def sort_helper(item): return item.get("model_name").lower() if model_type == "LLM": - from ..model.llm import get_user_defined_llm_families + from ..model.llm import get_registered_llm_families ret = [] - for family in get_user_defined_llm_families(): + for family in get_registered_llm_families(): ret.append({"model_name": family.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "embedding": - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings ret = [] - for model_spec in get_user_defined_embeddings(): + for model_spec in get_registered_embeddings(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "image": - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images ret = [] - for model_spec in get_user_defined_images(): + for model_spec in get_registered_images(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) return ret elif model_type == "audio": - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios ret = [] - for model_spec in get_user_defined_audios(): + for model_spec in get_registered_audios(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) @@ -703,11 +1323,11 @@ def sort_helper(item): elif model_type == "video": return [] elif model_type == "rerank": - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks ret = [] - for model_spec in get_user_defined_reranks(): + for model_spec in get_registered_reranks(): ret.append({"model_name": model_spec.model_name, "is_builtin": False}) ret.sort(key=sort_helper) @@ -728,35 +1348,35 @@ def sort_helper(item): @log_sync(logger=logger) async def get_model_registration(self, model_type: str, model_name: str) -> Any: if model_type == "LLM": - from ..model.llm import get_user_defined_llm_families + from ..model.llm import get_registered_llm_families - for f in get_user_defined_llm_families(): + for f in get_registered_llm_families(): if f.model_name == model_name: return f elif model_type == "embedding": - from ..model.embedding.custom import get_user_defined_embeddings + from ..model.embedding.custom import get_registered_embeddings - for f in get_user_defined_embeddings(): + for f in get_registered_embeddings(): if f.model_name == model_name: return f elif model_type == "image": - from ..model.image.custom import get_user_defined_images + from ..model.image.custom import get_registered_images - for f in get_user_defined_images(): + for f in get_registered_images(): if f.model_name == model_name: return f elif model_type == "audio": - from ..model.audio.custom import get_user_defined_audios + from ..model.audio.custom import get_registered_audios - for f in get_user_defined_audios(): + for f in get_registered_audios(): if f.model_name == model_name: return f elif model_type == "video": return None elif model_type == "rerank": - from ..model.rerank.custom import get_user_defined_reranks + from ..model.rerank.custom import get_registered_reranks - for f in get_user_defined_reranks(): + for f in get_registered_reranks(): if f.model_name == model_name: return f return None diff --git a/xinference/model/audio/__init__.py b/xinference/model/audio/__init__.py index 9465771917..b13c7b73d4 100644 --- a/xinference/model/audio/__init__.py +++ b/xinference/model/audio/__init__.py @@ -14,14 +14,19 @@ import codecs import json +import logging import os import platform import sys import warnings -from typing import Dict, List +from typing import Any, Dict, List from ...constants import XINFERENCE_MODEL_DIR from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + from .core import ( AUDIO_MODEL_DESCRIPTIONS, AudioModelFamilyV2, @@ -30,7 +35,7 @@ ) from .custom import ( CustomAudioModelFamilyV2, - get_user_defined_audios, + get_registered_audios, register_audio, unregister_audio, ) @@ -60,6 +65,37 @@ def register_custom_model(): warnings.warn(f"{user_defined_audio_dir}/{f} has error, {e}") +def register_builtin_model(): + # Use unified loading function with flatten_model_src + audio-specific defaults + from ..utils import flatten_model_src, load_complete_builtin_models + + def convert_audio_with_flatten(model_json): + flattened_list = flatten_model_src(model_json) + if not flattened_list: + return model_json + + result = flattened_list[0] + + # Add required defaults for audio models + if "multilingual" not in result: + result["multilingual"] = True + if "model_lang" not in result: + result["model_lang"] = ["en", "zh"] + if "version" not in result: + result["version"] = 2 + + return result + + loaded_count = load_complete_builtin_models( + model_type="audio", + builtin_registry=BUILTIN_AUDIO_MODELS, + convert_format_func=convert_audio_with_flatten, + model_class=AudioModelFamilyV2, + ) + + logger.info(f"Successfully loaded {loaded_count} audio models from complete JSON") + + def _need_filter(spec: dict): if (sys.platform != "darwin" or platform.processor() != "arm") and spec.get( "engine", "" @@ -80,7 +116,7 @@ def _install(): register_custom_model() # register model description - for ud_audio in get_user_defined_audios(): + for ud_audio in get_registered_audios(): AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio)) diff --git a/xinference/model/audio/cache_manager.py b/xinference/model/audio/cache_manager.py new file mode 100644 index 0000000000..6364751cf4 --- /dev/null +++ b/xinference/model/audio/cache_manager.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING + +from ..cache_manager import CacheManager + +if TYPE_CHECKING: + from .core import AudioModelFamilyV2 + + +class AudioCacheManager(CacheManager): + def __init__(self, model_family: "AudioModelFamilyV2"): + super().__init__(model_family) + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for audio models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of an audio model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if an audio model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/audio/core.py b/xinference/model/audio/core.py index e2d147fa38..666bf45c1d 100644 --- a/xinference/model/audio/core.py +++ b/xinference/model/audio/core.py @@ -100,9 +100,9 @@ def match_audio( ) -> AudioModelFamilyV2: from ..utils import download_from_modelscope from . import BUILTIN_AUDIO_MODELS - from .custom import get_user_defined_audios + from .custom import get_registered_audios - for model_spec in get_user_defined_audios(): + for model_spec in get_registered_audios(): if model_spec.model_name == model_name: return model_spec diff --git a/xinference/model/audio/custom.py b/xinference/model/audio/custom.py index 8024078481..b38fc1378e 100644 --- a/xinference/model/audio/custom.py +++ b/xinference/model/audio/custom.py @@ -83,7 +83,11 @@ def __init__(self): self.builtin_models = list(BUILTIN_AUDIO_MODELS.keys()) -def get_user_defined_audios() -> List[CustomAudioModelFamilyV2]: +def get_registered_audios() -> List[CustomAudioModelFamilyV2]: + """ + Get all audio families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("audio") diff --git a/xinference/model/cache_manager.py b/xinference/model/cache_manager.py index ae9a9f1bfd..e0d444e91f 100644 --- a/xinference/model/cache_manager.py +++ b/xinference/model/cache_manager.py @@ -1,3 +1,4 @@ +import json import logging import os from typing import TYPE_CHECKING @@ -16,8 +17,12 @@ def __init__(self, model_family: "CacheableModelSpec"): self._model_family = model_family self._v2_cache_dir_prefix = os.path.join(XINFERENCE_CACHE_DIR, "v2") self._v2_custom_dir_prefix = os.path.join(XINFERENCE_MODEL_DIR, "v2") + self._v2_builtin_dir_prefix = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin" + ) os.makedirs(self._v2_cache_dir_prefix, exist_ok=True) os.makedirs(self._v2_custom_dir_prefix, exist_ok=True) + os.makedirs(self._v2_builtin_dir_prefix, exist_ok=True) self._cache_dir = os.path.join( self._v2_cache_dir_prefix, self._model_family.model_name.replace(".", "_") ) @@ -109,9 +114,21 @@ def cache(self) -> str: return self._cache() def register_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, + f"{self._model_family.model_name}.json", + ) + os.makedirs(os.path.dirname(persist_path), exist_ok=True) + with open(persist_path, mode="w") as fd: + fd.write(self._model_family.json()) + + def register_builtin_model(self, model_type: str): + model_type_dir = model_type.lower() + persist_path = os.path.join( + self._v2_builtin_dir_prefix, + model_type_dir, f"{self._model_family.model_name}.json", ) os.makedirs(os.path.dirname(persist_path), exist_ok=True) @@ -119,9 +136,10 @@ def register_custom_model(self, model_type: str): fd.write(self._model_family.json()) def unregister_custom_model(self, model_type: str): + model_type_dir = model_type.lower() persist_path = os.path.join( self._v2_custom_dir_prefix, - model_type, + model_type_dir, f"{self._model_family.model_name}.json", ) if os.path.exists(persist_path): @@ -139,3 +157,108 @@ def unregister_custom_model(self, model_type: str): logger.warning( f"Cache directory is not a soft link, please remove it manually." ) + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory (added via update_model_type) + or from the custom directory (true user custom models). + + This method encapsulates the logic for parsing builtin model metadata + and determining if a model was added through the update_model_type mechanism. + + Args: + model_name: Name of the model to check + model_type: Type of the model (e.g., "llm", "embedding", "image") + + Returns: + True if model is from builtin directory (added via update_model_type), + False if model is from custom directory or not found. + """ + from ..constants import XINFERENCE_MODEL_DIR + + # Check builtin directory (update_model_type models) + builtin_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", model_type.lower() + ) + builtin_file = os.path.join(builtin_dir, f"{model_name}.json") + + if os.path.exists(builtin_file): + return True + + # Also check unified JSON file for models added via update_model_type + unified_json = os.path.join(builtin_dir, f"{model_type.lower()}_models.json") + if os.path.exists(unified_json): + try: + with open(unified_json, "r", encoding="utf-8") as f: + data = json.load(f) + + # Check if model_name exists in this JSON file + if isinstance(data, list): + return any(model.get("model_name") == model_name for model in data) + elif isinstance(data, dict): + if data.get("model_name") == model_name: + return True + else: + # Check dict values + return any( + isinstance(value, dict) + and value.get("model_name") == model_name + for value in data.values() + ) + except Exception: + # If JSON parsing fails, assume model is not from builtin dir + pass + + return False + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of a model with unified logic. + + This method provides a single source of truth for determining model sources, + replacing scattered logic across the codebase. + + Args: + model_name: Name of the model to check + model_type: Type of the model (e.g., "llm", "embedding", "image") + builtin_model_names: Set of hardcoded builtin model names (for checking builtin status) + + Returns: + "builtin" - Hardcoded builtin models (in code) + "editor_builtin" - Models added via update_model_type mechanism + "user" - True user-defined models (custom directory) + """ + # 1. Check if it's a hardcoded builtin model + if builtin_model_names and model_name in builtin_model_names: + return "builtin" + + # 2. Check if it's an editor-defined model (via update_model_type) + if CacheManager.is_model_from_builtin_dir(model_name, model_type): + return "editor_builtin" + + # 3. Otherwise it's a user-defined model + return "user" + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if a model should be considered builtin. + + Args: + model_name: Name of the model to check + model_type: Type of the model + builtin_model_names: Set of hardcoded builtin model names + + Returns: + True if model is builtin (hardcoded or editor-defined), False if user-defined + """ + source = CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + return source != "user" diff --git a/xinference/model/custom.py b/xinference/model/custom.py index f08a09dfea..a1adee9aea 100644 --- a/xinference/model/custom.py +++ b/xinference/model/custom.py @@ -118,6 +118,7 @@ def get_registry(cls, model_type: str) -> ModelRegistry: from .image.custom import ImageModelRegistry from .llm.custom import LLMModelRegistry from .rerank.custom import RerankModelRegistry + from .video.custom import VideoModelRegistry if model_type not in cls._instances: if model_type == "rerank": @@ -126,6 +127,8 @@ def get_registry(cls, model_type: str) -> ModelRegistry: cls._instances[model_type] = ImageModelRegistry() elif model_type == "audio": cls._instances[model_type] = AudioModelRegistry() + elif model_type == "video": + cls._instances[model_type] = VideoModelRegistry() elif model_type == "llm": cls._instances[model_type] = LLMModelRegistry() elif model_type == "flexible": diff --git a/xinference/model/embedding/__init__.py b/xinference/model/embedding/__init__.py index f1e822e112..05ff82c6f9 100644 --- a/xinference/model/embedding/__init__.py +++ b/xinference/model/embedding/__init__.py @@ -14,11 +14,16 @@ import codecs import json +import logging import os import warnings from typing import Any, Dict, List from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + from .core import ( EMBEDDING_MODEL_DESCRIPTIONS, EmbeddingModelFamilyV2, @@ -27,7 +32,7 @@ ) from .custom import ( CustomEmbeddingModelFamilyV2, - get_user_defined_embeddings, + get_registered_embeddings, register_embedding, unregister_embedding, ) @@ -64,6 +69,100 @@ def register_custom_model(): warnings.warn(f"{user_defined_embedding_dir}/{f} has error, {e}") +def register_builtin_model(): + # Use unified loading function with flatten_quantizations for embedding models + from ..custom import RegistryManager + from ..utils import flatten_quantizations, load_complete_builtin_models + from .embed_family import BUILTIN_EMBEDDING_MODELS + + registry = RegistryManager.get_registry("embedding") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + def convert_embedding_with_quantizations(model_json): + if "model_specs" not in model_json: + return model_json + + # Process each model_spec with flatten_quantizations (like builtin embedding loading) + result = model_json.copy() + flattened_specs = [] + for spec in result["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + result["model_specs"] = flattened_specs + + return result + + loaded_count = load_complete_builtin_models( + model_type="embedding", + builtin_registry=BUILTIN_EMBEDDING_MODELS, # Use actual registry + convert_format_func=convert_embedding_with_quantizations, + model_class=EmbeddingModelFamilyV2, + ) + + # Manually handle embedding's special registration logic + if loaded_count > 0: + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("embedding") + existing_model_names = { + spec.model_name for spec in registry.get_custom_models() + } + + builtin_embedding_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", "embedding" + ) + complete_json_path = os.path.join( + builtin_embedding_dir, "embedding_models.json" + ) + + if os.path.exists(complete_json_path): + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + for model_data in models_to_register: + try: + from ..utils import flatten_quantizations + + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + builtin_embedding_family = EmbeddingModelFamilyV2.parse_obj( + converted_data + ) + + if builtin_embedding_family.model_name not in existing_model_names: + register_embedding(builtin_embedding_family, persist=False) + existing_model_names.add(builtin_embedding_family.model_name) + except Exception as e: + warnings.warn( + f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info( + f"Successfully loaded {loaded_count} embedding models from complete JSON" + ) + + def check_format_with_engine(model_format, engine): if model_format in ["ggufv2"] and engine not in ["llama.cpp"]: return False @@ -151,7 +250,7 @@ def _install(): register_custom_model() # register model description - for ud_embedding in get_user_defined_embeddings(): + for ud_embedding in get_registered_embeddings(): EMBEDDING_MODEL_DESCRIPTIONS.update( generate_embedding_description(ud_embedding) ) diff --git a/xinference/model/embedding/cache_manager.py b/xinference/model/embedding/cache_manager.py index 306488716d..f2be81f54d 100644 --- a/xinference/model/embedding/cache_manager.py +++ b/xinference/model/embedding/cache_manager.py @@ -33,3 +33,32 @@ def cache(self) -> str: return self.cache_helper.cache_from_modelscope() else: raise ValueError(f"Unknown model hub: {spec.model_hub}") + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for embedding models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of an embedding model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if an embedding model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py index 180d2f690a..2e889d5e0a 100644 --- a/xinference/model/embedding/custom.py +++ b/xinference/model/embedding/custom.py @@ -69,7 +69,11 @@ def remove_ud_model_files(self, model_family: "CustomEmbeddingModelFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_embeddings() -> List[EmbeddingModelFamilyV2]: +def get_registered_embeddings() -> List[EmbeddingModelFamilyV2]: + """ + Get all embedding families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("embedding") diff --git a/xinference/model/embedding/embed_family.py b/xinference/model/embedding/embed_family.py index a572d7cb68..60c2682792 100644 --- a/xinference/model/embedding/embed_family.py +++ b/xinference/model/embedding/embed_family.py @@ -37,14 +37,14 @@ def match_embedding( ] = None, ) -> "EmbeddingModelFamilyV2": from ..utils import download_from_modelscope - from .custom import get_user_defined_embeddings + from .custom import get_registered_embeddings target_family = None if model_name in BUILTIN_EMBEDDING_MODELS: target_family = BUILTIN_EMBEDDING_MODELS[model_name] else: - for model_family in get_user_defined_embeddings(): + for model_family in get_registered_embeddings(): if model_name == model_family.model_name: target_family = model_family break diff --git a/xinference/model/image/__init__.py b/xinference/model/image/__init__.py index 14230ea41c..31e76bbfcc 100644 --- a/xinference/model/image/__init__.py +++ b/xinference/model/image/__init__.py @@ -14,10 +14,16 @@ import codecs import json +import logging import os import warnings +from typing import Any, Dict from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + from .core import ( BUILTIN_IMAGE_MODELS, IMAGE_MODEL_DESCRIPTIONS, @@ -27,7 +33,7 @@ ) from .custom import ( CustomImageModelFamilyV2, - get_user_defined_images, + get_registered_images, register_image, unregister_image, ) @@ -55,9 +61,128 @@ def register_custom_model(): warnings.warn(f"{user_defined_image_dir}/{f} has error, {e}") +def register_builtin_model(): + import json + + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("image") + existing_model_names = {spec.model_name for spec in registry.get_custom_models()} + + builtin_image_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "image") + if os.path.isdir(builtin_image_dir): + # First, try to load from the complete JSON file + complete_json_path = os.path.join(builtin_image_dir, "image_models.json") + if os.path.exists(complete_json_path): + try: + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + # Handle different formats + models_to_register = [] + if isinstance(model_data, list): + # Multiple models in a list + models_to_register = model_data + elif isinstance(model_data, dict): + # Single model + if "model_name" in model_data: + models_to_register = [model_data] + else: + # Models dict - extract models + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + # Register all models from the complete JSON + for model_data in models_to_register: + try: + # Convert format using flatten_model_src + from ..utils import flatten_model_src + + flattened_list = flatten_model_src(model_data) + converted_data = ( + flattened_list[0] if flattened_list else model_data + ) + builtin_image_family = ImageModelFamilyV2.parse_obj( + converted_data + ) + + # Only register if model doesn't already exist + if builtin_image_family.model_name not in existing_model_names: + # Add to BUILTIN_IMAGE_MODELS directly for proper builtin registration + if ( + builtin_image_family.model_name + not in BUILTIN_IMAGE_MODELS + ): + BUILTIN_IMAGE_MODELS[ + builtin_image_family.model_name + ] = [] + BUILTIN_IMAGE_MODELS[ + builtin_image_family.model_name + ].append(builtin_image_family) + # Update model descriptions for the new builtin model + IMAGE_MODEL_DESCRIPTIONS.update( + generate_image_description(builtin_image_family) + ) + existing_model_names.add(builtin_image_family.model_name) + except Exception as e: + warnings.warn( + f"Error parsing image model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info( + f"Successfully registered {len(models_to_register)} image models from complete JSON" + ) + + except Exception as e: + warnings.warn( + f"Error loading complete JSON file {complete_json_path}: {e}" + ) + # Fall back to individual files if complete JSON loading fails + + # Fall back: load individual JSON files (backward compatibility) + individual_files = [ + f + for f in os.listdir(builtin_image_dir) + if f.endswith(".json") and f != "image_models.json" + ] + for f in individual_files: + try: + with codecs.open( + os.path.join(builtin_image_dir, f), encoding="utf-8" + ) as fd: + model_data = json.load(fd) + # Apply flatten_model_src to individual files + from ..utils import flatten_model_src + + flattened_list = flatten_model_src(model_data) + converted_data = flattened_list[0] if flattened_list else model_data + builtin_image_family = ImageModelFamilyV2.parse_obj(converted_data) + + # Only register if model doesn't already exist + if builtin_image_family.model_name not in existing_model_names: + # Add to BUILTIN_IMAGE_MODELS directly for proper builtin registration + if builtin_image_family.model_name not in BUILTIN_IMAGE_MODELS: + BUILTIN_IMAGE_MODELS[builtin_image_family.model_name] = [] + BUILTIN_IMAGE_MODELS[builtin_image_family.model_name].append( + builtin_image_family + ) + # Update model descriptions for the new builtin model + IMAGE_MODEL_DESCRIPTIONS.update( + generate_image_description(builtin_image_family) + ) + existing_model_names.add(builtin_image_family.model_name) + except Exception as e: + warnings.warn(f"{builtin_image_dir}/{f} has error, {e}") + + def _install(): load_model_family_from_json("model_spec.json", BUILTIN_IMAGE_MODELS) + # Load models from complete JSON file (from update_model_type) + register_builtin_model() + # register model description for model_name, model_specs in BUILTIN_IMAGE_MODELS.items(): model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0] @@ -65,7 +190,7 @@ def _install(): register_custom_model() - for ud_image in get_user_defined_images(): + for ud_image in get_registered_images(): IMAGE_MODEL_DESCRIPTIONS.update(generate_image_description(ud_image)) diff --git a/xinference/model/image/cache_manager.py b/xinference/model/image/cache_manager.py index 37a89519da..57d315ee98 100644 --- a/xinference/model/image/cache_manager.py +++ b/xinference/model/image/cache_manager.py @@ -116,3 +116,32 @@ def cache_lightning(self, lightning_version: Optional[str] = None): raise NotImplementedError return full_path + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for image models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of an image model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if an image model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/image/core.py b/xinference/model/image/core.py index b4baa09bcd..46be16945f 100644 --- a/xinference/model/image/core.py +++ b/xinference/model/image/core.py @@ -121,9 +121,9 @@ def match_diffusion( ) -> ImageModelFamilyV2: from ..utils import download_from_modelscope from . import BUILTIN_IMAGE_MODELS - from .custom import get_user_defined_images + from .custom import get_registered_images - for model_spec in get_user_defined_images(): + for model_spec in get_registered_images(): if model_spec.model_name == model_name: return model_spec diff --git a/xinference/model/image/custom.py b/xinference/model/image/custom.py index 3e3e2a81b9..a8c75433b4 100644 --- a/xinference/model/image/custom.py +++ b/xinference/model/image/custom.py @@ -43,7 +43,11 @@ def __init__(self): self.builtin_models = list(BUILTIN_IMAGE_MODELS.keys()) -def get_user_defined_images() -> List[ImageModelFamilyV2]: +def get_registered_images() -> List[ImageModelFamilyV2]: + """ + Get all image families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("image") diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index a4c4704ce4..342bf2c533 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -13,17 +13,23 @@ # limitations under the License. import codecs import json +import logging import os import warnings +from typing import Any, Dict from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + from .core import ( LLM, LLM_VERSION_INFOS, generate_llm_version_info, get_llm_version_infos, ) -from .custom import get_user_defined_llm_families, register_llm, unregister_llm +from .custom import get_registered_llm_families, register_llm, unregister_llm from .llm_family import ( BUILTIN_LLM_FAMILIES, BUILTIN_LLM_MODEL_CHAT_FAMILIES, @@ -128,6 +134,121 @@ def register_custom_model(): warnings.warn(f"{user_defined_llm_dir}/{f} has error, {e}") +def register_builtin_model(): + # Use unified loading function with flatten_quantizations for LLM + from ..utils import flatten_quantizations, load_complete_builtin_models + from .llm_family import BUILTIN_LLM_FAMILIES + + def convert_llm_with_quantizations(model_json): + if "model_specs" not in model_json: + return model_json + + # Process each model_spec with flatten_quantizations (like builtin LLM loading) + result = model_json.copy() + flattened_specs = [] + for spec in result["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + result["model_specs"] = flattened_specs + + return result + + loaded_count = load_complete_builtin_models( + model_type="llm", + builtin_registry={}, # Temporarily use empty dict, we handle it manually + convert_format_func=convert_llm_with_quantizations, + model_class=LLMFamilyV2, + ) + + # Manually handle LLM's special registration logic + if loaded_count > 0: + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("llm") + existing_model_names = { + spec.model_name for spec in registry.get_custom_models() + } + + builtin_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", "llm") + complete_json_path = os.path.join(builtin_llm_dir, "llm_models.json") + + if os.path.exists(complete_json_path): + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + for model_data in models_to_register: + try: + from ..utils import flatten_quantizations + + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + builtin_llm_family = LLMFamilyV2.parse_obj(converted_data) + + if builtin_llm_family.model_name not in existing_model_names: + register_llm(builtin_llm_family, persist=False) + existing_model_names.add(builtin_llm_family.model_name) + except Exception as e: + warnings.warn( + f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + # Also load individual JSON files (for models added via add_model) + if os.path.isdir(builtin_llm_dir): + individual_files = [ + f + for f in os.listdir(builtin_llm_dir) + if f.endswith(".json") and f != "llm_models.json" + ] + for f in individual_files: + try: + with codecs.open( + os.path.join(builtin_llm_dir, f), encoding="utf-8" + ) as fd: + model_data = json.load(fd) + + # Apply flatten_quantizations to individual files + from ..utils import flatten_quantizations + + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + + builtin_llm_family = LLMFamilyV2.parse_obj(converted_data) + + if builtin_llm_family.model_name not in existing_model_names: + register_llm(builtin_llm_family, persist=False) + existing_model_names.add(builtin_llm_family.model_name) + except Exception as e: + warnings.warn(f"Error parsing LLM model {f}: {e}") + + def load_model_family_from_json(json_filename, target_families): json_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), json_filename) for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")): @@ -210,5 +331,5 @@ def _install(): register_custom_model() # register model description - for ud_llm in get_user_defined_llm_families(): + for ud_llm in get_registered_llm_families(): LLM_VERSION_INFOS.update(generate_llm_version_info(ud_llm)) diff --git a/xinference/model/llm/cache_manager.py b/xinference/model/llm/cache_manager.py index 665c121ba1..2653e91749 100644 --- a/xinference/model/llm/cache_manager.py +++ b/xinference/model/llm/cache_manager.py @@ -304,3 +304,32 @@ def cache(self) -> str: return self.cache_from_csghub() else: raise ValueError(f"Unknown model hub: {self._model_hub}") + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for LLM models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of an LLM model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if an LLM model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/llm/custom.py b/xinference/model/llm/custom.py index 65cf8f8afd..8d96a341eb 100644 --- a/xinference/model/llm/custom.py +++ b/xinference/model/llm/custom.py @@ -67,7 +67,11 @@ def remove_ud_model_files(self, llm_family: "LLMFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_llm_families(): +def get_registered_llm_families(): + """ + Get all LLM families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("llm") diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 628c59e98b..0e58dc9269 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -479,9 +479,9 @@ def match_llm( """ Find an LLM family, spec, and quantization that satisfy given criteria. """ - from .custom import get_user_defined_llm_families + from .custom import get_registered_llm_families - user_defined_llm_families = get_user_defined_llm_families() + user_defined_llm_families = get_registered_llm_families() def _match_quantization(q: Union[str, None], quant: str): # Currently, the quantization name could include both uppercase and lowercase letters, diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py index 4c67989755..4437004e74 100644 --- a/xinference/model/llm/tests/test_llm_family.py +++ b/xinference/model/llm/tests/test_llm_family.py @@ -207,7 +207,11 @@ def test_cache_from_uri_local(): def test_custom_llm(): - from ..custom import get_user_defined_llm_families, register_llm, unregister_llm + from ..custom import get_registered_llm_families as get_user_defined_llm_families + from ..custom import ( + register_llm, + unregister_llm, + ) spec = LlamaCppLLMSpecV2( model_format="ggufv2", @@ -239,7 +243,11 @@ def test_custom_llm(): def test_persistent_custom_llm(): from ....constants import XINFERENCE_MODEL_DIR - from ..custom import get_user_defined_llm_families, register_llm, unregister_llm + from ..custom import get_registered_llm_families as get_user_defined_llm_families + from ..custom import ( + register_llm, + unregister_llm, + ) spec = LlamaCppLLMSpecV2( model_format="ggufv2", @@ -663,7 +671,11 @@ def test_quert_engine_SGLang(): def test_query_engine_general(): - from ..custom import get_user_defined_llm_families, register_llm, unregister_llm + from ..custom import get_registered_llm_families as get_user_defined_llm_families + from ..custom import ( + register_llm, + unregister_llm, + ) from ..llama_cpp.core import XllamaCppModel from ..llm_family import LLM_ENGINES, check_engine_by_spec_parameters diff --git a/xinference/model/rerank/__init__.py b/xinference/model/rerank/__init__.py index 36334cb9fc..9f23f0e594 100644 --- a/xinference/model/rerank/__init__.py +++ b/xinference/model/rerank/__init__.py @@ -14,12 +14,17 @@ import codecs import json +import logging import os import warnings from typing import Any, Dict, List from ...constants import XINFERENCE_MODEL_DIR from ..utils import flatten_quantizations + +logger = logging.getLogger(__name__) + + from .core import ( RERANK_MODEL_DESCRIPTIONS, RerankModelFamilyV2, @@ -28,7 +33,7 @@ ) from .custom import ( CustomRerankModelFamilyV2, - get_user_defined_reranks, + get_registered_reranks, register_rerank, unregister_rerank, ) @@ -63,6 +68,85 @@ def register_custom_model(): warnings.warn(f"{user_defined_rerank_dir}/{f} has error, {e}") +def register_builtin_model(): + # Use unified loading function with flatten_quantizations for rerank models + from ..utils import flatten_quantizations, load_complete_builtin_models + from .rerank_family import BUILTIN_RERANK_MODELS + + def convert_rerank_with_quantizations(model_json): + if "model_specs" not in model_json: + return model_json + + # Process each model_spec with flatten_quantizations (like builtin rerank loading) + result = model_json.copy() + flattened_specs = [] + for spec in result["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + result["model_specs"] = flattened_specs + + return result + + loaded_count = load_complete_builtin_models( + model_type="rerank", + builtin_registry=BUILTIN_RERANK_MODELS, # Use actual registry + convert_format_func=convert_rerank_with_quantizations, + model_class=RerankModelFamilyV2, + ) + + # Manually handle rerank's special registration logic + if loaded_count > 0: + builtin_rerank_dir = os.path.join( + XINFERENCE_MODEL_DIR, "v2", "builtin", "rerank" + ) + complete_json_path = os.path.join(builtin_rerank_dir, "rerank_models.json") + + if os.path.exists(complete_json_path): + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + for model_data in models_to_register: + try: + from ..utils import flatten_quantizations + + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + builtin_rerank_family = RerankModelFamilyV2.parse_obj( + converted_data + ) + + if builtin_rerank_family.model_name not in BUILTIN_RERANK_MODELS: + BUILTIN_RERANK_MODELS[builtin_rerank_family.model_name] = ( + builtin_rerank_family + ) + except Exception as e: + warnings.warn( + f"Error parsing model {model_data.get('model_name', 'Unknown')}: {e}" + ) + + logger.info(f"Successfully loaded {loaded_count} rerank models from complete JSON") + + def generate_engine_config_by_model_name(model_family: "RerankModelFamilyV2"): model_name = model_family.model_name engines: Dict[str, List[Dict[str, Any]]] = RERANK_ENGINES.get( @@ -127,5 +211,5 @@ def _install(): register_custom_model() # register model description - for ud_rerank in get_user_defined_reranks(): + for ud_rerank in get_registered_reranks(): RERANK_MODEL_DESCRIPTIONS.update(generate_rerank_description(ud_rerank)) diff --git a/xinference/model/rerank/cache_manager.py b/xinference/model/rerank/cache_manager.py index 5bf63c8f62..0d39666987 100644 --- a/xinference/model/rerank/cache_manager.py +++ b/xinference/model/rerank/cache_manager.py @@ -33,3 +33,32 @@ def cache(self) -> str: return self.cache_helper.cache_from_modelscope() else: raise ValueError(f"Unknown model hub: {spec.model_hub}") + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for rerank models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of a rerank model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if a rerank model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/rerank/custom.py b/xinference/model/rerank/custom.py index c09fdd40be..1e22dfaf54 100644 --- a/xinference/model/rerank/custom.py +++ b/xinference/model/rerank/custom.py @@ -67,7 +67,11 @@ def remove_ud_model_files(self, model_family: "CustomRerankModelFamilyV2"): cache_manager.unregister_custom_model(self.model_type) -def get_user_defined_reranks() -> List[CustomRerankModelFamilyV2]: +def get_registered_reranks() -> List[CustomRerankModelFamilyV2]: + """ + Get all rerank families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ from ..custom import RegistryManager registry = RegistryManager.get_registry("rerank") diff --git a/xinference/model/rerank/rerank_family.py b/xinference/model/rerank/rerank_family.py index 62639d06cf..1cbcc681d9 100644 --- a/xinference/model/rerank/rerank_family.py +++ b/xinference/model/rerank/rerank_family.py @@ -36,14 +36,14 @@ def match_rerank( ] = None, ) -> "RerankModelFamilyV2": from ..utils import download_from_modelscope - from .custom import get_user_defined_reranks + from .custom import get_registered_reranks target_family = None if model_name in BUILTIN_RERANK_MODELS: target_family = BUILTIN_RERANK_MODELS[model_name] else: - for model_family in get_user_defined_reranks(): + for model_family in get_registered_reranks(): if model_name == model_family.model_name: target_family = model_family break diff --git a/xinference/model/utils.py b/xinference/model/utils.py index ea5dec74d5..87844c3189 100644 --- a/xinference/model/utils.py +++ b/xinference/model/utils.py @@ -603,6 +603,21 @@ def flatten_quantizations(input_json: dict): if key != "quantizations": record[key] = value + # Add required defaults for ggufv2 format if model_file_name_template is missing + if "model_format" in record and record["model_format"] == "ggufv2": + if "model_file_name_template" not in record: + # Generate default template from model_id + model_id = record.get("model_id", "") + if model_id: + # Extract model name from model_id (last part after /) + model_name = model_id.split("/")[-1] + # Remove potential suffixes + if "-GGUF" in model_name: + model_name = model_name.replace("-GGUF", "") + record["model_file_name_template"] = ( + f"{model_name.lower()}-{{quantization}}.gguf" + ) + flattened.append(record) return flattened @@ -709,3 +724,299 @@ def _wrapper(self, *args, **kwargs): return _async_wrapper else: return _wrapper + + +def load_complete_builtin_models( + model_type: str, builtin_registry: dict, convert_format_func=None, model_class=None +): + """ + Load complete JSON files for built-in models in a unified way. + This function loads both the traditional unified JSON file and individual model files. + + Args: + model_type: Model type (llm, embedding, audio, image, video, rerank) + builtin_registry: Built-in model registry dictionary + convert_format_func: Format conversion function (optional) + model_class: Model class (optional) + + Returns: + int: Number of successfully loaded models + """ + import codecs + import json + import logging + import os + + from ..constants import XINFERENCE_MODEL_DIR + + logger = logging.getLogger(__name__) + + builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin", model_type) + complete_json_path = os.path.join(builtin_dir, f"{model_type}_models.json") + + loaded_count = 0 + loaded_from_unified = set() # Track models loaded from unified file + + # First, try to load from the traditional unified JSON file + if os.path.exists(complete_json_path): + try: + with codecs.open(complete_json_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + models_to_register = [] + if isinstance(model_data, list): + models_to_register = model_data + elif isinstance(model_data, dict): + if "model_name" in model_data: + models_to_register = [model_data] + else: + for key, value in model_data.items(): + if isinstance(value, dict) and "model_name" in value: + models_to_register.append(value) + + for data in models_to_register: + try: + # Apply format conversion function (if provided) + if convert_format_func: + data = convert_format_func(data) + + # Create model instance (if model class is provided) + if model_class: + model = model_class.parse_obj(data) + model_name = model.model_name + else: + model_name = data.get("model_name", "unknown") + model = data + + # Add to registry based on model type + if model_type in ["audio", "image", "video", "llm"]: + # These model types use list structure: dict[model_name] = [model1, model2, ...] + if model_name not in builtin_registry: + builtin_registry[model_name] = [model] + else: + builtin_registry[model_name].append(model) + else: + # embedding, rerank use single model structure: dict[model_name] = model + builtin_registry[model_name] = model + + loaded_from_unified.add( + model_name + ) # Track that this model was loaded from unified file + loaded_count += 1 + logger.info( + f"Loaded {model_type} builtin model from unified file: {model_name}" + ) + + except Exception as e: + logger.warning( + f"Failed to load {model_type} model {data.get('model_name', 'Unknown')} from unified file: {e}" + ) + + logger.info( + f"Successfully loaded {loaded_count} {model_type} models from unified JSON" + ) + except Exception as e: + logger.warning(f"Failed to load unified JSON {complete_json_path}: {e}") + + # Second, load individual model files (for models added via add_model) + if os.path.isdir(builtin_dir): + for filename in os.listdir(builtin_dir): + # Skip the unified JSON file and other non-JSON files + if filename == f"{model_type}_models.json" or not filename.endswith( + ".json" + ): + continue + + file_path = os.path.join(builtin_dir, filename) + try: + with codecs.open(file_path, encoding="utf-8") as fd: + model_data = json.load(fd) + + # Skip if this doesn't look like a valid model file + if not isinstance(model_data, dict) or "model_name" not in model_data: + continue + + # Skip if we already loaded this model from unified file in THIS SESSION + model_name = model_data["model_name"] + if model_name in loaded_from_unified: + continue + + # Apply format conversion function (if provided) + if convert_format_func: + model_data = convert_format_func(model_data) + + # Create model instance (if model class is provided) + if model_class: + model = model_class.parse_obj(model_data) + model_name = model.model_name + else: + model_name = model_data.get("model_name", "unknown") + model = model_data + + # Add to registry based on model type + if model_type in ["audio", "image", "video", "llm"]: + # These model types use list structure: dict[model_name] = [model1, model2, ...] + builtin_registry[model_name] = [model] + else: + # embedding, rerank use single model structure: dict[model_name] = model + builtin_registry[model_name] = model + + loaded_count += 1 + logger.info( + f"Loaded {model_type} builtin model from individual file: {model_name}" + ) + + except Exception as e: + logger.warning( + f"Failed to load {model_type} model from {filename}: {e}" + ) + + return loaded_count + + +def load_persisted_models_to_registry(): + """ + Scan and load all persisted user models into the registry. + This function should be called when Worker starts up. + """ + import json + import logging + import os + + logger = logging.getLogger(__name__) + from ..constants import XINFERENCE_MODEL_DIR + + builtin_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "builtin") + + if not os.path.exists(builtin_dir): + logger.info(f"Builtin models directory not found: {builtin_dir}") + return 0 + + loaded_count = 0 + + # Iterate through all model types + for model_type in ["llm", "embedding", "image", "audio", "video", "rerank"]: + type_dir = os.path.join(builtin_dir, model_type) + if not os.path.isdir(type_dir): + continue + + logger.info(f"Loading {model_type} models from {type_dir}") + + # Scan individual model files + for model_file in os.listdir(type_dir): + if model_file.endswith(".json") and not model_file.endswith("_models.json"): + try: + model_path = os.path.join(type_dir, model_file) + with open(model_path, "r", encoding="utf-8") as f: + model_data = json.load(f) + + # Call the corresponding registration function + success = register_model_by_type(model_type, model_data) + if success: + loaded_count += 1 + logger.info(f"✓ Loaded {model_type} model: {model_file}") + else: + logger.warning( + f"✗ Failed to load {model_type} model: {model_file}" + ) + + except Exception as e: + logger.error(f"Error loading model {model_file}: {e}") + + logger.info(f"Total loaded {loaded_count} persisted models") + return loaded_count + + +def register_model_by_type(model_type, model_data): + """ + Call the appropriate registration function based on model type. + """ + try: + if model_type == "llm": + from .llm import register_llm + from .llm.llm_family import LLMFamilyV2 + + # Apply flatten_quantizations to LLM data to fill missing fields + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + + model_spec = LLMFamilyV2.parse_obj(converted_data) + register_llm(model_spec, persist=False) + return True + + elif model_type == "embedding": + from .embedding import register_embedding + from .embedding.core import EmbeddingModelFamilyV2 + + # Apply flatten_quantizations to embedding data to fill missing fields + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + + model_spec = EmbeddingModelFamilyV2.parse_obj(converted_data) + register_embedding(model_spec, persist=False) + return True + + elif model_type == "image": + from .image import register_image + from .image.custom import CustomImageModelFamilyV2 + + model_spec = CustomImageModelFamilyV2.parse_obj(model_data) + register_image(model_spec, persist=False) + return True + + elif model_type == "audio": + from .audio import register_audio + from .audio.custom import CustomAudioModelFamilyV2 + + model_spec = CustomAudioModelFamilyV2.parse_obj(model_data) + register_audio(model_spec, persist=False) + return True + + elif model_type == "video": + from .video import register_video + from .video.custom import CustomVideoModelFamilyV2 + + model_spec = CustomVideoModelFamilyV2.parse_obj(model_data) + register_video(model_spec, persist=False) + return True + + elif model_type == "rerank": + from .rerank import register_rerank + from .rerank.custom import CustomRerankModelFamilyV2 + + # Apply flatten_quantizations to rerank data to fill missing fields + converted_data = model_data.copy() + if "model_specs" in converted_data: + flattened_specs = [] + for spec in converted_data["model_specs"]: + if "model_src" in spec: + flattened_specs.extend(flatten_quantizations(spec)) + else: + flattened_specs.append(spec) + converted_data["model_specs"] = flattened_specs + + model_spec = CustomRerankModelFamilyV2.parse_obj(converted_data) + register_rerank(model_spec, persist=False) + return True + + else: + logger.warning(f"Unknown model type: {model_type}") + return False + + except Exception as e: + logger.error(f"Error registering {model_type} model: {e}") + return False diff --git a/xinference/model/video/__init__.py b/xinference/model/video/__init__.py index 5002fcc039..7e6c6f628f 100644 --- a/xinference/model/video/__init__.py +++ b/xinference/model/video/__init__.py @@ -14,9 +14,16 @@ import codecs import json +import logging import os +import warnings +from typing import Any, Dict from ..utils import flatten_model_src + +logger = logging.getLogger(__name__) + + from .core import ( BUILTIN_VIDEO_MODELS, VIDEO_MODEL_DESCRIPTIONS, @@ -24,11 +31,63 @@ generate_video_description, get_video_model_descriptions, ) +from .custom import ( + CustomVideoModelFamilyV2, + get_registered_videos, + register_video, + unregister_video, +) + + +def register_custom_model(): + from ...constants import XINFERENCE_MODEL_DIR + from ..custom import migrate_from_v1_to_v2 + + # migrate from v1 to v2 first + migrate_from_v1_to_v2("video", CustomVideoModelFamilyV2) + + user_defined_video_dir = os.path.join(XINFERENCE_MODEL_DIR, "v2", "video") + if os.path.isdir(user_defined_video_dir): + for f in os.listdir(user_defined_video_dir): + try: + with codecs.open( + os.path.join(user_defined_video_dir, f), encoding="utf-8" + ) as fd: + user_defined_video_family = CustomVideoModelFamilyV2.parse_obj( + json.load(fd) + ) + register_video(user_defined_video_family, persist=False) + except Exception as e: + warnings.warn(f"{user_defined_video_dir}/{f} has error, {e}") + + +def register_builtin_model(): + """ + Dynamically load built-in video models from builtin/video directory. + This function is called every time model list is requested, + ensuring real-time updates without server restart. + """ + # Use unified loading function with flatten_model_src + from ..utils import flatten_model_src, load_complete_builtin_models + + loaded_count = load_complete_builtin_models( + model_type="video", + builtin_registry=BUILTIN_VIDEO_MODELS, + convert_format_func=lambda x: ( + flatten_model_src(x)[0] if flatten_model_src(x) else x + ), + model_class=VideoModelFamilyV2, + ) + + logger.info(f"Successfully loaded {loaded_count} video models from complete JSON") def _install(): load_model_family_from_json("model_spec.json", BUILTIN_VIDEO_MODELS) + # Load models from complete JSON file (from update_model_type) + register_builtin_model() + # register model description for model_name, model_specs in BUILTIN_VIDEO_MODELS.items(): model_spec = [x for x in model_specs if x.model_hub == "huggingface"][0] diff --git a/xinference/model/video/cache_manager.py b/xinference/model/video/cache_manager.py new file mode 100644 index 0000000000..a55da23058 --- /dev/null +++ b/xinference/model/video/cache_manager.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING + +from ..cache_manager import CacheManager + +if TYPE_CHECKING: + from .core import VideoModelFamilyV2 + + +class VideoCacheManager(CacheManager): + def __init__(self, model_family: "VideoModelFamilyV2"): + super().__init__(model_family) + + @staticmethod + def is_model_from_builtin_dir(model_name: str, model_type: str) -> bool: + """ + Check if a model comes from the builtin directory for video models. + """ + return CacheManager.is_model_from_builtin_dir(model_name, model_type) + + @staticmethod + def resolve_model_source( + model_name: str, model_type: str, builtin_model_names=None + ) -> str: + """ + Resolve the source of a video model. + """ + return CacheManager.resolve_model_source( + model_name, model_type, builtin_model_names + ) + + @staticmethod + def is_builtin_model( + model_name: str, model_type: str, builtin_model_names=None + ) -> bool: + """ + Determine if a video model should be considered builtin. + """ + return CacheManager.is_builtin_model( + model_name, model_type, builtin_model_names + ) diff --git a/xinference/model/video/custom.py b/xinference/model/video/custom.py new file mode 100644 index 0000000000..841917c42e --- /dev/null +++ b/xinference/model/video/custom.py @@ -0,0 +1,74 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, List, Optional + +from ..._compat import ( + Literal, +) +from ..custom import ModelRegistry +from .core import VideoModelFamilyV2 + +logger = logging.getLogger(__name__) + + +class CustomVideoModelFamilyV2(VideoModelFamilyV2): + version: Literal[2] = 2 + model_id: Optional[str] # type: ignore + model_revision: Optional[str] # type: ignore + model_uri: Optional[str] + + +if TYPE_CHECKING: + from typing import TypeVar + + _T = TypeVar("_T", bound="CustomVideoModelFamilyV2") + + +class VideoModelRegistry(ModelRegistry): + model_type = "video" + + def __init__(self): + super().__init__() + + def get_user_defined_models(self) -> List["CustomVideoModelFamilyV2"]: + return self.get_custom_models() + + +video_registry = VideoModelRegistry() + + +def register_video(model_spec: CustomVideoModelFamilyV2, persist: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.register(model_spec, persist) + + +def unregister_video(model_name: str, raise_error: bool = True): + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + registry.unregister(model_name, raise_error) + + +def get_registered_videos() -> List[CustomVideoModelFamilyV2]: + """ + Get all video families registered in the registry (both user-defined and editor-defined). + This excludes hardcoded builtin models. + """ + from ..custom import RegistryManager + + registry = RegistryManager.get_registry("video") + return registry.get_custom_models() diff --git a/xinference/ui/web/ui/src/locales/en.json b/xinference/ui/web/ui/src/locales/en.json index a12662732d..437fb45a5c 100644 --- a/xinference/ui/web/ui/src/locales/en.json +++ b/xinference/ui/web/ui/src/locales/en.json @@ -124,7 +124,23 @@ "featured": "featured", "all": "all", "cancelledSuccessfully": "Cancelled Successfully!", - "mustBeUnique": "{{key}} must be unique" + "mustBeUnique": "{{key}} must be unique", + "addModel": "Add Model", + "addModelDialog": { + "introPrefix": "To add a model, please go to the", + "platformLinkText": "Xinference Model Hub", + "introSuffix": "and fill in the corresponding model name.", + "modelName": "Model Name", + "modelName.tip": "Please enter the model name", + "placeholder": "e.g. qwen3 (case-sensitive)" + }, + "update": "Update", + "error": { + "name_not_matched": "No exact model name match found (case-sensitive)", + "downloadFailed": "Download failed", + "requestFailed": "Request failed", + "json_parse_error": "Failed to parse JSON" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ja.json b/xinference/ui/web/ui/src/locales/ja.json index dc1636bfd3..e4075f9e1d 100644 --- a/xinference/ui/web/ui/src/locales/ja.json +++ b/xinference/ui/web/ui/src/locales/ja.json @@ -124,7 +124,23 @@ "featured": "おすすめとお気に入り", "all": "すべて", "cancelledSuccessfully": "正常にキャンセルされました!", - "mustBeUnique": "{{key}} は一意でなければなりません" + "mustBeUnique": "{{key}} は一意でなければなりません", + "addModel": "モデルを追加", + "addModelDialog": { + "introPrefix": "モデルを追加するには、", + "platformLinkText": "Xinference モデルセンター", + "introSuffix": "で対応するモデル名を入力してください。", + "modelName": "モデル名", + "modelName.tip": "モデル名を入力してください", + "placeholder": "例:qwen3(大文字と小文字を区別します)" + }, + "update": "更新", + "error": { + "name_not_matched": "完全に一致するモデル名が見つかりません(大文字と小文字を区別します)", + "downloadFailed": "ダウンロードに失敗しました", + "requestFailed": "リクエストに失敗しました", + "json_parse_error": "JSON の解析に失敗しました" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/ko.json b/xinference/ui/web/ui/src/locales/ko.json index 17ad7626a6..36fd0cd0c2 100644 --- a/xinference/ui/web/ui/src/locales/ko.json +++ b/xinference/ui/web/ui/src/locales/ko.json @@ -124,7 +124,23 @@ "featured": "추천 및 즐겨찾기", "all": "모두", "cancelledSuccessfully": "성공적으로 취소되었습니다!", - "mustBeUnique": "{{key}} 는 고유해야 합니다" + "mustBeUnique": "{{key}} 는 고유해야 합니다", + "addModel": "모델 추가", + "addModelDialog": { + "introPrefix": "모델을 추가하려면", + "platformLinkText": "Xinference 모델 센터", + "introSuffix": "에서 해당 모델 이름을 입력하세요.", + "modelName": "모델 이름", + "modelName.tip": "모델 이름을 입력하세요", + "placeholder": "예: qwen3 (대소문자를 구분합니다)" + }, + "update": "업데이트", + "error": { + "name_not_matched": "완전히 일치하는 모델 이름을 찾을 수 없습니다(대소문자 구분)", + "downloadFailed": "다운로드 실패", + "requestFailed": "요청 실패", + "json_parse_error": "JSON 구문 분석에 실패했습니다" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/locales/zh.json b/xinference/ui/web/ui/src/locales/zh.json index 36daec1756..3a0a1d7a19 100644 --- a/xinference/ui/web/ui/src/locales/zh.json +++ b/xinference/ui/web/ui/src/locales/zh.json @@ -124,7 +124,23 @@ "featured": "推荐和收藏", "all": "全部", "cancelledSuccessfully": "取消成功!", - "mustBeUnique": "{{key}} 必须唯一" + "mustBeUnique": "{{key}} 必须唯一", + "addModel": "添加模型", + "addModelDialog": { + "introPrefix": "添加模型需基于", + "platformLinkText": "Xinference 模型中心", + "introSuffix": ",填写模型对应的名称", + "modelName": "模型名称", + "modelName.tip": "请输入模型名称", + "placeholder": "例如:qwen3(需大小写完全匹配)" + }, + "update": "更新", + "error": { + "name_not_matched": "未找到完全匹配的模型名称(需大小写一致)", + "downloadFailed": "下载失败", + "requestFailed": "请求失败", + "json_parse_error": "JSON 解析失败" + } }, "runningModels": { diff --git a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js index cba7bf9a65..623a122b6d 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/LaunchModel.js @@ -10,9 +10,11 @@ import { Select, } from '@mui/material' import React, { + forwardRef, useCallback, useContext, useEffect, + useImperativeHandle, useRef, useState, } from 'react' @@ -28,494 +30,507 @@ import ModelCard from './modelCard' // Toggle pagination globally for this page. Set to false to disable pagination and load all items. const ENABLE_PAGINATION = false -const LaunchModelComponent = ({ modelType, gpuAvailable, featureModels }) => { - const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) - const { isUpdatingModel } = useContext(ApiContext) - const { setErrorMsg } = useContext(ApiContext) - const [cookie] = useCookies(['token']) - - const [registrationData, setRegistrationData] = useState([]) - // States used for filtering - const [searchTerm, setSearchTerm] = useState('') - const [status, setStatus] = useState('') - const [statusArr, setStatusArr] = useState([]) - const [collectionArr, setCollectionArr] = useState([]) - const [filterArr, setFilterArr] = useState([]) - const { t } = useTranslation() - const [modelListType, setModelListType] = useState('featured') - const [modelAbilityData, setModelAbilityData] = useState({ - type: modelType, - modelAbility: '', - options: [], - }) - const [selectedModel, setSelectedModel] = useState(null) - const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = useState(false) - - // Pagination status - const [displayedData, setDisplayedData] = useState([]) - const [currentPage, setCurrentPage] = useState(1) - const [hasMore, setHasMore] = useState(true) - const itemsPerPage = 20 - const loaderRef = useRef(null) - - const filter = useCallback( - (registration) => { - if (searchTerm !== '') { - if (!registration || typeof searchTerm !== 'string') return false - const modelName = registration.model_name - ? registration.model_name.toLowerCase() - : '' - const modelDescription = registration.model_description - ? registration.model_description.toLowerCase() - : '' +const LaunchModelComponent = forwardRef( + ({ modelType, gpuAvailable, featureModels }, ref) => { + const { isCallingApi, setIsCallingApi, endPoint } = useContext(ApiContext) + const { isUpdatingModel } = useContext(ApiContext) + const { setErrorMsg } = useContext(ApiContext) + const [cookie] = useCookies(['token']) + + const [registrationData, setRegistrationData] = useState([]) + // States used for filtering + const [searchTerm, setSearchTerm] = useState('') + const [status, setStatus] = useState('') + const [statusArr, setStatusArr] = useState([]) + const [collectionArr, setCollectionArr] = useState([]) + const [filterArr, setFilterArr] = useState([]) + const { t } = useTranslation() + const [modelListType, setModelListType] = useState('featured') + const [modelAbilityData, setModelAbilityData] = useState({ + type: modelType, + modelAbility: '', + options: [], + }) + const [selectedModel, setSelectedModel] = useState(null) + const [isOpenLaunchModelDrawer, setIsOpenLaunchModelDrawer] = + useState(false) + + // Pagination status + const [displayedData, setDisplayedData] = useState([]) + const [currentPage, setCurrentPage] = useState(1) + const [hasMore, setHasMore] = useState(true) + const itemsPerPage = 20 + const loaderRef = useRef(null) + + const filter = useCallback( + (registration) => { + if (searchTerm !== '') { + if (!registration || typeof searchTerm !== 'string') return false + const modelName = registration.model_name + ? registration.model_name.toLowerCase() + : '' + const modelDescription = registration.model_description + ? registration.model_description.toLowerCase() + : '' + + if ( + !modelName.includes(searchTerm.toLowerCase()) && + !modelDescription.includes(searchTerm.toLowerCase()) + ) { + return false + } + } - if ( - !modelName.includes(searchTerm.toLowerCase()) && - !modelDescription.includes(searchTerm.toLowerCase()) - ) { - return false + if (modelListType === 'featured') { + if ( + featureModels.length && + !featureModels.includes(registration.model_name) && + !collectionArr?.includes(registration.model_name) + ) { + return false + } } - } - if (modelListType === 'featured') { if ( - featureModels.length && - !featureModels.includes(registration.model_name) && - !collectionArr?.includes(registration.model_name) - ) { + modelAbilityData.modelAbility && + ((Array.isArray(registration.model_ability) && + registration.model_ability.indexOf(modelAbilityData.modelAbility) < + 0) || + (typeof registration.model_ability === 'string' && + registration.model_ability !== modelAbilityData.modelAbility)) + ) return false - } - } - if ( - modelAbilityData.modelAbility && - ((Array.isArray(registration.model_ability) && - registration.model_ability.indexOf(modelAbilityData.modelAbility) < - 0) || - (typeof registration.model_ability === 'string' && - registration.model_ability !== modelAbilityData.modelAbility)) - ) - return false - - if (statusArr.length === 1) { - if (statusArr[0] === 'cached') { + if (statusArr.length === 1) { + if (statusArr[0] === 'cached') { + const judge = + registration.model_specs?.some((spec) => filterCache(spec)) || + registration?.cache_status + return judge + } else { + return collectionArr?.includes(registration.model_name) + } + } else if (statusArr.length > 1) { const judge = registration.model_specs?.some((spec) => filterCache(spec)) || registration?.cache_status - return judge - } else { - return collectionArr?.includes(registration.model_name) + return judge && collectionArr?.includes(registration.model_name) } - } else if (statusArr.length > 1) { - const judge = - registration.model_specs?.some((spec) => filterCache(spec)) || - registration?.cache_status - return judge && collectionArr?.includes(registration.model_name) - } - return true - }, - [ - searchTerm, - modelListType, - featureModels, - collectionArr, - modelAbilityData.modelAbility, - statusArr, - ] - ) - - const filterCache = useCallback((spec) => { - if (Array.isArray(spec.cache_status)) { - return spec.cache_status?.some((cs) => cs) - } else { - return spec.cache_status === true - } - }, []) - - function getUniqueModelAbilities(arr) { - const uniqueAbilities = new Set() + return true + }, + [ + searchTerm, + modelListType, + featureModels, + collectionArr, + modelAbilityData.modelAbility, + statusArr, + ] + ) - arr.forEach((item) => { - if (Array.isArray(item.model_ability)) { - item.model_ability.forEach((ability) => { - uniqueAbilities.add(ability) - }) + const filterCache = useCallback((spec) => { + if (Array.isArray(spec.cache_status)) { + return spec.cache_status?.some((cs) => cs) + } else { + return spec.cache_status === true } - }) + }, []) - return Array.from(uniqueAbilities) - } + function getUniqueModelAbilities(arr) { + const uniqueAbilities = new Set() - const update = () => { - if ( - isCallingApi || - isUpdatingModel || - (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) - ) - return - - try { - setIsCallingApi(true) - - fetchWrapper - .get(`/v1/model_registrations/${modelType}?detailed=true`) - .then((data) => { - const builtinRegistrations = data.filter((v) => v.is_builtin) - setModelAbilityData({ - ...modelAbilityData, - options: getUniqueModelAbilities(builtinRegistrations), + arr.forEach((item) => { + if (Array.isArray(item.model_ability)) { + item.model_ability.forEach((ability) => { + uniqueAbilities.add(ability) }) - setRegistrationData(builtinRegistrations) - const collectionData = JSON.parse( - localStorage.getItem('collectionArr') - ) - setCollectionArr(collectionData) + } + }) - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - }) - .catch((error) => { - console.error('Error:', error) - if (error.response.status !== 403 && error.response.status !== 401) { - setErrorMsg(error.message) - } - }) - } catch (error) { - console.error('Error:', error) - } finally { - setIsCallingApi(false) + return Array.from(uniqueAbilities) } - } - useEffect(() => { - update() - }, [cookie.token]) + const update = () => { + if ( + isCallingApi || + isUpdatingModel || + (cookie.token !== 'no_auth' && !sessionStorage.getItem('token')) + ) + return + + try { + setIsCallingApi(true) + + fetchWrapper + .get(`/v1/model_registrations/${modelType}?detailed=true`) + .then((data) => { + const builtinRegistrations = data.filter((v) => v.is_builtin) + setModelAbilityData({ + ...modelAbilityData, + options: getUniqueModelAbilities(builtinRegistrations), + }) + setRegistrationData(builtinRegistrations) + const collectionData = JSON.parse( + localStorage.getItem('collectionArr') + ) + setCollectionArr(collectionData) + + // Reset pagination status + setCurrentPage(1) + setHasMore(true) + }) + .catch((error) => { + console.error('Error:', error) + if ( + error.response.status !== 403 && + error.response.status !== 401 + ) { + setErrorMsg(error.message) + } + }) + } catch (error) { + console.error('Error:', error) + } finally { + setIsCallingApi(false) + } + } - // Update pagination data - const updateDisplayedData = useCallback(() => { - const filteredData = registrationData.filter((registration) => - filter(registration) - ) + useEffect(() => { + update() + }, [cookie.token]) - const sortedData = [...filteredData].sort((a, b) => { - if (modelListType === 'featured') { - const indexA = featureModels.indexOf(a.model_name) - const indexB = featureModels.indexOf(b.model_name) - return ( - (indexA !== -1 ? indexA : Infinity) - - (indexB !== -1 ? indexB : Infinity) - ) + // Update pagination data + const updateDisplayedData = useCallback(() => { + const filteredData = registrationData.filter((registration) => + filter(registration) + ) + + const sortedData = [...filteredData].sort((a, b) => { + if (modelListType === 'featured') { + const indexA = featureModels.indexOf(a.model_name) + const indexB = featureModels.indexOf(b.model_name) + return ( + (indexA !== -1 ? indexA : Infinity) - + (indexB !== -1 ? indexB : Infinity) + ) + } + return 0 + }) + + // If pagination is disabled, show all data at once + if (!ENABLE_PAGINATION) { + setDisplayedData(sortedData) + setHasMore(false) + return } - return 0 - }) - // If pagination is disabled, show all data at once - if (!ENABLE_PAGINATION) { - setDisplayedData(sortedData) - setHasMore(false) - return - } + const startIndex = (currentPage - 1) * itemsPerPage + const endIndex = currentPage * itemsPerPage + const newData = sortedData.slice(startIndex, endIndex) - const startIndex = (currentPage - 1) * itemsPerPage - const endIndex = currentPage * itemsPerPage - const newData = sortedData.slice(startIndex, endIndex) + if (currentPage === 1) { + setDisplayedData(newData) + } else { + setDisplayedData((prev) => [...prev, ...newData]) + } + setHasMore(endIndex < sortedData.length) + }, [ + registrationData, + filter, + modelListType, + featureModels, + currentPage, + itemsPerPage, + ]) - if (currentPage === 1) { - setDisplayedData(newData) - } else { - setDisplayedData((prev) => [...prev, ...newData]) - } - setHasMore(endIndex < sortedData.length) - }, [ - registrationData, - filter, - modelListType, - featureModels, - currentPage, - itemsPerPage, - ]) - - useEffect(() => { - updateDisplayedData() - }, [updateDisplayedData]) - - // Reset pagination when filters change - useEffect(() => { - setCurrentPage(1) - setHasMore(true) - }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) - - // Infinite scroll observer - useEffect(() => { - if (!ENABLE_PAGINATION) return - - const observer = new IntersectionObserver( - (entries) => { - if (entries[0].isIntersecting && hasMore && !isCallingApi) { - setCurrentPage((prev) => prev + 1) - } - }, - { threshold: 1.0 } - ) + useEffect(() => { + updateDisplayedData() + }, [updateDisplayedData]) - if (loaderRef.current) { - observer.observe(loaderRef.current) - } + // Reset pagination when filters change + useEffect(() => { + setCurrentPage(1) + setHasMore(true) + }, [searchTerm, modelAbilityData.modelAbility, status, modelListType]) + + // Infinite scroll observer + useEffect(() => { + if (!ENABLE_PAGINATION) return + + const observer = new IntersectionObserver( + (entries) => { + if (entries[0].isIntersecting && hasMore && !isCallingApi) { + setCurrentPage((prev) => prev + 1) + } + }, + { threshold: 1.0 } + ) - return () => { if (loaderRef.current) { - observer.unobserve(loaderRef.current) + observer.observe(loaderRef.current) } - } - }, [hasMore, isCallingApi, currentPage]) - const getCollectionArr = (data) => { - setCollectionArr(data) - } + return () => { + if (loaderRef.current) { + observer.unobserve(loaderRef.current) + } + } + }, [hasMore, isCallingApi, currentPage]) - const handleChangeFilter = (type, value) => { - const typeMap = { - modelAbility: { - setter: (value) => { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: value, - }) - }, - filterArr: modelAbilityData.options, - }, - status: { setter: setStatus, filterArr: [] }, + const getCollectionArr = (data) => { + setCollectionArr(data) } - const { setter, filterArr: excludeArr } = typeMap[type] || {} - if (!setter) return + const handleChangeFilter = (type, value) => { + const typeMap = { + modelAbility: { + setter: (value) => { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: value, + }) + }, + filterArr: modelAbilityData.options, + }, + status: { setter: setStatus, filterArr: [] }, + } - setter(value) + const { setter, filterArr: excludeArr } = typeMap[type] || {} + if (!setter) return - const updatedFilterArr = Array.from( - new Set([ - ...filterArr.filter((item) => !excludeArr.includes(item)), - value, - ]) - ) + setter(value) + + const updatedFilterArr = Array.from( + new Set([ + ...filterArr.filter((item) => !excludeArr.includes(item)), + value, + ]) + ) - setFilterArr(updatedFilterArr) + setFilterArr(updatedFilterArr) - if (type === 'status') { - setStatusArr( - updatedFilterArr.filter( - (item) => ![...modelAbilityData.options].includes(item) + if (type === 'status') { + setStatusArr( + updatedFilterArr.filter( + (item) => ![...modelAbilityData.options].includes(item) + ) ) - ) - } + } - // Reset pagination status - setDisplayedData([]) - setCurrentPage(1) - setHasMore(true) - } + // Reset pagination status + setDisplayedData([]) + setCurrentPage(1) + setHasMore(true) + } - const handleDeleteChip = (item) => { - setFilterArr( - filterArr.filter((subItem) => { - return subItem !== item - }) - ) - if (item === modelAbilityData.modelAbility) { - setModelAbilityData({ - ...modelAbilityData, - modelAbility: '', - }) - } else { - setStatusArr( - statusArr.filter((subItem) => { + const handleDeleteChip = (item) => { + setFilterArr( + filterArr.filter((subItem) => { return subItem !== item }) ) - if (item === status) setStatus('') - } - - // Reset pagination status - setCurrentPage(1) - setHasMore(true) - } - - const handleModelType = (newModelType) => { - if (newModelType !== null) { - setModelListType(newModelType) + if (item === modelAbilityData.modelAbility) { + setModelAbilityData({ + ...modelAbilityData, + modelAbility: '', + }) + } else { + setStatusArr( + statusArr.filter((subItem) => { + return subItem !== item + }) + ) + if (item === status) setStatus('') + } // Reset pagination status - setDisplayedData([]) setCurrentPage(1) setHasMore(true) } - } - function getLabel(item) { - const translation = t(`launchModel.${item}`) - return translation === `launchModel.${item}` ? item : translation - } + const handleModelType = (newModelType) => { + if (newModelType !== null) { + setModelListType(newModelType) - return ( - -
{ - const hasAbility = modelAbilityData.options.length > 0 - const hasFeature = featureModels.length > 0 - - const baseColumns = hasAbility ? ['200px', '150px'] : ['200px'] - const altColumns = hasAbility ? ['150px', '150px'] : ['150px'] - - const columns = hasFeature - ? [...baseColumns, '150px', '1fr'] - : [...altColumns, '1fr'] - - return columns.join(' ') - })(), - columnGap: '20px', - margin: '30px 2rem', - alignItems: 'center', - }} - > - {featureModels.length > 0 && ( - - - + + + + )} + {modelAbilityData.options.length > 0 && ( + + + {t('launchModel.modelAbility')} + + + + )} - - {t('launchModel.modelAbility')} + + {t('launchModel.status')} - )} - - {t('launchModel.status')} - - - - - { - setSearchTerm(e.target.value) - }} - size="small" - hotkey="Enter" - t={t} - /> - -
-
- {filterArr.map((item, index) => ( - handleDeleteChip(item)} - /> - ))} -
-
- {displayedData.map((filteredRegistration) => ( - + { + setSearchTerm(e.target.value) + }} + size="small" + hotkey="Enter" + t={t} + /> + +
+
+ {filterArr.map((item, index) => ( + handleDeleteChip(item)} + /> + ))} +
+
+ {displayedData.map((filteredRegistration) => ( + { + setSelectedModel(filteredRegistration) + setIsOpenLaunchModelDrawer(true) + }} + /> + ))} +
+ +
+ {ENABLE_PAGINATION && hasMore && !isCallingApi && ( +
+ +
+ )} +
+ + {selectedModel && ( + { - setSelectedModel(filteredRegistration) - setIsOpenLaunchModelDrawer(true) - }} + gpuAvailable={gpuAvailable} + open={isOpenLaunchModelDrawer} + onClose={() => setIsOpenLaunchModelDrawer(false)} /> - ))} - - -
- {ENABLE_PAGINATION && hasMore && !isCallingApi && ( -
- -
)} -
- - {selectedModel && ( - setIsOpenLaunchModelDrawer(false)} - /> - )} -
- ) -} + + ) + } +) + +LaunchModelComponent.displayName = 'LaunchModelComponent' export default LaunchModelComponent diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js new file mode 100644 index 0000000000..885ef07a70 --- /dev/null +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/addModelDialog.js @@ -0,0 +1,109 @@ +import { + Button, + Dialog, + DialogActions, + DialogContent, + DialogTitle, + TextField, +} from '@mui/material' +import React, { useContext, useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { ApiContext } from '../../../components/apiContext' +import fetchWrapper from '../../../components/fetchWrapper' + +const AddModelDialog = ({ open, onClose, onUpdateList }) => { + const { t } = useTranslation() + const [modelName, setModelName] = useState('') + const [loading, setLoading] = useState(false) + const { setErrorMsg } = useContext(ApiContext) + + const handleFormSubmit = async (e) => { + e.preventDefault() + const name = modelName?.trim() + if (!name) { + setErrorMsg(t('launchModel.addModelDialog.modelName.tip')) + return + } + setLoading(true) + setErrorMsg('') + + fetchWrapper + .post('/v1/models/add', { model_name: modelName }) + .then((data) => { + onClose(`/launch_model/${data.data.model_type}`) + onUpdateList(data.data.model_type) + }) + .catch((error) => { + console.error('Error:', error) + if (error.response.status !== 403 && error.response.status !== 401) { + setErrorMsg(error.message) + } + }) + .finally(() => { + setLoading(false) + }) + } + + return ( + onClose()} width={500}> + {t('launchModel.addModel')} + +
+
+ {t('launchModel.addModelDialog.introPrefix')}{' '} + + {t('launchModel.addModelDialog.platformLinkText')} + + {t('launchModel.addModelDialog.introSuffix')} +
+
+ { + setModelName(e.target.value) + }} + disabled={loading} + /> + +
+
+ + + + +
+ ) +} + +export default AddModelDialog diff --git a/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js b/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js index 1169f06269..7a5bda45e8 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/components/launchModelDrawer.js @@ -407,9 +407,9 @@ const LaunchModelDrawer = ({ }, []) useEffect(() => { - if (modelEngineType.includes(modelType)) + if (open && modelEngineType.includes(modelType)) fetchModelEngine(modelData.model_name, modelType) - }, [modelData.model_name, modelType]) + }, [open, modelData.model_name, modelType]) useEffect(() => { if (formData.__isInitializing) { diff --git a/xinference/ui/web/ui/src/scenes/launch_model/index.js b/xinference/ui/web/ui/src/scenes/launch_model/index.js index 24f886a80d..4ac6cff612 100644 --- a/xinference/ui/web/ui/src/scenes/launch_model/index.js +++ b/xinference/ui/web/ui/src/scenes/launch_model/index.js @@ -1,6 +1,7 @@ -import { TabContext, TabList, TabPanel } from '@mui/lab' -import { Box, Tab } from '@mui/material' -import React, { useContext, useEffect, useState } from 'react' +import Add from '@mui/icons-material/Add' +import { LoadingButton, TabContext, TabList, TabPanel } from '@mui/lab' +import { Box, Button, MenuItem, Select, Tab } from '@mui/material' +import React, { useContext, useEffect, useRef, useState } from 'react' import { useCookies } from 'react-cookie' import { useTranslation } from 'react-i18next' import { useNavigate } from 'react-router-dom' @@ -11,6 +12,7 @@ import fetchWrapper from '../../components/fetchWrapper' import SuccessMessageSnackBar from '../../components/successMessageSnackBar' import Title from '../../components/Title' import { isValidBearerToken } from '../../components/utils' +import AddModelDialog from './components/addModelDialog' import { featureModels } from './data/data' import LaunchCustom from './launchCustom' import LaunchModelComponent from './LaunchModel' @@ -22,13 +24,17 @@ const LaunchModel = () => { : '/launch_model/llm' ) const [gpuAvailable, setGPUAvailable] = useState(-1) + const [open, setOpen] = useState(false) + const [modelType, setModelType] = useState('llm') + const [loading, setLoading] = useState(false) const { setErrorMsg } = useContext(ApiContext) const [cookie] = useCookies(['token']) const navigate = useNavigate() const { t } = useTranslation() + const LaunchModelRefs = useRef({}) - const handleTabChange = (event, newValue) => { + const handleTabChange = (newValue) => { setValue(newValue) navigate(newValue) sessionStorage.setItem('modelType', newValue) @@ -59,14 +65,56 @@ const LaunchModel = () => { } }, [cookie.token]) + const updateList = (modelType) => { + LaunchModelRefs.current[modelType]?.update() + } + + const handleClose = (value) => { + setOpen(false) + if (value) { + handleTabChange(value) + } + } + + const updateModels = () => { + setLoading(true) + fetchWrapper + .post('/v1/models/update_type', { model_type: modelType }) + .then(() => { + handleTabChange(`/launch_model/${modelType}`) + updateList(modelType) + }) + .catch((error) => { + console.error('Error:', error) + if (error.response.status !== 403 && error.response.status !== 401) { + setErrorMsg(error.message) + } + }) + .finally(() => { + setLoading(false) + }) + } + return ( <ErrorMessageSnackBar /> <SuccessMessageSnackBar /> <TabContext value={value}> - <Box sx={{ borderBottom: 1, borderColor: 'divider' }}> - <TabList value={value} onChange={handleTabChange} aria-label="tabs"> + <Box + sx={{ + borderBottom: 1, + borderColor: 'divider', + display: 'flex', + justifyContent: 'space-between', + alignItems: 'center', + }} + > + <TabList + value={value} + onChange={(_, value) => handleTabChange(value)} + aria-label="tabs" + > <Tab label={t('model.languageModels')} value="/launch_model/llm" /> <Tab label={t('model.embeddingModels')} @@ -81,6 +129,53 @@ const LaunchModel = () => { value="/launch_model/custom/llm" /> </TabList> + <Box + sx={{ + display: 'flex', + alignItems: 'center', + gap: '10px', + }} + > + <Box sx={{ display: 'flex', gap: 0 }}> + <Select + value={modelType} + onChange={(e) => setModelType(e.target.value)} + size="small" + sx={{ + borderTopRightRadius: 0, + borderBottomRightRadius: 0, + minWidth: 100, + }} + > + <MenuItem value="llm">LLM</MenuItem> + <MenuItem value="embedding">Embedding</MenuItem> + <MenuItem value="rerank">Rerank</MenuItem> + <MenuItem value="image">Image</MenuItem> + <MenuItem value="audio">Audio</MenuItem> + <MenuItem value="video">Video</MenuItem> + </Select> + + <LoadingButton + variant="contained" + onClick={updateModels} + loading={loading} + sx={{ + borderTopLeftRadius: 0, + borderBottomLeftRadius: 0, + whiteSpace: 'nowrap', + }} + > + {t('launchModel.update')} + </LoadingButton> + </Box> + <Button + variant="outlined" + startIcon={<Add />} + onClick={() => setOpen(true)} + > + {t('launchModel.addModel')} + </Button> + </Box> </Box> <TabPanel value="/launch_model/llm" sx={{ padding: 0 }}> <LaunchModelComponent @@ -89,6 +184,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'llm').feature_models } + ref={(ref) => (LaunchModelRefs.current.llm = ref)} /> </TabPanel> <TabPanel value="/launch_model/embedding" sx={{ padding: 0 }}> @@ -99,6 +195,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'embedding') .feature_models } + ref={(ref) => (LaunchModelRefs.current.embedding = ref)} /> </TabPanel> <TabPanel value="/launch_model/rerank" sx={{ padding: 0 }}> @@ -109,6 +206,7 @@ const LaunchModel = () => { featureModels.find((item) => item.type === 'rerank') .feature_models } + ref={(ref) => (LaunchModelRefs.current.rerank = ref)} /> </TabPanel> <TabPanel value="/launch_model/image" sx={{ padding: 0 }}> @@ -118,6 +216,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'image').feature_models } + ref={(ref) => (LaunchModelRefs.current.image = ref)} /> </TabPanel> <TabPanel value="/launch_model/audio" sx={{ padding: 0 }}> @@ -127,6 +226,7 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'audio').feature_models } + ref={(ref) => (LaunchModelRefs.current.audio = ref)} /> </TabPanel> <TabPanel value="/launch_model/video" sx={{ padding: 0 }}> @@ -136,12 +236,20 @@ const LaunchModel = () => { featureModels={ featureModels.find((item) => item.type === 'video').feature_models } + ref={(ref) => (LaunchModelRefs.current.video = ref)} /> </TabPanel> <TabPanel value="/launch_model/custom/llm" sx={{ padding: 0 }}> <LaunchCustom gpuAvailable={gpuAvailable} /> </TabPanel> </TabContext> + {open && ( + <AddModelDialog + onUpdateList={updateList} + open={open} + onClose={handleClose} + /> + )} </Box> ) }