From 7f78aac817c4faf37570adfc958702f99dbb92db Mon Sep 17 00:00:00 2001 From: Clement-Wang26 Date: Thu, 16 Oct 2025 21:38:47 +0800 Subject: [PATCH] feat:use shared memory for intra-process communication instead of BRPC --- xllm/core/common/global_flags.cpp | 3 + xllm/core/common/global_flags.h | 4 +- xllm/core/common/options.h | 3 + xllm/core/distributed_runtime/CMakeLists.txt | 4 + .../core/distributed_runtime/comm_channel.cpp | 440 ++++++++ xllm/core/distributed_runtime/comm_channel.h | 139 +++ .../core/distributed_runtime/dist_manager.cpp | 107 +- xllm/core/distributed_runtime/dist_manager.h | 2 +- .../distributed_runtime/remote_worker.cpp | 398 +------ xllm/core/distributed_runtime/remote_worker.h | 41 +- xllm/core/distributed_runtime/shm_channel.cpp | 72 ++ xllm/core/distributed_runtime/shm_channel.h | 38 + .../distributed_runtime/worker_server.cpp | 70 +- xllm/core/distributed_runtime/worker_server.h | 45 +- .../distributed_runtime/worker_service.cpp | 267 +++-- .../core/distributed_runtime/worker_service.h | 19 + xllm/core/framework/eplb/CMakeLists.txt | 2 - .../framework/eplb/expert_weight_buffer_shm.h | 2 +- xllm/core/runtime/CMakeLists.txt | 2 + .../runtime/forward_shared_memory_manager.cpp | 1000 +++++++++++++++++ .../runtime/forward_shared_memory_manager.h | 124 ++ xllm/core/runtime/master.cpp | 9 +- xllm/core/runtime/master.h | 2 - xllm/core/runtime/options.h | 3 + xllm/core/runtime/params_utils.cpp | 2 + xllm/core/util/CMakeLists.txt | 2 + xllm/core/util/net.cpp | 22 + xllm/core/util/net.h | 2 + .../eplb => util}/shared_memory_manager.cpp | 25 +- .../eplb => util}/shared_memory_manager.h | 6 +- xllm/core/util/timer.cpp | 9 + xllm/core/util/timer.h | 4 +- xllm/xllm.cpp | 14 +- 33 files changed, 2328 insertions(+), 554 deletions(-) create mode 100644 xllm/core/distributed_runtime/comm_channel.cpp create mode 100644 xllm/core/distributed_runtime/comm_channel.h create mode 100644 xllm/core/distributed_runtime/shm_channel.cpp create mode 100644 xllm/core/distributed_runtime/shm_channel.h create mode 100644 xllm/core/runtime/forward_shared_memory_manager.cpp create mode 100644 xllm/core/runtime/forward_shared_memory_manager.h rename xllm/core/{framework/eplb => util}/shared_memory_manager.cpp (83%) rename xllm/core/{framework/eplb => util}/shared_memory_manager.h (92%) diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 572f2b6e..dd95e2ea 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -275,6 +275,9 @@ DEFINE_int32(sleep_time_second, 3, "The sleep time for worker try to connect to server next time."); +DEFINE_bool(enable_shm, + true, + "Whether to enable shared memory for executing model."); // --- function call config --- DEFINE_string(tool_call_parser, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index cf8c7186..10e1d794 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -199,4 +199,6 @@ DECLARE_int64(cache_size_per_token); DECLARE_int64(buffer_size_per_seq); -DECLARE_bool(enable_beam_search_kernel); \ No newline at end of file +DECLARE_bool(enable_beam_search_kernel); + +DECLARE_bool(enable_shm); diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index b0533104..836e6396 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -175,6 +175,9 @@ class Options { PROPERTY(bool, enable_offline_inference) = false; // for offline inference: the path to spawn worker binary PROPERTY(std::string, spawn_worker_path) = ""; + + // whether the worker and master are on the same machine. + PROPERTY(bool, is_local) = false; }; } // namespace xllm diff --git a/xllm/core/distributed_runtime/CMakeLists.txt b/xllm/core/distributed_runtime/CMakeLists.txt index 4a68308d..ac283262 100644 --- a/xllm/core/distributed_runtime/CMakeLists.txt +++ b/xllm/core/distributed_runtime/CMakeLists.txt @@ -20,6 +20,8 @@ cc_library( remote_worker.h worker_server.h worker_service.h + comm_channel.h + shm_channel.h SRCS disagg_pd_service.cpp disagg_pd_service_impl.cpp @@ -29,6 +31,8 @@ cc_library( remote_worker.cpp worker_server.cpp worker_service.cpp + comm_channel.cpp + shm_channel.cpp DEPS :api_service :runtime diff --git a/xllm/core/distributed_runtime/comm_channel.cpp b/xllm/core/distributed_runtime/comm_channel.cpp new file mode 100644 index 00000000..720a6b91 --- /dev/null +++ b/xllm/core/distributed_runtime/comm_channel.cpp @@ -0,0 +1,440 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "comm_channel.h" + +#include +#include + +namespace xllm { + +bool CommChannel::init_brpc(const std::string& server_address) { + options_.connection_type = "pooled"; + options_.timeout_ms = -1; + options_.connect_timeout_ms = -1; + options_.max_retry = 3; + + if (channel_.Init(server_address.c_str(), "", &options_) != 0) { + LOG(ERROR) << "Failed to initialize brpc Channel"; + return false; + } + + stub_.reset(new proto::DistributeWorker_Stub(&channel_)); + return true; +} + +bool CommChannel::hello() { + proto::Status req; + proto::Status resp; + brpc::Controller cntl; + + cntl.Reset(); + stub_->Hello(&cntl, &req, &resp, nullptr); + if (cntl.Failed() || !resp.ok()) { + LOG(ERROR) << "Hello request failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::allocate_kv_cache( + const std::vector>& kv_cache_shape) { + proto::KVCacheShape shape; + shape.mutable_key_shape()->Reserve(kv_cache_shape[0].size()); + shape.mutable_value_shape()->Reserve(kv_cache_shape[1].size()); + + for (size_t i = 0; i < kv_cache_shape[0].size(); ++i) { + shape.add_key_shape(kv_cache_shape[0][i]); + shape.add_value_shape(kv_cache_shape[1][i]); + } + + proto::Status s; + brpc::Controller cntl; + stub_->AllocateKVCache(&cntl, &shape, &s, nullptr); + + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "allocate_kv_cache failed: " << cntl.ErrorText(); + return false; + } + + return true; +} + +bool CommChannel::allocate_continuous_kv_cache( + const std::vector& options) { + proto::XTensorOptionsVec xtensor_options_vec; + xtensor_options_vec.mutable_key_options()->set_num_kv_heads( + options[0].num_kv_heads()); + xtensor_options_vec.mutable_key_options()->set_head_size( + options[0].head_size()); + xtensor_options_vec.mutable_key_options()->set_max_context_len( + options[0].max_context_len()); + xtensor_options_vec.mutable_key_options()->set_max_seqs_per_batch( + options[0].max_seqs_per_batch()); + xtensor_options_vec.mutable_value_options()->set_num_kv_heads( + options[1].num_kv_heads()); + xtensor_options_vec.mutable_value_options()->set_head_size( + options[1].head_size()); + xtensor_options_vec.mutable_value_options()->set_max_context_len( + options[1].max_context_len()); + xtensor_options_vec.mutable_value_options()->set_max_seqs_per_batch( + options[1].max_seqs_per_batch()); + + proto::Status s; + brpc::Controller cntl; + stub_->AllocateContinuousKVCache(&cntl, &xtensor_options_vec, &s, nullptr); + + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "allocate_continuous_kv_cache failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::get_device_info(std::string& device_ip, uint16_t& port) { + proto::Empty req; + proto::DeviceInfo resp; + brpc::Controller cntl; + + stub_->GetDeviceInfo(&cntl, &req, &resp, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "GetDeviceInfo failed: " << cntl.ErrorText(); + return false; + } + + device_ip = resp.device_ip(); + port = resp.listen_port(); + return true; +} + +bool CommChannel::get_cache_info(uint64_t& cluster_id, + std::string& addr, + int64_t& k_cache_id, + int64_t& v_cache_id) { + proto::Empty req; + proto::CacheInfo resp; + brpc::Controller cntl; + + stub_->GetCacheInfo(&cntl, &req, &resp, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "GetCacheInfo failed: " << cntl.ErrorText(); + return false; + } + + cluster_id = resp.cluster_id(); + addr = resp.addr(); + k_cache_id = resp.k_cache_id(); + v_cache_id = resp.v_cache_id(); + return true; +} + +bool CommChannel::link_cluster(const std::vector& cluster_ids, + const std::vector& addrs, + const std::vector& device_ips, + const std::vector& ports) { + proto::ClusterInfo cluster_info; + cluster_info.mutable_cluster_ids()->Reserve(cluster_ids.size()); + cluster_info.mutable_addrs()->Reserve(addrs.size()); + cluster_info.mutable_device_ips()->Reserve(device_ips.size()); + cluster_info.mutable_ports()->Reserve(ports.size()); + + for (size_t i = 0; i < cluster_ids.size(); ++i) { + cluster_info.add_cluster_ids(cluster_ids[i]); + cluster_info.add_addrs(addrs[i]); + cluster_info.add_device_ips(device_ips[i]); + cluster_info.add_ports(ports[i]); + } + + proto::Status s; + brpc::Controller cntl; + stub_->LinkCluster(&cntl, &cluster_info, &s, nullptr); + + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "LinkCluster failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::unlink_cluster(const std::vector& cluster_ids, + const std::vector& addrs, + const std::vector& device_ips, + const std::vector& ports) { + proto::ClusterInfo cluster_info; + cluster_info.mutable_cluster_ids()->Reserve(cluster_ids.size()); + cluster_info.mutable_addrs()->Reserve(addrs.size()); + cluster_info.mutable_device_ips()->Reserve(device_ips.size()); + cluster_info.mutable_ports()->Reserve(ports.size()); + + for (size_t i = 0; i < cluster_ids.size(); ++i) { + cluster_info.add_cluster_ids(cluster_ids[i]); + cluster_info.add_addrs(addrs[i]); + cluster_info.add_device_ips(device_ips[i]); + cluster_info.add_ports(ports[i]); + } + + proto::Status s; + brpc::Controller cntl; + stub_->UnlinkCluster(&cntl, &cluster_info, &s, nullptr); + + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "UnlinkCluster failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::init_model(const std::string& model_weights_path) { + proto::ModelPath request; + + request.set_model_weights_path(model_weights_path); + proto::Status response; + brpc::Controller cntl; + stub_->InitModel(&cntl, &request, &response, nullptr); + if (cntl.Failed() || !response.ok()) { + LOG(ERROR) << "init_model failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::init_model_async(const std::string& model_weights_path, + folly::Promise& promise) { + proto::ModelPath request; + + request.set_model_weights_path(model_weights_path); + auto done = new InitModelClosure(); + done->promise = std::move(promise); + stub_->InitModel(&done->cntl, &request, &done->response, done); + + return true; +} + +bool CommChannel::estimate_kv_cache_capacity(int64_t& available_memory, + int64_t& total_memory) { + proto::Empty req; + proto::DeviceMemory mem; + brpc::Controller cntl; + + stub_->ProfileDeviceMemory(&cntl, &req, &mem, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "estimate_kv_cache_capacity failed: " << cntl.ErrorText(); + return false; + } + + available_memory = mem.available_memory(); + total_memory = mem.total_memory(); + return true; +} + +bool CommChannel::pull_kv_blocks(const uint64_t src_cluster_id, + const std::string& src_addr, + const int64_t src_k_cache_id, + const int64_t src_v_cache_id, + const std::vector& src_blocks, + const std::vector& dst_blocks) { + proto::PullKVCacheRequest request; + request.set_cluster_id(src_cluster_id); + request.set_addr(src_addr); + request.set_k_cache_id(src_k_cache_id); + request.set_v_cache_id(src_v_cache_id); + + ADD_VECTOR_TO_PROTO(request.mutable_src_blocks(), src_blocks); + ADD_VECTOR_TO_PROTO(request.mutable_dst_blocks(), dst_blocks); + + proto::Status s; + brpc::Controller cntl; + stub_->PullKVCache(&cntl, &request, &s, nullptr); + + return !cntl.Failed() && s.ok(); +} + +void CommChannel::execute_model_async( + const std::vector& inputs, + folly::Promise>& promise) { + execute_model_with_brpc(inputs, promise); +} + +bool CommChannel::process_group_test() { + proto::Empty req; + proto::Status s; + brpc::Controller cntl; + + stub_->ProcessGroupTest(&cntl, &req, &s, nullptr); + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "process_group_test failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::allocate_kv_cache_with_transfer( + const uint64_t kv_cache_size, + const std::vector>& kv_cache_shape) { + proto::AllocateKVCacheWithTransferRequest request; + request.set_kv_cache_size(kv_cache_size); + + auto* shape = request.mutable_kv_cache_shape(); + shape->mutable_key_shape()->Reserve(kv_cache_shape[0].size()); + shape->mutable_value_shape()->Reserve(kv_cache_shape[1].size()); + + for (size_t i = 0; i < kv_cache_shape[0].size(); ++i) { + shape->add_key_shape(kv_cache_shape[0][i]); + shape->add_value_shape(kv_cache_shape[1][i]); + } + + proto::Status s; + brpc::Controller cntl; + stub_->AllocateKVCacheWithTransfer(&cntl, &request, &s, nullptr); + + if (cntl.Failed() || !s.ok()) { + LOG(ERROR) << "AllocateKVCacheWithTransfer failed: " << cntl.ErrorText(); + return false; + } + return true; +} + +bool CommChannel::load_kv_blocks_from_store_async( + const std::vector& cache_block_info, + folly::Promise& promise) { + proto::CacheBlockInfos pb_cache_block_info; + if (!cache_block_info_to_proto(cache_block_info, &pb_cache_block_info)) { + promise.setValue(0); + return false; + } + + auto done = new LoadKVCacheFromStoreClosure(); + done->promise = std::move(promise); + stub_->LoadKVCacheFromStore( + &done->cntl, &pb_cache_block_info, &done->response, done); + + return true; +} + +bool CommChannel::get_last_step_result_async( + folly::Promise>& promise) { + proto::Empty req; + proto::ForwardOutput pb_output; + brpc::Controller cntl; + stub_->GetLastStepResult(&cntl, &req, &pb_output, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "Get last step model output result failed, " + << cntl.ErrorText(); + return false; + } + + // parse tokens + RawForwardOutput raw_forward_output; + proto_to_forward_output(pb_output, raw_forward_output); + promise.setValue(std::move(raw_forward_output)); + + return true; +} + +bool CommChannel::get_active_activation_memory(int64_t& memory) { + proto::Empty req; + proto::ActivationMemory mem; + brpc::Controller cntl; + + stub_->GetActiveActivationMemory(&cntl, &req, &mem, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "GetActiveActivationMemory failed: " << cntl.ErrorText(); + return false; + } + + memory = mem.active_activation_memory(); + return true; +} + +bool CommChannel::get_active_activation_memory_async( + folly::Promise& promise) { + proto::Empty req; + proto::ActivationMemory mem; + brpc::Controller cntl; + + stub_->GetActiveActivationMemory(&cntl, &req, &mem, nullptr); + if (cntl.Failed()) { + LOG(ERROR) << "get_active_activation_memory_async failed: " + << cntl.ErrorText(); + promise.setValue(0); + return false; + } + promise.setValue(mem.active_activation_memory()); + return true; +} + +bool CommChannel::execute_model_with_brpc( + const std::vector& inputs, + folly::Promise>& promise) { + // convert to proto::BatchedForwardInputs + proto::BatchedForwardInputs pb_batched_fwd_inputs; + std::vector batched_fwd_inputs_vec; + batched_fwd_inputs_vec.reserve(inputs.size()); + for (auto i = 0; i < inputs.size(); ++i) { + proto::ForwardInput pb_fwd_input; + forward_input_to_proto(inputs[i], &pb_fwd_input); + batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input)); + } + ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(), + batched_fwd_inputs_vec); + // call ExecuteModel with callback + auto done = new ExecuteModelClosure(); + done->promise = std::move(promise); + stub_->ExecuteModel( + &done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done); + return true; +} + +void LoadKVCacheFromStoreClosure::Run() { + std::unique_ptr self_guard(this); + + bool success = !cntl.Failed(); + if (!success) { + promise.setValue(0); + } else { + promise.setValue(response.success_cnt()); + } + return; +} + +void ExecuteModelClosure::Run() { + std::unique_ptr self_guard(this); + + if (cntl.Failed()) { + LOG(ERROR) << "Execute_model_async failed. Error code : " + << cntl.ErrorCode() << ", error message : " << cntl.ErrorText(); + } + + RawForwardOutput raw_forward_output; + proto_to_forward_output(pb_output, raw_forward_output); + promise.setValue(raw_forward_output); + + return; +} + +void InitModelClosure::Run() { + std::unique_ptr self_guard(this); + + bool success = !cntl.Failed() && response.ok(); + if (!success) { + LOG(ERROR) << "Init_model_async failed, " << cntl.ErrorText(); + } else { + LOG(INFO) << "Init_model_async succeed."; + } + promise.setValue(success); + + return; +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/distributed_runtime/comm_channel.h b/xllm/core/distributed_runtime/comm_channel.h new file mode 100644 index 00000000..a97850c0 --- /dev/null +++ b/xllm/core/distributed_runtime/comm_channel.h @@ -0,0 +1,139 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "framework/xtensor/xtensor.h" +#include "runtime/forward_params.h" +#include "runtime/params_utils.h" +#include "worker.pb.h" + +namespace xllm { + +class CommChannel { + public: + CommChannel() = default; + virtual ~CommChannel() = default; + + bool init_brpc(const std::string& server_address); + + virtual bool hello(); + + virtual bool allocate_kv_cache( + const std::vector>& kv_cache_shape); + + virtual bool allocate_continuous_kv_cache( + const std::vector& options); + + virtual bool get_device_info(std::string& device_ip, uint16_t& port); + + virtual bool get_cache_info(uint64_t& cluster_id, + std::string& addr, + int64_t& k_cache_id, + int64_t& v_cache_id); + + virtual bool link_cluster(const std::vector& cluster_ids, + const std::vector& addrs, + const std::vector& device_ips, + const std::vector& ports); + + virtual bool unlink_cluster(const std::vector& cluster_ids, + const std::vector& addrs, + const std::vector& device_ips, + const std::vector& ports); + + virtual bool init_model(const std::string& model_weights_path); + + virtual bool init_model_async(const std::string& model_weights_path, + folly::Promise& promise); + + virtual bool estimate_kv_cache_capacity(int64_t& available_memory, + int64_t& total_memory); + + virtual bool pull_kv_blocks(const uint64_t src_cluster_id, + const std::string& src_addr, + const int64_t src_k_cache_id, + const int64_t src_v_cache_id, + const std::vector& src_blocks, + const std::vector& dst_blocks); + + virtual void execute_model_async( + const std::vector& inputs, + folly::Promise>& promise); + + virtual bool process_group_test(); + + virtual bool allocate_kv_cache_with_transfer( + const uint64_t kv_cache_size, + const std::vector>& kv_cache_shape); + + virtual bool load_kv_blocks_from_store_async( + const std::vector& cache_block_info, + folly::Promise& promise); + + virtual bool get_last_step_result_async( + folly::Promise>& promise); + + virtual bool get_active_activation_memory(int64_t& memory); + + virtual bool get_active_activation_memory_async( + folly::Promise& promise); + + protected: + bool execute_model_with_brpc( + const std::vector& inputs, + folly::Promise>& promise); + + private: + brpc::Channel channel_; + brpc::ChannelOptions options_; + std::unique_ptr stub_; +}; + +class InitModelClosure : public google::protobuf::Closure { + public: + void Run(); + + proto::Status response; + brpc::Controller cntl; + folly::Promise promise; +}; + +class ExecuteModelClosure : public google::protobuf::Closure { + public: + void Run(); + + proto::ForwardOutput pb_output; + brpc::Controller cntl; + folly::Promise> promise; +}; + +class LoadKVCacheFromStoreClosure : public google::protobuf::Closure { + public: + void Run(); + + proto::StoreResponse response; + brpc::Controller cntl; + folly::Promise promise; +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/distributed_runtime/dist_manager.cpp b/xllm/core/distributed_runtime/dist_manager.cpp index 046048bd..e7abfc7e 100644 --- a/xllm/core/distributed_runtime/dist_manager.cpp +++ b/xllm/core/distributed_runtime/dist_manager.cpp @@ -17,13 +17,16 @@ limitations under the License. #include +#include "comm_channel.h" #include "distributed_runtime/collective_service.h" #include "framework/parallel_state/parallel_args.h" #include "framework/parallel_state/parallel_state.h" #include "framework/parallel_state/process_group.h" +#include "runtime/forward_shared_memory_manager.h" #include "runtime/llm_worker_impl.h" #include "server/xllm_server_registry.h" - +#include "shm_channel.h" +#include "util/net.h" namespace xllm { DistManager::DistManager(const runtime::Options& options) { @@ -90,6 +93,53 @@ void DistManager::setup_single_node_workers(const runtime::Options& options) { } } +namespace { +std::unique_ptr create_channel(const std::string& worker_addrs, + int r, + int dp_local_tp_size) { + std::unique_ptr channel; + + if (net::extract_ip(FLAGS_master_node_addr) == + net::extract_ip(worker_addrs) && + FLAGS_enable_shm) { + // create shared memory manager for local rank + bool is_driver = false; + int dp_group = r / dp_local_tp_size; + if (r % dp_local_tp_size == 0) { + is_driver = true; + } + channel = std::make_unique(dp_group, r, is_driver); + } else { + channel = std::make_unique(); + } + + channel->init_brpc(worker_addrs); + + return channel; +} + +void prepare_shm( + int dp_local_tp_size, + int rank, + std::unique_ptr& input_shm_manager, + std::unique_ptr& output_shm_manager) { + bool is_creator; + int32_t dp_group = rank / dp_local_tp_size; + + string name = ForwardSharedMemoryManager::create_unique_name( + dp_group, FORWARD_RAW_INPUT_TYPE, rank); + input_shm_manager = std::make_unique( + name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE); + LOG(INFO) << "Create input shared memory manager with name: " << name; + + name = ForwardSharedMemoryManager::create_unique_name( + dp_group, FORWARD_RAW_OUTPUT_TYPE, rank); + output_shm_manager = std::make_unique( + name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE); + LOG(INFO) << "Create output shared memory manager with name: " << name; +} +} // namespace + void DistManager::setup_multi_node_workers( const runtime::Options& options, const std::string& master_node_addr) { @@ -98,11 +148,11 @@ void DistManager::setup_multi_node_workers( // Process/Thread Worker Mode, we use it in multi-nodes serving. // Here, we assume that all node use same index devices. That is, if we set - // device='1,2,3,4' and nnodes=2, then both machine nodes will use the devices - // '1,2,3,4'. Therefore, the total world size is 2 * 4 = 8. This means that - // each of the two nodes will utilize four devices (specifically devices 1, 2, - // 3, and 4), resulting in a total of 8 devices being used across the entire - // distributed setup. + // device='1,2,3,4' and nnodes=2, then both machine nodes will use the + // devices '1,2,3,4'. Therefore, the total world size is 2 * 4 = 8. This + // means that each of the two nodes will utilize four devices (specifically + // devices 1, 2, 3, and 4), resulting in a total of 8 devices being used + // across the entire distributed setup. // To maintain interface consistency, we have implemented a new WorkerImpl // class. In this class, we create processes, initialize NCCL ProcessGroup, @@ -120,13 +170,17 @@ void DistManager::setup_multi_node_workers( const int32_t base_rank = options.node_rank() * each_node_ranks; const int32_t dp_size = options.dp_size(); const int32_t ep_size = options.ep_size(); + const int32_t dp_local_tp_size = world_size / dp_size; + LOG(INFO) << "Multi-node serving world_size = " << world_size << ", each_node_ranks = " << each_node_ranks << ", current node rank = " << options.node_rank() << ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size - << ", ep_size = " << ep_size; + << ", ep_size = " << ep_size << ", tp_size = " << dp_local_tp_size; + CHECK_EQ((world_size % dp_size), 0) - << "Global world size must be divisible by dp size in multi-node serving " + << "Global world size must be divisible by dp size in multi-node " + "serving " "mode."; runtime::Options worker_server_options = options; @@ -134,6 +188,7 @@ void DistManager::setup_multi_node_workers( WorkerType worker_type = (options.task_type() == "generate") ? WorkerType::LLM : WorkerType::ELM; + // create local workers for (size_t i = 0; i < devices.size(); ++i) { // worldsize = 8 @@ -145,15 +200,25 @@ void DistManager::setup_multi_node_workers( // when start a offline inference task with multi-gpu/npu/mpu/... bool use_spawn_worker = options.enable_offline_inference() && i > 0; ParallelArgs parallel_args(rank, world_size, dp_size, nullptr, ep_size); - servers_.emplace_back(std::make_unique(i, - master_node_addr, - // done, - dones[i], - parallel_args, - devices[i], - worker_server_options, - worker_type, - use_spawn_worker)); + + std::unique_ptr input_shm_manager = nullptr; + std::unique_ptr output_shm_manager = nullptr; + if (options.is_local() && FLAGS_enable_shm) { + prepare_shm( + dp_local_tp_size, rank, input_shm_manager, output_shm_manager); + } + servers_.emplace_back( + std::make_unique(i, + master_node_addr, + // done, + dones[i], + parallel_args, + devices[i], + worker_server_options, + worker_type, + use_spawn_worker, + std::move(input_shm_manager), + std::move(output_shm_manager))); } // Master node need to wait all workers done @@ -185,8 +250,12 @@ void DistManager::setup_multi_node_workers( << r; return; } - worker_clients_.emplace_back(std::make_unique( - r, worker_addrs_map[r], devices[r % each_node_ranks])); + auto channel = create_channel(worker_addrs_map[r], r, dp_local_tp_size); + worker_clients_.emplace_back( + std::make_unique(r, + worker_addrs_map[r], + devices[r % each_node_ranks], + std::move(channel))); } } diff --git a/xllm/core/distributed_runtime/dist_manager.h b/xllm/core/distributed_runtime/dist_manager.h index cd39c08d..041d7373 100644 --- a/xllm/core/distributed_runtime/dist_manager.h +++ b/xllm/core/distributed_runtime/dist_manager.h @@ -20,8 +20,8 @@ limitations under the License. #include "distributed_runtime/worker_server.h" #include "distributed_runtime/worker_service.h" #include "framework/parallel_state/process_group.h" +#include "runtime/forward_shared_memory_manager.h" #include "runtime/options.h" - namespace xllm { class DistManager { public: diff --git a/xllm/core/distributed_runtime/remote_worker.cpp b/xllm/core/distributed_runtime/remote_worker.cpp index 840ce843..30867723 100644 --- a/xllm/core/distributed_runtime/remote_worker.cpp +++ b/xllm/core/distributed_runtime/remote_worker.cpp @@ -37,40 +37,23 @@ limitations under the License. namespace xllm { RemoteWorker::RemoteWorker(int32_t global_rank, const std::string& server_address, - const torch::Device& d) - : global_rank_(global_rank), device_(d) { - // Initialize brpc channel - options_.connection_type = "pooled"; - options_.timeout_ms = -1; - options_.connect_timeout_ms = -1; - options_.max_retry = 3; - if (channel_.Init(server_address.c_str(), "", &options_) != 0) { - LOG(ERROR) << "Failed to initialize brpc channel"; - return; - } - // Initialize stub - stub_.reset(new proto::DistributeWorker_Stub(&channel_)); - + const torch::Device& d, + std::unique_ptr channel) + : global_rank_(global_rank), device_(d), channel_(std::move(channel)) { wait_for_server_ready(server_address); } bool RemoteWorker::wait_for_server_ready(const std::string& server_address) { - proto::Status req; - proto::Status resp; - // Retry until server initialize ready int try_count = 0; - brpc::Controller cntl; while (try_count < FLAGS_max_connect_count) { - cntl.Reset(); - stub_->Hello(&cntl, &req, &resp, nullptr); - if (cntl.Failed() || !resp.ok()) { - std::this_thread::sleep_for( - std::chrono::seconds(FLAGS_sleep_time_second)); - } else { + if (channel_->hello()) { LOG(INFO) << "RemoteWorker Hello connected, server_address: " << server_address << ", global_rank_: " << global_rank_; break; + } else { + std::this_thread::sleep_for( + std::chrono::seconds(FLAGS_sleep_time_second)); } try_count++; @@ -78,7 +61,7 @@ bool RemoteWorker::wait_for_server_ready(const std::string& server_address) { if (try_count >= FLAGS_max_connect_count) { LOG(ERROR) << "RemoteWorker Hello method failed, global_rank_ is " - << global_rank_ << ", error: " << cntl.ErrorText(); + << global_rank_; return false; } @@ -87,159 +70,52 @@ bool RemoteWorker::wait_for_server_ready(const std::string& server_address) { bool RemoteWorker::allocate_kv_cache( const std::vector>& kv_cache_shape) { - proto::KVCacheShape shape; - shape.mutable_key_shape()->Reserve(kv_cache_shape[0].size()); - shape.mutable_value_shape()->Reserve(kv_cache_shape[1].size()); - for (int32_t i = 0; i < kv_cache_shape[0].size(); ++i) { - shape.add_key_shape(kv_cache_shape[0][i]); - shape.add_value_shape(kv_cache_shape[1][i]); - } - proto::Status s; - brpc::Controller cntl; - stub_->AllocateKVCache(&cntl, &shape, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "allocate_kv_cache failed, " << cntl.ErrorText(); - return false; - } - return true; + return channel_->allocate_kv_cache(kv_cache_shape); } bool RemoteWorker::allocate_continuous_kv_cache( const std::vector& options) { - proto::XTensorOptionsVec xtensor_options_vec; - xtensor_options_vec.mutable_key_options()->set_num_kv_heads( - options[0].num_kv_heads()); - xtensor_options_vec.mutable_key_options()->set_head_size( - options[0].head_size()); - xtensor_options_vec.mutable_key_options()->set_max_context_len( - options[0].max_context_len()); - xtensor_options_vec.mutable_key_options()->set_max_seqs_per_batch( - options[0].max_seqs_per_batch()); - xtensor_options_vec.mutable_value_options()->set_num_kv_heads( - options[1].num_kv_heads()); - xtensor_options_vec.mutable_value_options()->set_head_size( - options[1].head_size()); - xtensor_options_vec.mutable_value_options()->set_max_context_len( - options[1].max_context_len()); - xtensor_options_vec.mutable_value_options()->set_max_seqs_per_batch( - options[1].max_seqs_per_batch()); - - proto::Status s; - brpc::Controller cntl; - stub_->AllocateContinuousKVCache(&cntl, &xtensor_options_vec, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "allocate_continuous_kv_cache failed, " << cntl.ErrorText(); - return false; - } - return true; + return channel_->allocate_continuous_kv_cache(options); } void RemoteWorker::get_device_info(std::string& device_ip, uint16_t& port) { - proto::Empty req; - proto::DeviceInfo resp; - brpc::Controller cntl; - stub_->GetDeviceInfo(&cntl, &req, &resp, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "GetDeviceInfo failed." << cntl.ErrorText(); - return; - } - device_ip = resp.device_ip(); - port = resp.listen_port(); + channel_->get_device_info(device_ip, port); } void RemoteWorker::get_cache_info(uint64_t& cluster_id, std::string& addr, int64_t& k_cache_id, int64_t& v_cache_id) { - proto::Empty req; - proto::CacheInfo resp; - brpc::Controller cntl; - stub_->GetCacheInfo(&cntl, &req, &resp, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "GetCacheInfo failed, " << cntl.ErrorText(); - return; - } - cluster_id = resp.cluster_id(); - addr = resp.addr(); - k_cache_id = resp.k_cache_id(); - v_cache_id = resp.v_cache_id(); + channel_->get_cache_info(cluster_id, addr, k_cache_id, v_cache_id); } bool RemoteWorker::link_cluster(const std::vector& cluster_ids, const std::vector& addrs, const std::vector& device_ips, const std::vector& ports) { - proto::ClusterInfo cluster_info; - cluster_info.mutable_cluster_ids()->Reserve(cluster_ids.size()); - cluster_info.mutable_addrs()->Reserve(addrs.size()); - cluster_info.mutable_device_ips()->Reserve(device_ips.size()); - cluster_info.mutable_ports()->Reserve(ports.size()); - for (int32_t i = 0; i < cluster_ids.size(); ++i) { - cluster_info.add_cluster_ids(cluster_ids[i]); - cluster_info.add_addrs(addrs[i]); - cluster_info.add_device_ips(device_ips[i]); - cluster_info.add_ports(ports[i]); - } - - proto::Status s; - brpc::Controller cntl; - stub_->LinkCluster(&cntl, &cluster_info, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(INFO) << "LinkCluster failed, " << cntl.ErrorText(); - return false; - } - return true; + return channel_->link_cluster(cluster_ids, addrs, device_ips, ports); } bool RemoteWorker::unlink_cluster(const std::vector& cluster_ids, const std::vector& addrs, const std::vector& device_ips, const std::vector& ports) { - proto::ClusterInfo cluster_info; - cluster_info.mutable_cluster_ids()->Reserve(cluster_ids.size()); - cluster_info.mutable_addrs()->Reserve(addrs.size()); - cluster_info.mutable_device_ips()->Reserve(device_ips.size()); - cluster_info.mutable_ports()->Reserve(ports.size()); - for (int32_t i = 0; i < cluster_ids.size(); ++i) { - cluster_info.add_cluster_ids(cluster_ids[i]); - cluster_info.add_addrs(addrs[i]); - cluster_info.add_device_ips(device_ips[i]); - cluster_info.add_ports(ports[i]); - } - - proto::Status s; - brpc::Controller cntl; - stub_->UnlinkCluster(&cntl, &cluster_info, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(INFO) << "UnlinkCluster failed, " << cntl.ErrorText(); - return false; - } - return true; + return channel_->unlink_cluster(cluster_ids, addrs, device_ips, ports); } bool RemoteWorker::init_model(const std::string& model_weights_path) { - proto::ModelPath request; - request.set_model_weights_path(model_weights_path); - proto::Status response; - brpc::Controller cntl; - stub_->InitModel(&cntl, &request, &response, nullptr); - if (cntl.Failed() || !response.ok()) { - LOG(ERROR) << "init_model failed, " << cntl.ErrorText(); - return false; - } - return true; + return channel_->init_model(model_weights_path); } std::tuple RemoteWorker::estimate_kv_cache_capacity() { proto::Empty req; proto::DeviceMemory mem; brpc::Controller cntl; - stub_->ProfileDeviceMemory(&cntl, &req, &mem, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "estimate_kv_cache_capacity failed: " << cntl.ErrorText(); - } - std::tuple result(mem.available_memory(), - mem.total_memory()); + int64_t available_memory = 0; + int64_t total_memory = 0; + + channel_->estimate_kv_cache_capacity(available_memory, total_memory); + std::tuple result(available_memory, total_memory); return result; } @@ -249,18 +125,12 @@ bool RemoteWorker::pull_kv_blocks(const uint64_t src_cluster_id, const int64_t src_v_cache_id, const std::vector& src_blocks, const std::vector& dst_blocks) { - proto::PullKVCacheRequest request; - request.set_cluster_id(src_cluster_id); - request.set_addr(src_addr); - request.set_k_cache_id(src_k_cache_id); - request.set_v_cache_id(src_v_cache_id); - ADD_VECTOR_TO_PROTO(request.mutable_src_blocks(), src_blocks); - ADD_VECTOR_TO_PROTO(request.mutable_dst_blocks(), dst_blocks); - proto::Status s; - brpc::Controller cntl; - stub_->PullKVCache(&cntl, &request, &s, nullptr); - - return s.ok(); + return channel_->pull_kv_blocks(src_cluster_id, + src_addr, + src_k_cache_id, + src_v_cache_id, + src_blocks, + dst_blocks); } ForwardInput RemoteWorker::prepare_inputs(Batch& batch) { @@ -279,13 +149,11 @@ RemoteWorker::estimate_kv_cache_capacity_async() { proto::Empty req; proto::DeviceMemory mem; brpc::Controller cntl; - stub_->ProfileDeviceMemory(&cntl, &req, &mem, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "estimate_kv_cache_capacity_async failed: " - << cntl.ErrorText(); - } - std::tuple result(mem.available_memory(), - mem.total_memory()); + int64_t available_memory = 0; + int64_t total_memory = 0; + + channel_->estimate_kv_cache_capacity(available_memory, total_memory); + std::tuple result(available_memory, total_memory); promise.setValue(result); }); return future; @@ -303,55 +171,17 @@ folly::SemiFuture> RemoteWorker::step_async( auto future = promise.getSemiFuture(); threadpool_.schedule( [this, inputs = inputs, promise = std::move(promise)]() mutable { - // 1. convert to proto::BatchedForwardInputs - proto::BatchedForwardInputs pb_batched_fwd_inputs; - std::vector batched_fwd_inputs_vec; - batched_fwd_inputs_vec.reserve(inputs.size()); - for (auto i = 0; i < inputs.size(); ++i) { - proto::ForwardInput pb_fwd_input; - forward_input_to_proto(inputs[i], &pb_fwd_input); - batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input)); - } - ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(), - batched_fwd_inputs_vec); - - // 2. call ExecuteModel with callback - auto done = new ExecuteModelClosure(); - done->promise = std::move(promise); - stub_->ExecuteModel( - &done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done); + channel_->execute_model_async(inputs, promise); }); return future; } -void ExecuteModelClosure::Run() { - std::unique_ptr self_guard(this); - - if (cntl.Failed()) { - LOG(ERROR) << "Execute_model_async failed. Error code : " - << cntl.ErrorCode() << ", error message : " << cntl.ErrorText(); - } - - // 3. parse tokens - RawForwardOutput raw_forward_output; - proto_to_forward_output(pb_output, raw_forward_output); - promise.setValue(raw_forward_output); - - return; -} - folly::SemiFuture RemoteWorker::process_group_test_async() { folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, promise = std::move(promise)]() mutable { - proto::Empty req; - proto::Status s; - brpc::Controller cntl; - stub_->ProcessGroupTest(&cntl, &req, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "process_group_test_async failed, " << cntl.ErrorText(); - } + channel_->process_group_test(); promise.setValue(); }); return future; @@ -364,50 +194,22 @@ folly::SemiFuture RemoteWorker::init_model_async( threadpool_.schedule( [this, model_weights_path, promise = std::move(promise)]() mutable { // call InitModel with callback - auto done = new InitModelClosure(); - done->promise = std::move(promise); - proto::ModelPath request; - request.set_model_weights_path(model_weights_path); - stub_->InitModel(&done->cntl, &request, &done->response, done); + channel_->init_model_async(model_weights_path, promise); }); return future; } -void InitModelClosure::Run() { - std::unique_ptr self_guard(this); - - bool success = !cntl.Failed() && response.ok(); - if (!success) { - LOG(ERROR) << "Init_model_async failed, " << cntl.ErrorText(); - } else { - LOG(INFO) << "Init_model_async succeed."; - } - promise.setValue(success); - - return; -} - folly::SemiFuture RemoteWorker::allocate_kv_cache_async( const std::vector>& kv_cache_shape) { folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule( [this, kv_cache_shape, promise = std::move(promise)]() mutable { - proto::KVCacheShape shape; - shape.mutable_key_shape()->Reserve(kv_cache_shape[0].size()); - shape.mutable_value_shape()->Reserve(kv_cache_shape[1].size()); - for (int32_t i = 0; i < kv_cache_shape[0].size(); ++i) { - shape.add_key_shape(kv_cache_shape[0][i]); - shape.add_value_shape(kv_cache_shape[1][i]); - } - proto::Status s; - brpc::Controller cntl; - stub_->AllocateKVCache(&cntl, &shape, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "allocate_kv_cache_async failed, " << cntl.ErrorText(); + if (!channel_->allocate_kv_cache(kv_cache_shape)) { + LOG(ERROR) << "allocate_kv_cache_async failed"; promise.setValue(false); } else { - promise.setValue(s.ok()); + promise.setValue(true); } }); return future; @@ -418,32 +220,11 @@ folly::SemiFuture RemoteWorker::allocate_continuous_kv_cache_async( folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, options, promise = std::move(promise)]() mutable { - proto::XTensorOptionsVec xtensor_options_vec; - xtensor_options_vec.mutable_key_options()->set_num_kv_heads( - options[0].num_kv_heads()); - xtensor_options_vec.mutable_key_options()->set_head_size( - options[0].head_size()); - xtensor_options_vec.mutable_key_options()->set_max_context_len( - options[0].max_context_len()); - xtensor_options_vec.mutable_key_options()->set_max_seqs_per_batch( - options[0].max_seqs_per_batch()); - xtensor_options_vec.mutable_value_options()->set_num_kv_heads( - options[1].num_kv_heads()); - xtensor_options_vec.mutable_value_options()->set_head_size( - options[1].head_size()); - xtensor_options_vec.mutable_value_options()->set_max_context_len( - options[1].max_context_len()); - xtensor_options_vec.mutable_value_options()->set_max_seqs_per_batch( - options[1].max_seqs_per_batch()); - proto::Status s; - brpc::Controller cntl; - stub_->AllocateContinuousKVCache(&cntl, &xtensor_options_vec, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "allocate_continuous_kv_cache_async failed, " - << cntl.ErrorText(); + if (!channel_->allocate_continuous_kv_cache(options)) { + LOG(ERROR) << "allocate_continuous_kv_cache_async failed"; promise.setValue(false); } else { - promise.setValue(s.ok()); + promise.setValue(true); } }); return future; @@ -458,24 +239,12 @@ folly::SemiFuture RemoteWorker::allocate_kv_cache_with_transfer_async( kv_cache_size, kv_cache_shape, promise = std::move(promise)]() mutable { - proto::AllocateKVCacheWithTransferRequest request; - request.set_kv_cache_size(kv_cache_size); - request.mutable_kv_cache_shape()->mutable_key_shape()->Reserve( - kv_cache_shape[0].size()); - request.mutable_kv_cache_shape()->mutable_value_shape()->Reserve( - kv_cache_shape[1].size()); - for (int32_t i = 0; i < kv_cache_shape[0].size(); ++i) { - request.mutable_kv_cache_shape()->add_key_shape(kv_cache_shape[0][i]); - request.mutable_kv_cache_shape()->add_value_shape(kv_cache_shape[1][i]); - } - proto::Status s; - brpc::Controller cntl; - stub_->AllocateKVCacheWithTransfer(&cntl, &request, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "AllocateKVCacheWithTransfer failed, " << cntl.ErrorText(); + if (!channel_->allocate_kv_cache_with_transfer(kv_cache_size, + kv_cache_shape)) { + LOG(ERROR) << "AllocateKVCacheWithTransfer failed"; promise.setValue(false); } else { - promise.setValue(s.ok()); + promise.setValue(true); } }); return future; @@ -498,21 +267,16 @@ folly::SemiFuture RemoteWorker::pull_kv_blocks_async( &src_blocks, &dst_blocks, promise = std::move(promise)]() mutable { - proto::PullKVCacheRequest request; - request.set_cluster_id(src_cluster_id); - request.set_addr(src_addr); - request.set_k_cache_id(src_k_cache_id); - request.set_v_cache_id(src_v_cache_id); - ADD_VECTOR_TO_PROTO(request.mutable_src_blocks(), src_blocks); - ADD_VECTOR_TO_PROTO(request.mutable_dst_blocks(), dst_blocks); - proto::Status s; - brpc::Controller cntl; - stub_->PullKVCache(&cntl, &request, &s, nullptr); - if (cntl.Failed() || !s.ok()) { - LOG(ERROR) << "PullKVCache failed, " << cntl.ErrorText(); + if (!channel_->pull_kv_blocks(src_cluster_id, + src_addr, + src_k_cache_id, + src_v_cache_id, + src_blocks, + dst_blocks)) { + LOG(ERROR) << "PullKVCache failed"; promise.setValue(false); } else { - promise.setValue(s.ok()); + promise.setValue(true); } }); return future; @@ -525,32 +289,11 @@ folly::SemiFuture RemoteWorker::load_kv_blocks_from_store_async( general_threadpool_.schedule([this, cache_block_info = std::move(cache_block_info), promise = std::move(promise)]() mutable { - proto::CacheBlockInfos pb_cache_block_info; - if (!cache_block_info_to_proto(cache_block_info, &pb_cache_block_info)) { - promise.setValue(0); - return; - } - - auto done = new LoadKVCacheFromStoreClosure(); - done->promise = std::move(promise); - stub_->LoadKVCacheFromStore( - &done->cntl, &pb_cache_block_info, &done->response, done); + channel_->load_kv_blocks_from_store_async(cache_block_info, promise); }); return future; } -void LoadKVCacheFromStoreClosure::Run() { - std::unique_ptr self_guard(this); - - bool success = !cntl.Failed(); - if (!success) { - promise.setValue(0); - } else { - promise.setValue(response.success_cnt()); - } - return; -} - const torch::Device& RemoteWorker::device() const { LOG(ERROR) << "RemoteWorker Method device is UnImplemented."; } @@ -560,47 +303,22 @@ RemoteWorker::get_last_step_result_async() { folly::Promise> promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, promise = std::move(promise)]() mutable { - proto::Empty req; - proto::ForwardOutput pb_output; - brpc::Controller cntl; - stub_->GetLastStepResult(&cntl, &req, &pb_output, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "Get last step model output result failed, " - << cntl.ErrorText(); - } - - // parse tokens - RawForwardOutput raw_forward_output; - proto_to_forward_output(pb_output, raw_forward_output); - promise.setValue(std::move(raw_forward_output)); + channel_->get_last_step_result_async(promise); }); return future; } int64_t RemoteWorker::get_active_activation_memory() { - proto::Empty req; - proto::ActivationMemory mem; - brpc::Controller cntl; - stub_->GetActiveActivationMemory(&cntl, &req, &mem, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "get_active_activation_memory failed: " << cntl.ErrorText(); - } - return mem.active_activation_memory(); + int64_t memory = 0; + channel_->get_active_activation_memory(memory); + return memory; } folly::SemiFuture RemoteWorker::get_active_activation_memory_async() { folly::Promise promise; auto future = promise.getSemiFuture(); threadpool_.schedule([this, promise = std::move(promise)]() mutable { - proto::Empty req; - proto::ActivationMemory mem; - brpc::Controller cntl; - stub_->GetActiveActivationMemory(&cntl, &req, &mem, nullptr); - if (cntl.Failed()) { - LOG(ERROR) << "get_active_activation_memory_async failed: " - << cntl.ErrorText(); - } - promise.setValue(mem.active_activation_memory()); + channel_->get_active_activation_memory_async(promise); }); return future; } diff --git a/xllm/core/distributed_runtime/remote_worker.h b/xllm/core/distributed_runtime/remote_worker.h index 399a36d5..9884c606 100644 --- a/xllm/core/distributed_runtime/remote_worker.h +++ b/xllm/core/distributed_runtime/remote_worker.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "comm_channel.h" #include "common/macros.h" #include "framework/model/causal_lm.h" #include "framework/model/embedding_lm.h" @@ -28,6 +29,7 @@ limitations under the License. #include "framework/state_dict/state_dict.h" #include "runtime/executor.h" #include "runtime/forward_params.h" +#include "runtime/forward_shared_memory_manager.h" #include "runtime/worker_client.h" #include "util/threadpool.h" #include "worker.pb.h" @@ -38,7 +40,8 @@ class RemoteWorker : public WorkerClient { public: explicit RemoteWorker(int32_t global_rank, const std::string& server_address, - const torch::Device& d); + const torch::Device& d, + std::unique_ptr channel); virtual ~RemoteWorker() = default; bool wait_for_server_ready(const std::string& server_address); @@ -134,44 +137,12 @@ class RemoteWorker : public WorkerClient { private: int32_t global_rank_; - - // brpc connection resource - brpc::Channel channel_; - brpc::ChannelOptions options_; - std::unique_ptr stub_; - + // connection resource + std::unique_ptr channel_; ThreadPool threadpool_; // general working thread // do some overlap work with model execute ThreadPool general_threadpool_{5}; const torch::Device device_; }; - -class InitModelClosure : public google::protobuf::Closure { - public: - void Run(); - - proto::Status response; - brpc::Controller cntl; - folly::Promise promise; -}; - -class ExecuteModelClosure : public google::protobuf::Closure { - public: - void Run(); - - proto::ForwardOutput pb_output; - brpc::Controller cntl; - folly::Promise> promise; -}; - -class LoadKVCacheFromStoreClosure : public google::protobuf::Closure { - public: - void Run(); - - proto::StoreResponse response; - brpc::Controller cntl; - folly::Promise promise; -}; - } // namespace xllm diff --git a/xllm/core/distributed_runtime/shm_channel.cpp b/xllm/core/distributed_runtime/shm_channel.cpp new file mode 100644 index 00000000..7bc61f83 --- /dev/null +++ b/xllm/core/distributed_runtime/shm_channel.cpp @@ -0,0 +1,72 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "shm_channel.h" + +#include "common/global_flags.h" + +namespace xllm { + +ShmChannel::ShmChannel(int dp_group, int rank, bool is_driver) { + bool is_creator; + + if (is_driver) { + auto name = ForwardSharedMemoryManager::create_unique_name( + dp_group, FORWARD_RAW_INPUT_TYPE, rank); + input_shm_manager_ = std::make_unique( + name, PB_INPUT_SHM_SIZE, is_creator, FORWARD_RAW_INPUT_TYPE); + LOG(INFO) << "Create input shared memory manager with name: " << name; + } + + auto name = ForwardSharedMemoryManager::create_unique_name( + dp_group, FORWARD_RAW_OUTPUT_TYPE, rank); + output_shm_manager_ = std::make_unique( + name, PB_OUTPUT_SHM_SIZE, is_creator, FORWARD_RAW_OUTPUT_TYPE); + LOG(INFO) << "Create output shared memory manager with name: " << name; +} + +bool ShmChannel::execute_model_with_shm( + const std::vector& inputs, + RawForwardOutput& raw_output) { + // write to shared memory, then wait output. + if (input_shm_manager_) { + int use_shm_ret = input_shm_manager_->raw_input_write(inputs); + if (use_shm_ret < 0) { + // fallback + FLAGS_enable_shm = false; + LOG(ERROR) + << "RemoteWorker SharedMemoryManager write failed, fallback to brpc."; + return false; + } + } + output_shm_manager_->raw_output_read(raw_output); + return true; +} + +void ShmChannel::execute_model_async( + const std::vector& inputs, + folly::Promise>& promise) { + if (FLAGS_enable_shm) { + // write to shared memory, then wait output. + RawForwardOutput raw_output; + bool shm_success = execute_model_with_shm(inputs, raw_output); + if (shm_success) { + promise.setValue(raw_output); + return; + } + } + execute_model_with_brpc(inputs, promise); +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/distributed_runtime/shm_channel.h b/xllm/core/distributed_runtime/shm_channel.h new file mode 100644 index 00000000..6c44f0da --- /dev/null +++ b/xllm/core/distributed_runtime/shm_channel.h @@ -0,0 +1,38 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include "comm_channel.h" +#include "runtime/forward_shared_memory_manager.h" + +namespace xllm { + +class ShmChannel : public CommChannel { + public: + explicit ShmChannel(int dp_group, int rank, bool is_driver); + ~ShmChannel() = default; + + void execute_model_async( + const std::vector& inputs, + folly::Promise>& promise) override; + + private: + bool execute_model_with_shm(const std::vector& inputs, + RawForwardOutput& raw_output); + std::unique_ptr input_shm_manager_ = nullptr; + std::unique_ptr output_shm_manager_ = nullptr; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 8fd14cf2..ffb006dc 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -51,15 +51,18 @@ extern char** environ; namespace xllm { -void WorkerServer::create_server(const runtime::Options& options, - std::atomic& done, - const std::string& master_node_addr, - const torch::Device& d, - int world_size, - int global_rank, - int32_t dp_size, - int local_rank, - int32_t ep_size) { +void WorkerServer::create_server( + const runtime::Options& options, + std::atomic& done, + const std::string& master_node_addr, + const torch::Device& d, + int world_size, + int global_rank, + int32_t dp_size, + int local_rank, + int32_t ep_size, + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager) { Device device(d); device.set_device(); LOG(INFO) << "Create worker server with device: " << device.index(); @@ -107,6 +110,11 @@ void WorkerServer::create_server(const runtime::Options& options, std::unique_ptr worker = std::make_unique(*parallel_args, device, options, worker_type); worker_service->set_worker(std::move(worker)); + if (FLAGS_enable_shm && input_shm_manager && output_shm_manager) { + worker_service->create_polling_shm_thread(std::move(input_shm_manager), + std::move(output_shm_manager)); + } + done.store(true); // Wait until Ctrl-C is pressed, then Stop() and Join() the server. @@ -160,14 +168,17 @@ void WorkerServer::create_spawn_server(int local_rank, done.store(true); } -WorkerServer::WorkerServer(int local_worker_idx, - const std::string& master_node_addr, - std::atomic& done, - const ParallelArgs& parallel_args, - const torch::Device& d, - const runtime::Options& options, - WorkerType worker_type, - bool use_spawn_worker) { +WorkerServer::WorkerServer( + int local_worker_idx, + const std::string& master_node_addr, + std::atomic& done, + const ParallelArgs& parallel_args, + const torch::Device& d, + const runtime::Options& options, + WorkerType worker_type, + bool use_spawn_worker, + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager) { if (worker_type == WorkerType::LLM || worker_type == WorkerType::ELM) { if (use_spawn_worker) { // start worker in a spawn process(for offline inference worker.) @@ -177,17 +188,20 @@ WorkerServer::WorkerServer(int local_worker_idx, } // start worker in a thread. - worker_thread_ = std::make_unique(&WorkerServer::create_server, - this, - std::cref(options), - std::ref(done), - std::cref(master_node_addr), - std::cref(d), - parallel_args.world_size(), - parallel_args.rank(), - parallel_args.dp_size(), - local_worker_idx, - parallel_args.ep_size()); + worker_thread_ = + std::make_unique(&WorkerServer::create_server, + this, + std::cref(options), + std::ref(done), + std::cref(master_node_addr), + std::cref(d), + parallel_args.world_size(), + parallel_args.rank(), + parallel_args.dp_size(), + local_worker_idx, + parallel_args.ep_size(), + std::move(input_shm_manager), + std::move(output_shm_manager)); } else { // TODO: support other model type later. LOG(ERROR) << "Unsupported model type: " << worker_type; diff --git a/xllm/core/distributed_runtime/worker_server.h b/xllm/core/distributed_runtime/worker_server.h index a3b8045c..f0e756d0 100644 --- a/xllm/core/distributed_runtime/worker_server.h +++ b/xllm/core/distributed_runtime/worker_server.h @@ -28,6 +28,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "runtime/executor.h" #include "runtime/forward_params.h" +#include "runtime/forward_shared_memory_manager.h" #include "runtime/options.h" #include "runtime/worker_impl.h" #include "worker.pb.h" @@ -36,26 +37,32 @@ namespace xllm { class WorkerServer { public: - WorkerServer(int local_worker_idx, - const std::string& master_node_addr, - std::atomic& done, - const ParallelArgs& parallel_args, - const torch::Device& d, - const runtime::Options& options, - WorkerType worker_type, - bool use_spawn_worker = false); + WorkerServer( + int local_worker_idx, + const std::string& master_node_addr, + std::atomic& done, + const ParallelArgs& parallel_args, + const torch::Device& d, + const runtime::Options& options, + WorkerType worker_type, + bool use_spawn_worker = false, + std::unique_ptr input_shm_manager = nullptr, + std::unique_ptr output_shm_manager = nullptr); virtual ~WorkerServer(); - void create_server(const runtime::Options& options, - std::atomic& done, - const std::string& master_node_addr, - const torch::Device& d, - int world_sizse, - int global_rank, - int32_t dp_size, - int local_rank, - int32_t ep_size); + void create_server( + const runtime::Options& options, + std::atomic& done, + const std::string& master_node_addr, + const torch::Device& d, + int world_sizse, + int global_rank, + int32_t dp_size, + int local_rank, + int32_t ep_size, + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager); private: DISALLOW_COPY_AND_ASSIGN(WorkerServer); @@ -67,6 +74,10 @@ class WorkerServer { const torch::Device& d, const runtime::Options& options); + void create_shared_memory_polling( + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager); + bool sync_master_node(const std::string& master_node_addr, proto::AddressInfo& addr_info, proto::CommUniqueIdList& uids); diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index 827cbfb8..b6611d59 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -61,6 +61,177 @@ void WorkerService::set_worker(std::unique_ptr worker) { initialized_ = true; } +void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs, + torch::Tensor& next_tokens, + torch::Tensor& logprobs, + torch::Tensor& top_tokens, + torch::Tensor& top_logprobs, + torch::Tensor& embeddings, + torch::Tensor& expert_load_data, + int32_t& prepared_layer_id, + torch::Tensor& src_seq_idxes, + torch::Tensor& out_tokens, + torch::Tensor& out_logprobs) { + device_.set_device(); + // execute model + auto future = worker_->step_async(batched_fwd_inputs); + + if (!options_.enable_schedule_overlap()) { + auto forward_outputs = std::move(future).get(); + // convert ForwardOutput to proto::ForwardOutput which contain Tokens. + if (forward_outputs) { + DCHECK(forward_outputs.has_value()) << "Failed to execute model"; + const auto& sample_output = forward_outputs.value().sample_output; + const auto& beam_search_output = + forward_outputs.value().beam_search_output; + expert_load_data = + safe_to(forward_outputs.value().expert_load_data, torch::kCPU, true); + prepared_layer_id = forward_outputs.value().prepared_layer_id; + + { + c10::StreamGuard streamGuard = stream_->set_stream_guard(); + // only driver worker (rank=0) need to fill this + // [num_seq, ..., embed_dim] FloatTensor + embeddings = safe_to(sample_output.embeddings, + torch::dtype(torch::kFloat32).device(torch::kCPU), + true); + + // [num_seq] + next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); + if (next_tokens.defined()) { + // [num_seq] + logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); + + if (!beam_search_output.src_seq_idxes.defined()) { + // beam search kernel will provide final tokens/logprobs in beam + // search output, so keep top_tokens/top_logprobs undefined to + // avoid returning them. + // [num_seq, topk] + top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true); + // [num_seq, topk] + top_logprobs = + safe_to(sample_output.top_logprobs, torch::kCPU, true); + } + } + + // beam search output + // [num_seq] + src_seq_idxes = + safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); + if (src_seq_idxes.defined()) { + // [num_seq] + out_tokens = + safe_to(beam_search_output.out_tokens, torch::kCPU, true); + // [num_seq] + out_logprobs = + safe_to(beam_search_output.out_logprobs, + torch::dtype(torch::kFloat32).device(torch::kCPU), + true); + } + auto ret = stream_->synchronize(); + } + } + } else { + if (worker_->is_driver()) { + // construct fake output tensor + auto options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + auto total_prefill_seq_len = 0; + auto total_num_sequences = 0; + for (auto& input : batched_fwd_inputs.micro_inputs) { + total_num_sequences += input.input_params.num_sequences; + total_prefill_seq_len += input.input_params.prefill_seq_len; + } + next_tokens = + torch::arange(-1, + -1 * (total_num_sequences - total_prefill_seq_len + 1), + -1, + options); + std::move(future).deferValue([](auto&&) {}); + } + expert_load_data = torch::zeros({1, 1}).to(torch::kInt64).contiguous(); + } +} + +void WorkerService::create_polling_shm_thread( + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager) { + polling_thread_ = std::make_unique( + [this, + input_shm_manager = std::move(input_shm_manager), + output_shm_manager = std::move(output_shm_manager)]() mutable { + Timer timer; + while (true) { + BatchedForwardInputs batched_fwd_inputs; + std::vector inputs; + input_shm_manager->raw_input_read(inputs); + timer.reset(); + // model output variables + torch::Tensor next_tokens; + torch::Tensor logprobs; + torch::Tensor top_tokens; + torch::Tensor top_logprobs; + torch::Tensor embeddings; + torch::Tensor expert_load_data; + int32_t prepared_layer_id = -1; + + // beam search kernel output + torch::Tensor src_seq_idxes; + torch::Tensor out_tokens; + torch::Tensor out_logprobs; + + auto micro_batches_num = inputs.size(); + batched_fwd_inputs.micro_inputs = std::move(inputs); + batched_fwd_inputs.concated_sampling_params = + batched_fwd_inputs.micro_inputs[0].sampling_params; + for (auto i = 1; i < micro_batches_num; ++i) { + batched_fwd_inputs.concated_sampling_params.concat( + batched_fwd_inputs.micro_inputs[i].sampling_params); + } + + // concat acc_logprob here for beam search together + if (micro_batches_num > 1) { + std::vector acc_logprob_vec; + acc_logprob_vec.reserve(micro_batches_num); + for (auto i = 0; i < micro_batches_num; ++i) { + acc_logprob_vec.push_back( + batched_fwd_inputs.micro_inputs[i].acc_logprob); + } + batched_fwd_inputs.acc_logprob = + torch::cat(acc_logprob_vec, /*dim=*/-1); + } else { + batched_fwd_inputs.acc_logprob = + batched_fwd_inputs.micro_inputs[0].acc_logprob; + } + + step(batched_fwd_inputs, + next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs); + + output_shm_manager->raw_output_write(next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs); + COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds()); + } + }); + return; +} + void WorkerService::Hello(::google::protobuf::RpcController* controller, const proto::Status* request, proto::Status* response, @@ -323,9 +494,7 @@ void WorkerService::ExecuteModel( pb_forward_output, done]() mutable { brpc::ClosureGuard done_guard(done); - device_.set_device(); Timer timer; - // convert proto::BatchedForwardInputs to BatchedForwardInputs auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs().size(); BatchedForwardInputs batched_fwd_inputs; @@ -338,7 +507,7 @@ void WorkerService::ExecuteModel( batched_fwd_inputs.micro_inputs.push_back(std::move(forward_input)); } - // concat sampling parameters here for executing sample together + // concat sampling parameters batched_fwd_inputs.concated_sampling_params = batched_fwd_inputs.micro_inputs[0].sampling_params; for (auto i = 1; i < micro_batches_num; ++i) { @@ -373,86 +542,18 @@ void WorkerService::ExecuteModel( torch::Tensor out_tokens; torch::Tensor out_logprobs; - // execute model - auto future = worker_->step_async(batched_fwd_inputs); - - if (!options_.enable_schedule_overlap()) { - auto forward_outputs = std::move(future).get(); - // convert ForwardOutput to proto::ForwardOutput which contain Tokens. - if (forward_outputs) { - DCHECK(forward_outputs.has_value()) << "Failed to execute model"; - const auto& sample_output = forward_outputs.value().sample_output; - const auto& beam_search_output = - forward_outputs.value().beam_search_output; - expert_load_data = safe_to( - forward_outputs.value().expert_load_data, torch::kCPU, true); - prepared_layer_id = forward_outputs.value().prepared_layer_id; - - { - c10::StreamGuard streamGuard = stream_->set_stream_guard(); - // only driver worker (rank=0) need to fill this - // [num_seq, ..., embed_dim] FloatTensor - embeddings = - safe_to(sample_output.embeddings, - torch::dtype(torch::kFloat32).device(torch::kCPU), - true); - - // [num_seq] - next_tokens = safe_to(sample_output.next_tokens, torch::kCPU, true); - if (next_tokens.defined()) { - // [num_seq] - logprobs = safe_to(sample_output.logprobs, torch::kCPU, true); - - if (!beam_search_output.src_seq_idxes.defined()) { - // beam search kernel will provide final tokens/logprobs in beam - // search output, so keep top_tokens/top_logprobs undefined to - // avoid returning them. - // [num_seq, topk] - top_tokens = safe_to(sample_output.top_tokens, torch::kCPU, true); - // [num_seq, topk] - top_logprobs = - safe_to(sample_output.top_logprobs, torch::kCPU, true); - } - } - - // beam search output - // [num_seq] - src_seq_idxes = - safe_to(beam_search_output.src_seq_idxes, torch::kCPU, true); - if (src_seq_idxes.defined()) { - // [num_seq] - out_tokens = - safe_to(beam_search_output.out_tokens, torch::kCPU, true); - // [num_seq] - out_logprobs = - safe_to(beam_search_output.out_logprobs, - torch::dtype(torch::kFloat32).device(torch::kCPU), - true); - } - auto ret = stream_->synchronize(); - } - } - } else { - if (worker_->is_driver()) { - // construct fake output tensor - auto options = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - auto total_prefill_seq_len = 0; - auto total_num_sequences = 0; - for (auto& input : batched_fwd_inputs.micro_inputs) { - total_num_sequences += input.input_params.num_sequences; - total_prefill_seq_len += input.input_params.prefill_seq_len; - } - next_tokens = torch::arange( - -1, - -1 * (total_num_sequences - total_prefill_seq_len + 1), - -1, - options); - std::move(future).deferValue([](auto&&) {}); - } - expert_load_data = torch::zeros({1, 1}).to(torch::kInt64).contiguous(); - } - + step(batched_fwd_inputs, + next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs); + // convert to proto output forward_output_to_proto(next_tokens, logprobs, top_tokens, diff --git a/xllm/core/distributed_runtime/worker_service.h b/xllm/core/distributed_runtime/worker_service.h index 7cfb1f96..ef2566e4 100644 --- a/xllm/core/distributed_runtime/worker_service.h +++ b/xllm/core/distributed_runtime/worker_service.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "runtime/forward_shared_memory_manager.h" #include "runtime/worker.h" #include "worker.pb.h" @@ -33,6 +34,10 @@ class WorkerService : public proto::DistributeWorker { void set_worker(std::unique_ptr worker); + void create_polling_shm_thread( + std::unique_ptr input_shm_manager, + std::unique_ptr output_shm_manager); + // service functions void Hello(::google::protobuf::RpcController* controller, const proto::Status* request, @@ -117,6 +122,17 @@ class WorkerService : public proto::DistributeWorker { ::google::protobuf::Closure* done) override; private: + void step(BatchedForwardInputs& batched_fwd_inputs, + torch::Tensor& next_tokens, + torch::Tensor& logprobs, + torch::Tensor& top_tokens, + torch::Tensor& top_logprobs, + torch::Tensor& embeddings, + torch::Tensor& expert_load_data, + int32_t& prepared_layer_id, + torch::Tensor& src_seq_idxes, + torch::Tensor& out_tokens, + torch::Tensor& out_logprobs); DISALLOW_COPY_AND_ASSIGN(WorkerService); private: @@ -126,10 +142,13 @@ class WorkerService : public proto::DistributeWorker { bool initialized_; Device device_; + std::unique_ptr stream_; std::unique_ptr worker_; + std::unique_ptr polling_thread_; + ThreadPool threadpool_{5}; }; diff --git a/xllm/core/framework/eplb/CMakeLists.txt b/xllm/core/framework/eplb/CMakeLists.txt index 1df1fec7..f6bbc88b 100644 --- a/xllm/core/framework/eplb/CMakeLists.txt +++ b/xllm/core/framework/eplb/CMakeLists.txt @@ -15,14 +15,12 @@ cc_library( eplb_manager.h eplb_policy.h expert_weight_buffer_shm.h - shared_memory_manager.h expert_buffer_manager.h SRCS eplb_executor.cpp eplb_manager.cpp eplb_policy.cpp expert_weight_buffer_shm.cpp - shared_memory_manager.cpp expert_buffer_manager.cpp DEPS :request diff --git a/xllm/core/framework/eplb/expert_weight_buffer_shm.h b/xllm/core/framework/eplb/expert_weight_buffer_shm.h index dbc017cb..ad0b3184 100644 --- a/xllm/core/framework/eplb/expert_weight_buffer_shm.h +++ b/xllm/core/framework/eplb/expert_weight_buffer_shm.h @@ -24,7 +24,7 @@ limitations under the License. #include #include -#include "shared_memory_manager.h" +#include "util/shared_memory_manager.h" namespace xllm { diff --git a/xllm/core/runtime/CMakeLists.txt b/xllm/core/runtime/CMakeLists.txt index a606353a..54b10152 100644 --- a/xllm/core/runtime/CMakeLists.txt +++ b/xllm/core/runtime/CMakeLists.txt @@ -30,6 +30,7 @@ cc_library( xservice_client.h speculative_engine.h speculative_worker_impl.h + forward_shared_memory_manager.h SRCS executor.cpp base_executor_impl.cpp @@ -50,6 +51,7 @@ cc_library( params_utils.cpp speculative_engine.cpp speculative_worker_impl.cpp + forward_shared_memory_manager.cpp DEPS torch $<$:torch_npu> diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp new file mode 100644 index 00000000..cd3ee268 --- /dev/null +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -0,0 +1,1000 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://github.com/jd-opensource/xllm/blob/main/LICENSE +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "forward_shared_memory_manager.h" + +#include + +#include +#include + +#include "core/common/global_flags.h" +#include "core/util/net.h" +#include "util/utils.h" + +#define INLINE __attribute__((inline)) + +#if defined(__GNUC__) +static inline bool(likely)(bool x) { return __builtin_expect((x), true); } +static inline bool(unlikely)(bool x) { return __builtin_expect((x), false); } +#else +static inline bool(likely)(bool x) { return x; } +static inline bool(unlikely)(bool x) { return x; } +#endif + +namespace xllm { + +template +constexpr size_t type_size = sizeof(T); + +constexpr size_t sampling_param_fixed_size() { + return 5 * type_size // frequency_penalty, presence_penalty, + // repetition_penalty, temperature, top_p + + 2 * type_size // top_k, top_logprobs + + 3 * type_size // logprobs, do_sample, is_embeddings + + type_size; // beam_width +} + +constexpr size_t cache_block_info_fixed_size() { + return type_size * 2 + + 16; // device_block_id + host_block_id + hash_key +} + +INLINE size_t get_string_size(const std::string& str) { + return type_size + str.size(); +} + +template +INLINE size_t get_vector_size(const std::vector& vec) { + return type_size + vec.size() * type_size; +} + +template +INLINE size_t get_2d_vector_size(const std::vector>& vec2d) { + size_t size = type_size; + for (const auto& vec : vec2d) { + size += get_vector_size(vec); + } + return size; +} + +INLINE size_t get_instance_info_size(const InstanceInfo& info) { + size_t size = get_string_size(info.name) + get_string_size(info.rpc_address) + + get_string_size(info.type); + + size += type_size + info.cluster_ids.size() * type_size; + + size += type_size; + for (const auto& addr : info.addrs) { + size += get_string_size(addr); + } + + size += type_size + info.k_cache_ids.size() * type_size + + type_size + info.v_cache_ids.size() * type_size + + type_size // dp_size + + type_size + + info.ttft_profiling_data.size() * + (type_size + type_size); + + return size; +} + +INLINE size_t get_transfer_kv_info_size(const TransferKVInfo& info) { + return get_string_size(info.request_id) + + get_vector_size(info.local_blocks_ids) + + get_vector_size(info.remote_blocks_ids) + + type_size // dp_rank + + get_instance_info_size(info.remote_instance_info); +} + +INLINE size_t get_eplb_info_size(const EplbInfo& info) { + return type_size // prepare_layer_id + + get_vector_size(info.expert_ids) + + type_size; // update_layer_id +} + +INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { + size_t total = 0; + + const auto* vec1d = &input.flatten_tokens_vec; + total += get_vector_size(*vec1d++); // flatten_tokens_vec + total += get_vector_size(*vec1d++); // flatten_positions_vec + total += get_vector_size(input.selected_token_idxes); + total += get_vector_size(input.sample_idxes); + total += get_vector_size(input.unique_token_lens_vec); + total += get_vector_size(input.seq_lens); + total += get_vector_size(input.q_seq_lens); + total += get_vector_size(input.new_token_slot_ids); + total += get_vector_size(input.dp_global_token_nums); + total += get_vector_size(input.embedding_ids); + total += get_vector_size(input.src_block_indices); + total += get_vector_size(input.dst_block_indices); + total += get_vector_size(input.cum_sum); + total += get_vector_size(input.new_cache_slot_offsets); + total += get_vector_size(input.kv_cache_start_offsets); + total += get_vector_size(input.acc_logprob_vec); + total += get_vector_size(input.extra_token_ids); + + total += get_2d_vector_size(input.unique_token_ids_vec); + total += get_2d_vector_size(input.unique_token_counts_vec); + total += get_2d_vector_size(input.block_tables_vec); + total += get_2d_vector_size(input.embeddings); + + total += type_size + + input.sampling_params.size() * sampling_param_fixed_size(); + + total += type_size; + for (const auto& t : input.transfer_kv_infos) { + total += get_transfer_kv_info_size(t); + } + + const size_t cache_block_size = + input.async_copy_out_blocks.size() + input.copy_out_blocks.size() + + input.copy_in_blocks.size() + input.swap_blocks.size(); + total += type_size * 4 + + cache_block_size * cache_block_info_fixed_size(); + + total += type_size * 2 // empty_kv_cache + global_empty_kv_cache + + type_size * + 3 // max_seq_len + q_max_seq_len + prefill_seq_len + + type_size // num_sequences + + get_eplb_info_size(input.eplb_info); + + return total; +} + +template +INLINE void write_data(char*& buffer, const T& data) { + *reinterpret_cast(buffer) = data; + buffer += type_size; +} + +INLINE void write_string(char*& buffer, const std::string& str) { + const uint64_t len = str.size(); + write_data(buffer, len); + if (len > 0) { + std::memcpy(buffer, str.data(), len); + buffer += len; + } +} + +INLINE void write_sampling_param(char*& buffer, + const RequestSamplingParam& param) { + char* ptr = buffer; + *reinterpret_cast(ptr) = param.frequency_penalty; + ptr += type_size; + *reinterpret_cast(ptr) = param.presence_penalty; + ptr += type_size; + *reinterpret_cast(ptr) = param.repetition_penalty; + ptr += type_size; + *reinterpret_cast(ptr) = param.temperature; + ptr += type_size; + *reinterpret_cast(ptr) = param.top_p; + ptr += type_size; + *reinterpret_cast(ptr) = param.top_k; + ptr += type_size; + *reinterpret_cast(ptr) = param.logprobs; + ptr += type_size; + *reinterpret_cast(ptr) = param.top_logprobs; + ptr += type_size; + *reinterpret_cast(ptr) = param.do_sample; + ptr += type_size; + *reinterpret_cast(ptr) = param.is_embeddings; + ptr += type_size; + *reinterpret_cast(ptr) = param.beam_width; + ptr += type_size; + buffer = ptr; +} + +template +INLINE void write_vector(char*& buffer, const std::vector& vec) { + const uint64_t size = vec.size(); + write_data(buffer, size); + if (size > 0) { + const size_t bytes = size * type_size; + std::memcpy(buffer, vec.data(), bytes); + buffer += bytes; + } +} + +template +INLINE void write_2d_vector(char*& buffer, + const std::vector>& vec2d) { + write_data(buffer, (uint64_t)vec2d.size()); + for (const auto& vec : vec2d) { + write_vector(buffer, vec); + } +} + +INLINE void write_instance_info(char*& buffer, const InstanceInfo& info) { + write_string(buffer, info.name); + write_string(buffer, info.rpc_address); + write_string(buffer, info.type); + + write_vector(buffer, info.cluster_ids); + + write_data(buffer, (uint64_t)info.addrs.size()); + for (const auto& addr : info.addrs) { + write_string(buffer, addr); + } + + write_vector(buffer, info.k_cache_ids); + write_vector(buffer, info.v_cache_ids); + write_data(buffer, info.dp_size); + + const uint64_t prof_size = info.ttft_profiling_data.size(); + write_data(buffer, prof_size); + if (prof_size > 0) { + std::memcpy(buffer, + info.ttft_profiling_data.data(), + prof_size * sizeof(std::pair)); + buffer += prof_size * sizeof(std::pair); + } +} + +INLINE void write_transfer_kv_info(char*& buffer, const TransferKVInfo& info) { + write_string(buffer, info.request_id); + write_vector(buffer, info.local_blocks_ids); + write_vector(buffer, info.remote_blocks_ids); + write_data(buffer, info.dp_rank); + write_instance_info(buffer, info.remote_instance_info); +} + +INLINE void write_eplb_info(char*& buffer, const EplbInfo& info) { + write_data(buffer, info.prepare_layer_id); + write_vector(buffer, info.expert_ids); + write_data(buffer, info.update_layer_id); +} + +INLINE void write_cache_block_info(char*& buffer, const CacheBlockInfo& info) { + *reinterpret_cast(buffer) = info.device_block_id; + *reinterpret_cast(buffer + 4) = info.host_block_id; + if (info.hash_key) { + std::memcpy(buffer + 8, info.hash_key, 16); + } else { + std::memset(buffer + 8, 0, 16); + } + buffer += cache_block_info_fixed_size(); +} + +INLINE void write_cache_blocks(char*& buffer, + const std::vector& blocks) { + write_data(buffer, (uint64_t)blocks.size()); + if constexpr (sizeof(CacheBlockInfo) == cache_block_info_fixed_size()) { + if (!blocks.empty()) { + std::memcpy( + buffer, blocks.data(), blocks.size() * cache_block_info_fixed_size()); + buffer += blocks.size() * cache_block_info_fixed_size(); + } + } else { + for (const auto& b : blocks) { + write_cache_block_info(buffer, b); + } + } +} + +template +INLINE void read_data(const char*& buffer, T& data) { + data = *reinterpret_cast(buffer); + buffer += type_size; +} + +INLINE void read_string(const char*& buffer, std::string& str) { + uint64_t len; + read_data(buffer, len); + if (len > 0) { + str.assign(buffer, len); + buffer += len; + } else { + str.clear(); + } +} + +template +INLINE void read_vector(const char*& buffer, std::vector& vec) { + uint64_t size; + read_data(buffer, size); + vec.resize(size); + if (size > 0) { + const size_t bytes = size * type_size; + std::memcpy(vec.data(), buffer, bytes); + buffer += bytes; + } +} + +template +INLINE void read_2d_vector(const char*& buffer, + std::vector>& vec2d) { + uint64_t size; + read_data(buffer, size); + vec2d.resize(size); + for (auto& vec : vec2d) { + read_vector(buffer, vec); + } +} + +INLINE void read_sampling_param(const char*& buffer, + RequestSamplingParam& param) { + const char* ptr = buffer; + param.frequency_penalty = *reinterpret_cast(ptr); + ptr += type_size; + param.presence_penalty = *reinterpret_cast(ptr); + ptr += type_size; + param.repetition_penalty = *reinterpret_cast(ptr); + ptr += type_size; + param.temperature = *reinterpret_cast(ptr); + ptr += type_size; + param.top_p = *reinterpret_cast(ptr); + ptr += type_size; + param.top_k = *reinterpret_cast(ptr); + ptr += type_size; + param.logprobs = *reinterpret_cast(ptr); + ptr += type_size; + param.top_logprobs = *reinterpret_cast(ptr); + ptr += type_size; + param.do_sample = *reinterpret_cast(ptr); + ptr += type_size; + param.is_embeddings = *reinterpret_cast(ptr); + ptr += type_size; + param.beam_width = *reinterpret_cast(ptr); + ptr += type_size; + buffer = ptr; +} + +INLINE void read_instance_info(const char*& buffer, InstanceInfo& info) { + read_string(buffer, info.name); + read_string(buffer, info.rpc_address); + read_string(buffer, info.type); + + read_vector(buffer, info.cluster_ids); + + uint64_t addr_count; + read_data(buffer, addr_count); + info.addrs.resize(addr_count); + for (auto& addr : info.addrs) { + read_string(buffer, addr); + } + + read_vector(buffer, info.k_cache_ids); + read_vector(buffer, info.v_cache_ids); + read_data(buffer, info.dp_size); + + uint64_t prof_size; + read_data(buffer, prof_size); + info.ttft_profiling_data.resize(prof_size); + if (prof_size > 0) { + std::memcpy(info.ttft_profiling_data.data(), + buffer, + prof_size * sizeof(std::pair)); + buffer += prof_size * sizeof(std::pair); + } +} + +INLINE void read_transfer_kv_info(const char*& buffer, TransferKVInfo& info) { + read_string(buffer, info.request_id); + read_vector(buffer, info.local_blocks_ids); + read_vector(buffer, info.remote_blocks_ids); + read_data(buffer, info.dp_rank); + read_instance_info(buffer, info.remote_instance_info); +} + +INLINE void read_eplb_info(const char*& buffer, EplbInfo& info) { + read_data(buffer, info.prepare_layer_id); + read_vector(buffer, info.expert_ids); + read_data(buffer, info.update_layer_id); +} + +INLINE void read_cache_block_info(const char*& buffer, CacheBlockInfo& info) { + info.device_block_id = *reinterpret_cast(buffer); + info.host_block_id = *reinterpret_cast(buffer + 4); + // notice: a temporary pointer in the buffer is stored here + info.hash_key = + const_cast(reinterpret_cast(buffer + 8)); + buffer += 8 + 16; +} + +INLINE void read_cache_blocks(const char*& buffer, + std::vector& blocks) { + uint64_t size; + read_data(buffer, size); + blocks.resize(size); + for (auto& block : blocks) { + read_cache_block_info(buffer, block); + } +} + +INLINE void deserialize_raw_forward_input( + const char*& buffer, + RawForwardInput& input, + std::vector& tmp_sampling_params) { + read_vector(buffer, input.flatten_tokens_vec); + read_vector(buffer, input.flatten_positions_vec); + + uint64_t sp_count; + read_data(buffer, sp_count); + input.sampling_params.reserve(sp_count); + tmp_sampling_params.resize(sp_count); + for (size_t i = 0; i < sp_count; ++i) { + read_sampling_param(buffer, tmp_sampling_params[i]); + input.sampling_params.push_back(&tmp_sampling_params[i]); + } + + read_vector(buffer, input.selected_token_idxes); + read_vector(buffer, input.sample_idxes); + read_vector(buffer, input.unique_token_lens_vec); + read_vector(buffer, input.seq_lens); + read_vector(buffer, input.q_seq_lens); + read_vector(buffer, input.new_token_slot_ids); + read_vector(buffer, input.dp_global_token_nums); + read_vector(buffer, input.embedding_ids); + read_vector(buffer, input.src_block_indices); + read_vector(buffer, input.dst_block_indices); + read_vector(buffer, input.cum_sum); + read_vector(buffer, input.new_cache_slot_offsets); + read_vector(buffer, input.kv_cache_start_offsets); + read_vector(buffer, input.extra_token_ids); + read_vector(buffer, input.acc_logprob_vec); + + read_2d_vector(buffer, input.unique_token_ids_vec); + read_2d_vector(buffer, input.unique_token_counts_vec); + read_2d_vector(buffer, input.block_tables_vec); + read_2d_vector(buffer, input.embeddings); + + uint64_t transfer_count; + read_data(buffer, transfer_count); + input.transfer_kv_infos.resize(transfer_count); + for (auto& transfer : input.transfer_kv_infos) { + read_transfer_kv_info(buffer, transfer); + } + + read_cache_blocks(buffer, input.async_copy_out_blocks); + read_cache_blocks(buffer, input.copy_out_blocks); + read_cache_blocks(buffer, input.copy_in_blocks); + read_cache_blocks(buffer, input.swap_blocks); + + read_data(buffer, input.empty_kv_cache); + read_data(buffer, input.global_empty_kv_cache); + read_data(buffer, input.max_seq_len); + read_data(buffer, input.q_max_seq_len); + read_data(buffer, input.num_sequences); + read_eplb_info(buffer, input.eplb_info); + read_data(buffer, input.prefill_seq_len); +} + +INLINE void serialize_raw_forward_input(const RawForwardInput& input, + char*& buffer) { + write_vector(buffer, input.flatten_tokens_vec); + write_vector(buffer, input.flatten_positions_vec); + + const uint64_t sp_count = input.sampling_params.size(); + write_data(buffer, sp_count); + + for (const auto* sp : input.sampling_params) { + write_sampling_param(buffer, *sp); + } + + write_vector(buffer, input.selected_token_idxes); + write_vector(buffer, input.sample_idxes); + write_vector(buffer, input.unique_token_lens_vec); + write_vector(buffer, input.seq_lens); + write_vector(buffer, input.q_seq_lens); + write_vector(buffer, input.new_token_slot_ids); + write_vector(buffer, input.dp_global_token_nums); + write_vector(buffer, input.embedding_ids); + write_vector(buffer, input.src_block_indices); + write_vector(buffer, input.dst_block_indices); + write_vector(buffer, input.cum_sum); + write_vector(buffer, input.new_cache_slot_offsets); + write_vector(buffer, input.kv_cache_start_offsets); + write_vector(buffer, input.extra_token_ids); + write_vector(buffer, input.acc_logprob_vec); + + write_2d_vector(buffer, input.unique_token_ids_vec); + write_2d_vector(buffer, input.unique_token_counts_vec); + write_2d_vector(buffer, input.block_tables_vec); + write_2d_vector(buffer, input.embeddings); + + write_data(buffer, (uint64_t)input.transfer_kv_infos.size()); + for (const auto& t : input.transfer_kv_infos) { + write_transfer_kv_info(buffer, t); + } + + write_cache_blocks(buffer, input.async_copy_out_blocks); + write_cache_blocks(buffer, input.copy_out_blocks); + write_cache_blocks(buffer, input.copy_in_blocks); + write_cache_blocks(buffer, input.swap_blocks); + + *reinterpret_cast(buffer) = input.empty_kv_cache; + buffer += 1; + *reinterpret_cast(buffer) = input.global_empty_kv_cache; + buffer += 1; + *reinterpret_cast(buffer) = input.max_seq_len; + buffer += 4; + *reinterpret_cast(buffer) = input.q_max_seq_len; + buffer += 4; + *reinterpret_cast(buffer) = input.num_sequences; + buffer += 4; + write_eplb_info(buffer, input.eplb_info); + *reinterpret_cast(buffer) = input.prefill_seq_len; + buffer += 4; +} + +size_t calculate_raw_token_size(const RawToken& token) { + size_t size = type_size; // id + + size += type_size; + if (token.logprob.has_value()) { + size += type_size; + } + + size += type_size + token.top_tokens.size() * type_size; + size += type_size + token.top_logprobs.size() * type_size; + size += type_size + token.embeddings.size() * type_size; + + return size; +} + +size_t calculate_raw_sample_output_size(const RawSampleOutput& sample) { + size_t size = type_size; + for (const auto& token : sample.tokens) { + size += calculate_raw_token_size(token); + } + return size; +} + +size_t calculate_raw_forward_output_size(const RawForwardOutput& output) { + size_t size = 0; + + size += type_size; + for (const auto& sample : output.outputs) { + size += calculate_raw_sample_output_size(sample); + } + + size += get_vector_size(output.expert_load_data); + size += get_vector_size(output.src_seq_idxes); + size += get_vector_size(output.out_tokens); + size += get_vector_size(output.out_logprobs); + size += type_size; // prepared_layer_id + + return size; +} + +void write_raw_token(char*& buffer, const RawToken& token) { + write_data(buffer, token.id); + + write_data(buffer, token.logprob.has_value()); + if (token.logprob.has_value()) { + write_data(buffer, token.logprob.value()); + } + + write_vector(buffer, token.top_tokens); + write_vector(buffer, token.top_logprobs); + write_vector(buffer, token.embeddings); +} + +void write_raw_sample_output(char*& buffer, const RawSampleOutput& sample) { + write_data(buffer, static_cast(sample.tokens.size())); + for (const auto& token : sample.tokens) { + write_raw_token(buffer, token); + } +} + +void read_raw_token(const char*& buffer, RawToken& token) { + read_data(buffer, token.id); + + bool has_logprob; + read_data(buffer, has_logprob); + if (has_logprob) { + float logprob_val; + read_data(buffer, logprob_val); + token.logprob = logprob_val; + } else { + token.logprob = std::nullopt; + } + + read_vector(buffer, token.top_tokens); + read_vector(buffer, token.top_logprobs); + read_vector(buffer, token.embeddings); +} + +void read_raw_sample_output(const char*& buffer, RawSampleOutput& sample) { + uint64_t token_count; + read_data(buffer, token_count); + sample.tokens.resize(token_count); + for (auto& token : sample.tokens) { + read_raw_token(buffer, token); + } +} + +void deserialize_raw_forward_output(const char* buffer, + RawForwardOutput& output) { + uint64_t outputs_count; + read_data(buffer, outputs_count); + output.outputs.resize(outputs_count); + for (auto& sample : output.outputs) { + read_raw_sample_output(buffer, sample); + } + + read_vector(buffer, output.expert_load_data); + + read_data(buffer, output.prepared_layer_id); +} + +void serialize_raw_forward_output(const RawForwardOutput& output, + char*& buffer) { + write_data(buffer, static_cast(output.outputs.size())); + for (const auto& sample : output.outputs) { + write_raw_sample_output(buffer, sample); + } + + write_vector(buffer, output.expert_load_data); + + write_data(buffer, output.prepared_layer_id); +} + +ForwardSharedMemoryManager::ForwardSharedMemoryManager(const std::string& name, + size_t size, + bool& is_creator, + ForwardType type) + : SharedMemoryManager(name, size, is_creator), forward_type_(type) { + control_ptr_ = static_cast(base_address()); + metadata_addr_ = static_cast(base_address()) + sizeof(ControlMetadata); +} + +ForwardSharedMemoryManager::~ForwardSharedMemoryManager() = default; + +/* The shared memory filename may have duplicates when using kill -9 xllm, but + this doesn't affect usage.*/ +std::string ForwardSharedMemoryManager::create_unique_name(int dp_group, + int forward_type, + int rank) { + std::string filename = "xllm_" + net::extract_port(FLAGS_master_node_addr); + if (forward_type == FORWARD_PB_INPUT_TYPE || + forward_type == FORWARD_RAW_INPUT_TYPE) { + filename += "_dpg_" + std::to_string(dp_group) + "_input"; + } else if (forward_type == FORWARD_PB_OUTPUT_TYPE || + forward_type == FORWARD_RAW_OUTPUT_TYPE) { + filename += "_rank_" + std::to_string(rank) + "_output"; + } else { + // TODO: support more type later + } + + return filename; +} + +bool ForwardSharedMemoryManager::raw_input_write( + const std::vector& inputs) { + uint64_t total_size = sizeof(ControlMetadata); + for (const auto& input : inputs) { + total_size += calculate_raw_forward_input_size(input); + } + if (unlikely(total_size > size())) { + LOG(ERROR) << "raw input size overflow, total_size: " << total_size + << ", shm size: " << size(); + return false; + } + + char* data_ptr = static_cast(base_address()) + sizeof(ControlMetadata); + write_data(data_ptr, static_cast(inputs.size())); + for (const auto& input : inputs) { + serialize_raw_forward_input(input, data_ptr); + } + std::atomic_thread_fence(std::memory_order_release); + control_ptr_->version = ++last_version_; + + return true; +} + +void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, + ForwardInput& forward_input) { + auto tensor_options = torch::TensorOptions() + .dtype(torch::kInt) + .device(torch::kCPU) + .pinned_memory(true); + + forward_input.token_ids = + torch::tensor(std::move(raw_input.flatten_tokens_vec), tensor_options); + forward_input.positions = + torch::tensor(std::move(raw_input.flatten_positions_vec), tensor_options); + + std::pair decode_seq_range{0, 0}; +#if defined(USE_NPU) + if (raw_input.q_seq_lens.size() >= 1) { + decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens); + } +#endif + auto& input_params = forward_input.input_params; + input_params.empty_kv_cache = raw_input.empty_kv_cache; + input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache; + input_params.num_sequences = raw_input.num_sequences; + input_params.kv_max_seq_len = raw_input.max_seq_len; + input_params.q_max_seq_len = raw_input.q_max_seq_len; + input_params.prefill_seq_len = raw_input.prefill_seq_len; + input_params.embedding_ids = std::move(raw_input.embedding_ids); + input_params.dp_global_token_nums = std::move(raw_input.dp_global_token_nums); + + input_params.kv_seq_lens = + torch::tensor(std::move(raw_input.seq_lens), tensor_options); + input_params.q_seq_lens = + torch::tensor(std::move(raw_input.q_seq_lens), tensor_options); + input_params.kv_seq_lens_vec = std::move(raw_input.seq_lens); + input_params.q_seq_lens_vec = std::move(raw_input.q_seq_lens); + + input_params.new_cache_slots = + torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options); + input_params.decode_seq_range = decode_seq_range; + util::pad_2d_vector(raw_input.block_tables_vec, 0); + input_params.block_tables = + create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt); + + input_params.src_block_indices = + torch::tensor(std::move(raw_input.src_block_indices), tensor_options); + input_params.dst_block_indices = + torch::tensor(std::move(raw_input.dst_block_indices), tensor_options); + input_params.cum_sum = + torch::tensor(std::move(raw_input.cum_sum), tensor_options); + + input_params.async_copy_out_blocks = + std::move(raw_input.async_copy_out_blocks); + input_params.copy_out_blocks = std::move(raw_input.copy_out_blocks); + input_params.copy_in_blocks = std::move(raw_input.copy_in_blocks); + input_params.swap_blocks = std::move(raw_input.swap_blocks); + input_params.extra_token_ids = std::move(raw_input.extra_token_ids); + + input_params.new_cache_slot_offsets = torch::tensor( + std::move(raw_input.new_cache_slot_offsets), tensor_options); + input_params.kv_cache_start_offsets = torch::tensor( + std::move(raw_input.kv_cache_start_offsets), tensor_options); + if (!raw_input.embeddings.empty()) { + torch::Tensor embeddings = + create_2d_tensor(std::move(raw_input.embeddings), torch::kBFloat16); + input_params.mm_data = + MMData(MMType::EMBEDDING, {{"embedding", embeddings}}); + } + + if (!raw_input.selected_token_idxes.empty()) { + util::pad_2d_vector(raw_input.unique_token_ids_vec, 0); + util::pad_2d_vector(raw_input.unique_token_counts_vec, 0); + forward_input.sampling_params.init( + std::move(raw_input.sampling_params), + std::move(raw_input.selected_token_idxes), + std::move(raw_input.sample_idxes), + std::move(raw_input.unique_token_ids_vec), + std::move(raw_input.unique_token_counts_vec), + std::move(raw_input.unique_token_lens_vec)); + } + + forward_input.acc_logprob = torch::tensor( + std::move(raw_input.acc_logprob_vec), + torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true)); + forward_input.transfer_kv_infos = std::move(raw_input.transfer_kv_infos); + forward_input.eplb_info = std::move(raw_input.eplb_info); +} + +void ForwardSharedMemoryManager::raw_input_read( + std::vector& inputs) { + while (true) { + if (control_ptr_->version != last_version_) { + last_version_ = control_ptr_->version; + break; + } + std::this_thread::sleep_for(std::chrono::nanoseconds(NUM_WAIT_NANOSECONDS)); + } + + const char* data_ptr = + static_cast(base_address()) + sizeof(ControlMetadata); + uint64_t count; + read_data(data_ptr, count); + + std::vector> tmp_sampling_params; + std::vector raw_inputs; + tmp_sampling_params.resize(count); + raw_inputs.resize(count); + for (uint64_t i = 0; i < count; ++i) { + deserialize_raw_forward_input( + data_ptr, raw_inputs[i], tmp_sampling_params[i]); + } + + // convert raw forward input to forward input + inputs.resize(raw_inputs.size()); + for (uint64_t i = 0; i < count; ++i) { + convert_raw_forward_input_to_forward_input(raw_inputs[i], inputs[i]); + } + + return; +} + +void convert_tensor_to_raw_output(const torch::Tensor& next_tokens, + const torch::Tensor& logprobs, + const torch::Tensor& top_tokens, + const torch::Tensor& top_logprobs, + const torch::Tensor& embeddings, + const torch::Tensor& expert_load_data, + int32_t prepared_layer_id, + const torch::Tensor& src_seq_idxes, + const torch::Tensor& out_tokens, + const torch::Tensor& out_logprobs, + RawForwardOutput& raw_output) { + raw_output.prepared_layer_id = prepared_layer_id; + + if (FLAGS_enable_eplb) { + torch::Tensor expert_load_data_flattened = + expert_load_data.view({-1}).contiguous(); + if (expert_load_data_flattened.defined()) { + const int64_t* data_ptr = expert_load_data_flattened.data_ptr(); + size_t size = static_cast(expert_load_data_flattened.size(0)); + raw_output.expert_load_data.assign(data_ptr, data_ptr + size); + } + } + + if (src_seq_idxes.defined() && src_seq_idxes.numel() > 0) { + const int32_t* data_ptr = src_seq_idxes.data_ptr(); + size_t size = static_cast(src_seq_idxes.size(0)); + raw_output.src_seq_idxes.assign(data_ptr, data_ptr + size); + } + + if (out_tokens.defined() && out_tokens.numel() > 0) { + const int32_t* data_ptr = out_tokens.data_ptr(); + size_t size = static_cast(out_tokens.size(0)); + raw_output.out_tokens.assign(data_ptr, data_ptr + size); + } + + if (out_logprobs.defined() && out_logprobs.numel() > 0) { + const float* data_ptr = out_logprobs.data_ptr(); + size_t size = static_cast(out_logprobs.size(0)); + raw_output.out_logprobs.assign(data_ptr, data_ptr + size); + } + + int32_t num_seqs = + next_tokens.defined() ? static_cast(next_tokens.size(0)) : 0; + if (embeddings.defined() && embeddings.numel() > 0) { + num_seqs = std::max(num_seqs, static_cast(embeddings.size(0))); + } + + raw_output.outputs.reserve(num_seqs); + for (int32_t output_idx = 0; output_idx < num_seqs; ++output_idx) { + RawSampleOutput raw_sample_output; + + if (next_tokens.defined() && next_tokens.dim() == 2) { + const auto curr_idx = output_idx; + const auto curr_next_tokens = next_tokens[curr_idx]; + const auto curr_logprobs = + logprobs.defined() ? logprobs[curr_idx] : logprobs; + const auto curr_top_tokens = + top_tokens.defined() ? top_tokens[curr_idx] : top_tokens; + const auto curr_top_logprobs = + top_logprobs.defined() ? top_logprobs[curr_idx] : top_logprobs; + const auto curr_embeddings = + embeddings.defined() ? embeddings[curr_idx] : embeddings; + + int32_t num_tokens = curr_next_tokens.size(0); + raw_sample_output.tokens.reserve(num_tokens); + + for (int32_t i = 0; i < num_tokens; ++i) { + const Token token = build_token(i, + curr_next_tokens, + curr_logprobs, + curr_top_tokens, + curr_top_logprobs); + if (token.id == -1) { + break; + } + + RawToken raw_token; + raw_token.id = token.id; + raw_token.logprob = token.logprob; + raw_token.top_tokens = token.top_tokens; + raw_token.top_logprobs = token.top_logprobs; + + if (curr_embeddings.defined()) { + const auto token_embeddings = curr_embeddings[i]; + if (token_embeddings.defined()) { + const float* emb_ptr = token_embeddings.data_ptr(); + size_t emb_size = static_cast(token_embeddings.size(0)); + raw_token.embeddings.assign(emb_ptr, emb_ptr + emb_size); + } + } + + raw_sample_output.tokens.push_back(std::move(raw_token)); + } + } else { + RawToken raw_token; + + if (next_tokens.defined() && next_tokens.numel() > 0) { + const Token token = build_token( + output_idx, next_tokens, logprobs, top_tokens, top_logprobs); + raw_token.id = token.id; + raw_token.logprob = token.logprob; + raw_token.top_tokens = std::move(token.top_tokens); + raw_token.top_logprobs = std::move(token.top_logprobs); + } else { + raw_token.id = -1; + raw_token.logprob = std::nullopt; + } + + if (embeddings.defined()) { + const auto token_embeddings = embeddings[output_idx]; + if (token_embeddings.defined()) { + const float* emb_ptr = token_embeddings.data_ptr(); + size_t emb_size = static_cast(token_embeddings.size(0)); + raw_token.embeddings.assign(emb_ptr, emb_ptr + emb_size); + } + } + + raw_sample_output.tokens.push_back(std::move(raw_token)); + } + raw_output.outputs.push_back(std::move(raw_sample_output)); + } +} + +bool ForwardSharedMemoryManager::raw_output_write( + const torch::Tensor& next_tokens, + const torch::Tensor& logprobs, + const torch::Tensor& top_tokens, + const torch::Tensor& top_logprobs, + const torch::Tensor& embeddings, + const torch::Tensor& expert_load_data, + int32_t prepared_layer_id, + const torch::Tensor& src_seq_idxes, + const torch::Tensor& out_tokens, + const torch::Tensor& out_logprobs) { + RawForwardOutput output; + convert_tensor_to_raw_output(next_tokens, + logprobs, + top_tokens, + top_logprobs, + embeddings, + expert_load_data, + prepared_layer_id, + src_seq_idxes, + out_tokens, + out_logprobs, + output); + uint64_t total_size = sizeof(ControlMetadata); + total_size += calculate_raw_forward_output_size(output); + if (unlikely(total_size > size())) { + LOG(ERROR) << "raw output size overflow, total_size: " << total_size + << ", shm size: " << size(); + return false; + } + + char* data_ptr = static_cast(base_address()) + sizeof(ControlMetadata); + serialize_raw_forward_output(output, data_ptr); + char* test = static_cast(base_address()) + sizeof(ControlMetadata); + std::atomic_thread_fence(std::memory_order_release); + control_ptr_->version = ++last_version_; + + return true; +} + +void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) { + while (true) { + if (control_ptr_->version != last_version_) { + last_version_ = control_ptr_->version; + break; + } + std::this_thread::sleep_for(std::chrono::nanoseconds(NUM_WAIT_NANOSECONDS)); + } + + const char* data_ptr = + static_cast(base_address()) + sizeof(ControlMetadata); + char* test = static_cast(base_address()) + sizeof(ControlMetadata); + deserialize_raw_forward_output(data_ptr, output); + + return; +} + +void ForwardSharedMemoryManager::clear() { + std::memset(base_address(), 0, size()); +} +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/forward_shared_memory_manager.h b/xllm/core/runtime/forward_shared_memory_manager.h new file mode 100644 index 00000000..f94b5b96 --- /dev/null +++ b/xllm/core/runtime/forward_shared_memory_manager.h @@ -0,0 +1,124 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://github.com/jd-opensource/xllm/blob/main/LICENSE +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "forward_params.h" +#include "params_utils.h" +#include "util/shared_memory_manager.h" + +#define PB_INPUT_SHM_SIZE (1024 * 1024 * 1024) // 1GB +#define PB_OUTPUT_SHM_SIZE (128 * 1024 * 1024) // 128MB +#define NUM_WAIT_NANOSECONDS (1000) // 1us + +namespace xllm { + +struct ControlMetadata { + volatile uint64_t version; +}; + +struct PbMetadata { + uint64_t pb_size; +}; + +enum ForwardType : int { + FORWARD_PB_INPUT_TYPE = 1, + FORWARD_PB_OUTPUT_TYPE = 2, + FORWARD_RAW_INPUT_TYPE = 3, + FORWARD_RAW_OUTPUT_TYPE = 4, +}; + +class ForwardSharedMemoryManager : public SharedMemoryManager { + public: + explicit ForwardSharedMemoryManager(const std::string& name, + size_t size, + bool& is_creator, + ForwardType type); + ~ForwardSharedMemoryManager(); + static std::string create_unique_name(int dp_group, + int forward_type, + int rank); + + template + bool pb_write(const PbType* pb_data) { + size_t data_size = pb_data->ByteSizeLong(); + if (data_size + sizeof(ControlMetadata) + sizeof(PbMetadata) > size()) { + LOG(ERROR) << "pb size overflow, data_size: " << data_size + << ", shm size: " << size(); + return false; + } + + auto metadata = reinterpret_cast(metadata_addr_); + metadata->pb_size = data_size; + + auto data_ptr = + reinterpret_cast(metadata_addr_) + sizeof(PbMetadata); + if (!pb_data->SerializeToArray(data_ptr, data_size)) { + LOG(ERROR) << "Failed to serialize protobuf data to shared memory"; + return false; + } + + std::atomic_thread_fence(std::memory_order_release); + control_ptr_->version = ++last_version_; + + return true; + }; + + template + bool pb_read(PbType& pb_data) { + while (true) { + if (control_ptr_->version != last_version_) { + last_version_ = control_ptr_->version; + break; + } + std::this_thread::sleep_for( + std::chrono::nanoseconds(NUM_WAIT_NANOSECONDS)); + } + + auto metadata = reinterpret_cast(metadata_addr_); + auto data_ptr = + reinterpret_cast(metadata_addr_) + sizeof(PbMetadata); + size_t pb_size = metadata->pb_size; + if (!pb_data.ParseFromArray(data_ptr, pb_size)) { + LOG(ERROR) << "Failed to parse pb data from shared memory"; + return false; + } + + return true; + }; + + bool raw_input_write(const std::vector& inputs); + void raw_input_read(std::vector& inputs); + bool raw_output_write(const torch::Tensor& next_tokens, + const torch::Tensor& logprobs, + const torch::Tensor& top_tokens, + const torch::Tensor& top_logprobs, + const torch::Tensor& embeddings, + const torch::Tensor& expert_load_data, + int32_t prepared_layer_id, + const torch::Tensor& src_seq_idxes, + const torch::Tensor& out_tokens, + const torch::Tensor& out_logprobs); + void raw_output_read(RawForwardOutput& outputs); + + void clear(); + + private: + ForwardType forward_type_; + uint64_t last_version_ = 0; + void* metadata_addr_ = nullptr; + ControlMetadata* control_ptr_ = nullptr; +}; +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/runtime/master.cpp b/xllm/core/runtime/master.cpp index 8b1a4590..c4418e73 100644 --- a/xllm/core/runtime/master.cpp +++ b/xllm/core/runtime/master.cpp @@ -111,7 +111,8 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .enable_cache_upload(options_.enable_cache_upload()) .enable_schedule_overlap(options_.enable_schedule_overlap()) .enable_offline_inference(options_.enable_offline_inference()) - .spawn_worker_path(options_.spawn_worker_path()); + .spawn_worker_path(options_.spawn_worker_path()) + .is_local(options_.is_local()); auto engine = std::make_unique(eng_options); engine_ = std::move(engine); @@ -152,7 +153,8 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .enable_schedule_overlap(options_.enable_schedule_overlap()) .enable_cache_upload(options_.enable_cache_upload()) .enable_offline_inference(options_.enable_offline_inference()) - .spawn_worker_path(options_.spawn_worker_path()); + .spawn_worker_path(options_.spawn_worker_path()) + .is_local(options_.is_local()); if (options_.device_ip().has_value()) { spec_options.device_ip(options_.device_ip().value()); @@ -198,7 +200,8 @@ Master::Master(const Options& options, EngineType type) : options_(options) { .store_metadata_connstring(options_.store_metadata_connstring()) .enable_continuous_kvcache(options_.enable_continuous_kvcache()) .enable_offline_inference(options_.enable_offline_inference()) - .spawn_worker_path(options_.spawn_worker_path()); + .spawn_worker_path(options_.spawn_worker_path()) + .is_local(options_.is_local()); if (options_.device_ip().has_value()) { eng_options.device_ip(options_.device_ip().value()); diff --git a/xllm/core/runtime/master.h b/xllm/core/runtime/master.h index ee2ded92..c4d3dfc0 100644 --- a/xllm/core/runtime/master.h +++ b/xllm/core/runtime/master.h @@ -27,7 +27,6 @@ limitations under the License. #include "common/types.h" #include "framework/request/request_params.h" #include "runtime/engine.h" - namespace xllm { class Master { @@ -63,7 +62,6 @@ class Master { protected: Options options_; std::unique_ptr engine_; - RateLimiter rate_limiter_; }; diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index 9d2c03ec..062b63be 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -162,6 +162,9 @@ struct Options { PROPERTY(bool, enable_offline_inference) = false; // the path to spawn worker binary PROPERTY(std::string, spawn_worker_path) = ""; + + // whether the worker and master are on the same machine. + PROPERTY(bool, is_local) = false; }; } // namespace runtime diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 90e2b4c3..428c0c3e 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -335,6 +335,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->eplb_info().expert_ids().begin(), pb_forward_input->eplb_info().expert_ids().end()); eplb_info.update_layer_id = pb_forward_input->eplb_info().update_layer_id(); + COUNTER_ADD(proto_latency_seconds_proto2i, timer.elapsed_seconds()); } @@ -552,6 +553,7 @@ void proto_to_forward_output(const proto::ForwardOutput& pb_output, } raw_forward_output.outputs.emplace_back(s); } + COUNTER_ADD(proto_latency_seconds_proto2o, timer.elapsed_seconds()); } diff --git a/xllm/core/util/CMakeLists.txt b/xllm/core/util/CMakeLists.txt index 83dc2dc7..3318822e 100644 --- a/xllm/core/util/CMakeLists.txt +++ b/xllm/core/util/CMakeLists.txt @@ -27,6 +27,7 @@ cc_library( type_traits.h utils.h uuid.h + shared_memory_manager.h SRCS device_name_utils.cpp env_var.cpp @@ -37,6 +38,7 @@ cc_library( timer.cpp utils.cpp uuid.cpp + shared_memory_manager.cpp DEPS torch brpc diff --git a/xllm/core/util/net.cpp b/xllm/core/util/net.cpp index d7d6e476..e407bbb7 100644 --- a/xllm/core/util/net.cpp +++ b/xllm/core/util/net.cpp @@ -97,6 +97,28 @@ uint64_t convert_ip_port_to_uint64(const std::string& ip, uint16_t port) { uint32_t ip_network = ip_addr.s_addr; return (static_cast(ip_network) << 32) | port; } +// input example: 127.0.0.1:18889 +std::string extract_ip(const std::string& input) { + std::istringstream stream(input); + std::string ip; + + std::getline(stream, ip, ':'); + if (ip == "127.0.0.1") { + ip = get_local_ip_addr(); + } + return ip; +} + +std::string extract_port(const std::string& input) { + std::istringstream stream(input); + std::string ip; + std::string port; + + std::getline(stream, ip, ':'); + std::getline(stream, port, ':'); + + return port; +} void parse_host_port_from_addr(const std::string& addr, std::string& host, diff --git a/xllm/core/util/net.h b/xllm/core/util/net.h index 005e448b..01190d10 100644 --- a/xllm/core/util/net.h +++ b/xllm/core/util/net.h @@ -28,5 +28,7 @@ void parse_host_port_from_addr(const std::string& addr, std::string& host, int& port); +std::string extract_ip(const std::string& input); +std::string extract_port(const std::string& input); } // namespace net } // namespace xllm diff --git a/xllm/core/framework/eplb/shared_memory_manager.cpp b/xllm/core/util/shared_memory_manager.cpp similarity index 83% rename from xllm/core/framework/eplb/shared_memory_manager.cpp rename to xllm/core/util/shared_memory_manager.cpp index e81a1dcb..eaecb31f 100644 --- a/xllm/core/framework/eplb/shared_memory_manager.cpp +++ b/xllm/core/util/shared_memory_manager.cpp @@ -15,7 +15,13 @@ limitations under the License. #include "shared_memory_manager.h" +#include +#include +#include + +#include #include +#include namespace xllm { std::vector SharedMemoryManager::pending_cleanups; @@ -61,6 +67,9 @@ SharedMemoryManager::SharedMemoryManager(const std::string& name, close(fd_); LOG(FATAL) << "mmap failed: " << strerror(errno); } + + // Initialize memory to zero. + std::memset(addr_, 0, size_); } SharedMemoryManager::~SharedMemoryManager() { @@ -95,20 +104,4 @@ void SharedMemoryManager::cleanup_handler(int sig) { exit(sig); } -void* SharedMemoryManager::allocate(int64_t size, int64_t alignment) { - std::lock_guard lock(mutex_); - - // Calculate aligned size and check bounds - int64_t aligned_size = (size + alignment - 1) & ~(alignment - 1); - if (current_offset_ + aligned_size > size_) { - LOG(FATAL) << "Shared memory overflow, size_ = " << size_ - << ", aligned_size = " << aligned_size - << ", current_offset_ = " << current_offset_; - } - - // Return current offset and advance - void* ptr = static_cast(addr_) + current_offset_; - current_offset_ += aligned_size; - return ptr; -} } // namespace xllm \ No newline at end of file diff --git a/xllm/core/framework/eplb/shared_memory_manager.h b/xllm/core/util/shared_memory_manager.h similarity index 92% rename from xllm/core/framework/eplb/shared_memory_manager.h rename to xllm/core/util/shared_memory_manager.h index 9608d9f9..037bd766 100644 --- a/xllm/core/framework/eplb/shared_memory_manager.h +++ b/xllm/core/util/shared_memory_manager.h @@ -18,11 +18,11 @@ limitations under the License. #include #include #include -#include #include #include #include +#include namespace xllm { @@ -31,11 +31,10 @@ class SharedMemoryManager { explicit SharedMemoryManager(const std::string& name, size_t size, bool& is_creator); - ~SharedMemoryManager(); - void* allocate(int64_t size, int64_t alignment = alignof(max_align_t)); void* base_address() const { return addr_; } int64_t size() const { return size_; } + std::string name() const { return shm_name_; } private: std::string shm_name_; @@ -43,7 +42,6 @@ class SharedMemoryManager { void* addr_ = MAP_FAILED; int64_t size_ = 0; int64_t current_offset_ = 0; - std::mutex mutex_; static void cleanup_handler(int sig); static std::vector pending_cleanups; diff --git a/xllm/core/util/timer.cpp b/xllm/core/util/timer.cpp index 74f32628..f7964ba3 100644 --- a/xllm/core/util/timer.cpp +++ b/xllm/core/util/timer.cpp @@ -31,4 +31,13 @@ double Timer::elapsed_seconds() const { return absl::ToDoubleSeconds(absl::Now() - start_); } +// get the elapsed time in milliseconds +double Timer::elapsed_milliseconds() const { + return absl::ToDoubleMilliseconds(absl::Now() - start_); +} + +// get the elapsed time in microseconds +double Timer::elapsed_microseconds() const { + return absl::ToDoubleMicroseconds(absl::Now() - start_); +} } // namespace xllm \ No newline at end of file diff --git a/xllm/core/util/timer.h b/xllm/core/util/timer.h index 5f1f151d..0fc103be 100644 --- a/xllm/core/util/timer.h +++ b/xllm/core/util/timer.h @@ -27,8 +27,10 @@ class Timer final { // reset the timer void reset(); - // get the elapsed time in seconds + // get the elapsed time. double elapsed_seconds() const; + double elapsed_milliseconds() const; + double elapsed_microseconds() const; private: // the start time of the timer diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 7c8526bb..19c3e3f2 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -91,6 +91,17 @@ int run() { FLAGS_host = net::get_local_ip_addr(); } + bool is_local = false; + if (FLAGS_host != "" && + net::extract_ip(FLAGS_master_node_addr) == FLAGS_host) { + is_local = true; + } else { + is_local = false; + } + + LOG(INFO) << "set worker role to " + << (is_local ? "local worker" : "remote worker"); + if (FLAGS_backend == "vlm") { FLAGS_enable_prefix_cache = false; FLAGS_enable_chunked_prefill = false; @@ -169,7 +180,8 @@ int run() { .max_global_ttft_ms(FLAGS_max_global_ttft_ms) .max_global_tpot_ms(FLAGS_max_global_tpot_ms) .max_requests_per_batch(FLAGS_max_requests_per_batch) - .enable_continuous_kvcache(FLAGS_enable_continuous_kvcache); + .enable_continuous_kvcache(FLAGS_enable_continuous_kvcache) + .is_local(is_local); InstanceName::name()->set_name(options.instance_name().value_or(""));