11#pragma once
2+ #include < absl/strings/match.h>
23#include < glog/logging.h>
34#include < torch/torch.h>
45
6+ #include < string>
7+
58#include " chat_template/coded_chat_template.h"
69#include " layers/activation.h"
710#include " layers/attention/attention.h"
811#include " layers/attention/handler.h"
912#include " layers/embedding.h"
1013#include " layers/linear.h"
1114#include " layers/normalization.h"
15+ #include " layers/qkv_linear.h"
1216#include " memory/kv_cache.h"
1317#include " models/model_args.h"
1418#include " models/model_registry.h"
@@ -87,26 +91,28 @@ class GemmaAttentionImpl : public torch::nn::Module {
8791 const int32_t world_size = parallel_args.world_size ();
8892 const int64_t hidden_size = args.hidden_size ();
8993 const int64_t n_heads = args.n_heads ();
90- const int64_t head_dim = args.head_dim ();
9194 const int64_t n_kv_heads = args.n_kv_heads ().value_or (n_heads);
95+ const int64_t head_dim = args.head_dim ();
9296 const int64_t n_local_heads = n_heads / world_size;
93- const int64_t n_local_kv_heads = n_kv_heads / world_size;
97+ const int64_t n_local_kv_heads =
98+ std::max<int64_t >(1 , n_kv_heads / world_size);
9499
95100 // size for q, k, v
96101 qkv_sizes_ = {n_local_heads * head_dim,
97102 n_local_kv_heads * head_dim,
98103 n_local_kv_heads * head_dim};
99104
100105 // register submodules
101- qkv_proj_ = register_module (
102- " qkv_proj" ,
103- ColumnParallelLinear (hidden_size,
104- (n_heads + 2 * n_kv_heads) * head_dim,
105- /* bias=*/ false ,
106- /* gather_output=*/ false ,
107- quant_args,
108- parallel_args,
109- options));
106+ qkv_proj_ = register_module (" qkv_proj" ,
107+ QKVColumnParallelLinear (hidden_size,
108+ n_heads,
109+ n_kv_heads,
110+ head_dim,
111+ /* bias=*/ false ,
112+ /* gather_output=*/ false ,
113+ quant_args,
114+ parallel_args,
115+ options));
110116
111117 o_proj_ = register_module (" o_proj" ,
112118 RowParallelLinear (n_heads * head_dim,
@@ -141,7 +147,8 @@ class GemmaAttentionImpl : public torch::nn::Module {
141147 // load the weight from the checkpoint
142148 void load_state_dict (const StateDict& state_dict) {
143149 // call each submodule's load_state_dict function
144- qkv_proj_->load_state_dict (state_dict, {" q_proj." , " k_proj." , " v_proj." });
150+ qkv_proj_->load_state_dict (
151+ state_dict, {" q_proj." , " k_proj." , " v_proj." }, {" k_proj." , " v_proj." });
145152 o_proj_->load_state_dict (state_dict.select (" o_proj." ));
146153 }
147154
@@ -152,7 +159,7 @@ class GemmaAttentionImpl : public torch::nn::Module {
152159
153160 private:
154161 // parameter members, must be registered
155- ColumnParallelLinear qkv_proj_{nullptr };
162+ QKVColumnParallelLinear qkv_proj_{nullptr };
156163
157164 RowParallelLinear o_proj_{nullptr };
158165
@@ -207,12 +214,16 @@ class GemmaDecoderLayerImpl : public torch::nn::Module {
207214 void load_state_dict (const StateDict& state_dict) {
208215 input_layernorm_->load_state_dict ((state_dict.select_with_transform (
209216 " input_layernorm." ,
210- [](torch::Tensor tensor) { return tensor + 1 .0f ; })));
217+ [](const std::string_view& /* name*/ , const torch::Tensor& tensor) {
218+ return tensor + 1 .0f ;
219+ })));
211220 mlp_->load_state_dict (state_dict.select (" mlp." ));
212221 post_attention_layernorm_->load_state_dict (
213222 (state_dict.select_with_transform (
214223 " post_attention_layernorm." ,
215- [](torch::Tensor tensor) { return tensor + 1 .0f ; })));
224+ [](const std::string_view& /* name*/ , const torch::Tensor& tensor) {
225+ return tensor + 1 .0f ;
226+ })));
216227 self_attn_->load_state_dict (state_dict.select (" self_attn." ));
217228 }
218229 void verify_loaded_weights (const std::string& prefix) const {
@@ -301,7 +312,10 @@ class GemmaModelImpl : public torch::nn::Module {
301312 // GemmaRMSNorm is different from Llama's in that it multiplies
302313 // (1 + weight) to the output, instead of just weight.
303314 norm_->load_state_dict ((state_dict.select_with_transform (
304- " norm." , [](torch::Tensor tensor) { return tensor + 1 .0f ; })));
315+ " norm." ,
316+ [](const std::string_view& /* name*/ , const torch::Tensor& tensor) {
317+ return tensor + 1 .0f ;
318+ })));
305319 }
306320
307321 void verify_loaded_weights (const std::string& prefix) const {
0 commit comments