diff --git a/README.md b/README.md
index e9ef49bd7..735178fda 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ DLI supports inference using the following frameworks:
- [ncnn][ncnn] (Python API).
- [PaddlePaddle][PaddlePaddle] (Python API).
- [ExecuTorch][executorch] (C++ and Python APIs)
+- [IREE][iree] (Python API)
More information about DLI is available on the web-site
([here][dli-ru-web-page] (in Russian)
@@ -105,6 +106,7 @@ Please consider citing the following papers.
for TensorFlow.
- `TensorFlowLite` is a directory of Dockerfiles for TensorFlow Lite.
- `TVM` is a directory of Dockerfiles for Apache TVM.
+ - `IREE` is a directory of Dockerfiles for IREE.
- `docs` directory contains auxiliary documentation. Please, find
complete documentation at the [Wiki page][dli-wiki].
@@ -158,6 +160,9 @@ Please consider citing the following papers.
- [`validation_results_tvm.md`](results/validation/validation_results_tvm.md)
is a table that confirms correctness of inference implementation
based on Apache TVM for several public models.
+ - [`validation_results_iree.md`](results/validation/validation_results_iree.md)
+ is a table that confirms correctness of inference implementation
+ based on IREE for several public models.
- [`mxnet_models_checklist.md`](results/mxnet_models_checklist.md) contains a list
of deep models inferred by MXNet checked in the DLI benchmark.
@@ -282,6 +287,7 @@ Report questions, issues and suggestions, using:
[ncnn]: https://github.com/Tencent/ncnn
[PaddlePaddle]: https://www.paddlepaddle.org.cn/en
[executorch]: https://pytorch.org/executorch-overview
+[iree]: https://iree.dev
[benchmark-app]: https://github.com/openvinotoolkit/openvino/tree/master/samples/cpp/benchmark_app
[dli-ru-web-page]: http://hpc-education.unn.ru/dli-ru
[dli-web-page]: http://hpc-education.unn.ru/dli
diff --git a/docker/IREE/Dockerfile b/docker/IREE/Dockerfile
new file mode 100644
index 000000000..c9bf64d24
--- /dev/null
+++ b/docker/IREE/Dockerfile
@@ -0,0 +1,19 @@
+FROM ubuntu_for_dli
+
+# Install IREE
+ARG IREE_VERSION=3.8.0
+RUN python3 -m pip install iree-base-compiler==${IREE_VERSION} iree-base-runtime==${IREE_VERSION} iree-turbine==${IREE_VERSION}
+
+# Install dependencies
+RUN python3 -m pip install opencv-python numpy
+
+# Install onnx for model conversion
+ARG ONNX_VERSION=1.19.1
+RUN python3 -m pip install onnx==${ONNX_VERSION}
+
+# Install torch for model conversion
+ARG TORCH_VERSION=2.9.1
+ARG TORCHVISION_VERSION=0.24.1
+RUN python3 -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION}
+
+WORKDIR /tmp/
diff --git a/requirements_frameworks.txt b/requirements_frameworks.txt
index 0334d3635..2684c5e55 100644
--- a/requirements_frameworks.txt
+++ b/requirements_frameworks.txt
@@ -18,8 +18,12 @@ dglgo==0.0.2
tflite
paddleslim==2.6.0
-paddlepaddle==2.6.0
+paddlepaddle==2.6.2
--extra-index-url https://mirror.baidu.com/pypi/simple
ncnn
spektral==1.3.0
+
+iree-base-compiler==3.8.0
+iree-base-runtime==3.8.0
+iree-turbine==3.8.0
\ No newline at end of file
diff --git a/results/validation/validation_results_iree.md b/results/validation/validation_results_iree.md
new file mode 100644
index 000000000..d4753fb8e
--- /dev/null
+++ b/results/validation/validation_results_iree.md
@@ -0,0 +1,112 @@
+# Validation results for the models inferring using IREE
+
+## Public models
+
+We infer models using the following APIs:
+
+1. IREE, when we load PyTorch models directly from source format.
+
+ ```bash
+ python inference_iree.py -t classification -is 1 3 224 224 \
+ -mn densenet121 \
+ -tm torchvision.models \
+ -f pytorch \
+ -i data/ \
+ --norm --mean 0.485 0.456 0.406 --std 0.229 0.224 0.225 \
+ -l labels/image_net_synset.txt \
+ --layout NCHW --channel_swap 2 1 0 \
+ -fn main
+ ```
+
+1. IREE, when we load ONNX models directly from source format.
+
+ ```bash
+ python inference_iree.py -t classification -is 1 3 224 224 \
+ -mn densenet121 \
+ -m densenet121.onnx \
+ -f onnx \
+ --onnx_opset_version 18 \
+ -i data/ \
+ --norm --mean 0.485 0.456 0.406 --std 0.229 0.224 0.225 \
+ -l labels/image_net_synset.txt \
+ --layout NCHW --channel_swap 2 1 0 \
+ -fn main_graph
+ ```
+
+1. PyTorch as source framework for reference.
+
+ ```bash
+ python inference_pytorch.py -t classification -is [1,3,224,224] \
+ --input_names data \
+ -mn densenet121 \
+ -mm torchvision.models \
+ -i data/ \
+ --mean [123.675,116.28,103.53] \
+ --input_scale [58.395,57.12,57.375] \
+ -l labels/image_net_synset.txt
+ ```
+
+### Notes
+
+1. Models in ONNX format loaded from [onnx/models][onnx-models] repository.
+1. The model `squeezenet1.1` is missed in [onnx/models][onnx-models] repository.
+
+### Image classification
+
+#### Test image #1
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 709 x 510
+
+
+

+
+
+Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) |
+-|-|-|-|-|
+densenet-121 | PyTorch | 0.9525911 Granny Smith
0.0132309 orange
0.0123391 lemon
0.0028140 banana
0.0020238 piggy bank, penny bank | 0.9523347 Granny Smith
0.0132272 orange
0.0125170 lemon
0.0027910 banana
0.0020333 piggy bank, penny bank | 0.9523349 Granny Smith
0.0132271 orange
0.0125169 lemon
0.0027909 banana
0.0020333 piggy bank, penny bank |
+efficientnet-b0 | PyTorch | 0.3421609 Granny Smith
0.1089311 piggy bank, penny bank
0.0693323 teapot
0.0249018 vase
0.0205339 saltshaker, salt shaker | 0.3421628 Granny Smith
0.1089310 piggy bank, penny bank
0.0693315 teapot
0.0249016 vase
0.0205339 saltshaker, salt shaker | 0.3421622 Granny Smith
0.1089308 piggy bank, penny bank
0.0693314 teapot
0.0249017 vase
0.0205338 saltshaker, salt shaker |
+googlenet-v1 | PyTorch | 0.5399834 Granny Smith
0.1101810 piggy bank, penny bank
0.0232574 vase
0.0213452 pitcher, ewer
0.0198953 bell pepper | 0.5432554 Granny Smith
0.1103971 piggy bank, penny bank
0.0232568 vase
0.0213901 pitcher, ewer
0.0196196 bell pepper | 0.5432543 Granny Smith
0.1103970 piggy bank, penny bank
0.0232569 vase
0.0213901 pitcher, ewer
0.0196196 bell pepper |
+resnet-50 | PyTorch | 0.9280675 Granny Smith
0.0129466 orange
0.0058861 lemon
0.0041993 necklace
0.0025445 banana | 0.9278086 Granny Smith
0.0129410 orange
0.0059573 lemon
0.0042141 necklace
0.0025712 banana | 0.4216066 Granny Smith
0.0661015 dumbbell
0.0348192 barbell
0.0049673 orange
0.0045203 syringe |
+squeezenet1.1 | PyTorch | 0.5913458 piggy bank, penny bank
0.0682889 Granny Smith
0.0610993 lemon
0.0596012 necklace
0.0492096 bucket, pail | 0.5895361 piggy bank, penny bank
0.0677933 Granny Smith
0.0610654 necklace
0.0610450 lemon
0.0490914 bucket, pail | - |
+
+#### Test image #2
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 500 x 500
+
+
+

+
+
+Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) |
+-|-|-|-|-|
+densenet-121 | PyTorch | 0.9847536 junco, snowbird
0.0068679 chickadee
0.0034511 brambling, Fringilla montifringilla
0.0015685 water ouzel, dipper
0.0012343 indigo bunting, indigo finch, indigo bird, Passerina cyanea | 0.9841590 junco, snowbird
0.0072199 chickadee
0.0034962 brambling, Fringilla montifringilla
0.0016226 water ouzel, dipper
0.0012858 indigo bunting, indigo finch, indigo bird, Passerina cyanea | 0.9841590 junco, snowbird
0.0072199 chickadee
0.0034962 brambling, Fringilla montifringilla
0.0016226 water ouzel, dipper
0.0012858 indigo bunting, indigo finch, indigo bird, Passerina cyanea |
+efficientnet-b0 | PyTorch | 0.8903497 junco, snowbird
0.0147084 water ouzel, dipper
0.0074830 chickadee
0.0044766 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis | 0.8903519 junco, snowbird
0.0147081 water ouzel, dipper
0.0074829 chickadee
0.0044765 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis | 0.8903498 junco, snowbird
0.0147084 water ouzel, dipper
0.0074830 chickadee
0.0044766 brambling, Fringilla montifringilla
0.0027406 goldfinch, Carduelis carduelis |
+googlenet-v1 | PyTorch | 0.6449553 junco, snowbird
0.0752306 chickadee
0.0480572 brambling, Fringilla montifringilla
0.0298399 goldfinch, Carduelis carduelis
0.0126128 house finch, linnet, Carpodacus mexicanus | 0.6461055 junco, snowbird
0.0772564 chickadee
0.0468782 brambling, Fringilla montifringilla
0.0295897 goldfinch, Carduelis carduelis
0.0123322 house finch, linnet, Carpodacus mexicanus | 0.6461049 junco, snowbird
0.0772565 chickadee
0.0468783 brambling, Fringilla montifringilla
0.0295897 goldfinch, Carduelis carduelis
0.0123323 house finch, linnet, Carpodacus mexicanus |
+resnet-50 | PyTorch | 0.9809760 junco, snowbird
0.0049167 goldfinch, Carduelis carduelis
0.0036987 chickadee
0.0036697 water ouzel, dipper
0.0029304 brambling, Fringilla montifringilla | 0.9805012 junco, snowbird
0.0049154 goldfinch, Carduelis carduelis
0.0039196 chickadee
0.0038098 water ouzel, dipper
0.0028983 brambling, Fringilla montifringilla | 0.3845567 junco, snowbird
0.0091156 water ouzel, dipper
0.0054526 chickadee
0.0026206 indigo bunting, indigo finch, indigo bird, Passerina cyanea
0.0023612 brambling, Fringilla montifringilla |
+squeezenet1.1 | PyTorch | 0.9609295 junco, snowbird
0.0248581 chickadee
0.0042597 brambling, Fringilla montifringilla
0.0037157 goldfinch, Carduelis carduelis
0.0033528 ruffed grouse, partridge, Bonasa umbellus | 0.9614577 junco, snowbird
0.0250981 chickadee
0.0040701 brambling, Fringilla montifringilla
0.0035156 goldfinch, Carduelis carduelis
0.0030858 ruffed grouse, partridge, Bonasa umbellus | - |
+
+#### Test image #3
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 333 x 500
+
+
+

+
+
+Model | Source Framework | Python API (source framework) | Python API (IREE, PyTorch) | Python API (IREE, ONNX) |
+-|-|-|-|-|
+densenet-121 | PyTorch | 0.3047960 liner, ocean liner
0.1327189 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1180288 container ship, containership, container vessel
0.0794686 drilling platform, offshore rig
0.0718431 dock, dockage, docking facility | 0.3022414 liner, ocean liner
0.1322474 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1194614 container ship, containership, container vessel
0.0795042 drilling platform, offshore rig
0.0723073 dock, dockage, docking facility | 0.3022407 liner, ocean liner
0.1322481 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1194605 container ship, containership, container vessel
0.0795041 drilling platform, offshore rig
0.0723069 dock, dockage, docking facility |
+efficientnet-b0 | PyTorch | 0.4476882 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953832 container ship, containership, container vessel
0.0872342 beacon, lighthouse, beacon light, pharos
0.0559825 drilling platform, offshore rig
0.0441807 liner, ocean liner | 0.4476875 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953838 container ship, containership, container vessel
0.0872344 beacon, lighthouse, beacon light, pharos
0.0559831 drilling platform, offshore rig
0.0441806 liner, ocean liner | 0.4476894 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0953836 container ship, containership, container vessel
0.0872341 beacon, lighthouse, beacon light, pharos
0.0559827 drilling platform, offshore rig
0.0441803 liner, ocean liner |
+googlenet-v1 | PyTorch | 0.1330581 liner, ocean liner
0.0796951 drilling platform, offshore rig
0.0680323 container ship, containership, container vessel
0.0588053 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0365606 fireboat | 0.1323653 liner, ocean liner
0.0796393 drilling platform, offshore rig
0.0678083 container ship, containership, container vessel
0.0585719 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0366882 fireboat | 0.1323648 liner, ocean liner
0.0796394 drilling platform, offshore rig
0.0678085 container ship, containership, container vessel
0.0585720 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0366881 fireboat |
+resnet-50 | PyTorch | 0.4818293 liner, ocean liner
0.0992477 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0687505 container ship, containership, container vessel
0.0517874 dock, dockage, docking facility
0.0483462 pirate, pirate ship | 0.4759648 liner, ocean liner
0.1025407 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0689996 container ship, containership, container vessel
0.0524496 dock, dockage, docking facility
0.0473777 pirate, pirate ship | 0.1220204 lifeboat
0.0430796 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0360478 beacon, lighthouse, beacon light, pharos
0.0335465 dock, dockage, docking facility
0.0251255 liner, ocean liner |
+squeezenet1.1 | PyTorch | 0.4393108 liner, ocean liner
0.1895231 container ship, containership, container vessel
0.1506845 pirate, pirate ship
0.0962459 fireboat
0.0199389 drilling platform, offshore rig | 0.4413096 liner, ocean liner
0.1931005 container ship, containership, container vessel
0.1459103 pirate, pirate ship
0.0937753 fireboat
0.0198682 drilling platform, offshore rig | - |
+
+
+[imagenet]: http://www.image-net.org
+[onnx-models]: https://github.com/onnx/models/tree/main
diff --git a/src/accuracy_checker/process.py b/src/accuracy_checker/process.py
index 6661be780..df117097f 100644
--- a/src/accuracy_checker/process.py
+++ b/src/accuracy_checker/process.py
@@ -25,7 +25,7 @@ def execute(self, idx):
command_line = self.__fill_command_line()
if command_line == '':
self.__log.error('Command line is empty')
- self.__log.info(f'Start accuracy check for {idx+1} test: {self._test.model.name}')
+ self.__log.info(f'Start accuracy check for {idx + 1} test: {self._test.model.name}')
self.__log.info(f'Command line is : {command_line}')
self._executor.set_target_framework(self._test.framework)
command_line = self._executor.prepare_command_line(self._test, command_line)
diff --git a/src/benchmark/README.md b/src/benchmark/README.md
index c44fd7e9a..c03ac63ef 100644
--- a/src/benchmark/README.md
+++ b/src/benchmark/README.md
@@ -22,6 +22,7 @@ the following frameworks:
- [RKNN][rknn].
- [Spektral][spektral] (Python API).
- [PaddlePaddle][paddlepaddle] (Python API).
+- [IREE][iree] (Python API).
### Implemented algorithm
@@ -274,3 +275,4 @@ pip install openvino_dev[mxnet,caffe,caffe2,onnx,pytorch,tensorflow2]==` внутри секции ``, чтобы указать путь к модулю с архитектурой (значение будет проброшено в параметр `--torch_module` скрипта инференса).
+
### Примеры заполнения
#### Пример заполнения конфигурации для измерения производительности вывода средствами Intel Distribution of OpenVINO Toolkit
@@ -773,6 +788,47 @@
```
+#### Пример заполнения конфигурации для измерения производительности вывода средствами IREE
+
+```xml
+
+
+
+ classification
+ resnet50
+ FP32
+ onnx
+ /home/user/models/resnet50/resnet50.onnx
+
+
+
+
+ ImageNet
+ /mnt/datasets/ILSVRC2012_img_val
+
+
+ IREE
+ 1
+ CPU
+ 20
+ 60
+
+
+ main
+ 1 3 224 224
+ NCHW
+ True
+ 0.485 0.456 0.406
+ 0.229 0.224 0.225
+ 2 1 0
+ llvm-cpu
+ 3
+ 17
+ --iree-llvmcpu-target-cpu-features=host
+
+
+```
+
#### Пример заполнения конфигурации для измерения производительности вывода средствами RKNN C++ API
```xml
diff --git a/src/inference/README.md b/src/inference/README.md
index da8fe747b..3157f765c 100644
--- a/src/inference/README.md
+++ b/src/inference/README.md
@@ -17,6 +17,7 @@
1. ncnn.
1. PaddlePaddle.
1. Spektral.
+1. IREE.
## Вывод глубоких моделей с использованием Inference Engine
@@ -1482,6 +1483,124 @@ python inference_ncnn.py --model \
--batch_size
```
+## Вывод глубоких моделей с использованием IREE
+
+#### Скрипт
+
+```bash
+inference_iree.py
+```
+
+#### Общие обязательные аргументы
+
+- `-fn / --function_name` - имя функции внутри IREE-модуля, которое будет вызвано при инференсе.
+- `-i / --input` - путь до изображения или директории с изображениями
+ (расширения файлов `.jpg`, `.png`, `.bmp` и т.д.).
+- `-is / --input_shape` - размеры входного тензора сети в формате
+ BxCxHxW, B - размер пачки, C - количество каналов изображений,
+ H - высота изображений, W - ширина изображений.
+
+#### Остальные параметры зависят от того, в каком формате модель подается на вход.
+
+- **Готовый IREE-модуль (`.vmfb` или `.mlir`)**
+ - `-m / --model` - путь до модели в формате `.vmfb` (готовый бинарник) или `.mlir` (будет скомпилирован перед запуском). Обязательный параметр.
+ - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр, если модель в формате `.mlir`.
+
+- **ONNX-модель**
+ - `--source_framework onnx` - фреймворк, из которого будет загружена модель. Обязательный параметр.
+ - `-m / --model` - путь до модели в формате `.onnx`. Обязательный параметр.
+ - `--onnx_opset_version` - версия ONNX-opset (по умолчанию `18`).
+ - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр.
+
+- **PyTorch-модель из файла**
+ - `--source_framework pytorch` - фреймворк, из которого будет загружена модель. Обязательный параметр.
+ - `-m / --model` - путь до модели в формате `.pt`. Обязательный параметр.
+ - `-w / --weights` - путь до файла с весами модели в формате `.pth`. Опциональный параметр.
+ - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр.
+
+- **PyTorch-модель из модуля**
+ - `--source_framework pytorch` - фреймворк, из которого будет загружена модель. Обязательный параметр.
+ - `-tm / --torch_module` - путь до Python модуля или относительный путь
+ до Python файла с архитектурой модели (например `torchvision.models` для модуля с [моделями][torchvision_models]). Обязательный параметр.
+ - `-mn / --model_name` - название модели. Обязательный параметр.
+ - `-w / --weights` - путь до файла с весами модели в формате `.pth`. Опциональный параметр.
+ - `-tb / --target_backend` - целевой backend для компиляции и исполнения (`llvm-cpu`, `cuda`, `vulkan`, `metal`, `rocm`, `vmvx` и т.д.). По умолчанию `llvm-cpu`. Обязательный параметр.
+
+#### Опциональные аргументы
+
+- `-b / --batch_size` - количество изображений, которые будут обработаны
+ за один проход сети. По умолчанию равно `1`. Значение данного параметра
+ должно быть равно значению B из параметра `input_shape`.
+- `-t / --task` - название задачи. Текущая реализация поддерживает
+ решение задачи классификации (`classification`). По умолчанию принимает значение `feedforward`.
+- `-l / --labels`- путь до файла в формате JSON с перечнем меток
+ при решении задачи классификации. По умолчанию принимает значение
+ `image_net_labels.json`, что соответствует меткам набора данных
+ ImageNet.
+- `-nt / --number_top` - количество лучших результатов, выводимых при решении задачи классификации. По умолчанию выводится `5` наилучших
+ результатов.
+- `-ni / --number_iter` - количество прямых проходов по сети.
+ По умолчанию выполняется `1` проход по сети.
+- `--raw_output` - работа скрипта без логов. По умолчанию не установлен.
+- `--time` – ограничение по времени в секундах. Если заданы одновременно `--time` и `-ni`, выполняется тот сценарий, который дольше.
+- `--report_path` – путь к `.json`-отчёту (по умолчанию `src/inference/iree_inference_report.json`).
+- `--layout` – формат входного тензора (`NHWC` или `NCHW`, по умолчанию `NCHW`).
+- `--norm` – флаг нормализации изображения (делит значения на `255` перед дальнейшей обработкой).
+- `--mean`, `--std`, `--channel_swap` – параметры препроцессинга. Значения по умолчанию: `mean=[0, 0, 0]`, `std=[1, 1, 1]`, `channel_swap=[2, 1, 0]`.
+- `--opt_level` – уровень оптимизаций, если перед началом вывода потребуется компиляция модели (`0–3`, по умолчанию `2`).
+- `--extra_compile_args` – дополнительные флаги компиляции (должны указываться строго в конце командной строки).
+ ```
+ --extra_compile_args --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-linux-gnu
+ ```
+
+#### Примеры запуска
+
+**Готовый `.vmfb`**
+
+```bash
+python3 inference_iree.py \
+ -m compiled/resnet50.vmfb \
+ -fn main \
+ -i ./data/images \
+ -is 1 3 224 224 \
+ -b 1 -ni 100 \
+ -t classification \
+ -l ./labels/imagenet_synset.txt
+```
+
+**Автоконвертация ONNX -> MLIR -> VMFB**
+
+```bash
+python3 inference_iree.py \
+ --source_framework onnx \
+ -m ./models/efficientnet-b0.onnx \
+ --onnx_opset_version 18 \
+ -fn main \
+ -i ./data/test.jpg \
+ -is 1 3 224 224 \
+ -tb llvm-cpu \
+ --opt_level 3 \
+ --extra_compile_args --iree-vulkan-target-triple=rdna2-pc-linux-gnu
+```
+
+** Автоконвертация Pytorch модели из `torchvision`**
+
+```bash
+python3 inference_iree.py \
+ --source_framework pytorch \
+ -mn resnet50 \
+ -tm torchvision.models \
+ -fn classification \
+ -i ./data/images \
+ -is 1 3 224 224 \
+ -tb llvm-cpu \
+ --mean 123.68 116.78 103.94 \
+ --std 58.40 57.12 57.38
+```
+
+Результат выполнения: набор наиболее вероятных классов, которым принадлежит
+изображение.
+
[execution_providers]: https://onnxruntime.ai/docs/execution-providers
[gluon_modelzoo]: https://cv.gluon.ai/model_zoo/index.html
@@ -1492,3 +1611,4 @@ python inference_ncnn.py --model \
[dgl]: https://www.dgl.ai/pages/start.html
[ogb]: https://ogb.stanford.edu/
[tensorflow-gpu]: https://www.tensorflow.org/install/pip
+[iree]: https://iree.dev
diff --git a/src/inference/inference_iree.py b/src/inference/inference_iree.py
new file mode 100644
index 000000000..7ba741b6d
--- /dev/null
+++ b/src/inference/inference_iree.py
@@ -0,0 +1,290 @@
+import argparse
+import sys
+import traceback
+from pathlib import Path
+
+import postprocessing_data as pp
+from inference_tools.loop_tools import loop_inference, get_exec_time
+from io_adapter import IOAdapter
+from io_model_wrapper import IREEModelWrapper
+from reporter.report_writer import ReportWriter
+from transformer import IREETransformer
+from iree_auxiliary import (load_model, create_dict_for_transformer, prepare_output, validate_cli_args)
+
+
+sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
+from logger_conf import configure_logger # noqa: E402
+
+log = configure_logger()
+
+try:
+ import iree.runtime as ireert # noqa: E402
+except ImportError as e:
+ log.error(f'IREE import error: {e}')
+ sys.exit(1)
+
+
+def cli_argument_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-f', '--source_framework',
+ help='Source model framework (required for automatic conversion to MLIR)',
+ type=str,
+ choices=['onnx', 'pytorch'],
+ dest='source_framework')
+ parser.add_argument('-m', '--model',
+ help='Path to source framework model (.onnx, .pt),'
+ 'to file with compiled model (.vmfb)'
+ 'or MLIR (.mlir).',
+ type=str,
+ dest='model')
+ parser.add_argument('-w', '--weights',
+ help='Path to an .pth file with a trained weights.'
+ 'Availiable when source_framework=pytorch ',
+ type=str,
+ dest='model_weights')
+ parser.add_argument('-tm', '--torch_module',
+ help='Torch module with model architecture.'
+ 'Availiable when source_framework=pytorch',
+ type=str,
+ dest='torch_module')
+ parser.add_argument('-mn', '--model_name',
+ help='Model name.',
+ type=str,
+ dest='model_name')
+ parser.add_argument('--onnx_opset_version',
+ help='Path to an .onnx with a trained model.'
+ 'Availiable when source_framework=onnx',
+ type=int,
+ dest='onnx_opset_version')
+ parser.add_argument('-fn', '--function_name',
+ help='IREE module function name to execute.',
+ required=True,
+ type=str,
+ dest='function_name')
+ parser.add_argument('-i', '--input',
+ help='Path to data.',
+ required=True,
+ type=str,
+ nargs='+',
+ dest='input')
+ parser.add_argument('-is', '--input_shape',
+ help='Input shape BxCxHxW, B is a batch size,'
+ 'C is an input tensor number of channels,'
+ 'H is an input tensor height,'
+ 'W is an input tensor width.',
+ required=True,
+ type=int,
+ nargs=4,
+ dest='input_shape')
+ parser.add_argument('-b', '--batch_size',
+ help='Size of the processed pack.'
+ 'Should be the same as B in input_shape argument.',
+ default=1,
+ type=int,
+ dest='batch_size')
+ parser.add_argument('-l', '--labels',
+ help='Labels mapping file.',
+ default=None,
+ type=str,
+ dest='labels')
+ parser.add_argument('-nt', '--number_top',
+ help='Number of top results.',
+ default=5,
+ type=int,
+ dest='number_top')
+ parser.add_argument('-t', '--task',
+ help='Task type. Default: feedforward.',
+ choices=['feedforward', 'classification'],
+ default='feedforward',
+ type=str,
+ dest='task')
+ parser.add_argument('-ni', '--number_iter',
+ help='Number of inference iterations.',
+ default=1,
+ type=int,
+ dest='number_iter')
+ parser.add_argument('--raw_output',
+ help='Raw output without logs.',
+ default=False,
+ type=bool,
+ dest='raw_output')
+ parser.add_argument('--time',
+ required=False,
+ default=0,
+ type=int,
+ dest='time',
+ help='Optional. Maximum test duration. 0 if no restrictions.')
+ parser.add_argument('--report_path',
+ type=Path,
+ default=Path(__file__).parent / 'iree_inference_report.json',
+ dest='report_path')
+ parser.add_argument('--layout',
+ help='Input layout.',
+ default='NCHW',
+ choices=['NHWC', 'NCHW'],
+ type=str,
+ dest='layout')
+ parser.add_argument('--norm',
+ help='Flag to normalize input images.',
+ action='store_true',
+ dest='norm')
+ parser.add_argument('--mean',
+ help='Mean values.',
+ default=[0, 0, 0],
+ type=float,
+ nargs=3,
+ dest='mean')
+ parser.add_argument('--std',
+ help='Standard deviation values.',
+ default=[1., 1., 1.],
+ type=float,
+ nargs=3,
+ dest='std')
+ parser.add_argument('--channel_swap',
+ help='Parameter of channel swap.',
+ default=[2, 1, 0],
+ type=int,
+ nargs=3,
+ dest='channel_swap')
+ parser.add_argument('-tb', '--target_backend',
+ help='Target backend, for example `llvm-cpu` for CPU.',
+ default='llvm-cpu',
+ type=str,
+ dest='target_backend')
+ parser.add_argument('--opt_level',
+ help='The optimization level of the compilation.',
+ type=int,
+ choices=[0, 1, 2, 3],
+ default=2)
+ parser.add_argument('--extra_compile_args',
+ help='The extra arguments for MLIR compilation.',
+ type=str,
+ nargs=argparse.REMAINDER,
+ default=[])
+ args = parser.parse_args()
+ validate_cli_args(args)
+ return args
+
+
+def get_inference_function(model_context, function_name):
+ try:
+ main_module = model_context.modules.module
+ inference_func = main_module[function_name]
+ log.info(f'Using function {function_name} for inference')
+ return inference_func
+
+ except Exception as e:
+ log.error(f'Failed to get inference function: {e}')
+ raise
+
+
+def inference_iree(inference_func, number_iter, get_slice, test_duration):
+ result = None
+ time_infer = []
+
+ if number_iter == 1:
+ slice_input = get_slice()
+ result, exec_time = infer_slice(inference_func, slice_input)
+ time_infer.append(exec_time)
+ else:
+ time_infer = loop_inference(number_iter, test_duration)(
+ inference_iteration,
+ )(inference_func, get_slice)['time_infer']
+
+ log.info('Inference completed')
+ return result, time_infer
+
+
+def inference_iteration(inference_func, get_slice):
+ slice_input = get_slice()
+ _, exec_time = infer_slice(inference_func, slice_input)
+ return exec_time
+
+
+@get_exec_time()
+def infer_slice(inference_func, slice_input):
+ config = ireert.Config('local-task')
+ device = config.device
+
+ input_buffers = []
+ for input_ in slice_input:
+ input_buffers.append(ireert.asdevicearray(device, input_))
+
+ result = inference_func(*input_buffers)
+
+ if hasattr(result, 'to_host'):
+ result = result.to_host()
+
+ return result
+
+
+def main():
+ args = cli_argument_parser()
+
+ try:
+ model_wrapper = IREEModelWrapper(args)
+ data_transformer = IREETransformer(create_dict_for_transformer(args))
+ io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer)
+
+ report_writer = ReportWriter()
+ report_writer.update_framework_info(name='IREE')
+ report_writer.update_configuration_setup(
+ batch_size=args.batch_size,
+ iterations_num=args.number_iter,
+ target_device=args.target_backend,
+ )
+
+ log.info('Loading model')
+ model_context = load_model(
+ model_path=args.model,
+ model_weights=args.model_weights,
+ torch_module=args.torch_module,
+ model_name=args.model_name,
+ onnx_opset_version=args.onnx_opset_version,
+ source_framework=args.source_framework,
+ input_shape=args.input_shape,
+ target_backend=args.target_backend,
+ opt_level=args.opt_level,
+ extra_compile_args=args.extra_compile_args,
+ )
+ inference_func = get_inference_function(model_context, args.function_name)
+
+ log.info(f'Preparing input data: {args.input}')
+ io.prepare_input(model_context, args.input)
+
+ log.info(f'Starting inference ({args.number_iter} iterations) on {args.target_backend}')
+ result, inference_time = inference_iree(
+ inference_func,
+ args.number_iter,
+ io.get_slice_input_iree,
+ args.time,
+ )
+
+ log.info('Computing performance metrics')
+ inference_result = pp.calculate_performance_metrics_sync_mode(
+ args.batch_size,
+ inference_time,
+ )
+
+ report_writer.update_execution_results(**inference_result)
+ report_writer.write_report(args.report_path)
+
+ if not args.raw_output:
+ if args.number_iter == 1:
+ try:
+ log.info('Converting output tensor to print results')
+ result = prepare_output(result, args.task)
+ log.info('Inference results')
+ io.process_output(result, log)
+ except Exception as ex:
+ log.warning(f'Error when printing inference results: {str(ex)}')
+
+ log.info(f'Performance results: {inference_result}')
+
+ except Exception:
+ log.error(traceback.format_exc())
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ sys.exit(main() or 0)
diff --git a/src/inference/io_adapter.py b/src/inference/io_adapter.py
index 75693e2e4..010a27748 100644
--- a/src/inference/io_adapter.py
+++ b/src/inference/io_adapter.py
@@ -186,6 +186,14 @@ def get_slice_input(self, *args, **kwargs):
return slice_input
+ def get_slice_input_iree(self, *args, **kwargs):
+ slice_input = []
+ for key in self._transformed_input:
+ data_gen = self._transformed_input[key]
+ slice_data = [copy.deepcopy(next(data_gen)) for _ in range(self._batch_size)]
+ slice_input.append(np.stack(slice_data))
+ return slice_input
+
def get_slice_input_mxnet(self, *args, **kwargs):
import mxnet
slice_input = dict.fromkeys(self._transformed_input.keys(), None)
@@ -425,7 +433,7 @@ def get_slice_input(self, *args, **kwargs):
return [self._prompts[0]] * self._batch_size
def process_output(self, result, log):
- output_text = '\n'.join([f'{i+1}) {text} ... \n' for i, text in enumerate(result)])
+ output_text = '\n'.join([f'{i + 1}) {text} ... \n' for i, text in enumerate(result)])
log.info(f'Generated results: \n{output_text}')
@@ -435,7 +443,7 @@ def get_slice_input(self, *args, **kwargs):
return self.audio_data, self.sampling_rate, self.audio_length
def process_output(self, result, log):
- output_text = '\n'.join([f'{i+1}) {text} ... \n' for i, text in enumerate(result)])
+ output_text = '\n'.join([f'{i + 1}) {text} ... \n' for i, text in enumerate(result)])
log.info(f'Generated results: \n{output_text}')
diff --git a/src/inference/io_model_wrapper.py b/src/inference/io_model_wrapper.py
index 1c68c89dd..2aa15d13b 100644
--- a/src/inference/io_model_wrapper.py
+++ b/src/inference/io_model_wrapper.py
@@ -409,3 +409,19 @@ def get_input_layer_dtype(self):
class ExecuTorchIOModelWrapper(TVMIOModelWrapper):
pass
+
+
+class IREEModelWrapper(IOModelWrapper):
+ def __init__(self, args):
+ self._input_shapes = [args.input_shape]
+ self._model_path = args.model
+
+ def get_input_layer_names(self, model):
+ return ['input']
+
+ def get_input_layer_shape(self, model, layer_name):
+ return self._input_shapes[0]
+
+ def get_input_layer_dtype(self, model, layer_name):
+ import numpy as np
+ return np.float32
diff --git a/src/inference/iree_auxiliary.py b/src/inference/iree_auxiliary.py
new file mode 100644
index 000000000..a508115c3
--- /dev/null
+++ b/src/inference/iree_auxiliary.py
@@ -0,0 +1,214 @@
+import os
+import sys
+import tempfile
+from pathlib import Path
+
+import numpy as np
+
+sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('model_converters',
+ 'iree_converter',
+ 'iree_auxiliary')))
+from compiler import IREECompiler # noqa: E402
+from converter import IREEConverter # noqa: E402
+
+sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
+from logger_conf import configure_logger # noqa: E402
+
+log = configure_logger()
+
+try:
+ import iree.runtime as ireert # noqa: E402
+except ImportError as e:
+ log.error(f'IREE import error: {e}')
+ sys.exit(1)
+
+
+def _validate_iree_model_args(args):
+ if not args.model:
+ raise ValueError('Model path (-m/--model) is required')
+ if not os.path.exists(args.model):
+ raise FileNotFoundError(f'The file not found: {args.model}')
+
+ file_type = args.model.split('.')[-1].lower()
+ supported_extensions = ['mlir', 'vmfb']
+ if file_type not in supported_extensions:
+ raise ValueError(f'Model must be an {supported_extensions} file')
+ if file_type == 'mlir' and not args.target_backend:
+ raise ValueError('target_backend is required when using .mlir model')
+
+
+def _validate_onnx_args(args):
+ if not args.model:
+ raise ValueError('Model path (-m/--model) is required for ONNX framework')
+ if not os.path.exists(args.model):
+ raise FileNotFoundError(f'Model file not found: {args.model}')
+
+ file_type = args.model.split('.')[-1]
+ if file_type == 'onnx':
+ if not args.onnx_opset_version:
+ raise ValueError('onnx_opset_version is required for ONNX framework')
+ else:
+ _validate_iree_model_args(args)
+
+
+def _validate_pytorch_args(args):
+ has_model_path = args.model is not None and args.model != ''
+ has_module_model = (args.torch_module is not None and args.torch_module != ''
+ and args.model_name is not None and args.model_name != '')
+
+ if not has_model_path and not has_module_model:
+ raise ValueError(
+ 'For PyTorch conversion, you must specify either model_path, '
+ 'or torch_module and model_name',
+ )
+
+ if has_model_path and has_module_model:
+ raise ValueError(
+ 'Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). '
+ 'Please choose only one method.',
+ )
+
+ if has_model_path:
+ if not os.path.exists(args.model):
+ raise FileNotFoundError(f'Model file not found: {args.model}')
+
+ file_type = args.model.split('.')[-1]
+ if file_type != 'pt':
+ _validate_iree_model_args(args)
+ else:
+ if not args.target_backend:
+ raise ValueError('target_backend is required when using conversion from torch module')
+
+ if args.model_weights and args.model_weights != '' and not os.path.exists(args.model_weights):
+ raise FileNotFoundError(f'Model weights not found: {args.model_weights}')
+
+
+def validate_cli_args(args):
+ if args.source_framework == 'onnx':
+ _validate_onnx_args(args)
+ elif args.source_framework == 'pytorch':
+ _validate_pytorch_args(args)
+ else:
+ _validate_iree_model_args(args)
+
+
+def _convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version,
+ source_framework, input_shape, output_mlir):
+ dictionary = {
+ 'source_framework': source_framework,
+ 'model_name': model_name,
+ 'model_path': model_path,
+ 'model_weights': model_weights,
+ 'torch_module': torch_module,
+ 'onnx_opset_version': onnx_opset_version,
+ 'input_shape': input_shape,
+ 'output_mlir': output_mlir,
+ }
+ converter = IREEConverter.get_converter(dictionary)
+ converter.convert_to_mlir()
+ return
+
+
+def _compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args):
+ try:
+ log.info('Starting model compilation')
+ return IREECompiler.compile_model(mlir_path, target_backend, opt_level, extra_compile_args)
+ except Exception as e:
+ log.error(f'Failed to compile MLIR: {e}')
+ raise
+
+
+def _load_model_buffer(model_path, target_backend, opt_level, extra_compile_args):
+ if not os.path.exists(model_path):
+ raise FileNotFoundError(f'Model file not found: {model_path}')
+
+ file_type = model_path.split('.')[-1]
+
+ if file_type == 'mlir':
+ if target_backend is None:
+ raise ValueError('target_backend is required for MLIR compilation')
+ vmfb_buffer = _compile_mlir(model_path, target_backend, opt_level, extra_compile_args)
+ elif file_type == 'vmfb':
+ with open(model_path, 'rb') as f:
+ vmfb_buffer = f.read()
+ else:
+ raise ValueError(f'The file type {file_type} is not supported. Supported types: .mlir, .vmfb')
+
+ log.info(f'Successfully loaded model buffer from {model_path}')
+ return vmfb_buffer
+
+
+def _create_iree_context_from_buffer(vmfb_buffer):
+ try:
+ config = ireert.Config('local-task')
+ vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer)
+ context = ireert.SystemContext(config=config)
+ context.add_vm_module(vm_module)
+
+ log.info('Successfully created IREE context from buffer')
+ return context
+
+ except Exception as e:
+ log.error(f'Failed to create IREE context: {e}')
+ raise
+
+
+def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version,
+ source_framework, input_shape, target_backend, opt_level, extra_compile_args):
+ is_tmp_mlir = False
+ if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']:
+ with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp:
+ output_mlir = temp.name
+ _convert_model_to_mlir(model_path,
+ model_weights,
+ torch_module,
+ model_name,
+ onnx_opset_version,
+ source_framework,
+ input_shape,
+ output_mlir)
+ model_path = output_mlir
+ is_tmp_mlir = True
+
+ vmfb_buffer = _load_model_buffer(
+ model_path,
+ target_backend=target_backend,
+ opt_level=opt_level,
+ extra_compile_args=extra_compile_args,
+ )
+
+ if is_tmp_mlir:
+ os.remove(model_path)
+
+ return _create_iree_context_from_buffer(vmfb_buffer)
+
+
+def prepare_output(result, task):
+ if task == 'feedforward':
+ return {}
+ elif task == 'classification':
+ if hasattr(result, 'to_host'):
+ result = result.to_host()
+
+ logits = np.array(result)
+
+ # Apply softmax
+ max_logits = np.max(logits, axis=-1, keepdims=True)
+ exp_logits = np.exp(logits - max_logits)
+ probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
+
+ return {'output': probabilities}
+ else:
+ raise ValueError(f'Unsupported task {task}')
+
+
+def create_dict_for_transformer(args):
+ return {
+ 'channel_swap': args.channel_swap,
+ 'mean': args.mean,
+ 'std': args.std,
+ 'norm': args.norm,
+ 'layout': args.layout,
+ 'input_shape': args.input_shape,
+ 'batch_size': args.batch_size,
+ }
diff --git a/src/inference/transformer.py b/src/inference/transformer.py
index 2d85d0c06..42a27eaaf 100644
--- a/src/inference/transformer.py
+++ b/src/inference/transformer.py
@@ -368,3 +368,71 @@ def transform_images(self, images, shape, element_type, *args):
class ExecuTorchTransformer(TVMTransformer):
pass
+
+
+class IREETransformer(Transformer):
+ def __init__(self, converting):
+ self._converting = converting
+
+ def __set_norm(self, image):
+ if self._converting.get('norm', False):
+ image = image.astype(np.float32) / 255.0
+ return image
+
+ def __set_channel_swap(self, image):
+ channel_swap = self._converting.get('channel_swap')
+ if channel_swap is not None:
+ image = image[:, :, channel_swap]
+ return image
+
+ def __set_mean(self, image):
+ mean = self._converting.get('mean')
+ if mean is not None and len(mean) == 3:
+ image[:, :, 0] -= mean[0]
+ image[:, :, 1] -= mean[1]
+ image[:, :, 2] -= mean[2]
+ return image
+
+ def __set_std(self, image):
+ std = self._converting.get('std')
+ if std is not None and len(std) == 3:
+ image[:, :, 0] /= std[0]
+ image[:, :, 1] /= std[1]
+ image[:, :, 2] /= std[2]
+ return image
+
+ def __set_layout(self, image):
+ layout = self._converting['layout']
+ if layout is not None:
+ layout = LAYER_LAYOUT_TO_IMAGE[layout]
+ image = np.expand_dims(image, 0).transpose(layout)
+ return image
+
+ def __bgr_to_rgb(self, image):
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ def _transform(self, image):
+ transformed_image = self.__bgr_to_rgb(image)
+ transformed_image = self.__set_norm(transformed_image)
+ transformed_image = self.__set_channel_swap(transformed_image)
+ transformed_image = self.__set_mean(transformed_image)
+ transformed_image = self.__set_std(transformed_image)
+ transformed_image = self.__set_layout(transformed_image)
+ return transformed_image
+
+ def transform_images(self, images, shape, element_type, *args):
+ dataset_size = images.shape[0]
+ new_shape = [dataset_size] + shape[1:]
+ transformed_images = np.zeros(shape=new_shape, dtype=element_type)
+ for i in range(dataset_size):
+ transformed_images[i] = self._transform(images[i])
+ return transformed_images
+
+ def get_shape_in_chw_order(self, shape, *args):
+ layout = self._converting.get('layout', 'NHWC')
+ if layout == 'NHWC':
+ return shape[3], shape[1], shape[2]
+ elif layout == 'NCHW':
+ return shape[1], shape[2], shape[3]
+ else:
+ return shape[1:]
diff --git a/src/model_converters/README.md b/src/model_converters/README.md
index 7d7dea499..28679735c 100644
--- a/src/model_converters/README.md
+++ b/src/model_converters/README.md
@@ -16,6 +16,7 @@
format from TensorFlow and ONNX formats.
- `tvm_converter` contains converter and compiler
to the TVM format.
+- `iree_converter` contains tools to convert ONNX or PyTorch models to IREE MLIR and compile them to VMFB binaries.
## An overview of existing model converters
diff --git a/src/model_converters/iree_converter/README.md b/src/model_converters/iree_converter/README.md
new file mode 100644
index 000000000..604265df5
--- /dev/null
+++ b/src/model_converters/iree_converter/README.md
@@ -0,0 +1,119 @@
+# Conversion to the IREE format
+IREE converter supports conversion to the IREE MLIR format from ONNX and PyTorch formats.
+
+IREE compiler supports compilation from `.mlir` format to the `.vmfb` format for deployment on various backends.
+
+## IREE converter usage
+
+Basic usage of the script:
+
+```sh
+iree_converter.py --source_framework \
+ --model_name \
+ --model \
+ --weights \
+ --torch_module \
+ --input_shape \
+ --onnx_opset_version \
+ --output_mlir
+```
+
+This script converts model from `` to the IREE MLIR format.
+
+### IREE converter parameters
+- `-f / --source_framework` is a source framework where the model was trained. Required. Choices: `onnx`, `pytorch`.
+- `-mn / --model_name` is a model name. Required for PyTorch models loaded from module.
+- `-m / --model` is a path to an `.onnx` or `.pt` file with a trained model.
+- `-w / --weights` is a path to an `.pth` file with trained weights for PyTorch models.
+- `-tm / --torch_module` is a module with the model architecture for PyTorch models. Default: `torchvision.models`.
+- `-is / --input_shape` is an input shape in the format BxHxWxC, where B is a batch size, H is an input tensor height, W is an input tensor width, C is an input tensor number of channels. Required for PyTorch models.
+- `--onnx_opset_version` is an ONNX opset version for ONNX models. Default: `18`.
+- `-o / --output_mlir` is a path to save the MLIR file. Required.
+
+### Parameter combinations
+#### For ONNX models:
+- Required: `--source_framework onnx`, `--model `, `--output_mlir `
+- Optional: `--onnx_opset_version` (default: 18; the converter validates that the value is set, so keep the default or override it explicitly)
+#### For PyTorch models:
+Two loading methods are supported (mutually exclusive):
+1. From file:
+- Required: `--source_framework pytorch`, `--model `, `--input_shape B H W C`, `--output_mlir `
+- Optional: `--model_name ` (used only for logging), `--weights `
+1. From module:
+- Required: `--source_framework pytorch`, `--model_name `, `--torch_module `, `--input_shape B H W C`, `--output_mlir `
+- Optional: `--weights `
+
+> **Note:** `--model` and the pair `(--torch_module`, `--model_name)` are mutually exclusive. Passing both at the same time will raise a validation error (`converter.py` enforces the rule). Likewise, `--input_shape` is only validated for PyTorch conversions, so you can omit it for ONNX.
+
+### Examples of usage
+ONNX model conversion ([source of the model efficientnet-b0.onnx](https://github.com/onnx/models/blob/main/Computer_Vision/efficientnet_b0_Opset17_timm/efficientnet_b0_Opset17.onnx)):
+```sh
+python3 iree_converter.py -f onnx -m efficientnet-b0.onnx \
+ --onnx_opset_version 18 \
+ -o ./output/efficientnet-b0.mlir
+```
+
+PyTorch model from file (`.pt` can be created using [tutorial](https://docs.pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules)):
+```sh
+python3 iree_converter.py -f pytorch -m resnet50.pt \
+ -is 1 224 224 3 \
+ -o ./output/resnet50.mlir
+```
+
+PyTorch model from [torchvision](https://docs.pytorch.org/vision/main/models.html) with pretrained weights:
+```sh
+python3 iree_converter.py -f pytorch -mn resnet50 \
+ -tm torchvision.models \
+ -is 1 224 224 3 \
+ -o ./output/resnet50.mlir
+```
+
+PyTorch model with custom weights:
+```sh
+python3 iree_converter.py -f pytorch -mn resnet50 \
+ -tm torchvision.models \
+ -w ./weights/resnet50-custom.pth \
+ -is 1 224 224 3 \
+ -o ./output/resnet50-custom.mlir
+```
+
+## IREE compiler usage
+
+Basic usage of the script:
+```sh
+iree_compiler.py --mlir \
+ --target_backend \
+ --opt_level \
+ --output_file \
+ [--extra_args ]
+```
+This script compiles model from `.mlir` format to the deployable binary format for the specified target backend.
+
+### IREE compiler parameters
+- `-m / --mlir` is a path to an .mlir file with a model. Required.
+- `-tb / --target_backend` is a target backend for compilation. Required. Examples: `llvm-cpu`, `cuda`, `vulkan`, `vmvx`.
+- `--opt_level` is an optimization level of the compilation. Choices: `0`, `1`, `2`, `3`. Default: `2`.
+- `-o / --output_file` is a path to save the compiled model. Required.
+- `--extra_args` - is an extra arguments for compilation. Optional.
+
+### Supported target backends
+- `llvm-cpu` - CPU execution using LLVM.
+- `cuda` - NVIDIA GPU execution using CUDA.
+- `vulkan` - GPU execution using Vulkan API.
+- `vmvx` - Portable VM bytecode execution.
+- `metal` - Apple GPU execution using Metal.
+- `rocm` - AMD GPU execution using ROCm.
+
+### Examples of usage
+```sh
+python3 iree_compiler.py -m ./models/resnet50.mlir \
+ -tb llvm-cpu \
+ --opt_level 2 \
+ -o ./compiled/resnet50-cpu.vmfb
+```
+### Using extra arguments
+The `--extra_args` parameter allows passing additional compilation flags:
+```sh
+python3 iree_compiler.py -m model.mlir -tb llvm-cpu -o output.vmfb \
+ --extra_args --iree-llvmcpu-target-triple=x86_64-linux-gnu
+```
diff --git a/src/model_converters/iree_converter/__init__.py b/src/model_converters/iree_converter/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/model_converters/iree_converter/iree_auxiliary/__init__.py b/src/model_converters/iree_converter/iree_auxiliary/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/model_converters/iree_converter/iree_auxiliary/compiler.py b/src/model_converters/iree_converter/iree_auxiliary/compiler.py
new file mode 100644
index 000000000..3792a6abd
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_auxiliary/compiler.py
@@ -0,0 +1,10 @@
+import os
+from iree.compiler.tools import compile_str, compile_file
+
+
+class IREECompiler:
+ @staticmethod
+ def compile_model(mlir, target, opt_level, extra_args, output_file=None):
+ extra_args.append(f'--iree-opt-level=O{opt_level}')
+ compile_func = compile_file if os.path.isfile(mlir) else compile_str
+ return compile_func(mlir, target_backends=[target], extra_args=extra_args, output_file=output_file)
diff --git a/src/model_converters/iree_converter/iree_auxiliary/converter.py b/src/model_converters/iree_converter/iree_auxiliary/converter.py
new file mode 100644
index 000000000..db5c7feec
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_auxiliary/converter.py
@@ -0,0 +1,39 @@
+import abc
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).resolve().parent.parent.parent.parent.joinpath('utils')))
+from logger_conf import configure_logger # noqa: E402
+
+log = configure_logger()
+
+
+class IREEConverter(metaclass=abc.ABCMeta):
+ def __init__(self, args):
+ self.model_name = args.get('model_name', None)
+ self.output_mlir = args.get('output_mlir', None)
+ self.log = log
+
+ @abc.abstractmethod
+ def _convert_model_from_framework(self):
+ pass
+
+ @property
+ @abc.abstractmethod
+ def source_framework(self):
+ pass
+
+ @staticmethod
+ def get_converter(args):
+ framework = args['source_framework'].lower()
+ if framework == 'onnx':
+ from onnx_format import IREEConverterONNXFormat
+ return IREEConverterONNXFormat(args)
+ elif framework == 'pytorch':
+ from pytorch_format import IREEConverterPyTorchFormat
+ return IREEConverterPyTorchFormat(args)
+
+ def convert_to_mlir(self):
+ self.log.info(f'Get IREE MLIR for {self.model_name} from {self.source_framework} framework')
+ self._convert_model_from_framework()
+ return
diff --git a/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py b/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py
new file mode 100644
index 000000000..0de82c784
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_auxiliary/onnx_format.py
@@ -0,0 +1,38 @@
+import subprocess
+import os
+from converter import IREEConverter
+
+
+class IREEConverterONNXFormat(IREEConverter):
+ def __init__(self, args):
+ super().__init__(args)
+ self.model_path = args.get('model_path', None)
+ self.onnx_opset_version = args.get('onnx_opset_version', None)
+ self._validate_arguments()
+
+ @property
+ def source_framework(self):
+ return 'ONNX'
+
+ def _validate_arguments(self):
+ if self.model_path is None or self.model_path == '':
+ raise ValueError('The model_path parameter is required for ONNX conversion.')
+
+ if not os.path.exists(self.model_path):
+ raise FileNotFoundError(f'Model file not found: {self.model_path}')
+
+ if self.onnx_opset_version is None:
+ raise ValueError('The onnx_opset_version parameter is required for ONNX conversion.')
+
+ def _convert_model_from_framework(self):
+ import_args = [
+ 'iree-import-onnx',
+ self.model_path,
+ '--opset-version',
+ str(self.onnx_opset_version),
+ '-o',
+ self.output_mlir,
+ ]
+ import_cmd = subprocess.list2cmdline(import_args)
+ subprocess.run(import_cmd, shell=True, capture_output=True)
+ return
diff --git a/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py b/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py
new file mode 100644
index 000000000..787fef036
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_auxiliary/pytorch_format.py
@@ -0,0 +1,86 @@
+import importlib
+import os
+from converter import IREEConverter
+
+
+class IREEConverterPyTorchFormat(IREEConverter):
+ def __init__(self, args):
+ super().__init__(args)
+ self.torch = importlib.import_module('torch')
+ self.aot = importlib.import_module('iree.turbine.aot')
+ self.model_path = args.get('model_path', None)
+ self.model_weights = args.get('model_weights', None)
+ self.module = args.get('torch_module', None)
+ self.input_shape = args.get('input_shape', None)
+ self._validate_arguments()
+
+ @property
+ def source_framework(self):
+ return 'PyTorch'
+
+ def _validate_arguments(self):
+ if self.input_shape is None:
+ raise ValueError('The input_shape parameter is required for PyTorch conversion.')
+
+ # Check load methods:
+ # 1. model_path (load from file)
+ # 2. module + model_name (load from torch module)
+ has_model_path = self.model_path is not None and self.model_path != ''
+ has_module_model = (self.module is not None
+ and self.module != ''
+ and self.model_name is not None
+ and self.model_name != '')
+
+ if not has_model_path and not has_module_model:
+ raise ValueError(
+ 'For PyTorch conversion, you must specify either model_path, \
+ or torch_module and model_name',
+ )
+
+ if has_model_path and has_module_model:
+ raise ValueError(
+ 'Provided incompatible parameters for PyTorch conversion (model_path and torch_module+model_name). \
+ Please choose only one method of this.',
+ )
+
+ if has_model_path and not os.path.exists(self.model_path):
+ raise FileNotFoundError(f'Model file not found: {self.model_path}')
+
+ if (self.model_weights is not None and self.model_weights != ''
+ and not os.path.exists(self.model_weights)):
+ raise FileNotFoundError(f'Model weights not found: {self.model_weights}')
+
+ def __get_model_from_path(self):
+ self.log.info(f'Loading model from path {self.model_path}')
+ file_type = self.model_path.split('.')[-1]
+ supported_extensions = ['pt']
+ if file_type not in supported_extensions:
+ raise ValueError(f'The file type {file_type} is not supported')
+ model = self.torch.load(self.model_path)
+ model.eval()
+ return model
+
+ def __get_model_from_module(self):
+ self.log.info(f'Loading model {self.model_name} from module')
+ model_cls = importlib.import_module(self.module).__getattribute__(self.model_name)
+ if self.model_weights is None or self.model_weights == '':
+ self.log.info('Loading pretrained model')
+ model = model_cls(weights=True)
+ else:
+ self.log.info(f'Loading model with weights from file {self.model_weights}')
+ model = model_cls()
+ checkpoint = self.torch.load(self.model_weights, map_location=self.device.lower())
+ model.load_state_dict(checkpoint, strict=False)
+ model.eval()
+ return model
+
+ def _convert_model_from_framework(self):
+ model = None
+ if self.module:
+ model = self.__get_model_from_module()
+ else:
+ model = self.__get_model_from_path()
+ example_arg = self.torch.randn(*self.input_shape)
+ export_output = self.aot.export(model, example_arg)
+ export_output.save_mlir(self.output_mlir)
+ return
diff --git a/src/model_converters/iree_converter/iree_compiler.py b/src/model_converters/iree_converter/iree_compiler.py
new file mode 100644
index 000000000..ad8f72b0a
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_compiler.py
@@ -0,0 +1,54 @@
+import argparse
+import sys
+import os
+import traceback
+from pathlib import Path
+from iree_auxiliary.compiler import IREECompiler
+
+sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
+from utils.logger_conf import configure_logger # noqa: E402
+
+log = configure_logger()
+
+
+def cli_argument_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-m', '--mlir',
+ help='Path to an .mlir file with a model.',
+ required=True,
+ type=str)
+ parser.add_argument('-tb', '--target_backend',
+ help='Target backend, for example "llvm-cpu" for CPU.',
+ required=True,
+ type=str)
+ parser.add_argument('--opt_level',
+ help='The optimization level of the task extractions.',
+ type=int,
+ choices=[0, 1, 2, 3],
+ default=2)
+ parser.add_argument('--extra_args',
+ help='The extra arguments for compilation.',
+ type=str,
+ nargs=argparse.REMAINDER,
+ default=[])
+ parser.add_argument('-o', '--output_file',
+ help='Path to compiled model.',
+ required=True,
+ type=str)
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = cli_argument_parser()
+ try:
+ IREECompiler.compile_model(args.mlir, args.target_backend, args.opt_level, args.extra_args, args.output_file)
+ if os.path.exists(args.output_file):
+ print(f'The MLIR has been sucessfully compiled into {args.output_file}')
+ except Exception:
+ log.error(traceback.format_exc())
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ sys.exit(main() or 0)
diff --git a/src/model_converters/iree_converter/iree_converter.py b/src/model_converters/iree_converter/iree_converter.py
new file mode 100644
index 000000000..60ac06795
--- /dev/null
+++ b/src/model_converters/iree_converter/iree_converter.py
@@ -0,0 +1,89 @@
+import argparse
+import os
+import sys
+import traceback
+from pathlib import Path
+sys.path.append(str(Path(__file__).resolve().parent.joinpath('iree_auxiliary')))
+from converter import IREEConverter # noqa: E402
+
+sys.path.append(str(Path(__file__).resolve().parent.parent.parent))
+from utils.logger_conf import configure_logger # noqa: E402
+
+log = configure_logger()
+
+
+def cli_argument_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-f', '--source_framework',
+ help='Source model framework',
+ required=True,
+ type=str,
+ choices=['onnx', 'pytorch'],
+ dest='source_framework')
+ parser.add_argument('-mn', '--model_name',
+ help='Model name.',
+ type=str,
+ dest='model_name')
+ parser.add_argument('-m', '--model',
+ help='Path to an .onnx or .pt file with a trained model.',
+ type=str,
+ dest='model_path')
+ parser.add_argument('-w', '--weights',
+ help='Path to an .pth file with a trained weights.',
+ type=str,
+ dest='model_weights')
+ parser.add_argument('-tm', '--torch_module',
+ help='Torch module with model architecture.',
+ default='torchvision.models',
+ type=str,
+ dest='torch_module')
+ parser.add_argument('--onnx_opset_version',
+ help='Path to an .onnx with a trained model.',
+ type=int,
+ default=18,
+ dest='onnx_opset_version')
+ parser.add_argument('-is', '--input_shape',
+ help='Input shape BxCxHxW, B is a batch size,'
+ 'C is an input tensor number of channels,'
+ 'H is an input tensor height,'
+ 'W is an input tensor width.',
+ type=int,
+ nargs=4,
+ dest='input_shape')
+ parser.add_argument('-o', '--output_mlir',
+ help='Path to save the MLIR.',
+ required=True,
+ type=str,
+ dest='output_mlir')
+ args = parser.parse_args()
+ return args
+
+
+def create_dict_for_converter(args):
+ dictionary = {
+ 'source_framework': args.source_framework,
+ 'model_name': args.model_name,
+ 'model_path': args.model_path,
+ 'model_weights': args.model_weights,
+ 'torch_module': args.torch_module,
+ 'onnx_opset_version': args.onnx_opset_version,
+ 'input_shape': args.input_shape,
+ 'output_mlir': args.output_mlir,
+ }
+ return dictionary
+
+
+def main():
+ args = cli_argument_parser()
+ try:
+ converter = IREEConverter.get_converter(create_dict_for_converter(args))
+ converter.convert_to_mlir()
+ if os.path.exists(args.output_mlir):
+ print(f'The MLIR has been sucessfully saved into {args.output_mlir}')
+ except Exception:
+ log.error(traceback.format_exc())
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ sys.exit(main() or 0)
diff --git a/src/quantization/process.py b/src/quantization/process.py
index 506c97a9e..3bf5f5787 100644
--- a/src/quantization/process.py
+++ b/src/quantization/process.py
@@ -74,7 +74,7 @@ def execute(self, idx):
command_line = self.__fill_command_line()
if command_line == '':
self.__log.error('Command line is empty')
- self.__log.info(f'Start quantization model #{idx+1}!')
+ self.__log.info(f'Start quantization model #{idx + 1}!')
self.__log.info(f'Command line is : {command_line}')
self._status, self._output = self._executor.execute_process(command_line)
if type(self._output) is not list: