Skip to content

GGML: Fix leak of backend buffer memory address in RPC #14882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <memory>
#include <mutex>
#include <random>
#include <unordered_map>
#include <unordered_set>
#ifdef _WIN32
Expand Down Expand Up @@ -876,6 +877,7 @@ class rpc_server {
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
uint64_t random_id();

private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
Expand All @@ -885,12 +887,19 @@ class rpc_server {
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);


ggml_backend_t backend;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
std::random_device rd;
// map from remote_ptr key to actual buffer pointer
std::unordered_map<uint64_t, ggml_backend_buffer_t> buffers;
};

uint64_t rpc_server::random_id() {
uint64_t high = static_cast<uint64_t>(rd()) << 32;
uint64_t low = static_cast<uint64_t>(rd());
return (high | low);
}

void rpc_server::hello(rpc_msg_hello_rsp & response) {
response.major = RPC_PROTO_MAJOR_VERSION;
response.minor = RPC_PROTO_MINOR_VERSION;
Expand Down Expand Up @@ -934,10 +943,12 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
response.remote_ptr = 0;
response.remote_size = 0;
if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
uint64_t rpk = random_id();
response.remote_ptr = rpk;
response.remote_size = buffer->size;
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer);
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> handle: %" PRIu64 ", remote_size: %" PRIu64 "\n",
__func__, request.size, rpk, response.remote_size);
buffers[rpk] = buffer;
} else {
GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
}
Expand All @@ -959,35 +970,38 @@ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {

bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
auto it = buffers.find(request.remote_ptr);
if (it == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
return false;
}
ggml_backend_buffer_t buffer = it->second;
void * base = ggml_backend_buffer_get_base(buffer);
response.base_ptr = reinterpret_cast<uint64_t>(base);
return true;
}

bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
auto it = buffers.find(request.remote_ptr);
if (it == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
return false;
}
ggml_backend_buffer_t buffer = it->second;
ggml_backend_buffer_free(buffer);
buffers.erase(buffer);
buffers.erase(it);
return true;
}

bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
auto it = buffers.find(request.remote_ptr);
if (it == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
return false;
}
ggml_backend_buffer_t buffer = it->second;
ggml_backend_buffer_clear(buffer, request.value);
return true;
}
Expand All @@ -1011,8 +1025,11 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result->nb[i] = tensor->nb[i];
}
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
// convert the remote_ptr handle to an actual buffer pointer
auto it_buf = buffers.find(tensor->buffer);
if (it_buf != buffers.end()) {
result->buffer = it_buf->second;
} else {
result->buffer = nullptr;
}

Expand Down Expand Up @@ -1273,7 +1290,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
const rpc_tensor * tensor = it_ptr->second;

struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
if (result == nullptr) {
if (result == nullptr || result->buffer == nullptr) {
return nullptr;
}
tensor_map[id] = result;
Expand Down Expand Up @@ -1366,8 +1383,8 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
}

rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
for (auto &kv : buffers) {
ggml_backend_buffer_free(kv.second);
}
}

Expand Down
Loading