Skip to content

Commit 454c0a7

Browse files
authored
Use torch.get_autocast_dtype instead of torch.get_autocast_gpu_dtype (#42055)
Update dtype handling for PyTorch 2.4 compatibility in flash attention models
1 parent f4c8497 commit 454c0a7

File tree

11 files changed

+24
-6
lines changed

11 files changed

+24
-6
lines changed

src/transformers/integrations/flash_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtyp
1515
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
1616
if query.dtype == torch.float32:
1717
if torch.is_autocast_enabled():
18-
return torch.get_autocast_gpu_dtype()
18+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
19+
return (
20+
torch.get_autocast_dtype("cuda")
21+
if hasattr(torch, "get_autocast_dtype")
22+
else torch.get_autocast_gpu_dtype()
23+
)
1924
# Handle the case where the model is quantized
2025
elif hasattr(module.config, "_pre_quantization_dtype"):
2126
return module.config._pre_quantization_dtype

src/transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,16 @@ def forward(
250250
delta_rank = delta_proj_weight.shape[1]
251251
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
252252
if torch.is_autocast_enabled():
253-
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
254-
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
255-
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
256-
out_proj_bias = (
257-
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None
253+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
254+
target_dtype = (
255+
torch.get_autocast_dtype("cuda")
256+
if hasattr(torch, "get_autocast_dtype")
257+
else torch.get_autocast_gpu_dtype()
258258
)
259+
x_proj_weight = x_proj_weight.to(dtype=target_dtype)
260+
delta_proj_weight = delta_proj_weight.to(dtype=target_dtype)
261+
out_proj_weight = out_proj_weight.to(dtype=target_dtype)
262+
out_proj_bias = out_proj_bias.to(dtype=target_dtype) if out_proj_bias is not None else None
259263
if xz.stride(-1) != 1:
260264
xz = xz.contiguous()
261265
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def forward(
353353
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
354354
if input_dtype == torch.float32:
355355
if torch.is_autocast_enabled():
356+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
356357
target_dtype = (
357358
torch.get_autocast_dtype(device_type)
358359
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/diffllama/modular_diffllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def forward(
229229
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
230230
if input_dtype == torch.float32:
231231
if torch.is_autocast_enabled():
232+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
232233
target_dtype = (
233234
torch.get_autocast_dtype(device_type)
234235
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/esm/modeling_esmfold.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class EsmForProteinFoldingOutput(ModelOutput):
137137

138138
def is_fp16_enabled(device_type):
139139
# Autocast world
140+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
140141
autocast_dtype = (
141142
torch.get_autocast_dtype(device_type)
142143
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/falcon/modeling_falcon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ def forward(
513513
device_type = query_layer.device.type if query_layer.device.type != "mps" else "cpu"
514514
if input_dtype == torch.float32:
515515
if torch.is_autocast_enabled():
516+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
516517
target_dtype = (
517518
torch.get_autocast_dtype(device_type)
518519
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/gpt_neo/modeling_gpt_neo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def forward(
230230
device_type = query.device.type if query.device.type != "mps" else "cpu"
231231
if query.dtype == torch.float32:
232232
if torch.is_autocast_enabled():
233+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
233234
target_dtype = (
234235
torch.get_autocast_dtype(device_type)
235236
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/gptj/modeling_gptj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def forward(
327327
device_type = query.device.type if query.device.type != "mps" else "cpu"
328328
if input_dtype == torch.float32:
329329
if torch.is_autocast_enabled():
330+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
330331
target_dtype = (
331332
torch.get_autocast_dtype(device_type)
332333
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@ def forward(
592592
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
593593
if input_dtype == torch.float32:
594594
if torch.is_autocast_enabled():
595+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
595596
target_dtype = (
596597
torch.get_autocast_dtype(device_type)
597598
if hasattr(torch, "get_autocast_dtype")

src/transformers/models/mimi/modeling_mimi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def forward(
806806
device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
807807
if input_dtype == torch.float32:
808808
if torch.is_autocast_enabled():
809+
# NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
809810
target_dtype = (
810811
torch.get_autocast_dtype(device_type)
811812
if hasattr(torch, "get_autocast_dtype")

0 commit comments

Comments
 (0)