Skip to content

Commit dfea4ff

Browse files
Add monet_bundle_inference_operator.py
- Introduced a new file containing the implementation of the MONetBundleInferenceOperator. - This operator extends the MonaiBundleInferenceOperator to facilitate inference with nnUNet-specific configurations. - Implemented methods for configuration initialization and multimodal data prediction, enhancing the MONAI framework's inference capabilities.
1 parent 4c4fbb5 commit dfea4ff

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright 2002 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Dict, Tuple, Union
13+
import logging
14+
from monai.deploy.core import Image
15+
from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
16+
from monai.deploy.utils.importutil import optional_import
17+
from monai.transforms import SpatialResample, ConcatItemsd
18+
import numpy as np
19+
MONAI_UTILS = "monai.utils"
20+
nibabel, _ = optional_import("nibabel", "3.2.1")
21+
torch, _ = optional_import("torch", "1.10.2")
22+
23+
NdarrayOrTensor, _ = optional_import("monai.config", name="NdarrayOrTensor")
24+
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
25+
PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix
26+
first, _ = optional_import("monai.utils.misc", name="first")
27+
ensure_tuple, _ = optional_import(MONAI_UTILS, name="ensure_tuple")
28+
convert_to_dst_type, _ = optional_import(MONAI_UTILS, name="convert_to_dst_type")
29+
Key, _ = optional_import(MONAI_UTILS, name="ImageMetaKey")
30+
MetaKeys, _ = optional_import(MONAI_UTILS, name="MetaKeys")
31+
SpaceKeys, _ = optional_import(MONAI_UTILS, name="SpaceKeys")
32+
Compose_, _ = optional_import("monai.transforms", name="Compose")
33+
ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser")
34+
MapTransform_, _ = optional_import("monai.transforms", name="MapTransform")
35+
SimpleInferer, _ = optional_import("monai.inferers", name="SimpleInferer")
36+
37+
Compose: Any = Compose_
38+
MapTransform: Any = MapTransform_
39+
ConfigParser: Any = ConfigParser_
40+
__all__ = ["MONetBundleInferenceOperator"]
41+
42+
43+
def define_affine_from_meta(meta: Dict[str, Any]) -> np.ndarray:
44+
"""
45+
Define an affine matrix from the metadata of a tensor.
46+
47+
Parameters
48+
----------
49+
meta : Dict[str, Any]
50+
Metadata dictionary containing 'pixdim', 'origin', and 'direction'.
51+
52+
Returns
53+
-------
54+
np.ndarray
55+
A 4x4 affine matrix constructed from the metadata.
56+
"""
57+
if "pixdim" not in meta or "origin" not in meta or "direction" not in meta:
58+
return meta.get("affine", np.eye(4))
59+
pixdim = meta["pixdim"]
60+
origin = meta["origin"]
61+
direction = meta["direction"].reshape(3, 3)
62+
63+
# Extract 3D spacing
64+
spacing = pixdim[1:4] # drop the first element (usually 1 for time dim)
65+
66+
# Scale the direction vectors by spacing to get rotation+scale part
67+
affine = direction * spacing[np.newaxis, :]
68+
69+
# Append origin to get 3x4 affine matrix
70+
affine = np.column_stack((affine, origin))
71+
72+
# Make it a full 4x4 affine
73+
affine_4x4 = np.vstack((affine, [0, 0, 0, 1]))
74+
pixdim = meta["pixdim"]
75+
origin = meta["origin"]
76+
direction = meta["direction"].reshape(3, 3)
77+
78+
# Extract 3D spacing
79+
spacing = pixdim[1:4] # drop the first element (usually 1 for time dim)
80+
81+
# Scale the direction vectors by spacing to get rotation+scale part
82+
affine = direction * spacing[np.newaxis, :]
83+
84+
# Append origin to get 3x4 affine matrix
85+
affine = np.column_stack((affine, origin))
86+
87+
# Make it a full 4x4 affine
88+
return torch.Tensor(np.vstack((affine, [0, 0, 0, 1])))
89+
90+
class MONetBundleInferenceOperator(MonaiBundleInferenceOperator):
91+
"""
92+
A specialized operator for performing inference using the MONet bundle.
93+
This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
94+
configurations and prediction logic. It initializes the nnUNet predictor and provides
95+
a method for performing inference on input data.
96+
97+
Attributes
98+
----------
99+
_nnunet_predictor : torch.nn.Module
100+
The nnUNet predictor module used for inference.
101+
102+
Methods
103+
-------
104+
_init_config(config_names)
105+
Initializes the configuration for the nnUNet bundle, including parsing the bundle
106+
configuration and setting up the nnUNet predictor.
107+
predict(data, *args, **kwargs)
108+
Performs inference on the input data using the nnUNet predictor.
109+
"""
110+
111+
def __init__(
112+
self,
113+
*args,
114+
**kwargs,
115+
):
116+
117+
super().__init__(*args, **kwargs)
118+
119+
self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
120+
self._nnunet_predictor: torch.nn.Module = None
121+
self.ref_modality = None
122+
if "ref_modality" in kwargs:
123+
self.ref_modality = kwargs["ref_modality"]
124+
125+
def _init_config(self, config_names):
126+
127+
super()._init_config(config_names)
128+
parser = get_bundle_config(str(self._bundle_path), config_names)
129+
self._parser = parser
130+
131+
self._nnunet_predictor = parser.get_parsed_content("network_def")
132+
133+
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
134+
"""Predicts output using the inferer."""
135+
136+
self._nnunet_predictor.predictor.network = self._model_network
137+
# os.environ['nnUNet_def_n_proc'] = "1"
138+
139+
if len(kwargs) > 0:
140+
multimodal_data = {"image": data}
141+
if self.ref_modality is not None:
142+
if self.ref_modality not in kwargs:
143+
target_affine_4x4 = define_affine_from_meta(data.meta)
144+
spatial_size = data.shape[1:4]
145+
if "pixdim" in data.meta:
146+
pixdim = data.meta["pixdim"]
147+
else:
148+
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))
149+
else:
150+
target_affine_4x4 = define_affine_from_meta(kwargs[self.ref_modality].meta)
151+
spatial_size = kwargs[self.ref_modality].shape[1:4]
152+
if "pixdim" in kwargs[self.ref_modality].meta:
153+
pixdim = kwargs[self.ref_modality].meta["pixdim"]
154+
else:
155+
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))
156+
else:
157+
target_affine_4x4 = define_affine_from_meta(data.meta)
158+
spatial_size = data.shape[1:4]
159+
if "pixdim" in data.meta:
160+
pixdim = data.meta["pixdim"]
161+
else:
162+
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))
163+
164+
for key in kwargs.keys():
165+
if isinstance(kwargs[key], MetaTensor):
166+
source_affine_4x4 = define_affine_from_meta(kwargs[key].meta)
167+
kwargs[key].meta["affine"] = torch.Tensor(source_affine_4x4)
168+
kwargs[key].meta["pixdim"] = pixdim
169+
self._logger.info(f"Resampling {key} from {source_affine_4x4} to {target_affine_4x4}")
170+
171+
multimodal_data[key] = SpatialResample(mode="bilinear")(kwargs[key], dst_affine=target_affine_4x4,
172+
spatial_size=spatial_size,
173+
)
174+
source_affine_4x4 = define_affine_from_meta(data.meta)
175+
data.meta["affine"] = torch.Tensor(source_affine_4x4)
176+
data.meta["pixdim"] = pixdim
177+
multimodal_data["image"] = SpatialResample(mode="bilinear")(
178+
data, dst_affine=target_affine_4x4, spatial_size=spatial_size
179+
)
180+
181+
self._logger.info(f"Resampling 'image' from from {source_affine_4x4} to {target_affine_4x4}")
182+
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
183+
data.meta["pixdim"] = np.insert(pixdim, 0, 0)
184+
185+
if len(data.shape) == 4:
186+
data = data[None]
187+
return self._nnunet_predictor(data)

0 commit comments

Comments
 (0)