17
17
18
18
import torch
19
19
20
- from ._quant_common .quant_config import local_rank , world_size , HpDtype
20
+ from ._quant_common .quant_config import HpDtype
21
21
from ._core .quant_dequant import QuantDequantBase
22
22
from ._core .scale_handler import update_state_dict_method , ScaleFormat
23
23
from ._core .quantized_func_wrappers import (
26
26
get_quantized_func_wrapper ,
27
27
OP_TYPE ,
28
28
)
29
+ from .prepare_quant .prepare_model import get_world_size , get_local_rank
29
30
from .utils .logger import logger
30
31
from neural_compressor .common import options
31
32
from neural_compressor .torch .utils import (
@@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs):
75
76
"""Save state_dict for model from each rank."""
76
77
# workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer
77
78
cur_accelerator .synchronize ()
78
- save_directory = add_rank_suffix (folder_prefix , local_rank , world_size )
79
+ save_directory = add_rank_suffix (folder_prefix , get_local_rank (), get_world_size () )
79
80
os .makedirs (save_directory , exist_ok = True )
80
81
safe_serialization = kwargs .get ("safe_serialization" , True )
81
82
max_shard_size = kwargs .get ("max_shard_size" , f"{ MAX_FILE_SIZE } GB" )
@@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]):
96
97
"""Gather state_dict from files saved by each rank."""
97
98
from safetensors .torch import load_file as safe_load_file
98
99
100
+ world_size = get_world_size ()
101
+
99
102
def _is_in_list (name , tp_mod_list ):
100
103
for tp_name in tp_mod_list :
101
104
if tp_name in name :
@@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list):
122
125
123
126
def clean_rank_files (folder_prefix , file_name = None ):
124
127
"""Clean files saved by each rank after gathering."""
128
+ world_size = get_world_size ()
125
129
for i in range (world_size ): # TODO: assuming tp_size == world_size
126
130
folder_name = add_rank_suffix (folder_prefix , i , world_size )
127
131
if file_name is None :
@@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
375
379
checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
376
380
format (str, optional): defaults to 'huggingface'.
377
381
"""
382
+ world_size = get_world_size ()
383
+ local_rank = get_local_rank ()
378
384
format = get_enum_from_format (format )
379
385
model = process_model_for_scalar_scale (model )
380
386
if world_size > 1 :
@@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
455
461
if model is None :
456
462
with init_empty_weights (include_buffers = False ):
457
463
model = transformers .AutoModelForCausalLM .from_config (config , torch_dtype = hp_dtype )
464
+ world_size = get_world_size ()
458
465
if world_size > 1 :
459
466
import deepspeed
460
467
from neural_compressor .torch .utils import get_non_persistent_buffers , load_non_persistent_buffers
@@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
604
611
FP8 model.
605
612
"""
606
613
format = get_enum_from_format (format )
607
- global world_size
608
- world_size = kwargs .get ("world_size" , world_size )
614
+ world_size = kwargs .get ("world_size" , get_world_size ())
609
615
assert format == SaveLoadFormat .HUGGINGFACE , "Currently, only huggingface models are supported."
610
616
assert device in ["hpu" , "cpu" ], "Currently, only hpu & cpu device is supported for FP8 model."
611
617
@@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params):
781
787
param .data = new_scale
782
788
783
789
784
- def get_new_rank_state_dict (all_rank_state_dict , model , world_size = world_size , local_rank = local_rank ):
790
+ def get_new_rank_state_dict (all_rank_state_dict , model , world_size = get_world_size () , local_rank = get_local_rank () ):
785
791
"""Get new rank state_dict for world_size.
786
792
787
793
Args:
0 commit comments