Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 179 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils import (
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
)
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand All @@ -54,6 +61,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"GenerateHeatmap",
"DistanceTransformEDT",
]

Expand Down Expand Up @@ -742,6 +750,176 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
return self.post_convert(out_pt, img)


class GenerateHeatmap(Transform):
"""
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.

Notes:
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
- Output layout uses channel-first convention with one channel per landmark:
- Non-batched points (N, D): (N, Y, X) for 2D or (N, Z, Y, X) for 3D
- Batched points (B, N, D): (B, N, Y, X) for 2D or (B, N, Z, Y, X) for 3D
- Each channel corresponds to one landmark.

Args:
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
A single int value will be broadcast to all spatial dimensions.
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).

Raises:
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.

"""

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(
self,
sigma: Sequence[float] | float = 5.0,
spatial_shape: Sequence[int] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
) -> None:
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
if any(s <= 0 for s in sigma):
raise ValueError("sigma values must be positive.")
self._sigma = tuple(float(s) for s in sigma)
else:
if float(sigma) <= 0:
raise ValueError("sigma must be positive.")
self._sigma = (float(sigma),)
if truncated <= 0:
raise ValueError("truncated must be positive.")
self.truncated = float(truncated)
self.normalize = normalize
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)

def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
original_points = points
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)

is_batched = points_t.ndim == 3
if not is_batched:
if points_t.ndim != 2:
raise ValueError(
"points must be a 2D or 3D array with shape (num_points, spatial_dims) or (B, num_points, spatial_dims)."
)
points_t = points_t.unsqueeze(0) # Add a batch dimension

if points_t.shape[-1] not in (2, 3):
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")

device = points_t.device
batch_size, num_points, spatial_dims = points_t.shape

target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
sigma = self._resolve_sigma(spatial_dims)
radius = tuple(int(np.ceil(self.truncated * s)) for s in sigma)

heatmap = torch.zeros((batch_size, num_points, *target_shape), dtype=self.torch_dtype, device=device)
image_bounds = tuple(int(s) for s in target_shape)
for b_idx in range(batch_size):
for idx, center in enumerate(points_t[b_idx]):
center_vals = center.tolist()
if not np.all(np.isfinite(center_vals)):
continue
if not self._is_inside(center_vals, image_bounds):
continue
window_slices, coord_shifts = self._make_window(center_vals, radius, image_bounds, device)
if window_slices is None:
continue
region = heatmap[b_idx, idx][window_slices]
gaussian = self._evaluate_gaussian(coord_shifts, sigma)
updated = torch.maximum(region, gaussian)
# write back
region.copy_(updated)
if self.normalize:
peak = heatmap[b_idx, idx].amax()
denom = torch.where(peak > 0, peak, torch.ones_like(peak))
heatmap[b_idx, idx].div_(denom)

if not is_batched:
heatmap = heatmap.squeeze(0)

target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
return converted

def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
shape = call_shape if call_shape is not None else self.spatial_shape
if shape is None:
raise ValueError("spatial_shape must be provided either at construction time or call time.")
shape_tuple = ensure_tuple(shape)
if len(shape_tuple) != spatial_dims:
if len(shape_tuple) == 1:
shape_tuple = shape_tuple * spatial_dims # type: ignore
else:
raise ValueError(
"spatial_shape length must match the landmarks' spatial dims (or pass a single int to broadcast)."
)
return tuple(int(s) for s in shape_tuple)

def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
if len(self._sigma) == spatial_dims:
return self._sigma
if len(self._sigma) == 1:
return self._sigma * spatial_dims
raise ValueError("sigma sequence length must equal the number of spatial dimensions.")

@staticmethod
def _is_inside(center: Sequence[float], bounds: tuple[int, ...]) -> bool:
for c, size in zip(center, bounds):
if not (0 <= c < size):
return False
return True

def _make_window(
self, center: Sequence[float], radius: tuple[int, ...], bounds: tuple[int, ...], device: torch.device
) -> tuple[tuple[slice, ...] | None, tuple[torch.Tensor, ...]]:
slices: list[slice] = []
coord_shifts: list[torch.Tensor] = []
for _dim, (c, r, size) in enumerate(zip(center, radius, bounds)):
start = max(int(np.floor(c - r)), 0)
stop = min(int(np.ceil(c + r)) + 1, size)
if start >= stop:
return None, ()
slices.append(slice(start, stop))
coord_shifts.append(torch.arange(start, stop, device=device, dtype=torch.float32) - float(c))
return tuple(slices), tuple(coord_shifts)

def _evaluate_gaussian(self, coord_shifts: tuple[torch.Tensor, ...], sigma: tuple[float, ...]) -> torch.Tensor:
"""
Evaluate Gaussian at given coordinate shifts with specified sigmas.

Args:
coord_shifts: Per-dimension coordinate offsets from center.
sigma: Per-dimension standard deviations.

Returns:
Gaussian values at the specified coordinates.
"""
device = coord_shifts[0].device
shape = tuple(len(axis) for axis in coord_shifts)
if 0 in shape:
return torch.zeros(shape, dtype=self.torch_dtype, device=device)
exponent = torch.zeros(shape, dtype=torch.float32, device=device)
for dim, (shift, sig) in enumerate(zip(coord_shifts, sigma)):
shift32 = shift.to(torch.float32)
scaled = (shift32 / float(sig)) ** 2
reshape_shape = [1] * len(coord_shifts)
reshape_shape[dim] = shift.numel()
exponent += scaled.reshape(reshape_shape)
gauss = torch.exp(-0.5 * exponent)
return gauss.to(dtype=self.torch_dtype)


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
Loading
Loading