Skip to content

Commit c9f8d32

Browse files
borisfompre-commit-ci[bot]KumoLiuyiheng-wang-nvbinliunls
authored
Added TRTWrapper (#7990)
### Description Added alternative class to ONNX->TRT export and wrap TRT engines for inference. It encapsulates filesystem persistence and does not rely on torch-tensortrt for execution. Also can be used to run ONNX with onnxruntime. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com> Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com>
1 parent fa1ef8b commit c9f8d32

17 files changed

+1121
-51
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ RUN apt-get update \
5656
&& rm -rf /var/lib/apt/lists/*
5757
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
5858
ENV PATH=${PATH}:/opt/tools
59+
ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
5960
WORKDIR /opt/monai

docs/source/config_syntax.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Content:
1616
- [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)
1717
- [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)
1818
- [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object)
19+
- [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files)
1920
- [The command line interface](#the-command-line-interface)
2021
- [Recommendations](#recommendations)
2122

@@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
175176
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
176177
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).
177178

179+
## Multiple config files
180+
181+
_Description:_ Multiple config files may be specified on the command line.
182+
The content of those config files is being merged. When same keys are specifiled in more than one config file,
183+
the value associated with the key is being overridden, in the order config files are specified.
184+
If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`.
185+
The value types for the merged contents must match and be both of `dict` or both of `list` type.
186+
`dict` values will be merged via update(), `list` values - concatenated via extend().
187+
Here's an example. In this case, "amp" value will be overridden by extra_config.json.
188+
`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`:
189+
190+
config.json:
191+
```json
192+
{
193+
"amp": "$True"
194+
"imports": [
195+
"$import torch"
196+
],
197+
"preprocessing": {
198+
"_target_": "Compose",
199+
"transforms": [
200+
"$@t1",
201+
"$@t2"
202+
]
203+
},
204+
}
205+
```
206+
207+
extra_config.json:
208+
```json
209+
{
210+
"amp": "$False"
211+
"+imports": [
212+
"$from monai.networks import trt_compile"
213+
],
214+
"+preprocessing#transforms": [
215+
"$@t3"
216+
]
217+
}
218+
```
219+
178220
## The command line interface
179221

180222
In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.

monai/bundle/config_parser.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
2222
from monai.bundle.reference_resolver import ReferenceResolver
23-
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
23+
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
2424
from monai.config import PathLike
2525
from monai.utils import ensure_tuple, look_up_option, optional_import
2626
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
@@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
423423
if isinstance(files, str) and not Path(files).is_file() and "," in files:
424424
files = files.split(",")
425425
for i in ensure_tuple(files):
426-
for k, v in (cls.load_config_file(i, **kwargs)).items():
427-
parser[k] = v
426+
config_dict = cls.load_config_file(i, **kwargs)
427+
for k, v in config_dict.items():
428+
merge_kv(parser, k, v)
429+
428430
return parser.get() # type: ignore
429431

430432
@classmethod

monai/bundle/scripts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from monai.apps.utils import _basename, download_url, extractall, get_logger
3333
from monai.bundle.config_item import ConfigComponent
3434
from monai.bundle.config_parser import ConfigParser
35-
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
35+
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
3636
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
3737
from monai.config import IgniteInfo, PathLike
3838
from monai.data import load_net_with_metadata, save_net_with_metadata
@@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
105105
if isinstance(v, dict) and isinstance(args_.get(k), dict):
106106
args_[k] = update_kwargs(args_[k], ignore_none, **v)
107107
else:
108-
args_[k] = v
108+
merge_kv(args_, k, v)
109109
return args_
110110

111111

monai/bundle/utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import json
1515
import os
16+
import warnings
1617
import zipfile
1718
from typing import Any
1819

@@ -21,12 +22,21 @@
2122

2223
yaml, _ = optional_import("yaml")
2324

24-
__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"]
25+
__all__ = [
26+
"ID_REF_KEY",
27+
"ID_SEP_KEY",
28+
"EXPR_KEY",
29+
"MACRO_KEY",
30+
"MERGE_KEY",
31+
"DEFAULT_MLFLOW_SETTINGS",
32+
"DEFAULT_EXP_MGMT_SETTINGS",
33+
]
2534

2635
ID_REF_KEY = "@" # start of a reference to a ConfigItem
2736
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
2837
EXPR_KEY = "$" # start of a ConfigExpression
2938
MACRO_KEY = "%" # start of a macro of a config
39+
MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.
3040

3141
_conf_values = get_config_values()
3242

@@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
233243
parser.read_config(f=cdata)
234244

235245
return parser
246+
247+
248+
def merge_kv(args: dict | Any, k: str, v: Any) -> None:
249+
"""
250+
Update the `args` dict-like object with the key/value pair `k` and `v`.
251+
"""
252+
if k.startswith(MERGE_KEY):
253+
"""
254+
Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
255+
`dict` values will be merged, `list` values - concatenated.
256+
"""
257+
id = k[1:]
258+
if id in args:
259+
if isinstance(v, dict) and isinstance(args[id], dict):
260+
args[id].update(v)
261+
elif isinstance(v, list) and isinstance(args[id], list):
262+
args[id].extend(v)
263+
else:
264+
raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
265+
else:
266+
warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
267+
args[id] = v
268+
else:
269+
args[k] = v

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@
4040
from .stats_handler import StatsHandler
4141
from .surface_distance import SurfaceDistance
4242
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
43+
from .trt_handler import TrtHandler
4344
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
4445
from .validation_handler import ValidationHandler

monai/handlers/trt_handler.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import TYPE_CHECKING
15+
16+
from monai.config import IgniteInfo
17+
from monai.networks import trt_compile
18+
from monai.utils import min_version, optional_import
19+
20+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
21+
if TYPE_CHECKING:
22+
from ignite.engine import Engine
23+
else:
24+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
25+
26+
27+
class TrtHandler:
28+
"""
29+
TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
30+
Usage example::
31+
handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
32+
handler.attach(engine)
33+
engine.run()
34+
"""
35+
36+
def __init__(self, model, base_path, args=None, submodule=None):
37+
"""
38+
Args:
39+
base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
40+
args: passed to trt_compile(). See trt_compile() for details.
41+
submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
42+
"""
43+
self.model = model
44+
self.base_path = base_path
45+
self.args = args
46+
self.submodule = submodule
47+
48+
def attach(self, engine: Engine) -> None:
49+
"""
50+
Args:
51+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
52+
"""
53+
self.logger = engine.logger
54+
engine.add_event_handler(Events.STARTED, self)
55+
56+
def __call__(self, engine: Engine) -> None:
57+
"""
58+
Args:
59+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
60+
"""
61+
trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)

monai/networks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
from .trt_compiler import trt_compile
1415
from .utils import (
16+
add_casts_around_norms,
1517
convert_to_onnx,
1618
convert_to_torchscript,
1719
convert_to_trt,

monai/networks/nets/swin_unetr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape):
320320
)
321321

322322
def forward(self, x_in):
323-
if not torch.jit.is_scripting():
323+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
324324
self._check_input_size(x_in.shape[2:])
325325
hidden_states_out = self.swinViT(x_in, self.normalize)
326326
enc0 = self.encoder1(x_in)
@@ -1046,14 +1046,14 @@ def __init__(
10461046

10471047
def proj_out(self, x, normalize=False):
10481048
if normalize:
1049-
x_shape = x.size()
1049+
x_shape = x.shape
1050+
# Force trace() to generate a constant by casting to int
1051+
ch = int(x_shape[1])
10501052
if len(x_shape) == 5:
1051-
n, ch, d, h, w = x_shape
10521053
x = rearrange(x, "n c d h w -> n d h w c")
10531054
x = F.layer_norm(x, [ch])
10541055
x = rearrange(x, "n d h w c -> n c d h w")
10551056
elif len(x_shape) == 4:
1056-
n, ch, h, w = x_shape
10571057
x = rearrange(x, "n c h w -> n h w c")
10581058
x = F.layer_norm(x, [ch])
10591059
x = rearrange(x, "n h w c -> n c h w")

0 commit comments

Comments
 (0)