diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..2fdfe3e --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,39 @@ +name: build + +on: + push: + branches: + - ekfac + pull_request: + branches: + - ekfac +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,faiss]" + # TODO: Proper test infrastructure for tests/ekfac_tests + # - name: Run tests + # run: pytest + # TODO: run pyright on whole codebase + - name: Type Checking bergson/hessians + uses: jakebailey/pyright-action@v1 + with: + version: 1.1.406 + working-directory: bergson/hessians + - name: Type Checking tests/ekfac_tests + uses: jakebailey/pyright-action@v1 + with: + version: 1.1.406 + working-directory: tests/ekfac_tests + - name: build + run: pip wheel --no-deps -w dist . +env: + HF_HUB_DOWNLOAD_TIMEOUT: 100 diff --git a/bergson/collection.py b/bergson/collection.py index 2d1bb54..aaa4235 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -72,7 +72,9 @@ def callback(name: str, g: torch.Tensor): grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()} # Allocate structured space ahead of time for the gradients - grad_buffer = create_index(cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16) + grad_buffer = create_index( + cfg.run_path, num_grads=len(data), grad_sizes=grad_sizes, dtype=np.float16 + ) per_doc_losses = torch.full( (len(data),), diff --git a/bergson/data.py b/bergson/data.py index 808f11e..f41f998 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -16,6 +16,8 @@ from .utils import assert_type +Precision = Literal["bf16", "fp16", "fp32", "int4", "int8"] + @dataclass class DataConfig: @@ -48,7 +50,7 @@ class IndexConfig: fsdp: bool = False """Whether to use Fully Sharded Data Parallel (FSDP) for collecing gradients.""" - precision: Literal["bf16", "fp16", "fp32", "int4", "int8"] = "bf16" + precision: Precision = "bf16" """Precision to use for the model parameters.""" projection_dim: int = 16 @@ -99,7 +101,9 @@ def ceildiv(a: int, b: int) -> int: return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers -def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = None) -> list[list[int]]: +def allocate_batches( + doc_lengths: list[int], N: int, world_size: Optional[int] = None +) -> list[list[int]]: """ Allocate documents into batches that are then distributed evenly across a fixed number of workers. @@ -183,7 +187,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = while len(batches) < world_size: big = batches.pop(0) # take the current largest if len(big) == 1: # cannot split a singleton - raise RuntimeError("Not enough documents to give each worker at least one batch.") + raise RuntimeError( + "Not enough documents to give each worker at least one batch." + ) batches.append([big.pop()]) # move one doc into new batch batches.append(big) # put the remainder back # preserve cost constraint automatically @@ -205,7 +211,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = i += 1 assert len(batches) == target_batches - assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches) + assert all( + max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches + ) # --------------------------------------------------------------------- # 4) Round-robin assignment to workers @@ -219,7 +227,9 @@ def allocate_batches(doc_lengths: list[int], N: int, world_size: Optional[int] = return allocation[rank] -def create_index(root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike) -> np.memmap: +def create_index( + root: str, num_grads: int, grad_sizes: dict[str, int], dtype: DTypeLike +) -> np.memmap: """Create a memory-mapped file for storing structured gradients and persist metadata.""" grad_path = os.path.join(root, "gradients.bin") @@ -310,7 +320,9 @@ def load_shard(dir: str) -> Dataset: if concatenate_gradients: unstructured_data = structured_to_unstructured(mmap) flat = pa.array(unstructured_data.reshape(-1)) - col_arrow = pa.FixedSizeListArray.from_arrays(flat, unstructured_data.shape[1]) + col_arrow = pa.FixedSizeListArray.from_arrays( + flat, unstructured_data.shape[1] + ) ds = ds.add_column("gradients", col_arrow, new_fingerprint="gradients") # Add a column for each module's gradient vectors @@ -374,7 +386,9 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer): {"role": "user", "content": assert_type(str, prompt)}, {"role": "assistant", "content": assert_type(str, resp)}, ] - for prompt, resp in zip(batch[args.prompt_column], batch[args.completion_column]) + for prompt, resp in zip( + batch[args.prompt_column], batch[args.completion_column] + ) ] elif args.conversation_column: # We're dealing with a conversation dataset @@ -421,4 +435,7 @@ def tokenize(batch: dict, *, args: DataConfig, tokenizer): def unflatten(x: torch.Tensor, shapes: dict[str, Sequence[int]], dim: int = -1): """Unflatten a tensor `x` into a dictionary of tensors with specified shapes.""" numels = [math.prod(shape) for shape in shapes.values()] - return {name: x.unflatten(dim, shape) for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim))} + return { + name: x.unflatten(dim, shape) + for (name, shape), x in zip(shapes.items(), x.split(numels, dim=dim)) + } diff --git a/bergson/distributed.py b/bergson/distributed.py index 2d9aeb2..60fb978 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -64,7 +64,9 @@ def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset: ds = load_dataset(data_str, split="train") if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict): - raise NotImplementedError("DatasetDicts and IterableDatasetDicts are not supported.") + raise NotImplementedError( + "DatasetDicts and IterableDatasetDicts are not supported." + ) except ValueError as e: # Automatically use load_from_disk if appropriate if "load_from_disk" in str(e): @@ -72,14 +74,20 @@ def setup_data_pipeline(cfg: IndexConfig) -> Dataset | IterableDataset: else: raise e - tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size) + tokenizer = AutoTokenizer.from_pretrained( + cfg.model, model_max_length=cfg.token_batch_size + ) - ds = ds.map(tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer)) + ds = ds.map( + tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer) + ) return ds -def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tuple[AutoModelForCausalLM, set | None]: +def setup_model_and_peft( + cfg: IndexConfig, rank: int, dtype: torch.dtype +) -> tuple[AutoModelForCausalLM, set | None]: """Handle model loading, quantization, FSDP, and PEFT detection""" torch.manual_seed(42) @@ -141,7 +149,9 @@ def setup_model_and_peft(cfg: IndexConfig, rank: int, dtype: torch.dtype) -> tup model.get_submodule(processed_name) target_modules.add(processed_name) except AttributeError: - print(f"Adapter parameter '{processed_name}' not found in the model.") + print( + f"Adapter parameter '{processed_name}' not found in the model." + ) # Configure gradients model.requires_grad_(False) @@ -223,7 +233,11 @@ def worker_wrapper( case "fp32": dtype = torch.float32 case "int4" | "int8": - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = ( + torch.bfloat16 + if torch.cuda.is_bf16_supported() + else torch.float16 + ) case other: raise ValueError(f"Unsupported precision: {other}") @@ -305,7 +319,10 @@ def distributed_computing( ctx = start_processes( "build", worker_wrapper, - args={i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor) for i in range(world_size)}, + args={ + i: (i, world_size, cfg, ds, worker_fn, setup_model, setup_processor) + for i in range(world_size) + }, envs={ i: { "LOCAL_RANK": str(i), diff --git a/bergson/gradients.py b/bergson/gradients.py index e88a7af..d7ebd7b 100644 --- a/bergson/gradients.py +++ b/bergson/gradients.py @@ -162,7 +162,9 @@ def to_adafactor(self) -> AdafactorNormalizer: and the factored second moments. """ # We assume avg_sq is a square matrix of shape [O, I] - assert self.avg_sq.ndim == 2, f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" + assert ( + self.avg_sq.ndim == 2 + ), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D" # Compute row and column means return AdafactorNormalizer( @@ -213,9 +215,6 @@ def save(self, path: str): json.dump(cfg, f, indent=2) - - - @dataclass class GradientCollector(ContextDecorator): """ @@ -346,7 +345,12 @@ def _save_input(self, module: nn.Module, inp: tuple, _): if p is not None and not isinstance(norm, AdamNormalizer): i = module.in_features - x = x @ self.projection(name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device).T + x = ( + x + @ self.projection( + name=name, m=p, n=i, side="right", dtype=x.dtype, device=x.device + ).T + ) module._inputs = x @@ -387,14 +391,20 @@ def _process_grad(self, module: nn.Module, _, grad_out): # Project the gradients to the lower-dimensional space if p is not None: - A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device) - B = self.projection(name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device) + A = self.projection( + name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device + ) + B = self.projection( + name=name, m=p, n=i, side="right", dtype=G.dtype, device=G.device + ) P = A @ P @ B.T # [N, p, q] # Both Adafactor and no normalizer, we can project G first else: if p is not None: - A = self.projection(name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device) + A = self.projection( + name=name, m=p, n=o, side="left", dtype=G.dtype, device=G.device + ) G = G @ A.T # [N, S, p] P = G.mT @ I # [N, O/p, S] @ [N, S, I/q] → [N, O/p, I/q] diff --git a/bergson/hessians/attribute.py b/bergson/hessians/attribute.py index 096c9eb..dcd53a2 100644 --- a/bergson/hessians/attribute.py +++ b/bergson/hessians/attribute.py @@ -13,14 +13,20 @@ # ## 1. Load index for query and train data parser = argparse.ArgumentParser(description="Process normalization flag.") -parser.add_argument("--normalize", action="store_true", help="Gradients will be unit normalized.") +parser.add_argument( + "--normalize", action="store_true", help="Gradients will be unit normalized." +) args = parser.parse_args() device = "cuda:1" # %% -base_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/" -index_dataset = load_dataset("json", data_files=f"{base_path}merged-medical-reformatted.jsonl")["train"] +base_path = ( + "/mnt/ssd-1/gpaulo/emergent-misalignment/emergent-misalignment-eleuther/data/" +) +index_dataset = load_dataset( + "json", data_files=f"{base_path}merged-medical-reformatted.jsonl" +)["train"] index_path = "/mnt/ssd-1/gpaulo/emergent-misalignment/qwen14_merged_medical_proj16/merged_medical_no_normalizer" queries_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac" @@ -37,17 +43,25 @@ normalize = args.normalize attribution_dict = {} -output_path = "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer" +output_path = ( + "/mnt/ssd-1/louis/emergent_misalignment/test_query_ekfac_attribution_no_normalizer" +) if normalize: output_path += "_unit_norm" os.makedirs(output_path, exist_ok=True) for name in tqdm(list(names)): index_tensor = torch.from_numpy(index[name]).to(device=device, dtype=torch.float32) - queries_tensor = torch.from_numpy(queries[name]).to(device=device, dtype=torch.float32) + queries_tensor = torch.from_numpy(queries[name]).to( + device=device, dtype=torch.float32 + ) if normalize: - index_tensor = index_tensor / (torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10) - queries_tensor = queries_tensor / (torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10) + index_tensor = index_tensor / ( + torch.norm(index_tensor, dim=1, keepdim=True) + 1e-10 + ) + queries_tensor = queries_tensor / ( + torch.norm(queries_tensor, dim=1, keepdim=True) + 1e-10 + ) # Compute result on GPU result_tensor = index_tensor @ queries_tensor.T @@ -56,7 +70,10 @@ # Create memory-mapped file with .bin extension mmap_file = np.memmap( - os.path.join(output_path, f"{name}_attribution.npy"), dtype=np.float32, mode="w+", shape=result_shape + os.path.join(output_path, f"{name}_attribution.npy"), + dtype=np.float32, + mode="w+", + shape=result_shape, ) # Copy GPU result directly to memmap diff --git a/bergson/hessians/collector.py b/bergson/hessians/collector.py index 6ead497..da75f2f 100644 --- a/bergson/hessians/collector.py +++ b/bergson/hessians/collector.py @@ -263,8 +263,14 @@ def teardown(self) -> None: os.makedirs(gradient_path, exist_ok=True) # Save sharded covariance matrices - save_file(self.A_cov_dict, os.path.join(activation_path, f"shard_{self.rank}.safetensors")) - save_file(self.S_cov_dict, os.path.join(gradient_path, f"shard_{self.rank}.safetensors")) + save_file( + self.A_cov_dict, + os.path.join(activation_path, f"shard_{self.rank}.safetensors"), + ) + save_file( + self.S_cov_dict, + os.path.join(gradient_path, f"shard_{self.rank}.safetensors"), + ) @dataclass(kw_only=True) @@ -286,11 +292,15 @@ def setup(self) -> None: """Load eigenvectors and initialize storage.""" # Load precomputed eigenvectors self.eigen_a = load_file( - os.path.join(self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors"), + os.path.join( + self.path, f"activation_eigen_sharded/shard_{self.rank}.safetensors" + ), device=f"cuda:{self.rank}", ) self.eigen_g = load_file( - os.path.join(self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors"), + os.path.join( + self.path, f"gradient_eigen_sharded/shard_{self.rank}.safetensors" + ), device=f"cuda:{self.rank}", ) @@ -303,7 +313,9 @@ def forward_hook(self, name: str, a: Tensor) -> None: # a shape: [N, S, I] # Transform: a @ eigen_a - transformed = self.shard_computer._matmul(vector_nsa=a, matrix_cb=self.eigen_a[name]) # shape [N, S, I] + transformed = self.shard_computer._matmul( + vector_nsa=a, matrix_cb=self.eigen_a[name] + ) # shape [N, S, I] # Cache for use in backward pass self.transformed_a_cache[name] = transformed @@ -313,11 +325,15 @@ def backward_hook(self, name: str, g: Tensor) -> None: # g shape: [N, S, O] # Transform: g @ eigen_g - transformed_g = self.shard_computer._matmul(vector_nsa=g, matrix_cb=self.eigen_g[name]) # shape [N, S, O] + transformed_g = self.shard_computer._matmul( + vector_nsa=g, matrix_cb=self.eigen_g[name] + ) # shape [N, S, O] # Compute outer product: sum_n (transformed_a_n^T @ transformed_g_n) # Einstein notation: [N, S, I] x [N, S, O] -> [N, O, I] - transformed_grad_shard = torch.einsum("N S I, N S O -> N O I", self.transformed_a_cache[name], transformed_g) + transformed_grad_shard = torch.einsum( + "N S I, N S O -> N O I", self.transformed_a_cache[name], transformed_g + ) # Square and sum over batch transformed_grad_shard = (transformed_grad_shard**2).sum(dim=0).contiguous() @@ -333,15 +349,26 @@ def backward_hook(self, name: str, g: Tensor) -> None: # Accumulate (with CPU offloading for memory efficiency) if name not in self.eigenvalue_corrections: - self.eigenvalue_corrections[name] = transformed_grad_shard[start_row:end_row, :].contiguous() + self.eigenvalue_corrections[name] = transformed_grad_shard[ + start_row:end_row, : + ].contiguous() else: - self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(device=self.device) - self.eigenvalue_corrections[name].add_(transformed_grad_shard[start_row:end_row, :].contiguous()) - self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to(device="cpu", non_blocking=False) + self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to( + device=self.device + ) + self.eigenvalue_corrections[name].add_( + transformed_grad_shard[start_row:end_row, :].contiguous() + ) + self.eigenvalue_corrections[name] = self.eigenvalue_corrections[name].to( + device="cpu", non_blocking=False + ) def teardown(self) -> None: """Save eigenvalue corrections to disk.""" output_path = os.path.join(self.path, "eigenvalue_correction_sharded") os.makedirs(output_path, exist_ok=True) - save_file(self.eigenvalue_corrections, os.path.join(output_path, f"shard_{self.rank}.safetensors")) + save_file( + self.eigenvalue_corrections, + os.path.join(output_path, f"shard_{self.rank}.safetensors"), + ) diff --git a/bergson/hessians/ekfac_compute.py b/bergson/hessians/ekfac_compute.py index 9664a1e..3a82afa 100644 --- a/bergson/hessians/ekfac_compute.py +++ b/bergson/hessians/ekfac_compute.py @@ -25,7 +25,11 @@ from transformers import PreTrainedModel from bergson.data import IndexConfig, create_index, load_gradients, pad_and_tensor -from bergson.hessians.collector import CovarianceCollector, HookCollectorBase, LambdaCollector +from bergson.hessians.collector import ( + CovarianceCollector, + HookCollectorBase, + LambdaCollector, +) from bergson.hessians.logger import get_logger from bergson.hessians.sharded_computation import ShardedMul @@ -70,14 +74,20 @@ def __init__( self.cfg = cfg - self.logger = get_logger("EkfacComputer", level="DEBUG" if cfg.debug else "INFO") + self.logger = get_logger( + "EkfacComputer", level="DEBUG" if cfg.debug else "INFO" + ) ### Distributed related - self.shard_computer = ShardedMul(target_info=self.target_info, lambda_damp_factor=cfg.lambda_damp_factor) + self.shard_computer = ShardedMul( + target_info=self.target_info, lambda_damp_factor=cfg.lambda_damp_factor + ) self.rank = dist.get_rank() if dist.is_initialized() else 0 self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.logger.info(f"Computing EKFAC for {list(self.target_info)} target modules.") + self.logger.info( + f"Computing EKFAC for {list(self.target_info)} target modules." + ) def compute_covariance(self): cov_collector = CovarianceCollector( @@ -90,7 +100,9 @@ def compute_covariance(self): self._collector(cov_collector, desc="covariances") - def compute_eigendecomposition(self, covariance_type: Literal["activation", "gradient"]): + def compute_eigendecomposition( + self, covariance_type: Literal["activation", "gradient"] + ): """This is Eq. 18 from above reference.""" total_processed = torch.load( os.path.join(self.path, "total_processed_covariances.pt"), @@ -98,7 +110,9 @@ def compute_eigendecomposition(self, covariance_type: Literal["activation", "gra ) random.seed(0) - shuffled_target_info = random.sample(list(self.target_info), len(list(self.target_info))) + shuffled_target_info = random.sample( + list(self.target_info), len(list(self.target_info)) + ) target_info_rank = shuffled_target_info[self.rank :: self.world_size] @@ -112,7 +126,10 @@ def compute_eigendecomposition(self, covariance_type: Literal["activation", "gra leave=False, ): matrix = self.shard_computer._compute_full_matrix( - key, shard_path=os.path.join(self.path, f"{covariance_type}_covariance_sharded") + key, + shard_path=os.path.join( + self.path, f"{covariance_type}_covariance_sharded" + ), ) # type: ignore original_dtype = matrix.dtype @@ -129,13 +146,17 @@ def compute_eigendecomposition(self, covariance_type: Literal["activation", "gra eigenvalues, eigenvectors = torch.linalg.eigh(matrix_normalized) except Exception as e: - raise RuntimeError(f"Eigendecomposition failed for {key} of type {covariance_type}") from e + raise RuntimeError( + f"Eigendecomposition failed for {key} of type {covariance_type}" + ) from e eigenvectors = eigenvectors.to(original_dtype).to(device="cpu").contiguous() covariance_eigenvectors[key] = eigenvectors covariance_eigenvectors = self.shard_computer._merge_and_shard_dict( - input_dict=covariance_eigenvectors, covariance_type=covariance_type, dtype=self.dtype + input_dict=covariance_eigenvectors, + covariance_type=covariance_type, + dtype=self.dtype, ) eigen_path = os.path.join(self.path, f"{covariance_type}_eigen_sharded") @@ -199,7 +220,9 @@ def _collector(self, collector, desc: Optional[str] = None): step = 0 with prof: - for sl in tqdm(self.batches, disable=self.rank != 0, desc=f"Computing {desc}"): + for sl in tqdm( + self.batches, disable=self.rank != 0, desc=f"Computing {desc}" + ): batch = self.data[sl] x, y = pad_and_tensor( batch["input_ids"], # type: ignore @@ -209,7 +232,14 @@ def _collector(self, collector, desc: Optional[str] = None): total_processed += x.numel() - with collector, record_function(f"step_{step}") if self.cfg.profile else nullcontext(): + with ( + collector, + ( + record_function(f"step_{step}") + if self.cfg.profile + else nullcontext() + ), + ): logits = self.model(x).logits logits = logits[:, :-1].reshape(-1, logits.size(-1)) @@ -249,7 +279,9 @@ def _collector(self, collector, desc: Optional[str] = None): dist.all_reduce(total_processed, op=dist.ReduceOp.SUM) if self.rank == 0: - torch.save(total_processed, os.path.join(self.path, f"total_processed_{desc}.pt")) + torch.save( + total_processed, os.path.join(self.path, f"total_processed_{desc}.pt") + ) self.logger.info(f"Total processed: {total_processed.item()}") @@ -262,14 +294,18 @@ def __init__( self.path = os.path.join(cfg.ekfac_path, "influence_results") self.gradient_path = cfg.gradient_path - self.logger = get_logger("EkfacApplicator", level="DEBUG" if cfg.debug else "INFO") + self.logger = get_logger( + "EkfacApplicator", level="DEBUG" if cfg.debug else "INFO" + ) ### Distributed related self.rank = dist.get_rank() if dist.is_initialized() else 0 self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.device = f"cuda:{self.rank}" - self.sharded_computer = ShardedMul(target_info=None, lambda_damp_factor=cfg.lambda_damp_factor) + self.sharded_computer = ShardedMul( + target_info=None, lambda_damp_factor=cfg.lambda_damp_factor + ) match cfg.precision: case "bf16": @@ -279,7 +315,9 @@ def __init__( case "fp32": self.dtype = torch.float32 case "int4" | "int8": - self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + self.dtype = ( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) case other: raise ValueError(f"Unsupported precision: {other}") @@ -308,10 +346,18 @@ def prepare_attribution(self): dtype=eigen_a[name].dtype, ) - proj_shards_wpt = torch.chunk(proj_pi, self.world_size, dim=1) # (w, p, i/w) - result_shard_pi = torch.einsum("t i, p t-> p i", eigen_a[name], proj_shards_wpt[self.rank]).contiguous() - - dist.all_reduce(result_shard_pi, op=dist.ReduceOp.SUM) if dist.is_initialized() else None + proj_shards_wpt = torch.chunk( + proj_pi, self.world_size, dim=1 + ) # (w, p, i/w) + result_shard_pi = torch.einsum( + "t i, p t-> p i", eigen_a[name], proj_shards_wpt[self.rank] + ).contiguous() + + ( + dist.all_reduce(result_shard_pi, op=dist.ReduceOp.SUM) + if dist.is_initialized() + else None + ) shard_size = result_shard_pi.shape[0] // self.world_size start_row = self.rank * shard_size @@ -319,7 +365,9 @@ def prepare_attribution(self): random_eigen_a[name] = result_shard_pi[start_row:end_row, :] - random_activation_path = os.path.join(self.path, "random_activation_eigen_sharded") + random_activation_path = os.path.join( + self.path, "random_activation_eigen_sharded" + ) os.makedirs(random_activation_path, exist_ok=True) save_file( random_eigen_a, @@ -334,9 +382,17 @@ def prepare_attribution(self): side="left", dtype=eigen_g[name].dtype, ) - proj_shards_wqr = torch.chunk(proj_qo, self.world_size, dim=1) # (w, q, o/w) - result_shard_qo = torch.einsum("q r, r o -> q o", proj_shards_wqr[self.rank], eigen_g[name]).contiguous() - dist.all_reduce(result_shard_qo, op=dist.ReduceOp.SUM) if dist.is_initialized() else None + proj_shards_wqr = torch.chunk( + proj_qo, self.world_size, dim=1 + ) # (w, q, o/w) + result_shard_qo = torch.einsum( + "q r, r o -> q o", proj_shards_wqr[self.rank], eigen_g[name] + ).contiguous() + ( + dist.all_reduce(result_shard_qo, op=dist.ReduceOp.SUM) + if dist.is_initialized() + else None + ) shard_size = result_shard_qo.shape[0] // self.world_size start_row = self.rank * shard_size @@ -350,8 +406,12 @@ def prepare_attribution(self): os.path.join(random_gradient_path, f"shard_{self.rank}.safetensors"), ) - self.logger.info(f"Saved random activation eigenvectors to {random_activation_path}") - self.logger.info(f"Saved random gradient eigenvectors to {random_gradient_path}") + self.logger.info( + f"Saved random activation eigenvectors to {random_activation_path}" + ) + self.logger.info( + f"Saved random gradient eigenvectors to {random_gradient_path}" + ) self.logger.info("-*-" * 50) @@ -366,7 +426,8 @@ def compute_ivhp_sharded(self): ) random_eigen_a = load_file( - self.path + f"/random_activation_eigen_sharded/shard_{self.rank}.safetensors", + self.path + + f"/random_activation_eigen_sharded/shard_{self.rank}.safetensors", device=f"cuda:{self.rank}", ) random_eigen_g = load_file( @@ -387,7 +448,10 @@ def compute_ivhp_sharded(self): lambda_factor[k] = v.to(dtype=torch.float32) grad_sizes = { - name: random_eigen_g[name].shape[0] * self.world_size * random_eigen_a[name].shape[0] * self.world_size + name: random_eigen_g[name].shape[0] + * self.world_size + * random_eigen_a[name].shape[0] + * self.world_size for name in random_eigen_a } @@ -403,7 +467,9 @@ def compute_ivhp_sharded(self): dtype=np.float32, ) - self.logger.info(f"Loaded gradients for {len(mmap)} queries and computing IVHP...") + self.logger.info( + f"Loaded gradients for {len(mmap)} queries and computing IVHP..." + ) for i in tqdm( range(math.ceil(info["num_grads"] / self.cfg.gradient_batch_size)), @@ -411,7 +477,8 @@ def compute_ivhp_sharded(self): disable=self.rank != 0, ): batch_slice = slice( - i * self.cfg.gradient_batch_size, min((i + 1) * self.cfg.gradient_batch_size, info["num_grads"]) + i * self.cfg.gradient_batch_size, + min((i + 1) * self.cfg.gradient_batch_size, info["num_grads"]), ) # profile profiler = self._setup_profiler() @@ -437,14 +504,27 @@ def compute_ivhp_sharded(self): self.logger.info(f"Saved IVHP gradients to {self.cfg.run_path}") - def compute_ivhp_batch(self, eigen_a, mmap, eigen_g, lambda_factor, random_eigen_a, random_eigen_g, batch_slice): + def compute_ivhp_batch( + self, + eigen_a, + mmap, + eigen_g, + lambda_factor, + random_eigen_a, + random_eigen_g, + batch_slice, + ): transformed_gradients: dict[str, Tensor] = {} for k, v in eigen_a.items(): gradients_noi = torch.from_numpy(mmap[k][batch_slice]).to( device=self.device, dtype=torch.float32 ) # shape [num_grads, out*in] - gradients_noi = gradients_noi.view(-1, eigen_g[k].shape[1], eigen_a[k].shape[1]) - transformed_gradients[k] = self.sharded_computer._matmul(vector_nsa=gradients_noi, matrix_cb=v) + gradients_noi = gradients_noi.view( + -1, eigen_g[k].shape[1], eigen_a[k].shape[1] + ) + transformed_gradients[k] = self.sharded_computer._matmul( + vector_nsa=gradients_noi, matrix_cb=v + ) self.logger.debug("Finished G @ Q_A") @@ -463,7 +543,9 @@ def compute_ivhp_batch(self, eigen_a, mmap, eigen_g, lambda_factor, random_eigen torch.cuda.empty_cache() for k, v in lambda_factor.items(): - self.sharded_computer._hadamard(matrix_noi=transformed_gradients[k], lambda_ci=v) # this is in-place + self.sharded_computer._hadamard( + matrix_noi=transformed_gradients[k], lambda_ci=v + ) # this is in-place self.logger.debug("Finished G'/lambda") diff --git a/bergson/hessians/logger.py b/bergson/hessians/logger.py index b7c64cf..85ed375 100644 --- a/bergson/hessians/logger.py +++ b/bergson/hessians/logger.py @@ -22,7 +22,9 @@ def filter(self, record): # Create a function to get loggers with consistent naming -def get_logger(name: Optional[str] = None, level: Optional[str] = None) -> logging.Logger: +def get_logger( + name: Optional[str] = None, level: Optional[str] = None +) -> logging.Logger: """ Get a logger with the configured format. diff --git a/bergson/hessians/scripts/ekfac_apply.sh b/bergson/hessians/scripts/ekfac_apply.sh index f44171b..fef4b83 100755 --- a/bergson/hessians/scripts/ekfac_apply.sh +++ b/bergson/hessians/scripts/ekfac_apply.sh @@ -6,5 +6,3 @@ python ../ekfac_apply.py /mnt/ssd-1/louis/emergent_misalignment/gradients_data/m --apply_ekfac \ --gradient_path "/mnt/ssd-1/louis/emergent_misalignment/gradients_data/merged_code/query" \ --gradient_batch_size 40 \ - - diff --git a/bergson/hessians/scripts/ekfac_apply_sweep.sh b/bergson/hessians/scripts/ekfac_apply_sweep.sh index 93f6529..94f776b 100644 --- a/bergson/hessians/scripts/ekfac_apply_sweep.sh +++ b/bergson/hessians/scripts/ekfac_apply_sweep.sh @@ -5,7 +5,7 @@ python ../ekfac_apply.py ekfac_merged_medical_eval \ --apply_ekfac \ --gradient_path "/root/bergson/bergson/hessians/scripts/test_query" \ --gradient_batch_size 50 \ - + python ../ekfac_apply.py ekfac_merged_medical_eval_sampled \ --projection_dim 16 \ --apply_ekfac \ @@ -17,4 +17,4 @@ python ../ekfac_apply.py ekfac_merged_medical_train_sampled \ --projection_dim 16 \ --apply_ekfac \ --gradient_path "/root/bergson/bergson/hessians/scripts/test_query" \ - --gradient_batch_size 50 \ \ No newline at end of file + --gradient_batch_size 50 \ diff --git a/bergson/hessians/scripts/query_random_and_default.sh b/bergson/hessians/scripts/query_random_and_default.sh index 0382367..de314b8 100755 --- a/bergson/hessians/scripts/query_random_and_default.sh +++ b/bergson/hessians/scripts/query_random_and_default.sh @@ -18,7 +18,7 @@ python -m bergson "${BASE_OUTPUT_PATH}/query" \ --token_batch_size 2048 \ --normalizer none \ --fsdp \ - --projection_dim 0 + --projection_dim 0 echo "Completed projection_dim 0 run" diff --git a/bergson/hessians/sharded_computation.py b/bergson/hessians/sharded_computation.py index b107988..ec9e5d1 100644 --- a/bergson/hessians/sharded_computation.py +++ b/bergson/hessians/sharded_computation.py @@ -15,7 +15,9 @@ def __init__(self, target_info, lambda_damp_factor=0.1): self.rank = dist.get_rank() if self.dist else 0 self.world_size = dist.get_world_size() if self.dist else 1 - self.device = torch.device(f"cuda:{self.rank}" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + f"cuda:{self.rank}" if torch.cuda.is_available() else "cpu" + ) self.target_info = target_info self.lambda_damp_factor = lambda_damp_factor @@ -32,12 +34,16 @@ def _init_covariance_dict( # Activation covariance A^T A has shape [in_dim, in_dim] in_dim = weight_shape[1] shard_in_dim = in_dim if not self.dist else in_dim // self.world_size - activation_covariance_dict[name] = torch.zeros((shard_in_dim, in_dim), device=self.device, dtype=dtype) + activation_covariance_dict[name] = torch.zeros( + (shard_in_dim, in_dim), device=self.device, dtype=dtype + ) # Gradient covariance G^T G has shape [out_dim, out_dim] out_dim = weight_shape[0] shard_out_dim = out_dim if not self.dist else out_dim // self.world_size - gradient_covariance_dict[name] = torch.zeros((shard_out_dim, out_dim), device=self.device, dtype=dtype) + gradient_covariance_dict[name] = torch.zeros( + (shard_out_dim, out_dim), device=self.device, dtype=dtype + ) def _matmul( self, @@ -46,11 +52,12 @@ def _matmul( ) -> Float[Tensor, "n s b"]: """Vector-matrix multiplication. - If not distributed, this does usual multiplication with a=c. - - If distributed, assumes that c=a/world_size and does sharded multiplication.""" + - If distributed, assumes that c=a/world_size and does sharded multiplication. + """ - assert vector_nsa.shape[2] == matrix_cb.shape[0] * self.world_size, ( - f"Vector shape {vector_nsa.shape} not compatible with matrix shape {matrix_cb.shape} and world_size {self.world_size}" - ) + assert ( + vector_nsa.shape[2] == matrix_cb.shape[0] * self.world_size + ), f"Vector shape {vector_nsa.shape} not compatible with matrix shape {matrix_cb.shape} and world_size {self.world_size}" if not self.dist: result_nsb = torch.einsum("n s c, c b-> n s b", vector_nsa, matrix_cb) @@ -81,7 +88,9 @@ def _compute_full_matrix( """ files = os.listdir(shard_path) - assert len(files) == self.world_size, f"Expected {self.world_size} shards, found {len(files)} in {shard_path}" + assert ( + len(files) == self.world_size + ), f"Expected {self.world_size} shards, found {len(files)} in {shard_path}" full_matrix = None @@ -89,14 +98,20 @@ def _compute_full_matrix( full_path_rank = os.path.join( shard_path, "shard_0.safetensors" ) # TODO: Does this work with different CUDA visible devices? - with safe_open(full_path_rank, framework="pt", device=f"cuda:{self.rank}") as f: + with safe_open( + full_path_rank, framework="pt", device=f"cuda:{self.rank}" + ) as f: full_matrix = f.get_tensor(name) else: full_matrix_list = [] for shard_id in range(self.world_size): - shard_path_rank = os.path.join(shard_path, f"shard_{shard_id}.safetensors") - with safe_open(shard_path_rank, framework="pt", device=f"cuda:{self.rank}") as f: + shard_path_rank = os.path.join( + shard_path, f"shard_{shard_id}.safetensors" + ) + with safe_open( + shard_path_rank, framework="pt", device=f"cuda:{self.rank}" + ) as f: local_matrix = f.get_tensor(name) full_matrix_list.append(local_matrix) @@ -107,7 +122,10 @@ def _compute_full_matrix( return full_matrix def _merge_and_shard_dict( - self, input_dict: dict[str, torch.Tensor], covariance_type: Literal["activation", "gradient"], dtype + self, + input_dict: dict[str, torch.Tensor], + covariance_type: Literal["activation", "gradient"], + dtype, ) -> dict[str, torch.Tensor]: """This function takes a dict of tensors, where each rank will have *full* eigenvectors of *some* modules. It then redistributes the tensors across all ranks, @@ -127,13 +145,21 @@ def _merge_and_shard_dict( tensor = input_dict[key].to(device=self.device) shard_size = d // self.world_size - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) if dist.is_initialized() else None + ( + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + if dist.is_initialized() + else None + ) shard = torch.empty(shard_size, d, device=self.device, dtype=dtype) - shard.copy_(tensor[self.rank * shard_size : (self.rank + 1) * shard_size, :]) + shard.copy_( + tensor[self.rank * shard_size : (self.rank + 1) * shard_size, :] + ) result_dict[key] = shard.to(device="cpu", non_blocking=True) - assert shard.shape[0] == shard_size, f"Shard shape {shard.shape} does not match expected {shard_size}" + assert ( + shard.shape[0] == shard_size + ), f"Shard shape {shard.shape} does not match expected {shard_size}" del tensor @@ -142,10 +168,14 @@ def _merge_and_shard_dict( return result_dict - def _hadamard(self, matrix_noi: Float[Tensor, "n o i"], lambda_ci: Float[Tensor, "c i"]): + def _hadamard( + self, matrix_noi: Float[Tensor, "n o i"], lambda_ci: Float[Tensor, "c i"] + ): if not self.dist: global_lambda_mean = lambda_ci.mean() - inverse_lambda = (lambda_ci + self.lambda_damp_factor * global_lambda_mean).reciprocal() + inverse_lambda = ( + lambda_ci + self.lambda_damp_factor * global_lambda_mean + ).reciprocal() matrix_noi.mul_(inverse_lambda) else: self._sharded_hadamard(matrix_noi, lambda_ci) @@ -162,7 +192,9 @@ def _sharded_matmul( Returns: [n, s, b] """ # Split the vector into shards - vector_shards_wnsc = torch.chunk(vector_nsa, self.world_size, dim=-1) # (w, n, s, a/w) + vector_shards_wnsc = torch.chunk( + vector_nsa, self.world_size, dim=-1 + ) # (w, n, s, a/w) n, s, b = vector_nsa.shape[0], vector_nsa.shape[1], matrix_cb.shape[1] result_nsb = torch.zeros( @@ -178,13 +210,17 @@ def _sharded_matmul( shard_cb = torch.zeros_like(matrix_cb) dist.broadcast(shard_cb, src=rank_index) - result_nsb += torch.einsum("n s c, c b-> n s b", vector_shards_wnsc[rank_index], shard_cb) # [B, c] + result_nsb += torch.einsum( + "n s c, c b-> n s b", vector_shards_wnsc[rank_index], shard_cb + ) # [B, c] if self.rank != rank_index: del shard_cb return result_nsb - def _sharded_hadamard(self, matrix_noi: Float[Tensor, "n o i"], lambda_ci: Float[Tensor, "c i"]): + def _sharded_hadamard( + self, matrix_noi: Float[Tensor, "n o i"], lambda_ci: Float[Tensor, "c i"] + ): """ Sharded in-place element-wise multiplication for distributed training. gradients: [n, o, i] @@ -207,7 +243,9 @@ def _sharded_hadamard(self, matrix_noi: Float[Tensor, "n o i"], lambda_ci: Float start_row = rank_index * shard_ci.shape[0] end_row = (rank_index + 1) * shard_ci.shape[0] - inverse_lambda = (shard_ci + self.lambda_damp_factor * global_lambda_mean).reciprocal() + inverse_lambda = ( + shard_ci + self.lambda_damp_factor * global_lambda_mean + ).reciprocal() matrix_noi[:, start_row:end_row, :].mul_(inverse_lambda) @@ -229,7 +267,9 @@ def _sharded_transpose_matmul( x, y = (matrix_noi.shape[1], matrix_bc.shape[0] * self.world_size) - result_nxy = torch.zeros(matrix_noi.shape[0], x, y, device=matrix_noi.device, dtype=matrix_noi.dtype) + result_nxy = torch.zeros( + matrix_noi.shape[0], x, y, device=matrix_noi.device, dtype=matrix_noi.dtype + ) for rank_index in range(self.world_size): if rank_index == self.rank: @@ -242,7 +282,9 @@ def _sharded_transpose_matmul( start_row = rank_index * shard_size end_row = (rank_index + 1) * shard_size - result_nxy[:, :, start_row:end_row].copy_(torch.einsum("n o i, c i -> n o c", matrix_noi, shard_bc)) + result_nxy[:, :, start_row:end_row].copy_( + torch.einsum("n o i, c i -> n o c", matrix_noi, shard_bc) + ) if self.rank != rank_index: del shard_bc diff --git a/bergson/utils.py b/bergson/utils.py index dfc71b7..a5974b8 100644 --- a/bergson/utils.py +++ b/bergson/utils.py @@ -18,7 +18,11 @@ def assert_type(typ: Type[T], obj: Any) -> T: def get_layer_list(model: PreTrainedModel | PeftModel) -> nn.ModuleList: """Get the list of layers to train SAEs on.""" N = assert_type(int, model.config.num_hidden_layers) - candidates = [mod for mod in model.base_model.modules() if isinstance(mod, nn.ModuleList) and len(mod) == N] + candidates = [ + mod + for mod in model.base_model.modules() + if isinstance(mod, nn.ModuleList) and len(mod) == N + ] assert len(candidates) == 1, "Could not find the list of layers." return candidates[0] diff --git a/pyproject.toml b/pyproject.toml index dd662a7..2c02cd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ version = "0.0.1" [project.optional-dependencies] dev = [ "pre-commit", + "pytest", + "pyright", ] example = [ "trl", diff --git a/tests/ekfac_tests/compute_ekfac_ground_truth.ipynb b/tests/ekfac_tests/compute_ekfac_ground_truth.ipynb deleted file mode 100644 index 10536ee..0000000 --- a/tests/ekfac_tests/compute_ekfac_ground_truth.ipynb +++ /dev/null @@ -1,1480 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import gc\n", - "import json\n", - "import os\n", - "from dataclasses import asdict\n", - "from typing import Optional\n", - "\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.nn.functional as F\n", - "from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset\n", - "from safetensors.torch import load_file, save_file\n", - "from tqdm.notebook import tqdm\n", - "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", - "\n", - "from bergson.data import DataConfig, IndexConfig, pad_and_tensor, tokenize\n", - "from bergson.gradients import (\n", - " GradientProcessor,\n", - ")\n", - "from bergson.hessians.collector import EkfacCollector\n", - "from bergson.hessians.utils import TensorDict\n", - "from bergson.utils import assert_type" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## -1. Helper functions" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int] = None) -> list[list[list[int]]]:\n", - " \"\"\"\n", - " Modification of allocate_batches to return a flat list of batches for testing(instead of returning allocation[rank])\n", - " Allocate documents into batches that are then distributed evenly across\n", - " a fixed number of workers.\n", - "\n", - " Parameters\n", - " ----------\n", - " doc_lengths : Sequence[int]\n", - " Length (in tokens) of each document. The *i-th* document is referred to\n", - " internally by its index ``i``.\n", - " workers : int\n", - " Number of parallel workers ( 1 ≤ workers ≤ 8).\n", - " N : int\n", - " Hard memory budget per *batch*, expressed as\n", - " ``max(length in batch) * (# docs in batch) ≤ N``.\n", - "\n", - " Returns\n", - " -------\n", - " list[list[list[int]]]\n", - " ``allocation[w][b]`` is the list of document indices that belong to the\n", - " *b-th* batch assigned to worker ``w``. Every worker receives the same\n", - " number of (non-empty) batches.\n", - "\n", - " Raises\n", - " ------\n", - " AllocationError\n", - " If the three hard constraints cannot be satisfied.\n", - "\n", - " Notes\n", - " -----\n", - " 1. **Per-batch cost constraint**: Each batch is padded to the maximum\n", - " sequence length *inside that batch*, so its cost in “token × examples”\n", - " units is ``max_len_in_batch * batch_size``. This must stay ≤ ``N``.\n", - " 2. **Bin-packing strategy**: We use *first-fit decreasing* (FFD) to obtain\n", - " an initial near-minimal set of batches, then split some of the larger\n", - " batches (never increases cost) until\n", - "\n", - " * every worker has at least one batch,\n", - " * the total number of batches is a multiple of ``workers``.\n", - "\n", - " Because each split only lowers the cost of the two resulting batches,\n", - " the constraint in (1) remains satisfied throughout.\n", - " \"\"\"\n", - "\n", - " if workers is None:\n", - " world_size = dist.get_world_size() if dist.is_initialized() else 1\n", - " else:\n", - " world_size = workers\n", - "\n", - " if not doc_lengths:\n", - " raise RuntimeError(\"Empty document list.\")\n", - " if max(doc_lengths) > N: # a single document would overflow any batch\n", - " raise RuntimeError(\"At least one document is too long for the budget N.\")\n", - "\n", - " # ---------------------------------------------------------------------\n", - " # 1) First-fit decreasing (FFD) bin packing under the cost function\n", - " # cost(batch) = max_len_in_batch * len(batch)\n", - " # ---------------------------------------------------------------------\n", - " docs_sorted = sorted(enumerate(doc_lengths), key=lambda x: x[1], reverse=True)\n", - " batches: list[list[int]] = [] # holds document *indices*\n", - " batch_meta = [] # (max_len, size) for each batch\n", - "\n", - " for idx, length in docs_sorted:\n", - " placed = False\n", - " for j, (mx, sz) in enumerate(batch_meta):\n", - " new_mx = max(mx, length)\n", - " new_sz = sz + 1\n", - " if new_mx * new_sz <= N: # still fits\n", - " batches[j].append(idx)\n", - " batch_meta[j] = (new_mx, new_sz)\n", - " placed = True\n", - " break\n", - "\n", - " if not placed: # open a new batch\n", - " batches.append([idx])\n", - " batch_meta.append((length, 1))\n", - "\n", - " # ---------------------------------------------------------------------\n", - " # 2) Ensure every worker gets ≥ 1 batch\n", - " # ---------------------------------------------------------------------\n", - " if len(batches) < world_size:\n", - " # split the largest batches (by size) until we have ≥ workers batches\n", - " batches.sort(key=len, reverse=True)\n", - " while len(batches) < world_size:\n", - " big = batches.pop(0) # take the current largest\n", - " if len(big) == 1: # cannot split a singleton\n", - " raise RuntimeError(\"Not enough documents to give each worker at least one batch.\")\n", - " batches.append([big.pop()]) # move one doc into new batch\n", - " batches.append(big) # put the remainder back\n", - " # preserve cost constraint automatically\n", - "\n", - " # ---------------------------------------------------------------------\n", - " # 3) Pad the number of batches to a multiple of `workers`\n", - " # ---------------------------------------------------------------------\n", - " k = -(-len(batches) // world_size) # ceiling division\n", - " target_batches = world_size * k # == k batches per worker\n", - "\n", - " # Split arbitrary (non-singleton) batches until we reach the target\n", - " i = 0\n", - " while len(batches) < target_batches:\n", - " batch = batches[i % len(batches)]\n", - " if len(batch) == 1:\n", - " i += 1 # try another batch\n", - " continue\n", - " batches.append([batch.pop()]) # split off a singleton\n", - " i += 1\n", - "\n", - " assert len(batches) == target_batches\n", - " assert all(max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches)\n", - "\n", - " # ---------------------------------------------------------------------\n", - " # 4) Round-robin assignment to workers\n", - " # ---------------------------------------------------------------------\n", - " allocation: list[list[list[int]]] = [[] for _ in range(world_size)]\n", - " for b_idx, batch in enumerate(batches):\n", - " allocation[b_idx % world_size].append(batch)\n", - "\n", - " # sanity: equal # of batches per worker\n", - " assert len({len(b) for b in allocation}) == 1\n", - " return allocation" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 0. Hyperparameters" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "import random\n", - "import os\n", - "from torch.backends import cudnn\n", - "\n", - "\n", - "# Set all random seeds\n", - "def set_all_seeds(seed=42):\n", - " random.seed(seed)\n", - " np.random.seed(seed)\n", - " torch.manual_seed(seed)\n", - " torch.cuda.manual_seed(seed)\n", - " torch.cuda.manual_seed_all(seed) # for multi-GPU\n", - " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", - "\n", - "\n", - "set_all_seeds(42) # or whatever seed you prefer\n", - "\n", - "# Force deterministic behavior (sacrifices speed for reproducibility)\n", - "torch.backends.cudnn.deterministic = True\n", - "torch.backends.cudnn.benchmark = False\n", - "torch.use_deterministic_algorithms(True)\n", - "\n", - "# Set environment variables for additional determinism\n", - "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n", - "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "current_path = os.getcwd()\n", - "parent_path = os.path.join(current_path, \"test_files\", \"pile_100_examples\")\n", - "\n", - "test_path = parent_path + \"/ground_truth\"\n", - "ekfac_run_path = parent_path + \"/run/influence_results\"\n", - "\n", - "\n", - "os.makedirs(test_path, exist_ok=True)\n", - "cfg = IndexConfig(run_path=\"\") # empty run path because we are not using it to save data\n", - "cfg.model = \"EleutherAI/Pythia-14m\"\n", - "cfg.precision = \"fp32\"\n", - "cfg.fsdp = False\n", - "\n", - "\n", - "cfg.data = DataConfig(dataset=parent_path + \"/data\")\n", - "# cfg.data = DataConfig(dataset=\"NeelNanda/pile-10k\")\n", - "\n", - "data_str = cfg.data.dataset\n", - "\n", - "# save cfg\n", - "with open(os.path.join(test_path, \"index_config.json\"), \"w\") as f:\n", - " json.dump(asdict(cfg), f, indent=4)\n", - "\n", - "\n", - "workers = 8 # simulating n workers, but we will run on a single GPU to get ground truth\n", - "device = torch.device(\"cuda:0\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "match cfg.precision:\n", - " case \"bf16\":\n", - " dtype = torch.bfloat16\n", - " case \"fp16\":\n", - " dtype = torch.float16\n", - " case \"fp32\":\n", - " dtype = torch.float32\n", - " case \"int4\" | \"int8\":\n", - " dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16\n", - " case other:\n", - " raise ValueError(f\"Unsupported precision: {other}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "debug_name = \"layers.0.mlp.dense_h_to_4h\" # for debugging" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Loading model and data" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`torch_dtype` is deprecated! Use `dtype` instead!\n" - ] - } - ], - "source": [ - "model = AutoModelForCausalLM.from_pretrained(\n", - " cfg.model,\n", - " device_map=\"cuda\",\n", - " quantization_config=(\n", - " BitsAndBytesConfig(\n", - " load_in_4bit=cfg.precision == \"int4\",\n", - " load_in_8bit=cfg.precision == \"int8\",\n", - " bnb_4bit_compute_dtype=dtype,\n", - " bnb_4bit_quant_storage=dtype,\n", - " bnb_4bit_quant_type=\"nf4\",\n", - " bnb_4bit_use_double_quant=True,\n", - " )\n", - " if cfg.precision in (\"int4\", \"int8\")\n", - " else None\n", - " ),\n", - " torch_dtype=dtype,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "data_str = cfg.data.dataset\n", - "if data_str.endswith(\".csv\"):\n", - " ds = assert_type(Dataset, Dataset.from_csv(data_str))\n", - "elif data_str.endswith(\".json\") or data_str.endswith(\".jsonl\"):\n", - " ds = assert_type(Dataset, Dataset.from_json(data_str))\n", - "else:\n", - " try:\n", - " ds = load_dataset(data_str, split=\"train\")\n", - "\n", - " if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):\n", - " raise NotImplementedError(\"DatasetDicts and IterableDatasetDicts are not supported.\")\n", - " except ValueError as e:\n", - " # Automatically use load_from_disk if appropriate\n", - " if \"load_from_disk\" in str(e):\n", - " ds = Dataset.load_from_disk(data_str, keep_in_memory=False)\n", - " else:\n", - " raise e\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(ds, Dataset) # pleasing the typechecker\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(cfg.model, model_max_length=cfg.token_batch_size)\n", - "\n", - "\n", - "ds = ds.map(\n", - " tokenize,\n", - " batched=True,\n", - " fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),\n", - ")\n", - "\n", - "data = ds" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "batches_world = allocate_batches_test(doc_lengths=ds[\"length\"], N=cfg.token_batch_size, workers=workers)\n", - "assert len(batches_world) == workers" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "target_modules = None\n", - "normalizers = {}\n", - "\n", - "processor = GradientProcessor(\n", - " projection_dim=None,\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Compute activation and gradient covariance" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "covariance_test_path = os.path.join(test_path, \"covariances\")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "def compute_covariance(rank: int, activation_covariances={}, gradient_covariances={}):\n", - " total_processed = 0\n", - " batches = batches_world[rank]\n", - "\n", - " loss_list = []\n", - "\n", - " def callback_activation(name: str, a: torch.Tensor):\n", - " activation_covariance = activation_covariances.get(name, None) # Our stored slice\n", - "\n", - " a = a.reshape(-1, a.shape[-1]) # [N*S, O]\n", - "\n", - " # if name == debug_name:\n", - " # print(a[0, 0])\n", - " update = a.mT @ a\n", - " # if name == debug_name:\n", - " # print(update[0, 0])\n", - "\n", - " if activation_covariance is None:\n", - " activation_covariances[name] = update\n", - "\n", - " else:\n", - " # Add it to our permanently stored slice\n", - " activation_covariance.add_(update)\n", - "\n", - " def callback_gradient(name: str, g: torch.Tensor):\n", - " gradient_covariance = gradient_covariances.get(name, None)\n", - "\n", - " g = g.reshape(-1, g.shape[-1]) # [N*S, O]\n", - "\n", - " # if name == debug_name:\n", - " # print(g.abs().sum(), rank)\n", - " update = g.mT @ g\n", - "\n", - " if gradient_covariance is None:\n", - " gradient_covariances[name] = update\n", - " else:\n", - " gradient_covariance.add_(update)\n", - "\n", - " # def callback_gradient(name: str, g: torch.Tensor):\n", - " # gradient_covariance = gradient_covariances.get(name, None)\n", - "\n", - " # g = g.reshape(-1, g.shape[-1]) # [N*S, O]\n", - " # # if name == debug_name:\n", - " # # print(g.abs().sum(), rank)\n", - "\n", - " # if gradient_covariance is None:\n", - " # gradient_covariances[name] = g.sum(dim=0)\n", - " # else:\n", - " # gradient_covariance.add_(g.sum(dim=0))\n", - "\n", - " collector = EkfacCollector(\n", - " model.base_model,\n", - " closure=callback_gradient,\n", - " target_modules=target_modules,\n", - " fwd_closure=callback_activation,\n", - " )\n", - " for sl in tqdm(batches):\n", - " batch = data[sl]\n", - " x, y = pad_and_tensor(\n", - " batch[\"input_ids\"], # type: ignore\n", - " labels=batch.get(\"labels\"), # type: ignore\n", - " device=device,\n", - " )\n", - "\n", - " total_processed += x.numel()\n", - "\n", - " with collector:\n", - " logits = model(x).logits\n", - " losses = F.cross_entropy(\n", - " logits[:, :-1].reshape(-1, logits.size(-1)),\n", - " y[:, 1:].flatten(),\n", - " reduction=\"none\",\n", - " ).reshape_as(y[:, 1:])\n", - "\n", - " losses = losses.sum(1)\n", - "\n", - " losses.mean().backward()\n", - "\n", - " loss_list.append(losses.detach().cpu())\n", - "\n", - " model.zero_grad()\n", - "\n", - " return {\"losses\": loss_list, \"total_processed_rank\": total_processed}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f846066dadf64e259e1dc6bf97ab707d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/2 [00:00 N S O I\", g, activation_cache[name])\n", - "\n", - " gradient = torch.einsum(\"N S O I, I J -> N S O J\", gradient, eigenvector_a)\n", - " gradient = torch.einsum(\" O P, N S O J -> N S P J\", eigenvector_g, gradient)\n", - " gradient = gradient.sum(dim=1) # sum over sequence length\n", - "\n", - " gradient = gradient**2\n", - " correction = gradient.sum(dim=0)\n", - "\n", - " if name not in eigenvalue_corrections:\n", - " eigenvalue_corrections[name] = correction\n", - " else:\n", - " eigenvalue_corrections[name].add_(correction)\n", - "\n", - " collector = EkfacCollector(\n", - " model.base_model,\n", - " closure=callback_gradient,\n", - " target_modules=target_modules,\n", - " fwd_closure=callback_activation,\n", - " )\n", - " for sl in tqdm(batches):\n", - " batch = data[sl]\n", - " x, y = pad_and_tensor(\n", - " batch[\"input_ids\"], # type: ignore\n", - " labels=batch.get(\"labels\"), # type: ignore\n", - " device=device,\n", - " )\n", - "\n", - " total_processed += x.numel()\n", - "\n", - " with collector:\n", - " logits = model(x).logits\n", - " losses = F.cross_entropy(\n", - " logits[:, :-1].reshape(-1, logits.size(-1)),\n", - " y[:, 1:].flatten(),\n", - " reduction=\"none\",\n", - " ).reshape_as(y[:, 1:])\n", - "\n", - " losses = losses.sum(1)\n", - "\n", - " losses.mean().backward()\n", - "\n", - " model.zero_grad()\n", - "\n", - " return {\"losses\": loss_list, \"total_processed_rank\": total_processed}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], - "source": [ - "gradient_cache_amortized = {}\n", - "\n", - "\n", - "def compute_eigenvalue_correction_amortized(\n", - " rank: int,\n", - " eigenvalue_corrections,\n", - " eigenvectors_activations=eigenvectors_activations,\n", - " eigenvectors_gradients=eigenvectors_gradients,\n", - "):\n", - " total_processed = 0\n", - " batches = batches_world[rank]\n", - "\n", - " def callback_activation(name: str, a: torch.Tensor):\n", - " # a = torch.ones_like(a) # for debugging, pretend all activations are 1\n", - " activation = a # [N, S, I]\n", - " activation_cache[name] = activation\n", - "\n", - " def callback_gradient(name: str, g: torch.Tensor):\n", - " eigenvector_a = eigenvectors_activations[name].to(device=device)\n", - " eigenvector_g = eigenvectors_gradients[name].to(device=device)\n", - " gradient = g # [N, S, O]\n", - "\n", - " transformed_a = torch.einsum(\"N S I, I J -> N S J\", activation_cache[name], eigenvector_a)\n", - " transformed_g = torch.einsum(\"O P, N S O -> N S P\", eigenvector_g, gradient)\n", - " correction = (torch.einsum(\" N S O, N S I ->N O I\", transformed_g, transformed_a) ** 2).sum(dim=0).contiguous()\n", - " # torch.save(activation_cache[name], \"activation.pt\")\n", - " # torch.save(transformed_a, \"transformed_a.pt\")\n", - " # torch.save(transformed_g, \"transformed_g.pt\")\n", - " # torch.save(correction, \"correction.pt\")\n", - "\n", - " if name == debug_name:\n", - " gradient_cache_amortized[name] = g\n", - " if name not in eigenvalue_corrections:\n", - " eigenvalue_corrections[name] = correction\n", - " else:\n", - " eigenvalue_corrections[name].add_(correction)\n", - "\n", - " collector = EkfacCollector(\n", - " model.base_model,\n", - " closure=callback_gradient,\n", - " target_modules=target_modules,\n", - " fwd_closure=callback_activation,\n", - " )\n", - " for sl in tqdm(batches):\n", - " batch = data[sl]\n", - " x, y = pad_and_tensor(\n", - " batch[\"input_ids\"], # type: ignore\n", - " labels=batch.get(\"labels\"), # type: ignore\n", - " device=device,\n", - " )\n", - "\n", - " total_processed += x.numel()\n", - "\n", - " with collector:\n", - " logits = model(x).logits\n", - " losses = F.cross_entropy(\n", - " logits[:, :-1].reshape(-1, logits.size(-1)),\n", - " y[:, 1:].flatten(),\n", - " reduction=\"none\",\n", - " ).reshape_as(y[:, 1:])\n", - "\n", - " losses = losses.sum(1)\n", - "\n", - " losses.mean().backward()\n", - "\n", - " model.zero_grad()\n", - "\n", - " return {\"losses\": loss_list, \"total_processed_rank\": total_processed}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c0beeae5588d4f5e92fce5ba10687a97", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/2 [00:00 Batches: + """ + Modification of allocate_batches to return a flat list of batches for testing. + + Allocate documents into batches that are then distributed evenly across + a fixed number of workers. + + Parameters + ---------- + doc_lengths : Sequence[int] + Length (in tokens) of each document. + workers : int + Number of parallel workers ( 1 ≤ workers ≤ 8). + N : int + Hard memory budget per *batch*, expressed as + ``max(length in batch) * (# docs in batch) ≤ N``. + + Returns + ------- + list[list[list[int]]] + ``allocation[w][b]`` is the list of document indices that belong to the + *b-th* batch assigned to worker ``w``. + """ + if workers is None: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + else: + world_size = workers + + if not doc_lengths: + raise RuntimeError("Empty document list.") + if max(doc_lengths) > N: + raise RuntimeError("At least one document is too long for the budget N.") + + # First-fit decreasing (FFD) bin packing + docs_sorted = sorted(enumerate(doc_lengths), key=lambda x: x[1], reverse=True) + batches: list[list[int]] = [] + batch_meta = [] + + for idx, length in docs_sorted: + placed = False + for j, (mx, sz) in enumerate(batch_meta): + new_mx = max(mx, length) + new_sz = sz + 1 + if new_mx * new_sz <= N: + batches[j].append(idx) + batch_meta[j] = (new_mx, new_sz) + placed = True + break + + if not placed: + batches.append([idx]) + batch_meta.append((length, 1)) + + # Ensure every worker gets ≥ 1 batch + if len(batches) < world_size: + batches.sort(key=len, reverse=True) + while len(batches) < world_size: + big = batches.pop(0) + if len(big) == 1: + raise RuntimeError( + "Not enough documents to give each worker at least one batch." + ) + batches.append([big.pop()]) + batches.append(big) + + # Pad the number of batches to a multiple of `workers` + k = -(-len(batches) // world_size) + target_batches = world_size * k + + i = 0 + while len(batches) < target_batches: + batch = batches[i % len(batches)] + if len(batch) == 1: + i += 1 + continue + batches.append([batch.pop()]) + i += 1 + + assert len(batches) == target_batches + assert all( + max(doc_lengths[i] for i in batch) * len(batch) <= N for batch in batches + ) + + # Round-robin assignment to workers + allocation: Batches = [[] for _ in range(world_size)] + for b_idx, batch in enumerate(batches): + allocation[b_idx % world_size].append(batch) + + assert len({len(b) for b in allocation}) == 1 + return allocation + + +# %% [markdown] +# ## 0. Hyperparameters + + +# %% +def parse_config() -> tuple[Precision, Optional[str]]: + """Parse command-line arguments or return defaults.""" + precision: Precision + output_dir: Optional[str] + + if len(sys.argv) > 1 and not hasattr(builtins, "__IPYTHON__"): + parser = argparse.ArgumentParser( + description="Compute EKFAC ground truth for testing" + ) + parser.add_argument( + "--precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "int4", "int8"], + help="Model precision (default: fp32)", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default=None, + help="Output directory for ground truth results (default: test_files/pile_100_examples/ground_truth)", + ) + args = parser.parse_args() + precision = args.precision + output_dir = args.output_dir + else: + # Defaults for interactive execution or running without arguments + precision = "fp32" + output_dir = None + + # Set random seeds for reproducibility + set_all_seeds(42) + + return precision, output_dir + + +if __name__ == "__main__" or TYPE_CHECKING: + precision, output_dir = parse_config() + + +# %% +def setup_paths_and_config( + precision: Precision, output_dir: Optional[str] = None +) -> tuple[IndexConfig, str, int, torch.device, Any, torch.dtype]: + """Setup paths and configuration object.""" + current_path = os.getcwd() + parent_path = os.path.join(current_path, "test_files", "pile_100_examples") + if output_dir is not None: + test_path = output_dir + else: + test_path = os.path.join(parent_path, "ground_truth") + os.makedirs(test_path, exist_ok=True) + + # Configuration + cfg = IndexConfig(run_path="") + cfg.model = "EleutherAI/Pythia-14m" + cfg.precision = precision + cfg.fsdp = False + cfg.data = DataConfig(dataset=os.path.join(parent_path, "data")) + + data_str = cfg.data.dataset + + # Create pile-100 dataset if it doesn't exist + if not os.path.exists(data_str): + full_dataset = load_dataset("NeelNanda/pile-10k", split="train") + assert isinstance(full_dataset, Dataset), "Expected Dataset, got something else" + subset = full_dataset.select(range(100)) + os.makedirs(os.path.dirname(data_str), exist_ok=True) + subset.save_to_disk(data_str) + print(f"Generated pile-100 in {data_str}") + + # Save config + with open(os.path.join(test_path, "index_config.json"), "w") as f: + json.dump(asdict(cfg), f, indent=4) + + # Setup + workers = 8 + device = torch.device("cuda:0") + target_modules = None + + # Determine dtype + match cfg.precision: + case "bf16": + dtype = torch.bfloat16 + case "fp16": + dtype = torch.float16 + case "fp32": + dtype = torch.float32 + case "int4" | "int8": + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + case other: + raise ValueError(f"Unsupported precision: {other}") + + return cfg, test_path, workers, device, target_modules, dtype + + +if __name__ == "__main__" or TYPE_CHECKING: + cfg, test_path, workers, device, target_modules, dtype = setup_paths_and_config( + precision, output_dir + ) + + +# %% [markdown] +# ## 1. Loading model and data + + +# %% +def load_model_step(cfg: IndexConfig, dtype: torch.dtype) -> PreTrainedModel: + """Load the model.""" + print(f"Loading model {cfg.model}...") + model = AutoModelForCausalLM.from_pretrained( + cfg.model, + device_map="cuda", + quantization_config=( + BitsAndBytesConfig( + load_in_4bit=cfg.precision == "int4", + load_in_8bit=cfg.precision == "int8", + bnb_4bit_compute_dtype=dtype, + bnb_4bit_quant_storage=dtype, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + if cfg.precision in ("int4", "int8") + else None + ), + dtype=dtype, + ) + return model + + +if __name__ == "__main__" or TYPE_CHECKING: + model = load_model_step(cfg, dtype) + + +# %% +def load_dataset_step(cfg: IndexConfig) -> Dataset: + """Load and return the dataset.""" + data_str = cfg.data.dataset + print(f"Loading dataset from {data_str}...") + + if data_str.endswith(".csv"): + ds = assert_type(Dataset, Dataset.from_csv(data_str)) + elif data_str.endswith(".json") or data_str.endswith(".jsonl"): + ds = assert_type(Dataset, Dataset.from_json(data_str)) + else: + try: + ds = load_dataset(data_str, split="train") + if isinstance(ds, (DatasetDict, IterableDatasetDict)): + raise NotImplementedError( + "DatasetDicts and IterableDatasetDicts are not supported." + ) + except ValueError as e: + if "load_from_disk" in str(e): + ds = Dataset.load_from_disk(data_str, keep_in_memory=False) + else: + raise e + + assert isinstance(ds, Dataset) + return ds + + +if __name__ == "__main__" or TYPE_CHECKING: + ds = load_dataset_step(cfg) + + +# %% +def tokenize_and_allocate_step( + ds: Dataset, cfg: IndexConfig, workers: int +) -> tuple[Dataset, Batches, Any]: + """Tokenize dataset and allocate batches.""" + tokenizer = AutoTokenizer.from_pretrained( + cfg.model, model_max_length=cfg.token_batch_size + ) + ds = ds.map( + tokenize, batched=True, fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer) + ) + data = ds + + # Allocate batches + batches_world = allocate_batches_test( + doc_lengths=ds["length"], N=cfg.token_batch_size, workers=workers + ) + assert len(batches_world) == workers + + return data, batches_world, tokenizer + + +if __name__ == "__main__" or TYPE_CHECKING: + data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers) + + +# %% [markdown] +# ## 2. Compute activation and gradient covariance + + +# %% +def compute_covariance( + rank: int, + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + activation_covariances: dict[str, Tensor], + gradient_covariances: dict[str, Tensor], +) -> dict[str, Any]: + """Compute activation and gradient covariances for a single worker.""" + total_processed = 0 + batches = batches_world[rank] + loss_list = [] + + collector = GroundTruthCovarianceCollector( + model=model.base_model, + activation_covariances=activation_covariances, + gradient_covariances=gradient_covariances, + target_modules=target_modules, + ) + + for sl in tqdm(batches, desc=f"Rank {rank} covariances"): + batch = data[sl] + x, y = pad_and_tensor( + batch["input_ids"], + labels=batch.get("labels"), + device=device, + ) + + total_processed += x.numel() + + with collector: + logits = model(x).logits + losses = F.cross_entropy( + logits[:, :-1].reshape(-1, logits.size(-1)), + y[:, 1:].flatten(), + reduction="none", + ).reshape_as(y[:, 1:]) + + losses = losses.sum(1) + losses.mean().backward() + loss_list.append(losses.detach().cpu()) + model.zero_grad() + + return {"losses": loss_list, "total_processed_rank": total_processed} + + +# %% +def compute_covariances_step( + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + workers: int, + test_path: str, +) -> str: + """Compute covariances for all ranks and save to disk.""" + covariance_test_path = os.path.join(test_path, "covariances") + + for rank in range(workers): + covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}") + os.makedirs(covariance_test_path_rank, exist_ok=True) + + activation_covariances = {} + gradient_covariances = {} + d = compute_covariance( + rank=rank, + model=model, + data=data, + batches_world=batches_world, + device=device, + target_modules=target_modules, + activation_covariances=activation_covariances, + gradient_covariances=gradient_covariances, + ) + + save_file( + activation_covariances, + os.path.join( + covariance_test_path_rank, "activation_covariance.safetensors" + ), + ) + save_file( + gradient_covariances, + os.path.join(covariance_test_path_rank, "gradient_covariance.safetensors"), + ) + with open(os.path.join(covariance_test_path_rank, "stats.json"), "w") as f: + json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4) + print(f"Rank {rank} processed {d['total_processed_rank']} tokens.") + + return covariance_test_path + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Covariances ===") + covariance_test_path = compute_covariances_step( + model, data, batches_world, device, target_modules, workers, test_path + ) + + +# %% +def combine_covariances_step( + covariance_test_path: str, workers: int, device: torch.device +) -> int: + """Combine covariance results from all ranks.""" + activation_covariances = TensorDict({}) + gradient_covariances = TensorDict({}) + total_processed_global = 0 + + for rank in range(workers): + covariance_test_path_rank = os.path.join(covariance_test_path, f"rank_{rank}") + + with open(os.path.join(covariance_test_path_rank, "stats.json"), "r") as f: + d = json.load(f) + total_processed_global += d["total_processed_rank"] + + activation_covariances_rank = TensorDict( + load_file( + os.path.join( + covariance_test_path_rank, "activation_covariance.safetensors" + ) + ) + ).to(device) + + gradient_covariances_rank = TensorDict( + load_file( + os.path.join( + covariance_test_path_rank, "gradient_covariance.safetensors" + ) + ) + ).to(device) + + if not activation_covariances: + activation_covariances = activation_covariances_rank + else: + activation_covariances = ( + activation_covariances + activation_covariances_rank + ) + + if not gradient_covariances: + gradient_covariances = gradient_covariances_rank + else: + gradient_covariances = gradient_covariances + gradient_covariances_rank + + save_file( + activation_covariances.to_dict(), + os.path.join(covariance_test_path, "activation_covariance.safetensors"), + ) + save_file( + gradient_covariances.to_dict(), + os.path.join(covariance_test_path, "gradient_covariance.safetensors"), + ) + with open(os.path.join(covariance_test_path, "stats.json"), "w") as f: + json.dump({"total_processed_global": total_processed_global}, f, indent=4) + print(f"Global processed {total_processed_global} tokens.") + + gc.collect() + torch.cuda.empty_cache() + + return total_processed_global + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Combining Covariances ===") + total_processed_global = combine_covariances_step( + covariance_test_path, workers, device + ) + + +# %% [markdown] +# ## 3. Compute eigenvalues and eigenvectors + + +# %% +def compute_eigenvectors_step( + test_path: str, device: torch.device, dtype: torch.dtype +) -> str: + """Compute eigenvectors from covariances.""" + covariance_test_path = os.path.join(test_path, "covariances") + eigenvectors_test_path = os.path.join(test_path, "eigenvectors") + os.makedirs(eigenvectors_test_path, exist_ok=True) + + # Load covariances + with open(os.path.join(covariance_test_path, "stats.json"), "r") as f: + d = json.load(f) + total_processed_global = d["total_processed_global"] + + activation_covariances = load_file( + os.path.join(covariance_test_path, "activation_covariance.safetensors") + ) + gradient_covariances = load_file( + os.path.join(covariance_test_path, "gradient_covariance.safetensors") + ) + + eigenvectors_activations = {} + eigenvectors_gradients = {} + + for name in activation_covariances.keys(): + a = activation_covariances[name].to(dtype=torch.float64, device=device) + g = gradient_covariances[name].to(dtype=torch.float64, device=device) + a = (a + a.T).div(2) + g = (g + g.T).div(2) + a.div_(total_processed_global) + g.div_(total_processed_global) + + eigenvalues_a, eigenvectors_a = torch.linalg.eigh(a) + eigenvalues_g, eigenvectors_g = torch.linalg.eigh(g) + print( + f"{name}: eigenvectors_a.sum()={eigenvectors_a.sum()}, eigenvectors_g.sum()={eigenvectors_g.sum()}" + ) + eigenvectors_activations[name] = eigenvectors_a.to(dtype=dtype).contiguous() + eigenvectors_gradients[name] = eigenvectors_g.to(dtype=dtype).contiguous() + + save_file( + eigenvectors_activations, + os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors"), + ) + save_file( + eigenvectors_gradients, + os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors"), + ) + + gc.collect() + torch.cuda.empty_cache() + + return eigenvectors_test_path + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Eigenvectors ===") + eigenvectors_test_path = compute_eigenvectors_step(test_path, device, dtype) + + +# %% [markdown] +# ## 4. Compute eigenvalue correction + + +# %% +def compute_eigenvalue_correction_amortized( + rank: int, + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + eigenvalue_corrections: dict[str, Tensor], + eigenvectors_activations: dict[str, Tensor], + eigenvectors_gradients: dict[str, Tensor], +) -> dict[str, int]: + """Compute eigenvalue corrections using the amortized method.""" + total_processed = 0 + batches = batches_world[rank] + + collector = GroundTruthAmortizedLambdaCollector( + model=model.base_model, + eigenvalue_corrections=eigenvalue_corrections, + eigenvectors_activations=eigenvectors_activations, + eigenvectors_gradients=eigenvectors_gradients, + device=device, + target_modules=target_modules, + ) + + for sl in tqdm(batches, desc=f"Rank {rank} eigenvalue corrections"): + batch = data[sl] + x, y = pad_and_tensor( + batch["input_ids"], + labels=batch.get("labels"), + device=device, + ) + + total_processed += x.numel() + + with collector: + logits = model(x).logits + losses = F.cross_entropy( + logits[:, :-1].reshape(-1, logits.size(-1)), + y[:, 1:].flatten(), + reduction="none", + ).reshape_as(y[:, 1:]) + + losses = losses.sum(1) + losses.mean().backward() + model.zero_grad() + + return {"total_processed_rank": total_processed} + + +# %% +def compute_eigenvalue_corrections_step( + model: PreTrainedModel, + data: Dataset, + batches_world: Batches, + device: torch.device, + target_modules: Any, + workers: int, + test_path: str, +) -> tuple[str, int]: + """Compute eigenvalue corrections for all ranks.""" + eigenvectors_test_path = os.path.join(test_path, "eigenvectors") + eigenvalue_correction_test_path = os.path.join(test_path, "eigenvalue_corrections") + os.makedirs(eigenvalue_correction_test_path, exist_ok=True) + + # Load eigenvectors + eigenvectors_activations = load_file( + os.path.join(eigenvectors_test_path, "eigenvectors_activations.safetensors") + ) + eigenvectors_gradients = load_file( + os.path.join(eigenvectors_test_path, "eigenvectors_gradients.safetensors") + ) + + total_processed_global = 0 + for rank in range(workers): + eigenvalue_correction_test_path_rank = os.path.join( + eigenvalue_correction_test_path, f"rank_{rank}" + ) + os.makedirs(eigenvalue_correction_test_path_rank, exist_ok=True) + + eigenvalue_corrections = {} + d = compute_eigenvalue_correction_amortized( + rank=rank, + model=model, + data=data, + batches_world=batches_world, + device=device, + target_modules=target_modules, + eigenvalue_corrections=eigenvalue_corrections, + eigenvectors_activations=eigenvectors_activations, + eigenvectors_gradients=eigenvectors_gradients, + ) + + save_file( + eigenvalue_corrections, + os.path.join( + eigenvalue_correction_test_path_rank, + "eigenvalue_corrections.safetensors", + ), + ) + with open( + os.path.join(eigenvalue_correction_test_path_rank, "stats.json"), "w" + ) as f: + json.dump({"total_processed_rank": d["total_processed_rank"]}, f, indent=4) + print(f"Rank {rank} processed {d['total_processed_rank']} tokens.") + total_processed_global += d["total_processed_rank"] + + return eigenvalue_correction_test_path, total_processed_global + + +if __name__ == "__main__" or TYPE_CHECKING: + print("\n=== Computing Eigenvalue Corrections ===") + eigenvalue_correction_test_path, total_processed_global_lambda = ( + compute_eigenvalue_corrections_step( + model, data, batches_world, device, target_modules, workers, test_path + ) + ) + + +# %% +def combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path: str, + workers: int, + device: torch.device, + total_processed_global: int, +) -> TensorDict: + """Combine eigenvalue correction results from all ranks.""" + eigenvalue_corrections = TensorDict({}) + + for rank in range(workers): + eigenvalue_correction_test_path_rank = os.path.join( + eigenvalue_correction_test_path, f"rank_{rank}" + ) + + eigenvalue_corrections_rank = TensorDict( + load_file( + os.path.join( + eigenvalue_correction_test_path_rank, + "eigenvalue_corrections.safetensors", + ) + ) + ).to(device) + + if not eigenvalue_corrections: + eigenvalue_corrections = eigenvalue_corrections_rank + else: + eigenvalue_corrections = ( + eigenvalue_corrections + eigenvalue_corrections_rank + ) + + eigenvalue_corrections.div_(total_processed_global) + save_file( + eigenvalue_corrections.to_dict(), + os.path.join( + eigenvalue_correction_test_path, "eigenvalue_corrections.safetensors" + ), + ) + + return eigenvalue_corrections + + +if __name__ == "__main__" or TYPE_CHECKING: + eigenvalue_corrections = combine_eigenvalue_corrections_step( + eigenvalue_correction_test_path, workers, device, total_processed_global_lambda + ) + print("\n=== Ground Truth Computation Complete ===") + print(f"Results saved to: {test_path}") diff --git a/tests/ekfac_tests/ground_truth/__init__.py b/tests/ekfac_tests/ground_truth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ekfac_tests/ground_truth/collector.py b/tests/ekfac_tests/ground_truth/collector.py new file mode 100644 index 0000000..74adf74 --- /dev/null +++ b/tests/ekfac_tests/ground_truth/collector.py @@ -0,0 +1,121 @@ +"""Ground truth collector for EKFAC testing. + +This module provides a collector that computes activation and gradient covariances +for ground truth comparison, mimicking the old EkfacCollector (removed in commit 8232b77). +""" + +from collections.abc import Mapping, MutableMapping +from dataclasses import dataclass + +import torch +from torch import Tensor + +from bergson.hessians.collector import HookCollectorBase + + +@dataclass(kw_only=True) +class GroundTruthCovarianceCollector(HookCollectorBase): + activation_covariances: MutableMapping[str, Tensor] + gradient_covariances: MutableMapping[str, Tensor] + + def setup(self) -> None: + pass + + def teardown(self) -> None: + pass + + def forward_hook(self, name: str, a: Tensor) -> None: + a = a.reshape(-1, a.shape[-1]) # [N*S, O] + + update = a.mT @ a + + if name not in self.activation_covariances: + self.activation_covariances[name] = update + else: + self.activation_covariances[name].add_(update) + + def backward_hook(self, name: str, g: Tensor) -> None: + g = g.reshape(-1, g.shape[-1]) # [N*S, O] + + update = g.mT @ g + + if name not in self.gradient_covariances: + self.gradient_covariances[name] = update + else: + self.gradient_covariances[name].add_(update) + + +@dataclass(kw_only=True) +class GroundTruthNonAmortizedLambdaCollector(HookCollectorBase): + eigenvalue_corrections: MutableMapping[str, Tensor] + eigenvectors_activations: Mapping[str, Tensor] + eigenvectors_gradients: Mapping[str, Tensor] + device: torch.device + + def setup(self) -> None: + self.activation_cache: dict[str, Tensor] = {} + + def teardown(self) -> None: + self.activation_cache.clear() + + def forward_hook(self, name: str, a: Tensor) -> None: + self.activation_cache[name] = a + + def backward_hook(self, name: str, g: Tensor) -> None: + eigenvector_a = self.eigenvectors_activations[name].to(device=self.device) + eigenvector_g = self.eigenvectors_gradients[name].to(device=self.device) + + activation = self.activation_cache[name] # [N, S, I] + gradient = g # [N, S, O] + + gradient = torch.einsum("N S O, N S I -> N S O I", gradient, activation) + + gradient = torch.einsum("N S O I, I J -> N S O J", gradient, eigenvector_a) + gradient = torch.einsum("O P, N S O J -> N S P J", eigenvector_g, gradient) + + gradient = gradient.sum(dim=1) # sum over sequence length + + gradient = gradient**2 + correction = gradient.sum(dim=0) + + if name not in self.eigenvalue_corrections: + self.eigenvalue_corrections[name] = correction + else: + self.eigenvalue_corrections[name].add_(correction) + + +@dataclass(kw_only=True) +class GroundTruthAmortizedLambdaCollector(HookCollectorBase): + eigenvalue_corrections: MutableMapping[str, Tensor] + eigenvectors_activations: Mapping[str, Tensor] + eigenvectors_gradients: Mapping[str, Tensor] + device: torch.device + + def setup(self) -> None: + self.activation_cache: dict[str, Tensor] = {} + + def teardown(self) -> None: + self.activation_cache.clear() + + def forward_hook(self, name: str, a: Tensor) -> None: + self.activation_cache[name] = a + + def backward_hook(self, name: str, g: Tensor) -> None: + eigenvector_a = self.eigenvectors_activations[name].to(device=self.device) + eigenvector_g = self.eigenvectors_gradients[name].to(device=self.device) + + activation = self.activation_cache[name] # [N, S, I] + + transformed_a = torch.einsum("N S I, I J -> N S J", activation, eigenvector_a) + transformed_g = torch.einsum("O P, N S O -> N S P", eigenvector_g, g) + + correction = ( + (torch.einsum("N S O, N S I -> N O I", transformed_g, transformed_a) ** 2) + .sum(dim=0) + .contiguous() + ) + + if name not in self.eigenvalue_corrections: + self.eigenvalue_corrections[name] = correction + else: + self.eigenvalue_corrections[name].add_(correction) diff --git a/tests/ekfac_tests/run_apply_compute_ekfac.sh b/tests/ekfac_tests/run_apply_compute_ekfac.sh index d94c12e..3d5117c 100755 --- a/tests/ekfac_tests/run_apply_compute_ekfac.sh +++ b/tests/ekfac_tests/run_apply_compute_ekfac.sh @@ -8,4 +8,3 @@ python test_apply_ekfac.py \ --use_fsdp \ --world_size 8 \ --gradient_batch_size 10 \ - diff --git a/tests/ekfac_tests/run_test_compute_ekfac.sh b/tests/ekfac_tests/run_test_compute_ekfac.sh index e64d91d..739a0b5 100755 --- a/tests/ekfac_tests/run_test_compute_ekfac.sh +++ b/tests/ekfac_tests/run_test_compute_ekfac.sh @@ -2,9 +2,7 @@ # Run all tests python test_compute_ekfac.py \ - --test_dir "./test_files/pile_100_examples" \ + --test_dir "/root/bergson/test_files/pile_100_examples" \ --world_size 8 \ --use_fsdp \ --overwrite - - diff --git a/tests/ekfac_tests/test_apply_ekfac.py b/tests/ekfac_tests/test_apply_ekfac.py index a336512..414cf2e 100644 --- a/tests/ekfac_tests/test_apply_ekfac.py +++ b/tests/ekfac_tests/test_apply_ekfac.py @@ -61,18 +61,24 @@ def test_gradients(run_path, ground_truth_path): - ground_truth = load_file(os.path.join(ground_truth_path, "gradients.safetensors"), device="cuda") + ground_truth = load_file( + os.path.join(ground_truth_path, "gradients.safetensors"), device="cuda" + ) computed_mmap = load_gradients(run_path) for k in ground_truth.keys(): ground_truth_tensor = ground_truth[k].to(dtype=torch.float32) computed_tensor = ( - torch.from_numpy(computed_mmap[k].copy()).to(device="cuda").view(-1, *ground_truth_tensor.shape[1:]) + torch.from_numpy(computed_mmap[k].copy()) + .to(device="cuda") + .view(-1, *ground_truth_tensor.shape[1:]) ).to(dtype=torch.float32) if not (ground_truth_tensor.shape == computed_tensor.shape): - raise ValueError(f"Shape mismatch for key {k}: {ground_truth_tensor.shape} vs {computed_tensor.shape}") + raise ValueError( + f"Shape mismatch for key {k}: {ground_truth_tensor.shape} vs {computed_tensor.shape}" + ) if not torch.allclose(ground_truth_tensor, computed_tensor, rtol=1e-3, atol=0): abs_diff = torch.abs(ground_truth_tensor - computed_tensor) @@ -86,7 +92,9 @@ def test_gradients(run_path, ground_truth_path): gt_val = ground_truth_tensor.flatten()[argmax_idx].item() comp_val = computed_tensor.flatten()[argmax_idx].item() - print(f"Mismatch '{k}': max_abs={max_abs_diff:.2e}, max_rel={max_rel_diff:.2e}") + print( + f"Mismatch '{k}': max_abs={max_abs_diff:.2e}, max_rel={max_rel_diff:.2e}" + ) print(f" At {tuple(coords)}: gt={gt_val:.2e}, comp={comp_val:.2e}") @@ -101,9 +109,13 @@ def main(): ] for file_name in required_files: - assert os.path.exists(os.path.join(ground_truth_path, file_name)), f"Missing required file: {file_name}" + assert os.path.exists( + os.path.join(ground_truth_path, file_name) + ), f"Missing required file: {file_name}" - cfg_json = json.load(open(os.path.join(ground_truth_path, "index_config.json"), "r")) + cfg_json = json.load( + open(os.path.join(ground_truth_path, "index_config.json"), "r") + ) print(cfg_json) cfg = IndexConfig(**cfg_json) diff --git a/tests/ekfac_tests/test_compute_ekfac.py b/tests/ekfac_tests/test_compute_ekfac.py index b81db06..e978832 100644 --- a/tests/ekfac_tests/test_compute_ekfac.py +++ b/tests/ekfac_tests/test_compute_ekfac.py @@ -6,6 +6,7 @@ from test_covariance import test_covariances from test_eigenvalue_correction import test_eigenvalue_correction from test_eigenvectors import test_eigenvectors +from test_utils import set_all_seeds from bergson.data import DataConfig, IndexConfig from bergson.distributed import distributed_computing @@ -55,35 +56,11 @@ run_path = os.path.join(test_dir, "run/influence_results") -import os -import random - -import numpy as np - - -def deterministic_cuda(seed=42): - # Set all random seeds - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # for multi-GPU - os.environ["PYTHONHASHSEED"] = str(seed) - - # Force deterministic behavior (sacrifices speed for reproducibility) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) - - # Set environment variables for additional determinism - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - - def test_total_processed_examples(): - total_processed_ground_truth_path = os.path.join(ground_truth_path, "covariances/stats.json") - total_processed_run_path = os.path.join(run_path, "total_processed.pt") + total_processed_ground_truth_path = os.path.join( + ground_truth_path, "covariances/stats.json" + ) + total_processed_run_path = os.path.join(run_path, "total_processed_covariances.pt") with open(total_processed_ground_truth_path, "r") as f: ground_truth_data = json.load(f) @@ -106,7 +83,7 @@ def test_total_processed_examples(): def main(): # assert covariances, eigenvalue_corrections, eigenvectors and index_config.json exist - deterministic_cuda(seed=42) + set_all_seeds(seed=42) required_files = [ "covariances", "eigenvalue_corrections", @@ -115,9 +92,13 @@ def main(): ] for file_name in required_files: - assert os.path.exists(os.path.join(ground_truth_path, file_name)), f"Missing required file: {file_name}" + assert os.path.exists( + os.path.join(ground_truth_path, file_name) + ), f"Missing required file: {file_name}" - cfg_json = json.load(open(os.path.join(ground_truth_path, "index_config.json"), "r")) + cfg_json = json.load( + open(os.path.join(ground_truth_path, "index_config.json"), "r") + ) cfg = IndexConfig(**cfg_json) diff --git a/tests/ekfac_tests/test_covariance.py b/tests/ekfac_tests/test_covariance.py index ceff66e..9824d75 100644 --- a/tests/ekfac_tests/test_covariance.py +++ b/tests/ekfac_tests/test_covariance.py @@ -15,7 +15,9 @@ def test_covariances( covariances_ground_truth_path = os.path.join( ground_truth_path, f"covariances/{covariance_type}_covariance.safetensors" ) - covariances_run_path = os.path.join(run_path, f"{covariance_type}_covariance_sharded") + covariances_run_path = os.path.join( + run_path, f"{covariance_type}_covariance_sharded" + ) # load ground_truth ground_truth_covariances = TensorDict(load_file(covariances_ground_truth_path)) @@ -24,20 +26,29 @@ def test_covariances( # load run covariances shards and concatenate them run_covariances_shards = [ - os.path.join(covariances_run_path, f"shard_{rank}.safetensors") for rank in range(world_size) + os.path.join(covariances_run_path, f"shard_{rank}.safetensors") + for rank in range(world_size) ] run_covariances_list = [(load_file(shard)) for shard in run_covariances_shards] run_covariances = {} for k, v in run_covariances_list[0].items(): - run_covariances[k] = torch.cat([shard[k] for shard in run_covariances_list], dim=0) + run_covariances[k] = torch.cat( + [shard[k] for shard in run_covariances_list], dim=0 + ) run_covariances = TensorDict(run_covariances) - diff = ground_truth_covariances.sub(run_covariances).div(ground_truth_covariances).abs() + diff = ( + ground_truth_covariances.sub(run_covariances) + .div(ground_truth_covariances) + .abs() + ) rtol = 1e-10 atol = 0 - equal_dict = ground_truth_covariances.allclose(run_covariances, rtol=rtol, atol=atol) + equal_dict = ground_truth_covariances.allclose( + run_covariances, rtol=rtol, atol=atol + ) if all(equal_dict.values()): print(f"{covariance_type} covariances match") diff --git a/tests/ekfac_tests/test_eigenvalue_correction.py b/tests/ekfac_tests/test_eigenvalue_correction.py index 8f6576c..5c3fe8b 100644 --- a/tests/ekfac_tests/test_eigenvalue_correction.py +++ b/tests/ekfac_tests/test_eigenvalue_correction.py @@ -29,7 +29,7 @@ def test_eigenvalue_correction(ground_truth_path, run_path): lambda_run = TensorDict(lambda_run) - total_processed_run_path = os.path.join(run_path, "total_processed_lambda.pt") + total_processed_run_path = os.path.join(run_path, "total_processed_lambda_correction.pt") total = torch.load(total_processed_run_path).to(device=lambda_run[list(lambda_run.keys())[0]].device) lambda_run.div_(total) rtol = 1e-10 diff --git a/tests/ekfac_tests/test_eigenvectors.py b/tests/ekfac_tests/test_eigenvectors.py index 0526d01..79eb5f6 100644 --- a/tests/ekfac_tests/test_eigenvectors.py +++ b/tests/ekfac_tests/test_eigenvectors.py @@ -23,16 +23,21 @@ def test_eigenvectors( world_size = len(os.listdir(eigenvectors_run_path)) # number of shards # load run eigenvectors shards and concatenate them run_eigenvectors_shards = [ - os.path.join(eigenvectors_run_path, f"shard_{rank}.safetensors") for rank in range(world_size) + os.path.join(eigenvectors_run_path, f"shard_{rank}.safetensors") + for rank in range(world_size) ] run_eigenvectors_list = [(load_file(shard)) for shard in run_eigenvectors_shards] run_eigenvectors = {} for k, v in run_eigenvectors_list[0].items(): - run_eigenvectors[k] = torch.cat([shard[k] for shard in run_eigenvectors_list], dim=0) + run_eigenvectors[k] = torch.cat( + [shard[k] for shard in run_eigenvectors_list], dim=0 + ) run_eigenvectors = TensorDict(run_eigenvectors) - equal_dict = ground_truth_eigenvectors.allclose(run_eigenvectors, atol=0, rtol=1e-10) + equal_dict = ground_truth_eigenvectors.allclose( + run_eigenvectors, atol=0, rtol=1e-10 + ) if all(equal_dict.values()): print(f"{eigenvector_type} eigenvectors match!") @@ -47,7 +52,9 @@ def test_eigenvectors( # Find location of max difference max_diff_flat_idx = torch.argmax(diff[k]) max_diff_idx = torch.unravel_index(max_diff_flat_idx, diff[k].shape) - relative_diff = 100 * max_diff[k] / ground_truth_eigenvectors[k][max_diff_idx].abs() + relative_diff = ( + 100 * max_diff[k] / ground_truth_eigenvectors[k][max_diff_idx].abs() + ) if max_diff[k] < 1e-6 and relative_diff < 1e-3: print(f"Eigenvector {k} small differences within tolerance.") else: @@ -55,7 +62,6 @@ def test_eigenvectors( f"Eigenvalue corrections {k} does not match with absolute difference {max_diff[k]:.3f} and max " f"rel. difference {relative_diff:.3f} %!" ) - print("\n") print("-*" * 50) diff --git a/tests/ekfac_tests/test_utils.py b/tests/ekfac_tests/test_utils.py new file mode 100644 index 0000000..d3d6f10 --- /dev/null +++ b/tests/ekfac_tests/test_utils.py @@ -0,0 +1,27 @@ +"""Common utilities for EKFAC tests.""" + +import os +import random + +import numpy as np +import torch + + +def set_all_seeds(seed: int = 42) -> None: + """Set all random seeds for reproducibility.""" + # Set all random seeds + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + # Force deterministic behavior + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + # Set environment variables for additional determinism + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" diff --git a/uv.lock b/uv.lock index 5e30004..4b816c5 100644 --- a/uv.lock +++ b/uv.lock @@ -190,6 +190,8 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "pre-commit" }, + { name = "pyright" }, + { name = "pytest" }, ] example = [ { name = "trl" }, @@ -210,6 +212,8 @@ requires-dist = [ { name = "natsort" }, { name = "peft", specifier = ">=0.16.0" }, { name = "pre-commit", marker = "extra == 'dev'" }, + { name = "pyright", marker = "extra == 'dev'" }, + { name = "pytest", marker = "extra == 'dev'" }, { name = "simple-parsing" }, { name = "torch" }, { name = "transformers" }, @@ -701,6 +705,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "ipykernel" version = "6.30.1" @@ -1410,6 +1423,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/39/979e8e21520d4e47a0bbe349e2713c0aac6f3d853d0e5b34d76206c439aa/platformdirs-4.3.8-py3-none-any.whl", hash = "sha256:ff7059bb7eb1179e2685604f4aaf157cfd9535242bd23742eadc3c13542139b4", size = 18567, upload-time = "2025-05-07T22:47:40.376Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "pre-commit" version = "4.2.0" @@ -1621,6 +1643,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyright" +version = "1.1.406" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/16/6b4fbdd1fef59a0292cbb99f790b44983e390321eccbc5921b4d161da5d1/pyright-1.1.406.tar.gz", hash = "sha256:c4872bc58c9643dac09e8a2e74d472c62036910b3bd37a32813989ef7576ea2c", size = 4113151, upload-time = "2025-10-02T01:04:45.488Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/a2/e309afbb459f50507103793aaef85ca4348b66814c86bc73908bdeb66d12/pyright-1.1.406-py3-none-any.whl", hash = "sha256:1d81fb43c2407bf566e97e57abb01c811973fdb21b2df8df59f870f688bdca71", size = 5980982, upload-time = "2025-10-02T01:04:43.137Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1979,6 +2032,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/f2/fd673d979185f5dcbac4be7d09461cbb99751554ffb6718d0013af8604cb/tokenizers-0.21.4-cp39-abi3-win_amd64.whl", hash = "sha256:475d807a5c3eb72c59ad9b5fcdb254f6e17f53dfcbb9903233b0dfa9c943b597", size = 2507568, upload-time = "2025-07-28T15:48:55.456Z" }, ] +[[package]] +name = "tomli" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, + { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, + { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, + { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, + { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, + { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, + { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, + { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, + { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, + { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, + { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, + { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, + { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, + { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, + { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, + { url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" }, + { url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" }, + { url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" }, + { url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" }, + { url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" }, + { url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" }, + { url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" }, + { url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" }, + { url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" }, + { url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" }, + { url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" }, + { url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" }, + { url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" }, + { url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" }, + { url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" }, + { url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" }, + { url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" }, + { url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" }, + { url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" }, + { url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" }, + { url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" }, + { url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" }, + { url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" }, + { url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, +] + [[package]] name = "torch" version = "2.7.1"