Skip to content

Commit c3b24de

Browse files
authored
Merge pull request #53 from smarter/fix-ground-truth
ekfac: fix compute_ekfac_ground_truth, add minimal CI
2 parents 8232b77 + 7e350e6 commit c3b24de

28 files changed

+1428
-1639
lines changed

.github/workflows/build.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: build
2+
3+
on:
4+
push:
5+
branches:
6+
- ekfac
7+
pull_request:
8+
branches:
9+
- ekfac
10+
jobs:
11+
build:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: actions/checkout@v4
15+
- uses: actions/setup-python@v5
16+
with:
17+
python-version: "3.10"
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install -e ".[dev,faiss]"
22+
# TODO: Proper test infrastructure for tests/ekfac_tests
23+
# - name: Run tests
24+
# run: pytest
25+
# TODO: run pyright on whole codebase
26+
- name: Type Checking bergson/hessians
27+
uses: jakebailey/pyright-action@v1
28+
with:
29+
version: 1.1.406
30+
working-directory: bergson/hessians
31+
- name: Type Checking tests/ekfac_tests
32+
uses: jakebailey/pyright-action@v1
33+
with:
34+
version: 1.1.406
35+
working-directory: tests/ekfac_tests
36+
- name: build
37+
run: pip wheel --no-deps -w dist .
38+
env:
39+
HF_HUB_DOWNLOAD_TIMEOUT: 100

bergson/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def callback(name: str, g: torch.Tensor):
7272
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
7373

7474
# Allocate structured space ahead of time for the gradients
75-
grad_buffer = create_index(cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16)
75+
grad_buffer = create_index(
76+
cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16
77+
)
7678

7779
per_doc_losses = torch.full(
7880
(len(data),),

bergson/data.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from .utils import assert_type
1818

19+
Precision = Literal["bf16", "fp16", "fp32", "int4", "int8"]
20+
1921

2022
@dataclass
2123
class DataConfig:
@@ -48,7 +50,7 @@ class IndexConfig:
4850
fsdp: bool = False
4951
"""Whether to use Fully Sharded Data Parallel (FSDP) for collecing gradients."""
5052

51-
precision: Literal["bf16", "fp16", "fp32", "int4", "int8"] = "bf16"
53+
precision: Precision = "bf16"
5254
"""Precision to use for the model parameters."""
5355

5456
projection_dim: int = 16
@@ -99,7 +101,9 @@ def ceildiv(a: int, b: int) -> int:
99101
return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers
100102

101103

102-
def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = None) -> list[list[int]]:
104+
def allocate_batches(
105+
doc_lengths: list[int], N: int, world_size: Optional[int] = None
106+
) -> list[list[int]]:
103107
"""
104108
Allocate documents into batches that are then distributed evenly across
105109
a fixed number of workers.
@@ -183,7 +187,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
183187
while len(batches) < world_size:
184188
big = batches.pop(0) # take the current largest
185189
if len(big) == 1: # cannot split a singleton
186-
raise RuntimeError("Not enough documents to give each worker at least one batch.")
190+
raise RuntimeError(
191+
"Not enough documents to give each worker at least one batch."
192+
)
187193
batches.append([big.pop()]) # move one doc into new batch
188194
batches.append(big) # put the remainder back
189195
# preserve cost constraint automatically
@@ -205,7 +211,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
205211
i += 1
206212

207213
assert len(batches) == target_batches
208-
assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)
214+
assert all(
215+
max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches
216+
)
209217

210218
# ---------------------------------------------------------------------
211219
# 4) Round-robin assignment to workers
@@ -219,7 +227,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
219227
return allocation[rank]
220228

221229

222-
def create_index(root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike) -> np.memmap:
230+
def create_index(
231+
root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike
232+
) -> np.memmap:
223233
"""Create a memory-mapped file for storing structured gradients
224234
and persist metadata."""
225235
grad_path = os.path.join(root, "gradients.bin")
@@ -310,7 +320,9 @@ def load_shard(dir: str) -> Dataset:
310320
if concatenate_gradients:
311321
unstructured_data = structured_to_unstructured(mmap)
312322
flat = pa.array(unstructured_data.reshape(-1))
313-
col_arrow = pa.FixedSizeListArray.from_arrays(flat, unstructured_data.shape[1])
323+
col_arrow = pa.FixedSizeListArray.from_arrays(
324+
flat, unstructured_data.shape[1]
325+
)
314326

315327
ds = ds.add_column("gradients", col_arrow, new_fingerprint="gradients")
316328
# Add a column for each module's gradient vectors
@@ -374,7 +386,9 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
374386
{"role": "user", "content": assert_type(str, prompt)},
375387
{"role": "assistant", "content": assert_type(str, resp)},
376388
]
377-
for prompt, resp in zip(batch[args.prompt_column], batch[args.completion_column])
389+
for prompt, resp in zip(
390+
batch[args.prompt_column], batch[args.completion_column]
391+
)
378392
]
379393
elif args.conversation_column:
380394
# We're dealing with a conversation dataset
@@ -421,4 +435,7 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
421435
def unflatten(x: torch.Tensor, shapes: dict[str, Sequence[int]], dim: int = -1):
422436
"""Unflatten a tensor `x` into a dictionary of tensors with specified shapes."""
423437
numels = [math.prod(shape) for shape in shapes.values()]
424-
return {name: x.unflatten(dim, shape) for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim))}
438+
return {
439+
name: x.unflatten(dim, shape)
440+
for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim))
441+
}

bergson/distributed.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,30 @@ def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset:
6464
ds = load_dataset(data_str, split="train")
6565

6666
if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
67-
raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
67+
raise NotImplementedError(
68+
"DatasetDicts and IterableDatasetDicts are not supported."
69+
)
6870
except ValueError as e:
6971
# Automatically use load_from_disk if appropriate
7072
if "load_from_disk" in str(e):
7173
ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
7274
else:
7375
raise e
7476

75-
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)
77+
tokenizer = AutoTokenizer.from_pretrained(
78+
cfg.model, model_max_length=cfg.token_batch_size
79+
)
7680

77-
ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer))
81+
ds = ds.map(
82+
tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer)
83+
)
7884

7985
return ds
8086

8187

82-
def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tuple[AutoModelForCausalLM, set | None]:
88+
def setup_model_and_peft(
89+
cfg: IndexConfig, rank: int, dtype: torch.dtype
90+
) -> tuple[AutoModelForCausalLM, set | None]:
8391
"""Handle model loading, quantization, FSDP, and PEFT detection"""
8492

8593
torch.manual_seed(42)
@@ -141,7 +149,9 @@ def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tup
141149
model.get_submodule(processed_name)
142150
target_modules.add(processed_name)
143151
except AttributeError:
144-
print(f"Adapter parameter '{processed_name}' not found in the model.")
152+
print(
153+
f"Adapter parameter '{processed_name}' not found in the model."
154+
)
145155

146156
# Configure gradients
147157
model.requires_grad_(False)
@@ -223,7 +233,11 @@ def worker_wrapper(
223233
case "fp32":
224234
dtype = torch.float32
225235
case "int4" | "int8":
226-
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
236+
dtype = (
237+
torch.bfloat16
238+
if torch.cuda.is_bf16_supported()
239+
else torch.float16
240+
)
227241
case other:
228242
raise ValueError(f"Unsupported precision: {other}")
229243

@@ -305,7 +319,10 @@ def distributed_computing(
305319
ctx = start_processes(
306320
"build",
307321
worker_wrapper,
308-
args={i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor) for i in range(world_size)},
322+
args={
323+
i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor)
324+
for i in range(world_size)
325+
},
309326
envs={
310327
i: {
311328
"LOCAL_RANK": str(i),

bergson/gradients.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def to_adafactor(self) -> AdafactorNormalizer:
162162
and the factored second moments.
163163
"""
164164
# We assume avg_sq is a square matrix of shape [O, I]
165-
assert self.avg_sq.ndim == 2, f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
165+
assert (
166+
self.avg_sq.ndim == 2
167+
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
166168

167169
# Compute row and column means
168170
return AdafactorNormalizer(
@@ -213,9 +215,6 @@ def save(self, path: str):
213215
json.dump(cfg, f, indent=2)
214216

215217

216-
217-
218-
219218
@dataclass
220219
class GradientCollector(ContextDecorator):
221220
"""
@@ -346,7 +345,12 @@ def _save_input(self, module: nn.Module, inp: tuple, _):
346345
if p is not None and not isinstance(norm, AdamNormalizer):
347346
i = module.in_features
348347

349-
x = x @ self.projection(name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device).T
348+
x = (
349+
x
350+
@ self.projection(
351+
name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device
352+
).T
353+
)
350354

351355
module._inputs = x
352356

@@ -387,14 +391,20 @@ def _process_grad(self, module: nn.Module, _, grad_out):
387391

388392
# Project the gradients to the lower-dimensional space
389393
if p is not None:
390-
A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device)
391-
B = self.projection(name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device)
394+
A = self.projection(
395+
name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device
396+
)
397+
B = self.projection(
398+
name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device
399+
)
392400
P = A @ P @ B.T # [N, p, q]
393401

394402
# Both Adafactor and no normalizer, we can project G first
395403
else:
396404
if p is not None:
397-
A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device)
405+
A = self.projection(
406+
name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device
407+
)
398408
G = G @ A.T # [N, S, p]
399409

400410
P = G.mT @ I # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]

bergson/hessians/attribute.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
# ## 1. Load index for query and train data
1414

1515
parser = argparse.ArgumentParser(description="Process normalization flag.")
16-
parser.add_argument("--normalize", action="store_true", help="Gradients will be unit normalized.")
16+
parser.add_argument(
17+
"--normalize", action="store_true", help="Gradients will be unit normalized."
18+
)
1719
args = parser.parse_args()
1820

1921
device = "cuda:1"
2022

2123
# %%
22-
base_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/"
23-
index_dataset = load_dataset("json", data_files=f"{base_path}merged-medical-reformatted.jsonl")["train"]
24+
base_path = (
25+
"/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/"
26+
)
27+
index_dataset = load_dataset(
28+
"json", data_files=f"{base_path}merged-medical-reformatted.jsonl"
29+
)["train"]
2430
index_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/qwen14_merged_medical_proj16/merged_medical_no_normalizer"
2531
queries_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac"
2632

@@ -37,17 +43,25 @@
3743
normalize = args.normalize
3844

3945
attribution_dict = {}
40-
output_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer"
46+
output_path = (
47+
"/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer"
48+
)
4149
if normalize:
4250
output_path += "_unit_norm"
4351
os.makedirs(output_path, exist_ok=True)
4452

4553
for name in tqdm(list(names)):
4654
index_tensor = torch.from_numpy(index[name]).to(device=device, dtype=torch.float32)
47-
queries_tensor = torch.from_numpy(queries[name]).to(device=device, dtype=torch.float32)
55+
queries_tensor = torch.from_numpy(queries[name]).to(
56+
device=device, dtype=torch.float32
57+
)
4858
if normalize:
49-
index_tensor = index_tensor / (torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10)
50-
queries_tensor = queries_tensor / (torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10)
59+
index_tensor = index_tensor / (
60+
torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10
61+
)
62+
queries_tensor = queries_tensor / (
63+
torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10
64+
)
5165
# Compute result on GPU
5266
result_tensor = index_tensor @ queries_tensor.T
5367

@@ -56,7 +70,10 @@
5670

5771
# Create memory-mapped file with .bin extension
5872
mmap_file = np.memmap(
59-
os.path.join(output_path, f"{name}_attribution.npy"), dtype=np.float32, mode="w+", shape=result_shape
73+
os.path.join(output_path, f"{name}_attribution.npy"),
74+
dtype=np.float32,
75+
mode="w+",
76+
shape=result_shape,
6077
)
6178

6279
# Copy GPU result directly to memmap

0 commit comments

Comments
 (0)