From 56b6195e73874ed383c993a7faf26b60fc377ab1 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Mon, 7 Jul 2025 16:54:41 +0800 Subject: [PATCH 01/10] Add onnxruntime as wasi-nn backend --- build-scripts/config_common.cmake | 7 +- core/iwasm/libraries/wasi-nn/README.md | 3 +- .../wasi-nn/cmake/Findonnxruntime.cmake | 77 ++ .../libraries/wasi-nn/cmake/wasi_nn.cmake | 28 + .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 2 +- .../libraries/wasi-nn/include/wasi_nn_types.h | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 17 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 828 ++++++++++++++++++ 8 files changed, 959 insertions(+), 5 deletions(-) create mode 100644 core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake create mode 100644 core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp diff --git a/build-scripts/config_common.cmake b/build-scripts/config_common.cmake index d89c3a285e..366abb16ab 100644 --- a/build-scripts/config_common.cmake +++ b/build-scripts/config_common.cmake @@ -546,7 +546,8 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) # Variant backends if (NOT WAMR_BUILD_WASI_NN_TFLITE EQUAL 1 AND NOT WAMR_BUILD_WASI_NN_OPENVINO EQUAL 1 AND - NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) + NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1 AND + NOT WAMR_BUILD_WASI_NN_ONNX EQUAL 1) message (FATAL_ERROR " Need to select a backend for WASI-NN") endif () @@ -562,6 +563,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) message (" WASI-NN: backend llamacpp enabled") add_definitions (-DWASM_ENABLE_WASI_NN_LLAMACPP) endif () + if (WAMR_BUILD_WASI_NN_ONNX EQUAL 1) + message (" WASI-NN: backend onnx enabled") + add_definitions (-DWASM_ENABLE_WASI_NN_ONNX) + endif () # Variant devices if (WAMR_BUILD_WASI_NN_ENABLE_GPU EQUAL 1) message (" WASI-NN: GPU enabled") diff --git a/core/iwasm/libraries/wasi-nn/README.md b/core/iwasm/libraries/wasi-nn/README.md index 2e926a0327..e16891a1ba 100644 --- a/core/iwasm/libraries/wasi-nn/README.md +++ b/core/iwasm/libraries/wasi-nn/README.md @@ -26,6 +26,7 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 ... - `WAMR_BUILD_WASI_NN_TFLITE`. This option designates TensorFlow Lite as the backend. - `WAMR_BUILD_WASI_NN_OPENVINO`. This option designates OpenVINO as the backend. - `WAMR_BUILD_WASI_NN_LLAMACPP`. This option designates Llama.cpp as the backend. +- `WAMR_BUILD_WASI_NN_ONNX`. This option designates ONNX Runtime as the backend. ### Wasm @@ -151,7 +152,7 @@ docker run \ Supported: -- Graph encoding: `tensorflowlite`, `openvino` and `ggml` +- Graph encoding: `tensorflowlite`, `openvino`, `ggml` and `onnx` - Execution target: `cpu` for all. `gpu` and `tpu` for `tensorflowlite`. - Tensor type: `fp32`. diff --git a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake new file mode 100644 index 0000000000..41df9d5770 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake @@ -0,0 +1,77 @@ +# Copyright 2025 Sony Semiconductor Solutions Corporation. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Find ONNX Runtime library +# +# This module defines the following variables: +# +# :: +# +# onnxruntime_FOUND - True if onnxruntime is found +# onnxruntime_INCLUDE_DIRS - Include directories for onnxruntime +# onnxruntime_LIBRARIES - List of libraries for onnxruntime +# onnxruntime_VERSION - Version of onnxruntime +# +# :: +# +# Example usage: +# +# find_package(onnxruntime) +# if(onnxruntime_FOUND) +# target_link_libraries(app onnxruntime) +# endif() + +# First try to find ONNX Runtime using the CMake config file + +# If not found via CMake config, try to find manually +find_path(onnxruntime_INCLUDE_DIR + NAMES onnxruntime_c_api.h + PATHS + /usr/include + /usr/local/include + /opt/onnxruntime/include + $ENV{ONNXRUNTIME_ROOT}/include + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. +) + +find_library(onnxruntime_LIBRARY + NAMES onnxruntime + PATHS + /usr/lib + /usr/local/lib + /opt/onnxruntime/lib + $ENV{ONNXRUNTIME_ROOT}/lib + ${CMAKE_CURRENT_LIST_DIR}/../../../../.. +) + +# Try to determine version from header file +if(onnxruntime_INCLUDE_DIR) + file(STRINGS "${onnxruntime_INCLUDE_DIR}/onnxruntime_c_api.h" onnxruntime_version_str + REGEX "^#define[\t ]+ORT_API_VERSION[\t ]+[0-9]+") + + if(onnxruntime_version_str) + string(REGEX REPLACE "^#define[\t ]+ORT_API_VERSION[\t ]+([0-9]+)" "\\1" + onnxruntime_VERSION "${onnxruntime_version_str}") + endif() +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(onnxruntime + REQUIRED_VARS onnxruntime_LIBRARY onnxruntime_INCLUDE_DIR + VERSION_VAR onnxruntime_VERSION +) + +if(onnxruntime_FOUND) + set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY}) + set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR}) + + if(NOT TARGET onnxruntime) + add_library(onnxruntime UNKNOWN IMPORTED) + set_target_properties(onnxruntime PROPERTIES + IMPORTED_LOCATION "${onnxruntime_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}" + ) + endif() +endif() + +mark_as_advanced(onnxruntime_INCLUDE_DIR onnxruntime_LIBRARY) diff --git a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake index b771b1c402..370fcc468a 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake @@ -109,3 +109,31 @@ if(WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1) install(TARGETS wasi_nn_llamacpp DESTINATION lib) endif() + +# - onnx +if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1) + find_package(onnxruntime REQUIRED) + enable_language(CXX) + + add_library( + wasi_nn_onnx + SHARED + ${WASI_NN_ROOT}/src/wasi_nn_onnx.cpp + ) + + target_include_directories( + wasi_nn_onnx + PUBLIC + ${onnxruntime_INCLUDE_DIR}/onnx + ${onnxruntime_INCLUDE_DIR} + ) + + target_link_libraries( + wasi_nn_onnx + PUBLIC + vmlib + onnxruntime + ) + + install(TARGETS wasi_nn_onnx DESTINATION lib) +endif() diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index cda26324eb..8576624674 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,7 +21,7 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" #endif /** diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 952fb65e28..23b697a8cf 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -27,7 +27,7 @@ extern "C" { #define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) #define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) #define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) -#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) #endif /** diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 7921ec9539..8781c32105 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -21,7 +21,8 @@ #include "wasm_export.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning \ + "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" #endif #define HASHMAP_INITIAL_SIZE 20 @@ -33,6 +34,7 @@ #define TFLITE_BACKEND_LIB "libwasi_nn_tflite" LIB_EXTENTION #define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION +#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION /* Global variables */ static korp_mutex wasi_nn_lock; @@ -240,6 +242,17 @@ choose_a_backend() return openvino; } +#ifndef NDEBUG + NN_WARN_PRINTF("%s", dlerror()); +#endif + + handle = dlopen(ONNX_BACKEND_LIB, RTLD_LAZY); + if (handle) { + NN_INFO_PRINTF("Using onnx backend"); + dlclose(handle); + return onnx; + } + #ifndef NDEBUG NN_WARN_PRINTF("%s", dlerror()); #endif @@ -363,6 +376,8 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding) return TFLITE_BACKEND_LIB; case ggml: return LLAMACPP_BACKEND_LIB; + case onnx: + return ONNX_BACKEND_LIB; default: return NULL; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp new file mode 100644 index 0000000000..ab9042d013 --- /dev/null +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -0,0 +1,828 @@ +/* + * Copyright 2025 Sony Semiconductor Solutions Corporation. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +#include +#include +#include +#include +#include +#include +#include "bh_platform.h" +#include "wasi_nn_backend.h" +#include "utils/logger.h" +#include "onnxruntime_c_api.h" + +/* Maximum number of graphs and execution contexts */ +#define MAX_GRAPHS 10 +#define MAX_CONTEXTS 10 + +/* ONNX Runtime context structure */ +typedef struct { + OrtEnv *env; + OrtSessionOptions *session_options; + OrtAllocator *allocator; + const OrtApi *ort_api; + std::mutex mutex; + bool is_initialized; +} OnnxRuntimeContext; + +/* Graph structure */ +typedef struct { + OrtSession *session; + bool is_initialized; +} OnnxRuntimeGraph; + +/* Execution context structure */ +typedef struct { + OrtMemoryInfo *memory_info; + std::vector input_names; + std::vector output_names; + std::unordered_map inputs; + std::unordered_map outputs; + OnnxRuntimeGraph *graph; + bool is_initialized; +} OnnxRuntimeExecCtx; + +/* Global variables */ +static OnnxRuntimeContext g_ort_ctx; +static OnnxRuntimeGraph g_graphs[MAX_GRAPHS]; +static OnnxRuntimeExecCtx g_exec_ctxs[MAX_CONTEXTS]; + +/* Helper functions */ +static void +check_status_and_log(OrtStatus *status) +{ + if (status != nullptr) { + const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + NN_ERR_PRINTF("ONNX Runtime error: %s", msg); + g_ort_ctx.ort_api->ReleaseStatus(status); + } +} + +static wasi_nn_error +convert_ort_error_to_wasi_nn_error(OrtStatus *status) +{ + if (status == nullptr) { + return success; + } + + wasi_nn_error err; + OrtErrorCode code = g_ort_ctx.ort_api->GetErrorCode(status); + const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + + NN_ERR_PRINTF("ONNX Runtime error: %s", msg); + + switch (code) { + case ORT_INVALID_ARGUMENT: + err = invalid_argument; + break; + case ORT_RUNTIME_EXCEPTION: + err = runtime_error; + break; + case ORT_NOT_IMPLEMENTED: + err = unsupported_operation; + break; + case ORT_INVALID_PROTOBUF: + err = invalid_encoding; + break; + case ORT_MODEL_LOADED: + err = too_large; + break; + case ORT_INVALID_GRAPH: + err = invalid_encoding; + break; + default: + err = runtime_error; + break; + } + + g_ort_ctx.ort_api->ReleaseStatus(status); + return err; +} + +static tensor_type +convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type) +{ + switch (ort_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return fp32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return fp16; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return fp64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return u8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return i32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return i64; +#else + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return up8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ip32; +#endif + default: + NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); + return fp32; // Default to fp32 + } +} + +static ONNXTensorElementDataType +convert_wasi_nn_type_to_ort_type(tensor_type type) +{ + switch (type) { + case fp32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + case fp16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case fp64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + case u8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + case i32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + case i64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; +#else + case up8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + case ip32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; +#endif + default: + NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float + } +} + +static size_t +get_tensor_element_size(tensor_type type) +{ + switch (type) { + case fp32: + return 4; + case fp16: + return 2; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case fp64: + return 8; + case u8: + return 1; + case i32: + return 4; + case i64: + return 8; +#else + case up8: + return 1; + case ip32: + return 4; +#endif + default: + NN_WARN_PRINTF("Unsupported tensor type: %d", type); + return 4; // Default to 4 bytes (float) + } +} + +/* Backend API implementation */ + +extern "C" { + +__attribute__((visibility("default"))) wasi_nn_error +init_backend(void **onnx_ctx) +{ + std::lock_guard lock(g_ort_ctx.mutex); + + if (g_ort_ctx.is_initialized) { + *onnx_ctx = &g_ort_ctx; + return success; + } + + g_ort_ctx.ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + if (!g_ort_ctx.ort_api) { + NN_ERR_PRINTF("Failed to get ONNX Runtime API"); + return runtime_error; + } + + NN_INFO_PRINTF("Creating ONNX Runtime environment..."); + OrtStatus *status = g_ort_ctx.ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, + "wasi-nn", &g_ort_ctx.env); + if (status != nullptr) { + const char *error_message = g_ort_ctx.ort_api->GetErrorMessage(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s", + error_message); + g_ort_ctx.ort_api->ReleaseStatus(status); + return err; + } + NN_INFO_PRINTF("ONNX Runtime environment created successfully"); + + status = + g_ort_ctx.ort_api->CreateSessionOptions(&g_ort_ctx.session_options); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to create ONNX Runtime session options"); + return err; + } + + status = g_ort_ctx.ort_api->SetSessionGraphOptimizationLevel( + g_ort_ctx.session_options, ORT_ENABLE_BASIC); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to set graph optimization level"); + return err; + } + + status = + g_ort_ctx.ort_api->GetAllocatorWithDefaultOptions(&g_ort_ctx.allocator); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); + g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + NN_ERR_PRINTF("Failed to get default allocator"); + return err; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + g_graphs[i].is_initialized = false; + g_graphs[i].session = nullptr; + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + g_exec_ctxs[i].is_initialized = false; + g_exec_ctxs[i].memory_info = nullptr; + g_exec_ctxs[i].graph = nullptr; + g_exec_ctxs[i].input_names.clear(); + g_exec_ctxs[i].output_names.clear(); + g_exec_ctxs[i].inputs.clear(); + g_exec_ctxs[i].outputs.clear(); + } + + g_ort_ctx.is_initialized = true; + *onnx_ctx = &g_ort_ctx; + + NN_INFO_PRINTF("ONNX Runtime backend initialized"); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +deinit_backend(void *onnx_ctx) +{ + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + if (!ctx->is_initialized) { + return success; + } + + for (int i = 0; i < MAX_GRAPHS; i++) { + if (g_graphs[i].is_initialized) { + ctx->ort_api->ReleaseSession(g_graphs[i].session); + g_graphs[i].is_initialized = false; + } + } + + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (g_exec_ctxs[i].is_initialized) { + for (auto &input : g_exec_ctxs[i].inputs) { + ctx->ort_api->ReleaseValue(input.second); + } + for (auto &output : g_exec_ctxs[i].outputs) { + ctx->ort_api->ReleaseValue(output.second); + } + ctx->ort_api->ReleaseMemoryInfo(g_exec_ctxs[i].memory_info); + g_exec_ctxs[i].is_initialized = false; + } + } + + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); + ctx->is_initialized = false; + + NN_INFO_PRINTF("ONNX Runtime backend deinitialized"); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, + execution_target target, graph *g) +{ + if (encoding != onnx) { + NN_ERR_PRINTF("Unsupported encoding: %d", encoding); + return invalid_encoding; + } + + if (target != cpu) { + NN_ERR_PRINTF("Only CPU target is supported"); + return unsupported_operation; + } + + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!g_graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + if (builder->size == 0 || builder->buf == NULL) { + NN_ERR_PRINTF("No model data provided"); + return invalid_argument; + } + + NN_INFO_PRINTF("[ONNX Runtime] Loading model of size %zu bytes...", + builder->buf[0].size); + + if (builder->buf[0].size > 16) { + NN_INFO_PRINTF( + "Model header bytes: %02x %02x %02x %02x %02x %02x %02x %02x", + ((uint8_t *)builder->buf[0].buf)[0], + ((uint8_t *)builder->buf[0].buf)[1], + ((uint8_t *)builder->buf[0].buf)[2], + ((uint8_t *)builder->buf[0].buf)[3], + ((uint8_t *)builder->buf[0].buf)[4], + ((uint8_t *)builder->buf[0].buf)[5], + ((uint8_t *)builder->buf[0].buf)[6], + ((uint8_t *)builder->buf[0].buf)[7]); + } + + OrtStatus *status = ctx->ort_api->CreateSessionFromArray( + ctx->env, builder->buf[0].buf, builder->buf[0].size, + ctx->session_options, &g_graphs[graph_index].session); + + if (status != nullptr) { + const char *error_message = ctx->ort_api->GetErrorMessage(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s", + error_message); + ctx->ort_api->ReleaseStatus(status); + return err; + } + + NN_INFO_PRINTF("ONNX Runtime session created successfully"); + + g_graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) +{ + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ctx->mutex); + + int graph_index = -1; + for (int i = 0; i < MAX_GRAPHS; i++) { + if (!g_graphs[i].is_initialized) { + graph_index = i; + break; + } + } + + if (graph_index == -1) { + NN_ERR_PRINTF("Maximum number of graphs reached"); + return runtime_error; + } + + OrtStatus *status = ctx->ort_api->CreateSession( + ctx->env, name, ctx->session_options, &g_graphs[graph_index].session); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s", + name); + return err; + } + + g_graphs[graph_index].is_initialized = true; + *g = graph_index; + + NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name, + graph_index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) +{ + if (g >= MAX_GRAPHS || !g_graphs[g].is_initialized) { + NN_ERR_PRINTF("Invalid graph handle: %d", g); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + + int ctx_index = -1; + for (int i = 0; i < MAX_CONTEXTS; i++) { + if (!g_exec_ctxs[i].is_initialized) { + ctx_index = i; + break; + } + } + + if (ctx_index == -1) { + NN_ERR_PRINTF("Maximum number of execution contexts reached"); + return runtime_error; + } + + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx_index]; + exec_ctx->graph = &g_graphs[g]; + + OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo( + OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create CPU memory info"); + return err; + } + + size_t num_input_nodes; + status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session, + &num_input_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input count"); + return err; + } + + for (size_t i = 0; i < num_input_nodes; i++) { + char *input_name; + status = ort_ctx->ort_api->SessionGetInputName( + exec_ctx->graph->session, i, ort_ctx->allocator, &input_name); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + NN_ERR_PRINTF("Failed to get input name"); + return err; + } + exec_ctx->input_names.push_back(input_name); + } + + size_t num_output_nodes; + status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session, + &num_output_nodes); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output count"); + return err; + } + + for (size_t i = 0; i < num_output_nodes; i++) { + char *output_name; + status = ort_ctx->ort_api->SessionGetOutputName( + exec_ctx->graph->session, i, ort_ctx->allocator, &output_name); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); + for (const char *name : exec_ctx->input_names) { + ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); + } + NN_ERR_PRINTF("Failed to get output name"); + return err; + } + exec_ctx->output_names.push_back(output_name); + } + + exec_ctx->is_initialized = true; + *ctx = ctx_index; + + NN_INFO_PRINTF("Execution context %d initialized for graph %d", ctx_index, + g); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, + tensor *input_tensor) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= g_exec_ctxs[ctx].input_names.size()) { + NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index, + g_exec_ctxs[ctx].input_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + OrtTypeInfo *type_info = nullptr; + OrtStatus *status = ort_ctx->ort_api->SessionGetInputTypeInfo( + exec_ctx->graph->session, index, &type_info); + if (status != nullptr) { + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + return runtime_error; + } + + const OrtTensorTypeAndShapeInfo *tensor_info; + status = + ort_ctx->ort_api->CastTypeInfoToTensorInfo(type_info, &tensor_info); + if (status != nullptr) { + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + return runtime_error; + } + + size_t num_model_dims; + status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_model_dims); + std::vector model_dims(num_model_dims); + status = ort_ctx->ort_api->GetDimensions(tensor_info, model_dims.data(), + num_model_dims); + + size_t model_tensor_size = 1; + for (size_t i = 0; i < num_model_dims; ++i) + model_tensor_size *= model_dims[i]; + + size_t input_tensor_size = 1; + for (size_t i = 0; i < input_tensor->dimensions->size; ++i) + input_tensor_size *= input_tensor->dimensions->buf[i]; + + void *input_tensor_data = input_tensor->data.buf; + void *input_tensor_scaled_data = NULL; + ort_ctx->ort_api->ReleaseTypeInfo(type_info); + size_t num_dims = input_tensor->dimensions->size; + int64_t *ort_dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!ort_dims) { + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + for (size_t i = 0; i < num_dims; i++) { + ort_dims[i] = input_tensor->dimensions->buf[i]; + } + + ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type( + static_cast(input_tensor->type)); + + OrtValue *input_value = nullptr; + size_t total_elements = 1; + for (size_t i = 0; i < num_dims; i++) { + total_elements *= input_tensor->dimensions->buf[i]; + } + + status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( + exec_ctx->memory_info, input_tensor->data.buf, + get_tensor_element_size(static_cast(input_tensor->type)) + * total_elements, + ort_dims, num_dims, ort_type, &input_value); + + free(ort_dims); + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to create input tensor"); + return err; + } + + if (exec_ctx->inputs.count(index) > 0) { + ort_ctx->ort_api->ReleaseValue(exec_ctx->inputs[index]); + } + exec_ctx->inputs[index] = input_value; + + NN_INFO_PRINTF("Input tensor set for context %d, index %d", ctx, index); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +compute(void *onnx_ctx, graph_execution_context ctx) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + std::vector input_values; + std::vector input_names; + + for (size_t i = 0; i < exec_ctx->input_names.size(); i++) { + if (exec_ctx->inputs.count(i) == 0) { + NN_ERR_PRINTF("Input tensor not set for index %zu", i); + return invalid_argument; + } + input_values.push_back(exec_ctx->inputs[i]); + input_names.push_back(exec_ctx->input_names[i]); + } + + for (auto &output : exec_ctx->outputs) { + ort_ctx->ort_api->ReleaseValue(output.second); + } + exec_ctx->outputs.clear(); + + std::vector output_values(exec_ctx->output_names.size()); + + OrtStatus *status = ort_ctx->ort_api->Run( + exec_ctx->graph->session, nullptr, input_names.data(), + input_values.data(), input_values.size(), exec_ctx->output_names.data(), + exec_ctx->output_names.size(), output_values.data()); + + for (size_t i = 0; i < output_values.size(); i++) { + exec_ctx->outputs[i] = output_values[i]; + } + + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to run inference"); + return err; + } + + NN_INFO_PRINTF("Inference computed for context %d", ctx); + return success; +} + +__attribute__((visibility("default"))) wasi_nn_error +get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, + tensor_data *out_buffer, uint32_t *out_buffer_size) +{ + if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); + return invalid_argument; + } + + if (index >= g_exec_ctxs[ctx].output_names.size()) { + NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index, + g_exec_ctxs[ctx].output_names.size() - 1); + return invalid_argument; + } + + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + std::lock_guard lock(ort_ctx->mutex); + OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + + OrtValue *output_value = exec_ctx->outputs[index]; + if (!output_value) { + NN_ERR_PRINTF("Output tensor not available for index %d", index); + return runtime_error; + } + + OrtTensorTypeAndShapeInfo *tensor_info; + OrtStatus *status = + ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to get tensor type and shape"); + return err; + } + + ONNXTensorElementDataType element_type; + status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element type"); + return err; + } + + size_t num_dims; + status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions count"); + return err; + } + + int64_t *dims = (int64_t *)malloc(num_dims * sizeof(int64_t)); + if (!dims) { + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions"); + return runtime_error; + } + + status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor dimensions"); + return err; + } + + size_t tensor_size; + status = + ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + free(dims); + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + NN_ERR_PRINTF("Failed to get tensor element count"); + return err; + } + + NN_INFO_PRINTF("Output tensor dimensions: "); + for (size_t i = 0; i < num_dims; i++) { + NN_INFO_PRINTF(" dim[%zu] = %lld", i, dims[i]); + } + NN_INFO_PRINTF("Total elements: %zu", tensor_size); + + ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + free(dims); + + if (tensor_size == 0) { + NN_ERR_PRINTF("Tensor is empty (zero elements)"); + return runtime_error; + } + + void *tensor_data = nullptr; + status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data); + if (status != nullptr) { + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + NN_ERR_PRINTF("Failed to get tensor data"); + return err; + } + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data pointer is null"); + return runtime_error; + } + + size_t element_size; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + element_size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + element_size = sizeof(uint16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + element_size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + element_size = sizeof(int32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + element_size = sizeof(int64_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + element_size = sizeof(uint8_t); + break; + default: + NN_ERR_PRINTF("Unsupported tensor element type: %d", element_type); + return unsupported_operation; + } + + size_t output_size_bytes = tensor_size * element_size; + + NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, " + "total: %zu bytes", + tensor_size, element_size, output_size_bytes); + + if (*out_buffer_size < output_size_bytes) { + NN_ERR_PRINTF( + "Output buffer too small: %u bytes provided, %zu bytes needed", + *out_buffer_size, output_size_bytes); + *out_buffer_size = output_size_bytes; + return invalid_argument; + } + + if (tensor_data == nullptr) { + NN_ERR_PRINTF("Tensor data is null"); + return runtime_error; + } + + if (out_buffer->buf == nullptr) { + NN_ERR_PRINTF("Output buffer is null"); + return invalid_argument; + } + + memcpy(out_buffer->buf, tensor_data, output_size_bytes); + *out_buffer_size = output_size_bytes; + + NN_INFO_PRINTF( + "Output tensor retrieved for context %d, index %d, size %zu bytes", ctx, + index, output_size_bytes); + return success; +} + +} /* End of extern "C" */ From 6cab94c38253f6483d4bf4d7eaf2882a63d0b668 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Thu, 17 Jul 2025 14:20:59 +0800 Subject: [PATCH 02/10] follow up some review comments 1, type converter btw wasi-nn and onnx runtime returns bool instead of type 2, out_buffer_size does not hold the expected size. 3, onnx runtime does not need calculate input_tenser size. --- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 3 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 118 ++++++++---------- 2 files changed, 56 insertions(+), 65 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index 8576624674..56ae36d2ed 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,7 +21,8 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" +#warning \ + "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" #endif /** diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index ab9042d013..4121439625 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -102,91 +102,81 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status) return err; } -static tensor_type -convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type) +static bool +convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, tensor_type *tensor_type) { switch (ort_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return fp32; + *tensor_type = fp32; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return fp16; + *tensor_type = fp16; + break; #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return fp64; + *tensor_type = fp64; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return u8; + *tensor_type = u8; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return i32; + *tensor_type = i32; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - return i64; + *tensor_type = i64; + break; #else case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return up8; + *tensor_type = up8; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return ip32; + *tensor_type = ip32; + break; #endif default: NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); - return fp32; // Default to fp32 + return false; } -} -static ONNXTensorElementDataType -convert_wasi_nn_type_to_ort_type(tensor_type type) -{ - switch (type) { - case fp32: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - case fp16: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - case fp64: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; - case u8: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - case i32: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - case i64: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; -#else - case up8: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - case ip32: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; -#endif - default: - NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float - } + return true; } -static size_t -get_tensor_element_size(tensor_type type) +static bool +convert_wasi_nn_type_to_ort_type(tensor_type type, ONNXTensorElementDataType *ort_type) { switch (type) { case fp32: - return 4; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + break; case fp16: - return 2; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + break; #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 case fp64: - return 8; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + break; case u8: - return 1; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + break; case i32: - return 4; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + break; case i64: - return 8; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + break; #else case up8: - return 1; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + break; case ip32: - return 4; + *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + break; #endif default: - NN_WARN_PRINTF("Unsupported tensor type: %d", type); - return 4; // Default to 4 bytes (float) + NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); + return false; // Default to float } + return true; } /* Backend API implementation */ @@ -579,8 +569,12 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, ort_dims[i] = input_tensor->dimensions->buf[i]; } - ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type( - static_cast(input_tensor->type)); + ONNXTensorElementDataType ort_type; + if (!convert_wasi_nn_type_to_ort_type( + static_cast(input_tensor->type), &ort_type)) { + NN_ERR_PRINTF("Failed to convert tensor type"); + return runtime_error; + } OrtValue *input_value = nullptr; size_t total_elements = 1; @@ -589,9 +583,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, } status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( - exec_ctx->memory_info, input_tensor->data.buf, - get_tensor_element_size(static_cast(input_tensor->type)) - * total_elements, + exec_ctx->memory_info, input_tensor->data.buf,input_tensor->data.size, ort_dims, num_dims, ort_type, &input_value); free(ort_dims); @@ -793,18 +785,16 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, } size_t output_size_bytes = tensor_size * element_size; - - NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, " - "total: %zu bytes", - tensor_size, element_size, output_size_bytes); - - if (*out_buffer_size < output_size_bytes) { + if (out_buffer->size < output_size_bytes) { NN_ERR_PRINTF( "Output buffer too small: %u bytes provided, %zu bytes needed", - *out_buffer_size, output_size_bytes); + out_buffer->size, output_size_bytes); *out_buffer_size = output_size_bytes; - return invalid_argument; + return too_large; } + NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, " + "total: %zu bytes", + tensor_size, element_size, output_size_bytes); if (tensor_data == nullptr) { NN_ERR_PRINTF("Tensor data is null"); From c866d0569587731bd85b1a3ccac1228180509e0c Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Thu, 17 Jul 2025 14:30:39 +0800 Subject: [PATCH 03/10] clang-format --- core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 4121439625..acfc668a14 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -103,7 +103,8 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status) } static bool -convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, tensor_type *tensor_type) +convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, + tensor_type *tensor_type) { switch (ort_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: @@ -142,7 +143,8 @@ convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, tensor_type } static bool -convert_wasi_nn_type_to_ort_type(tensor_type type, ONNXTensorElementDataType *ort_type) +convert_wasi_nn_type_to_ort_type(tensor_type type, + ONNXTensorElementDataType *ort_type) { switch (type) { case fp32: @@ -583,7 +585,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, } status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( - exec_ctx->memory_info, input_tensor->data.buf,input_tensor->data.size, + exec_ctx->memory_info, input_tensor->data.buf, input_tensor->data.size, ort_dims, num_dims, ort_type, &input_value); free(ort_dims); From 4b503412fb759b8192ae5fdf17ef15ce8926a4af Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Tue, 29 Jul 2025 11:52:43 +0800 Subject: [PATCH 04/10] remove global context --- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 3 +- .../libraries/wasi-nn/include/wasi_nn_types.h | 2 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 246 +++++++++--------- 4 files changed, 125 insertions(+), 129 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index 56ae36d2ed..cda26324eb 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,8 +21,7 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning \ - "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif /** diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 23b697a8cf..952fb65e28 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -27,7 +27,7 @@ extern "C" { #define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) #define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) #define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) -#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); #endif /** diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8781c32105..787c3a432d 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -21,8 +21,7 @@ #include "wasm_export.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 -#warning \ - "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif #define HASHMAP_INITIAL_SIZE 20 diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index acfc668a14..90a01dabf8 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -15,18 +15,8 @@ #include "onnxruntime_c_api.h" /* Maximum number of graphs and execution contexts */ -#define MAX_GRAPHS 10 -#define MAX_CONTEXTS 10 - -/* ONNX Runtime context structure */ -typedef struct { - OrtEnv *env; - OrtSessionOptions *session_options; - OrtAllocator *allocator; - const OrtApi *ort_api; - std::mutex mutex; - bool is_initialized; -} OnnxRuntimeContext; +#define MAX_GRAPHS 4 +#define MAX_CONTEXTS 4 /* Graph structure */ typedef struct { @@ -45,32 +35,40 @@ typedef struct { bool is_initialized; } OnnxRuntimeExecCtx; -/* Global variables */ -static OnnxRuntimeContext g_ort_ctx; -static OnnxRuntimeGraph g_graphs[MAX_GRAPHS]; -static OnnxRuntimeExecCtx g_exec_ctxs[MAX_CONTEXTS]; +/* ONNX Runtime context structure */ +typedef struct { + OrtEnv *env; + OrtSessionOptions *session_options; + OrtAllocator *allocator; + const OrtApi *ort_api; + std::mutex mutex; + bool is_initialized; + OnnxRuntimeGraph graphs[MAX_GRAPHS]; + OnnxRuntimeExecCtx exec_ctxs[MAX_CONTEXTS]; +} OnnxRuntimeContext; /* Helper functions */ static void -check_status_and_log(OrtStatus *status) +check_status_and_log(const OnnxRuntimeContext *ctx, OrtStatus *status) { if (status != nullptr) { - const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + const char *msg = ctx->ort_api->GetErrorMessage(status); NN_ERR_PRINTF("ONNX Runtime error: %s", msg); - g_ort_ctx.ort_api->ReleaseStatus(status); + ctx->ort_api->ReleaseStatus(status); } } static wasi_nn_error -convert_ort_error_to_wasi_nn_error(OrtStatus *status) +convert_ort_error_to_wasi_nn_error(const OnnxRuntimeContext *ctx, + OrtStatus *status) { if (status == nullptr) { return success; } wasi_nn_error err; - OrtErrorCode code = g_ort_ctx.ort_api->GetErrorCode(status); - const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status); + OrtErrorCode code = ctx->ort_api->GetErrorCode(status); + const char *msg = ctx->ort_api->GetErrorMessage(status); NN_ERR_PRINTF("ONNX Runtime error: %s", msg); @@ -98,7 +96,7 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status) break; } - g_ort_ctx.ort_api->ReleaseStatus(status); + ctx->ort_api->ReleaseStatus(status); return err; } @@ -188,81 +186,81 @@ extern "C" { __attribute__((visibility("default"))) wasi_nn_error init_backend(void **onnx_ctx) { - std::lock_guard lock(g_ort_ctx.mutex); - - if (g_ort_ctx.is_initialized) { - *onnx_ctx = &g_ort_ctx; - return success; - } - - g_ort_ctx.ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - if (!g_ort_ctx.ort_api) { + wasi_nn_error err = success; + OrtStatus *status = nullptr; + OnnxRuntimeContext *ctx = nullptr; + ctx = new OnnxRuntimeContext(); + ctx->ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + if (!ctx->ort_api) { NN_ERR_PRINTF("Failed to get ONNX Runtime API"); - return runtime_error; + err = runtime_error; + goto fail; } NN_INFO_PRINTF("Creating ONNX Runtime environment..."); - OrtStatus *status = g_ort_ctx.ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, - "wasi-nn", &g_ort_ctx.env); + status = ctx->ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE, "wasi-nn", + &ctx->env); if (status != nullptr) { - const char *error_message = g_ort_ctx.ort_api->GetErrorMessage(status); - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + const char *error_message = ctx->ort_api->GetErrorMessage(status); + err = convert_ort_error_to_wasi_nn_error(ctx, status); NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s", error_message); - g_ort_ctx.ort_api->ReleaseStatus(status); - return err; + ctx->ort_api->ReleaseStatus(status); + goto fail; } NN_INFO_PRINTF("ONNX Runtime environment created successfully"); - status = - g_ort_ctx.ort_api->CreateSessionOptions(&g_ort_ctx.session_options); + status = ctx->ort_api->CreateSessionOptions(&ctx->session_options); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); - g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseEnv(ctx->env); NN_ERR_PRINTF("Failed to create ONNX Runtime session options"); - return err; + goto fail; } - status = g_ort_ctx.ort_api->SetSessionGraphOptimizationLevel( - g_ort_ctx.session_options, ORT_ENABLE_BASIC); + status = ctx->ort_api->SetSessionGraphOptimizationLevel( + ctx->session_options, ORT_ENABLE_BASIC); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); - g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); - g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); NN_ERR_PRINTF("Failed to set graph optimization level"); - return err; + goto fail; } - status = - g_ort_ctx.ort_api->GetAllocatorWithDefaultOptions(&g_ort_ctx.allocator); + status = ctx->ort_api->GetAllocatorWithDefaultOptions(&ctx->allocator); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); - g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options); - g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env); + err = convert_ort_error_to_wasi_nn_error(ctx, status); + ctx->ort_api->ReleaseSessionOptions(ctx->session_options); + ctx->ort_api->ReleaseEnv(ctx->env); NN_ERR_PRINTF("Failed to get default allocator"); - return err; + goto fail; } for (int i = 0; i < MAX_GRAPHS; i++) { - g_graphs[i].is_initialized = false; - g_graphs[i].session = nullptr; + ctx->graphs[i].is_initialized = false; + ctx->graphs[i].session = nullptr; } for (int i = 0; i < MAX_CONTEXTS; i++) { - g_exec_ctxs[i].is_initialized = false; - g_exec_ctxs[i].memory_info = nullptr; - g_exec_ctxs[i].graph = nullptr; - g_exec_ctxs[i].input_names.clear(); - g_exec_ctxs[i].output_names.clear(); - g_exec_ctxs[i].inputs.clear(); - g_exec_ctxs[i].outputs.clear(); + ctx->exec_ctxs[i].is_initialized = false; + ctx->exec_ctxs[i].memory_info = nullptr; + ctx->exec_ctxs[i].graph = nullptr; + ctx->exec_ctxs[i].input_names.clear(); + ctx->exec_ctxs[i].output_names.clear(); + ctx->exec_ctxs[i].inputs.clear(); + ctx->exec_ctxs[i].outputs.clear(); } - g_ort_ctx.is_initialized = true; - *onnx_ctx = &g_ort_ctx; + ctx->is_initialized = true; + *onnx_ctx = ctx; NN_INFO_PRINTF("ONNX Runtime backend initialized"); return success; + +fail: + delete (ctx); + return err; } __attribute__((visibility("default"))) wasi_nn_error @@ -276,22 +274,22 @@ deinit_backend(void *onnx_ctx) } for (int i = 0; i < MAX_GRAPHS; i++) { - if (g_graphs[i].is_initialized) { - ctx->ort_api->ReleaseSession(g_graphs[i].session); - g_graphs[i].is_initialized = false; + if (ctx->graphs[i].is_initialized) { + ctx->ort_api->ReleaseSession(ctx->graphs[i].session); + ctx->graphs[i].is_initialized = false; } } for (int i = 0; i < MAX_CONTEXTS; i++) { - if (g_exec_ctxs[i].is_initialized) { - for (auto &input : g_exec_ctxs[i].inputs) { + if (ctx->exec_ctxs[i].is_initialized) { + for (auto &input : ctx->exec_ctxs[i].inputs) { ctx->ort_api->ReleaseValue(input.second); } - for (auto &output : g_exec_ctxs[i].outputs) { + for (auto &output : ctx->exec_ctxs[i].outputs) { ctx->ort_api->ReleaseValue(output.second); } - ctx->ort_api->ReleaseMemoryInfo(g_exec_ctxs[i].memory_info); - g_exec_ctxs[i].is_initialized = false; + ctx->ort_api->ReleaseMemoryInfo(ctx->exec_ctxs[i].memory_info); + ctx->exec_ctxs[i].is_initialized = false; } } @@ -299,6 +297,8 @@ deinit_backend(void *onnx_ctx) ctx->ort_api->ReleaseEnv(ctx->env); ctx->is_initialized = false; + delete (ctx); + NN_INFO_PRINTF("ONNX Runtime backend deinitialized"); return success; } @@ -322,7 +322,7 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, int graph_index = -1; for (int i = 0; i < MAX_GRAPHS; i++) { - if (!g_graphs[i].is_initialized) { + if (!ctx->graphs[i].is_initialized) { graph_index = i; break; } @@ -356,11 +356,11 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, OrtStatus *status = ctx->ort_api->CreateSessionFromArray( ctx->env, builder->buf[0].buf, builder->buf[0].size, - ctx->session_options, &g_graphs[graph_index].session); + ctx->session_options, &ctx->graphs[graph_index].session); if (status != nullptr) { const char *error_message = ctx->ort_api->GetErrorMessage(status); - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ctx, status); NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s", error_message); ctx->ort_api->ReleaseStatus(status); @@ -369,7 +369,7 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, NN_INFO_PRINTF("ONNX Runtime session created successfully"); - g_graphs[graph_index].is_initialized = true; + ctx->graphs[graph_index].is_initialized = true; *g = graph_index; NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index); @@ -384,7 +384,7 @@ load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) int graph_index = -1; for (int i = 0; i < MAX_GRAPHS; i++) { - if (!g_graphs[i].is_initialized) { + if (!ctx->graphs[i].is_initialized) { graph_index = i; break; } @@ -395,17 +395,18 @@ load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) return runtime_error; } - OrtStatus *status = ctx->ort_api->CreateSession( - ctx->env, name, ctx->session_options, &g_graphs[graph_index].session); + OrtStatus *status = + ctx->ort_api->CreateSession(ctx->env, name, ctx->session_options, + &ctx->graphs[graph_index].session); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ctx, status); NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s", name); return err; } - g_graphs[graph_index].is_initialized = true; + ctx->graphs[graph_index].is_initialized = true; *g = graph_index; NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name, @@ -416,17 +417,17 @@ load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) __attribute__((visibility("default"))) wasi_nn_error init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) { - if (g >= MAX_GRAPHS || !g_graphs[g].is_initialized) { + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + + if (g >= MAX_GRAPHS || !ort_ctx->graphs[g].is_initialized) { NN_ERR_PRINTF("Invalid graph handle: %d", g); return invalid_argument; } - OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; std::lock_guard lock(ort_ctx->mutex); - int ctx_index = -1; for (int i = 0; i < MAX_CONTEXTS; i++) { - if (!g_exec_ctxs[i].is_initialized) { + if (!ort_ctx->exec_ctxs[i].is_initialized) { ctx_index = i; break; } @@ -437,13 +438,13 @@ init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) return runtime_error; } - OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx_index]; - exec_ctx->graph = &g_graphs[g]; + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx_index]; + exec_ctx->graph = &ort_ctx->graphs[g]; OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo( OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); NN_ERR_PRINTF("Failed to create CPU memory info"); return err; } @@ -452,7 +453,7 @@ init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session, &num_input_nodes); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); NN_ERR_PRINTF("Failed to get input count"); return err; @@ -463,7 +464,8 @@ init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) status = ort_ctx->ort_api->SessionGetInputName( exec_ctx->graph->session, i, ort_ctx->allocator, &input_name); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = + convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); NN_ERR_PRINTF("Failed to get input name"); return err; @@ -475,7 +477,7 @@ init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session, &num_output_nodes); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); for (const char *name : exec_ctx->input_names) { ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); @@ -489,7 +491,8 @@ init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) status = ort_ctx->ort_api->SessionGetOutputName( exec_ctx->graph->session, i, ort_ctx->allocator, &output_name); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = + convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info); for (const char *name : exec_ctx->input_names) { ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name); @@ -512,20 +515,21 @@ __attribute__((visibility("default"))) wasi_nn_error set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor *input_tensor) { - if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); return invalid_argument; } - if (index >= g_exec_ctxs[ctx].input_names.size()) { + if (index >= ort_ctx->exec_ctxs[ctx].input_names.size()) { NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index, - g_exec_ctxs[ctx].input_names.size() - 1); + ort_ctx->exec_ctxs[ctx].input_names.size() - 1); return invalid_argument; } - OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; std::lock_guard lock(ort_ctx->mutex); - OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; OrtTypeInfo *type_info = nullptr; OrtStatus *status = ort_ctx->ort_api->SessionGetInputTypeInfo( @@ -549,14 +553,6 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, status = ort_ctx->ort_api->GetDimensions(tensor_info, model_dims.data(), num_model_dims); - size_t model_tensor_size = 1; - for (size_t i = 0; i < num_model_dims; ++i) - model_tensor_size *= model_dims[i]; - - size_t input_tensor_size = 1; - for (size_t i = 0; i < input_tensor->dimensions->size; ++i) - input_tensor_size *= input_tensor->dimensions->buf[i]; - void *input_tensor_data = input_tensor->data.buf; void *input_tensor_scaled_data = NULL; ort_ctx->ort_api->ReleaseTypeInfo(type_info); @@ -591,7 +587,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, free(ort_dims); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); NN_ERR_PRINTF("Failed to create input tensor"); return err; } @@ -608,14 +604,15 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, __attribute__((visibility("default"))) wasi_nn_error compute(void *onnx_ctx, graph_execution_context ctx) { - if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); return invalid_argument; } - OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; std::lock_guard lock(ort_ctx->mutex); - OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; std::vector input_values; std::vector input_names; @@ -646,7 +643,7 @@ compute(void *onnx_ctx, graph_execution_context ctx) } if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); NN_ERR_PRINTF("Failed to run inference"); return err; } @@ -659,20 +656,21 @@ __attribute__((visibility("default"))) wasi_nn_error get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor_data *out_buffer, uint32_t *out_buffer_size) { - if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) { + OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + + if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); return invalid_argument; } - if (index >= g_exec_ctxs[ctx].output_names.size()) { + if (index >= ort_ctx->exec_ctxs[ctx].output_names.size()) { NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index, - g_exec_ctxs[ctx].output_names.size() - 1); + ort_ctx->exec_ctxs[ctx].output_names.size() - 1); return invalid_argument; } - OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; std::lock_guard lock(ort_ctx->mutex); - OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx]; + OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; OrtValue *output_value = exec_ctx->outputs[index]; if (!output_value) { @@ -684,7 +682,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, OrtStatus *status = ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); NN_ERR_PRINTF("Failed to get tensor type and shape"); return err; } @@ -692,7 +690,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, ONNXTensorElementDataType element_type; status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); NN_ERR_PRINTF("Failed to get tensor element type"); return err; @@ -701,7 +699,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, size_t num_dims; status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); NN_ERR_PRINTF("Failed to get tensor dimensions count"); return err; @@ -716,7 +714,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); free(dims); ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); NN_ERR_PRINTF("Failed to get tensor dimensions"); @@ -727,7 +725,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, status = ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); free(dims); ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); NN_ERR_PRINTF("Failed to get tensor element count"); @@ -751,7 +749,7 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, void *tensor_data = nullptr; status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data); if (status != nullptr) { - wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status); + wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ort_ctx, status); NN_ERR_PRINTF("Failed to get tensor data"); return err; } From 801eb2b8d2172eafb35db502b44d1980ba29c7b9 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Thu, 31 Jul 2025 15:40:44 +0800 Subject: [PATCH 05/10] put checks under the lock --- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 90a01dabf8..5c5f36727d 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -174,7 +174,7 @@ convert_wasi_nn_type_to_ort_type(tensor_type type, #endif default: NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); - return false; // Default to float + return false; } return true; } @@ -418,13 +418,17 @@ __attribute__((visibility("default"))) wasi_nn_error init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx) { OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); if (g >= MAX_GRAPHS || !ort_ctx->graphs[g].is_initialized) { NN_ERR_PRINTF("Invalid graph handle: %d", g); return invalid_argument; } - std::lock_guard lock(ort_ctx->mutex); int ctx_index = -1; for (int i = 0; i < MAX_CONTEXTS; i++) { if (!ort_ctx->exec_ctxs[i].is_initialized) { @@ -516,6 +520,11 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor *input_tensor) { OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); @@ -528,7 +537,6 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index, return invalid_argument; } - std::lock_guard lock(ort_ctx->mutex); OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; OrtTypeInfo *type_info = nullptr; @@ -605,13 +613,17 @@ __attribute__((visibility("default"))) wasi_nn_error compute(void *onnx_ctx, graph_execution_context ctx) { OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); return invalid_argument; } - std::lock_guard lock(ort_ctx->mutex); OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; std::vector input_values; @@ -657,6 +669,11 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, tensor_data *out_buffer, uint32_t *out_buffer_size) { OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; + if (!onnx_ctx) { + return runtime_error; + } + + std::lock_guard lock(ort_ctx->mutex); if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { NN_ERR_PRINTF("Invalid execution context handle: %d", ctx); @@ -669,7 +686,6 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index, return invalid_argument; } - std::lock_guard lock(ort_ctx->mutex); OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs[ctx]; OrtValue *output_value = exec_ctx->outputs[index]; From cd3cb6c2d1ec89820887e9602ed4a5c06f5fd333 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Tue, 5 Aug 2025 09:44:06 +0800 Subject: [PATCH 06/10] use WASM_ENABLE_WASI_EPHEMERAL_NN --- core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 5c5f36727d..5c655244f1 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -14,6 +14,10 @@ #include "utils/logger.h" #include "onnxruntime_c_api.h" +#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 +#error This backend doesn't support legacy "wasi_nn" abi. Please enable WASM_ENABLE_WASI_EPHEMERAL_NN. +#endif + /* Maximum number of graphs and execution contexts */ #define MAX_GRAPHS 4 #define MAX_CONTEXTS 4 From 38e9d2b5d605fd0a77162ea1e49bbe97515ce9c1 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Tue, 5 Aug 2025 12:02:20 +0800 Subject: [PATCH 07/10] tensor type will not support legacy wasi-nn abi --- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 5c655244f1..e310d56531 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -115,7 +115,6 @@ convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: *tensor_type = fp16; break; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: *tensor_type = fp64; break; @@ -128,14 +127,6 @@ convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: *tensor_type = i64; break; -#else - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - *tensor_type = up8; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - *tensor_type = ip32; - break; -#endif default: NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); return false; @@ -155,7 +146,6 @@ convert_wasi_nn_type_to_ort_type(tensor_type type, case fp16: *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; break; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 case fp64: *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; break; @@ -168,14 +158,6 @@ convert_wasi_nn_type_to_ort_type(tensor_type type, case i64: *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; break; -#else - case up8: - *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; - break; - case ip32: - *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; - break; -#endif default: NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); return false; From fa6c3a3175986f87132624e203562dcaffb256d5 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Fri, 8 Aug 2025 10:14:23 +0800 Subject: [PATCH 08/10] using the CMake config file --- core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake | 9 +++++++++ core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake | 4 +--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake index 41df9d5770..d73082f587 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake @@ -22,6 +22,15 @@ # endif() # First try to find ONNX Runtime using the CMake config file +# FIXME: This is a temporary workaround for ONNX Runtime's broken CMake config on Linux. +# See https://github.com/microsoft/onnxruntime/issues/25279 +# Once the upstream issue is fixed, this conditional can be safely removed. +if(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + find_package(onnxruntime CONFIG QUIET) + if(onnxruntime_FOUND) + return() + endif() +endif() # If not found via CMake config, try to find manually find_path(onnxruntime_INCLUDE_DIR diff --git a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake index 370fcc468a..fefca7d4ef 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake @@ -123,9 +123,7 @@ if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1) target_include_directories( wasi_nn_onnx - PUBLIC - ${onnxruntime_INCLUDE_DIR}/onnx - ${onnxruntime_INCLUDE_DIR} + PUBLIC ${INTERFACE_INCLUDE_DIRECTORIES} ) target_link_libraries( From 0c164b6ac62d0ca42b726eb6cd0da6df0dd60d66 Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Fri, 8 Aug 2025 12:04:13 +0800 Subject: [PATCH 09/10] Manually set the imported target with name space --- core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake | 8 ++++---- core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake | 7 +------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake index d73082f587..db8f287e36 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake @@ -73,10 +73,10 @@ find_package_handle_standard_args(onnxruntime if(onnxruntime_FOUND) set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY}) set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR}) - - if(NOT TARGET onnxruntime) - add_library(onnxruntime UNKNOWN IMPORTED) - set_target_properties(onnxruntime PROPERTIES + + if(NOT TARGET onnxruntime::onnxruntime) + add_library(onnxruntime::onnxruntime UNKNOWN IMPORTED) + set_target_properties(onnxruntime::onnxruntime PROPERTIES IMPORTED_LOCATION "${onnxruntime_LIBRARY}" INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}" ) diff --git a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake index fefca7d4ef..56a7b44e4a 100644 --- a/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake +++ b/core/iwasm/libraries/wasi-nn/cmake/wasi_nn.cmake @@ -121,16 +121,11 @@ if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1) ${WASI_NN_ROOT}/src/wasi_nn_onnx.cpp ) - target_include_directories( - wasi_nn_onnx - PUBLIC ${INTERFACE_INCLUDE_DIRECTORIES} - ) - target_link_libraries( wasi_nn_onnx PUBLIC vmlib - onnxruntime + onnxruntime::onnxruntime ) install(TARGETS wasi_nn_onnx DESTINATION lib) From c70e22e645c8745d30d59929c5d60f406ef59ada Mon Sep 17 00:00:00 2001 From: "dongsheng.yan" Date: Tue, 12 Aug 2025 17:05:11 +0800 Subject: [PATCH 10/10] follow up comments remove dead code release memory --- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 63 ++++++------------- 1 file changed, 19 insertions(+), 44 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index e310d56531..44d8d66135 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -51,17 +51,6 @@ typedef struct { OnnxRuntimeExecCtx exec_ctxs[MAX_CONTEXTS]; } OnnxRuntimeContext; -/* Helper functions */ -static void -check_status_and_log(const OnnxRuntimeContext *ctx, OrtStatus *status) -{ - if (status != nullptr) { - const char *msg = ctx->ort_api->GetErrorMessage(status); - NN_ERR_PRINTF("ONNX Runtime error: %s", msg); - ctx->ort_api->ReleaseStatus(status); - } -} - static wasi_nn_error convert_ort_error_to_wasi_nn_error(const OnnxRuntimeContext *ctx, OrtStatus *status) @@ -104,37 +93,6 @@ convert_ort_error_to_wasi_nn_error(const OnnxRuntimeContext *ctx, return err; } -static bool -convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type, - tensor_type *tensor_type) -{ - switch (ort_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - *tensor_type = fp32; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - *tensor_type = fp16; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - *tensor_type = fp64; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - *tensor_type = u8; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - *tensor_type = i32; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - *tensor_type = i64; - break; - default: - NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); - return false; - } - - return true; -} - static bool convert_wasi_nn_type_to_ort_type(tensor_type type, ONNXTensorElementDataType *ort_type) @@ -191,7 +149,6 @@ init_backend(void **onnx_ctx) err = convert_ort_error_to_wasi_nn_error(ctx, status); NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s", error_message); - ctx->ort_api->ReleaseStatus(status); goto fail; } NN_INFO_PRINTF("ONNX Runtime environment created successfully"); @@ -274,6 +231,17 @@ deinit_backend(void *onnx_ctx) for (auto &output : ctx->exec_ctxs[i].outputs) { ctx->ort_api->ReleaseValue(output.second); } + + for (auto name : ctx->exec_ctxs[i].input_names) { + free((void *)name); + } + ctx->exec_ctxs[i].input_names.clear(); + + for (auto name : ctx->exec_ctxs[i].output_names) { + free((void *)name); + } + ctx->exec_ctxs[i].output_names.clear(); + ctx->ort_api->ReleaseMemoryInfo(ctx->exec_ctxs[i].memory_info); ctx->exec_ctxs[i].is_initialized = false; } @@ -293,6 +261,10 @@ __attribute__((visibility("default"))) wasi_nn_error load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, execution_target target, graph *g) { + if (!onnx_ctx) { + return runtime_error; + } + if (encoding != onnx) { NN_ERR_PRINTF("Unsupported encoding: %d", encoding); return invalid_encoding; @@ -349,7 +321,6 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, wasi_nn_error err = convert_ort_error_to_wasi_nn_error(ctx, status); NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s", error_message); - ctx->ort_api->ReleaseStatus(status); return err; } @@ -365,6 +336,10 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) { + if (!onnx_ctx) { + return runtime_error; + } + OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx; std::lock_guard lock(ctx->mutex);