Skip to content
6 changes: 3 additions & 3 deletions src/transformers/quantizers/quantizer_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging
from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, is_torch_xpu_available, logging
from ..utils.quantization_config import QuantizationConfigMixin


Expand All @@ -45,9 +45,9 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.quantization_config = quantization_config

def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
if not torch.cuda.is_available() and not is_torch_xpu_available():
raise NotImplementedError(
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
"FPQuant quantization is only supported on GPU or Intel XPU. Please use a different quantizer."
)

if not is_qutlass_available() and not self.quantization_config.pseudoquantization:
Expand Down
31 changes: 16 additions & 15 deletions tests/quantization/fp_quant_integration/test_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
require_accelerate,
require_fp_quant,
require_qutlass,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)


@require_torch_gpu
@require_torch_accelerator
class FPQuantConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Expand All @@ -53,7 +53,7 @@ def test_from_dict(self):


@slow
@require_torch_gpu
@require_torch_accelerator
@require_fp_quant
@require_accelerate
class FPQuantBaseTest(unittest.TestCase):
Expand All @@ -64,7 +64,7 @@ class FPQuantBaseTest(unittest.TestCase):

EXPECTED_OUTPUT = "1 2 3 4 5 6"

device_map = "cuda"
device_map = torch_device

@classmethod
def getQuantizationConfig(cls):
Expand All @@ -77,10 +77,10 @@ def setUpClass(cls):
Setup quantized model
"""

quantization_config = cls.getQuantizationConfig()
cls.quantization_config = cls.getQuantizationConfig()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe @SunMarc could have a second look here.

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config
)

def tearDown(self):
Expand Down Expand Up @@ -111,24 +111,25 @@ def test_save_pretrained(self):
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
@require_torch_multi_accelerator
def test_quantized_model_multi_accelerator(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
Simple test that checks if the quantized model is working properly with multiple accelerators
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 CUDA GPUs, or set ZE_AFFINITY_MASK=0,1
if you have more than 2 Intel XPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = FPQuantConfig()

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config
self.model_name, device_map="auto", quantization_config=self.quantization_config
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
@require_torch_multi_accelerator
def test_save_pretrained_multi_accelerator(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
Expand Down