Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype/safetensors/test_safetensors_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_safetensors(self, config, act_pre_scale=False):
save_file(tensors_data_dict, f.name, metadata=metadata)
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
reconstructed_dict = unflatten_tensor_state_dict(
tensors_data_dict, metadata
tensors_data_dict, metadata, is_last_file=True
)
assert not tensors_data_dict

model = torch.nn.Sequential(
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
Expand Down
38 changes: 34 additions & 4 deletions torchao/prototype/safetensors/safetensors_support.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import json
import logging
from typing import Any, Dict
Expand All @@ -12,10 +13,13 @@

logger: logging.Logger = logging.getLogger(__name__)

_torchao_quantization_module = importlib.import_module("torchao.quantization")


def unflatten_tensor_state_dict(
tensors_data_dict: Dict[str, Any],
metadata: Dict[str, Any],
is_last_file: bool = False,
):
"""
Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata dictionary
Expand Down Expand Up @@ -68,30 +72,56 @@ def unflatten_tensor_state_dict(
result = {}

for tensor_name in tensor_names:
to_be_deleted = []

module_fqn, weight_name = tensor_name.rsplit(".", 1)

prefix = f"{module_fqn}._{weight_name}_"
tensor_tensors = {}

for key, value in combined_data.items():
if key.startswith(prefix):
# Remove the prefix
tensor_tensors[key[len(prefix) :]] = value
full_tensor_name_in_state_dict = key
to_be_deleted.append(
full_tensor_name_in_state_dict
) # for tensor subclass

tensor_metadata = json.loads(metadata.get(tensor_name))
tensor_type = tensor_metadata.get("_type")

if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
if not tensor_tensors:
# we allow the option of loading in state_dict info for a single tensor
# if tensor state dict info is not loaded in yet, we wait for it to be provided
# in a future call
tensor_cls = getattr(_torchao_quantization_module, tensor_type)
complete_tensor_data = list(tensor_cls.tensor_data_names)
if hasattr(tensor_cls, "optional_tensor_data_names"):
complete_tensor_data.append(list(tensor_cls.optional_tensor_data_names))

# if not all tensor data is present (ie missing qdata) we wait for it
# to be loaded in from a future call
if not tensor_tensors or (
not is_last_file
and not len(tensor_tensors) is len(complete_tensor_data)
):
continue
tensor_metadata["_data"].update(tensor_tensors)
result[tensor_name] = object_from_dict(tensor_metadata)
elif tensor_type == torch.Tensor.__name__:
# we allow the option of loading in state_dict info for a single tensor
# if tensor state dict info is not loaded in yet, we wait for it to be provided
# in a future call
if tensor_name not in tensors_data_dict.keys():
continue
result[tensor_name] = tensors_data_dict[tensor_name]
to_be_deleted.append(
tensor_name
) # add here because key for torch.Tensor has no prefix
else:
raise ValueError(f"Unsupported tensor type: {tensor_type}")

for tensor_name in to_be_deleted:
del tensors_data_dict[tensor_name]

return result


Expand Down
Loading