Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions bergson/attributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(
self.grads[name] /= norm

def search(
self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None
self,
queries: dict[str, Tensor],
k: int | None,
modules: list[str] | None = None,
) -> tuple[Tensor, Tensor]:
"""
Search for the `k` nearest examples in the index based on the query or queries.
Expand Down Expand Up @@ -112,7 +115,7 @@ def search(
)

modules = modules or list(q.keys())
k = min(k, self.N)
k = min(k or self.N, self.N)

scores = torch.stack(
[q[name] @ self.grads[name].mT for name in modules], dim=-1
Expand Down
4 changes: 2 additions & 2 deletions bergson/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def callback(name: str, g: torch.Tensor):
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)

losses.mean().backward()
losses.sum().backward()
else:
with collector:
logits = model(x).logits[:, :-1]
Expand All @@ -123,7 +123,7 @@ def callback(name: str, g: torch.Tensor):
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)

losses.mean().backward()
losses.sum().backward()

# Weirdly you need to explicitly synchronize here in order to make sure that
# the nonblocking copies actually finish before we call .numpy()
Expand Down
2 changes: 1 addition & 1 deletion bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class IndexConfig:
streaming: bool = False
"""Whether to use streaming mode for the dataset."""

stream_shard_size: int = 100_000
stream_shard_size: int = 400_000
"""Shard size for streaming the dataset into Dataset objects."""

revision: str | None = None
Expand Down
159 changes: 100 additions & 59 deletions bergson/faiss_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from time import time
from time import perf_counter
from typing import Protocol

import numpy as np
Expand Down Expand Up @@ -91,7 +90,7 @@ def load_shard(shard_dir: str) -> np.memmap:
yield load_shard(root_dir)
else:
for shard_path in sorted(root_path.iterdir()):
if shard_path.is_dir():
if shard_path.is_dir() and "shard" in shard_path.name:
yield load_shard(str(shard_path))


Expand Down Expand Up @@ -124,10 +123,12 @@ def index_to_device(index: Index, device: str) -> Index:


class FaissIndex:
"""FAISS index."""
"""Shard-based FAISS index."""

shards: list[Index]

faiss_cfg: FaissConfig

def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool):
try:
import faiss
Expand All @@ -145,97 +146,137 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo
f"{'_unit_norm' if unit_norm else ''}"
)
)
faiss_path.mkdir(exist_ok=True, parents=True)

if not (faiss_path.exists() and any(faiss_path.iterdir())):
if not any(faiss_path.iterdir()):
print("Building FAISS index...")
start = time()
start = perf_counter()

root_path = Path(path)
if (root_path / "info.json").exists():
info_paths = [root_path / "info.json"]
else:
info_paths = [
shard_path / "info.json"
for shard_path in sorted(root_path.iterdir())
if shard_path.is_dir() and (shard_path / "info.json").exists()
]

if not info_paths:
raise FileNotFoundError(f"No gradient metadata found under {path}")

total_grads = sum(
[json.load(open(info_path))["num_grads"] for info_path in info_paths]
)

assert faiss_cfg.num_shards <= total_grads and faiss_cfg.num_shards > 0

faiss_path.mkdir(exist_ok=True, parents=True)
# Set the number of grads for each faiss index shard
base_shard_size = total_grads // faiss_cfg.num_shards
remainder = total_grads % faiss_cfg.num_shards
shard_sizes = [base_shard_size] * (faiss_cfg.num_shards)
shard_sizes[-1] += remainder

num_dataset_shards = len(list(Path(path).iterdir()))
shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards)
# Verify all gradients will be consumed
assert (
sum(shard_sizes) == total_grads
), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}"

dl = gradients_loader(path)
buffer = []
index_idx = 0
buffer: list[NDArray] = []
buffer_size = 0
shard_idx = 0

for grads in tqdm(dl, desc="Loading gradients"):
if grads.dtype.names is not None:
grads = structured_to_unstructured(grads)
def build_shard_from_buffer(
buffer_parts: list[NDArray], shard_idx: int
) -> None:
print(f"Building shard {shard_idx}...")
grads_chunk = np.concatenate(buffer_parts, axis=0)
buffer_parts.clear()

if unit_norm:
grads = normalize_grads(grads, device, faiss_cfg.batch_size)
index = faiss.index_factory(
grads_chunk.shape[1],
faiss_cfg.index_factory,
faiss.METRIC_INNER_PRODUCT,
)
index = index_to_device(index, device)
if faiss_cfg.max_train_examples is not None:
train_examples = min(
faiss_cfg.max_train_examples, grads_chunk.shape[0]
)
else:
train_examples = grads_chunk.shape[0]
index.train(grads_chunk[:train_examples])
index.add(grads_chunk)

buffer.append(grads)
del grads_chunk

if len(buffer) == shards_per_index:
# Build index shard
print(f"Building shard {index_idx}...")
index = index_to_device(index, "cpu")
faiss.write_index(index, str(faiss_path / f"{shard_idx}.faiss"))

grads = np.concatenate(buffer, axis=0)
buffer = []
for grads in tqdm(dl, desc="Loading gradients"):
grads = structured_to_unstructured(grads)

index = faiss.index_factory(
grads.shape[1],
faiss_cfg.index_factory,
faiss.METRIC_INNER_PRODUCT,
)
index = index_to_device(index, device)
train_examples = faiss_cfg.max_train_examples or grads.shape[0]
index.train(grads[:train_examples])
index.add(grads)
if unit_norm:
grads = normalize_grads(grads, device, faiss_cfg.batch_size)

# Write index to disk
del grads
index = index_to_device(index, "cpu")
faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss"))
batch_idx = 0
batch_size = grads.shape[0]
while batch_idx < batch_size and shard_idx < faiss_cfg.num_shards:
remaining_in_shard = shard_sizes[shard_idx] - buffer_size
take = min(remaining_in_shard, batch_size - batch_idx)

index_idx += 1
if take > 0:
buffer.append(grads[batch_idx : batch_idx + take])
buffer_size += take
batch_idx += take

if buffer:
grads = np.concatenate(buffer, axis=0)
buffer = []
index = faiss.index_factory(
grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT
)
index = index_to_device(index, device)
index.train(grads)
index.add(grads)
if buffer_size == shard_sizes[shard_idx]:
build_shard_from_buffer(buffer, shard_idx)
buffer = []
buffer_size = 0
shard_idx += 1

# Write index to disk
del grads
index = index_to_device(index, "cpu")
faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss"))

print(f"Built index in {(time() - start) / 60:.2f} minutes.")
del buffer, index
assert shard_idx == faiss_cfg.num_shards
print(f"Built index in {(perf_counter() - start) / 60:.2f} minutes.")

shard_paths = sorted(
(c for c in faiss_path.glob("*.faiss") if c.stem.isdigit()),
key=lambda p: int(p.stem),
)

shards = []
for i in range(faiss_cfg.num_shards):
for shard_path in shard_paths:
shard = faiss.read_index(
str(faiss_path / f"{i}.faiss"),
str(shard_path),
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
)
if not faiss_cfg.mmap_index:
shard = index_to_device(shard, device)

shards.append(shard)

if len(shards) != faiss_cfg.num_shards:
faiss_cfg.num_shards = len(shards)

self.shards = shards

def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]:
def search(self, q: NDArray, k: int | None) -> tuple[NDArray, NDArray]:
"""Note: if fewer than `k` examples are found FAISS will return items
with the index -1 and the maximum negative distance."""
with the index -1 and the maximum negative distance. If `k` is `None`,
all examples will be returned."""
shard_distances = []
shard_indices = []
offset = 0

for index in self.shards:
index.nprobe = self.faiss_cfg.nprobe
distances, indices = index.search(q, k)
for shard in self.shards:
shard.nprobe = self.faiss_cfg.nprobe
distances, indices = shard.search(q, k or shard.ntotal)

indices += offset
offset += index.ntotal
offset += shard.ntotal

shard_distances.append(distances)
shard_indices.append(indices)
Expand All @@ -245,7 +286,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]:

# Rerank results overfetched from multiple shards
if len(self.shards) > 1:
topk_indices = np.argsort(distances, axis=1)[:, :k]
topk_indices = np.argsort(distances, axis=1)[:, : k or self.ntotal]
indices = indices[np.arange(indices.shape[0])[:, None], topk_indices]
distances = distances[np.arange(distances.shape[0])[:, None], topk_indices]

Expand Down
31 changes: 23 additions & 8 deletions bergson/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ def __init__(
accumulate_grads: bool = False,
use_optimizer_state: bool = True,
track_order: bool = False,
shard_size: int | None = 200_000,
):
"""
Args:
path: The path to save the gradients
head_cfgs: Information used to split matrix-valued parameters into
per-head matrices before down projection.
projection_dim: The dimension to project the gradients onto
dtype: The dtype of the on-disk gradient store
accumulate_grads: Whether to take the sum of the gradients
Expand All @@ -48,8 +51,8 @@ def __init__(
normalize the gradients. If `False`, no normalization is
applied.
track_order: Whether to record the shuffled order of training data.
head_cfgs: Information used to split matrix-valued parameters into
per-head matrices before down projection.
head_cfgs: Information used to split matrix-valued parameters into
per-head matrices before down projection.
"""
super().__init__()

Expand Down Expand Up @@ -95,7 +98,7 @@ def on_train_begin(
if not hasattr(args, "__gradient_collection_enabled__"):
raise RuntimeError(
"Gradient collection is not enabled. Please enable it by "
"calling bergson.prepare_gradient_collection on the trainer."
"calling bergson.prepare_for_gradient_collection on the trainer."
)

if isinstance(model, PeftModel):
Expand Down Expand Up @@ -133,7 +136,7 @@ def on_epoch_begin(
state: TrainerState,
control: TrainerControl,
*,
eval_dataloader: DataLoader | dict[str, DataLoader],
eval_dataloader: DataLoader | dict[str, DataLoader] | None,
train_dataloader: DataLoader,
**kwargs,
):
Expand All @@ -158,8 +161,16 @@ def on_epoch_begin(

# Set up the gradient buffers for the evaluation datasets
if eval_dataloader is None:
return
elif isinstance(eval_dataloader, dict):
# HF Trainer doesn't expose the evaluation dataloaders
if hasattr(args, "eval_dataset"):
eval_dataloader = DataLoader(
args.eval_dataset, batch_size=1, shuffle=False
)
else:
print("Warning: no evaluation dataloader found")
return

if isinstance(eval_dataloader, dict):
eval_datasets = eval_dataloader
else:
eval_datasets = {"eval": eval_dataloader}
Expand Down Expand Up @@ -302,9 +313,11 @@ def on_step_end(

proc.normalizers = normalizers

def on_evaluate(self, args, state, control, **kwargs):
print("on_evaluate")

def on_prediction_step(self, args, state, control, **kwargs):
dataset_name = kwargs["inputs"]["dataset_name"]
self.write_grads(self.eval_grad_buffers[dataset_name])
print("on_prediction_step")

def on_train_end(
self,
Expand Down Expand Up @@ -365,6 +378,8 @@ def prepare_for_gradient_collection(trainer: Trainer):
lambda ex, idx: {"_idx": idx}, with_indices=True
)

trainer.args.eval_dataset = trainer.eval_dataset

trainer._set_signature_columns_if_needed()
trainer._signature_columns.append("_idx")

Expand Down
Loading