Skip to content

Commit f9c1f92

Browse files
committed
Support full scores calculation with FAISS; assume mod faiss impl in induction heads script
1 parent e655ec6 commit f9c1f92

File tree

5 files changed

+155
-102
lines changed

5 files changed

+155
-102
lines changed

bergson/attributor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def __init__(
7474
self.grads[name] /= norm
7575

7676
def search(
77-
self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None
77+
self,
78+
queries: dict[str, Tensor],
79+
k: int | None,
80+
modules: list[str] | None = None,
7881
) -> tuple[Tensor, Tensor]:
7982
"""
8083
Search for the `k` nearest examples in the index based on the query or queries.
@@ -112,7 +115,7 @@ def search(
112115
)
113116

114117
modules = modules or list(q.keys())
115-
k = min(k, self.N)
118+
k = min(k or self.N, self.N)
116119

117120
scores = torch.stack(
118121
[q[name] @ self.grads[name].mT for name in modules], dim=-1

bergson/faiss_index.py

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import json
2-
import math
32
import os
43
from dataclasses import dataclass
54
from pathlib import Path
6-
from time import time
5+
from time import perf_counter
76
from typing import Protocol
87

98
import numpy as np
@@ -124,10 +123,12 @@ def index_to_device(index: Index, device: str) -> Index:
124123

125124

126125
class FaissIndex:
127-
"""FAISS index."""
126+
"""Shard-based FAISS index."""
128127

129128
shards: list[Index]
130129

130+
faiss_cfg: FaissConfig
131+
131132
def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool):
132133
try:
133134
import faiss
@@ -145,96 +146,137 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo
145146
f"{'_unit_norm' if unit_norm else ''}"
146147
)
147148
)
149+
faiss_path.mkdir(exist_ok=True, parents=True)
148150

149-
if not (faiss_path.exists() and any(faiss_path.iterdir())):
151+
if not any(faiss_path.iterdir()):
150152
print("Building FAISS index...")
151-
start = time()
153+
start = perf_counter()
154+
155+
root_path = Path(path)
156+
if (root_path / "info.json").exists():
157+
info_paths = [root_path / "info.json"]
158+
else:
159+
info_paths = [
160+
shard_path / "info.json"
161+
for shard_path in sorted(root_path.iterdir())
162+
if shard_path.is_dir() and (shard_path / "info.json").exists()
163+
]
164+
165+
if not info_paths:
166+
raise FileNotFoundError(f"No gradient metadata found under {path}")
167+
168+
total_grads = sum(
169+
[json.load(open(info_path))["num_grads"] for info_path in info_paths]
170+
)
152171

153-
faiss_path.mkdir(exist_ok=True, parents=True)
172+
assert faiss_cfg.num_shards <= total_grads and faiss_cfg.num_shards > 0
154173

155-
num_dataset_shards = len(list(Path(path).iterdir()))
156-
shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards)
174+
# Set the number of grads for each faiss index shard
175+
base_shard_size = total_grads // faiss_cfg.num_shards
176+
remainder = total_grads % faiss_cfg.num_shards
177+
shard_sizes = [base_shard_size] * (faiss_cfg.num_shards)
178+
shard_sizes[-1] += remainder
179+
180+
# Verify all gradients will be consumed
181+
assert (
182+
sum(shard_sizes) == total_grads
183+
), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}"
157184

158185
dl = gradients_loader(path)
159-
buffer = []
160-
index_idx = 0
186+
buffer: list[NDArray] = []
187+
buffer_size = 0
188+
shard_idx = 0
161189

162-
for grads in tqdm(dl, desc="Loading gradients"):
163-
grads = structured_to_unstructured(grads)
190+
def build_shard_from_buffer(
191+
buffer_parts: list[NDArray], shard_idx: int
192+
) -> None:
193+
print(f"Building shard {shard_idx}...")
194+
grads_chunk = np.concatenate(buffer_parts, axis=0)
195+
buffer_parts.clear()
164196

165-
if unit_norm:
166-
grads = normalize_grads(grads, device, faiss_cfg.batch_size)
197+
index = faiss.index_factory(
198+
grads_chunk.shape[1],
199+
faiss_cfg.index_factory,
200+
faiss.METRIC_INNER_PRODUCT,
201+
)
202+
index = index_to_device(index, device)
203+
if faiss_cfg.max_train_examples is not None:
204+
train_examples = min(
205+
faiss_cfg.max_train_examples, grads_chunk.shape[0]
206+
)
207+
else:
208+
train_examples = grads_chunk.shape[0]
209+
index.train(grads_chunk[:train_examples])
210+
index.add(grads_chunk)
167211

168-
buffer.append(grads)
212+
del grads_chunk
169213

170-
if len(buffer) == shards_per_index:
171-
# Build index shard
172-
print(f"Building shard {index_idx}...")
214+
index = index_to_device(index, "cpu")
215+
faiss.write_index(index, str(faiss_path / f"{shard_idx}.faiss"))
173216

174-
grads = np.concatenate(buffer, axis=0)
175-
buffer = []
217+
for grads in tqdm(dl, desc="Loading gradients"):
218+
grads = structured_to_unstructured(grads)
176219

177-
index = faiss.index_factory(
178-
grads.shape[1],
179-
faiss_cfg.index_factory,
180-
faiss.METRIC_INNER_PRODUCT,
181-
)
182-
index = index_to_device(index, device)
183-
train_examples = faiss_cfg.max_train_examples or grads.shape[0]
184-
index.train(grads[:train_examples])
185-
index.add(grads)
220+
if unit_norm:
221+
grads = normalize_grads(grads, device, faiss_cfg.batch_size)
186222

187-
# Write index to disk
188-
del grads
189-
index = index_to_device(index, "cpu")
190-
faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss"))
223+
batch_idx = 0
224+
batch_size = grads.shape[0]
225+
while batch_idx < batch_size and shard_idx < faiss_cfg.num_shards:
226+
remaining_in_shard = shard_sizes[shard_idx] - buffer_size
227+
take = min(remaining_in_shard, batch_size - batch_idx)
191228

192-
index_idx += 1
229+
if take > 0:
230+
buffer.append(grads[batch_idx : batch_idx + take])
231+
buffer_size += take
232+
batch_idx += take
193233

194-
if buffer:
195-
grads = np.concatenate(buffer, axis=0)
196-
buffer = []
197-
index = faiss.index_factory(
198-
grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT
199-
)
200-
index = index_to_device(index, device)
201-
index.train(grads)
202-
index.add(grads)
234+
if buffer_size == shard_sizes[shard_idx]:
235+
build_shard_from_buffer(buffer, shard_idx)
236+
buffer = []
237+
buffer_size = 0
238+
shard_idx += 1
203239

204-
# Write index to disk
205240
del grads
206-
index = index_to_device(index, "cpu")
207-
faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss"))
208241

209-
print(f"Built index in {(time() - start) / 60:.2f} minutes.")
210-
del buffer, index
242+
assert shard_idx == faiss_cfg.num_shards
243+
print(f"Built index in {(perf_counter() - start) / 60:.2f} minutes.")
244+
245+
shard_paths = sorted(
246+
(c for c in faiss_path.glob("*.faiss") if c.stem.isdigit()),
247+
key=lambda p: int(p.stem),
248+
)
211249

212250
shards = []
213-
for i in range(faiss_cfg.num_shards):
251+
for shard_path in shard_paths:
214252
shard = faiss.read_index(
215-
str(faiss_path / f"{i}.faiss"),
253+
str(shard_path),
216254
faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY,
217255
)
218256
if not faiss_cfg.mmap_index:
219257
shard = index_to_device(shard, device)
220258

221259
shards.append(shard)
222260

261+
if len(shards) != faiss_cfg.num_shards:
262+
faiss_cfg.num_shards = len(shards)
263+
223264
self.shards = shards
224265

225-
def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]:
266+
def search(self, q: NDArray, k: int | None) -> tuple[NDArray, NDArray]:
226267
"""Note: if fewer than `k` examples are found FAISS will return items
227-
with the index -1 and the maximum negative distance."""
268+
with the index -1 and the maximum negative distance. If `k` is `None`,
269+
all examples will be returned."""
228270
shard_distances = []
229271
shard_indices = []
230272
offset = 0
231273

232-
for index in self.shards:
233-
index.nprobe = self.faiss_cfg.nprobe
234-
distances, indices = index.search(q, k)
274+
for shard in self.shards:
275+
shard.nprobe = self.faiss_cfg.nprobe
276+
distances, indices = shard.search(q, k or shard.ntotal)
235277

236278
indices += offset
237-
offset += index.ntotal
279+
offset += shard.ntotal
238280

239281
shard_distances.append(distances)
240282
shard_indices.append(indices)
@@ -244,7 +286,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]:
244286

245287
# Rerank results overfetched from multiple shards
246288
if len(self.shards) > 1:
247-
topk_indices = np.argsort(distances, axis=1)[:, :k]
289+
topk_indices = np.argsort(distances, axis=1)[:, : k or self.ntotal]
248290
indices = indices[np.arange(indices.shape[0])[:, None], topk_indices]
249291
distances = distances[np.arange(distances.shape[0])[:, None], topk_indices]
250292

bergson/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
accumulate_grads: bool = False,
3636
use_optimizer_state: bool = True,
3737
track_order: bool = False,
38+
shard_size: int | None = 200_000,
3839
):
3940
"""
4041
Args:
@@ -50,6 +51,8 @@ def __init__(
5051
normalize the gradients. If `False`, no normalization is
5152
applied.
5253
track_order: Whether to record the shuffled order of training data.
54+
head_cfgs: Information used to split matrix-valued parameters into
55+
per-head matrices before down projection.
5356
"""
5457
super().__init__()
5558

0 commit comments

Comments
 (0)