4444)
4545from  bergson .utils  import  assert_type 
4646
47+ HEAD_CFGS  =  {
48+     "h.0.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
49+     "h.0.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
50+     "h.1.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
51+     "h.1.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
52+ }
53+ 
4754
4855class  AttnOnlyConfig (PretrainedConfig ):
4956    model_type  =  "attn_only" 
@@ -542,7 +549,7 @@ def setup_training(
542549
543550    def  compute_metrics (eval_preds ):
544551        accuracy  =  (
545-             (eval_preds .label_ids  ==  eval_preds .predictions )
552+             (eval_preds .label_ids  ==  eval_preds .predictions . argmax ( axis = - 1 ) )
546553            .astype (np .float32 )
547554            .mean ()
548555            .item ()
@@ -568,29 +575,21 @@ def compute_metrics(eval_preds):
568575        weight_decay = 0.01 ,
569576        logging_dir = f"{ output_dir }  /logs" ,
570577        logging_steps = 10 ,
571-         eval_steps = 10 ,
578+         eval_steps = 100 ,
572579        eval_strategy = "steps" ,
573580        save_strategy = "steps" ,
574-         save_steps = 1000 ,
575-         save_total_limit = 3 ,
576-         # load_best_model_at_end=True, 
577-         # metric_for_best_model="train_loss", 
578-         # greater_is_better=False, 
581+         save_steps = 10_000 ,
582+         # save_total_limit=3, 
579583        report_to = "wandb"  if  wandb  else  None ,
580584        run_name = "2-layer-transformer-SmolLM2-corpus" ,
581585        seed = 42 ,
582586        fp16 = False ,
583-         dataloader_drop_last = True ,
587+         dataloader_drop_last = False ,
584588    )
585589
586590    bergson_callback  =  GradientCollectorCallback (
587591        path = f"{ output_dir }  /gradients" ,
588-         head_cfgs = {
589-             "h.0.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
590-             "h.0.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
591-             "h.1.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
592-             "h.1.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
593-         },
592+         head_cfgs = HEAD_CFGS ,
594593        projection_dim = projection_dim ,
595594        dtype = np .float32 ,
596595        accumulate_grads = False ,
@@ -618,8 +617,6 @@ def build_induction_index(
618617    model , induction_prompts , output_dir , projection_dim , unit_norm 
619618):
620619    """Build static query Bergson index using synthetic induction head data.""" 
621-     print ("Building Bergson index for induction head queries..." )
622- 
623620    # Convert induction prompts to dataset format 
624621    induction_dataset  =  Dataset .from_list (induction_prompts )
625622
@@ -638,12 +635,7 @@ def build_induction_index(
638635        processor = processor ,
639636        path = f"{ output_dir }  /induction_gradients" ,
640637        skip_preconditioners = True ,
641-         head_cfgs = {
642-             "h.0.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
643-             "h.0.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
644-             "h.1.attn.c_attn" : HeadConfig (12 , 192 , 2 ),
645-             "h.1.attn.c_proj" : HeadConfig (12 , 64 , 2 ),
646-         },
638+         head_cfgs = HEAD_CFGS ,
647639    )
648640
649641    # Build the attributor for querying 
@@ -656,7 +648,6 @@ def build_induction_index(
656648    )
657649
658650    # Collect mean gradient from attributor index 
659- 
660651    mean_gradient  =  torch .cat ([grad  for  grad  in  attributor .grads .values ()], dim = 1 ).mean (
661652        dim = 0 
662653    )
@@ -683,8 +674,8 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories")
683674
684675
685676def  main (args ):
686-     #  dataset_name = "EleutherAI/SmolLM2-135M-10B"
687-     dataset_name  =  "RonenEldan/TinyStories" 
677+     dataset_name  =  "EleutherAI/SmolLM2-135M-10B" 
678+     #  dataset_name = "RonenEldan/TinyStories"
688679    num_train_epochs  =  1 
689680
690681    unit_norm  =  args .unit_norm 
@@ -723,7 +714,7 @@ def main(args):
723714        # Set up training 
724715        # Load data 
725716        if  args .small :
726-             train_dataset , _  =  load_data (tokenizer , name = dataset_name , N = 10_000 )
717+             train_dataset , _  =  load_data (tokenizer , name = dataset_name , N = 20_000 )
727718        else :
728719            train_dataset , _  =  load_data (tokenizer , name = dataset_name )
729720
0 commit comments