@@ -817,35 +817,31 @@ def load_complete_builtin_models(
817817
818818def register_builtin_models_unified (
819819 model_type : str ,
820- flatten_func : callable ,
821- model_class : type ,
822- builtin_registry : dict ,
823- custom_convert_func : callable = None ,
824- custom_defaults : dict = None ,
825- special_handling : callable = None ,
820+ flatten_func : Callable ,
821+ model_class : Type ,
822+ builtin_registry : Dict [ str , Any ] ,
823+ custom_convert_func : Optional [ Callable ] = None ,
824+ custom_defaults : Optional [ Dict [ str , Any ]] = None ,
825+ special_handling : Optional [ Callable ] = None ,
826826):
827827 """
828- 统一的内置模型注册函数
828+ Unified builtin model registration function
829829
830830 Args:
831- model_type: 模型类型 ('llm', 'embedding', 'rerank', 'audio', 'image', 'video')
832- flatten_func: 扁平化函数 (flatten_quantizations 或 flatten_model_src)
833- model_class: 模型类 (如 LLMFamilyV2)
834- builtin_registry: 内置模型注册表
835- custom_convert_func: 自定义转换函数 (可选 )
836- custom_defaults: 自定义默认值 (可选,用于Audio模型 )
837- special_handling: 特殊处理函数 (可选,用于Image/LLM模型 )
831+ model_type: Model type ('llm', 'embedding', 'rerank', 'audio', 'image', 'video')
832+ flatten_func: Flatten function (flatten_quantizations or flatten_model_src)
833+ model_class: Model class (e.g. LLMFamilyV2)
834+ builtin_registry: Builtin model registry
835+ custom_convert_func: Custom conversion function (optional )
836+ custom_defaults: Custom default values (optional, for Audio models )
837+ special_handling: Special handling function (optional, for Image/LLM models )
838838
839839 Returns:
840- int: 成功加载的模型数量
840+ int: Number of successfully loaded models
841841 """
842- import codecs
843- import json
844- from ..constants import XINFERENCE_MODEL_DIR
845-
846842 logger = logging .getLogger (__name__ )
847843
848- # 默认转换函数
844+ # Default conversion function
849845 def default_convert_func (model_json ):
850846 if "model_specs" not in model_json :
851847 return model_json
@@ -860,43 +856,47 @@ def default_convert_func(model_json):
860856 result ["model_specs" ] = flattened_specs
861857 return result
862858
863- # 使用自定义转换函数或默认函数
859+ # Use custom conversion function or default function
864860 convert_func = custom_convert_func or default_convert_func
865861
866- # 使用统一的加载函数
862+ # Use unified loading function
867863 loaded_count = load_complete_builtin_models (
868864 model_type = model_type ,
869865 builtin_registry = builtin_registry ,
870866 convert_format_func = convert_func ,
871867 model_class = model_class ,
872868 )
873869
874- # 应用自定义默认值 (用于Audio模型 )
875- if custom_defaults and loaded_count > 0 :
870+ # Apply custom defaults (for Audio models )
871+ if custom_defaults is not None and loaded_count > 0 :
876872 _apply_custom_defaults (builtin_registry , custom_defaults , model_type )
877873
878- # 执行特殊处理 (用于Image/LLM模型 )
879- if special_handling and loaded_count > 0 :
874+ # Execute special handling (for Image/LLM models )
875+ if special_handling is not None and loaded_count > 0 :
880876 special_handling (builtin_registry , model_type )
881877
882- logger .info (f"Successfully loaded { loaded_count } { model_type } models using unified function" )
878+ logger .info (
879+ f"Successfully loaded { loaded_count } { model_type } models using unified function"
880+ )
883881 return loaded_count
884882
885883
886- def _apply_custom_defaults (registry : dict , defaults : dict , model_type : str ):
884+ def _apply_custom_defaults (
885+ registry : Dict [str , Any ], defaults : Dict [str , Any ], model_type : str
886+ ):
887887 """
888- 应用自定义默认值到模型规格
888+ Apply custom defaults to model specifications
889889
890890 Args:
891- registry: 模型注册表
892- defaults: 默认值字典
893- model_type: 模型类型
891+ registry: Model registry
892+ defaults: Default values dictionary
893+ model_type: Model type
894894 """
895895 for model_name , model_specs in registry .items ():
896896 if isinstance (model_specs , list ):
897- # 对于使用列表结构的模型类型 (audio, image, video, llm)
897+ # For model types using list structure (audio, image, video, llm)
898898 for spec in model_specs :
899- if hasattr (spec , ' __dict__' ):
899+ if hasattr (spec , " __dict__" ):
900900 for key , value in defaults .items ():
901901 if not hasattr (spec , key ):
902902 setattr (spec , key , value )
@@ -905,8 +905,8 @@ def _apply_custom_defaults(registry: dict, defaults: dict, model_type: str):
905905 if key not in spec :
906906 spec [key ] = value
907907 else :
908- # 对于使用单一结构的模型类型 (embedding, rerank)
909- if hasattr (model_specs , ' __dict__' ):
908+ # For model types using single structure (embedding, rerank)
909+ if hasattr (model_specs , " __dict__" ):
910910 for key , value in defaults .items ():
911911 if not hasattr (model_specs , key ):
912912 setattr (model_specs , key , value )
0 commit comments