Skip to content

Commit d002d08

Browse files
committed
copy-back for non-USM memory
Signed-off-by: Mateusz P. Nowak <mateusz.p.nowak@intel.com>
1 parent f1b62a9 commit d002d08

File tree

2 files changed

+48
-62
lines changed

2 files changed

+48
-62
lines changed

unified-runtime/source/adapters/level_zero/v2/memory.cpp

Lines changed: 45 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11+
#include "memory.hpp"
1112
#include "../ur_interface_loader.hpp"
1213
#include "context.hpp"
13-
#include "memory.hpp"
1414

1515
#include "../helpers/memory_helpers.hpp"
1616
#include "../image_common.hpp"
@@ -66,33 +66,45 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t(
6666
if (ret == UR_RESULT_SUCCESS && memProps.type != ZE_MEMORY_TYPE_UNKNOWN) {
6767
// Already a USM allocation - just use it directly without import
6868
this->ptr = usm_unique_ptr_t(hostPtr, [](void *) {});
69-
} else {
70-
// Not USM - try to import it
71-
bool hostPtrImported =
72-
maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated,
73-
hContext->getZeHandle(), hostPtr, size);
69+
// No copy-back needed for USM pointers
70+
return;
71+
}
7472

75-
if (!hostPtrImported) {
76-
throw UR_RESULT_ERROR_INVALID_VALUE;
77-
}
73+
// Not USM - try to import it
74+
bool hostPtrImported =
75+
maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated,
76+
hContext->getZeHandle(), hostPtr, size);
7877

78+
if (hostPtrImported) {
79+
// Successfully imported - use it with release
7980
this->ptr = usm_unique_ptr_t(hostPtr, [hContext](void *ptr) {
8081
ZeUSMImport.doZeUSMRelease(
8182
hContext->getPlatform()->ZeDriverHandleExpTranslated, ptr);
8283
});
84+
// No copy-back needed for imported pointers
85+
return;
8386
}
84-
} else {
85-
// No host pointer - allocate new USM host memory
86-
void *rawPtr;
87-
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
88-
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &rawPtr));
89-
90-
this->ptr = usm_unique_ptr_t(rawPtr, [hContext](void *ptr) {
91-
auto ret = hContext->getDefaultUSMPool()->free(ptr);
92-
if (ret != UR_RESULT_SUCCESS) {
93-
UR_LOG(ERR, "Failed to free host memory: {}", ret);
94-
}
95-
});
87+
88+
// Import failed - allocate backing buffer and set up copy-back
89+
}
90+
91+
// No host pointer, or import failed - allocate new USM host memory
92+
void *rawPtr;
93+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
94+
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &rawPtr));
95+
96+
this->ptr = usm_unique_ptr_t(rawPtr, [hContext](void *ptr) {
97+
auto ret = hContext->getDefaultUSMPool()->free(ptr);
98+
if (ret != UR_RESULT_SUCCESS) {
99+
UR_LOG(ERR, "Failed to free host memory: {}", ret);
100+
}
101+
});
102+
103+
if (hostPtr) {
104+
// Copy data from user pointer to our backing buffer
105+
std::memcpy(this->ptr.get(), hostPtr, size);
106+
// Remember to copy back on destruction
107+
writeBackPtr = hostPtr;
96108
}
97109
}
98110

@@ -108,6 +120,12 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t(
108120
});
109121
}
110122

123+
ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() {
124+
if (writeBackPtr) {
125+
std::memcpy(writeBackPtr, ptr.get(), size);
126+
}
127+
}
128+
111129
void *ur_integrated_buffer_handle_t::getDevicePtr(
112130
ur_device_handle_t /*hDevice*/, device_access_mode_t /*access*/,
113131
size_t offset, size_t /*size*/, ze_command_list_handle_t /*cmdList*/,
@@ -564,47 +582,12 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
564582
void *hostPtr = pProperties ? pProperties->pHost : nullptr;
565583
auto accessMode = ur_mem_buffer_t::getDeviceAccessMode(flags);
566584

567-
// For integrated devices, we can use zero-copy host buffers when:
568-
// 1. No host pointer is provided (we'll allocate USM host memory)
569-
// 2. Host pointer is already USM memory
570-
// 3. Host pointer can be imported as USM
571-
// Otherwise, fall back to discrete buffer (explicit copies).
572-
if (useHostBuffer(hContext) && hostPtr) {
573-
// Check what type of memory this pointer is
574-
ZeStruct<ze_memory_allocation_properties_t> memProps;
575-
auto ret =
576-
getMemoryAttrs(hContext->getZeHandle(), hostPtr, nullptr, &memProps);
577-
578-
if (ret == UR_RESULT_SUCCESS) {
579-
if (memProps.type != ZE_MEMORY_TYPE_UNKNOWN) {
580-
// Already USM memory (host, device, or shared) - use integrated path
581-
*phBuffer = ur_mem_handle_t_::create<ur_integrated_buffer_handle_t>(
582-
hContext, hostPtr, size, accessMode);
583-
return UR_RESULT_SUCCESS;
584-
}
585-
586-
// Memory type is UNKNOWN - try to import it
587-
bool canImport =
588-
maybeImportUSM(hContext->getPlatform()->ZeDriverHandleExpTranslated,
589-
hContext->getZeHandle(), hostPtr, size);
590-
if (!canImport) {
591-
// Cannot import: fall back to discrete buffer path
592-
*phBuffer = ur_mem_handle_t_::create<ur_discrete_buffer_handle_t>(
593-
hContext, hostPtr, size, accessMode);
594-
return UR_RESULT_SUCCESS;
595-
}
596-
// Successfully imported: release it now, constructor will import again
597-
ZeUSMImport.doZeUSMRelease(
598-
hContext->getPlatform()->ZeDriverHandleExpTranslated, hostPtr);
599-
} else {
600-
// Cannot get memory attributes: fall back to discrete buffer
601-
*phBuffer = ur_mem_handle_t_::create<ur_discrete_buffer_handle_t>(
602-
hContext, hostPtr, size, accessMode);
603-
return UR_RESULT_SUCCESS;
604-
}
605-
}
606-
607-
// Use integrated buffer path (no hostPtr, or hostPtr is USM/importable)
585+
// For integrated devices, use zero-copy host buffers. The integrated buffer
586+
// constructor will handle all cases:
587+
// 1. No host pointer - allocate USM host memory
588+
// 2. Host pointer is already USM - use directly
589+
// 3. Host pointer can be imported - import it
590+
// 4. Otherwise - allocate USM and copy-back on destruction
608591
if (useHostBuffer(hContext)) {
609592
*phBuffer = ur_mem_handle_t_::create<ur_integrated_buffer_handle_t>(
610593
hContext, hostPtr, size, accessMode);

unified-runtime/source/adapters/level_zero/v2/memory.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ struct ur_integrated_buffer_handle_t : ur_mem_buffer_t {
9898
size_t size, device_access_mode_t accessMode,
9999
bool ownHostPtr);
100100

101+
~ur_integrated_buffer_handle_t();
102+
101103
void *getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset,
102104
size_t size, ze_command_list_handle_t cmdList,
103105
wait_list_view &waitListView) override;
@@ -109,6 +111,7 @@ struct ur_integrated_buffer_handle_t : ur_mem_buffer_t {
109111

110112
private:
111113
usm_unique_ptr_t ptr;
114+
void *writeBackPtr = nullptr;
112115
};
113116

114117
struct host_allocation_desc_t {

0 commit comments

Comments
 (0)