Skip to content

Commit f384bb8

Browse files
authored
[SAM2] Fix inconsistent results with original implementation with input boxes (#40800)
* Fix inconsistencies with box input inference with original repo * remove print * always pad * fix modular
1 parent 4cb41ad commit f384bb8

File tree

6 files changed

+31
-30
lines changed

6 files changed

+31
-30
lines changed

src/transformers/models/metaclip_2/modeling_metaclip_2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -960,9 +960,8 @@ def forward(
960960
interpolate_pos_encoding: bool = False,
961961
) -> MetaClip2Output:
962962
r"""
963-
Args:
964-
return_loss (`bool`, *optional*):
965-
Whether or not to return the contrastive loss.
963+
return_loss (`bool`, *optional*):
964+
Whether or not to return the contrastive loss.
966965
967966
Examples:
968967

src/transformers/models/metaclip_2/modular_metaclip_2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,8 @@ def forward(
551551
interpolate_pos_encoding: bool = False,
552552
):
553553
r"""
554-
Args:
555-
return_loss (`bool`, *optional*):
556-
Whether or not to return the contrastive loss.
554+
return_loss (`bool`, *optional*):
555+
Whether or not to return the contrastive loss.
557556
558557
Examples:
559558

src/transformers/models/sam2/modeling_sam2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,13 +793,14 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
793793

794794
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
795795
"""Embeds box prompts."""
796-
boxes = boxes + 0.5 # Shift to center of pixel
797-
batch_size, nb_boxes = boxes.shape[:2]
798-
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
799-
input_shape = (self.input_image_size, self.input_image_size)
800-
corner_embedding = self.shared_embedding(coords, input_shape)
796+
boxes += 0.5 # Shift to center of pixel
797+
coords = boxes.view(*boxes.shape[:2], 2, 2)
798+
# add padding point for consistency with the original implementation
799+
coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
800+
corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
801801
corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
802802
corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
803+
corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
803804
return corner_embedding
804805

805806
def forward(

src/transformers/models/sam2/modular_sam2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,13 +882,14 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
882882

883883
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
884884
"""Embeds box prompts."""
885-
boxes = boxes + 0.5 # Shift to center of pixel
886-
batch_size, nb_boxes = boxes.shape[:2]
887-
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
888-
input_shape = (self.input_image_size, self.input_image_size)
889-
corner_embedding = self.shared_embedding(coords, input_shape)
885+
boxes += 0.5 # Shift to center of pixel
886+
coords = boxes.view(*boxes.shape[:2], 2, 2)
887+
# add padding point for consistency with the original implementation
888+
coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
889+
corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
890890
corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
891891
corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
892+
corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
892893
return corner_embedding
893894

894895

src/transformers/models/sam2_video/modeling_sam2_video.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,13 +1224,14 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
12241224

12251225
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
12261226
"""Embeds box prompts."""
1227-
boxes = boxes + 0.5 # Shift to center of pixel
1228-
batch_size, nb_boxes = boxes.shape[:2]
1229-
coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
1230-
input_shape = (self.input_image_size, self.input_image_size)
1231-
corner_embedding = self.shared_embedding(coords, input_shape)
1227+
boxes += 0.5 # Shift to center of pixel
1228+
coords = boxes.view(*boxes.shape[:2], 2, 2)
1229+
# add padding point for consistency with the original implementation
1230+
coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
1231+
corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
12321232
corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
12331233
corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
1234+
corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
12341235
return corner_embedding
12351236

12361237
def forward(

tests/models/sam2/test_modeling_sam2.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ def test_inference_batched_images_batched_boxes(self):
901901
self.assertEqual(outputs.pred_masks.shape, (2, 4, 1, 256, 256))
902902
torch.testing.assert_close(
903903
outputs.iou_scores,
904-
torch.tensor([[[0.9873], [0.9264], [0.9496], [0.9208]], [[0.9445], [0.9496], [0.9497], [0.9481]]]).to(
904+
torch.tensor([[[0.9904], [0.9689], [0.9770], [0.9079]], [[0.9739], [0.9816], [0.9838], [0.9781]]]).to(
905905
torch_device
906906
),
907907
atol=1e-4,
@@ -912,16 +912,16 @@ def test_inference_batched_images_batched_boxes(self):
912912
torch.tensor(
913913
[
914914
[
915-
[[[-7.6204, -11.9286], [-8.7747, -10.5662]]],
916-
[[[-17.1070, -23.4025], [-20.9608, -19.5600]]],
917-
[[[-20.5766, -29.4410], [-26.0739, -24.3225]]],
918-
[[[-19.7201, -29.0836], [-24.4915, -23.6377]]],
915+
[[[-11.1540, -18.3994], [-12.4230, -17.4403]]],
916+
[[[-19.3144, -29.3947], [-24.6341, -24.1144]]],
917+
[[[-24.2983, -37.6470], [-31.6659, -31.0893]]],
918+
[[[-25.4313, -44.0231], [-34.0903, -34.7447]]],
919919
],
920920
[
921-
[[[-18.5259, -23.5202], [-25.1906, -17.2518]]],
922-
[[[-20.1214, -25.4215], [-25.7877, -19.1169]]],
923-
[[[-21.0878, -24.7938], [-27.5625, -19.2650]]],
924-
[[[-20.5210, -22.5343], [-26.0968, -17.7544]]],
921+
[[[-22.5539, -30.4633], [-32.8940, -21.6813]]],
922+
[[[-23.6637, -31.3489], [-32.5095, -22.4442]]],
923+
[[[-25.2987, -30.9999], [-34.6243, -24.1717]]],
924+
[[[-26.3150, -30.5313], [-35.0152, -24.0271]]],
925925
],
926926
]
927927
).to(torch_device),

0 commit comments

Comments
 (0)