Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def load_model_dict_into_meta(
empty_state_dict = model.state_dict()

for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
if param_name in unexpected_keys:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be?

Suggested change
if param_name in unexpected_keys:
if param_name not in empty_state_dict or param_name in unexpected_keys:

Copy link
Contributor Author

@Disty0 Disty0 Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters that will be added in quantization isn't in the empty_state_dict yet. They will be added to the model in create_quantized_param within this loop.

Transformers uses param_name not in expected_keys for this check. I used the unexpected keys here instead because diffusers doesn't pass the expected keys to this loop.

continue

set_module_kwargs = {}
Expand All @@ -260,10 +260,15 @@ def load_model_dict_into_meta(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if param_name in empty_state_dict:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
else:
# hf_quantizer can add parameters that doesn't exist yet
# they will be in the loaded_state_dict when pre_quantized
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also provide more details when this can arise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed a new commit that fixes the failing pipeline tests when unexpected_keys is None. Also added more details to this comment lines.

old_param = None

if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
Expand All @@ -279,7 +284,7 @@ def load_model_dict_into_meta(

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if param_name in empty_state_dict and empty_state_dict[param_name].shape != param.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just add a small comment for that as we will probably refactor the loading at some point to match what we have in transformers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added: 745b041

if (
is_quantized
and hf_quantizer.pre_quantized
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,12 +1564,17 @@ def _load_pretrained_model(
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
):
is_quantized = hf_quantizer is not None
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
if is_quantized:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
missing_keys = list(set(expected_keys) - set(loaded_keys))
if hf_quantizer is not None:
if is_quantized:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
if is_quantized:
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_unexpected is not None:
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,28 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li
"""
return missing_keys

def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.

Args:
expected_keys (`list[str]`, *optional*):
The list of the expected keys in the initialized model.
loaded_keys (`list[str]`, *optional*):
The list of the loaded keys in the checkpoint.
"""
return expected_keys

def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.

Args:
unexpected_keys (`list[str]`, *optional*):
The list of the unexpected keys in the checkpoint compared to the state dict of the model
"""
return unexpected_keys

def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
Expand Down
Loading