Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "wasi_nn_backend.h"
#include "wasm_export.h"

#include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
Expand Down Expand Up @@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor *input_tensor)
{
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
TfLiteType tfl_type;

if (input_tensor->type != fp32) {
NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type);
return runtime_error;
switch (input_tensor->type) {
case fp32:
tfl_type = TfLiteType::kTfLiteFloat32;
break;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
case u8:
tfl_type = TfLiteType::kTfLiteUInt8;
break;
#endif
default:
NN_ERR_PRINTF("unsupported input tensor type %u",
input_tensor->type);
return runtime_error;
}

wasi_nn_error res;
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
return res;

uint32_t num_tensors =
tfl_ctx->interpreters[ctx].interpreter->inputs().size();
auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get();

uint32_t num_tensors = interpreter->inputs().size();
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
if (index + 1 > num_tensors) {
return runtime_error;
}

auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
auto tensor = interpreter->input_tensor(index);
if (tensor == NULL) {
NN_ERR_PRINTF("Missing memory");
return too_large;
}

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (TfLiteTensorType(tensor) != tfl_type) {
NN_ERR_PRINTF("Type mismatch");
return runtime_error;
}

if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf,
input_tensor->data.size)
!= kTfLiteOk) {
return runtime_error;
}
#else
uint32_t model_tensor_size = 1;
for (int i = 0; i < tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];
Expand Down Expand Up @@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
}
}
#endif

return success;
}
Expand Down Expand Up @@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
return too_large;
}

if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (output_tensor->size < tensor->bytes) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
size_t sz = TfLiteTensorByteSize(tensor);
if (output_tensor->size < sz) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) {
return runtime_error;
}
*output_tensor_size = sz;
#else
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
Expand All @@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
tensor->bytes);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = tensor->bytes;
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = tensor->bytes / sizeof(float);
#endif
}
else { // TODO: Assuming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
Expand All @@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
for (int i = 0; i < (int)tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (output_tensor->size / sizeof(float) < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
Expand All @@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif

uint8_t *ot = tfl_ctx->interpreters[ctx]
.interpreter->typed_output_tensor<uint8_t>(index);
Expand All @@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = model_tensor_size * sizeof(float);
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = model_tensor_size;
#endif
}
#endif

return success;
}
Expand Down
Loading