88
99#include < string_view>
1010
11+ #include " intel_npu/config/options.hpp"
1112#include " intel_npu/prefix.hpp"
1213#include " intel_npu/utils/utils.hpp"
1314#include " intel_npu/utils/zero/zero_api.hpp"
1718#include " openvino/core/dimension.hpp"
1819#include " openvino/core/partial_shape.hpp"
1920
21+ // ext version >= 1.5, support API (pfnQueryNetworkCreate2, pfnQueryContextMemory)
22+ #define NotSupportAPIGraphQueryNetworkV2 (T ) (T < ZE_GRAPH_EXT_VERSION_1_5)
23+
24+ // For ext version >= 1.5, pfnCreate2 api is available
25+ #define NotSupportGraph2 (T ) (T < ZE_GRAPH_EXT_VERSION_1_5)
26+
2027// A bug inside the driver makes the "pfnGraphGetArgumentMetadata" call not safe for use prior to
2128// "ze_graph_dditable_ext_1_6_t".
2229// See: E#117498
@@ -133,8 +140,8 @@ ZeGraphExtWrappers::ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHold
133140 _logger.debug (" capabilities:" );
134141 _logger.debug (" -SupportQuery: %d" , true );
135142 _logger.debug (" -SupportAPIGraphQueryNetworkV1: %d" , true );
136- _logger.debug (" -SupportAPIGraphQueryNetworkV2 :%d" , true );
137- _logger.debug (" -SupportpfnCreate2 :%d" , true );
143+ _logger.debug (" -SupportAPIGraphQueryNetworkV2 :%d" , ! NotSupportAPIGraphQueryNetworkV2 (_graphExtVersion) );
144+ _logger.debug (" -SupportpfnCreate2 :%d" , ! NotSupportGraph2 (_graphExtVersion) );
138145 _logger.debug (" -SupportArgumentMetadata :%d" , !NotSupportArgumentMetadata (_graphExtVersion));
139146 _logger.debug (" -UseCopyForNativeBinary :%d" , UseCopyForNativeBinary (_graphExtVersion));
140147}
@@ -272,37 +279,19 @@ static std::unordered_set<std::string> parseQueryResult(std::vector<char>& data)
272279 return result;
273280}
274281
275- std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph (SerializedIR serializedIR,
276- const std::string& buildFlags) const {
277- // For ext version >= 1.5
278- ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr ;
279-
280- // For ext version >= 1.5
281- ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
282- nullptr ,
283- ZE_GRAPH_FORMAT_NGRAPH_LITE,
284- serializedIR.first ,
285- serializedIR.second .get (),
286- buildFlags.c_str (),
287- ZE_GRAPH_FLAG_NONE};
288-
289- // Create querynetwork handle
290- _logger.debug (" For ext larger than 1.4 - perform pfnQueryNetworkCreate2" );
291- ze_result_t result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkCreate2 (_zeroInitStruct->getContext (),
292- _zeroInitStruct->getDevice (),
293- &desc,
294- &hGraphQueryNetwork);
295- THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnQueryNetworkCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
296-
297- _logger.debug (" queryGraph - perform pfnQueryNetworkGetSupportedLayers to get size" );
282+ std::unordered_set<std::string> ZeGraphExtWrappers::getQueryResultFromSupportedLayers (
283+ ze_graph_query_network_handle_t & hGraphQueryNetwork) const {
284+ // Get the size of query result
285+ _logger.debug (" getQueryResultFromSupportedLayers - perform pfnQueryNetworkGetSupportedLayers to get size" );
298286 size_t size = 0 ;
299- result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkGetSupportedLayers (hGraphQueryNetwork, &size, nullptr );
287+ auto result =
288+ _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkGetSupportedLayers (hGraphQueryNetwork, &size, nullptr );
300289 THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnQueryNetworkGetSupportedLayers get size of query result" ,
301290 result,
302291 _zeroInitStruct->getGraphDdiTable ());
303292
304293 // Get the result data of query
305- _logger.debug (" queryGraph - perform pfnQueryNetworkGetSupportedLayers to get data" );
294+ _logger.debug (" getQueryResultFromSupportedLayers - perform pfnQueryNetworkGetSupportedLayers to get data" );
306295 std::vector<char > supportedLayers (size);
307296 result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkGetSupportedLayers (hGraphQueryNetwork,
308297 &size,
@@ -311,13 +300,57 @@ std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph(SerializedIR seri
311300 result,
312301 _zeroInitStruct->getGraphDdiTable ());
313302
314- _logger.debug (" queryGraph - perform pfnQueryNetworkDestroy" );
303+ _logger.debug (" getQueryResultFromSupportedLayers - perform pfnQueryNetworkDestroy" );
315304 result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkDestroy (hGraphQueryNetwork);
316305 THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnQueryNetworkDestroy" , result, _zeroInitStruct->getGraphDdiTable ());
317306
318307 return parseQueryResult (supportedLayers);
319308}
320309
310+ std::unordered_set<std::string> ZeGraphExtWrappers::queryGraph (SerializedIR serializedIR,
311+ const std::string& buildFlags) const {
312+ if (NotSupportAPIGraphQueryNetworkV2 (_graphExtVersion)) {
313+ // For ext version == 1.4, query network is supported
314+ ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr ;
315+
316+ ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
317+ nullptr ,
318+ ZE_GRAPH_FORMAT_NGRAPH_LITE,
319+ serializedIR.first ,
320+ serializedIR.second .get (),
321+ buildFlags.c_str ()};
322+
323+ _logger.debug (" queryGraph - perform pfnQueryNetworkCreate" );
324+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkCreate (_zeroInitStruct->getContext (),
325+ _zeroInitStruct->getDevice (),
326+ &desc,
327+ &hGraphQueryNetwork);
328+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnQueryNetworkCreate" , result, _zeroInitStruct->getGraphDdiTable ());
329+
330+ return getQueryResultFromSupportedLayers (hGraphQueryNetwork);
331+ } else {
332+ // For ext version >= 1.5
333+ ze_graph_query_network_handle_t hGraphQueryNetwork = nullptr ;
334+
335+ ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
336+ nullptr ,
337+ ZE_GRAPH_FORMAT_NGRAPH_LITE,
338+ serializedIR.first ,
339+ serializedIR.second .get (),
340+ buildFlags.c_str (),
341+ ZE_GRAPH_FLAG_NONE};
342+
343+ _logger.debug (" queryGraph - perform pfnQueryNetworkCreate2" );
344+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnQueryNetworkCreate2 (_zeroInitStruct->getContext (),
345+ _zeroInitStruct->getDevice (),
346+ &desc,
347+ &hGraphQueryNetwork);
348+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnQueryNetworkCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
349+
350+ return getQueryResultFromSupportedLayers (hGraphQueryNetwork);
351+ }
352+ }
353+
321354bool ZeGraphExtWrappers::canCpuVaBeImported (void * data, size_t size) const {
322355 if (_graphExtVersion < ZE_MAKE_VERSION (1 , 13 ) ||
323356 !utils::memory_and_size_aligned_to_standard_page_size (data, size)) {
@@ -341,59 +374,100 @@ bool ZeGraphExtWrappers::canCpuVaBeImported(void* data, size_t size) const {
341374
342375GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor (SerializedIR serializedIR,
343376 const std::string& buildFlags,
344- const uint32_t & flags) const {
345- // For ext version >= 1.5, calling pfnCreate2 api in _zeroInitStruct->getGraphDdiTable()
346- ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
347- nullptr ,
348- ZE_GRAPH_FORMAT_NGRAPH_LITE,
349- serializedIR.first ,
350- serializedIR.second .get (),
351- buildFlags.c_str (),
352- flags};
353-
354- _logger.debug (" getGraphDescriptor - perform pfnCreate2" );
355- // Create querynetwork handle
377+ const Config& config) const {
356378 ze_graph_handle_t graphHandle = nullptr ;
357- auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate2 (_zeroInitStruct->getContext (),
358- _zeroInitStruct->getDevice (),
359- &desc,
360- &graphHandle);
361- THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
379+
380+ if (NotSupportGraph2 (_graphExtVersion)) {
381+ // For ext version <1.5, calling pfnCreate api
382+ ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
383+ nullptr ,
384+ ZE_GRAPH_FORMAT_NGRAPH_LITE,
385+ serializedIR.first ,
386+ serializedIR.second .get (),
387+ buildFlags.c_str ()};
388+
389+ _logger.debug (" getGraphDescriptor - perform pfnCreate" );
390+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate (_zeroInitStruct->getContext (),
391+ _zeroInitStruct->getDevice (),
392+ &desc,
393+ &graphHandle);
394+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate" , result, _zeroInitStruct->getGraphDdiTable ());
395+ } else {
396+ // If UMD Caching is requested to be bypassed or if OV cache is enabled, disable driver caching
397+ uint32_t flags = ZE_GRAPH_FLAG_NONE;
398+ const auto set_cache_dir = config.get <CACHE_DIR>();
399+ if (!set_cache_dir.empty () || config.get <BYPASS_UMD_CACHING>()) {
400+ flags = flags | ZE_GRAPH_FLAG_DISABLE_CACHING;
401+ }
402+
403+ // For ext version >= 1.5, calling pfnCreate2
404+ ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
405+ nullptr ,
406+ ZE_GRAPH_FORMAT_NGRAPH_LITE,
407+ serializedIR.first ,
408+ serializedIR.second .get (),
409+ buildFlags.c_str (),
410+ flags};
411+
412+ _logger.debug (" getGraphDescriptor - perform pfnCreate2" );
413+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate2 (_zeroInitStruct->getContext (),
414+ _zeroInitStruct->getDevice (),
415+ &desc,
416+ &graphHandle);
417+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
418+ }
362419
363420 return GraphDescriptor{graphHandle};
364421}
365422
366423GraphDescriptor ZeGraphExtWrappers::getGraphDescriptor (void * blobData, size_t blobSize) const {
367424 ze_graph_handle_t graphHandle = nullptr ;
425+ bool setPersistentFlag = false ;
368426
369427 if (blobSize == 0 ) {
370428 OPENVINO_THROW (" Empty blob" );
371429 }
372430
373- uint32_t flags = 0 ;
374- bool setPersistentFlag = canCpuVaBeImported (blobData, blobSize);
431+ if (NotSupportGraph2 (_graphExtVersion)) {
432+ // For ext version < 1.5, calling pfnCreate api
433+ ze_graph_desc_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
434+ nullptr ,
435+ ZE_GRAPH_FORMAT_NATIVE,
436+ blobSize,
437+ reinterpret_cast <const uint8_t *>(blobData),
438+ nullptr };
439+
440+ _logger.debug (" getGraphHandle - perform pfnCreate" );
441+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate (_zeroInitStruct->getContext (),
442+ _zeroInitStruct->getDevice (),
443+ &desc,
444+ &graphHandle);
445+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate" , result, _zeroInitStruct->getGraphDdiTable ());
446+ } else {
447+ uint32_t flags = 0 ;
448+ setPersistentFlag = canCpuVaBeImported (blobData, blobSize);
449+ if (setPersistentFlag) {
450+ _logger.debug (" getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT" );
451+ flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
452+ }
375453
376- if (setPersistentFlag) {
377- _logger.debug (" getGraphDescriptor - set ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT" );
378- flags = ZE_GRAPH_FLAG_INPUT_GRAPH_PERSISTENT;
454+ // For ext version >= 1.5, calling pfnCreate2
455+ ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
456+ nullptr ,
457+ ZE_GRAPH_FORMAT_NATIVE,
458+ blobSize,
459+ reinterpret_cast <const uint8_t *>(blobData),
460+ nullptr ,
461+ flags};
462+
463+ _logger.debug (" getGraphDescriptor - perform pfnCreate2" );
464+ auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate2 (_zeroInitStruct->getContext (),
465+ _zeroInitStruct->getDevice (),
466+ &desc,
467+ &graphHandle);
468+ THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
379469 }
380470
381- ze_graph_desc_2_t desc = {ZE_STRUCTURE_TYPE_GRAPH_DESC_PROPERTIES,
382- nullptr ,
383- ZE_GRAPH_FORMAT_NATIVE,
384- blobSize,
385- reinterpret_cast <const uint8_t *>(blobData),
386- nullptr ,
387- flags};
388-
389- _logger.debug (" getGraphDescriptor - perform pfnCreate2" );
390-
391- auto result = _zeroInitStruct->getGraphDdiTable ().pfnCreate2 (_zeroInitStruct->getContext (),
392- _zeroInitStruct->getDevice (),
393- &desc,
394- &graphHandle);
395- THROW_ON_FAIL_FOR_LEVELZERO_EXT (" pfnCreate2" , result, _zeroInitStruct->getGraphDdiTable ());
396-
397471 return GraphDescriptor{graphHandle, setPersistentFlag};
398472}
399473
0 commit comments