Skip to content

Commit 7e350e6

Browse files
committed
use torch dtype + change path in run_test_compute_ekfac.sh to make the test run
1 parent 376c7fd commit 7e350e6

File tree

2 files changed

+131
-41
lines changed

2 files changed

+131
-41
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 130 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
from test_utils import set_all_seeds
3232
from torch import Tensor
3333
from tqdm import tqdm
34-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel
34+
from transformers import (
35+
AutoModelForCausalLM,
36+
AutoTokenizer,
37+
BitsAndBytesConfig,
38+
PreTrainedModel,
39+
)
3540

3641
from bergson.data import DataConfig, IndexConfig, Precision, pad_and_tensor, tokenize
3742
from bergson.hessians.utils import TensorDict
@@ -44,7 +49,9 @@
4449

4550

4651
# %%
47-
def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> Batches:
52+
def allocate_batches_test(
53+
doc_lengths: list[int], N: int, workers: Optional[int] = None
54+
) -> Batches:
4855
"""
4956
Modification of allocate_batches to return a flat list of batches for testing.
5057
@@ -103,7 +110,9 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
103110
while len(batches) < world_size:
104111
big = batches.pop(0)
105112
if len(big) == 1:
106-
raise RuntimeError("Not enough documents to give each worker at least one batch.")
113+
raise RuntimeError(
114+
"Not enough documents to give each worker at least one batch."
115+
)
107116
batches.append([big.pop()])
108117
batches.append(big)
109118

@@ -121,7 +130,9 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
121130
i += 1
122131

123132
assert len(batches) == target_batches
124-
assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)
133+
assert all(
134+
max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches
135+
)
125136

126137
# Round-robin assignment to workers
127138
allocation: Batches = [[] for _ in range(world_size)]
@@ -143,7 +154,9 @@ def parse_config() -> tuple[Precision, Optional[str]]:
143154
output_dir: Optional[str]
144155

145156
if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"):
146-
parser = argparse.ArgumentParser(description="Compute EKFAC ground truth for testing")
157+
parser = argparse.ArgumentParser(
158+
description="Compute EKFAC ground truth for testing"
159+
)
147160
parser.add_argument(
148161
"--precision",
149162
type=str,
@@ -233,7 +246,9 @@ def setup_paths_and_config(
233246

234247

235248
if __name__ == "__main__" or TYPE_CHECKING:
236-
cfg, test_path, workers, device, target_modules, dtype = setup_paths_and_config(precision, output_dir)
249+
cfg, test_path, workers, device, target_modules, dtype = setup_paths_and_config(
250+
precision, output_dir
251+
)
237252

238253

239254
# %% [markdown]
@@ -259,7 +274,7 @@ def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel:
259274
if cfg.precision in ("int4", "int8")
260275
else None
261276
),
262-
torch_dtype=dtype,
277+
dtype=dtype,
263278
)
264279
return model
265280

@@ -282,7 +297,9 @@ def load_dataset_step(cfg: IndexConfig) -> Dataset:
282297
try:
283298
ds = load_dataset(data_str, split="train")
284299
if isinstance(ds, (DatasetDict, IterableDatasetDict)):
285-
raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.")
300+
raise NotImplementedError(
301+
"DatasetDicts and IterableDatasetDicts are not supported."
302+
)
286303
except ValueError as e:
287304
if "load_from_disk" in str(e):
288305
ds = Dataset.load_from_disk(data_str, keep_in_memory=False)
@@ -302,12 +319,18 @@ def tokenize_and_allocate_step(
302319
ds: Dataset, cfg: IndexConfig, workers: int
303320
) -> tuple[Dataset, Batches, Any]:
304321
"""Tokenize dataset and allocate batches."""
305-
tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)
306-
ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer))
322+
tokenizer = AutoTokenizer.from_pretrained(
323+
cfg.model, model_max_length=cfg.token_batch_size
324+
)
325+
ds = ds.map(
326+
tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer)
327+
)
307328
data = ds
308329

309330
# Allocate batches
310-
batches_world = allocate_batches_test(doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers)
331+
batches_world = allocate_batches_test(
332+
doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers
333+
)
311334
assert len(batches_world) == workers
312335

313336
return data, batches_world, tokenizer
@@ -400,8 +423,16 @@ def compute_covariances_step(
400423
gradient_covariances=gradient_covariances,
401424
)
402425

403-
save_file(activation_covariances, os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
404-
save_file(gradient_covariances, os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
426+
save_file(
427+
activation_covariances,
428+
os.path.join(
429+
covariance_test_path_rank, "activation_covariance.safetensors"
430+
),
431+
)
432+
save_file(
433+
gradient_covariances,
434+
os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"),
435+
)
405436
with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f:
406437
json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
407438
print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
@@ -417,7 +448,9 @@ def compute_covariances_step(
417448

418449

419450
# %%
420-
def combine_covariances_step(covariance_test_path: str, workers: int, device: torch.device) -> int:
451+
def combine_covariances_step(
452+
covariance_test_path: str, workers: int, device: torch.device
453+
) -> int:
421454
"""Combine covariance results from all ranks."""
422455
activation_covariances = TensorDict({})
423456
gradient_covariances = TensorDict({})
@@ -431,25 +464,41 @@ def combine_covariances_step(covariance_test_path: str, workers: int, device: to
431464
total_processed_global += d["total_processed_rank"]
432465

433466
activation_covariances_rank = TensorDict(
434-
load_file(os.path.join(covariance_test_path_rank, "activation_covariance.safetensors"))
467+
load_file(
468+
os.path.join(
469+
covariance_test_path_rank, "activation_covariance.safetensors"
470+
)
471+
)
435472
).to(device)
436473

437474
gradient_covariances_rank = TensorDict(
438-
load_file(os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"))
475+
load_file(
476+
os.path.join(
477+
covariance_test_path_rank, "gradient_covariance.safetensors"
478+
)
479+
)
439480
).to(device)
440481

441482
if not activation_covariances:
442483
activation_covariances = activation_covariances_rank
443484
else:
444-
activation_covariances = activation_covariances + activation_covariances_rank
485+
activation_covariances = (
486+
activation_covariances + activation_covariances_rank
487+
)
445488

446489
if not gradient_covariances:
447490
gradient_covariances = gradient_covariances_rank
448491
else:
449492
gradient_covariances = gradient_covariances + gradient_covariances_rank
450493

451-
save_file(activation_covariances.to_dict(), os.path.join(covariance_test_path, "activation_covariance.safetensors"))
452-
save_file(gradient_covariances.to_dict(), os.path.join(covariance_test_path, "gradient_covariance.safetensors"))
494+
save_file(
495+
activation_covariances.to_dict(),
496+
os.path.join(covariance_test_path, "activation_covariance.safetensors"),
497+
)
498+
save_file(
499+
gradient_covariances.to_dict(),
500+
os.path.join(covariance_test_path, "gradient_covariance.safetensors"),
501+
)
453502
with open(os.path.join(covariance_test_path, "stats.json"), "w") as f:
454503
json.dump({"total_processed_global": total_processed_global}, f, indent=4)
455504
print(f"Global processed {total_processed_global} tokens.")
@@ -462,15 +511,19 @@ def combine_covariances_step(covariance_test_path: str, workers: int, device: to
462511

463512
if __name__ == "__main__" or TYPE_CHECKING:
464513
print("\n=== Combining Covariances ===")
465-
total_processed_global = combine_covariances_step(covariance_test_path, workers, device)
514+
total_processed_global = combine_covariances_step(
515+
covariance_test_path, workers, device
516+
)
466517

467518

468519
# %% [markdown]
469520
# ## 3. Compute eigenvalues and eigenvectors
470521

471522

472523
# %%
473-
def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch.dtype) -> str:
524+
def compute_eigenvectors_step(
525+
test_path: str, device: torch.device, dtype: torch.dtype
526+
) -> str:
474527
"""Compute eigenvectors from covariances."""
475528
covariance_test_path = os.path.join(test_path, "covariances")
476529
eigenvectors_test_path = os.path.join(test_path, "eigenvectors")
@@ -481,8 +534,12 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
481534
d = json.load(f)
482535
total_processed_global = d["total_processed_global"]
483536

484-
activation_covariances = load_file(os.path.join(covariance_test_path, "activation_covariance.safetensors"))
485-
gradient_covariances = load_file(os.path.join(covariance_test_path, "gradient_covariance.safetensors"))
537+
activation_covariances = load_file(
538+
os.path.join(covariance_test_path, "activation_covariance.safetensors")
539+
)
540+
gradient_covariances = load_file(
541+
os.path.join(covariance_test_path, "gradient_covariance.safetensors")
542+
)
486543

487544
eigenvectors_activations = {}
488545
eigenvectors_gradients = {}
@@ -497,12 +554,20 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
497554

498555
eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a)
499556
eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g)
500-
print(f"{name}: eigenvectors_a.sum()={eigenvectors_a.sum()}, eigenvectors_g.sum()={eigenvectors_g.sum()}")
557+
print(
558+
f"{name}: eigenvectors_a.sum()={eigenvectors_a.sum()}, eigenvectors_g.sum()={eigenvectors_g.sum()}"
559+
)
501560
eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous()
502561
eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous()
503562

504-
save_file(eigenvectors_activations, os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
505-
save_file(eigenvectors_gradients, os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))
563+
save_file(
564+
eigenvectors_activations,
565+
os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"),
566+
)
567+
save_file(
568+
eigenvectors_gradients,
569+
os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"),
570+
)
506571

507572
gc.collect()
508573
torch.cuda.empty_cache()
@@ -585,12 +650,18 @@ def compute_eigenvalue_corrections_step(
585650
os.makedirs(eigenvalue_correction_test_path, exist_ok=True)
586651

587652
# Load eigenvectors
588-
eigenvectors_activations = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"))
589-
eigenvectors_gradients = load_file(os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"))
653+
eigenvectors_activations = load_file(
654+
os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors")
655+
)
656+
eigenvectors_gradients = load_file(
657+
os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors")
658+
)
590659

591660
total_processed_global = 0
592661
for rank in range(workers):
593-
eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")
662+
eigenvalue_correction_test_path_rank = os.path.join(
663+
eigenvalue_correction_test_path, f"rank_{rank}"
664+
)
594665
os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True)
595666

596667
eigenvalue_corrections = {}
@@ -608,9 +679,14 @@ def compute_eigenvalue_corrections_step(
608679

609680
save_file(
610681
eigenvalue_corrections,
611-
os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors"),
682+
os.path.join(
683+
eigenvalue_correction_test_path_rank,
684+
"eigenvalue_corrections.safetensors",
685+
),
612686
)
613-
with open(os.path.join(eigenvalue_correction_test_path_rank, "stats.json"), "w") as f:
687+
with open(
688+
os.path.join(eigenvalue_correction_test_path_rank, "stats.json"), "w"
689+
) as f:
614690
json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4)
615691
print(f"Rank {rank} processed {d['total_processed_rank']} tokens.")
616692
total_processed_global += d["total_processed_rank"]
@@ -620,34 +696,50 @@ def compute_eigenvalue_corrections_step(
620696

621697
if __name__ == "__main__" or TYPE_CHECKING:
622698
print("\n=== Computing Eigenvalue Corrections ===")
623-
eigenvalue_correction_test_path, total_processed_global_lambda = compute_eigenvalue_corrections_step(
624-
model, data, batches_world, device, target_modules, workers, test_path
699+
eigenvalue_correction_test_path, total_processed_global_lambda = (
700+
compute_eigenvalue_corrections_step(
701+
model, data, batches_world, device, target_modules, workers, test_path
702+
)
625703
)
626704

627705

628706
# %%
629707
def combine_eigenvalue_corrections_step(
630-
eigenvalue_correction_test_path: str, workers: int, device: torch.device, total_processed_global: int
708+
eigenvalue_correction_test_path: str,
709+
workers: int,
710+
device: torch.device,
711+
total_processed_global: int,
631712
) -> TensorDict:
632713
"""Combine eigenvalue correction results from all ranks."""
633714
eigenvalue_corrections = TensorDict({})
634715

635716
for rank in range(workers):
636-
eigenvalue_correction_test_path_rank = os.path.join(eigenvalue_correction_test_path, f"rank_{rank}")
717+
eigenvalue_correction_test_path_rank = os.path.join(
718+
eigenvalue_correction_test_path, f"rank_{rank}"
719+
)
637720

638721
eigenvalue_corrections_rank = TensorDict(
639-
load_file(os.path.join(eigenvalue_correction_test_path_rank, "eigenvalue_corrections.safetensors"))
722+
load_file(
723+
os.path.join(
724+
eigenvalue_correction_test_path_rank,
725+
"eigenvalue_corrections.safetensors",
726+
)
727+
)
640728
).to(device)
641729

642730
if not eigenvalue_corrections:
643731
eigenvalue_corrections = eigenvalue_corrections_rank
644732
else:
645-
eigenvalue_corrections = eigenvalue_corrections + eigenvalue_corrections_rank
733+
eigenvalue_corrections = (
734+
eigenvalue_corrections + eigenvalue_corrections_rank
735+
)
646736

647737
eigenvalue_corrections.div_(total_processed_global)
648738
save_file(
649739
eigenvalue_corrections.to_dict(),
650-
os.path.join(eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"),
740+
os.path.join(
741+
eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors"
742+
),
651743
)
652744

653745
return eigenvalue_corrections

tests/ekfac_tests/run_test_compute_ekfac.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
# Run all tests
44
python test_compute_ekfac.py \
5-
--test_dir "./test_files/pile_100_examples" \
5+
--test_dir "/root/bergson/test_files/pile_100_examples" \
66
--world_size 8 \
77
--use_fsdp \
88
--overwrite
9-
10-

0 commit comments

Comments
 (0)