Skip to content

Commit 8502b41

Browse files
authored
[Sam2Video] Fix video inference with batched boxes and add test (#40797)
fix video inference with batched boxes and add test
1 parent f384bb8 commit 8502b41

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

src/transformers/models/sam2_video/modular_sam2_video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,8 +836,7 @@ def process_new_points_or_boxes_for_video_frame(
836836
"(please use clear_old_points=True instead)"
837837
)
838838
box_coords = input_boxes.reshape(1, -1, 2, 2)
839-
box_labels = torch.tensor([2, 3], dtype=torch.int32)
840-
box_labels = box_labels.reshape(1, -1, 2)
839+
box_labels = torch.tensor([2, 3], dtype=torch.int32).repeat(1, box_coords.shape[1], 1)
841840
input_points = torch.cat([box_coords, input_points], dim=2)
842841
input_labels = torch.cat([box_labels, input_labels], dim=2)
843842

src/transformers/models/sam2_video/processing_sam2_video.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,7 @@ def process_new_points_or_boxes_for_video_frame(
721721
"(please use clear_old_points=True instead)"
722722
)
723723
box_coords = input_boxes.reshape(1, -1, 2, 2)
724-
box_labels = torch.tensor([2, 3], dtype=torch.int32)
725-
box_labels = box_labels.reshape(1, -1, 2)
724+
box_labels = torch.tensor([2, 3], dtype=torch.int32).repeat(1, box_coords.shape[1], 1)
726725
input_points = torch.cat([box_coords, input_points], dim=2)
727726
input_labels = torch.cat([box_labels, input_labels], dim=2)
728727

tests/models/sam2_video/test_modeling_sam2_video.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,47 @@ def test_inference_mask_generation_video_multi_objects_multi_points(self):
393393
rtol=1e-4,
394394
)
395395

396+
def test_inference_mask_generation_video_batched_bb(self):
397+
raw_video = prepare_video()
398+
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
399+
ann_frame_idx = 0 # the frame index we interact with
400+
ann_obj_ids = [2, 3] # give a unique id to each object we interact with (it can be any integers)
401+
402+
self.processor.add_inputs_to_inference_session(
403+
inference_session=inference_session,
404+
frame_idx=ann_frame_idx,
405+
obj_ids=ann_obj_ids,
406+
input_boxes=[[[300, 0, 500, 400], [400, 0, 600, 400]]],
407+
)
408+
409+
frames = []
410+
for sam2_video_output in self.video_model.propagate_in_video_iterator(
411+
inference_session=inference_session,
412+
start_frame_idx=ann_frame_idx,
413+
max_frame_num_to_track=2,
414+
):
415+
video_res_masks = self.processor.post_process_masks(
416+
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
417+
)[0]
418+
print(video_res_masks.shape)
419+
frames.append(video_res_masks)
420+
frames = torch.stack(frames, dim=0)
421+
self.assertEqual(frames.shape, (3, 2, 1, raw_video.shape[-3], raw_video.shape[-2]))
422+
print(frames.shape)
423+
print(frames[:3, :, :, :2, :2])
424+
torch.testing.assert_close(
425+
frames[:3, :, :, :2, :2],
426+
torch.tensor(
427+
[
428+
[[[[-13.1427, -13.1427], [-13.7753, -13.7753]]], [[[-8.4576, -8.4576], [-8.7329, -8.7329]]]],
429+
[[[[-14.9998, -14.9998], [-15.7086, -15.7086]]], [[[-9.2998, -9.2998], [-9.8947, -9.8947]]]],
430+
[[[[-15.4558, -15.4558], [-16.1649, -16.1649]]], [[[-10.4880, -10.4880], [-11.2098, -11.2098]]]],
431+
]
432+
).to(torch_device),
433+
atol=1e-4,
434+
rtol=1e-4,
435+
)
436+
396437
def test_inference_propagate_video_from_mask_input(self):
397438
raw_video = prepare_video()
398439
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)

0 commit comments

Comments
 (0)