@@ -272,7 +272,11 @@ def __init__(
272272 )
273273
274274 self .norm_1 = nn .Identity () if not config .norm_1 else config .norm_class (config .n_embd , eps = config .norm_eps )
275- self .attn = CausalSelfAttention (config , block_idx )
275+ self .attn = (
276+ CausalSelfAttention (config , block_idx )
277+ if not config .latent_attention
278+ else MultiheadLatentAttention (config , block_idx )
279+ )
276280 self .post_attention_norm = (
277281 config .norm_class (config .n_embd , eps = config .norm_eps ) if config .post_attention_norm else nn .Identity ()
278282 )
@@ -549,6 +553,146 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
549553 super ()._load_from_state_dict (state_dict , prefix , * args , ** kwargs )
550554
551555
556+ class MultiheadLatentAttention (nn .Module ):
557+ def __init__ (self , config : Config , block_idx : int ) -> None :
558+ super ().__init__ ()
559+
560+ self .q_a_proj = nn .Linear (config .n_embd , config .q_lora_rank , bias = config .attn_bias )
561+ self .q_a_norm = RMSNorm (config .q_lora_rank , eps = config .norm_eps )
562+ self .q_b_proj = nn .Linear (config .q_lora_rank , config .n_head * config .qk_head_dim , bias = config .bias )
563+
564+ self .kv_a_proj_with_mqa = nn .Linear (
565+ config .n_embd , config .kv_lora_rank + config .qk_rope_head_dim , bias = config .attn_bias
566+ )
567+ self .kv_a_norm = RMSNorm (config .kv_lora_rank , eps = config .norm_eps )
568+ self .kv_b_proj = nn .Linear (
569+ config .kv_lora_rank ,
570+ config .n_query_groups * (config .qk_nope_head_dim + config .v_head_dim ),
571+ bias = config .bias ,
572+ )
573+
574+ # output projection
575+ self .proj = nn .Linear (config .n_head * config .v_head_dim , config .n_embd , bias = config .bias )
576+ # disabled by default
577+ self .kv_cache : Optional [KVCache ] = None
578+
579+ self .config = config
580+ self .block_idx = block_idx
581+
582+ def forward (
583+ self ,
584+ x : torch .Tensor ,
585+ cos : torch .Tensor ,
586+ sin : torch .Tensor ,
587+ mask : Optional [torch .Tensor ] = None ,
588+ input_pos : Optional [torch .Tensor ] = None ,
589+ input_pos_maxp1 : Optional [int ] = None ,
590+ ) -> torch .Tensor :
591+ # Notation:
592+ # - B | batch size
593+ # - T | time-step (sequence length)
594+ # - C | model's embeddings size (n_embd)
595+ # - C* | attentions's embeddings size
596+ # - hs | head size
597+ # - nh_(q,k,v) | number of heads for query, key and value
598+ # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
599+ # alternative notation: num_kv_groups = n_query_groups
600+ B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
601+
602+ q = self .q_b_proj (self .q_a_norm (self .q_a_proj (x ))) # (B, T, n_head * qk_head_dim)
603+ q = q .view (B , T , - 1 , self .config .qk_head_dim ) # (B, T, n_head, qk_head_dim)
604+ q = q .transpose (1 , 2 ) # (B, n_head, T, qk_head_dim)
605+ q_pass , q_rot = torch .split (q , [self .config .qk_nope_head_dim , self .config .qk_rope_head_dim ], dim = - 1 )
606+
607+ compressed_kv = self .kv_a_proj_with_mqa (x ) # (B, T, kv_lora_rank + qk_rope_head_dim)
608+ k_pass , k_rot = torch .split (compressed_kv , [self .config .kv_lora_rank , self .config .qk_rope_head_dim ], dim = - 1 )
609+
610+ k_pass = self .kv_b_proj (self .kv_a_norm (k_pass ))
611+ k_pass = k_pass .view (B , T , self .config .n_query_groups , - 1 )
612+ k_pass = k_pass .transpose (1 , 2 )
613+
614+ k_pass , v = torch .split (k_pass , [self .config .qk_nope_head_dim , self .config .v_head_dim ], dim = - 1 )
615+ k_rot = k_rot .view (B , 1 , T , self .config .qk_rope_head_dim ) # (B, 1, T, qk_rope_head_dim)
616+
617+ # Unlike standard positional embeddings rotary embeddings must be applied at every layer.
618+ q_roped = apply_rope (q_rot , cos , sin )
619+ k_roped = apply_rope (k_rot , cos , sin )
620+ k_roped = k_roped .expand (* k_pass .shape [:- 1 ], - 1 ) # (B, n_head, T, qk_rope_head_dim)
621+
622+ q = torch .cat ((q_pass , q_roped ), dim = - 1 )
623+ k = torch .cat ((k_pass , k_roped ), dim = - 1 )
624+
625+ # Apply kv-cache during inference.
626+ if input_pos is not None :
627+ if not isinstance (self .kv_cache , KVCache ):
628+ raise TypeError ("You need to call `gpt.set_kv_cache()`" )
629+ k , v = self .kv_cache (input_pos , k , v )
630+ if input_pos_maxp1 is not None :
631+ # Subselect along sequence dimension
632+ k = k [..., :input_pos_maxp1 , :]
633+ v = v [..., :input_pos_maxp1 , :]
634+ # k, v: (B, nh_k, input_pos_maxp1, hs)
635+ # If input_pos_maxp1 is None -> max_seq_length
636+
637+ # Grouped queries: balance the number of heads across all three matrices.
638+ # NOTE: flash attention requires it in training mode.
639+ # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
640+ if self .config .n_query_groups != self .config .n_head and (input_pos is None or self .config .n_query_groups != 1 ):
641+ q_per_kv = self .config .n_head // self .config .n_query_groups
642+ k = k .repeat_interleave (q_per_kv , dim = 1 ) # (B, nh_q, T, hs)
643+ v = v .repeat_interleave (q_per_kv , dim = 1 ) # (B, nh_q, T, hs)
644+
645+ # Efficient attention using Flash Attention CUDA kernels.
646+ # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
647+ # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
648+ y = self .scaled_dot_product_attention (q , k , v , mask )
649+
650+ # Re-assemble all head outputs side by side.
651+ y = y .reshape (B , T , self .config .n_head * self .config .v_head_dim )
652+
653+ # Output projection.
654+ return self .proj (y ) # (B, T, C)
655+
656+ def scaled_dot_product_attention (
657+ self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
658+ ) -> torch .Tensor :
659+ scale = 1.0 / math .sqrt (self .config .attention_scores_scalar or self .config .qk_head_dim )
660+
661+ # with softcapping we cannot use SDPA
662+ if self .config .attention_logit_softcapping is not None :
663+ scores = q @ k .mT * scale
664+ scores = do_softcapping (scores , self .config .attention_logit_softcapping )
665+ if mask is None :
666+ mask = torch .ones (q .size (2 ), q .size (2 ), dtype = q .dtype , device = q .device ).triu (diagonal = 1 )
667+ mask .masked_fill_ (mask .bool (), torch .finfo (q .dtype ).min )
668+ scores = scores + mask
669+ scores = F .softmax (scores , dim = - 1 , dtype = torch .float ).to (dtype = q .dtype )
670+ y = scores @ v
671+ else :
672+ y = F .scaled_dot_product_attention (
673+ q , k , v , attn_mask = mask , dropout_p = 0.0 , scale = scale , is_causal = mask is None
674+ )
675+ return y .transpose (1 , 2 )
676+
677+ def build_kv_cache (
678+ self ,
679+ batch_size : int ,
680+ max_seq_length : int ,
681+ rope_cache_length : Optional [int ] = None ,
682+ device : Optional [torch .device ] = None ,
683+ dtype : Optional [torch .dtype ] = None ,
684+ ) -> "KVCache" :
685+ v_shape = (batch_size , self .config .n_head , max_seq_length , self .config .v_head_dim )
686+ k_shape = (batch_size , self .config .n_head , max_seq_length , self .config .qk_head_dim )
687+
688+ if rope_cache_length is not None :
689+ print ("Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!" )
690+ if self .config .rotary_percentage != 1.0 :
691+ print ("Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!" )
692+
693+ return KVCache (k_shape , v_shape , device = device , dtype = dtype )
694+
695+
552696class GptNeoxMLP (nn .Module ):
553697 def __init__ (self , config : Config , intermediate_size : Optional [int ] = None ) -> None :
554698 super ().__init__ ()
0 commit comments