Skip to content

Commit 376c7fd

Browse files
committed
running precommit hook
1 parent 7df9e5e commit 376c7fd

19 files changed

+385
-131
lines changed

bergson/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def callback(name: str, g: torch.Tensor):
7272
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
7373

7474
# Allocate structured space ahead of time for the gradients
75-
grad_buffer = create_index(cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16)
75+
grad_buffer = create_index(
76+
cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16
77+
)
7678

7779
per_doc_losses = torch.full(
7880
(len(data),),

bergson/data.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
Precision = Literal["bf16", "fp16", "fp32", "int4", "int8"]
2020

21+
2122
@dataclass
2223
class DataConfig:
2324
dataset: str = "EleutherAI/SmolLM2-135M-10B"
@@ -100,7 +101,9 @@ def ceildiv(a: int, b: int) -> int:
100101
return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers
101102

102103

103-
def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = None) -> list[list[int]]:
104+
def allocate_batches(
105+
doc_lengths: list[int], N: int, world_size: Optional[int] = None
106+
) -> list[list[int]]:
104107
"""
105108
Allocate documents into batches that are then distributed evenly across
106109
a fixed number of workers.
@@ -184,7 +187,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
184187
while len(batches) < world_size:
185188
big = batches.pop(0) # take the current largest
186189
if len(big) == 1: # cannot split a singleton
187-
raise RuntimeError("Not enough documents to give each worker at least one batch.")
190+
raise RuntimeError(
191+
"Not enough documents to give each worker at least one batch."
192+
)
188193
batches.append([big.pop()]) # move one doc into new batch
189194
batches.append(big) # put the remainder back
190195
# preserve cost constraint automatically
@@ -206,7 +211,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
206211
i += 1
207212

208213
assert len(batches) == target_batches
209-
assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)
214+
assert all(
215+
max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches
216+
)
210217

211218
# ---------------------------------------------------------------------
212219
# 4) Round-robin assignment to workers
@@ -220,7 +227,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] =
220227
return allocation[rank]
221228

222229

223-
def create_index(root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike) -> np.memmap:
230+
def create_index(
231+
root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike
232+
) -> np.memmap:
224233
"""Create a memory-mapped file for storing structured gradients
225234
and persist metadata."""
226235
grad_path = os.path.join(root, "gradients.bin")
@@ -311,7 +320,9 @@ def load_shard(dir: str) -> Dataset:
311320
if concatenate_gradients:
312321
unstructured_data = structured_to_unstructured(mmap)
313322
flat = pa.array(unstructured_data.reshape(-1))
314-
col_arrow = pa.FixedSizeListArray.from_arrays(flat, unstructured_data.shape[1])
323+
col_arrow = pa.FixedSizeListArray.from_arrays(
324+
flat, unstructured_data.shape[1]
325+
)
315326

316327
ds = ds.add_column("gradients", col_arrow, new_fingerprint="gradients")
317328
# Add a column for each module's gradient vectors
@@ -375,7 +386,9 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
375386
{"role": "user", "content": assert_type(str, prompt)},
376387
{"role": "assistant", "content": assert_type(str, resp)},
377388
]
378-
for prompt, resp in zip(batch[args.prompt_column], batch[args.completion_column])
389+
for prompt, resp in zip(
390+
batch[args.prompt_column], batch[args.completion_column]
391+
)
379392
]
380393
elif args.conversation_column:
381394
# We're dealing with a conversation dataset
@@ -422,4 +435,7 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer):
422435
def unflatten(x: torch.Tensor, shapes: dict[str, Sequence[int]], dim: int = -1):
423436
"""Unflatten a tensor `x` into a dictionary of tensors with specified shapes."""
424437
numels = [math.prod(shape) for shape in shapes.values()]
425-
return {name: x.unflatten(dim, shape) for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim))}
438+
return {
439+
name: x.unflatten(dim, shape)
440+
for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim))
441+
}

bergson/distributed.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,30 @@ def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset:
6464
ds = load_dataset(data_str, split="train")
6565

6666
if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
67-
raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
67+
raise NotImplementedError(
68+
"DatasetDicts and IterableDatasetDicts are not supported."
69+
)
6870
except ValueError as e:
6971
# Automatically use load_from_disk if appropriate
7072
if "load_from_disk" in str(e):
7173
ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
7274
else:
7375
raise e
7476

75-
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)
77+
tokenizer = AutoTokenizer.from_pretrained(
78+
cfg.model, model_max_length=cfg.token_batch_size
79+
)
7680

77-
ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer))
81+
ds = ds.map(
82+
tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer)
83+
)
7884

7985
return ds
8086

8187

82-
def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tuple[AutoModelForCausalLM, set | None]:
88+
def setup_model_and_peft(
89+
cfg: IndexConfig, rank: int, dtype: torch.dtype
90+
) -> tuple[AutoModelForCausalLM, set | None]:
8391
"""Handle model loading, quantization, FSDP, and PEFT detection"""
8492

8593
torch.manual_seed(42)
@@ -141,7 +149,9 @@ def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tup
141149
model.get_submodule(processed_name)
142150
target_modules.add(processed_name)
143151
except AttributeError:
144-
print(f"Adapter parameter '{processed_name}' not found in the model.")
152+
print(
153+
f"Adapter parameter '{processed_name}' not found in the model."
154+
)
145155

146156
# Configure gradients
147157
model.requires_grad_(False)
@@ -223,7 +233,11 @@ def worker_wrapper(
223233
case "fp32":
224234
dtype = torch.float32
225235
case "int4" | "int8":
226-
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
236+
dtype = (
237+
torch.bfloat16
238+
if torch.cuda.is_bf16_supported()
239+
else torch.float16
240+
)
227241
case other:
228242
raise ValueError(f"Unsupported precision: {other}")
229243

@@ -305,7 +319,10 @@ def distributed_computing(
305319
ctx = start_processes(
306320
"build",
307321
worker_wrapper,
308-
args={i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor) for i in range(world_size)},
322+
args={
323+
i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor)
324+
for i in range(world_size)
325+
},
309326
envs={
310327
i: {
311328
"LOCAL_RANK": str(i),

bergson/gradients.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def to_adafactor(self) -> AdafactorNormalizer:
162162
and the factored second moments.
163163
"""
164164
# We assume avg_sq is a square matrix of shape [O, I]
165-
assert self.avg_sq.ndim == 2, f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
165+
assert (
166+
self.avg_sq.ndim == 2
167+
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
166168

167169
# Compute row and column means
168170
return AdafactorNormalizer(
@@ -213,9 +215,6 @@ def save(self, path: str):
213215
json.dump(cfg, f, indent=2)
214216

215217

216-
217-
218-
219218
@dataclass
220219
class GradientCollector(ContextDecorator):
221220
"""
@@ -346,7 +345,12 @@ def _save_input(self, module: nn.Module, inp: tuple, _):
346345
if p is not None and not isinstance(norm, AdamNormalizer):
347346
i = module.in_features
348347

349-
x = x @ self.projection(name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device).T
348+
x = (
349+
x
350+
@ self.projection(
351+
name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device
352+
).T
353+
)
350354

351355
module._inputs = x
352356

@@ -387,14 +391,20 @@ def _process_grad(self, module: nn.Module, _, grad_out):
387391

388392
# Project the gradients to the lower-dimensional space
389393
if p is not None:
390-
A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device)
391-
B = self.projection(name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device)
394+
A = self.projection(
395+
name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device
396+
)
397+
B = self.projection(
398+
name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device
399+
)
392400
P = A @ P @ B.T # [N, p, q]
393401

394402
# Both Adafactor and no normalizer, we can project G first
395403
else:
396404
if p is not None:
397-
A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device)
405+
A = self.projection(
406+
name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device
407+
)
398408
G = G @ A.T # [N, S, p]
399409

400410
P = G.mT @ I # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q]

bergson/hessians/attribute.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
# ## 1. Load index for query and train data
1414

1515
parser = argparse.ArgumentParser(description="Process normalization flag.")
16-
parser.add_argument("--normalize", action="store_true", help="Gradients will be unit normalized.")
16+
parser.add_argument(
17+
"--normalize", action="store_true", help="Gradients will be unit normalized."
18+
)
1719
args = parser.parse_args()
1820

1921
device = "cuda:1"
2022

2123
# %%
22-
base_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/"
23-
index_dataset = load_dataset("json", data_files=f"{base_path}merged-medical-reformatted.jsonl")["train"]
24+
base_path = (
25+
"/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/"
26+
)
27+
index_dataset = load_dataset(
28+
"json", data_files=f"{base_path}merged-medical-reformatted.jsonl"
29+
)["train"]
2430
index_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/qwen14_merged_medical_proj16/merged_medical_no_normalizer"
2531
queries_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac"
2632

@@ -37,17 +43,25 @@
3743
normalize = args.normalize
3844

3945
attribution_dict = {}
40-
output_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer"
46+
output_path = (
47+
"/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer"
48+
)
4149
if normalize:
4250
output_path += "_unit_norm"
4351
os.makedirs(output_path, exist_ok=True)
4452

4553
for name in tqdm(list(names)):
4654
index_tensor = torch.from_numpy(index[name]).to(device=device, dtype=torch.float32)
47-
queries_tensor = torch.from_numpy(queries[name]).to(device=device, dtype=torch.float32)
55+
queries_tensor = torch.from_numpy(queries[name]).to(
56+
device=device, dtype=torch.float32
57+
)
4858
if normalize:
49-
index_tensor = index_tensor / (torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10)
50-
queries_tensor = queries_tensor / (torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10)
59+
index_tensor = index_tensor / (
60+
torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10
61+
)
62+
queries_tensor = queries_tensor / (
63+
torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10
64+
)
5165
# Compute result on GPU
5266
result_tensor = index_tensor @ queries_tensor.T
5367

@@ -56,7 +70,10 @@
5670

5771
# Create memory-mapped file with .bin extension
5872
mmap_file = np.memmap(
59-
os.path.join(output_path, f"{name}_attribution.npy"), dtype=np.float32, mode="w+", shape=result_shape
73+
os.path.join(output_path, f"{name}_attribution.npy"),
74+
dtype=np.float32,
75+
mode="w+",
76+
shape=result_shape,
6077
)
6178

6279
# Copy GPU result directly to memmap

bergson/hessians/collector.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,14 @@ def teardown(self) -> None:
263263
os.makedirs(gradient_path, exist_ok=True)
264264

265265
# Save sharded covariance matrices
266-
save_file(self.A_cov_dict, os.path.join(activation_path, f"shard_{self.rank}.safetensors"))
267-
save_file(self.S_cov_dict, os.path.join(gradient_path, f"shard_{self.rank}.safetensors"))
266+
save_file(
267+
self.A_cov_dict,
268+
os.path.join(activation_path, f"shard_{self.rank}.safetensors"),
269+
)
270+
save_file(
271+
self.S_cov_dict,
272+
os.path.join(gradient_path, f"shard_{self.rank}.safetensors"),
273+
)
268274

269275

270276
@dataclass(kw_only=True)
@@ -286,11 +292,15 @@ def setup(self) -> None:
286292
"""Load eigenvectors and initialize storage."""
287293
# Load precomputed eigenvectors
288294
self.eigen_a = load_file(
289-
os.path.join(self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors"),
295+
os.path.join(
296+
self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors"
297+
),
290298
device=f"cuda:{self.rank}",
291299
)
292300
self.eigen_g = load_file(
293-
os.path.join(self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors"),
301+
os.path.join(
302+
self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors"
303+
),
294304
device=f"cuda:{self.rank}",
295305
)
296306

@@ -303,7 +313,9 @@ def forward_hook(self, name: str, a: Tensor) -> None:
303313
# a shape: [N, S, I]
304314

305315
# Transform: a @ eigen_a
306-
transformed = self.shard_computer._matmul(vector_nsa=a, matrix_cb=self.eigen_a[name]) # shape [N, S, I]
316+
transformed = self.shard_computer._matmul(
317+
vector_nsa=a, matrix_cb=self.eigen_a[name]
318+
) # shape [N, S, I]
307319

308320
# Cache for use in backward pass
309321
self.transformed_a_cache[name] = transformed
@@ -313,11 +325,15 @@ def backward_hook(self, name: str, g: Tensor) -> None:
313325
# g shape: [N, S, O]
314326

315327
# Transform: g @ eigen_g
316-
transformed_g = self.shard_computer._matmul(vector_nsa=g, matrix_cb=self.eigen_g[name]) # shape [N, S, O]
328+
transformed_g = self.shard_computer._matmul(
329+
vector_nsa=g, matrix_cb=self.eigen_g[name]
330+
) # shape [N, S, O]
317331

318332
# Compute outer product: sum_n (transformed_a_n^T @ transformed_g_n)
319333
# Einstein notation: [N, S, I] x [N, S, O] -> [N, O, I]
320-
transformed_grad_shard = torch.einsum("N S I, N S O -> N O I", self.transformed_a_cache[name], transformed_g)
334+
transformed_grad_shard = torch.einsum(
335+
"N S I, N S O -> N O I", self.transformed_a_cache[name], transformed_g
336+
)
321337

322338
# Square and sum over batch
323339
transformed_grad_shard = (transformed_grad_shard**2).sum(dim=0).contiguous()
@@ -333,15 +349,26 @@ def backward_hook(self, name: str, g: Tensor) -> None:
333349

334350
# Accumulate (with CPU offloading for memory efficiency)
335351
if name not in self.eigenvalue_corrections:
336-
self.eigenvalue_corrections[name] = transformed_grad_shard[start_row:end_row, :].contiguous()
352+
self.eigenvalue_corrections[name] = transformed_grad_shard[
353+
start_row:end_row, :
354+
].contiguous()
337355
else:
338-
self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(device=self.device)
339-
self.eigenvalue_corrections[name].add_(transformed_grad_shard[start_row:end_row, :].contiguous())
340-
self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(device="cpu", non_blocking=False)
356+
self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(
357+
device=self.device
358+
)
359+
self.eigenvalue_corrections[name].add_(
360+
transformed_grad_shard[start_row:end_row, :].contiguous()
361+
)
362+
self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(
363+
device="cpu", non_blocking=False
364+
)
341365

342366
def teardown(self) -> None:
343367
"""Save eigenvalue corrections to disk."""
344368
output_path = os.path.join(self.path, "eigenvalue_correction_sharded")
345369
os.makedirs(output_path, exist_ok=True)
346370

347-
save_file(self.eigenvalue_corrections, os.path.join(output_path, f"shard_{self.rank}.safetensors"))
371+
save_file(
372+
self.eigenvalue_corrections,
373+
os.path.join(output_path, f"shard_{self.rank}.safetensors"),
374+
)

0 commit comments

Comments
 (0)