|
14 | 14 | import unittest |
15 | 15 |
|
16 | 16 | import numpy as np |
| 17 | +import torch |
17 | 18 | from parameterized import parameterized |
18 | 19 |
|
19 | 20 | from monai.data.box_utils import ( |
@@ -218,5 +219,46 @@ def test_value(self, input_data, mode2, expected_box, expected_area): |
218 | 219 | assert_allclose(nms_box, [1], type_test=False) |
219 | 220 |
|
220 | 221 |
|
| 222 | +class TestBoxUtilsDtype(unittest.TestCase): |
| 223 | + @parameterized.expand( |
| 224 | + [ |
| 225 | + # numpy dtypes |
| 226 | + (np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32)), |
| 227 | + (np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32)), |
| 228 | + # torch dtypes |
| 229 | + ( |
| 230 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64), |
| 231 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64), |
| 232 | + ), |
| 233 | + ( |
| 234 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32), |
| 235 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32), |
| 236 | + ), |
| 237 | + # mixed numpy (int + float) |
| 238 | + (np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32), np.array([[0, 0, 0, 2, 2, 2]], dtype=np.float32)), |
| 239 | + # mixed torch (int + float) |
| 240 | + ( |
| 241 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.int64), |
| 242 | + torch.tensor([[0, 0, 0, 2, 2, 2]], dtype=torch.float32), |
| 243 | + ), |
| 244 | + ] |
| 245 | + ) |
| 246 | + def test_dtype_behavior(self, boxes1, boxes2): |
| 247 | + funcs = [box_iou, box_giou, box_pair_giou] |
| 248 | + for func in funcs: |
| 249 | + result = func(boxes1, boxes2) |
| 250 | + |
| 251 | + if isinstance(result, np.ndarray): |
| 252 | + self.assertTrue( |
| 253 | + np.issubdtype(result.dtype, np.floating), f"{func.__name__} expected float, got {result.dtype}" |
| 254 | + ) |
| 255 | + elif torch.is_tensor(result): |
| 256 | + self.assertTrue( |
| 257 | + torch.is_floating_point(result), f"{func.__name__} expected float tensor, got {result.dtype}" |
| 258 | + ) |
| 259 | + else: |
| 260 | + self.fail(f"Unexpected return type {type(result)}") |
| 261 | + |
| 262 | + |
221 | 263 | if __name__ == "__main__": |
222 | 264 | unittest.main() |
0 commit comments