Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions devel/monet_bundle_inference_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2002 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union
import logging
from monai.deploy.core import Image
from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
from monai.deploy.utils.importutil import optional_import
from monai.transforms import SpatialResample, ConcatItemsd
import numpy as np
MONAI_UTILS = "monai.utils"
nibabel, _ = optional_import("nibabel", "3.2.1")
torch, _ = optional_import("torch", "1.10.2")

NdarrayOrTensor, _ = optional_import("monai.config", name="NdarrayOrTensor")
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
PostFix, _ = optional_import("monai.utils.enums", name="PostFix") # For the default meta_key_postfix
first, _ = optional_import("monai.utils.misc", name="first")
ensure_tuple, _ = optional_import(MONAI_UTILS, name="ensure_tuple")
convert_to_dst_type, _ = optional_import(MONAI_UTILS, name="convert_to_dst_type")
Key, _ = optional_import(MONAI_UTILS, name="ImageMetaKey")
MetaKeys, _ = optional_import(MONAI_UTILS, name="MetaKeys")
SpaceKeys, _ = optional_import(MONAI_UTILS, name="SpaceKeys")
Compose_, _ = optional_import("monai.transforms", name="Compose")
ConfigParser_, _ = optional_import("monai.bundle", name="ConfigParser")
MapTransform_, _ = optional_import("monai.transforms", name="MapTransform")
SimpleInferer, _ = optional_import("monai.inferers", name="SimpleInferer")

Compose: Any = Compose_
MapTransform: Any = MapTransform_
ConfigParser: Any = ConfigParser_
__all__ = ["MONetBundleInferenceOperator"]


def define_affine_from_meta(meta: Dict[str, Any]) -> np.ndarray:
"""
Define an affine matrix from the metadata of a tensor.

Parameters
----------
meta : Dict[str, Any]
Metadata dictionary containing 'pixdim', 'origin', and 'direction'.

Returns
-------
np.ndarray
A 4x4 affine matrix constructed from the metadata.
"""
if "pixdim" not in meta or "origin" not in meta or "direction" not in meta:
return meta.get("affine", np.eye(4))
pixdim = meta["pixdim"]
origin = meta["origin"]
direction = meta["direction"].reshape(3, 3)

# Extract 3D spacing
spacing = pixdim[1:4] # drop the first element (usually 1 for time dim)

# Scale the direction vectors by spacing to get rotation+scale part
affine = direction * spacing[np.newaxis, :]

# Append origin to get 3x4 affine matrix
affine = np.column_stack((affine, origin))

# Make it a full 4x4 affine
affine_4x4 = np.vstack((affine, [0, 0, 0, 1]))

Check warning on line 73 in devel/monet_bundle_inference_operator.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the unused local variable "affine_4x4".

See more on https://sonarcloud.io/project/issues?id=Project-MONAI_monai-deploy-app-sdk&issues=AZsSKoCqsxC2bMcg5aOJ&open=AZsSKoCqsxC2bMcg5aOJ&pullRequest=574
pixdim = meta["pixdim"]
origin = meta["origin"]
direction = meta["direction"].reshape(3, 3)

# Extract 3D spacing
spacing = pixdim[1:4] # drop the first element (usually 1 for time dim)

# Scale the direction vectors by spacing to get rotation+scale part
affine = direction * spacing[np.newaxis, :]

# Append origin to get 3x4 affine matrix
affine = np.column_stack((affine, origin))

# Make it a full 4x4 affine
return torch.Tensor(np.vstack((affine, [0, 0, 0, 1])))

class MONetBundleInferenceOperator(MonaiBundleInferenceOperator):
"""
A specialized operator for performing inference using the MONet bundle.
This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
configurations and prediction logic. It initializes the nnUNet predictor and provides
a method for performing inference on input data.

Attributes
----------
_nnunet_predictor : torch.nn.Module
The nnUNet predictor module used for inference.

Methods
-------
_init_config(config_names)
Initializes the configuration for the nnUNet bundle, including parsing the bundle
configuration and setting up the nnUNet predictor.
predict(data, *args, **kwargs)
Performs inference on the input data using the nnUNet predictor.
"""

def __init__(
self,
*args,
**kwargs,
):

super().__init__(*args, **kwargs)

self._logger = logging.getLogger("{}.{}".format(__name__, type(self).__name__))
self._nnunet_predictor: torch.nn.Module = None
self.ref_modality = None
if "ref_modality" in kwargs:
self.ref_modality = kwargs["ref_modality"]

def _init_config(self, config_names):

super()._init_config(config_names)
parser = get_bundle_config(str(self._bundle_path), config_names)
self._parser = parser

self._nnunet_predictor = parser.get_parsed_content("network_def")

def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:

Check failure on line 133 in devel/monet_bundle_inference_operator.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 28 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=Project-MONAI_monai-deploy-app-sdk&issues=AZsSKoCqsxC2bMcg5aOK&open=AZsSKoCqsxC2bMcg5aOK&pullRequest=574
"""Predicts output using the inferer."""

self._nnunet_predictor.predictor.network = self._model_network
# os.environ['nnUNet_def_n_proc'] = "1"

Check warning on line 137 in devel/monet_bundle_inference_operator.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove this commented out code.

See more on https://sonarcloud.io/project/issues?id=Project-MONAI_monai-deploy-app-sdk&issues=AZsSKoCqsxC2bMcg5aOL&open=AZsSKoCqsxC2bMcg5aOL&pullRequest=574

if len(kwargs) > 0:
multimodal_data = {"image": data}
if self.ref_modality is not None:
if self.ref_modality not in kwargs:
target_affine_4x4 = define_affine_from_meta(data.meta)
spatial_size = data.shape[1:4]
if "pixdim" in data.meta:
pixdim = data.meta["pixdim"]
else:
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))
else:
target_affine_4x4 = define_affine_from_meta(kwargs[self.ref_modality].meta)
spatial_size = kwargs[self.ref_modality].shape[1:4]
if "pixdim" in kwargs[self.ref_modality].meta:
pixdim = kwargs[self.ref_modality].meta["pixdim"]
else:
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))
else:
target_affine_4x4 = define_affine_from_meta(data.meta)
spatial_size = data.shape[1:4]
if "pixdim" in data.meta:
pixdim = data.meta["pixdim"]
else:
pixdim = np.abs(np.array(target_affine_4x4[:3, :3].diagonal().tolist()))

for key in kwargs.keys():
if isinstance(kwargs[key], MetaTensor):
source_affine_4x4 = define_affine_from_meta(kwargs[key].meta)
kwargs[key].meta["affine"] = torch.Tensor(source_affine_4x4)
kwargs[key].meta["pixdim"] = pixdim
self._logger.info(f"Resampling {key} from {source_affine_4x4} to {target_affine_4x4}")

multimodal_data[key] = SpatialResample(mode="bilinear")(kwargs[key], dst_affine=target_affine_4x4,
spatial_size=spatial_size,
)
source_affine_4x4 = define_affine_from_meta(data.meta)
data.meta["affine"] = torch.Tensor(source_affine_4x4)
data.meta["pixdim"] = pixdim
multimodal_data["image"] = SpatialResample(mode="bilinear")(
data, dst_affine=target_affine_4x4, spatial_size=spatial_size
)

self._logger.info(f"Resampling 'image' from from {source_affine_4x4} to {target_affine_4x4}")
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
data.meta["pixdim"] = np.insert(pixdim, 0, 0)

if len(data.shape) == 4:
data = data[None]
return self._nnunet_predictor(data)
3 changes: 3 additions & 0 deletions monai/deploy/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
IOMapping
ModelInfo
MonaiBundleInferenceOperator
MONetBundleInferenceOperator
MonaiSegInferenceOperator
PNGConverterOperator
PublisherOperator
Expand All @@ -49,6 +50,7 @@
from .inference_operator import InferenceOperator
from .monai_bundle_inference_operator import BundleConfigNames, IOMapping, MonaiBundleInferenceOperator
from .monai_seg_inference_operator import MonaiSegInferenceOperator
from .monet_bundle_inference_operator import MONetBundleInferenceOperator
from .nii_data_loader_operator import NiftiDataLoader
from .png_converter_operator import PNGConverterOperator
from .publisher_operator import PublisherOperator
Expand All @@ -69,6 +71,7 @@
"IOMapping",
"ModelInfo",
"MonaiBundleInferenceOperator",
"MONetBundleInferenceOperator",
"MonaiSegInferenceOperator",
"NiftiDataLoader",
"PNGConverterOperator",
Expand Down
5 changes: 3 additions & 2 deletions monai/deploy/operators/monai_bundle_inference_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_bundle_config(bundle_path, config_names):
Gets the configuration parser from the specified Torchscript bundle file path.
"""

bundle_suffixes = (".json", ".yaml", "yml") # The only supported file ext(s)
bundle_suffixes = (".json", ".yaml", ".yml") # The only supported file ext(s)
config_folder = "extra"

def _read_from_archive(archive, root_name: str, config_name: str, do_search=True):
Expand Down Expand Up @@ -216,7 +216,7 @@ def _read_from_archive(archive, root_name: str, config_name: str, do_search=True
name_list = archive.namelist()
for suffix in bundle_suffixes:
for n in name_list:
if (f"{config_name}{suffix}").casefold in n.casefold():
if (f"{config_name}{suffix}").casefold() in n.casefold():
logging.debug(f"Trying to read content of config {config_name!r} from {n!r}.")
content_text = archive.read(n)
break
Expand Down Expand Up @@ -745,6 +745,7 @@ def compute(self, op_input, op_output, context):
# value: NdarrayOrTensor # MyPy complaints
value, meta_data = self._receive_input(name, op_input, context)
value = convert_to_dst_type(value, dst=value)[0]
meta_data = meta_data or {}
if not isinstance(meta_data, dict):
raise ValueError("`meta_data` must be a dict.")
value = MetaTensor.ensure_torch_and_prune_meta(value, meta_data)
Expand Down
91 changes: 91 additions & 0 deletions monai/deploy/operators/monet_bundle_inference_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2002 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union

from monai.deploy.core import Image
from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
from monai.deploy.utils.importutil import optional_import
from monai.transforms import ConcatItemsd, ResampleToMatch
from monai.deploy.core.models.torch_model import TorchScriptModel
torch, _ = optional_import("torch", "1.10.2")
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
__all__ = ["MONetBundleInferenceOperator"]


class MONetBundleInferenceOperator(MonaiBundleInferenceOperator):
"""
A specialized operator for performing inference using the MONet bundle.
This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
configurations and prediction logic. It initializes the nnUNet predictor and provides
a method for performing inference on input data.

Attributes
----------
_nnunet_predictor : torch.nn.Module
The nnUNet predictor module used for inference.

Methods
-------
_init_config(config_names)
Initializes the configuration for the nnUNet bundle, including parsing the bundle
configuration and setting up the nnUNet predictor.
predict(data, *args, **kwargs)
Performs inference on the input data using the nnUNet predictor.
"""

def __init__(
self,
*args,
**kwargs,
):

super().__init__(*args, **kwargs)

self._nnunet_predictor: torch.nn.Module = None

def _init_config(self, config_names):

super()._init_config(config_names)
parser = get_bundle_config(str(self._bundle_path), config_names)
self._parser = parser

self._nnunet_predictor = parser.get_parsed_content("network_def")

def _set_model_network(self, model_network):
"""
Sets the model network for the nnUNet predictor.

Parameters
----------
model_network : torch.nn.Module or torch.jit.ScriptModule
The model network to be used for inference.
"""
if not isinstance(model_network, torch.nn.Module) and not torch.jit.isinstance(model_network, torch.jit.ScriptModule) and not isinstance(model_network, TorchScriptModel):
raise TypeError("model_network must be an instance of torch.nn.Module or torch.jit.ScriptModule")
self._nnunet_predictor.predictor.network = model_network

def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
"""Predicts output using the inferer. If multimodal data is provided as keyword arguments,
it concatenates the data with the main input data."""

self._set_model_network(self._model_network)

if len(kwargs) > 0:
multimodal_data = {"image": data}
for key in kwargs.keys():
if isinstance(kwargs[key], MetaTensor):
multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data
)
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
if len(data.shape) == 4:
data = data[None]
return self._nnunet_predictor(data)