diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index 6251ea8e83..86b4e68864 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -232,6 +232,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.print_info: logger.info(f"Number of splits: {self.num_splits}") + if self.dim_split <= 1 and self.num_splits <= 1: + x = self.conv(x) + return x + # compute size of splits l = x.size(self.dim_split + 2) split_size = l // self.num_splits