Skip to content

Commit 76d2a00

Browse files
committed
Function 2: One-Click Model Update
1 parent 8f37219 commit 76d2a00

File tree

2 files changed

+246
-0
lines changed

2 files changed

+246
-0
lines changed

xinference/api/restful_api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ class AddModelRequest(BaseModel):
203203
model_json: Dict[str, Any]
204204

205205

206+
class UpdateModelRequest(BaseModel):
207+
model_type: str
208+
209+
206210
class BuildGradioInterfaceRequest(BaseModel):
207211
model_type: str
208212
model_name: str
@@ -915,6 +919,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
915919
else None
916920
),
917921
)
922+
self._router.add_api_route(
923+
"/v1/models/update_type",
924+
self.update_model_type,
925+
methods=["POST"],
926+
dependencies=(
927+
[Security(self._auth_service, scopes=["models:add"])]
928+
if self.is_authenticated()
929+
else None
930+
),
931+
)
918932
self._router.add_api_route(
919933
"/v1/cache/models",
920934
self.list_cached_models,
@@ -3192,6 +3206,47 @@ async def add_model(self, request: Request) -> JSONResponse:
31923206
content={"message": f"Model added successfully for type: {model_type}"}
31933207
)
31943208

3209+
async def update_model_type(self, request: Request) -> JSONResponse:
3210+
try:
3211+
# Parse request
3212+
raw_json = await request.json()
3213+
logger.info(f"[DEBUG] Update model type API called with: {raw_json}")
3214+
3215+
body = UpdateModelRequest.parse_obj(raw_json)
3216+
model_type = body.model_type
3217+
3218+
logger.info(f"[DEBUG] Parsed model_type for update: {model_type}")
3219+
3220+
# Get supervisor reference
3221+
supervisor_ref = await self._get_supervisor_ref()
3222+
3223+
# Call supervisor to update model type
3224+
logger.info(
3225+
f"[DEBUG] Calling supervisor.update_model_type with model_type: {model_type}"
3226+
)
3227+
await supervisor_ref.update_model_type(model_type)
3228+
logger.info(f"[DEBUG] Supervisor.update_model_type completed successfully")
3229+
3230+
except ValueError as re:
3231+
logger.error(
3232+
f"[DEBUG] ValueError in update_model_type API: {re}", exc_info=True
3233+
)
3234+
raise HTTPException(status_code=400, detail=str(re))
3235+
except Exception as e:
3236+
logger.error(
3237+
f"[DEBUG] Unexpected error in update_model_type API: {e}", exc_info=True
3238+
)
3239+
raise HTTPException(status_code=500, detail=str(e))
3240+
3241+
logger.info(
3242+
f"[DEBUG] Update model type API completed successfully for model_type: {model_type}"
3243+
)
3244+
return JSONResponse(
3245+
content={
3246+
"message": f"Model configurations updated successfully for type: {model_type}"
3247+
}
3248+
)
3249+
31953250
async def list_model_registrations(
31963251
self, model_type: str, detailed: bool = Query(False)
31973252
) -> JSONResponse:

xinference/core/supervisor.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,197 @@ async def _sync_register_model(
14831483
logger.warning(f"finish unregister model: {model} for {name}")
14841484
raise e
14851485

1486+
@log_async(logger=logger)
1487+
async def update_model_type(self, model_type: str):
1488+
"""
1489+
Update model configurations for a specific model type by downloading
1490+
the latest JSON from the remote API and storing it locally.
1491+
1492+
Args:
1493+
model_type: Type of model (LLM, embedding, image, etc.)
1494+
"""
1495+
import json
1496+
1497+
import requests
1498+
1499+
logger.info(
1500+
f"[DEBUG SUPERVISOR] update_model_type called with model_type: {model_type}"
1501+
)
1502+
1503+
# Validate model type
1504+
normalized_model_type = "LLM" if model_type.lower() == "llm" else model_type
1505+
supported_types = list(self._custom_register_type_to_cls.keys())
1506+
1507+
if normalized_model_type not in supported_types:
1508+
logger.error(
1509+
f"[DEBUG SUPERVISOR] Unsupported model type: {normalized_model_type}"
1510+
)
1511+
raise ValueError(
1512+
f"Unsupported model type '{model_type}'. "
1513+
f"Supported types are: {', '.join(supported_types)}"
1514+
)
1515+
1516+
# Use normalized model type for the rest of the function
1517+
model_type = normalized_model_type
1518+
logger.info(f"[DEBUG SUPERVISOR] Using model_type: '{model_type}' for update")
1519+
1520+
# Construct the URL to download JSON
1521+
url = f"https://model.xinference.io/api/models/download?model_type={model_type.lower()}"
1522+
logger.info(f"[DEBUG SUPERVISOR] Downloading model configurations from: {url}")
1523+
1524+
try:
1525+
# Download JSON from remote API
1526+
response = requests.get(url, timeout=30)
1527+
response.raise_for_status()
1528+
1529+
# Parse JSON response
1530+
model_data = response.json()
1531+
logger.info(
1532+
f"[DEBUG SUPERVISOR] Successfully downloaded JSON for model type: {model_type}"
1533+
)
1534+
logger.info(f"[DEBUG SUPERVISOR] JSON data type: {type(model_data)}")
1535+
1536+
if isinstance(model_data, dict):
1537+
logger.info(
1538+
f"[DEBUG SUPERVISOR] JSON data keys: {list(model_data.keys())}"
1539+
)
1540+
elif isinstance(model_data, list):
1541+
logger.info(
1542+
f"[DEBUG SUPERVISOR] JSON data contains {len(model_data)} items"
1543+
)
1544+
if model_data:
1545+
logger.info(
1546+
f"[DEBUG SUPERVISOR] First item keys: {list(model_data[0].keys()) if isinstance(model_data[0], dict) else 'Not a dict'}"
1547+
)
1548+
1549+
# Store the JSON data using CacheManager
1550+
logger.info(f"[DEBUG SUPERVISOR] Storing model configurations...")
1551+
await self._store_model_configurations(model_type, model_data)
1552+
logger.info(f"[DEBUG SUPERVISOR] Model configurations stored successfully")
1553+
1554+
except requests.exceptions.RequestException as e:
1555+
logger.error(
1556+
f"[DEBUG SUPERVISOR] Network error downloading model configurations: {e}"
1557+
)
1558+
raise ValueError(f"Failed to download model configurations: {str(e)}")
1559+
except json.JSONDecodeError as e:
1560+
logger.error(f"[DEBUG SUPERVISOR] JSON decode error: {e}")
1561+
raise ValueError(f"Invalid JSON response from remote API: {str(e)}")
1562+
except Exception as e:
1563+
logger.error(
1564+
f"[DEBUG SUPERVISOR] Unexpected error during model update: {e}",
1565+
exc_info=True,
1566+
)
1567+
raise ValueError(f"Failed to update model configurations: {str(e)}")
1568+
1569+
async def _store_model_configurations(self, model_type: str, model_data):
1570+
"""
1571+
Store model configurations using the appropriate CacheManager.
1572+
1573+
Args:
1574+
model_type: Type of model
1575+
model_data: JSON data containing model configurations
1576+
"""
1577+
1578+
logger.info(
1579+
f"[DEBUG SUPERVISOR] Storing configurations for model type: {model_type}"
1580+
)
1581+
1582+
try:
1583+
# Create a temporary model spec to get CacheManager instance
1584+
# We need to determine the appropriate model spec class for this model type
1585+
model_spec_cls, _, _, _ = self._custom_register_type_to_cls[model_type]
1586+
1587+
# Handle different response formats
1588+
if isinstance(model_data, dict):
1589+
# Single model configuration
1590+
logger.info(f"[DEBUG SUPERVISOR] Processing single model configuration")
1591+
await self._store_single_model_config(
1592+
model_type, model_data, model_spec_cls
1593+
)
1594+
elif isinstance(model_data, list):
1595+
# Multiple model configurations
1596+
logger.info(
1597+
f"[DEBUG SUPERVISOR] Processing {len(model_data)} model configurations"
1598+
)
1599+
for i, model_config in enumerate(model_data):
1600+
if isinstance(model_config, dict):
1601+
logger.info(f"[DEBUG SUPERVISOR] Processing model config {i+1}")
1602+
await self._store_single_model_config(
1603+
model_type, model_config, model_spec_cls
1604+
)
1605+
else:
1606+
logger.warning(
1607+
f"[DEBUG SUPERVISOR] Skipping invalid model config {i+1}: not a dict"
1608+
)
1609+
else:
1610+
raise ValueError(
1611+
f"Invalid model data format: expected dict or list, got {type(model_data)}"
1612+
)
1613+
1614+
except Exception as e:
1615+
logger.error(
1616+
f"[DEBUG SUPERVISOR] Error storing model configurations: {e}",
1617+
exc_info=True,
1618+
)
1619+
raise
1620+
1621+
async def _store_single_model_config(
1622+
self, model_type: str, model_config: dict, model_spec_cls
1623+
):
1624+
"""
1625+
Store a single model configuration.
1626+
1627+
Args:
1628+
model_type: Type of model
1629+
model_config: Single model configuration dictionary
1630+
model_spec_cls: Model specification class
1631+
"""
1632+
from ..model.cache_manager import CacheManager
1633+
1634+
# Ensure required fields are present
1635+
if "model_name" not in model_config:
1636+
logger.warning(
1637+
f"[DEBUG SUPERVISOR] Skipping model config without model_name: {model_config}"
1638+
)
1639+
return
1640+
1641+
model_name = model_config["model_name"]
1642+
logger.info(f"[DEBUG SUPERVISOR] Storing model: {model_name}")
1643+
1644+
# Validate model name format
1645+
from ..model.utils import is_valid_model_name
1646+
1647+
if not is_valid_model_name(model_name):
1648+
logger.warning(
1649+
f"[DEBUG SUPERVISOR] Skipping model with invalid name: {model_name}"
1650+
)
1651+
return
1652+
1653+
try:
1654+
# Convert model hub JSON format to Xinference expected format
1655+
converted_config = self._convert_model_json_format(model_config)
1656+
logger.info(f"[DEBUG SUPERVISOR] Converted model config for: {model_name}")
1657+
1658+
# Create model spec instance
1659+
model_spec = model_spec_cls.parse_obj(converted_config)
1660+
logger.info(f"[DEBUG SUPERVISOR] Created model spec for: {model_name}")
1661+
1662+
# Create CacheManager and store the configuration
1663+
cache_manager = CacheManager(model_spec)
1664+
cache_manager.register_custom_model(model_type)
1665+
logger.info(
1666+
f"[DEBUG SUPERVISOR] Stored model configuration for: {model_name}"
1667+
)
1668+
1669+
except Exception as e:
1670+
logger.error(
1671+
f"[DEBUG SUPERVISOR] Error storing model {model_name}: {e}",
1672+
exc_info=True,
1673+
)
1674+
# Continue with other models instead of failing completely
1675+
return
1676+
14861677
@log_async(logger=logger)
14871678
async def unregister_model(self, model_type: str, model_name: str):
14881679
if model_type in self._custom_register_type_to_cls:

0 commit comments

Comments
 (0)