3131from test_utils import set_all_seeds
3232from torch import Tensor
3333from tqdm import tqdm
34- from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig , PreTrainedModel
34+ from transformers import (
35+ AutoModelForCausalLM ,
36+ AutoTokenizer ,
37+ BitsAndBytesConfig ,
38+ PreTrainedModel ,
39+ )
3540
3641from bergson .data import DataConfig , IndexConfig , Precision , pad_and_tensor , tokenize
3742from bergson .hessians .utils import TensorDict
4449
4550
4651# %%
47- def allocate_batches_test (doc_lengths : list [int ], N : int , workers : Optional [int ] = None ) -> Batches :
52+ def allocate_batches_test (
53+ doc_lengths : list [int ], N : int , workers : Optional [int ] = None
54+ ) -> Batches :
4855 """
4956 Modification of allocate_batches to return a flat list of batches for testing.
5057
@@ -103,7 +110,9 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
103110 while len (batches ) < world_size :
104111 big = batches .pop (0 )
105112 if len (big ) == 1 :
106- raise RuntimeError ("Not enough documents to give each worker at least one batch." )
113+ raise RuntimeError (
114+ "Not enough documents to give each worker at least one batch."
115+ )
107116 batches .append ([big .pop ()])
108117 batches .append (big )
109118
@@ -121,7 +130,9 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
121130 i += 1
122131
123132 assert len (batches ) == target_batches
124- assert all (max (doc_lengths [i ] for i in batch ) * len (batch ) <= N for batch in batches )
133+ assert all (
134+ max (doc_lengths [i ] for i in batch ) * len (batch ) <= N for batch in batches
135+ )
125136
126137 # Round-robin assignment to workers
127138 allocation : Batches = [[] for _ in range (world_size )]
@@ -143,7 +154,9 @@ def parse_config() -> tuple[Precision, Optional[str]]:
143154 output_dir : Optional [str ]
144155
145156 if len (sys .argv ) > 1 and not hasattr (builtins , "__IPYTHON__" ):
146- parser = argparse .ArgumentParser (description = "Compute EKFAC ground truth for testing" )
157+ parser = argparse .ArgumentParser (
158+ description = "Compute EKFAC ground truth for testing"
159+ )
147160 parser .add_argument (
148161 "--precision" ,
149162 type = str ,
@@ -233,7 +246,9 @@ def setup_paths_and_config(
233246
234247
235248if __name__ == "__main__" or TYPE_CHECKING :
236- cfg , test_path , workers , device , target_modules , dtype = setup_paths_and_config (precision , output_dir )
249+ cfg , test_path , workers , device , target_modules , dtype = setup_paths_and_config (
250+ precision , output_dir
251+ )
237252
238253
239254# %% [markdown]
@@ -259,7 +274,7 @@ def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel:
259274 if cfg .precision in ("int4" , "int8" )
260275 else None
261276 ),
262- torch_dtype = dtype ,
277+ dtype = dtype ,
263278 )
264279 return model
265280
@@ -282,7 +297,9 @@ def load_dataset_step(cfg: IndexConfig) -> Dataset:
282297 try :
283298 ds = load_dataset (data_str , split = "train" )
284299 if isinstance (ds , (DatasetDict , IterableDatasetDict )):
285- raise NotImplementedError ("DatasetDicts and IterableDatasetDicts are not supported." )
300+ raise NotImplementedError (
301+ "DatasetDicts and IterableDatasetDicts are not supported."
302+ )
286303 except ValueError as e :
287304 if "load_from_disk" in str (e ):
288305 ds = Dataset .load_from_disk (data_str , keep_in_memory = False )
@@ -302,12 +319,18 @@ def tokenize_and_allocate_step(
302319 ds : Dataset , cfg : IndexConfig , workers : int
303320) -> tuple [Dataset , Batches , Any ]:
304321 """Tokenize dataset and allocate batches."""
305- tokenizer = AutoTokenizer .from_pretrained (cfg .model , model_max_length = cfg .token_batch_size )
306- ds = ds .map (tokenize , batched = True , fn_kwargs = dict (args = cfg .data , tokenizer = tokenizer ))
322+ tokenizer = AutoTokenizer .from_pretrained (
323+ cfg .model , model_max_length = cfg .token_batch_size
324+ )
325+ ds = ds .map (
326+ tokenize , batched = True , fn_kwargs = dict (args = cfg .data , tokenizer = tokenizer )
327+ )
307328 data = ds
308329
309330 # Allocate batches
310- batches_world = allocate_batches_test (doc_lengths = ds ["length" ], N = cfg .token_batch_size , workers = workers )
331+ batches_world = allocate_batches_test (
332+ doc_lengths = ds ["length" ], N = cfg .token_batch_size , workers = workers
333+ )
311334 assert len (batches_world ) == workers
312335
313336 return data , batches_world , tokenizer
@@ -400,8 +423,16 @@ def compute_covariances_step(
400423 gradient_covariances = gradient_covariances ,
401424 )
402425
403- save_file (activation_covariances , os .path .join (covariance_test_path_rank , "activation_covariance.safetensors" ))
404- save_file (gradient_covariances , os .path .join (covariance_test_path_rank , "gradient_covariance.safetensors" ))
426+ save_file (
427+ activation_covariances ,
428+ os .path .join (
429+ covariance_test_path_rank , "activation_covariance.safetensors"
430+ ),
431+ )
432+ save_file (
433+ gradient_covariances ,
434+ os .path .join (covariance_test_path_rank , "gradient_covariance.safetensors" ),
435+ )
405436 with open (os .path .join (covariance_test_path_rank , "stats.json" ), "w" ) as f :
406437 json .dump ({"total_processed_rank" : d ["total_processed_rank" ]}, f , indent = 4 )
407438 print (f"Rank { rank } processed { d ['total_processed_rank' ]} tokens." )
@@ -417,7 +448,9 @@ def compute_covariances_step(
417448
418449
419450# %%
420- def combine_covariances_step (covariance_test_path : str , workers : int , device : torch .device ) -> int :
451+ def combine_covariances_step (
452+ covariance_test_path : str , workers : int , device : torch .device
453+ ) -> int :
421454 """Combine covariance results from all ranks."""
422455 activation_covariances = TensorDict ({})
423456 gradient_covariances = TensorDict ({})
@@ -431,25 +464,41 @@ def combine_covariances_step(covariance_test_path: str, workers: int, device: to
431464 total_processed_global += d ["total_processed_rank" ]
432465
433466 activation_covariances_rank = TensorDict (
434- load_file (os .path .join (covariance_test_path_rank , "activation_covariance.safetensors" ))
467+ load_file (
468+ os .path .join (
469+ covariance_test_path_rank , "activation_covariance.safetensors"
470+ )
471+ )
435472 ).to (device )
436473
437474 gradient_covariances_rank = TensorDict (
438- load_file (os .path .join (covariance_test_path_rank , "gradient_covariance.safetensors" ))
475+ load_file (
476+ os .path .join (
477+ covariance_test_path_rank , "gradient_covariance.safetensors"
478+ )
479+ )
439480 ).to (device )
440481
441482 if not activation_covariances :
442483 activation_covariances = activation_covariances_rank
443484 else :
444- activation_covariances = activation_covariances + activation_covariances_rank
485+ activation_covariances = (
486+ activation_covariances + activation_covariances_rank
487+ )
445488
446489 if not gradient_covariances :
447490 gradient_covariances = gradient_covariances_rank
448491 else :
449492 gradient_covariances = gradient_covariances + gradient_covariances_rank
450493
451- save_file (activation_covariances .to_dict (), os .path .join (covariance_test_path , "activation_covariance.safetensors" ))
452- save_file (gradient_covariances .to_dict (), os .path .join (covariance_test_path , "gradient_covariance.safetensors" ))
494+ save_file (
495+ activation_covariances .to_dict (),
496+ os .path .join (covariance_test_path , "activation_covariance.safetensors" ),
497+ )
498+ save_file (
499+ gradient_covariances .to_dict (),
500+ os .path .join (covariance_test_path , "gradient_covariance.safetensors" ),
501+ )
453502 with open (os .path .join (covariance_test_path , "stats.json" ), "w" ) as f :
454503 json .dump ({"total_processed_global" : total_processed_global }, f , indent = 4 )
455504 print (f"Global processed { total_processed_global } tokens." )
@@ -462,15 +511,19 @@ def combine_covariances_step(covariance_test_path: str, workers: int, device: to
462511
463512if __name__ == "__main__" or TYPE_CHECKING :
464513 print ("\n === Combining Covariances ===" )
465- total_processed_global = combine_covariances_step (covariance_test_path , workers , device )
514+ total_processed_global = combine_covariances_step (
515+ covariance_test_path , workers , device
516+ )
466517
467518
468519# %% [markdown]
469520# ## 3. Compute eigenvalues and eigenvectors
470521
471522
472523# %%
473- def compute_eigenvectors_step (test_path : str , device : torch .device , dtype : torch .dtype ) -> str :
524+ def compute_eigenvectors_step (
525+ test_path : str , device : torch .device , dtype : torch .dtype
526+ ) -> str :
474527 """Compute eigenvectors from covariances."""
475528 covariance_test_path = os .path .join (test_path , "covariances" )
476529 eigenvectors_test_path = os .path .join (test_path , "eigenvectors" )
@@ -481,8 +534,12 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
481534 d = json .load (f )
482535 total_processed_global = d ["total_processed_global" ]
483536
484- activation_covariances = load_file (os .path .join (covariance_test_path , "activation_covariance.safetensors" ))
485- gradient_covariances = load_file (os .path .join (covariance_test_path , "gradient_covariance.safetensors" ))
537+ activation_covariances = load_file (
538+ os .path .join (covariance_test_path , "activation_covariance.safetensors" )
539+ )
540+ gradient_covariances = load_file (
541+ os .path .join (covariance_test_path , "gradient_covariance.safetensors" )
542+ )
486543
487544 eigenvectors_activations = {}
488545 eigenvectors_gradients = {}
@@ -497,12 +554,20 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
497554
498555 eigenvalues_a , eigenvectors_a = torch .linalg .eigh (a )
499556 eigenvalues_g , eigenvectors_g = torch .linalg .eigh (g )
500- print (f"{ name } : eigenvectors_a.sum()={ eigenvectors_a .sum ()} , eigenvectors_g.sum()={ eigenvectors_g .sum ()} " )
557+ print (
558+ f"{ name } : eigenvectors_a.sum()={ eigenvectors_a .sum ()} , eigenvectors_g.sum()={ eigenvectors_g .sum ()} "
559+ )
501560 eigenvectors_activations [name ] = eigenvectors_a .to (dtype = dtype ).contiguous ()
502561 eigenvectors_gradients [name ] = eigenvectors_g .to (dtype = dtype ).contiguous ()
503562
504- save_file (eigenvectors_activations , os .path .join (eigenvectors_test_path , "eigenvectors_activations.safetensors" ))
505- save_file (eigenvectors_gradients , os .path .join (eigenvectors_test_path , "eigenvectors_gradients.safetensors" ))
563+ save_file (
564+ eigenvectors_activations ,
565+ os .path .join (eigenvectors_test_path , "eigenvectors_activations.safetensors" ),
566+ )
567+ save_file (
568+ eigenvectors_gradients ,
569+ os .path .join (eigenvectors_test_path , "eigenvectors_gradients.safetensors" ),
570+ )
506571
507572 gc .collect ()
508573 torch .cuda .empty_cache ()
@@ -585,12 +650,18 @@ def compute_eigenvalue_corrections_step(
585650 os .makedirs (eigenvalue_correction_test_path , exist_ok = True )
586651
587652 # Load eigenvectors
588- eigenvectors_activations = load_file (os .path .join (eigenvectors_test_path , "eigenvectors_activations.safetensors" ))
589- eigenvectors_gradients = load_file (os .path .join (eigenvectors_test_path , "eigenvectors_gradients.safetensors" ))
653+ eigenvectors_activations = load_file (
654+ os .path .join (eigenvectors_test_path , "eigenvectors_activations.safetensors" )
655+ )
656+ eigenvectors_gradients = load_file (
657+ os .path .join (eigenvectors_test_path , "eigenvectors_gradients.safetensors" )
658+ )
590659
591660 total_processed_global = 0
592661 for rank in range (workers ):
593- eigenvalue_correction_test_path_rank = os .path .join (eigenvalue_correction_test_path , f"rank_{ rank } " )
662+ eigenvalue_correction_test_path_rank = os .path .join (
663+ eigenvalue_correction_test_path , f"rank_{ rank } "
664+ )
594665 os .makedirs (eigenvalue_correction_test_path_rank , exist_ok = True )
595666
596667 eigenvalue_corrections = {}
@@ -608,9 +679,14 @@ def compute_eigenvalue_corrections_step(
608679
609680 save_file (
610681 eigenvalue_corrections ,
611- os .path .join (eigenvalue_correction_test_path_rank , "eigenvalue_corrections.safetensors" ),
682+ os .path .join (
683+ eigenvalue_correction_test_path_rank ,
684+ "eigenvalue_corrections.safetensors" ,
685+ ),
612686 )
613- with open (os .path .join (eigenvalue_correction_test_path_rank , "stats.json" ), "w" ) as f :
687+ with open (
688+ os .path .join (eigenvalue_correction_test_path_rank , "stats.json" ), "w"
689+ ) as f :
614690 json .dump ({"total_processed_rank" : d ["total_processed_rank" ]}, f , indent = 4 )
615691 print (f"Rank { rank } processed { d ['total_processed_rank' ]} tokens." )
616692 total_processed_global += d ["total_processed_rank" ]
@@ -620,34 +696,50 @@ def compute_eigenvalue_corrections_step(
620696
621697if __name__ == "__main__" or TYPE_CHECKING :
622698 print ("\n === Computing Eigenvalue Corrections ===" )
623- eigenvalue_correction_test_path , total_processed_global_lambda = compute_eigenvalue_corrections_step (
624- model , data , batches_world , device , target_modules , workers , test_path
699+ eigenvalue_correction_test_path , total_processed_global_lambda = (
700+ compute_eigenvalue_corrections_step (
701+ model , data , batches_world , device , target_modules , workers , test_path
702+ )
625703 )
626704
627705
628706# %%
629707def combine_eigenvalue_corrections_step (
630- eigenvalue_correction_test_path : str , workers : int , device : torch .device , total_processed_global : int
708+ eigenvalue_correction_test_path : str ,
709+ workers : int ,
710+ device : torch .device ,
711+ total_processed_global : int ,
631712) -> TensorDict :
632713 """Combine eigenvalue correction results from all ranks."""
633714 eigenvalue_corrections = TensorDict ({})
634715
635716 for rank in range (workers ):
636- eigenvalue_correction_test_path_rank = os .path .join (eigenvalue_correction_test_path , f"rank_{ rank } " )
717+ eigenvalue_correction_test_path_rank = os .path .join (
718+ eigenvalue_correction_test_path , f"rank_{ rank } "
719+ )
637720
638721 eigenvalue_corrections_rank = TensorDict (
639- load_file (os .path .join (eigenvalue_correction_test_path_rank , "eigenvalue_corrections.safetensors" ))
722+ load_file (
723+ os .path .join (
724+ eigenvalue_correction_test_path_rank ,
725+ "eigenvalue_corrections.safetensors" ,
726+ )
727+ )
640728 ).to (device )
641729
642730 if not eigenvalue_corrections :
643731 eigenvalue_corrections = eigenvalue_corrections_rank
644732 else :
645- eigenvalue_corrections = eigenvalue_corrections + eigenvalue_corrections_rank
733+ eigenvalue_corrections = (
734+ eigenvalue_corrections + eigenvalue_corrections_rank
735+ )
646736
647737 eigenvalue_corrections .div_ (total_processed_global )
648738 save_file (
649739 eigenvalue_corrections .to_dict (),
650- os .path .join (eigenvalue_correction_test_path , "eigenvalue_corrections.safetensors" ),
740+ os .path .join (
741+ eigenvalue_correction_test_path , "eigenvalue_corrections.safetensors"
742+ ),
651743 )
652744
653745 return eigenvalue_corrections
0 commit comments