1+ /*  Copyright 2025 The xLLM Authors. All Rights Reserved.
2+ 
3+ Licensed under the Apache License, Version 2.0 (the "License"); 
4+ you may not use this file except in compliance with the License. 
5+ You may obtain a copy of the License at 
6+ 
7+     https://github.com/jd-opensource/xllm/blob/main/LICENSE 
8+ 
9+ Unless required by applicable law or agreed to in writing, software 
10+ distributed under the License is distributed on an "AS IS" BASIS, 
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
12+ See the License for the specific language governing permissions and 
13+ limitations under the License. 
14+ ==============================================================================*/  
15+ 
16+ #include  " layer_utils.h" 
17+ 
18+ #include  " framework/parallel_state/parallel_state.h" 
19+ 
20+ namespace  xllm  {
21+ namespace  layer  {
22+ 
23+ bool  is_dummy_run (const  ModelInputParams& input_params,
24+                   const  ParallelArgs& parallel_args) {
25+   int  dp_rank = 0 ;
26+   if  (parallel_args.dp_size () > 1 ) {
27+     dp_rank = parallel_args.dp_local_process_group_ ->rank ();
28+   }
29+   return  input_params.dp_global_token_nums [dp_rank] == 0 ;
30+ }
31+ 
32+ torch::Tensor dummy_run (torch::Tensor& input,
33+                         const  ModelInputParams& input_params,
34+                         const  ParallelArgs& parallel_args) {
35+   if  (parallel_args.dp_size () <= 1  && parallel_args.ep_size () <= 1 ) {
36+     return  input;
37+   }
38+ 
39+   auto  tp_pg = parallel_args.tp_group_ ;
40+   if  (parallel_args.ep_size () > 1 ) {
41+     tp_pg = parallel_args.process_group_ ;
42+   }
43+   bool  need_slice = false ;
44+   if  (parallel_args.dp_size () > 1  && parallel_args.ep_size () > 1 ) {
45+     input = parallel_state::gather (input,
46+                                    parallel_args.dp_local_process_group_ ,
47+                                    input_params.dp_global_token_nums );
48+     need_slice = true ;
49+   }
50+   if  (tp_pg->world_size () > 1 ) {
51+     input = parallel_state::reduce (input, tp_pg);
52+   }
53+   if  (need_slice) {
54+     const  auto & dp_tokens = input_params.dp_global_token_nums ;
55+     const  int  dp_rank = parallel_args.dp_local_process_group_ ->rank ();
56+     auto  start =
57+         std::accumulate (dp_tokens.begin (), dp_tokens.begin () + dp_rank, 0 );
58+     auto  end = start + dp_tokens[dp_rank];
59+     input = input.slice (0 , start, end);
60+   }
61+   return  input;
62+ }
63+ 
64+ }  //  namespace layer
65+ }  //  namespace xllm
0 commit comments