@@ -901,7 +901,7 @@ def test_inference_batched_images_batched_boxes(self):
901901 self .assertEqual (outputs .pred_masks .shape , (2 , 4 , 1 , 256 , 256 ))
902902 torch .testing .assert_close (
903903 outputs .iou_scores ,
904- torch .tensor ([[[0.9873 ], [0.9264 ], [0.9496 ], [0.9208 ]], [[0.9445 ], [0.9496 ], [0.9497 ], [0.9481 ]]]).to (
904+ torch .tensor ([[[0.9904 ], [0.9689 ], [0.9770 ], [0.9079 ]], [[0.9739 ], [0.9816 ], [0.9838 ], [0.9781 ]]]).to (
905905 torch_device
906906 ),
907907 atol = 1e-4 ,
@@ -912,16 +912,16 @@ def test_inference_batched_images_batched_boxes(self):
912912 torch .tensor (
913913 [
914914 [
915- [[[- 7.6204 , - 11.9286 ], [- 8.7747 , - 10.5662 ]]],
916- [[[- 17.1070 , - 23.4025 ], [- 20.9608 , - 19.5600 ]]],
917- [[[- 20.5766 , - 29.4410 ], [- 26.0739 , - 24.3225 ]]],
918- [[[- 19.7201 , - 29.0836 ], [- 24.4915 , - 23.6377 ]]],
915+ [[[- 11.1540 , - 18.3994 ], [- 12.4230 , - 17.4403 ]]],
916+ [[[- 19.3144 , - 29.3947 ], [- 24.6341 , - 24.1144 ]]],
917+ [[[- 24.2983 , - 37.6470 ], [- 31.6659 , - 31.0893 ]]],
918+ [[[- 25.4313 , - 44.0231 ], [- 34.0903 , - 34.7447 ]]],
919919 ],
920920 [
921- [[[- 18.5259 , - 23.5202 ], [- 25.1906 , - 17.2518 ]]],
922- [[[- 20.1214 , - 25.4215 ], [- 25.7877 , - 19.1169 ]]],
923- [[[- 21.0878 , - 24.7938 ], [- 27.5625 , - 19.2650 ]]],
924- [[[- 20.5210 , - 22.5343 ], [- 26.0968 , - 17.7544 ]]],
921+ [[[- 22.5539 , - 30.4633 ], [- 32.8940 , - 21.6813 ]]],
922+ [[[- 23.6637 , - 31.3489 ], [- 32.5095 , - 22.4442 ]]],
923+ [[[- 25.2987 , - 30.9999 ], [- 34.6243 , - 24.1717 ]]],
924+ [[[- 26.3150 , - 30.5313 ], [- 35.0152 , - 24.0271 ]]],
925925 ],
926926 ]
927927 ).to (torch_device ),
0 commit comments