Skip to content

Commit 0408bc1

Browse files
authored
[feat] added support for kv_cache with different strides. (#143)
1 parent a7925c9 commit 0408bc1

File tree

5 files changed

+95
-83
lines changed

5 files changed

+95
-83
lines changed

src/kernels/kv_cache_kernels.cu

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ __global__ void set_kv_cache_kernel(
1212
const T* __restrict__ values, // [n_tokens, n_heads, head_dim]
1313
T* __restrict__ key_cache,
1414
T* __restrict__ value_cache,
15-
int kv_stride,
16-
int n_kv_heads,
17-
int head_dim,
18-
int block_size) {
15+
int64_t k_stride,
16+
int64_t v_stride,
17+
int64_t n_kv_heads,
18+
int64_t head_dim,
19+
int64_t block_size) {
1920
// block/token index
2021
const int64_t bid = blockIdx.x;
2122
// which slot to write to
@@ -29,8 +30,9 @@ __global__ void set_kv_cache_kernel(
2930
const int64_t block_base_idx = block_idx * block_size * n_kv_heads * head_dim;
3031

3132
// copy value one by one for the token
32-
for (int i = threadIdx.x; i < n_kv_heads * head_dim; i += blockDim.x) {
33-
const int64_t src_idx = bid * kv_stride + i;
33+
for (int64_t i = threadIdx.x; i < n_kv_heads * head_dim; i += blockDim.x) {
34+
const int64_t k_src_idx = bid * k_stride + i;
35+
const int64_t v_src_idx = bid * v_stride + i;
3436

3537
// cache: [n_blocks, block_size, n_heads, head_dim]
3638
const int64_t head_base_idx =
@@ -42,8 +44,8 @@ __global__ void set_kv_cache_kernel(
4244
const int head_offset = i % head_dim;
4345
const int64_t dst_idx = head_base_idx + head_idx * head_dim + head_offset;
4446

45-
key_cache[dst_idx] = keys[src_idx];
46-
value_cache[dst_idx] = values[src_idx];
47+
key_cache[dst_idx] = keys[k_src_idx];
48+
value_cache[dst_idx] = values[v_src_idx];
4749
}
4850
}
4951

@@ -53,15 +55,21 @@ void set_kv_cache(
5355
const torch::Tensor& values, // [n_tokens, n_kv_heads, head_dim]
5456
torch::Tensor& key_cache, // [n_blocks, block_size, n_heads, head_dim]
5557
torch::Tensor& value_cache) {
56-
const int n_tokens = keys.size(0);
57-
const int n_kv_heads = keys.size(-2);
58-
const int head_dim = keys.size(-1);
59-
const int block_size = key_cache.size(-3);
60-
const int kv_stride = keys.stride(0);
61-
const int n = n_kv_heads * head_dim;
58+
// keys and values should be continuous at n_kv_heads and head_dim dims
59+
CHECK(keys.stride(-1) == 1 && keys.stride(-2) == keys.size(-1));
60+
CHECK(values.stride(-1) == 1 && values.stride(-2) == values.size(-1));
61+
62+
const int64_t n_tokens = keys.size(-3);
63+
const int64_t n_kv_heads = keys.size(-2);
64+
const int64_t head_dim = keys.size(-1);
65+
const int64_t block_size = key_cache.size(-3);
66+
// it is possible that keys and values have different strides
67+
const int64_t k_stride = keys.stride(-3);
68+
const int64_t v_stride = values.stride(-3);
69+
const int64_t n = n_kv_heads * head_dim;
6270

6371
dim3 grid(n_tokens);
64-
dim3 block(std::min(n, 1024));
72+
dim3 block(std::min<int>(n, 1024));
6573
DISPATCH_FLOATING_TYPES(keys.scalar_type(), "set_kv_cache_kernel", [&] {
6674
set_kv_cache_kernel<scalar_t>
6775
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
@@ -70,7 +78,8 @@ void set_kv_cache(
7078
values.data_ptr<scalar_t>(),
7179
key_cache.data_ptr<scalar_t>(),
7280
value_cache.data_ptr<scalar_t>(),
73-
kv_stride,
81+
k_stride,
82+
v_stride,
7483
n_kv_heads,
7584
head_dim,
7685
block_size);

src/kernels/layernorm_kernels.cu

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ __global__ void rms_norm_kernel(T* __restrict__ out,
1616
const T* __restrict__ input,
1717
const T* __restrict__ weight,
1818
const float epsilon,
19-
int n) {
20-
const int tidx = threadIdx.x;
21-
const int bidx = blockIdx.x;
19+
int64_t n) {
20+
const auto tidx = threadIdx.x;
21+
const auto bidx = blockIdx.x;
2222

2323
__shared__ float s_variance;
2424
float variance = 0.0f;
2525

26-
for (int i = tidx; i < n; i += blockDim.x) {
26+
for (int64_t i = tidx; i < n; i += blockDim.x) {
2727
const float x = input[bidx * n + i];
2828
variance += x * x;
2929
}
@@ -33,8 +33,8 @@ __global__ void rms_norm_kernel(T* __restrict__ out,
3333
}
3434
__syncthreads();
3535

36-
for (int i = tidx; i < n; i += blockDim.x) {
37-
const int idx = bidx * n + i;
36+
for (int64_t i = tidx; i < n; i += blockDim.x) {
37+
const int64_t idx = bidx * n + i;
3838
const float x = input[idx];
3939
out[idx] = (T)(x * s_variance) * weight[i];
4040
}
@@ -47,10 +47,10 @@ void rms_norm(torch::Tensor& out,
4747
DCHECK(input.is_contiguous()) << "input tensor must be contiguous";
4848
DCHECK(out.is_contiguous()) << "output tensor must be contiguous";
4949

50-
const int n = input.size(1);
50+
const int64_t n = input.size(1);
5151

5252
dim3 grid(input.size(0));
53-
dim3 block(std::min(n, 1024));
53+
dim3 block(std::min<int>(n, 1024));
5454
DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
5555
rms_norm_kernel<scalar_t>
5656
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
@@ -73,15 +73,15 @@ __global__ void rms_norm_residual_kernel(T* __restrict__ out,
7373
const T* __restrict__ input,
7474
const T* __restrict__ weight,
7575
const float epsilon,
76-
int n) {
77-
const int tidx = threadIdx.x;
78-
const int bidx = blockIdx.x;
76+
int64_t n) {
77+
const auto tidx = threadIdx.x;
78+
const auto bidx = blockIdx.x;
7979

8080
__shared__ float s_variance;
8181
float variance = 0.0f;
8282

83-
for (int i = tidx; i < n; i += blockDim.x) {
84-
const int idx = bidx * n + i;
83+
for (int64_t i = tidx; i < n; i += blockDim.x) {
84+
const int64_t idx = bidx * n + i;
8585
const float r = residual[idx];
8686
const float x = r + input[idx];
8787
residual[idx] = x;
@@ -93,8 +93,8 @@ __global__ void rms_norm_residual_kernel(T* __restrict__ out,
9393
}
9494
__syncthreads();
9595

96-
for (int i = tidx; i < n; i += blockDim.x) {
97-
const int idx = bidx * n + i;
96+
for (int64_t i = tidx; i < n; i += blockDim.x) {
97+
const int64_t idx = bidx * n + i;
9898
const float x = residual[idx];
9999
out[idx] = (T)(x * s_variance) * weight[i];
100100
}
@@ -109,10 +109,10 @@ void rms_norm_residual(torch::Tensor& out,
109109
DCHECK(out.is_contiguous()) << "output tensor must be contiguous";
110110
DCHECK(residual.is_contiguous()) << "residual tensor must be contiguous";
111111

112-
const int n = input.size(1);
112+
const int64_t n = input.size(1);
113113

114114
dim3 grid(input.size(0));
115-
dim3 block(std::min(n, 1024));
115+
dim3 block(std::min<int>(n, 1024));
116116
DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_residual_kernel", [&] {
117117
rms_norm_residual_kernel<scalar_t>
118118
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
@@ -133,17 +133,17 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
133133
const T* __restrict__ weight,
134134
const T* __restrict__ bias,
135135
const float epsilon,
136-
int n) {
137-
const int tidx = threadIdx.x;
138-
const int bidx = blockIdx.x;
136+
int64_t n) {
137+
const auto tidx = threadIdx.x;
138+
const auto bidx = blockIdx.x;
139139

140140
__shared__ float s_mean;
141141
__shared__ float s_variance;
142142
float mean = 0.0f;
143143
float variance = 0.0f;
144144

145145
// calculate mean of the input.
146-
for (int i = tidx; i < n; i += blockDim.x) {
146+
for (int64_t i = tidx; i < n; i += blockDim.x) {
147147
mean += input[bidx * n + i];
148148
}
149149
mean = block_reduce_sum<float>(mean);
@@ -153,7 +153,7 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
153153
__syncthreads();
154154

155155
// calculate variance of the input.
156-
for (int i = tidx; i < n; i += blockDim.x) {
156+
for (int64_t i = tidx; i < n; i += blockDim.x) {
157157
const float x = input[bidx * n + i] - s_mean;
158158
variance += x * x;
159159
}
@@ -163,8 +163,8 @@ __global__ void layer_norm_kernel(T* __restrict__ out,
163163
}
164164
__syncthreads();
165165

166-
for (int i = tidx; i < n; i += blockDim.x) {
167-
const int idx = bidx * n + i;
166+
for (int64_t i = tidx; i < n; i += blockDim.x) {
167+
const int64_t idx = bidx * n + i;
168168
float local_out = (input[idx] - s_mean) * s_variance * weight[i];
169169
if (bias != nullptr) {
170170
local_out += bias[i];
@@ -181,10 +181,10 @@ void layer_norm(torch::Tensor& out,
181181
DCHECK(input.is_contiguous()) << "input tensor must be contiguous";
182182
DCHECK(out.is_contiguous()) << "output tensor must be contiguous";
183183

184-
const int n = input.size(1);
184+
const int64_t n = input.size(1);
185185

186186
dim3 grid(input.size(0));
187-
dim3 block(std::min(n, 1024));
187+
dim3 block(std::min<int>(n, 1024));
188188
DISPATCH_FLOATING_TYPES(input.scalar_type(), "layer_norm_kernel", [&] {
189189
layer_norm_kernel<scalar_t>
190190
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(

src/kernels/pos_embedding_kernels.cu

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,47 +33,47 @@ struct RotaryEmbedding {
3333
// inplace update query and key
3434
template <typename T>
3535
__global__ void rotary_embedding_kernel(
36-
T* __restrict__ query, // [n_tokens, n_heads, head_dim]
37-
T* __restrict__ key, // [n_tokens, n_kv_heads, head_dim]
36+
T* __restrict__ querys, // [n_tokens, n_heads, head_dim]
37+
T* __restrict__ keys, // [n_tokens, n_kv_heads, head_dim]
3838
const int* __restrict__ positions, // [n_tokens]
3939
const T* __restrict__ cos_sin, // [max_positions, 2, rotary_dim/2]
40-
int head_dim,
41-
int rotary_dim,
42-
int n_heads,
43-
int n_kv_heads,
44-
int q_stride,
45-
int k_stride,
40+
int64_t head_dim,
41+
int64_t rotary_dim,
42+
int64_t n_heads,
43+
int64_t n_kv_heads,
44+
int64_t q_stride,
45+
int64_t k_stride,
4646
bool interleaved) {
4747
const int tidx = threadIdx.x;
4848
const int bidx = blockIdx.x;
4949

5050
// figure out cos sin base ptr for the token
51-
const int n = rotary_dim / 2;
51+
const int64_t n = rotary_dim / 2;
5252
const T* cos_sin_base = cos_sin + positions[bidx] * rotary_dim;
5353
const T* cos = cos_sin_base;
5454
const T* sin = cos_sin_base + n;
5555

5656
// apply rotary embedding to query head by head
5757
// q base ptr for the token
58-
T* q_base = query + bidx * q_stride;
59-
for (int i = tidx; i < n_heads * n; i += blockDim.x) {
58+
T* q_base = querys + bidx * q_stride;
59+
for (int64_t i = tidx; i < n_heads * n; i += blockDim.x) {
6060
// head idx
61-
const int h_idx = i / n;
61+
const int64_t h_idx = i / n;
6262
// rotary idx within head
63-
const int r_idx = i % n;
63+
const int64_t r_idx = i % n;
6464
// q ptr for the head
6565
T* q = q_base + h_idx * head_dim;
6666
RotaryEmbedding<T>::apply(q, cos, sin, r_idx, n, interleaved);
6767
}
6868

6969
// apply rotary embedding to key head by head
7070
// k base ptr for the token
71-
T* k_base = key + bidx * k_stride;
72-
for (int i = tidx; i < n_kv_heads * n; i += blockDim.x) {
71+
T* k_base = keys + bidx * k_stride;
72+
for (int64_t i = tidx; i < n_kv_heads * n; i += blockDim.x) {
7373
// head idx
74-
const int h_idx = i / n;
74+
const int64_t h_idx = i / n;
7575
// rotary idx within head
76-
const int r_idx = i % n;
76+
const int64_t r_idx = i % n;
7777
// k ptr for the head
7878
T* k = k_base + h_idx * head_dim;
7979
RotaryEmbedding<T>::apply(k, cos, sin, r_idx, n, interleaved);
@@ -82,31 +82,30 @@ __global__ void rotary_embedding_kernel(
8282

8383
// apply rotary embedding to query and key inplace
8484
void apply_rotary_pos_emb(
85-
torch::Tensor& query, // [n_tokens, n_heads, head_dim]
86-
torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
85+
torch::Tensor& querys, // [n_tokens, n_heads, head_dim]
86+
torch::Tensor& keys, // [n_tokens, n_kv_heads, head_dim]
8787
const torch::Tensor& positions, // [n_tokens]
8888
const torch::Tensor& cos_sin, // [max_positions, 2, rotary_dim/2]
8989
int rotary_dim,
9090
bool interleaved) {
91-
DCHECK(query.is_cuda()) << "query must be on gpu";
92-
DCHECK(key.is_cuda()) << "key must be on gpu";
93-
DCHECK(query.dim() == 3) << "query must be 3d";
94-
DCHECK(key.dim() == 3) << "key must be 3d";
91+
// keys and values should be continuous at n_kv_heads and head_dim dims
92+
CHECK(querys.stride(-1) == 1 && querys.stride(-2) == querys.size(-1));
93+
CHECK(keys.stride(-1) == 1 && keys.stride(-2) == keys.size(-1));
9594

96-
const int n_tokens = query.size(0);
97-
const int n_heads = query.size(1);
98-
const int n_kv_heads = key.size(1);
99-
const int head_dim = query.size(2);
100-
const int q_stride = query.stride(0);
101-
const int k_stride = key.stride(0);
95+
const int64_t n_tokens = querys.size(-3);
96+
const int64_t n_heads = querys.size(-2);
97+
const int64_t n_kv_heads = keys.size(-2);
98+
const int64_t head_dim = querys.size(-1);
99+
const int64_t q_stride = querys.stride(-3);
100+
const int64_t k_stride = keys.stride(-3);
102101

103102
const dim3 grid(n_tokens);
104-
const dim3 block(std::min(1024, n_heads * rotary_dim) / 2);
105-
DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding_kernel", [&] {
103+
const dim3 block(std::min<int>(1024, n_heads * rotary_dim) / 2);
104+
DISPATCH_FLOATING_TYPES(querys.scalar_type(), "rotary_embedding_kernel", [&] {
106105
rotary_embedding_kernel<scalar_t>
107106
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
108-
query.data_ptr<scalar_t>(),
109-
key.data_ptr<scalar_t>(),
107+
querys.data_ptr<scalar_t>(),
108+
keys.data_ptr<scalar_t>(),
110109
positions.data_ptr<int>(),
111110
cos_sin.data_ptr<scalar_t>(),
112111
head_dim,

src/layers/attention/handler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class AttentionHandler {
5353
const torch::Tensor& key, // [n_tokens, n_kv_heads, head_dim]
5454
const torch::Tensor& value, // [n_tokens, n_kv_heads, head_dim]
5555
const InputParameters& input_params) = 0;
56+
5657
// create an attention handler
5758
static std::unique_ptr<AttentionHandler> create_handler(
5859
const ModelArgs& args,

src/memory/kv_cache_test.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,10 @@ TEST(KVCacheTest, Basic) {
5757
}
5858

5959
TEST(KVCacheTest, Random) {
60-
const int num_kv_heads = 12;
61-
const int head_dim = 128;
62-
const int block_size = 4;
63-
const int x = 8;
64-
const int num_blocks = 2;
60+
const int64_t num_kv_heads = 12;
61+
const int64_t head_dim = 128;
62+
const int64_t block_size = 4;
63+
const int64_t num_blocks = 2;
6564

6665
// auto dtype = torch::kFloat16;
6766
torch::set_default_dtype(
@@ -82,17 +81,21 @@ TEST(KVCacheTest, Random) {
8281
for (int32_t i = 0; i < 10000; ++i) {
8382
using ISlice = torch::indexing::Slice;
8483

85-
const int sample_size = std::min(num_blocks * block_size, 10);
86-
const int num_slots = i % sample_size + 1;
84+
const int64_t sample_size = std::min<int64_t>(num_blocks * block_size, 10);
85+
const int64_t num_slots = i % sample_size + 1;
8786
torch::Tensor slot_ids =
8887
torch::randperm(num_blocks * block_size,
8988
torch::dtype(torch::kInt).device(device))
9089
.index({ISlice(0, num_slots)});
9190

91+
// construct keys and values with different strides
9292
torch::Tensor keys =
93-
torch::rand({num_slots, num_kv_heads, head_dim}, torch::device(device));
93+
torch::rand({num_slots, num_kv_heads * 2, head_dim},
94+
torch::device(device))
95+
.slice(/*dim=*/1, /*start=*/0, /*end=*/num_kv_heads);
9496
torch::Tensor values =
9597
torch::rand({num_slots, num_kv_heads, head_dim}, torch::device(device));
98+
EXPECT_NE(keys.stride(0), values.stride(0));
9699

97100
kv_cache.set_kv_cache_cuda(slot_ids, keys, values);
98101

0 commit comments

Comments
 (0)