-
-
Notifications
You must be signed in to change notification settings - Fork 307
Description
🐛 Describe the bug
When working for Endless Continual Learning Simulator, specific for semantic segmentation scenario. Integrating accuracy_metrics in evaluation plugin as
eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loggers=logger,
)
It causes crashed bugs due to semantic segmentation requiring pixel level check, which is different from traditional classification.
I check the accuracy calculation in source code, it uses
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
total_patterns = len(true_y)
self._mean_accuracy.update(true_positives / total_patterns, total_patterns)
However, this assumes the label is one dimension only. In the semantic segmentation task, accuracy should be calculated per pixel.
An example to show why this does not work:
For training process, the input:
predicted_y.shape is [batch_size, num_classes, height, width]
true_y.shape is [batch_size, height, width]
This code will change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, width]
This makes dimension mismatch crash in line
true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
🐜 To Reproduce
A minimal working example code
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from avalanche.benchmarks.classic import EndlessCLSim
from avalanche.training import Naive
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import (
forgetting_metrics,
accuracy_metrics,
loss_metrics,
ram_usage_metrics,
timing_metrics,
MAC_metrics,
)
from avalanche.logging import InteractiveLogger, CSVLogger
from avalanche.models import pytorchcv_wrapper
import argparse
import random
import numpy as np
# Set seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", type=int, default=0, help="Use CUDA device index; -1 for CPU")
parser.add_argument("--semseg", action="store_true", default=True, help="Enable semantic segmentation mode")
parser.add_argument("--dataset_root", type=str, default=".", help="Dataset root")
parser.add_argument("--scenario", type=str, default="Classes", choices=["Classes", "Illumination", "Weather"])
parser.add_argument("--training_bs", type=int, default=16)
parser.add_argument("--eval_bs", type=int, default=16)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epoch", type=int, default=1)
args = parser.parse_args()
device = torch.device(f"cuda:{args.cuda}" if args.cuda != -1 else "cpu")
# Build model: use resnet20 (from pytorchcv_wrapper) with default settings for cifar10
# (default backbone is designed for classification; we adjust it for semantic segmentation)
model = pytorchcv_wrapper.resnet("cifar10", depth=20, pretrained=False)
model.to(device)
if args.semseg:
# For semantic segmentation, we remove the final pooling and replace the classifier head.
model.final_pool = nn.Identity()
# Here we assume that the feature extractor outputs features with 64 channels.
# We replace the classifier head with a segmentation head:
model.output = nn.Sequential(
nn.Conv2d(64, 512, kernel_size=3, padding=1), # Increase feature depth
nn.ReLU(),
nn.Conv2d(512, 8, kernel_size=1) # 8 segmentation classes
)
# Override the forward function: extract features, apply segmentation head,
# and upsample the result to the original input spatial dimensions.
def _seg_forward(x):
input_size = x.shape[-2:] # e.g., (135, 240)
x = model.features(x) # features, shape: [N, 64, H_feat, W_feat]
x = model.final_pool(x) # Identity (keeps current spatial size)
x = model.output(x) # logits, shape: [N, num_classes, H_feat, W_feat]
x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=False)
return x
model.forward = _seg_forward
# Create the EndlessCLSim benchmark (only semantic segmentation is enabled)
benchmark = EndlessCLSim(
scenario=args.scenario,
sequence_order=None,
task_order=None,
semseg=args.semseg,
dataset_root=args.dataset_root,
)
# Retrieve training and testing streams
train_stream = benchmark.train_stream
test_stream = benchmark.test_stream
# Set up optimizer and loss (using CrossEntropyLoss, which expects:
# model output shape: [N, num_classes, H, W] and target shape: [N, H, W])
optimizer = Adam(model.parameters(), lr=args.lr)
criterion = torch.nn.CrossEntropyLoss()
# Set up loggers (optional)
interactive_logger = InteractiveLogger()
csv_logger = CSVLogger("log_semseg.csv")
logger = [interactive_logger, csv_logger]
# Set up the evaluation plugin (using the same metrics as in your full code)
eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loggers=logger,
)
# Create the continual learning strategy (Naive in this case)
cl_strategy = Naive(
model,
optimizer,
criterion,
train_mb_size=args.training_bs,
train_epochs=args.epoch,
eval_mb_size=args.eval_bs,
device=device,
evaluator=eval_plugin,
)
print("Starting experiment...")
for experience in train_stream:
cl_strategy.train(experience)
res = cl_strategy.eval(test_stream)
print("Evaluation results:", res)
print("Experiment completed.")
🐝 Expected behavior
For the training process, the input:
predicted_y.shape is [batch_size, num_classes, height, width]
true_y.shape is [batch_size, height, width]
Accuracy caculation code need change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, height, width]
For the evaluation process, the input:
predicted_y.shape is [batch_size, height, width]
true_y.shape is [batch_size, num_classes, height, width]
Accuracy caculation code need change:
predicted_y.shape to [batch_size, height, width]
true_y.shape to [batch_size, height, width]
Currently, I modify avalanche/evaluation/metrics/accuracy.py to bypass this error:
'''
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
total_patterns = len(true_y)
self._mean_accuracy.update(true_positives / total_patterns, total_patterns)
'''
if predicted_y.dim() > 3:
predicted_y = torch.argmax(predicted_y, dim=1)
if true_y.dim() > 3:
true_y = torch.argmax(true_y, dim=1)
if predicted_y.shape != true_y.shape:
raise ValueError(f"Size mismatch: predicted_y shape {predicted_y.shape} vs true_y shape {true_y.shape}")
true_positives = float(torch.sum(torch.eq(predicted_y, true_y)))
total_patterns = true_y.numel()
self._mean_accuracy.update(true_positives / total_patterns, total_patterns)
🐞 Screenshots
If applicable, add screenshots to help explain your problem.
🦋 Additional context
Add any other context about the problem here like your python setup.