Skip to content

Commit fb9d972

Browse files
committed
Add unit tests and improve docstrings for box_iou, box_giou, and box_pair_giou
1 parent 1bc93ad commit fb9d972

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

monai/data/box_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828
Returns:
829-
The output is always floating-point (size: (N, M)):
829+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
830+
floating-point with size ``(N, M)``:
830831
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
831832
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
832833
@@ -871,7 +872,8 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
871872
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
872873
873874
Returns:
874-
The output is always floating-point (size: (N, M)):
875+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
876+
floating-point with size ``(N, M)``:
875877
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
876878
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
877879
@@ -934,7 +936,8 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
934936
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
935937
936938
Returns:
937-
The output is always floating-point (size: (N,)):
939+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
940+
floating-point with size ``(N, )``:
938941
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
939942
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
940943

tests/data/test_box_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import unittest
1515

1616
import numpy as np
17+
import torch
1718
from parameterized import parameterized
1819

1920
from monai.data.box_utils import (
@@ -218,5 +219,46 @@ def test_value(self, input_data, mode2, expected_box, expected_area):
218219
assert_allclose(nms_box, [1], type_test=False)
219220

220221

222+
class TestBoxUtilsDtype(unittest.TestCase):
223+
@parameterized.expand(
224+
[
225+
# numpy dtypes
226+
(np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32)),
227+
(np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32)),
228+
# torch dtypes
229+
(
230+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64),
231+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64),
232+
),
233+
(
234+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32),
235+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32),
236+
),
237+
# mixed numpy (int + float)
238+
(np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32)),
239+
# mixed torch (int + float)
240+
(
241+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64),
242+
torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32),
243+
),
244+
]
245+
)
246+
def test_dtype_behavior(self, boxes1, boxes2):
247+
funcs = [box_iou, box_giou, box_pair_giou]
248+
for func in funcs:
249+
result = func(boxes1, boxes2)
250+
251+
if isinstance(result, np.ndarray):
252+
self.assertTrue(
253+
np.issubdtype(result.dtype, np.floating), f"{func.__name__} expected float, got {result.dtype}"
254+
)
255+
elif torch.is_tensor(result):
256+
self.assertTrue(
257+
torch.is_floating_point(result), f"{func.__name__} expected float tensor, got {result.dtype}"
258+
)
259+
else:
260+
self.fail(f"Unexpected return type {type(result)}")
261+
262+
221263
if __name__ == "__main__":
222264
unittest.main()

0 commit comments

Comments
 (0)