Skip to content

Commit 96bdf12

Browse files
Handle label imbalance in binary classification tasks on text benchmark (#376)
Labels in the text benchmarks are imbalanced and weighting the positive labels improves performance. Experiments done on `fake` dataset (5% positive labels) with `text_embedded` and `RoBERTa` encodings: - `ResNet` result changes 91.1% -> 93.4% - `FTTransformer` result remains unchanged - `Trompt` result changes 95.2% -> 95.8% The differences were even more stark with distilled roberta, but we aren't reporting those anywhere so I didn't note them down. More results are pending --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 893678f commit 96bdf12

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

benchmark/data_frame_text_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ def main_torch(
457457

458458
if dataset.task_type == TaskType.BINARY_CLASSIFICATION:
459459
out_channels = 1
460-
loss_fun = BCEWithLogitsLoss()
460+
label_imbalance = sum(train_tensor_frame.y) / len(train_tensor_frame.y)
461+
loss_fun = BCEWithLogitsLoss(pos_weight=1 / label_imbalance)
461462
metric_computer = AUROC(task='binary').to(device)
462463
higher_is_better = True
463464
elif dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:

0 commit comments

Comments
 (0)