@@ -4186,9 +4186,6 @@ def _load_pretrained_model(
41864186 expanded_device_map = expand_device_map (device_map , expected_keys )
41874187 caching_allocator_warmup (model , expanded_device_map , hf_quantizer )
41884188
4189- if device_map is None :
4190- device_map = {"" : "cpu" }
4191- keys = sorted (device_map .keys (), key = len , reverse = True )
41924189 tp_plan = getattr (model , "_tp_plan" , None )
41934190 error_msgs = []
41944191
@@ -4211,33 +4208,18 @@ def _load_pretrained_model(
42114208 missing_keys , unexpected_keys , mismatched_keys , misc = set (), set (), set (), set ()
42124209 else :
42134210 all_pointer = set ()
4211+ # Checkpoints are safetensors
42144212 if checkpoint_files is not None and checkpoint_files [0 ].endswith (".safetensors" ):
4215- pattern = re .compile (r"(" + "|" .join (map (re .escape , keys )) + r")" )
4216- if sharded_metadata is None :
4217- k_v_iterator = dict .fromkeys (
4218- safe_open (checkpoint_files [0 ], framework = "pt" ).keys (), checkpoint_files [0 ].rsplit ("/" , 1 )[1 ]
4219- ).items ()
4220- else :
4221- k_v_iterator = sharded_metadata ["weight_map" ].items ()
4222-
42234213 merged_state_dict = {}
4224- for k , v in k_v_iterator :
4225- match = pattern .match (k )
4226- if match and match .group (1 ) != "" :
4227- device = device_map [match .group (1 )]
4228- else :
4229- device = device_map .get ("" , "cpu" )
4230- if isinstance (device , torch .device ):
4231- device = device .index # safetensors only
4232- if device == "disk" :
4233- device = "cpu" # we read to cpu to then write to disk
4234- file_pointer = safe_open (
4235- os .path .join (checkpoint_files [0 ].rsplit ("/" , 1 )[0 ], v ), framework = "pt" , device = device
4236- )
4214+ for file in checkpoint_files :
4215+ file_pointer = safe_open (file , framework = "pt" , device = "cpu" )
42374216 all_pointer .add (file_pointer )
4238- merged_state_dict [k ] = file_pointer .get_slice (k ) # don't materialize yet
4217+ for k in file_pointer .keys ():
4218+ merged_state_dict [k ] = file_pointer .get_slice (k ) # don't materialize yet
4219+ # User passed an explicit state_dict
42394220 elif state_dict is not None :
42404221 merged_state_dict = state_dict
4222+ # Checkpoints are .bin
42414223 elif checkpoint_files is not None :
42424224 merged_state_dict = {}
42434225 for ckpt_file in checkpoint_files :
0 commit comments