Skip to content

Commit 7df9e5e

Browse files
committed
Add a type alias Batches for readability
1 parent c7cb117 commit 7df9e5e

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@
3737
from bergson.hessians.utils import TensorDict
3838
from bergson.utils import assert_type
3939

40+
Batches = list[list[list[int]]]
41+
4042
# %% [markdown]
4143
# ## -1. Helper functions
4244

4345

4446
# %%
45-
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:
47+
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> Batches:
4648
"""
4749
Modification of allocate_batches to return a flat list of batches for testing.
4850
@@ -122,7 +124,7 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
122124
assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)
123125

124126
# Round-robin assignment to workers
125-
allocation: list[list[list[int]]] = [[] for _ in range(world_size)]
127+
allocation: Batches = [[] for _ in range(world_size)]
126128
for b_idx, batch in enumerate(batches):
127129
allocation[b_idx % world_size].append(batch)
128130

@@ -298,7 +300,7 @@ def load_dataset_step(cfg: IndexConfig) -> Dataset:
298300
# %%
299301
def tokenize_and_allocate_step(
300302
ds: Dataset, cfg: IndexConfig, workers: int
301-
) -> tuple[Dataset, list[list[list[int]]], Any]:
303+
) -> tuple[Dataset, Batches, Any]:
302304
"""Tokenize dataset and allocate batches."""
303305
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)
304306
ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer))
@@ -324,7 +326,7 @@ def compute_covariance(
324326
rank: int,
325327
model: PreTrainedModel,
326328
data: Dataset,
327-
batches_world: list[list[list[int]]],
329+
batches_world: Batches,
328330
device: torch.device,
329331
target_modules: Any,
330332
activation_covariances: dict[str, Tensor],
@@ -372,7 +374,7 @@ def compute_covariance(
372374
def compute_covariances_step(
373375
model: PreTrainedModel,
374376
data: Dataset,
375-
batches_world: list[list[list[int]]],
377+
batches_world: Batches,
376378
device: torch.device,
377379
target_modules: Any,
378380
workers: int,
@@ -522,7 +524,7 @@ def compute_eigenvalue_correction_amortized(
522524
rank: int,
523525
model: PreTrainedModel,
524526
data: Dataset,
525-
batches_world: list[list[list[int]]],
527+
batches_world: Batches,
526528
device: torch.device,
527529
target_modules: Any,
528530
eigenvalue_corrections: dict[str, Tensor],
@@ -571,7 +573,7 @@ def compute_eigenvalue_correction_amortized(
571573
def compute_eigenvalue_corrections_step(
572574
model: PreTrainedModel,
573575
data: Dataset,
574-
batches_world: list[list[list[int]]],
576+
batches_world: Batches,
575577
device: torch.device,
576578
target_modules: Any,
577579
workers: int,

0 commit comments

Comments
 (0)