Skip to content

Commit ffdd10f

Browse files
shanjiazkylesayrs
andauthored
Allow compression on meta device (#39039)
* disable gradient calculation for int weights Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * Update src/transformers/quantizers/quantizer_compressed_tensors.py Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> * updated model procession before/after weight loading Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * fix style Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * reformat Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * fix style Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f0e7781 commit ffdd10f

File tree

1 file changed

+8
-60
lines changed

1 file changed

+8
-60
lines changed

src/transformers/quantizers/quantizer_compressed_tensors.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
import os
17-
import re
18-
1916
from ..utils import is_compressed_tensors_available, is_torch_available, logging
2017
from ..utils.quantization_config import CompressedTensorsConfig
2118
from .base import HfQuantizer
@@ -55,45 +52,6 @@ def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
5552
self.run_compressed = quantization_config.run_compressed
5653
self.quantization_config = quantization_config
5754

58-
def update_missing_keys_after_loading(self, model, missing_keys: list[str], prefix: str) -> list[str]:
59-
"""
60-
Update missing keys after loading the model. This is necessary for compressed tensors
61-
to load the model correctly. We expect weights to be present in missing keys.
62-
The weight's are re-constructed by ModelCompressor in _process_model_after_weight_loading
63-
64-
This function cleans up expected missing keys and returns the remaining missing keys
65-
"""
66-
67-
if self.run_compressed:
68-
return missing_keys
69-
70-
# We expect some keys to be missing for
71-
# compressed models
72-
# This is fine as the weights are reconstructed by ModelCompressor
73-
# in _process_model_after_weight_loading
74-
75-
expected_missing_keys = self.compressor.get_missing_module_keys(model)
76-
return [
77-
key for key in missing_keys if not any(re.match(f".*{pattern}", key) for pattern in expected_missing_keys)
78-
]
79-
80-
def update_unexpected_keys(self, model, unexpected_keys: list[str], prefix: str) -> list[str]:
81-
"""
82-
Override this method if you want to adjust the `unexpected_keys`.
83-
84-
Args:
85-
unexpected_keys (`list[str]`, *optional*):
86-
The list of unexpected keys in the checkpoint compared to the state dict of the model
87-
"""
88-
89-
if self.run_compressed:
90-
return unexpected_keys
91-
92-
# We expect some unexpected keys in model
93-
# safetensors file for compressed models
94-
keys_to_ignore = self.compressor.get_unexpected_file_keys(model)
95-
return [key for key in unexpected_keys if not any(re.match(f".*{pattern}", key) for pattern in keys_to_ignore)]
96-
9755
def validate_environment(self, *args, **kwargs):
9856
if not is_compressed_tensors_available():
9957
raise ImportError(
@@ -117,31 +75,21 @@ def _process_model_before_weight_loading(self, model, **kwargs):
11775

11876
ct_quantization_config = self.compressor.quantization_config
11977

120-
if self.run_compressed:
121-
apply_quantization_config(model, ct_quantization_config, run_compressed=True)
122-
elif not self.quantization_config.is_quantization_compressed:
123-
apply_quantization_config(model, ct_quantization_config)
78+
# Always initialize compressed wrappers to match the checkpoint
79+
apply_quantization_config(model, ct_quantization_config, self.run_compressed)
80+
if (
81+
self.quantization_config.is_quantization_compressed
82+
or self.quantization_config.is_sparsification_compressed
83+
):
84+
self.compressor.compress_model(model=model)
12485

12586
def _process_model_after_weight_loading(self, model, **kwargs):
12687
"""Decompress loaded model if necessary - need for qat"""
12788

12889
if (
12990
self.quantization_config.is_quantization_compressed and not self.run_compressed
13091
) or self.quantization_config.is_sparsification_compressed:
131-
config = kwargs.get("config")
132-
cache_path = config._name_or_path
133-
134-
if not os.path.exists(cache_path):
135-
from transformers.utils import cached_file
136-
137-
config_file_path = cached_file(cache_path, "config.json")
138-
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])
139-
140-
if self.quantization_config.is_quantization_compressed and not self.run_compressed:
141-
from compressed_tensors.quantization import QuantizationStatus
142-
143-
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
144-
self.compressor.decompress(model_path=cache_path, model=model)
92+
self.compressor.decompress_model(model=model)
14593

14694
def update_tp_plan(self, config):
14795
additional_plan = {

0 commit comments

Comments
 (0)