diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index 615a3129..33a807a5 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -124,3 +124,41 @@ def test_main(self) -> None: def test_errors(self) -> None: with pytest.raises(TypeError, match="is expected to be `int`"): AdaptiveCalibrationError(task="multiclass", num_classes=None) + + def test_plot_binary(self) -> None: + """Test plot functionality for binary adaptive calibration error.""" + metric = AdaptiveCalibrationError(task="binary", num_bins=2, norm="l1") + metric.update( + torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]), + torch.as_tensor([0, 0, 1, 1, 1]), + ) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert ax[0].get_xlabel() == "Top-class Confidence (%)" + assert ax[0].get_ylabel() == "Success Rate (%)" + assert ax[1].get_xlabel() == "Top-class Confidence (%)" + assert ax[1].get_ylabel() == "Density (%)" + + plt.close(fig) + + def test_plot_multiclass(self) -> None: + """Test plot functionality for multiclass adaptive calibration error.""" + metric = AdaptiveCalibrationError(task="multiclass", num_bins=3, norm="l1", num_classes=3) + metric.update( + torch.as_tensor( + [ + [0.25, 0.20, 0.55], + [0.55, 0.05, 0.40], + [0.10, 0.30, 0.60], + [0.90, 0.05, 0.05], + ] + ), + torch.as_tensor([0, 1, 2, 0]), + ) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert ax[0].get_xlabel() == "Top-class Confidence (%)" + assert ax[0].get_ylabel() == "Success Rate (%)" + assert ax[1].get_xlabel() == "Top-class Confidence (%)" + assert ax[1].get_ylabel() == "Density (%)" + plt.close(fig) diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py index 2a452858..4465c383 100644 --- a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -10,6 +10,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.plot import _PLOT_OUT_TYPE def _equal_binning_bucketize( @@ -131,6 +132,44 @@ def compute(self) -> Tensor: accuracies = dim_zero_cat(self.accuracies) return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) + def plot(self) -> _PLOT_OUT_TYPE: + """Plot the adaptive calibration reliability diagram.""" + # Import here to avoid circular dependency + from .calibration_error import reliability_chart + + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( + confidences, accuracies, self.n_bins + ) + + np_acc_bin = acc_bin.cpu().numpy() + np_conf_bin = conf_bin.cpu().numpy() + np_prop_bin = prop_bin.cpu().numpy() + + # For visualization purposes, use uniform bin boundaries to match the plotting expectations + # Note: This is just for visualization - the actual adaptive computation uses adaptive binning + bin_boundaries = torch.linspace( + 0, + 1, + self.n_bins + 1, + dtype=torch.float, + device=confidences.device, + )[:-1] # Remove the last boundary to match the number of bins + + np_bin_boundaries = bin_boundaries.cpu().numpy() + + return reliability_chart( + accuracies=accuracies.cpu().numpy(), + confidences=confidences.cpu().numpy(), + bin_accuracies=np_acc_bin, + bin_confidences=np_conf_bin, + bin_sizes=np_prop_bin, + bins=np_bin_boundaries, + ) + class MulticlassAdaptiveCalibrationError(Metric): is_differentiable: bool = False @@ -175,6 +214,44 @@ def compute(self) -> Tensor: accuracies = dim_zero_cat(self.accuracies) return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) + def plot(self) -> _PLOT_OUT_TYPE: + """Plot the adaptive calibration reliability diagram.""" + # Import here to avoid circular dependency + from .calibration_error import reliability_chart + + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( + confidences, accuracies, self.n_bins + ) + + np_acc_bin = acc_bin.cpu().numpy() + np_conf_bin = conf_bin.cpu().numpy() + np_prop_bin = prop_bin.cpu().numpy() + + # For visualization purposes, use uniform bin boundaries to match the plotting expectations + # Note: This is just for visualization - the actual adaptive computation uses adaptive binning + bin_boundaries = torch.linspace( + 0, + 1, + self.n_bins + 1, + dtype=torch.float, + device=confidences.device, + )[:-1] # Remove the last boundary to match the number of bins + + np_bin_boundaries = bin_boundaries.cpu().numpy() + + return reliability_chart( + accuracies=accuracies.cpu().numpy(), + confidences=confidences.cpu().numpy(), + bin_accuracies=np_acc_bin, + bin_confidences=np_conf_bin, + bin_sizes=np_prop_bin, + bins=np_bin_boundaries, + ) + class AdaptiveCalibrationError: def __new__( diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py.backup b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py.backup new file mode 100644 index 00000000..1d3c0f51 --- /dev/null +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py.backup @@ -0,0 +1,349 @@ +from typing import Any, Literal + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +from torchmetrics.classification.calibration_error import ( + _binary_calibration_error_arg_validation, + _multiclass_calibration_error_arg_validation, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.plot import _PLOT_OUT_TYPE + + +def _equal_binning_bucketize( + confidences: Tensor, accuracies: Tensor, num_bins: int +) -> tuple[Tensor, Tensor, Tensor]: + """Compute bins for the adaptive calibration error. + + Args: + confidences: The confidence (i.e. predicted prob) of the top1 + prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + num_bins: Number of bins to use when computing adaptive calibration + error. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + """ + confidences, indices = torch.sort(confidences) + accuracies = accuracies[indices] + acc_bin, conf_bin = ( + accuracies.tensor_split(num_bins), + confidences.tensor_split(num_bins), + ) + count_bin = torch.as_tensor( + [len(cb) for cb in conf_bin], + dtype=confidences.dtype, + device=confidences.device, + ) + return ( + pad_sequence(acc_bin, batch_first=True).sum(1) / count_bin, + pad_sequence(conf_bin, batch_first=True).sum(1) / count_bin, + torch.as_tensor(count_bin) / len(confidences), + ) + + +def _ace_compute( + confidences: Tensor, + accuracies: Tensor, + num_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + debias: bool = False, +) -> Tensor: + """Compute the adaptive calibration error given the provided number of bins + and norm. + + Args: + confidences: The confidence (i.e. predicted prob) of the top1 + prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + num_bins: Number of bins to use when computing adaptive calibration + error. + norm: Norm function to use when computing calibration error. Defaults + to "l1". + debias: Apply debiasing to L2 norm computation as in + `Verified Uncertainty Calibration`. Defaults to False. + + Returns: + Tensor: Adaptive Calibration error scalar. + """ + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize(confidences, accuracies, num_bins) + + if norm == "l1": + return torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) + if norm == "max": + ace = torch.max(torch.abs(acc_bin - conf_bin)) + if norm == "l2": + ace = torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin) + if debias: # coverage: ignore + debias_bins = (acc_bin * (acc_bin - 1) * prop_bin) / ( + prop_bin * accuracies.size()[0] - 1 + ) + ace += torch.sum( + torch.nan_to_num(debias_bins) + ) # replace nans with zeros if nothing appeared in a bin + return torch.sqrt(ace) if ace > 0 else torch.tensor(0) + return ace + + +class BinaryAdaptiveCalibrationError(Metric): + is_differentiable = False + higher_is_better = False + full_state_update = False + + confidences: list[Tensor] + accuracies: list[Tensor] + + def __init__( + self, + n_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + r"""Adaptive Top-label Calibration Error for binary tasks.""" + super().__init__(**kwargs) + if ignore_index is not None: # coverage: ignore + raise ValueError("ignore_index is not supported for multiclass tasks.") + + if validate_args: + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) + self.n_bins = n_bins + self.norm = norm + + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Update metric states with predictions and targets.""" + confidences, preds = torch.max(probs, 1 - probs), torch.round(probs) + accuracies = preds == targets + self.confidences.append(confidences.float()) + self.accuracies.append(accuracies.float()) + + def compute(self) -> Tensor: + """Compute metric.""" + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) + + def plot(self) -> _PLOT_OUT_TYPE: + """Plot the adaptive calibration reliability diagram.""" + # Import here to avoid circular dependency + from .calibration_error import reliability_chart + + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( + confidences, accuracies, self.n_bins + ) + + np_acc_bin = acc_bin.cpu().numpy() + np_conf_bin = conf_bin.cpu().numpy() + np_prop_bin = prop_bin.cpu().numpy() + + # For adaptive binning, create bin boundaries that represent the adaptive splits + # We'll use quantiles to approximate the adaptive bin boundaries + sorted_confidences = torch.sort(confidences)[0] + n_samples = len(sorted_confidences) + + # Create bin boundaries based on quantiles + bin_boundaries = torch.zeros(self.n_bins + 1, device=confidences.device) + bin_boundaries[0] = 0.0 + bin_boundaries[-1] = 1.0 + + for i in range(1, self.n_bins): + quantile_idx = int(i * n_samples / self.n_bins) + if quantile_idx < n_samples: + bin_boundaries[i] = sorted_confidences[quantile_idx].item() + + np_bin_boundaries = bin_boundaries.cpu().numpy() + + return reliability_chart( + accuracies=accuracies.cpu().numpy(), + confidences=confidences.cpu().numpy(), + bin_accuracies=np_acc_bin, + bin_confidences=np_conf_bin, + bin_sizes=np_prop_bin, + bins=np_bin_boundaries, + ) + + +class MulticlassAdaptiveCalibrationError(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + confidences: list[Tensor] + accuracies: list[Tensor] + + def __init__( + self, + num_classes: int, + n_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + r"""Adaptive Top-label Calibration Error for multiclass tasks.""" + super().__init__(**kwargs) + if ignore_index is not None: # coverage: ignore + raise ValueError("ignore_index is not supported for multiclass tasks.") + + if validate_args: + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) + self.n_bins = n_bins + self.norm = norm + + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Update metric states with predictions and targets.""" + confidences, preds = torch.max(probs, 1) + accuracies = preds == targets + self.confidences.append(confidences.float()) + self.accuracies.append(accuracies.float()) + + def compute(self) -> Tensor: + """Compute metric.""" + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ace_compute(confidences, accuracies, self.n_bins, norm=self.norm) + + def plot(self) -> _PLOT_OUT_TYPE: + """Plot the adaptive calibration reliability diagram.""" + # Import here to avoid circular dependency + from .calibration_error import reliability_chart + + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( + confidences, accuracies, self.n_bins + ) + + np_acc_bin = acc_bin.cpu().numpy() + np_conf_bin = conf_bin.cpu().numpy() + np_prop_bin = prop_bin.cpu().numpy() + + # For adaptive binning, create bin boundaries that represent the adaptive splits + # We'll use quantiles to approximate the adaptive bin boundaries + sorted_confidences = torch.sort(confidences)[0] + n_samples = len(sorted_confidences) + + # Create bin boundaries based on quantiles + bin_boundaries = torch.zeros(self.n_bins + 1, device=confidences.device) + bin_boundaries[0] = 0.0 + bin_boundaries[-1] = 1.0 + + for i in range(1, self.n_bins): + quantile_idx = int(i * n_samples / self.n_bins) + if quantile_idx < n_samples: + bin_boundaries[i] = sorted_confidences[quantile_idx].item() + + np_bin_boundaries = bin_boundaries.cpu().numpy() + + return reliability_chart( + accuracies=accuracies.cpu().numpy(), + confidences=confidences.cpu().numpy(), + bin_accuracies=np_acc_bin, + bin_confidences=np_conf_bin, + bin_sizes=np_prop_bin, + bins=np_bin_boundaries, + ) + + +class AdaptiveCalibrationError: + def __new__( + cls, + task: Literal["binary", "multiclass"], + num_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + num_classes: int | None = None, + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + r"""Computes the Adaptive Top-label Calibration Error (ACE) for classification tasks. + + The Adaptive Calibration Error is a metric designed to measure the calibration + of predicted probabilities by dividing the probability space into bins that adapt + to the distribution of predicted probabilities. Unlike uniform binning, adaptive binning + ensures a more balanced representation of predictions across bins. + + This metric is particularly useful for datasets or models where predictions are + concentrated in certain regions of the probability space. + + Args: + task (str): Specifies the task type, either ``"binary"`` or ``"multiclass"``. + num_bins (int, optional): Number of bins to divide the probability space. Defaults to ``10``. + norm (str, optional): Specifies the type of norm to use: ``"l1"``, ``"l2"``, or ``"max"``. Defaults to ``"l1"``. + num_classes (int, optional): Number of classes for ``"multiclass"`` tasks. Required when task is ``"multiclass"``. + ignore_index (int, optional): Index to ignore during calculations. Defaults to ``None``. + validate_args (bool, optional): Whether to validate input arguments. Defaults to ``True``. + **kwargs (Any): Additional keyword arguments passed to the metric. + + Example: + + .. code-block:: python + + from torch_uncertainty.metrics.classification.adaptive_calibration_error import ( + AdaptiveCalibrationError, + ) + + # Binary classification example + predicted_probs = torch.tensor([0.95, 0.85, 0.15, 0.05]) + true_labels = torch.tensor([1, 1, 0, 0]) + + metric = CalibrationError( + task="binary", + num_bins=5, + norm="l1", + ) + + calibration_error = metric(predicted_probs, true_labels) + print(f"Calibration Error (Binary): {calibration_error}") + # Output : Calibration Error (Binary): 0.1 + + + Note: + - Adaptive binning adjusts the size of bins to ensure a more uniform distribution of samples across bins. + - If `task="multiclass"`, `num_classes` must be provided; otherwise, a :class:`TypeError` will be raised. + + Warning: + - Ensure that `num_classes` matches the actual number of classes in the dataset for multiclass tasks. + + References: + [1] `Nixon et al., Measuring calibration in deep learning, CVPR Workshops, 2019 + `_. + + .. seealso:: + - See `:class:`CalibrationError` for a metric that uses uniform binning. + """ + task = ClassificationTaskNoMultilabel.from_str(task) + kwargs.update( + { + "n_bins": num_bins, + "norm": norm, + "ignore_index": ignore_index, + "validate_args": validate_args, + } + ) + if task == ClassificationTaskNoMultilabel.BINARY: + return BinaryAdaptiveCalibrationError(**kwargs) + # task is ClassificationTaskNoMultilabel.MULTICLASS + if not isinstance(num_classes, int): + raise TypeError( + f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" + ) + return MulticlassAdaptiveCalibrationError(num_classes, **kwargs)