1313# limitations under the License.
1414
1515
16- import os
17- import re
18-
1916from ..utils import is_compressed_tensors_available , is_torch_available , logging
2017from ..utils .quantization_config import CompressedTensorsConfig
2118from .base import HfQuantizer
@@ -55,45 +52,6 @@ def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
5552 self .run_compressed = quantization_config .run_compressed
5653 self .quantization_config = quantization_config
5754
58- def update_missing_keys_after_loading (self , model , missing_keys : list [str ], prefix : str ) -> list [str ]:
59- """
60- Update missing keys after loading the model. This is necessary for compressed tensors
61- to load the model correctly. We expect weights to be present in missing keys.
62- The weight's are re-constructed by ModelCompressor in _process_model_after_weight_loading
63-
64- This function cleans up expected missing keys and returns the remaining missing keys
65- """
66-
67- if self .run_compressed :
68- return missing_keys
69-
70- # We expect some keys to be missing for
71- # compressed models
72- # This is fine as the weights are reconstructed by ModelCompressor
73- # in _process_model_after_weight_loading
74-
75- expected_missing_keys = self .compressor .get_missing_module_keys (model )
76- return [
77- key for key in missing_keys if not any (re .match (f".*{ pattern } " , key ) for pattern in expected_missing_keys )
78- ]
79-
80- def update_unexpected_keys (self , model , unexpected_keys : list [str ], prefix : str ) -> list [str ]:
81- """
82- Override this method if you want to adjust the `unexpected_keys`.
83-
84- Args:
85- unexpected_keys (`list[str]`, *optional*):
86- The list of unexpected keys in the checkpoint compared to the state dict of the model
87- """
88-
89- if self .run_compressed :
90- return unexpected_keys
91-
92- # We expect some unexpected keys in model
93- # safetensors file for compressed models
94- keys_to_ignore = self .compressor .get_unexpected_file_keys (model )
95- return [key for key in unexpected_keys if not any (re .match (f".*{ pattern } " , key ) for pattern in keys_to_ignore )]
96-
9755 def validate_environment (self , * args , ** kwargs ):
9856 if not is_compressed_tensors_available ():
9957 raise ImportError (
@@ -117,31 +75,21 @@ def _process_model_before_weight_loading(self, model, **kwargs):
11775
11876 ct_quantization_config = self .compressor .quantization_config
11977
120- if self .run_compressed :
121- apply_quantization_config (model , ct_quantization_config , run_compressed = True )
122- elif not self .quantization_config .is_quantization_compressed :
123- apply_quantization_config (model , ct_quantization_config )
78+ # Always initialize compressed wrappers to match the checkpoint
79+ apply_quantization_config (model , ct_quantization_config , self .run_compressed )
80+ if (
81+ self .quantization_config .is_quantization_compressed
82+ or self .quantization_config .is_sparsification_compressed
83+ ):
84+ self .compressor .compress_model (model = model )
12485
12586 def _process_model_after_weight_loading (self , model , ** kwargs ):
12687 """Decompress loaded model if necessary - need for qat"""
12788
12889 if (
12990 self .quantization_config .is_quantization_compressed and not self .run_compressed
13091 ) or self .quantization_config .is_sparsification_compressed :
131- config = kwargs .get ("config" )
132- cache_path = config ._name_or_path
133-
134- if not os .path .exists (cache_path ):
135- from transformers .utils import cached_file
136-
137- config_file_path = cached_file (cache_path , "config.json" )
138- cache_path = os .path .sep .join (config_file_path .split (os .path .sep )[:- 1 ])
139-
140- if self .quantization_config .is_quantization_compressed and not self .run_compressed :
141- from compressed_tensors .quantization import QuantizationStatus
142-
143- self .compressor .quantization_config .quantization_status = QuantizationStatus .FROZEN
144- self .compressor .decompress (model_path = cache_path , model = model )
92+ self .compressor .decompress_model (model = model )
14593
14694 def update_tp_plan (self , config ):
14795 additional_plan = {
0 commit comments