Skip to content
3 changes: 3 additions & 0 deletions invokeai/backend/model_manager/model_on_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:

path = self.resolve_weight_file(path)

if path in self._state_dict_cache:
return self._state_dict_cache[path]

with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
Expand Down
55 changes: 43 additions & 12 deletions invokeai/backend/quantization/gguf/loaders.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,53 @@
import gc
from pathlib import Path

import gguf
import torch

from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
from invokeai.backend.util.logging import InvokeAILogger

logger = InvokeAILogger.get_logger()


class WrappedGGUFReader:
"""Wrapper around GGUFReader that adds a close() method."""

def __init__(self, path: Path):
self.reader = gguf.GGUFReader(path)

def __enter__(self):
return self.reader

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False

def close(self):
"""Explicitly close the memory-mapped file."""
if hasattr(self.reader, "data"):
try:
self.reader.data.flush()
del self.reader.data
except (AttributeError, OSError, ValueError) as e:
logger.warning(f"Wasn't able to close GGUF memory map: {e}")
del self.reader
gc.collect()


def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)

sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
)
return sd
with WrappedGGUFReader(path) as reader:
sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor,
ggml_quantization_type=tensor.tensor_type,
tensor_shape=shape,
compute_dtype=compute_dtype,
)
return sd