Skip to content

Commit 98df691

Browse files
yao-matrixydshiehSunMarc
authored andcommitted
extend fp_quant cases to xpu (huggingface#41833)
* extend fp_quant UTs to xpu Signed-off-by: Yao, Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao, Matrix <matrix.yao@intel.com> * Update tests/quantization/fp_quant_integration/test_fp_quant.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --------- Signed-off-by: Yao, Matrix <matrix.yao@intel.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent aa4104f commit 98df691

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

src/transformers/quantizers/quantizer_fp_quant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
if TYPE_CHECKING:
2121
from ..modeling_utils import PreTrainedModel
2222

23-
from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, logging
23+
from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, is_torch_xpu_available, logging
2424
from ..utils.quantization_config import QuantizationConfigMixin
2525

2626

@@ -45,9 +45,9 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
4545
self.quantization_config = quantization_config
4646

4747
def validate_environment(self, device_map, **kwargs):
48-
if not torch.cuda.is_available():
48+
if not torch.cuda.is_available() and not is_torch_xpu_available():
4949
raise NotImplementedError(
50-
"FPQuant quantization is only supported on GPU. Please use a different quantizer."
50+
"FPQuant quantization is only supported on GPU or Intel XPU. Please use a different quantizer."
5151
)
5252

5353
if not is_qutlass_available() and not self.quantization_config.pseudoquantization:

tests/quantization/fp_quant_integration/test_fp_quant.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
require_accelerate,
2323
require_fp_quant,
2424
require_qutlass,
25-
require_torch_gpu,
26-
require_torch_multi_gpu,
25+
require_torch_accelerator,
26+
require_torch_multi_accelerator,
2727
slow,
2828
torch_device,
2929
)
3030

3131

32-
@require_torch_gpu
32+
@require_torch_accelerator
3333
class FPQuantConfigTest(unittest.TestCase):
3434
def test_to_dict(self):
3535
"""
@@ -53,7 +53,7 @@ def test_from_dict(self):
5353

5454

5555
@slow
56-
@require_torch_gpu
56+
@require_torch_accelerator
5757
@require_fp_quant
5858
@require_accelerate
5959
class FPQuantBaseTest(unittest.TestCase):
@@ -64,7 +64,7 @@ class FPQuantBaseTest(unittest.TestCase):
6464

6565
EXPECTED_OUTPUT = "1 2 3 4 5 6"
6666

67-
device_map = "cuda"
67+
device_map = torch_device
6868

6969
@classmethod
7070
def getQuantizationConfig(cls):
@@ -77,10 +77,10 @@ def setUpClass(cls):
7777
Setup quantized model
7878
"""
7979

80-
quantization_config = cls.getQuantizationConfig()
80+
cls.quantization_config = cls.getQuantizationConfig()
8181
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
8282
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
83-
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
83+
cls.model_name, device_map=cls.device_map, quantization_config=cls.quantization_config
8484
)
8585

8686
def tearDown(self):
@@ -111,24 +111,25 @@ def test_save_pretrained(self):
111111
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
112112
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
113113

114-
@require_torch_multi_gpu
115-
def test_quantized_model_multi_gpu(self):
114+
@require_torch_multi_accelerator
115+
def test_quantized_model_multi_accelerator(self):
116116
"""
117-
Simple test that checks if the quantized model is working properly with multiple GPUs
118-
set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUs
117+
Simple test that checks if the quantized model is working properly with multiple accelerators.
118+
Set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 CUDA GPUs. Or set ZE_AFFINITY_MASK=0,1
119+
if you have more than 2 Intel XPUs.
119120
"""
120121
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
121-
quantization_config = FPQuantConfig()
122+
122123
quantized_model = AutoModelForCausalLM.from_pretrained(
123-
self.model_name, device_map="auto", quantization_config=quantization_config
124+
self.model_name, device_map="auto", quantization_config=self.quantization_config
124125
)
125126
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
126127

127128
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
128129
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
129130

130-
@require_torch_multi_gpu
131-
def test_save_pretrained_multi_gpu(self):
131+
@require_torch_multi_accelerator
132+
def test_save_pretrained_multi_accelerator(self):
132133
"""
133134
Simple test that checks if the quantized model is working properly after being saved and loaded
134135
"""

0 commit comments

Comments
 (0)