Skip to content

Commit 0a5036b

Browse files
authored
CUDA: add roll (#14919)
* CUDA: add roll * Make everything const, use __restrict__
1 parent 8ad7b3e commit 0a5036b

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "ggml-cuda/pool2d.cuh"
3232
#include "ggml-cuda/quantize.cuh"
3333
#include "ggml-cuda/rope.cuh"
34+
#include "ggml-cuda/roll.cuh"
3435
#include "ggml-cuda/scale.cuh"
3536
#include "ggml-cuda/softmax.cuh"
3637
#include "ggml-cuda/ssm-conv.cuh"
@@ -2419,6 +2420,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24192420
case GGML_OP_ROPE_BACK:
24202421
ggml_cuda_op_rope_back(ctx, dst);
24212422
break;
2423+
case GGML_OP_ROLL:
2424+
ggml_cuda_op_roll(ctx, dst);
2425+
break;
24222426
case GGML_OP_IM2COL:
24232427
ggml_cuda_op_im2col(ctx, dst);
24242428
break;
@@ -3411,6 +3415,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34113415
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
34123416
return max_bias == 0.0f;
34133417
}
3418+
case GGML_OP_ROLL:
3419+
if(op->src[0]->type == GGML_TYPE_F32) {
3420+
return true;
3421+
}
3422+
return false;
34143423
case GGML_OP_ROPE:
34153424
case GGML_OP_ROPE_BACK: {
34163425
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);

ggml/src/ggml-cuda/roll.cu

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include "ggml-cuda/common.cuh"
2+
#include "roll.cuh"
3+
4+
static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
5+
if (idx < 0) {
6+
return idx + ne;
7+
}
8+
if (idx >= ne) {
9+
return idx - ne;
10+
}
11+
return idx;
12+
}
13+
14+
static __global__ void roll_f32_cuda(const float * __restrict__ src,
15+
float * __restrict__ dst,
16+
const int64_t ne00,
17+
const int64_t ne01,
18+
const int64_t ne02,
19+
const int64_t ne03,
20+
const int s0,
21+
const int s1,
22+
const int s2,
23+
const int s3) {
24+
const int64_t idx = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
25+
const int64_t n_elements = ne00 * ne01 * ne02 * ne03;
26+
27+
if (idx >= n_elements) {
28+
return;
29+
}
30+
31+
const int64_t i0 = idx % ne00;
32+
const int64_t i1 = (idx / ne00) % ne01;
33+
const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
34+
const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;
35+
36+
const int64_t d0 = wrap_index(i0 - s0, ne00);
37+
const int64_t d1 = wrap_index(i1 - s1, ne01);
38+
const int64_t d2 = wrap_index(i2 - s2, ne02);
39+
const int64_t d3 = wrap_index(i3 - s3, ne03);
40+
41+
dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
42+
src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
43+
}
44+
45+
void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
46+
int s0 = dst->op_params[0];
47+
int s1 = dst->op_params[1];
48+
int s2 = dst->op_params[2];
49+
int s3 = dst->op_params[3];
50+
51+
const ggml_tensor * src0 = dst->src[0];
52+
const float * src0_d = (const float *) dst->src[0]->data;
53+
float * dst_d = (float *) dst->data;
54+
55+
GGML_TENSOR_UNARY_OP_LOCALS;
56+
57+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
58+
GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
59+
60+
cudaStream_t stream = ctx.stream();
61+
62+
int64_t sz = (ne00 * ne01 * ne02 * ne03);
63+
int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;
64+
65+
roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(
66+
src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
67+
}

ggml/src/ggml-cuda/roll.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_ROLL_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)