Skip to content

Commit dff774e

Browse files
authored
[feat] support tensor parallelism for MQA/GQA models when num_kv_heads < world_size (#137)
1 parent d176e87 commit dff774e

File tree

13 files changed

+206
-54
lines changed

13 files changed

+206
-54
lines changed

src/common/slice.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ class Slice final {
6262

6363
// help comparison operators between slices and std::vector
6464
template <typename T>
65-
bool operator==(const Slice<T>& lhs, const std::vector<T>& rhs) {
65+
inline bool operator==(const Slice<T>& lhs, const std::vector<T>& rhs) {
6666
return lhs.size() == rhs.size() &&
6767
std::equal(lhs.begin(), lhs.end(), rhs.begin());
6868
}
6969

7070
template <typename T>
71-
bool operator==(const std::vector<T>& lhs, const Slice<T>& rhs) {
71+
inline bool operator==(const std::vector<T>& lhs, const Slice<T>& rhs) {
7272
return lhs.size() == rhs.size() &&
7373
std::equal(lhs.begin(), lhs.end(), rhs.begin());
7474
}

src/engine/llm_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) {
141141
const int world_size = static_cast<int>(workers_.size());
142142
const int64_t n_heads = args_.n_heads();
143143
const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads);
144-
n_local_kv_heads_ = n_kv_heads / world_size;
144+
n_local_kv_heads_ = std::max<int64_t>(1, n_kv_heads / world_size);
145145
head_dim_ = args_.head_dim();
146146
dtype_ = parse_dtype(args_.dtype(), options_.devices()[0]);
147147

src/layers/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ cc_library(
66
linear
77
HDRS
88
linear.h
9+
qkv_linear.h
910
linear_impl.h
1011
SRCS
1112
linear.cpp
13+
qkv_linear.cpp
1214
linear_impl.cpp
1315
DEPS
1416
:state_dict

src/layers/qkv_linear.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "qkv_linear.h"
2+
3+
#include <absl/strings/match.h>
4+
#include <glog/logging.h>
5+
#include <torch/torch.h>
6+
7+
namespace llm {
8+
QKVColumnParallelLinearImpl::QKVColumnParallelLinearImpl(
9+
int64_t hidden_size,
10+
int64_t n_heads,
11+
int64_t n_kv_heads,
12+
int64_t head_dim,
13+
bool bias,
14+
bool gather_output,
15+
const QuantArgs& quant_args,
16+
const ParallelArgs& parallel_args,
17+
const torch::TensorOptions& options) {
18+
// calculate logical kv heads with support of MQA/GQA
19+
const int32_t world_size = parallel_args.world_size();
20+
if (n_kv_heads >= world_size) {
21+
// partition kv heads evenly across world_size for MHA
22+
CHECK_EQ(n_kv_heads % world_size, 0)
23+
<< "kv_heads can't be partitioned evenly across world_size";
24+
kv_replication_ratio_ = 1;
25+
} else {
26+
// replicate kv heads evenly across world_size for GQA/MQA
27+
CHECK_EQ(world_size % n_kv_heads, 0)
28+
<< "kv heads can't be replicated evenly across world_size";
29+
kv_replication_ratio_ = world_size / n_kv_heads;
30+
n_kv_heads = world_size;
31+
}
32+
33+
parallel_linear_ = ColumnParallelLinear(hidden_size,
34+
(n_heads + 2 * n_kv_heads) * head_dim,
35+
bias,
36+
gather_output,
37+
quant_args,
38+
parallel_args,
39+
options);
40+
}
41+
42+
// special load_state_dict for fused cases
43+
void QKVColumnParallelLinearImpl::load_state_dict(
44+
const StateDict& state_dict,
45+
const std::vector<std::string_view>& prefixes,
46+
const std::vector<std::string_view>& kv_prefixes) {
47+
if (kv_replication_ratio_ > 1) {
48+
// replicate kv heads
49+
auto kv_replicated_state_dict = state_dict.select_with_transform(
50+
"", [&](const std::string_view& name, const torch::Tensor& tensor) {
51+
for (const auto& kv_prefix : kv_prefixes) {
52+
if (absl::StartsWith(name, kv_prefix)) {
53+
return tensor.repeat({kv_replication_ratio_, 1});
54+
}
55+
}
56+
return tensor;
57+
});
58+
parallel_linear_->load_state_dict(kv_replicated_state_dict, prefixes);
59+
} else {
60+
parallel_linear_->load_state_dict(state_dict, prefixes);
61+
}
62+
}
63+
64+
} // namespace llm

src/layers/qkv_linear.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#pragma once
2+
3+
#include <glog/logging.h>
4+
#include <torch/torch.h>
5+
6+
#include "linear.h"
7+
#include "model_loader/state_dict.h"
8+
#include "model_parallel/parallel_args.h"
9+
#include "quantization/quant_args.h"
10+
11+
namespace llm {
12+
13+
// a thin wrapper to handle state_dict loading for QKV with
14+
// support of MQA/GQA
15+
class QKVColumnParallelLinearImpl : public torch::nn::Module {
16+
public:
17+
QKVColumnParallelLinearImpl(int64_t hidden_size,
18+
int64_t n_heads,
19+
int64_t n_kv_heads,
20+
int64_t head_dim,
21+
bool bias,
22+
bool gather_output,
23+
const QuantArgs& quant_args,
24+
const ParallelArgs& parallel_args,
25+
const torch::TensorOptions& options);
26+
27+
torch::Tensor forward(torch::Tensor input) const {
28+
return parallel_linear_->forward(input);
29+
}
30+
31+
// special load_state_dict for fused cases
32+
void load_state_dict(const StateDict& state_dict,
33+
const std::vector<std::string_view>& prefixes,
34+
const std::vector<std::string_view>& kv_prefixes);
35+
36+
void verify_loaded_weights(const std::string& prefix = "") const {
37+
parallel_linear_->verify_loaded_weights(prefix);
38+
}
39+
40+
private:
41+
ColumnParallelLinear parallel_linear_{nullptr};
42+
43+
// replication ratio of kv heads for MQA/GQA cases
44+
int64_t kv_replication_ratio_ = 1;
45+
};
46+
TORCH_MODULE(QKVColumnParallelLinear);
47+
48+
} // namespace llm

src/model_loader/state_dict.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "state_dict.h"
22

33
#include <ATen/core/TensorBody.h>
4+
#include <absl/strings/match.h>
45
#include <caffe2/serialize/inline_container.h>
56
#include <glog/logging.h>
67
#include <torch/csrc/jit/serialization/import_read.h>
@@ -183,7 +184,8 @@ torch::Tensor StateDict::get_tensor(const std::string_view& tensor_name) const {
183184
return torch::Tensor{nullptr};
184185
}
185186
// apply transform function if exists
186-
return transform_func_ ? transform_func_(it->second) : it->second;
187+
return transform_func_ ? transform_func_(tensor_name, it->second)
188+
: it->second;
187189
}
188190

189191
torch::Tensor StateDict::get_sharded_tensor(const std::string_view& tensor_name,
@@ -231,8 +233,7 @@ torch::Tensor StateDict::get_sharded_tensor(const std::string_view& tensor_name,
231233
StateDict StateDict::select(const std::string_view& prefix) const {
232234
std::unordered_map<std::string, torch::Tensor> selected;
233235
for (const auto& [name, tensor] : dict_) {
234-
std::size_t found = name.find(prefix);
235-
if (found == 0) {
236+
if (absl::StartsWith(name, prefix)) {
236237
selected[name.substr(prefix.length())] = tensor;
237238
}
238239
}

src/model_loader/state_dict.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class StateDict final {
4444

4545
// select all tensors whose name starts with prefix and apply the transform
4646
// for each tensor.
47-
using TensorTransform = std::function<torch::Tensor(torch::Tensor)>;
47+
using TensorTransform = std::function<torch::Tensor(const std::string_view&,
48+
const torch::Tensor&)>;
4849
StateDict select_with_transform(const std::string_view& prefix,
4950
TensorTransform transform_func) const;
5051

src/models/huggingface/baichuan.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ class BaichuanForCausalLMImpl : public torch::nn::Module {
437437
// Baichuan2 normalizes the head weights:
438438
// https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L508
439439
lm_head_->load_state_dict(state_dict.select_with_transform(
440-
"lm_head.", [](torch::Tensor tensor) {
440+
"lm_head.",
441+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
441442
return torch::nn::functional::normalize(tensor);
442443
}));
443444
} else {

src/models/huggingface/gemma.h

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
#pragma once
2+
#include <absl/strings/match.h>
23
#include <glog/logging.h>
34
#include <torch/torch.h>
45

6+
#include <string>
7+
58
#include "chat_template/coded_chat_template.h"
69
#include "layers/activation.h"
710
#include "layers/attention/attention.h"
811
#include "layers/attention/handler.h"
912
#include "layers/embedding.h"
1013
#include "layers/linear.h"
1114
#include "layers/normalization.h"
15+
#include "layers/qkv_linear.h"
1216
#include "memory/kv_cache.h"
1317
#include "models/model_args.h"
1418
#include "models/model_registry.h"
@@ -87,26 +91,28 @@ class GemmaAttentionImpl : public torch::nn::Module {
8791
const int32_t world_size = parallel_args.world_size();
8892
const int64_t hidden_size = args.hidden_size();
8993
const int64_t n_heads = args.n_heads();
90-
const int64_t head_dim = args.head_dim();
9194
const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads);
95+
const int64_t head_dim = args.head_dim();
9296
const int64_t n_local_heads = n_heads / world_size;
93-
const int64_t n_local_kv_heads = n_kv_heads / world_size;
97+
const int64_t n_local_kv_heads =
98+
std::max<int64_t>(1, n_kv_heads / world_size);
9499

95100
// size for q, k, v
96101
qkv_sizes_ = {n_local_heads * head_dim,
97102
n_local_kv_heads * head_dim,
98103
n_local_kv_heads * head_dim};
99104

100105
// register submodules
101-
qkv_proj_ = register_module(
102-
"qkv_proj",
103-
ColumnParallelLinear(hidden_size,
104-
(n_heads + 2 * n_kv_heads) * head_dim,
105-
/*bias=*/false,
106-
/*gather_output=*/false,
107-
quant_args,
108-
parallel_args,
109-
options));
106+
qkv_proj_ = register_module("qkv_proj",
107+
QKVColumnParallelLinear(hidden_size,
108+
n_heads,
109+
n_kv_heads,
110+
head_dim,
111+
/*bias=*/false,
112+
/*gather_output=*/false,
113+
quant_args,
114+
parallel_args,
115+
options));
110116

111117
o_proj_ = register_module("o_proj",
112118
RowParallelLinear(n_heads * head_dim,
@@ -141,7 +147,8 @@ class GemmaAttentionImpl : public torch::nn::Module {
141147
// load the weight from the checkpoint
142148
void load_state_dict(const StateDict& state_dict) {
143149
// call each submodule's load_state_dict function
144-
qkv_proj_->load_state_dict(state_dict, {"q_proj.", "k_proj.", "v_proj."});
150+
qkv_proj_->load_state_dict(
151+
state_dict, {"q_proj.", "k_proj.", "v_proj."}, {"k_proj.", "v_proj."});
145152
o_proj_->load_state_dict(state_dict.select("o_proj."));
146153
}
147154

@@ -152,7 +159,7 @@ class GemmaAttentionImpl : public torch::nn::Module {
152159

153160
private:
154161
// parameter members, must be registered
155-
ColumnParallelLinear qkv_proj_{nullptr};
162+
QKVColumnParallelLinear qkv_proj_{nullptr};
156163

157164
RowParallelLinear o_proj_{nullptr};
158165

@@ -207,12 +214,16 @@ class GemmaDecoderLayerImpl : public torch::nn::Module {
207214
void load_state_dict(const StateDict& state_dict) {
208215
input_layernorm_->load_state_dict((state_dict.select_with_transform(
209216
"input_layernorm.",
210-
[](torch::Tensor tensor) { return tensor + 1.0f; })));
217+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
218+
return tensor + 1.0f;
219+
})));
211220
mlp_->load_state_dict(state_dict.select("mlp."));
212221
post_attention_layernorm_->load_state_dict(
213222
(state_dict.select_with_transform(
214223
"post_attention_layernorm.",
215-
[](torch::Tensor tensor) { return tensor + 1.0f; })));
224+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
225+
return tensor + 1.0f;
226+
})));
216227
self_attn_->load_state_dict(state_dict.select("self_attn."));
217228
}
218229
void verify_loaded_weights(const std::string& prefix) const {
@@ -301,7 +312,10 @@ class GemmaModelImpl : public torch::nn::Module {
301312
// GemmaRMSNorm is different from Llama's in that it multiplies
302313
// (1 + weight) to the output, instead of just weight.
303314
norm_->load_state_dict((state_dict.select_with_transform(
304-
"norm.", [](torch::Tensor tensor) { return tensor + 1.0f; })));
315+
"norm.",
316+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
317+
return tensor + 1.0f;
318+
})));
305319
}
306320

307321
void verify_loaded_weights(const std::string& prefix) const {

src/models/huggingface/gpt2.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,15 @@ class GPT2MLPImpl : public torch::nn::Module {
5757
// GPT-2 implementation uses Conv1D instead of Linear. As a result, we
5858
// need to transpose the weight.
5959
c_fc_->load_state_dict(state_dict.select_with_transform(
60-
"c_fc.", [](torch::Tensor tensor) { return tensor.t(); }));
60+
"c_fc.",
61+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
62+
return tensor.t();
63+
}));
6164
c_proj_->load_state_dict(state_dict.select_with_transform(
62-
"c_proj.", [](torch::Tensor tensor) { return tensor.t(); }));
65+
"c_proj.",
66+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
67+
return tensor.t();
68+
}));
6369
}
6470

6571
void verify_loaded_weights(const std::string& prefix) const {
@@ -134,9 +140,15 @@ class GPT2AttentionImpl : public torch::nn::Module {
134140
// GPT-2 implementation uses Conv1D instead of Linear. As a result, we
135141
// need to transpose the weight.
136142
c_attn_->load_state_dict(state_dict.select_with_transform(
137-
"c_attn.", [](torch::Tensor tensor) { return tensor.t(); }));
143+
"c_attn.",
144+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
145+
return tensor.t();
146+
}));
138147
c_proj_->load_state_dict(state_dict.select_with_transform(
139-
"c_proj.", [](torch::Tensor tensor) { return tensor.t(); }));
148+
"c_proj.",
149+
[](const std::string_view& /*name*/, const torch::Tensor& tensor) {
150+
return tensor.t();
151+
}));
140152
}
141153

142154
void verify_loaded_weights(const std::string& prefix) const {

0 commit comments

Comments
 (0)