Skip to content

Commit 30567c2

Browse files
Yozerqubvel
andauthored
[timm_wrapper] add support for gradient checkpointing (#39287)
* feat: add support for gradient checkpointing in TimmWrapperModel and TimmWrapperForImageClassification * ruff fix * refactor + add test for not supported model * ruff * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/timm_wrapper/modeling_timm_wrapper.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent a44dcbe commit 30567c2

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/transformers/models/timm_wrapper/modeling_timm_wrapper.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(self, *args, **kwargs):
7070
requires_backends(self, ["vision", "timm"])
7171
super().__init__(*args, **kwargs)
7272

73+
def post_init(self):
74+
self.supports_gradient_checkpointing = self._timm_model_supports_gradient_checkpointing()
75+
super().post_init()
76+
7377
@staticmethod
7478
def _fix_state_dict_key_on_load(key) -> tuple[str, bool]:
7579
"""
@@ -107,6 +111,24 @@ def _init_weights(self, module):
107111
if module.bias is not None:
108112
module.bias.data.zero_()
109113

114+
def _timm_model_supports_gradient_checkpointing(self):
115+
"""
116+
Check if the timm model supports gradient checkpointing by checking if the `set_grad_checkpointing` method is available.
117+
Some timm models will have the method but will raise an AssertionError when called so in this case we return False.
118+
"""
119+
if not hasattr(self.timm_model, "set_grad_checkpointing"):
120+
return False
121+
122+
try:
123+
self.timm_model.set_grad_checkpointing(enable=True)
124+
self.timm_model.set_grad_checkpointing(enable=False)
125+
return True
126+
except Exception:
127+
return False
128+
129+
def _set_gradient_checkpointing(self, enable: bool = True, *args, **kwargs):
130+
self.timm_model.set_grad_checkpointing(enable)
131+
110132

111133
class TimmWrapperModel(TimmWrapperPreTrainedModel):
112134
"""

tests/models/timm_wrapper/test_modeling_timm_wrapper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ def test_mismatched_shapes_have_properly_initialized_weights(self):
170170
def test_model_is_small(self):
171171
pass
172172

173+
def test_gradient_checkpointing(self):
174+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
175+
model = TimmWrapperModel._from_config(config)
176+
self.assertTrue(model.supports_gradient_checkpointing)
177+
178+
def test_gradient_checkpointing_on_non_supported_model(self):
179+
config = TimmWrapperConfig.from_pretrained("timm/hrnet_w18.ms_aug_in1k")
180+
model = TimmWrapperModel._from_config(config)
181+
self.assertFalse(model.supports_gradient_checkpointing)
182+
173183
def test_forward_signature(self):
174184
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
175185

0 commit comments

Comments
 (0)