Skip to content

Commit 67ad720

Browse files
Gossityzhaotianyi
authored andcommitted
fix: code format and fix moe dummy run.
1 parent 39b7067 commit 67ad720

File tree

8 files changed

+115
-28
lines changed

8 files changed

+115
-28
lines changed

xllm/core/layers/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cc_library(
1616
qwen3_moe_decoder_layer.h
1717
linear_impl.h
1818
word_embedding_impl.h
19+
layer_utils.h
1920
SRCS
2021
qwen3_attention.cpp
2122
attention.cpp
@@ -26,6 +27,7 @@ cc_library(
2627
qwen3_decoder_layer.cpp
2728
qwen3_moe_decoder_layer.cpp
2829
linear_impl.cpp
30+
layer_utils.cpp
2931
DEPS
3032
"-Wl,--whole-archive"
3133
"-Wl,--no-whole-archive"

xllm/core/layers/common/fused_moe.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,28 +130,13 @@ torch::Tensor FusedMoEImpl::forward_expert(
130130
torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states,
131131
const ModelInputParams& input_params) {
132132
auto input = hidden_states;
133-
const auto& dp_tokens = input_params.dp_global_token_nums;
134-
int dp_rank = 0;
135133
bool need_slice = false;
136134
if (parallel_args_.dp_size() > 1 && parallel_args_.ep_size() > 1) {
137-
dp_rank = parallel_args_.dp_local_process_group_->rank();
138135
input = parallel_state::gather(input,
139136
parallel_args_.dp_local_process_group_,
140137
input_params.dp_global_token_nums);
141138
need_slice = true;
142139
}
143-
// fake run for dp rank with zero tokens
144-
if (dp_tokens[dp_rank] == 0) {
145-
// If the current dp rank has zero tokens, return an empty tensor
146-
input = parallel_state::reduce(input, tp_pg_);
147-
if (need_slice) {
148-
auto start =
149-
std::accumulate(dp_tokens.begin(), dp_tokens.begin() + dp_rank, 0);
150-
auto end = start + dp_tokens[dp_rank];
151-
return input.slice(0, start, end);
152-
}
153-
return input;
154-
}
155140

156141
pack_params();
157142
std::optional<torch::Tensor> shared_output = std::nullopt;
@@ -162,6 +147,8 @@ torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states,
162147
auto output = forward_expert(input, router_logits, shared_output);
163148

164149
if (need_slice) {
150+
const auto& dp_tokens = input_params.dp_global_token_nums;
151+
const int dp_rank = parallel_args_.dp_local_process_group_->rank();
165152
auto start =
166153
std::accumulate(dp_tokens.begin(), dp_tokens.begin() + dp_rank, 0);
167154
auto end = start + dp_tokens[dp_rank];
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "layer_utils.h"
17+
18+
#include "framework/parallel_state/parallel_state.h"
19+
20+
namespace xllm {
21+
namespace layer {
22+
23+
bool is_dummy_run(const ModelInputParams& input_params,
24+
const ParallelArgs& parallel_args) {
25+
int dp_rank = 0;
26+
if (parallel_args.dp_size() > 1) {
27+
dp_rank = parallel_args.dp_local_process_group_->rank();
28+
}
29+
return input_params.dp_global_token_nums[dp_rank] == 0;
30+
}
31+
32+
torch::Tensor dummy_run(torch::Tensor& input,
33+
const ModelInputParams& input_params,
34+
const ParallelArgs& parallel_args) {
35+
if (parallel_args.dp_size() <= 1 && parallel_args.ep_size() <= 1) {
36+
return input;
37+
}
38+
39+
auto tp_pg = parallel_args.tp_group_;
40+
if (parallel_args.ep_size() > 1) {
41+
tp_pg = parallel_args.process_group_;
42+
}
43+
bool need_slice = false;
44+
if (parallel_args.dp_size() > 1 && parallel_args.ep_size() > 1) {
45+
input = parallel_state::gather(input,
46+
parallel_args.dp_local_process_group_,
47+
input_params.dp_global_token_nums);
48+
need_slice = true;
49+
}
50+
if (tp_pg->world_size() > 1) {
51+
input = parallel_state::reduce(input, tp_pg);
52+
}
53+
if (need_slice) {
54+
const auto& dp_tokens = input_params.dp_global_token_nums;
55+
const int dp_rank = parallel_args.dp_local_process_group_->rank();
56+
auto start =
57+
std::accumulate(dp_tokens.begin(), dp_tokens.begin() + dp_rank, 0);
58+
auto end = start + dp_tokens[dp_rank];
59+
input = input.slice(0, start, end);
60+
}
61+
return input;
62+
}
63+
64+
} // namespace layer
65+
} // namespace xllm
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include "framework/model/model_input_params.h"
18+
#include "framework/parallel_state/parallel_args.h"
19+
20+
namespace xllm {
21+
namespace layer {
22+
23+
bool is_dummy_run(const ModelInputParams& input_params,
24+
const ParallelArgs& parallel_args);
25+
26+
torch::Tensor dummy_run(torch::Tensor& input,
27+
const ModelInputParams& input_params,
28+
const ParallelArgs& parallel_args);
29+
30+
} // namespace layer
31+
} // namespace xllm

xllm/core/layers/common/qwen3_decoder_layer.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ limitations under the License.
1717

1818
#include <glog/logging.h>
1919

20+
#include "layer_utils.h"
21+
2022
namespace xllm {
2123
namespace layer {
2224

23-
Qwen3DecoderImpl::Qwen3DecoderImpl(const ModelContext& context) {
25+
Qwen3DecoderImpl::Qwen3DecoderImpl(const ModelContext& context)
26+
: parallel_args_(context.get_parallel_args()) {
2427
const auto& model_args = context.get_model_args();
2528
const auto& quant_args = context.get_quant_args();
2629
const auto& parallel_args = context.get_parallel_args();
@@ -65,6 +68,10 @@ torch::Tensor Qwen3DecoderImpl::forward(torch::Tensor& x,
6568
const AttentionMetadata& attn_metadata,
6669
KVCache& kv_cache,
6770
const ModelInputParams& input_params) {
71+
bool is_dummy_run = layer::is_dummy_run(input_params, parallel_args_);
72+
if (is_dummy_run) {
73+
return x;
74+
}
6875
// Pre-attention norm
6976
auto residual = x;
7077
x = input_norm_->forward(x);

xllm/core/layers/common/qwen3_decoder_layer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class Qwen3DecoderImpl : public torch::nn::Module {
5353
DenseMLP mlp_{nullptr};
5454
RmsNorm input_norm_{nullptr};
5555
RmsNorm post_norm_{nullptr};
56+
57+
ParallelArgs parallel_args_;
5658
};
5759

5860
} // namespace layer

xllm/core/layers/common/qwen3_moe_decoder_layer.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717

1818
#include <glog/logging.h>
1919

20+
#include "layer_utils.h"
21+
2022
namespace xllm {
2123
namespace layer {
2224

@@ -95,13 +97,9 @@ torch::Tensor Qwen3MoeDecoderImpl::forward(
9597
const AttentionMetadata& attn_metadata,
9698
KVCache& kv_cache,
9799
const ModelInputParams& input_params) {
98-
const auto& dp_rank = parallel_args_.dp_local_process_group_->rank();
99-
if (input_params.dp_global_token_nums[dp_rank] == 0) {
100-
if (moe_mlp_) {
101-
return moe_mlp_(x, input_params);
102-
} else {
103-
return x;
104-
}
100+
bool is_dummy_run = layer::is_dummy_run(input_params, parallel_args_);
101+
if (is_dummy_run) {
102+
return layer::dummy_run(x, input_params, parallel_args_);
105103
}
106104
// Pre-attention norm
107105
torch::Tensor residual = x;

xllm/models/llm/llm_model_base.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,6 @@ class LlmModelImplBase : public torch::nn::Module {
308308
auto cancated_h = torch::cat(hs, 0);
309309
return norm_(cancated_h, 0);
310310
#elif defined(USE_MLU)
311-
CHECK(input_params.size() == 1)
312-
<< "invalid input_params size: " << input_params.size();
313-
if (input_params[0].q_max_seq_len == 0) {
314-
return hs[0];
315-
}
316311
bool is_prefill = input_params[0].q_max_seq_len > 1;
317312
auto attn_metadata =
318313
layer::AttentionMetadata::build(input_params[0], is_prefill);

0 commit comments

Comments
 (0)