99import torch .distributed as dist
1010
1111import deepspeed
12- from transformers import AutoConfig , AutoTokenizer
12+ from huggingface_hub import try_to_load_from_cache
13+ from transformers import AutoConfig
1314
1415from ..utils import print_rank_n , run_rank_n
15- from .model import Model , get_downloaded_model_path , get_hf_model_class , load_tokenizer
16+ from .model import Model , get_hf_model_class
1617
1718
1819# basic DeepSpeed inference model class for benchmarking
@@ -24,26 +25,22 @@ def __init__(self, args: Namespace) -> None:
2425
2526 world_size = int (os .getenv ("WORLD_SIZE" , "1" ))
2627
27- downloaded_model_path = get_downloaded_model_path (args .model_name )
28-
29- self .tokenizer = load_tokenizer (downloaded_model_path )
30- self .pad = self .tokenizer .pad_token_id
31-
3228 # create dummy tensors for allocating space which will be filled with
3329 # the actual weights while calling deepspeed.init_inference in the
3430 # following code
3531 with deepspeed .OnDevice (dtype = torch .float16 , device = "meta" ):
3632 self .model = get_hf_model_class (args .model_class ).from_config (
37- AutoConfig .from_pretrained (downloaded_model_path ), torch_dtype = torch .bfloat16
33+ AutoConfig .from_pretrained (args . model_name ), torch_dtype = torch .bfloat16
3834 )
3935 self .model = self .model .eval ()
4036
37+ downloaded_model_path = get_model_path (args .model_name )
38+
4139 if args .dtype in [torch .float16 , torch .int8 ]:
4240 # We currently support the weights provided by microsoft (which are
4341 # pre-sharded)
44- if args .use_pre_sharded_checkpoints :
45- checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
46-
42+ checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
43+ if os .path .isfile (checkpoints_json ):
4744 self .model = deepspeed .init_inference (
4845 self .model ,
4946 mp_size = world_size ,
@@ -74,6 +71,8 @@ def __init__(self, args: Namespace) -> None:
7471 print_rank_n ("Model loaded" )
7572 dist .barrier ()
7673
74+ self .post_init (args .model_name )
75+
7776
7877class TemporaryCheckpointsJSON :
7978 def __init__ (self , model_path : str ):
@@ -91,5 +90,15 @@ def __enter__(self):
9190 run_rank_n (partial (self .write_checkpoints_json , model_path = self .model_path ), barrier = True )
9291 return self .tmp_file
9392
94- def __exit__ (self , type , value , traceback ):
95- return
93+
94+ def get_model_path (model_name : str ):
95+ config_file = "config.json"
96+
97+ # will fall back to HUGGINGFACE_HUB_CACHE
98+ config_path = try_to_load_from_cache (model_name , config_file , cache_dir = os .getenv ("TRANSFORMERS_CACHE" ))
99+
100+ if config_path is not None :
101+ return os .path .dirname (config_path )
102+ # treat the model name as an explicit model path
103+ elif os .path .isfile (os .path .join (model_name , config_file )):
104+ return model_name
0 commit comments