Skip to content

Commit 23c271e

Browse files
ferreirafabio80Fabio Ferreiracoderabbitai[bot]pre-commit-ci[bot]KumoLiu
authored
feat: add activation checkpointing to unet (#8554)
### Description Introduces an optional `use_checkpointing` flag in the `UNet` implementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory. - Implemented via a lightweight `_ActivationCheckpointWrapper` wrapper around sub-blocks. - Checkpointing is only applied during training to avoid overhead at inference. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Fábio S. Ferreira <ferreira.fabio80@gmail.com> Signed-off-by: Fabio Ferreira <f.ferreira@qureight.com> Co-authored-by: Fabio Ferreira <f.ferreira@qureight.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 16e469c commit 23c271e

File tree

3 files changed

+254
-1
lines changed

3 files changed

+254
-1
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import cast
15+
16+
import torch
17+
import torch.nn as nn
18+
from torch.utils.checkpoint import checkpoint
19+
20+
21+
class ActivationCheckpointWrapper(nn.Module):
22+
"""Wrapper applying activation checkpointing to a module during training.
23+
24+
Args:
25+
module: The module to wrap with activation checkpointing.
26+
"""
27+
28+
def __init__(self, module: nn.Module) -> None:
29+
super().__init__()
30+
self.module = module
31+
32+
def forward(self, x: torch.Tensor) -> torch.Tensor:
33+
"""Forward pass with optional activation checkpointing.
34+
35+
Args:
36+
x: Input tensor.
37+
38+
Returns:
39+
Output tensor from the wrapped module.
40+
"""
41+
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))

monai/networks/nets/unet.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
2021
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
2122
from monai.networks.layers.factories import Act, Norm
2223
from monai.networks.layers.simplelayers import SkipConnection
2324

24-
__all__ = ["UNet", "Unet"]
25+
__all__ = ["UNet", "Unet", "CheckpointUNet"]
2526

2627

2728
class UNet(nn.Module):
@@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298299
return x
299300

300301

302+
class CheckpointUNet(UNet):
303+
"""UNet variant that wraps internal connection blocks with activation checkpointing.
304+
305+
See `UNet` for constructor arguments. During training with gradients enabled,
306+
intermediate activations inside encoder-decoder connections are recomputed in
307+
the backward pass to reduce peak memory usage at the cost of extra compute.
308+
"""
309+
310+
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
311+
"""Returns connection block with activation checkpointing applied to all components.
312+
313+
Args:
314+
down_path: encoding half of the layer (will be wrapped with checkpointing).
315+
up_path: decoding half of the layer (will be wrapped with checkpointing).
316+
subblock: block defining the next layer (will be wrapped with checkpointing).
317+
318+
Returns:
319+
Connection block with all components wrapped for activation checkpointing.
320+
"""
321+
subblock = ActivationCheckpointWrapper(subblock)
322+
down_path = ActivationCheckpointWrapper(down_path)
323+
up_path = ActivationCheckpointWrapper(up_path)
324+
return super()._get_connection_block(down_path, up_path, subblock)
325+
326+
301327
Unet = UNet
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.networks import eval_mode
20+
from monai.networks.layers import Act, Norm
21+
from monai.networks.nets.unet import CheckpointUNet, UNet
22+
23+
device = "cuda" if torch.cuda.is_available() else "cpu"
24+
25+
TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
26+
{
27+
"spatial_dims": 2,
28+
"in_channels": 1,
29+
"out_channels": 3,
30+
"channels": (16, 32, 64),
31+
"strides": (2, 2),
32+
"num_res_units": 0,
33+
},
34+
(16, 1, 32, 32),
35+
(16, 3, 32, 32),
36+
]
37+
38+
TEST_CASE_1 = [ # single channel 2D, batch 16
39+
{
40+
"spatial_dims": 2,
41+
"in_channels": 1,
42+
"out_channels": 3,
43+
"channels": (16, 32, 64),
44+
"strides": (2, 2),
45+
"num_res_units": 1,
46+
},
47+
(16, 1, 32, 32),
48+
(16, 3, 32, 32),
49+
]
50+
51+
TEST_CASE_2 = [ # single channel 3D, batch 16
52+
{
53+
"spatial_dims": 3,
54+
"in_channels": 1,
55+
"out_channels": 3,
56+
"channels": (16, 32, 64),
57+
"strides": (2, 2),
58+
"num_res_units": 1,
59+
},
60+
(16, 1, 32, 24, 48),
61+
(16, 3, 32, 24, 48),
62+
]
63+
64+
TEST_CASE_3 = [ # 4-channel 3D, batch 16
65+
{
66+
"spatial_dims": 3,
67+
"in_channels": 4,
68+
"out_channels": 3,
69+
"channels": (16, 32, 64),
70+
"strides": (2, 2),
71+
"num_res_units": 1,
72+
},
73+
(16, 4, 32, 64, 48),
74+
(16, 3, 32, 64, 48),
75+
]
76+
77+
TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
78+
{
79+
"spatial_dims": 3,
80+
"in_channels": 4,
81+
"out_channels": 3,
82+
"channels": (16, 32, 64),
83+
"strides": (2, 2),
84+
"num_res_units": 1,
85+
"norm": Norm.BATCH,
86+
},
87+
(16, 4, 32, 64, 48),
88+
(16, 3, 32, 64, 48),
89+
]
90+
91+
TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
92+
{
93+
"spatial_dims": 3,
94+
"in_channels": 4,
95+
"out_channels": 3,
96+
"channels": (16, 32, 64),
97+
"strides": (2, 2),
98+
"num_res_units": 1,
99+
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
100+
"adn_ordering": "NA",
101+
},
102+
(16, 4, 32, 64, 48),
103+
(16, 3, 32, 64, 48),
104+
]
105+
106+
TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
107+
{
108+
"spatial_dims": 3,
109+
"in_channels": 4,
110+
"out_channels": 3,
111+
"channels": (16, 32, 64),
112+
"strides": (2, 2),
113+
"num_res_units": 1,
114+
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
115+
},
116+
(16, 4, 32, 64, 48),
117+
(16, 3, 32, 64, 48),
118+
]
119+
120+
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]
121+
122+
123+
class TestCheckpointUNet(unittest.TestCase):
124+
@parameterized.expand(CASES)
125+
def test_shape(self, input_param, input_shape, expected_shape):
126+
"""Validate CheckpointUNet output shapes across configurations.
127+
128+
Args:
129+
input_param: Dictionary of UNet constructor arguments.
130+
input_shape: Tuple specifying input tensor dimensions.
131+
expected_shape: Tuple specifying expected output tensor dimensions.
132+
"""
133+
net = CheckpointUNet(**input_param).to(device)
134+
with eval_mode(net):
135+
result = net.forward(torch.randn(input_shape).to(device))
136+
self.assertEqual(result.shape, expected_shape)
137+
138+
def test_checkpointing_equivalence_eval(self):
139+
"""Confirm eval parity when checkpointing is inactive."""
140+
params = dict(
141+
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
142+
)
143+
144+
x = torch.randn(2, 1, 32, 32, device=device)
145+
146+
torch.manual_seed(42)
147+
net_plain = UNet(**params).to(device)
148+
149+
torch.manual_seed(42)
150+
net_ckpt = CheckpointUNet(**params).to(device)
151+
152+
# Both in eval mode disables checkpointing logic
153+
with eval_mode(net_ckpt), eval_mode(net_plain):
154+
y_ckpt = net_ckpt(x)
155+
y_plain = net_plain(x)
156+
157+
# Check shape equality
158+
self.assertEqual(y_ckpt.shape, y_plain.shape)
159+
160+
# Check numerical equivalence
161+
self.assertTrue(
162+
torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5),
163+
f"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.2e}",
164+
)
165+
166+
def test_checkpointing_activates_training(self):
167+
"""Verify checkpointing recomputes activations during training."""
168+
params = dict(
169+
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
170+
)
171+
172+
net = CheckpointUNet(**params).to(device)
173+
net.train()
174+
175+
x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)
176+
y = net(x)
177+
loss = y.mean()
178+
loss.backward()
179+
180+
# gradient flow check
181+
grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)
182+
self.assertGreater(grad_norm.item(), 0.0)
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

0 commit comments

Comments
 (0)