|
| 1 | +# Copyright 2002 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from typing import Any, Dict, Tuple, Union |
| 13 | +import logging |
| 14 | +from monai.deploy.core import Image |
| 15 | +from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config |
| 16 | +from monai.deploy.utils.importutil import optional_import |
| 17 | +from monai.transforms import SpatialResample, ConcatItemsd |
| 18 | +import numpy as np |
| 19 | +MONAI_UTILS = "monai.utils" |
| 20 | +nibabel, _ = optional_import("nibabel", "3.2.1") |
| 21 | +torch, _ = optional_import("torch", "1.10.2") |
| 22 | + |
| 23 | +NdarrayOrTensor, _ = optional_import("monai.config", name="NdarrayOrTensor") |
| 24 | +MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor") |
| 25 | +PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix |
| 26 | +first, _ = optional_import("monai.utils.misc", name="first") |
| 27 | +ensure_tuple, _ = optional_import(MONAI_UTILS, name="ensure_tuple") |
| 28 | +convert_to_dst_type, _ = optional_import(MONAI_UTILS, name="convert_to_dst_type") |
| 29 | +Key, _ = optional_import(MONAI_UTILS, name="ImageMetaKey") |
| 30 | +MetaKeys, _ = optional_import(MONAI_UTILS, name="MetaKeys") |
| 31 | +SpaceKeys, _ = optional_import(MONAI_UTILS, name="SpaceKeys") |
| 32 | +Compose_, _ = optional_import("monai.transforms", name="Compose") |
| 33 | +ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser") |
| 34 | +MapTransform_, _ = optional_import("monai.transforms", name="MapTransform") |
| 35 | +SimpleInferer, _ = optional_import("monai.inferers", name="SimpleInferer") |
| 36 | + |
| 37 | +Compose: Any = Compose_ |
| 38 | +MapTransform: Any = MapTransform_ |
| 39 | +ConfigParser: Any = ConfigParser_ |
| 40 | +__all__ = ["MONetBundleInferenceOperator"] |
| 41 | + |
| 42 | + |
| 43 | +def define_affine_from_meta(meta: Dict[str, Any]) -> np.ndarray: |
| 44 | + """ |
| 45 | + Define an affine matrix from the metadata of a tensor. |
| 46 | +
|
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + meta : Dict[str, Any] |
| 50 | + Metadata dictionary containing 'pixdim', 'origin', and 'direction'. |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + np.ndarray |
| 55 | + A 4x4 affine matrix constructed from the metadata. |
| 56 | + """ |
| 57 | + if "pixdim" not in meta or "origin" not in meta or "direction" not in meta: |
| 58 | + return meta.get("affine", np.eye(4)) |
| 59 | + pixdim = meta["pixdim"] |
| 60 | + origin = meta["origin"] |
| 61 | + direction = meta["direction"].reshape(3, 3) |
| 62 | + |
| 63 | + # Extract 3D spacing |
| 64 | + spacing = pixdim[1:4] # drop the first element (usually 1 for time dim) |
| 65 | + |
| 66 | + # Scale the direction vectors by spacing to get rotation+scale part |
| 67 | + affine = direction * spacing[np.newaxis, :] |
| 68 | + |
| 69 | + # Append origin to get 3x4 affine matrix |
| 70 | + affine = np.column_stack((affine, origin)) |
| 71 | + |
| 72 | + # Make it a full 4x4 affine |
| 73 | + affine_4x4 = np.vstack((affine, [0, 0, 0, 1])) |
| 74 | + pixdim = meta["pixdim"] |
| 75 | + origin = meta["origin"] |
| 76 | + direction = meta["direction"].reshape(3, 3) |
| 77 | + |
| 78 | + # Extract 3D spacing |
| 79 | + spacing = pixdim[1:4] # drop the first element (usually 1 for time dim) |
| 80 | + |
| 81 | + # Scale the direction vectors by spacing to get rotation+scale part |
| 82 | + affine = direction * spacing[np.newaxis, :] |
| 83 | + |
| 84 | + # Append origin to get 3x4 affine matrix |
| 85 | + affine = np.column_stack((affine, origin)) |
| 86 | + |
| 87 | + # Make it a full 4x4 affine |
| 88 | + return torch.Tensor(np.vstack((affine, [0, 0, 0, 1]))) |
| 89 | + |
| 90 | +class MONetBundleInferenceOperator(MonaiBundleInferenceOperator): |
| 91 | + """ |
| 92 | + A specialized operator for performing inference using the MONet bundle. |
| 93 | + This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific |
| 94 | + configurations and prediction logic. It initializes the nnUNet predictor and provides |
| 95 | + a method for performing inference on input data. |
| 96 | +
|
| 97 | + Attributes |
| 98 | + ---------- |
| 99 | + _nnunet_predictor : torch.nn.Module |
| 100 | + The nnUNet predictor module used for inference. |
| 101 | +
|
| 102 | + Methods |
| 103 | + ------- |
| 104 | + _init_config(config_names) |
| 105 | + Initializes the configuration for the nnUNet bundle, including parsing the bundle |
| 106 | + configuration and setting up the nnUNet predictor. |
| 107 | + predict(data, *args, **kwargs) |
| 108 | + Performs inference on the input data using the nnUNet predictor. |
| 109 | + """ |
| 110 | + |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + *args, |
| 114 | + **kwargs, |
| 115 | + ): |
| 116 | + |
| 117 | + super().__init__(*args, **kwargs) |
| 118 | + |
| 119 | + self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__)) |
| 120 | + self._nnunet_predictor: torch.nn.Module = None |
| 121 | + self.ref_modality = None |
| 122 | + if "ref_modality" in kwargs: |
| 123 | + self.ref_modality = kwargs["ref_modality"] |
| 124 | + |
| 125 | + def _init_config(self, config_names): |
| 126 | + |
| 127 | + super()._init_config(config_names) |
| 128 | + parser = get_bundle_config(str(self._bundle_path), config_names) |
| 129 | + self._parser = parser |
| 130 | + |
| 131 | + self._nnunet_predictor = parser.get_parsed_content("network_def") |
| 132 | + |
| 133 | + def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: |
| 134 | + """Predicts output using the inferer.""" |
| 135 | + |
| 136 | + self._nnunet_predictor.predictor.network = self._model_network |
| 137 | + # os.environ['nnUNet_def_n_proc'] = "1" |
| 138 | + |
| 139 | + if len(kwargs) > 0: |
| 140 | + multimodal_data = {"image": data} |
| 141 | + if self.ref_modality is not None: |
| 142 | + if self.ref_modality not in kwargs: |
| 143 | + target_affine_4x4 = define_affine_from_meta(data.meta) |
| 144 | + spatial_size = data.shape[1:4] |
| 145 | + if "pixdim" in data.meta: |
| 146 | + pixdim = data.meta["pixdim"] |
| 147 | + else: |
| 148 | + pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist())) |
| 149 | + else: |
| 150 | + target_affine_4x4 = define_affine_from_meta(kwargs[self.ref_modality].meta) |
| 151 | + spatial_size = kwargs[self.ref_modality].shape[1:4] |
| 152 | + if "pixdim" in kwargs[self.ref_modality].meta: |
| 153 | + pixdim = kwargs[self.ref_modality].meta["pixdim"] |
| 154 | + else: |
| 155 | + pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist())) |
| 156 | + else: |
| 157 | + target_affine_4x4 = define_affine_from_meta(data.meta) |
| 158 | + spatial_size = data.shape[1:4] |
| 159 | + if "pixdim" in data.meta: |
| 160 | + pixdim = data.meta["pixdim"] |
| 161 | + else: |
| 162 | + pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist())) |
| 163 | + |
| 164 | + for key in kwargs.keys(): |
| 165 | + if isinstance(kwargs[key], MetaTensor): |
| 166 | + source_affine_4x4 = define_affine_from_meta(kwargs[key].meta) |
| 167 | + kwargs[key].meta["affine"] = torch.Tensor(source_affine_4x4) |
| 168 | + kwargs[key].meta["pixdim"] = pixdim |
| 169 | + self._logger.info(f"Resampling {key} from {source_affine_4x4} to {target_affine_4x4}") |
| 170 | + |
| 171 | + multimodal_data[key] = SpatialResample(mode="bilinear")(kwargs[key], dst_affine=target_affine_4x4, |
| 172 | + spatial_size=spatial_size, |
| 173 | + ) |
| 174 | + source_affine_4x4 = define_affine_from_meta(data.meta) |
| 175 | + data.meta["affine"] = torch.Tensor(source_affine_4x4) |
| 176 | + data.meta["pixdim"] = pixdim |
| 177 | + multimodal_data["image"] = SpatialResample(mode="bilinear")( |
| 178 | + data, dst_affine=target_affine_4x4, spatial_size=spatial_size |
| 179 | + ) |
| 180 | + |
| 181 | + self._logger.info(f"Resampling 'image' from from {source_affine_4x4} to {target_affine_4x4}") |
| 182 | + data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"] |
| 183 | + data.meta["pixdim"] = np.insert(pixdim, 0, 0) |
| 184 | + |
| 185 | + if len(data.shape) == 4: |
| 186 | + data = data[None] |
| 187 | + return self._nnunet_predictor(data) |
0 commit comments