Skip to content

Commit 7bc04d3

Browse files
committed
num3:
1 parent 36f4401 commit 7bc04d3

File tree

7 files changed

+163
-55
lines changed

7 files changed

+163
-55
lines changed

xinference/model/audio/__init__.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,81 @@ def register_custom_model():
6767

6868
def register_builtin_model():
6969
# Use unified function for audio models
70-
from ..utils import register_builtin_models_unified, flatten_model_src
70+
from ..utils import flatten_model_src, register_builtin_models_unified
71+
72+
def convert_audio_model_format(model_json):
73+
"""
74+
Convert audio model hub JSON format to Xinference expected format.
75+
Add missing required fields for AudioModelFamilyV2.
76+
"""
77+
converted = model_json.copy()
78+
79+
# Apply conversion logic to handle null model_id and other issues
80+
if converted.get("model_id") is None and "model_src" in converted:
81+
model_src = converted["model_src"]
82+
# Extract model_id from available sources
83+
if "huggingface" in model_src and "model_id" in model_src["huggingface"]:
84+
converted["model_id"] = model_src["huggingface"]["model_id"]
85+
elif "modelscope" in model_src and "model_id" in model_src["modelscope"]:
86+
converted["model_id"] = model_src["modelscope"]["model_id"]
87+
88+
# Extract model_revision if available
89+
if converted.get("model_revision") is None and "model_src" in converted:
90+
model_src = converted["model_src"]
91+
if (
92+
"huggingface" in model_src
93+
and "model_revision" in model_src["huggingface"]
94+
):
95+
converted["model_revision"] = model_src["huggingface"]["model_revision"]
96+
elif (
97+
"modelscope" in model_src
98+
and "model_revision" in model_src["modelscope"]
99+
):
100+
converted["model_revision"] = model_src["modelscope"]["model_revision"]
101+
102+
return converted
103+
104+
def audio_special_handling(registry, model_type):
105+
"""Handle audio's special registration logic"""
106+
from ..custom import RegistryManager
107+
from .custom import register_audio
108+
109+
registry_mgr = RegistryManager.get_registry("audio")
110+
existing_model_names = {
111+
spec.model_name for spec in registry_mgr.get_custom_models()
112+
}
113+
114+
for model_name, model_families in BUILTIN_AUDIO_MODELS.items():
115+
for model_family in model_families:
116+
if model_family.model_name not in existing_model_names:
117+
try:
118+
# Actually register model to RegistryManager
119+
register_audio(model_family, persist=False)
120+
existing_model_names.add(model_family.model_name)
121+
except ValueError as e:
122+
# Capture conflict errors and output warnings instead of raising exceptions
123+
import warnings
124+
125+
warnings.warn(str(e))
126+
except Exception as e:
127+
import warnings
128+
129+
warnings.warn(
130+
f"Error registering audio model {model_family.model_name}: {e}"
131+
)
71132

72133
loaded_count = register_builtin_models_unified(
73134
model_type="audio",
74135
flatten_func=flatten_model_src,
75-
model_class=AudioModelFamilyV2,
136+
model_class=CustomAudioModelFamilyV2,
76137
builtin_registry=BUILTIN_AUDIO_MODELS,
138+
custom_convert_func=convert_audio_model_format,
77139
custom_defaults={
78140
"multilingual": True,
79141
"model_lang": ["en", "zh"],
80142
"version": 2,
81-
}
143+
},
144+
special_handling=audio_special_handling,
82145
)
83146

84147

xinference/model/embedding/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def register_custom_model():
7171

7272
def register_builtin_model():
7373
# Use unified function for embedding models
74-
from ..utils import register_builtin_models_unified, flatten_quantizations
74+
from ..utils import flatten_quantizations, register_builtin_models_unified
7575
from .embed_family import BUILTIN_EMBEDDING_MODELS
7676

7777
def embedding_special_handling(registry, model_type):
@@ -89,12 +89,16 @@ def embedding_special_handling(registry, model_type):
8989
register_embedding(model_family, persist=False)
9090
existing_model_names.add(model_family.model_name)
9191
except ValueError as e:
92-
# 捕获冲突错误并输出警告,而不是抛出异常
92+
# Capture conflict errors and output warnings instead of raising exceptions
9393
import warnings
94+
9495
warnings.warn(str(e))
9596
except Exception as e:
9697
import warnings
97-
warnings.warn(f"Error registering embedding model {model_family.model_name}: {e}")
98+
99+
warnings.warn(
100+
f"Error registering embedding model {model_family.model_name}: {e}"
101+
)
98102

99103
loaded_count = register_builtin_models_unified(
100104
model_type="embedding",

xinference/model/image/__init__.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,30 +63,67 @@ def register_custom_model():
6363

6464
def register_builtin_model():
6565
# Use unified function for image models
66-
from ..utils import register_builtin_models_unified, flatten_model_src
66+
from ..utils import flatten_model_src, register_builtin_models_unified
67+
68+
def convert_image_model_format(model_json):
69+
"""
70+
Convert image model hub JSON format to Xinference expected format.
71+
Add missing required fields for ImageModelFamilyV2.
72+
"""
73+
converted = model_json.copy()
74+
75+
# Add missing required fields from model_src if they exist
76+
if "model_src" in converted and "huggingface" in converted["model_src"]:
77+
hf_info = converted["model_src"]["huggingface"]
78+
if "model_id" in hf_info and "model_id" not in converted:
79+
converted["model_id"] = hf_info["model_id"]
80+
if "model_revision" in hf_info and "model_revision" not in converted:
81+
converted["model_revision"] = hf_info["model_revision"]
82+
83+
# Add other missing required fields with defaults
84+
if "version" not in converted:
85+
converted["version"] = 2
86+
if "model_lang" not in converted:
87+
converted["model_lang"] = ["en"]
88+
89+
return converted
6790

6891
def image_special_handling(registry, model_type):
6992
"""Handle image's special registration logic"""
7093
from ...constants import XINFERENCE_MODEL_DIR
7194
from ..custom import RegistryManager
95+
from .custom import register_image
7296

7397
registry_mgr = RegistryManager.get_registry("image")
74-
existing_model_names = {spec.model_name for spec in registry_mgr.get_custom_models()}
98+
existing_model_names = {
99+
spec.model_name for spec in registry_mgr.get_custom_models()
100+
}
75101

76102
for model_name, model_families in BUILTIN_IMAGE_MODELS.items():
77103
for model_family in model_families:
78104
if model_family.model_name not in existing_model_names:
79-
# Update model descriptions for the new builtin model
80-
IMAGE_MODEL_DESCRIPTIONS.update(
81-
generate_image_description(model_family)
82-
)
83-
existing_model_names.add(model_family.model_name)
105+
try:
106+
# Actually register model to RegistryManager
107+
register_image(model_family, persist=False)
108+
existing_model_names.add(model_family.model_name)
109+
except ValueError as e:
110+
# Capture conflict errors and output warnings instead of raising exceptions
111+
import warnings
112+
113+
warnings.warn(str(e))
114+
except Exception as e:
115+
import warnings
116+
117+
warnings.warn(
118+
f"Error registering image model {model_family.model_name}: {e}"
119+
)
84120

85121
loaded_count = register_builtin_models_unified(
86122
model_type="image",
87123
flatten_func=flatten_model_src,
88-
model_class=ImageModelFamilyV2,
124+
model_class=CustomImageModelFamilyV2,
89125
builtin_registry=BUILTIN_IMAGE_MODELS,
126+
custom_convert_func=convert_image_model_format,
90127
special_handling=image_special_handling,
91128
)
92129

xinference/model/llm/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ def register_custom_model():
136136

137137
def register_builtin_model():
138138
# Use unified function for LLM models
139-
from ..utils import register_builtin_models_unified, flatten_quantizations
140139
from ..custom import RegistryManager
140+
from ..utils import flatten_quantizations, register_builtin_models_unified
141141
from .custom import register_llm
142142

143143
def llm_special_handling(registry, model_type):
@@ -154,12 +154,16 @@ def llm_special_handling(registry, model_type):
154154
register_llm(model_family, persist=False)
155155
existing_model_names.add(model_family.model_name)
156156
except ValueError as e:
157-
# 捕获冲突错误并输出警告,而不是抛出异常
157+
# Capture conflict errors and output warnings instead of raising exceptions
158158
import warnings
159+
159160
warnings.warn(str(e))
160161
except Exception as e:
161162
import warnings
162-
warnings.warn(f"Error registering LLM model {model_family.model_name}: {e}")
163+
164+
warnings.warn(
165+
f"Error registering LLM model {model_family.model_name}: {e}"
166+
)
163167

164168
loaded_count = register_builtin_models_unified(
165169
model_type="llm",

xinference/model/rerank/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def register_custom_model():
7070

7171
def register_builtin_model():
7272
# Use unified function for rerank models
73-
from ..utils import register_builtin_models_unified, flatten_quantizations
73+
from ..utils import flatten_quantizations, register_builtin_models_unified
7474

7575
loaded_count = register_builtin_models_unified(
7676
model_type="rerank",

xinference/model/utils.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -817,35 +817,31 @@ def load_complete_builtin_models(
817817

818818
def register_builtin_models_unified(
819819
model_type: str,
820-
flatten_func: callable,
821-
model_class: type,
822-
builtin_registry: dict,
823-
custom_convert_func: callable = None,
824-
custom_defaults: dict = None,
825-
special_handling: callable = None,
820+
flatten_func: Callable,
821+
model_class: Type,
822+
builtin_registry: Dict[str, Any],
823+
custom_convert_func: Optional[Callable] = None,
824+
custom_defaults: Optional[Dict[str, Any]] = None,
825+
special_handling: Optional[Callable] = None,
826826
):
827827
"""
828-
统一的内置模型注册函数
828+
Unified builtin model registration function
829829
830830
Args:
831-
model_type: 模型类型 ('llm', 'embedding', 'rerank', 'audio', 'image', 'video')
832-
flatten_func: 扁平化函数 (flatten_quantizations flatten_model_src)
833-
model_class: 模型类 (如 LLMFamilyV2)
834-
builtin_registry: 内置模型注册表
835-
custom_convert_func: 自定义转换函数 (可选)
836-
custom_defaults: 自定义默认值 (可选,用于Audio模型)
837-
special_handling: 特殊处理函数 (可选,用于Image/LLM模型)
831+
model_type: Model type ('llm', 'embedding', 'rerank', 'audio', 'image', 'video')
832+
flatten_func: Flatten function (flatten_quantizations or flatten_model_src)
833+
model_class: Model class (e.g. LLMFamilyV2)
834+
builtin_registry: Builtin model registry
835+
custom_convert_func: Custom conversion function (optional)
836+
custom_defaults: Custom default values (optional, for Audio models)
837+
special_handling: Special handling function (optional, for Image/LLM models)
838838
839839
Returns:
840-
int: 成功加载的模型数量
840+
int: Number of successfully loaded models
841841
"""
842-
import codecs
843-
import json
844-
from ..constants import XINFERENCE_MODEL_DIR
845-
846842
logger = logging.getLogger(__name__)
847843

848-
# 默认转换函数
844+
# Default conversion function
849845
def default_convert_func(model_json):
850846
if "model_specs" not in model_json:
851847
return model_json
@@ -860,43 +856,47 @@ def default_convert_func(model_json):
860856
result["model_specs"] = flattened_specs
861857
return result
862858

863-
# 使用自定义转换函数或默认函数
859+
# Use custom conversion function or default function
864860
convert_func = custom_convert_func or default_convert_func
865861

866-
# 使用统一的加载函数
862+
# Use unified loading function
867863
loaded_count = load_complete_builtin_models(
868864
model_type=model_type,
869865
builtin_registry=builtin_registry,
870866
convert_format_func=convert_func,
871867
model_class=model_class,
872868
)
873869

874-
# 应用自定义默认值 (用于Audio模型)
875-
if custom_defaults and loaded_count > 0:
870+
# Apply custom defaults (for Audio models)
871+
if custom_defaults is not None and loaded_count > 0:
876872
_apply_custom_defaults(builtin_registry, custom_defaults, model_type)
877873

878-
# 执行特殊处理 (用于Image/LLM模型)
879-
if special_handling and loaded_count > 0:
874+
# Execute special handling (for Image/LLM models)
875+
if special_handling is not None and loaded_count > 0:
880876
special_handling(builtin_registry, model_type)
881877

882-
logger.info(f"Successfully loaded {loaded_count} {model_type} models using unified function")
878+
logger.info(
879+
f"Successfully loaded {loaded_count} {model_type} models using unified function"
880+
)
883881
return loaded_count
884882

885883

886-
def _apply_custom_defaults(registry: dict, defaults: dict, model_type: str):
884+
def _apply_custom_defaults(
885+
registry: Dict[str, Any], defaults: Dict[str, Any], model_type: str
886+
):
887887
"""
888-
应用自定义默认值到模型规格
888+
Apply custom defaults to model specifications
889889
890890
Args:
891-
registry: 模型注册表
892-
defaults: 默认值字典
893-
model_type: 模型类型
891+
registry: Model registry
892+
defaults: Default values dictionary
893+
model_type: Model type
894894
"""
895895
for model_name, model_specs in registry.items():
896896
if isinstance(model_specs, list):
897-
# 对于使用列表结构的模型类型 (audio, image, video, llm)
897+
# For model types using list structure (audio, image, video, llm)
898898
for spec in model_specs:
899-
if hasattr(spec, '__dict__'):
899+
if hasattr(spec, "__dict__"):
900900
for key, value in defaults.items():
901901
if not hasattr(spec, key):
902902
setattr(spec, key, value)
@@ -905,8 +905,8 @@ def _apply_custom_defaults(registry: dict, defaults: dict, model_type: str):
905905
if key not in spec:
906906
spec[key] = value
907907
else:
908-
# 对于使用单一结构的模型类型 (embedding, rerank)
909-
if hasattr(model_specs, '__dict__'):
908+
# For model types using single structure (embedding, rerank)
909+
if hasattr(model_specs, "__dict__"):
910910
for key, value in defaults.items():
911911
if not hasattr(model_specs, key):
912912
setattr(model_specs, key, value)

xinference/model/video/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def register_builtin_model():
6868
ensuring real-time updates without server restart.
6969
"""
7070
# Use unified function for video models
71-
from ..utils import register_builtin_models_unified, flatten_model_src
71+
from ..utils import flatten_model_src, register_builtin_models_unified
7272

7373
def video_convert_func(model_json):
7474
"""Video-specific conversion function"""

0 commit comments

Comments
 (0)