|
| 1 | +import importlib |
1 | 2 | import json |
2 | 3 | import logging |
3 | 4 | from typing import Any, Dict |
|
12 | 13 |
|
13 | 14 | logger: logging.Logger = logging.getLogger(__name__) |
14 | 15 |
|
| 16 | +_torchao_quantization_module = importlib.import_module("torchao.quantization") |
| 17 | + |
15 | 18 |
|
16 | 19 | def unflatten_tensor_state_dict( |
17 | 20 | tensors_data_dict: Dict[str, Any], |
18 | 21 | metadata: Dict[str, Any], |
| 22 | + is_last_file: bool = False, |
19 | 23 | ): |
20 | 24 | """ |
21 | 25 | Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata dictionary |
@@ -68,30 +72,56 @@ def unflatten_tensor_state_dict( |
68 | 72 | result = {} |
69 | 73 |
|
70 | 74 | for tensor_name in tensor_names: |
| 75 | + to_be_deleted = [] |
| 76 | + |
71 | 77 | module_fqn, weight_name = tensor_name.rsplit(".", 1) |
72 | 78 |
|
73 | 79 | prefix = f"{module_fqn}._{weight_name}_" |
74 | 80 | tensor_tensors = {} |
| 81 | + |
75 | 82 | for key, value in combined_data.items(): |
76 | 83 | if key.startswith(prefix): |
77 | 84 | # Remove the prefix |
78 | 85 | tensor_tensors[key[len(prefix) :]] = value |
| 86 | + full_tensor_name_in_state_dict = key |
| 87 | + to_be_deleted.append( |
| 88 | + full_tensor_name_in_state_dict |
| 89 | + ) # for tensor subclass |
79 | 90 |
|
80 | 91 | tensor_metadata = json.loads(metadata.get(tensor_name)) |
81 | 92 | tensor_type = tensor_metadata.get("_type") |
82 | 93 |
|
83 | 94 | if tensor_type in ALLOWED_TENSORS_SUBCLASSES: |
84 | | - if not tensor_tensors: |
85 | | - # we allow the option of loading in state_dict info for a single tensor |
86 | | - # if tensor state dict info is not loaded in yet, we wait for it to be provided |
87 | | - # in a future call |
| 95 | + tensor_cls = getattr(_torchao_quantization_module, tensor_type) |
| 96 | + complete_tensor_data = list(tensor_cls.tensor_data_names) |
| 97 | + if hasattr(tensor_cls, "optional_tensor_data_names"): |
| 98 | + complete_tensor_data.append(list(tensor_cls.optional_tensor_data_names)) |
| 99 | + |
| 100 | + # if not all tensor data is present (ie missing qdata) we wait for it |
| 101 | + # to be loaded in from a future call |
| 102 | + if not tensor_tensors or ( |
| 103 | + not is_last_file |
| 104 | + and not len(tensor_tensors) is len(complete_tensor_data) |
| 105 | + ): |
88 | 106 | continue |
89 | 107 | tensor_metadata["_data"].update(tensor_tensors) |
90 | 108 | result[tensor_name] = object_from_dict(tensor_metadata) |
91 | 109 | elif tensor_type == torch.Tensor.__name__: |
| 110 | + # we allow the option of loading in state_dict info for a single tensor |
| 111 | + # if tensor state dict info is not loaded in yet, we wait for it to be provided |
| 112 | + # in a future call |
| 113 | + if tensor_name not in tensors_data_dict.keys(): |
| 114 | + continue |
92 | 115 | result[tensor_name] = tensors_data_dict[tensor_name] |
| 116 | + to_be_deleted.append( |
| 117 | + tensor_name |
| 118 | + ) # add here because key for torch.Tensor has no prefix |
93 | 119 | else: |
94 | 120 | raise ValueError(f"Unsupported tensor type: {tensor_type}") |
| 121 | + |
| 122 | + for tensor_name in to_be_deleted: |
| 123 | + del tensors_data_dict[tensor_name] |
| 124 | + |
95 | 125 | return result |
96 | 126 |
|
97 | 127 |
|
|
0 commit comments