From cbd0fe1b5f92ebc3dc17a49a58fe334d32d809e8 Mon Sep 17 00:00:00 2001 From: YAMAMOTO Takashi Date: Mon, 28 Jul 2025 09:55:33 +0900 Subject: [PATCH] wasi_nn_tensorflowlite.cpp: make this compatible with wasmedge for wasi_ephemeral_nn, * implement u8 input * stop dealing with quantization. * wasi-nn doesn't have a concept of quantization or pre/post-processing. i can't think of any ways to make the backend perform zero-point/scale processing without risking to break other applications. * there seems to be applications which just use u8 inputs/outputs for a quantized model. (see [1] for an example.) for certain kinds of inputs/outputs, it usually just works. this commit keeps the legacy wasi_nn logic intact for now. tested with [1] with [2] applied. WAMR with this patch: ``` Read graph weights, size in bytes: 3561598 [wasi_nn.c:297 WARNING] load_by_name_with_config() not found [wasi_nn_tensorflowlite.cpp:272 WARNING] Default encoding is CPU. Loaded graph into wasi-nn with ID: Graph#0 Read input tensor, size in bytes: 150528 1.) [166](198)Aix galericulata 2.) [34](1)Gallus gallus domesticus 3.) [158](1)Coccothraustes coccothraustes 4.) [778](1)Sitta europaea 5.) [819](1)Anas platyrhynchos ``` wasmedge: ``` Read graph weights, size in bytes: 3561598 Loaded graph into wasi-nn with ID: Graph#0 Read input tensor, size in bytes: 150528 1.) [166](198)Aix galericulata 2.) [34](1)Gallus gallus domesticus 3.) [158](1)Coccothraustes coccothraustes 4.) [778](1)Sitta europaea 5.) [819](1)Anas platyrhynchos ``` and "Aix galericulata" seems like a reasonable classification of the image to my eyes. [1] https://github.com/second-state/WasmEdge-WASINN-examples/tree/67f174bab59d98c1b52f7367ec0928701dc998f9/tflite-birds_v1-image [2] https://github.com/second-state/WasmEdge-WASINN-examples/pull/204 Related: https://github.com/bytecodealliance/wasm-micro-runtime/issues/3555 https://github.com/bytecodealliance/wasm-micro-runtime/issues/2611 --- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 819bd52aff..9ac54e6644 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -9,6 +9,7 @@ #include "wasi_nn_backend.h" #include "wasm_export.h" +#include #include #include #include @@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index, tensor *input_tensor) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; + TfLiteType tfl_type; - if (input_tensor->type != fp32) { - NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type); - return runtime_error; + switch (input_tensor->type) { + case fp32: + tfl_type = TfLiteType::kTfLiteFloat32; + break; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case u8: + tfl_type = TfLiteType::kTfLiteUInt8; + break; +#endif + default: + NN_ERR_PRINTF("unsupported input tensor type %u", + input_tensor->type); + return runtime_error; } wasi_nn_error res; if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx))) return res; - uint32_t num_tensors = - tfl_ctx->interpreters[ctx].interpreter->inputs().size(); + auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get(); + + uint32_t num_tensors = interpreter->inputs().size(); NN_DBG_PRINTF("Number of tensors (%d)", num_tensors); if (index + 1 > num_tensors) { return runtime_error; } - auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index); + auto tensor = interpreter->input_tensor(index); if (tensor == NULL) { NN_ERR_PRINTF("Missing memory"); return too_large; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (TfLiteTensorType(tensor) != tfl_type) { + NN_ERR_PRINTF("Type mismatch"); + return runtime_error; + } + + if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf, + input_tensor->data.size) + != kTfLiteOk) { + return runtime_error; + } +#else uint32_t model_tensor_size = 1; for (int i = 0; i < tensor->dims->size; ++i) model_tensor_size *= (uint32_t)tensor->dims->data[i]; @@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index, it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point); } } +#endif return success; } @@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, return too_large; } - if (tensor->quantization.type == kTfLiteNoQuantization) { - NN_DBG_PRINTF("No quantization information"); #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (output_tensor->size < tensor->bytes) { - NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return too_large; - } + size_t sz = TfLiteTensorByteSize(tensor); + if (output_tensor->size < sz) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } + if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) { + return runtime_error; + } + *output_tensor_size = sz; #else + if (tensor->quantization.type == kTfLiteNoQuantization) { + NN_DBG_PRINTF("No quantization information"); /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. @@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); return too_large; } -#endif bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data, tensor->bytes); -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - *output_tensor_size = tensor->bytes; -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. */ *output_tensor_size = tensor->bytes / sizeof(float); -#endif } else { // TODO: Assuming uint8 quantized networks. TfLiteAffineQuantization *quant_info = @@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, for (int i = 0; i < (int)tensor->dims->size; ++i) model_tensor_size *= (uint32_t)tensor->dims->data[i]; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (output_tensor->size / sizeof(float) < model_tensor_size) { - NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return too_large; - } -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. @@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); return too_large; } -#endif uint8_t *ot = tfl_ctx->interpreters[ctx] .interpreter->typed_output_tensor(index); @@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, output_tensor_f[i] = (ot[i] - zero_point) * scale; } -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - *output_tensor_size = model_tensor_size * sizeof(float); -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. */ *output_tensor_size = model_tensor_size; -#endif } +#endif return success; }