Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,16 @@ class ConversionEntry:
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4


def _materialize_copy(tensor, dtype=None):
def _materialize_copy(tensor, device=None, dtype=None):
tensor = tensor[...]
if dtype is not None:
tensor = tensor.to(dtype)
if dtype is not None or device is not None:
tensor = tensor.to(device=device, dtype=dtype)
return tensor


def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
def _job():
return _materialize_copy(tensor, dtype)
return _materialize_copy(tensor, device, dtype)

return thread_pool.submit(_job)

Expand Down Expand Up @@ -447,7 +447,10 @@ def convert_and_load_state_dict_in_model(

prefix = model.base_model_prefix
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
device_map = device_map or {} # {exact_target_key: device}
device_map = device_map or {"": "cpu"} # {exact_target_key: device}
device_map_regex = re.compile(
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True))
)
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
meta_model_state_dict = model.state_dict()
Expand Down Expand Up @@ -534,7 +537,9 @@ def convert_and_load_state_dict_in_model(
)

if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
future = spawn_materialize(thread_pool, tensor, _dtype)
device_match = device_map_regex.match(first_target_key)
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)

# 2. Actually convert the ckpt
Expand Down
32 changes: 7 additions & 25 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4186,9 +4186,6 @@ def _load_pretrained_model(
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)

if device_map is None:
device_map = {"": "cpu"}
keys = sorted(device_map.keys(), key=len, reverse=True)
tp_plan = getattr(model, "_tp_plan", None)
error_msgs = []

Expand All @@ -4211,33 +4208,18 @@ def _load_pretrained_model(
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
else:
all_pointer = set()
# Checkpoints are safetensors
if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
if sharded_metadata is None:
k_v_iterator = dict.fromkeys(
safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0].rsplit("/", 1)[1]
).items()
else:
k_v_iterator = sharded_metadata["weight_map"].items()

merged_state_dict = {}
for k, v in k_v_iterator:
match = pattern.match(k)
if match and match.group(1) != "":
device = device_map[match.group(1)]
else:
device = device_map.get("", "cpu")
if isinstance(device, torch.device):
device = device.index # safetensors only
if device == "disk":
device = "cpu" # we read to cpu to then write to disk
file_pointer = safe_open(
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
)
for file in checkpoint_files:
file_pointer = safe_open(file, framework="pt", device="cpu")
all_pointer.add(file_pointer)
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
for k in file_pointer.keys():
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
# User passed an explicit state_dict
elif state_dict is not None:
merged_state_dict = state_dict
# Checkpoints are .bin
elif checkpoint_files is not None:
merged_state_dict = {}
for ckpt_file in checkpoint_files:
Expand Down