Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/metrics/classification/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,41 @@ def test_main(self) -> None:
def test_errors(self) -> None:
with pytest.raises(TypeError, match="is expected to be `int`"):
AdaptiveCalibrationError(task="multiclass", num_classes=None)

def test_plot_binary(self) -> None:
"""Test plot functionality for binary adaptive calibration error."""
metric = AdaptiveCalibrationError(task="binary", num_bins=2, norm="l1")
metric.update(
torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]),
torch.as_tensor([0, 0, 1, 1, 1]),
)
fig, ax = metric.plot()
assert isinstance(fig, plt.Figure)
assert ax[0].get_xlabel() == "Top-class Confidence (%)"
assert ax[0].get_ylabel() == "Success Rate (%)"
assert ax[1].get_xlabel() == "Top-class Confidence (%)"
assert ax[1].get_ylabel() == "Density (%)"

plt.close(fig)

def test_plot_multiclass(self) -> None:
"""Test plot functionality for multiclass adaptive calibration error."""
metric = AdaptiveCalibrationError(task="multiclass", num_bins=3, norm="l1", num_classes=3)
metric.update(
torch.as_tensor(
[
[0.25, 0.20, 0.55],
[0.55, 0.05, 0.40],
[0.10, 0.30, 0.60],
[0.90, 0.05, 0.05],
]
),
torch.as_tensor([0, 1, 2, 0]),
)
fig, ax = metric.plot()
assert isinstance(fig, plt.Figure)
assert ax[0].get_xlabel() == "Top-class Confidence (%)"
assert ax[0].get_ylabel() == "Success Rate (%)"
assert ax[1].get_xlabel() == "Top-class Confidence (%)"
assert ax[1].get_ylabel() == "Density (%)"
plt.close(fig)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE


def _equal_binning_bucketize(
Expand Down Expand Up @@ -131,6 +132,44 @@ def compute(self) -> Tensor:
accuracies = dim_zero_cat(self.accuracies)
return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm)

def plot(self) -> _PLOT_OUT_TYPE:
"""Plot the adaptive calibration reliability diagram."""
# Import here to avoid circular dependency
from .calibration_error import reliability_chart

confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)

with torch.no_grad():
acc_bin, conf_bin, prop_bin = _equal_binning_bucketize(
confidences, accuracies, self.n_bins
)

np_acc_bin = acc_bin.cpu().numpy()
np_conf_bin = conf_bin.cpu().numpy()
np_prop_bin = prop_bin.cpu().numpy()

# For visualization purposes, use uniform bin boundaries to match the plotting expectations
# Note: This is just for visualization - the actual adaptive computation uses adaptive binning
bin_boundaries = torch.linspace(
0,
1,
self.n_bins + 1,
dtype=torch.float,
device=confidences.device,
)[:-1] # Remove the last boundary to match the number of bins

np_bin_boundaries = bin_boundaries.cpu().numpy()

return reliability_chart(
accuracies=accuracies.cpu().numpy(),
confidences=confidences.cpu().numpy(),
bin_accuracies=np_acc_bin,
bin_confidences=np_conf_bin,
bin_sizes=np_prop_bin,
bins=np_bin_boundaries,
)


class MulticlassAdaptiveCalibrationError(Metric):
is_differentiable: bool = False
Expand Down Expand Up @@ -175,6 +214,44 @@ def compute(self) -> Tensor:
accuracies = dim_zero_cat(self.accuracies)
return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm)

def plot(self) -> _PLOT_OUT_TYPE:
"""Plot the adaptive calibration reliability diagram."""
# Import here to avoid circular dependency
from .calibration_error import reliability_chart

confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)

with torch.no_grad():
acc_bin, conf_bin, prop_bin = _equal_binning_bucketize(
confidences, accuracies, self.n_bins
)

np_acc_bin = acc_bin.cpu().numpy()
np_conf_bin = conf_bin.cpu().numpy()
np_prop_bin = prop_bin.cpu().numpy()

# For visualization purposes, use uniform bin boundaries to match the plotting expectations
# Note: This is just for visualization - the actual adaptive computation uses adaptive binning
bin_boundaries = torch.linspace(
0,
1,
self.n_bins + 1,
dtype=torch.float,
device=confidences.device,
)[:-1] # Remove the last boundary to match the number of bins

np_bin_boundaries = bin_boundaries.cpu().numpy()

return reliability_chart(
accuracies=accuracies.cpu().numpy(),
confidences=confidences.cpu().numpy(),
bin_accuracies=np_acc_bin,
bin_confidences=np_conf_bin,
bin_sizes=np_prop_bin,
bins=np_bin_boundaries,
)


class AdaptiveCalibrationError:
def __new__(
Expand Down
Loading