diff --git a/README.md b/README.md index fbb5aec..ecbb12e 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,11 @@ trainer.train() ## Attention Head Gradients -By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module can be collected too. To collect per-head gradients configure an AttentionConfig for each module of interest. +By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module may also be collected. To collect per-head gradients programmatically configure an AttentionConfig for each module of interest, or specify attention modules and a single shared configuration using the command line tool. + +```bash +bergson build runs/test --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation --split_attention_modules "h.0.attn.attention.out_proj" --attention.num_heads 16 --attention.head_size 4 --attention.head_dim 2 +``` ```python from bergson import AttentionConfig, IndexConfig, DataConfig @@ -99,7 +103,7 @@ Where a reward signal is available we compute gradients using a weighted advanta bergson build --model --dataset --reward_column ``` -## Queries +## Index Queries We provide a query Attributor which supports unit normalized gradients and KNN search out of the box. @@ -127,6 +131,12 @@ with attr.trace(model.base_model, 5) as result: model.zero_grad() ``` +## On-The-Fly Queries + +```bash +bergson query runs/scores --query_path runs/query_index --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --save_index False --truncation +``` + # Development ```bash diff --git a/bergson/__main__.py b/bergson/__main__.py index 27cf7a8..632e15d 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -6,7 +6,26 @@ from .build import build_gradient_dataset from .data import IndexConfig, QueryConfig -from .query import query_gradient_dataset +from .dynamic_query import query_gradient_dataset +from .static_query import query_existing + + +@dataclass +class StaticQuery: + """Query an on-disk gradient index.""" + + query_cfg: QueryConfig + + index_cfg: IndexConfig + + k: int | None = 50 + + def execute(self): + """Query an on-disk gradient index.""" + assert self.query_cfg.scores_path + assert self.query_cfg.query_path + + query_existing(self.query_cfg, self.index_cfg, self.k) @dataclass @@ -35,12 +54,14 @@ class Query: def execute(self): """Query the gradient dataset.""" + assert self.query_cfg.scores_path + assert self.query_cfg.query_path if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index: raise ValueError( "Index path already exists and save_index is True - " "running this query will overwrite the existing gradients. " - "If you meant to query the existing gradients, use " + "If you meant to query the existing gradients use " "Attributor instead." ) @@ -51,7 +72,7 @@ def execute(self): class Main: """Routes to the subcommands.""" - command: Union[Build, Query] + command: Union[Build, Query, StaticQuery] def execute(self): """Run the script.""" diff --git a/bergson/attributor.py b/bergson/attributor.py index eb2c1af..60859c2 100644 --- a/bergson/attributor.py +++ b/bergson/attributor.py @@ -74,8 +74,11 @@ def __init__( self.grads[name] /= norm def search( - self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None - ) -> tuple[Tensor, Tensor]: + self, + queries: dict[str, Tensor], + k: int | None, + modules: list[str] | None = None, + ): """ Search for the `k` nearest examples in the index based on the query or queries. @@ -112,7 +115,7 @@ def search( ) modules = modules or list(q.keys()) - k = min(k, self.N) + k = min(k or self.N, self.N) scores = torch.stack( [q[name] @ self.grads[name].mT for name in modules], dim=-1 diff --git a/bergson/build.py b/bergson/build.py index e0eac1d..35f4088 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -168,6 +168,8 @@ def worker( save_index=cfg.save_index, save_processor=cfg.save_processor, drop_columns=cfg.drop_columns, + create_custom_query=cfg.in_memory_index, + module_wise=cfg.module_wise, ) else: # Convert each shard to a Dataset then map over its gradients @@ -194,6 +196,8 @@ def flush(): # Save a processor state checkpoint after each shard save_processor=cfg.save_processor, drop_columns=cfg.drop_columns, + create_custom_query=cfg.in_memory_index, + module_wise=cfg.module_wise, ) buf.clear() shard_id += 1 diff --git a/bergson/collection.py b/bergson/collection.py index 7ef81da..42e1260 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -1,6 +1,7 @@ import math -from typing import Callable, Literal +from typing import Literal +# import os import numpy as np import torch import torch.distributed as dist @@ -9,7 +10,7 @@ from tqdm.auto import tqdm from transformers import PreTrainedModel -from .data import create_index, pad_and_tensor +from .data import Query, create_index, pad_and_tensor from .gradients import AttentionConfig, GradientCollector, GradientProcessor from .peft import set_peft_enabled @@ -29,11 +30,17 @@ def collect_gradients( save_index: bool = True, save_processor: bool = True, drop_columns: bool = False, - query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None, + query: Query | None = None, + module_wise: bool = False, + create_custom_query: bool = False, ): """ Compute projected gradients using a subset of the dataset. """ + assert not create_custom_query, "create_custom_query is commented out" + if module_wise and query: + assert skip_preconditioners + rank = dist.get_rank() if dist.is_initialized() else 0 if attention_cfgs is None: @@ -53,26 +60,9 @@ def collect_gradients( lo = torch.finfo(dtype).min hi = torch.finfo(dtype).max - def callback(name: str, g: torch.Tensor): - g = g.flatten(1).clamp_(lo, hi) - if save_index: - # Asynchronously move the gradient to CPU and convert to the final dtype - mod_grads[name] = g.to(device="cpu", dtype=dtype, non_blocking=True) - else: - mod_grads[name] = g.to(dtype=dtype) - - # Compute the outer product of the flattened gradient - if not skip_preconditioners: - g = g.float() - preconditioner = preconditioners.get(name, None) - if preconditioner is None: - preconditioners[name] = g.mT @ g - else: - preconditioner.addmm_(g.mT, g) - collector = GradientCollector( model.base_model, - callback, + lambda _: None, processor, target_modules=target_modules, attention_cfgs=attention_cfgs, @@ -87,6 +77,52 @@ def callback(name: str, g: torch.Tensor): if save_index else None ) + # if create_custom_query: + # num_grads = sum(len(indices) for indices in batches) + # print("file size in GB", sum(list(grad_sizes.values())) * + # np.dtype(np_dtype).itemsize / 1024**3) + # grads = { + # name: torch.zeros(1, grad_sizes[name], dtype=torch.float16, device="cpu") + # for name in grad_sizes.keys() + # } + # else: + # grads = {} + # num_grads = -1 + + def callback(name: str, g: torch.Tensor, indices: list[int]): + g = g.flatten(1).clamp_(lo, hi) + if grad_buffer is not None: # or grads: + # Asynchronously move the gradient to CPU and convert to the final dtype + mod_grads[name] = g.to(device="cpu", dtype=dtype, non_blocking=True) + + if module_wise: + # Consume gradient immediately + torch.cuda.synchronize() + # if grads: + # grads[name][0, :] += mod_grads[name].sum(dim=0) / len(indices) + # elif grad_buffer is not None: + if grad_buffer is not None: + grad_buffer[name][indices] = mod_grads[name].numpy() + + mod_grads.pop(name) + else: + # TODO do we need the dtype conversion + mod_grads[name] = g.to(dtype=dtype) + if module_wise and query: + query(indices, mod_grads, name) + mod_grads.pop(name) + + # Compute the outer product of the flattened gradient + if not skip_preconditioners: + g = g.float() + preconditioner = preconditioners.get(name, None) + if preconditioner is None: + preconditioners[name] = g.mT @ g + else: + preconditioner.addmm_(g.mT, g) + + # Update collect with callback + collector.closure = callback per_doc_losses = torch.full( (len(data),), @@ -94,12 +130,6 @@ def callback(name: str, g: torch.Tensor): dtype=dtype, fill_value=0.0, ) - per_doc_scores = torch.full( - (len(data),), - device=model.device, - dtype=dtype, - fill_value=0.0, - ) for indices in tqdm(batches, disable=rank != 0, desc="Building index"): batch = data[indices] @@ -117,7 +147,14 @@ def callback(name: str, g: torch.Tensor): ref_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1) set_peft_enabled(model, True) - with collector: + with GradientCollector( + model.base_model, + callback, + processor, + target_modules=target_modules, + attention_cfgs=attention_cfgs, + indices=indices, + ) as collector: ft_lps = torch.log_softmax(model(x).logits[:, :-1], dim=-1) # Compute average KL across all unmasked tokens @@ -128,7 +165,14 @@ def callback(name: str, g: torch.Tensor): losses.mean().backward() else: - with collector: + with GradientCollector( + model.base_model, + callback, + processor, + target_modules=target_modules, + attention_cfgs=attention_cfgs, + indices=indices, + ) as collector: logits = model(x).logits[:, :-1] losses = F.cross_entropy( @@ -144,7 +188,7 @@ def callback(name: str, g: torch.Tensor): model.zero_grad() - if grad_buffer is not None: + if grad_buffer is not None and not module_wise: # Weirdly you need to explicitly synchronize here in order to make # sure that the nonblocking copies actually finish before we call # .numpy() @@ -153,12 +197,15 @@ def callback(name: str, g: torch.Tensor): # It turns out that it's very important for efficiency to write the # gradients sequentially instead of first concatenating them, then # writing to one vector + # if custom: + # for name in mod_grads.keys(): + # grads[name] += mod_grads[name].sum(dim=0) / len(indices) + # else: for module_name in mod_grads.keys(): grad_buffer[module_name][indices] = mod_grads[module_name].numpy() - if query_callback is not None: - scores = query_callback(mod_grads) - per_doc_scores[indices] = scores.detach().type_as(per_doc_scores) + if query and not module_wise: + query(indices, mod_grads) mod_grads.clear() per_doc_losses[indices] = losses.detach().type_as(per_doc_losses) @@ -178,12 +225,6 @@ def callback(name: str, g: torch.Tensor): feature=Value("float16" if dtype == torch.float16 else "float32"), new_fingerprint="loss", ) - data = data.add_column( - "scores", - per_doc_scores.cpu().numpy(), - feature=Value("float16" if dtype == torch.float16 else "float32"), - new_fingerprint="scores", - ) data.save_to_disk(path + "/data.hf") if save_processor: @@ -193,6 +234,14 @@ def callback(name: str, g: torch.Tensor): if grad_buffer is not None: grad_buffer.flush() + # Make sure the scores are written to disk + if query: + query.save_scores(rank) + + # if create_custom_query: + # torch.save(grads, os.path.join(path, f"accum_mean_grads_{rank}.pth")) + # torch.save(num_grads, os.path.join(path, f"num_grads_{rank}.pth")) + def process_preconditioners( processor: GradientProcessor, diff --git a/bergson/data.py b/bergson/data.py index 24afd7e..1f1ed2f 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -2,9 +2,10 @@ import math import os import random +import uuid from dataclasses import dataclass from pathlib import Path -from typing import Literal, Sequence +from typing import Callable, Literal, Sequence import numpy as np import pyarrow as pa @@ -25,11 +26,118 @@ from .utils import assert_type +class Query: + """ + Wraps a query scoring callback and stores the resulting scores in a tensor. + """ + + def __init__( + self, + query_callback: Callable[..., torch.Tensor], + num_items: int, + num_scores: int, + scores_path: str, + *, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", + rank: int, + module_wise: bool = False, + ): + self._query_callback = query_callback + self._scores_path = scores_path + self.scores = torch.zeros((num_items, num_scores), dtype=dtype, device=device) + self.rank = rank + self.num_written = 0 + + self.module_wise = module_wise + if self.module_wise: + self.sum_of_squares = torch.zeros((num_items,), dtype=dtype, device=device) + + def __call__( + self, + indices: list[int], + mod_grads: dict[str, torch.Tensor], + name: str | None = None, + ): + if name: + # Accumulate module-wise scores + scores, sum_of_squares = self._query_callback(mod_grads, name) + self.sum_of_squares[indices] += sum_of_squares + + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + + self.scores[indices] += scores.to( + device=self.scores.device, dtype=self.scores.dtype + ) + + else: + scores = self._query_callback(mod_grads) + + if scores.ndim == 1: + scores = scores.unsqueeze(-1) + + scores = scores.to(device=self.scores.device, dtype=self.scores.dtype) + self.scores[indices] = scores + + def save_scores(self, rank: int): + if rank != 0: + return + + dataset = Dataset.from_dict( + { + "scores": self.scores.cpu().numpy().tolist(), + "indices": list(range(len(self.scores))), + } + ) + try: + dataset.save_to_disk(self._scores_path) + except Exception as e: + # Handle collisions with existing datasets + print(f"Error writing scores to disk: {e}") + random_hash = str(uuid.uuid4())[:8] + alternate_path = self._scores_path.replace(".hf", f"_{random_hash}.hf") + print(f"Writing to alternate path: {alternate_path}") + dataset.save_to_disk(alternate_path) + + if self.module_wise: + dataset = Dataset.from_dict( + { + "sum_of_squares": self.sum_of_squares.cpu().numpy().tolist(), + "indices": list(range(len(self.sum_of_squares))), + } + ) + dataset.save_to_disk(self._scores_path.replace(".hf", "_sum_of_squares.hf")) + + normalized_scores = np.zeros_like(self.scores.cpu().numpy()) + batch_size = 1024 + for i in range(0, len(self.sum_of_squares), batch_size): + batch = self.sum_of_squares[i : i + batch_size] + normalized_scores[i : i + batch_size] = ( + (self.scores[i : i + batch_size] / (batch.sqrt() + 1e-12)) + .cpu() + .numpy() + ) + dataset = Dataset.from_dict( + { + "normalized_scores": normalized_scores.tolist(), + "indices": list(range(len(self.sum_of_squares))), + } + ) + + dataset.save_to_disk( + self._scores_path.replace(".hf", "_normalized_scores.hf") + ) + + @dataclass class DataConfig: dataset: str = "EleutherAI/SmolLM2-135M-10B" """Dataset identifier to build the index from.""" + subset: str | None = None + """Subset of the dataset to use for building the index.""" + split: str = "train" """Split of the dataset to use for building the index.""" @@ -70,29 +178,30 @@ class QueryConfig: """Config for querying an index on the fly.""" query_path: str = "" - """Path to the query dataset.""" + """Path to the existing query index.""" - query_method: Literal["mean", "nearest"] = "mean" - """Method to use for computing the query.""" + score: Literal["mean", "nearest", "individual"] = "mean" + """Method for scoring the gradients with the query. If mean + gradients will be scored by their similarity with the mean + query gradients, otherwise by the most similar query gradient.""" - save_processor: bool = True - """Whether to write the query dataset gradient processor - to disk.""" + scores_path: str = "" + """Path to the directory where query scores should be written.""" query_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients.""" index_preconditioner_path: str | None = None - """Path to a precomputed preconditioner. The precomputed - preconditioner is applied to the query dataset gradients. - This does not affect the ability to compute a new - preconditioner during gradient collection.""" + """Path to a precomputed preconditioner to be applied to + the query dataset gradients. This does not affect the + ability to compute a new preconditioner during gradient + collection.""" - mixing_coefficient: float = 0.5 + mixing_coefficient: float = 0.99 """Coefficient to weight the application of the query preconditioner and the pre-computed index preconditioner. 0.0 means only use the - query preconditioner and 1.0 means only use the index preconditioner.""" + index preconditioner and 1.0 means only use the query preconditioner.""" modules: list[str] = field(default_factory=list) """Modules to use for the query. If empty, all modules will be used.""" @@ -114,6 +223,9 @@ class IndexConfig: save_index: bool = True """Whether to write the gradient index to disk.""" + in_memory_index: bool = False + """Whether to keep the gradient index in memory and torch.save it to disk.""" + save_processor: bool = True """Whether to write the gradient processor to disk.""" @@ -175,6 +287,9 @@ class IndexConfig: """Configuration for each attention module to be split into head matrices. Used for attention modules specified in `split_attention_modules`.""" + module_wise: bool = False + """Whether to compute the gradients module-wise.""" + @property def partial_run_path(self) -> str: """Temporary path used while writing build artifacts.""" @@ -357,7 +472,10 @@ def create_index( def load_data_string( - data_str: str, split: str = "train", streaming: bool = False + data_str: str, + split: str = "train", + subset: str | None = None, + streaming: bool = False, ) -> Dataset | IterableDataset: """Load a dataset from a string identifier or path.""" if data_str.endswith(".csv"): @@ -366,7 +484,12 @@ def load_data_string( ds = assert_type(Dataset, Dataset.from_json(data_str)) else: try: - ds = load_dataset(data_str, split=split, streaming=streaming) + if subset: + ds = load_dataset( + data_str, split=split, subset=subset, streaming=streaming + ) + else: + ds = load_dataset(data_str, split=split, streaming=streaming) if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): raise NotImplementedError( diff --git a/bergson/query.py b/bergson/dynamic_query.py similarity index 64% rename from bergson/query.py rename to bergson/dynamic_query.py index 522ee67..c711797 100644 --- a/bergson/query.py +++ b/bergson/dynamic_query.py @@ -2,12 +2,13 @@ import os import socket from datetime import timedelta +from pathlib import Path from typing import cast import torch import torch.distributed as dist import torch.multiprocessing as mp -from datasets import Dataset, IterableDataset +from datasets import Dataset, IterableDataset, Sequence, Value, load_from_disk from peft import PeftConfig, PeftModel from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes from torch.distributed.fsdp import fully_shard @@ -23,6 +24,7 @@ from .collection import collect_gradients from .data import ( IndexConfig, + Query, QueryConfig, allocate_batches, load_data_string, @@ -34,7 +36,7 @@ from .utils import assert_type, get_layer_list -def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig): +def get_query_data(query_cfg: QueryConfig): """ Load and optionally precondition the query dataset. Preconditioners may be mixed as described in https://arxiv.org/html/2410.17413v1#S3. @@ -46,12 +48,38 @@ def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig): "Please build a query dataset index first." ) + if query_cfg.query_path.endswith("full_accum_mean_mod_grads.hf"): + print("Short circuiting code") + query_ds = load_from_disk(query_cfg.query_path) + if not query_cfg.modules: + query_cfg.modules = list(query_ds.column_names) + print(f"Modules: {query_cfg.modules}") + query_ds = query_ds.with_format("torch", columns=query_cfg.modules) + + return query_ds + # Load the query dataset with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: target_modules = json.load(f)["dtype"]["names"] query_ds = load_gradient_dataset(query_cfg.query_path, concatenate_gradients=False) - query_ds = query_ds.with_format("torch", columns=target_modules) + query_ds = query_ds.with_format( + "torch", columns=target_modules, dtype=torch.float32 + ) + + # Ensure all gradient columns are materialized as float32 tensors with their native + # sequence length. + for col in target_modules: + feature = query_ds.features[col] + if isinstance(feature, Sequence): + length = feature.length + else: + length = None + query_ds = query_ds.cast_column( + col, + Sequence(Value("float32"), length=length), + ) + print("Length", length) use_q = query_cfg.query_preconditioner_path is not None use_i = query_cfg.index_preconditioner_path is not None @@ -79,11 +107,17 @@ def get_query_data(index_cfg: IndexConfig, query_cfg: QueryConfig): if (q and i) else (q or i) ) - mixed_preconditioner = {k: v.cuda() for k, v in mixed_preconditioner.items()} + mixed_preconditioner = { + k: v.cuda().to(torch.float64) for k, v in mixed_preconditioner.items() + } def precondition(batch): + print("batch") for name in target_modules: - batch[name] = (batch[name].cuda() @ mixed_preconditioner[name]).cpu() + result = ( + batch[name].cuda().to(torch.float64) @ mixed_preconditioner[name] + ) + batch[name] = result.cpu() return batch @@ -91,9 +125,126 @@ def precondition(batch): precondition, batched=True, batch_size=query_cfg.batch_size ) + for name in target_modules: + assert query_ds[0][name].sum() != 0 + return query_ds +def get_individual_query( + query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype +): + """ + Compute the individual query and return a callback function that scores gradients + according to their inner products or cosine similarities with the individual + queries. + """ + queries = torch.cat([query_ds[:][name] for name in query_cfg.modules], dim=1).to( + device=device, dtype=dtype + ) + + if query_cfg.unit_normalize: + queries /= queries.norm(dim=1, keepdim=True) + + # Assert on device + assert queries.device == device + + def callback(mod_grads: dict[str, torch.Tensor]): + grads = torch.cat([mod_grads[name] for name in query_cfg.modules], dim=1) + if query_cfg.unit_normalize: + grads /= grads.norm(dim=1, keepdim=True) + + # Return a score for every query + return grads @ queries.T + + return callback + + +def get_module_wise_mean_query( + query_ds: Dataset, + query_cfg: QueryConfig, + device: torch.device, + dtype: torch.dtype, + precomputed_mean: bool = True, +): + """ + Compute the mean query and return a callback function that scores gradients + according to their inner products or cosine similarities with the mean query. + """ + # Accumulate on CPU to avoid holding full-resolution gradients on GPU. + if not precomputed_mean: + acc_device = torch.device("cpu") + acc = { + module: torch.zeros_like( + query_ds[0][module], device=acc_device, dtype=torch.float32 + ) + for module in query_cfg.modules + } + + def sum_(*cols): + for module, x in zip(query_cfg.modules, cols): + x = x.to(device=acc_device, dtype=torch.float32) + if query_cfg.unit_normalize: + x = x / (x.norm(dim=1, keepdim=True) + 1e-12) + acc[module].add_(x.sum(0)) + + query_ds.map( + sum_, + input_columns=query_cfg.modules, + batched=True, + batch_size=query_cfg.batch_size, + ) + + callback_query = { + module: (acc[module] / len(query_ds)).to( + device=device, dtype=dtype, non_blocking=True + ) + for module in query_cfg.modules + } + else: + if query_cfg.unit_normalize: + callback_query = { + module: query_ds[0][module].to( + device=device, dtype=torch.float32, non_blocking=True + ) + for module in query_cfg.modules + } + norm = ( + torch.cat( + [ + (query_ds[0][module].to(torch.float32) ** 2).sum(dim=1) + for module in query_cfg.modules + ], + dim=0, + ) + .sum() + .sqrt() + ) + print(norm.shape) + callback_query = { + module: (callback_query[module].to(torch.float32) / norm).to( + dtype=dtype + ) + for module in query_cfg.modules + } + else: + callback_query = { + module: query_ds[0][module].to( + device=device, dtype=dtype, non_blocking=True + ) + for module in query_cfg.modules + } + + @torch.inference_mode() + def callback(mod_grads: dict[str, torch.Tensor], name: str): + module_scores = mod_grads[name] @ callback_query[name] + sum_of_squares = (mod_grads[name] ** 2).sum(dim=1) + + return module_scores, sum_of_squares + + return callback + + def get_mean_query( query_ds: Dataset, query_cfg: QueryConfig, device: torch.device, dtype: torch.dtype ): @@ -101,18 +252,21 @@ def get_mean_query( Compute the mean query and return a callback function that scores gradients according to their inner products or cosine similarities with the mean query. """ + # Accumulate on CPU to avoid holding full-resolution gradients on GPU. + acc_device = torch.device("cpu") acc = { module: torch.zeros_like( - query_ds[0][module], device=device, dtype=torch.float32 + query_ds[0][module], device=acc_device, dtype=torch.float32 ) for module in query_cfg.modules } def sum_(*cols): for module, x in zip(query_cfg.modules, cols): + x = x.to(device=acc_device, dtype=torch.float32) if query_cfg.unit_normalize: x = x / (x.norm(dim=1, keepdim=True) + 1e-12) - acc[module] += x.to(device=device, dtype=torch.float32).sum(0) + acc[module].add_(x.sum(0)) query_ds.map( sum_, @@ -123,7 +277,9 @@ def sum_(*cols): callback_query = torch.cat( [ - (acc[module] / len(query_ds)).to(device=device, dtype=dtype) + (acc[module] / len(query_ds)).to( + device=device, dtype=dtype, non_blocking=True + ) for module in query_cfg.modules ], dim=0, @@ -298,25 +454,51 @@ def worker( else: attention_cfgs = {} - with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: - query_cfg.modules = json.load(f)["dtype"]["names"] + if not query_cfg.modules: + with open(os.path.join(query_cfg.query_path, "info.json"), "r") as f: + query_cfg.modules = json.load(f)["dtype"]["names"] query_ds = query_ds.with_format("torch", columns=query_cfg.modules) - query_device = torch.device(f"cuda:{rank}") query_dtype = dtype if dtype != "auto" else torch.float16 - if query_cfg.query_method == "mean": - query_callback = get_mean_query(query_ds, query_cfg, query_device, query_dtype) - elif query_cfg.query_method == "nearest": - query_callback = get_nearest_query( - query_ds, query_cfg, query_device, query_dtype + if query_cfg.score == "mean": + if index_cfg.module_wise: + base_query_callback = get_module_wise_mean_query( + query_ds, query_cfg, model.device, query_dtype + ) + else: + base_query_callback = get_mean_query( + query_ds, query_cfg, model.device, query_dtype + ) + num_scores = 1 + elif query_cfg.score == "nearest": + base_query_callback = get_nearest_query( + query_ds, query_cfg, model.device, query_dtype + ) + num_scores = 1 + elif query_cfg.score == "individual": + base_query_callback = get_individual_query( + query_ds, query_cfg, model.device, query_dtype ) + num_scores = len(query_ds) else: - raise ValueError(f"Invalid query method: {query_cfg.query_method}") + raise ValueError(f"Invalid query scoring method: {query_cfg.score}") + + scores_dtype = torch.float32 if model.dtype == torch.float32 else torch.float16 if isinstance(ds, Dataset): batches = allocate_batches(ds["length"][:], index_cfg.token_batch_size) + query = Query( + base_query_callback, + len(ds), + num_scores, + str(Path(query_cfg.scores_path) / "scores.hf"), + dtype=scores_dtype, + device=model.device, + rank=rank, + module_wise=index_cfg.module_wise, + ) collect_gradients( model, ds, @@ -329,9 +511,10 @@ def worker( target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + query=query, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, + module_wise=index_cfg.module_wise, ) else: # Convert each shard to a Dataset then collect its gradients @@ -345,6 +528,18 @@ def flush(): batches = allocate_batches( ds_shard["length"][:], index_cfg.token_batch_size ) + query = Query( + base_query_callback, + len(ds_shard), + num_scores, + str( + Path(query_cfg.scores_path) / f"shard-{shard_id:05d}" / "scores.hf" + ), + dtype=scores_dtype, + device=model.device, + rank=rank, + module_wise=index_cfg.module_wise, + ) collect_gradients( model, ds_shard, @@ -357,9 +552,10 @@ def flush(): target_modules=target_modules, attention_cfgs=attention_cfgs, drop_columns=index_cfg.drop_columns, - query_callback=query_callback, + query=query, save_index=index_cfg.save_index, save_processor=index_cfg.save_processor, + module_wise=index_cfg.module_wise, ) buf.clear() shard_id += 1 @@ -415,7 +611,7 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): new_fingerprint="advantage", # type: ignore ) - query_ds = get_query_data(index_cfg, query_cfg) + query_ds = get_query_data(query_cfg) world_size = torch.cuda.device_count() if world_size <= 1: @@ -452,5 +648,5 @@ def query_gradient_dataset(query_cfg: QueryConfig, index_cfg: IndexConfig): try: os.rename(index_cfg.partial_run_path, index_cfg.run_path) - except Exception: - pass + except Exception as e: + print(f"Error renaming index path: {e}") diff --git a/bergson/faiss_index.py b/bergson/faiss_index.py index 87f3a72..e42beb6 100644 --- a/bergson/faiss_index.py +++ b/bergson/faiss_index.py @@ -1,9 +1,8 @@ import json -import math import os from dataclasses import dataclass from pathlib import Path -from time import time +from time import perf_counter from types import ModuleType from typing import TYPE_CHECKING, Protocol @@ -95,7 +94,7 @@ def load_shard(shard_dir: str) -> np.memmap: yield load_shard(root_dir) else: for shard_path in sorted(root_path.iterdir()): - if shard_path.is_dir(): + if shard_path.is_dir() and "shard" in shard_path.name: yield load_shard(str(shard_path)) @@ -135,10 +134,12 @@ def index_to_device(index: Index, device: str) -> Index: class FaissIndex: - """FAISS index.""" + """Shard-based FAISS index.""" shards: list[Index] + faiss_cfg: FaissConfig + def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bool): faiss = _require_faiss() @@ -152,75 +153,114 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo f"{'_unit_norm' if unit_norm else ''}" ) ) + faiss_path.mkdir(exist_ok=True, parents=True) - if not (faiss_path.exists() and any(faiss_path.iterdir())): + if len(list(faiss_path.iterdir())) != faiss_cfg.num_shards: print("Building FAISS index...") - start = time() + start = perf_counter() + + root_path = Path(path) + if (root_path / "info.json").exists(): + info_paths = [root_path / "info.json"] + else: + info_paths = [ + shard_path / "info.json" + for shard_path in sorted(root_path.iterdir()) + if shard_path.is_dir() and (shard_path / "info.json").exists() + ] + + assert info_paths, f"No gradient metadata found under {path}" + + total_grads = sum( + [json.load(open(info_path))["num_grads"] for info_path in info_paths] + ) - faiss_path.mkdir(exist_ok=True, parents=True) + assert faiss_cfg.num_shards <= total_grads and faiss_cfg.num_shards > 0 - num_dataset_shards = len(list(Path(path).iterdir())) - shards_per_index = math.ceil(num_dataset_shards / faiss_cfg.num_shards) + # Set the number of grads for each faiss index shard + base_shard_size = total_grads // faiss_cfg.num_shards + remainder = total_grads % faiss_cfg.num_shards + shard_sizes = [base_shard_size] * (faiss_cfg.num_shards) + shard_sizes[-1] += remainder - dl = gradients_loader(path) - buffer = [] - index_idx = 0 + # Verify all gradients will be consumed + assert ( + sum(shard_sizes) == total_grads + ), f"Shard sizes {shard_sizes} don't sum to total_grads {total_grads}" - for grads in tqdm(dl, desc="Loading gradients"): - if grads.dtype.names is not None: - grads = structured_to_unstructured(grads) - - if unit_norm: - grads = normalize_grads(grads, device, faiss_cfg.batch_size) + dl = gradients_loader(path) + buffer: list[NDArray] = [] + buffer_size = 0 + shard_idx = 0 + + def build_shard_from_buffer( + buffer_parts: list[NDArray], shard_idx: int + ) -> None: + shard_path = faiss_path / f"{shard_idx}.faiss" + if shard_path.exists(): + print(f"Shard {shard_idx} already exists, skipping...") + return + + print(f"Building shard {shard_idx}...") + grads_chunk = np.concatenate(buffer_parts, axis=0) + buffer_parts.clear() - buffer.append(grads) + index = faiss.index_factory( + grads_chunk.shape[1], + faiss_cfg.index_factory, + faiss.METRIC_INNER_PRODUCT, + ) + index = index_to_device(index, device) + if faiss_cfg.max_train_examples is not None: + train_examples = min( + faiss_cfg.max_train_examples, grads_chunk.shape[0] + ) + else: + train_examples = grads_chunk.shape[0] + index.train(grads_chunk[:train_examples]) + index.add(grads_chunk) - if len(buffer) == shards_per_index: - # Build index shard - print(f"Building shard {index_idx}...") + del grads_chunk - grads = np.concatenate(buffer, axis=0) - buffer = [] + index = index_to_device(index, "cpu") + faiss.write_index(index, str(shard_path)) - index = faiss.index_factory( - grads.shape[1], - faiss_cfg.index_factory, - faiss.METRIC_INNER_PRODUCT, - ) - index = index_to_device(index, device) - train_examples = faiss_cfg.max_train_examples or grads.shape[0] - index.train(grads[:train_examples]) - index.add(grads) + for grads in tqdm(dl, desc="Loading gradients"): + grads = structured_to_unstructured(grads) + if unit_norm: + grads = normalize_grads(grads, device, faiss_cfg.batch_size) - # Write index to disk - del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) + batch_idx = 0 + batch_size = grads.shape[0] + while batch_idx < batch_size and shard_idx < faiss_cfg.num_shards: + remaining_in_shard = shard_sizes[shard_idx] - buffer_size + take = min(remaining_in_shard, batch_size - batch_idx) - index_idx += 1 + if take > 0: + buffer.append(grads[batch_idx : batch_idx + take]) + buffer_size += take + batch_idx += take - if buffer: - grads = np.concatenate(buffer, axis=0) - buffer = [] - index = faiss.index_factory( - grads.shape[1], faiss_cfg.index_factory, faiss.METRIC_INNER_PRODUCT - ) - index = index_to_device(index, device) - index.train(grads) - index.add(grads) + if buffer_size == shard_sizes[shard_idx]: + build_shard_from_buffer(buffer, shard_idx) + buffer = [] + buffer_size = 0 + shard_idx += 1 - # Write index to disk del grads - index = index_to_device(index, "cpu") - faiss.write_index(index, str(faiss_path / f"{index_idx}.faiss")) - print(f"Built index in {(time() - start) / 60:.2f} minutes.") - del buffer + assert shard_idx == faiss_cfg.num_shards + print(f"Built index in {(perf_counter() - start) / 60:.2f} minutes.") + + shard_paths = sorted( + (c for c in faiss_path.glob("*.faiss") if c.stem.isdigit()), + key=lambda p: int(p.stem), + ) shards = [] - for i in range(faiss_cfg.num_shards): + for shard_path in shard_paths: shard = faiss.read_index( - str(faiss_path / f"{i}.faiss"), + str(shard_path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY, ) if not faiss_cfg.mmap_index: @@ -228,21 +268,26 @@ def __init__(self, path: str, faiss_cfg: FaissConfig, device: str, unit_norm: bo shards.append(shard) + # TODO raise? + if len(shards) != faiss_cfg.num_shards: + faiss_cfg.num_shards = len(shards) + self.shards = shards - def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: + def search(self, q: NDArray, k: int | None) -> tuple[NDArray, NDArray]: """Note: if fewer than `k` examples are found FAISS will return items - with the index -1 and the maximum negative distance.""" + with the index -1 and the maximum negative distance. If `k` is `None`, + all examples will be returned.""" shard_distances = [] shard_indices = [] offset = 0 - for index in self.shards: - index.nprobe = self.faiss_cfg.nprobe - distances, indices = index.search(q, k) + for shard in self.shards: + shard.nprobe = self.faiss_cfg.nprobe + distances, indices = shard.search(q, k or shard.ntotal) indices += offset - offset += index.ntotal + offset += shard.ntotal shard_distances.append(distances) shard_indices.append(indices) @@ -252,7 +297,7 @@ def search(self, q: NDArray, k: int) -> tuple[NDArray, NDArray]: # Rerank results overfetched from multiple shards if len(self.shards) > 1: - topk_indices = np.argsort(distances, axis=1)[:, :k] + topk_indices = np.argsort(distances, axis=1)[:, : k or self.ntotal] indices = indices[np.arange(indices.shape[0])[:, None], topk_indices] distances = distances[np.arange(distances.shape[0])[:, None], topk_indices] diff --git a/bergson/gradients.py b/bergson/gradients.py index 325ebd7..e4b523b 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -353,6 +353,9 @@ class GradientCollector(ContextDecorator): Dictionary of head configurations for each module to be split into head matrices. """ + indices: list[int] | None = None + """Indices of the data points to collect gradients for.""" + def __post_init__(self): self._fwd_hooks: list[RemovableHandle] = [] self._bwd_hooks: list[RemovableHandle] = [] @@ -555,7 +558,7 @@ def _process_grad(self, module: nn.Module, _, grad_out): P = G.mT @ I # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] - self.closure(name, P) + self.closure(name, P, self.indices) # Save memory ASAP del module._inputs diff --git a/bergson/static_query.py b/bergson/static_query.py new file mode 100644 index 0000000..d197d42 --- /dev/null +++ b/bergson/static_query.py @@ -0,0 +1,69 @@ +from pathlib import Path +from time import perf_counter + +import torch +from datasets import Dataset + +from bergson import Attributor, FaissConfig + +from .data import ( + IndexConfig, + QueryConfig, + load_gradients, +) +from .dynamic_query import get_query_data + + +@torch.inference_mode() +def query_existing( + query_cfg: QueryConfig, + index_cfg: IndexConfig, + k: int | None, + device="cpu", +): + if not query_cfg.modules: + query_cfg.modules = list(load_gradients(query_cfg.query_path).dtype.names) + + query_ds = get_query_data(query_cfg) + query_ds = query_ds.with_format("torch", columns=query_cfg.modules) + + start = perf_counter() + attr = Attributor( + index_cfg.run_path, + device=device, + faiss_cfg=FaissConfig(mmap_index=True, num_shards=5), + unit_norm=query_cfg.unit_normalize, + ) + print(f"Attributor loaded in {perf_counter() - start}") + + print("Searching...") + start = perf_counter() + + search_inputs = { + name: torch.as_tensor(query_ds[:][name]).to(device) + for name in query_cfg.modules + } + + scores, indices = attr.search(search_inputs, k) + + print(f"Query time: {perf_counter() - start}") + + scores_tensor = torch.as_tensor(scores).cpu() + indices_tensor = torch.as_tensor(indices).to(torch.int64).cpu() + + num_queries, num_scores = scores_tensor.shape + print( + f"Collected scores for {num_queries} queries " f"with {num_scores} scores each." + ) + + output_path = Path(query_cfg.scores_path) / "scores.hf" + output_path.parent.mkdir(parents=True, exist_ok=True) + + dataset = Dataset.from_dict( + { + "scores": scores_tensor.tolist(), + "indices": indices_tensor.tolist(), + } + ) + dataset.save_to_disk(output_path) + print(f"Saved raw search results to {output_path}") diff --git a/examples/analyze_query.py b/examples/analyze_query.py new file mode 100644 index 0000000..f1fabc4 --- /dev/null +++ b/examples/analyze_query.py @@ -0,0 +1,238 @@ +import shutil +from argparse import ArgumentParser +from pathlib import Path + +import torch +from datasets import Dataset, IterableDataset, concatenate_datasets, load_from_disk + +from bergson.data import DataConfig, IndexConfig, load_data_string + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--scores_path", + type=Path, + required=True, + help="Directory containing the query scores (scores.hf).", + ) + parser.add_argument( + "--dataset", type=str, default="EleutherAI/deep-ignorance-pretraining-mix" + ) + parser.add_argument("--split", type=str, default="train") + + return parser.parse_args() + + +def save_strategy(name: str, dataset: Dataset, output_root: Path): + strategy_dir = output_root / name + strategy_dir.mkdir(parents=True, exist_ok=True) + + csv_path = strategy_dir / f"{name}.csv" + dataset.to_csv(csv_path) + + hf_path = strategy_dir / "data.hf" + if hf_path.exists(): + shutil.rmtree(hf_path) + dataset.save_to_disk(hf_path, num_proc=1) + + return csv_path, hf_path + + +def get_text_rows( + indices: list[int], + scores: list[float], + index_cfg: IndexConfig, + score_column: str = "score", +): + ds = load_data_string( + index_cfg.data.dataset, index_cfg.data.split, streaming=index_cfg.streaming + ) + + if not index_cfg.streaming: + assert isinstance(ds, Dataset), "Dataset required for direct selection" + + selected = ds.select(indices) + return selected.add_column(score_column, scores, new_fingerprint="score") + + assert isinstance(ds, IterableDataset), "IterableDataset required for streaming" + + pending_idxs = {idx: score for idx, score in zip(indices, scores)} + data = [] + + for i, row in enumerate(ds): + if i in pending_idxs: + score = pending_idxs.pop(i) + data.append( + { + score_column: score, + **row, + } + ) + if not pending_idxs: + break + + assert not pending_idxs, ( + f"Could not collect all rows for the requested indices." + f" Missing indices: {sorted(pending_idxs)}" + ) + + return Dataset.from_list(data) + + +def main(): + do_query_normalization = True + + args = parse_args() + + limit = 50 + + file_name = "scores.hf" + col_name = "scores" + + sum_of_squares_file_name = "scores_sum_of_squares.hf" + sum_of_squares_col_name = "sum_of_squares" + + # file_name = "scores_normalized_scores.hf" + # col_name = "normalized_scores" + + query_path = "runs/wmdp_bio_robust_mcqa_means_0/full_accum_mean_mod_grads.hf" + query_scores_path = Path(args.scores_path) + assert query_scores_path.exists(), f"Query scores not found at {query_scores_path}." + + if (query_scores_path / "shard-00000").exists(): + datasets = [] + ss_datasets = [] + for shard_path in sorted(query_scores_path.iterdir()): + if shard_path.is_dir() and "shard" in shard_path.name: + datasets.append(load_from_disk(shard_path / file_name)) + ss_datasets.append( + load_from_disk(shard_path / sum_of_squares_file_name) + ) + scores_ds = concatenate_datasets(datasets) + ss_ds = concatenate_datasets(ss_datasets) + else: + scores_ds = load_from_disk(query_scores_path / file_name) + ss_ds = load_from_disk(query_scores_path / sum_of_squares_file_name) + + scores_tensor = torch.tensor(scores_ds[col_name], dtype=torch.float32) + ss_tensor = torch.tensor(ss_ds[sum_of_squares_col_name], dtype=torch.float32) + values_norm = ss_tensor.sqrt() + + # breakpoint() + try: + indices_tensor = torch.tensor(scores_ds["indices"], dtype=torch.int64) + except Exception as e: + print(f"Error loading indices: {e}") + indices_tensor = torch.tensor(range(len(scores_ds)), dtype=torch.int64) + indices_tensor = indices_tensor.unsqueeze(-1) + print(indices_tensor.shape) + + # ss_tensor = ss_tensor.unsqueeze(-1) + # batch_size = 1024 + # for i in range(0, len(ss_tensor), batch_size): + # batch = ss_tensor[i:i+batch_size] + # scores_tensor[i:i+batch_size] = + # scores_tensor[i:i+batch_size] / (batch.sqrt() + 1e-12) + if do_query_normalization: + query_ds = load_from_disk(query_path) + query_ds = query_ds.with_format("torch", columns=list(query_ds.column_names)) + query_norm = ( + torch.stack( + [(query_ds[module][:] ** 2).sum() for module in query_ds.column_names] + ) + .sum() + .sqrt() + ) + scores_tensor /= (query_norm * (values_norm + 1e-12)).unsqueeze(-1) + else: + scores_tensor = scores_tensor / (values_norm + 1e-12) + + print(f"max score: {scores_tensor.max()}, min score: {scores_tensor.min()}") + + assert scores_tensor.ndim == 2 and indices_tensor.ndim == 2, ( + f"Expected scores and indices to have shape (num_queries, k). " + f"Got {scores_tensor.shape} and {indices_tensor.shape}." + ) + + # num_scores_per_doc, k = scores_tensor.shape + # if num_scores_per_doc != 7: + # raise ValueError( + # f"Expected exactly 7 queries (6 for max, 1 for mean); " + # f"received {num_scores_per_doc} for each of {k} top documents." + # ) + # if num_scores_per_doc != 1: + # raise ValueError( + # f"Expected exactly 1 query (for mean); " + # f"received {num_scores_per_doc} for each of {k} top documents." + # ) + + # Skip first row, flatten remaining 6 rows + # scores_flat = scores_tensor[1:].reshape(-1) # Shape: (6 * num_neighbors,) + # indices_flat = indices_tensor[1:].reshape(-1) # Shape: (6 * num_neighbors,) + + # # Use scatter_reduce to get max score for each unique index + # max_doc_idx = int(indices_flat.max().item() + 1) + # max_scores = torch.full((max_doc_idx,), float("-inf"), dtype=torch.float32) + # max_scores.scatter_reduce_(0, indices_flat, scores_flat, reduce="amax") + + # # Convert to dict, excluding indices that were never seen + # max_scores_by_index = { + # idx: score.item() + # for idx, score in enumerate(max_scores) + # if score != float("-inf") + # } + + # sorted_max = sorted( + # max_scores_by_index.items(), key=lambda item: item[1], reverse=True + # ) + # max_indices = [idx for idx, _ in sorted_max] + # max_scores = [score for _, score in sorted_max] + + mean_indices = indices_tensor[:, 0] # .tolist() + mean_scores = scores_tensor[:, 0] # .tolist() + + top_k = min( + limit, scores_tensor.shape[0] + ) # Handle case where there are fewer than 50 + top_scores, top_positions = torch.topk(mean_scores, k=top_k, largest=True) + + top_k_indices = mean_indices[top_positions].tolist() + top_k_scores = top_scores.tolist() + + print("top k scores", top_k_scores) + + mean_indices = top_k_indices + mean_scores = top_k_scores + + print( + f"max score: {torch.tensor(mean_scores).max()}, " + f"min score: {torch.tensor(mean_scores).min()}" + ) + + # TODO get the top limit + # if limit is not None: + # max_indices = max_indices[:limit] + # max_scores = max_scores[:limit] + # mean_indices = mean_indices[:limit] + # mean_scores = mean_scores[:limit] + + data_cfg = DataConfig(dataset=args.dataset, split=args.split) + index_cfg = IndexConfig(run_path="", data=data_cfg, streaming=True) + + # max_ds = + # get_text_rows(max_indices, max_scores, index_cfg, score_column="max_score") + mean_ds = get_text_rows( + mean_indices, mean_scores, index_cfg, score_column="mean_score" + ) + + query_scores_path.mkdir(parents=True, exist_ok=True) + # max_csv, max_hf = save_strategy("max", max_ds, query_scores_path) + mean_csv, mean_hf = save_strategy("mean", mean_ds, query_scores_path) + + # print(f"Saved max top results to {max_csv} and {max_hf}") + print(f"Saved mean top results to {mean_csv} and {mean_hf}") + + +if __name__ == "__main__": + main() diff --git a/examples/assemble_query.py b/examples/assemble_query.py new file mode 100644 index 0000000..44148e7 --- /dev/null +++ b/examples/assemble_query.py @@ -0,0 +1,472 @@ +# TODAY +# Assemble dataset!! 6 queries +# Try multi node generation +# I believe the MCQA and Cloze setups are pulled from the same eval and are +# both roughly 1k rows, like the original wmdp-bio as a whole. + +import os +import shutil +import socket +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +import torch.multiprocessing as mp +from datasets import ( + Dataset, + concatenate_datasets, + get_dataset_config_names, + load_dataset, +) +from numpy.lib.recfunctions import ( + structured_to_unstructured, + unstructured_to_structured, +) +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes +from transformers import AutoTokenizer + +from bergson import DataConfig, IndexConfig, load_gradients +from bergson.build import dist_worker, estimate_advantage, worker +from bergson.data import create_index, load_gradient_dataset +from bergson.utils import assert_type + + +def load_mcqa_dataset(): + def map(x): + """Many of the questions require the choices to be given to be coherent, e.g. + 'Which of the following is the correct answer to the question?'""" + + choices = [f"{i}. {choice}" for i, choice in enumerate(x["choices"])] + prompt = " \n ".join( + [x["question"]] + + ["Choices: "] + + choices + + ["Answer: "] + + [f"{choices[int(x['answer'])]}"] + ) + + return { + "text": prompt, + # "prompt": prompt, + # "completion": f"{choices[int(x['answer'])]}", + "subset": x["subset"], + } + + mcqa_dataset_name = "EleutherAI/wmdp_bio_robust_mcqa" + subsets = get_dataset_config_names(mcqa_dataset_name) + mcqa_datasets = [] + for subset in subsets: + ds = assert_type( + Dataset, load_dataset(mcqa_dataset_name, subset, split="robust") + ) + ds = ds.add_column("subset", [subset] * len(ds)) + mcqa_datasets.append(ds) + + mcqa_ds = concatenate_datasets(mcqa_datasets) + + return mcqa_ds.map(map, remove_columns=["choices", "answer", "question"]) + + +def tokenize( + batch: dict, *, args: DataConfig, tokenizer, apply_chat_template: bool = True +): + """Tokenize a batch of data with `tokenizer` according to `args`.""" + kwargs = dict( + return_attention_mask=False, + return_length=True, + truncation=args.truncation, + ) + if args.completion_column: + # We're dealing with a prompt-completion dataset + convos = [ + [ + {"role": "user", "content": assert_type(str, prompt)}, + {"role": "assistant", "content": assert_type(str, resp)}, + ] + for prompt, resp in zip( + batch[args.prompt_column], batch[args.completion_column] + ) + ] + elif args.conversation_column: + # We're dealing with a conversation dataset + convos = assert_type(list, batch[args.conversation_column]) + else: + # We're dealing with vanilla next-token prediction + return tokenizer(batch[args.prompt_column], **kwargs) + + strings = convos + + encodings = tokenizer(strings, **kwargs) + labels_list: list[list[int]] = [] + + for i, convo in enumerate(convos): + # Find the spans of the assistant's responses in the tokenized output + pos = 0 + spans: list[tuple[int, int]] = [] + + for msg in convo: + if msg["role"] != "assistant": + continue + + ans = msg["content"] + start = strings[i].rfind(ans, pos) + if start < 0: + raise RuntimeError( + "Failed to find completion in the chat-formatted conversation. " + "Make sure the chat template does not alter the completion, e.g. " + "by removing leading whitespace." + ) + + # move past this match + pos = start + len(ans) + + start_token = encodings.char_to_token(i, start) + end_token = encodings.char_to_token(i, pos) + spans.append((start_token, end_token)) + + # Labels are -100 everywhere except where the assistant's response is + tokens = encodings["input_ids"][i] + labels = [-100] * len(tokens) + for start, end in spans: + if start is not None and end is not None: + labels[start:end] = tokens[start:end] + + labels_list.append(labels) + + return dict(**encodings, labels=labels_list) + + +def tokenize_mcqa( + batch: dict, + *, + tokenizer, + args: DataConfig, + answer_marker: str = "Answer:", +): + """ + Custom tokenizer for this MCQA experiment that only keeps labels on the + final answer span so gradient collection ignores the rest of the prompt. + + Codex wrote this. + """ + # TODO integrate custom masking into tokenize if necessary + return tokenize(batch, args=args, tokenizer=tokenizer, apply_chat_template=False) + + if not tokenizer.is_fast: + raise ValueError("Fast tokenizer required for answer span alignment.") + + kwargs = dict( + return_attention_mask=False, + return_length=True, + truncation=args.truncation, + return_offsets_mapping=True, + ) + + encodings = tokenizer(batch[args.prompt_column], **kwargs) + + whitespace = {" ", "\t", "\n", "\r", "\f", "\v"} + labels: list[list[int]] = [] + answer_token_indices: list[list[int]] = [] + + for text, offsets, token_ids in zip( + batch[args.prompt_column], + encodings["offset_mapping"], + encodings["input_ids"], + ): + marker_idx = text.rfind(answer_marker) + if marker_idx < 0: + raise ValueError(f"Failed to locate '{answer_marker}' in:\n{text}") + + start = marker_idx + len(answer_marker) + while start < len(text) and text[start] in whitespace: + start += 1 + + if start >= len(text): + raise ValueError( + f"No answer text found after '{answer_marker}' in:\n{text}" + ) + + end = len(text) + while end > start and text[end - 1] in whitespace: + end -= 1 + + if end <= start: + raise ValueError(f"Empty answer span detected in:\n{text}") + + example_labels = [-100] * len(token_ids) + supervised_indices: list[int] = [] + + for idx, span in enumerate(offsets): + if span is None: + continue + + tok_start, tok_end = span + if tok_start is None or tok_end is None or tok_start == tok_end: + continue + + if max(tok_start, start) < min(tok_end, end): + example_labels[idx] = token_ids[idx] + supervised_indices.append(idx) + + if not supervised_indices: + raise RuntimeError( + "Failed to align answer text with any tokens.\n" f"Example:\n{text}" + ) + + labels.append(example_labels) + answer_token_indices.append(supervised_indices) + + encodings.pop("offset_mapping") + encodings["labels"] = labels + encodings["answer_token_indices"] = answer_token_indices + return encodings + + +def build_mcqa_index(cfg: IndexConfig, ds_path: str): + # TODO Set labels mask to final answer + # TODO are these labels always one token? + + # In many cases the token_batch_size may be smaller than the max length allowed by + # the model. If cfg.data.truncation is True, we use the tokenizer to truncate + tokenizer = AutoTokenizer.from_pretrained(cfg.model, revision=cfg.revision) + tokenizer.model_max_length = min(tokenizer.model_max_length, cfg.token_batch_size) + + # Do all the data loading and preprocessing on the main process + ds = Dataset.load_from_disk(ds_path, keep_in_memory=False) + + remove_columns = ds.column_names if cfg.drop_columns else None + ds = ds.map( + tokenize_mcqa, + batched=True, + fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer), + remove_columns=remove_columns, + ) + if cfg.data.reward_column: + assert isinstance(ds, Dataset), "Dataset required for advantage estimation" + ds = ds.add_column( + "advantage", + estimate_advantage(ds, cfg.data), + new_fingerprint="advantage", # type: ignore + ) + + world_size = torch.cuda.device_count() + if world_size <= 1: + # Run the worker directly if no distributed training is needed. This is great + # for debugging purposes. + worker(0, 1, cfg, ds) + else: + # Set up multiprocessing and distributed training + mp.set_sharing_strategy("file_system") + + # Find an available port for distributed training + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + _, port = s.getsockname() + + ctx = start_processes( + "build", + dist_worker, + args={i: (i, world_size, cfg, ds) for i in range(world_size)}, + envs={ + i: { + "LOCAL_RANK": str(i), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(port), + } + for i in range(world_size) + }, + logs_specs=DefaultLogsSpecs(), + ) + ctx.wait() + + try: + os.rename(cfg.partial_run_path, cfg.run_path) + except Exception: + pass + + +def create_query_index( + query_ds: Dataset, run_path: str, assembled_dataset_path: str, index_dtype: np.dtype +): + structured_mmap = load_gradients(run_path) + mmap_dtype = structured_mmap.dtype + + # Copy into memory + gradient_tensor = torch.tensor(structured_to_unstructured(structured_mmap)).to( + torch.float32 + ) + + print("mmap sum", gradient_tensor.sum()) + print("mmap sum", gradient_tensor.abs().sum()) + + # Group mmap gradient rows by the subset they came from + subset_gradients = defaultdict(list) + for grads_row, ds_row in zip(gradient_tensor, query_ds): + subset_gradients[ds_row["subset"]].append(grads_row) + + subset_mean_gradients = {"overall": gradient_tensor.mean(dim=0)} + for subset, gradients in subset_gradients.items(): + mean_gradient = torch.stack(gradients).mean(dim=0) + subset_mean_gradients[subset] = mean_gradient + + # Copy everything from the origin run path to the new path + # except gradients.bin and data.hf + os.makedirs(assembled_dataset_path, exist_ok=True) + for item in os.listdir(run_path): + if item not in ["gradients.bin", "data.hf"]: + dest = Path(assembled_dataset_path) / item + shutil.copy(Path(run_path) / item, dest) + + if (Path(assembled_dataset_path) / "data.hf").exists(): + if (Path(assembled_dataset_path) / "data.hf").is_file(): + (Path(assembled_dataset_path) / "data.hf").unlink() + else: + shutil.rmtree(Path(assembled_dataset_path) / "data.hf") + + # Write structured mean queries to data.hf + np_mean_grads = np.stack( + [item.numpy() for item in list(subset_mean_gradients.values())], axis=0 + ) + # structured_np_mean_grads = unstructured_to_structured(np_mean_grads, mmap_dtype) + # data = [ + # { + # name: structured_np_mean_grads[name][i].tolist() + # for name in mmap_dtype.names + # } + # for i in range(structured_np_mean_grads.shape[0]) + # ] + + means_dataset = Dataset.from_dict( + { + "scores": [0.0] * len(subset_mean_gradients), + } + ) + means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + + # Assemble grad sizes + grad_sizes = {} + for name in mmap_dtype.names: + field_dtype = mmap_dtype.fields[name][0] + subdtype = field_dtype.subdtype + assert subdtype is not None + + _, shape = subdtype + grad_sizes[name] = int(np.prod(shape)) + + # Create and populate the index + index_grads = create_index( + str(assembled_dataset_path), len(subset_mean_gradients), grad_sizes, index_dtype + ) + index_grads[:] = unstructured_to_structured( + np_mean_grads.astype(index_dtype), mmap_dtype + ) + index_grads.flush() + + load_gradient_dataset(assembled_dataset_path) + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + if torch.any(cosine_sims <= 0.09): + raise ValueError( + f"Cosine similarity between mean gradients and the first query gradient " + f"is not greater than 0.09. Cosine sims: {cosine_sims}" + ) + else: + print(f"Cosine sims: {cosine_sims}") + + +def main(): + projection_dim = 0 + model_name = "EleutherAI/deep_ignorance_pretraining_baseline_small" + ds_path = f"runs/ds_wmdp_bio_robust_mcqa_{projection_dim}" + index_path = f"runs/wmdp_bio_robust_mcqa_means_{projection_dim}" + assembled_dataset_path = f"runs/mean_biorisk_queries_{projection_dim}" + + mcqa_ds = assert_type(Dataset, load_mcqa_dataset()) + mcqa_ds.save_to_disk(ds_path) + + data_config = DataConfig( + dataset=ds_path, + # prompt_column="prompt", + # completion_column="completion", + prompt_column="text", + ) + + # cfg = IndexConfig( + # run_path=index_path, + # save_index=True, + # save_processor=True, + # precision="fp16", + # data=data_config, + # fsdp=True, + # model=model_name, + # projection_dim=projection_dim, + # reshape_to_square=True, + # ) + + cfg = IndexConfig( + run_path=index_path, + # save_index=True, + save_index=False, + # save_processor=True, + save_processor=False, + precision="fp16", + data=data_config, + fsdp=True, + model=model_name, + projection_dim=projection_dim, + reshape_to_square=True, + module_wise=True, + in_memory_index=True, + skip_preconditioners=True, + token_batch_size=1024, + ) + + # build_mcqa_index(cfg, ds_path) + + # Sum all the accumulated mean gradients into a single tensor + world_size = 8 + accum = {} + for rank in range(world_size): + accum_mean_mod_grads = torch.load( + os.path.join(cfg.run_path, f"accum_mean_grads_{rank}.pth") + ) + if not accum: + accum = accum_mean_mod_grads + else: + for name in accum_mean_mod_grads.keys(): + accum[name] += accum_mean_mod_grads[name] + + # Convert to HF DS + accum_ds = Dataset.from_dict({name: accum[name].numpy() for name in accum}) + accum_ds.save_to_disk( + os.path.join(cfg.run_path, "full_accum_mean_mod_grads.hf"), num_shards=1 + ) + + # Trackstar uses 2**16 with an 8B model + # We are collecting gradients for a ~2.7B model + # We are using ~2**13 I think + # modules = set(load_gradients(cfg.run_path).dtype.names) + # print( + # f"Full projection dim: {len(modules) * cfg.projection_dim * cfg.projection_dim}" + # ) + + exit() + + create_query_index( + mcqa_ds, cfg.run_path, assembled_dataset_path, index_dtype=np.float16 + ) + + +if __name__ == "__main__": + main() diff --git a/examples/setup_ds.py b/examples/setup_ds.py new file mode 100644 index 0000000..6ac1319 --- /dev/null +++ b/examples/setup_ds.py @@ -0,0 +1,23 @@ +import subprocess + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "EleutherAI/deep_ignorance_pretraining_baseline_small" +ds_name = "EleutherAI/deep-ignorance-pretraining-mix" +# model_name = "HuggingFaceTB/SmolLM2-135M" +# ds_name = "NeelNanda/pile-10k" + +model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained(model_name) + +ds = load_dataset(ds_name, split="train") + +# Kick off a build to warm up the cache +subprocess.run( + ( + f"bergson build runs/test --model {model_name} --dataset {ds_name} " + f"--truncation --save_index False --split[:10]" + ), + shell=True, +) diff --git a/examples/warmup.sh b/examples/warmup.sh new file mode 100644 index 0000000..e709792 --- /dev/null +++ b/examples/warmup.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --job-name=setup_ds +#SBATCH --output=setup_ds_%j.out +#SBATCH --error=setup_ds_%j.err +#SBATCH --time=5:00:00 +#SBATCH --gpus=1 # This allocates 72 CPU cores +#SBATCH --mail-type=BEGIN,END,FAIL +#SBATCH --mail-user=luciarosequirke@gmail.com + + +cd /home/lucia/bergson +source .venv/bin/activate + +# Set HF cache + +# Run the Python command +HUGGINGFACE_HUB_CACHE="/projects/a5k/public/lucia" \ +TRANSFORMERS_CACHE="/projects/a5k/public/lucia" \ +uv run python -m examples.setup_ds