Skip to content

[SW-234750] Fix reading distributed data in quant_config (#284) #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,10 @@

from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
from ..utils.logger import logger
from ..prepare_quant.prepare_model import get_world_size, get_local_rank
from .._core.scale_methods.scale_method_parser import parse_scale_method, validate_and_populate_scale_method, convert_scale_method_strings_to_enum
from .._core.scale_methods.scale_method_config import get_scale_method_from_config, check_scale_method_fields, ScaleMethodString, CfgStr, ScaleGranularity, ScaleValueType, ScaleRoundMethod

try:
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
except:
local_rank = int(os.getenv("LOCAL_RANK", "-1"))
world_size = int(os.getenv("WORLD_SIZE", "-1"))

class QuantMode(Enum):
NONE = 0
Expand Down Expand Up @@ -153,6 +148,8 @@ class Fp8cfg:
cfg: Mapping[str, Any]

def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
world_size = get_world_size()
local_rank = get_local_rank()
measured_global_config = {
"dump_stats_path": "stats",
"fp8_config": torch.float8_e4m3fn, # The parameters of the chosen Quantization methed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,34 @@
# limitations under the License.

import os
import torch
from typing import Optional

from .._core.save_measure import save_measurements
from .._core.utils import prepare_model
from .._quant_common.quant_config import Fp8cfg, _read_config_from_file, set_hqt_config

_world_size = -1
_local_rank = -1


def get_world_size():
global _world_size
if _world_size == -1:
if torch.distributed.is_initialized():
_world_size = torch.distributed.get_world_size()
return _world_size


def get_local_rank():
global _local_rank
if _local_rank == -1:
if torch.distributed.is_initialized():
_local_rank = torch.distributed.get_rank()
return _local_rank


def _prep_model_with_predefined_config(model, *, config):
from .._core.utils import prepare_model
from .._quant_common.quant_config import set_hqt_config

def _prep_model_with_predefined_config(model, *, config: Fp8cfg):
set_hqt_config(model, config)
prepare_model(model)

Expand All @@ -31,6 +51,8 @@ def prep_model(model, config_path: Optional[str] = None):
If `config_path` is not given or `None`,
instead perform the legacy behavior of checking for env variable `QUANT_CONFIG`.
"""
from .._quant_common.quant_config import Fp8cfg, _read_config_from_file

if config_path is None:
config_path = os.getenv("QUANT_CONFIG")
if config_path is None:
Expand All @@ -44,4 +66,6 @@ def prep_model(model, config_path: Optional[str] = None):


def finish_measurements(model):
from .._core.save_measure import save_measurements

save_measurements(model)
16 changes: 11 additions & 5 deletions neural_compressor/torch/algorithms/fp8_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from ._quant_common.quant_config import local_rank, world_size, HpDtype
from ._quant_common.quant_config import HpDtype
from ._core.quant_dequant import QuantDequantBase
from ._core.scale_handler import update_state_dict_method, ScaleFormat
from ._core.quantized_func_wrappers import (
Expand All @@ -26,6 +26,7 @@
get_quantized_func_wrapper,
OP_TYPE,
)
from .prepare_quant.prepare_model import get_world_size, get_local_rank
from .utils.logger import logger
from neural_compressor.common import options
from neural_compressor.torch.utils import (
Expand Down Expand Up @@ -75,7 +76,7 @@ def save_rank_model(model, folder_prefix="", **kwargs):
"""Save state_dict for model from each rank."""
# workaround for [SW-199005] [HQT] casted fp8 tensor cannot get data pointer
cur_accelerator.synchronize()
save_directory = add_rank_suffix(folder_prefix, local_rank, world_size)
save_directory = add_rank_suffix(folder_prefix, get_local_rank(), get_world_size())
os.makedirs(save_directory, exist_ok=True)
safe_serialization = kwargs.get("safe_serialization", True)
max_shard_size = kwargs.get("max_shard_size", f"{MAX_FILE_SIZE}GB")
Expand All @@ -96,6 +97,8 @@ def gather_state_dict(folder_prefix, file_name, tp_mod_list=[]):
"""Gather state_dict from files saved by each rank."""
from safetensors.torch import load_file as safe_load_file

world_size = get_world_size()

def _is_in_list(name, tp_mod_list):
for tp_name in tp_mod_list:
if tp_name in name:
Expand All @@ -122,6 +125,7 @@ def _is_in_list(name, tp_mod_list):

def clean_rank_files(folder_prefix, file_name=None):
"""Clean files saved by each rank after gathering."""
world_size = get_world_size()
for i in range(world_size): # TODO: assuming tp_size == world_size
folder_name = add_rank_suffix(folder_prefix, i, world_size)
if file_name is None:
Expand Down Expand Up @@ -375,6 +379,8 @@ def save(model, checkpoint_dir="saved_results", format="huggingface", **kwargs):
checkpoint_dir (str, optional): path to checkpoint. Defaults to "saved_results".
format (str, optional): defaults to 'huggingface'.
"""
world_size = get_world_size()
local_rank = get_local_rank()
format = get_enum_from_format(format)
model = process_model_for_scalar_scale(model)
if world_size > 1:
Expand Down Expand Up @@ -455,6 +461,7 @@ def load_empty_raw_model(model_name_or_path, **kwargs):
if model is None:
with init_empty_weights(include_buffers=False):
model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=hp_dtype)
world_size = get_world_size()
if world_size > 1:
import deepspeed
from neural_compressor.torch.utils import get_non_persistent_buffers, load_non_persistent_buffers
Expand Down Expand Up @@ -604,8 +611,7 @@ def load(model_name_or_path, format="huggingface", device="hpu", **kwargs):
FP8 model.
"""
format = get_enum_from_format(format)
global world_size
world_size = kwargs.get("world_size", world_size)
world_size = kwargs.get("world_size", get_world_size())
assert format == SaveLoadFormat.HUGGINGFACE, "Currently, only huggingface models are supported."
assert device in ["hpu", "cpu"], "Currently, only hpu & cpu device is supported for FP8 model."

Expand Down Expand Up @@ -781,7 +787,7 @@ def load_scale_params(model, new_scale_params):
param.data = new_scale


def get_new_rank_state_dict(all_rank_state_dict, model, world_size=world_size, local_rank=local_rank):
def get_new_rank_state_dict(all_rank_state_dict, model, world_size=get_world_size(), local_rank=get_local_rank()):
"""Get new rank state_dict for world_size.

Args:
Expand Down
6 changes: 5 additions & 1 deletion test/3x/torch/quantization/fp8_quant/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import transformers

from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import local_rank, world_size
from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import get_world_size, get_local_rank
from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear
from neural_compressor.torch.utils import get_used_hpu_mem_MB
Expand Down Expand Up @@ -45,6 +45,7 @@ def calib_func(model):

def test_save_vllm_compatible_model():
name = "Qwen/Qwen2-0.5B-Instruct"
world_size = get_world_size()
if world_size > 0:
# Do not use random weights since multi-processes will get different weights for Embedding
model = transformers.AutoModelForCausalLM.from_pretrained(name)
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_save_vllm_compatible_model():

@pytest.mark.skip(reason="[SW-226589] Skip this test since the model was updated")
def test_load_model_provided_by_neuralmagic():
world_size = get_world_size()
model_name_or_path = "neuralmagic/Qwen2-0.5B-Instruct-FP8"
hpu_mem0 = get_used_hpu_mem_MB()
model = load(model_name_or_path, format="huggingface", device="hpu")
Expand Down Expand Up @@ -117,6 +119,8 @@ def init_model(world_size):
@torch.no_grad()
@pytest.mark.parametrize("scale_method", ["maxabs_hw", "act_maxabs_hw_weights_pcs_maxabs_pow2"])
def test_default_save_load(scale_method):
world_size = get_world_size()
local_rank = get_local_rank()
example_inputs = torch.tensor([[10, 20]], dtype=torch.long).to("hpu")
model = init_model(world_size)
# The default value of model.generation_config.max_length in transformers is 20
Expand Down