diff --git a/src/frontends/onnx/frontend/include/openvino/frontend/onnx/decoder.hpp b/src/frontends/onnx/frontend/include/openvino/frontend/onnx/decoder.hpp index a7b3fd0f2cac67..3cfde73f9476ba 100644 --- a/src/frontends/onnx/frontend/include/openvino/frontend/onnx/decoder.hpp +++ b/src/frontends/onnx/frontend/include/openvino/frontend/onnx/decoder.hpp @@ -16,6 +16,7 @@ struct ONNX_FRONTEND_API TensorMetaInfo { ov::PartialShape m_partial_shape; ov::element::Type m_element_type; const uint8_t* m_tensor_data; + ov::Any m_tensor_data_any; size_t m_tensor_data_size; const std::string* m_tensor_name; std::shared_ptr m_external_location; diff --git a/src/frontends/onnx/frontend/src/core/decoder_proto.hpp b/src/frontends/onnx/frontend/src/core/decoder_proto.hpp index 946a262e1ccb0a..a41923f48f24f3 100644 --- a/src/frontends/onnx/frontend/src/core/decoder_proto.hpp +++ b/src/frontends/onnx/frontend/src/core/decoder_proto.hpp @@ -166,7 +166,7 @@ class DecoderProto : public ov::frontend::onnx::DecoderBaseOperation { } const std::string& get_domain() const override { - return (m_node->has_domain() && m_node->domain() != "ai.onnx" ? m_node->domain() : DEFAULT_DOMAIN); + return (m_node->has_domain() ? m_node->domain() : DEFAULT_DOMAIN); } bool has_attribute(const std::string& name) const override; diff --git a/src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp b/src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp index a8a457b631573a..697bf1723304ee 100644 --- a/src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp +++ b/src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp @@ -6,10 +6,17 @@ #include +#include +#include +#include +#include #include -#include +#include +#include +#include #include "decoder_proto.hpp" +#include "openvino/frontend/graph_iterator.hpp" #include "openvino/frontend/onnx/graph_iterator.hpp" #include "openvino/util/file_util.hpp" #include "openvino/util/wstring_convert_util.hpp" @@ -259,6 +266,11 @@ ov::frontend::onnx::TensorMetaInfo extract_tensor_meta_info(const TensorProto* t static_cast(static_cast(tensor_info->double_data().data())); tensor_meta_info.m_tensor_data_size = tensor_info->double_data_size(); break; + case TensorProto_DataType::TensorProto_DataType_STRING: + tensor_meta_info.m_tensor_data_any = + std::vector(tensor_info->string_data().begin(), tensor_info->string_data().end()); + tensor_meta_info.m_tensor_data_size = tensor_info->string_data_size(); + break; default: throw std::runtime_error("Unsupported type " + ::ONNX_NAMESPACE::TensorProto_DataType_Name(tensor_info->data_type())); @@ -304,10 +316,10 @@ void GraphIteratorProto::initialize(const std::string& path) { m_model_dir = std::make_shared(ov::util::get_directory(path).string()); try { std::ifstream model_file(path, std::ios::binary | std::ios::in); - FRONT_END_GENERAL_CHECK(model_file && model_file.is_open(), "Model file does not exist: ", path); + FRONT_END_GENERAL_CHECK(model_file && model_file.is_open(), "Could not open the file: ", path); m_model = std::make_shared(); - m_model->ParseFromIstream(&model_file); + FRONT_END_GENERAL_CHECK(m_model->ParseFromIstream(&model_file), "Model can't be parsed"); model_file.close(); if (m_model->has_graph()) { m_graph = &m_model->graph(); @@ -320,6 +332,8 @@ void GraphIteratorProto::initialize(const std::string& path) { m_graph = nullptr; node_index = 0; m_decoders.clear(); + m_tensors.clear(); + throw; } } @@ -445,7 +459,7 @@ void GraphIteratorProto::reset() { output_tensors.push_back(&this->get_tensor(empty_name, &tensor_owner)->get_tensor_info()); } } - const std::string& domain = node.has_domain() && node.domain() != "ai.onnx" ? node.domain() : DEFAULT_DOMAIN; + const std::string& domain = node.has_domain() ? node.domain() : DEFAULT_DOMAIN; int64_t opset = get_opset_version(domain); if (opset == -1) { // Forcing a first opset instead of failing @@ -477,7 +491,8 @@ std::int64_t GraphIteratorProto::get_opset_version(const std::string& domain) co }); for (const auto& opset_import : opset_imports) { - if (domain == opset_import.domain()) { + if ((domain == DEFAULT_DOMAIN && opset_import.domain() == "ai.onnx") || + (domain == "ai.onnx" && opset_import.domain() == DEFAULT_DOMAIN) || (domain == opset_import.domain())) { return opset_import.version(); } } @@ -485,21 +500,6 @@ std::int64_t GraphIteratorProto::get_opset_version(const std::string& domain) co return -1; } -} // namespace onnx -} // namespace frontend -} // namespace ov - -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace onnx { namespace detail { namespace { enum Field { diff --git a/src/frontends/onnx/frontend/src/core/node.cpp b/src/frontends/onnx/frontend/src/core/node.cpp index 97c57c81149018..dc7c3d90e9a134 100644 --- a/src/frontends/onnx/frontend/src/core/node.cpp +++ b/src/frontends/onnx/frontend/src/core/node.cpp @@ -800,6 +800,7 @@ Tensor Node::get_attribute_value(const std::string& name) const { std::vector{*tensor_meta_info.m_tensor_name}, tensor_meta_info.m_tensor_data, tensor_meta_info.m_tensor_data_size, + tensor_meta_info.m_tensor_data_any, tensor_meta_info.m_external_location, tensor_meta_info.m_is_raw); return {tensor_place}; @@ -826,6 +827,7 @@ SparseTensor Node::get_attribute_value(const std::string& name) const { std::vector{*values_meta_info.m_tensor_name}, values_meta_info.m_tensor_data, values_meta_info.m_tensor_data_size, + values_meta_info.m_tensor_data_any, values_meta_info.m_external_location, values_meta_info.m_is_raw); @@ -839,6 +841,7 @@ SparseTensor Node::get_attribute_value(const std::string& name) const { std::vector{*indices_meta_info.m_tensor_name}, indices_meta_info.m_tensor_data, indices_meta_info.m_tensor_data_size, + indices_meta_info.m_tensor_data_any, indices_meta_info.m_external_location, indices_meta_info.m_is_raw); return {values_place, indices_place, sparse_tensor_info.m_partial_shape}; diff --git a/src/frontends/onnx/frontend/src/core/tensor.cpp b/src/frontends/onnx/frontend/src/core/tensor.cpp index 6e4574b30dc8f0..5ba4f461295e79 100644 --- a/src/frontends/onnx/frontend/src/core/tensor.cpp +++ b/src/frontends/onnx/frontend/src/core/tensor.cpp @@ -366,13 +366,15 @@ std::vector Tensor::get_data() const { template <> std::vector Tensor::get_data() const { - if (m_tensor_place != nullptr) { - FRONT_END_NOT_IMPLEMENTED(get_data); - } - if (has_external_data()) { FRONT_END_THROW("External strings are not supported"); } + if (m_tensor_place != nullptr) { + FRONT_END_GENERAL_CHECK(!m_tensor_place->is_raw(), "Loading strings from raw data isn't supported"); + FRONT_END_GENERAL_CHECK(m_tensor_place->get_data_any().is>(), + "Tensor data type mismatch for strings"); + return m_tensor_place->get_data_any().as>(); + } if (m_tensor_proto->has_raw_data()) { FRONT_END_THROW("Loading strings from raw data isn't supported"); } @@ -477,11 +479,6 @@ std::shared_ptr Tensor::get_ov_constant() const { "UINT4, UINT8, UINT16, UINT32, UINT64, STRING"); } } else if (element_count == shape_size(m_shape) && m_tensor_place != nullptr) { -#if 0 - constant = std::make_shared(m_tensor_place->get_element_type(), - m_tensor_place->get_partial_shape().get_shape(), - m_tensor_place->get_data()); -#endif switch (m_tensor_place->get_element_type()) { case ov::element::f32: case ov::element::f64: @@ -524,6 +521,9 @@ std::shared_ptr Tensor::get_ov_constant() const { case ov::element::f8e5m2: constant = std::make_shared(ov_type, m_shape, get_data().data()); break; + case ov::element::string: + constant = std::make_shared(ov_type, m_shape, get_data().data()); + break; default: ONNX_UNSUPPORTED_DATA_TYPE( m_tensor_proto->data_type(), diff --git a/src/frontends/onnx/frontend/src/core/tensor.hpp b/src/frontends/onnx/frontend/src/core/tensor.hpp index aa904610323189..dae0bb3547101e 100644 --- a/src/frontends/onnx/frontend/src/core/tensor.hpp +++ b/src/frontends/onnx/frontend/src/core/tensor.hpp @@ -85,11 +85,13 @@ class TensorONNXPlace : public ov::frontend::onnx::TensorPlace { const std::vector& names, const void* data, const size_t data_size, + const ov::Any& data_any, std::shared_ptr data_location, const bool is_raw) : ov::frontend::onnx::TensorPlace(input_model, pshape, type, names), m_input_model(input_model), m_data(data), + m_data_any(data_any), m_data_size(data_size), m_data_location(data_location), m_is_raw(is_raw) {}; @@ -125,6 +127,10 @@ class TensorONNXPlace : public ov::frontend::onnx::TensorPlace { return m_data_size; } + const ov::Any get_data_any() const { + return m_data_any; + } + std::shared_ptr get_data_location() const { return m_data_location; } @@ -140,6 +146,7 @@ class TensorONNXPlace : public ov::frontend::onnx::TensorPlace { int64_t m_input_idx = -1, m_output_idx = -1; const ov::frontend::InputModel& m_input_model; const void* m_data; + ov::Any m_data_any; size_t m_data_size; std::shared_ptr m_data_location; bool m_is_raw; diff --git a/src/frontends/onnx/frontend/src/input_model.cpp b/src/frontends/onnx/frontend/src/input_model.cpp index 00d9972f9c5eac..ee9d83120842a3 100644 --- a/src/frontends/onnx/frontend/src/input_model.cpp +++ b/src/frontends/onnx/frontend/src/input_model.cpp @@ -656,6 +656,7 @@ std::shared_ptr decode_tensor_place( std::vector{*tensor_meta_info.m_tensor_name}, tensor_meta_info.m_tensor_data, tensor_meta_info.m_tensor_data_size, + tensor_meta_info.m_tensor_data_any, tensor_meta_info.m_external_location, tensor_meta_info.m_is_raw); return tensor_place; @@ -774,19 +775,8 @@ InputModel::InputModelONNXImpl::InputModelONNXImpl(const GraphIterator::Ptr& gra const ov::frontend::InputModel& input_model, const std::shared_ptr& telemetry, const bool enable_mmap) - : m_graph_iterator(graph_iterator), - m_input_model(input_model), - m_telemetry(telemetry), - m_enable_mmap(enable_mmap) { - FRONT_END_GENERAL_CHECK(m_graph_iterator, "Null pointer specified for GraphIterator"); - if (m_enable_mmap) { - m_mmap_cache = std::make_shared>>(); - m_stream_cache = nullptr; - } else { - m_mmap_cache = nullptr; - m_stream_cache = std::make_shared>>(); - } - load_model(); + : InputModelONNXImpl(graph_iterator, input_model, enable_mmap) { + m_telemetry = telemetry; } InputModel::InputModelONNXImpl::InputModelONNXImpl(const GraphIterator::Ptr& graph_iterator, diff --git a/src/frontends/onnx/tests/onnx_importer_test.cpp b/src/frontends/onnx/tests/onnx_importer_test.cpp index 6dc364edd45221..fca22e1ad1cfef 100644 --- a/src/frontends/onnx/tests/onnx_importer_test.cpp +++ b/src/frontends/onnx/tests/onnx_importer_test.cpp @@ -104,7 +104,7 @@ TEST(ONNX_Importer_Tests, ImportModelWithNotSupportedOp) { TEST(ONNX_Importer_Tests, ImportModelWhenFileDoesNotExist) { try { auto model = convert_model("not_exist_file.onnx"); - FAIL() << "Any expection was thrown despite the ONNX model file does not exist"; + FAIL() << "No exception was thrown despite the ONNX model file does not exist"; } catch (const Exception& error) { EXPECT_PRED_FORMAT2(testing::IsSubstring, std::string("Could not open the file"), error.what()); } catch (...) {