Skip to content

Commit 5e33a30

Browse files
committed
Make compute_ekfac_ground_truth usable as a notebook
Use jupytext percent format which is directly interpretable by vscode (and can also be converted into ipynb using `jupytext --to notebook`). We want compute_ekfac_ground_truth to be: - Importable from other files. So it should be split in functions and importing it shouldn't have side-effects. - Usable a script we run from. So it should have a main that parses input arguments and run everything. - Usable as a notebook. So it should be split into cells where each cell can be executed individually and produce some output. To gain back usability as a notebook without compromising the other usecases, we split the logic that used to be in `main()` in multiple statements guarded by `if __name__ == "__main__"` at the end of the cell that defines the relevant function (since each of these guarded statement defines some variable, we actually need `or TYPE_CHECKING` to ensure they are visible to the typechecker).
1 parent 51a25a6 commit 5e33a30

File tree

1 file changed

+195
-36
lines changed

1 file changed

+195
-36
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 195 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# %%
2+
# %load_ext autoreload
3+
# %autoreload 2
4+
5+
# %%
16
"""Compute EKFAC ground truth for testing.
27
38
This script computes ground truth covariance matrices, eigenvectors, and eigenvalue
@@ -6,11 +11,13 @@
611
"""
712

813
import argparse
14+
import builtins
915
import gc
1016
import json
1117
import os
18+
import sys
1219
from dataclasses import asdict
13-
from typing import Any, Optional
20+
from typing import TYPE_CHECKING, Any, Optional
1421

1522
import torch
1623
import torch.distributed as dist
@@ -26,11 +33,15 @@
2633
from tqdm import tqdm
2734
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel
2835

29-
from bergson.data import DataConfig, IndexConfig, pad_and_tensor, tokenize
36+
from bergson.data import DataConfig, IndexConfig, Precision, pad_and_tensor, tokenize
3037
from bergson.hessians.utils import TensorDict
3138
from bergson.utils import assert_type
3239

40+
# %% [markdown]
41+
# ## -1. Helper functions
3342

43+
44+
# %%
3445
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:
3546
"""
3647
Modification of allocate_batches to return a flat list of batches for testing.
@@ -119,6 +130,7 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
119130
return allocation
120131

121132

133+
# %%
122134
def compute_covariance(
123135
rank: int,
124136
model: PreTrainedModel,
@@ -217,41 +229,67 @@ def compute_eigenvalue_correction_amortized(
217229
return {"total_processed_rank": total_processed}
218230

219231

220-
def main():
221-
"""Main function to compute EKFAC ground truth."""
222-
parser = argparse.ArgumentParser(description="Compute EKFAC ground truth for testing")
223-
parser.add_argument(
224-
"--precision",
225-
type=str,
226-
default="fp32",
227-
choices=["fp32", "fp16", "bf16", "int4", "int8"],
228-
help="Model precision (default: fp32)",
229-
)
230-
parser.add_argument(
231-
"-o",
232-
"--output-dir",
233-
type=str,
234-
default=None,
235-
help="Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)",
236-
)
237-
args = parser.parse_args()
232+
# %% [markdown]
233+
# ## 0. Hyperparameters
234+
235+
236+
# %%
237+
def parse_config() -> tuple[Precision, Optional[str]]:
238+
"""Parse command-line arguments or return defaults."""
239+
precision: Precision
240+
output_dir: Optional[str]
241+
242+
if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"):
243+
parser = argparse.ArgumentParser(description="Compute EKFAC ground truth for testing")
244+
parser.add_argument(
245+
"--precision",
246+
type=str,
247+
default="fp32",
248+
choices=["fp32", "fp16", "bf16", "int4", "int8"],
249+
help="Model precision (default: fp32)",
250+
)
251+
parser.add_argument(
252+
"-o",
253+
"--output-dir",
254+
type=str,
255+
default=None,
256+
help="Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)",
257+
)
258+
args = parser.parse_args()
259+
precision = args.precision
260+
output_dir = args.output_dir
261+
else:
262+
# Defaults for interactive execution or running without arguments
263+
precision = "fp32"
264+
output_dir = None
238265

239266
# Set random seeds for reproducibility
240267
set_all_seeds(42)
241268

242-
# Setup paths
269+
return precision, output_dir
270+
271+
272+
if __name__ == "__main__" or TYPE_CHECKING:
273+
precision, output_dir = parse_config()
274+
275+
276+
# %%
277+
def setup_paths_and_config(
278+
precision: Precision, output_dir: Optional[str] = None
279+
) -> tuple[IndexConfig, str, int, torch.device, Any, torch.dtype]:
280+
"""Setup paths and configuration object."""
243281
current_path = os.getcwd()
244282
parent_path = os.path.join(current_path, "test_files", "pile_100_examples")
245-
if args.output_dir is not None:
246-
test_path = args.output_dir
283+
if output_dir is not None:
284+
test_path = output_dir
247285
else:
248286
test_path = os.path.join(parent_path, "ground_truth")
249287
os.makedirs(test_path, exist_ok=True)
250288

251289
# Configuration
252290
cfg = IndexConfig(run_path="")
253291
cfg.model = "EleutherAI/Pythia-14m"
254-
cfg.precision = args.precision
292+
cfg.precision = precision
255293
cfg.fsdp = False
256294
cfg.data = DataConfig(dataset=os.path.join(parent_path, "data"))
257295

@@ -288,7 +326,20 @@ def main():
288326
case other:
289327
raise ValueError(f"Unsupported precision: {other}")
290328

291-
# Load model
329+
return cfg, test_path, workers, device, target_modules, dtype
330+
331+
332+
if __name__ == "__main__" or TYPE_CHECKING:
333+
cfg, test_path, workers, device, target_modules, dtype = setup_paths_and_config(precision, output_dir)
334+
335+
336+
# %% [markdown]
337+
# ## 1. Loading model and data
338+
339+
340+
# %%
341+
def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel:
342+
"""Load the model."""
292343
print(f"Loading model {cfg.model}...")
293344
model = AutoModelForCausalLM.from_pretrained(
294345
cfg.model,
@@ -307,9 +358,19 @@ def main():
307358
),
308359
torch_dtype=dtype,
309360
)
361+
return model
362+
310363

311-
# Load dataset
364+
if __name__ == "__main__" or TYPE_CHECKING:
365+
model = load_model_step(cfg, dtype)
366+
367+
368+
# %%
369+
def load_dataset_step(cfg: IndexConfig) -> Dataset:
370+
"""Load and return the dataset."""
371+
data_str = cfg.data.dataset
312372
print(f"Loading dataset from {data_str}...")
373+
313374
if data_str.endswith(".csv"):
314375
ds = assert_type(Dataset, Dataset.from_csv(data_str))
315376
elif data_str.endswith(".json") or data_str.endswith(".jsonl"):
@@ -326,7 +387,18 @@ def main():
326387
raise e
327388

328389
assert isinstance(ds, Dataset)
390+
return ds
391+
329392

393+
if __name__ == "__main__" or TYPE_CHECKING:
394+
ds = load_dataset_step(cfg)
395+
396+
397+
# %%
398+
def tokenize_and_allocate_step(
399+
ds: Dataset, cfg: IndexConfig, workers: int
400+
) -> tuple[Dataset, list[list[list[int]]], Any]:
401+
"""Tokenize dataset and allocate batches."""
330402
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)
331403
ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer))
332404
data = ds
@@ -335,10 +407,30 @@ def main():
335407
batches_world = allocate_batches_test(doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers)
336408
assert len(batches_world) == workers
337409

338-
print("\n=== Computing Covariances ===")
410+
return data, batches_world, tokenizer
411+
412+
413+
if __name__ == "__main__" or TYPE_CHECKING:
414+
data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers)
415+
416+
417+
# %% [markdown]
418+
# ## 2. Compute activation and gradient covariance
419+
420+
421+
# %%
422+
def compute_covariances_step(
423+
model: PreTrainedModel,
424+
data: Dataset,
425+
batches_world: list[list[list[int]]],
426+
device: torch.device,
427+
target_modules: Any,
428+
workers: int,
429+
test_path: str,
430+
) -> str:
431+
"""Compute covariances for all ranks and save to disk."""
339432
covariance_test_path = os.path.join(test_path, "covariances")
340433

341-
total_processed_global = 0
342434
for rank in range(workers):
343435
covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}")
344436
os.makedirs(covariance_test_path_rank, exist_ok=True)
@@ -362,7 +454,19 @@ def main():
362454
json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
363455
print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
364456

365-
# Combine results from all ranks
457+
return covariance_test_path
458+
459+
460+
if __name__ == "__main__" or TYPE_CHECKING:
461+
print("\n=== Computing Covariances ===")
462+
covariance_test_path = compute_covariances_step(
463+
model, data, batches_world, device, target_modules, workers, test_path
464+
)
465+
466+
467+
# %%
468+
def combine_covariances_step(covariance_test_path: str, workers: int, device: torch.device) -> int:
469+
"""Combine covariance results from all ranks."""
366470
activation_covariances = TensorDict({})
367471
gradient_covariances = TensorDict({})
368472
total_processed_global = 0
@@ -401,7 +505,22 @@ def main():
401505
gc.collect()
402506
torch.cuda.empty_cache()
403507

404-
print("\n=== Computing Eigenvectors ===")
508+
return total_processed_global
509+
510+
511+
if __name__ == "__main__" or TYPE_CHECKING:
512+
print("\n=== Combining Covariances ===")
513+
total_processed_global = combine_covariances_step(covariance_test_path, workers, device)
514+
515+
516+
# %% [markdown]
517+
# ## 3. Compute eigenvalues and eigenvectors
518+
519+
520+
# %%
521+
def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch.dtype) -> str:
522+
"""Compute eigenvectors from covariances."""
523+
covariance_test_path = os.path.join(test_path, "covariances")
405524
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
406525
os.makedirs(eigenvectors_test_path, exist_ok=True)
407526

@@ -436,7 +555,30 @@ def main():
436555
gc.collect()
437556
torch.cuda.empty_cache()
438557

439-
print("\n=== Computing Eigenvalue Corrections ===")
558+
return eigenvectors_test_path
559+
560+
561+
if __name__ == "__main__" or TYPE_CHECKING:
562+
print("\n=== Computing Eigenvectors ===")
563+
eigenvectors_test_path = compute_eigenvectors_step(test_path, device, dtype)
564+
565+
566+
# %% [markdown]
567+
# ## 4. Compute eigenvalue correction
568+
569+
570+
# %%
571+
def compute_eigenvalue_corrections_step(
572+
model: PreTrainedModel,
573+
data: Dataset,
574+
batches_world: list[list[list[int]]],
575+
device: torch.device,
576+
target_modules: Any,
577+
workers: int,
578+
test_path: str,
579+
) -> tuple[str, int]:
580+
"""Compute eigenvalue corrections for all ranks."""
581+
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
440582
eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections")
441583
os.makedirs(eigenvalue_correction_test_path, exist_ok=True)
442584

@@ -471,7 +613,21 @@ def main():
471613
print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
472614
total_processed_global += d["total_processed_rank"]
473615

474-
# Combine results from all ranks
616+
return eigenvalue_correction_test_path, total_processed_global
617+
618+
619+
if __name__ == "__main__" or TYPE_CHECKING:
620+
print("\n=== Computing Eigenvalue Corrections ===")
621+
eigenvalue_correction_test_path, total_processed_global_lambda = compute_eigenvalue_corrections_step(
622+
model, data, batches_world, device, target_modules, workers, test_path
623+
)
624+
625+
626+
# %%
627+
def combine_eigenvalue_corrections_step(
628+
eigenvalue_correction_test_path: str, workers: int, device: torch.device, total_processed_global: int
629+
) -> TensorDict:
630+
"""Combine eigenvalue correction results from all ranks."""
475631
eigenvalue_corrections = TensorDict({})
476632

477633
for rank in range(workers):
@@ -492,9 +648,12 @@ def main():
492648
os.path.join(eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"),
493649
)
494650

495-
print("\n=== Ground Truth Computation Complete ===")
496-
print(f"Results saved to: {test_path}")
651+
return eigenvalue_corrections
497652

498653

499-
if __name__ == "__main__":
500-
main()
654+
if __name__ == "__main__" or TYPE_CHECKING:
655+
eigenvalue_corrections = combine_eigenvalue_corrections_step(
656+
eigenvalue_correction_test_path, workers, device, total_processed_global_lambda
657+
)
658+
print("\n=== Ground Truth Computation Complete ===")
659+
print(f"Results saved to: {test_path}")

0 commit comments

Comments
 (0)