-
Notifications
You must be signed in to change notification settings - Fork 17
Description
Voxel shift maps
As mentioned in today's call, I think that adding the full fieldmap into nitransforms would be difficult. But thinking a bit more about voxel shifts, I think we can do that fairly straightforwardly as an argument to apply()
:
class TransformBase:
def apply(
self,
spatialimage,
reference=None,
voxel_shift_map=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
...
# Ignoring transposes and homogeneous coordinates for brevity
rascoords = self.map(reference.ndcoords)
voxcoords = Affine(spatialimage.affine).map(rascoords).reshape((reference.ndim, *reference.shape))
if voxel_shift_map:
# voxel_shift_map must have shape (reference.ndim, *reference.shape)
# Alternately, we could accept it in (*reference.shape, reference.ndim) and roll axes
voxcoords += voxel_shift_map
resampled = ndi.map_coordinates(
data,
voxcoords,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
Because map
operates on RAS coordinates and not voxel indices, we cannot use it in that context, so we probably do not want to include it as part of the transform itself.
We specifically do not want to describe voxel shift maps in the world space of the target image. While it may be possible to fit it at the end of the chain, after motion correction transforms, any solution would be more complicated than the above.
Per-volume transformations
The above discussion works for an individual volume. In order to correctly handle VSMs in a motion-corrected frame, we need TransformChain
s to become aware that they are involved in a per-volume transform. Unfortunately, right now, TransformChain
s are iterable over transforms, while LinearTransformsMapping
are iterable over volumes, which at the very least means straightforward API composition isn't going to work.
Currently, LinearTransformsMapping
operates in apply()
:
nitransforms/nitransforms/linear.py
Lines 395 to 498 in 1674e86
def apply( | |
self, | |
spatialimage, | |
reference=None, | |
order=3, | |
mode="constant", | |
cval=0.0, | |
prefilter=True, | |
output_dtype=None, | |
): | |
""" | |
Apply a transformation to an image, resampling on the reference spatial object. | |
Parameters | |
---------- | |
spatialimage : `spatialimage` | |
The image object containing the data to be resampled in reference | |
space | |
reference : spatial object, optional | |
The image, surface, or combination thereof containing the coordinates | |
of samples that will be sampled. | |
order : int, optional | |
The order of the spline interpolation, default is 3. | |
The order has to be in the range 0-5. | |
mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional | |
Determines how the input image is extended when the resamplings overflows | |
a border. Default is "constant". | |
cval : float, optional | |
Constant value for ``mode="constant"``. Default is 0.0. | |
prefilter: bool, optional | |
Determines if the image's data array is prefiltered with | |
a spline filter before interpolation. The default is ``True``, | |
which will create a temporary *float64* array of filtered values | |
if *order > 1*. If setting this to ``False``, the output will be | |
slightly blurred if *order > 1*, unless the input is prefiltered, | |
i.e. it is the result of calling the spline filter on the original | |
input. | |
Returns | |
------- | |
resampled : `spatialimage` or ndarray | |
The data imaged after resampling to reference space. | |
""" | |
if reference is not None and isinstance(reference, (str, Path)): | |
reference = _nbload(str(reference)) | |
_ref = ( | |
self.reference if reference is None else SpatialReference.factory(reference) | |
) | |
if isinstance(spatialimage, (str, Path)): | |
spatialimage = _nbload(str(spatialimage)) | |
data = np.squeeze(np.asanyarray(spatialimage.dataobj)) | |
output_dtype = output_dtype or data.dtype | |
ycoords = self.map(_ref.ndcoords.T) | |
targets = ImageGrid(spatialimage).index( # data should be an image | |
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim) | |
) | |
if data.ndim == 4: | |
if len(self) != data.shape[-1]: | |
raise ValueError( | |
"Attempting to apply %d transforms on a file with " | |
"%d timepoints" % (len(self), data.shape[-1]) | |
) | |
targets = targets.reshape((len(self), -1, targets.shape[-1])) | |
resampled = np.stack( | |
[ | |
ndi.map_coordinates( | |
data[..., t], | |
targets[t, ..., : _ref.ndim].T, | |
output=output_dtype, | |
order=order, | |
mode=mode, | |
cval=cval, | |
prefilter=prefilter, | |
) | |
for t in range(data.shape[-1]) | |
], | |
axis=0, | |
) | |
elif data.ndim in (2, 3): | |
resampled = ndi.map_coordinates( | |
data, | |
targets[..., : _ref.ndim].T, | |
output=output_dtype, | |
order=order, | |
mode=mode, | |
cval=cval, | |
prefilter=prefilter, | |
) | |
if isinstance(_ref, ImageGrid): # If reference is grid, reshape | |
newdata = resampled.reshape((len(self), *_ref.shape)) | |
moved = spatialimage.__class__( | |
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header | |
) | |
moved.header.set_data_dtype(output_dtype) | |
return moved | |
return resampled |
A VSM+multivolume-aware TransformChain
could do what we want in apply()
. Another thought is that we could treat transforms as data objects and not actors. The interface could be:
def apply_transform(
source: SpatialImage,
target: Pointset,
transform: TransformBase,
shift_map: np.ndarray,
# map_coordinates args
...
) -> np.ndarray:
...
If we give up on defining apply()
correctly for each transform, and leave them to focus on composing and mapping, it might make things cleaner. Just imagining how we might approach chains that include per-volume transforms:
class TransformBase:
n_transforms: int = 1
def iter_transforms(self) -> Iterator[TransformBase]:
"""Repeat current transform as often as required"""
return itertools.repeat(self)
class AffineSeries(TransformBase):
@property
def n_transforms(self) -> int:
return len(self.series)
def iter_transforms(self) -> Iterator[TransformBase]:
"""Iterate over the defined series"""
return iter(self.series)
class TransformChain(TransformBase):
@property
def n_transforms(self) -> int:
lengths = [xfm.n_transforms for xfm in self.chain if xfm.n_transforms != 1]
return min(lengths) if lengths else 1
def iter_transforms(self) -> Iterator[TransformChain]:
"""Iterate over all transforms in chain, simultaneously, stopping with first to stop"""
return map(TransformChain, zip(*(xfm.iter_transforms() for xfm in self.chain)))