diff --git a/paddlex/inference/models/doc_vlm/predictor.py b/paddlex/inference/models/doc_vlm/predictor.py index 9d252b454..4c9591ac8 100644 --- a/paddlex/inference/models/doc_vlm/predictor.py +++ b/paddlex/inference/models/doc_vlm/predictor.py @@ -27,7 +27,7 @@ from ....modules.doc_vlm.model_list import MODELS from ....utils import logging from ....utils.deps import require_genai_client_plugin -from ....utils.device import TemporaryDeviceChanger +from ....utils.device import TemporaryDeviceChanger, constr_device from ....utils.env import get_device_type from ...common.batch_sampler import DocVLMBatchSampler from ..base import BasePredictor @@ -56,6 +56,14 @@ def __init__(self, *args, **kwargs): import paddle self.device = kwargs.get("device", None) + if self.device is None and self.pp_option is not None: + if self.pp_option.device_type is not None: + device_ids = ( + None + if self.pp_option.device_id is None + else [self.pp_option.device_id] + ) + self.device = constr_device(self.pp_option.device_type, device_ids) self.dtype = ( "bfloat16" if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())