Skip to content

Commit 018af48

Browse files
authored
feat: optimize the QKV computation performance in DiT model. (#273)
1 parent 322803c commit 018af48

File tree

1 file changed

+165
-88
lines changed

1 file changed

+165
-88
lines changed

xllm/models/dit/dit.h

Lines changed: 165 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,12 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
142142
auto head_dim = model_args.head_dim();
143143
auto query_dim = heads_ * head_dim;
144144
auto out_dim = query_dim;
145-
to_q_ = register_module("to_q",
146-
DiTLinear(query_dim, out_dim, true /*has_bias*/));
147-
to_k_ = register_module("to_k",
148-
DiTLinear(query_dim, out_dim, true /*has_bias*/));
149-
to_v_ = register_module("to_v",
150-
DiTLinear(query_dim, out_dim, true /*has_bias*/));
151145

152-
to_q_->to(options_);
153-
to_k_->to(options_);
154-
to_v_->to(options_);
146+
fused_qkv_weight_ = register_parameter(
147+
"fused_qkv_weight", torch::empty({3 * query_dim, out_dim}, options_));
148+
149+
fused_qkv_bias_ = register_parameter("fused_qkv_bias",
150+
torch::empty({3 * out_dim}, options_));
155151

156152
norm_q_ = register_module("norm_q",
157153
DiTRMSNorm(head_dim,
@@ -170,19 +166,15 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
170166
torch::Tensor forward(const torch::Tensor& hidden_states,
171167
const torch::Tensor& image_rotary_emb) {
172168
int64_t batch_size, channel, height, width;
169+
batch_size = hidden_states.size(0);
173170

174-
// Reshape 4D input to [B, seq_len, C]
175-
torch::Tensor hidden_states_ =
176-
hidden_states; // Use copy to avoid modifying input
177-
batch_size = hidden_states_.size(0);
178-
179-
// Self-attention: use hidden_states as context
180-
torch::Tensor context = hidden_states_;
171+
auto qkv = torch::nn::functional::linear(
172+
hidden_states, fused_qkv_weight_, fused_qkv_bias_);
173+
auto chunks = qkv.chunk(3, -1);
181174

182-
// Compute QKV projections
183-
torch::Tensor query = to_q_->forward(hidden_states_);
184-
torch::Tensor key = to_k_->forward(context);
185-
torch::Tensor value = to_v_->forward(context);
175+
torch::Tensor query = chunks[0];
176+
torch::Tensor key = chunks[1];
177+
torch::Tensor value = chunks[2];
186178

187179
// Reshape for multi-head attention
188180
int64_t inner_dim = key.size(-1);
@@ -210,26 +202,53 @@ class FluxSingleAttentionImpl : public torch::nn::Module {
210202
norm_q_->load_state_dict(state_dict.get_dict_with_prefix("norm_q."));
211203
// norm_k
212204
norm_k_->load_state_dict(state_dict.get_dict_with_prefix("norm_k."));
213-
// to_q
214-
to_q_->load_state_dict(state_dict.get_dict_with_prefix("to_q."));
215-
// to_k
216-
to_k_->load_state_dict(state_dict.get_dict_with_prefix("to_k."));
217-
// to_v
218-
to_v_->load_state_dict(state_dict.get_dict_with_prefix("to_v."));
205+
206+
auto to_q_weight = state_dict.get_tensor("to_q.weight");
207+
auto to_q_bias = state_dict.get_tensor("to_q.bias");
208+
auto to_k_weight = state_dict.get_tensor("to_k.weight");
209+
auto to_k_bias = state_dict.get_tensor("to_k.bias");
210+
auto to_v_weight = state_dict.get_tensor("to_v.weight");
211+
auto to_v_bias = state_dict.get_tensor("to_v.bias");
212+
213+
if (to_q_weight.defined() && to_k_weight.defined() &&
214+
to_v_weight.defined()) {
215+
auto fused_qkv_weight =
216+
torch::cat({to_q_weight, to_k_weight, to_v_weight}, 0).contiguous();
217+
DCHECK_EQ(fused_qkv_weight_.sizes(), fused_qkv_weight.sizes())
218+
<< "fused_qkv_weight_ size mismatch: expected "
219+
<< fused_qkv_weight_.sizes() << " but got "
220+
<< fused_qkv_weight.sizes();
221+
fused_qkv_weight_.data().copy_(fused_qkv_weight.to(
222+
fused_qkv_weight_.device(), fused_qkv_weight_.dtype()));
223+
is_qkv_weight_loaded_ = true;
224+
}
225+
226+
if (to_q_bias.defined() && to_k_bias.defined() && to_v_bias.defined()) {
227+
auto fused_qkv_bias =
228+
torch::cat({to_q_bias, to_k_bias, to_v_bias}, 0).contiguous();
229+
DCHECK_EQ(fused_qkv_bias_.sizes(), fused_qkv_bias.sizes())
230+
<< "fused_qkv_bias_ size mismatch: expected "
231+
<< fused_qkv_bias_.sizes() << " but got " << fused_qkv_bias.sizes();
232+
fused_qkv_bias_.data().copy_(
233+
fused_qkv_bias.to(fused_qkv_bias_.device(), fused_qkv_bias_.dtype()));
234+
is_qkv_bias_loaded_ = true;
235+
}
219236
}
220237

221238
void verify_loaded_weights(const std::string& prefix) const {
239+
CHECK(is_qkv_weight_loaded_)
240+
<< "weight is not loaded for " << prefix + "qkv_proj.weight";
241+
CHECK(is_qkv_bias_loaded_)
242+
<< "bias is not loaded for " << prefix + "qkv_proj.bias";
222243
norm_q_->verify_loaded_weights(prefix + "norm_q.");
223244
norm_k_->verify_loaded_weights(prefix + "norm_k.");
224-
to_q_->verify_loaded_weights(prefix + "to_q.");
225-
to_k_->verify_loaded_weights(prefix + "to_k.");
226-
to_v_->verify_loaded_weights(prefix + "to_v.");
227245
}
228246

229247
private:
230-
DiTLinear to_q_{nullptr};
231-
DiTLinear to_k_{nullptr};
232-
DiTLinear to_v_{nullptr};
248+
bool is_qkv_weight_loaded_{false};
249+
bool is_qkv_bias_loaded_{false};
250+
torch::Tensor fused_qkv_weight_{};
251+
torch::Tensor fused_qkv_bias_{};
233252
int64_t heads_;
234253
DiTRMSNorm norm_q_{nullptr};
235254
DiTRMSNorm norm_k_{nullptr};
@@ -248,29 +267,24 @@ class FluxAttentionImpl : public torch::nn::Module {
248267
auto out_dim = query_dim;
249268
auto added_kv_proj_dim = query_dim;
250269

251-
to_q_ = register_module("to_q", DiTLinear(query_dim, out_dim, true));
252-
to_k_ = register_module("to_k", DiTLinear(query_dim, out_dim, true));
253-
to_v_ = register_module("to_v", DiTLinear(query_dim, out_dim, true));
254-
add_q_proj_ = register_module("add_q_proj",
255-
DiTLinear(added_kv_proj_dim, out_dim, true));
256-
257-
add_k_proj_ = register_module("add_k_proj",
258-
DiTLinear(added_kv_proj_dim, out_dim, true));
259-
260-
add_v_proj_ = register_module("add_v_proj",
261-
DiTLinear(added_kv_proj_dim, out_dim, true));
262-
263270
to_out_ = register_module("to_out", DiTLinear(out_dim, query_dim, true));
264271

265272
to_add_out_ = register_module("to_add_out",
266273
DiTLinear(out_dim, added_kv_proj_dim, true));
267274

268-
to_q_->to(options_);
269-
to_k_->to(options_);
270-
to_v_->to(options_);
271-
add_q_proj_->to(options_);
272-
add_k_proj_->to(options_);
273-
add_v_proj_->to(options_);
275+
fused_qkv_weight_ = register_parameter(
276+
"fused_qkv_weight", torch::empty({3 * query_dim, out_dim}, options_));
277+
278+
fused_qkv_bias_ = register_parameter("fused_qkv_bias",
279+
torch::empty({3 * out_dim}, options_));
280+
281+
fused_add_qkv_weight_ = register_parameter(
282+
"fused_add_qkv_weight",
283+
torch::empty({3 * added_kv_proj_dim, out_dim}, options_));
284+
285+
fused_add_qkv_bias_ = register_parameter(
286+
"fused_add_qkv_bias", torch::empty({3 * out_dim}, options_));
287+
274288
to_out_->to(options_);
275289
to_add_out_->to(options_);
276290

@@ -330,9 +344,15 @@ class FluxAttentionImpl : public torch::nn::Module {
330344
.transpose(1, 2);
331345
}
332346
int64_t batch_size = encoder_hidden_states_reshaped.size(0);
333-
torch::Tensor query = to_q_->forward(hidden_states_reshaped);
334-
torch::Tensor key = to_k_->forward(hidden_states_reshaped);
335-
torch::Tensor value = to_v_->forward(hidden_states_reshaped);
347+
348+
auto qkv = torch::nn::functional::linear(
349+
hidden_states_reshaped, fused_qkv_weight_, fused_qkv_bias_);
350+
351+
auto chunks = qkv.chunk(3, -1);
352+
torch::Tensor query = chunks[0];
353+
torch::Tensor key = chunks[1];
354+
torch::Tensor value = chunks[2];
355+
336356
int64_t inner_dim = key.size(-1);
337357
int64_t attn_heads = heads_;
338358

@@ -342,13 +362,17 @@ class FluxAttentionImpl : public torch::nn::Module {
342362
value = value.view({batch_size, -1, attn_heads, head_dim}).transpose(1, 2);
343363
if (norm_q_) query = norm_q_->forward(query);
344364
if (norm_k_) key = norm_k_->forward(key);
345-
// encoder hidden states
346-
torch::Tensor encoder_hidden_states_query_proj =
347-
add_q_proj_->forward(encoder_hidden_states_reshaped);
348-
torch::Tensor encoder_hidden_states_key_proj =
349-
add_k_proj_->forward(encoder_hidden_states_reshaped);
350-
torch::Tensor encoder_hidden_states_value_proj =
351-
add_v_proj_->forward(encoder_hidden_states_reshaped);
365+
366+
auto encoder_qkv =
367+
torch::nn::functional::linear(encoder_hidden_states_reshaped,
368+
fused_add_qkv_weight_,
369+
fused_add_qkv_bias_);
370+
371+
auto encoder_chunks = encoder_qkv.chunk(3, -1);
372+
torch::Tensor encoder_hidden_states_query_proj = encoder_chunks[0];
373+
torch::Tensor encoder_hidden_states_key_proj = encoder_chunks[1];
374+
torch::Tensor encoder_hidden_states_value_proj = encoder_chunks[2];
375+
352376
encoder_hidden_states_query_proj =
353377
encoder_hidden_states_query_proj
354378
.view({batch_size, -1, attn_heads, head_dim})
@@ -396,12 +420,6 @@ class FluxAttentionImpl : public torch::nn::Module {
396420
}
397421

398422
void load_state_dict(const StateDict& state_dict) {
399-
// to_q
400-
to_q_->load_state_dict(state_dict.get_dict_with_prefix("to_q."));
401-
// to_k
402-
to_k_->load_state_dict(state_dict.get_dict_with_prefix("to_k."));
403-
// to_v
404-
to_v_->load_state_dict(state_dict.get_dict_with_prefix("to_v."));
405423
// to_out
406424
to_out_->load_state_dict(state_dict.get_dict_with_prefix("to_out.0."));
407425
// to_add_out
@@ -417,39 +435,98 @@ class FluxAttentionImpl : public torch::nn::Module {
417435
// norm_added_k
418436
norm_added_k_->load_state_dict(
419437
state_dict.get_dict_with_prefix("norm_added_k."));
420-
// add_q_proj
421-
add_q_proj_->load_state_dict(
422-
state_dict.get_dict_with_prefix("add_q_proj."));
423-
// add_k_proj
424-
add_k_proj_->load_state_dict(
425-
state_dict.get_dict_with_prefix("add_k_proj."));
426-
// add_v_proj
427-
add_v_proj_->load_state_dict(
428-
state_dict.get_dict_with_prefix("add_v_proj."));
438+
439+
auto to_q_weight = state_dict.get_tensor("to_q.weight");
440+
auto to_q_bias = state_dict.get_tensor("to_q.bias");
441+
auto to_k_weight = state_dict.get_tensor("to_k.weight");
442+
auto to_k_bias = state_dict.get_tensor("to_k.bias");
443+
auto to_v_weight = state_dict.get_tensor("to_v.weight");
444+
auto to_v_bias = state_dict.get_tensor("to_v.bias");
445+
446+
if (to_q_weight.defined() && to_k_weight.defined() &&
447+
to_v_weight.defined()) {
448+
auto fused_qkv_weight =
449+
torch::cat({to_q_weight, to_k_weight, to_v_weight}, 0).contiguous();
450+
DCHECK_EQ(fused_qkv_weight_.sizes(), fused_qkv_weight.sizes())
451+
<< "fused_qkv_weight_ size mismatch: expected "
452+
<< fused_qkv_weight_.sizes() << " but got "
453+
<< fused_qkv_weight.sizes();
454+
fused_qkv_weight_.data().copy_(fused_qkv_weight.to(
455+
fused_qkv_weight_.device(), fused_qkv_weight_.dtype()));
456+
is_qkv_weight_loaded_ = true;
457+
}
458+
459+
if (to_q_bias.defined() && to_k_bias.defined() && to_v_bias.defined()) {
460+
auto fused_qkv_bias =
461+
torch::cat({to_q_bias, to_k_bias, to_v_bias}, 0).contiguous();
462+
DCHECK_EQ(fused_qkv_bias_.sizes(), fused_qkv_bias.sizes())
463+
<< "fused_qkv_bias_ size mismatch: expected "
464+
<< fused_qkv_bias_.sizes() << " but got " << fused_qkv_bias.sizes();
465+
fused_qkv_bias_.data().copy_(
466+
fused_qkv_bias.to(fused_qkv_bias_.device(), fused_qkv_bias_.dtype()));
467+
is_qkv_bias_loaded_ = true;
468+
}
469+
470+
auto add_q_weight = state_dict.get_tensor("add_q_proj.weight");
471+
auto add_q_bias = state_dict.get_tensor("add_q_proj.bias");
472+
auto add_k_weight = state_dict.get_tensor("add_k_proj.weight");
473+
auto add_k_bias = state_dict.get_tensor("add_k_proj.bias");
474+
auto add_v_weight = state_dict.get_tensor("add_v_proj.weight");
475+
auto add_v_bias = state_dict.get_tensor("add_v_proj.bias");
476+
477+
if (add_q_weight.defined() && add_k_weight.defined() &&
478+
add_v_weight.defined()) {
479+
auto fused_add_qkv_weight =
480+
torch::cat({add_q_weight, add_k_weight, add_v_weight}, 0)
481+
.contiguous();
482+
DCHECK_EQ(fused_add_qkv_weight_.sizes(), fused_add_qkv_weight.sizes())
483+
<< "fused_add_qkv_weight_ size mismatch: expected "
484+
<< fused_add_qkv_weight_.sizes() << " but got "
485+
<< fused_add_qkv_weight.sizes();
486+
fused_add_qkv_weight_.data().copy_(fused_add_qkv_weight.to(
487+
fused_add_qkv_weight_.device(), fused_add_qkv_weight_.dtype()));
488+
is_add_qkv_weight_loaded_ = true;
489+
}
490+
491+
if (add_q_bias.defined() && add_k_bias.defined() && add_v_bias.defined()) {
492+
auto fused_add_qkv_bias =
493+
torch::cat({add_q_bias, add_k_bias, add_v_bias}, 0).contiguous();
494+
DCHECK_EQ(fused_add_qkv_bias_.sizes(), fused_add_qkv_bias.sizes())
495+
<< "fused_add_qkv_bias_ size mismatch: expected "
496+
<< fused_add_qkv_bias_.sizes() << " but got "
497+
<< fused_add_qkv_bias.sizes();
498+
fused_add_qkv_bias_.data().copy_(fused_add_qkv_bias.to(
499+
fused_add_qkv_bias_.device(), fused_add_qkv_bias_.dtype()));
500+
is_add_qkv_bias_loaded_ = true;
501+
}
429502
}
430503

431504
void verify_loaded_weights(const std::string& prefix) const {
505+
CHECK(is_qkv_weight_loaded_)
506+
<< "weight is not loaded for " << prefix + "qkv_proj.weight";
507+
CHECK(is_qkv_bias_loaded_)
508+
<< "bias is not loaded for " << prefix + "qkv_proj.bias";
509+
CHECK(is_add_qkv_weight_loaded_)
510+
<< "weight is not loaded for " << prefix + "add_qkv_proj.weight";
511+
CHECK(is_add_qkv_bias_loaded_)
512+
<< "bias is not loaded for " << prefix + "add_qkv_proj.bias";
432513
norm_q_->verify_loaded_weights(prefix + "norm_q.");
433514
norm_k_->verify_loaded_weights(prefix + "norm_k.");
434515
norm_added_q_->verify_loaded_weights(prefix + "norm_added_q.");
435516
norm_added_k_->verify_loaded_weights(prefix + "norm_added_k.");
436-
to_q_->verify_loaded_weights(prefix + "to_q.");
437-
to_k_->verify_loaded_weights(prefix + "to_k.");
438-
to_v_->verify_loaded_weights(prefix + "to_v.");
439517
to_out_->verify_loaded_weights(prefix + "to_out.0.");
440518
to_add_out_->verify_loaded_weights(prefix + "to_add_out.");
441-
add_q_proj_->verify_loaded_weights(prefix + "add_q_proj.");
442-
add_k_proj_->verify_loaded_weights(prefix + "add_k_proj.");
443-
add_v_proj_->verify_loaded_weights(prefix + "add_v_proj.");
444519
}
445520

446521
private:
447-
DiTLinear to_q_{nullptr};
448-
DiTLinear to_k_{nullptr};
449-
DiTLinear to_v_{nullptr};
450-
DiTLinear add_q_proj_{nullptr};
451-
DiTLinear add_k_proj_{nullptr};
452-
DiTLinear add_v_proj_{nullptr};
522+
bool is_qkv_weight_loaded_{false};
523+
bool is_qkv_bias_loaded_{false};
524+
bool is_add_qkv_weight_loaded_{false};
525+
bool is_add_qkv_bias_loaded_{false};
526+
torch::Tensor fused_qkv_weight_{};
527+
torch::Tensor fused_qkv_bias_{};
528+
torch::Tensor fused_add_qkv_weight_{};
529+
torch::Tensor fused_add_qkv_bias_{};
453530
DiTLinear to_out_{nullptr};
454531
DiTLinear to_add_out_{nullptr};
455532

@@ -1119,8 +1196,8 @@ class FluxTransformerBlockImpl : public torch::nn::Module {
11191196
FluxAttention attn_{nullptr};
11201197
torch::nn::LayerNorm norm2_{nullptr};
11211198
FeedForward ff_{nullptr};
1122-
torch::nn::LayerNorm norm2_context_{nullptr};
11231199
FeedForward ff_context_{nullptr};
1200+
torch::nn::LayerNorm norm2_context_{nullptr};
11241201
torch::TensorOptions options_;
11251202
};
11261203
TORCH_MODULE(FluxTransformerBlock);

0 commit comments

Comments
 (0)