1616
1717from .utils import assert_type
1818
19+ Precision = Literal ["bf16" , "fp16" , "fp32" , "int4" , "int8" ]
20+
1921
2022@dataclass
2123class DataConfig :
@@ -48,7 +50,7 @@ class IndexConfig:
4850 fsdp : bool = False
4951 """Whether to use Fully Sharded Data Parallel (FSDP) for collecing gradients."""
5052
51- precision : Literal [ "bf16" , "fp16" , "fp32" , "int4" , "int8" ] = "bf16"
53+ precision : Precision = "bf16"
5254 """Precision to use for the model parameters."""
5355
5456 projection_dim : int = 16
@@ -99,7 +101,9 @@ def ceildiv(a: int, b: int) -> int:
99101 return - (- a // b ) # Equivalent to math.ceil(a / b) but faster for integers
100102
101103
102- def allocate_batches (doc_lengths : list [int ], N : int , world_size : Optional [int ] = None ) -> list [list [int ]]:
104+ def allocate_batches (
105+ doc_lengths : list [int ], N : int , world_size : Optional [int ] = None
106+ ) -> list [list [int ]]:
103107 """
104108 Allocate documents into batches that are then distributed evenly across
105109 a fixed number of workers.
@@ -183,7 +187,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
183187 while len (batches ) < world_size :
184188 big = batches .pop (0 ) # take the current largest
185189 if len (big ) == 1 : # cannot split a singleton
186- raise RuntimeError ("Not enough documents to give each worker at least one batch." )
190+ raise RuntimeError (
191+ "Not enough documents to give each worker at least one batch."
192+ )
187193 batches .append ([big .pop ()]) # move one doc into new batch
188194 batches .append (big ) # put the remainder back
189195 # preserve cost constraint automatically
@@ -205,7 +211,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
205211 i += 1
206212
207213 assert len (batches ) == target_batches
208- assert all (max (doc_lengths [i ] for i in batch ) * len (batch ) <= N for batch in batches )
214+ assert all (
215+ max (doc_lengths [i ] for i in batch ) * len (batch ) <= N for batch in batches
216+ )
209217
210218 # ---------------------------------------------------------------------
211219 # 4) Round-robin assignment to workers
@@ -219,7 +227,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
219227 return allocation [rank ]
220228
221229
222- def create_index (root : str , num_grads : int , grad_sizes : dict [str , int ], dtype : DTypeLike ) -> np .memmap :
230+ def create_index (
231+ root : str , num_grads : int , grad_sizes : dict [str , int ], dtype : DTypeLike
232+ ) -> np .memmap :
223233 """Create a memory-mapped file for storing structured gradients
224234 and persist metadata."""
225235 grad_path = os .path .join (root , "gradients.bin" )
@@ -310,7 +320,9 @@ def load_shard(dir: str) -> Dataset:
310320 if concatenate_gradients :
311321 unstructured_data = structured_to_unstructured (mmap )
312322 flat = pa .array (unstructured_data .reshape (- 1 ))
313- col_arrow = pa .FixedSizeListArray .from_arrays (flat , unstructured_data .shape [1 ])
323+ col_arrow = pa .FixedSizeListArray .from_arrays (
324+ flat , unstructured_data .shape [1 ]
325+ )
314326
315327 ds = ds .add_column ("gradients" , col_arrow , new_fingerprint = "gradients" )
316328 # Add a column for each module's gradient vectors
@@ -374,7 +386,9 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
374386 {"role" : "user" , "content" : assert_type (str , prompt )},
375387 {"role" : "assistant" , "content" : assert_type (str , resp )},
376388 ]
377- for prompt , resp in zip (batch [args .prompt_column ], batch [args .completion_column ])
389+ for prompt , resp in zip (
390+ batch [args .prompt_column ], batch [args .completion_column ]
391+ )
378392 ]
379393 elif args .conversation_column :
380394 # We're dealing with a conversation dataset
@@ -421,4 +435,7 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
421435def unflatten (x : torch .Tensor , shapes : dict [str , Sequence [int ]], dim : int = - 1 ):
422436 """Unflatten a tensor `x` into a dictionary of tensors with specified shapes."""
423437 numels = [math .prod (shape ) for shape in shapes .values ()]
424- return {name : x .unflatten (dim , shape ) for (name , shape ), x in zip (shapes .items (), x .split (numels , dim = dim ))}
438+ return {
439+ name : x .unflatten (dim , shape )
440+ for (name , shape ), x in zip (shapes .items (), x .split (numels , dim = dim ))
441+ }
0 commit comments