Skip to content

convert : handle pre-quantized models #14810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 174 additions & 65 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
tensor_names: set[str] | None
model_tensors: dict[str, Callable[[], Tensor]]
gguf_writer: gguf.GGUFWriter
model_name: str | None
metadata_override: Path | None
Expand Down Expand Up @@ -99,24 +97,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True

def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))

self.get_tensors = get_remote_tensors
else:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
self.tensor_names = None
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
Expand All @@ -132,6 +114,8 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16

self.dequant_model()

# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
Expand All @@ -150,63 +134,209 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
return None
raise KeyError(f"could not find any of: {keys}")

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts: set[str] = set()
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
tensors: dict[str, Callable[[], Tensor]] = {}

if remote_hf_model_id is not None:
is_safetensors = True

index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
for name, remote_tensor in remote_tensors.items():
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)

return tensors

part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
is_safetensors: bool = len(part_names) > 0
if not is_safetensors:
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")

tensor_names_from_index: set[str] = set()

index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
index_name += ".index.json"
index_file = self.dir_model / index_name

if index_file.is_file():
self.tensor_names = set()
logger.info(f"gguf: loading model weight map from '{index_name}'")
with open(index_file, "r", encoding="utf-8") as f:
index: dict[str, Any] = json.load(f)
weight_map = index.get("weight_map")
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
self.tensor_names.update(weight_map.keys())
tensor_names_from_index.update(weight_map.keys())
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}

for part_name in self.part_names:
logger.info(f"gguf: loading model part '{part_name}'")
for part_name in part_names:
logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
if is_safetensors:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))

with ctx as model_part:
tensor_names_from_parts.update(model_part.keys())
assert model_part is not None

for name in model_part.keys():
if self.is_safetensors:
if is_safetensors:
if self.lazy:
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
else:
data = model_part.get_tensor(name)
data_gen = lambda data=data: data # noqa: E731
else:
data = model_part[name]
if self.lazy:
data = LazyTorchTensor.from_eager(data)
yield name, data
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
else:
data_gen = lambda data=data: data # noqa: E731
tensors[name] = data_gen

# verify tensor name presence and identify potentially missing files
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
if len(tensor_names_from_index) > 0:
tensor_names_from_parts = set(tensors.keys())
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")

return tensors

def dequant_model(self):
tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}

if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
quant_method = quant_config.get("quant_method")

def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
weight = weight.view(torch.uint8)
orig_shape = weight.shape

shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
data = data & 3
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))

# The scale is inverted
return data / scale.float()

def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
scale = scale.float()

if (weight_block_size := quant_config.get("weight_block_size")):
# TODO: make sure it's a list of integers
for i, size in enumerate(weight_block_size):
scale = scale.repeat_interleave(size, i)

return weight.float() * scale

# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
bits = quant_config["bits"]
assert bits in (2, 3, 4, 8)
assert qweight.dtype == qzeros.dtype
maxq = (2 ** bits) - 1
weight = None
zeros = None
pack_dtype_bits = qweight.dtype.itemsize * 8

if bits in [2, 4, 8]:
pack_factor = pack_dtype_bits // bits
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
if self.lazy:
wf = LazyTorchTensor.from_eager(wf)

zeros = torch.bitwise_right_shift(
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
wf.unsqueeze(0)
).to(torch.int16 if bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)

weight = torch.bitwise_and(
torch.bitwise_right_shift(
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
wf.unsqueeze(-1)
).to(torch.int16 if bits == 8 else torch.int8),
maxq
)
elif bits == 3:
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")

assert weight is not None
assert zeros is not None

weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])

# gptq_v2 doesn't need to offset zeros
if quant_config.get("checkpoint_format", "gptq") == "gptq":
zeros += 1

return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T

if quant_method == "bitnet":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
tensors_to_remove.append(name)
elif quant_method == "fp8":
for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"):
weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
tensors_to_remove.append(name)
elif quant_method == "gptq":
for name in self.model_tensors.keys():
if name.endswith(".qweight"):
base_name = name.removesuffix(".qweight")
g_idx = self.model_tensors[base_name + ".g_idx"]
qweight = self.model_tensors[base_name + ".qweight"]
qzeros = self.model_tensors[base_name + ".qzeros"]
scales = self.model_tensors[base_name + ".scales"]
new_tensors[base_name + ".weight"] = (
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
g(), w(), z(), s()
)
)
tensors_to_remove += [
base_name + n
for n in (
".g_idx",
".qzeros",
".qweight",
".scales",
)
]
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")

for name in tensors_to_remove:
if name in self.model_tensors:
del self.model_tensors[name]

for name, value in new_tensors.items():
self.model_tensors[name] = value

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, gen in self.model_tensors.items():
yield name, gen()

def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
if key not in gguf.MODEL_TENSORS[self.model_arch]:
Expand Down Expand Up @@ -3860,27 +3990,6 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)

_has_tok_embd = False

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)

new_name = self.map_tensor_name(name)

# assuming token_embd.weight is seen before output.weight
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
self.tensor_names.remove("transformer.wte.weight")
elif new_name == tok_embd_name:
self._has_tok_embd = True

return [(new_name, data_torch)]


@ModelBase.register("InternLM2ForCausalLM")
class InternLM2Model(TextModel):
Expand Down