Skip to content

Commit 586b6e6

Browse files
ZnerualLaurenz-Ruzicka
authored andcommitted
Fix missing None default values for Gemma3n model in get_placeholder_mask (#39991) (#40024)
* Fix missing None default values for Gemma3n model in get_placeholder_mask (#39991) * Switched definition of optional from| None to Optiona[] (Issue #39991) --------- Co-authored-by: Laurenz Ruzicka <Laurenz.Ruzicka@ait.ac.at>
1 parent 95ae07d commit 586b6e6

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/transformers/models/gemma3n/modeling_gemma3n.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,10 +1963,10 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
19631963

19641964
def get_placeholder_mask(
19651965
self,
1966-
input_ids: torch.LongTensor,
1967-
inputs_embeds: torch.FloatTensor,
1968-
image_features: torch.FloatTensor,
1969-
audio_features: torch.FloatTensor,
1966+
input_ids: Optional[torch.LongTensor] = None,
1967+
inputs_embeds: Optional[torch.FloatTensor] = None,
1968+
image_features: Optional[torch.FloatTensor] = None,
1969+
audio_features: Optional[torch.FloatTensor] = None,
19701970
):
19711971
"""
19721972
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is

src/transformers/models/gemma3n/modular_gemma3n.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,10 +2261,10 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
22612261

22622262
def get_placeholder_mask(
22632263
self,
2264-
input_ids: torch.LongTensor,
2265-
inputs_embeds: torch.FloatTensor,
2266-
image_features: torch.FloatTensor,
2267-
audio_features: torch.FloatTensor,
2264+
input_ids: Optional[torch.LongTensor] = None,
2265+
inputs_embeds: Optional[torch.FloatTensor] = None,
2266+
image_features: Optional[torch.FloatTensor] = None,
2267+
audio_features: Optional[torch.FloatTensor] = None,
22682268
):
22692269
"""
22702270
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is

0 commit comments

Comments
 (0)