diff --git a/README.md b/README.md index e9ef49bd7..735178fda 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ DLI supports inference using the following frameworks: - [ncnn][ncnn] (Python API). - [PaddlePaddle][PaddlePaddle] (Python API). - [ExecuTorch][executorch] (C++ and Python APIs) +- [IREE][iree] (Python API) More information about DLI is available on the web-site ([here][dli-ru-web-page] (in Russian) @@ -105,6 +106,7 @@ Please consider citing the following papers. for TensorFlow. - `TensorFlowLite` is a directory of Dockerfiles for TensorFlow Lite. - `TVM` is a directory of Dockerfiles for Apache TVM. + - `IREE` is a directory of Dockerfiles for IREE. - `docs` directory contains auxiliary documentation. Please, find complete documentation at the [Wiki page][dli-wiki]. @@ -158,6 +160,9 @@ Please consider citing the following papers. - [`validation_results_tvm.md`](results/validation/validation_results_tvm.md) is a table that confirms correctness of inference implementation based on Apache TVM for several public models. + - [`validation_results_iree.md`](results/validation/validation_results_iree.md) + is a table that confirms correctness of inference implementation + based on IREE for several public models. - [`mxnet_models_checklist.md`](results/mxnet_models_checklist.md) contains a list of deep models inferred by MXNet checked in the DLI benchmark. @@ -282,6 +287,7 @@ Report questions, issues and suggestions, using: [ncnn]: https://github.com/Tencent/ncnn [PaddlePaddle]: https://www.paddlepaddle.org.cn/en [executorch]: https://pytorch.org/executorch-overview +[iree]: https://iree.dev [benchmark-app]: https://github.com/openvinotoolkit/openvino/tree/master/samples/cpp/benchmark_app [dli-ru-web-page]: http://hpc-education.unn.ru/dli-ru [dli-web-page]: http://hpc-education.unn.ru/dli diff --git a/docker/IREE/Dockerfile b/docker/IREE/Dockerfile new file mode 100644 index 000000000..c9bf64d24 --- /dev/null +++ b/docker/IREE/Dockerfile @@ -0,0 +1,19 @@ +FROM ubuntu_for_dli + +# Install IREE +ARG IREE_VERSION=3.8.0 +RUN python3 -m pip install iree-base-compiler==${IREE_VERSION} iree-base-runtime==${IREE_VERSION} iree-turbine==${IREE_VERSION} + +# Install dependencies +RUN python3 -m pip install opencv-python numpy + +# Install onnx for model conversion +ARG ONNX_VERSION=1.19.1 +RUN python3 -m pip install onnx==${ONNX_VERSION} + +# Install torch for model conversion +ARG TORCH_VERSION=2.9.1 +ARG TORCHVISION_VERSION=0.24.1 +RUN python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} + +WORKDIR /tmp/ diff --git a/requirements_frameworks.txt b/requirements_frameworks.txt index 0334d3635..2684c5e55 100644 --- a/requirements_frameworks.txt +++ b/requirements_frameworks.txt @@ -18,8 +18,12 @@ dglgo==0.0.2 tflite paddleslim==2.6.0 -paddlepaddle==2.6.0 +paddlepaddle==2.6.2 --extra-index-url https://mirror.baidu.com/pypi/simple ncnn spektral==1.3.0 + +iree-base-compiler==3.8.0 +iree-base-runtime==3.8.0 +iree-turbine==3.8.0 \ No newline at end of file diff --git a/results/validation/validation_results_iree.md b/results/validation/validation_results_iree.md new file mode 100644 index 000000000..d4753fb8e --- /dev/null +++ b/results/validation/validation_results_iree.md @@ -0,0 +1,112 @@ +# Validation results for the models inferring using IREE + +## Public models + +We infer models using the following APIs: + +1. IREE, when we load PyTorch models directly from source format. + + ```bash + python inference_iree.py -t classification -is 1 3 224 224 \ + -mn densenet121 \ + -tm torchvision.models \ + -f pytorch \ + -i data/ \ + --norm --mean 0.485 0.456 0.406 --std 0.229 0.224 0.225 \ + -l labels/image_net_synset.txt \ + --layout NCHW --channel_swap 2 1 0 \ + -fn main + ``` + +1. IREE, when we load ONNX models directly from source format. + + ```bash + python inference_iree.py -t classification -is 1 3 224 224 \ + -mn densenet121 \ + -m densenet121.onnx \ + -f onnx \ + --onnx_opset_version 18 \ + -i data/ \ + --norm --mean 0.485 0.456 0.406 --std 0.229 0.224 0.225 \ + -l labels/image_net_synset.txt \ + --layout NCHW --channel_swap 2 1 0 \ + -fn main_graph + ``` + +1. PyTorch as source framework for reference. + + ```bash + python inference_pytorch.py -t classification -is [1,3,224,224] \ + --input_names data \ + -mn densenet121 \ + -mm torchvision.models \ + -i data/ \ + --mean [123.675,116.28,103.53] \ + --input_scale [58.395,57.12,57.375] \ + -l labels/image_net_synset.txt + ``` + +### Notes + +1. Models in ONNX format loaded from [onnx/models][onnx-models] repository. +1. The model `squeezenet1.1` is missed in [onnx/models][onnx-models] repository. + +### Image classification + +#### Test image #1 + +Data source: [ImageNet][imagenet] + +Image resolution: 709 x 510 + +
+ +
+ +Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) | +-|-|-|-|-| +densenet-121 | PyTorch | 0.9525911 Granny Smith
0.0132309 orange
0.0123391 lemon
0.0028140 banana
0.0020238 piggy bank, penny bank | 0.9523347 Granny Smith
0.0132272 orange
0.0125170 lemon
0.0027910 banana
0.0020333 piggy bank, penny bank | 0.9523349 Granny Smith
0.0132271 orange
0.0125169 lemon
0.0027909 banana
0.0020333 piggy bank, penny bank | +efficientnet-b0 | PyTorch | 0.3421609 Granny Smith
0.1089311 piggy bank, penny bank
0.0693323 teapot
0.0249018 vase
0.0205339 saltshaker, salt shaker | 0.3421628 Granny Smith
0.1089310 piggy bank, penny bank
0.0693315 teapot
0.0249016 vase
0.0205339 saltshaker, salt shaker | 0.3421622 Granny Smith
0.1089308 piggy bank, penny bank
0.0693314 teapot
0.0249017 vase
0.0205338 saltshaker, salt shaker | +googlenet-v1 | PyTorch | 0.5399834 Granny Smith
0.1101810 piggy bank, penny bank
0.0232574 vase
0.0213452 pitcher, ewer
0.0198953 bell pepper | 0.5432554 Granny Smith
0.1103971 piggy bank, penny bank
0.0232568 vase
0.0213901 pitcher, ewer
0.0196196 bell pepper | 0.5432543 Granny Smith
0.1103970 piggy bank, penny bank
0.0232569 vase
0.0213901 pitcher, ewer
0.0196196 bell pepper | +resnet-50 | PyTorch | 0.9280675 Granny Smith
0.0129466 orange
0.0058861 lemon
0.0041993 necklace
0.0025445 banana | 0.9278086 Granny Smith
0.0129410 orange
0.0059573 lemon
0.0042141 necklace
0.0025712 banana | 0.4216066 Granny Smith
0.0661015 dumbbell
0.0348192 barbell
0.0049673 orange
0.0045203 syringe | +squeezenet1.1 | PyTorch | 0.5913458 piggy bank, penny bank
0.0682889 Granny Smith
0.0610993 lemon
0.0596012 necklace
0.0492096 bucket, pail | 0.5895361 piggy bank, penny bank
0.0677933 Granny Smith
0.0610654 necklace
0.0610450 lemon
0.0490914 bucket, pail | - | + +#### Test image #2 + +Data source: [ImageNet][imagenet] + +Image resolution: 500 x 500 + +
+ +
+ +Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) | +-|-|-|-|-| +densenet-121 | PyTorch | 0.9847536 junco, snowbird
0.0068679 chickadee
0.0034511 brambling, Fringilla montifringilla
0.0015685 water ouzel, dipper
0.0012343 indigo bunting, indigo finch, indigo bird, Passerina cyanea | 0.9841590 junco, snowbird
0.0072199 chickadee
0.0034962 brambling, Fringilla montifringilla
0.0016226 water ouzel, dipper
0.0012858 indigo bunting, indigo finch, indigo bird, Passerina cyanea | 0.9841590 junco, snowbird
0.0072199 chickadee
0.0034962 brambling, Fringilla montifringilla
0.0016226 water ouzel, dipper
0.0012858 indigo bunting, indigo finch, indigo bird, Passerina cyanea | +efficientnet-b0 | PyTorch | 0.8903497 junco, snowbird
0.0147084 water ouzel, dipper
0.0074830 chickadee
0.0044766 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis | 0.8903519 junco, snowbird
0.0147081 water ouzel, dipper
0.0074829 chickadee
0.0044765 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis | 0.8903498 junco, snowbird
0.0147084 water ouzel, dipper
0.0074830 chickadee
0.0044766 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis | +googlenet-v1 | PyTorch | 0.6449553 junco, snowbird
0.0752306 chickadee
0.0480572 brambling, Fringilla montifringilla
0.0298399 goldfinch, Carduelis carduelis
0.0126128 house finch, linnet, Carpodacus mexicanus | 0.6461055 junco, snowbird
0.0772564 chickadee
0.0468782 brambling, Fringilla montifringilla
0.0295897 goldfinch, Carduelis carduelis
0.0123322 house finch, linnet, Carpodacus mexicanus | 0.6461049 junco, snowbird
0.0772565 chickadee
0.0468783 brambling, Fringilla montifringilla
0.0295897 goldfinch, Carduelis carduelis
0.0123323 house finch, linnet, Carpodacus mexicanus | +resnet-50 | PyTorch | 0.9809760 junco, snowbird
0.0049167 goldfinch, Carduelis carduelis
0.0036987 chickadee
0.0036697 water ouzel, dipper
0.0029304 brambling, Fringilla montifringilla | 0.9805012 junco, snowbird
0.0049154 goldfinch, Carduelis carduelis
0.0039196 chickadee
0.0038098 water ouzel, dipper
0.0028983 brambling, Fringilla montifringilla | 0.3845567 junco, snowbird
0.0091156 water ouzel, dipper
0.0054526 chickadee
0.0026206 indigo bunting, indigo finch, indigo bird, Passerina cyanea
0.0023612 brambling, Fringilla montifringilla | +squeezenet1.1 | PyTorch | 0.9609295 junco, snowbird
0.0248581 chickadee
0.0042597 brambling, Fringilla montifringilla
0.0037157 goldfinch, Carduelis carduelis
0.0033528 ruffed grouse, partridge, Bonasa umbellus | 0.9614577 junco, snowbird
0.0250981 chickadee
0.0040701 brambling, Fringilla montifringilla
0.0035156 goldfinch, Carduelis carduelis
0.0030858 ruffed grouse, partridge, Bonasa umbellus | - | + +#### Test image #3 + +Data source: [ImageNet][imagenet] + +Image resolution: 333 x 500 + +
+ +
+ +Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) | +-|-|-|-|-| +densenet-121 | PyTorch | 0.3047960 liner, ocean liner
0.1327189 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1180288 container ship, containership, container vessel
0.0794686 drilling platform, offshore rig
0.0718431 dock, dockage, docking facility | 0.3022414 liner, ocean liner
0.1322474 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1194614 container ship, containership, container vessel
0.0795042 drilling platform, offshore rig
0.0723073 dock, dockage, docking facility | 0.3022407 liner, ocean liner
0.1322481 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1194605 container ship, containership, container vessel
0.0795041 drilling platform, offshore rig
0.0723069 dock, dockage, docking facility | +efficientnet-b0 | PyTorch | 0.4476882 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953832 container ship, containership, container vessel
0.0872342 beacon, lighthouse, beacon light, pharos
0.0559825 drilling platform, offshore rig
0.0441807 liner, ocean liner | 0.4476875 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953838 container ship, containership, container vessel
0.0872344 beacon, lighthouse, beacon light, pharos
0.0559831 drilling platform, offshore rig
0.0441806 liner, ocean liner | 0.4476894 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953836 container ship, containership, container vessel
0.0872341 beacon, lighthouse, beacon light, pharos
0.0559827 drilling platform, offshore rig
0.0441803 liner, ocean liner | +googlenet-v1 | PyTorch | 0.1330581 liner, ocean liner
0.0796951 drilling platform, offshore rig
0.0680323 container ship, containership, container vessel
0.0588053 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0365606 fireboat | 0.1323653 liner, ocean liner
0.0796393 drilling platform, offshore rig
0.0678083 container ship, containership, container vessel
0.0585719 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0366882 fireboat | 0.1323648 liner, ocean liner
0.0796394 drilling platform, offshore rig
0.0678085 container ship, containership, container vessel
0.0585720 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0366881 fireboat | +resnet-50 | PyTorch | 0.4818293 liner, ocean liner
0.0992477 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0687505 container ship, containership, container vessel
0.0517874 dock, dockage, docking facility
0.0483462 pirate, pirate ship | 0.4759648 liner, ocean liner
0.1025407 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0689996 container ship, containership, container vessel
0.0524496 dock, dockage, docking facility
0.0473777 pirate, pirate ship | 0.1220204 lifeboat
0.0430796 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0360478 beacon, lighthouse, beacon light, pharos
0.0335465 dock, dockage, docking facility
0.0251255 liner, ocean liner | +squeezenet1.1 | PyTorch | 0.4393108 liner, ocean liner
0.1895231 container ship, containership, container vessel
0.1506845 pirate, pirate ship
0.0962459 fireboat
0.0199389 drilling platform, offshore rig | 0.4413096 liner, ocean liner
0.1931005 container ship, containership, container vessel
0.1459103 pirate, pirate ship
0.0937753 fireboat
0.0198682 drilling platform, offshore rig | - | + + +[imagenet]: http://www.image-net.org +[onnx-models]: https://github.com/onnx/models/tree/main diff --git a/src/accuracy_checker/process.py b/src/accuracy_checker/process.py index 6661be780..df117097f 100644 --- a/src/accuracy_checker/process.py +++ b/src/accuracy_checker/process.py @@ -25,7 +25,7 @@ def execute(self, idx): command_line = self.__fill_command_line() if command_line == '': self.__log.error('Command line is empty') - self.__log.info(f'Start accuracy check for {idx+1} test: {self._test.model.name}') + self.__log.info(f'Start accuracy check for {idx + 1} test: {self._test.model.name}') self.__log.info(f'Command line is : {command_line}') self._executor.set_target_framework(self._test.framework) command_line = self._executor.prepare_command_line(self._test, command_line) diff --git a/src/benchmark/README.md b/src/benchmark/README.md index c44fd7e9a..c03ac63ef 100644 --- a/src/benchmark/README.md +++ b/src/benchmark/README.md @@ -22,6 +22,7 @@ the following frameworks: - [RKNN][rknn]. - [Spektral][spektral] (Python API). - [PaddlePaddle][paddlepaddle] (Python API). +- [IREE][iree] (Python API). ### Implemented algorithm @@ -274,3 +275,4 @@ pip install openvino_dev[mxnet,caffe,caffe2,onnx,pytorch,tensorflow2]==` внутри секции ``, чтобы указать путь к модулю с архитектурой (значение будет проброшено в параметр `--torch_module` скрипта инференса). + ### Примеры заполнения #### Пример заполнения конфигурации для измерения производительности вывода средствами Intel Distribution of OpenVINO Toolkit @@ -773,6 +788,47 @@ ``` +#### Пример заполнения конфигурации для измерения производительности вывода средствами IREE + +```xml + + + + classification + resnet50 + FP32 + onnx + /home/user/models/resnet50/resnet50.onnx + + + + + ImageNet + /mnt/datasets/ILSVRC2012_img_val + + + IREE + 1 + CPU + 20 + 60 + + + main + 1 3 224 224 + NCHW + True + 0.485 0.456 0.406 + 0.229 0.224 0.225 + 2 1 0 + llvm-cpu + 3 + 17 + --iree-llvmcpu-target-cpu-features=host + + +``` + #### Пример заполнения конфигурации для измерения производительности вывода средствами RKNN C++ API ```xml diff --git a/src/inference/README.md b/src/inference/README.md index da8fe747b..3157f765c 100644 --- a/src/inference/README.md +++ b/src/inference/README.md @@ -17,6 +17,7 @@ 1. ncnn. 1. PaddlePaddle. 1. Spektral. +1. IREE. ## Вывод глубоких моделей с использованием Inference Engine @@ -1482,6 +1483,124 @@ python inference_ncnn.py --model \ --batch_size ``` +## Вывод глубоких моделей с использованием IREE + +#### Скрипт + +```bash +inference_iree.py +``` + +#### Общие обязательные аргументы + +- `-fn / --function_name` - имя функции внутри IREE-модуля, которое будет вызвано при инференсе. +- `-i / --input` - путь до изображения или директории с изображениями + (расширения файлов `.jpg`, `.png`, `.bmp` и т.д.). +- `-is / --input_shape` - размеры входного тензора сети в формате + BxCxHxW, B - размер пачки, C - количество каналов изображений, + H - высота изображений, W - ширина изображений. + +#### Остальные параметры зависят от того, в каком формате модель подается на вход. + +- **Готовый IREE-модуль (`.vmfb` или `.mlir`)** + - `-m / --model` - путь до модели в формате `.vmfb` (готовый бинарник) или `.mlir` (будет скомпилирован перед запуском). Обязательный параметр. + - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр, если модель в формате `.mlir`. + +- **ONNX-модель** + - `--source_framework onnx` - фреймворк, из которого будет загружена модель. Обязательный параметр. + - `-m / --model` - путь до модели в формате `.onnx`. Обязательный параметр. + - `--onnx_opset_version` - версия ONNX-opset (по умолчанию `18`). + - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр. + +- **PyTorch-модель из файла** + - `--source_framework pytorch` - фреймворк, из которого будет загружена модель. Обязательный параметр. + - `-m / --model` - путь до модели в формате `.pt`. Обязательный параметр. + - `-w / --weights` - путь до файла с весами модели в формате `.pth`. Опциональный параметр. + - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр. + +- **PyTorch-модель из модуля** + - `--source_framework pytorch` - фреймворк, из которого будет загружена модель. Обязательный параметр. + - `-tm / --torch_module` - путь до Python модуля или относительный путь + до Python файла с архитектурой модели (например `torchvision.models` для модуля с [моделями][torchvision_models]). Обязательный параметр. + - `-mn / --model_name` - название модели. Обязательный параметр. + - `-w / --weights` - путь до файла с весами модели в формате `.pth`. Опциональный параметр. + - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр. + +#### Опциональные аргументы + +- `-b / --batch_size` - количество изображений, которые будут обработаны + за один проход сети. По умолчанию равно `1`. Значение данного параметра + должно быть равно значению B из параметра `input_shape`. +- `-t / --task` - название задачи. Текущая реализация поддерживает + решение задачи классификации (`classification`). По умолчанию принимает значение `feedforward`. +- `-l / --labels`- путь до файла в формате JSON с перечнем меток + при решении задачи классификации. По умолчанию принимает значение + `image_net_labels.json`, что соответствует меткам набора данных + ImageNet. +- `-nt / --number_top` - количество лучших результатов, выводимых при решении задачи классификации. По умолчанию выводится `5` наилучших + результатов. +- `-ni / --number_iter` - количество прямых проходов по сети. + По умолчанию выполняется `1` проход по сети. +- `--raw_output` - работа скрипта без логов. По умолчанию не установлен. +- `--time` – ограничение по времени в секундах. Если заданы одновременно `--time` и `-ni`, выполняется тот сценарий, который дольше. +- `--report_path` – путь к `.json`-отчёту (по умолчанию `src/inference/iree_inference_report.json`). +- `--layout` – формат входного тензора (`NHWC` или `NCHW`, по умолчанию `NCHW`). +- `--norm` – флаг нормализации изображения (делит значения на `255` перед дальнейшей обработкой). +- `--mean`, `--std`, `--channel_swap` – параметры препроцессинга. Значения по умолчанию: `mean=[0, 0, 0]`, `std=[1, 1, 1]`, `channel_swap=[2, 1, 0]`. +- `--opt_level` – уровень оптимизаций, если перед началом вывода потребуется компиляция модели (`0–3`, по умолчанию `2`). +- `--extra_compile_args` – дополнительные флаги компиляции (должны указываться строго в конце командной строки). + ``` + --extra_compile_args --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-linux-gnu + ``` + +#### Примеры запуска + +**Готовый `.vmfb`** + +```bash +python3 inference_iree.py \ + -m compiled/resnet50.vmfb \ + -fn main \ + -i ./data/images \ + -is 1 3 224 224 \ + -b 1 -ni 100 \ + -t classification \ + -l ./labels/imagenet_synset.txt +``` + +**Автоконвертация ONNX -> MLIR -> VMFB** + +```bash +python3 inference_iree.py \ + --source_framework onnx \ + -m ./models/efficientnet-b0.onnx \ + --onnx_opset_version 18 \ + -fn main \ + -i ./data/test.jpg \ + -is 1 3 224 224 \ + -tb llvm-cpu \ + --opt_level 3 \ + --extra_compile_args --iree-vulkan-target-triple=rdna2-pc-linux-gnu +``` + +** Автоконвертация Pytorch модели из `torchvision`** + +```bash +python3 inference_iree.py \ + --source_framework pytorch \ + -mn resnet50 \ + -tm torchvision.models \ + -fn classification \ + -i ./data/images \ + -is 1 3 224 224 \ + -tb llvm-cpu \ + --mean 123.68 116.78 103.94 \ + --std 58.40 57.12 57.38 +``` + +Результат выполнения: набор наиболее вероятных классов, которым принадлежит +изображение. + [execution_providers]: https://onnxruntime.ai/docs/execution-providers [gluon_modelzoo]: https://cv.gluon.ai/model_zoo/index.html @@ -1492,3 +1611,4 @@ python inference_ncnn.py --model \ [dgl]: https://www.dgl.ai/pages/start.html [ogb]: https://ogb.stanford.edu/ [tensorflow-gpu]: https://www.tensorflow.org/install/pip +[iree]: https://iree.dev diff --git a/src/inference/inference_iree.py b/src/inference/inference_iree.py new file mode 100644 index 000000000..7ba741b6d --- /dev/null +++ b/src/inference/inference_iree.py @@ -0,0 +1,290 @@ +import argparse +import sys +import traceback +from pathlib import Path + +import postprocessing_data as pp +from inference_tools.loop_tools import loop_inference, get_exec_time +from io_adapter import IOAdapter +from io_model_wrapper import IREEModelWrapper +from reporter.report_writer import ReportWriter +from transformer import IREETransformer +from iree_auxiliary import (load_model, create_dict_for_transformer, prepare_output, validate_cli_args) + + +sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils'))) +from logger_conf import configure_logger # noqa: E402 + +log = configure_logger() + +try: + import iree.runtime as ireert # noqa: E402 +except ImportError as e: + log.error(f'IREE import error: {e}') + sys.exit(1) + + +def cli_argument_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--source_framework', + help='Source model framework (required for automatic conversion to MLIR)', + type=str, + choices=['onnx', 'pytorch'], + dest='source_framework') + parser.add_argument('-m', '--model', + help='Path to source framework model (.onnx, .pt),' + 'to file with compiled model (.vmfb)' + 'or MLIR (.mlir).', + type=str, + dest='model') + parser.add_argument('-w', '--weights', + help='Path to an .pth file with a trained weights.' + 'Availiable when source_framework=pytorch ', + type=str, + dest='model_weights') + parser.add_argument('-tm', '--torch_module', + help='Torch module with model architecture.' + 'Availiable when source_framework=pytorch', + type=str, + dest='torch_module') + parser.add_argument('-mn', '--model_name', + help='Model name.', + type=str, + dest='model_name') + parser.add_argument('--onnx_opset_version', + help='Path to an .onnx with a trained model.' + 'Availiable when source_framework=onnx', + type=int, + dest='onnx_opset_version') + parser.add_argument('-fn', '--function_name', + help='IREE module function name to execute.', + required=True, + type=str, + dest='function_name') + parser.add_argument('-i', '--input', + help='Path to data.', + required=True, + type=str, + nargs='+', + dest='input') + parser.add_argument('-is', '--input_shape', + help='Input shape BxCxHxW, B is a batch size,' + 'C is an input tensor number of channels,' + 'H is an input tensor height,' + 'W is an input tensor width.', + required=True, + type=int, + nargs=4, + dest='input_shape') + parser.add_argument('-b', '--batch_size', + help='Size of the processed pack.' + 'Should be the same as B in input_shape argument.', + default=1, + type=int, + dest='batch_size') + parser.add_argument('-l', '--labels', + help='Labels mapping file.', + default=None, + type=str, + dest='labels') + parser.add_argument('-nt', '--number_top', + help='Number of top results.', + default=5, + type=int, + dest='number_top') + parser.add_argument('-t', '--task', + help='Task type. Default: feedforward.', + choices=['feedforward', 'classification'], + default='feedforward', + type=str, + dest='task') + parser.add_argument('-ni', '--number_iter', + help='Number of inference iterations.', + default=1, + type=int, + dest='number_iter') + parser.add_argument('--raw_output', + help='Raw output without logs.', + default=False, + type=bool, + dest='raw_output') + parser.add_argument('--time', + required=False, + default=0, + type=int, + dest='time', + help='Optional. Maximum test duration. 0 if no restrictions.') + parser.add_argument('--report_path', + type=Path, + default=Path(__file__).parent / 'iree_inference_report.json', + dest='report_path') + parser.add_argument('--layout', + help='Input layout.', + default='NCHW', + choices=['NHWC', 'NCHW'], + type=str, + dest='layout') + parser.add_argument('--norm', + help='Flag to normalize input images.', + action='store_true', + dest='norm') + parser.add_argument('--mean', + help='Mean values.', + default=[0, 0, 0], + type=float, + nargs=3, + dest='mean') + parser.add_argument('--std', + help='Standard deviation values.', + default=[1., 1., 1.], + type=float, + nargs=3, + dest='std') + parser.add_argument('--channel_swap', + help='Parameter of channel swap.', + default=[2, 1, 0], + type=int, + nargs=3, + dest='channel_swap') + parser.add_argument('-tb', '--target_backend', + help='Target backend, for example `llvm-cpu` for CPU.', + default='llvm-cpu', + type=str, + dest='target_backend') + parser.add_argument('--opt_level', + help='The optimization level of the compilation.', + type=int, + choices=[0, 1, 2, 3], + default=2) + parser.add_argument('--extra_compile_args', + help='The extra arguments for MLIR compilation.', + type=str, + nargs=argparse.REMAINDER, + default=[]) + args = parser.parse_args() + validate_cli_args(args) + return args + + +def get_inference_function(model_context, function_name): + try: + main_module = model_context.modules.module + inference_func = main_module[function_name] + log.info(f'Using function {function_name} for inference') + return inference_func + + except Exception as e: + log.error(f'Failed to get inference function: {e}') + raise + + +def inference_iree(inference_func, number_iter, get_slice, test_duration): + result = None + time_infer = [] + + if number_iter == 1: + slice_input = get_slice() + result, exec_time = infer_slice(inference_func, slice_input) + time_infer.append(exec_time) + else: + time_infer = loop_inference(number_iter, test_duration)( + inference_iteration, + )(inference_func, get_slice)['time_infer'] + + log.info('Inference completed') + return result, time_infer + + +def inference_iteration(inference_func, get_slice): + slice_input = get_slice() + _, exec_time = infer_slice(inference_func, slice_input) + return exec_time + + +@get_exec_time() +def infer_slice(inference_func, slice_input): + config = ireert.Config('local-task') + device = config.device + + input_buffers = [] + for input_ in slice_input: + input_buffers.append(ireert.asdevicearray(device, input_)) + + result = inference_func(*input_buffers) + + if hasattr(result, 'to_host'): + result = result.to_host() + + return result + + +def main(): + args = cli_argument_parser() + + try: + model_wrapper = IREEModelWrapper(args) + data_transformer = IREETransformer(create_dict_for_transformer(args)) + io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer) + + report_writer = ReportWriter() + report_writer.update_framework_info(name='IREE') + report_writer.update_configuration_setup( + batch_size=args.batch_size, + iterations_num=args.number_iter, + target_device=args.target_backend, + ) + + log.info('Loading model') + model_context = load_model( + model_path=args.model, + model_weights=args.model_weights, + torch_module=args.torch_module, + model_name=args.model_name, + onnx_opset_version=args.onnx_opset_version, + source_framework=args.source_framework, + input_shape=args.input_shape, + target_backend=args.target_backend, + opt_level=args.opt_level, + extra_compile_args=args.extra_compile_args, + ) + inference_func = get_inference_function(model_context, args.function_name) + + log.info(f'Preparing input data: {args.input}') + io.prepare_input(model_context, args.input) + + log.info(f'Starting inference ({args.number_iter} iterations) on {args.target_backend}') + result, inference_time = inference_iree( + inference_func, + args.number_iter, + io.get_slice_input_iree, + args.time, + ) + + log.info('Computing performance metrics') + inference_result = pp.calculate_performance_metrics_sync_mode( + args.batch_size, + inference_time, + ) + + report_writer.update_execution_results(**inference_result) + report_writer.write_report(args.report_path) + + if not args.raw_output: + if args.number_iter == 1: + try: + log.info('Converting output tensor to print results') + result = prepare_output(result, args.task) + log.info('Inference results') + io.process_output(result, log) + except Exception as ex: + log.warning(f'Error when printing inference results: {str(ex)}') + + log.info(f'Performance results: {inference_result}') + + except Exception: + log.error(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/src/inference/io_adapter.py b/src/inference/io_adapter.py index 75693e2e4..010a27748 100644 --- a/src/inference/io_adapter.py +++ b/src/inference/io_adapter.py @@ -186,6 +186,14 @@ def get_slice_input(self, *args, **kwargs): return slice_input + def get_slice_input_iree(self, *args, **kwargs): + slice_input = [] + for key in self._transformed_input: + data_gen = self._transformed_input[key] + slice_data = [copy.deepcopy(next(data_gen)) for _ in range(self._batch_size)] + slice_input.append(np.stack(slice_data)) + return slice_input + def get_slice_input_mxnet(self, *args, **kwargs): import mxnet slice_input = dict.fromkeys(self._transformed_input.keys(), None) @@ -425,7 +433,7 @@ def get_slice_input(self, *args, **kwargs): return [self._prompts[0]] * self._batch_size def process_output(self, result, log): - output_text = '\n'.join([f'{i+1}) {text} ... \n' for i, text in enumerate(result)]) + output_text = '\n'.join([f'{i + 1}) {text} ... \n' for i, text in enumerate(result)]) log.info(f'Generated results: \n{output_text}') @@ -435,7 +443,7 @@ def get_slice_input(self, *args, **kwargs): return self.audio_data, self.sampling_rate, self.audio_length def process_output(self, result, log): - output_text = '\n'.join([f'{i+1}) {text} ... \n' for i, text in enumerate(result)]) + output_text = '\n'.join([f'{i + 1}) {text} ... \n' for i, text in enumerate(result)]) log.info(f'Generated results: \n{output_text}') diff --git a/src/inference/io_model_wrapper.py b/src/inference/io_model_wrapper.py index 1c68c89dd..2aa15d13b 100644 --- a/src/inference/io_model_wrapper.py +++ b/src/inference/io_model_wrapper.py @@ -409,3 +409,19 @@ def get_input_layer_dtype(self): class ExecuTorchIOModelWrapper(TVMIOModelWrapper): pass + + +class IREEModelWrapper(IOModelWrapper): + def __init__(self, args): + self._input_shapes = [args.input_shape] + self._model_path = args.model + + def get_input_layer_names(self, model): + return ['input'] + + def get_input_layer_shape(self, model, layer_name): + return self._input_shapes[0] + + def get_input_layer_dtype(self, model, layer_name): + import numpy as np + return np.float32 diff --git a/src/inference/iree_auxiliary.py b/src/inference/iree_auxiliary.py new file mode 100644 index 000000000..a508115c3 --- /dev/null +++ b/src/inference/iree_auxiliary.py @@ -0,0 +1,214 @@ +import os +import sys +import tempfile +from pathlib import Path + +import numpy as np + +sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('model_converters', + 'iree_converter', + 'iree_auxiliary'))) +from compiler import IREECompiler # noqa: E402 +from converter import IREEConverter # noqa: E402 + +sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils'))) +from logger_conf import configure_logger # noqa: E402 + +log = configure_logger() + +try: + import iree.runtime as ireert # noqa: E402 +except ImportError as e: + log.error(f'IREE import error: {e}') + sys.exit(1) + + +def _validate_iree_model_args(args): + if not args.model: + raise ValueError('Model path (-m/--model) is required') + if not os.path.exists(args.model): + raise FileNotFoundError(f'The file not found: {args.model}') + + file_type = args.model.split('.')[-1].lower() + supported_extensions = ['mlir', 'vmfb'] + if file_type not in supported_extensions: + raise ValueError(f'Model must be an {supported_extensions} file') + if file_type == 'mlir' and not args.target_backend: + raise ValueError('target_backend is required when using .mlir model') + + +def _validate_onnx_args(args): + if not args.model: + raise ValueError('Model path (-m/--model) is required for ONNX framework') + if not os.path.exists(args.model): + raise FileNotFoundError(f'Model file not found: {args.model}') + + file_type = args.model.split('.')[-1] + if file_type == 'onnx': + if not args.onnx_opset_version: + raise ValueError('onnx_opset_version is required for ONNX framework') + else: + _validate_iree_model_args(args) + + +def _validate_pytorch_args(args): + has_model_path = args.model is not None and args.model != '' + has_module_model = (args.torch_module is not None and args.torch_module != '' + and args.model_name is not None and args.model_name != '') + + if not has_model_path and not has_module_model: + raise ValueError( + 'For PyTorch conversion, you must specify either model_path, ' + 'or torch_module and model_name', + ) + + if has_model_path and has_module_model: + raise ValueError( + 'Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). ' + 'Please choose only one method.', + ) + + if has_model_path: + if not os.path.exists(args.model): + raise FileNotFoundError(f'Model file not found: {args.model}') + + file_type = args.model.split('.')[-1] + if file_type != 'pt': + _validate_iree_model_args(args) + else: + if not args.target_backend: + raise ValueError('target_backend is required when using conversion from torch module') + + if args.model_weights and args.model_weights != '' and not os.path.exists(args.model_weights): + raise FileNotFoundError(f'Model weights not found: {args.model_weights}') + + +def validate_cli_args(args): + if args.source_framework == 'onnx': + _validate_onnx_args(args) + elif args.source_framework == 'pytorch': + _validate_pytorch_args(args) + else: + _validate_iree_model_args(args) + + +def _convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version, + source_framework, input_shape, output_mlir): + dictionary = { + 'source_framework': source_framework, + 'model_name': model_name, + 'model_path': model_path, + 'model_weights': model_weights, + 'torch_module': torch_module, + 'onnx_opset_version': onnx_opset_version, + 'input_shape': input_shape, + 'output_mlir': output_mlir, + } + converter = IREEConverter.get_converter(dictionary) + converter.convert_to_mlir() + return + + +def _compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args): + try: + log.info('Starting model compilation') + return IREECompiler.compile_model(mlir_path, target_backend, opt_level, extra_compile_args) + except Exception as e: + log.error(f'Failed to compile MLIR: {e}') + raise + + +def _load_model_buffer(model_path, target_backend, opt_level, extra_compile_args): + if not os.path.exists(model_path): + raise FileNotFoundError(f'Model file not found: {model_path}') + + file_type = model_path.split('.')[-1] + + if file_type == 'mlir': + if target_backend is None: + raise ValueError('target_backend is required for MLIR compilation') + vmfb_buffer = _compile_mlir(model_path, target_backend, opt_level, extra_compile_args) + elif file_type == 'vmfb': + with open(model_path, 'rb') as f: + vmfb_buffer = f.read() + else: + raise ValueError(f'The file type {file_type} is not supported. Supported types: .mlir, .vmfb') + + log.info(f'Successfully loaded model buffer from {model_path}') + return vmfb_buffer + + +def _create_iree_context_from_buffer(vmfb_buffer): + try: + config = ireert.Config('local-task') + vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer) + context = ireert.SystemContext(config=config) + context.add_vm_module(vm_module) + + log.info('Successfully created IREE context from buffer') + return context + + except Exception as e: + log.error(f'Failed to create IREE context: {e}') + raise + + +def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version, + source_framework, input_shape, target_backend, opt_level, extra_compile_args): + is_tmp_mlir = False + if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']: + with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp: + output_mlir = temp.name + _convert_model_to_mlir(model_path, + model_weights, + torch_module, + model_name, + onnx_opset_version, + source_framework, + input_shape, + output_mlir) + model_path = output_mlir + is_tmp_mlir = True + + vmfb_buffer = _load_model_buffer( + model_path, + target_backend=target_backend, + opt_level=opt_level, + extra_compile_args=extra_compile_args, + ) + + if is_tmp_mlir: + os.remove(model_path) + + return _create_iree_context_from_buffer(vmfb_buffer) + + +def prepare_output(result, task): + if task == 'feedforward': + return {} + elif task == 'classification': + if hasattr(result, 'to_host'): + result = result.to_host() + + logits = np.array(result) + + # Apply softmax + max_logits = np.max(logits, axis=-1, keepdims=True) + exp_logits = np.exp(logits - max_logits) + probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) + + return {'output': probabilities} + else: + raise ValueError(f'Unsupported task {task}') + + +def create_dict_for_transformer(args): + return { + 'channel_swap': args.channel_swap, + 'mean': args.mean, + 'std': args.std, + 'norm': args.norm, + 'layout': args.layout, + 'input_shape': args.input_shape, + 'batch_size': args.batch_size, + } diff --git a/src/inference/transformer.py b/src/inference/transformer.py index 2d85d0c06..42a27eaaf 100644 --- a/src/inference/transformer.py +++ b/src/inference/transformer.py @@ -368,3 +368,71 @@ def transform_images(self, images, shape, element_type, *args): class ExecuTorchTransformer(TVMTransformer): pass + + +class IREETransformer(Transformer): + def __init__(self, converting): + self._converting = converting + + def __set_norm(self, image): + if self._converting.get('norm', False): + image = image.astype(np.float32) / 255.0 + return image + + def __set_channel_swap(self, image): + channel_swap = self._converting.get('channel_swap') + if channel_swap is not None: + image = image[:, :, channel_swap] + return image + + def __set_mean(self, image): + mean = self._converting.get('mean') + if mean is not None and len(mean) == 3: + image[:, :, 0] -= mean[0] + image[:, :, 1] -= mean[1] + image[:, :, 2] -= mean[2] + return image + + def __set_std(self, image): + std = self._converting.get('std') + if std is not None and len(std) == 3: + image[:, :, 0] /= std[0] + image[:, :, 1] /= std[1] + image[:, :, 2] /= std[2] + return image + + def __set_layout(self, image): + layout = self._converting['layout'] + if layout is not None: + layout = LAYER_LAYOUT_TO_IMAGE[layout] + image = np.expand_dims(image, 0).transpose(layout) + return image + + def __bgr_to_rgb(self, image): + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + def _transform(self, image): + transformed_image = self.__bgr_to_rgb(image) + transformed_image = self.__set_norm(transformed_image) + transformed_image = self.__set_channel_swap(transformed_image) + transformed_image = self.__set_mean(transformed_image) + transformed_image = self.__set_std(transformed_image) + transformed_image = self.__set_layout(transformed_image) + return transformed_image + + def transform_images(self, images, shape, element_type, *args): + dataset_size = images.shape[0] + new_shape = [dataset_size] + shape[1:] + transformed_images = np.zeros(shape=new_shape, dtype=element_type) + for i in range(dataset_size): + transformed_images[i] = self._transform(images[i]) + return transformed_images + + def get_shape_in_chw_order(self, shape, *args): + layout = self._converting.get('layout', 'NHWC') + if layout == 'NHWC': + return shape[3], shape[1], shape[2] + elif layout == 'NCHW': + return shape[1], shape[2], shape[3] + else: + return shape[1:] diff --git a/src/model_converters/README.md b/src/model_converters/README.md index 7d7dea499..28679735c 100644 --- a/src/model_converters/README.md +++ b/src/model_converters/README.md @@ -16,6 +16,7 @@ format from TensorFlow and ONNX formats. - `tvm_converter` contains converter and compiler to the TVM format. +- `iree_converter` contains tools to convert ONNX or PyTorch models to IREE MLIR and compile them to VMFB binaries. ## An overview of existing model converters diff --git a/src/model_converters/iree_converter/README.md b/src/model_converters/iree_converter/README.md new file mode 100644 index 000000000..604265df5 --- /dev/null +++ b/src/model_converters/iree_converter/README.md @@ -0,0 +1,119 @@ +# Conversion to the IREE format +IREE converter supports conversion to the IREE MLIR format from ONNX and PyTorch formats. + +IREE compiler supports compilation from `.mlir` format to the `.vmfb` format for deployment on various backends. + +## IREE converter usage + +Basic usage of the script: + +```sh +iree_converter.py --source_framework \ + --model_name \ + --model \ + --weights \ + --torch_module \ + --input_shape \ + --onnx_opset_version \ + --output_mlir +``` + +This script converts model from `` to the IREE MLIR format. + +### IREE converter parameters +- `-f / --source_framework` is a source framework where the model was trained. Required. Choices: `onnx`, `pytorch`. +- `-mn / --model_name` is a model name. Required for PyTorch models loaded from module. +- `-m / --model` is a path to an `.onnx` or `.pt` file with a trained model. +- `-w / --weights` is a path to an `.pth` file with trained weights for PyTorch models. +- `-tm / --torch_module` is a module with the model architecture for PyTorch models. Default: `torchvision.models`. +- `-is / --input_shape` is an input shape in the format BxHxWxC, where B is a batch size, H is an input tensor height, W is an input tensor width, C is an input tensor number of channels. Required for PyTorch models. +- `--onnx_opset_version` is an ONNX opset version for ONNX models. Default: `18`. +- `-o / --output_mlir` is a path to save the MLIR file. Required. + +### Parameter combinations +#### For ONNX models: +- Required: `--source_framework onnx`, `--model `, `--output_mlir ` +- Optional: `--onnx_opset_version` (default: 18; the converter validates that the value is set, so keep the default or override it explicitly) +#### For PyTorch models: +Two loading methods are supported (mutually exclusive): +1. From file: +- Required: `--source_framework pytorch`, `--model `, `--input_shape B H W C`, `--output_mlir ` +- Optional: `--model_name ` (used only for logging), `--weights ` +1. From module: +- Required: `--source_framework pytorch`, `--model_name `, `--torch_module `, `--input_shape B H W C`, `--output_mlir ` +- Optional: `--weights ` + +> **Note:** `--model` and the pair `(--torch_module`, `--model_name)` are mutually exclusive. Passing both at the same time will raise a validation error (`converter.py` enforces the rule). Likewise, `--input_shape` is only validated for PyTorch conversions, so you can omit it for ONNX. + +### Examples of usage +ONNX model conversion ([source of the model efficientnet-b0.onnx](https://github.com/onnx/models/blob/main/Computer_Vision/efficientnet_b0_Opset17_timm/efficientnet_b0_Opset17.onnx)): +```sh +python3 iree_converter.py -f onnx -m efficientnet-b0.onnx \ + --onnx_opset_version 18 \ + -o ./output/efficientnet-b0.mlir +``` + +PyTorch model from file (`.pt` can be created using [tutorial](https://docs.pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules)): +```sh +python3 iree_converter.py -f pytorch -m resnet50.pt \ + -is 1 224 224 3 \ + -o ./output/resnet50.mlir +``` + +PyTorch model from [torchvision](https://docs.pytorch.org/vision/main/models.html) with pretrained weights: +```sh +python3 iree_converter.py -f pytorch -mn resnet50 \ + -tm torchvision.models \ + -is 1 224 224 3 \ + -o ./output/resnet50.mlir +``` + +PyTorch model with custom weights: +```sh +python3 iree_converter.py -f pytorch -mn resnet50 \ + -tm torchvision.models \ + -w ./weights/resnet50-custom.pth \ + -is 1 224 224 3 \ + -o ./output/resnet50-custom.mlir +``` + +## IREE compiler usage + +Basic usage of the script: +```sh +iree_compiler.py --mlir \ + --target_backend \ + --opt_level \ + --output_file \ + [--extra_args ] +``` +This script compiles model from `.mlir` format to the deployable binary format for the specified target backend. + +### IREE compiler parameters +- `-m / --mlir` is a path to an .mlir file with a model. Required. +- `-tb / --target_backend` is a target backend for compilation. Required. Examples: `llvm-cpu`, `cuda`, `vulkan`, `vmvx`. +- `--opt_level` is an optimization level of the compilation. Choices: `0`, `1`, `2`, `3`. Default: `2`. +- `-o / --output_file` is a path to save the compiled model. Required. +- `--extra_args` - is an extra arguments for compilation. Optional. + +### Supported target backends +- `llvm-cpu` - CPU execution using LLVM. +- `cuda` - NVIDIA GPU execution using CUDA. +- `vulkan` - GPU execution using Vulkan API. +- `vmvx` - Portable VM bytecode execution. +- `metal` - Apple GPU execution using Metal. +- `rocm` - AMD GPU execution using ROCm. + +### Examples of usage +```sh +python3 iree_compiler.py -m ./models/resnet50.mlir \ + -tb llvm-cpu \ + --opt_level 2 \ + -o ./compiled/resnet50-cpu.vmfb +``` +### Using extra arguments +The `--extra_args` parameter allows passing additional compilation flags: +```sh +python3 iree_compiler.py -m model.mlir -tb llvm-cpu -o output.vmfb \ + --extra_args --iree-llvmcpu-target-triple=x86_64-linux-gnu +``` diff --git a/src/model_converters/iree_converter/__init__.py b/src/model_converters/iree_converter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/model_converters/iree_converter/iree_auxiliary/__init__.py b/src/model_converters/iree_converter/iree_auxiliary/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/model_converters/iree_converter/iree_auxiliary/compiler.py b/src/model_converters/iree_converter/iree_auxiliary/compiler.py new file mode 100644 index 000000000..3792a6abd --- /dev/null +++ b/src/model_converters/iree_converter/iree_auxiliary/compiler.py @@ -0,0 +1,10 @@ +import os +from iree.compiler.tools import compile_str, compile_file + + +class IREECompiler: + @staticmethod + def compile_model(mlir, target, opt_level, extra_args, output_file=None): + extra_args.append(f'--iree-opt-level=O{opt_level}') + compile_func = compile_file if os.path.isfile(mlir) else compile_str + return compile_func(mlir, target_backends=[target], extra_args=extra_args, output_file=output_file) diff --git a/src/model_converters/iree_converter/iree_auxiliary/converter.py b/src/model_converters/iree_converter/iree_auxiliary/converter.py new file mode 100644 index 000000000..db5c7feec --- /dev/null +++ b/src/model_converters/iree_converter/iree_auxiliary/converter.py @@ -0,0 +1,39 @@ +import abc +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent.parent.joinpath('utils'))) +from logger_conf import configure_logger # noqa: E402 + +log = configure_logger() + + +class IREEConverter(metaclass=abc.ABCMeta): + def __init__(self, args): + self.model_name = args.get('model_name', None) + self.output_mlir = args.get('output_mlir', None) + self.log = log + + @abc.abstractmethod + def _convert_model_from_framework(self): + pass + + @property + @abc.abstractmethod + def source_framework(self): + pass + + @staticmethod + def get_converter(args): + framework = args['source_framework'].lower() + if framework == 'onnx': + from onnx_format import IREEConverterONNXFormat + return IREEConverterONNXFormat(args) + elif framework == 'pytorch': + from pytorch_format import IREEConverterPyTorchFormat + return IREEConverterPyTorchFormat(args) + + def convert_to_mlir(self): + self.log.info(f'Get IREE MLIR for {self.model_name} from {self.source_framework} framework') + self._convert_model_from_framework() + return diff --git a/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py b/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py new file mode 100644 index 000000000..0de82c784 --- /dev/null +++ b/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py @@ -0,0 +1,38 @@ +import subprocess +import os +from converter import IREEConverter + + +class IREEConverterONNXFormat(IREEConverter): + def __init__(self, args): + super().__init__(args) + self.model_path = args.get('model_path', None) + self.onnx_opset_version = args.get('onnx_opset_version', None) + self._validate_arguments() + + @property + def source_framework(self): + return 'ONNX' + + def _validate_arguments(self): + if self.model_path is None or self.model_path == '': + raise ValueError('The model_path parameter is required for ONNX conversion.') + + if not os.path.exists(self.model_path): + raise FileNotFoundError(f'Model file not found: {self.model_path}') + + if self.onnx_opset_version is None: + raise ValueError('The onnx_opset_version parameter is required for ONNX conversion.') + + def _convert_model_from_framework(self): + import_args = [ + 'iree-import-onnx', + self.model_path, + '--opset-version', + str(self.onnx_opset_version), + '-o', + self.output_mlir, + ] + import_cmd = subprocess.list2cmdline(import_args) + subprocess.run(import_cmd, shell=True, capture_output=True) + return diff --git a/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py b/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py new file mode 100644 index 000000000..787fef036 --- /dev/null +++ b/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py @@ -0,0 +1,86 @@ +import importlib +import os +from converter import IREEConverter + + +class IREEConverterPyTorchFormat(IREEConverter): + def __init__(self, args): + super().__init__(args) + self.torch = importlib.import_module('torch') + self.aot = importlib.import_module('iree.turbine.aot') + self.model_path = args.get('model_path', None) + self.model_weights = args.get('model_weights', None) + self.module = args.get('torch_module', None) + self.input_shape = args.get('input_shape', None) + self._validate_arguments() + + @property + def source_framework(self): + return 'PyTorch' + + def _validate_arguments(self): + if self.input_shape is None: + raise ValueError('The input_shape parameter is required for PyTorch conversion.') + + # Check load methods: + # 1. model_path (load from file) + # 2. module + model_name (load from torch module) + has_model_path = self.model_path is not None and self.model_path != '' + has_module_model = (self.module is not None + and self.module != '' + and self.model_name is not None + and self.model_name != '') + + if not has_model_path and not has_module_model: + raise ValueError( + 'For PyTorch conversion, you must specify either model_path, \ + or torch_module and model_name', + ) + + if has_model_path and has_module_model: + raise ValueError( + 'Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). \ + Please choose only one method of this.', + ) + + if has_model_path and not os.path.exists(self.model_path): + raise FileNotFoundError(f'Model file not found: {self.model_path}') + + if (self.model_weights is not None and self.model_weights != '' + and not os.path.exists(self.model_weights)): + raise FileNotFoundError(f'Model weights not found: {self.model_weights}') + + def __get_model_from_path(self): + self.log.info(f'Loading model from path {self.model_path}') + file_type = self.model_path.split('.')[-1] + supported_extensions = ['pt'] + if file_type not in supported_extensions: + raise ValueError(f'The file type {file_type} is not supported') + model = self.torch.load(self.model_path) + model.eval() + return model + + def __get_model_from_module(self): + self.log.info(f'Loading model {self.model_name} from module') + model_cls = importlib.import_module(self.module).__getattribute__(self.model_name) + if self.model_weights is None or self.model_weights == '': + self.log.info('Loading pretrained model') + model = model_cls(weights=True) + else: + self.log.info(f'Loading model with weights from file {self.model_weights}') + model = model_cls() + checkpoint = self.torch.load(self.model_weights, map_location=self.device.lower()) + model.load_state_dict(checkpoint, strict=False) + model.eval() + return model + + def _convert_model_from_framework(self): + model = None + if self.module: + model = self.__get_model_from_module() + else: + model = self.__get_model_from_path() + example_arg = self.torch.randn(*self.input_shape) + export_output = self.aot.export(model, example_arg) + export_output.save_mlir(self.output_mlir) + return diff --git a/src/model_converters/iree_converter/iree_compiler.py b/src/model_converters/iree_converter/iree_compiler.py new file mode 100644 index 000000000..ad8f72b0a --- /dev/null +++ b/src/model_converters/iree_converter/iree_compiler.py @@ -0,0 +1,54 @@ +import argparse +import sys +import os +import traceback +from pathlib import Path +from iree_auxiliary.compiler import IREECompiler + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) +from utils.logger_conf import configure_logger # noqa: E402 + +log = configure_logger() + + +def cli_argument_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--mlir', + help='Path to an .mlir file with a model.', + required=True, + type=str) + parser.add_argument('-tb', '--target_backend', + help='Target backend, for example "llvm-cpu" for CPU.', + required=True, + type=str) + parser.add_argument('--opt_level', + help='The optimization level of the task extractions.', + type=int, + choices=[0, 1, 2, 3], + default=2) + parser.add_argument('--extra_args', + help='The extra arguments for compilation.', + type=str, + nargs=argparse.REMAINDER, + default=[]) + parser.add_argument('-o', '--output_file', + help='Path to compiled model.', + required=True, + type=str) + args = parser.parse_args() + return args + + +def main(): + args = cli_argument_parser() + try: + IREECompiler.compile_model(args.mlir, args.target_backend, args.opt_level, args.extra_args, args.output_file) + if os.path.exists(args.output_file): + print(f'The MLIR has been sucessfully compiled into {args.output_file}') + except Exception: + log.error(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/src/model_converters/iree_converter/iree_converter.py b/src/model_converters/iree_converter/iree_converter.py new file mode 100644 index 000000000..60ac06795 --- /dev/null +++ b/src/model_converters/iree_converter/iree_converter.py @@ -0,0 +1,89 @@ +import argparse +import os +import sys +import traceback +from pathlib import Path +sys.path.append(str(Path(__file__).resolve().parent.joinpath('iree_auxiliary'))) +from converter import IREEConverter # noqa: E402 + +sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) +from utils.logger_conf import configure_logger # noqa: E402 + +log = configure_logger() + + +def cli_argument_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--source_framework', + help='Source model framework', + required=True, + type=str, + choices=['onnx', 'pytorch'], + dest='source_framework') + parser.add_argument('-mn', '--model_name', + help='Model name.', + type=str, + dest='model_name') + parser.add_argument('-m', '--model', + help='Path to an .onnx or .pt file with a trained model.', + type=str, + dest='model_path') + parser.add_argument('-w', '--weights', + help='Path to an .pth file with a trained weights.', + type=str, + dest='model_weights') + parser.add_argument('-tm', '--torch_module', + help='Torch module with model architecture.', + default='torchvision.models', + type=str, + dest='torch_module') + parser.add_argument('--onnx_opset_version', + help='Path to an .onnx with a trained model.', + type=int, + default=18, + dest='onnx_opset_version') + parser.add_argument('-is', '--input_shape', + help='Input shape BxCxHxW, B is a batch size,' + 'C is an input tensor number of channels,' + 'H is an input tensor height,' + 'W is an input tensor width.', + type=int, + nargs=4, + dest='input_shape') + parser.add_argument('-o', '--output_mlir', + help='Path to save the MLIR.', + required=True, + type=str, + dest='output_mlir') + args = parser.parse_args() + return args + + +def create_dict_for_converter(args): + dictionary = { + 'source_framework': args.source_framework, + 'model_name': args.model_name, + 'model_path': args.model_path, + 'model_weights': args.model_weights, + 'torch_module': args.torch_module, + 'onnx_opset_version': args.onnx_opset_version, + 'input_shape': args.input_shape, + 'output_mlir': args.output_mlir, + } + return dictionary + + +def main(): + args = cli_argument_parser() + try: + converter = IREEConverter.get_converter(create_dict_for_converter(args)) + converter.convert_to_mlir() + if os.path.exists(args.output_mlir): + print(f'The MLIR has been sucessfully saved into {args.output_mlir}') + except Exception: + log.error(traceback.format_exc()) + sys.exit(1) + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/src/quantization/process.py b/src/quantization/process.py index 506c97a9e..3bf5f5787 100644 --- a/src/quantization/process.py +++ b/src/quantization/process.py @@ -74,7 +74,7 @@ def execute(self, idx): command_line = self.__fill_command_line() if command_line == '': self.__log.error('Command line is empty') - self.__log.info(f'Start quantization model #{idx+1}!') + self.__log.info(f'Start quantization model #{idx + 1}!') self.__log.info(f'Command line is : {command_line}') self._status, self._output = self._executor.execute_process(command_line) if type(self._output) is not list: