@@ -393,6 +393,47 @@ def test_inference_mask_generation_video_multi_objects_multi_points(self):
393393 rtol = 1e-4 ,
394394 )
395395
396+ def test_inference_mask_generation_video_batched_bb (self ):
397+ raw_video = prepare_video ()
398+ inference_session = self .processor .init_video_session (video = raw_video , inference_device = torch_device )
399+ ann_frame_idx = 0 # the frame index we interact with
400+ ann_obj_ids = [2 , 3 ] # give a unique id to each object we interact with (it can be any integers)
401+
402+ self .processor .add_inputs_to_inference_session (
403+ inference_session = inference_session ,
404+ frame_idx = ann_frame_idx ,
405+ obj_ids = ann_obj_ids ,
406+ input_boxes = [[[300 , 0 , 500 , 400 ], [400 , 0 , 600 , 400 ]]],
407+ )
408+
409+ frames = []
410+ for sam2_video_output in self .video_model .propagate_in_video_iterator (
411+ inference_session = inference_session ,
412+ start_frame_idx = ann_frame_idx ,
413+ max_frame_num_to_track = 2 ,
414+ ):
415+ video_res_masks = self .processor .post_process_masks (
416+ [sam2_video_output .pred_masks ], [raw_video .shape [- 3 :- 1 ]], binarize = False
417+ )[0 ]
418+ print (video_res_masks .shape )
419+ frames .append (video_res_masks )
420+ frames = torch .stack (frames , dim = 0 )
421+ self .assertEqual (frames .shape , (3 , 2 , 1 , raw_video .shape [- 3 ], raw_video .shape [- 2 ]))
422+ print (frames .shape )
423+ print (frames [:3 , :, :, :2 , :2 ])
424+ torch .testing .assert_close (
425+ frames [:3 , :, :, :2 , :2 ],
426+ torch .tensor (
427+ [
428+ [[[[- 13.1427 , - 13.1427 ], [- 13.7753 , - 13.7753 ]]], [[[- 8.4576 , - 8.4576 ], [- 8.7329 , - 8.7329 ]]]],
429+ [[[[- 14.9998 , - 14.9998 ], [- 15.7086 , - 15.7086 ]]], [[[- 9.2998 , - 9.2998 ], [- 9.8947 , - 9.8947 ]]]],
430+ [[[[- 15.4558 , - 15.4558 ], [- 16.1649 , - 16.1649 ]]], [[[- 10.4880 , - 10.4880 ], [- 11.2098 , - 11.2098 ]]]],
431+ ]
432+ ).to (torch_device ),
433+ atol = 1e-4 ,
434+ rtol = 1e-4 ,
435+ )
436+
396437 def test_inference_propagate_video_from_mask_input (self ):
397438 raw_video = prepare_video ()
398439 inference_session = self .processor .init_video_session (video = raw_video , inference_device = torch_device )
0 commit comments