Skip to content

Commit 138b288

Browse files
authored
cuda : add softcap fusion (#14907)
1 parent bbd0f91 commit 138b288

File tree

4 files changed

+118
-6
lines changed

4 files changed

+118
-6
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "ggml-cuda/rope.cuh"
3434
#include "ggml-cuda/roll.cuh"
3535
#include "ggml-cuda/scale.cuh"
36+
#include "ggml-cuda/softcap.cuh"
3637
#include "ggml-cuda/softmax.cuh"
3738
#include "ggml-cuda/ssm-conv.cuh"
3839
#include "ggml-cuda/ssm-scan.cuh"
@@ -2770,7 +2771,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
27702771
}
27712772
#endif
27722773

2773-
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
2774+
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
2775+
#ifndef NDEBUG
2776+
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
2777+
GGML_ASSERT(unary_ops.size() == num_unary);
2778+
#endif
2779+
27742780
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
27752781
return false;
27762782
}
@@ -2798,9 +2804,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
27982804
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
27992805
return false;
28002806
}
2807+
2808+
return true;
28012809
}
28022810

2803-
return true;
2811+
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
2812+
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
2813+
const ggml_tensor *scale = cgraph->nodes[node_idx];
2814+
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
2815+
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
2816+
2817+
GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
2818+
GGML_ASSERT(scale->type == GGML_TYPE_F32);
2819+
2820+
if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
2821+
return false;
2822+
}
2823+
2824+
// Check for bias
2825+
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
2826+
return false;
2827+
}
2828+
2829+
return true;
2830+
}
2831+
2832+
return false;
28042833
}
28052834

28062835
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
@@ -2821,10 +2850,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
28212850
}
28222851

28232852
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
2824-
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
2825-
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2826-
i++;
2827-
continue;
2853+
if (!disable_fusion) {
2854+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
2855+
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
2856+
i++;
2857+
continue;
2858+
}
2859+
2860+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
2861+
i += 2;
2862+
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
2863+
continue;
2864+
}
28282865
}
28292866
#ifndef NDEBUG
28302867
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml/src/ggml-cuda/softcap.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "softcap.cuh"
2+
3+
static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
4+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5+
6+
if (i >= k) {
7+
return;
8+
}
9+
10+
dst[i] = tanhf(scale * x[i]) * softcap;
11+
}
12+
13+
static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
14+
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
15+
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
16+
}
17+
18+
// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
19+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
20+
const ggml_tensor * src0 = src->src[0];
21+
const float * src0_d = (const float *)src0->data;
22+
float * dst_d = (float *)dst->data;
23+
cudaStream_t stream = ctx.stream();
24+
25+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
26+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
27+
28+
float scale;
29+
float softcap;
30+
memcpy(&scale, (float *) src->op_params + 0, sizeof(float));
31+
memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
32+
33+
softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
34+
}

ggml/src/ggml-cuda/softcap.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_SOFTCAP_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);

tests/test-backend-ops.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,41 @@ struct test_scale : public test_case {
25452545
}
25462546
};
25472547

2548+
// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
2549+
struct test_softcap : public test_case {
2550+
const ggml_type type;
2551+
const std::array<int64_t, 4> ne;
2552+
float softcap;
2553+
2554+
std::string op_desc(ggml_tensor * t) override {
2555+
GGML_UNUSED(t);
2556+
return "SOFTCAP";
2557+
}
2558+
2559+
bool run_whole_graph() override { return true; }
2560+
2561+
std::string vars() override {
2562+
return VARS_TO_STR3(type, ne, softcap);
2563+
}
2564+
2565+
test_softcap(ggml_type type = GGML_TYPE_F32,
2566+
std::array<int64_t, 4> ne = {10, 10, 10, 10},
2567+
float softcap = 30.0f)
2568+
: type(type), ne(ne), softcap(softcap) {}
2569+
2570+
ggml_tensor * build_graph(ggml_context * ctx) override {
2571+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2572+
2573+
ggml_set_param(a);
2574+
ggml_set_name(a, "a");
2575+
2576+
ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);
2577+
ggml_set_name(out, "out");
2578+
2579+
return out;
2580+
}
2581+
};
2582+
25482583
// GGML_OP_SILU_BACK
25492584
struct test_silu_back : public test_case {
25502585
const ggml_type type;
@@ -5421,6 +5456,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
54215456
test_cases.emplace_back(new test_add1());
54225457
test_cases.emplace_back(new test_scale());
54235458
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
5459+
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
54245460
test_cases.emplace_back(new test_silu_back());
54255461

54265462
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {

0 commit comments

Comments
 (0)