From 9ab5dd21986afc8041caf14a647315af6bda60e7 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 17 Oct 2025 08:39:40 +0000 Subject: [PATCH 1/5] Assemble query --- README.md | 14 +- bergson/data.py | 30 +++-- bergson/query.py | 6 +- examples/assemble_query.py | 257 +++++++++++++++++++++++++++++++++++++ 4 files changed, 291 insertions(+), 16 deletions(-) create mode 100644 examples/assemble_query.py diff --git a/README.md b/README.md index fbb5aec..ecbb12e 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,11 @@ trainer.train() ## Attention Head Gradients -By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module can be collected too. To collect per-head gradients configure an AttentionConfig for each module of interest. +By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module may also be collected. To collect per-head gradients programmatically configure an AttentionConfig for each module of interest, or specify attention modules and a single shared configuration using the command line tool. + +```bash +bergson build runs/test --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation --split_attention_modules "h.0.attn.attention.out_proj" --attention.num_heads 16 --attention.head_size 4 --attention.head_dim 2 +``` ```python from bergson import AttentionConfig, IndexConfig, DataConfig @@ -99,7 +103,7 @@ Where a reward signal is available we compute gradients using a weighted advanta bergson build --model --dataset --reward_column ``` -## Queries +## Index Queries We provide a query Attributor which supports unit normalized gradients and KNN search out of the box. @@ -127,6 +131,12 @@ with attr.trace(model.base_model, 5) as result: model.zero_grad() ``` +## On-The-Fly Queries + +```bash +bergson query runs/scores --query_path runs/query_index --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --save_index False --truncation +``` + # Development ```bash diff --git a/bergson/data.py b/bergson/data.py index 24afd7e..b130d20 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -30,6 +30,9 @@ class DataConfig: dataset: str = "EleutherAI/SmolLM2-135M-10B" """Dataset identifier to build the index from.""" + subset: str | None = None + """Subset of the dataset to use for building the index.""" + split: str = "train" """Split of the dataset to use for building the index.""" @@ -70,24 +73,26 @@ class QueryConfig: """Config for querying an index on the fly.""" query_path: str = "" - """Path to the query dataset.""" + """Path to the existing query index.""" - query_method: Literal["mean", "nearest"] = "mean" - """Method to use for computing the query.""" + score: Literal["mean", "nearest"] = "mean" + """Method for scoring the gradients with the query. If mean + gradients will be scored by their similarity with the mean + query gradients, otherwise by the most similar query gradient.""" save_processor: bool = True """Whether to write the query dataset gradient processor to disk.""" query_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients.""" index_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients. - This does not affect the ability to compute a new - preconditioner during gradient collection.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients. This does not affect the + ability to compute a new preconditioner during gradient + collection.""" mixing_coefficient: float = 0.5 """Coefficient to weight the application of the query preconditioner @@ -357,7 +362,7 @@ def create_index( def load_data_string( - data_str: str, split: str = "train", streaming: bool = False + data_str: str, split: str = "train", subset: str | None = None, streaming: bool = False ) -> Dataset | IterableDataset: """Load a dataset from a string identifier or path.""" if data_str.endswith(".csv"): @@ -366,7 +371,10 @@ def load_data_string( ds = assert_type(Dataset, Dataset.from_json(data_str)) else: try: - ds = load_dataset(data_str, split=split, streaming=streaming) + if subset: + ds = load_dataset(data_str, split=split, subset=subset, streaming=streaming) + else: + ds = load_dataset(data_str, split=split, streaming=streaming) if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): raise NotImplementedError( diff --git a/bergson/query.py b/bergson/query.py index 522ee67..abf23f1 100644 --- a/bergson/query.py +++ b/bergson/query.py @@ -306,14 +306,14 @@ def worker( query_device = torch.device(f"cuda:{rank}") query_dtype = dtype if dtype != "auto" else torch.float16 - if query_cfg.query_method == "mean": + if query_cfg.score == "mean": query_callback = get_mean_query(query_ds, query_cfg, query_device, query_dtype) - elif query_cfg.query_method == "nearest": + elif query_cfg.score == "nearest": query_callback = get_nearest_query( query_ds, query_cfg, query_device, query_dtype ) else: - raise ValueError(f"Invalid query method: {query_cfg.query_method}") + raise ValueError(f"Invalid query scoring method: {query_cfg.score}") if isinstance(ds, Dataset): batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) diff --git a/examples/assemble_query.py b/examples/assemble_query.py new file mode 100644 index 0000000..79aa9c6 --- /dev/null +++ b/examples/assemble_query.py @@ -0,0 +1,257 @@ +# TODAY +# Assemble dataset!! 6 queries +# Try multi node generation +# I believe the MCQA and Cloze setups are pulled from the same eval and are both roughly 1k rows, like the original wmdp-bio as a whole. + +import os +import socket +from collections import defaultdict +import torch.multiprocessing as mp +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes +from numpy.lib.recfunctions import structured_to_unstructured +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer +from datasets import get_dataset_config_names, concatenate_datasets +import torch + +from bergson.utils import assert_type +from bergson import load_gradients, IndexConfig, DataConfig +from bergson.build import estimate_advantage, dist_worker, worker + + +def load_mcqa_dataset(): + def map(x): + """Many of the questions require the choices to be given to be coherent, e.g. + 'Which of the following is the correct answer to the question?'""" + + choices = [f"{i}. {choice}" for i, choice in enumerate(x["choices"])] + prompt = " \n ".join([x["question"]] + ["Choices: "] + choices + ["Answer: "] + [f"{choices[int(x['answer'])]}"]) + + return { + "text": prompt, + "subset": x["subset"] + } + mcqa_dataset_name = "EleutherAI/wmdp_bio_robust_mcqa" + subsets = get_dataset_config_names(mcqa_dataset_name) + mcqa_datasets = [] + for subset in subsets: + ds = assert_type(Dataset, load_dataset(mcqa_dataset_name, subset, split="robust")) + ds = ds.add_column("subset", [subset] * len(ds)) + mcqa_datasets.append(ds) + + mcqa_ds = concatenate_datasets(mcqa_datasets) + + return mcqa_ds.map(map, remove_columns=["choices", "answer", "question"]) + + +def tokenize_mcqa( + batch: dict, + *, + tokenizer, + args: DataConfig, + answer_marker: str = "Answer:", +): + """ + Custom tokenizer for this MCQA experiment that only keeps labels on the + final answer span so gradient collection ignores the rest of the prompt. + + Codex wrote this. + """ + + if not tokenizer.is_fast: + raise ValueError("Fast tokenizer required for answer span alignment.") + + kwargs = dict( + return_attention_mask=False, + return_length=True, + truncation=args.truncation, + return_offsets_mapping=True, + ) + + encodings = tokenizer(batch[args.prompt_column], **kwargs) + + whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} + labels: list[list[int]] = [] + answer_token_indices: list[list[int]] = [] + + for text, offsets, token_ids in zip( + batch[args.prompt_column], + encodings["offset_mapping"], + encodings["input_ids"], + ): + marker_idx = text.rfind(answer_marker) + if marker_idx < 0: + raise ValueError(f"Failed to locate '{answer_marker}' in:\n{text}") + + start = marker_idx + len(answer_marker) + while start < len(text) and text[start] in whitespace: + start += 1 + + if start >= len(text): + raise ValueError( + f"No answer text found after '{answer_marker}' in:\n{text}" + ) + + end = len(text) + while end > start and text[end - 1] in whitespace: + end -= 1 + + if end <= start: + raise ValueError(f"Empty answer span detected in:\n{text}") + + example_labels = [-100] * len(token_ids) + supervised_indices: list[int] = [] + + for idx, span in enumerate(offsets): + if span is None: + continue + + tok_start, tok_end = span + if tok_start is None or tok_end is None or tok_start == tok_end: + continue + + if max(tok_start, start) < min(tok_end, end): + example_labels[idx] = token_ids[idx] + supervised_indices.append(idx) + + if not supervised_indices: + raise RuntimeError( + "Failed to align answer text with any tokens.\n" + f"Example:\n{text}" + ) + + labels.append(example_labels) + answer_token_indices.append(supervised_indices) + + encodings.pop("offset_mapping") + encodings["labels"] = labels + encodings["answer_token_indices"] = answer_token_indices + return encodings + + +def build_mcqa_index(cfg: IndexConfig, ds_path: str): + # TODO Set labels mask to final answer + # TODO are these labels always one token? + + # In many cases the token_batch_size may be smaller than the max length allowed by + # the model. If cfg.data.truncation is True, we use the tokenizer to truncate + tokenizer = AutoTokenizer.from_pretrained(cfg.model, revision=cfg.revision) + tokenizer.model_max_length = min(tokenizer.model_max_length, cfg.token_batch_size) + + # Do all the data loading and preprocessing on the main process + ds = Dataset.load_from_disk(ds_path, keep_in_memory=False) + + remove_columns = ds.column_names if cfg.drop_columns else None + ds = ds.map( + tokenize_mcqa, + batched=True, + fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer), + remove_columns=remove_columns, + ) + if cfg.data.reward_column: + assert isinstance(ds, Dataset), "Dataset required for advantage estimation" + ds = ds.add_column( + "advantage", + estimate_advantage(ds, cfg.data), + new_fingerprint="advantage", # type: ignore + ) + + world_size = torch.cuda.device_count() + if world_size <= 1: + # Run the worker directly if no distributed training is needed. This is great + # for debugging purposes. + worker(0, 1, cfg, ds) + else: + # Set up multiprocessing and distributed training + mp.set_sharing_strategy("file_system") + + # Find an available port for distributed training + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + _, port = s.getsockname() + + ctx = start_processes( + "build", + dist_worker, + args={i: (i, world_size, cfg, ds) for i in range(world_size)}, + envs={ + i: { + "LOCAL_RANK": str(i), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(port), + } + for i in range(world_size) + }, + logs_specs=DefaultLogsSpecs(), + ) + ctx.wait() + + try: + os.rename(cfg.partial_run_path, cfg.run_path) + except Exception: + pass + + +def assemble_query(query_ds: Dataset, run_path: str): + mmap = load_gradients(run_path) + + target_modules = set(mmap.dtype.names) + + print(f"Full projection dim: {len(target_modules) * 64}") + + mmap = structured_to_unstructured(mmap) + + print("mmap sum", torch.from_numpy(mmap).sum()) + + # Group mmap gradient rows by the subset they came from + subset_gradients = defaultdict(list) + for mmap_row, ds_row in zip(mmap, query_ds): + subset_gradients[ds_row["subset"]].append(torch.from_numpy(mmap_row)) + + subset_mean_gradients = {} + for subset, gradients in subset_gradients.items(): + print(f"Subset {subset} has {len(gradients)} gradients") + print(f"Computing mean gradient for subset {subset}") + mean_gradient = torch.stack(gradients).mean(dim=0) + print(f"Mean gradient for subset {subset}: {mean_gradient.mean(), mean_gradient.std(), mean_gradient.shape}") + subset_mean_gradients[subset] = mean_gradient + + + overall_mean_gradient = torch.from_numpy(mmap).mean(dim=0) + print(f"Overall mean gradient: {overall_mean_gradient.mean(), overall_mean_gradient.std(), overall_mean_gradient.shape}") + + breakpoint() + + # Maybe just overwrite the original dataset with the means for each + # subset and the overall mean + + +def main(): + mcqa_ds = assert_type(Dataset, load_mcqa_dataset()) + ds_path = "runs/ds_wmdp_bio_robust_mcqa" + mcqa_ds.save_to_disk(ds_path) + + data_config = DataConfig( + dataset=ds_path, + prompt_column="text", + ) + + cfg = IndexConfig( + run_path="runs/wmdp_bio_robust_mcqa_means", + save_index=True, + save_processor=True, + precision="fp32", + data=data_config, + fsdp=True, + model="EleutherAI/deep-ignorance-unfiltered", + projection_dim=64, + reshape_to_square=True, + ) + + build_mcqa_index(cfg, ds_path) + + assemble_query(mcqa_ds, cfg.run_path) + + +if __name__ == "__main__": + main() From fa4a7776f605783b09087aabcba9d859c90b4ecf Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sun, 19 Oct 2025 05:10:30 +0200 Subject: [PATCH 2/5] assemble query --- bergson/data.py | 21 ++-- bergson/query.py | 4 +- examples/assemble_query.py | 238 +++++++++++++++++++++++++++++++------ 3 files changed, 215 insertions(+), 48 deletions(-) diff --git a/bergson/data.py b/bergson/data.py index b130d20..6775063 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -85,19 +85,19 @@ class QueryConfig: to disk.""" query_preconditioner_path: str | None = None - """Path to a precomputed preconditioner to be applied to + """Path to a precomputed preconditioner to be applied to the query dataset gradients.""" index_preconditioner_path: str | None = None - """Path to a precomputed preconditioner to be applied to - the query dataset gradients. This does not affect the - ability to compute a new preconditioner during gradient + """Path to a precomputed preconditioner to be applied to + the query dataset gradients. This does not affect the + ability to compute a new preconditioner during gradient collection.""" - mixing_coefficient: float = 0.5 + mixing_coefficient: float = 0.99 """Coefficient to weight the application of the query preconditioner and the pre-computed index preconditioner. 0.0 means only use the - query preconditioner and 1.0 means only use the index preconditioner.""" + index preconditioner and 1.0 means only use the query preconditioner.""" modules: list[str] = field(default_factory=list) """Modules to use for the query. If empty, all modules will be used.""" @@ -362,7 +362,10 @@ def create_index( def load_data_string( - data_str: str, split: str = "train", subset: str | None = None, streaming: bool = False + data_str: str, + split: str = "train", + subset: str | None = None, + streaming: bool = False, ) -> Dataset | IterableDataset: """Load a dataset from a string identifier or path.""" if data_str.endswith(".csv"): @@ -372,7 +375,9 @@ def load_data_string( else: try: if subset: - ds = load_dataset(data_str, split=split, subset=subset, streaming=streaming) + ds = load_dataset( + data_str, split=split, subset=subset, streaming=streaming + ) else: ds = load_dataset(data_str, split=split, streaming=streaming) diff --git a/bergson/query.py b/bergson/query.py index abf23f1..336990b 100644 --- a/bergson/query.py +++ b/bergson/query.py @@ -34,7 +34,7 @@ from .utils import assert_type, get_layer_list -def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig): +def get_query_data(query_cfg: QueryConfig): """ Load and optionally precondition the query dataset. Preconditioners may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. @@ -415,7 +415,7 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): new_fingerprint="advantage", # type: ignore ) - query_ds = get_query_data(index_cfg, query_cfg) + query_ds = get_query_data(query_cfg) world_size = torch.cuda.device_count() if world_size <= 1: diff --git a/examples/assemble_query.py b/examples/assemble_query.py index 79aa9c6..7c2f639 100644 --- a/examples/assemble_query.py +++ b/examples/assemble_query.py @@ -1,22 +1,35 @@ # TODAY # Assemble dataset!! 6 queries # Try multi node generation -# I believe the MCQA and Cloze setups are pulled from the same eval and are both roughly 1k rows, like the original wmdp-bio as a whole. +# I believe the MCQA and Cloze setups are pulled from the same eval and are +# both roughly 1k rows, like the original wmdp-bio as a whole. import os +import shutil import socket from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch import torch.multiprocessing as mp +from datasets import ( + Dataset, + concatenate_datasets, + get_dataset_config_names, + load_dataset, +) +from numpy.lib.recfunctions import ( + structured_to_unstructured, + unstructured_to_structured, +) from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes -from numpy.lib.recfunctions import structured_to_unstructured -from datasets import load_dataset, Dataset from transformers import AutoTokenizer -from datasets import get_dataset_config_names, concatenate_datasets -import torch +from bergson import DataConfig, IndexConfig, load_gradients +from bergson.build import dist_worker, estimate_advantage, worker +from bergson.data import create_index from bergson.utils import assert_type -from bergson import load_gradients, IndexConfig, DataConfig -from bergson.build import estimate_advantage, dist_worker, worker def load_mcqa_dataset(): @@ -25,25 +38,105 @@ def map(x): 'Which of the following is the correct answer to the question?'""" choices = [f"{i}. {choice}" for i, choice in enumerate(x["choices"])] - prompt = " \n ".join([x["question"]] + ["Choices: "] + choices + ["Answer: "] + [f"{choices[int(x['answer'])]}"]) + prompt = " \n ".join( + [x["question"]] + + ["Choices: "] + + choices + + ["Answer: "] + + [f"{choices[int(x['answer'])]}"] + ) return { "text": prompt, - "subset": x["subset"] + # "prompt": prompt, + # "completion": f"{choices[int(x['answer'])]}", + "subset": x["subset"], } + mcqa_dataset_name = "EleutherAI/wmdp_bio_robust_mcqa" subsets = get_dataset_config_names(mcqa_dataset_name) mcqa_datasets = [] for subset in subsets: - ds = assert_type(Dataset, load_dataset(mcqa_dataset_name, subset, split="robust")) + ds = assert_type( + Dataset, load_dataset(mcqa_dataset_name, subset, split="robust") + ) ds = ds.add_column("subset", [subset] * len(ds)) mcqa_datasets.append(ds) mcqa_ds = concatenate_datasets(mcqa_datasets) - + return mcqa_ds.map(map, remove_columns=["choices", "answer", "question"]) +def tokenize( + batch: dict, *, args: DataConfig, tokenizer, apply_chat_template: bool = True +): + """Tokenize a batch of data with `tokenizer` according to `args`.""" + kwargs = dict( + return_attention_mask=False, + return_length=True, + truncation=args.truncation, + ) + if args.completion_column: + # We're dealing with a prompt-completion dataset + convos = [ + [ + {"role": "user", "content": assert_type(str, prompt)}, + {"role": "assistant", "content": assert_type(str, resp)}, + ] + for prompt, resp in zip( + batch[args.prompt_column], batch[args.completion_column] + ) + ] + elif args.conversation_column: + # We're dealing with a conversation dataset + convos = assert_type(list, batch[args.conversation_column]) + else: + # We're dealing with vanilla next-token prediction + return tokenizer(batch[args.prompt_column], **kwargs) + + strings = convos + + encodings = tokenizer(strings, **kwargs) + labels_list: list[list[int]] = [] + + for i, convo in enumerate(convos): + # Find the spans of the assistant's responses in the tokenized output + pos = 0 + spans: list[tuple[int, int]] = [] + + for msg in convo: + if msg["role"] != "assistant": + continue + + ans = msg["content"] + start = strings[i].rfind(ans, pos) + if start < 0: + raise RuntimeError( + "Failed to find completion in the chat-formatted conversation. " + "Make sure the chat template does not alter the completion, e.g. " + "by removing leading whitespace." + ) + + # move past this match + pos = start + len(ans) + + start_token = encodings.char_to_token(i, start) + end_token = encodings.char_to_token(i, pos) + spans.append((start_token, end_token)) + + # Labels are -100 everywhere except where the assistant's response is + tokens = encodings["input_ids"][i] + labels = [-100] * len(tokens) + for start, end in spans: + if start is not None and end is not None: + labels[start:end] = tokens[start:end] + + labels_list.append(labels) + + return dict(**encodings, labels=labels_list) + + def tokenize_mcqa( batch: dict, *, @@ -57,6 +150,8 @@ def tokenize_mcqa( Codex wrote this. """ + # TODO integrate custom masking into tokenize if necessary + return tokenize(batch, args=args, tokenizer=tokenizer, apply_chat_template=False) if not tokenizer.is_fast: raise ValueError("Fast tokenizer required for answer span alignment.") @@ -116,8 +211,7 @@ def tokenize_mcqa( if not supervised_indices: raise RuntimeError( - "Failed to align answer text with any tokens.\n" - f"Example:\n{text}" + "Failed to align answer text with any tokens.\n" f"Example:\n{text}" ) labels.append(example_labels) @@ -192,65 +286,133 @@ def build_mcqa_index(cfg: IndexConfig, ds_path: str): pass -def assemble_query(query_ds: Dataset, run_path: str): - mmap = load_gradients(run_path) - - target_modules = set(mmap.dtype.names) +def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str): + structured_mmap = load_gradients(run_path) + mmap_dtype = structured_mmap.dtype - print(f"Full projection dim: {len(target_modules) * 64}") - - mmap = structured_to_unstructured(mmap) + # Copy into memory + gradient_tensor = torch.tensor(structured_to_unstructured(structured_mmap)).to( + torch.float32 + ) - print("mmap sum", torch.from_numpy(mmap).sum()) + print("mmap sum", gradient_tensor.sum()) + print("mmap sum", gradient_tensor.abs().sum()) # Group mmap gradient rows by the subset they came from subset_gradients = defaultdict(list) - for mmap_row, ds_row in zip(mmap, query_ds): - subset_gradients[ds_row["subset"]].append(torch.from_numpy(mmap_row)) + for grads_row, ds_row in zip(gradient_tensor, query_ds): + subset_gradients[ds_row["subset"]].append(grads_row) - subset_mean_gradients = {} + subset_mean_gradients = {"overall": gradient_tensor.mean(dim=0)} for subset, gradients in subset_gradients.items(): - print(f"Subset {subset} has {len(gradients)} gradients") - print(f"Computing mean gradient for subset {subset}") mean_gradient = torch.stack(gradients).mean(dim=0) - print(f"Mean gradient for subset {subset}: {mean_gradient.mean(), mean_gradient.std(), mean_gradient.shape}") subset_mean_gradients[subset] = mean_gradient + # Copy everything from the origin run path to the new path + # except data.hf and gradients.bin + os.makedirs(assembled_dataset_path, exist_ok=True) + for file in os.listdir(run_path): + if file != "data.hf" and file != "gradients.bin": + dest = Path(assembled_dataset_path) / file + shutil.copy(Path(run_path) / file, dest) + + # Write the mean queries to data.hf + # subset_mean_gradients["overall"] = overall_mean_gradient + # means_dataset = Dataset.from_dict(subset_mean_gradients) + # means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + + if torch.any(cosine_sims <= 0.09): + raise ValueError( + f"Cosine similarity between mean gradients and the first query gradient " + f"is not greater than 0.09. Cosine sims: {cosine_sims}" + ) + else: + print(f"Cosine sims: {cosine_sims}") + + # Create_index with the 7 gradients + index_dtype = np.float16 - overall_mean_gradient = torch.from_numpy(mmap).mean(dim=0) - print(f"Overall mean gradient: {overall_mean_gradient.mean(), overall_mean_gradient.std(), overall_mean_gradient.shape}") + mean_gradients_unstructured = np.stack( + [v.numpy() for v in subset_mean_gradients.values()], + axis=0, + ).astype(index_dtype) + mean_gradients_structured = unstructured_to_structured( + mean_gradients_unstructured, mmap_dtype + ) - breakpoint() + grad_sizes = {} + for name in mmap_dtype.names: + field_dtype = mmap_dtype.fields[name][0] + subdtype = field_dtype.subdtype + assert subdtype is not None - # Maybe just overwrite the original dataset with the means for each - # subset and the overall mean + _, shape = subdtype + grad_sizes[name] = int(np.prod(shape)) + + index_grads = create_index( + str(assembled_dataset_path), len(subset_mean_gradients), grad_sizes, index_dtype + ) + index_grads[:] = mean_gradients_structured + index_grads.flush() + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + if torch.any(cosine_sims <= 0.09): + raise ValueError( + f"Cosine similarity between mean gradients and the first query gradient " + f"is not greater than 0.09. Cosine sims: {cosine_sims}" + ) + else: + print(f"Cosine sims: {cosine_sims}") def main(): + projection_dim = 64 + model_name = "EleutherAI/deep_ignorance_pretraining_baseline_small" + ds_path = f"runs/ds_wmdp_bio_robust_mcqa_{projection_dim}" + index_path = f"runs/wmdp_bio_robust_mcqa_means_{projection_dim}" + assembled_dataset_path = f"runs/mean_biorisk_queries_{projection_dim}" + mcqa_ds = assert_type(Dataset, load_mcqa_dataset()) - ds_path = "runs/ds_wmdp_bio_robust_mcqa" mcqa_ds.save_to_disk(ds_path) data_config = DataConfig( dataset=ds_path, + # prompt_column="prompt", + # completion_column="completion", prompt_column="text", ) cfg = IndexConfig( - run_path="runs/wmdp_bio_robust_mcqa_means", + run_path=index_path, save_index=True, save_processor=True, - precision="fp32", + precision="fp16", data=data_config, fsdp=True, - model="EleutherAI/deep-ignorance-unfiltered", - projection_dim=64, + model=model_name, + projection_dim=projection_dim, reshape_to_square=True, ) build_mcqa_index(cfg, ds_path) - assemble_query(mcqa_ds, cfg.run_path) + # Trackstar uses 2**16 with an 8B model + # We are collecting gradients for a ~2.7B model + # We are using ~2**13 I think + modules = set(load_gradients(cfg.run_path).dtype.names) + print(f"Full projection dim: {len(modules) * cfg.projection_dim}") + + assemble_query(mcqa_ds, cfg.run_path, assembled_dataset_path) if __name__ == "__main__": From e33c324f79340da19162597da6c80cdcebbe8210 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 20 Oct 2025 02:35:11 +0200 Subject: [PATCH 3/5] save progress --- bergson/__main__.py | 20 +++- bergson/attributor.py | 9 +- bergson/collection.py | 3 +- bergson/data.py | 2 +- bergson/faiss_index.py | 160 +++++++++++++++++++------------ bergson/query.py | 36 +++++++ bergson/static_query.py | 190 +++++++++++++++++++++++++++++++++++++ examples/assemble_query.py | 73 +++++++------- 8 files changed, 393 insertions(+), 100 deletions(-) create mode 100644 bergson/static_query.py diff --git a/bergson/__main__.py b/bergson/__main__.py index 27cf7a8..d76b304 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -7,6 +7,24 @@ from .build import build_gradient_dataset from .data import IndexConfig, QueryConfig from .query import query_gradient_dataset +from .static_query import query_gradient_dataset as query_static_gradient_dataset + + +@dataclass +class StaticQuery: + """Query an on-disk gradient index.""" + + scores_path: str + + query_cfg: QueryConfig + + index_cfg: IndexConfig + + k: int | None = None + + def execute(self): + """Query an on-disk gradient index.""" + query_static_gradient_dataset(self.scores_path, self.query_cfg, self.index_cfg, self.k) @dataclass @@ -51,7 +69,7 @@ def execute(self): class Main: """Routes to the subcommands.""" - command: Union[Build, Query] + command: Union[Build, Query, StaticQuery] def execute(self): """Run the script.""" diff --git a/bergson/attributor.py b/bergson/attributor.py index eb2c1af..60859c2 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -74,8 +74,11 @@ def __init__( self.grads[name] /= norm def search( - self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None - ) -> tuple[Tensor, Tensor]: + self, + queries: dict[str, Tensor], + k: int | None, + modules: list[str] | None = None, + ): """ Search for the `k` nearest examples in the index based on the query or queries. @@ -112,7 +115,7 @@ def search( ) modules = modules or list(q.keys()) - k = min(k, self.N) + k = min(k or self.N, self.N) scores = torch.stack( [q[name] @ self.grads[name].mT for name in modules], dim=-1 diff --git a/bergson/collection.py b/bergson/collection.py index 7ef81da..b40f1f5 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -30,6 +30,7 @@ def collect_gradients( save_processor: bool = True, drop_columns: bool = False, query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None, + num_scores: int = 1, ): """ Compute projected gradients using a subset of the dataset. @@ -95,7 +96,7 @@ def callback(name: str, g: torch.Tensor): fill_value=0.0, ) per_doc_scores = torch.full( - (len(data),), + (len(data), num_scores), device=model.device, dtype=dtype, fill_value=0.0, diff --git a/bergson/data.py b/bergson/data.py index 6775063..8ce0a26 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -75,7 +75,7 @@ class QueryConfig: query_path: str = "" """Path to the existing query index.""" - score: Literal["mean", "nearest"] = "mean" + score: Literal["mean", "nearest", "individual"] = "mean" """Method for scoring the gradients with the query. If mean gradients will be scored by their similarity with the mean query gradients, otherwise by the most similar query gradient.""" diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index 87f3a72..9be0082 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass from pathlib import Path -from time import time +from time import perf_counter from types import ModuleType from typing import TYPE_CHECKING, Protocol @@ -95,7 +95,7 @@ def load_shard(shard_dir: str) -> np.memmap: yield load_shard(root_dir) else: for shard_path in sorted(root_path.iterdir()): - if shard_path.is_dir(): + if shard_path.is_dir() and "shard" in shard_path.name: yield load_shard(str(shard_path)) @@ -135,10 +135,12 @@ def index_to_device(index: Index, device: str) -> Index: class FaissIndex: - """FAISS index.""" + """Shard-based FAISS index.""" shards: list[Index] + faiss_cfg: FaissConfig + def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool): faiss = _require_faiss() @@ -152,75 +154,110 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo f"{'_unit_norm' if unit_norm else ''}" ) ) + faiss_path.mkdir(exist_ok=True, parents=True) - if not (faiss_path.exists() and any(faiss_path.iterdir())): + if not any(faiss_path.iterdir()): print("Building FAISS index...") - start = time() + start = perf_counter() + + root_path = Path(path) + if (root_path / "info.json").exists(): + info_paths = [root_path / "info.json"] + else: + info_paths = [ + shard_path / "info.json" + for shard_path in sorted(root_path.iterdir()) + if shard_path.is_dir() and (shard_path / "info.json").exists() + ] + + if not info_paths: + raise FileNotFoundError(f"No gradient metadata found under {path}") + + total_grads = sum( + [json.load(open(info_path))["num_grads"] for info_path in info_paths] + ) - faiss_path.mkdir(exist_ok=True, parents=True) + assert faiss_cfg.num_shards <= total_grads and faiss_cfg.num_shards > 0 - num_dataset_shards = len(list(Path(path).iterdir())) - shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards) + # Set the number of grads for each faiss index shard + base_shard_size = total_grads // faiss_cfg.num_shards + remainder = total_grads % faiss_cfg.num_shards + shard_sizes = [base_shard_size] * (faiss_cfg.num_shards) + shard_sizes[-1] += remainder - dl = gradients_loader(path) - buffer = [] - index_idx = 0 + # Verify all gradients will be consumed + assert ( + sum(shard_sizes) == total_grads + ), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}" - for grads in tqdm(dl, desc="Loading gradients"): - if grads.dtype.names is not None: - grads = structured_to_unstructured(grads) + dl = gradients_loader(path) + buffer: list[NDArray] = [] + buffer_size = 0 + shard_idx = 0 - if unit_norm: - grads = normalize_grads(grads, device, faiss_cfg.batch_size) + def build_shard_from_buffer( + buffer_parts: list[NDArray], shard_idx: int + ) -> None: + print(f"Building shard {shard_idx}...") + grads_chunk = np.concatenate(buffer_parts, axis=0) + buffer_parts.clear() - buffer.append(grads) + index = faiss.index_factory( + grads_chunk.shape[1], + faiss_cfg.index_factory, + faiss.METRIC_INNER_PRODUCT, + ) + index = index_to_device(index, device) + if faiss_cfg.max_train_examples is not None: + train_examples = min( + faiss_cfg.max_train_examples, grads_chunk.shape[0] + ) + else: + train_examples = grads_chunk.shape[0] + index.train(grads_chunk[:train_examples]) + index.add(grads_chunk) - if len(buffer) == shards_per_index: - # Build index shard - print(f"Building shard {index_idx}...") + del grads_chunk - grads = np.concatenate(buffer, axis=0) - buffer = [] + index = index_to_device(index, "cpu") + faiss.write_index(index, str(faiss_path / f"{shard_idx}.faiss")) - index = faiss.index_factory( - grads.shape[1], - faiss_cfg.index_factory, - faiss.METRIC_INNER_PRODUCT, - ) - index = index_to_device(index, device) - train_examples = faiss_cfg.max_train_examples or grads.shape[0] - index.train(grads[:train_examples]) - index.add(grads) + for grads in tqdm(dl, desc="Loading gradients"): + grads = structured_to_unstructured(grads) + if unit_norm: + grads = normalize_grads(grads, device, faiss_cfg.batch_size) - # Write index to disk - del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) + batch_idx = 0 + batch_size = grads.shape[0] + while batch_idx < batch_size and shard_idx < faiss_cfg.num_shards: + remaining_in_shard = shard_sizes[shard_idx] - buffer_size + take = min(remaining_in_shard, batch_size - batch_idx) - index_idx += 1 + if take > 0: + buffer.append(grads[batch_idx : batch_idx + take]) + buffer_size += take + batch_idx += take - if buffer: - grads = np.concatenate(buffer, axis=0) - buffer = [] - index = faiss.index_factory( - grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT - ) - index = index_to_device(index, device) - index.train(grads) - index.add(grads) + if buffer_size == shard_sizes[shard_idx]: + build_shard_from_buffer(buffer, shard_idx) + buffer = [] + buffer_size = 0 + shard_idx += 1 - # Write index to disk del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) - print(f"Built index in {(time() - start) / 60:.2f} minutes.") - del buffer + assert shard_idx == faiss_cfg.num_shards + print(f"Built index in {(perf_counter() - start) / 60:.2f} minutes.") + + shard_paths = sorted( + (c for c in faiss_path.glob("*.faiss") if c.stem.isdigit()), + key=lambda p: int(p.stem), + ) shards = [] - for i in range(faiss_cfg.num_shards): + for shard_path in shard_paths: shard = faiss.read_index( - str(faiss_path / f"{i}.faiss"), + str(shard_path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, ) if not faiss_cfg.mmap_index: @@ -228,21 +265,26 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo shards.append(shard) + # TODO raise? + if len(shards) != faiss_cfg.num_shards: + faiss_cfg.num_shards = len(shards) + self.shards = shards - def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: + def search(self, q: NDArray, k: int | None) -> tuple[NDArray, NDArray]: """Note: if fewer than `k` examples are found FAISS will return items - with the index -1 and the maximum negative distance.""" + with the index -1 and the maximum negative distance. If `k` is `None`, + all examples will be returned.""" shard_distances = [] shard_indices = [] offset = 0 - for index in self.shards: - index.nprobe = self.faiss_cfg.nprobe - distances, indices = index.search(q, k) + for shard in self.shards: + shard.nprobe = self.faiss_cfg.nprobe + distances, indices = shard.search(q, k or shard.ntotal) indices += offset - offset += index.ntotal + offset += shard.ntotal shard_distances.append(distances) shard_indices.append(indices) @@ -252,7 +294,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: # Rerank results overfetched from multiple shards if len(self.shards) > 1: - topk_indices = np.argsort(distances, axis=1)[:, :k] + topk_indices = np.argsort(distances, axis=1)[:, : k or self.ntotal] indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] diff --git a/bergson/query.py b/bergson/query.py index 336990b..5699b97 100644 --- a/bergson/query.py +++ b/bergson/query.py @@ -94,6 +94,36 @@ def precondition(batch): return query_ds +def get_individual_query( + query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype +): + """ + Compute the individual query and return a callback function that scores gradients + according to their inner products or cosine similarities with the individual queries. + Requires a custom setup in the score saving code. + """ + queries = torch.cat([query_ds[:][name] for name in query_cfg.modules], dim=1).to( + device=device, dtype=dtype + ) + + if query_cfg.unit_normalize: + queries /= queries.norm(dim=1, keepdim=True) + + # Assert on device + assert queries.device == device + + def callback(mod_grads: dict[str, torch.Tensor]): + grads = torch.cat([mod_grads[name] for name in query_cfg.modules], dim=1) + if query_cfg.unit_normalize: + grads /= grads.norm(dim=1, keepdim=True) + + # Return a score for every query + print(grads.device, queries.device) + return grads @ queries.T + + return callback + + def get_mean_query( query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype ): @@ -312,6 +342,10 @@ def worker( query_callback = get_nearest_query( query_ds, query_cfg, query_device, query_dtype ) + elif query_cfg.score == "individual": + query_callback = get_individual_query( + query_ds, query_cfg, query_device, query_dtype + ) else: raise ValueError(f"Invalid query scoring method: {query_cfg.score}") @@ -332,6 +366,7 @@ def worker( query_callback=query_callback, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, + num_scores=1 if query_cfg.score == "individual" else len(query_ds), ) else: # Convert each shard to a Dataset then collect its gradients @@ -360,6 +395,7 @@ def flush(): query_callback=query_callback, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, + num_scores=1 if query_cfg.score == "individual" else len(query_ds), ) buf.clear() shard_id += 1 diff --git a/bergson/static_query.py b/bergson/static_query.py new file mode 100644 index 0000000..5259779 --- /dev/null +++ b/bergson/static_query.py @@ -0,0 +1,190 @@ +import json +import os +from time import perf_counter +from pathlib import Path + +import torch +from datasets import Dataset, IterableDataset +from transformers import AutoTokenizer + +from bergson import Attributor, FaissConfig +from .data import ( + IndexConfig, + QueryConfig, + load_data_string, + load_gradient_dataset, + tokenize, + load_gradients, +) +from .gradients import GradientProcessor + + +def get_query_data(query_cfg: QueryConfig): + """ + Load and optionally precondition the query dataset. Preconditioners + may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. + """ + # Collect the query gradients if they don't exist + if not os.path.exists(query_cfg.query_path): + raise FileNotFoundError( + f"Query dataset not found at {query_cfg.query_path}. " + "Please build a query dataset index first." + ) + + # Load the query dataset + with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: + target_modules = json.load(f)["dtype"]["names"] + + query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False) + query_ds = query_ds.with_format( + "torch", columns=target_modules, format_kwargs={"dtype": torch.float64} + ) + + use_q = query_cfg.query_preconditioner_path is not None + use_i = query_cfg.index_preconditioner_path is not None + + if use_q or use_i: + q, i = {}, {} + if use_q: + assert query_cfg.query_preconditioner_path is not None + q = GradientProcessor.load( + query_cfg.query_preconditioner_path, + map_location="cuda", + ).preconditioners + if use_i: + assert query_cfg.index_preconditioner_path is not None + i = GradientProcessor.load( + query_cfg.index_preconditioner_path, map_location="cuda" + ).preconditioners + + mixed_preconditioner = ( + { + k: q[k] * query_cfg.mixing_coefficient + + i[k] * (1 - query_cfg.mixing_coefficient) + for k in q + } + if (q and i) + else (q or i) + ) + mixed_preconditioner = { + k: v.cuda().to(torch.float64) for k, v in mixed_preconditioner.items() + } + + def precondition(batch): + for name in target_modules: + breakpoint() + print(batch[name].shape, mixed_preconditioner[name].shape) + batch[name] = ( + batch[name].cuda().to(torch.float64) @ mixed_preconditioner[name] + ).cpu() + + return batch + + breakpoint() + query_ds = query_ds.map( + precondition, batched=True, batch_size=query_cfg.batch_size + ) + + return query_ds + + +def get_text_rows(indices, scores, index_cfg: IndexConfig): + ds = load_data_string( + index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming + ) + + + if not index_cfg.streaming: + assert isinstance(ds, Dataset), "Dataset required for direct selection" + return ds.select(indices) + else: + rows = [] + assert isinstance(ds, IterableDataset), "IterableDataset required for streaming" + # Loop through the dataset and collect the indices + for i, row in enumerate(ds): + if i in indices: + rows.append(row) + return Dataset.from_list(rows) + + +@torch.inference_mode() +def query_gradient_dataset( + query_cfg: QueryConfig, index_cfg: IndexConfig, device="cpu", k: int | None = 50 +): + # In many cases the token_batch_size may be smaller than the max length allowed by + # the model. If cfg.data.truncation is True, we use the tokenizer to truncate + tokenizer = AutoTokenizer.from_pretrained( + index_cfg.model, revision=index_cfg.revision + ) + tokenizer.model_max_length = min( + tokenizer.model_max_length, index_cfg.token_batch_size + ) + + # # Do all the data loading and preprocessing on the main process + # ds = load_data_string( + # index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming + # ) + + # remove_columns = ds.column_names if index_cfg.drop_columns else None + # ds = ds.map( + # tokenize, + # batched=True, + # fn_kwargs=dict(args=index_cfg.data, tokenizer=tokenizer), + # remove_columns=remove_columns, + # ) + # if index_cfg.data.reward_column: + # assert isinstance(ds, Dataset), "Dataset required for advantage estimation" + # ds = ds.add_column( + # "advantage", + # estimate_advantage(ds, index_cfg.data), + # new_fingerprint="advantage", # type: ignore + # ) + + if not query_cfg.modules: + query_cfg.modules = list(load_gradients(query_cfg.query_path).dtype.names) + + query_ds = get_query_data(query_cfg) + query_ds = query_ds.with_format("torch", columns=query_cfg.modules) + + start = perf_counter() + attr = Attributor( + index_cfg.run_path, + device=device, + faiss_cfg=FaissConfig("IVF1,SQfp16", mmap_index=True, num_shards=5), + unit_norm=query_cfg.unit_normalize, + ) + print(f"Attributor loaded in {perf_counter() - start}") + + # print({name: torch.tensor(query_ds[:][name]) for name in query_cfg.modules}) + + print("Searching...") + start = perf_counter() + scores, indices = attr.search( + { + name: torch.tensor(query_ds[:][name]).to(device) + for name in query_cfg.modules + }, + k, + ) + + print(f"Query time: {perf_counter() - start}") + + data = { + "scores": scores, + "indices": indices, + } + print(data) + print("Max score", scores.max()) + print("Min score", scores.min()) + print("Mean score", scores.mean()) + print("Std score", scores.std()) + + dataset = Dataset.from_dict(data) + dataset.save_to_disk(Path(query_cfg.query_path) / "trial" / "static_query.hf") + + # Get the text rows associated with the top 50 scores + text_ds = get_text_rows(indices, scores, index_cfg) + print(text_ds) + + # Add the scores to the text rows + dataset.save_to_disk(Path(query_cfg.query_path) / "trial" / "static_query.hf") diff --git a/examples/assemble_query.py b/examples/assemble_query.py index 7c2f639..498cf94 100644 --- a/examples/assemble_query.py +++ b/examples/assemble_query.py @@ -28,10 +28,11 @@ from bergson import DataConfig, IndexConfig, load_gradients from bergson.build import dist_worker, estimate_advantage, worker -from bergson.data import create_index +from bergson.data import create_index, load_gradient_dataset from bergson.utils import assert_type + def load_mcqa_dataset(): def map(x): """Many of the questions require the choices to be given to be coherent, e.g. @@ -285,8 +286,7 @@ def build_mcqa_index(cfg: IndexConfig, ds_path: str): except Exception: pass - -def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str): +def create_query_index(query_ds: Dataset, run_path: str, assembled_dataset_path: str, index_dtype: np.dtype): structured_mmap = load_gradients(run_path) mmap_dtype = structured_mmap.dtype @@ -309,17 +309,32 @@ def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str subset_mean_gradients[subset] = mean_gradient # Copy everything from the origin run path to the new path - # except data.hf and gradients.bin + # except gradients.bin and data.hf os.makedirs(assembled_dataset_path, exist_ok=True) - for file in os.listdir(run_path): - if file != "data.hf" and file != "gradients.bin": - dest = Path(assembled_dataset_path) / file - shutil.copy(Path(run_path) / file, dest) - - # Write the mean queries to data.hf - # subset_mean_gradients["overall"] = overall_mean_gradient - # means_dataset = Dataset.from_dict(subset_mean_gradients) - # means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") + for item in os.listdir(run_path): + if item not in ["gradients.bin", "data.hf"]: + dest = Path(assembled_dataset_path) / item + shutil.copy(Path(run_path) / item, dest) + + + if (Path(assembled_dataset_path) / "data.hf").exists(): + if (Path(assembled_dataset_path) / "data.hf").is_file(): + (Path(assembled_dataset_path) / "data.hf").unlink() + else: + shutil.rmtree(Path(assembled_dataset_path) / "data.hf") + + # Write structured mean queries to data.hf + np_mean_grads = np.stack([item.numpy() for item in list(subset_mean_gradients.values())], axis=0) + structured_np_mean_grads = unstructured_to_structured(np_mean_grads, mmap_dtype) + # data = [ + # {name: structured_np_mean_grads[name][i].tolist() for name in mmap_dtype.names} + # for i in range(structured_np_mean_grads.shape[0]) + # ] + + means_dataset = Dataset.from_dict({ + "scores": [0.] * len(subset_mean_gradients), + }) + means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) @@ -327,25 +342,7 @@ def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str mean_grad_stack, first_query_grad, dim=1 ) - if torch.any(cosine_sims <= 0.09): - raise ValueError( - f"Cosine similarity between mean gradients and the first query gradient " - f"is not greater than 0.09. Cosine sims: {cosine_sims}" - ) - else: - print(f"Cosine sims: {cosine_sims}") - - # Create_index with the 7 gradients - index_dtype = np.float16 - - mean_gradients_unstructured = np.stack( - [v.numpy() for v in subset_mean_gradients.values()], - axis=0, - ).astype(index_dtype) - mean_gradients_structured = unstructured_to_structured( - mean_gradients_unstructured, mmap_dtype - ) - + # Assemble grad sizes grad_sizes = {} for name in mmap_dtype.names: field_dtype = mmap_dtype.fields[name][0] @@ -355,12 +352,18 @@ def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str _, shape = subdtype grad_sizes[name] = int(np.prod(shape)) + # Create and populate the index index_grads = create_index( str(assembled_dataset_path), len(subset_mean_gradients), grad_sizes, index_dtype ) - index_grads[:] = mean_gradients_structured + structured_mean_grads = unstructured_to_structured(np_mean_grads.astype(index_dtype), mmap_dtype) + index_grads[:] = unstructured_to_structured( + np_mean_grads.astype(index_dtype), mmap_dtype + ) index_grads.flush() + ds = load_gradient_dataset(assembled_dataset_path) + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) cosine_sims = torch.nn.functional.cosine_similarity( @@ -376,7 +379,7 @@ def assemble_query(query_ds: Dataset, run_path: str, assembled_dataset_path: str def main(): - projection_dim = 64 + projection_dim = 16 model_name = "EleutherAI/deep_ignorance_pretraining_baseline_small" ds_path = f"runs/ds_wmdp_bio_robust_mcqa_{projection_dim}" index_path = f"runs/wmdp_bio_robust_mcqa_means_{projection_dim}" @@ -412,7 +415,7 @@ def main(): modules = set(load_gradients(cfg.run_path).dtype.names) print(f"Full projection dim: {len(modules) * cfg.projection_dim}") - assemble_query(mcqa_ds, cfg.run_path, assembled_dataset_path) + create_query_index(mcqa_ds, cfg.run_path, assembled_dataset_path, index_dtype=np.float16) if __name__ == "__main__": From 0d3ad9e9d34fd1851729ac1debeee28c736a9373 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 20 Oct 2025 05:09:42 +0200 Subject: [PATCH 4/5] Add static query and query analysis --- bergson/__main__.py | 4 +- bergson/collection.py | 23 +---- bergson/data.py | 3 + bergson/query.py | 94 +++++++++++++++-- bergson/query_existing.py | 70 +++++++++++++ bergson/static_query.py | 190 ---------------------------------- examples/query_analysis.py | 202 +++++++++++++++++++++++++++++++++++++ 7 files changed, 366 insertions(+), 220 deletions(-) create mode 100644 bergson/query_existing.py delete mode 100644 bergson/static_query.py create mode 100644 examples/query_analysis.py diff --git a/bergson/__main__.py b/bergson/__main__.py index d76b304..93db26f 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -7,7 +7,7 @@ from .build import build_gradient_dataset from .data import IndexConfig, QueryConfig from .query import query_gradient_dataset -from .static_query import query_gradient_dataset as query_static_gradient_dataset +from .query_existing import query_existing @dataclass @@ -24,7 +24,7 @@ class StaticQuery: def execute(self): """Query an on-disk gradient index.""" - query_static_gradient_dataset(self.scores_path, self.query_cfg, self.index_cfg, self.k) + query_existing(self.scores_path, self.query_cfg, self.index_cfg, self.k) @dataclass diff --git a/bergson/collection.py b/bergson/collection.py index b40f1f5..6b615c0 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -1,5 +1,5 @@ import math -from typing import Callable, Literal +from typing import Literal import numpy as np import torch @@ -12,6 +12,7 @@ from .data import create_index, pad_and_tensor from .gradients import AttentionConfig, GradientCollector, GradientProcessor from .peft import set_peft_enabled +from .query import Query def collect_gradients( @@ -29,8 +30,7 @@ def collect_gradients( save_index: bool = True, save_processor: bool = True, drop_columns: bool = False, - query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None, - num_scores: int = 1, + query: Query | None = None, ): """ Compute projected gradients using a subset of the dataset. @@ -95,12 +95,6 @@ def callback(name: str, g: torch.Tensor): dtype=dtype, fill_value=0.0, ) - per_doc_scores = torch.full( - (len(data), num_scores), - device=model.device, - dtype=dtype, - fill_value=0.0, - ) for indices in tqdm(batches, disable=rank != 0, desc="Building index"): batch = data[indices] @@ -157,9 +151,8 @@ def callback(name: str, g: torch.Tensor): for module_name in mod_grads.keys(): grad_buffer[module_name][indices] = mod_grads[module_name].numpy() - if query_callback is not None: - scores = query_callback(mod_grads) - per_doc_scores[indices] = scores.detach().type_as(per_doc_scores) + if query: + query(indices, mod_grads) mod_grads.clear() per_doc_losses[indices] = losses.detach().type_as(per_doc_losses) @@ -179,12 +172,6 @@ def callback(name: str, g: torch.Tensor): feature=Value("float16" if dtype == torch.float16 else "float32"), new_fingerprint="loss", ) - data = data.add_column( - "scores", - per_doc_scores.cpu().numpy(), - feature=Value("float16" if dtype == torch.float16 else "float32"), - new_fingerprint="scores", - ) data.save_to_disk(path + "/data.hf") if save_processor: diff --git a/bergson/data.py b/bergson/data.py index 8ce0a26..8a19746 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -80,6 +80,9 @@ class QueryConfig: gradients will be scored by their similarity with the mean query gradients, otherwise by the most similar query gradient.""" + scores_path: str = "" + """Path to the directory where query scores should be written.""" + save_processor: bool = True """Whether to write the query dataset gradient processor to disk.""" diff --git a/bergson/query.py b/bergson/query.py index 5699b97..9bda82a 100644 --- a/bergson/query.py +++ b/bergson/query.py @@ -1,8 +1,9 @@ import json import os import socket +import uuid from datetime import timedelta -from typing import cast +from typing import Callable, cast import torch import torch.distributed as dist @@ -34,6 +35,57 @@ from .utils import assert_type, get_layer_list +class Query: + """ + Wraps a query scoring callback and stores the resulting scores in a tensor. + """ + + def __init__( + self, + query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor], + num_items: int, + num_scores: int, + scores_path: str, + *, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", + ): + self._query_callback = query_callback + self._scores_path = scores_path + self.scores = torch.zeros((num_items, num_scores), dtype=dtype, device=device) + self.num_written = 0 + + def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): + scores = self._query_callback(mod_grads).detach() + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + assert scores.shape[0] == len(indices) + assert scores.shape[1] == self.scores.shape[1] + + scores = scores.to(device=self.scores.device, dtype=self.scores.dtype) + self.scores[indices] = scores + self.num_written += len(indices) + + if self.num_written >= len(self.scores): + self.flush() + + def flush(self): + dataset = Dataset.from_dict( + { + "scores": self.scores.cpu().numpy().tolist(), + } + ) + try: + dataset.save_to_disk(self._scores_path) + except Exception as e: + # Handle collisions with existing datasets + print(f"Error writing scores to disk: {e}") + random_hash = str(uuid.uuid4())[:8] + alternate_path = self._scores_path.replace(".hf", f"_{random_hash}.hf") + print(f"Writing to alternate path: {alternate_path}") + dataset.save_to_disk(alternate_path) + + def get_query_data(query_cfg: QueryConfig): """ Load and optionally precondition the query dataset. Preconditioners @@ -99,8 +151,8 @@ def get_individual_query( ): """ Compute the individual query and return a callback function that scores gradients - according to their inner products or cosine similarities with the individual queries. - Requires a custom setup in the score saving code. + according to their inner products or cosine similarities with the individual + queries. """ queries = torch.cat([query_ds[:][name] for name in query_cfg.modules], dim=1).to( device=device, dtype=dtype @@ -333,24 +385,40 @@ def worker( query_ds = query_ds.with_format("torch", columns=query_cfg.modules) + print(f"Model device: {model.device}") query_device = torch.device(f"cuda:{rank}") query_dtype = dtype if dtype != "auto" else torch.float16 if query_cfg.score == "mean": - query_callback = get_mean_query(query_ds, query_cfg, query_device, query_dtype) + base_query_callback = get_mean_query( + query_ds, query_cfg, query_device, query_dtype + ) + num_scores = 1 elif query_cfg.score == "nearest": - query_callback = get_nearest_query( + base_query_callback = get_nearest_query( query_ds, query_cfg, query_device, query_dtype ) + num_scores = 1 elif query_cfg.score == "individual": - query_callback = get_individual_query( + base_query_callback = get_individual_query( query_ds, query_cfg, query_device, query_dtype ) + num_scores = len(query_ds) else: raise ValueError(f"Invalid query scoring method: {query_cfg.score}") + scores_dtype = torch.float32 if model.dtype == torch.float32 else torch.float16 + if isinstance(ds, Dataset): batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) + query = Query( + base_query_callback, + len(ds), + num_scores, + query_cfg.scores_path, + dtype=scores_dtype, + device=query_device, + ) collect_gradients( model, ds, @@ -363,10 +431,9 @@ def worker( target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + query=query, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, - num_scores=1 if query_cfg.score == "individual" else len(query_ds), ) else: # Convert each shard to a Dataset then collect its gradients @@ -380,6 +447,14 @@ def flush(): batches = allocate_batches( ds_shard["length"][:], index_cfg.token_batch_size ) + query = Query( + base_query_callback, + len(ds_shard), + num_scores, + query_cfg.scores_path, + dtype=scores_dtype, + device=query_device, + ) collect_gradients( model, ds_shard, @@ -392,10 +467,9 @@ def flush(): target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + query=query, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, - num_scores=1 if query_cfg.score == "individual" else len(query_ds), ) buf.clear() shard_id += 1 diff --git a/bergson/query_existing.py b/bergson/query_existing.py new file mode 100644 index 0000000..3720636 --- /dev/null +++ b/bergson/query_existing.py @@ -0,0 +1,70 @@ +from pathlib import Path +from time import perf_counter + +import torch +from datasets import Dataset + +from bergson import Attributor, FaissConfig + +from .data import ( + IndexConfig, + QueryConfig, + load_gradients, +) +from .query import get_query_data + + +@torch.inference_mode() +def query_existing( + scores_path: str, + query_cfg: QueryConfig, + index_cfg: IndexConfig, + k: int | None, + device="cpu", +): + if not query_cfg.modules: + query_cfg.modules = list(load_gradients(query_cfg.query_path).dtype.names) + + query_ds = get_query_data(query_cfg) + query_ds = query_ds.with_format("torch", columns=query_cfg.modules) + + start = perf_counter() + attr = Attributor( + index_cfg.run_path, + device=device, + faiss_cfg=FaissConfig(mmap_index=True, num_shards=5), + unit_norm=query_cfg.unit_normalize, + ) + print(f"Attributor loaded in {perf_counter() - start}") + + print("Searching...") + start = perf_counter() + + search_inputs = { + name: torch.as_tensor(query_ds[:][name]).to(device) + for name in query_cfg.modules + } + + scores, indices = attr.search(search_inputs, k) + + print(f"Query time: {perf_counter() - start}") + + scores_tensor = torch.as_tensor(scores).cpu() + indices_tensor = torch.as_tensor(indices).to(torch.int64).cpu() + + num_queries, num_scores = scores_tensor.shape + print( + f"Collected scores for {num_queries} queries " f"with {num_scores} scores each." + ) + + output_path = Path(scores_path) / "scores.hf" + output_path.parent.mkdir(parents=True, exist_ok=True) + + dataset = Dataset.from_dict( + { + "scores": scores_tensor.tolist(), + "indices": indices_tensor.tolist(), + } + ) + dataset.save_to_disk(output_path) + print(f"Saved raw search results to {output_path}") diff --git a/bergson/static_query.py b/bergson/static_query.py deleted file mode 100644 index 5259779..0000000 --- a/bergson/static_query.py +++ /dev/null @@ -1,190 +0,0 @@ -import json -import os -from time import perf_counter -from pathlib import Path - -import torch -from datasets import Dataset, IterableDataset -from transformers import AutoTokenizer - -from bergson import Attributor, FaissConfig -from .data import ( - IndexConfig, - QueryConfig, - load_data_string, - load_gradient_dataset, - tokenize, - load_gradients, -) -from .gradients import GradientProcessor - - -def get_query_data(query_cfg: QueryConfig): - """ - Load and optionally precondition the query dataset. Preconditioners - may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. - """ - # Collect the query gradients if they don't exist - if not os.path.exists(query_cfg.query_path): - raise FileNotFoundError( - f"Query dataset not found at {query_cfg.query_path}. " - "Please build a query dataset index first." - ) - - # Load the query dataset - with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: - target_modules = json.load(f)["dtype"]["names"] - - query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False) - query_ds = query_ds.with_format( - "torch", columns=target_modules, format_kwargs={"dtype": torch.float64} - ) - - use_q = query_cfg.query_preconditioner_path is not None - use_i = query_cfg.index_preconditioner_path is not None - - if use_q or use_i: - q, i = {}, {} - if use_q: - assert query_cfg.query_preconditioner_path is not None - q = GradientProcessor.load( - query_cfg.query_preconditioner_path, - map_location="cuda", - ).preconditioners - if use_i: - assert query_cfg.index_preconditioner_path is not None - i = GradientProcessor.load( - query_cfg.index_preconditioner_path, map_location="cuda" - ).preconditioners - - mixed_preconditioner = ( - { - k: q[k] * query_cfg.mixing_coefficient - + i[k] * (1 - query_cfg.mixing_coefficient) - for k in q - } - if (q and i) - else (q or i) - ) - mixed_preconditioner = { - k: v.cuda().to(torch.float64) for k, v in mixed_preconditioner.items() - } - - def precondition(batch): - for name in target_modules: - breakpoint() - print(batch[name].shape, mixed_preconditioner[name].shape) - batch[name] = ( - batch[name].cuda().to(torch.float64) @ mixed_preconditioner[name] - ).cpu() - - return batch - - breakpoint() - query_ds = query_ds.map( - precondition, batched=True, batch_size=query_cfg.batch_size - ) - - return query_ds - - -def get_text_rows(indices, scores, index_cfg: IndexConfig): - ds = load_data_string( - index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming - ) - - - if not index_cfg.streaming: - assert isinstance(ds, Dataset), "Dataset required for direct selection" - return ds.select(indices) - else: - rows = [] - assert isinstance(ds, IterableDataset), "IterableDataset required for streaming" - # Loop through the dataset and collect the indices - for i, row in enumerate(ds): - if i in indices: - rows.append(row) - return Dataset.from_list(rows) - - -@torch.inference_mode() -def query_gradient_dataset( - query_cfg: QueryConfig, index_cfg: IndexConfig, device="cpu", k: int | None = 50 -): - # In many cases the token_batch_size may be smaller than the max length allowed by - # the model. If cfg.data.truncation is True, we use the tokenizer to truncate - tokenizer = AutoTokenizer.from_pretrained( - index_cfg.model, revision=index_cfg.revision - ) - tokenizer.model_max_length = min( - tokenizer.model_max_length, index_cfg.token_batch_size - ) - - # # Do all the data loading and preprocessing on the main process - # ds = load_data_string( - # index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming - # ) - - # remove_columns = ds.column_names if index_cfg.drop_columns else None - # ds = ds.map( - # tokenize, - # batched=True, - # fn_kwargs=dict(args=index_cfg.data, tokenizer=tokenizer), - # remove_columns=remove_columns, - # ) - # if index_cfg.data.reward_column: - # assert isinstance(ds, Dataset), "Dataset required for advantage estimation" - # ds = ds.add_column( - # "advantage", - # estimate_advantage(ds, index_cfg.data), - # new_fingerprint="advantage", # type: ignore - # ) - - if not query_cfg.modules: - query_cfg.modules = list(load_gradients(query_cfg.query_path).dtype.names) - - query_ds = get_query_data(query_cfg) - query_ds = query_ds.with_format("torch", columns=query_cfg.modules) - - start = perf_counter() - attr = Attributor( - index_cfg.run_path, - device=device, - faiss_cfg=FaissConfig("IVF1,SQfp16", mmap_index=True, num_shards=5), - unit_norm=query_cfg.unit_normalize, - ) - print(f"Attributor loaded in {perf_counter() - start}") - - # print({name: torch.tensor(query_ds[:][name]) for name in query_cfg.modules}) - - print("Searching...") - start = perf_counter() - scores, indices = attr.search( - { - name: torch.tensor(query_ds[:][name]).to(device) - for name in query_cfg.modules - }, - k, - ) - - print(f"Query time: {perf_counter() - start}") - - data = { - "scores": scores, - "indices": indices, - } - print(data) - print("Max score", scores.max()) - print("Min score", scores.min()) - print("Mean score", scores.mean()) - print("Std score", scores.std()) - - dataset = Dataset.from_dict(data) - dataset.save_to_disk(Path(query_cfg.query_path) / "trial" / "static_query.hf") - - # Get the text rows associated with the top 50 scores - text_ds = get_text_rows(indices, scores, index_cfg) - print(text_ds) - - # Add the scores to the text rows - dataset.save_to_disk(Path(query_cfg.query_path) / "trial" / "static_query.hf") diff --git a/examples/query_analysis.py b/examples/query_analysis.py new file mode 100644 index 0000000..26bbabe --- /dev/null +++ b/examples/query_analysis.py @@ -0,0 +1,202 @@ +import shutil +from argparse import ArgumentParser +from pathlib import Path + +import torch +from datasets import Dataset, IterableDataset, load_from_disk + +from bergson.data import DataConfig, IndexConfig, load_data_string + + +def parse_args(): + parser = ArgumentParser( + description=( + "Aggregate static query results assuming six max-accumulated queries " + "and one mean-based query." + ) + ) + parser.add_argument( + "--trial-dir", + type=Path, + required=True, + help="Directory containing the saved static query results (static_query.hf).", + ) + parser.add_argument( + "--dataset", + type=str, + required=True, + help="HF dataset identifier matching the index build.", + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="Dataset split used for the index (default: train).", + ) + parser.add_argument( + "--subset", + type=str, + default=None, + help="Dataset subset used for the index, if any.", + ) + parser.add_argument( + "--streaming", + action="store_true", + help="Flag indicating the index dataset was streamed.", + ) + parser.add_argument( + "--run-path", + type=str, + default="static-query", + help="Run path label to instantiate IndexConfig (not used for IO).", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Maximum number of rows to keep per strategy (defaults to k).", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory to write strategy outputs (defaults to trial directory).", + ) + parser.add_argument( + "--raw-name", + type=str, + default="static_query.hf", + help="Name of the saved raw search dataset inside the trial directory.", + ) + return parser.parse_args() + + +def build_index_config(args) -> IndexConfig: + data_cfg = DataConfig(dataset=args.dataset, split=args.split, subset=args.subset) + return IndexConfig(run_path=args.run_path, data=data_cfg, streaming=args.streaming) + + +def save_strategy(name: str, dataset: Dataset, output_root: Path): + strategy_dir = output_root / name + strategy_dir.mkdir(parents=True, exist_ok=True) + + csv_path = strategy_dir / f"{name}.csv" + dataset.to_csv(csv_path) + + hf_path = strategy_dir / "hf" + if hf_path.exists(): + shutil.rmtree(hf_path) + dataset.save_to_disk(hf_path) + + return csv_path, hf_path + + +def get_text_rows( + indices: list[int], + scores: list[float], + index_cfg: IndexConfig, + score_column: str = "score", +): + ds = load_data_string( + index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming + ) + + if not index_cfg.streaming: + assert isinstance(ds, Dataset), "Dataset required for direct selection" + + selected = ds.select(indices) + return selected.add_column(score_column, scores, new_fingerprint="score") + + assert isinstance(ds, IterableDataset), "IterableDataset required for streaming" + + pending_idxs = {idx: score for idx, score in zip(indices, scores)} + data = [] + + for i, row in enumerate(ds): + if i in pending_idxs: + score = pending_idxs.pop(i) + data.append( + { + score_column: score, + **row, + } + ) + if not pending_idxs: + break + + assert not pending_idxs, ( + f"Could not collect all rows for the requested indices." + f" Missing indices: {sorted(pending_idxs)}" + ) + + return Dataset.from_list(data) + + +def main(): + args = parse_args() + + trial_dir = args.trial_dir + output_root = args.output_dir or trial_dir + raw_path = trial_dir / args.raw_name + + if not raw_path.exists(): + raise FileNotFoundError( + f"Static query results not found at {raw_path}. " + "Run bergson.static_query.query_gradient_dataset first." + ) + + raw_ds = load_from_disk(raw_path) + scores_tensor = torch.tensor(raw_ds["scores"], dtype=torch.float32) + indices_tensor = torch.tensor(raw_ds["indices"], dtype=torch.int64) + + if scores_tensor.ndim != 2 or indices_tensor.ndim != 2: + raise ValueError( + "Expected scores and indices to have shape (num_queries, k). " + f"Got {scores_tensor.shape} and {indices_tensor.shape}." + ) + + num_queries, num_neighbors = scores_tensor.shape + if num_queries != 7: + raise ValueError( + f"Expected exactly 7 queries (6 for max, 1 for mean); " + f"received {num_queries}." + ) + + limit = args.limit or num_neighbors + + max_scores_by_index: dict[int, float] = {} + for idx, score in zip( + indices_tensor[1:].reshape(-1).tolist(), + scores_tensor[1:].reshape(-1).tolist(), + ): + idx = int(idx) + score = float(score) + best = max_scores_by_index.get(idx) + if best is None or score > best: + max_scores_by_index[idx] = score + + sorted_max = sorted( + max_scores_by_index.items(), key=lambda item: item[1], reverse=True + )[:limit] + max_indices = [idx for idx, _ in sorted_max] + max_scores = [score for _, score in sorted_max] + + mean_indices = indices_tensor[0].tolist()[:limit] + mean_scores = scores_tensor[0].tolist()[:limit] + + index_cfg = build_index_config(args) + max_ds = get_text_rows(max_indices, max_scores, index_cfg, score_column="max_score") + mean_ds = get_text_rows( + mean_indices, mean_scores, index_cfg, score_column="mean_score" + ) + + output_root.mkdir(parents=True, exist_ok=True) + max_csv, max_hf = save_strategy("max_strategy", max_ds, output_root) + mean_csv, mean_hf = save_strategy("mean_strategy", mean_ds, output_root) + + print(f"Saved max strategy outputs to {max_csv} and {max_hf}") + print(f"Saved mean strategy outputs to {mean_csv} and {mean_hf}") + + +if __name__ == "__main__": + main() From 60d2718e383215a78b868f9e61db525ab800892b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 21 Oct 2025 06:38:47 +0200 Subject: [PATCH 5/5] Fix precision issue in query --- bergson/__main__.py | 15 +- bergson/collection.py | 3 +- bergson/data.py | 58 ++++++- bergson/{query.py => dynamic_query.py} | 98 ++++-------- bergson/faiss_index.py | 13 +- .../{query_existing.py => static_query.py} | 5 +- examples/query_analysis.py | 148 +++++++----------- 7 files changed, 162 insertions(+), 178 deletions(-) rename bergson/{query.py => dynamic_query.py} (86%) rename bergson/{query_existing.py => static_query.py} (94%) diff --git a/bergson/__main__.py b/bergson/__main__.py index 93db26f..36a60b0 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -6,16 +6,14 @@ from .build import build_gradient_dataset from .data import IndexConfig, QueryConfig -from .query import query_gradient_dataset -from .query_existing import query_existing +from .dynamic_query import query_gradient_dataset +from .static_query import query_existing @dataclass class StaticQuery: """Query an on-disk gradient index.""" - scores_path: str - query_cfg: QueryConfig index_cfg: IndexConfig @@ -24,7 +22,10 @@ class StaticQuery: def execute(self): """Query an on-disk gradient index.""" - query_existing(self.scores_path, self.query_cfg, self.index_cfg, self.k) + assert self.query_cfg.scores_path + assert self.query_cfg.query_path + + query_existing(self.query_cfg, self.index_cfg, self.k) @dataclass @@ -53,12 +54,14 @@ class Query: def execute(self): """Query the gradient dataset.""" + assert self.query_cfg.scores_path + assert self.query_cfg.query_path if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index: raise ValueError( "Index path already exists and save_index is True - " "running this query will overwrite the existing gradients. " - "If you meant to query the existing gradients, use " + "If you meant to query the existing gradients use " "Attributor instead." ) diff --git a/bergson/collection.py b/bergson/collection.py index 6b615c0..809264b 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -9,10 +9,9 @@ from tqdm.auto import tqdm from transformers import PreTrainedModel -from .data import create_index, pad_and_tensor +from .data import Query, create_index, pad_and_tensor from .gradients import AttentionConfig, GradientCollector, GradientProcessor from .peft import set_peft_enabled -from .query import Query def collect_gradients( diff --git a/bergson/data.py b/bergson/data.py index 8a19746..087d007 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -2,9 +2,10 @@ import math import os import random +import uuid from dataclasses import dataclass from pathlib import Path -from typing import Literal, Sequence +from typing import Callable, Literal, Sequence import numpy as np import pyarrow as pa @@ -25,6 +26,57 @@ from .utils import assert_type +class Query: + """ + Wraps a query scoring callback and stores the resulting scores in a tensor. + """ + + def __init__( + self, + query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor], + num_items: int, + num_scores: int, + scores_path: str, + *, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", + ): + self._query_callback = query_callback + self._scores_path = scores_path + self.scores = torch.zeros((num_items, num_scores), dtype=dtype, device=device) + self.num_written = 0 + + def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): + scores = self._query_callback(mod_grads).detach() + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + assert scores.shape[0] == len(indices) + assert scores.shape[1] == self.scores.shape[1] + + scores = scores.to(device=self.scores.device, dtype=self.scores.dtype) + self.scores[indices] = scores + self.num_written += len(indices) + + if self.num_written >= len(self.scores): + self.flush() + + def flush(self): + dataset = Dataset.from_dict( + { + "scores": self.scores.cpu().numpy().tolist(), + } + ) + try: + dataset.save_to_disk(self._scores_path) + except Exception as e: + # Handle collisions with existing datasets + print(f"Error writing scores to disk: {e}") + random_hash = str(uuid.uuid4())[:8] + alternate_path = self._scores_path.replace(".hf", f"_{random_hash}.hf") + print(f"Writing to alternate path: {alternate_path}") + dataset.save_to_disk(alternate_path) + + @dataclass class DataConfig: dataset: str = "EleutherAI/SmolLM2-135M-10B" @@ -83,10 +135,6 @@ class QueryConfig: scores_path: str = "" """Path to the directory where query scores should be written.""" - save_processor: bool = True - """Whether to write the query dataset gradient processor - to disk.""" - query_preconditioner_path: str | None = None """Path to a precomputed preconditioner to be applied to the query dataset gradients.""" diff --git a/bergson/query.py b/bergson/dynamic_query.py similarity index 86% rename from bergson/query.py rename to bergson/dynamic_query.py index 9bda82a..550e955 100644 --- a/bergson/query.py +++ b/bergson/dynamic_query.py @@ -1,14 +1,14 @@ import json import os import socket -import uuid from datetime import timedelta -from typing import Callable, cast +from pathlib import Path +from typing import cast import torch import torch.distributed as dist import torch.multiprocessing as mp -from datasets import Dataset, IterableDataset +from datasets import Dataset, IterableDataset, Sequence, Value from peft import PeftConfig, PeftModel from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes from torch.distributed.fsdp import fully_shard @@ -24,6 +24,7 @@ from .collection import collect_gradients from .data import ( IndexConfig, + Query, QueryConfig, allocate_batches, load_data_string, @@ -35,57 +36,6 @@ from .utils import assert_type, get_layer_list -class Query: - """ - Wraps a query scoring callback and stores the resulting scores in a tensor. - """ - - def __init__( - self, - query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor], - num_items: int, - num_scores: int, - scores_path: str, - *, - dtype: torch.dtype = torch.float32, - device: torch.device | str = "cpu", - ): - self._query_callback = query_callback - self._scores_path = scores_path - self.scores = torch.zeros((num_items, num_scores), dtype=dtype, device=device) - self.num_written = 0 - - def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]): - scores = self._query_callback(mod_grads).detach() - if scores.ndim == 1: - scores = scores.unsqueeze(-1) - assert scores.shape[0] == len(indices) - assert scores.shape[1] == self.scores.shape[1] - - scores = scores.to(device=self.scores.device, dtype=self.scores.dtype) - self.scores[indices] = scores - self.num_written += len(indices) - - if self.num_written >= len(self.scores): - self.flush() - - def flush(self): - dataset = Dataset.from_dict( - { - "scores": self.scores.cpu().numpy().tolist(), - } - ) - try: - dataset.save_to_disk(self._scores_path) - except Exception as e: - # Handle collisions with existing datasets - print(f"Error writing scores to disk: {e}") - random_hash = str(uuid.uuid4())[:8] - alternate_path = self._scores_path.replace(".hf", f"_{random_hash}.hf") - print(f"Writing to alternate path: {alternate_path}") - dataset.save_to_disk(alternate_path) - - def get_query_data(query_cfg: QueryConfig): """ Load and optionally precondition the query dataset. Preconditioners @@ -103,7 +53,13 @@ def get_query_data(query_cfg: QueryConfig): target_modules = json.load(f)["dtype"]["names"] query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False) - query_ds = query_ds.with_format("torch", columns=target_modules) + query_ds = query_ds.with_format( + "torch", columns=target_modules, dtype=torch.float32 + ) + + query_ds.features + for col in target_modules: + query_ds = query_ds.cast_column(col, Sequence(Value("float32"), length=1024)) use_q = query_cfg.query_preconditioner_path is not None use_i = query_cfg.index_preconditioner_path is not None @@ -131,11 +87,17 @@ def get_query_data(query_cfg: QueryConfig): if (q and i) else (q or i) ) - mixed_preconditioner = {k: v.cuda() for k, v in mixed_preconditioner.items()} + mixed_preconditioner = { + k: v.cuda().to(torch.float32) for k, v in mixed_preconditioner.items() + } def precondition(batch): + print("batch") for name in target_modules: - batch[name] = (batch[name].cuda() @ mixed_preconditioner[name]).cpu() + result = ( + batch[name].cuda().to(torch.float32) @ mixed_preconditioner[name] + ) + batch[name] = result.cpu() return batch @@ -143,6 +105,9 @@ def precondition(batch): precondition, batched=True, batch_size=query_cfg.batch_size ) + for name in target_modules: + assert query_ds[0][name].sum() != 0 + return query_ds @@ -170,7 +135,6 @@ def callback(mod_grads: dict[str, torch.Tensor]): grads /= grads.norm(dim=1, keepdim=True) # Return a score for every query - print(grads.device, queries.device) return grads @ queries.T return callback @@ -385,23 +349,21 @@ def worker( query_ds = query_ds.with_format("torch", columns=query_cfg.modules) - print(f"Model device: {model.device}") - query_device = torch.device(f"cuda:{rank}") query_dtype = dtype if dtype != "auto" else torch.float16 if query_cfg.score == "mean": base_query_callback = get_mean_query( - query_ds, query_cfg, query_device, query_dtype + query_ds, query_cfg, model.device, query_dtype ) num_scores = 1 elif query_cfg.score == "nearest": base_query_callback = get_nearest_query( - query_ds, query_cfg, query_device, query_dtype + query_ds, query_cfg, model.device, query_dtype ) num_scores = 1 elif query_cfg.score == "individual": base_query_callback = get_individual_query( - query_ds, query_cfg, query_device, query_dtype + query_ds, query_cfg, model.device, query_dtype ) num_scores = len(query_ds) else: @@ -415,9 +377,9 @@ def worker( base_query_callback, len(ds), num_scores, - query_cfg.scores_path, + str(Path(query_cfg.scores_path) / "scores.hf"), dtype=scores_dtype, - device=query_device, + device=model.device, ) collect_gradients( model, @@ -451,9 +413,11 @@ def flush(): base_query_callback, len(ds_shard), num_scores, - query_cfg.scores_path, + str( + Path(query_cfg.scores_path) / f"shard-{shard_id:05d}" / "scores.hf" + ), dtype=scores_dtype, - device=query_device, + device=model.device, ) collect_gradients( model, diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index 9be0082..e42beb6 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -1,5 +1,4 @@ import json -import math import os from dataclasses import dataclass from pathlib import Path @@ -156,7 +155,7 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo ) faiss_path.mkdir(exist_ok=True, parents=True) - if not any(faiss_path.iterdir()): + if len(list(faiss_path.iterdir())) != faiss_cfg.num_shards: print("Building FAISS index...") start = perf_counter() @@ -170,8 +169,7 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo if shard_path.is_dir() and (shard_path / "info.json").exists() ] - if not info_paths: - raise FileNotFoundError(f"No gradient metadata found under {path}") + assert info_paths, f"No gradient metadata found under {path}" total_grads = sum( [json.load(open(info_path))["num_grads"] for info_path in info_paths] @@ -198,6 +196,11 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo def build_shard_from_buffer( buffer_parts: list[NDArray], shard_idx: int ) -> None: + shard_path = faiss_path / f"{shard_idx}.faiss" + if shard_path.exists(): + print(f"Shard {shard_idx} already exists, skipping...") + return + print(f"Building shard {shard_idx}...") grads_chunk = np.concatenate(buffer_parts, axis=0) buffer_parts.clear() @@ -220,7 +223,7 @@ def build_shard_from_buffer( del grads_chunk index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{shard_idx}.faiss")) + faiss.write_index(index, str(shard_path)) for grads in tqdm(dl, desc="Loading gradients"): grads = structured_to_unstructured(grads) diff --git a/bergson/query_existing.py b/bergson/static_query.py similarity index 94% rename from bergson/query_existing.py rename to bergson/static_query.py index 3720636..d197d42 100644 --- a/bergson/query_existing.py +++ b/bergson/static_query.py @@ -11,12 +11,11 @@ QueryConfig, load_gradients, ) -from .query import get_query_data +from .dynamic_query import get_query_data @torch.inference_mode() def query_existing( - scores_path: str, query_cfg: QueryConfig, index_cfg: IndexConfig, k: int | None, @@ -57,7 +56,7 @@ def query_existing( f"Collected scores for {num_queries} queries " f"with {num_scores} scores each." ) - output_path = Path(scores_path) / "scores.hf" + output_path = Path(query_cfg.scores_path) / "scores.hf" output_path.parent.mkdir(parents=True, exist_ok=True) dataset = Dataset.from_dict( diff --git a/examples/query_analysis.py b/examples/query_analysis.py index 26bbabe..de88a2f 100644 --- a/examples/query_analysis.py +++ b/examples/query_analysis.py @@ -3,77 +3,32 @@ from pathlib import Path import torch -from datasets import Dataset, IterableDataset, load_from_disk +from datasets import Dataset, IterableDataset, concatenate_datasets, load_from_disk from bergson.data import DataConfig, IndexConfig, load_data_string def parse_args(): - parser = ArgumentParser( - description=( - "Aggregate static query results assuming six max-accumulated queries " - "and one mean-based query." - ) - ) + parser = ArgumentParser() parser.add_argument( - "--trial-dir", + "--query_scores", type=Path, required=True, - help="Directory containing the saved static query results (static_query.hf).", - ) - parser.add_argument( - "--dataset", - type=str, - required=True, - help="HF dataset identifier matching the index build.", - ) - parser.add_argument( - "--split", - type=str, - default="train", - help="Dataset split used for the index (default: train).", - ) - parser.add_argument( - "--subset", - type=str, - default=None, - help="Dataset subset used for the index, if any.", - ) - parser.add_argument( - "--streaming", - action="store_true", - help="Flag indicating the index dataset was streamed.", - ) - parser.add_argument( - "--run-path", - type=str, - default="static-query", - help="Run path label to instantiate IndexConfig (not used for IO).", - ) - parser.add_argument( - "--limit", - type=int, - default=None, - help="Maximum number of rows to keep per strategy (defaults to k).", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=None, - help="Directory to write strategy outputs (defaults to trial directory).", + help="Directory containing the saved query results (scores.hf).", ) parser.add_argument( - "--raw-name", + "--raw_name", type=str, default="static_query.hf", help="Name of the saved raw search dataset inside the trial directory.", ) - return parser.parse_args() - + parser.add_argument( + "--dataset", type=str, default="EleutherAI/deep-ignorance-pretraining-mix" + ) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--streaming", action="store_true") -def build_index_config(args) -> IndexConfig: - data_cfg = DataConfig(dataset=args.dataset, split=args.split, subset=args.subset) - return IndexConfig(run_path=args.run_path, data=data_cfg, streaming=args.streaming) + return parser.parse_args() def save_strategy(name: str, dataset: Dataset, output_root: Path): @@ -83,7 +38,7 @@ def save_strategy(name: str, dataset: Dataset, output_root: Path): csv_path = strategy_dir / f"{name}.csv" dataset.to_csv(csv_path) - hf_path = strategy_dir / "hf" + hf_path = strategy_dir / "data.hf" if hf_path.exists(): shutil.rmtree(hf_path) dataset.save_to_disk(hf_path) @@ -135,19 +90,27 @@ def get_text_rows( def main(): args = parse_args() - trial_dir = args.trial_dir - output_root = args.output_dir or trial_dir - raw_path = trial_dir / args.raw_name + trial_dir = Path(args.query_scores) - if not raw_path.exists(): - raise FileNotFoundError( - f"Static query results not found at {raw_path}. " - "Run bergson.static_query.query_gradient_dataset first." - ) + assert trial_dir.exists(), f"Query scores not found at {trial_dir}." - raw_ds = load_from_disk(raw_path) - scores_tensor = torch.tensor(raw_ds["scores"], dtype=torch.float32) - indices_tensor = torch.tensor(raw_ds["indices"], dtype=torch.int64) + if (trial_dir / "shard-00000").exists(): + datasets = [] + for shard_path in sorted(trial_dir.iterdir()): + if shard_path.is_dir() and "shard" in shard_path.name: + try: + datasets.append(load_from_disk(shard_path / "data.hf")) + except FileNotFoundError: + datasets.append(load_from_disk(shard_path / "scores.hf")) + scores_ds = concatenate_datasets(datasets) + else: + try: + scores_ds = load_from_disk(trial_dir / "data.hf") + except FileNotFoundError: + scores_ds = load_from_disk(trial_dir / "scores.hf") + + scores_tensor = torch.tensor(scores_ds["scores"], dtype=torch.float32) + indices_tensor = torch.tensor(scores_ds["indices"], dtype=torch.int64) if scores_tensor.ndim != 2 or indices_tensor.ndim != 2: raise ValueError( @@ -155,47 +118,52 @@ def main(): f"Got {scores_tensor.shape} and {indices_tensor.shape}." ) - num_queries, num_neighbors = scores_tensor.shape + num_queries, num_docs = scores_tensor.shape if num_queries != 7: raise ValueError( f"Expected exactly 7 queries (6 for max, 1 for mean); " - f"received {num_queries}." + f"received {num_queries} for each of {num_docs} documents." ) - limit = args.limit or num_neighbors + # Skip first row, flatten remaining 6 rows + scores_flat = scores_tensor[1:].reshape(-1) # Shape: (6 * num_neighbors,) + indices_flat = indices_tensor[1:].reshape(-1) # Shape: (6 * num_neighbors,) + + # Use scatter_reduce to get max score for each unique index + num_docs = int(indices_flat.max().item() + 1) + max_scores = torch.full((num_docs,), float("-inf"), dtype=torch.float32) + max_scores.scatter_reduce_(0, indices_flat, scores_flat, reduce="amax") - max_scores_by_index: dict[int, float] = {} - for idx, score in zip( - indices_tensor[1:].reshape(-1).tolist(), - scores_tensor[1:].reshape(-1).tolist(), - ): - idx = int(idx) - score = float(score) - best = max_scores_by_index.get(idx) - if best is None or score > best: - max_scores_by_index[idx] = score + # Convert to dict, excluding indices that were never seen + max_scores_by_index = { + idx: score.item() + for idx, score in enumerate(max_scores) + if score != float("-inf") + } sorted_max = sorted( max_scores_by_index.items(), key=lambda item: item[1], reverse=True - )[:limit] + ) max_indices = [idx for idx, _ in sorted_max] max_scores = [score for _, score in sorted_max] - mean_indices = indices_tensor[0].tolist()[:limit] - mean_scores = scores_tensor[0].tolist()[:limit] + mean_indices = indices_tensor[0].tolist() + mean_scores = scores_tensor[0].tolist() + + data_cfg = DataConfig(dataset=args.dataset, split=args.split) + index_cfg = IndexConfig(run_path="", data=data_cfg, streaming=args.streaming) - index_cfg = build_index_config(args) max_ds = get_text_rows(max_indices, max_scores, index_cfg, score_column="max_score") mean_ds = get_text_rows( mean_indices, mean_scores, index_cfg, score_column="mean_score" ) - output_root.mkdir(parents=True, exist_ok=True) - max_csv, max_hf = save_strategy("max_strategy", max_ds, output_root) - mean_csv, mean_hf = save_strategy("mean_strategy", mean_ds, output_root) + trial_dir.mkdir(parents=True, exist_ok=True) + max_csv, max_hf = save_strategy("max", max_ds, trial_dir) + mean_csv, mean_hf = save_strategy("mean", mean_ds, trial_dir) - print(f"Saved max strategy outputs to {max_csv} and {max_hf}") - print(f"Saved mean strategy outputs to {mean_csv} and {mean_hf}") + print(f"Saved max top results to {max_csv} and {max_hf}") + print(f"Saved mean top results to {mean_csv} and {mean_hf}") if __name__ == "__main__":