Skip to content

Commit 33963cc

Browse files
committed
Add support for graph ext version 1.4
Signed-off-by: Bogdan Pereanu <bogdan.pereanu@intel.com>
1 parent 7f07884 commit 33963cc

File tree

7 files changed

+309
-198
lines changed

7 files changed

+309
-198
lines changed

src/plugins/intel_npu/src/compiler_adapter/include/ze_graph_ext_wrappers.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ZeGraphExtWrappers {
3535

3636
GraphDescriptor getGraphDescriptor(SerializedIR serializedIR,
3737
const std::string& buildFlags,
38-
const uint32_t& flags) const;
38+
const Config& config) const;
3939

4040
GraphDescriptor getGraphDescriptor(void* data, size_t size) const;
4141

@@ -60,6 +60,9 @@ class ZeGraphExtWrappers {
6060
bool isBlobDataImported(const GraphDescriptor& graphDescriptor) const;
6161

6262
private:
63+
std::unordered_set<std::string> getQueryResultFromSupportedLayers(
64+
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;
65+
6366
void getMetadata(ze_graph_handle_t graphHandle,
6467
uint32_t index,
6568
std::vector<IODescriptor>& inputs,

src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,8 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con
101101

102102
_logger.debug("compileIR Build flags : %s", buildFlags.c_str());
103103

104-
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
105-
uint32_t flags = ZE_GRAPH_FLAG_NONE;
106-
const auto set_cache_dir = config.get<CACHE_DIR>();
107-
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
108-
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
109-
}
110-
111104
_logger.debug("compile start");
112-
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, flags);
105+
auto graphDesc = _zeGraphExt->getGraphDescriptor(std::move(serializedIR), buildFlags, config);
113106
_logger.debug("compile end");
114107

115108
OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");
@@ -161,13 +154,6 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
161154
}
162155
FilteredConfig updatedConfig = *plgConfig;
163156

164-
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
165-
uint32_t flags = ZE_GRAPH_FLAG_NONE;
166-
const auto set_cache_dir = config.get<CACHE_DIR>();
167-
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
168-
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
169-
}
170-
171157
// WS v3 is based on a stateless compiler. We'll use a separate config entry for informing the compiler the index of
172158
// the current call iteration.
173159
std::vector<NetworkMetadata> initNetworkMetadata;
@@ -191,7 +177,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compileWS(const std::shared_ptr<o
191177
buildFlags += irSerializer.serializeConfig(updatedConfig, compilerVersion);
192178

193179
_logger.debug("compile start");
194-
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, flags);
180+
auto graphDesc = _zeGraphExt->getGraphDescriptor(serializedIR, buildFlags, config);
195181
_logger.debug("compile end");
196182

197183
OV_ITT_TASK_NEXT(COMPILE_BLOB, "getNetworkMeta");

src/plugins/intel_npu/src/compiler_adapter/src/ze_graph_ext_wrappers.cpp

Lines changed: 140 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <string_view>
1010

11+
#include "intel_npu/config/options.hpp"
1112
#include "intel_npu/prefix.hpp"
1213
#include "intel_npu/utils/utils.hpp"
1314
#include "intel_npu/utils/zero/zero_api.hpp"
@@ -17,6 +18,12 @@
1718
#include "openvino/core/dimension.hpp"
1819
#include "openvino/core/partial_shape.hpp"
1920

21+
// ext version >= 1.5, support API (pfnQueryNetworkCreate2, pfnQueryContextMemory)
22+
#define NotSupportAPIGraphQueryNetworkV2(T) (T < ZE_GRAPH_EXT_VERSION_1_5)
23+
24+
// For ext version >= 1.5, pfnCreate2 api is available
25+
#define NotSupportGraph2(T) (T < ZE_GRAPH_EXT_VERSION_1_5)
26+
2027
// A bug inside the driver makes the "pfnGraphGetArgumentMetadata" call not safe for use prior to
2128
// "ze_graph_dditable_ext_1_6_t".
2229
// See: E#117498
@@ -133,8 +140,8 @@ ZeGraphExtWrappers::ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHold
133140
_logger.debug("capabilities:");
134141
_logger.debug("-SupportQuery: %d", true);
135142
_logger.debug("-SupportAPIGraphQueryNetworkV1: %d", true);
136-
_logger.debug("-SupportAPIGraphQueryNetworkV2 :%d", true);
137-
_logger.debug("-SupportpfnCreate2 :%d", true);
143+
_logger.debug("-SupportAPIGraphQueryNetworkV2 :%d", !NotSupportAPIGraphQueryNetworkV2(_graphExtVersion));
144+
_logger.debug("-SupportpfnCreate2 :%d", !NotSupportGraph2(_graphExtVersion));
138145
_logger.debug("-SupportArgumentMetadata :%d", !NotSupportArgumentMetadata(_graphExtVersion));
139146
_logger.debug("-UseCopyForNativeBinary :%d", UseCopyForNativeBinary(_graphExtVersion));
140147
}
@@ -272,37 +279,19 @@ static std::unordered_set<std::string> parseQueryResult(std::vector<char>& data)
272279
return result;
273280
}
274281

275-
std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR serializedIR,
276-
const std::string& buildFlags) const {
277-
// For ext version >= 1.5
278-
ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr;
279-
280-
// For ext version >= 1.5
281-
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
282-
nullptr,
283-
ZE_GRAPH_FORMAT_NGRAPH_LITE,
284-
serializedIR.first,
285-
serializedIR.second.get(),
286-
buildFlags.c_str(),
287-
ZE_GRAPH_FLAG_NONE};
288-
289-
// Create querynetwork handle
290-
_logger.debug("For ext larger than 1.4 - perform pfnQueryNetworkCreate2");
291-
ze_result_t result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
292-
_zeroInitStruct->getDevice(),
293-
&desc,
294-
&hGraphQueryNetwork);
295-
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkCreate2", result, _zeroInitStruct->getGraphDdiTable());
296-
297-
_logger.debug("queryGraph - perform pfnQueryNetworkGetSupportedLayers to get size");
282+
std::unordered_set<std::string> ZeGraphExtWrappers::getQueryResultFromSupportedLayers(
283+
ze_graph_query_network_handle_t& hGraphQueryNetwork) const {
284+
// Get the size of query result
285+
_logger.debug("getQueryResultFromSupportedLayers - perform pfnQueryNetworkGetSupportedLayers to get size");
298286
size_t size = 0;
299-
result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkGetSupportedLayers(hGraphQueryNetwork, &size, nullptr);
287+
auto result =
288+
_zeroInitStruct->getGraphDdiTable().pfnQueryNetworkGetSupportedLayers(hGraphQueryNetwork, &size, nullptr);
300289
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkGetSupportedLayers get size of query result",
301290
result,
302291
_zeroInitStruct->getGraphDdiTable());
303292

304293
// Get the result data of query
305-
_logger.debug("queryGraph - perform pfnQueryNetworkGetSupportedLayers to get data");
294+
_logger.debug("getQueryResultFromSupportedLayers - perform pfnQueryNetworkGetSupportedLayers to get data");
306295
std::vector<char> supportedLayers(size);
307296
result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkGetSupportedLayers(hGraphQueryNetwork,
308297
&size,
@@ -311,13 +300,57 @@ std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR seri
311300
result,
312301
_zeroInitStruct->getGraphDdiTable());
313302

314-
_logger.debug("queryGraph - perform pfnQueryNetworkDestroy");
303+
_logger.debug("getQueryResultFromSupportedLayers - perform pfnQueryNetworkDestroy");
315304
result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkDestroy(hGraphQueryNetwork);
316305
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkDestroy", result, _zeroInitStruct->getGraphDdiTable());
317306

318307
return parseQueryResult(supportedLayers);
319308
}
320309

310+
std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR serializedIR,
311+
const std::string& buildFlags) const {
312+
if (NotSupportAPIGraphQueryNetworkV2(_graphExtVersion)) {
313+
// For ext version == 1.4, query network is supported
314+
ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr;
315+
316+
ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
317+
nullptr,
318+
ZE_GRAPH_FORMAT_NGRAPH_LITE,
319+
serializedIR.first,
320+
serializedIR.second.get(),
321+
buildFlags.c_str()};
322+
323+
_logger.debug("queryGraph - perform pfnQueryNetworkCreate");
324+
auto result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate(_zeroInitStruct->getContext(),
325+
_zeroInitStruct->getDevice(),
326+
&desc,
327+
&hGraphQueryNetwork);
328+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkCreate", result, _zeroInitStruct->getGraphDdiTable());
329+
330+
return getQueryResultFromSupportedLayers(hGraphQueryNetwork);
331+
} else {
332+
// For ext version >= 1.5
333+
ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr;
334+
335+
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
336+
nullptr,
337+
ZE_GRAPH_FORMAT_NGRAPH_LITE,
338+
serializedIR.first,
339+
serializedIR.second.get(),
340+
buildFlags.c_str(),
341+
ZE_GRAPH_FLAG_NONE};
342+
343+
_logger.debug("queryGraph - perform pfnQueryNetworkCreate2");
344+
auto result = _zeroInitStruct->getGraphDdiTable().pfnQueryNetworkCreate2(_zeroInitStruct->getContext(),
345+
_zeroInitStruct->getDevice(),
346+
&desc,
347+
&hGraphQueryNetwork);
348+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnQueryNetworkCreate2", result, _zeroInitStruct->getGraphDdiTable());
349+
350+
return getQueryResultFromSupportedLayers(hGraphQueryNetwork);
351+
}
352+
}
353+
321354
bool ZeGraphExtWrappers::canCpuVaBeImported(void* data, size_t size) const {
322355
if (_graphExtVersion < ZE_MAKE_VERSION(1, 13) ||
323356
!utils::memory_and_size_aligned_to_standard_page_size(data, size)) {
@@ -341,59 +374,100 @@ bool ZeGraphExtWrappers::canCpuVaBeImported(void* data, size_t size) const {
341374

342375
GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(SerializedIR serializedIR,
343376
const std::string& buildFlags,
344-
const uint32_t& flags) const {
345-
// For ext version >= 1.5, calling pfnCreate2 api in _zeroInitStruct->getGraphDdiTable()
346-
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
347-
nullptr,
348-
ZE_GRAPH_FORMAT_NGRAPH_LITE,
349-
serializedIR.first,
350-
serializedIR.second.get(),
351-
buildFlags.c_str(),
352-
flags};
353-
354-
_logger.debug("getGraphDescriptor - perform pfnCreate2");
355-
// Create querynetwork handle
377+
const Config& config) const {
356378
ze_graph_handle_t graphHandle = nullptr;
357-
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
358-
_zeroInitStruct->getDevice(),
359-
&desc,
360-
&graphHandle);
361-
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate2", result, _zeroInitStruct->getGraphDdiTable());
379+
380+
if (NotSupportGraph2(_graphExtVersion)) {
381+
// For ext version <1.5, calling pfnCreate api
382+
ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
383+
nullptr,
384+
ZE_GRAPH_FORMAT_NGRAPH_LITE,
385+
serializedIR.first,
386+
serializedIR.second.get(),
387+
buildFlags.c_str()};
388+
389+
_logger.debug("getGraphDescriptor - perform pfnCreate");
390+
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate(_zeroInitStruct->getContext(),
391+
_zeroInitStruct->getDevice(),
392+
&desc,
393+
&graphHandle);
394+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate", result, _zeroInitStruct->getGraphDdiTable());
395+
} else {
396+
// If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
397+
uint32_t flags = ZE_GRAPH_FLAG_NONE;
398+
const auto set_cache_dir = config.get<CACHE_DIR>();
399+
if (!set_cache_dir.empty() || config.get<BYPASS_UMD_CACHING>()) {
400+
flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
401+
}
402+
403+
// For ext version >= 1.5, calling pfnCreate2
404+
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
405+
nullptr,
406+
ZE_GRAPH_FORMAT_NGRAPH_LITE,
407+
serializedIR.first,
408+
serializedIR.second.get(),
409+
buildFlags.c_str(),
410+
flags};
411+
412+
_logger.debug("getGraphDescriptor - perform pfnCreate2");
413+
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
414+
_zeroInitStruct->getDevice(),
415+
&desc,
416+
&graphHandle);
417+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate2", result, _zeroInitStruct->getGraphDdiTable());
418+
}
362419

363420
return GraphDescriptor{graphHandle};
364421
}
365422

366423
GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor(void* blobData, size_t blobSize) const {
367424
ze_graph_handle_t graphHandle = nullptr;
425+
bool setPersistentFlag = false;
368426

369427
if (blobSize == 0) {
370428
OPENVINO_THROW("Empty blob");
371429
}
372430

373-
uint32_t flags = 0;
374-
bool setPersistentFlag = canCpuVaBeImported(blobData, blobSize);
431+
if (NotSupportGraph2(_graphExtVersion)) {
432+
// For ext version < 1.5, calling pfnCreate api
433+
ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
434+
nullptr,
435+
ZE_GRAPH_FORMAT_NATIVE,
436+
blobSize,
437+
reinterpret_cast<const uint8_t*>(blobData),
438+
nullptr};
439+
440+
_logger.debug("getGraphHandle - perform pfnCreate");
441+
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate(_zeroInitStruct->getContext(),
442+
_zeroInitStruct->getDevice(),
443+
&desc,
444+
&graphHandle);
445+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate", result, _zeroInitStruct->getGraphDdiTable());
446+
} else {
447+
uint32_t flags = 0;
448+
setPersistentFlag = canCpuVaBeImported(blobData, blobSize);
449+
if (setPersistentFlag) {
450+
_logger.debug("getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT");
451+
flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
452+
}
375453

376-
if (setPersistentFlag) {
377-
_logger.debug("getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT");
378-
flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
454+
// For ext version >= 1.5, calling pfnCreate2
455+
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
456+
nullptr,
457+
ZE_GRAPH_FORMAT_NATIVE,
458+
blobSize,
459+
reinterpret_cast<const uint8_t*>(blobData),
460+
nullptr,
461+
flags};
462+
463+
_logger.debug("getGraphDescriptor - perform pfnCreate2");
464+
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
465+
_zeroInitStruct->getDevice(),
466+
&desc,
467+
&graphHandle);
468+
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate2", result, _zeroInitStruct->getGraphDdiTable());
379469
}
380470

381-
ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
382-
nullptr,
383-
ZE_GRAPH_FORMAT_NATIVE,
384-
blobSize,
385-
reinterpret_cast<const uint8_t*>(blobData),
386-
nullptr,
387-
flags};
388-
389-
_logger.debug("getGraphDescriptor - perform pfnCreate2");
390-
391-
auto result = _zeroInitStruct->getGraphDdiTable().pfnCreate2(_zeroInitStruct->getContext(),
392-
_zeroInitStruct->getDevice(),
393-
&desc,
394-
&graphHandle);
395-
THROW_ON_FAIL_FOR_LEVELZERO_EXT("pfnCreate2", result, _zeroInitStruct->getGraphDdiTable());
396-
397471
return GraphDescriptor{graphHandle, setPersistentFlag};
398472
}
399473

src/plugins/intel_npu/tests/functional/internal/backend/zero_tensor_tests.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class ZeroTensorTests : public ov::test::behavior::OVPluginTestBase,
6161
ov::element::Type type;
6262
std::tie(targetDevice, configuration, type) = obj.param;
6363
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
64-
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);
6564

6665
std::ostringstream result;
6766
result << "targetDevice=" << targetDevice << "_";

src/plugins/intel_npu/tests/functional/internal/backend/zero_variable_state_tests.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class ZeroVariableStateTests : public ov::test::behavior::OVPluginTestBase,
5959
ov::AnyMap configuration;
6060
std::tie(targetDevice, configuration) = obj.param;
6161
std::replace(targetDevice.begin(), targetDevice.end(), ':', '_');
62-
targetDevice = ov::test::utils::getTestsPlatformFromEnvironmentOr(ov::test::utils::DEVICE_NPU);
6362

6463
std::ostringstream result;
6564
result << "targetDevice=" << targetDevice << "_";

0 commit comments

Comments
 (0)