|
1 | 1 | import argparse |
2 | | -import os |
3 | 2 | import sys |
4 | 3 | import traceback |
5 | | -import tempfile |
6 | 4 | from pathlib import Path |
7 | 5 |
|
8 | 6 | import postprocessing_data as pp |
|
11 | 9 | from io_model_wrapper import IREEModelWrapper |
12 | 10 | from reporter.report_writer import ReportWriter |
13 | 11 | from transformer import IREETransformer |
| 12 | +from iree_auxiliary import (load_model, create_dict_for_transformer, prepare_output, validate_cli_args) |
14 | 13 |
|
15 | | -import numpy as np |
16 | | - |
17 | | -sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('model_converters', |
18 | | - 'iree_converter', |
19 | | - 'iree_auxiliary'))) |
20 | | -from compiler import IREECompiler # noqa: E402 |
21 | | -from converter import IREEConverter # noqa: E402 |
22 | 14 |
|
23 | 15 | sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils'))) |
24 | 16 | from logger_conf import configure_logger # noqa: E402 |
|
32 | 24 | sys.exit(1) |
33 | 25 |
|
34 | 26 |
|
35 | | -def validate_cli_args(args): |
36 | | - if args.model: |
37 | | - pass |
38 | | - else: |
39 | | - pass |
40 | | - |
41 | | - |
42 | 27 | def cli_argument_parser(): |
43 | 28 | parser = argparse.ArgumentParser() |
44 | 29 | parser.add_argument('-f', '--source_framework', |
45 | 30 | help='Source model framework (required for automatic conversion to MLIR)', |
46 | 31 | type=str, |
47 | 32 | choices=['onnx', 'pytorch'], |
48 | | - dest='source_framework') |
| 33 | + dest='source_framework') |
49 | 34 | parser.add_argument('-m', '--model', |
50 | 35 | help='Path to source framework model (.onnx, .pt),' |
51 | 36 | 'to file with compiled model (.vmfb)' |
@@ -181,96 +166,6 @@ def cli_argument_parser(): |
181 | 166 | return args |
182 | 167 |
|
183 | 168 |
|
184 | | -def convert_model_to_mlir(model_path, model_weights, torch_module, model_name, onnx_opset_version, source_framework, input_shape, output_mlir): |
185 | | - dictionary = { |
186 | | - 'source_framework': source_framework, |
187 | | - 'model_name': model_name, |
188 | | - 'model_path': model_path, |
189 | | - 'model_weights': model_weights, |
190 | | - 'torch_module': torch_module, |
191 | | - 'onnx_opset_version': onnx_opset_version, |
192 | | - 'input_shape': input_shape, |
193 | | - 'output_mlir': output_mlir |
194 | | - } |
195 | | - converter = IREEConverter.get_converter(dictionary) |
196 | | - converter.convert_to_mlir() |
197 | | - return |
198 | | - |
199 | | - |
200 | | -def compile_mlir(mlir_path, target_backend, opt_level, extra_compile_args): |
201 | | - try: |
202 | | - log.info('Starting model compilation') |
203 | | - return IREECompiler.compile(mlir_path, target_backend, opt_level, extra_compile_args) |
204 | | - except Exception as e: |
205 | | - log.error(f'Failed to compile MLIR: {e}') |
206 | | - raise |
207 | | - |
208 | | - |
209 | | -def load_model_buffer(model_path, target_backend, opt_level, extra_compile_args): |
210 | | - if not os.path.exists(model_path): |
211 | | - raise FileNotFoundError(f'Model file not found: {model_path}') |
212 | | - |
213 | | - file_type = model_path.split('.')[-1] |
214 | | - |
215 | | - if file_type == 'mlir': |
216 | | - if target_backend is None: |
217 | | - raise ValueError('target_backend is required for MLIR compilation') |
218 | | - vmfb_buffer = compile_mlir(model_path, target_backend, opt_level, extra_compile_args) |
219 | | - elif file_type == 'vmfb': |
220 | | - with open(model_path, 'rb') as f: |
221 | | - vmfb_buffer = f.read() |
222 | | - else: |
223 | | - raise ValueError(f'The file type {file_type} is not supported. Supported types: .mlir, .vmfb') |
224 | | - |
225 | | - log.info(f'Successfully loaded model buffer from {model_path}') |
226 | | - return vmfb_buffer |
227 | | - |
228 | | - |
229 | | -def create_iree_context_from_buffer(vmfb_buffer): |
230 | | - try: |
231 | | - config = ireert.Config('local-task') |
232 | | - vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, vmfb_buffer) |
233 | | - context = ireert.SystemContext(config=config) |
234 | | - context.add_vm_module(vm_module) |
235 | | - |
236 | | - log.info('Successfully created IREE context from buffer') |
237 | | - return context |
238 | | - |
239 | | - except Exception as e: |
240 | | - log.error(f'Failed to create IREE context: {e}') |
241 | | - raise |
242 | | - |
243 | | - |
244 | | -def load_model(model_path, model_weights, torch_module, model_name, onnx_opset_version, |
245 | | - source_framework, input_shape, target_backend, opt_level, extra_compile_args): |
246 | | - is_tmp_mlir = False |
247 | | - if model_path is None or model_path.split('.')[-1] not in ['vmfb', 'mlir']: |
248 | | - with tempfile.NamedTemporaryFile(mode='w+t', delete=False, suffix='.mlir') as temp: |
249 | | - output_mlir = temp.name |
250 | | - convert_model_to_mlir(model_path, |
251 | | - model_weights, |
252 | | - torch_module, |
253 | | - model_name, |
254 | | - onnx_opset_version, |
255 | | - source_framework, |
256 | | - input_shape, |
257 | | - output_mlir) |
258 | | - model_path = output_mlir |
259 | | - is_tmp_mlir = True |
260 | | - |
261 | | - vmfb_buffer = load_model_buffer( |
262 | | - model_path, |
263 | | - target_backend=target_backend, |
264 | | - opt_level=opt_level, |
265 | | - extra_compile_args=extra_compile_args |
266 | | - ) |
267 | | - |
268 | | - if is_tmp_mlir: |
269 | | - os.remove(model_path) |
270 | | - |
271 | | - return create_iree_context_from_buffer(vmfb_buffer) |
272 | | - |
273 | | - |
274 | 169 | def get_inference_function(model_context, function_name): |
275 | 170 | try: |
276 | 171 | main_module = model_context.modules.module |
@@ -323,50 +218,6 @@ def infer_slice(inference_func, slice_input): |
323 | 218 | return result |
324 | 219 |
|
325 | 220 |
|
326 | | -def prepare_output(result, task): |
327 | | - if task == 'feedforward': |
328 | | - return {} |
329 | | - elif task == 'classification': |
330 | | - if hasattr(result, 'to_host'): |
331 | | - result = result.to_host() |
332 | | - |
333 | | - # Extract tensor from dict if needed |
334 | | - if isinstance(result, dict): |
335 | | - result_key = next(iter(result)) |
336 | | - logits = result[result_key] |
337 | | - output_key = result_key |
338 | | - else: |
339 | | - logits = np.array(result) |
340 | | - output_key = 'output' |
341 | | - |
342 | | - # Ensure correct shape (batch_size, num_classes) |
343 | | - if logits.ndim == 1: |
344 | | - logits = logits.reshape(1, -1) |
345 | | - elif logits.ndim > 2: |
346 | | - logits = logits.reshape(logits.shape[0], -1) |
347 | | - |
348 | | - # Apply softmax |
349 | | - max_logits = np.max(logits, axis=-1, keepdims=True) |
350 | | - exp_logits = np.exp(logits - max_logits) |
351 | | - probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
352 | | - |
353 | | - return {output_key: probabilities} |
354 | | - else: |
355 | | - raise ValueError(f'Unsupported task {task}') |
356 | | - |
357 | | - |
358 | | -def create_dict_for_transformer(args): |
359 | | - return { |
360 | | - 'channel_swap': getattr(args, 'channel_swap'), |
361 | | - 'mean': getattr(args, 'mean'), |
362 | | - 'std': getattr(args, 'std'), |
363 | | - 'norm': getattr(args, 'norm'), |
364 | | - 'layout': getattr(args, 'layout'), |
365 | | - 'input_shape': getattr(args, 'input_shape'), |
366 | | - 'batch_size': getattr(args, 'batch_size'), |
367 | | - } |
368 | | - |
369 | | - |
370 | 221 | def main(): |
371 | 222 | args = cli_argument_parser() |
372 | 223 |
|
|
0 commit comments