Skip to content

Commit 0e290bb

Browse files
committed
modify unflatten for vllm
1 parent 6c78c4d commit 0e290bb

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def test_safetensors(self, config, act_pre_scale=False):
7575
save_file(tensors_data_dict, f.name, metadata=metadata)
7676
tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda")
7777
reconstructed_dict = unflatten_tensor_state_dict(
78-
tensors_data_dict, metadata
78+
tensors_data_dict, metadata, is_last_file=True
7979
)
80+
assert not tensors_data_dict
8081

8182
model = torch.nn.Sequential(
8283
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib
12
import json
23
import logging
34
from typing import Any, Dict
@@ -12,10 +13,13 @@
1213

1314
logger: logging.Logger = logging.getLogger(__name__)
1415

16+
_torchao_quantization_module = importlib.import_module("torchao.quantization")
17+
1518

1619
def unflatten_tensor_state_dict(
1720
tensors_data_dict: Dict[str, Any],
1821
metadata: Dict[str, Any],
22+
is_last_file: bool = False,
1923
):
2024
"""
2125
Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata dictionary
@@ -68,30 +72,55 @@ def unflatten_tensor_state_dict(
6872
result = {}
6973

7074
for tensor_name in tensor_names:
75+
to_be_deleted = []
76+
7177
module_fqn, weight_name = tensor_name.rsplit(".", 1)
7278

7379
prefix = f"{module_fqn}._{weight_name}_"
7480
tensor_tensors = {}
81+
7582
for key, value in combined_data.items():
7683
if key.startswith(prefix):
7784
# Remove the prefix
7885
tensor_tensors[key[len(prefix) :]] = value
86+
full_tensor_name_in_state_dict = key
87+
to_be_deleted.append(
88+
full_tensor_name_in_state_dict
89+
) # for tensor subclass
7990

8091
tensor_metadata = json.loads(metadata.get(tensor_name))
8192
tensor_type = tensor_metadata.get("_type")
8293

8394
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
84-
if not tensor_tensors:
85-
# we allow the option of loading in state_dict info for a single tensor
86-
# if tensor state dict info is not loaded in yet, we wait for it to be provided
87-
# in a future call
95+
tensor_cls = getattr(_torchao_quantization_module, tensor_type)
96+
complete_tensor_data = list(tensor_cls.tensor_data_names)
97+
if hasattr(tensor_cls, "optional_tensor_data_names"):
98+
complete_tensor_data.append(list(tensor_cls.optional_tensor_data_names))
99+
100+
# if not all tensor data is present (ie missing qdata) we wait for it
101+
# to be loaded in from a future call
102+
if not tensor_tensors or (
103+
is_last_file and not len(tensor_tensors) is len(complete_tensor_data)
104+
):
88105
continue
89106
tensor_metadata["_data"].update(tensor_tensors)
90107
result[tensor_name] = object_from_dict(tensor_metadata)
91108
elif tensor_type == torch.Tensor.__name__:
109+
# we allow the option of loading in state_dict info for a single tensor
110+
# if tensor state dict info is not loaded in yet, we wait for it to be provided
111+
# in a future call
112+
if tensor_name not in tensors_data_dict.keys():
113+
continue
92114
result[tensor_name] = tensors_data_dict[tensor_name]
115+
to_be_deleted.append(
116+
tensor_name
117+
) # add here because key for torch.Tensor has no prefix
93118
else:
94119
raise ValueError(f"Unsupported tensor type: {tensor_type}")
120+
121+
for tensor_name in to_be_deleted:
122+
del tensors_data_dict[tensor_name]
123+
95124
return result
96125

97126

0 commit comments

Comments
 (0)