Skip to content

Commit 2fbd25c

Browse files
SunMarcCyrilvallez
authored andcommitted
Fix bnb fsdp loading for pre-quantized checkpoint (#41415)
* fix * fix * get_param_name * fix device name
1 parent a92b1e8 commit 2fbd25c

File tree

5 files changed

+32
-25
lines changed

5 files changed

+32
-25
lines changed

src/transformers/modeling_utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -777,21 +777,17 @@ def _load_state_dict_into_meta_model(
777777
# and then cast it to CPU to avoid excessive memory usage on each GPU
778778
# in comparison to the sharded model across GPUs.
779779
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
780-
param_name = hf_quantizer.update_param_name(param_name)
780+
param_name = hf_quantizer.get_param_name(param_name)
781781
module, param_type = get_module_from_name(model, param_name)
782782
value = getattr(module, param_type)
783-
# special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu
784-
if model.config.model_type == "gpt_oss" and value.device.type == "meta":
783+
# We need to wait until the quantized value is created
784+
if value.device.type == "meta":
785785
continue
786-
param_to = "cpu"
787-
if is_fsdp_enabled() and not is_local_dist_rank_0():
788-
param_to = "meta"
789-
val_kwargs = {}
790-
if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (
791-
value.dtype == torch.uint8 or value.dtype == torch.int8
792-
):
786+
val_kwargs = value.__dict__
787+
if not value.is_floating_point():
793788
val_kwargs["requires_grad"] = False
794-
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
789+
device = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu"
790+
value = type(value)(value.data.to(device), **val_kwargs)
795791
setattr(module, param_type, value)
796792

797793
# Remove the param from the state dict if it was not loaded on the fly to avoid wasting memory
@@ -6070,7 +6066,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
60706066
# For example in the case of MXFP4 quantization, we need to update the param name to the original param name
60716067
# because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
60726068
if hf_quantizer is not None:
6073-
param_name = hf_quantizer.update_param_name(param_name)
6069+
param_name = hf_quantizer.get_param_name(param_name)
60746070

60756071
try:
60766072
param = model.get_parameter_or_buffer(param_name)

src/transformers/quantizers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _dequantize(self, model):
283283
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
284284
)
285285

286-
def update_param_name(self, param_name: str) -> str:
286+
def get_param_name(self, param_name: str) -> str:
287287
"""
288288
Override this method if you want to adjust the `param_name`.
289289
"""

src/transformers/quantizers/quantizer_bnb_4bit.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
160160
module, name = get_module_from_name(model, param_name)
161161
return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
162162

163+
def get_param_name(self, param_name: str) -> str:
164+
"""
165+
Get the right param_name in order to get the module associated with the param.
166+
This is useful for quantized stats lile absmax or quant_map as we need to update the param_name to get the module as they are stored in ...weight.absmax.
167+
"""
168+
if self.pre_quantized:
169+
# We need to get the param name of quantized weights and not its components. Otherwise, we won't be able to get the nn.Module associated.
170+
if any(param_name.endswith(x) for x in self.bnb_keys):
171+
param_name = (
172+
param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
173+
)
174+
return param_name
175+
163176
def create_quantized_param(
164177
self,
165178
model: "PreTrainedModel",
@@ -170,12 +183,10 @@ def create_quantized_param(
170183
):
171184
import bitsandbytes as bnb
172185

173-
is_quant_stat = any(param_name.endswith(x) for x in self.bnb_keys)
174186
full_name = param_name
175-
if is_quant_stat:
176-
param_name = (
177-
param_name.rsplit(".", 1)[0] if "quant_state." not in param_name else param_name.rsplit(".", 2)[0]
178-
)
187+
188+
# update param name to get the weights instead of the quantized stats
189+
param_name = self.get_param_name(param_name)
179190
module, tensor_name = get_module_from_name(model, param_name)
180191

181192
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def update_ep_plan(self, config):
365365
)
366366
return config
367367

368-
def update_param_name(self, param_name: str) -> str:
368+
def get_param_name(self, param_name: str) -> str:
369369
if self.quantization_config.dequantize:
370370
if "_blocks" in param_name:
371371
return param_name.replace("_blocks", "")

tests/quantization/mxfp4/test_mxfp4.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_update_expected_keys(self):
265265

266266
self.assertEqual(set(updated_keys), set(expected_updated))
267267

268-
def test_update_param_name_dequantize(self):
268+
def test_get_param_name_dequantize(self):
269269
"""Test parameter name updating when dequantizing"""
270270
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
271271

@@ -274,28 +274,28 @@ def test_update_param_name_dequantize(self):
274274

275275
# Should remove _blocks suffix
276276
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
277-
updated_name = quantizer.update_param_name(param_name)
277+
updated_name = quantizer.get_param_name(param_name)
278278
self.assertEqual(updated_name, "model.layers.0.mlp.experts.gate_up_proj")
279279

280280
# Should remove _scales suffix
281281
param_name = "model.layers.0.mlp.experts.down_proj_scales"
282-
updated_name = quantizer.update_param_name(param_name)
282+
updated_name = quantizer.get_param_name(param_name)
283283
self.assertEqual(updated_name, "model.layers.0.mlp.experts.down_proj")
284284

285285
# Should not change other names
286286
param_name = "model.embed_tokens.weight"
287-
updated_name = quantizer.update_param_name(param_name)
287+
updated_name = quantizer.get_param_name(param_name)
288288
self.assertEqual(updated_name, "model.embed_tokens.weight")
289289

290-
def test_update_param_name_no_dequantize(self):
290+
def test_get_param_name_no_dequantize(self):
291291
"""Test parameter name updating when not dequantizing"""
292292
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
293293

294294
config = Mxfp4Config(dequantize=False)
295295
quantizer = Mxfp4HfQuantizer(config)
296296

297297
param_name = "model.layers.0.mlp.experts.gate_up_proj_blocks"
298-
updated_name = quantizer.update_param_name(param_name)
298+
updated_name = quantizer.get_param_name(param_name)
299299
self.assertEqual(updated_name, param_name)
300300

301301
def test_is_trainable(self):

0 commit comments

Comments
 (0)