-
Notifications
You must be signed in to change notification settings - Fork 706
Add onnxruntime as wasi-nn backend #4485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add onnxruntime as wasi-nn backend #4485
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the relationship with #4304?
#endif | ||
default: | ||
NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type); | ||
return fp32; // Default to fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
is there anything relying on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow up:
Type converter btw wasi-nn and onnx runtime returns bool instead of type
std::lock_guard<std::mutex> lock(g_ort_ctx.mutex); | ||
|
||
if (g_ort_ctx.is_initialized) { | ||
*onnx_ctx = &g_ort_ctx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you use globals?
i guess resources like graphs are not expected to be shared among unrelated instances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is not for sharing among instances? just a cache for next running
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as g_graphs is a global, any instances can access any graphs.
ditto for g_exec_ctxs.
why to cache? is it very expensive to build these objects?
"total: %zu bytes", | ||
tensor_size, element_size, output_size_bytes); | ||
|
||
if (*out_buffer_size < output_size_bytes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow up:
out_buffer_size will not hold the expected size.
status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue( | ||
exec_ctx->memory_info, input_tensor->data.buf, | ||
get_tensor_element_size(static_cast<tensor_type>(input_tensor->type)) | ||
* total_elements, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you really need to calculate the size by yourself?
isn't input_tensor->data.size enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow up:
onnx runtime will not calculate input_tenser size
1, Adapt to latest wasi-nn arch and support WAMR_BUILD_WASI_EPHEMERAL_NN |
1, type converter btw wasi-nn and onnx runtime returns bool instead of type 2, out_buffer_size does not hold the expected size. 3, onnx runtime does not need calculate input_tenser size.
@@ -21,7 +21,8 @@ | |||
#else | |||
#define WASI_NN_IMPORT(name) \ | |||
__attribute__((import_module("wasi_nn"), import_name(name))) | |||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) | |||
#warning \ | |||
"You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's simpler to move unrelated changes to separate PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1, clang-format change the line
2, agree with unrelated, can it be OK for this time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what version of clang-format are you using?
wamr is currently using clang-format-14.
it didn't change the line for me.
@@ -27,7 +27,7 @@ extern "C" { | |||
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) | |||
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) | |||
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) | |||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); | |||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Agree with unrelated, can it be OK for this time?
@@ -21,7 +21,8 @@ | |||
#include "wasm_export.h" | |||
|
|||
#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 | |||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) | |||
#warning \ | |||
"You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
std::lock_guard<std::mutex> lock(g_ort_ctx.mutex); | ||
|
||
if (g_ort_ctx.is_initialized) { | ||
*onnx_ctx = &g_ort_ctx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as g_graphs is a global, any instances can access any graphs.
ditto for g_exec_ctxs.
why to cache? is it very expensive to build these objects?
|
||
size_t input_tensor_size = 1; | ||
for (size_t i = 0; i < input_tensor->dimensions->size; ++i) | ||
input_tensor_size *= input_tensor->dimensions->buf[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_tensor_size and input_tensor_size are now unused?
aa88085
to
1fb25ad
Compare
1fb25ad
to
1e60909
Compare
1e60909
to
29e4dd5
Compare
#endif | ||
default: | ||
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type); | ||
return false; // Default to float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stale comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
wasi_nn_error err = success; | ||
OrtStatus *status = nullptr; | ||
OnnxRuntimeContext *ctx = nullptr; | ||
ctx = new OnnxRuntimeContext(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we typically use wasm_runtime_malloc for memory allocation.
i don't think wamr in general is prepared for C++ exceptions.
also, malloc/free is used in this file. it's better to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just considering the std::xxx object wrapped in them, new will trigger the creation for them together and release automatically. and seems wasi_nn_tensorflowlite also does that...
{ | ||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; | ||
|
||
if (g >= MAX_GRAPHS || !ort_ctx->graphs[g].is_initialized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's simpler to make these checks under the lock.
{ | ||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; | ||
|
||
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's simpler to make these checks under the lock.
{ | ||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; | ||
|
||
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's simpler to make these checks under the lock.
{ | ||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx; | ||
|
||
if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs[ctx].is_initialized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's simpler to make these checks under the lock.
NN_ERR_PRINTF("Failed to get input name"); | ||
return err; | ||
} | ||
exec_ctx->input_names.push_back(input_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can push_back raise an exception?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we catch that?
Steps to verify:
Generate output.bin, with shape [1, 100, 4] and f32 type, which contents match the sample's output