From 3183a73096cd52d11b746313f1aaebb62d4699fc Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Fri, 18 Jul 2025 10:16:41 +0200 Subject: [PATCH] [SW-234750] Fix reading distributed data in quant_config (#284) --- .../fp8_quant/_quant_common/quant_config.py | 9 ++---- .../fp8_quant/prepare_quant/prepare_model.py | 32 ++++++++++++++++--- .../torch/algorithms/fp8_quant/save_load.py | 16 +++++++--- .../quantization/fp8_quant/test_save_load.py | 6 +++- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py index b1bec97424f..689d6e6800b 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -25,15 +25,10 @@ from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType from ..utils.logger import logger +from ..prepare_quant.prepare_model import get_world_size, get_local_rank from .._core.scale_methods.scale_method_parser import parse_scale_method, validate_and_populate_scale_method, convert_scale_method_strings_to_enum from .._core.scale_methods.scale_method_config import get_scale_method_from_config, check_scale_method_fields, ScaleMethodString, CfgStr, ScaleGranularity, ScaleValueType, ScaleRoundMethod -try: - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() -except: - local_rank = int(os.getenv("LOCAL_RANK", "-1")) - world_size = int(os.getenv("WORLD_SIZE", "-1")) class QuantMode(Enum): NONE = 0 @@ -153,6 +148,8 @@ class Fp8cfg: cfg: Mapping[str, Any] def parse(custom_config: Mapping[str, str]) -> Fp8cfg: + world_size = get_world_size() + local_rank = get_local_rank() measured_global_config = { "dump_stats_path": "stats", "fp8_config": torch.float8_e4m3fn, # The parameters of the chosen Quantization methed diff --git a/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py b/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py index 2ab2e0b0260..be5bb8e247e 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py +++ b/neural_compressor/torch/algorithms/fp8_quant/prepare_quant/prepare_model.py @@ -13,14 +13,34 @@ # limitations under the License. import os +import torch from typing import Optional -from .._core.save_measure import save_measurements -from .._core.utils import prepare_model -from .._quant_common.quant_config import Fp8cfg, _read_config_from_file, set_hqt_config +_world_size = -1 +_local_rank = -1 + + +def get_world_size(): + global _world_size + if _world_size == -1: + if torch.distributed.is_initialized(): + _world_size = torch.distributed.get_world_size() + return _world_size + + +def get_local_rank(): + global _local_rank + if _local_rank == -1: + if torch.distributed.is_initialized(): + _local_rank = torch.distributed.get_rank() + return _local_rank + + +def _prep_model_with_predefined_config(model, *, config): + from .._core.utils import prepare_model + from .._quant_common.quant_config import set_hqt_config -def _prep_model_with_predefined_config(model, *, config: Fp8cfg): set_hqt_config(model, config) prepare_model(model) @@ -31,6 +51,8 @@ def prep_model(model, config_path: Optional[str] = None): If `config_path` is not given or `None`, instead perform the legacy behavior of checking for env variable `QUANT_CONFIG`. """ + from .._quant_common.quant_config import Fp8cfg, _read_config_from_file + if config_path is None: config_path = os.getenv("QUANT_CONFIG") if config_path is None: @@ -44,4 +66,6 @@ def prep_model(model, config_path: Optional[str] = None): def finish_measurements(model): + from .._core.save_measure import save_measurements + save_measurements(model) diff --git a/neural_compressor/torch/algorithms/fp8_quant/save_load.py b/neural_compressor/torch/algorithms/fp8_quant/save_load.py index 3b16dcc41ea..df599a87009 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/save_load.py +++ b/neural_compressor/torch/algorithms/fp8_quant/save_load.py @@ -17,7 +17,7 @@ import torch -from ._quant_common.quant_config import local_rank, world_size, HpDtype +from ._quant_common.quant_config import HpDtype from ._core.quant_dequant import QuantDequantBase from ._core.scale_handler import update_state_dict_method, ScaleFormat from ._core.quantized_func_wrappers import ( @@ -26,6 +26,7 @@ get_quantized_func_wrapper, OP_TYPE, ) +from .prepare_quant.prepare_model import get_world_size, get_local_rank from .utils.logger import logger from neural_compressor.common import options from neural_compressor.torch.utils import ( @@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs): """Save state_dict for model from each rank.""" # workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer cur_accelerator.synchronize() - save_directory = add_rank_suffix(folder_prefix, local_rank, world_size) + save_directory = add_rank_suffix(folder_prefix, get_local_rank(), get_world_size()) os.makedirs(save_directory, exist_ok=True) safe_serialization = kwargs.get("safe_serialization", True) max_shard_size = kwargs.get("max_shard_size", f"{MAX_FILE_SIZE}GB") @@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]): """Gather state_dict from files saved by each rank.""" from safetensors.torch import load_file as safe_load_file + world_size = get_world_size() + def _is_in_list(name, tp_mod_list): for tp_name in tp_mod_list: if tp_name in name: @@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list): def clean_rank_files(folder_prefix, file_name=None): """Clean files saved by each rank after gathering.""" + world_size = get_world_size() for i in range(world_size): # TODO: assuming tp_size == world_size folder_name = add_rank_suffix(folder_prefix, i, world_size) if file_name is None: @@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs): checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results". format (str, optional): defaults to 'huggingface'. """ + world_size = get_world_size() + local_rank = get_local_rank() format = get_enum_from_format(format) model = process_model_for_scalar_scale(model) if world_size > 1: @@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs): if model is None: with init_empty_weights(include_buffers=False): model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype) + world_size = get_world_size() if world_size > 1: import deepspeed from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers @@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs): FP8 model. """ format = get_enum_from_format(format) - global world_size - world_size = kwargs.get("world_size", world_size) + world_size = kwargs.get("world_size", get_world_size()) assert format == SaveLoadFormat.HUGGINGFACE, "Currently, only huggingface models are supported." assert device in ["hpu", "cpu"], "Currently, only hpu & cpu device is supported for FP8 model." @@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params): param.data = new_scale -def get_new_rank_state_dict(all_rank_state_dict, model, world_size=world_size, local_rank=local_rank): +def get_new_rank_state_dict(all_rank_state_dict, model, world_size=get_world_size(), local_rank=get_local_rank()): """Get new rank state_dict for world_size. Args: diff --git a/test/3x/torch/quantization/fp8_quant/test_save_load.py b/test/3x/torch/quantization/fp8_quant/test_save_load.py index 48eeb147f80..344ed40a0fb 100644 --- a/test/3x/torch/quantization/fp8_quant/test_save_load.py +++ b/test/3x/torch/quantization/fp8_quant/test_save_load.py @@ -4,7 +4,7 @@ import torch import transformers -from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import local_rank, world_size +from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import get_world_size, get_local_rank from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear from neural_compressor.torch.utils import get_used_hpu_mem_MB @@ -45,6 +45,7 @@ def calib_func(model): def test_save_vllm_compatible_model(): name = "Qwen/Qwen2-0.5B-Instruct" + world_size = get_world_size() if world_size > 0: # Do not use random weights since multi-processes will get different weights for Embedding model = transformers.AutoModelForCausalLM.from_pretrained(name) @@ -77,6 +78,7 @@ def test_save_vllm_compatible_model(): @pytest.mark.skip(reason="[SW-226589] Skip this test since the model was updated") def test_load_model_provided_by_neuralmagic(): + world_size = get_world_size() model_name_or_path = "neuralmagic/Qwen2-0.5B-Instruct-FP8" hpu_mem0 = get_used_hpu_mem_MB() model = load(model_name_or_path, format="huggingface", device="hpu") @@ -117,6 +119,8 @@ def init_model(world_size): @torch.no_grad() @pytest.mark.parametrize("scale_method", ["maxabs_hw", "act_maxabs_hw_weights_pcs_maxabs_pow2"]) def test_default_save_load(scale_method): + world_size = get_world_size() + local_rank = get_local_rank() example_inputs = torch.tensor([[10, 20]], dtype=torch.long).to("hpu") model = init_model(world_size) # The default value of model.generation_config.max_length in transformers is 20