1515from repositories .binary_repo import ImageBinaryRepository , ModelBinaryRepository
1616from services import ModelService
1717from services .job_service import JobService
18+ from utils .devices import Devices
1819from utils .experiment_loggers import TrackioLogger
1920
2021
@@ -59,6 +60,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
5960 await job_service .update_job_status (job_id = job .id , status = JobStatus .RUNNING , message = "Training started" )
6061 project_id = job .project_id
6162 model_name = job .payload .get ("model_name" )
63+ device = job .payload .get ("device" )
6264 if model_name is None :
6365 raise ValueError (f"Job { job .id } payload must contain 'model_name'" )
6466
@@ -73,7 +75,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
7375 try :
7476 # Use asyncio.to_thread to keep event loop responsive
7577 # TODO: Consider ProcessPoolExecutor for true parallelism with multiple jobs
76- trained_model = await asyncio .to_thread (cls ._train_model , model )
78+ trained_model = await asyncio .to_thread (cls ._train_model , model = model , device = device )
7779 if trained_model is None :
7880 raise ValueError ("Training failed - model is None" )
7981
@@ -94,7 +96,7 @@ async def _run_training_job(cls, job: Job, job_service: JobService) -> Model:
9496 raise e
9597
9698 @staticmethod
97- def _train_model (model : Model ) -> Model | None :
99+ def _train_model (model : Model , device : str | None = None ) -> Model | None :
98100 """
99101 Execute CPU-intensive model training using anomalib.
100102
@@ -104,13 +106,22 @@ def _train_model(model: Model) -> Model | None:
104106
105107 Args:
106108 model: Model object with training configuration
109+ device: Device to train on
107110
108111 Returns:
109112 Model: Trained model with updated export_path and is_ready=True
110113 """
111114 from core .logging import global_log_config
112115 from core .logging .handlers import LoggerStdoutWriter
113116
117+ if device and not Devices .is_device_supported_for_training (device ):
118+ raise ValueError (
119+ f"Device '{ device } ' is not supported for training. "
120+ f"Supported devices: { ', ' .join (Devices .training_devices ())} "
121+ )
122+
123+ logger .info (f"Training on device: { device or 'auto' } " )
124+
114125 model_binary_repo = ModelBinaryRepository (project_id = model .project_id , model_id = model .id )
115126 image_binary_repo = ImageBinaryRepository (project_id = model .project_id )
116127 image_folder_path = image_binary_repo .project_folder_path
@@ -134,6 +145,7 @@ def _train_model(model: Model) -> Model | None:
134145 default_root_dir = model .export_path ,
135146 logger = [trackio , tensorboard ],
136147 max_epochs = 10 ,
148+ accelerator = device ,
137149 )
138150
139151 # Execute training and export
0 commit comments