Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
188 changes: 188 additions & 0 deletions nitransforms/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""
Utilities to aid in performing and evaluating image registration.

This module provides functions to compute displacements of image coordinates
under a transformation, useful for assessing the accuracy of image registration
processes.

"""

from __future__ import annotations

from itertools import product
from typing import Tuple

import nibabel as nb
import numpy as np
from scipy.stats import zscore

from nitransforms.base import TransformBase


RADIUS = 50.0
"""Typical radius (in mm) of a sphere mimicking the size of a typical human brain."""


def compute_fd_from_motion(motion_parameters: np.ndarray, radius: float = RADIUS) -> np.ndarray:
"""Compute framewise displacement (FD) from motion parameters.

Each row in the motion parameters represents one frame, and columns
represent each coordinate axis ``x``, `y``, and ``z``. Translation
parameters are followed by rotation parameters column-wise.

Parameters
----------
motion_parameters : :obj:`numpy.ndarray`
Motion parameters.
radius : :obj:`float`, optional
Radius (in mm) of a sphere mimicking the size of a typical human brain.

Returns
-------
:obj:`numpy.ndarray`
The framewise displacement (FD) as the sum of absolute differences
between consecutive frames.
"""

translations = motion_parameters[:, :3]
rotations_deg = motion_parameters[:, 3:]
rotations_rad = np.deg2rad(rotations_deg)

# Compute differences between consecutive frames
d_translations = np.vstack([np.zeros((1, 3)), np.diff(translations, axis=0)])
d_rotations = np.vstack([np.zeros((1, 3)), np.diff(rotations_rad, axis=0)])

# Convert rotations from radians to displacement on a sphere
rotation_displacement = d_rotations * radius

# Compute FD as sum of absolute differences
return np.sum(np.abs(d_translations) + np.abs(rotation_displacement), axis=1)


def compute_fd_from_transform(
img: nb.spatialimages.SpatialImage,
test_xfm: TransformBase,
radius: float = RADIUS,
) -> float:
"""
Compute the framewise displacement (FD) for a given transformation.

Parameters
----------
img : :obj:`~nibabel.spatialimages.SpatialImage`
The reference image. Used to extract the center coordinates.
test_xfm : :obj:`~nitransforms.base.TransformBase`
The transformation to test. Applied to coordinates around the image center.
radius : :obj:`float`, optional
The radius (in mm) of the spherical neighborhood around the center of the image.

Returns
-------
:obj:`float`
The average framewise displacement (FD) for the test transformation.

"""
affine = img.affine
# Compute the center of the image in voxel space
center_ijk = 0.5 * (np.array(img.shape[:3]) - 1)
# Convert to world coordinates
center_xyz = nb.affines.apply_affine(affine, center_ijk)
# Generate coordinates of points at radius distance from center
fd_coords = np.array(list(product(*((radius, -radius),) * 3))) + center_xyz
# Compute the average displacement from the test transformation
return np.mean(np.linalg.norm(test_xfm.map(fd_coords) - fd_coords, axis=-1))


def displacements_within_mask(
mask_img: nb.spatialimages.SpatialImage,
test_xfm: TransformBase,
reference_xfm: TransformBase | None = None,
) -> np.ndarray:
"""
Compute the distance between voxel coordinates mapped through two transforms.

Parameters
----------
mask_img : :obj:`~nibabel.spatialimages.SpatialImage`
A mask image that defines the region of interest. Voxel coordinates
within the mask are transformed.
test_xfm : :obj:`~nitransforms.base.TransformBase`
The transformation to test. This transformation is applied to the
voxel coordinates.
reference_xfm : :obj:`~nitransforms.base.TransformBase`, optional
A reference transformation to compare with. If ``None``, the identity
transformation is assumed (no transformation).

Returns
-------
:obj:`~numpy.ndarray`
An array of displacements (in mm) for each voxel within the mask.

"""
# Mask data as boolean (True for voxels inside the mask)
maskdata = np.asanyarray(mask_img.dataobj) > 0
# Convert voxel coordinates to world coordinates using affine transform
xyz = nb.affines.apply_affine(
mask_img.affine,
np.argwhere(maskdata),
)
# Apply the test transformation
targets = test_xfm.map(xyz)

# Compute the difference (displacement) between the test and reference transformations
diffs = targets - xyz if reference_xfm is None else targets - reference_xfm.map(xyz)
return np.linalg.norm(diffs, axis=-1)


def extract_motion_parameters(affine: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Extract translation (mm) and rotation (degrees) parameters from an affine matrix.

Parameters
----------
affine : :obj:`~numpy.ndarray`
The affine transformation matrix.

Returns
-------
:obj:`tuple`
Extracted translation and rotation parameters.
"""

translation = affine[:3, 3]
rotation_rad = np.arctan2(
[affine[2, 1], affine[0, 2], affine[1, 0]], [affine[2, 2], affine[0, 0], affine[1, 1]]
)
rotation_deg = np.rad2deg(rotation_rad)
return *translation, *rotation_deg


def identify_spikes(fd: np.ndarray, threshold: float = 2.0):
"""Identify motion spikes in framewise displacement data.

Identifies high-motion frames as timepoint exceeding a given threshold value
based on z-score normalized framewise displacement (FD) values.

Parameters
----------
fd : :obj:`~numpy.ndarray`
Framewise displacement data.
threshold : :obj:`float`, optional
Threshold value to determine motion spikes.

Returns
-------
indices : :obj:`~numpy.ndarray`
Indices of identified motion spikes.
mask : :obj:`~numpy.ndarray`
Mask of identified motion spikes.
"""

# Normalize (z-score)
fd_norm = zscore(fd)

mask = fd_norm > threshold
indices = np.where(mask)[0]

return indices, mask
11 changes: 11 additions & 0 deletions nitransforms/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:

import numpy as np
import pytest


@pytest.fixture(autouse=True)
def random_number_generator(request):
"""Automatically set a fixed-seed random number generator for all tests."""
request.node.rng = np.random.default_rng(1234)
156 changes: 156 additions & 0 deletions nitransforms/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:

import numpy as np
import nibabel as nb
import pytest

import nitransforms as nt

from nitransforms.analysis.utils import (
compute_fd_from_motion,
compute_fd_from_transform,
displacements_within_mask,
extract_motion_parameters,
identify_spikes,
)


@pytest.fixture
def identity_affine():
return np.eye(4)


@pytest.fixture
def simple_mask_img(identity_affine):
# 3x3x3 mask with center voxel as 1, rest 0
data = np.zeros((3, 3, 3), dtype=np.uint8)
data[1, 1, 1] = 1
return nb.Nifti1Image(data, identity_affine)


@pytest.fixture
def translation_transform():
# Simple translation of (1, 2, 3) mm
return nt.linear.Affine(map=np.array([
[1, 0, 0, 1],
[0, 1, 0, 2],
[0, 0, 1, 3],
[0, 0, 0, 1],
]))


@pytest.fixture
def rotation_transform():
# 90 degree rotation around z axis
angle = np.pi / 2
rot = np.array([
[np.cos(angle), -np.sin(angle), 0, 0],
[np.sin(angle), np.cos(angle), 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
])
return nt.linear.Affine(map=rot)


@pytest.mark.parametrize(
"test_xfm, reference_xfm, expected",
[
(nt.linear.Affine(np.eye(4)), None, np.zeros(1)),
(nt.linear.Affine(np.array([
[1, 0, 0, 1],
[0, 1, 0, 2],
[0, 0, 1, 3],
[0, 0, 0, 1],
])), None, [np.linalg.norm([1, 2, 3])]),
(nt.linear.Affine(np.array([
[1, 0, 0, 1],
[0, 1, 0, 2],
[0, 0, 1, 3],
[0, 0, 0, 1],
])), nt.linear.Affine(np.eye(4)), [np.linalg.norm([1, 2, 3])]),
],
)
def test_displacements_within_mask(simple_mask_img, test_xfm, reference_xfm, expected):
disp = displacements_within_mask(simple_mask_img, test_xfm, reference_xfm)
np.testing.assert_allclose(disp, expected)


@pytest.mark.parametrize(
"test_xfm, expected",
[
(nt.linear.Affine(np.eye(4)), 0),
(nt.linear.Affine(np.array([
[1, 0, 0, 1],
[0, 1, 0, 2],
[0, 0, 1, 3],
[0, 0, 0, 1],
])), np.linalg.norm([1, 2, 3])),
],
)
def test_compute_fd_from_transform(simple_mask_img, test_xfm, expected):
fd = compute_fd_from_transform(simple_mask_img, test_xfm)
assert np.isclose(fd, expected)


@pytest.mark.parametrize(
"motion_params, radius, expected",
[
(np.zeros((5, 6)), 50, np.zeros(5)), # 5 frames, 3 trans, 3 rot
(
np.array([
[0,0,0,0,0,0],
[2,0,0,0,0,0], # 2mm translation in x at frame 1
[2,0,0,90,0,0],
]), # 90deg rotation in x at frame 2
50,
[0, 2, abs(np.deg2rad(90)) * 50]
), # First frame: 0, Second: translation 2mm, Third: rotation (pi/2)*50
],
)
def test_compute_fd_from_motion(motion_params, radius, expected):
fd = compute_fd_from_motion(motion_params, radius=radius)
np.testing.assert_allclose(fd, expected, atol=1e-4)


@pytest.mark.parametrize(
"affine, expected_trans, expected_rot",
[
(np.eye(4) + np.array([[0,0,0,10],[0,0,0,15],[0,0,0,20],[0,0,0,0]]), # translation only
[10, 15, 20], [0, 0, 0]),
(np.array([
[1, 0, 0, 0],
[0, np.cos(np.deg2rad(30)), -np.sin(np.deg2rad(30)), 0],
[0, np.sin(np.deg2rad(30)), np.cos(np.deg2rad(30)), 0],
[0, 0, 0, 1], # rotation only
]), [0, 0, 0], [30, 0, 0]), # Only one rot will be close to 30
],
)
def test_extract_motion_parameters(affine, expected_trans, expected_rot):
params = extract_motion_parameters(affine)
assert np.allclose(params[:3], expected_trans)
# For rotation case, at least one value close to 30
if np.any(np.abs(expected_rot)):
assert np.any(np.isclose(np.abs(params[3:]), 30, atol=1e-4))
else:
assert np.allclose(params[3:], expected_rot)


def test_identify_spikes(request):
rng = request.node.rng

n_samples = 450

fd = rng.normal(0, 5, n_samples)
threshold = 2.0

expected_indices = np.asarray(
[5, 57, 85, 100, 127, 180, 191, 202, 335, 393, 409]
)
expected_mask = np.zeros(n_samples, dtype=bool)
expected_mask[expected_indices] = True

obtained_indices, obtained_mask = identify_spikes(fd, threshold=threshold)

assert np.array_equal(obtained_indices, expected_indices)
assert np.array_equal(obtained_mask, expected_mask)
Loading