@@ -263,8 +263,14 @@ def teardown(self) -> None:
263263 os .makedirs (gradient_path , exist_ok = True )
264264
265265 # Save sharded covariance matrices
266- save_file (self .A_cov_dict , os .path .join (activation_path , f"shard_{ self .rank } .safetensors" ))
267- save_file (self .S_cov_dict , os .path .join (gradient_path , f"shard_{ self .rank } .safetensors" ))
266+ save_file (
267+ self .A_cov_dict ,
268+ os .path .join (activation_path , f"shard_{ self .rank } .safetensors" ),
269+ )
270+ save_file (
271+ self .S_cov_dict ,
272+ os .path .join (gradient_path , f"shard_{ self .rank } .safetensors" ),
273+ )
268274
269275
270276@dataclass (kw_only = True )
@@ -286,11 +292,15 @@ def setup(self) -> None:
286292 """Load eigenvectors and initialize storage."""
287293 # Load precomputed eigenvectors
288294 self .eigen_a = load_file (
289- os .path .join (self .path , f"activation_eigen_sharded/shard_{ self .rank } .safetensors" ),
295+ os .path .join (
296+ self .path , f"activation_eigen_sharded/shard_{ self .rank } .safetensors"
297+ ),
290298 device = f"cuda:{ self .rank } " ,
291299 )
292300 self .eigen_g = load_file (
293- os .path .join (self .path , f"gradient_eigen_sharded/shard_{ self .rank } .safetensors" ),
301+ os .path .join (
302+ self .path , f"gradient_eigen_sharded/shard_{ self .rank } .safetensors"
303+ ),
294304 device = f"cuda:{ self .rank } " ,
295305 )
296306
@@ -303,7 +313,9 @@ def forward_hook(self, name: str, a: Tensor) -> None:
303313 # a shape: [N, S, I]
304314
305315 # Transform: a @ eigen_a
306- transformed = self .shard_computer ._matmul (vector_nsa = a , matrix_cb = self .eigen_a [name ]) # shape [N, S, I]
316+ transformed = self .shard_computer ._matmul (
317+ vector_nsa = a , matrix_cb = self .eigen_a [name ]
318+ ) # shape [N, S, I]
307319
308320 # Cache for use in backward pass
309321 self .transformed_a_cache [name ] = transformed
@@ -313,11 +325,15 @@ def backward_hook(self, name: str, g: Tensor) -> None:
313325 # g shape: [N, S, O]
314326
315327 # Transform: g @ eigen_g
316- transformed_g = self .shard_computer ._matmul (vector_nsa = g , matrix_cb = self .eigen_g [name ]) # shape [N, S, O]
328+ transformed_g = self .shard_computer ._matmul (
329+ vector_nsa = g , matrix_cb = self .eigen_g [name ]
330+ ) # shape [N, S, O]
317331
318332 # Compute outer product: sum_n (transformed_a_n^T @ transformed_g_n)
319333 # Einstein notation: [N, S, I] x [N, S, O] -> [N, O, I]
320- transformed_grad_shard = torch .einsum ("N S I, N S O -> N O I" , self .transformed_a_cache [name ], transformed_g )
334+ transformed_grad_shard = torch .einsum (
335+ "N S I, N S O -> N O I" , self .transformed_a_cache [name ], transformed_g
336+ )
321337
322338 # Square and sum over batch
323339 transformed_grad_shard = (transformed_grad_shard ** 2 ).sum (dim = 0 ).contiguous ()
@@ -333,15 +349,26 @@ def backward_hook(self, name: str, g: Tensor) -> None:
333349
334350 # Accumulate (with CPU offloading for memory efficiency)
335351 if name not in self .eigenvalue_corrections :
336- self .eigenvalue_corrections [name ] = transformed_grad_shard [start_row :end_row , :].contiguous ()
352+ self .eigenvalue_corrections [name ] = transformed_grad_shard [
353+ start_row :end_row , :
354+ ].contiguous ()
337355 else :
338- self .eigenvalue_corrections [name ] = self .eigenvalue_corrections [name ].to (device = self .device )
339- self .eigenvalue_corrections [name ].add_ (transformed_grad_shard [start_row :end_row , :].contiguous ())
340- self .eigenvalue_corrections [name ] = self .eigenvalue_corrections [name ].to (device = "cpu" , non_blocking = False )
356+ self .eigenvalue_corrections [name ] = self .eigenvalue_corrections [name ].to (
357+ device = self .device
358+ )
359+ self .eigenvalue_corrections [name ].add_ (
360+ transformed_grad_shard [start_row :end_row , :].contiguous ()
361+ )
362+ self .eigenvalue_corrections [name ] = self .eigenvalue_corrections [name ].to (
363+ device = "cpu" , non_blocking = False
364+ )
341365
342366 def teardown (self ) -> None :
343367 """Save eigenvalue corrections to disk."""
344368 output_path = os .path .join (self .path , "eigenvalue_correction_sharded" )
345369 os .makedirs (output_path , exist_ok = True )
346370
347- save_file (self .eigenvalue_corrections , os .path .join (output_path , f"shard_{ self .rank } .safetensors" ))
371+ save_file (
372+ self .eigenvalue_corrections ,
373+ os .path .join (output_path , f"shard_{ self .rank } .safetensors" ),
374+ )
0 commit comments