Skip to content

Commit c7cb117

Browse files
committed
Reorder cells to be more convenient as a notebook
All covariance-related cells are next to each other, same for all eigendecomposition-related code.
1 parent 5e33a30 commit c7cb117

File tree

1 file changed

+99
-99
lines changed

1 file changed

+99
-99
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -130,105 +130,6 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
130130
return allocation
131131

132132

133-
# %%
134-
def compute_covariance(
135-
rank: int,
136-
model: PreTrainedModel,
137-
data: Dataset,
138-
batches_world: list[list[list[int]]],
139-
device: torch.device,
140-
target_modules: Any,
141-
activation_covariances: dict[str, Tensor],
142-
gradient_covariances: dict[str, Tensor],
143-
) -> dict[str, Any]:
144-
"""Compute activation and gradient covariances for a single worker."""
145-
total_processed = 0
146-
batches = batches_world[rank]
147-
loss_list = []
148-
149-
collector = GroundTruthCovarianceCollector(
150-
model=model.base_model,
151-
activation_covariances=activation_covariances,
152-
gradient_covariances=gradient_covariances,
153-
target_modules=target_modules,
154-
)
155-
156-
for sl in tqdm(batches, desc=f"Rank {rank} covariances"):
157-
batch = data[sl]
158-
x, y = pad_and_tensor(
159-
batch["input_ids"],
160-
labels=batch.get("labels"),
161-
device=device,
162-
)
163-
164-
total_processed += x.numel()
165-
166-
with collector:
167-
logits = model(x).logits
168-
losses = F.cross_entropy(
169-
logits[:, :-1].reshape(-1, logits.size(-1)),
170-
y[:, 1:].flatten(),
171-
reduction="none",
172-
).reshape_as(y[:, 1:])
173-
174-
losses = losses.sum(1)
175-
losses.mean().backward()
176-
loss_list.append(losses.detach().cpu())
177-
model.zero_grad()
178-
179-
return {"losses": loss_list, "total_processed_rank": total_processed}
180-
181-
182-
# %%
183-
def compute_eigenvalue_correction_amortized(
184-
rank: int,
185-
model: PreTrainedModel,
186-
data: Dataset,
187-
batches_world: list[list[list[int]]],
188-
device: torch.device,
189-
target_modules: Any,
190-
eigenvalue_corrections: dict[str, Tensor],
191-
eigenvectors_activations: dict[str, Tensor],
192-
eigenvectors_gradients: dict[str, Tensor],
193-
) -> dict[str, int]:
194-
"""Compute eigenvalue corrections using the amortized method."""
195-
total_processed = 0
196-
batches = batches_world[rank]
197-
198-
collector = GroundTruthAmortizedLambdaCollector(
199-
model=model.base_model,
200-
eigenvalue_corrections=eigenvalue_corrections,
201-
eigenvectors_activations=eigenvectors_activations,
202-
eigenvectors_gradients=eigenvectors_gradients,
203-
device=device,
204-
target_modules=target_modules,
205-
)
206-
207-
for sl in tqdm(batches, desc=f"Rank {rank} eigenvalue corrections"):
208-
batch = data[sl]
209-
x, y = pad_and_tensor(
210-
batch["input_ids"],
211-
labels=batch.get("labels"),
212-
device=device,
213-
)
214-
215-
total_processed += x.numel()
216-
217-
with collector:
218-
logits = model(x).logits
219-
losses = F.cross_entropy(
220-
logits[:, :-1].reshape(-1, logits.size(-1)),
221-
y[:, 1:].flatten(),
222-
reduction="none",
223-
).reshape_as(y[:, 1:])
224-
225-
losses = losses.sum(1)
226-
losses.mean().backward()
227-
model.zero_grad()
228-
229-
return {"total_processed_rank": total_processed}
230-
231-
232133
# %% [markdown]
233134
# ## 0. Hyperparameters
234135

@@ -418,6 +319,55 @@ def tokenize_and_allocate_step(
418319
# ## 2. Compute activation and gradient covariance
419320

420321

322+
# %%
323+
def compute_covariance(
324+
rank: int,
325+
model: PreTrainedModel,
326+
data: Dataset,
327+
batches_world: list[list[list[int]]],
328+
device: torch.device,
329+
target_modules: Any,
330+
activation_covariances: dict[str, Tensor],
331+
gradient_covariances: dict[str, Tensor],
332+
) -> dict[str, Any]:
333+
"""Compute activation and gradient covariances for a single worker."""
334+
total_processed = 0
335+
batches = batches_world[rank]
336+
loss_list = []
337+
338+
collector = GroundTruthCovarianceCollector(
339+
model=model.base_model,
340+
activation_covariances=activation_covariances,
341+
gradient_covariances=gradient_covariances,
342+
target_modules=target_modules,
343+
)
344+
345+
for sl in tqdm(batches, desc=f"Rank {rank} covariances"):
346+
batch = data[sl]
347+
x, y = pad_and_tensor(
348+
batch["input_ids"],
349+
labels=batch.get("labels"),
350+
device=device,
351+
)
352+
353+
total_processed += x.numel()
354+
355+
with collector:
356+
logits = model(x).logits
357+
losses = F.cross_entropy(
358+
logits[:, :-1].reshape(-1, logits.size(-1)),
359+
y[:, 1:].flatten(),
360+
reduction="none",
361+
).reshape_as(y[:, 1:])
362+
363+
losses = losses.sum(1)
364+
losses.mean().backward()
365+
loss_list.append(losses.detach().cpu())
366+
model.zero_grad()
367+
368+
return {"losses": loss_list, "total_processed_rank": total_processed}
369+
370+
421371
# %%
422372
def compute_covariances_step(
423373
model: PreTrainedModel,
@@ -567,6 +517,56 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
567517
# ## 4. Compute eigenvalue correction
568518

569519

520+
# %%
521+
def compute_eigenvalue_correction_amortized(
522+
rank: int,
523+
model: PreTrainedModel,
524+
data: Dataset,
525+
batches_world: list[list[list[int]]],
526+
device: torch.device,
527+
target_modules: Any,
528+
eigenvalue_corrections: dict[str, Tensor],
529+
eigenvectors_activations: dict[str, Tensor],
530+
eigenvectors_gradients: dict[str, Tensor],
531+
) -> dict[str, int]:
532+
"""Compute eigenvalue corrections using the amortized method."""
533+
total_processed = 0
534+
batches = batches_world[rank]
535+
536+
collector = GroundTruthAmortizedLambdaCollector(
537+
model=model.base_model,
538+
eigenvalue_corrections=eigenvalue_corrections,
539+
eigenvectors_activations=eigenvectors_activations,
540+
eigenvectors_gradients=eigenvectors_gradients,
541+
device=device,
542+
target_modules=target_modules,
543+
)
544+
545+
for sl in tqdm(batches, desc=f"Rank {rank} eigenvalue corrections"):
546+
batch = data[sl]
547+
x, y = pad_and_tensor(
548+
batch["input_ids"],
549+
labels=batch.get("labels"),
550+
device=device,
551+
)
552+
553+
total_processed += x.numel()
554+
555+
with collector:
556+
logits = model(x).logits
557+
losses = F.cross_entropy(
558+
logits[:, :-1].reshape(-1, logits.size(-1)),
559+
y[:, 1:].flatten(),
560+
reduction="none",
561+
).reshape_as(y[:, 1:])
562+
563+
losses = losses.sum(1)
564+
losses.mean().backward()
565+
model.zero_grad()
566+
567+
return {"total_processed_rank": total_processed}
568+
569+
570570
# %%
571571
def compute_eigenvalue_corrections_step(
572572
model: PreTrainedModel,

0 commit comments

Comments
 (0)