diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index bfec170fd0..25691efe2a 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -75,8 +75,9 @@ def test_safetensors(self, config, act_pre_scale=False): save_file(tensors_data_dict, f.name, metadata=metadata) tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda") reconstructed_dict = unflatten_tensor_state_dict( - tensors_data_dict, metadata + tensors_data_dict, metadata, is_last_file=True ) + assert not tensors_data_dict model = torch.nn.Sequential( torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 63623dcb15..8fb2230186 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -1,3 +1,4 @@ +import importlib import json import logging from typing import Any, Dict @@ -12,10 +13,13 @@ logger: logging.Logger = logging.getLogger(__name__) +_torchao_quantization_module = importlib.import_module("torchao.quantization") + def unflatten_tensor_state_dict( tensors_data_dict: Dict[str, Any], metadata: Dict[str, Any], + is_last_file: bool = False, ): """ Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata dictionary @@ -68,30 +72,56 @@ def unflatten_tensor_state_dict( result = {} for tensor_name in tensor_names: + to_be_deleted = [] + module_fqn, weight_name = tensor_name.rsplit(".", 1) prefix = f"{module_fqn}._{weight_name}_" tensor_tensors = {} + for key, value in combined_data.items(): if key.startswith(prefix): # Remove the prefix tensor_tensors[key[len(prefix) :]] = value + full_tensor_name_in_state_dict = key + to_be_deleted.append( + full_tensor_name_in_state_dict + ) # for tensor subclass tensor_metadata = json.loads(metadata.get(tensor_name)) tensor_type = tensor_metadata.get("_type") if tensor_type in ALLOWED_TENSORS_SUBCLASSES: - if not tensor_tensors: - # we allow the option of loading in state_dict info for a single tensor - # if tensor state dict info is not loaded in yet, we wait for it to be provided - # in a future call + tensor_cls = getattr(_torchao_quantization_module, tensor_type) + complete_tensor_data = list(tensor_cls.tensor_data_names) + if hasattr(tensor_cls, "optional_tensor_data_names"): + complete_tensor_data.append(list(tensor_cls.optional_tensor_data_names)) + + # if not all tensor data is present (ie missing qdata) we wait for it + # to be loaded in from a future call + if not tensor_tensors or ( + not is_last_file + and not len(tensor_tensors) is len(complete_tensor_data) + ): continue tensor_metadata["_data"].update(tensor_tensors) result[tensor_name] = object_from_dict(tensor_metadata) elif tensor_type == torch.Tensor.__name__: + # we allow the option of loading in state_dict info for a single tensor + # if tensor state dict info is not loaded in yet, we wait for it to be provided + # in a future call + if tensor_name not in tensors_data_dict.keys(): + continue result[tensor_name] = tensors_data_dict[tensor_name] + to_be_deleted.append( + tensor_name + ) # add here because key for torch.Tensor has no prefix else: raise ValueError(f"Unsupported tensor type: {tensor_type}") + + for tensor_name in to_be_deleted: + del tensors_data_dict[tensor_name] + return result