@@ -15,11 +15,33 @@ limitations under the License.
1515
1616#include " parallel_state.h"
1717
18+ #include " core/util/utils.h"
19+
1820#if defined(USE_NPU)
1921#include " hccl/hccl.h"
2022#include " npu_process_group.h"
2123#endif
2224
25+ namespace {
26+
27+ torch::Tensor remove_paddings_after_all_gather (
28+ const torch::Tensor& input,
29+ int64_t padding_to_token_num,
30+ const std::vector<int >& token_num_list) {
31+ std::vector<torch::Tensor> group_tensors;
32+ int64_t offset = 0 ;
33+ for (const auto & token_num : token_num_list) {
34+ if (token_num != 0 ) {
35+ auto tensor_slice = input.slice (0 , offset, offset + token_num);
36+ group_tensors.push_back (tensor_slice);
37+ }
38+ offset += padding_to_token_num;
39+ }
40+
41+ return torch::cat (group_tensors).contiguous ();
42+ }
43+ } // namespace
44+
2345namespace xllm {
2446namespace parallel_state {
2547
@@ -45,7 +67,9 @@ std::optional<ParallelArgs> get_dp_attn_parallel_args(
4567 parallel_args.dp_size ());
4668}
4769
48- torch::Tensor gather (torch::Tensor input, ProcessGroup* process_group) {
70+ torch::Tensor gather (const torch::Tensor& input,
71+ ProcessGroup* process_group,
72+ int dim) {
4973 if (!process_group) {
5074 return input;
5175 }
@@ -61,10 +85,56 @@ torch::Tensor gather(torch::Tensor input, ProcessGroup* process_group) {
6185 }
6286 // blocking call
6387 process_group->allgather (input, tensors);
64- return torch::cat (tensors, /* dim=*/ -1 ).contiguous ();
88+ return torch::cat (tensors, /* dim=*/ dim).contiguous ();
89+ }
90+
91+ torch::Tensor gather (const torch::Tensor& input,
92+ ProcessGroup* process_group,
93+ const std::vector<int32_t >& token_num_list) {
94+ if (!process_group) {
95+ return input;
96+ }
97+ const auto world_size = process_group->world_size ();
98+ const auto rank = process_group->rank ();
99+ if (world_size == 1 ) {
100+ return input;
101+ }
102+ if (token_num_list.empty ()) {
103+ return gather (input, process_group, 0 );
104+ }
105+ CHECK_EQ (token_num_list.size (), world_size)
106+ << " token_num_list size " << token_num_list.size ()
107+ << " does not match world_size " << world_size;
108+
109+ const bool num_tokens_equal =
110+ std::all_of (token_num_list.begin (),
111+ token_num_list.end (),
112+ [first_token_num = token_num_list[0 ]](int64_t num) {
113+ return num == first_token_num;
114+ });
115+ if (num_tokens_equal) {
116+ return gather (input, process_group, 0 );
117+ }
118+
119+ int32_t max_num_tokens = xllm::util::max (token_num_list);
120+ int32_t num_padding = max_num_tokens - token_num_list[rank];
121+ auto padded_input = input;
122+ if (token_num_list[rank] == 0 ) {
123+ // If the current rank has zero tokens, create a padding tensor
124+ padded_input =
125+ torch::empty ({max_num_tokens, input.size (-1 )}, input.options ());
126+ } else if (num_padding > 0 ) {
127+ std::vector<int64_t > pad = {0 , 0 , 0 , num_padding};
128+ padded_input = torch::nn::functional::pad (
129+ input, torch::nn::functional::PadFuncOptions (pad));
130+ }
131+
132+ auto gathered_input = gather (padded_input, process_group, 0 );
133+ return remove_paddings_after_all_gather (
134+ gathered_input, max_num_tokens, token_num_list);
65135}
66136
67- torch::Tensor reduce (torch::Tensor input, ProcessGroup* process_group) {
137+ torch::Tensor reduce (torch::Tensor& input, ProcessGroup* process_group) {
68138 if (!process_group) {
69139 return input;
70140 }
@@ -76,7 +146,9 @@ torch::Tensor reduce(torch::Tensor input, ProcessGroup* process_group) {
76146 return input;
77147}
78148
79- torch::Tensor scatter (torch::Tensor input, ProcessGroup* process_group) {
149+ torch::Tensor scatter (torch::Tensor input,
150+ ProcessGroup* process_group,
151+ int dim) {
80152 if (!process_group) {
81153 return input;
82154 }
@@ -86,13 +158,13 @@ torch::Tensor scatter(torch::Tensor input, ProcessGroup* process_group) {
86158 }
87159
88160 // get the size for last dimension
89- const auto last_dim_size = input.size (- 1 );
90- CHECK (last_dim_size % world_size == 0 )
91- << " last_dim_size " << last_dim_size
92- << " cannot be divided by world_size " << world_size;
161+ const auto dim_size = input.size (dim );
162+ CHECK (dim_size % world_size == 0 )
163+ << " dim_size " << dim_size << " cannot be divided by world_size "
164+ << world_size;
93165
94166 // torch::split does not create contiguous tensors by default.
95- const auto tensor_list = input.split (last_dim_size / world_size, /* dim= */ - 1 );
167+ const auto tensor_list = input.split (dim_size / world_size, dim);
96168 const auto rank = process_group->rank ();
97169 return tensor_list[rank];
98170}
@@ -126,4 +198,4 @@ std::vector<std::unique_ptr<ProcessGroup>> create_npu_process_groups(
126198}
127199
128200} // namespace parallel_state
129- } // namespace xllm
201+ } // namespace xllm
0 commit comments