@@ -506,7 +506,7 @@ def create_induction_head_dataset(tokenizer, seed, num_prompts=100):
506506 "text" : text ,
507507 }
508508 )
509- return dataset
509+ return Dataset . from_list ( dataset )
510510
511511
512512def test_induction_head_labels (tokenizer ):
@@ -547,16 +547,83 @@ def setup_training(
547547):
548548 """Set up the training configuration with Bergson callback."""
549549
550+ pad_id = - 100
551+
550552 def compute_metrics (eval_preds ):
551- accuracy = (
552- (eval_preds .label_ids == eval_preds .predictions .argmax (axis = - 1 ))
553- .astype (np .float32 )
554- .mean ()
555- .item ()
556- )
557- return {
558- "accuracy" : accuracy ,
559- }
553+ # predictions: (B, T, V)
554+ # label_ids: with your collator, this equals input_ids: (B, T)
555+ preds = eval_preds .predictions
556+ input_ids = eval_preds .label_ids
557+
558+ correct = 0
559+ total = 0
560+ # for each sequence, evaluate the final next-token prediction
561+ for i in range (input_ids .shape [0 ]):
562+ seq = input_ids [i ]
563+ # last non-pad index j
564+ non_pad = np .where (seq != pad_id )[0 ]
565+ if len (non_pad ) == 0 :
566+ continue
567+ j = non_pad [- 1 ]
568+ if j == 0 :
569+ continue # nothing to predict
570+ pred_tok = preds [i , j - 1 ].argmax (- 1 )
571+ tgt_tok = seq [j ]
572+ correct += int (pred_tok == tgt_tok )
573+ total += 1
574+
575+ # avoid div-by-zero
576+ acc = (correct / total ) if total > 0 else 0.0
577+ return {"accuracy" : acc }
578+
579+ # def compute_metrics(eval_preds):
580+ # print("compute_metrics")
581+ # # predictions: (B, T, V)
582+ # preds = eval_preds.predictions
583+ # label_ids = eval_preds.label_ids
584+
585+ # correct = 0
586+ # total = 0
587+
588+ # # how many examples to print
589+ # max_print = 5
590+ # printed = 0
591+
592+ # for i in range(label_ids.shape[0]):
593+ # seq = label_ids[i]
594+ # # last non-pad index j
595+ # non_pad = np.where(seq != pad_id)[0]
596+ # if len(non_pad) == 0:
597+ # continue
598+ # j = non_pad[-1]
599+ # if j == 0:
600+ # continue
601+
602+ # # predicted token at position j-1 (predicting token j)
603+ # pred_logits = preds[i, j - 1]
604+ # pred_tok = pred_logits.argmax(-1)
605+ # tgt_tok = seq[j]
606+
607+ # correct += int(pred_tok == tgt_tok)
608+ # total += 1
609+
610+ # # Trigger additional info approximately 1% of the time
611+ # if random.random() < 0.01:
612+ # if printed < max_print:
613+ # seq_str = tokenizer.decode(seq[:j + 1], skip_special_tokens=True)
614+ # pred_str = tokenizer.decode([pred_tok])
615+ # tgt_str = tokenizer.decode([tgt_tok])
616+ # print("=" * 40)
617+ # print(f"Example {i}")
618+ # print(f"Context up to target: {seq_str}")
619+ # print(f"Target token id: {tgt_tok} ({tgt_str})")
620+ # print(f"Predicted token id: {pred_tok} ({pred_str})")
621+ # print(f"Match? {pred_tok == tgt_tok}")
622+ # printed += 1
623+
624+ # acc = correct / total
625+
626+ # return {"accuracy": acc}
560627
561628 data_collator = DataCollatorForLanguageModeling (
562629 tokenizer = tokenizer ,
@@ -614,12 +681,9 @@ def compute_metrics(eval_preds):
614681
615682
616683def build_induction_index (
617- model , induction_prompts , output_dir , projection_dim , unit_norm
684+ model , induction_dataset , output_dir , projection_dim , unit_norm
618685):
619686 """Build static query Bergson index using synthetic induction head data."""
620- # Convert induction prompts to dataset format
621- induction_dataset = Dataset .from_list (induction_prompts )
622-
623687 # Create gradient processor
624688 processor = GradientProcessor (
625689 {},
@@ -706,7 +770,7 @@ def main(args):
706770
707771 # # Create induction head dataset
708772 # test_induction_head_labels(tokenizer)
709- induction_prompts = create_induction_head_dataset (
773+ induction_dataset = create_induction_head_dataset (
710774 tokenizer , seed = seed , num_prompts = 100
711775 )
712776
@@ -718,20 +782,18 @@ def main(args):
718782 else :
719783 train_dataset , _ = load_data (tokenizer , name = dataset_name )
720784
721- eval_dataset = Dataset .from_list (induction_prompts )
722-
723785 trainer = setup_training (
724786 model ,
725787 tokenizer ,
726788 train_dataset ,
727- eval_dataset ,
789+ eval_dataset = induction_dataset ,
728790 output_dir = output_dir ,
729791 projection_dim = projection_dim ,
730792 wandb = False ,
731793 num_train_epochs = num_train_epochs ,
732794 )
733795
734- trainer .train ()
796+ trainer .train () # resume_from_checkpoint=True)
735797 trainer .save_model (output_dir )
736798 tokenizer .save_pretrained (output_dir )
737799
@@ -748,7 +810,7 @@ def main(args):
748810
749811 # Build Bergson index for induction head queries
750812 mean_induction_gradients , module_induction_gradients = build_induction_index (
751- model , induction_prompts , output_dir , projection_dim , unit_norm
813+ model , induction_dataset , output_dir , projection_dim , unit_norm
752814 )
753815 model = model .cpu ()
754816
@@ -833,7 +895,7 @@ def main(args):
833895 print (f"Loaded module scores from { df_path } " )
834896 else :
835897 data = []
836- for epoch_idx in range (trainer . args . num_train_epochs ):
898+ for epoch_idx in range (num_train_epochs ):
837899 grads = Attributor (
838900 index_path = f"{ trainer .args .output_dir } /gradients/train/epoch_{ epoch_idx } " ,
839901 device = "cpu" ,
@@ -1025,7 +1087,7 @@ def main(args):
10251087 from argparse import ArgumentParser
10261088
10271089 parser = ArgumentParser ()
1028- parser .add_argument ("--projection_dim" , type = int , default = 128 )
1090+ parser .add_argument ("--projection_dim" , type = int , default = 16 )
10291091 parser .add_argument ("--seed" , type = int , default = 0 )
10301092 parser .add_argument ("--train" , action = "store_true" )
10311093 parser .add_argument ("--unit_norm" , action = "store_true" )
0 commit comments