Skip to content

Commit 701aca5

Browse files
authored
feat(inspect): add training device selection (#3056)
* add training device selection Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com> * fix style Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com> * improve error handling Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com> * reset uv lock * refactor device names Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com> --------- Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com>
1 parent 31c3d2f commit 701aca5

File tree

15 files changed

+140
-39
lines changed

15 files changed

+140
-39
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from fastapi import APIRouter
5+
6+
from api.endpoints import API_PREFIX
7+
from pydantic_models.devices import DeviceList
8+
from utils.devices import Devices
9+
10+
device_router = APIRouter(
11+
prefix=API_PREFIX,
12+
tags=["Job"],
13+
)
14+
15+
16+
@device_router.get("/inference-devices")
17+
async def get_inference_devices() -> DeviceList:
18+
"""Endpoint to get list of supported devices for inference"""
19+
return DeviceList(devices=Devices.inference_devices())
20+
21+
22+
@device_router.get("/training-devices")
23+
async def get_training_devices() -> DeviceList:
24+
"""Endpoint to get list of supported devices for training"""
25+
return DeviceList(devices=Devices.training_devices())

application/backend/src/api/endpoints/model_endpoints.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from api.media_rest_validator import MediaRestValidator
1111
from exceptions import ResourceNotFoundException
1212
from pydantic_models import Model, ModelList, PredictionResponse
13-
from pydantic_models.model import SupportedDevices
1413
from services import ModelService
1514
from services.exceptions import DeviceNotFoundError
1615

@@ -69,11 +68,3 @@ async def predict(
6968
return await model_service.predict_image(model, image_bytes, request.app.state.active_models, device=device)
7069
except DeviceNotFoundError as e:
7170
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
72-
73-
74-
@model_router.post(":supported-devices")
75-
async def get_supported_devices(
76-
model_service: Annotated[ModelService, Depends(get_model_service)],
77-
) -> SupportedDevices:
78-
"""Endpoint to get list of supported devices for inference"""
79-
return model_service.get_supported_devices()

application/backend/src/core/logging/setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def setup_logging(config: LogConfig | None = None) -> None:
5050
config = LogConfig()
5151

5252
# overwrite global log_config
53-
global global_log_config
5453
global_log_config = config
5554

5655
for worker_name, log_file in global_log_config.worker_log_info.items():

application/backend/src/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from loguru import logger
1616
from starlette.responses import JSONResponse, Response
1717

18+
from api.endpoints.devices_endpoints import device_router
1819
from api.endpoints.job_endpoints import job_router
1920
from api.endpoints.media_endpoints import media_router
2021
from api.endpoints.model_endpoints import model_router
@@ -57,6 +58,7 @@
5758
app.include_router(sink_router)
5859
app.include_router(webrtc_router)
5960
app.include_router(trainable_model_router)
61+
app.include_router(device_router)
6062

6163

6264
@app.exception_handler(GetiBaseException)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
from pydantic import BaseModel
4+
5+
6+
class DeviceList(BaseModel):
7+
devices: list[str]
8+
9+
model_config = {
10+
"json_schema_extra": {
11+
"example": {
12+
"devices": ["CPU", "XPU", "NPU"],
13+
}
14+
}
15+
}

application/backend/src/pydantic_models/job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ class JobSubmitted(BaseModel):
4949
class TrainJobPayload(BaseModel):
5050
project_id: UUID = Field(exclude=True)
5151
model_name: str
52+
device: str | None

application/backend/src/pydantic_models/model.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,3 @@ def validate_score_range(self) -> "PredictionResponse":
8383
}
8484
}
8585
}
86-
87-
88-
class SupportedDevices(BaseModel):
89-
devices: list[str]
90-
91-
model_config = {
92-
"json_schema_extra": {
93-
"example": {
94-
"devices": ["CPU", "GPU", "NPU"]
95-
}
96-
}
97-
}

application/backend/src/services/model_service.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,22 @@
44
import base64
55
import io
66
from dataclasses import dataclass
7-
from functools import lru_cache
87
from multiprocessing.synchronize import Event as EventClass
98
from uuid import UUID
109

1110
import cv2
1211
import numpy as np
13-
import openvino as ov
1412
import openvino.properties.hint as ov_hints
1513
from anomalib.deploy import ExportType, OpenVINOInferencer
1614
from loguru import logger
1715
from PIL import Image
1816

1917
from db import get_async_db_session_ctx
2018
from pydantic_models import Model, ModelList, PredictionLabel, PredictionResponse
21-
from pydantic_models.model import SupportedDevices
2219
from repositories import ModelRepository
2320
from repositories.binary_repo import ModelBinaryRepository
2421
from services.exceptions import DeviceNotFoundError
22+
from utils.devices import Devices
2523

2624
DEFAULT_DEVICE = "AUTO"
2725

@@ -103,11 +101,11 @@ async def load_inference_model(cls, model: Model, device: str | None = None) ->
103101
return await asyncio.to_thread(
104102
OpenVINOInferencer,
105103
path=model_path,
106-
device=_device,
104+
device=_device.upper(), # OV always expects uppercase device names
107105
config={ov_hints.performance_mode: ov_hints.PerformanceMode.LATENCY},
108106
)
109107
except Exception as e:
110-
if _device not in cls.get_supported_devices().devices:
108+
if device and not Devices.is_device_supported_for_inference(device):
111109
raise DeviceNotFoundError(device_name=_device) from e
112110
raise e
113111

@@ -195,10 +193,3 @@ def _run_prediction_pipeline(inference_model: OpenVINOInferencer, image_bytes: b
195193
score = float(pred.pred_score.item())
196194

197195
return {"anomaly_map": im_base64, "label": label, "score": score}
198-
199-
@staticmethod
200-
@lru_cache
201-
def get_supported_devices() -> SupportedDevices:
202-
"""Get list of supported devices for inference."""
203-
core = ov.Core()
204-
return SupportedDevices(devices=core.available_devices)

application/backend/src/services/training_service.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository
1616
from services import ModelService
1717
from services.job_service import JobService
18+
from utils.devices import Devices
1819
from utils.experiment_loggers import TrackioLogger
1920

2021

@@ -59,6 +60,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
5960
await job_service.update_job_status(job_id=job.id, status=JobStatus.RUNNING, message="Training started")
6061
project_id = job.project_id
6162
model_name = job.payload.get("model_name")
63+
device = job.payload.get("device")
6264
if model_name is None:
6365
raise ValueError(f"Job {job.id} payload must contain 'model_name'")
6466

@@ -73,7 +75,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
7375
try:
7476
# Use asyncio.to_thread to keep event loop responsive
7577
# TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs
76-
trained_model = await asyncio.to_thread(cls._train_model, model)
78+
trained_model = await asyncio.to_thread(cls._train_model, model=model, device=device)
7779
if trained_model is None:
7880
raise ValueError("Training failed - model is None")
7981

@@ -94,7 +96,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
9496
raise e
9597

9698
@staticmethod
97-
def _train_model(model: Model) -> Model | None:
99+
def _train_model(model: Model, device: str | None = None) -> Model | None:
98100
"""
99101
Execute CPU-intensive model training using anomalib.
100102
@@ -104,13 +106,22 @@ def _train_model(model: Model) -> Model | None:
104106
105107
Args:
106108
model: Model object with training configuration
109+
device: Device to train on
107110
108111
Returns:
109112
Model: Trained model with updated export_path and is_ready=True
110113
"""
111114
from core.logging import global_log_config
112115
from core.logging.handlers import LoggerStdoutWriter
113116

117+
if device and not Devices.is_device_supported_for_training(device):
118+
raise ValueError(
119+
f"Device '{device}' is not supported for training. "
120+
f"Supported devices: {', '.join(Devices.training_devices())}"
121+
)
122+
123+
logger.info(f"Training on device: {device or 'auto'}")
124+
114125
model_binary_repo = ModelBinaryRepository(project_id=model.project_id, model_id=model.id)
115126
image_binary_repo = ImageBinaryRepository(project_id=model.project_id)
116127
image_folder_path = image_binary_repo.project_folder_path
@@ -134,6 +145,7 @@ def _train_model(model: Model) -> Model | None:
134145
default_root_dir=model.export_path,
135146
logger=[trackio, tensorboard],
136147
max_epochs=10,
148+
accelerator=device,
137149
)
138150

139151
# Execute training and export
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
from functools import lru_cache
4+
5+
import openvino as ov
6+
from lightning.pytorch.accelerators import AcceleratorRegistry
7+
8+
9+
class Devices:
10+
"""Utility class for device-related operations."""
11+
@staticmethod
12+
@lru_cache
13+
def training_devices() -> list[str]:
14+
"""Get list of supported devices for training."""
15+
devices = []
16+
for device_name, device_info in AcceleratorRegistry.items():
17+
accelerator = device_info["accelerator"]
18+
if accelerator.is_available():
19+
devices.append(device_name.casefold())
20+
return devices
21+
22+
@staticmethod
23+
@lru_cache
24+
def inference_devices() -> list[str]:
25+
"""Get list of supported devices for inference."""
26+
ov_core = ov.Core()
27+
return [device.casefold() for device in ov_core.available_devices]
28+
29+
@classmethod
30+
@lru_cache
31+
def _is_device_supported(cls, device_name: str, for_training: bool = False) -> bool:
32+
"""Check if a device is supported for inference or training.
33+
34+
Args:
35+
device_name (str): Name of the device to check.
36+
for_training (bool): If True, check for training devices; otherwise, check for inference devices.
37+
38+
Returns:
39+
bool: True if the device is supported, False otherwise.
40+
"""
41+
device_name = device_name.casefold()
42+
if for_training:
43+
return device_name in cls.training_devices()
44+
return device_name in cls.inference_devices()
45+
46+
@classmethod
47+
def is_device_supported_for_inference(cls, device_name: str) -> bool:
48+
"""Check if a device is supported for inference.
49+
50+
Args:
51+
device_name (str): Name of the device to check.
52+
53+
Returns:
54+
bool: True if the device is supported for inference, False otherwise.
55+
"""
56+
return cls._is_device_supported(device_name, for_training=False)
57+
58+
@classmethod
59+
def is_device_supported_for_training(cls, device_name: str) -> bool:
60+
"""Check if a device is supported for training.
61+
62+
Args:
63+
device_name (str): Name of the device to check.
64+
65+
Returns:
66+
bool: True if the device is supported for training, False otherwise.
67+
"""
68+
return cls._is_device_supported(device_name, for_training=True)

0 commit comments

Comments
 (0)