diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 29bc421d58f5c..fdaf63c561ca5 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #ifdef _WIN32 @@ -876,6 +877,7 @@ class rpc_server { bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); + uint64_t random_id(); private: bool get_cached_file(uint64_t hash, std::vector & data); @@ -885,12 +887,19 @@ class rpc_server { const std::unordered_map & tensor_ptrs, std::unordered_map & tensor_map); - ggml_backend_t backend; const char * cache_dir; - std::unordered_set buffers; + std::random_device rd; + // map from remote_ptr key to actual buffer pointer + std::unordered_map buffers; }; +uint64_t rpc_server::random_id() { + uint64_t high = static_cast(rd()) << 32; + uint64_t low = static_cast(rd()); + return (high | low); +} + void rpc_server::hello(rpc_msg_hello_rsp & response) { response.major = RPC_PROTO_MAJOR_VERSION; response.minor = RPC_PROTO_MINOR_VERSION; @@ -934,10 +943,12 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_ response.remote_ptr = 0; response.remote_size = 0; if (buffer != nullptr) { - response.remote_ptr = reinterpret_cast(buffer); + uint64_t rpk = random_id(); + response.remote_ptr = rpk; response.remote_size = buffer->size; - GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); - buffers.insert(buffer); + GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> handle: %" PRIu64 ", remote_size: %" PRIu64 "\n", + __func__, request.size, rpk, response.remote_size); + buffers[rpk] = buffer; } else { GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); } @@ -959,11 +970,12 @@ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); - ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); - if (buffers.find(buffer) == buffers.end()) { - GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + auto it = buffers.find(request.remote_ptr); + if (it == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr); return false; } + ggml_backend_buffer_t buffer = it->second; void * base = ggml_backend_buffer_get_base(buffer); response.base_ptr = reinterpret_cast(base); return true; @@ -971,23 +983,25 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); - ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); - if (buffers.find(buffer) == buffers.end()) { - GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + auto it = buffers.find(request.remote_ptr); + if (it == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr); return false; } + ggml_backend_buffer_t buffer = it->second; ggml_backend_buffer_free(buffer); - buffers.erase(buffer); + buffers.erase(it); return true; } bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); - ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); - if (buffers.find(buffer) == buffers.end()) { - GGML_LOG_ERROR("[%s] buffer not found\n", __func__); + auto it = buffers.find(request.remote_ptr); + if (it == buffers.end()) { + GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr); return false; } + ggml_backend_buffer_t buffer = it->second; ggml_backend_buffer_clear(buffer, request.value); return true; } @@ -1011,8 +1025,11 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result->nb[i] = tensor->nb[i]; } - result->buffer = reinterpret_cast(tensor->buffer); - if (result->buffer && buffers.find(result->buffer) == buffers.end()) { + // convert the remote_ptr handle to an actual buffer pointer + auto it_buf = buffers.find(tensor->buffer); + if (it_buf != buffers.end()) { + result->buffer = it_buf->second; + } else { result->buffer = nullptr; } @@ -1273,7 +1290,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, const rpc_tensor * tensor = it_ptr->second; struct ggml_tensor * result = deserialize_tensor(ctx, tensor); - if (result == nullptr) { + if (result == nullptr || result->buffer == nullptr) { return nullptr; } tensor_map[id] = result; @@ -1366,8 +1383,8 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } rpc_server::~rpc_server() { - for (auto buffer : buffers) { - ggml_backend_buffer_free(buffer); + for (auto &kv : buffers) { + ggml_backend_buffer_free(kv.second); } }