diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 23ad0565d..2cc1dd248 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -111,7 +111,7 @@ def __init__(self, config, mode="train"): # set device assert self.config["Global"]["device"] in [ "cpu", "gpu", "xpu", "npu", "mlu", "dcu", "ascend", "intel_gpu", - "mps", "gcu" + "mps", "gcu", "iluvatar_gpu" ] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format(