We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 4552e91 + 96bdf12 commit 1706c96Copy full SHA for 1706c96
benchmark/data_frame_text_benchmark.py
@@ -457,7 +457,8 @@ def main_torch(
457
458
if dataset.task_type == TaskType.BINARY_CLASSIFICATION:
459
out_channels = 1
460
- loss_fun = BCEWithLogitsLoss()
+ label_imbalance = sum(train_tensor_frame.y) / len(train_tensor_frame.y)
461
+ loss_fun = BCEWithLogitsLoss(pos_weight=1 / label_imbalance)
462
metric_computer = AUROC(task='binary').to(device)
463
higher_is_better = True
464
elif dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
0 commit comments