2828 TensorParallelEmbedding ,
2929 TensorParallelRowLinear ,
3030 get_linear ,
31+ Fp8Linear ,
3132)
3233from text_generation_server .layers .attention import (
3334 Seqlen ,
3435 attention ,
35- paged_attention ,
36+ paged_attention_mla ,
3637 set_block_mapping ,
3738 HPUPagedAttentionMetadata ,
3839)
4445import habana_frameworks .torch as htorch
4546
4647
48+ def get_and_maybe_dequant_weights (layer : torch .nn .Module ) -> torch .Tensor :
49+ if isinstance (layer , Fp8Linear ):
50+ eye = torch .eye (
51+ layer .qweight .shape [- 1 ], dtype = torch .bfloat16 , device = layer .qweight .device
52+ )
53+ dequant_weights = layer (eye )
54+ del eye
55+ # standardize to (output, input)
56+ return dequant_weights .T
57+ return layer .weight
58+
59+
4760class DeepseekV2Config (PretrainedConfig ):
4861 def __init__ (
4962 self ,
@@ -246,6 +259,45 @@ def __init__(
246259 0 , self .num_key_value_heads , dtype = torch .int32 , device = weights .device
247260 ).repeat_interleave (self .num_groups )
248261
262+ kv_b_proj_weight = get_and_maybe_dequant_weights (self .kv_b_proj .linear ).T
263+ kv_b_proj_weight = kv_b_proj_weight .view (
264+ self .kv_lora_rank ,
265+ self .num_heads ,
266+ self .qk_nope_head_dim + self .value_head_size ,
267+ )
268+
269+ W_UK , W_UV = kv_b_proj_weight .split (
270+ [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
271+ )
272+ # Convert from (L, N, V) to (N, L, V)
273+ self .W_UV = W_UV .transpose (0 , 1 )
274+ # Convert from (L, N, P) to (N, P, L)
275+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
276+
277+ def _q_proj_and_k_up_proj (self , x ):
278+ q_proj = self .q_proj if self .q_lora_rank is None else self .q_b_proj
279+ q_nope , q_pe = (
280+ q_proj (x )
281+ .view (- 1 , self .num_heads , self .head_size )
282+ .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
283+ )
284+
285+ # Convert from (B, N, P) to (N, B, P)
286+ q_nope = q_nope .transpose (0 , 1 )
287+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
288+ ql_nope = torch .bmm (q_nope , self .W_UK_T )
289+ # Convert from (N, B, L) to (B, N, L)
290+ return ql_nope .transpose (0 , 1 ), q_pe
291+
292+ def _v_up_proj_and_o_proj (self , x ):
293+ # Convert from (B, N, L) to (N, B, L)
294+ x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
295+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
296+ x = torch .bmm (x , self .W_UV )
297+ # Convert from (N, B, V) to (B, N * V)
298+ x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .value_head_size )
299+ return self .o_proj (x )
300+
249301 def forward (
250302 self ,
251303 hidden_states : torch .Tensor ,
@@ -258,28 +310,28 @@ def forward(
258310 hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
259311 ):
260312 if self .q_lora_rank is None :
261- query = self . q_proj ( hidden_states )
313+ hidden_states_or_q_c = hidden_states
262314 else :
263- query = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states ))[0 ])
264- query = query .view (- 1 , self .num_heads , self .head_size )
265-
266- _ , query_pe = torch .split (
267- query , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
268- )
315+ hidden_states_or_q_c = self .q_a_layernorm (self .q_a_proj (hidden_states ))[0 ]
269316
270317 compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
271318 compressed_kv , key_pe = torch .split (
272319 compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1
273320 )
274321
275322 key_pe = key_pe .view (- 1 , 1 , self .qk_rope_head_dim )
276- kv = self .kv_b_proj (self .kv_a_layernorm (compressed_kv .contiguous ())[0 ]).view (
277- - 1 , self .num_key_value_heads , self .qk_nope_head_dim + self .value_head_size
278- )
323+ kv_c_normed = self .kv_a_layernorm (compressed_kv .contiguous ())[0 ]
279324
280- key_nope , value = torch .split (
281- kv , [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
282- )
325+ # Prefill
326+ if cu_seqlen_prefill is not None :
327+ q_proj = self .q_proj if self .q_lora_rank is None else self .q_b_proj
328+ query = q_proj (hidden_states_or_q_c )
329+ query = query .view (- 1 , self .num_heads , self .head_size )
330+ query_nope , query_pe = torch .split (
331+ query , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
332+ )
333+ else :
334+ query_nope , query_pe = self ._q_proj_and_k_up_proj (hidden_states_or_q_c )
283335
284336 batch_size , heads , head_dim = query_pe .shape
285337 query_pe = (
@@ -294,33 +346,47 @@ def forward(
294346 .reshape (batch_size , heads , head_dim )
295347 )
296348 self .rotary_emb (query_pe , key_pe , cos , sin )
297-
298- query [..., self .qk_nope_head_dim :] = query_pe
299- key = torch .empty_like (query )
300- key [..., : self .qk_nope_head_dim ] = key_nope
301- key [..., self .qk_nope_head_dim :] = key_pe
302-
303- # We need to pad the heads because Flash Attention does not support
304- # qk and v with different head sizes.
305- query = torch .nn .functional .pad (
306- query , (0 , self .head_pad_size - self .head_size ), value = 0
307- )
308- key = torch .nn .functional .pad (
309- key , (0 , self .head_pad_size - self .head_size ), value = 0
310- )
311- value = torch .nn .functional .pad (
312- value , (0 , self .head_pad_size - self .value_head_size ), value = 0
349+ latent_vec_k = torch .concat (
350+ (kv_c_normed , key_pe .view (- 1 , self .qk_rope_head_dim )), dim = - 1
313351 )
352+ latent_vec_k = latent_vec_k .view (- 1 , self .qk_rope_head_dim + self .kv_lora_rank )
353+
354+ latent_vec_k = latent_vec_k .unflatten (0 , (slots .size (0 ), - 1 ))
314355
315356 kv_cache .store (
316- key = key ,
317- value = value ,
357+ key = latent_vec_k ,
358+ value = None ,
318359 slots = slots ,
319360 kv_scales = self .kv_scales ,
320361 )
321362
322- # Prefill
323363 if cu_seqlen_prefill is not None :
364+ kv = self .kv_b_proj (kv_c_normed ).view (
365+ - 1 ,
366+ self .num_key_value_heads ,
367+ self .qk_nope_head_dim + self .value_head_size ,
368+ )
369+
370+ key_nope , value = torch .split (
371+ kv , [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
372+ )
373+ query [..., self .qk_nope_head_dim :] = query_pe
374+ key = torch .empty_like (query )
375+ key [..., : self .qk_nope_head_dim ] = key_nope
376+ key [..., self .qk_nope_head_dim :] = key_pe
377+
378+ # We need to pad the heads because Flash Attention does not support
379+ # qk and v with different head sizes.
380+ query = torch .nn .functional .pad (
381+ query , (0 , self .head_pad_size - self .head_size ), value = 0
382+ )
383+ key = torch .nn .functional .pad (
384+ key , (0 , self .head_pad_size - self .head_size ), value = 0
385+ )
386+ value = torch .nn .functional .pad (
387+ value , (0 , self .head_pad_size - self .value_head_size ), value = 0
388+ )
389+
324390 # flash attention
325391 attn_output = attention (
326392 query = query ,
@@ -331,24 +397,26 @@ def forward(
331397 seqlen = seqlen ,
332398 softmax_scale = self .softmax_scale ,
333399 )
334- # Decode
400+ attn_output = attn_output [..., : self .value_head_size ]
401+
402+ return self .o_proj (
403+ attn_output .reshape (- 1 , self .num_heads * self .value_head_size )
404+ )
335405 else :
336- attn_output = paged_attention (
406+ # Decode
407+ query = torch .cat ([query_nope , query_pe ], dim = - 1 )
408+ attn_output = paged_attention_mla (
337409 query ,
338410 kv_cache ,
339411 self .kv_head_mapping ,
340412 self .softmax_scale ,
341413 seqlen ,
342414 kv_scales = self .kv_scales ,
343415 hpu_attention_meta = hpu_attention_meta ,
416+ kv_lora_rank = self .kv_lora_rank ,
344417 )
345-
346- # Remove padding.
347- attn_output = attn_output [..., : self .value_head_size ]
348-
349- return self .o_proj (
350- attn_output .reshape (- 1 , self .num_heads * self .value_head_size )
351- )
418+ attn_output = self ._v_up_proj_and_o_proj (attn_output )
419+ return attn_output
352420
353421
354422class DeepseekV2MLP (nn .Module ):
0 commit comments