3737from bergson .hessians .utils import TensorDict
3838from bergson .utils import assert_type
3939
40+ Batches = list [list [list [int ]]]
41+
4042# %% [markdown]
4143# ## -1. Helper functions
4244
4345
4446# %%
45- def allocate_batches_test (doc_lengths : list [int ], N : int , workers : Optional [int ] = None ) -> list [ list [ list [ int ]]] :
47+ def allocate_batches_test (doc_lengths : list [int ], N : int , workers : Optional [int ] = None ) -> Batches :
4648 """
4749 Modification of allocate_batches to return a flat list of batches for testing.
4850
@@ -122,7 +124,7 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
122124 assert all (max (doc_lengths [i ] for i in batch ) * len (batch ) <= N for batch in batches )
123125
124126 # Round-robin assignment to workers
125- allocation : list [ list [ list [ int ]]] = [[] for _ in range (world_size )]
127+ allocation : Batches = [[] for _ in range (world_size )]
126128 for b_idx , batch in enumerate (batches ):
127129 allocation [b_idx % world_size ].append (batch )
128130
@@ -298,7 +300,7 @@ def load_dataset_step(cfg: IndexConfig) -> Dataset:
298300# %%
299301def tokenize_and_allocate_step (
300302 ds : Dataset , cfg : IndexConfig , workers : int
301- ) -> tuple [Dataset , list [ list [ list [ int ]]] , Any ]:
303+ ) -> tuple [Dataset , Batches , Any ]:
302304 """Tokenize dataset and allocate batches."""
303305 tokenizer = AutoTokenizer .from_pretrained (cfg .model , model_max_length = cfg .token_batch_size )
304306 ds = ds .map (tokenize , batched = True , fn_kwargs = dict (args = cfg .data , tokenizer = tokenizer ))
@@ -324,7 +326,7 @@ def compute_covariance(
324326 rank : int ,
325327 model : PreTrainedModel ,
326328 data : Dataset ,
327- batches_world : list [ list [ list [ int ]]] ,
329+ batches_world : Batches ,
328330 device : torch .device ,
329331 target_modules : Any ,
330332 activation_covariances : dict [str , Tensor ],
@@ -372,7 +374,7 @@ def compute_covariance(
372374def compute_covariances_step (
373375 model : PreTrainedModel ,
374376 data : Dataset ,
375- batches_world : list [ list [ list [ int ]]] ,
377+ batches_world : Batches ,
376378 device : torch .device ,
377379 target_modules : Any ,
378380 workers : int ,
@@ -522,7 +524,7 @@ def compute_eigenvalue_correction_amortized(
522524 rank : int ,
523525 model : PreTrainedModel ,
524526 data : Dataset ,
525- batches_world : list [ list [ list [ int ]]] ,
527+ batches_world : Batches ,
526528 device : torch .device ,
527529 target_modules : Any ,
528530 eigenvalue_corrections : dict [str , Tensor ],
@@ -571,7 +573,7 @@ def compute_eigenvalue_correction_amortized(
571573def compute_eigenvalue_corrections_step (
572574 model : PreTrainedModel ,
573575 data : Dataset ,
574- batches_world : list [ list [ list [ int ]]] ,
576+ batches_world : Batches ,
575577 device : torch .device ,
576578 target_modules : Any ,
577579 workers : int ,
0 commit comments