Skip to content

Commit 4c4fbb5

Browse files
Add MONetBundleInferenceOperator implementation
- Introduced a new operator, MONetBundleInferenceOperator, for performing inference using the MONet bundle. - Extended functionality from MonaiBundleInferenceOperator to support nnUNet-specific configurations. - Implemented methods for initializing configurations and performing predictions with multimodal data handling. This addition enhances the inference capabilities within the MONAI framework.
1 parent a96b680 commit 4c4fbb5

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2002 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Any, Dict, Tuple, Union
13+
14+
from monai.deploy.core import Image
15+
from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config
16+
from monai.deploy.utils.importutil import optional_import
17+
from monai.transforms import ConcatItemsd, ResampleToMatch
18+
from monai.deploy.core.models.torch_model import TorchScriptModel
19+
torch, _ = optional_import("torch", "1.10.2")
20+
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
21+
__all__ = ["MONetBundleInferenceOperator"]
22+
23+
24+
class MONetBundleInferenceOperator(MonaiBundleInferenceOperator):
25+
"""
26+
A specialized operator for performing inference using the MONet bundle.
27+
This operator extends the `MonaiBundleInferenceOperator` to support nnUNet-specific
28+
configurations and prediction logic. It initializes the nnUNet predictor and provides
29+
a method for performing inference on input data.
30+
31+
Attributes
32+
----------
33+
_nnunet_predictor : torch.nn.Module
34+
The nnUNet predictor module used for inference.
35+
36+
Methods
37+
-------
38+
_init_config(config_names)
39+
Initializes the configuration for the nnUNet bundle, including parsing the bundle
40+
configuration and setting up the nnUNet predictor.
41+
predict(data, *args, **kwargs)
42+
Performs inference on the input data using the nnUNet predictor.
43+
"""
44+
45+
def __init__(
46+
self,
47+
*args,
48+
**kwargs,
49+
):
50+
51+
super().__init__(*args, **kwargs)
52+
53+
self._nnunet_predictor: torch.nn.Module = None
54+
55+
def _init_config(self, config_names):
56+
57+
super()._init_config(config_names)
58+
parser = get_bundle_config(str(self._bundle_path), config_names)
59+
self._parser = parser
60+
61+
self._nnunet_predictor = parser.get_parsed_content("network_def")
62+
63+
def _set_model_network(self, model_network):
64+
"""
65+
Sets the model network for the nnUNet predictor.
66+
67+
Parameters
68+
----------
69+
model_network : torch.nn.Module or torch.jit.ScriptModule
70+
The model network to be used for inference.
71+
"""
72+
if not isinstance(model_network, torch.nn.Module) and not torch.jit.isinstance(model_network, torch.jit.ScriptModule) and not isinstance(model_network, TorchScriptModel):
73+
raise TypeError("model_network must be an instance of torch.nn.Module or torch.jit.ScriptModule")
74+
self._nnunet_predictor.predictor.network = model_network
75+
76+
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
77+
"""Predicts output using the inferer. If multimodal data is provided as keyword arguments,
78+
it concatenates the data with the main input data."""
79+
80+
self._set_model_network(self._model_network)
81+
82+
if len(kwargs) > 0:
83+
multimodal_data = {"image": data}
84+
for key in kwargs.keys():
85+
if isinstance(kwargs[key], MetaTensor):
86+
multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data
87+
)
88+
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
89+
if len(data.shape) == 4:
90+
data = data[None]
91+
return self._nnunet_predictor(data)

0 commit comments

Comments
 (0)