Skip to content

Commit ff5558e

Browse files
committed
feat: support mlu dp and ep.
1 parent 8bbe6f6 commit ff5558e

28 files changed

+393
-56
lines changed

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,72 @@ void CollectiveCommunicator::create_process_groups_cncl(
108108
int global_rank = parallel_args_->rank();
109109
int world_size = parallel_args_->world_size();
110110
int dp_size = parallel_args_->dp_size();
111-
process_group_ = std::make_unique<ProcessGroupCncl>(
112-
global_rank, world_size, world_size, ++port, host, "world_group", device);
111+
int ep_size = parallel_args_->ep_size();
112+
process_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
113+
world_size,
114+
world_size,
115+
++port,
116+
false,
117+
host,
118+
"world_group",
119+
device);
120+
parallel_args_->process_group_ = process_group_.get();
121+
113122
int tp_size = world_size / dp_size;
114123
CHECK_EQ(tp_size * dp_size, world_size);
115124
int port_offset = global_rank / tp_size + 1;
116125
tp_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
117126
world_size,
118127
tp_size,
119128
port + port_offset,
129+
false,
120130
host,
121131
"tp_group",
122132
device);
123-
parallel_args_->process_group_ = process_group_.get();
124133
parallel_args_->tp_group_ = tp_group_.get();
134+
port += dp_size;
135+
136+
if (dp_size > 1) {
137+
port_offset = global_rank % tp_size + 1;
138+
dp_local_process_group_ =
139+
std::make_unique<ProcessGroupCncl>(global_rank,
140+
world_size,
141+
dp_size,
142+
port + port_offset,
143+
true,
144+
host,
145+
"dp_group",
146+
device);
147+
parallel_args_->dp_local_process_group_ = dp_local_process_group_.get();
148+
port += tp_size;
149+
}
150+
151+
if (ep_size > 1) {
152+
int moe_tp_size = world_size / ep_size;
153+
port_offset = global_rank / moe_tp_size + 1;
154+
moe_tp_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
155+
world_size,
156+
moe_tp_size,
157+
port + port_offset,
158+
false,
159+
host,
160+
"moe_tp_group",
161+
device);
162+
parallel_args_->moe_tp_group_ = moe_tp_group_.get();
163+
port += ep_size;
164+
port_offset = global_rank % moe_tp_size + 1;
165+
moe_ep_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
166+
world_size,
167+
ep_size,
168+
port + port_offset,
169+
true,
170+
host,
171+
"moe_ep_group",
172+
device);
173+
parallel_args_->moe_ep_group_ = moe_ep_group_.get();
174+
}
125175
}
176+
126177
#endif
127178

128179
const ParallelArgs* CollectiveCommunicator::parallel_args() {

xllm/core/framework/parallel_state/collective_communicator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class CollectiveCommunicator {
4545
std::unique_ptr<ProcessGroup> dp_local_process_group_;
4646
#if defined(USE_MLU)
4747
std::unique_ptr<ProcessGroup> tp_group_;
48+
std::unique_ptr<ProcessGroup> moe_tp_group_;
49+
std::unique_ptr<ProcessGroup> moe_ep_group_;
4850
#endif
4951
};
5052

xllm/core/framework/parallel_state/mlu_process_group.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,29 @@ limitations under the License.
1919

2020
namespace {
2121

22+
std::pair<int, std::vector<uint64_t>> get_trans_group_rank(int world_size,
23+
int global_rank,
24+
int split_size) {
25+
int trans_group_count = split_size;
26+
int trans_group_size = world_size / split_size;
27+
int trans_group_index = global_rank % trans_group_size;
28+
int trans_index = global_rank / trans_group_size;
29+
std::vector<uint64_t> trans_group_ranks;
30+
for (int i = 0; i < trans_group_count; i++) {
31+
uint64_t rank = i * trans_group_size + trans_group_index;
32+
trans_group_ranks.push_back(rank);
33+
}
34+
35+
return {trans_index, trans_group_ranks};
36+
}
37+
2238
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
2339
int global_rank,
24-
int split_size) {
40+
int split_size,
41+
bool trans) {
42+
if (trans) {
43+
return get_trans_group_rank(world_size, global_rank, split_size);
44+
}
2545
int target_group_index = global_rank / split_size;
2646
uint64_t start_rank = target_group_index * split_size;
2747
uint64_t end_rank = start_rank + split_size;
@@ -41,6 +61,7 @@ ProcessGroupCncl::ProcessGroupCncl(int rank,
4161
int world_size,
4262
int rank_size,
4363
int port,
64+
bool trans,
4465
const std::string& host,
4566
const std::string& group_name,
4667
const torch::Device& device)
@@ -52,19 +73,18 @@ ProcessGroupCncl::ProcessGroupCncl(int rank,
5273
cncl_pg_options->group_name = group_name;
5374
if (world_size != rank_size) {
5475
auto [local_rank, group_ranks] =
55-
get_group_rank(world_size, rank, rank_size);
76+
get_group_rank(world_size, rank, rank_size, trans);
5677
cncl_pg_options->global_ranks_in_group = group_ranks;
5778
rank_ = local_rank;
5879
}
5980

6081
c10d::TCPStoreOptions tcp_options;
6182
tcp_options.isServer = (rank_ == 0);
6283
tcp_options.port = port;
63-
6484
c10::intrusive_ptr<c10d::Store> store =
6585
c10::make_intrusive<c10d::TCPStore>(host, tcp_options);
6686
cncl_pg_ = std::make_unique<torch_mlu::ProcessGroupCNCL>(
67-
store, rank, world_size, cncl_pg_options);
87+
store, rank_, world_size_, cncl_pg_options);
6888
}
6989

7090
// Destructor.
@@ -75,7 +95,7 @@ void ProcessGroupCncl::allreduce(torch::Tensor& input) {
7595
cncl_pg_->allreduce(input_tensors)->wait();
7696
}
7797

78-
void ProcessGroupCncl::allgather(torch::Tensor input,
98+
void ProcessGroupCncl::allgather(const torch::Tensor& input,
7999
std::vector<torch::Tensor>& outputs) {
80100
std::vector<torch::Tensor> input_tensors = {input};
81101
std::vector<std::vector<torch::Tensor>> output_tensors = {outputs};

xllm/core/framework/parallel_state/mlu_process_group.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class ProcessGroupCncl : public ProcessGroup {
2828
int world_size,
2929
int rank_size,
3030
int port,
31+
bool trans,
3132
const std::string& host,
3233
const std::string& group_name,
3334
const torch::Device& device);
@@ -41,7 +42,7 @@ class ProcessGroupCncl : public ProcessGroup {
4142

4243
void allreduce(torch::Tensor& input) override;
4344

44-
void allgather(torch::Tensor input,
45+
void allgather(const torch::Tensor& input,
4546
std::vector<torch::Tensor>& outputs) override;
4647

4748
private:

xllm/core/framework/parallel_state/npu_process_group.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void ProcessGroupHCCL::allreduce(torch::Tensor& input) {
109109
// /*comm=*/comm_,
110110
// /*stream=*/stream));
111111
}
112-
void ProcessGroupHCCL::allgather(torch::Tensor input,
112+
void ProcessGroupHCCL::allgather(const torch::Tensor& input,
113113
std::vector<torch::Tensor>& outputs) {
114114
check_input(input);
115115
// CHECK(outputs.size() == world_size())

xllm/core/framework/parallel_state/npu_process_group.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ProcessGroupHCCL : public ProcessGroup {
3333

3434
void allreduce(torch::Tensor& input) override;
3535

36-
void allgather(torch::Tensor input,
36+
void allgather(const torch::Tensor& input,
3737
std::vector<torch::Tensor>& outputs) override;
3838

3939
private:

xllm/core/framework/parallel_state/parallel_state.cpp

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,33 @@ limitations under the License.
1515

1616
#include "parallel_state.h"
1717

18+
#include "core/util/utils.h"
19+
1820
#if defined(USE_NPU)
1921
#include "hccl/hccl.h"
2022
#include "npu_process_group.h"
2123
#endif
2224

25+
namespace {
26+
27+
torch::Tensor remove_paddings_after_all_gather(
28+
const torch::Tensor& input,
29+
int64_t padding_to_token_num,
30+
const std::vector<int>& token_num_list) {
31+
std::vector<torch::Tensor> group_tensors;
32+
int64_t offset = 0;
33+
for (const auto& token_num : token_num_list) {
34+
if (token_num != 0) {
35+
auto tensor_slice = input.slice(0, offset, offset + token_num);
36+
group_tensors.push_back(tensor_slice);
37+
}
38+
offset += padding_to_token_num;
39+
}
40+
41+
return torch::cat(group_tensors).contiguous();
42+
}
43+
} // namespace
44+
2345
namespace xllm {
2446
namespace parallel_state {
2547

@@ -45,7 +67,9 @@ std::optional<ParallelArgs> get_dp_attn_parallel_args(
4567
parallel_args.dp_size());
4668
}
4769

48-
torch::Tensor gather(torch::Tensor input, ProcessGroup* process_group) {
70+
torch::Tensor gather(const torch::Tensor& input,
71+
ProcessGroup* process_group,
72+
int dim) {
4973
if (!process_group) {
5074
return input;
5175
}
@@ -61,10 +85,56 @@ torch::Tensor gather(torch::Tensor input, ProcessGroup* process_group) {
6185
}
6286
// blocking call
6387
process_group->allgather(input, tensors);
64-
return torch::cat(tensors, /*dim=*/-1).contiguous();
88+
return torch::cat(tensors, /*dim=*/dim).contiguous();
89+
}
90+
91+
torch::Tensor gather(const torch::Tensor& input,
92+
ProcessGroup* process_group,
93+
const std::vector<int32_t>& token_num_list) {
94+
if (!process_group) {
95+
return input;
96+
}
97+
const auto world_size = process_group->world_size();
98+
const auto rank = process_group->rank();
99+
if (world_size == 1) {
100+
return input;
101+
}
102+
if (token_num_list.empty()) {
103+
return gather(input, process_group, 0);
104+
}
105+
CHECK_EQ(token_num_list.size(), world_size)
106+
<< "token_num_list size " << token_num_list.size()
107+
<< " does not match world_size " << world_size;
108+
109+
const bool num_tokens_equal =
110+
std::all_of(token_num_list.begin(),
111+
token_num_list.end(),
112+
[first_token_num = token_num_list[0]](int64_t num) {
113+
return num == first_token_num;
114+
});
115+
if (num_tokens_equal) {
116+
return gather(input, process_group, 0);
117+
}
118+
119+
int32_t max_num_tokens = xllm::util::max(token_num_list);
120+
int32_t num_padding = max_num_tokens - token_num_list[rank];
121+
auto padded_input = input;
122+
if (token_num_list[rank] == 0) {
123+
// If the current rank has zero tokens, create a padding tensor
124+
padded_input =
125+
torch::empty({max_num_tokens, input.size(-1)}, input.options());
126+
} else if (num_padding > 0) {
127+
std::vector<int64_t> pad = {0, 0, 0, num_padding};
128+
padded_input = torch::nn::functional::pad(
129+
input, torch::nn::functional::PadFuncOptions(pad));
130+
}
131+
132+
auto gathered_input = gather(padded_input, process_group, 0);
133+
return remove_paddings_after_all_gather(
134+
gathered_input, max_num_tokens, token_num_list);
65135
}
66136

67-
torch::Tensor reduce(torch::Tensor input, ProcessGroup* process_group) {
137+
torch::Tensor reduce(torch::Tensor& input, ProcessGroup* process_group) {
68138
if (!process_group) {
69139
return input;
70140
}
@@ -76,7 +146,9 @@ torch::Tensor reduce(torch::Tensor input, ProcessGroup* process_group) {
76146
return input;
77147
}
78148

79-
torch::Tensor scatter(torch::Tensor input, ProcessGroup* process_group) {
149+
torch::Tensor scatter(torch::Tensor input,
150+
ProcessGroup* process_group,
151+
int dim) {
80152
if (!process_group) {
81153
return input;
82154
}
@@ -86,13 +158,13 @@ torch::Tensor scatter(torch::Tensor input, ProcessGroup* process_group) {
86158
}
87159

88160
// get the size for last dimension
89-
const auto last_dim_size = input.size(-1);
90-
CHECK(last_dim_size % world_size == 0)
91-
<< "last_dim_size " << last_dim_size
92-
<< " cannot be divided by world_size " << world_size;
161+
const auto dim_size = input.size(dim);
162+
CHECK(dim_size % world_size == 0)
163+
<< "dim_size " << dim_size << " cannot be divided by world_size "
164+
<< world_size;
93165

94166
// torch::split does not create contiguous tensors by default.
95-
const auto tensor_list = input.split(last_dim_size / world_size, /*dim=*/-1);
167+
const auto tensor_list = input.split(dim_size / world_size, dim);
96168
const auto rank = process_group->rank();
97169
return tensor_list[rank];
98170
}
@@ -126,4 +198,4 @@ std::vector<std::unique_ptr<ProcessGroup>> create_npu_process_groups(
126198
}
127199

128200
} // namespace parallel_state
129-
} // namespace xllm
201+
} // namespace xllm

xllm/core/framework/parallel_state/parallel_state.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,19 @@ namespace parallel_state {
2525
std::optional<ParallelArgs> get_dp_attn_parallel_args(
2626
const ParallelArgs& parallel_args);
2727

28-
torch::Tensor gather(torch::Tensor input, ProcessGroup* process_group);
28+
torch::Tensor gather(const torch::Tensor& input,
29+
ProcessGroup* process_group,
30+
int dim = -1);
2931

30-
torch::Tensor reduce(torch::Tensor input, ProcessGroup* process_group);
32+
torch::Tensor gather(const torch::Tensor& input,
33+
ProcessGroup* process_group,
34+
const std::vector<int32_t>& token_num_list);
3135

32-
torch::Tensor scatter(torch::Tensor input, ProcessGroup* process_group);
36+
torch::Tensor reduce(torch::Tensor& input, ProcessGroup* process_group);
37+
38+
torch::Tensor scatter(torch::Tensor input,
39+
ProcessGroup* process_group,
40+
int dim = -1);
3341

3442
// Create a process group where each process has a single device
3543
// devices: list of devices to create process groups on.

xllm/core/framework/parallel_state/process_group.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ProcessGroup {
3737
virtual void allreduce(torch::Tensor& input) = 0;
3838

3939
// allgather: gather tensors from all processes and concatenate them.
40-
virtual void allgather(torch::Tensor input,
40+
virtual void allgather(const torch::Tensor& input,
4141
std::vector<torch::Tensor>& outputs) = 0;
4242

4343
private:

xllm/core/framework/sampling/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cc_library(
2020
glog::glog
2121
torch
2222
$<$<BOOL:${USE_NPU}>:xllm_ops>
23+
$<$<BOOL:${USE_MLU}>:mlu_kernels>
2324
)
2425

2526
cc_test(

0 commit comments

Comments
 (0)