Skip to content

Commit e655ec6

Browse files
committed
Update induction heads eval dataset
1 parent eda302f commit e655ec6

File tree

1 file changed

+84
-22
lines changed

1 file changed

+84
-22
lines changed

examples/find_induction_heads.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def create_induction_head_dataset(tokenizer, seed, num_prompts=100):
506506
"text": text,
507507
}
508508
)
509-
return dataset
509+
return Dataset.from_list(dataset)
510510

511511

512512
def test_induction_head_labels(tokenizer):
@@ -547,16 +547,83 @@ def setup_training(
547547
):
548548
"""Set up the training configuration with Bergson callback."""
549549

550+
pad_id = -100
551+
550552
def compute_metrics(eval_preds):
551-
accuracy = (
552-
(eval_preds.label_ids == eval_preds.predictions.argmax(axis=-1))
553-
.astype(np.float32)
554-
.mean()
555-
.item()
556-
)
557-
return {
558-
"accuracy": accuracy,
559-
}
553+
# predictions: (B, T, V)
554+
# label_ids: with your collator, this equals input_ids: (B, T)
555+
preds = eval_preds.predictions
556+
input_ids = eval_preds.label_ids
557+
558+
correct = 0
559+
total = 0
560+
# for each sequence, evaluate the final next-token prediction
561+
for i in range(input_ids.shape[0]):
562+
seq = input_ids[i]
563+
# last non-pad index j
564+
non_pad = np.where(seq != pad_id)[0]
565+
if len(non_pad) == 0:
566+
continue
567+
j = non_pad[-1]
568+
if j == 0:
569+
continue # nothing to predict
570+
pred_tok = preds[i, j - 1].argmax(-1)
571+
tgt_tok = seq[j]
572+
correct += int(pred_tok == tgt_tok)
573+
total += 1
574+
575+
# avoid div-by-zero
576+
acc = (correct / total) if total > 0 else 0.0
577+
return {"accuracy": acc}
578+
579+
# def compute_metrics(eval_preds):
580+
# print("compute_metrics")
581+
# # predictions: (B, T, V)
582+
# preds = eval_preds.predictions
583+
# label_ids = eval_preds.label_ids
584+
585+
# correct = 0
586+
# total = 0
587+
588+
# # how many examples to print
589+
# max_print = 5
590+
# printed = 0
591+
592+
# for i in range(label_ids.shape[0]):
593+
# seq = label_ids[i]
594+
# # last non-pad index j
595+
# non_pad = np.where(seq != pad_id)[0]
596+
# if len(non_pad) == 0:
597+
# continue
598+
# j = non_pad[-1]
599+
# if j == 0:
600+
# continue
601+
602+
# # predicted token at position j-1 (predicting token j)
603+
# pred_logits = preds[i, j - 1]
604+
# pred_tok = pred_logits.argmax(-1)
605+
# tgt_tok = seq[j]
606+
607+
# correct += int(pred_tok == tgt_tok)
608+
# total += 1
609+
610+
# # Trigger additional info approximately 1% of the time
611+
# if random.random() < 0.01:
612+
# if printed < max_print:
613+
# seq_str = tokenizer.decode(seq[:j + 1], skip_special_tokens=True)
614+
# pred_str = tokenizer.decode([pred_tok])
615+
# tgt_str = tokenizer.decode([tgt_tok])
616+
# print("=" * 40)
617+
# print(f"Example {i}")
618+
# print(f"Context up to target: {seq_str}")
619+
# print(f"Target token id: {tgt_tok} ({tgt_str})")
620+
# print(f"Predicted token id: {pred_tok} ({pred_str})")
621+
# print(f"Match? {pred_tok == tgt_tok}")
622+
# printed += 1
623+
624+
# acc = correct / total
625+
626+
# return {"accuracy": acc}
560627

561628
data_collator = DataCollatorForLanguageModeling(
562629
tokenizer=tokenizer,
@@ -614,12 +681,9 @@ def compute_metrics(eval_preds):
614681

615682

616683
def build_induction_index(
617-
model, induction_prompts, output_dir, projection_dim, unit_norm
684+
model, induction_dataset, output_dir, projection_dim, unit_norm
618685
):
619686
"""Build static query Bergson index using synthetic induction head data."""
620-
# Convert induction prompts to dataset format
621-
induction_dataset = Dataset.from_list(induction_prompts)
622-
623687
# Create gradient processor
624688
processor = GradientProcessor(
625689
{},
@@ -706,7 +770,7 @@ def main(args):
706770

707771
# # Create induction head dataset
708772
# test_induction_head_labels(tokenizer)
709-
induction_prompts = create_induction_head_dataset(
773+
induction_dataset = create_induction_head_dataset(
710774
tokenizer, seed=seed, num_prompts=100
711775
)
712776

@@ -718,20 +782,18 @@ def main(args):
718782
else:
719783
train_dataset, _ = load_data(tokenizer, name=dataset_name)
720784

721-
eval_dataset = Dataset.from_list(induction_prompts)
722-
723785
trainer = setup_training(
724786
model,
725787
tokenizer,
726788
train_dataset,
727-
eval_dataset,
789+
eval_dataset=induction_dataset,
728790
output_dir=output_dir,
729791
projection_dim=projection_dim,
730792
wandb=False,
731793
num_train_epochs=num_train_epochs,
732794
)
733795

734-
trainer.train()
796+
trainer.train() # resume_from_checkpoint=True)
735797
trainer.save_model(output_dir)
736798
tokenizer.save_pretrained(output_dir)
737799

@@ -748,7 +810,7 @@ def main(args):
748810

749811
# Build Bergson index for induction head queries
750812
mean_induction_gradients, module_induction_gradients = build_induction_index(
751-
model, induction_prompts, output_dir, projection_dim, unit_norm
813+
model, induction_dataset, output_dir, projection_dim, unit_norm
752814
)
753815
model = model.cpu()
754816

@@ -833,7 +895,7 @@ def main(args):
833895
print(f"Loaded module scores from {df_path}")
834896
else:
835897
data = []
836-
for epoch_idx in range(trainer.args.num_train_epochs):
898+
for epoch_idx in range(num_train_epochs):
837899
grads = Attributor(
838900
index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}",
839901
device="cpu",
@@ -1025,7 +1087,7 @@ def main(args):
10251087
from argparse import ArgumentParser
10261088

10271089
parser = ArgumentParser()
1028-
parser.add_argument("--projection_dim", type=int, default=128)
1090+
parser.add_argument("--projection_dim", type=int, default=16)
10291091
parser.add_argument("--seed", type=int, default=0)
10301092
parser.add_argument("--train", action="store_true")
10311093
parser.add_argument("--unit_norm", action="store_true")

0 commit comments

Comments
 (0)