From 007d0bdcd759594c59a8ff13b9b618c4da466156 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 10 Sep 2025 11:11:08 +0000 Subject: [PATCH 01/15] Add induction heads script --- bergson/collection.py | 4 +- bergson/huggingface.py | 32 ++ examples/pretrain_transformer.py | 531 +++++++++++++++++++++++++++++++ 3 files changed, 565 insertions(+), 2 deletions(-) create mode 100644 examples/pretrain_transformer.py diff --git a/bergson/collection.py b/bergson/collection.py index a65b25a..eed2f3f 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -109,7 +109,7 @@ def callback(name: str, g: torch.Tensor): if "advantage" in batch: losses *= torch.tensor(batch["advantage"], device=losses.device) - losses.mean().backward() + losses.sum().backward() else: with collector: logits = model(x).logits[:, :-1] @@ -123,7 +123,7 @@ def callback(name: str, g: torch.Tensor): if "advantage" in batch: losses *= torch.tensor(batch["advantage"], device=losses.device) - losses.mean().backward() + losses.sum().backward() # Weirdly you need to explicitly synchronize here in order to make sure that # the nonblocking copies actually finish before we call .numpy() diff --git a/bergson/huggingface.py b/bergson/huggingface.py index cc7130f..ff1eab3 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -71,6 +71,7 @@ def __init__( self.mod_grads = {} self.batch_indices: Tensor | None = None + self.training_order: list[dict] = [] # TODO: Handle this more elegantly self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 @@ -83,6 +84,12 @@ def write_grads(self, grad_buffer: np.memmap): self.mod_grads.clear() + def on_step_begin(self, args, state, control, **kwargs): + """Track the current step and epoch for training order recording.""" + if self.track_training_order: + self._current_step = state.global_step + self._current_epoch = int(state.epoch or 0) + def on_train_begin( self, args: TrainingArguments, @@ -158,6 +165,7 @@ def on_epoch_begin( # Set up the gradient buffers for the evaluation datasets if eval_dataloader is None: + print("No evaluation dataloader found") return elif isinstance(eval_dataloader, dict): eval_datasets = eval_dataloader @@ -212,6 +220,12 @@ def on_module_backward(self, name: str, g: Tensor): device="cpu", dtype=self.torch_dtype, non_blocking=True ) + if (self.mod_grads[name].pow(2).sum(dim=1) == 0).any(): + print( + f"{self.mod_grads[name].pow(2).sum(dim=1).eq(0).sum().item()} " + "sum of squares == 0 rows found in gradients after fp16" + ) + def on_substep_end( self, args: TrainingArguments, @@ -258,6 +272,24 @@ def on_step_end( if not self.use_optimizer_state: return + # Record training order if enabled + if self.training_order is not None: + if self.batch_indices is None: + raise ValueError( + "Batch indices are not available for training order tracking" + ) + + rank = dist.get_rank() if dist.is_initialized() else 0 + self.training_order.extend( + { + "_idx": int(idx), + "rank": rank, + "global_step": getattr(self, "_current_step", 0), + "epoch": getattr(self, "_current_epoch", 0), + } + for idx in self.batch_indices.tolist() + ) + # The optimizer doesn't actually know the names of the parameters model = getattr(model, "base_model", model) param_to_name = { diff --git a/examples/pretrain_transformer.py b/examples/pretrain_transformer.py new file mode 100644 index 0000000..cad66ea --- /dev/null +++ b/examples/pretrain_transformer.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python3 +""" +Pretrain a two-layer transformer and try to identify the formation of induction heads +from the influence functions wrt simple induction head completions gradients. + +This script: +1. Creates a 2-layer transformer using HF transformers architecture +2. Trains on TinyStories dataset using HF Trainer with Bergson callback +3. Builds a static query Bergson index using synthetic induction head data +4. Uploads the trained model to HF hub +""" + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pyarrow.parquet as pq +import torch +from datasets import Dataset, load_dataset +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + GPTNeoConfig, + GPTNeoForCausalLM, + Trainer, + TrainingArguments, +) + +import wandb +from bergson.attributor import Attributor + +# from bergson.data import load_gradient_dataset +from bergson.collection import collect_gradients +from bergson.gradients import GradientProcessor +from bergson.huggingface import ( + GradientCollectorCallback, + prepare_for_gradient_collection, +) + + +def check_logins(): + """Check if user is logged into HF hub and wandb.""" + print("Checking authentication...") + + # Check HF hub login + try: + from huggingface_hub import whoami + + whoami() + print("✓ Logged into Hugging Face Hub") + except Exception as e: + print("✗ Not logged into Hugging Face Hub. Please run: huggingface-cli login") + raise e + + # Check wandb login + try: + wandb.login() + print("✓ Logged into Weights & Biases") + except Exception as e: + print("✗ Not logged into Weights & Biases. Please run: wandb login") + raise e + + +def create_transformer(): + """Create a transformer model using GPTNeo architecture.""" + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/TinyStories-restricted") + + # TODO use the EleutherAI 10k token tokenizer custom-built for TinyStories + # Padding and truncation = True + config = GPTNeoConfig( + vocab_size=len(tokenizer), + hidden_size=256, + intermediate_size=1024, + num_layers=2, + num_heads=4, + max_position_embeddings=1024, + attention_types=[[["global"], 2]], + window_size=256, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + # Token IDs from the tokenizer + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + model = GPTNeoForCausalLM(config) + + # Set pad token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print( + f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters" + ) + return model, tokenizer + + +def load_tinystories_data(tokenizer, max_length=512, N=10000): + """Load and preprocess TinyStories dataset.""" + dataset = load_dataset("roneneldan/TinyStories", split="train") + dataset = dataset.select(range(min(N, len(dataset)))) + + def tokenize_function(examples): + # Tokenize the text + tokenized = tokenizer( + examples["text"], + truncation=True, + padding=False, + max_length=max_length, + return_tensors=None, + ) + + # For language modeling, labels are the same as input_ids + # TODO probably remove this + # tokenized["labels"] = tokenized["input_ids"].copy() + + return tokenized + + # Tokenize the dataset + tokenized_dataset = dataset.map( + tokenize_function, + batched=True, + remove_columns=dataset.column_names, + desc="Tokenizing dataset", + ) + + # Split into train/eval + train_eval = tokenized_dataset.train_test_split(test_size=0.1, seed=42) + train_dataset = train_eval["train"] + eval_dataset = train_eval["test"] + + print(f"Training samples: {len(train_dataset)}") + print(f"Evaluation samples: {len(eval_dataset)}") + + return train_dataset, eval_dataset + + +def create_induction_head_dataset(tokenizer, num_prompts=10): + """Create synthetic induction head dataset for building the query index.""" + print(f"Creating {num_prompts} synthetic induction head prompts...") + + # Create induction head patterns: [A][B] ... [A] -> [B] + # These are designed to test if the model learns to copy tokens + # from earlier in the sequence + + # Generate diverse induction head patterns + patterns = [ + "The cat sat on the mat. The cat", + "Once upon a time, there was a princess. Once upon a time", + "In the forest, the wolf howled. In the forest", + "The sun shines bright today. The sun", + "My favorite color is blue. My favorite color", + "The dog ran in the park. The dog", + "She loves to read books. She loves", + "The moon is full tonight. The moon", + "He plays guitar every day. He plays", + "The bird sings a sweet song. The bird", + ] + + # Take the requested number of prompts + selected_prompts = patterns[:num_prompts] + + # Tokenize the prompts + tokenized_prompts = [] + for prompt in selected_prompts: + # Split into input and target (everything after the last space) + parts = prompt.rsplit(" ", 1) + if len(parts) == 2: + input_text, target = parts + tokenized = tokenizer( + input_text, + return_tensors="pt", + padding=False, + truncation=True, + max_length=512, + ) + # Get the target token ID + target_tokens = tokenizer( + target, return_tensors="pt", add_special_tokens=False + ) + if target_tokens["input_ids"].numel() > 0: + target_token_id = target_tokens["input_ids"][0, 0].item() + tokenized_prompts.append( + { + "input_ids": tokenized["input_ids"][0], + "attention_mask": tokenized["attention_mask"][0], + "target_token": target_token_id, + "text": prompt, + } + ) + + print(f"Created {len(tokenized_prompts)} induction head prompts") + return tokenized_prompts + + +def setup_training( + model, + tokenizer, + train_dataset, + eval_dataset, + output_dir: str, + projection_dim: int, + wandb: bool = True, +): + """Set up the training configuration with Bergson callback.""" + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + ) + + training_args = TrainingArguments( + output_dir=output_dir, + overwrite_output_dir=True, + num_train_epochs=3, + per_device_train_batch_size=16, + # per_device_eval_batch_size=16, + gradient_accumulation_steps=1, + warmup_steps=100, + learning_rate=5e-4, + weight_decay=0.01, + logging_dir=f"{output_dir}/logs", + logging_steps=10, + # save_strategy="steps", + # save_steps=1000, + # save_total_limit=3, + # load_best_model_at_end=True, + # metric_for_best_model="train_loss", + # greater_is_better=False, + report_to="wandb" if wandb else None, + run_name="2-layer-transformer-tinystories", + seed=42, + fp16=False, + dataloader_drop_last=True, + ) + + bergson_callback = GradientCollectorCallback( + path=f"{output_dir}/gradients", + projection_dim=projection_dim, + dtype=np.float32, + accumulate_grads=False, + track_training_order=True, + ) + + # Create trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + # eval_dataset=eval_dataset, + data_collator=data_collator, + callbacks=[bergson_callback], + ) + + # Prepare for gradient collection + trainer = prepare_for_gradient_collection(trainer) + + return trainer + + +def build_induction_index(model, induction_prompts, output_dir, projection_dim): + """Build static query Bergson index using synthetic induction head data.""" + print("Building Bergson index for induction head queries...") + + # Convert induction prompts to dataset format + induction_data = [] + for prompt_data in induction_prompts: + # Create a simple dataset entry + induction_data.append( + { + "input_ids": prompt_data["input_ids"].tolist(), + "attention_mask": prompt_data["attention_mask"].tolist(), + "labels": prompt_data["input_ids"].tolist(), # For language modeling + "text": prompt_data["text"], + } + ) + + induction_dataset = Dataset.from_list(induction_data) + + # Create gradient processor + processor = GradientProcessor( + {}, + projection_dim=projection_dim, + reshape_to_square=False, + ) + + # Collect gradients for the induction head dataset + print("Collecting gradients for induction head dataset...") + collect_gradients( + model=model, + data=induction_dataset, + processor=processor, + path=f"{output_dir}/induction_gradients", + skip_preconditioners=False, + ) + + # Build the attributor for querying + print("Building attributor for querying...") + attributor = Attributor( + index_path=f"{output_dir}/induction_gradients", + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, + ) + + # Collect mean gradient from attributor index + mean_gradient = attributor.grads.mean(dim=0) + + print("In-context index built successfully! Returning mean gradient...") + return mean_gradient + + +def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories"): + """Upload the trained model to Hugging Face Hub.""" + print(f"Uploading model to Hugging Face Hub as {model_name}...") + + try: + # Push model and tokenizer + model.push_to_hub(model_name) + tokenizer.push_to_hub(model_name) + print(f"✓ Successfully uploaded to https://huggingface.co/{model_name}") + except Exception as e: + print(f"✗ Failed to upload to HF Hub: {e}") + raise e + + +def main(projection_dim=128): + tag = "" + k = 1000 + unit_norm = False + + print( + "Starting 2-layer transformer pretraining with Bergson gradient collection..." + ) + + # Check authentication + check_logins() + + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Create model and tokenizer + model, tokenizer = create_transformer() + model = model.to(device) + + # # Load TinyStories data + train_dataset, eval_dataset = load_tinystories_data(tokenizer) + + # # Create induction head dataset + induction_prompts = create_induction_head_dataset(tokenizer) + + # # Set up training + trainer = setup_training( + model, + tokenizer, + train_dataset, + eval_dataset, + output_dir=f"examples/runs/transformer_2_layer{'_' + tag if tag else ''}", + projection_dim=projection_dim, + wandb=False, + ) + + trainer.train() + + # trainer.save_model(trainer.args.output_dir) + # tokenizer.save_pretrained(trainer.args.output_dir) + + # upload_to_hub(model, tokenizer) + + # Reload model and tokenizer + model = GPTNeoForCausalLM.from_pretrained(trainer.args.output_dir) + tokenizer = AutoTokenizer.from_pretrained(trainer.args.output_dir) + model = model.to(device) + + # Build Bergson index for induction head queries + mean_induction_gradients = build_induction_index( + model, induction_prompts, trainer.args.output_dir, projection_dim + ) + model = model.cpu() + + # Read Bergson index from training + epoch_attributors = [ + Attributor( + str( + Path(trainer.args.output_dir) / "gradients" / "train" / f"epoch_{epoch}" + ), + device=device, + unit_norm=unit_norm, + dtype=torch.float32, + # faiss_cfg=FaissConfig( + ) + for epoch in [0] # range(trainer.args.num_train_epochs) + ] + # Load parquet table containing training order + training_order = pq.read_table( + str(Path(trainer.args.output_dir) / "gradients" / "training_order.parquet") + ).to_pandas() + + # Test the attributor with a sample query + print("Testing Bergson index with sample query...") + test_prompt = "The cat sat on the mat. The cat" + test_input = tokenizer(test_prompt, return_tensors="pt").to(device) + + # Mask out everything except the last token in the labels + test_input["labels"] = test_input["input_ids"].clone() + test_input["labels"][:, :-1] = -100 + + top_data = [] + + model = model.to(device) + for epoch_idx, epoch_attributor in enumerate(epoch_attributors): + # print(f"Top {k} most influential training examples for epoch {epoch_idx}:") + + with epoch_attributor.trace(model.base_model, k=k) as result: + outputs = model(**test_input) + outputs.loss.backward() + model.zero_grad() + + skips = 0 + for i, (score, idx) in enumerate( + zip(result.scores.squeeze(), result.indices.squeeze()) + ): + + if idx.item() != -1: + # Get the training order + training_metadata = training_order[ + (training_order["_idx"] == idx.item()) + & (training_order["epoch"] == epoch_idx) + ] + if training_metadata.empty: + skips += 1 + continue + for row in training_metadata.itertuples(index=False): + # print(f"{i+1}. Score: {score.item():.4f}, + # Global step: {row.global_step}, Index: {idx.item()}") + top_data.append( + { + "epoch": epoch_idx, + "global_step": row.global_step, + "index": idx.item(), + "score": score.item(), + } + ) + print(f"Skipped {skips} examples for epoch {epoch_idx}") + + top_data = pd.DataFrame(top_data) + + # Scatter plot of scores over time + plt.figure(figsize=(12, 8)) + + for epoch in sorted(top_data["epoch"].unique()): + epoch_data = top_data[top_data["epoch"] == epoch] + plt.scatter( + epoch_data["global_step"], + epoch_data["score"], + alpha=0.6, + s=20, + label=f"Epoch {epoch}", + ) + + plt.xlabel("Cumulative Training Steps") + plt.ylabel("Influence Score") + plt.title("Most Influential Training Examples Per Epoch (Normalized)") + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'training_dynamics{"_" + tag if tag else ""}{"_norm" if unit_norm else ""}.pdf' + ) + plt.savefig( + fig_name, + format="pdf", + bbox_inches="tight", + ) + plt.show() + + # Calculate the inner products with the training gradients + data = [] + for epoch_idx, attributor in enumerate(epoch_attributors): + inner_products = attributor.grads.float() @ mean_induction_gradients.float() + for i, score in enumerate(inner_products.squeeze()): + training_metadata = training_order[ + (training_order["_idx"] == i) & (training_order["epoch"] == epoch_idx) + ] + if len(training_metadata) != 1: + continue + + for row in training_metadata.itertuples(index=False): + data.append( + { + "epoch": epoch_idx, + "global_step": row.global_step, + "index": i, + "score": score.item(), + } + ) + data = pd.DataFrame(data) + + plt.figure(figsize=(12, 8)) + for epoch in sorted(data["epoch"].unique()): + epoch_data = data[data["epoch"] == epoch] + plt.scatter( + epoch_data["global_step"], + epoch_data["score"], + alpha=0.6, + s=20, + label=f"Epoch {epoch}", + ) + + plt.xlabel("Cumulative Training Steps") + plt.ylabel("Influence Score") + plt.title("Most Influential Training Examples Per Epoch (Normalized)") + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'training_dynamics_mean_induction{"_" + tag if tag else ""}' + f'{"_norm" if unit_norm else ""}.pdf' + ) + plt.savefig( + fig_name, + format="pdf", + bbox_inches="tight", + ) + plt.show() + + +if __name__ == "__main__": + main() From 0ef1dce7b5d2a5e586e9441796a8802cc9a6ba12 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 10 Sep 2025 21:30:40 +0000 Subject: [PATCH 02/15] Rename example script --- bergson/huggingface.py | 3 ++- examples/{pretrain_transformer.py => find_induction_heads.py} | 0 2 files changed, 2 insertions(+), 1 deletion(-) rename examples/{pretrain_transformer.py => find_induction_heads.py} (100%) diff --git a/bergson/huggingface.py b/bergson/huggingface.py index ff1eab3..1ef00cf 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -5,6 +5,8 @@ from typing import Sized import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq import torch import torch.distributed as dist from datasets import Dataset @@ -354,7 +356,6 @@ def on_train_end( def _save_order(self): """Save the training order to disk, handling distributed training.""" - assert self.order is not None os.makedirs(self.path, exist_ok=True) if dist.is_initialized(): diff --git a/examples/pretrain_transformer.py b/examples/find_induction_heads.py similarity index 100% rename from examples/pretrain_transformer.py rename to examples/find_induction_heads.py From 0e8fb08d4616ed1c94c4c8f03499a4f50b1115f9 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 11 Sep 2025 00:08:10 +0000 Subject: [PATCH 03/15] Add module plots --- bergson/attributor.py | 62 ++++++++++++ bergson/huggingface.py | 4 +- examples/find_induction_heads.py | 160 ++++++++++++++++++++++++++++--- 3 files changed, 212 insertions(+), 14 deletions(-) diff --git a/bergson/attributor.py b/bergson/attributor.py index eb2c1af..82f1fc5 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -43,6 +43,7 @@ def __init__( dtype: torch.dtype = torch.float32, unit_norm: bool = False, faiss_cfg: FaissConfig | None = None, + modules: bool = False, ): self.device = device self.dtype = dtype @@ -120,6 +121,67 @@ def search( return torch.topk(scores, k) + def search_module( + self, queries: Tensor, k: int, module: str + ) -> tuple[Tensor, Tensor]: + """ + Search for the `k` nearest examples in the index based on the query or queries. + If fewer than `k` examples are found FAISS will return items with the index -1 + and the maximum negative distance. + + Args: + queries: The query tensor of shape [..., d]. + k: The number of nearest examples to return for each query. + nprobe: The number of FAISS vector clusters to search if using ANN. + + Returns: + A namedtuple containing the top `k` indices and inner products for each + query. Both have shape [..., k]. + """ + assert isinstance( + self.grads, dict + ), "Gradients must be a dictionary of tensors." + assert module in self.grads, f"Module {module} not found in gradients." + + k = min(k, self.grads[module].shape[0]) + + q = queries + + if self.unit_norm: + q /= q.norm(dim=1, keepdim=True) + + if not self.faiss_cfg: + return torch.topk(q.to(self.device) @ self.grads[module].mT, k) + + q = q.cpu().numpy() + + shard_distances = [] + shard_indices = [] + offset = 0 + + for index in self.faiss_shards: + index.nprobe = self.faiss_cfg.nprobe + distances, indices = index.search(q, k) + + indices += offset + offset += index.ntotal + + shard_distances.append(distances) + shard_indices.append(indices) + + distances = np.concatenate(shard_distances, axis=1) + indices = np.concatenate(shard_indices, axis=1) + + # Rerank results overfetched from multiple shards + if len(self.faiss_shards) > 1: + topk_indices = np.argsort(distances, axis=1)[:, :k] + indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] + distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] + + return torch.from_numpy(distances.squeeze()), torch.from_numpy( + indices.squeeze() + ) + @contextmanager def trace( self, module: nn.Module, k: int, *, precondition: bool = False diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 1ef00cf..c2a5946 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -33,7 +33,6 @@ def __init__( path: str, head_cfgs: dict[str, HeadConfig], projection_dim: int = 16, - dtype: DTypeLike = np.float16, accumulate_grads: bool = False, use_optimizer_state: bool = True, track_order: bool = False, @@ -74,6 +73,7 @@ def __init__( self.mod_grads = {} self.batch_indices: Tensor | None = None self.training_order: list[dict] = [] + self.torch_dtype = torch_dtype # TODO: Handle this more elegantly self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 @@ -225,7 +225,7 @@ def on_module_backward(self, name: str, g: Tensor): if (self.mod_grads[name].pow(2).sum(dim=1) == 0).any(): print( f"{self.mod_grads[name].pow(2).sum(dim=1).eq(0).sum().item()} " - "sum of squares == 0 rows found in gradients after fp16" + f"sum of squares == 0 rows found in gradients after {self.torch_dtype}" ) def on_substep_end( diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index cad66ea..9dca5d3 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -216,7 +216,7 @@ def setup_training( training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, - num_train_epochs=3, + num_train_epochs=2, per_device_train_batch_size=16, # per_device_eval_batch_size=16, gradient_accumulation_steps=1, @@ -242,6 +242,7 @@ def setup_training( path=f"{output_dir}/gradients", projection_dim=projection_dim, dtype=np.float32, + torch_dtype=torch.float32, accumulate_grads=False, track_training_order=True, ) @@ -278,6 +279,10 @@ def build_induction_index(model, induction_prompts, output_dir, projection_dim): "text": prompt_data["text"], } ) + # Mask out everything except the last token in the labels + labels = [-100] * len(prompt_data["input_ids"]) + labels[-1] = prompt_data["input_ids"][-1] + induction_data[-1]["labels"] = labels induction_dataset = Dataset.from_list(induction_data) @@ -309,8 +314,19 @@ def build_induction_index(model, induction_prompts, output_dir, projection_dim): # Collect mean gradient from attributor index mean_gradient = attributor.grads.mean(dim=0) + attributor = Attributor( + index_path=f"{output_dir}/induction_gradients", + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, + modules=True, + ) + + mean_module_gradients = { + name: attributor.grads[name].mean(dim=0) for name in attributor.grads.keys() + } + print("In-context index built successfully! Returning mean gradient...") - return mean_gradient + return mean_gradient, mean_module_gradients def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories"): @@ -328,9 +344,9 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(projection_dim=128): - tag = "" + tag = "mask_query" k = 1000 - unit_norm = False + unit_norm = True print( "Starting 2-layer transformer pretraining with Bergson gradient collection..." @@ -364,10 +380,10 @@ def main(projection_dim=128): wandb=False, ) - trainer.train() + # trainer.train() - # trainer.save_model(trainer.args.output_dir) - # tokenizer.save_pretrained(trainer.args.output_dir) + trainer.save_model(trainer.args.output_dir) + tokenizer.save_pretrained(trainer.args.output_dir) # upload_to_hub(model, tokenizer) @@ -377,7 +393,7 @@ def main(projection_dim=128): model = model.to(device) # Build Bergson index for induction head queries - mean_induction_gradients = build_induction_index( + mean_induction_gradients, module_induction_gradients = build_induction_index( model, induction_prompts, trainer.args.output_dir, projection_dim ) model = model.cpu() @@ -393,7 +409,7 @@ def main(projection_dim=128): dtype=torch.float32, # faiss_cfg=FaissConfig( ) - for epoch in [0] # range(trainer.args.num_train_epochs) + for epoch in range(trainer.args.num_train_epochs) ] # Load parquet table containing training order training_order = pq.read_table( @@ -464,7 +480,10 @@ def main(projection_dim=128): plt.xlabel("Cumulative Training Steps") plt.ylabel("Influence Score") - plt.title("Most Influential Training Examples Per Epoch (Normalized)") + plt.title( + f"Most Influential Training Examples Per Epoch " + f"({'Normalized' if unit_norm else 'Unnormalized'})" + ) plt.legend() plt.grid(True, alpha=0.3) fig_name = ( @@ -512,7 +531,10 @@ def main(projection_dim=128): plt.xlabel("Cumulative Training Steps") plt.ylabel("Influence Score") - plt.title("Most Influential Training Examples Per Epoch (Normalized)") + plt.title( + f"Most Influential Training Examples Per Epoch " + f"({'Normalized' if unit_norm else 'Unnormalized'})" + ) plt.legend() plt.grid(True, alpha=0.3) fig_name = ( @@ -524,7 +546,121 @@ def main(projection_dim=128): format="pdf", bbox_inches="tight", ) - plt.show() + + # Produce the same plot but split out by module (i.e. key in the grads mmap) + + # Second, produce the module-wise scores + del epoch_attributors + import os + + os.makedirs("module_figures", exist_ok=True) + + for epoch_idx in range(trainer.args.num_train_epochs): + module_attributor = Attributor( + index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", + device=device, + dtype=torch.float32, + modules=True, + ) + for name, grads in module_attributor.grads.items(): + data = [] + inner_products = grads.float() @ module_induction_gradients[name].float() + for i, score in enumerate(inner_products.squeeze()): + training_metadata = training_order[ + (training_order["_idx"] == i) + & (training_order["epoch"] == epoch_idx) + ] + if len(training_metadata) != 1: + continue + for row in training_metadata.itertuples(index=False): + data.append( + { + "global_step": row.global_step, + "epoch": epoch_idx, + "module": name, + "score": score.item(), + } + ) + data = pd.DataFrame(data) + module_data = data + + plt.figure(figsize=(12, 8)) + + plt.scatter( + module_data["global_step"], + module_data["score"], + alpha=0.6, + s=20, + label=f"Module {name}", + ) + plt.xlabel("Training Step") + plt.ylabel("Influence Score") + plt.title( + f"Most Influential Training Examples for {name} " + f"({'Normalized' if unit_norm else 'Unnormalized'})" + ) + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'training_dynamics_induction{"_" + tag if tag else ""}' + f'{"_norm" if unit_norm else ""}_{name}.pdf' + ) + plt.savefig( + os.path.join("module_figures", fig_name), + format="pdf", + bbox_inches="tight", + ) + plt.close() + + # Add a line plot with absolute sum of the gradients for each module + plt.figure(figsize=(12, 8)) + + module_data = module_data.groupby("global_step", as_index=False).agg( + score=("score", lambda s: s.abs().sum()) + ) + plt.plot( + module_data["global_step"], module_data["score"], label=f"Module {name}" + ) + plt.xlabel("Training Step") + plt.ylabel("Absolute Sum of Gradients") + plt.title(f"Absolute Sum of Gradients for {name}") + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'training_dynamics_induction_abs_sum{"_" + tag if tag else ""}' + f'{"_norm" if unit_norm else ""}_{name}.pdf' + ) + plt.savefig( + os.path.join("module_figures", fig_name), + format="pdf", + bbox_inches="tight", + ) + plt.close() + + # Add a line plot with the sum of the gradients for each module + # Sum points at each global step + module_data = module_data.groupby("global_step", as_index=False).agg( + score=("score", lambda s: s.sum()) + ) + plt.figure(figsize=(12, 8)) + plt.plot( + module_data["global_step"], module_data["score"], label=f"Module {name}" + ) + plt.xlabel("Training Step") + plt.ylabel("Sum of Gradients") + plt.title(f"Sum of Gradients for {name}") + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'training_dynamics_induction_sum{"_" + tag if tag else ""}' + f'{"_norm" if unit_norm else ""}_{name}.pdf' + ) + plt.savefig( + os.path.join("module_figures", fig_name), + format="pdf", + bbox_inches="tight", + ) + plt.close() if __name__ == "__main__": From 436fb5dfa54dd88c8579b9897e91bf738bdd7f66 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 12 Sep 2025 02:10:11 +0000 Subject: [PATCH 04/15] Clean up huggingface callback; simplify induction heads model arch --- bergson/huggingface.py | 30 +- examples/find_induction_heads.py | 813 ++++++++++++++++++++----------- 2 files changed, 535 insertions(+), 308 deletions(-) diff --git a/bergson/huggingface.py b/bergson/huggingface.py index c2a5946..542b2e4 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -5,8 +5,6 @@ from typing import Sized import numpy as np -import pyarrow as pa -import pyarrow.parquet as pq import torch import torch.distributed as dist from datasets import Dataset @@ -41,7 +39,6 @@ def __init__( Args: path: The path to save the gradients projection_dim: The dimension to project the gradients onto - dtype: The dtype of the on-disk gradient store accumulate_grads: Whether to take the sum of the gradients of the same example across epochs. If `False`, the gradients for each epoch are stored separately. @@ -72,8 +69,9 @@ def __init__( self.mod_grads = {} self.batch_indices: Tensor | None = None - self.training_order: list[dict] = [] - self.torch_dtype = torch_dtype + + # TODO: Handle this more elegantly + self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 # TODO: Handle this more elegantly self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 @@ -88,7 +86,7 @@ def write_grads(self, grad_buffer: np.memmap): def on_step_begin(self, args, state, control, **kwargs): """Track the current step and epoch for training order recording.""" - if self.track_training_order: + if self.order: self._current_step = state.global_step self._current_epoch = int(state.epoch or 0) @@ -222,12 +220,6 @@ def on_module_backward(self, name: str, g: Tensor): device="cpu", dtype=self.torch_dtype, non_blocking=True ) - if (self.mod_grads[name].pow(2).sum(dim=1) == 0).any(): - print( - f"{self.mod_grads[name].pow(2).sum(dim=1).eq(0).sum().item()} " - f"sum of squares == 0 rows found in gradients after {self.torch_dtype}" - ) - def on_substep_end( self, args: TrainingArguments, @@ -275,17 +267,14 @@ def on_step_end( return # Record training order if enabled - if self.training_order is not None: - if self.batch_indices is None: - raise ValueError( - "Batch indices are not available for training order tracking" - ) + if self.order: + assert ( + self.batch_indices is not None + ), "Batch indices are not available for training order tracking" - rank = dist.get_rank() if dist.is_initialized() else 0 - self.training_order.extend( + self.order.extend( { "_idx": int(idx), - "rank": rank, "global_step": getattr(self, "_current_step", 0), "epoch": getattr(self, "_current_epoch", 0), } @@ -356,6 +345,7 @@ def on_train_end( def _save_order(self): """Save the training order to disk, handling distributed training.""" + assert self.order is not None os.makedirs(self.path, exist_ok=True) if dist.is_initialized(): diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index 9dca5d3..7850b70 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -10,22 +10,33 @@ 4. Uploads the trained model to HF hub """ +# attn_only.py +import math +import os +import random from pathlib import Path +from typing import List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import pyarrow.parquet as pq import torch +import torch.nn as nn +import torch.nn.functional as F from datasets import Dataset, load_dataset from transformers import ( + AutoConfig, + AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, - GPTNeoConfig, - GPTNeoForCausalLM, + PretrainedConfig, + PreTrainedModel, Trainer, TrainingArguments, ) +from transformers.generation.utils import GenerationMixin +from transformers.modeling_outputs import CausalLMOutputWithPast import wandb from bergson.attributor import Attributor @@ -39,6 +50,233 @@ ) +class AttnOnlyConfig(PretrainedConfig): + model_type = "attn_only" + + def __init__( + self, + vocab_size=50257, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=2048, + layer_norm_epsilon=1e-5, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + use_cache=True, + **kwargs, + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_epsilon = layer_norm_epsilon + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.use_cache = use_cache + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: AttnOnlyConfig): + super().__init__() + assert config.hidden_size % config.num_attention_heads == 0 + self.n_head = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + self.register_buffer( + "mask", + torch.tril( + torch.ones( + config.max_position_embeddings, config.max_position_embeddings + ) + ).view( + 1, 1, config.max_position_embeddings, config.max_position_embeddings + ), + persistent=False, + ) + + def _split_heads(self, x): + B, T, C = x.shape + x = x.view(B, T, self.n_head, self.head_dim).transpose(1, 2) + return x + + def _merge_heads(self, x): + B, _, T, _ = x.shape + return x.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim) + + def forward( + self, + x, + pos_emb, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = True, + attn_mask: Optional[torch.Tensor] = None, + ): + B, T, C = x.shape + qkv = self.c_attn(x) + q, k, v = qkv.split(C, dim=2) + + # add position to q and k only + q = q + pos_emb + k = k + pos_emb + + q = self._split_heads(q) + k = self._split_heads(k) + v = self._split_heads(v) + + if layer_past is not None: + pk, pv = layer_past + k = torch.cat([pk, k], dim=2) + v = torch.cat([pv, v], dim=2) + + att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) + causal = self.mask[:, :, :T, : k.size(-2)] + att = att.masked_fill(causal == 0, float("-inf")) + if attn_mask is not None: + att = att + attn_mask + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v + y = self._merge_heads(y) + y = self.resid_drop(self.c_proj(y)) + + present = (k, v) if use_cache else None + return y, present + + +class AttnOnlyBlock(nn.Module): + def __init__(self, config: AttnOnlyConfig): + super().__init__() + # self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention(config) + + def forward( + self, + x, + pos_emb, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = True, + attn_mask: Optional[torch.Tensor] = None, + ): + # self.ln_1(x) + a, present = self.attn( + x, pos_emb, layer_past=layer_past, use_cache=use_cache, attn_mask=attn_mask + ) + x = x + a + return x, present + + +class AttnOnlyForCausalLM(PreTrainedModel, GenerationMixin): + config_class = AttnOnlyConfig + + def __init__(self, config: AttnOnlyConfig): + super().__init__(config) + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList( + [AttnOnlyBlock(config) for _ in range(config.num_hidden_layers)] + ) + # self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + # HF helpers + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_emb): + self.wte = new_emb + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_lm_head): + self.lm_head = new_lm_head + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + } + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + ) -> CausalLMOutputWithPast: + B, T = input_ids.size() + pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0) + x = self.wte(input_ids) # + self.wpe(pos) + x = self.drop(x) + + pos_emb = self.wpe(pos) + presents = [] + for i, block in enumerate(self.h): + layer_past = None if past_key_values is None else past_key_values[i] + x, present = block( + x, + pos_emb, + layer_past=layer_past, + use_cache=self.config.use_cache if use_cache is None else use_cache, + ) + if present is not None: + presents.append(present) + + # x = self.ln_f(x) + logits = self.lm_head(x) + + loss = None + if labels is not None: + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = labels[:, 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=presents if presents else None, + hidden_states=None, + attentions=None, + ) + + +AutoConfig.register("attn_only", AttnOnlyConfig) +AutoModelForCausalLM.register(AttnOnlyConfig, AttnOnlyForCausalLM) + + def check_logins(): """Check if user is logged into HF hub and wandb.""" print("Checking authentication...") @@ -64,31 +302,42 @@ def check_logins(): def create_transformer(): """Create a transformer model using GPTNeo architecture.""" - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/TinyStories-restricted") - - # TODO use the EleutherAI 10k token tokenizer custom-built for TinyStories - # Padding and truncation = True - config = GPTNeoConfig( + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") + # Alternative: use the EleutherAI 10k token tokenizer custom-built for TinyStories + + # config = GPTNeoConfig( + # vocab_size=len(tokenizer), + # hidden_size=768, + # intermediate_size=2, + # num_layers=2, + # num_heads=2, + # max_position_embeddings=1024, + # attention_types=[[["global"], 2]], + # window_size=256, + # resid_pdrop=0.0, + # embd_pdrop=0.0, + # attn_pdrop=0.0, + # layer_norm_epsilon=1e-5, + # initializer_range=0.02, + # use_cache=True, + # # Token IDs from the tokenizer + # pad_token_id=tokenizer.pad_token_id, + # bos_token_id=tokenizer.bos_token_id, + # eos_token_id=tokenizer.eos_token_id, + # ) + # model = GPTNeoForCausalLM(config) + + cfg = AttnOnlyConfig( vocab_size=len(tokenizer), - hidden_size=256, - intermediate_size=1024, - num_layers=2, - num_heads=4, + hidden_size=768, + num_hidden_layers=2, + num_attention_heads=12, max_position_embeddings=1024, - attention_types=[[["global"], 2]], - window_size=256, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - use_cache=True, - # Token IDs from the tokenizer - pad_token_id=tokenizer.pad_token_id, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, ) - model = GPTNeoForCausalLM(config) + model = AttnOnlyForCausalLM(cfg) + + # AutoConfig.register("attn_only", AttnOnlyConfig) + # AutoModelForCausalLM.register(AttnOnlyConfig, AttnOnlyForCausalLM) # Set pad token if tokenizer.pad_token is None: @@ -102,8 +351,9 @@ def create_transformer(): def load_tinystories_data(tokenizer, max_length=512, N=10000): """Load and preprocess TinyStories dataset.""" - dataset = load_dataset("roneneldan/TinyStories", split="train") - dataset = dataset.select(range(min(N, len(dataset)))) + dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split="train") + # dataset = load_dataset("roneneldan/TinyStories", split="train") + # dataset = dataset.select(range(min(N, len(dataset)))) def tokenize_function(examples): # Tokenize the text @@ -140,62 +390,124 @@ def tokenize_function(examples): return train_dataset, eval_dataset -def create_induction_head_dataset(tokenizer, num_prompts=10): - """Create synthetic induction head dataset for building the query index.""" - print(f"Creating {num_prompts} synthetic induction head prompts...") - - # Create induction head patterns: [A][B] ... [A] -> [B] - # These are designed to test if the model learns to copy tokens - # from earlier in the sequence +def build_single_token_vocab(tokenizer, wordlist, max_words=500): + singles = [] + for w in wordlist: + toks = tokenizer(w, add_special_tokens=False)["input_ids"] + if len(toks) == 1: + singles.append(w) + if len(singles) >= max_words: + break + return singles + + +def create_induction_head_dataset(tokenizer, seed, num_prompts=100): + random.seed(seed) + + # crude word list, can be expanded + base_words = [ + "cat", + "dog", + "bird", + "wolf", + "bear", + "sun", + "moon", + "star", + "book", + "tree", + "car", + "road", + "sky", + "song", + "color", + "blue", + "green", + "red", + "gold", + "day", + "night", + "king", + "queen", + "child", + "story", + ] + vocab = build_single_token_vocab(tokenizer, base_words) + print(f"Vocab size: {len(vocab)}") - # Generate diverse induction head patterns patterns = [ - "The cat sat on the mat. The cat", - "Once upon a time, there was a princess. Once upon a time", - "In the forest, the wolf howled. In the forest", - "The sun shines bright today. The sun", - "My favorite color is blue. My favorite color", - "The dog ran in the park. The dog", - "She loves to read books. She loves", - "The moon is full tonight. The moon", - "He plays guitar every day. He plays", - "The bird sings a sweet song. The bird", + "The {A} saw the {B}. The {A}", + "Once the {A} met the {B}, later the {A}", + "In the story the {A} followed the {B}. The {A}", + "My favorite is the {A} with the {B}. The {A}", + "Everyone said the {A} remembers the {B}. The {A}", ] - # Take the requested number of prompts - selected_prompts = patterns[:num_prompts] - - # Tokenize the prompts - tokenized_prompts = [] - for prompt in selected_prompts: - # Split into input and target (everything after the last space) - parts = prompt.rsplit(" ", 1) - if len(parts) == 2: - input_text, target = parts - tokenized = tokenizer( - input_text, - return_tensors="pt", - padding=False, - truncation=True, - max_length=512, - ) - # Get the target token ID - target_tokens = tokenizer( - target, return_tensors="pt", add_special_tokens=False - ) - if target_tokens["input_ids"].numel() > 0: - target_token_id = target_tokens["input_ids"][0, 0].item() - tokenized_prompts.append( - { - "input_ids": tokenized["input_ids"][0], - "attention_mask": tokenized["attention_mask"][0], - "target_token": target_token_id, - "text": prompt, - } - ) + dataset = [] + for _ in range(num_prompts): + try: + A, B = random.sample(vocab, 2) + except ValueError: + print(f"Vocab size: {len(vocab)}") + breakpoint() + raise ValueError("Not enough unique tokens in vocab") + + template = random.choice(patterns) + text = template.format(A=A, B=B) + toks = tokenizer(text, return_tensors="pt", add_special_tokens=False) + input_ids = toks["input_ids"][0] + labels = torch.full_like(input_ids, -100) + + A_id = tokenizer(A, add_special_tokens=False)["input_ids"][0] + B_id = tokenizer(B, add_special_tokens=False)["input_ids"][0] + + # mask all A and B positions + matches_A = (input_ids == A_id).nonzero(as_tuple=True)[0] + matches_B = (input_ids == B_id).nonzero(as_tuple=True)[0] + labels[matches_A] = A_id + labels[matches_B] = B_id + + # explicitly make sure final label is B + labels[-1] = B_id + + dataset.append( + { + "input_ids": input_ids, + "attention_mask": toks["attention_mask"][0], + "labels": labels, + "A": A, + "B": B, + "text": text, + } + ) + return dataset + + +def test_induction_head_labels(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + dataset = create_induction_head_dataset(tokenizer, seed=0, num_prompts=3) + + for ex in dataset: + input_ids = ex["input_ids"] + labels = ex["labels"] + + A_id = tokenizer(ex["A"], add_special_tokens=False)["input_ids"][0] + B_id = tokenizer(ex["B"], add_special_tokens=False)["input_ids"][0] + + # check only {A, B, -100} appear + allowed = {A_id, B_id, -100} + assert set(labels.tolist()).issubset(allowed) - print(f"Created {len(tokenized_prompts)} induction head prompts") - return tokenized_prompts + # every A in input_ids must be in labels + for pos in (input_ids == A_id).nonzero(as_tuple=True)[0]: + assert labels[pos] == A_id + + # every B in input_ids must be in labels + for pos in (input_ids == B_id).nonzero(as_tuple=True)[0]: + assert labels[pos] == B_id + + # final token must be B + assert labels[-1].item() == B_id def setup_training( @@ -216,11 +528,11 @@ def setup_training( training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, - num_train_epochs=2, - per_device_train_batch_size=16, - # per_device_eval_batch_size=16, + num_train_epochs=1, + per_device_train_batch_size=8, + # per_device_eval_batch_size=8, gradient_accumulation_steps=1, - warmup_steps=100, + warmup_steps=1000, learning_rate=5e-4, weight_decay=0.01, logging_dir=f"{output_dir}/logs", @@ -232,7 +544,7 @@ def setup_training( # metric_for_best_model="train_loss", # greater_is_better=False, report_to="wandb" if wandb else None, - run_name="2-layer-transformer-tinystories", + run_name="2-layer-transformer-smollm2-corpus", seed=42, fp16=False, dataloader_drop_last=True, @@ -244,7 +556,7 @@ def setup_training( dtype=np.float32, torch_dtype=torch.float32, accumulate_grads=False, - track_training_order=True, + track_order=True, ) # Create trainer @@ -263,7 +575,9 @@ def setup_training( return trainer -def build_induction_index(model, induction_prompts, output_dir, projection_dim): +def build_induction_index( + model, induction_prompts, output_dir, projection_dim, unit_norm +): """Build static query Bergson index using synthetic induction head data.""" print("Building Bergson index for induction head queries...") @@ -309,6 +623,7 @@ def build_induction_index(model, induction_prompts, output_dir, projection_dim): index_path=f"{output_dir}/induction_gradients", device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float32, + unit_norm=unit_norm, ) # Collect mean gradient from attributor index @@ -319,6 +634,7 @@ def build_induction_index(model, induction_prompts, output_dir, projection_dim): device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float32, modules=True, + unit_norm=unit_norm, ) mean_module_gradients = { @@ -343,10 +659,14 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") raise e -def main(projection_dim=128): - tag = "mask_query" - k = 1000 - unit_norm = True +def main(args): + unit_norm = args.unit_norm + tag = args.tag + + projection_dim = args.projection_dim + seed = args.seed + train = args.train + plot = False print( "Starting 2-layer transformer pretraining with Bergson gradient collection..." @@ -361,13 +681,15 @@ def main(projection_dim=128): # Create model and tokenizer model, tokenizer = create_transformer() - model = model.to(device) # # Load TinyStories data train_dataset, eval_dataset = load_tinystories_data(tokenizer) # # Create induction head dataset - induction_prompts = create_induction_head_dataset(tokenizer) + test_induction_head_labels() + induction_prompts = create_induction_head_dataset( + tokenizer, seed=seed, num_prompts=10 + ) # # Set up training trainer = setup_training( @@ -380,125 +702,52 @@ def main(projection_dim=128): wandb=False, ) - # trainer.train() + if train: + trainer.train() + trainer.save_model(trainer.args.output_dir) + tokenizer.save_pretrained(trainer.args.output_dir) - trainer.save_model(trainer.args.output_dir) - tokenizer.save_pretrained(trainer.args.output_dir) + if not plot: + return # upload_to_hub(model, tokenizer) # Reload model and tokenizer - model = GPTNeoForCausalLM.from_pretrained(trainer.args.output_dir) + # model = AutoModelForCausalLM.from_pretrained(trainer.args.output_dir) + model = AttnOnlyForCausalLM.from_pretrained(trainer.args.output_dir) tokenizer = AutoTokenizer.from_pretrained(trainer.args.output_dir) model = model.to(device) # Build Bergson index for induction head queries mean_induction_gradients, module_induction_gradients = build_induction_index( - model, induction_prompts, trainer.args.output_dir, projection_dim + model, induction_prompts, trainer.args.output_dir, projection_dim, unit_norm ) model = model.cpu() - # Read Bergson index from training - epoch_attributors = [ - Attributor( - str( - Path(trainer.args.output_dir) / "gradients" / "train" / f"epoch_{epoch}" - ), - device=device, - unit_norm=unit_norm, - dtype=torch.float32, - # faiss_cfg=FaissConfig( - ) - for epoch in range(trainer.args.num_train_epochs) - ] # Load parquet table containing training order training_order = pq.read_table( str(Path(trainer.args.output_dir) / "gradients" / "training_order.parquet") ).to_pandas() - # Test the attributor with a sample query - print("Testing Bergson index with sample query...") - test_prompt = "The cat sat on the mat. The cat" - test_input = tokenizer(test_prompt, return_tensors="pt").to(device) - - # Mask out everything except the last token in the labels - test_input["labels"] = test_input["input_ids"].clone() - test_input["labels"][:, :-1] = -100 - - top_data = [] - - model = model.to(device) - for epoch_idx, epoch_attributor in enumerate(epoch_attributors): - # print(f"Top {k} most influential training examples for epoch {epoch_idx}:") - - with epoch_attributor.trace(model.base_model, k=k) as result: - outputs = model(**test_input) - outputs.loss.backward() - model.zero_grad() - - skips = 0 - for i, (score, idx) in enumerate( - zip(result.scores.squeeze(), result.indices.squeeze()) - ): - - if idx.item() != -1: - # Get the training order - training_metadata = training_order[ - (training_order["_idx"] == idx.item()) - & (training_order["epoch"] == epoch_idx) - ] - if training_metadata.empty: - skips += 1 - continue - for row in training_metadata.itertuples(index=False): - # print(f"{i+1}. Score: {score.item():.4f}, - # Global step: {row.global_step}, Index: {idx.item()}") - top_data.append( - { - "epoch": epoch_idx, - "global_step": row.global_step, - "index": idx.item(), - "score": score.item(), - } - ) - print(f"Skipped {skips} examples for epoch {epoch_idx}") - - top_data = pd.DataFrame(top_data) - - # Scatter plot of scores over time - plt.figure(figsize=(12, 8)) - - for epoch in sorted(top_data["epoch"].unique()): - epoch_data = top_data[top_data["epoch"] == epoch] - plt.scatter( - epoch_data["global_step"], - epoch_data["score"], - alpha=0.6, - s=20, - label=f"Epoch {epoch}", - ) - - plt.xlabel("Cumulative Training Steps") - plt.ylabel("Influence Score") - plt.title( - f"Most Influential Training Examples Per Epoch " - f"({'Normalized' if unit_norm else 'Unnormalized'})" - ) - plt.legend() - plt.grid(True, alpha=0.3) - fig_name = ( - f'training_dynamics{"_" + tag if tag else ""}{"_norm" if unit_norm else ""}.pdf' - ) - plt.savefig( - fig_name, - format="pdf", - bbox_inches="tight", - ) - plt.show() + # Plots + os.makedirs("figures", exist_ok=True) # Calculate the inner products with the training gradients data = [] - for epoch_idx, attributor in enumerate(epoch_attributors): + for epoch_idx in range(trainer.args.num_train_epochs): + # Read Bergson index from training + attributor = Attributor( + str( + Path(trainer.args.output_dir) + / "gradients" + / "train" + / f"epoch_{epoch_idx}" + ), + device=device, + unit_norm=unit_norm, + dtype=torch.float32, + # faiss_cfg=FaissConfig( + ) inner_products = attributor.grads.float() @ mean_induction_gradients.float() for i, score in enumerate(inner_products.squeeze()): training_metadata = training_order[ @@ -519,28 +768,22 @@ def main(projection_dim=128): data = pd.DataFrame(data) plt.figure(figsize=(12, 8)) - for epoch in sorted(data["epoch"].unique()): - epoch_data = data[data["epoch"] == epoch] - plt.scatter( - epoch_data["global_step"], - epoch_data["score"], - alpha=0.6, - s=20, - label=f"Epoch {epoch}", - ) - + plt.scatter( + data["global_step"], + data["score"], + alpha=0.6, + s=20, + # Use epoch for color + c=data["epoch"], + ) plt.xlabel("Cumulative Training Steps") plt.ylabel("Influence Score") plt.title( - f"Most Influential Training Examples Per Epoch " + f"Most Influential Training Examples " f"({'Normalized' if unit_norm else 'Unnormalized'})" ) - plt.legend() plt.grid(True, alpha=0.3) - fig_name = ( - f'training_dynamics_mean_induction{"_" + tag if tag else ""}' - f'{"_norm" if unit_norm else ""}.pdf' - ) + fig_name = f"figures/scores_{tag}" f'{"_norm" if unit_norm else ""}.pdf' plt.savefig( fig_name, format="pdf", @@ -548,22 +791,22 @@ def main(projection_dim=128): ) # Produce the same plot but split out by module (i.e. key in the grads mmap) - - # Second, produce the module-wise scores - del epoch_attributors - import os - - os.makedirs("module_figures", exist_ok=True) - + data = [] for epoch_idx in range(trainer.args.num_train_epochs): module_attributor = Attributor( index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", device=device, dtype=torch.float32, modules=True, + unit_norm=unit_norm, ) for name, grads in module_attributor.grads.items(): - data = [] + if "attention" not in name and "attn" not in name: + print(f"Skipping {name}") + continue + else: + print(f"Processing {name}") + inner_products = grads.float() @ module_induction_gradients[name].float() for i, score in enumerate(inner_products.squeeze()): training_metadata = training_order[ @@ -581,87 +824,81 @@ def main(projection_dim=128): "score": score.item(), } ) - data = pd.DataFrame(data) - module_data = data - plt.figure(figsize=(12, 8)) + df = pd.DataFrame(data) + print(df) - plt.scatter( - module_data["global_step"], - module_data["score"], - alpha=0.6, - s=20, - label=f"Module {name}", - ) - plt.xlabel("Training Step") - plt.ylabel("Influence Score") - plt.title( - f"Most Influential Training Examples for {name} " - f"({'Normalized' if unit_norm else 'Unnormalized'})" - ) - plt.legend() - plt.grid(True, alpha=0.3) - fig_name = ( - f'training_dynamics_induction{"_" + tag if tag else ""}' - f'{"_norm" if unit_norm else ""}_{name}.pdf' - ) - plt.savefig( - os.path.join("module_figures", fig_name), - format="pdf", - bbox_inches="tight", - ) - plt.close() + for module in df["module"].unique(): + name = module + module_data = df[df["module"] == module] + print(module_data) - # Add a line plot with absolute sum of the gradients for each module - plt.figure(figsize=(12, 8)) + plt.figure(figsize=(12, 8)) - module_data = module_data.groupby("global_step", as_index=False).agg( - score=("score", lambda s: s.abs().sum()) - ) - plt.plot( - module_data["global_step"], module_data["score"], label=f"Module {name}" - ) - plt.xlabel("Training Step") - plt.ylabel("Absolute Sum of Gradients") - plt.title(f"Absolute Sum of Gradients for {name}") - plt.legend() - plt.grid(True, alpha=0.3) - fig_name = ( - f'training_dynamics_induction_abs_sum{"_" + tag if tag else ""}' - f'{"_norm" if unit_norm else ""}_{name}.pdf' - ) - plt.savefig( - os.path.join("module_figures", fig_name), - format="pdf", - bbox_inches="tight", - ) - plt.close() + plt.scatter( + module_data["global_step"], + module_data["score"], + # c=module_data["epoch"], + alpha=0.6, + s=20, + label=f"Module {name}", + ) + plt.xlabel("Training Step") + plt.ylabel("Influence Score") + plt.title( + f"Most Influential Training Examples for {name} " + f"({'Normalized' if unit_norm else 'Unnormalized'})" + ) + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f"figures/module_scores_{tag}" f'{"_norm" if unit_norm else ""}_{name}.pdf' + ) + plt.savefig( + fig_name, + format="pdf", + bbox_inches="tight", + ) + plt.close() - # Add a line plot with the sum of the gradients for each module - # Sum points at each global step - module_data = module_data.groupby("global_step", as_index=False).agg( - score=("score", lambda s: s.sum()) - ) - plt.figure(figsize=(12, 8)) - plt.plot( - module_data["global_step"], module_data["score"], label=f"Module {name}" - ) - plt.xlabel("Training Step") - plt.ylabel("Sum of Gradients") - plt.title(f"Sum of Gradients for {name}") - plt.legend() - plt.grid(True, alpha=0.3) - fig_name = ( - f'training_dynamics_induction_sum{"_" + tag if tag else ""}' - f'{"_norm" if unit_norm else ""}_{name}.pdf' - ) - plt.savefig( - os.path.join("module_figures", fig_name), - format="pdf", - bbox_inches="tight", - ) - plt.close() + # Add a line plot with the sum of the gradients for each module + # Sum points at each global step + module_data = module_data.groupby(["global_step", "epoch"], as_index=False).agg( + score=("score", "sum") + ) + plt.figure(figsize=(12, 8)) + plt.plot( + module_data["global_step"], + module_data["score"], + label=f"Module {name}", # c=module_data["epoch"] + ) + plt.xlabel("Training Step") + plt.ylabel("Sum of Gradients") + plt.title(f"Sum of Gradients for {name}") + plt.legend() + plt.grid(True, alpha=0.3) + fig_name = ( + f'figures/sum{"_" + tag if tag else ""}' + f'{"_norm" if unit_norm else ""}_{name}.pdf' + ) + plt.savefig( + fig_name, + format="pdf", + bbox_inches="tight", + ) + plt.close() + + # Can we use SVCCA to align the gradients? if __name__ == "__main__": - main() + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--projection_dim", type=int, default=128) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--train", action="store_true") + parser.add_argument("--unit_norm", action="store_true") + parser.add_argument("--tag", type=str, default="") + args = parser.parse_args() + main(args) From 2785bd915dd59cba9ff5993b6ec042a6a9edae6f Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 18 Sep 2025 03:45:13 +0000 Subject: [PATCH 05/15] Update from rebase --- bergson/attributor.py | 62 -------------------------------- bergson/huggingface.py | 30 +++------------- examples/find_induction_heads.py | 20 ++++++++--- 3 files changed, 20 insertions(+), 92 deletions(-) diff --git a/bergson/attributor.py b/bergson/attributor.py index 82f1fc5..eb2c1af 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -43,7 +43,6 @@ def __init__( dtype: torch.dtype = torch.float32, unit_norm: bool = False, faiss_cfg: FaissConfig | None = None, - modules: bool = False, ): self.device = device self.dtype = dtype @@ -121,67 +120,6 @@ def search( return torch.topk(scores, k) - def search_module( - self, queries: Tensor, k: int, module: str - ) -> tuple[Tensor, Tensor]: - """ - Search for the `k` nearest examples in the index based on the query or queries. - If fewer than `k` examples are found FAISS will return items with the index -1 - and the maximum negative distance. - - Args: - queries: The query tensor of shape [..., d]. - k: The number of nearest examples to return for each query. - nprobe: The number of FAISS vector clusters to search if using ANN. - - Returns: - A namedtuple containing the top `k` indices and inner products for each - query. Both have shape [..., k]. - """ - assert isinstance( - self.grads, dict - ), "Gradients must be a dictionary of tensors." - assert module in self.grads, f"Module {module} not found in gradients." - - k = min(k, self.grads[module].shape[0]) - - q = queries - - if self.unit_norm: - q /= q.norm(dim=1, keepdim=True) - - if not self.faiss_cfg: - return torch.topk(q.to(self.device) @ self.grads[module].mT, k) - - q = q.cpu().numpy() - - shard_distances = [] - shard_indices = [] - offset = 0 - - for index in self.faiss_shards: - index.nprobe = self.faiss_cfg.nprobe - distances, indices = index.search(q, k) - - indices += offset - offset += index.ntotal - - shard_distances.append(distances) - shard_indices.append(indices) - - distances = np.concatenate(shard_distances, axis=1) - indices = np.concatenate(shard_indices, axis=1) - - # Rerank results overfetched from multiple shards - if len(self.faiss_shards) > 1: - topk_indices = np.argsort(distances, axis=1)[:, :k] - indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] - distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] - - return torch.from_numpy(distances.squeeze()), torch.from_numpy( - indices.squeeze() - ) - @contextmanager def trace( self, module: nn.Module, k: int, *, precondition: bool = False diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 542b2e4..119a5a2 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -31,6 +31,7 @@ def __init__( path: str, head_cfgs: dict[str, HeadConfig], projection_dim: int = 16, + dtype: DTypeLike = np.float16, accumulate_grads: bool = False, use_optimizer_state: bool = True, track_order: bool = False, @@ -38,7 +39,10 @@ def __init__( """ Args: path: The path to save the gradients + head_cfgs: Information used to split matrix-valued parameters into + per-head matrices before down projection. projection_dim: The dimension to project the gradients onto + dtype: The dtype of the on-disk gradient store accumulate_grads: Whether to take the sum of the gradients of the same example across epochs. If `False`, the gradients for each epoch are stored separately. @@ -46,8 +50,6 @@ def __init__( normalize the gradients. If `False`, no normalization is applied. track_order: Whether to record the shuffled order of training data. - head_cfgs: Information used to split matrix-valued parameters into - per-head matrices before down projection. """ super().__init__() @@ -73,9 +75,6 @@ def __init__( # TODO: Handle this more elegantly self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 - # TODO: Handle this more elegantly - self.torch_dtype = torch.float32 if self.dtype == np.float32 else torch.float16 - def write_grads(self, grad_buffer: np.memmap): # Ensure the nonblocking copies are all finished torch.cuda.synchronize() @@ -84,12 +83,6 @@ def write_grads(self, grad_buffer: np.memmap): self.mod_grads.clear() - def on_step_begin(self, args, state, control, **kwargs): - """Track the current step and epoch for training order recording.""" - if self.order: - self._current_step = state.global_step - self._current_epoch = int(state.epoch or 0) - def on_train_begin( self, args: TrainingArguments, @@ -266,21 +259,6 @@ def on_step_end( if not self.use_optimizer_state: return - # Record training order if enabled - if self.order: - assert ( - self.batch_indices is not None - ), "Batch indices are not available for training order tracking" - - self.order.extend( - { - "_idx": int(idx), - "global_step": getattr(self, "_current_step", 0), - "epoch": getattr(self, "_current_epoch", 0), - } - for idx in self.batch_indices.tolist() - ) - # The optimizer doesn't actually know the names of the parameters model = getattr(model, "base_model", model) param_to_name = { diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index 7850b70..e58fe26 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -39,9 +39,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast import wandb -from bergson.attributor import Attributor # from bergson.data import load_gradient_dataset +from bergson import HeadConfig +from bergson.attributor import Attributor from bergson.collection import collect_gradients from bergson.gradients import GradientProcessor from bergson.huggingface import ( @@ -349,9 +350,11 @@ def create_transformer(): return model, tokenizer -def load_tinystories_data(tokenizer, max_length=512, N=10000): +def load_tinystories_data(tokenizer, max_length=512, N: int | None = 10_000): """Load and preprocess TinyStories dataset.""" dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split="train") + if N is not None: + dataset = dataset.select(range(min(N, len(dataset)))) # dataset = load_dataset("roneneldan/TinyStories", split="train") # dataset = dataset.select(range(min(N, len(dataset)))) @@ -552,9 +555,14 @@ def setup_training( bergson_callback = GradientCollectorCallback( path=f"{output_dir}/gradients", + head_cfgs={ + "h.0.attn.c_attn": HeadConfig(12, 192, 2), + "h.0.attn.c_proj": HeadConfig(12, 64, 2), + "h.1.attn.c_attn": HeadConfig(12, 192, 2), + "h.1.attn.c_proj": HeadConfig(12, 64, 2), + }, projection_dim=projection_dim, dtype=np.float32, - torch_dtype=torch.float32, accumulate_grads=False, track_order=True, ) @@ -683,7 +691,10 @@ def main(args): model, tokenizer = create_transformer() # # Load TinyStories data - train_dataset, eval_dataset = load_tinystories_data(tokenizer) + if args.small: + train_dataset, eval_dataset = load_tinystories_data(tokenizer, N=1000) + else: + train_dataset, eval_dataset = load_tinystories_data(tokenizer) # # Create induction head dataset test_induction_head_labels() @@ -899,6 +910,7 @@ def main(args): parser.add_argument("--seed", type=int, default=0) parser.add_argument("--train", action="store_true") parser.add_argument("--unit_norm", action="store_true") + parser.add_argument("--small", action="store_true") parser.add_argument("--tag", type=str, default="") args = parser.parse_args() main(args) From 76c9c4ec0ea97a7ff20c99335c1fc7d058d13db9 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 19 Sep 2025 06:24:29 +0000 Subject: [PATCH 06/15] Update script --- examples/find_induction_heads.py | 286 ++++++++++++++++++++----------- 1 file changed, 185 insertions(+), 101 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index e58fe26..6ddb82d 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -1,16 +1,15 @@ #!/usr/bin/env python3 """ Pretrain a two-layer transformer and try to identify the formation of induction heads -from the influence functions wrt simple induction head completions gradients. +from the influence functions with respect to simple induction head completion gradients. This script: -1. Creates a 2-layer transformer using HF transformers architecture -2. Trains on TinyStories dataset using HF Trainer with Bergson callback +1. Creates a 2-layer attention-only transformer +2. Trains using the HF Trainer with the Bergson callback to collect gradients 3. Builds a static query Bergson index using synthetic induction head data -4. Uploads the trained model to HF hub +4. Plots the influence of the training examples on the induction heads """ -# attn_only.py import math import os import random @@ -20,11 +19,10 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import pyarrow.parquet as pq import torch import torch.nn as nn import torch.nn.functional as F -from datasets import Dataset, load_dataset +from datasets import Dataset, load_dataset, load_from_disk from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -39,16 +37,12 @@ from transformers.modeling_outputs import CausalLMOutputWithPast import wandb - -# from bergson.data import load_gradient_dataset -from bergson import HeadConfig -from bergson.attributor import Attributor -from bergson.collection import collect_gradients -from bergson.gradients import GradientProcessor +from bergson import Attributor, GradientProcessor, HeadConfig, collect_gradients from bergson.huggingface import ( GradientCollectorCallback, prepare_for_gradient_collection, ) +from bergson.utils import assert_type class AttnOnlyConfig(PretrainedConfig): @@ -302,31 +296,10 @@ def check_logins(): def create_transformer(): - """Create a transformer model using GPTNeo architecture.""" + """Create an attention-only transformer.""" tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") - # Alternative: use the EleutherAI 10k token tokenizer custom-built for TinyStories - - # config = GPTNeoConfig( - # vocab_size=len(tokenizer), - # hidden_size=768, - # intermediate_size=2, - # num_layers=2, - # num_heads=2, - # max_position_embeddings=1024, - # attention_types=[[["global"], 2]], - # window_size=256, - # resid_pdrop=0.0, - # embd_pdrop=0.0, - # attn_pdrop=0.0, - # layer_norm_epsilon=1e-5, - # initializer_range=0.02, - # use_cache=True, - # # Token IDs from the tokenizer - # pad_token_id=tokenizer.pad_token_id, - # bos_token_id=tokenizer.bos_token_id, - # eos_token_id=tokenizer.eos_token_id, - # ) - # model = GPTNeoForCausalLM(config) + # Alternative: use the EleutherAI 10k token tokenizer custom-built for TinyStories, + # but it's harder to find good single-token words cfg = AttnOnlyConfig( vocab_size=len(tokenizer), @@ -350,13 +323,14 @@ def create_transformer(): return model, tokenizer -def load_tinystories_data(tokenizer, max_length=512, N: int | None = 10_000): - """Load and preprocess TinyStories dataset.""" - dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split="train") +def load_data( + tokenizer, N: int | None, name="EleutherAI/SmolLM2-135M-10B", max_length=512 +): + """Load and preprocess dataset.""" + dataset = load_dataset(name, split="train") + dataset = assert_type(Dataset, dataset) if N is not None: dataset = dataset.select(range(min(N, len(dataset)))) - # dataset = load_dataset("roneneldan/TinyStories", split="train") - # dataset = dataset.select(range(min(N, len(dataset)))) def tokenize_function(examples): # Tokenize the text @@ -547,7 +521,7 @@ def setup_training( # metric_for_best_model="train_loss", # greater_is_better=False, report_to="wandb" if wandb else None, - run_name="2-layer-transformer-smollm2-corpus", + run_name="2-layer-transformer-SmolLM2-corpus", seed=42, fp16=False, dataloader_drop_last=True, @@ -622,7 +596,13 @@ def build_induction_index( data=induction_dataset, processor=processor, path=f"{output_dir}/induction_gradients", - skip_preconditioners=False, + skip_preconditioners=True, + head_cfgs={ + "h.0.attn.c_attn": HeadConfig(12, 192, 2), + "h.0.attn.c_proj": HeadConfig(12, 64, 2), + "h.1.attn.c_attn": HeadConfig(12, 192, 2), + "h.1.attn.c_proj": HeadConfig(12, 64, 2), + }, ) # Build the attributor for querying @@ -635,21 +615,15 @@ def build_induction_index( ) # Collect mean gradient from attributor index - mean_gradient = attributor.grads.mean(dim=0) - attributor = Attributor( - index_path=f"{output_dir}/induction_gradients", - device="cuda" if torch.cuda.is_available() else "cpu", - dtype=torch.float32, - modules=True, - unit_norm=unit_norm, + mean_gradient = torch.cat([grad for grad in attributor.grads.values()], dim=1).mean( + dim=0 ) - mean_module_gradients = { name: attributor.grads[name].mean(dim=0) for name in attributor.grads.keys() } - print("In-context index built successfully! Returning mean gradient...") + print("In-context index built successfully! Returning mean gradients...") return mean_gradient, mean_module_gradients @@ -668,13 +642,15 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(args): + dataset_name = "EleutherAI/SmolLM2-135M-10B" + unit_norm = args.unit_norm tag = args.tag projection_dim = args.projection_dim seed = args.seed train = args.train - plot = False + plot = args.plot print( "Starting 2-layer transformer pretraining with Bergson gradient collection..." @@ -690,11 +666,11 @@ def main(args): # Create model and tokenizer model, tokenizer = create_transformer() - # # Load TinyStories data + # Load data if args.small: - train_dataset, eval_dataset = load_tinystories_data(tokenizer, N=1000) + train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name, N=1000) else: - train_dataset, eval_dataset = load_tinystories_data(tokenizer) + train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name) # # Create induction head dataset test_induction_head_labels() @@ -736,8 +712,8 @@ def main(args): model = model.cpu() # Load parquet table containing training order - training_order = pq.read_table( - str(Path(trainer.args.output_dir) / "gradients" / "training_order.parquet") + training_order = load_from_disk( + str(Path(trainer.args.output_dir) / "gradients" / "order.hf") ).to_pandas() # Plots @@ -747,7 +723,7 @@ def main(args): data = [] for epoch_idx in range(trainer.args.num_train_epochs): # Read Bergson index from training - attributor = Attributor( + grads = Attributor( str( Path(trainer.args.output_dir) / "gradients" @@ -757,9 +733,22 @@ def main(args): device=device, unit_norm=unit_norm, dtype=torch.float32, - # faiss_cfg=FaissConfig( - ) - inner_products = attributor.grads.float() @ mean_induction_gradients.float() + ).grads + + inner_products = None + offset = 0 + for grad in grads.values(): + d = grad.shape[1] + mean_block = mean_induction_gradients[offset : offset + d] + offset += d + for i in range(0, grad.shape[0], 1024): + batch = grad[i : i + 1024].to(mean_block.device, dtype=torch.float32) + contrib = batch @ mean_block.float() + if inner_products is None: + inner_products = torch.zeros(grad.shape[0], device=contrib.device) + inner_products[i : i + 1024] += contrib + inner_products = inner_products.cpu() + for i, score in enumerate(inner_products.squeeze()): training_metadata = training_order[ (training_order["_idx"] == i) & (training_order["epoch"] == epoch_idx) @@ -802,47 +791,66 @@ def main(args): ) # Produce the same plot but split out by module (i.e. key in the grads mmap) - data = [] - for epoch_idx in range(trainer.args.num_train_epochs): - module_attributor = Attributor( - index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", - device=device, - dtype=torch.float32, - modules=True, - unit_norm=unit_norm, - ) - for name, grads in module_attributor.grads.items(): - if "attention" not in name and "attn" not in name: - print(f"Skipping {name}") - continue - else: - print(f"Processing {name}") - - inner_products = grads.float() @ module_induction_gradients[name].float() - for i, score in enumerate(inner_products.squeeze()): - training_metadata = training_order[ - (training_order["_idx"] == i) - & (training_order["epoch"] == epoch_idx) - ] - if len(training_metadata) != 1: + df_path = f"figures/module_scores_{tag}{'_norm' if unit_norm else ''}.csv" + if os.path.exists(df_path): + df = pd.read_csv(df_path) + print(f"Loaded module scores from {df_path}") + else: + data = [] + for epoch_idx in range(trainer.args.num_train_epochs): + grads = Attributor( + index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", + device="cpu", + dtype=torch.float32, + unit_norm=unit_norm, + ).grads + + # module_inner_products = {} + offset = 0 + for name, grad in grads.items(): + if "attention" not in name and "attn" not in name: + print(f"Skipping {name}") continue - for row in training_metadata.itertuples(index=False): - data.append( - { - "global_step": row.global_step, - "epoch": epoch_idx, - "module": name, - "score": score.item(), - } + else: + print(f"Processing {name}") + + d = grad.shape[1] + mean_block = mean_induction_gradients[offset : offset + d] + offset += d + scores = [] + for i in range(0, grad.shape[0], 1024): + batch = grad[i : i + 1024].to( + mean_block.device, dtype=torch.float32 ) - - df = pd.DataFrame(data) - print(df) - - for module in df["module"].unique(): + scores.append(batch @ mean_block.float()) + mod_inner_products = torch.cat(scores, dim=0).cpu() + + for i, score in enumerate(mod_inner_products.squeeze()): + training_metadata = training_order[ + (training_order["_idx"] == i) + & (training_order["epoch"] == epoch_idx) + ] + if len(training_metadata) != 1: + continue + for row in training_metadata.itertuples(index=False): + data.append( + { + "global_step": row.global_step, + "epoch": epoch_idx, + "module": name, + "score": score.item(), + } + ) + + df = pd.DataFrame(data) + df.to_csv(df_path, index=False) + + attn_modules = [name for name in df["module"].unique() if "attn" in name] + non_attn_modules = [name for name in df["module"].unique() if "attn" not in name] + + for module in non_attn_modules: name = module module_data = df[df["module"] == module] - print(module_data) plt.figure(figsize=(12, 8)) @@ -899,7 +907,82 @@ def main(args): ) plt.close() - # Can we use SVCCA to align the gradients? + # Plot all attention heads in one file + n = len(attn_modules) + cols = math.ceil(math.sqrt(n)) + rows = math.ceil(n / cols) + + fig, axes = plt.subplots( + rows, cols, figsize=(5 * cols, 4 * rows), squeeze=False, sharey=True + ) + + for ax, module in zip(axes.flatten(), attn_modules): + module_data = df[df["module"] == module] + ax.scatter( + module_data["global_step"], + module_data["score"], + alpha=0.6, + s=20, + ) + ax.set_title(module) + ax.set_xlabel("Step") + ax.set_ylabel("Score") + ax.grid(True, alpha=0.3) + + plt.tight_layout() + fig.savefig(f"figures/all_heads_scores_{tag}{'_norm' if unit_norm else ''}.pdf") + plt.close(fig) + + # Single figure with each attention modules' sum-of-scores over steps + fig, ax = plt.subplots(figsize=(6, 4)) + + for module in attn_modules: + module_data = df[df["module"] == module] + summed = module_data.groupby("global_step")["score"].sum().reset_index() + ax.plot(summed["global_step"], summed["score"], label=module, alpha=0.7) + + ax.set_xlabel("Step") + ax.set_ylabel("Sum of Scores") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + ax.legend().remove() + + plt.tight_layout() + fig.savefig(f"figures/all_heads_sum_scores_{tag}{'_norm' if unit_norm else ''}.pdf") + plt.close(fig) + + # Single figure with each attention modules' sum-of-scores summed over steps + sums = [df[df["module"] == m]["score"].sum() for m in attn_modules] + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.bar(range(len(attn_modules)), sums) + ax.set_xticks(range(len(attn_modules))) + ax.set_xticklabels(attn_modules, rotation=90) + ax.set_ylabel("Sum of Scores") + ax.set_xlabel("Module") + ax.grid(True, axis="y", alpha=0.3) + + plt.tight_layout() + fig.savefig( + f"figures/all_heads_sum_scores_bar_{tag}{'_norm' if unit_norm else ''}.pdf" + ) + plt.close(fig) + + # Step 1: pick checkpoint steps + # Step 2: compute a bunch of gradients at this step using the static index build + # and save it + # Step 1.5: fix the horrible static index build bug + + # Can we use optimal transport to align the gradients? + # Should we transport the activations then transport the gradients in the same way? + # Or should we transport the gradients directly? + + # To compute the optimal transport maps we just need a huge dataset of training + # gradients at different steps. + + # Once we have optimal transport maps we can optimal transport the gradients to the + # trained model distribution. Then we can compute the influence of the training + # examples on the induction heads. if __name__ == "__main__": @@ -912,5 +995,6 @@ def main(args): parser.add_argument("--unit_norm", action="store_true") parser.add_argument("--small", action="store_true") parser.add_argument("--tag", type=str, default="") + parser.add_argument("--plot", action="store_true") args = parser.parse_args() main(args) From 86051fe4d15100f893e7e9d1abb00dcb699b7f7f Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 23 Sep 2025 00:49:52 +0000 Subject: [PATCH 07/15] configurable attn only transformer --- examples/find_induction_heads.py | 41 +++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index 6ddb82d..bcba826 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -60,6 +60,8 @@ def __init__( embd_pdrop=0.0, attn_pdrop=0.0, use_cache=True, + layer_norm=False, + special_pos_embed=True, **kwargs, ): super().__init__(**kwargs) @@ -73,6 +75,7 @@ def __init__( self.embd_pdrop = embd_pdrop self.attn_pdrop = attn_pdrop self.use_cache = use_cache + self.layer_norm = layer_norm class CausalSelfAttention(nn.Module): @@ -119,8 +122,9 @@ def forward( q, k, v = qkv.split(C, dim=2) # add position to q and k only - q = q + pos_emb - k = k + pos_emb + if self.config.special_pos_embed: + q = q + pos_emb + k = k + pos_emb q = self._split_heads(q) k = self._split_heads(k) @@ -149,7 +153,8 @@ def forward( class AttnOnlyBlock(nn.Module): def __init__(self, config: AttnOnlyConfig): super().__init__() - # self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + if config.layer_norm: + self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention(config) def forward( @@ -160,7 +165,9 @@ def forward( use_cache: bool = True, attn_mask: Optional[torch.Tensor] = None, ): - # self.ln_1(x) + if self.ln_1 is not None: + x = self.ln_1(x) + a, present = self.attn( x, pos_emb, layer_past=layer_past, use_cache=use_cache, attn_mask=attn_mask ) @@ -179,7 +186,8 @@ def __init__(self, config: AttnOnlyConfig): self.h = nn.ModuleList( [AttnOnlyBlock(config) for _ in range(config.num_hidden_layers)] ) - # self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + if config.layer_norm: + self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.apply(self._init_weights) @@ -232,10 +240,14 @@ def forward( ) -> CausalLMOutputWithPast: B, T = input_ids.size() pos = torch.arange(0, T, dtype=torch.long, device=input_ids.device).unsqueeze(0) - x = self.wte(input_ids) # + self.wpe(pos) - x = self.drop(x) + x = self.wte(input_ids) pos_emb = self.wpe(pos) + if not self.config.special_pos_embed: + x = x + pos_emb + + x = self.drop(x) + presents = [] for i, block in enumerate(self.h): layer_past = None if past_key_values is None else past_key_values[i] @@ -248,7 +260,9 @@ def forward( if present is not None: presents.append(present) - # x = self.ln_f(x) + if self.ln_f is not None: + x = self.ln_f(x) + logits = self.lm_head(x) loss = None @@ -295,7 +309,7 @@ def check_logins(): raise e -def create_transformer(): +def create_transformer(special_pos_embed): """Create an attention-only transformer.""" tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") # Alternative: use the EleutherAI 10k token tokenizer custom-built for TinyStories, @@ -307,6 +321,8 @@ def create_transformer(): num_hidden_layers=2, num_attention_heads=12, max_position_embeddings=1024, + layer_norm=False, + special_pos_embed=special_pos_embed, ) model = AttnOnlyForCausalLM(cfg) @@ -324,7 +340,7 @@ def create_transformer(): def load_data( - tokenizer, N: int | None, name="EleutherAI/SmolLM2-135M-10B", max_length=512 + tokenizer, N: int | None = None, name="EleutherAI/SmolLM2-135M-10B", max_length=512 ): """Load and preprocess dataset.""" dataset = load_dataset(name, split="train") @@ -664,7 +680,9 @@ def main(args): print(f"Using device: {device}") # Create model and tokenizer - model, tokenizer = create_transformer() + model, tokenizer = create_transformer( + special_pos_embed=not args.no_special_pos_embed + ) # Load data if args.small: @@ -996,5 +1014,6 @@ def main(args): parser.add_argument("--small", action="store_true") parser.add_argument("--tag", type=str, default="") parser.add_argument("--plot", action="store_true") + parser.add_argument("--no_special_pos_embed", action="store_false") args = parser.parse_args() main(args) From f78ee12bd4f0ff8e72335382273dd7d94ccfdaea Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 23 Sep 2025 04:46:49 +0000 Subject: [PATCH 08/15] clean up --- examples/find_induction_heads.py | 74 ++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index bcba826..acd48c1 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -76,6 +76,7 @@ def __init__( self.attn_pdrop = attn_pdrop self.use_cache = use_cache self.layer_norm = layer_norm + self.special_pos_embed = special_pos_embed class CausalSelfAttention(nn.Module): @@ -88,6 +89,7 @@ def __init__(self, config: AttnOnlyConfig): self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) self.attn_drop = nn.Dropout(config.attn_pdrop) self.resid_drop = nn.Dropout(config.resid_pdrop) + self.special_pos_embed = config.special_pos_embed self.register_buffer( "mask", torch.tril( @@ -122,7 +124,7 @@ def forward( q, k, v = qkv.split(C, dim=2) # add position to q and k only - if self.config.special_pos_embed: + if self.special_pos_embed: q = q + pos_emb k = k + pos_emb @@ -155,6 +157,8 @@ def __init__(self, config: AttnOnlyConfig): super().__init__() if config.layer_norm: self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + else: + self.ln_1 = None self.attn = CausalSelfAttention(config) def forward( @@ -188,6 +192,8 @@ def __init__(self, config: AttnOnlyConfig): ) if config.layer_norm: self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + else: + self.ln_f = None self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.apply(self._init_weights) @@ -511,6 +517,7 @@ def setup_training( output_dir: str, projection_dim: int, wandb: bool = True, + num_train_epochs: int = 1, ): """Set up the training configuration with Bergson callback.""" data_collator = DataCollatorForLanguageModeling( @@ -521,7 +528,7 @@ def setup_training( training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, - num_train_epochs=1, + num_train_epochs=num_train_epochs, per_device_train_batch_size=8, # per_device_eval_batch_size=8, gradient_accumulation_steps=1, @@ -659,6 +666,7 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(args): dataset_name = "EleutherAI/SmolLM2-135M-10B" + num_train_epochs = 1 unit_norm = args.unit_norm tag = args.tag @@ -668,6 +676,8 @@ def main(args): train = args.train plot = args.plot + output_dir = f"examples/runs/transformer_2_layer{'_' + tag if tag else ''}" + print( "Starting 2-layer transformer pretraining with Bergson gradient collection..." ) @@ -684,33 +694,36 @@ def main(args): special_pos_embed=not args.no_special_pos_embed ) - # Load data - if args.small: - train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name, N=1000) - else: - train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name) - # # Create induction head dataset test_induction_head_labels() induction_prompts = create_induction_head_dataset( tokenizer, seed=seed, num_prompts=10 ) - # # Set up training - trainer = setup_training( - model, - tokenizer, - train_dataset, - eval_dataset, - output_dir=f"examples/runs/transformer_2_layer{'_' + tag if tag else ''}", - projection_dim=projection_dim, - wandb=False, - ) - if train: + # Set up training + # Load data + if args.small: + train_dataset, eval_dataset = load_data( + tokenizer, name=dataset_name, N=1000 + ) + else: + train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name) + + trainer = setup_training( + model, + tokenizer, + train_dataset, + eval_dataset, + output_dir=output_dir, + projection_dim=projection_dim, + wandb=False, + num_train_epochs=num_train_epochs, + ) + trainer.train() - trainer.save_model(trainer.args.output_dir) - tokenizer.save_pretrained(trainer.args.output_dir) + trainer.save_model(output_dir) + tokenizer.save_pretrained(output_dir) if not plot: return @@ -718,20 +731,20 @@ def main(args): # upload_to_hub(model, tokenizer) # Reload model and tokenizer - # model = AutoModelForCausalLM.from_pretrained(trainer.args.output_dir) - model = AttnOnlyForCausalLM.from_pretrained(trainer.args.output_dir) - tokenizer = AutoTokenizer.from_pretrained(trainer.args.output_dir) + # model = AutoModelForCausalLM.from_pretrained(output_dir) + model = AttnOnlyForCausalLM.from_pretrained(output_dir) + tokenizer = AutoTokenizer.from_pretrained(output_dir) model = model.to(device) # Build Bergson index for induction head queries mean_induction_gradients, module_induction_gradients = build_induction_index( - model, induction_prompts, trainer.args.output_dir, projection_dim, unit_norm + model, induction_prompts, output_dir, projection_dim, unit_norm ) model = model.cpu() # Load parquet table containing training order training_order = load_from_disk( - str(Path(trainer.args.output_dir) / "gradients" / "order.hf") + str(Path(output_dir) / "gradients" / "order.hf") ).to_pandas() # Plots @@ -739,15 +752,10 @@ def main(args): # Calculate the inner products with the training gradients data = [] - for epoch_idx in range(trainer.args.num_train_epochs): + for epoch_idx in range(num_train_epochs): # Read Bergson index from training grads = Attributor( - str( - Path(trainer.args.output_dir) - / "gradients" - / "train" - / f"epoch_{epoch_idx}" - ), + str(Path(output_dir) / "gradients" / "train" / f"epoch_{epoch_idx}"), device=device, unit_norm=unit_norm, dtype=torch.float32, From 0bc453ce3bb8fd6c6223d46521304c7e936cea49 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 25 Sep 2025 03:34:28 +0000 Subject: [PATCH 09/15] tweaks and fixes --- bergson/data.py | 2 +- bergson/faiss_index.py | 5 ++--- bergson/huggingface.py | 25 ++++++++++++++++++------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/bergson/data.py b/bergson/data.py index 29d5cb3..598ea27 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -104,7 +104,7 @@ class IndexConfig: streaming: bool = False """Whether to use streaming mode for the dataset.""" - stream_shard_size: int = 100_000 + stream_shard_size: int = 400_000 """Shard size for streaming the dataset into Dataset objects.""" revision: str | None = None diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index 970fa62..fab29d7 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -91,7 +91,7 @@ def load_shard(shard_dir: str) -> np.memmap: yield load_shard(root_dir) else: for shard_path in sorted(root_path.iterdir()): - if shard_path.is_dir(): + if shard_path.is_dir() and "shard" in shard_path.name: yield load_shard(str(shard_path)) @@ -160,8 +160,7 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo index_idx = 0 for grads in tqdm(dl, desc="Loading gradients"): - if grads.dtype.names is not None: - grads = structured_to_unstructured(grads) + grads = structured_to_unstructured(grads) if unit_norm: grads = normalize_grads(grads, device, faiss_cfg.batch_size) diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 119a5a2..9bc853a 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -95,7 +95,7 @@ def on_train_begin( if not hasattr(args, "__gradient_collection_enabled__"): raise RuntimeError( "Gradient collection is not enabled. Please enable it by " - "calling bergson.prepare_gradient_collection on the trainer." + "calling bergson.prepare_for_gradient_collection on the trainer." ) if isinstance(model, PeftModel): @@ -133,7 +133,7 @@ def on_epoch_begin( state: TrainerState, control: TrainerControl, *, - eval_dataloader: DataLoader | dict[str, DataLoader], + eval_dataloader: DataLoader | dict[str, DataLoader] | None, train_dataloader: DataLoader, **kwargs, ): @@ -158,9 +158,16 @@ def on_epoch_begin( # Set up the gradient buffers for the evaluation datasets if eval_dataloader is None: - print("No evaluation dataloader found") - return - elif isinstance(eval_dataloader, dict): + # HF Trainer doesn't expose the evaluation dataloaders + if hasattr(args, "eval_dataset"): + eval_dataloader = DataLoader( + args.eval_dataset, batch_size=1, shuffle=False + ) + else: + print("Warning: no evaluation dataloader found") + return + + if isinstance(eval_dataloader, dict): eval_datasets = eval_dataloader else: eval_datasets = {"eval": eval_dataloader} @@ -303,9 +310,11 @@ def on_step_end( proc.normalizers = normalizers + def on_evaluate(self, args, state, control, **kwargs): + print("on_evaluate") + def on_prediction_step(self, args, state, control, **kwargs): - dataset_name = kwargs["inputs"]["dataset_name"] - self.write_grads(self.eval_grad_buffers[dataset_name]) + print("on_prediction_step") def on_train_end( self, @@ -366,6 +375,8 @@ def prepare_for_gradient_collection(trainer: Trainer): lambda ex, idx: {"_idx": idx}, with_indices=True ) + trainer.args.eval_dataset = trainer.eval_dataset + trainer._set_signature_columns_if_needed() trainer._signature_columns.append("_idx") From 1f2451c157c6ee537c8bed834a1504a8781b5601 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 25 Sep 2025 03:35:05 +0000 Subject: [PATCH 10/15] research commit --- examples/find_induction_heads.py | 159 +++++++++++++++++-------------- 1 file changed, 89 insertions(+), 70 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index acd48c1..ad1d1d1 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -349,10 +349,9 @@ def load_data( tokenizer, N: int | None = None, name="EleutherAI/SmolLM2-135M-10B", max_length=512 ): """Load and preprocess dataset.""" - dataset = load_dataset(name, split="train") + split = f"train[:{N}]" if N is not None else "train" + dataset = load_dataset(name, split=split) dataset = assert_type(Dataset, dataset) - if N is not None: - dataset = dataset.select(range(min(N, len(dataset)))) def tokenize_function(examples): # Tokenize the text @@ -403,8 +402,26 @@ def build_single_token_vocab(tokenizer, wordlist, max_words=500): def create_induction_head_dataset(tokenizer, seed, num_prompts=100): random.seed(seed) - # crude word list, can be expanded - base_words = [ + # Separate words into appropriate A and B categories for sensible bigrams + A_words = [ + "blue", + "green", + "red", + "gold", + "happy", + "sad", + "big", + "small", + "fast", + "slow", + "smart", + "kind", + "brave", + "wise", + "young", + "old", + ] + B_words = [ "cat", "dog", "bird", @@ -419,71 +436,73 @@ def create_induction_head_dataset(tokenizer, seed, num_prompts=100): "road", "sky", "song", - "color", - "blue", - "green", - "red", - "gold", - "day", - "night", "king", "queen", "child", "story", + "house", + "river", + "mountain", + "flower", + "cloud", ] - vocab = build_single_token_vocab(tokenizer, base_words) - print(f"Vocab size: {len(vocab)}") + + A_vocab = build_single_token_vocab(tokenizer, A_words) + B_vocab = build_single_token_vocab(tokenizer, B_words) + print(f"A vocab size: {len(A_vocab)}") + print(f"B vocab size: {len(B_vocab)}") + + # Verify that all words are indeed single tokens + print("A vocab:", A_vocab) + print("B vocab:", B_vocab) patterns = [ - "The {A} saw the {B}. The {A}", - "Once the {A} met the {B}, later the {A}", - "In the story the {A} followed the {B}. The {A}", - "My favorite is the {A} with the {B}. The {A}", - "Everyone said the {A} remembers the {B}. The {A}", + "The {A} {B} was happy. The {A} {B}", + "Once the {A} {B} played, later the {A} {B}", + "In the story the {A} {B} ran fast. The {A} {B}", + "My favorite is the {A} {B} that sings. The {A} {B}", + "Everyone said the {A} {B} is smart. The {A} {B}", ] dataset = [] for _ in range(num_prompts): try: - A, B = random.sample(vocab, 2) + A = random.choice(A_vocab) + B = random.choice(B_vocab) except ValueError: - print(f"Vocab size: {len(vocab)}") - breakpoint() + print(f"A vocab size: {len(A_vocab)}, B vocab size: {len(B_vocab)}") raise ValueError("Not enough unique tokens in vocab") template = random.choice(patterns) text = template.format(A=A, B=B) - toks = tokenizer(text, return_tensors="pt", add_special_tokens=False) - input_ids = toks["input_ids"][0] - labels = torch.full_like(input_ids, -100) - - A_id = tokenizer(A, add_special_tokens=False)["input_ids"][0] - B_id = tokenizer(B, add_special_tokens=False)["input_ids"][0] - - # mask all A and B positions - matches_A = (input_ids == A_id).nonzero(as_tuple=True)[0] - matches_B = (input_ids == B_id).nonzero(as_tuple=True)[0] - labels[matches_A] = A_id - labels[matches_B] = B_id + toks = tokenizer( + text, + add_special_tokens=False, + padding="max_length", + truncation=True, + max_length=16, + ) + input_ids = toks["input_ids"] + labels = [-100] * len(input_ids) - # explicitly make sure final label is B - labels[-1] = B_id + # Set the last non-padding token as the target + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] != tokenizer.pad_token_id: + labels[i] = input_ids[i] + break dataset.append( { "input_ids": input_ids, - "attention_mask": toks["attention_mask"][0], + "attention_mask": toks["attention_mask"], "labels": labels, - "A": A, - "B": B, "text": text, } ) return dataset -def test_induction_head_labels(): - tokenizer = AutoTokenizer.from_pretrained("gpt2") +def test_induction_head_labels(tokenizer): dataset = create_induction_head_dataset(tokenizer, seed=0, num_prompts=3) for ex in dataset: @@ -520,6 +539,18 @@ def setup_training( num_train_epochs: int = 1, ): """Set up the training configuration with Bergson callback.""" + + def compute_metrics(eval_preds): + accuracy = ( + (eval_preds.label_ids == eval_preds.predictions) + .astype(np.float32) + .mean() + .item() + ) + return { + "accuracy": accuracy, + } + data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, @@ -530,16 +561,18 @@ def setup_training( overwrite_output_dir=True, num_train_epochs=num_train_epochs, per_device_train_batch_size=8, - # per_device_eval_batch_size=8, + per_device_eval_batch_size=128, gradient_accumulation_steps=1, warmup_steps=1000, learning_rate=5e-4, weight_decay=0.01, logging_dir=f"{output_dir}/logs", logging_steps=10, - # save_strategy="steps", - # save_steps=1000, - # save_total_limit=3, + eval_steps=10, + eval_strategy="steps", + save_strategy="steps", + save_steps=1000, + save_total_limit=3, # load_best_model_at_end=True, # metric_for_best_model="train_loss", # greater_is_better=False, @@ -569,9 +602,10 @@ def setup_training( model=model, args=training_args, train_dataset=train_dataset, - # eval_dataset=eval_dataset, + eval_dataset=eval_dataset, data_collator=data_collator, callbacks=[bergson_callback], + compute_metrics=compute_metrics, ) # Prepare for gradient collection @@ -587,23 +621,7 @@ def build_induction_index( print("Building Bergson index for induction head queries...") # Convert induction prompts to dataset format - induction_data = [] - for prompt_data in induction_prompts: - # Create a simple dataset entry - induction_data.append( - { - "input_ids": prompt_data["input_ids"].tolist(), - "attention_mask": prompt_data["attention_mask"].tolist(), - "labels": prompt_data["input_ids"].tolist(), # For language modeling - "text": prompt_data["text"], - } - ) - # Mask out everything except the last token in the labels - labels = [-100] * len(prompt_data["input_ids"]) - labels[-1] = prompt_data["input_ids"][-1] - induction_data[-1]["labels"] = labels - - induction_dataset = Dataset.from_list(induction_data) + induction_dataset = Dataset.from_list(induction_prompts) # Create gradient processor processor = GradientProcessor( @@ -665,7 +683,8 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(args): - dataset_name = "EleutherAI/SmolLM2-135M-10B" + # dataset_name = "EleutherAI/SmolLM2-135M-10B" + dataset_name = "RonenEldan/TinyStories" num_train_epochs = 1 unit_norm = args.unit_norm @@ -695,20 +714,20 @@ def main(args): ) # # Create induction head dataset - test_induction_head_labels() + # test_induction_head_labels(tokenizer) induction_prompts = create_induction_head_dataset( - tokenizer, seed=seed, num_prompts=10 + tokenizer, seed=seed, num_prompts=100 ) if train: # Set up training # Load data if args.small: - train_dataset, eval_dataset = load_data( - tokenizer, name=dataset_name, N=1000 - ) + train_dataset, _ = load_data(tokenizer, name=dataset_name, N=10_000) else: - train_dataset, eval_dataset = load_data(tokenizer, name=dataset_name) + train_dataset, _ = load_data(tokenizer, name=dataset_name) + + eval_dataset = Dataset.from_list(induction_prompts) trainer = setup_training( model, From 24b4c5fcd755dac36b066974cedb26d829726f92 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 25 Sep 2025 05:47:23 +0000 Subject: [PATCH 11/15] Add eval logging --- examples/find_induction_heads.py | 43 +++++++++++++------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index ad1d1d1..89ce1e9 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -44,6 +44,13 @@ ) from bergson.utils import assert_type +HEAD_CFGS = { + "h.0.attn.c_attn": HeadConfig(12, 192, 2), + "h.0.attn.c_proj": HeadConfig(12, 64, 2), + "h.1.attn.c_attn": HeadConfig(12, 192, 2), + "h.1.attn.c_proj": HeadConfig(12, 64, 2), +} + class AttnOnlyConfig(PretrainedConfig): model_type = "attn_only" @@ -542,7 +549,7 @@ def setup_training( def compute_metrics(eval_preds): accuracy = ( - (eval_preds.label_ids == eval_preds.predictions) + (eval_preds.label_ids == eval_preds.predictions.argmax(axis=-1)) .astype(np.float32) .mean() .item() @@ -568,29 +575,21 @@ def compute_metrics(eval_preds): weight_decay=0.01, logging_dir=f"{output_dir}/logs", logging_steps=10, - eval_steps=10, + eval_steps=100, eval_strategy="steps", save_strategy="steps", - save_steps=1000, - save_total_limit=3, - # load_best_model_at_end=True, - # metric_for_best_model="train_loss", - # greater_is_better=False, + save_steps=10_000, + # save_total_limit=3, report_to="wandb" if wandb else None, run_name="2-layer-transformer-SmolLM2-corpus", seed=42, fp16=False, - dataloader_drop_last=True, + dataloader_drop_last=False, ) bergson_callback = GradientCollectorCallback( path=f"{output_dir}/gradients", - head_cfgs={ - "h.0.attn.c_attn": HeadConfig(12, 192, 2), - "h.0.attn.c_proj": HeadConfig(12, 64, 2), - "h.1.attn.c_attn": HeadConfig(12, 192, 2), - "h.1.attn.c_proj": HeadConfig(12, 64, 2), - }, + head_cfgs=HEAD_CFGS, projection_dim=projection_dim, dtype=np.float32, accumulate_grads=False, @@ -618,8 +617,6 @@ def build_induction_index( model, induction_prompts, output_dir, projection_dim, unit_norm ): """Build static query Bergson index using synthetic induction head data.""" - print("Building Bergson index for induction head queries...") - # Convert induction prompts to dataset format induction_dataset = Dataset.from_list(induction_prompts) @@ -638,12 +635,7 @@ def build_induction_index( processor=processor, path=f"{output_dir}/induction_gradients", skip_preconditioners=True, - head_cfgs={ - "h.0.attn.c_attn": HeadConfig(12, 192, 2), - "h.0.attn.c_proj": HeadConfig(12, 64, 2), - "h.1.attn.c_attn": HeadConfig(12, 192, 2), - "h.1.attn.c_proj": HeadConfig(12, 64, 2), - }, + head_cfgs=HEAD_CFGS, ) # Build the attributor for querying @@ -656,7 +648,6 @@ def build_induction_index( ) # Collect mean gradient from attributor index - mean_gradient = torch.cat([grad for grad in attributor.grads.values()], dim=1).mean( dim=0 ) @@ -683,8 +674,8 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(args): - # dataset_name = "EleutherAI/SmolLM2-135M-10B" - dataset_name = "RonenEldan/TinyStories" + dataset_name = "EleutherAI/SmolLM2-135M-10B" + # dataset_name = "RonenEldan/TinyStories" num_train_epochs = 1 unit_norm = args.unit_norm @@ -723,7 +714,7 @@ def main(args): # Set up training # Load data if args.small: - train_dataset, _ = load_data(tokenizer, name=dataset_name, N=10_000) + train_dataset, _ = load_data(tokenizer, name=dataset_name, N=20_000) else: train_dataset, _ = load_data(tokenizer, name=dataset_name) From 7602411408724b22927b8d5ad525249673ae02f5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 1 Oct 2025 06:05:21 +0000 Subject: [PATCH 12/15] Update induction heads eval dataset --- examples/find_induction_heads.py | 106 ++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 22 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index 89ce1e9..b3204aa 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -506,7 +506,7 @@ def create_induction_head_dataset(tokenizer, seed, num_prompts=100): "text": text, } ) - return dataset + return Dataset.from_list(dataset) def test_induction_head_labels(tokenizer): @@ -547,16 +547,83 @@ def setup_training( ): """Set up the training configuration with Bergson callback.""" + pad_id = -100 + def compute_metrics(eval_preds): - accuracy = ( - (eval_preds.label_ids == eval_preds.predictions.argmax(axis=-1)) - .astype(np.float32) - .mean() - .item() - ) - return { - "accuracy": accuracy, - } + # predictions: (B, T, V) + # label_ids: with your collator, this equals input_ids: (B, T) + preds = eval_preds.predictions + input_ids = eval_preds.label_ids + + correct = 0 + total = 0 + # for each sequence, evaluate the final next-token prediction + for i in range(input_ids.shape[0]): + seq = input_ids[i] + # last non-pad index j + non_pad = np.where(seq != pad_id)[0] + if len(non_pad) == 0: + continue + j = non_pad[-1] + if j == 0: + continue # nothing to predict + pred_tok = preds[i, j - 1].argmax(-1) + tgt_tok = seq[j] + correct += int(pred_tok == tgt_tok) + total += 1 + + # avoid div-by-zero + acc = (correct / total) if total > 0 else 0.0 + return {"accuracy": acc} + + # def compute_metrics(eval_preds): + # print("compute_metrics") + # # predictions: (B, T, V) + # preds = eval_preds.predictions + # label_ids = eval_preds.label_ids + + # correct = 0 + # total = 0 + + # # how many examples to print + # max_print = 5 + # printed = 0 + + # for i in range(label_ids.shape[0]): + # seq = label_ids[i] + # # last non-pad index j + # non_pad = np.where(seq != pad_id)[0] + # if len(non_pad) == 0: + # continue + # j = non_pad[-1] + # if j == 0: + # continue + + # # predicted token at position j-1 (predicting token j) + # pred_logits = preds[i, j - 1] + # pred_tok = pred_logits.argmax(-1) + # tgt_tok = seq[j] + + # correct += int(pred_tok == tgt_tok) + # total += 1 + + # # Trigger additional info approximately 1% of the time + # if random.random() < 0.01: + # if printed < max_print: + # seq_str = tokenizer.decode(seq[:j + 1], skip_special_tokens=True) + # pred_str = tokenizer.decode([pred_tok]) + # tgt_str = tokenizer.decode([tgt_tok]) + # print("=" * 40) + # print(f"Example {i}") + # print(f"Context up to target: {seq_str}") + # print(f"Target token id: {tgt_tok} ({tgt_str})") + # print(f"Predicted token id: {pred_tok} ({pred_str})") + # print(f"Match? {pred_tok == tgt_tok}") + # printed += 1 + + # acc = correct / total + + # return {"accuracy": acc} data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, @@ -614,12 +681,9 @@ def compute_metrics(eval_preds): def build_induction_index( - model, induction_prompts, output_dir, projection_dim, unit_norm + model, induction_dataset, output_dir, projection_dim, unit_norm ): """Build static query Bergson index using synthetic induction head data.""" - # Convert induction prompts to dataset format - induction_dataset = Dataset.from_list(induction_prompts) - # Create gradient processor processor = GradientProcessor( {}, @@ -706,7 +770,7 @@ def main(args): # # Create induction head dataset # test_induction_head_labels(tokenizer) - induction_prompts = create_induction_head_dataset( + induction_dataset = create_induction_head_dataset( tokenizer, seed=seed, num_prompts=100 ) @@ -718,20 +782,18 @@ def main(args): else: train_dataset, _ = load_data(tokenizer, name=dataset_name) - eval_dataset = Dataset.from_list(induction_prompts) - trainer = setup_training( model, tokenizer, train_dataset, - eval_dataset, + eval_dataset=induction_dataset, output_dir=output_dir, projection_dim=projection_dim, wandb=False, num_train_epochs=num_train_epochs, ) - trainer.train() + trainer.train() # resume_from_checkpoint=True) trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) @@ -748,7 +810,7 @@ def main(args): # Build Bergson index for induction head queries mean_induction_gradients, module_induction_gradients = build_induction_index( - model, induction_prompts, output_dir, projection_dim, unit_norm + model, induction_dataset, output_dir, projection_dim, unit_norm ) model = model.cpu() @@ -833,7 +895,7 @@ def main(args): print(f"Loaded module scores from {df_path}") else: data = [] - for epoch_idx in range(trainer.args.num_train_epochs): + for epoch_idx in range(num_train_epochs): grads = Attributor( index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", device="cpu", @@ -1025,7 +1087,7 @@ def main(args): from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument("--projection_dim", type=int, default=128) + parser.add_argument("--projection_dim", type=int, default=16) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--train", action="store_true") parser.add_argument("--unit_norm", action="store_true") From 85bd366266faefb03f9c9a0b4e247fc5c0d0381b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 7 Oct 2025 04:55:43 +0000 Subject: [PATCH 13/15] Support full scores calculation with FAISS; assume mod faiss impl in induction heads script --- bergson/attributor.py | 7 +- bergson/faiss_index.py | 156 ++++++++++++++++++---------- bergson/huggingface.py | 3 + examples/find_induction_heads.py | 89 ++++++++-------- examples/trainer_grad_collection.py | 2 +- 5 files changed, 155 insertions(+), 102 deletions(-) diff --git a/bergson/attributor.py b/bergson/attributor.py index eb2c1af..ced3065 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -74,7 +74,10 @@ def __init__( self.grads[name] /= norm def search( - self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None + self, + queries: dict[str, Tensor], + k: int | None, + modules: list[str] | None = None, ) -> tuple[Tensor, Tensor]: """ Search for the `k` nearest examples in the index based on the query or queries. @@ -112,7 +115,7 @@ def search( ) modules = modules or list(q.keys()) - k = min(k, self.N) + k = min(k or self.N, self.N) scores = torch.stack( [q[name] @ self.grads[name].mT for name in modules], dim=-1 diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index fab29d7..a076e7f 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -1,9 +1,8 @@ import json -import math import os from dataclasses import dataclass from pathlib import Path -from time import time +from time import perf_counter from typing import Protocol import numpy as np @@ -124,10 +123,12 @@ def index_to_device(index: Index, device: str) -> Index: class FaissIndex: - """FAISS index.""" + """Shard-based FAISS index.""" shards: list[Index] + faiss_cfg: FaissConfig + def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool): try: import faiss @@ -145,74 +146,111 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo f"{'_unit_norm' if unit_norm else ''}" ) ) + faiss_path.mkdir(exist_ok=True, parents=True) - if not (faiss_path.exists() and any(faiss_path.iterdir())): + if not any(faiss_path.iterdir()): print("Building FAISS index...") - start = time() + start = perf_counter() + + root_path = Path(path) + if (root_path / "info.json").exists(): + info_paths = [root_path / "info.json"] + else: + info_paths = [ + shard_path / "info.json" + for shard_path in sorted(root_path.iterdir()) + if shard_path.is_dir() and (shard_path / "info.json").exists() + ] + + if not info_paths: + raise FileNotFoundError(f"No gradient metadata found under {path}") + + total_grads = sum( + [json.load(open(info_path))["num_grads"] for info_path in info_paths] + ) - faiss_path.mkdir(exist_ok=True, parents=True) + assert faiss_cfg.num_shards <= total_grads and faiss_cfg.num_shards > 0 - num_dataset_shards = len(list(Path(path).iterdir())) - shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards) + # Set the number of grads for each faiss index shard + base_shard_size = total_grads // faiss_cfg.num_shards + remainder = total_grads % faiss_cfg.num_shards + shard_sizes = [base_shard_size] * (faiss_cfg.num_shards) + shard_sizes[-1] += remainder + + # Verify all gradients will be consumed + assert ( + sum(shard_sizes) == total_grads + ), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}" dl = gradients_loader(path) - buffer = [] - index_idx = 0 + buffer: list[NDArray] = [] + buffer_size = 0 + shard_idx = 0 - for grads in tqdm(dl, desc="Loading gradients"): - grads = structured_to_unstructured(grads) + def build_shard_from_buffer( + buffer_parts: list[NDArray], shard_idx: int + ) -> None: + print(f"Building shard {shard_idx}...") + grads_chunk = np.concatenate(buffer_parts, axis=0) + buffer_parts.clear() - if unit_norm: - grads = normalize_grads(grads, device, faiss_cfg.batch_size) + index = faiss.index_factory( + grads_chunk.shape[1], + faiss_cfg.index_factory, + faiss.METRIC_INNER_PRODUCT, + ) + index = index_to_device(index, device) + if faiss_cfg.max_train_examples is not None: + train_examples = min( + faiss_cfg.max_train_examples, grads_chunk.shape[0] + ) + else: + train_examples = grads_chunk.shape[0] + index.train(grads_chunk[:train_examples]) + index.add(grads_chunk) - buffer.append(grads) + del grads_chunk - if len(buffer) == shards_per_index: - # Build index shard - print(f"Building shard {index_idx}...") + index = index_to_device(index, "cpu") + faiss.write_index(index, str(faiss_path / f"{shard_idx}.faiss")) - grads = np.concatenate(buffer, axis=0) - buffer = [] + for grads in tqdm(dl, desc="Loading gradients"): + grads = structured_to_unstructured(grads) - index = faiss.index_factory( - grads.shape[1], - faiss_cfg.index_factory, - faiss.METRIC_INNER_PRODUCT, - ) - index = index_to_device(index, device) - train_examples = faiss_cfg.max_train_examples or grads.shape[0] - index.train(grads[:train_examples]) - index.add(grads) + if unit_norm: + grads = normalize_grads(grads, device, faiss_cfg.batch_size) - # Write index to disk - del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) + batch_idx = 0 + batch_size = grads.shape[0] + while batch_idx < batch_size and shard_idx < faiss_cfg.num_shards: + remaining_in_shard = shard_sizes[shard_idx] - buffer_size + take = min(remaining_in_shard, batch_size - batch_idx) - index_idx += 1 + if take > 0: + buffer.append(grads[batch_idx : batch_idx + take]) + buffer_size += take + batch_idx += take - if buffer: - grads = np.concatenate(buffer, axis=0) - buffer = [] - index = faiss.index_factory( - grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT - ) - index = index_to_device(index, device) - index.train(grads) - index.add(grads) + if buffer_size == shard_sizes[shard_idx]: + build_shard_from_buffer(buffer, shard_idx) + buffer = [] + buffer_size = 0 + shard_idx += 1 - # Write index to disk del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) - print(f"Built index in {(time() - start) / 60:.2f} minutes.") - del buffer, index + assert shard_idx == faiss_cfg.num_shards + print(f"Built index in {(perf_counter() - start) / 60:.2f} minutes.") + + shard_paths = sorted( + (c for c in faiss_path.glob("*.faiss") if c.stem.isdigit()), + key=lambda p: int(p.stem), + ) shards = [] - for i in range(faiss_cfg.num_shards): + for shard_path in shard_paths: shard = faiss.read_index( - str(faiss_path / f"{i}.faiss"), + str(shard_path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, ) if not faiss_cfg.mmap_index: @@ -220,21 +258,25 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo shards.append(shard) + if len(shards) != faiss_cfg.num_shards: + faiss_cfg.num_shards = len(shards) + self.shards = shards - def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: + def search(self, q: NDArray, k: int | None) -> tuple[NDArray, NDArray]: """Note: if fewer than `k` examples are found FAISS will return items - with the index -1 and the maximum negative distance.""" + with the index -1 and the maximum negative distance. If `k` is `None`, + all examples will be returned.""" shard_distances = [] shard_indices = [] offset = 0 - for index in self.shards: - index.nprobe = self.faiss_cfg.nprobe - distances, indices = index.search(q, k) + for shard in self.shards: + shard.nprobe = self.faiss_cfg.nprobe + distances, indices = shard.search(q, k or shard.ntotal) indices += offset - offset += index.ntotal + offset += shard.ntotal shard_distances.append(distances) shard_indices.append(indices) @@ -244,7 +286,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: # Rerank results overfetched from multiple shards if len(self.shards) > 1: - topk_indices = np.argsort(distances, axis=1)[:, :k] + topk_indices = np.argsort(distances, axis=1)[:, : k or self.ntotal] indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] diff --git a/bergson/huggingface.py b/bergson/huggingface.py index 9bc853a..656ffca 100644 --- a/bergson/huggingface.py +++ b/bergson/huggingface.py @@ -35,6 +35,7 @@ def __init__( accumulate_grads: bool = False, use_optimizer_state: bool = True, track_order: bool = False, + shard_size: int | None = 200_000, ): """ Args: @@ -50,6 +51,8 @@ def __init__( normalize the gradients. If `False`, no normalization is applied. track_order: Whether to record the shuffled order of training data. + head_cfgs: Information used to split matrix-valued parameters into + per-head matrices before down projection. """ super().__init__() diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index b3204aa..e54be32 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -37,7 +37,13 @@ from transformers.modeling_outputs import CausalLMOutputWithPast import wandb -from bergson import Attributor, GradientProcessor, HeadConfig, collect_gradients +from bergson import ( + Attributor, + FaissConfig, + GradientProcessor, + HeadConfig, + collect_gradients, +) from bergson.huggingface import ( GradientCollectorCallback, prepare_for_gradient_collection, @@ -681,7 +687,11 @@ def compute_metrics(eval_preds): def build_induction_index( - model, induction_dataset, output_dir, projection_dim, unit_norm + model, + induction_dataset, + output_dir, + projection_dim, + unit_norm, ): """Build static query Bergson index using synthetic induction head data.""" # Create gradient processor @@ -712,15 +722,13 @@ def build_induction_index( ) # Collect mean gradient from attributor index - mean_gradient = torch.cat([grad for grad in attributor.grads.values()], dim=1).mean( - dim=0 - ) mean_module_gradients = { - name: attributor.grads[name].mean(dim=0) for name in attributor.grads.keys() + name: attributor.grads[name].mean(dim=0, keepdim=True) + for name in attributor.grads.keys() } print("In-context index built successfully! Returning mean gradients...") - return mean_gradient, mean_module_gradients + return mean_module_gradients def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories"): @@ -809,8 +817,12 @@ def main(args): model = model.to(device) # Build Bergson index for induction head queries - mean_induction_gradients, module_induction_gradients = build_induction_index( - model, induction_dataset, output_dir, projection_dim, unit_norm + mean_module_induction_gradients = build_induction_index( + model, + induction_dataset, + output_dir, + projection_dim, + unit_norm, ) model = model.cpu() @@ -826,26 +838,25 @@ def main(args): data = [] for epoch_idx in range(num_train_epochs): # Read Bergson index from training - grads = Attributor( + attributor = Attributor( str(Path(output_dir) / "gradients" / "train" / f"epoch_{epoch_idx}"), - device=device, + device="cpu", unit_norm=unit_norm, dtype=torch.float32, - ).grads - - inner_products = None - offset = 0 - for grad in grads.values(): - d = grad.shape[1] - mean_block = mean_induction_gradients[offset : offset + d] - offset += d - for i in range(0, grad.shape[0], 1024): - batch = grad[i : i + 1024].to(mean_block.device, dtype=torch.float32) - contrib = batch @ mean_block.float() - if inner_products is None: - inner_products = torch.zeros(grad.shape[0], device=contrib.device) - inner_products[i : i + 1024] += contrib - inner_products = inner_products.cpu() + faiss_cfg=FaissConfig( + mmap_index=True, index_factory="IVF1,SQfp16", num_shards=10 + ), + ) + + # returns from largest to smallest 3 2 1 ... + inner_products, indices = attributor.search( + mean_module_induction_gradients, k=None + ) + del attributor + + # put in original order + order = indices.argsort(dim=-1) + inner_products = torch.gather(inner_products, -1, order) for i, score in enumerate(inner_products.squeeze()): training_metadata = training_order[ @@ -896,32 +907,26 @@ def main(args): else: data = [] for epoch_idx in range(num_train_epochs): - grads = Attributor( + attributor = Attributor( index_path=f"{trainer.args.output_dir}/gradients/train/epoch_{epoch_idx}", device="cpu", - dtype=torch.float32, unit_norm=unit_norm, - ).grads + dtype=torch.float32, + faiss_cfg=FaissConfig( + mmap_index=True, index_factory="IVF1,SQfp16", num_shards=10 + ), + ) - # module_inner_products = {} - offset = 0 - for name, grad in grads.items(): + for name, grad in mean_module_induction_gradients.items(): if "attention" not in name and "attn" not in name: print(f"Skipping {name}") continue else: print(f"Processing {name}") - d = grad.shape[1] - mean_block = mean_induction_gradients[offset : offset + d] - offset += d - scores = [] - for i in range(0, grad.shape[0], 1024): - batch = grad[i : i + 1024].to( - mean_block.device, dtype=torch.float32 - ) - scores.append(batch @ mean_block.float()) - mod_inner_products = torch.cat(scores, dim=0).cpu() + mod_inner_products, _ = attributor.search( + {name: grad}, k=None, modules=[name] + ) for i, score in enumerate(mod_inner_products.squeeze()): training_metadata = training_order[ diff --git a/examples/trainer_grad_collection.py b/examples/trainer_grad_collection.py index 612ce11..ff9a92f 100644 --- a/examples/trainer_grad_collection.py +++ b/examples/trainer_grad_collection.py @@ -120,7 +120,7 @@ def main(args: IndexConfig): conversation_column=args.data.conversation_column, ) dataset = load_data_string( - args.data.dataset, args.data.split, streaming=args.streaming + args.data.dataset, args.data.split, streaming=args.stream ) dataset = dataset.map( tokenize, From 0850f5d2adc3196e78e5f49f7ea845975e5ea2dc Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 14 Oct 2025 07:43:01 +0000 Subject: [PATCH 14/15] Fix induction heads types --- examples/find_induction_heads.py | 60 ++++++++++++++------------------ 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index e54be32..ee0f5da 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -686,14 +686,14 @@ def compute_metrics(eval_preds): return trainer -def build_induction_index( +def mean_query_gradients( model, induction_dataset, output_dir, projection_dim, unit_norm, ): - """Build static query Bergson index using synthetic induction head data.""" + """Build on-disk Bergson index using synthetic induction head data.""" # Create gradient processor processor = GradientProcessor( {}, @@ -746,6 +746,8 @@ def upload_to_hub(model, tokenizer, model_name="2layer-transformer-tinystories") def main(args): + check_logins() + dataset_name = "EleutherAI/SmolLM2-135M-10B" # dataset_name = "RonenEldan/TinyStories" num_train_epochs = 1 @@ -756,7 +758,7 @@ def main(args): projection_dim = args.projection_dim seed = args.seed train = args.train - plot = args.plot + analyze = args.analyze output_dir = f"examples/runs/transformer_2_layer{'_' + tag if tag else ''}" @@ -764,27 +766,20 @@ def main(args): "Starting 2-layer transformer pretraining with Bergson gradient collection..." ) - # Check authentication - check_logins() - - # Set device - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") - # Create model and tokenizer model, tokenizer = create_transformer( special_pos_embed=not args.no_special_pos_embed ) # # Create induction head dataset - # test_induction_head_labels(tokenizer) + # test_induction_head_labels(tokenizer) # Outdated induction_dataset = create_induction_head_dataset( tokenizer, seed=seed, num_prompts=100 ) if train: - # Set up training - # Load data if args.small: train_dataset, _ = load_data(tokenizer, name=dataset_name, N=20_000) else: @@ -805,19 +800,14 @@ def main(args): trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) - if not plot: + if not analyze: return # upload_to_hub(model, tokenizer) - # Reload model and tokenizer - # model = AutoModelForCausalLM.from_pretrained(output_dir) - model = AttnOnlyForCausalLM.from_pretrained(output_dir) - tokenizer = AutoTokenizer.from_pretrained(output_dir) - model = model.to(device) - - # Build Bergson index for induction head queries - mean_module_induction_gradients = build_induction_index( + # Get mean module gradients for induction head queries + model = model.to(device) # type: ignore + mean_module_induction_gradients = mean_query_gradients( model, induction_dataset, output_dir, @@ -827,14 +817,15 @@ def main(args): model = model.cpu() # Load parquet table containing training order - training_order = load_from_disk( - str(Path(output_dir) / "gradients" / "order.hf") - ).to_pandas() + training_order_ds = assert_type( + Dataset, load_from_disk(str(Path(output_dir) / "gradients" / "order.hf")) + ) + training_order = assert_type(pd.DataFrame, training_order_ds.to_pandas()) - # Plots + # Analyze data os.makedirs("figures", exist_ok=True) - # Calculate the inner products with the training gradients + # Calculate the mean query gradients' inner products with the training gradients data = [] for epoch_idx in range(num_train_epochs): # Read Bergson index from training @@ -848,15 +839,12 @@ def main(args): ), ) - # returns from largest to smallest 3 2 1 ... + # Ordered from largest to smallest like (3 2 1 ...) inner_products, indices = attributor.search( mean_module_induction_gradients, k=None ) - del attributor - - # put in original order - order = indices.argsort(dim=-1) - inner_products = torch.gather(inner_products, -1, order) + # Restore original order + inner_products = torch.gather(inner_products, -1, indices.argsort(dim=-1)) for i, score in enumerate(inner_products.squeeze()): training_metadata = training_order[ @@ -876,6 +864,7 @@ def main(args): ) data = pd.DataFrame(data) + # Visualize the influence scores plt.figure(figsize=(12, 8)) plt.scatter( data["global_step"], @@ -899,6 +888,9 @@ def main(args): bbox_inches="tight", ) + print("Module-wise scores not yet supported for FAISS index") + exit() + # Produce the same plot but split out by module (i.e. key in the grads mmap) df_path = f"figures/module_scores_{tag}{'_norm' if unit_norm else ''}.csv" if os.path.exists(df_path): @@ -1074,7 +1066,7 @@ def main(args): # Step 1: pick checkpoint steps # Step 2: compute a bunch of gradients at this step using the static index build # and save it - # Step 1.5: fix the horrible static index build bug + # Step 1.5: fix the static index build bug # Can we use optimal transport to align the gradients? # Should we transport the activations then transport the gradients in the same way? @@ -1098,7 +1090,7 @@ def main(args): parser.add_argument("--unit_norm", action="store_true") parser.add_argument("--small", action="store_true") parser.add_argument("--tag", type=str, default="") - parser.add_argument("--plot", action="store_true") + parser.add_argument("--analyze", action="store_true") parser.add_argument("--no_special_pos_embed", action="store_false") args = parser.parse_args() main(args) From c40f83a7b393b8becaba3569c8e7db128fb75dea Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 14 Oct 2025 08:00:11 +0000 Subject: [PATCH 15/15] Fix induction heads types --- examples/find_induction_heads.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/find_induction_heads.py b/examples/find_induction_heads.py index ee0f5da..3b227d1 100644 --- a/examples/find_induction_heads.py +++ b/examples/find_induction_heads.py @@ -857,7 +857,9 @@ def main(args): data.append( { "epoch": epoch_idx, - "global_step": row.global_step, + "global_step": row[ + training_metadata.columns.get_loc("global_step") + ], "index": i, "score": score.item(), }