Skip to content

endless_cl_sim semantic segmentation accuracy calculation bug #1688

@ZexinLi0w0

Description

@ZexinLi0w0

🐛 Describe the bug
When working for Endless Continual Learning Simulator, specific for semantic segmentation scenario. Integrating accuracy_metrics in evaluation plugin as

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=logger,
    )

It causes crashed bugs due to semantic segmentation requiring pixel level check, which is different from traditional classification.

I check the accuracy calculation in source code, it uses

        # Check if logits or labels
        if len(predicted_y.shape) > 1:
            # Logits -> transform to labels
            predicted_y = torch.max(predicted_y, 1)[1]

        if len(true_y.shape) > 1:
            # Logits -> transform to labels
            true_y = torch.max(true_y, 1)[1]

        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = len(true_y)
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)

However, this assumes the label is one dimension only. In the semantic segmentation task, accuracy should be calculated per pixel.

An example to show why this does not work:

For training process, the input:
predicted_y.shape is [batch_size, num_classes, height, width]
true_y.shape is [batch_size, height, width]

This code will change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, width]

This makes dimension mismatch crash in line

 true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))

🐜 To Reproduce
A minimal working example code

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from avalanche.benchmarks.classic import EndlessCLSim
from avalanche.training import Naive
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    loss_metrics,
    ram_usage_metrics,
    timing_metrics,
    MAC_metrics,
)
from avalanche.logging import InteractiveLogger, CSVLogger
from avalanche.models import pytorchcv_wrapper
import argparse
import random
import numpy as np

# Set seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device index; -1 for CPU")
    parser.add_argument("--semseg", action="store_true", default=True, help="Enable semantic segmentation mode")
    parser.add_argument("--dataset_root", type=str, default=".", help="Dataset root")
    parser.add_argument("--scenario", type=str, default="Classes", choices=["Classes", "Illumination", "Weather"])
    parser.add_argument("--training_bs", type=int, default=16)
    parser.add_argument("--eval_bs", type=int, default=16)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--epoch", type=int, default=1)
    args = parser.parse_args()

    device = torch.device(f"cuda:{args.cuda}" if args.cuda != -1 else "cpu")

    # Build model: use resnet20 (from pytorchcv_wrapper) with default settings for cifar10
    # (default backbone is designed for classification; we adjust it for semantic segmentation)
    model = pytorchcv_wrapper.resnet("cifar10", depth=20, pretrained=False)
    model.to(device)

    if args.semseg:
        # For semantic segmentation, we remove the final pooling and replace the classifier head.
        model.final_pool = nn.Identity()
        # Here we assume that the feature extractor outputs features with 64 channels.
        # We replace the classifier head with a segmentation head:
        model.output = nn.Sequential(
            nn.Conv2d(64, 512, kernel_size=3, padding=1),  # Increase feature depth
            nn.ReLU(),
            nn.Conv2d(512, 8, kernel_size=1)  # 8 segmentation classes
        )

        # Override the forward function: extract features, apply segmentation head,
        # and upsample the result to the original input spatial dimensions.
        def _seg_forward(x):
            input_size = x.shape[-2:]  # e.g., (135, 240)
            x = model.features(x)      # features, shape: [N, 64, H_feat, W_feat]
            x = model.final_pool(x)    # Identity (keeps current spatial size)
            x = model.output(x)        # logits, shape: [N, num_classes, H_feat, W_feat]
            x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)
            return x

        model.forward = _seg_forward

    # Create the EndlessCLSim benchmark (only semantic segmentation is enabled)
    benchmark = EndlessCLSim(
        scenario=args.scenario,
        sequence_order=None,
        task_order=None,
        semseg=args.semseg,
        dataset_root=args.dataset_root,
    )

    # Retrieve training and testing streams
    train_stream = benchmark.train_stream
    test_stream = benchmark.test_stream

    # Set up optimizer and loss (using CrossEntropyLoss, which expects:
    # model output shape: [N, num_classes, H, W] and target shape: [N, H, W])
    optimizer = Adam(model.parameters(), lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    # Set up loggers (optional)
    interactive_logger = InteractiveLogger()
    csv_logger = CSVLogger("log_semseg.csv")
    logger = [interactive_logger, csv_logger]

    # Set up the evaluation plugin (using the same metrics as in your full code)
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loggers=logger,
    )

    # Create the continual learning strategy (Naive in this case)
    cl_strategy = Naive(
        model,
        optimizer,
        criterion,
        train_mb_size=args.training_bs,
        train_epochs=args.epoch,
        eval_mb_size=args.eval_bs,
        device=device,
        evaluator=eval_plugin,
    )

    print("Starting experiment...")
    for experience in train_stream:
        cl_strategy.train(experience)
        res = cl_strategy.eval(test_stream)
        print("Evaluation results:", res)

    print("Experiment completed.")

🐝 Expected behavior
For the training process, the input:
predicted_y.shape is [batch_size, num_classes, height, width]
true_y.shape is [batch_size, height, width]

Accuracy caculation code need change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, height, width]

For the evaluation process, the input:
predicted_y.shape is [batch_size, height, width]
true_y.shape is [batch_size, num_classes, height, width]

Accuracy caculation code need change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, height, width]

Currently, I modify avalanche/evaluation/metrics/accuracy.py to bypass this error:

        '''
        # Check if logits or labels
        if len(predicted_y.shape) > 1:
            # Logits -> transform to labels
            predicted_y = torch.max(predicted_y, 1)[1]
        if len(true_y.shape) > 1:
            # Logits -> transform to labels
            true_y = torch.max(true_y, 1)[1]
        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = len(true_y)
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)
        '''

        if predicted_y.dim() > 3:
            predicted_y = torch.argmax(predicted_y, dim=1)
        if true_y.dim() > 3:
            true_y = torch.argmax(true_y, dim=1)
        if predicted_y.shape != true_y.shape:
            raise ValueError(f"Size mismatch: predicted_y shape {predicted_y.shape} vs true_y shape {true_y.shape}")

        true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
        total_patterns = true_y.numel()
        self._mean_accuracy.update(true_positives / total_patterns, total_patterns)

🐞 Screenshots
If applicable, add screenshots to help explain your problem.

🦋 Additional context
Add any other context about the problem here like your python setup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions