Skip to content

Commit d94df3f

Browse files
Nic-Mapre-commit-ci[bot]KumoLiuericspod
authored
8134 Add unit test for responsive inference (#8146)
Fixes #8134 . ### Description This PR added unit test to cover the realtime inference with bundles. And updated `BundleWorkflow` to support cyclically calling the `run` function with all components instantiated. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Nic Ma <nma@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent b1e915c commit d94df3f

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed

monai/bundle/reference_resolver.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
192192
"""
193193
return self._resolve_one_item(id=id, **kwargs)
194194

195+
def remove_resolved_content(self, id: str) -> Any | None:
196+
"""
197+
Remove the resolved ``ConfigItem`` by id.
198+
199+
Args:
200+
id: id name of the expected item.
201+
202+
"""
203+
return self.resolved_content.pop(id) if id in self.resolved_content else None
204+
195205
@classmethod
196206
def normalize_id(cls, id: str | int) -> str:
197207
"""

monai/bundle/workflows.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,23 @@ def check_properties(self) -> list[str] | None:
394394
ret.extend(wrong_props)
395395
return ret
396396

397-
def _run_expr(self, id: str, **kwargs: dict) -> Any:
398-
return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
397+
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
398+
"""
399+
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
400+
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
401+
"""
402+
ret = []
403+
if id in self.parser:
404+
# suppose all the expressions are in a list, run and reset the expressions
405+
if isinstance(self.parser[id], list):
406+
for i in range(len(self.parser[id])):
407+
sub_id = f"{id}{ID_SEP_KEY}{i}"
408+
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
409+
self.parser.ref_resolver.remove_resolved_content(sub_id)
410+
else:
411+
ret.append(self.parser.get_parsed_content(id, **kwargs))
412+
self.parser.ref_resolver.remove_resolved_content(id)
413+
return ret
399414

400415
def _get_prop_id(self, name: str, property: dict) -> Any:
401416
prop_id = property[BundlePropertyConfig.ID]

tests/test_bundle_workflow.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from monai.data import Dataset
2727
from monai.inferers import SimpleInferer, SlidingWindowInferer
2828
from monai.networks.nets import UNet
29-
from monai.transforms import Compose, LoadImage
29+
from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged
3030
from tests.nonconfig_workflow import NonConfigWorkflow
3131

3232
TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]
@@ -35,6 +35,8 @@
3535

3636
TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]
3737

38+
TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")]
39+
3840
TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."]
3941

4042

@@ -45,7 +47,9 @@ def setUp(self):
4547
self.expected_shape = (128, 128, 128)
4648
test_image = np.random.rand(*self.expected_shape)
4749
self.filename = os.path.join(self.data_dir, "image.nii")
50+
self.filename1 = os.path.join(self.data_dir, "image1.nii")
4851
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename)
52+
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1)
4953

5054
def tearDown(self):
5155
shutil.rmtree(self.data_dir)
@@ -115,6 +119,35 @@ def test_inference_config(self, config_file):
115119
self._test_inferer(inferer)
116120
self.assertEqual(inferer.workflow_type, None)
117121

122+
@parameterized.expand([TEST_CASE_4])
123+
def test_responsive_inference_config(self, config_file):
124+
input_loader = LoadImaged(keys="image")
125+
output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg")
126+
127+
# test standard MONAI model-zoo config workflow
128+
inferer = ConfigWorkflow(
129+
workflow_type="infer",
130+
config_file=config_file,
131+
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
132+
)
133+
# FIXME: temp add the property for test, we should add it to some formal realtime infer properties
134+
inferer.add_property(name="dataflow", required=True, config_id="dataflow")
135+
136+
inferer.initialize()
137+
inferer.dataflow.update(input_loader({"image": self.filename}))
138+
inferer.run()
139+
output_saver(inferer.dataflow)
140+
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz")))
141+
142+
# bundle is instantiated and idle, just change the input for next inference
143+
inferer.dataflow.clear()
144+
inferer.dataflow.update(input_loader({"image": self.filename1}))
145+
inferer.run()
146+
output_saver(inferer.dataflow)
147+
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz")))
148+
149+
inferer.finalize()
150+
118151
@parameterized.expand([TEST_CASE_3])
119152
def test_train_config(self, config_file):
120153
# test standard MONAI model-zoo config workflow
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
{
2+
"imports": [
3+
"$from collections import defaultdict"
4+
],
5+
"bundle_root": "will override",
6+
"device": "$torch.device('cpu')",
7+
"network_def": {
8+
"_target_": "UNet",
9+
"spatial_dims": 3,
10+
"in_channels": 1,
11+
"out_channels": 2,
12+
"channels": [
13+
2,
14+
2,
15+
4,
16+
8,
17+
4
18+
],
19+
"strides": [
20+
2,
21+
2,
22+
2,
23+
2
24+
],
25+
"num_res_units": 2,
26+
"norm": "batch"
27+
},
28+
"network": "$@network_def.to(@device)",
29+
"dataflow": "$defaultdict()",
30+
"preprocessing": {
31+
"_target_": "Compose",
32+
"transforms": [
33+
{
34+
"_target_": "EnsureChannelFirstd",
35+
"keys": "image"
36+
},
37+
{
38+
"_target_": "ScaleIntensityd",
39+
"keys": "image"
40+
},
41+
{
42+
"_target_": "RandRotated",
43+
"_disabled_": true,
44+
"keys": "image"
45+
}
46+
]
47+
},
48+
"dataset": {
49+
"_target_": "Dataset",
50+
"data": [
51+
"@dataflow"
52+
],
53+
"transform": "@preprocessing"
54+
},
55+
"dataloader": {
56+
"_target_": "DataLoader",
57+
"dataset": "@dataset",
58+
"batch_size": 1,
59+
"shuffle": false,
60+
"num_workers": 0
61+
},
62+
"inferer": {
63+
"_target_": "SlidingWindowInferer",
64+
"roi_size": [
65+
64,
66+
64,
67+
32
68+
],
69+
"sw_batch_size": 4,
70+
"overlap": 0.25
71+
},
72+
"postprocessing": {
73+
"_target_": "Compose",
74+
"transforms": [
75+
{
76+
"_target_": "Activationsd",
77+
"keys": "pred",
78+
"softmax": true
79+
},
80+
{
81+
"_target_": "AsDiscreted",
82+
"keys": "pred",
83+
"argmax": true
84+
}
85+
]
86+
},
87+
"evaluator": {
88+
"_target_": "SupervisedEvaluator",
89+
"device": "@device",
90+
"val_data_loader": "@dataloader",
91+
"network": "@network",
92+
"inferer": "@inferer",
93+
"postprocessing": "@postprocessing",
94+
"amp": false,
95+
"epoch_length": 1
96+
},
97+
"run": [
98+
"$@evaluator.run()",
99+
"$@dataflow.update(@evaluator.state.output[0])"
100+
]
101+
}

0 commit comments

Comments
 (0)