Skip to content

Conversation

@ferreirafabio80
Copy link

@ferreirafabio80 ferreirafabio80 commented Sep 3, 2025

Description

Introduces an optional use_checkpointing flag in the UNet implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory.

  • Implemented via a lightweight _ActivationCheckpointWrapper wrapper around sub-blocks.
  • Checkpointing is only applied during training to avoid overhead at inference.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 3, 2025

Walkthrough

Adds a new module monai/networks/blocks/activation_checkpointing.py defining ActivationCheckpointWrapper, a torch.nn.Module that wraps an inner module and applies torch.utils.checkpoint.checkpoint (using use_reentrant=False) in forward. Adds a new public CheckpointUNet(UNet) subclass in monai/networks/nets/unet.py that overrides _get_connection_block to wrap the connection subblock, down_path, and up_path with ActivationCheckpointWrapper. Updates public exports to include CheckpointUNet. Existing UNet class and aliases remain unchanged; UNet init not modified.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Areas to pay attention to:

  • monai/networks/blocks/activation_checkpointing.py: correct usage of torch.utils.checkpoint.checkpoint, the use_reentrant=False choice, input typing/shape assumptions, and gradient/device semantics.
  • monai/networks/nets/unet.py: integration of ActivationCheckpointWrapper in _get_connection_block, ensuring module wiring and forward signatures remain compatible.
  • Public exports: __all__ updated to include CheckpointUNet.
  • Potential impacts on state_dict/serialization and behavioral differences between training vs. evaluation when wrappers are present.

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding activation checkpointing capability to the UNet implementation.
Description check ✅ Passed The description covers the feature, implementation approach, and training-only application. Most checklist items are addressed; documentation update is not checked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@ferreirafabio80 ferreirafabio80 changed the title Feat/add activation checkpointing to unet feat: add activation checkpointing to unet Sep 3, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)

29-33: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.

Apply this diff:

 class _ActivationCheckpointWrapper(nn.Module):
-    def __init__(self, module: nn.Module) -> None:
+    """Apply activation checkpointing to the wrapped module during training."""
+    def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module

134-135: Document the new use_checkpointing arg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible with torch.no_grad, and preserves RNG by default.

Proposed docstring snippet to add under “Args”:

use_checkpointing: if True, apply activation checkpointing to internal sub-blocks
    during training to reduce memory at the cost of extra compute. Checkpointing
    is bypassed in eval and when gradients are disabled. Defaults to False.

I can open a docs patch and add a short example enabling the flag.


163-164: Static wrapping caveat: runtime flips won’t take effect.
After init, changing self.use_checkpointing won’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.

Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.


210-212: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrap down_path/up_path (with a warning about extra compute).

I can add a checkpoint_policy: Literal["subblock","all"] = "subblock" in __init__ and wire it here on request.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0968da2 and 66edcb5.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: build-docs
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)

16-21: Imports for checkpointing look good.
cast and checkpoint are appropriate for the new wrapper.

ferreirafabio80 and others added 3 commits September 4, 2025 15:36
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Fábio S. Ferreira <ferreira.fabio80@gmail.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)

29-43: Good guard + compatibility fallback.
Training/grad-enabled checks and use_reentrant=False with TypeError fallback are the right call. This addresses the prior review note.

🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)

29-43: Avoid per-iteration TypeError cost: detect use_reentrant support once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.

Apply:

@@
-class _ActivationCheckpointWrapper(nn.Module):
+_SUPPORTS_USE_REENTRANT: bool | None = None
+
+class _ActivationCheckpointWrapper(nn.Module):
@@
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        if self.training and torch.is_grad_enabled() and x.requires_grad:
-            try:
-                return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
-            except TypeError:
-                # Fallback for older PyTorch without `use_reentrant`
-                return cast(torch.Tensor, checkpoint(self.module, x))
-        return cast(torch.Tensor, self.module(x))
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            global _SUPPORTS_USE_REENTRANT
+            if _SUPPORTS_USE_REENTRANT is None:
+                try:
+                    # probe once
+                    checkpoint(self.module, x, use_reentrant=False)  # type: ignore[arg-type]
+                    _SUPPORTS_USE_REENTRANT = True
+                except TypeError:
+                    _SUPPORTS_USE_REENTRANT = False
+                except Exception:
+                    # do not change behavior on unexpected errors; fall back below
+                    _SUPPORTS_USE_REENTRANT = False
+            if _SUPPORTS_USE_REENTRANT:
+                return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+            return cast(torch.Tensor, checkpoint(self.module, x))
+        return cast(torch.Tensor, self.module(x))

Add outside the hunk (file header):

import inspect  # if you switch to signature probing instead of try/except

Note: PyTorch recommends passing use_reentrant explicitly going forward. (docs.pytorch.org)


29-43: TorchScript: make wrapper script-safe.
try/except and dynamic checkpoint calls won’t script. Short-circuit under scripting.

Apply:

@@
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if torch.jit.is_scripting():
+            # Avoid checkpoint in scripted graphs
+            return cast(torch.Tensor, self.module(x))

29-43: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).

Example:

@@
-class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+class _ActivationCheckpointWrapper(nn.Module):
+    """Wrap a module and apply activation checkpointing during training.
+
+    Args:
+        module: The submodule to checkpoint.
+
+    Returns:
+        torch.Tensor: Output tensor from the wrapped submodule.
+
+    Raises:
+        RuntimeError: If checkpoint fails at runtime.
+    """

90-92: Tighten the use_checkpointing docstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with
+            BatchNorm layers due to running-stat updates during recomputation. Defaults to False.

Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)


217-219: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.

-        if self.use_checkpointing:
-            subblock = _ActivationCheckpointWrapper(subblock)
+        if self.use_checkpointing:
+            has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules())
+            if has_bn:
+                warnings.warn(
+                    "Skipping activation checkpointing for this subblock (contains BatchNorm).",
+                    RuntimeWarning,
+                )
+            else:
+                subblock = _ActivationCheckpointWrapper(subblock)
         return nn.Sequential(down_path, SkipConnection(subblock), up_path)

Rationale: avoids double-updating BN stats during recomputation. (github.com)

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e66e357 and e112457.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)

16-16: LGTM: imports are correct and scoped.
cast and checkpoint are needed by the wrapper; no issues.

Also applies to: 20-20


141-141: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, and no_grad() bypass; document in release notes/configs.

Proposed minimal tests:

  • Forward/backward equivalence within tolerance for use_checkpointing={False,True} on a tiny UNet.
  • model.eval() and torch.no_grad() paths skip checkpoint (no checkpoint mocks invoked).
  • Autocast path under CUDA runs without dtype mismatches.

Do you want a patch with pytest tests added under tests/test_unet_checkpointing.py?


170-170: LGTM: stores flag on the instance.
No concerns.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)

29-51: Warn once for BN, and add param/return docstring per guidelines.

Avoid repeated RuntimeWarnings and document the wrapper’s contract.

 class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+    """Apply activation checkpointing to the wrapped module during training.
+
+    Args:
+        module: submodule to wrap.
+    Returns:
+        torch.Tensor: output of the wrapped module.
+    Warnings:
+        Skips checkpointing and emits a RuntimeWarning if the submodule contains
+        BatchNorm to avoid double-updating running stats.
+    """
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         # Pre-detect BatchNorm presence for fast path
         self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())
         self.module = module
+        self._bn_warned = False
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         if self.training and torch.is_grad_enabled() and x.requires_grad:
             if self._has_bn:
-                warnings.warn(
-                    "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
-                    "running statistics during recomputation.",
-                    RuntimeWarning,
-                )
+                if not self._bn_warned:
+                    warnings.warn(
+                        "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating "
+                        "running statistics during recomputation.",
+                        RuntimeWarning,
+                    )
+                    self._bn_warned = True
                 return cast(torch.Tensor, self.module(x))
             try:
                 return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
             except TypeError:
                 # Fallback for older PyTorch without `use_reentrant`
                 return cast(torch.Tensor, checkpoint(self.module, x))
         return cast(torch.Tensor, self.module(x))

Minimal tests to add:

  • Training vs eval parity (values match with/without checkpointing).
  • BN subblock emits RuntimeWarning and bypasses checkpointing.
  • Guard under torch.no_grad() and when requires_grad=False.

99-101: Clarify use_checkpointing behavior in docs (BN, grad, training-only, build-time).

Make expectations explicit for users.

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients
+            are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled.
+            Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats.
+            Note: this is a build-time option; changing it after initialization will not retroactively wrap existing
+            sub-blocks. Defaults to False.

179-180: Flag is build-time only; consider asserting or documenting.

Changing self.use_checkpointing post-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.


226-228: Wrapping only the subblock is OK; consider optional coverage toggle.

If desired later, expose an opt-in to also wrap down/up blocks for additional savings.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between e112457 and f673ca1.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)

16-21: Imports look good.

Scoped import of checkpoint plus cast is appropriate.


150-151: API change is sensible and non-breaking.

Parameter added at the end; default preserves behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)

29-43: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.

Apply this diff:

 class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+    """Apply activation checkpointing to the wrapped module during training.

+    Skips checkpointing for submodules containing BatchNorm to avoid double-updating
+    running statistics during recomputation.
+    """
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module
+        self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())

     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        if self.training and torch.is_grad_enabled() and x.requires_grad:
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            if self._has_bn:
+                warnings.warn(
+                    "Activation checkpointing skipped for a subblock containing BatchNorm "
+                    "to avoid double-updating running statistics during recomputation.",
+                    RuntimeWarning,
+                )
+                return cast(torch.Tensor, self.module(x))
             try:
                 return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
             except TypeError:
                 # Fallback for older PyTorch without `use_reentrant`
                 return cast(torch.Tensor, checkpoint(self.module, x))
         return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)

90-92: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.

Apply this diff:

-        use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory
-        at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.
+        use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce
+            memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled.
+            Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating
+            running statistics. Defaults to False.

217-219: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.


141-142: Add tests to lock behavior.

  • Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
  • Modes: train vs. eval; torch.no_grad().
  • Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).

I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between f673ca1 and 69540ff.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)

16-21: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.


35-42: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.

Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.


141-142: API addition looks good.
Name and default match MONAI conventions.

@ericspod
Copy link
Member

Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of UNet and override the method:

class CheckpointUNet(UNet):
    def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
        subblock = _ActivationCheckpointWrapper(subblock)
        return super()._get_connection_block(down_path, up_path, subblock)

This would suffice for your own use if you just wanted such a definition. I think the _ActivationCheckpointWrapper class may be a good thing to add to the blocks submodule instead, so it should have a public name like CheckpointWrapper.

I see also that checkpoint is used elsewhere in MONAI already like here without the checks for training and gradient that you have in your class, so I wonder if these are needed at all?

@ferreirafabio80
Copy link
Author

Hi @ericspod, thank you for your comments.

Yes, that also works. I've defined a subclass and overridden the method as you suggested.

Regarding the _ActivationCheckpointWrapper class, should I create a new script in the blocks submodule or add it to an existent one?

I was probably being extremely careful with the checks in _ActivationCheckpointWrapper, but agree we can drop the checks.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)

35-36: Missing training and gradient guards causes eval overhead and no_grad crashes.

The forward unconditionally calls checkpoint. This will:

  1. Apply checkpointing during inference (eval mode) → unnecessary compute overhead.
  2. Fail under torch.no_grad() → runtime error.

Apply this diff:

 def forward(self, x: torch.Tensor) -> torch.Tensor:
-    return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+    if self.training and torch.is_grad_enabled() and x.requires_grad:
+        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+    return cast(torch.Tensor, self.module(x))

29-37: BatchNorm in checkpointed subblocks will double-update running stats.

Checkpoint recomputes the forward pass during backward, causing BatchNorm layers to update running_mean/running_var twice per training step, skewing statistics.

Consider detecting BatchNorm in __init__ and either warning or skipping checkpoint:

 class _ActivationCheckpointWrapper(nn.Module):
     """Apply activation checkpointing to the wrapped module during training."""
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module
+        self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules())
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            if self._has_bn:
+                warnings.warn(
+                    "Activation checkpointing skipped for subblock with BatchNorm to avoid double-update of running stats.",
+                    RuntimeWarning,
+                )
+                return cast(torch.Tensor, self.module(x))
+            return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        return cast(torch.Tensor, self.module(x))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 69540ff and 42ec757.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: build-docs
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)

16-16: LGTM on imports.

Both cast and checkpoint are used in the new wrapper and are correctly imported.

Also applies to: 20-20


316-316: Clarify checkpointing scope: only subblock vs. entire connection block.

Only subblock (the recursive nested structure) is wrapped, while down_path and up_path (encoder/decoder convolutions at each level) are not checkpointed. Is this intentional?

Typical UNet checkpointing strategies checkpoint entire encoder/decoder blocks for maximum memory savings. Consider whether down_path and up_path should also be wrapped, or document the rationale for checkpointing only the recursive substructure.

# Alternative: checkpoint all three components
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
    down_path = _ActivationCheckpointWrapper(down_path)
    up_path = _ActivationCheckpointWrapper(up_path)
    subblock = _ActivationCheckpointWrapper(subblock)
    return super()._get_connection_block(down_path, up_path, subblock)

Comment on lines 29 to 37
class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add comprehensive Google-style docstring.

Per coding guidelines, document all parameters, return values, and behavior.

As per coding guidelines.

Apply this diff:

 class _ActivationCheckpointWrapper(nn.Module):
-    """Apply activation checkpointing to the wrapped module during training."""
+    """
+    Wrapper applying activation checkpointing to a module during training.
+    
+    During the forward pass in training mode, intermediate activations are not stored;
+    they are recomputed during the backward pass to reduce memory usage.
+    
+    Args:
+        module: The module to wrap with activation checkpointing.
+    """
     def __init__(self, module: nn.Module) -> None:
         super().__init__()
         self.module = module
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass with optional activation checkpointing.
+        
+        Args:
+            x: Input tensor.
+            
+        Returns:
+            Output tensor from the wrapped module.
+        """
         return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
class _ActivationCheckpointWrapper(nn.Module):
"""
Wrapper applying activation checkpointing to a module during training.
During the forward pass in training mode, intermediate activations are not stored;
they are recomputed during the backward pass to reduce memory usage.
Args:
module: The module to wrap with activation checkpointing.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with optional activation checkpointing.
Args:
x: Input tensor.
Returns:
Output tensor from the wrapped module.
"""
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack a comprehensive
Google-style docstring; add a docstring for the class describing its purpose
(applies activation checkpointing to a wrapped nn.Module during training),
document the constructor parameter module (type and role) and any stored
attributes, and add a Google-style docstring on forward documenting Args (x:
torch.Tensor — shape/expected dtype/context), Returns (torch.Tensor — shape and
that it is the wrapped module's output), and Behavior/Raises (explain that
checkpoint(...) is used with use_reentrant=False, when it will be active—only
during forward in training—and any exceptions propagated from the wrapped
module); keep wording concise and follow Google docstring sections (Args,
Returns, Raises, and a short Example/Notes if useful).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)

26-26: Export CheckpointUNet via __all__.

Public class isn’t exported; importing it from monai.networks.nets will fail. Add it to __all__.

-__all__ = ["UNet", "Unet"]
+__all__ = ["UNet", "Unet", "CheckpointUNet"]

36-37: Guard checkpoint under grad-enabled training only.

Forward unconditionally checkpointing will crash under torch.no_grad() (common in eval) because checkpoint requires gradients. Gate on self.training, torch.is_grad_enabled(), and a grad-requiring input before calling checkpoint.

     def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        return cast(torch.Tensor, self.module(x))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between a2e8474 and 4c4782e.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: min-dep-pytorch (2.6.0)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: build-docs

Comment on lines 29 to 37
class _ActivationCheckpointWrapper(nn.Module):
"""Apply activation checkpointing to the wrapped module during training."""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add Google-style docstrings.

Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor.

As per coding guidelines.

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack Google-style
docstrings; add a class-level docstring that briefly describes purpose, document
the module parameter as "module: nn.Module — module to wrap for activation
checkpointing", mention that checkpointing is applied during training to save
memory and that use_reentrant=False is used as the checkpoint guard, and add a
forward method docstring with Args: x (torch.Tensor): input tensor to the
wrapped module and Returns: torch.Tensor: output tensor from the wrapped module
(with activations checkpointed); keep wording concise and follow Google-style
"Args/Returns" formatting.

@ericspod
Copy link
Member

Regarding the _ActivationCheckpointWrapper class, should I create a new script in the blocks submodule or add it to an existent one?

Sorry for the delay, I think we should put this into a new file in the monai/networks/blocks directory since it doesn't really go anywhere else. You can then give it a good name and docstring comments, plus whatever else Coderabbit has said that's reasonable. Thanks!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/nets/unet.py (1)

25-25: Export the new public variant.
Expose CheckpointUNet alongside UNet/Unet so from monai.networks.nets import CheckpointUNet works consistently.

-__all__ = ["UNet", "Unet"]
+__all__ = ["UNet", "Unet", "CheckpointUNet"]
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)

302-307: Docstring required for CheckpointUNet.
Add a Google-style class docstring describing the checkpointing behavior, inherited args, and trade-offs per our guidelines.

 class CheckpointUNet(UNet):
+    """
+    UNet variant that wraps internal connection blocks with activation checkpointing.
+
+    See `UNet` for constructor arguments. During training with gradients enabled,
+    intermediate activations inside encoder–decoder connections are recomputed in
+    the backward pass to reduce peak memory usage at the cost of extra compute.
+    """
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 515c659 and da5a3a4.

📒 Files selected for processing (2)
  • monai/networks/blocks/activation_checkpointing.py (1 hunks)
  • monai/networks/nets/unet.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/blocks/activation_checkpointing.py
  • monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: quick-py3 (ubuntu-latest)
  • GitHub Check: quick-py3 (windows-latest)
  • GitHub Check: quick-py3 (macOS-latest)
  • GitHub Check: packaging
  • GitHub Check: build-docs
  • GitHub Check: min-dep-py3 (3.9)
  • GitHub Check: flake8-py3 (codeformat)
  • GitHub Check: flake8-py3 (mypy)
  • GitHub Check: flake8-py3 (pytype)
  • GitHub Check: min-dep-py3 (3.10)
  • GitHub Check: min-dep-py3 (3.11)
  • GitHub Check: min-dep-pytorch (2.7.1)
  • GitHub Check: min-dep-py3 (3.12)
  • GitHub Check: min-dep-os (windows-latest)
  • GitHub Check: min-dep-pytorch (2.5.1)
  • GitHub Check: min-dep-os (macOS-latest)
  • GitHub Check: min-dep-pytorch (2.8.0)
  • GitHub Check: min-dep-os (ubuntu-latest)
  • GitHub Check: min-dep-pytorch (2.6.0)

Comment on lines +32 to +41
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional activation checkpointing.
Args:
x: Input tensor.
Returns:
Output tensor from the wrapped module.
"""
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate checkpointing to active training passes.
The docstring promises training-only checkpointing, but forward always recomputes, so eval/no-grad still pays the checkpoint dispatch. Wrap the call with self.training, torch.is_grad_enabled(), and an x.requires_grad check, falling back to the plain module call otherwise, to avoid needless recompute overhead while preserving the memory trade-off during training.(docs.pytorch.org)

     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Forward pass with optional activation checkpointing.
 
         Args:
             x: Input tensor.
 
         Returns:
             Output tensor from the wrapped module.
         """
-        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        return cast(torch.Tensor, self.module(x))
🤖 Prompt for AI Agents
In monai/networks/blocks/activation_checkpointing.py around lines 32 to 41,
forward always calls checkpoint(self.module, x, use_reentrant=False) even during
eval/no-grad, causing unnecessary recompute; change it to only use
torch.utils.checkpoint when running training and gradients are enabled: check
self.training and torch.is_grad_enabled() and that input tensor x.requires_grad
before calling checkpoint(..., use_reentrant=False); otherwise call and return
self.module(x) directly to avoid unnecessary checkpoint overhead while
preserving training memory savings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
monai/networks/nets/unet.py (1)

25-25: Sort __all__ alphabetically.

-__all__ = ["UNet", "Unet", "CheckpointUNet"]
+__all__ = ["CheckpointUNet", "UNet", "Unet"]
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between da5a3a4 and 43dec88.

📒 Files selected for processing (1)
  • monai/networks/nets/unet.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

⚙️ CodeRabbit configuration file

Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.

Files:

  • monai/networks/nets/unet.py
🪛 Ruff (0.14.3)
monai/networks/nets/unet.py

25-25: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


306-306: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)

🔇 Additional comments (1)
monai/networks/nets/unet.py (1)

20-20: LGTM.

Import is clean and the wrapper is properly used in CheckpointUNet.

Comment on lines +302 to +308
class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.
See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoder–decoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace EN DASH with regular hyphen.

Line 306 contains (EN DASH) instead of - (HYPHEN-MINUS).

-    intermediate activations inside encoder–decoder connections are recomputed in
+    intermediate activations inside encoder-decoder connections are recomputed in
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.
See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoderdecoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""
class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.
See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoder-decoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""
🧰 Tools
🪛 Ruff (0.14.3)

306-306: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 302 to 308, the class docstring
contains an EN DASH (`–`) on line 306 used in "encoder–decoder"; replace that
character with a regular hyphen-minus (`-`) so it reads "encoder-decoder".
Update the docstring text only, keeping punctuation and spacing consistent.

Comment on lines +302 to +314
class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.
See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoder–decoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""

def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = ActivationCheckpointWrapper(subblock)
down_path = ActivationCheckpointWrapper(down_path)
up_path = ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify unit test coverage for CheckpointUNet.

Past review comments requested tests that confirm: (a) forward parity with UNet in eval/no-grad mode, (b) checkpointing engages only during training.


🏁 Script executed:

#!/bin/bash
# Search for test files covering CheckpointUNet
rg -n --type=py -C3 'CheckpointUNet|CheckpointUnet|checkpointunet' -g '*test*.py'

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Broader search for test files in common test directories
find . -type f -name "*test*.py" -path "*/test*" | head -20

Length of output: 1066


🏁 Script executed:

#!/bin/bash
# Search for CheckpointUNet anywhere in the codebase (not just test files)
rg -n 'CheckpointUNet' --type=py

Length of output: 135


🏁 Script executed:

#!/bin/bash
# Check for UNet tests
rg -n 'class.*UNet.*Test|def test.*unet' --type=py -i

Length of output: 3075



---

🏁 Script executed:

```shell
#!/bin/bash
# Inspect test_unet.py for any CheckpointUNet references
rg -n 'CheckpointUNet|checkpoint' tests/networks/nets/test_unet.py -i

Length of output: 0


Add unit tests for CheckpointUNet to validate wrapper behavior.

CheckpointUNet has zero test coverage. Tests must confirm: (a) forward pass produces identical results to parent UNet in eval mode or with torch.no_grad(), and (b) activation checkpointing engages only during training with gradients enabled.

🧰 Tools
🪛 Ruff (0.14.3)

306-306: Docstring contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF002)

🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 302 to 314, add unit tests (e.g.,
monai/networks/nets/test_unet_checkpoint.py) that create a UNet and a
CheckpointUNet with identical seeded initialization and the same input tensor;
assert their forward outputs are equal in eval mode and when wrapped with
torch.no_grad(); then verify activation checkpointing is active only during
training with gradients by monkeypatching or wrapping
ActivationCheckpointWrapper.forward to count invocations: run a training
forward+backward (output.sum().backward()) with requires_grad enabled and assert
the wrapper.forward is invoked more than once (indicating recomputation), and
run the same in eval or torch.no_grad() and assert it is invoked exactly once.
Ensure deterministic seeding, zero gradients between runs, and use
torch.allclose with a tight tolerance for output comparisons.

Comment on lines +310 to +314
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = ActivationCheckpointWrapper(subblock)
down_path = ActivationCheckpointWrapper(down_path)
up_path = ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add Google-style docstring to overridden method.

Per coding guidelines, all definitions require docstrings with Args/Returns sections.

As per coding guidelines.

     def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
+        """
+        Returns connection block with activation checkpointing applied to all components.
+        
+        Args:
+            down_path: encoding half of the layer (will be wrapped with checkpointing).
+            up_path: decoding half of the layer (will be wrapped with checkpointing).
+            subblock: block defining the next layer (will be wrapped with checkpointing).
+            
+        Returns:
+            Connection block with all components wrapped for activation checkpointing.
+        """
         subblock = ActivationCheckpointWrapper(subblock)
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 310 to 314, the overridden
_get_connection_block method is missing a Google-style docstring; add a
docstring immediately above the def that follows Google style with a short
summary line, an Args section documenting down_path (nn.Module): the down path
module, up_path (nn.Module): the up path module, and subblock (nn.Module): the
connecting subblock, and a Returns section documenting nn.Module: the connection
block returned (note the method wraps the three inputs with
ActivationCheckpointWrapper and delegates to super()._get_connection_block);
keep wording concise and include types for each parameter and the return.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants