@@ -130,105 +130,6 @@ def allocate_batches_test(doc_lengths: list[int], N: int, workers: Optional[int]
130130 return allocation
131131
132132
133- # %%
134- def compute_covariance (
135- rank : int ,
136- model : PreTrainedModel ,
137- data : Dataset ,
138- batches_world : list [list [list [int ]]],
139- device : torch .device ,
140- target_modules : Any ,
141- activation_covariances : dict [str , Tensor ],
142- gradient_covariances : dict [str , Tensor ],
143- ) -> dict [str , Any ]:
144- """Compute activation and gradient covariances for a single worker."""
145- total_processed = 0
146- batches = batches_world [rank ]
147- loss_list = []
148-
149- collector = GroundTruthCovarianceCollector (
150- model = model .base_model ,
151- activation_covariances = activation_covariances ,
152- gradient_covariances = gradient_covariances ,
153- target_modules = target_modules ,
154- )
155-
156- for sl in tqdm (batches , desc = f"Rank { rank } covariances" ):
157- batch = data [sl ]
158- x , y = pad_and_tensor (
159- batch ["input_ids" ],
160- labels = batch .get ("labels" ),
161- device = device ,
162- )
163-
164- total_processed += x .numel ()
165-
166- with collector :
167- logits = model (x ).logits
168- losses = F .cross_entropy (
169- logits [:, :- 1 ].reshape (- 1 , logits .size (- 1 )),
170- y [:, 1 :].flatten (),
171- reduction = "none" ,
172- ).reshape_as (y [:, 1 :])
173-
174- losses = losses .sum (1 )
175- losses .mean ().backward ()
176- loss_list .append (losses .detach ().cpu ())
177- model .zero_grad ()
178-
179- return {"losses" : loss_list , "total_processed_rank" : total_processed }
180-
181-
182- # %%
183- def compute_eigenvalue_correction_amortized (
184- rank : int ,
185- model : PreTrainedModel ,
186- data : Dataset ,
187- batches_world : list [list [list [int ]]],
188- device : torch .device ,
189- target_modules : Any ,
190- eigenvalue_corrections : dict [str , Tensor ],
191- eigenvectors_activations : dict [str , Tensor ],
192- eigenvectors_gradients : dict [str , Tensor ],
193- ) -> dict [str , int ]:
194- """Compute eigenvalue corrections using the amortized method."""
195- total_processed = 0
196- batches = batches_world [rank ]
197-
198- collector = GroundTruthAmortizedLambdaCollector (
199- model = model .base_model ,
200- eigenvalue_corrections = eigenvalue_corrections ,
201- eigenvectors_activations = eigenvectors_activations ,
202- eigenvectors_gradients = eigenvectors_gradients ,
203- device = device ,
204- target_modules = target_modules ,
205- )
206-
207- for sl in tqdm (batches , desc = f"Rank { rank } eigenvalue corrections" ):
208- batch = data [sl ]
209- x , y = pad_and_tensor (
210- batch ["input_ids" ],
211- labels = batch .get ("labels" ),
212- device = device ,
213- )
214-
215- total_processed += x .numel ()
216-
217- with collector :
218- logits = model (x ).logits
219- losses = F .cross_entropy (
220- logits [:, :- 1 ].reshape (- 1 , logits .size (- 1 )),
221- y [:, 1 :].flatten (),
222- reduction = "none" ,
223- ).reshape_as (y [:, 1 :])
224-
225- losses = losses .sum (1 )
226- losses .mean ().backward ()
227- model .zero_grad ()
228-
229- return {"total_processed_rank" : total_processed }
230-
231-
232133# %% [markdown]
233134# ## 0. Hyperparameters
234135
@@ -418,6 +319,55 @@ def tokenize_and_allocate_step(
418319# ## 2. Compute activation and gradient covariance
419320
420321
322+ # %%
323+ def compute_covariance (
324+ rank : int ,
325+ model : PreTrainedModel ,
326+ data : Dataset ,
327+ batches_world : list [list [list [int ]]],
328+ device : torch .device ,
329+ target_modules : Any ,
330+ activation_covariances : dict [str , Tensor ],
331+ gradient_covariances : dict [str , Tensor ],
332+ ) -> dict [str , Any ]:
333+ """Compute activation and gradient covariances for a single worker."""
334+ total_processed = 0
335+ batches = batches_world [rank ]
336+ loss_list = []
337+
338+ collector = GroundTruthCovarianceCollector (
339+ model = model .base_model ,
340+ activation_covariances = activation_covariances ,
341+ gradient_covariances = gradient_covariances ,
342+ target_modules = target_modules ,
343+ )
344+
345+ for sl in tqdm (batches , desc = f"Rank { rank } covariances" ):
346+ batch = data [sl ]
347+ x , y = pad_and_tensor (
348+ batch ["input_ids" ],
349+ labels = batch .get ("labels" ),
350+ device = device ,
351+ )
352+
353+ total_processed += x .numel ()
354+
355+ with collector :
356+ logits = model (x ).logits
357+ losses = F .cross_entropy (
358+ logits [:, :- 1 ].reshape (- 1 , logits .size (- 1 )),
359+ y [:, 1 :].flatten (),
360+ reduction = "none" ,
361+ ).reshape_as (y [:, 1 :])
362+
363+ losses = losses .sum (1 )
364+ losses .mean ().backward ()
365+ loss_list .append (losses .detach ().cpu ())
366+ model .zero_grad ()
367+
368+ return {"losses" : loss_list , "total_processed_rank" : total_processed }
369+
370+
421371# %%
422372def compute_covariances_step (
423373 model : PreTrainedModel ,
@@ -567,6 +517,56 @@ def compute_eigenvectors_step(test_path: str, device: torch.device, dtype: torch
567517# ## 4. Compute eigenvalue correction
568518
569519
520+ # %%
521+ def compute_eigenvalue_correction_amortized (
522+ rank : int ,
523+ model : PreTrainedModel ,
524+ data : Dataset ,
525+ batches_world : list [list [list [int ]]],
526+ device : torch .device ,
527+ target_modules : Any ,
528+ eigenvalue_corrections : dict [str , Tensor ],
529+ eigenvectors_activations : dict [str , Tensor ],
530+ eigenvectors_gradients : dict [str , Tensor ],
531+ ) -> dict [str , int ]:
532+ """Compute eigenvalue corrections using the amortized method."""
533+ total_processed = 0
534+ batches = batches_world [rank ]
535+
536+ collector = GroundTruthAmortizedLambdaCollector (
537+ model = model .base_model ,
538+ eigenvalue_corrections = eigenvalue_corrections ,
539+ eigenvectors_activations = eigenvectors_activations ,
540+ eigenvectors_gradients = eigenvectors_gradients ,
541+ device = device ,
542+ target_modules = target_modules ,
543+ )
544+
545+ for sl in tqdm (batches , desc = f"Rank { rank } eigenvalue corrections" ):
546+ batch = data [sl ]
547+ x , y = pad_and_tensor (
548+ batch ["input_ids" ],
549+ labels = batch .get ("labels" ),
550+ device = device ,
551+ )
552+
553+ total_processed += x .numel ()
554+
555+ with collector :
556+ logits = model (x ).logits
557+ losses = F .cross_entropy (
558+ logits [:, :- 1 ].reshape (- 1 , logits .size (- 1 )),
559+ y [:, 1 :].flatten (),
560+ reduction = "none" ,
561+ ).reshape_as (y [:, 1 :])
562+
563+ losses = losses .sum (1 )
564+ losses .mean ().backward ()
565+ model .zero_grad ()
566+
567+ return {"total_processed_rank" : total_processed }
568+
569+
570570# %%
571571def compute_eigenvalue_corrections_step (
572572 model : PreTrainedModel ,
0 commit comments