1
- from __future__ import annotations
2
-
3
1
import gc
4
2
import os
5
3
import shutil
6
4
from abc import ABC
7
5
from collections import defaultdict
8
- from typing import TYPE_CHECKING , Any , Callable , Literal
6
+ from collections .abc import Callable , Mapping
7
+ from typing import Any , Literal
9
8
10
9
import numpy as np
11
10
import torch
12
11
import wandb
13
12
from sklearn .metrics import f1_score
14
13
from torch import BoolTensor , Tensor , nn
15
14
from torch .nn .functional import softmax
15
+ from torch .utils .data import DataLoader
16
16
from torch .utils .tensorboard import SummaryWriter
17
17
from tqdm import tqdm
18
18
19
19
from aviary import ROOT
20
-
21
- if TYPE_CHECKING :
22
- from collections .abc import Mapping
23
-
24
- from torch .utils .data import DataLoader
25
-
26
- from aviary .data import InMemoryDataLoader
20
+ from aviary .data import InMemoryDataLoader , Normalizer
27
21
28
22
TaskType = Literal ["regression" , "classification" ]
29
23
@@ -129,6 +123,14 @@ def fit(
129
123
for metric , val in metrics .items ():
130
124
writer .add_scalar (f"{ task } /train/{ metric } " , val , epoch )
131
125
126
+ if writer == "wandb" :
127
+ flat_train_metrics = {}
128
+ for task , metrics in train_metrics .items ():
129
+ for metric , val in metrics .items ():
130
+ flat_train_metrics [f"train_{ task } _{ metric .lower ()} " ] = val
131
+ flat_train_metrics ["epoch" ] = epoch
132
+ wandb .log (flat_train_metrics )
133
+
132
134
# Validation
133
135
if val_loader is not None :
134
136
with torch .no_grad ():
@@ -149,6 +151,14 @@ def fit(
149
151
f"{ task } /validation/{ metric } " , val , epoch
150
152
)
151
153
154
+ if writer == "wandb" :
155
+ flat_val_metrics = {}
156
+ for task , metrics in val_metrics .items ():
157
+ for metric , val in metrics .items ():
158
+ flat_val_metrics [f"val_{ task } _{ metric .lower ()} " ] = val
159
+ flat_val_metrics ["epoch" ] = epoch
160
+ wandb .log (flat_val_metrics )
161
+
152
162
# TODO test all tasks to see if they are best,
153
163
# save a best model if any is best.
154
164
# TODO what are the costs of this approach.
@@ -207,9 +217,6 @@ def fit(
207
217
# catch memory leak
208
218
gc .collect ()
209
219
210
- if writer == "wandb" :
211
- wandb .log ({"train" : train_metrics , "validation" : val_metrics })
212
-
213
220
except KeyboardInterrupt :
214
221
pass
215
222
@@ -271,7 +278,11 @@ def evaluate(
271
278
mixed_loss : Tensor = 0 # type: ignore[assignment]
272
279
273
280
for target_name , targets , output , normalizer in zip (
274
- self .target_names , targets_list , outputs , normalizer_dict .values ()
281
+ self .target_names ,
282
+ targets_list ,
283
+ outputs ,
284
+ normalizer_dict .values (),
285
+ strict = False ,
275
286
):
276
287
task , loss_func = loss_dict [target_name ]
277
288
target_metrics = epoch_metrics [target_name ]
@@ -318,7 +329,7 @@ def evaluate(
318
329
else :
319
330
raise ValueError (f"invalid task: { task } " )
320
331
321
- epoch_metrics [ target_name ] ["Loss" ].append (loss .cpu ().item ())
332
+ target_metrics ["Loss" ].append (loss .cpu ().item ())
322
333
323
334
# NOTE multitasking currently just uses a direct sum of individual
324
335
# target losses this should be okay but is perhaps sub-optimal
@@ -396,11 +407,13 @@ def predict(
396
407
# for multitask learning
397
408
targets = tuple (
398
409
torch .cat (targets , dim = 0 ).view (- 1 ).cpu ().numpy ()
399
- for targets in zip (* test_targets )
410
+ for targets in zip (* test_targets , strict = False )
411
+ )
412
+ predictions = tuple (
413
+ torch .cat (preds , dim = 0 ) for preds in zip (* test_preds , strict = False )
400
414
)
401
- predictions = tuple (torch .cat (preds , dim = 0 ) for preds in zip (* test_preds ))
402
415
# identifier columns
403
- ids = tuple (np .concatenate (x ) for x in zip (* test_ids ))
416
+ ids = tuple (np .concatenate (x ) for x in zip (* test_ids , strict = False ))
404
417
return targets , predictions , ids
405
418
406
419
@torch .no_grad ()
@@ -445,83 +458,6 @@ def __repr__(self) -> str:
445
458
return f"{ cls_name } with { n_params :,} trainable params at { n_epochs :,} epochs"
446
459
447
460
448
- class Normalizer :
449
- """Normalize a Tensor and restore it later."""
450
-
451
- def __init__ (self ) -> None :
452
- """Initialize Normalizer with mean 0 and std 1."""
453
- self .mean = torch .tensor (0 )
454
- self .std = torch .tensor (1 )
455
-
456
- def fit (self , tensor : Tensor , dim : int = 0 , keepdim : bool = False ) -> None :
457
- """Compute the mean and standard deviation of the given tensor.
458
-
459
- Args:
460
- tensor (Tensor): Tensor to determine the mean and standard deviation over.
461
- dim (int, optional): Which dimension to take mean and standard deviation
462
- over. Defaults to 0.
463
- keepdim (bool, optional): Whether to keep the reduced dimension in Tensor.
464
- Defaults to False.
465
- """
466
- self .mean = torch .mean (tensor , dim , keepdim )
467
- self .std = torch .std (tensor , dim , keepdim )
468
-
469
- def norm (self , tensor : Tensor ) -> Tensor :
470
- """Normalize a Tensor.
471
-
472
- Args:
473
- tensor (Tensor): Tensor to be normalized
474
-
475
- Returns:
476
- Tensor: Normalized Tensor
477
- """
478
- return (tensor - self .mean ) / self .std
479
-
480
- def denorm (self , normed_tensor : Tensor ) -> Tensor :
481
- """Restore normalized Tensor to original.
482
-
483
- Args:
484
- normed_tensor (Tensor): Tensor to be restored
485
-
486
- Returns:
487
- Tensor: Restored Tensor
488
- """
489
- return normed_tensor * self .std + self .mean
490
-
491
- def state_dict (self ) -> dict [str , Tensor ]:
492
- """Get Normalizer parameters mean and std.
493
-
494
- Returns:
495
- dict[str, Tensor]: Dictionary storing Normalizer parameters.
496
- """
497
- return {"mean" : self .mean , "std" : self .std }
498
-
499
- def load_state_dict (self , state_dict : dict [str , Tensor ]) -> None :
500
- """Overwrite Normalizer parameters given a new state_dict.
501
-
502
- Args:
503
- state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
504
- """
505
- self .mean = state_dict ["mean" ].cpu ()
506
- self .std = state_dict ["std" ].cpu ()
507
-
508
- @classmethod
509
- def from_state_dict (cls , state_dict : dict [str , Tensor ]) -> Normalizer :
510
- """Create a new Normalizer given a state_dict.
511
-
512
- Args:
513
- state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
514
-
515
- Returns:
516
- Normalizer
517
- """
518
- instance = cls ()
519
- instance .mean = state_dict ["mean" ].cpu ()
520
- instance .std = state_dict ["std" ].cpu ()
521
-
522
- return instance
523
-
524
-
525
461
def save_checkpoint (
526
462
state : dict [str , Any ], is_best : bool , model_name : str , run_id : int
527
463
) -> None :
@@ -662,3 +598,12 @@ def masked_min(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
662
598
x_inf = x .float ().masked_fill (~ mask , float ("inf" ))
663
599
x_min , _ = x_inf .min (dim = dim )
664
600
return x_min
601
+
602
+
603
+ AGGREGATORS : dict [str , Callable [[Tensor , BoolTensor , int ], Tensor ]] = {
604
+ "mean" : masked_mean ,
605
+ "std" : masked_std ,
606
+ "max" : masked_max ,
607
+ "min" : masked_min ,
608
+ "sum" : lambda x , mask , dim : (x * mask ).sum (dim = dim ),
609
+ }
0 commit comments