Skip to content

Commit 1742d11

Browse files
authored
[loading] Fix device when source and target are different (#42246)
* fix device * fix * CI * simplify a bit
1 parent 16924cd commit 1742d11

File tree

2 files changed

+19
-32
lines changed

2 files changed

+19
-32
lines changed

src/transformers/core_model_loading.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,16 @@ class ConversionEntry:
312312
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
313313

314314

315-
def _materialize_copy(tensor, dtype=None):
315+
def _materialize_copy(tensor, device=None, dtype=None):
316316
tensor = tensor[...]
317-
if dtype is not None:
318-
tensor = tensor.to(dtype)
317+
if dtype is not None or device is not None:
318+
tensor = tensor.to(device=device, dtype=dtype)
319319
return tensor
320320

321321

322-
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
322+
def spawn_materialize(thread_pool, tensor, device=None, dtype=None) -> Future:
323323
def _job():
324-
return _materialize_copy(tensor, dtype)
324+
return _materialize_copy(tensor, device, dtype)
325325

326326
return thread_pool.submit(_job)
327327

@@ -447,7 +447,10 @@ def convert_and_load_state_dict_in_model(
447447

448448
prefix = model.base_model_prefix
449449
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
450-
device_map = device_map or {} # {exact_target_key: device}
450+
device_map = device_map or {"": "cpu"} # {exact_target_key: device}
451+
device_map_regex = re.compile(
452+
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True))
453+
)
451454
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
452455
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
453456
meta_model_state_dict = model.state_dict()
@@ -534,7 +537,9 @@ def convert_and_load_state_dict_in_model(
534537
)
535538

536539
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
537-
future = spawn_materialize(thread_pool, tensor, _dtype)
540+
device_match = device_map_regex.match(first_target_key)
541+
param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
542+
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
538543
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
539544

540545
# 2. Actually convert the ckpt

src/transformers/modeling_utils.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4186,9 +4186,6 @@ def _load_pretrained_model(
41864186
expanded_device_map = expand_device_map(device_map, expected_keys)
41874187
caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
41884188

4189-
if device_map is None:
4190-
device_map = {"": "cpu"}
4191-
keys = sorted(device_map.keys(), key=len, reverse=True)
41924189
tp_plan = getattr(model, "_tp_plan", None)
41934190
error_msgs = []
41944191

@@ -4211,33 +4208,18 @@ def _load_pretrained_model(
42114208
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()
42124209
else:
42134210
all_pointer = set()
4211+
# Checkpoints are safetensors
42144212
if checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors"):
4215-
pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")")
4216-
if sharded_metadata is None:
4217-
k_v_iterator = dict.fromkeys(
4218-
safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0].rsplit("/", 1)[1]
4219-
).items()
4220-
else:
4221-
k_v_iterator = sharded_metadata["weight_map"].items()
4222-
42234213
merged_state_dict = {}
4224-
for k, v in k_v_iterator:
4225-
match = pattern.match(k)
4226-
if match and match.group(1) != "":
4227-
device = device_map[match.group(1)]
4228-
else:
4229-
device = device_map.get("", "cpu")
4230-
if isinstance(device, torch.device):
4231-
device = device.index # safetensors only
4232-
if device == "disk":
4233-
device = "cpu" # we read to cpu to then write to disk
4234-
file_pointer = safe_open(
4235-
os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device
4236-
)
4214+
for file in checkpoint_files:
4215+
file_pointer = safe_open(file, framework="pt", device="cpu")
42374216
all_pointer.add(file_pointer)
4238-
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4217+
for k in file_pointer.keys():
4218+
merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet
4219+
# User passed an explicit state_dict
42394220
elif state_dict is not None:
42404221
merged_state_dict = state_dict
4222+
# Checkpoints are .bin
42414223
elif checkpoint_files is not None:
42424224
merged_state_dict = {}
42434225
for ckpt_file in checkpoint_files:

0 commit comments

Comments
 (0)