Skip to content
Merged
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
de2b6bd
feat: add optional gradient checkpointing to unet
Sep 3, 2025
66edcb5
fix: small ruff issue
Sep 3, 2025
e66e357
Update monai/networks/nets/unet.py
ferreirafabio80 Sep 4, 2025
feefcaa
docs: update docstrings
Sep 4, 2025
e112457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
f673ca1
fix: avoid BatchNorm subblocks
Sep 4, 2025
69540ff
fix: revert batch norm changes
Sep 4, 2025
42ec757
refactor: creates a subclass of UNet and overrides the get connection…
Oct 1, 2025
a2e8474
chore: remove use checkpointing from doc string
Oct 1, 2025
4c4782e
fix: linting issues
Oct 2, 2025
515c659
feat: add activation checkpointing to down and up paths to be more ef…
Oct 8, 2025
da5a3a4
refactor: move activation checkpointing wrapper to blocks
Nov 4, 2025
43dec88
chore: add docstrings to checkpointed unet
Nov 4, 2025
84c0f48
test: add checkpoint unet test
Nov 7, 2025
5805515
fix: change test name
Nov 7, 2025
1aa8e3c
fix: simplify test and make sure that checkpoint unet runs well in tr…
Nov 7, 2025
447d9f2
fix: set seed
Nov 7, 2025
b20a19e
fix: fix testing bugs
Nov 7, 2025
41f000f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2025
a068c0e
chore: add test docstrings
Nov 10, 2025
26668cd
DCO Remediation Commit for Fabio Ferreira <f.ferreira@qureight.com>
Nov 10, 2025
814fa80
fix: remove test script save
Nov 13, 2025
c45ee48
fix: tighten tolerance for numerical equivalence
Nov 13, 2025
4349d3f
chore: update doc strings
Nov 14, 2025
885993b
Merge branch 'dev' into feat/add_activation_checkpointing_to_unet
KumoLiu Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

import warnings
from collections.abc import Sequence
from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
Expand All @@ -24,6 +26,22 @@
__all__ = ["UNet", "Unet"]


class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training and torch.is_grad_enabled() and x.requires_grad:
try:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
except TypeError:
# Fallback for older PyTorch without `use_reentrant`
return cast(torch.Tensor, checkpoint(self.module, x))
return cast(torch.Tensor, self.module(x))


class UNet(nn.Module):
"""
Enhanced version of UNet which has residual units implemented with the ResidualUnit class.
Expand Down Expand Up @@ -69,6 +87,8 @@ class UNet(nn.Module):
if a conv layer is directly followed by a batch norm layer, bias should be False.
adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D).
Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`.
use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.

Examples::

Expand Down Expand Up @@ -118,6 +138,7 @@ def __init__(
dropout: float = 0.0,
bias: bool = True,
adn_ordering: str = "NDA",
use_checkpointing: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -146,6 +167,7 @@ def __init__(
self.dropout = dropout
self.bias = bias
self.adn_ordering = adn_ordering
self.use_checkpointing = use_checkpointing

def _create_block(
inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool
Expand Down Expand Up @@ -192,6 +214,8 @@ def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblo
subblock: block defining the next layer in the network.
Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
"""
if self.use_checkpointing:
subblock = _ActivationCheckpointWrapper(subblock)
return nn.Sequential(down_path, SkipConnection(subblock), up_path)

def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
Expand Down
Loading