Skip to content

Commit 3cff626

Browse files
committed
fix box_iou
1 parent 0968da2 commit 3cff626

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

monai/data/box_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -842,16 +842,18 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
842842

843843
inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)
844844

845-
# compute IoU and convert back to original box_dtype
845+
# compute IoU and convert back to original box_dtype or float32
846846
iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M)
847+
if not box_dtype.is_floating_point:
848+
box_dtype = COMPUTE_DTYPE
847849
iou_t = iou_t.to(dtype=box_dtype)
848850

849851
# check if NaN or Inf
850852
if torch.isnan(iou_t).any() or torch.isinf(iou_t).any():
851853
raise ValueError("Box IoU is NaN or Inf.")
852854

853855
# convert tensor back to numpy if needed
854-
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1)
856+
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1, dtype=box_dtype)
855857
return iou
856858

857859

0 commit comments

Comments
 (0)