-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
de2b6bd
66edcb5
e66e357
feefcaa
e112457
f673ca1
69540ff
42ec757
a2e8474
4c4782e
515c659
da5a3a4
43dec88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import cast | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.utils.checkpoint import checkpoint | ||
|
|
||
|
|
||
| class ActivationCheckpointWrapper(nn.Module): | ||
| """Wrapper applying activation checkpointing to a module during training. | ||
| Args: | ||
| module: The module to wrap with activation checkpointing. | ||
| """ | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Forward pass with optional activation checkpointing. | ||
| Args: | ||
| x: Input tensor. | ||
| Returns: | ||
| Output tensor from the wrapped module. | ||
| """ | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,11 +17,12 @@ | |||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper | ||||||||||||||||||||||||||||||
| from monai.networks.blocks.convolutions import Convolution, ResidualUnit | ||||||||||||||||||||||||||||||
| from monai.networks.layers.factories import Act, Norm | ||||||||||||||||||||||||||||||
| from monai.networks.layers.simplelayers import SkipConnection | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| __all__ = ["UNet", "Unet"] | ||||||||||||||||||||||||||||||
| __all__ = ["UNet", "Unet", "CheckpointUNet"] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class UNet(nn.Module): | ||||||||||||||||||||||||||||||
|
|
@@ -298,4 +299,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class CheckpointUNet(UNet): | ||||||||||||||||||||||||||||||
| """UNet variant that wraps internal connection blocks with activation checkpointing. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| See `UNet` for constructor arguments. During training with gradients enabled, | ||||||||||||||||||||||||||||||
| intermediate activations inside encoder–decoder connections are recomputed in | ||||||||||||||||||||||||||||||
| the backward pass to reduce peak memory usage at the cost of extra compute. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
Comment on lines
+302
to
+308
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace EN DASH with regular hyphen. Line 306 contains - intermediate activations inside encoder–decoder connections are recomputed in
+ intermediate activations inside encoder-decoder connections are recomputed in📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.3)306-306: Docstring contains ambiguous (RUF002) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: | ||||||||||||||||||||||||||||||
| subblock = ActivationCheckpointWrapper(subblock) | ||||||||||||||||||||||||||||||
| down_path = ActivationCheckpointWrapper(down_path) | ||||||||||||||||||||||||||||||
| up_path = ActivationCheckpointWrapper(up_path) | ||||||||||||||||||||||||||||||
| return super()._get_connection_block(down_path, up_path, subblock) | ||||||||||||||||||||||||||||||
|
Comment on lines
+302
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainVerify unit test coverage for Past review comments requested tests that confirm: (a) forward parity with 🏁 Script executed: #!/bin/bash
# Search for test files covering CheckpointUNet
rg -n --type=py -C3 'CheckpointUNet|CheckpointUnet|checkpointunet' -g '*test*.py'Length of output: 0 Length of output: 1066 🏁 Script executed: #!/bin/bash
# Search for CheckpointUNet anywhere in the codebase (not just test files)
rg -n 'CheckpointUNet' --type=pyLength of output: 135 🏁 Script executed: #!/bin/bash
# Check for UNet tests
rg -n 'class.*UNet.*Test|def test.*unet' --type=py -iLength of output: 3075 Length of output: 0 Add unit tests for CheckpointUNet has zero test coverage. Tests must confirm: (a) forward pass produces identical results to parent 🧰 Tools🪛 Ruff (0.14.3)306-306: Docstring contains ambiguous (RUF002) 🤖 Prompt for AI Agents
Comment on lines
+310
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add Google-style docstring to overridden method. Per coding guidelines, all definitions require docstrings with Args/Returns sections. As per coding guidelines. def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
+ """
+ Returns connection block with activation checkpointing applied to all components.
+
+ Args:
+ down_path: encoding half of the layer (will be wrapped with checkpointing).
+ up_path: decoding half of the layer (will be wrapped with checkpointing).
+ subblock: block defining the next layer (will be wrapped with checkpointing).
+
+ Returns:
+ Connection block with all components wrapped for activation checkpointing.
+ """
subblock = ActivationCheckpointWrapper(subblock)🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Unet = UNet | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gate checkpointing to active training passes.
The docstring promises training-only checkpointing, but
forwardalways recomputes, so eval/no-grad still pays the checkpoint dispatch. Wrap the call withself.training,torch.is_grad_enabled(), and anx.requires_gradcheck, falling back to the plain module call otherwise, to avoid needless recompute overhead while preserving the memory trade-off during training.(docs.pytorch.org)def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with optional activation checkpointing. Args: x: Input tensor. Returns: Output tensor from the wrapped module. """ - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))🤖 Prompt for AI Agents