Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ trainer.train()

## Attention Head Gradients

By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module can be collected too. To collect per-head gradients configure an AttentionConfig for each module of interest.
By default Bergson collects gradients for named parameter matrices, but gradients for individual attention heads within an attention module may also be collected. To collect per-head gradients programmatically configure an AttentionConfig for each module of interest, or specify attention modules and a single shared configuration using the command line tool.

```bash
bergson build runs/test --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --truncation --split_attention_modules "h.0.attn.attention.out_proj" --attention.num_heads 16 --attention.head_size 4 --attention.head_dim 2
```

```python
from bergson import AttentionConfig, IndexConfig, DataConfig
Expand Down Expand Up @@ -99,7 +103,7 @@ Where a reward signal is available we compute gradients using a weighted advanta
bergson build <output_path> --model <model_name> --dataset <dataset_name> --reward_column <reward_column_name>
```

## Queries
## Index Queries

We provide a query Attributor which supports unit normalized gradients and KNN search out of the box.

Expand Down Expand Up @@ -127,6 +131,12 @@ with attr.trace(model.base_model, 5) as result:
model.zero_grad()
```

## On-The-Fly Queries

```bash
bergson query runs/scores --query_path runs/query_index --model EleutherAI/pythia-14m --dataset NeelNanda/pile-10k --save_index False --truncation
```

# Development

```bash
Expand Down
27 changes: 24 additions & 3 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,26 @@

from .build import build_gradient_dataset
from .data import IndexConfig, QueryConfig
from .query import query_gradient_dataset
from .dynamic_query import query_gradient_dataset
from .static_query import query_existing


@dataclass
class StaticQuery:
"""Query an on-disk gradient index."""

query_cfg: QueryConfig

index_cfg: IndexConfig

k: int | None = None

def execute(self):
"""Query an on-disk gradient index."""
assert self.query_cfg.scores_path
assert self.query_cfg.query_path

query_existing(self.query_cfg, self.index_cfg, self.k)


@dataclass
Expand Down Expand Up @@ -35,12 +54,14 @@ class Query:

def execute(self):
"""Query the gradient dataset."""
assert self.query_cfg.scores_path
assert self.query_cfg.query_path

if os.path.exists(self.index_cfg.run_path) and self.index_cfg.save_index:
raise ValueError(
"Index path already exists and save_index is True - "
"running this query will overwrite the existing gradients. "
"If you meant to query the existing gradients, use "
"If you meant to query the existing gradients use "
"Attributor instead."
)

Expand All @@ -51,7 +72,7 @@ def execute(self):
class Main:
"""Routes to the subcommands."""

command: Union[Build, Query]
command: Union[Build, Query, StaticQuery]

def execute(self):
"""Run the script."""
Expand Down
9 changes: 6 additions & 3 deletions bergson/attributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def __init__(
self.grads[name] /= norm

def search(
self, queries: dict[str, Tensor], k: int, modules: list[str] | None = None
) -> tuple[Tensor, Tensor]:
self,
queries: dict[str, Tensor],
k: int | None,
modules: list[str] | None = None,
):
"""
Search for the `k` nearest examples in the index based on the query or queries.

Expand Down Expand Up @@ -112,7 +115,7 @@ def search(
)

modules = modules or list(q.keys())
k = min(k, self.N)
k = min(k or self.N, self.N)

scores = torch.stack(
[q[name] @ self.grads[name].mT for name in modules], dim=-1
Expand Down
23 changes: 5 additions & 18 deletions bergson/collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Callable, Literal
from typing import Literal

import numpy as np
import torch
Expand All @@ -9,7 +9,7 @@
from tqdm.auto import tqdm
from transformers import PreTrainedModel

from .data import create_index, pad_and_tensor
from .data import Query, create_index, pad_and_tensor
from .gradients import AttentionConfig, GradientCollector, GradientProcessor
from .peft import set_peft_enabled

Expand All @@ -29,7 +29,7 @@ def collect_gradients(
save_index: bool = True,
save_processor: bool = True,
drop_columns: bool = False,
query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor] | None = None,
query: Query | None = None,
):
"""
Compute projected gradients using a subset of the dataset.
Expand Down Expand Up @@ -94,12 +94,6 @@ def callback(name: str, g: torch.Tensor):
dtype=dtype,
fill_value=0.0,
)
per_doc_scores = torch.full(
(len(data),),
device=model.device,
dtype=dtype,
fill_value=0.0,
)

for indices in tqdm(batches, disable=rank != 0, desc="Building index"):
batch = data[indices]
Expand Down Expand Up @@ -156,9 +150,8 @@ def callback(name: str, g: torch.Tensor):
for module_name in mod_grads.keys():
grad_buffer[module_name][indices] = mod_grads[module_name].numpy()

if query_callback is not None:
scores = query_callback(mod_grads)
per_doc_scores[indices] = scores.detach().type_as(per_doc_scores)
if query:
query(indices, mod_grads)

mod_grads.clear()
per_doc_losses[indices] = losses.detach().type_as(per_doc_losses)
Expand All @@ -178,12 +171,6 @@ def callback(name: str, g: torch.Tensor):
feature=Value("float16" if dtype == torch.float16 else "float32"),
new_fingerprint="loss",
)
data = data.add_column(
"scores",
per_doc_scores.cpu().numpy(),
feature=Value("float16" if dtype == torch.float16 else "float32"),
new_fingerprint="scores",
)
data.save_to_disk(path + "/data.hf")

if save_processor:
Expand Down
98 changes: 81 additions & 17 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import math
import os
import random
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Sequence
from typing import Callable, Literal, Sequence

import numpy as np
import pyarrow as pa
Expand All @@ -25,11 +26,65 @@
from .utils import assert_type


class Query:
"""
Wraps a query scoring callback and stores the resulting scores in a tensor.
"""

def __init__(
self,
query_callback: Callable[[dict[str, torch.Tensor]], torch.Tensor],
num_items: int,
num_scores: int,
scores_path: str,
*,
dtype: torch.dtype = torch.float32,
device: torch.device | str = "cpu",
):
self._query_callback = query_callback
self._scores_path = scores_path
self.scores = torch.zeros((num_items, num_scores), dtype=dtype, device=device)
self.num_written = 0

def __call__(self, indices: list[int], mod_grads: dict[str, torch.Tensor]):
scores = self._query_callback(mod_grads).detach()
if scores.ndim == 1:
scores = scores.unsqueeze(-1)
assert scores.shape[0] == len(indices)
assert scores.shape[1] == self.scores.shape[1]

scores = scores.to(device=self.scores.device, dtype=self.scores.dtype)
self.scores[indices] = scores
self.num_written += len(indices)

if self.num_written >= len(self.scores):
self.flush()

def flush(self):
dataset = Dataset.from_dict(
{
"scores": self.scores.cpu().numpy().tolist(),
}
)
try:
dataset.save_to_disk(self._scores_path)
except Exception as e:
# Handle collisions with existing datasets
print(f"Error writing scores to disk: {e}")
random_hash = str(uuid.uuid4())[:8]
alternate_path = self._scores_path.replace(".hf", f"_{random_hash}.hf")
print(f"Writing to alternate path: {alternate_path}")
dataset.save_to_disk(alternate_path)


@dataclass
class DataConfig:
dataset: str = "EleutherAI/SmolLM2-135M-10B"
"""Dataset identifier to build the index from."""

subset: str | None = None
"""Subset of the dataset to use for building the index."""

split: str = "train"
"""Split of the dataset to use for building the index."""

Expand Down Expand Up @@ -70,29 +125,30 @@ class QueryConfig:
"""Config for querying an index on the fly."""

query_path: str = ""
"""Path to the query dataset."""
"""Path to the existing query index."""

query_method: Literal["mean", "nearest"] = "mean"
"""Method to use for computing the query."""
score: Literal["mean", "nearest", "individual"] = "mean"
"""Method for scoring the gradients with the query. If mean
gradients will be scored by their similarity with the mean
query gradients, otherwise by the most similar query gradient."""

save_processor: bool = True
"""Whether to write the query dataset gradient processor
to disk."""
scores_path: str = ""
"""Path to the directory where query scores should be written."""

query_preconditioner_path: str | None = None
"""Path to a precomputed preconditioner. The precomputed
preconditioner is applied to the query dataset gradients."""
"""Path to a precomputed preconditioner to be applied to
the query dataset gradients."""

index_preconditioner_path: str | None = None
"""Path to a precomputed preconditioner. The precomputed
preconditioner is applied to the query dataset gradients.
This does not affect the ability to compute a new
preconditioner during gradient collection."""
"""Path to a precomputed preconditioner to be applied to
the query dataset gradients. This does not affect the
ability to compute a new preconditioner during gradient
collection."""

mixing_coefficient: float = 0.5
mixing_coefficient: float = 0.99
"""Coefficient to weight the application of the query preconditioner
and the pre-computed index preconditioner. 0.0 means only use the
query preconditioner and 1.0 means only use the index preconditioner."""
index preconditioner and 1.0 means only use the query preconditioner."""

modules: list[str] = field(default_factory=list)
"""Modules to use for the query. If empty, all modules will be used."""
Expand Down Expand Up @@ -357,7 +413,10 @@ def create_index(


def load_data_string(
data_str: str, split: str = "train", streaming: bool = False
data_str: str,
split: str = "train",
subset: str | None = None,
streaming: bool = False,
) -> Dataset | IterableDataset:
"""Load a dataset from a string identifier or path."""
if data_str.endswith(".csv"):
Expand All @@ -366,7 +425,12 @@ def load_data_string(
ds = assert_type(Dataset, Dataset.from_json(data_str))
else:
try:
ds = load_dataset(data_str, split=split, streaming=streaming)
if subset:
ds = load_dataset(
data_str, split=split, subset=subset, streaming=streaming
)
else:
ds = load_dataset(data_str, split=split, streaming=streaming)

if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
raise NotImplementedError(
Expand Down
Loading
Loading