1+ # %%
2+ # %load_ext autoreload
3+ # %autoreload 2
4+
5+ # %%
16"""Compute EKFAC ground truth for testing.
27
38This script computes ground truth covariance matrices, eigenvectors, and eigenvalue
611"""
712
813import argparse
14+ import builtins
915import gc
1016import json
1117import os
18+ import sys
1219from dataclasses import asdict
13- from typing import Any , Optional
20+ from typing import TYPE_CHECKING , Any , Optional
1421
1522import torch
1623import torch .distributed as dist
2633from tqdm import tqdm
2734from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig , PreTrainedModel
2835
29- from bergson .data import DataConfig , IndexConfig , pad_and_tensor , tokenize
36+ from bergson .data import DataConfig , IndexConfig , Precision , pad_and_tensor , tokenize
3037from bergson .hessians .utils import TensorDict
3138from bergson .utils import assert_type
3239
40+ # %% [markdown]
41+ # ## -1. Helper functions
3342
43+
44+ # %%
3445def allocate_batches_test (doc_lengths : list [int ], N : int , workers : Optional [int ] = None ) -> list [list [list [int ]]]:
3546 """
3647 Modification of allocate_batches to return a flat list of batches for testing.
@@ -119,6 +130,7 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
119130 return allocation
120131
121132
133+ # %%
122134def compute_covariance (
123135 rank : int ,
124136 model : PreTrainedModel ,
@@ -217,41 +229,67 @@ def compute_eigenvalue_correction_amortized(
217229 return {"total_processed_rank" : total_processed }
218230
219231
220- def main ():
221- """Main function to compute EKFAC ground truth."""
222- parser = argparse .ArgumentParser (description = "Compute EKFAC ground truth for testing" )
223- parser .add_argument (
224- "--precision" ,
225- type = str ,
226- default = "fp32" ,
227- choices = ["fp32" , "fp16" , "bf16" , "int4" , "int8" ],
228- help = "Model precision (default: fp32)" ,
229- )
230- parser .add_argument (
231- "-o" ,
232- "--output-dir" ,
233- type = str ,
234- default = None ,
235- help = "Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)" ,
236- )
237- args = parser .parse_args ()
232+ # %% [markdown]
233+ # ## 0. Hyperparameters
234+
235+
236+ # %%
237+ def parse_config () -> tuple [Precision , Optional [str ]]:
238+ """Parse command-line arguments or return defaults."""
239+ precision : Precision
240+ output_dir : Optional [str ]
241+
242+ if len (sys .argv ) > 1 and not hasattr (builtins , "__IPYTHON__" ):
243+ parser = argparse .ArgumentParser (description = "Compute EKFAC ground truth for testing" )
244+ parser .add_argument (
245+ "--precision" ,
246+ type = str ,
247+ default = "fp32" ,
248+ choices = ["fp32" , "fp16" , "bf16" , "int4" , "int8" ],
249+ help = "Model precision (default: fp32)" ,
250+ )
251+ parser .add_argument (
252+ "-o" ,
253+ "--output-dir" ,
254+ type = str ,
255+ default = None ,
256+ help = "Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)" ,
257+ )
258+ args = parser .parse_args ()
259+ precision = args .precision
260+ output_dir = args .output_dir
261+ else :
262+ # Defaults for interactive execution or running without arguments
263+ precision = "fp32"
264+ output_dir = None
238265
239266 # Set random seeds for reproducibility
240267 set_all_seeds (42 )
241268
242- # Setup paths
269+ return precision , output_dir
270+
271+
272+ if __name__ == "__main__" or TYPE_CHECKING :
273+ precision , output_dir = parse_config ()
274+
275+
276+ # %%
277+ def setup_paths_and_config (
278+ precision : Precision , output_dir : Optional [str ] = None
279+ ) -> tuple [IndexConfig , str , int , torch .device , Any , torch .dtype ]:
280+ """Setup paths and configuration object."""
243281 current_path = os .getcwd ()
244282 parent_path = os .path .join (current_path , "test_files" , "pile_100_examples" )
245- if args . output_dir is not None :
246- test_path = args . output_dir
283+ if output_dir is not None :
284+ test_path = output_dir
247285 else :
248286 test_path = os .path .join (parent_path , "ground_truth" )
249287 os .makedirs (test_path , exist_ok = True )
250288
251289 # Configuration
252290 cfg = IndexConfig (run_path = "" )
253291 cfg .model = "EleutherAI/Pythia-14m"
254- cfg .precision = args . precision
292+ cfg .precision = precision
255293 cfg .fsdp = False
256294 cfg .data = DataConfig (dataset = os .path .join (parent_path , "data" ))
257295
@@ -288,7 +326,20 @@ def main():
288326 case other :
289327 raise ValueError (f"Unsupported precision: { other } " )
290328
291- # Load model
329+ return cfg , test_path , workers , device , target_modules , dtype
330+
331+
332+ if __name__ == "__main__" or TYPE_CHECKING :
333+ cfg , test_path , workers , device , target_modules , dtype = setup_paths_and_config (precision , output_dir )
334+
335+
336+ # %% [markdown]
337+ # ## 1. Loading model and data
338+
339+
340+ # %%
341+ def load_model_step (cfg : IndexConfig , dtype : torch .dtype ) -> PreTrainedModel :
342+ """Load the model."""
292343 print (f"Loading model { cfg .model } ..." )
293344 model = AutoModelForCausalLM .from_pretrained (
294345 cfg .model ,
@@ -307,9 +358,19 @@ def main():
307358 ),
308359 torch_dtype = dtype ,
309360 )
361+ return model
362+
310363
311- # Load dataset
364+ if __name__ == "__main__" or TYPE_CHECKING :
365+ model = load_model_step (cfg , dtype )
366+
367+
368+ # %%
369+ def load_dataset_step (cfg : IndexConfig ) -> Dataset :
370+ """Load and return the dataset."""
371+ data_str = cfg .data .dataset
312372 print (f"Loading dataset from { data_str } ..." )
373+
313374 if data_str .endswith (".csv" ):
314375 ds = assert_type (Dataset , Dataset .from_csv (data_str ))
315376 elif data_str .endswith (".json" ) or data_str .endswith (".jsonl" ):
@@ -326,7 +387,18 @@ def main():
326387 raise e
327388
328389 assert isinstance (ds , Dataset )
390+ return ds
391+
329392
393+ if __name__ == "__main__" or TYPE_CHECKING :
394+ ds = load_dataset_step (cfg )
395+
396+
397+ # %%
398+ def tokenize_and_allocate_step (
399+ ds : Dataset , cfg : IndexConfig , workers : int
400+ ) -> tuple [Dataset , list [list [list [int ]]], Any ]:
401+ """Tokenize dataset and allocate batches."""
330402 tokenizer = AutoTokenizer .from_pretrained (cfg .model , model_max_length = cfg .token_batch_size )
331403 ds = ds .map (tokenize , batched = True , fn_kwargs = dict (args = cfg .data , tokenizer = tokenizer ))
332404 data = ds
@@ -335,10 +407,30 @@ def main():
335407 batches_world = allocate_batches_test (doc_lengths = ds ["length" ], N = cfg .token_batch_size , workers = workers )
336408 assert len (batches_world ) == workers
337409
338- print ("\n === Computing Covariances ===" )
410+ return data , batches_world , tokenizer
411+
412+
413+ if __name__ == "__main__" or TYPE_CHECKING :
414+ data , batches_world , tokenizer = tokenize_and_allocate_step (ds , cfg , workers )
415+
416+
417+ # %% [markdown]
418+ # ## 2. Compute activation and gradient covariance
419+
420+
421+ # %%
422+ def compute_covariances_step (
423+ model : PreTrainedModel ,
424+ data : Dataset ,
425+ batches_world : list [list [list [int ]]],
426+ device : torch .device ,
427+ target_modules : Any ,
428+ workers : int ,
429+ test_path : str ,
430+ ) -> str :
431+ """Compute covariances for all ranks and save to disk."""
339432 covariance_test_path = os .path .join (test_path , "covariances" )
340433
341- total_processed_global = 0
342434 for rank in range (workers ):
343435 covariance_test_path_rank = os .path .join (covariance_test_path , f"rank_{ rank } " )
344436 os .makedirs (covariance_test_path_rank , exist_ok = True )
@@ -362,7 +454,19 @@ def main():
362454 json .dump ({"total_processed_rank" : d ["total_processed_rank" ]}, f , indent = 4 )
363455 print (f"Rank { rank } processed { d ['total_processed_rank' ]} tokens." )
364456
365- # Combine results from all ranks
457+ return covariance_test_path
458+
459+
460+ if __name__ == "__main__" or TYPE_CHECKING :
461+ print ("\n === Computing Covariances ===" )
462+ covariance_test_path = compute_covariances_step (
463+ model , data , batches_world , device , target_modules , workers , test_path
464+ )
465+
466+
467+ # %%
468+ def combine_covariances_step (covariance_test_path : str , workers : int , device : torch .device ) -> int :
469+ """Combine covariance results from all ranks."""
366470 activation_covariances = TensorDict ({})
367471 gradient_covariances = TensorDict ({})
368472 total_processed_global = 0
@@ -401,7 +505,22 @@ def main():
401505 gc .collect ()
402506 torch .cuda .empty_cache ()
403507
404- print ("\n === Computing Eigenvectors ===" )
508+ return total_processed_global
509+
510+
511+ if __name__ == "__main__" or TYPE_CHECKING :
512+ print ("\n === Combining Covariances ===" )
513+ total_processed_global = combine_covariances_step (covariance_test_path , workers , device )
514+
515+
516+ # %% [markdown]
517+ # ## 3. Compute eigenvalues and eigenvectors
518+
519+
520+ # %%
521+ def compute_eigenvectors_step (test_path : str , device : torch .device , dtype : torch .dtype ) -> str :
522+ """Compute eigenvectors from covariances."""
523+ covariance_test_path = os .path .join (test_path , "covariances" )
405524 eigenvectors_test_path = os .path .join (test_path , "eigenvectors" )
406525 os .makedirs (eigenvectors_test_path , exist_ok = True )
407526
@@ -436,7 +555,30 @@ def main():
436555 gc .collect ()
437556 torch .cuda .empty_cache ()
438557
439- print ("\n === Computing Eigenvalue Corrections ===" )
558+ return eigenvectors_test_path
559+
560+
561+ if __name__ == "__main__" or TYPE_CHECKING :
562+ print ("\n === Computing Eigenvectors ===" )
563+ eigenvectors_test_path = compute_eigenvectors_step (test_path , device , dtype )
564+
565+
566+ # %% [markdown]
567+ # ## 4. Compute eigenvalue correction
568+
569+
570+ # %%
571+ def compute_eigenvalue_corrections_step (
572+ model : PreTrainedModel ,
573+ data : Dataset ,
574+ batches_world : list [list [list [int ]]],
575+ device : torch .device ,
576+ target_modules : Any ,
577+ workers : int ,
578+ test_path : str ,
579+ ) -> tuple [str , int ]:
580+ """Compute eigenvalue corrections for all ranks."""
581+ eigenvectors_test_path = os .path .join (test_path , "eigenvectors" )
440582 eigenvalue_correction_test_path = os .path .join (test_path , "eigenvalue_corrections" )
441583 os .makedirs (eigenvalue_correction_test_path , exist_ok = True )
442584
@@ -471,7 +613,21 @@ def main():
471613 print (f"Rank { rank } processed { d ['total_processed_rank' ]} tokens." )
472614 total_processed_global += d ["total_processed_rank" ]
473615
474- # Combine results from all ranks
616+ return eigenvalue_correction_test_path , total_processed_global
617+
618+
619+ if __name__ == "__main__" or TYPE_CHECKING :
620+ print ("\n === Computing Eigenvalue Corrections ===" )
621+ eigenvalue_correction_test_path , total_processed_global_lambda = compute_eigenvalue_corrections_step (
622+ model , data , batches_world , device , target_modules , workers , test_path
623+ )
624+
625+
626+ # %%
627+ def combine_eigenvalue_corrections_step (
628+ eigenvalue_correction_test_path : str , workers : int , device : torch .device , total_processed_global : int
629+ ) -> TensorDict :
630+ """Combine eigenvalue correction results from all ranks."""
475631 eigenvalue_corrections = TensorDict ({})
476632
477633 for rank in range (workers ):
@@ -492,9 +648,12 @@ def main():
492648 os .path .join (eigenvalue_correction_test_path , "eigenvalue_corrections.safetensors" ),
493649 )
494650
495- print ("\n === Ground Truth Computation Complete ===" )
496- print (f"Results saved to: { test_path } " )
651+ return eigenvalue_corrections
497652
498653
499- if __name__ == "__main__" :
500- main ()
654+ if __name__ == "__main__" or TYPE_CHECKING :
655+ eigenvalue_corrections = combine_eigenvalue_corrections_step (
656+ eigenvalue_correction_test_path , workers , device , total_processed_global_lambda
657+ )
658+ print ("\n === Ground Truth Computation Complete ===" )
659+ print (f"Results saved to: { test_path } " )
0 commit comments