11import json
2- import math
32import os
43from dataclasses import dataclass
54from pathlib import Path
6- from time import time
5+ from time import perf_counter
76from typing import Protocol
87
98import numpy as np
@@ -124,10 +123,12 @@ def index_to_device(index: Index, device: str) -> Index:
124123
125124
126125class FaissIndex :
127- """FAISS index."""
126+ """Shard-based FAISS index."""
128127
129128 shards : list [Index ]
130129
130+ faiss_cfg : FaissConfig
131+
131132 def __init__ (self , path : str , faiss_cfg : FaissConfig , device : str , unit_norm : bool ):
132133 try :
133134 import faiss
@@ -145,96 +146,137 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo
145146 f"{ '_unit_norm' if unit_norm else '' } "
146147 )
147148 )
149+ faiss_path .mkdir (exist_ok = True , parents = True )
148150
149- if not ( faiss_path . exists () and any (faiss_path .iterdir () )):
151+ if not any (faiss_path .iterdir ()):
150152 print ("Building FAISS index..." )
151- start = time ()
153+ start = perf_counter ()
154+
155+ root_path = Path (path )
156+ if (root_path / "info.json" ).exists ():
157+ info_paths = [root_path / "info.json" ]
158+ else :
159+ info_paths = [
160+ shard_path / "info.json"
161+ for shard_path in sorted (root_path .iterdir ())
162+ if shard_path .is_dir () and (shard_path / "info.json" ).exists ()
163+ ]
164+
165+ if not info_paths :
166+ raise FileNotFoundError (f"No gradient metadata found under { path } " )
167+
168+ total_grads = sum (
169+ [json .load (open (info_path ))["num_grads" ] for info_path in info_paths ]
170+ )
152171
153- faiss_path . mkdir ( exist_ok = True , parents = True )
172+ assert faiss_cfg . num_shards <= total_grads and faiss_cfg . num_shards > 0
154173
155- num_dataset_shards = len (list (Path (path ).iterdir ()))
156- shards_per_index = math .ceil (num_dataset_shards / faiss_cfg .num_shards )
174+ # Set the number of grads for each faiss index shard
175+ base_shard_size = total_grads // faiss_cfg .num_shards
176+ remainder = total_grads % faiss_cfg .num_shards
177+ shard_sizes = [base_shard_size ] * (faiss_cfg .num_shards )
178+ shard_sizes [- 1 ] += remainder
179+
180+ # Verify all gradients will be consumed
181+ assert (
182+ sum (shard_sizes ) == total_grads
183+ ), f"Shard sizes { shard_sizes } don't sum to total_grads { total_grads } "
157184
158185 dl = gradients_loader (path )
159- buffer = []
160- index_idx = 0
186+ buffer : list [NDArray ] = []
187+ buffer_size = 0
188+ shard_idx = 0
161189
162- for grads in tqdm (dl , desc = "Loading gradients" ):
163- grads = structured_to_unstructured (grads )
190+ def build_shard_from_buffer (
191+ buffer_parts : list [NDArray ], shard_idx : int
192+ ) -> None :
193+ print (f"Building shard { shard_idx } ..." )
194+ grads_chunk = np .concatenate (buffer_parts , axis = 0 )
195+ buffer_parts .clear ()
164196
165- if unit_norm :
166- grads = normalize_grads (grads , device , faiss_cfg .batch_size )
197+ index = faiss .index_factory (
198+ grads_chunk .shape [1 ],
199+ faiss_cfg .index_factory ,
200+ faiss .METRIC_INNER_PRODUCT ,
201+ )
202+ index = index_to_device (index , device )
203+ if faiss_cfg .max_train_examples is not None :
204+ train_examples = min (
205+ faiss_cfg .max_train_examples , grads_chunk .shape [0 ]
206+ )
207+ else :
208+ train_examples = grads_chunk .shape [0 ]
209+ index .train (grads_chunk [:train_examples ])
210+ index .add (grads_chunk )
167211
168- buffer . append ( grads )
212+ del grads_chunk
169213
170- if len (buffer ) == shards_per_index :
171- # Build index shard
172- print (f"Building shard { index_idx } ..." )
214+ index = index_to_device (index , "cpu" )
215+ faiss .write_index (index , str (faiss_path / f"{ shard_idx } .faiss" ))
173216
174- grads = np . concatenate ( buffer , axis = 0 )
175- buffer = []
217+ for grads in tqdm ( dl , desc = "Loading gradients" ):
218+ grads = structured_to_unstructured ( grads )
176219
177- index = faiss .index_factory (
178- grads .shape [1 ],
179- faiss_cfg .index_factory ,
180- faiss .METRIC_INNER_PRODUCT ,
181- )
182- index = index_to_device (index , device )
183- train_examples = faiss_cfg .max_train_examples or grads .shape [0 ]
184- index .train (grads [:train_examples ])
185- index .add (grads )
220+ if unit_norm :
221+ grads = normalize_grads (grads , device , faiss_cfg .batch_size )
186222
187- # Write index to disk
188- del grads
189- index = index_to_device (index , "cpu" )
190- faiss .write_index (index , str (faiss_path / f"{ index_idx } .faiss" ))
223+ batch_idx = 0
224+ batch_size = grads .shape [0 ]
225+ while batch_idx < batch_size and shard_idx < faiss_cfg .num_shards :
226+ remaining_in_shard = shard_sizes [shard_idx ] - buffer_size
227+ take = min (remaining_in_shard , batch_size - batch_idx )
191228
192- index_idx += 1
229+ if take > 0 :
230+ buffer .append (grads [batch_idx : batch_idx + take ])
231+ buffer_size += take
232+ batch_idx += take
193233
194- if buffer :
195- grads = np .concatenate (buffer , axis = 0 )
196- buffer = []
197- index = faiss .index_factory (
198- grads .shape [1 ], faiss_cfg .index_factory , faiss .METRIC_INNER_PRODUCT
199- )
200- index = index_to_device (index , device )
201- index .train (grads )
202- index .add (grads )
234+ if buffer_size == shard_sizes [shard_idx ]:
235+ build_shard_from_buffer (buffer , shard_idx )
236+ buffer = []
237+ buffer_size = 0
238+ shard_idx += 1
203239
204- # Write index to disk
205240 del grads
206- index = index_to_device (index , "cpu" )
207- faiss .write_index (index , str (faiss_path / f"{ index_idx } .faiss" ))
208241
209- print (f"Built index in { (time () - start ) / 60 :.2f} minutes." )
210- del buffer , index
242+ assert shard_idx == faiss_cfg .num_shards
243+ print (f"Built index in { (perf_counter () - start ) / 60 :.2f} minutes." )
244+
245+ shard_paths = sorted (
246+ (c for c in faiss_path .glob ("*.faiss" ) if c .stem .isdigit ()),
247+ key = lambda p : int (p .stem ),
248+ )
211249
212250 shards = []
213- for i in range ( faiss_cfg . num_shards ) :
251+ for shard_path in shard_paths :
214252 shard = faiss .read_index (
215- str (faiss_path / f" { i } .faiss" ),
253+ str (shard_path ),
216254 faiss .IO_FLAG_MMAP | faiss .IO_FLAG_READ_ONLY ,
217255 )
218256 if not faiss_cfg .mmap_index :
219257 shard = index_to_device (shard , device )
220258
221259 shards .append (shard )
222260
261+ if len (shards ) != faiss_cfg .num_shards :
262+ faiss_cfg .num_shards = len (shards )
263+
223264 self .shards = shards
224265
225- def search (self , q : NDArray , k : int ) -> tuple [NDArray , NDArray ]:
266+ def search (self , q : NDArray , k : int | None ) -> tuple [NDArray , NDArray ]:
226267 """Note: if fewer than `k` examples are found FAISS will return items
227- with the index -1 and the maximum negative distance."""
268+ with the index -1 and the maximum negative distance. If `k` is `None`,
269+ all examples will be returned."""
228270 shard_distances = []
229271 shard_indices = []
230272 offset = 0
231273
232- for index in self .shards :
233- index .nprobe = self .faiss_cfg .nprobe
234- distances , indices = index .search (q , k )
274+ for shard in self .shards :
275+ shard .nprobe = self .faiss_cfg .nprobe
276+ distances , indices = shard .search (q , k or shard . ntotal )
235277
236278 indices += offset
237- offset += index .ntotal
279+ offset += shard .ntotal
238280
239281 shard_distances .append (distances )
240282 shard_indices .append (indices )
@@ -244,7 +286,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]:
244286
245287 # Rerank results overfetched from multiple shards
246288 if len (self .shards ) > 1 :
247- topk_indices = np .argsort (distances , axis = 1 )[:, :k ]
289+ topk_indices = np .argsort (distances , axis = 1 )[:, : k or self . ntotal ]
248290 indices = indices [np .arange (indices .shape [0 ])[:, None ], topk_indices ]
249291 distances = distances [np .arange (distances .shape [0 ])[:, None ], topk_indices ]
250292
0 commit comments