diff --git a/benchmarks/BM_resnet101/LICENSE b/benchmarks/BM_resnet101/LICENSE new file mode 100644 index 00000000..ff5bd9ad --- /dev/null +++ b/benchmarks/BM_resnet101/LICENSE @@ -0,0 +1,6 @@ +Test video - excerpt from Sintel https://durian.blender.org + +License (https://durian.blender.org/sharing/): +CC BY 3.0 + +© copyright Blender Foundation | www.sintel.org \ No newline at end of file diff --git a/benchmarks/BM_resnet101/install.sh b/benchmarks/BM_resnet101/install.sh new file mode 100644 index 00000000..b9c672d3 --- /dev/null +++ b/benchmarks/BM_resnet101/install.sh @@ -0,0 +1,24 @@ +#!/bin/bash -ex + +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118 diff --git a/benchmarks/BM_resnet101/model_repository/dali_postprocessing/1/dali.py b/benchmarks/BM_resnet101/model_repository/dali_postprocessing/1/dali.py new file mode 100644 index 00000000..c249efe9 --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/dali_postprocessing/1/dali.py @@ -0,0 +1,50 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from nvidia.dali import fn +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.triton import autoserialize +import nvidia.dali.types as types + + +@autoserialize +@pipeline_def(batch_size=256, num_threads=4, device_id=0, output_ndim=[3], output_dtype=[types.UINT8]) +def dali_postprocessing_pipe(class_idx=0, prob_threshold=0.6): + """ + DALI post-processing pipeline definition + Args: + class_idx: Index of the class that shall be segmented. Shall be correlated with `seg_class_name` argument + in the Model instance. + prob_threshold: Probability threshold, at which the class affiliation is determined. + + Returns: + Segmented images. + """ + image = fn.external_source(device="gpu", name="original") + image = fn.reshape(image, layout="HWC") # No reshape performed, only setting the layout + width = fn.external_source(device="cpu", name="video_width") + height = fn.external_source(device="cpu", name="video_height") + prob = fn.external_source(device="gpu", name="probabilities") + prob = fn.reshape(prob, layout="CHW") # No reshape performed, only setting the layout + prob = fn.expand_dims(prob[class_idx], axes=[2], new_axis_names="C") + prob = fn.resize(prob, resize_x=width, resize_y=height, interp_type=types.DALIInterpType.INTERP_NN) + mask = fn.cast(prob > prob_threshold, dtype=types.UINT8) + return image * mask diff --git a/benchmarks/BM_resnet101/model_repository/dali_postprocessing/config.pbtxt b/benchmarks/BM_resnet101/model_repository/dali_postprocessing/config.pbtxt new file mode 100644 index 00000000..ea51911f --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/dali_postprocessing/config.pbtxt @@ -0,0 +1,54 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +name: "dali_postprocessing" +backend: "dali" +max_batch_size: 256 +input [ + { + name: "original" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1 ] + }, + { + name: "probabilities" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + }, + { + name: "video_width" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "video_height" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] + +output [ + { + name: "segmented" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1 ] + } +] diff --git a/benchmarks/BM_resnet101/model_repository/dali_preprocessing/1/dali.py b/benchmarks/BM_resnet101/model_repository/dali_preprocessing/1/dali.py new file mode 100644 index 00000000..2a68d3c8 --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/dali_preprocessing/1/dali.py @@ -0,0 +1,46 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from nvidia.dali import fn +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.triton import autoserialize +import nvidia.dali.types as types + + +@autoserialize +@pipeline_def(batch_size=16, num_threads=4, device_id=0) +def dali_preprocessing_pipe(): + """ + DALI pre-processing pipeline definition. + """ + encoded = fn.external_source(name="encoded") + decoded = fn.experimental.decoders.video(encoded, device="mixed", name="original") + preprocessed = fn.resize(decoded, resize_x=224, resize_y=224) + preprocessed = fn.crop_mirror_normalize( + preprocessed, + dtype=types.FLOAT, + output_layout="FCHW", + crop=(224, 224), + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255], + name="preprocessed", + ) + return decoded, preprocessed # split_along_outer_axis enabled diff --git a/benchmarks/BM_resnet101/model_repository/dali_preprocessing/config.pbtxt b/benchmarks/BM_resnet101/model_repository/dali_preprocessing/config.pbtxt new file mode 100644 index 00000000..bc5eef53 --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/dali_preprocessing/config.pbtxt @@ -0,0 +1,51 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +name: "dali_preprocessing" +backend: "dali" +max_batch_size: 16 +input [ + { + name: "encoded" + data_type: TYPE_UINT8 + dims: [ -1 ] + } +] + +output [ + { + name: "original" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1 ] + }, + { + name: "preprocessed" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] + +parameters [ + { + key: "split_along_outer_axis", + value: { string_value: "original:preprocessed" } + } +] diff --git a/benchmarks/BM_resnet101/model_repository/resnet101/1/model.py b/benchmarks/BM_resnet101/model_repository/resnet101/1/model.py new file mode 100644 index 00000000..5aacef5a --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/resnet101/1/model.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import json +# import nvtx # pytype: disable=import-error +import torch # pytype: disable=import-error +from torchvision.models import segmentation as segmentation_models # pytype: disable=import-error +import triton_python_backend_utils as pb_utils +from torch.utils.dlpack import from_dlpack, to_dlpack + + +class SegmentationPyTorch: + """ + Excerpt from CV-CUDA segmentation example: + https://github.com/CVCUDA/CV-CUDA/blob/release_v0.3.x/samples/segmentation/python/model_inference.py + """ + def __init__(self, seg_class_name, device_id): + self.logger = logging.getLogger(__name__) + self.device_id = device_id + # Fetch the segmentation index to class name information from the weights + # meta properties. + # The underlying pytorch model that we use for inference is the FCN model + # from torchvision. + torch_model = segmentation_models.fcn_resnet101 + weights = segmentation_models.FCN_ResNet101_Weights.DEFAULT + + try: + self.class_index = weights.meta["categories"].index(seg_class_name) + except ValueError: + raise ValueError( + "Requested segmentation class '%s' is not supported by the " + "fcn_resnet101 model. All supported class names are: %s" + % (seg_class_name, ", ".join(weights.meta["categories"])) + ) + + # Inference uses PyTorch to run a segmentation model on the pre-processed + # input and outputs the segmentation masks. + class FCN_Softmax(torch.nn.Module): # noqa: N801 + def __init__(self, fcn): + super().__init__() + self.fcn = fcn + + def forward(self, x): + infer_output = self.fcn(x)["out"] + return torch.nn.functional.softmax(infer_output, dim=1) + + fcn_base = torch_model(weights=weights) + fcn_base.eval() + self.model = FCN_Softmax(fcn_base).cuda(self.device_id) + self.model.eval() + + self.logger.info("Using PyTorch as the inference engine.") + + def __call__(self, tensor): + # nvtx.push_range("inference.torch") + + with torch.no_grad(): + segmented = self.model(tensor) + + # nvtx.pop_range() + return segmented + + +class TritonPythonModel: + def __init__(self): + self.segmentation_model=SegmentationPyTorch( + seg_class_name="__background__", + device_id=0, + ) + + def initialize(self, args): + self.model_config = model_config = json.loads(args['model_config']) + output0_config = pb_utils.get_output_config_by_name(model_config, "probabilities") + self.output_dtype = pb_utils.triton_string_to_numpy(output0_config['data_type']) + + + + def execute(self, requests): + responses = [] + + for request in requests: + in0 = pb_utils.get_input_tensor_by_name(request, "preprocessed") + in0_t = from_dlpack(in0.to_dlpack()).cuda() + out0_t = self.segmentation_model(in0_t) + out0 = pb_utils.Tensor.from_dlpack("probabilities", to_dlpack(out0_t)) + + response = pb_utils.InferenceResponse(output_tensors=[out0]) + responses.append(response) + return responses diff --git a/benchmarks/BM_resnet101/model_repository/resnet101/config.pbtxt b/benchmarks/BM_resnet101/model_repository/resnet101/config.pbtxt new file mode 100644 index 00000000..6ca06472 --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/resnet101/config.pbtxt @@ -0,0 +1,39 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +name: "resnet101" +backend: "python" +max_batch_size: 256 +input [ + { + name: "preprocessed" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] + +output [ + { + name: "probabilities" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] diff --git a/benchmarks/BM_resnet101/model_repository/segmentation_bls/1/model.py b/benchmarks/BM_resnet101/model_repository/segmentation_bls/1/model.py new file mode 100644 index 00000000..b14aade8 --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/segmentation_bls/1/model.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils + + +def run_inference(model_name, inputs, output_names): + request = pb_utils.InferenceRequest( + model_name=model_name, + requested_output_names=output_names, + inputs=inputs) + response = request.exec() + + if response.has_error(): + raise pb_utils.TritonModelException( + response.error().message()) + + return map(lambda oname: pb_utils.get_output_tensor_by_name(response, oname), output_names) + + +def extract_subtensor(tensor, start_idx, size): + tensor_pt = torch.from_dlpack(tensor.to_dlpack()) + subtensor = tensor_pt[start_idx: start_idx + size] + return pb_utils.Tensor.from_dlpack(tensor.name(), torch.utils.dlpack.to_dlpack(subtensor)) + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args['model_config']) + + def execute(self, requests): + responses = [] + for request in requests: + in_encoded = pb_utils.get_input_tensor_by_name(request, "encoded") + + original, preprocessed = run_inference("dali_preprocessing", [in_encoded], ["original", "preprocessed"]) + + probabilities, = run_inference("resnet101", [preprocessed], ["probabilities"]) + + batch_size = original.shape()[0] + in_height = original.shape()[1] + in_width = original.shape()[2] + video_height = pb_utils.Tensor("video_height", np.full((batch_size, 1), in_height, dtype=np.float32)) + video_width = pb_utils.Tensor("video_width", np.full((batch_size, 1), in_width, dtype=np.float32)) + + segmented, = run_inference("dali_postprocessing", [original, probabilities, video_width, video_height], + ["segmented"]) + + inference_response = pb_utils.InferenceResponse(output_tensors=[original, segmented]) + responses.append(inference_response) + + return responses diff --git a/benchmarks/BM_resnet101/model_repository/segmentation_bls/config.pbtxt b/benchmarks/BM_resnet101/model_repository/segmentation_bls/config.pbtxt new file mode 100644 index 00000000..89173c4c --- /dev/null +++ b/benchmarks/BM_resnet101/model_repository/segmentation_bls/config.pbtxt @@ -0,0 +1,44 @@ +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +name: "segmentation_bls" +backend: "python" +max_batch_size: 256 + +input [ + { + name: "encoded" + data_type: TYPE_UINT8 + dims: [ -1 ] + } +] +output [ + { + name: "original" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1 ] + }, + { + name: "segmented" + data_type: TYPE_UINT8 + dims: [ -1, -1, -1 ] + } +] diff --git a/benchmarks/BM_resnet101/run-benchmarks.sh b/benchmarks/BM_resnet101/run-benchmarks.sh new file mode 100644 index 00000000..11348ed8 --- /dev/null +++ b/benchmarks/BM_resnet101/run-benchmarks.sh @@ -0,0 +1,60 @@ +#!/bin/bash -ex + +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +: ${GRPC_ADDR:=${1:-"localhost:8001"}} + +load_models() { + echo "Loading models..." + python scripts/model-loader.py -u "${GRPC_ADDR}" load -m dali_preprocessing + python scripts/model-loader.py -u "${GRPC_ADDR}" load -m resnet101 + python scripts/model-loader.py -u "${GRPC_ADDR}" load -m dali_postprocessing + python scripts/model-loader.py -u "${GRPC_ADDR}" load -m segmentation_bls + sleep 5 + echo "...models loaded" +} + +unload_models() { + echo "Unloading models..." + python scripts/model-loader.py -u "${GRPC_ADDR}" unload -m segmentation_bls + python scripts/model-loader.py -u "${GRPC_ADDR}" unload -m dali_postprocessing + python scripts/model-loader.py -u "${GRPC_ADDR}" unload -m resnet101 + python scripts/model-loader.py -u "${GRPC_ADDR}" unload -m dali_preprocessing + sleep 5 + echo "...models unloaded" +} + +TIME_WINDOW=10000 +BATCH_SIZES="1 2" +PERF_ANALYZER_ARGS="-i grpc -u $GRPC_ADDR -p$TIME_WINDOW" + +echo "ResNet101 Benchmark: single-sample" +load_models +perf_analyzer $PERF_ANALYZER_ARGS -m segmentation_bls --input-data test_sample --shape encoded:$(stat --printf="%s" test_sample/encoded) --concurrency-range=16:128:16 +unload_models + +echo "ResNet101 Benchmark: batched" +for BS in $BATCH_SIZES; do + load_models + perf_analyzer $PERF_ANALYZER_ARGS -m segmentation_bls --input-data test_sample --shape encoded:$(stat --printf="%s" test_sample/encoded) -b$BS + unload_models +done diff --git a/benchmarks/BM_resnet101/scripts/model-loader.py b/benchmarks/BM_resnet101/scripts/model-loader.py new file mode 100644 index 00000000..bc4634c2 --- /dev/null +++ b/benchmarks/BM_resnet101/scripts/model-loader.py @@ -0,0 +1,47 @@ +# The MIT License (MIT) +# +# Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import tritonclient.grpc as t_client +import argparse +import sys + +def get_args(): + parser = argparse.ArgumentParser(description='Load or unload a model in Triton server.') + parser.add_argument('action', action='store', choices=['load', 'unload', 'reload']) + parser.add_argument('-u', '--url', required=False, action='store', default='localhost:8001', help='Server url.') + parser.add_argument('-m', '--model', required=True, action='store', help='Model name.') + return parser.parse_args() + + +def main(args): + client = t_client.InferenceServerClient(url=args.url) + if args.action in ['reload', 'unload']: + client.unload_model(args.model) + print('Successfully unloaded model', args.model) + + if args.action in ['reload', 'load']: + client.load_model(args.model) + print('Successfully loaded model', args.model) + + +if __name__ == '__main__': + args = get_args() + main(args) diff --git a/benchmarks/BM_resnet101/setup.sh b/benchmarks/BM_resnet101/setup.sh new file mode 100644 index 00000000..26acb196 --- /dev/null +++ b/benchmarks/BM_resnet101/setup.sh @@ -0,0 +1,25 @@ +#!/bin/bash -ex + +# The MIT License (MIT) +# +# Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +mkdir -p test_sample +cp test_video/sintel_trailer_short.mp4 test_sample/encoded diff --git a/benchmarks/BM_resnet101/test_video/sintel_trailer_short.mp4 b/benchmarks/BM_resnet101/test_video/sintel_trailer_short.mp4 new file mode 100644 index 00000000..c92d0eb3 Binary files /dev/null and b/benchmarks/BM_resnet101/test_video/sintel_trailer_short.mp4 differ