Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions application/backend/src/api/endpoints/devices_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from fastapi import APIRouter

from api.endpoints import API_PREFIX
from pydantic_models.devices import DeviceList
from utils.devices import Devices

device_router = APIRouter(
prefix=API_PREFIX,
tags=["Job"],
)


@device_router.get("/inference-devices")
async def get_inference_devices() -> DeviceList:
"""Endpoint to get list of supported devices for inference"""
return DeviceList(devices=Devices.inference_devices())


@device_router.get("/training-devices")
async def get_training_devices() -> DeviceList:
"""Endpoint to get list of supported devices for training"""
return DeviceList(devices=Devices.training_devices())
9 changes: 0 additions & 9 deletions application/backend/src/api/endpoints/model_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from api.media_rest_validator import MediaRestValidator
from exceptions import ResourceNotFoundException
from pydantic_models import Model, ModelList, PredictionResponse
from pydantic_models.model import SupportedDevices
from services import ModelService
from services.exceptions import DeviceNotFoundError

Expand Down Expand Up @@ -69,11 +68,3 @@ async def predict(
return await model_service.predict_image(model, image_bytes, request.app.state.active_models, device=device)
except DeviceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))


@model_router.post(":supported-devices")
async def get_supported_devices(
model_service: Annotated[ModelService, Depends(get_model_service)],
) -> SupportedDevices:
"""Endpoint to get list of supported devices for inference"""
return model_service.get_supported_devices()
1 change: 0 additions & 1 deletion application/backend/src/core/logging/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def setup_logging(config: LogConfig | None = None) -> None:
config = LogConfig()

# overwrite global log_config
global global_log_config
global_log_config = config

for worker_name, log_file in global_log_config.worker_log_info.items():
Expand Down
2 changes: 2 additions & 0 deletions application/backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from loguru import logger
from starlette.responses import JSONResponse, Response

from api.endpoints.devices_endpoints import device_router
from api.endpoints.job_endpoints import job_router
from api.endpoints.media_endpoints import media_router
from api.endpoints.model_endpoints import model_router
Expand Down Expand Up @@ -57,6 +58,7 @@
app.include_router(sink_router)
app.include_router(webrtc_router)
app.include_router(trainable_model_router)
app.include_router(device_router)


@app.exception_handler(GetiBaseException)
Expand Down
15 changes: 15 additions & 0 deletions application/backend/src/pydantic_models/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from pydantic import BaseModel


class DeviceList(BaseModel):
devices: list[str]

model_config = {
"json_schema_extra": {
"example": {
"devices": ["CPU", "XPU", "NPU"],
}
}
}
1 change: 1 addition & 0 deletions application/backend/src/pydantic_models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ class JobSubmitted(BaseModel):
class TrainJobPayload(BaseModel):
project_id: UUID = Field(exclude=True)
model_name: str
device: str | None
12 changes: 0 additions & 12 deletions application/backend/src/pydantic_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,3 @@ def validate_score_range(self) -> "PredictionResponse":
}
}
}


class SupportedDevices(BaseModel):
devices: list[str]

model_config = {
"json_schema_extra": {
"example": {
"devices": ["CPU", "GPU", "NPU"]
}
}
}
15 changes: 3 additions & 12 deletions application/backend/src/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,22 @@
import base64
import io
from dataclasses import dataclass
from functools import lru_cache
from multiprocessing.synchronize import Event as EventClass
from uuid import UUID

import cv2
import numpy as np
import openvino as ov
import openvino.properties.hint as ov_hints
from anomalib.deploy import ExportType, OpenVINOInferencer
from loguru import logger
from PIL import Image

from db import get_async_db_session_ctx
from pydantic_models import Model, ModelList, PredictionLabel, PredictionResponse
from pydantic_models.model import SupportedDevices
from repositories import ModelRepository
from repositories.binary_repo import ModelBinaryRepository
from services.exceptions import DeviceNotFoundError
from utils.devices import Devices

DEFAULT_DEVICE = "AUTO"

Expand Down Expand Up @@ -103,11 +101,11 @@ async def load_inference_model(cls, model: Model, device: str | None = None) ->
return await asyncio.to_thread(
OpenVINOInferencer,
path=model_path,
device=_device,
device=_device.upper(), # OV always expects uppercase device names
config={ov_hints.performance_mode: ov_hints.PerformanceMode.LATENCY},
)
except Exception as e:
if _device not in cls.get_supported_devices().devices:
if device and not Devices.is_device_supported_for_inference(device):
raise DeviceNotFoundError(device_name=_device) from e
raise e

Expand Down Expand Up @@ -195,10 +193,3 @@ def _run_prediction_pipeline(inference_model: OpenVINOInferencer, image_bytes: b
score = float(pred.pred_score.item())

return {"anomaly_map": im_base64, "label": label, "score": score}

@staticmethod
@lru_cache
def get_supported_devices() -> SupportedDevices:
"""Get list of supported devices for inference."""
core = ov.Core()
return SupportedDevices(devices=core.available_devices)
16 changes: 14 additions & 2 deletions application/backend/src/services/training_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from repositories.binary_repo import ImageBinaryRepository, ModelBinaryRepository
from services import ModelService
from services.job_service import JobService
from utils.devices import Devices
from utils.experiment_loggers import TrackioLogger


Expand Down Expand Up @@ -59,6 +60,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
await job_service.update_job_status(job_id=job.id, status=JobStatus.RUNNING, message="Training started")
project_id = job.project_id
model_name = job.payload.get("model_name")
device = job.payload.get("device")
if model_name is None:
raise ValueError(f"Job {job.id} payload must contain 'model_name'")

Expand All @@ -73,7 +75,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
try:
# Use asyncio.to_thread to keep event loop responsive
# TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs
trained_model = await asyncio.to_thread(cls._train_model, model)
trained_model = await asyncio.to_thread(cls._train_model, model=model, device=device)
if trained_model is None:
raise ValueError("Training failed - model is None")

Expand All @@ -94,7 +96,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
raise e

@staticmethod
def _train_model(model: Model) -> Model | None:
def _train_model(model: Model, device: str | None = None) -> Model | None:
"""
Execute CPU-intensive model training using anomalib.

Expand All @@ -104,13 +106,22 @@ def _train_model(model: Model) -> Model | None:

Args:
model: Model object with training configuration
device: Device to train on

Returns:
Model: Trained model with updated export_path and is_ready=True
"""
from core.logging import global_log_config
from core.logging.handlers import LoggerStdoutWriter

if device and not Devices.is_device_supported_for_training(device):
raise ValueError(
f"Device '{device}' is not supported for training. "
f"Supported devices: {', '.join(Devices.training_devices())}"
)

logger.info(f"Training on device: {device or 'auto'}")

model_binary_repo = ModelBinaryRepository(project_id=model.project_id, model_id=model.id)
image_binary_repo = ImageBinaryRepository(project_id=model.project_id)
image_folder_path = image_binary_repo.project_folder_path
Expand All @@ -134,6 +145,7 @@ def _train_model(model: Model) -> Model | None:
default_root_dir=model.export_path,
logger=[trackio, tensorboard],
max_epochs=10,
accelerator=device,
)

# Execute training and export
Expand Down
68 changes: 68 additions & 0 deletions application/backend/src/utils/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from functools import lru_cache

import openvino as ov
from lightning.pytorch.accelerators import AcceleratorRegistry


class Devices:
"""Utility class for device-related operations."""
@staticmethod
@lru_cache
def training_devices() -> list[str]:
"""Get list of supported devices for training."""
devices = []
for device_name, device_info in AcceleratorRegistry.items():
accelerator = device_info["accelerator"]
if accelerator.is_available():
devices.append(device_name.casefold())
return devices

@staticmethod
@lru_cache
def inference_devices() -> list[str]:
"""Get list of supported devices for inference."""
ov_core = ov.Core()
return [device.casefold() for device in ov_core.available_devices]

@classmethod
@lru_cache
def _is_device_supported(cls, device_name: str, for_training: bool = False) -> bool:
"""Check if a device is supported for inference or training.

Args:
device_name (str): Name of the device to check.
for_training (bool): If True, check for training devices; otherwise, check for inference devices.

Returns:
bool: True if the device is supported, False otherwise.
"""
device_name = device_name.casefold()
if for_training:
return device_name in cls.training_devices()
return device_name in cls.inference_devices()

@classmethod
def is_device_supported_for_inference(cls, device_name: str) -> bool:
"""Check if a device is supported for inference.

Args:
device_name (str): Name of the device to check.

Returns:
bool: True if the device is supported for inference, False otherwise.
"""
return cls._is_device_supported(device_name, for_training=False)

@classmethod
def is_device_supported_for_training(cls, device_name: str) -> bool:
"""Check if a device is supported for training.

Args:
device_name (str): Name of the device to check.

Returns:
bool: True if the device is supported for training, False otherwise.
"""
return cls._is_device_supported(device_name, for_training=True)
2 changes: 1 addition & 1 deletion application/backend/src/workers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def run_loop(self) -> None: # noqa: C901, PLR0912, PLR0915
# Refresh loaded model reference if changed
if self._loaded_model is None or self._loaded_model.id != active_model.id:
self._loaded_model = active_model
logger.info("Using model '%s' (%s) for inference", self._loaded_model.name, self._loaded_model.id)
logger.info(f"Using model '{self._loaded_model.name}' ({self._loaded_model.id}) for inference")

await self._handle_model_reload()

Expand Down
2 changes: 2 additions & 0 deletions application/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def fxt_job_payload(fxt_project):
return TrainJobPayload(
project_id=fxt_project.id,
model_name="padim",
device=None,
)


Expand All @@ -76,6 +77,7 @@ def fxt_model(fxt_project):
format="openvino",
is_ready=True,
export_path="/path/to/model",
train_job_id=uuid.uuid4(),
)


Expand Down
1 change: 1 addition & 0 deletions application/backend/tests/unit/endpoints/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def fxt_model(fxt_project):
name="test_model",
project_id=fxt_project.id,
export_path="/path/to/model",
train_job_id=uuid4(),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,13 @@ def test_update_pipeline_running_to_running_model_change(
self, fxt_pipeline_service, fxt_pipeline, fxt_pipeline_repository, fxt_model_service
):
"""Test updating running pipeline with model change."""
new_model = Model(id=uuid.uuid4(), project_id=fxt_pipeline.project_id, name="new_model", format="openvino")
new_model = Model(
id=uuid.uuid4(),
project_id=fxt_pipeline.project_id,
name="new_model",
format="openvino",
train_job_id=uuid.uuid4(),
)
updated_pipeline = fxt_pipeline.model_copy(update={"model": new_model, "model_id": new_model.id})
fxt_pipeline_repository.get_by_id.return_value = fxt_pipeline
fxt_pipeline_repository.update.return_value = updated_pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_train_pending_job_cleanup_on_failure(

with patch("services.training_service.asyncio.to_thread") as mock_to_thread:
# Mock the training to succeed first, setting export_path, then fail
def mock_train_model(cls, model):
def mock_train_model(cls, model, device=None):
model.export_path = "/path/to/model"
raise Exception("Training failed")

Expand Down
Loading