Skip to content

Commit 85d30e8

Browse files
IREE inference
1 parent ec4106e commit 85d30e8

File tree

3 files changed

+393
-0
lines changed

3 files changed

+393
-0
lines changed

src/inference/inference_iree.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import argparse
2+
import sys
3+
import traceback
4+
from pathlib import Path
5+
6+
import postprocessing_data as pp
7+
from inference_tools.loop_tools import loop_inference, get_exec_time
8+
from io_adapter import IOAdapter
9+
from io_model_wrapper import IREEModelWrapper
10+
from reporter.report_writer import ReportWriter
11+
from transformer import IREETransformer
12+
13+
import numpy as np
14+
15+
sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
16+
from logger_conf import configure_logger # noqa: E402
17+
18+
log = configure_logger()
19+
20+
try:
21+
import iree.runtime as ireert # noqa: E402
22+
except ImportError as e:
23+
log.error(f"IREE import error: {e}")
24+
sys.exit(1)
25+
26+
27+
def cli_argument_parser():
28+
parser = argparse.ArgumentParser()
29+
30+
parser.add_argument('-i', '--input',
31+
help='Path to data.',
32+
required=True,
33+
type=str,
34+
nargs='+',
35+
dest='input')
36+
parser.add_argument('-m', '--model',
37+
help='Path to .vmfb file with compiled model.',
38+
required=True,
39+
type=str,
40+
dest='model')
41+
parser.add_argument('-in', '--input_name',
42+
help='IREE module function name to execute.',
43+
required=True,
44+
type=str,
45+
dest='input_name')
46+
parser.add_argument('-d', '--device',
47+
help='Specify the target device to infer (CPU by default)',
48+
default='CPU',
49+
type=str,
50+
dest='device')
51+
parser.add_argument('-is', '--input_shape',
52+
help='Input shape BxHxWxC, B is a batch size,'
53+
'H is an input tensor height,'
54+
'W is an input tensor width,'
55+
'C is an input tensor number of channels.',
56+
required=True,
57+
type=int,
58+
nargs=4,
59+
dest='input_shape')
60+
parser.add_argument('-b', '--batch_size',
61+
help='Size of the processed pack.'
62+
'Should be the same as B in input_shape argument.',
63+
default=1,
64+
type=int,
65+
dest='batch_size')
66+
parser.add_argument('-l', '--labels',
67+
help='Labels mapping file.',
68+
default=None,
69+
type=str,
70+
dest='labels')
71+
parser.add_argument('-nt', '--number_top',
72+
help='Number of top results.',
73+
default=5,
74+
type=int,
75+
dest='number_top')
76+
parser.add_argument('-t', '--task',
77+
help='Task type. Default: feedforward.',
78+
choices=['feedforward', 'classification'],
79+
default='feedforward',
80+
type=str,
81+
dest='task')
82+
parser.add_argument('-ni', '--number_iter',
83+
help='Number of inference iterations.',
84+
default=1,
85+
type=int,
86+
dest='number_iter')
87+
parser.add_argument('--raw_output',
88+
help='Raw output without logs.',
89+
default=False,
90+
type=bool,
91+
dest='raw_output')
92+
parser.add_argument('--time',
93+
required=False,
94+
default=0,
95+
type=int,
96+
dest='time',
97+
help='Optional. Maximum test duration. 0 if no restrictions.')
98+
parser.add_argument('--report_path',
99+
type=Path,
100+
default=Path(__file__).parent / 'iree_inference_report.json',
101+
dest='report_path')
102+
parser.add_argument('--layout',
103+
help='Input layout.',
104+
default='NHWC',
105+
choices=['NHWC', 'NCHW'],
106+
type=str,
107+
dest='layout')
108+
parser.add_argument('--norm',
109+
help='Flag to normalize input images.',
110+
action='store_true',
111+
dest='norm')
112+
parser.add_argument('--mean',
113+
help='Mean values.',
114+
default=[0, 0, 0],
115+
type=float,
116+
nargs=3,
117+
dest='mean')
118+
parser.add_argument('--std',
119+
help='Standard deviation values.',
120+
default=[1., 1., 1.],
121+
type=float,
122+
nargs=3,
123+
dest='std')
124+
parser.add_argument('--channel_swap',
125+
help='Parameter of channel swap.',
126+
default=[2, 1, 0],
127+
type=int,
128+
nargs=3,
129+
dest='channel_swap')
130+
131+
return parser.parse_args()
132+
133+
134+
def load_iree_model(model_path):
135+
try:
136+
config = ireert.Config('local-task')
137+
138+
with open(model_path, 'rb') as f:
139+
vmfb_buffer = f.read()
140+
141+
vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer)
142+
context = ireert.SystemContext(config=config)
143+
context.add_vm_module(vm_module)
144+
145+
log.info(f"Successfully loaded IREE model")
146+
return context
147+
148+
except Exception as e:
149+
log.error(f"Failed to load IREE model: {e}")
150+
raise
151+
152+
153+
def get_inference_function(model_context, input_name):
154+
try:
155+
main_module = model_context.modules.module
156+
inference_func = main_module[input_name]
157+
log.info(f"Using function '{input_name}' for inference")
158+
return inference_func
159+
160+
except Exception as e:
161+
log.error(f"Failed to get inference function: {e}")
162+
raise
163+
164+
165+
def inference_iree(inference_func, number_iter, batch_size, get_slice, test_duration):
166+
result = None
167+
time_infer = []
168+
169+
if number_iter == 1:
170+
slice_input = get_slice()
171+
result, exec_time = infer_slice(inference_func, slice_input)
172+
time_infer.append(exec_time)
173+
else:
174+
time_infer = loop_inference(number_iter, test_duration)(
175+
inference_iteration
176+
)(inference_func, get_slice)['time_infer']
177+
178+
log.info('Inference completed')
179+
return result, time_infer
180+
181+
182+
def inference_iteration(inference_func, get_slice):
183+
slice_input = get_slice()
184+
_, exec_time = infer_slice(inference_func, slice_input)
185+
return exec_time
186+
187+
188+
@get_exec_time()
189+
def infer_slice(inference_func, slice_input):
190+
config = ireert.Config('local-task')
191+
device = config.device
192+
193+
input_name = list(slice_input.keys())[0]
194+
input_data = slice_input[input_name]
195+
196+
input_buffer = ireert.asdevicearray(device, input_data)
197+
198+
result = inference_func(input_buffer)
199+
200+
if hasattr(result, 'to_host'):
201+
result = result.to_host()
202+
203+
return result
204+
205+
206+
def prepare_output(result, task):
207+
if task == 'feedforward':
208+
return {}
209+
elif task == 'classification':
210+
if hasattr(result, 'to_host'):
211+
result = result.to_host()
212+
213+
# Extract tensor from dict if needed
214+
if isinstance(result, dict):
215+
result_key = next(iter(result))
216+
logits = result[result_key]
217+
output_key = result_key
218+
else:
219+
logits = np.array(result)
220+
output_key = 'output'
221+
222+
# Ensure correct shape (batch_size, num_classes)
223+
if logits.ndim == 1:
224+
logits = logits.reshape(1, -1)
225+
elif logits.ndim > 2:
226+
logits = logits.reshape(logits.shape[0], -1)
227+
228+
# Apply softmax
229+
max_logits = np.max(logits, axis=-1, keepdims=True)
230+
exp_logits = np.exp(logits - max_logits)
231+
probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
232+
233+
return {output_key: probabilities}
234+
else:
235+
raise ValueError(f'Unsupported task {task}')
236+
237+
238+
def create_dict_for_transformer(args):
239+
return {
240+
'channel_swap': getattr(args, 'channel_swap'),
241+
'mean': getattr(args, 'mean'),
242+
'std': getattr(args, 'std'),
243+
'norm': getattr(args, 'norm'),
244+
'layout': getattr(args, 'layout'),
245+
'input_shape': getattr(args, 'input_shape'),
246+
'batch_size': getattr(args, 'batch_size'),
247+
}
248+
249+
250+
def main():
251+
args = cli_argument_parser()
252+
253+
try:
254+
model_wrapper = IREEModelWrapper(args)
255+
data_transformer = IREETransformer(create_dict_for_transformer(args))
256+
io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer)
257+
258+
report_writer = ReportWriter()
259+
report_writer.update_framework_info(name='IREE')
260+
report_writer.update_configuration_setup(
261+
batch_size=args.batch_size,
262+
iterations_num=args.number_iter,
263+
target_device=args.device
264+
)
265+
266+
model_context = load_iree_model(args.model)
267+
inference_func = get_inference_function(model_context, args.input_name)
268+
269+
log.info(f'Preparing input data: {args.input}')
270+
io.prepare_input(model_context, args.input)
271+
272+
log.info(f'Starting inference ({args.number_iter} iterations) on {args.device}')
273+
result, inference_time = inference_iree(
274+
inference_func,
275+
args.number_iter,
276+
args.batch_size,
277+
io.get_slice_input,
278+
args.time
279+
)
280+
281+
log.info('Computing performance metrics')
282+
inference_result = pp.calculate_performance_metrics_sync_mode(
283+
args.batch_size,
284+
inference_time
285+
)
286+
287+
report_writer.update_execution_results(**inference_result)
288+
report_writer.write_report(args.report_path)
289+
290+
if not args.raw_output:
291+
if args.number_iter == 1:
292+
try:
293+
log.info('Converting output tensor to print results')
294+
result = prepare_output(result, args.task)
295+
log.info('Inference results')
296+
io.process_output(result, log)
297+
except Exception as ex:
298+
log.warning(f'Error when printing inference results: {str(ex)}')
299+
300+
log.info(f'Performance results: {inference_result}')
301+
302+
except Exception:
303+
log.error(traceback.format_exc())
304+
sys.exit(1)
305+
306+
307+
if __name__ == '__main__':
308+
sys.exit(main() or 0)

src/inference/io_model_wrapper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,20 @@ def get_input_layer_dtype(self):
409409

410410
class ExecuTorchIOModelWrapper(TVMIOModelWrapper):
411411
pass
412+
413+
414+
class IREEModelWrapper(IOModelWrapper):
415+
def __init__(self, args):
416+
self._input_names = [args.input_name]
417+
self._input_shapes = [args.input_shape]
418+
self._model_path = args.model
419+
420+
def get_input_layer_names(self, model):
421+
return self._input_names
422+
423+
def get_input_layer_shape(self, model, layer_name):
424+
return self._input_shapes[0]
425+
426+
def get_input_layer_dtype(self, model, layer_name):
427+
import numpy as np
428+
return np.float32

0 commit comments

Comments
 (0)