Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ZeGraphExtWrappers {

GraphDescriptor getGraphDescriptor(SerializedIR serializedIR,
const std::string& buildFlags,
const uint32_t& flags) const;
const Config& config) const;

GraphDescriptor getGraphDescriptor(void* data, size_t size) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,8 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con

_logger.debug("compileIR Build flags : %s", buildFlags.c_str());

// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
uint32_t flags = ZE_GRAPH_FLAG_NONE;
const auto set_cache_dir = config.get<CACHE_DIR>();
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
}

_logger.debug("compile start");
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, flags);
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, config);
_logger.debug("compile end");

OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");
Expand Down Expand Up @@ -161,13 +154,6 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
}
FilteredConfig updatedConfig = *plgConfig;

// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
uint32_t flags = ZE_GRAPH_FLAG_NONE;
const auto set_cache_dir = config.get<CACHE_DIR>();
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
}

// WS v3 is based on a stateless compiler. We'll use a separate config entry for informing the compiler the index of
// the current call iteration.
std::vector<NetworkMetadata> initNetworkMetadata;
Expand All @@ -191,7 +177,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
buildFlags += irSerializer.serializeConfig(updatedConfig, compilerVersion);

_logger.debug("compile start");
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, flags);
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, config);
_logger.debug("compile end");

OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <string_view>

#include "intel_npu/config/options.hpp"
#include "intel_npu/prefix.hpp"
#include "intel_npu/utils/utils.hpp"
#include "intel_npu/utils/zero/zero_api.hpp"
Expand Down Expand Up @@ -131,10 +132,6 @@ ZeGraphExtWrappers::ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHold
ZE_MAJOR_VERSION(_graphExtVersion),
ZE_MINOR_VERSION(_graphExtVersion));
_logger.debug("capabilities:");
_logger.debug("-SupportQuery: %d", true);
_logger.debug("-SupportAPIGraphQueryNetworkV1: %d", true);
_logger.debug("-SupportAPIGraphQueryNetworkV2 :%d", true);
_logger.debug("-SupportpfnCreate2 :%d", true);
_logger.debug("-SupportArgumentMetadata :%d", !NotSupportArgumentMetadata(_graphExtVersion));
_logger.debug("-UseCopyForNativeBinary :%d", UseCopyForNativeBinary(_graphExtVersion));
}
Expand Down Expand Up @@ -274,10 +271,8 @@ static std::unordered_set<std::string> parseQueryResult(std::vector<char>& data)

std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR serializedIR,
const std::string& buildFlags) const {
// For ext version >= 1.5
ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr;

// For ext version >= 1.5
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
nullptr,
ZE_GRAPH_FORMAT_NGRAPH_LITE,
Expand All @@ -286,14 +281,14 @@ std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR seri
buildFlags.c_str(),
ZE_GRAPH_FLAG_NONE};

// Create querynetwork handle
_logger.debug("For ext larger than 1.4 - perform pfnQueryNetworkCreate2");
ze_result_t result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
_zeroInitStruct->getDevice(),
&desc,
&hGraphQueryNetwork);
_logger.debug("queryGraph - perform pfnQueryNetworkCreate2");
auto result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
_zeroInitStruct->getDevice(),
&desc,
&hGraphQueryNetwork);
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkCreate2", result, _zeroInitStruct->getGraphDdiTable());

// Get the size of query result
_logger.debug("queryGraph - perform pfnQueryNetworkGetSupportedLayers to get size");
size_t size = 0;
result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkGetSupportedLayers(hGraphQueryNetwork, &size, nullptr);
Expand Down Expand Up @@ -341,8 +336,16 @@ bool ZeGraphExtWrappers::canCpuVaBeImported(void* data, size_t size) const {

GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR,
const std::string& buildFlags,
const uint32_t& flags) const {
// For ext version >= 1.5, calling pfnCreate2 api in _zeroInitStruct->getGraphDdiTable()
const Config& config) const {
ze_graph_handle_t graphHandle = nullptr;

// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
uint32_t flags = ZE_GRAPH_FLAG_NONE;
const auto set_cache_dir = config.get<CACHE_DIR>();
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
}

ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
nullptr,
ZE_GRAPH_FORMAT_NGRAPH_LITE,
Expand All @@ -352,8 +355,6 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR
flags};

_logger.debug("getGraphDescriptor - perform pfnCreate2");
// Create querynetwork handle
ze_graph_handle_t graphHandle = nullptr;
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
_zeroInitStruct->getDevice(),
&desc,
Expand All @@ -365,14 +366,14 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR

GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(void* blobData, size_t blobSize) const {
ze_graph_handle_t graphHandle = nullptr;
bool setPersistentFlag = false;

if (blobSize == 0) {
OPENVINO_THROW("Empty blob");
}

uint32_t flags = 0;
bool setPersistentFlag = canCpuVaBeImported(blobData, blobSize);

setPersistentFlag = canCpuVaBeImported(blobData, blobSize);
if (setPersistentFlag) {
_logger.debug("getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT");
flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
Expand All @@ -387,7 +388,6 @@ GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(void* blobData, size_t bl
flags};

_logger.debug("getGraphDescriptor - perform pfnCreate2");

auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
_zeroInitStruct->getDevice(),
&desc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class ZeroTensorTests : public ov::test::behavior::OVPluginTestBase,
ov::element::Type type;
std::tie(targetDevice, configuration, type) = obj.param;
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);

std::ostringstream result;
result << "targetDevice=" << targetDevice << "_";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class ZeroVariableStateTests : public ov::test::behavior::OVPluginTestBase,
ov::AnyMap configuration;
std::tie(targetDevice, configuration) = obj.param;
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);

std::ostringstream result;
result << "targetDevice=" << targetDevice << "_";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,27 @@

#include "zero_graph.hpp"

#include <common_test_utils/test_assertions.hpp>

namespace {
std::vector<int> graphDescflags = {ZE_GRAPH_FLAG_NONE, ZE_GRAPH_FLAG_DISABLE_CACHING, ZE_GRAPH_FLAG_ENABLE_PROFILING};
const std::vector<ov::AnyMap> configsGraphCompilationTests = {{},
{ov::cache_dir("test")},
{ov::intel_npu::bypass_umd_caching(true)}};

// tested versions interval is [1.5, CURRENT + 1)
auto extVersions = ::testing::Range(ZE_MAKE_VERSION(1, 5), ZE_GRAPH_EXT_VERSION_CURRENT + 1);
auto graphExtVersions = ::testing::Range(ZE_MAKE_VERSION(1, 5), ZE_GRAPH_EXT_VERSION_CURRENT + 1);

INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTest,
ZeroGraphCompilationTests,
::testing::Combine(::testing::ValuesIn(graphDescflags), extVersions),
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_NPU),
::testing::ValuesIn(configsGraphCompilationTests),
graphExtVersions),
ZeroGraphTest::getTestCaseName);

std::vector<int> noneGraphDescflags = {ZE_GRAPH_FLAG_NONE};
const std::vector<ov::AnyMap> emptyConfigsTests = {{}};

INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTest,
ZeroGraphTest,
::testing::Combine(::testing::ValuesIn(noneGraphDescflags), extVersions),
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_NPU),
::testing::ValuesIn(emptyConfigsTests),
graphExtVersions),
ZeroGraphTest::getTestCaseName);
} // namespace
Loading
Loading