From 656584f1a7fb3d2728a6d6e1aaf0486c32411df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 27 Jul 2025 21:41:22 +0200 Subject: [PATCH 1/6] add softcap fusion --- ggml/src/ggml-cuda/ggml-cuda.cu | 89 ++++++++++++++++++++++++--------- ggml/src/ggml-cuda/softcap.cu | 33 ++++++++++++ ggml/src/ggml-cuda/softcap.cuh | 5 ++ tests/test-backend-ops.cpp | 36 +++++++++++++ 4 files changed, 140 insertions(+), 23 deletions(-) create mode 100644 ggml/src/ggml-cuda/softcap.cu create mode 100644 ggml/src/ggml-cuda/softcap.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 03c380897cd8a..6cea616c92f96 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -32,6 +32,7 @@ #include "ggml-cuda/quantize.cuh" #include "ggml-cuda/rope.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" #include "ggml-cuda/ssm-scan.cuh" @@ -2766,34 +2767,59 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif -static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { - const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + switch (ops.size()) { + case 2: + if (ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; - GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); - //rms norm only supports F32 - if (mul->src[0]->type != GGML_TYPE_F32 || - mul->src[1]->type != GGML_TYPE_F32 || - mul->type != GGML_TYPE_F32) { - return false; - } + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } - //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { - return false; - } + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + break; + case 3: + if (ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { + const ggml_tensor *scale = cgraph->nodes[node_idx]; + const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; + const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; + + GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(scale->type == GGML_TYPE_F32); + + if (tanh->src[0] != scale || scale2->src[0] != tanh) { + return false; + } - //rms_norm kernel assumes contigous rows - if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { + return false; + } + } + break; + default: return false; - } } return true; @@ -2817,10 +2843,27 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; + if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + ggml_tensor * src0 = node->src[0]; + float scale = ggml_get_op_params_f32(node, 0); + + i += 2; node = cgraph->nodes[i]; + float softcap = ggml_get_op_params_f32(node, 0); + + ggml_set_op_params_f32(node, 0, scale); + ggml_set_op_params_f32(node, 1, softcap); + node->src[0] = src0; + + ggml_cuda_op_softcap(*cuda_ctx, node); + continue; + } } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu new file mode 100644 index 0000000000000..15ce38250fc08 --- /dev/null +++ b/ggml/src/ggml-cuda/softcap.cu @@ -0,0 +1,33 @@ +#include "softcap.cuh" + +static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = tanhf(scale * x[i]) * softcap; +} + +static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; + softcap_f32<<>>(x, dst, scale, softcap, k); +} + +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + float softcap; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 1, sizeof(float)); + + softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream); +} diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh new file mode 100644 index 0000000000000..2b875bfb0aabe --- /dev/null +++ b/ggml/src/ggml-cuda/softcap.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_SOFTCAP_BLOCK_SIZE 256 + +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7fb02a78899a6..e0d4feeeda9bb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2514,6 +2514,41 @@ struct test_scale : public test_case { } }; +// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE +struct test_softcap : public test_case { + const ggml_type type; + const std::array ne; + float softcap; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "SOFTCAP"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR3(type, ne, softcap); + } + + test_softcap(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}, + float softcap = 30.0f) + : type(type), ne(ne), softcap(softcap) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_SILU_BACK struct test_silu_back : public test_case { const ggml_type type; @@ -5390,6 +5425,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); + test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_silu_back()); for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { From 415b825d76b05a7254648fec2a20974c4b1b6152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 28 Jul 2025 10:15:17 +0200 Subject: [PATCH 2/6] undo switch block ggml-ci --- ggml/src/ggml-cuda/ggml-cuda.cu | 83 ++++++++++++++++----------------- 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6cea616c92f96..a923ad47891d9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2772,57 +2772,54 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } - switch (ops.size()) { - case 2: - if (ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { - const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx+1]; - - GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); - - //rms norm only supports F32 - if (mul->src[0]->type != GGML_TYPE_F32 || - mul->src[1]->type != GGML_TYPE_F32 || - mul->type != GGML_TYPE_F32) { - return false; - } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; - //if rms norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { - return false; - } + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); - //rms_norm kernel assumes contigous rows - if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { - return false; - } - } - break; - case 3: - if (ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE - && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { - const ggml_tensor *scale = cgraph->nodes[node_idx]; - const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; - const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } - GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(scale->type == GGML_TYPE_F32); + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } - if (tanh->src[0] != scale || scale2->src[0] != tanh) { - return false; - } + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } - if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { - return false; - } - } - break; - default: + return true; + } + + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { + const ggml_tensor *scale = cgraph->nodes[node_idx]; + const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; + const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; + + GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(scale->type == GGML_TYPE_F32); + + if (tanh->src[0] != scale || scale2->src[0] != tanh) { return false; + } + + if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { + return false; + } + + return true; } - return true; + return false; } static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, From 43c75774e143098399c9e0ebcaac8e336d5dbd9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 28 Jul 2025 17:13:02 +0200 Subject: [PATCH 3/6] simplify by swapping op params ggml-ci --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +--- ggml/src/ggml-cuda/softcap.cu | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a923ad47891d9..f6639fd997654 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2852,10 +2852,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx float scale = ggml_get_op_params_f32(node, 0); i += 2; node = cgraph->nodes[i]; - float softcap = ggml_get_op_params_f32(node, 0); - ggml_set_op_params_f32(node, 0, scale); - ggml_set_op_params_f32(node, 1, softcap); + ggml_set_op_params_f32(node, 1, scale); node->src[0] = src0; ggml_cuda_op_softcap(*cuda_ctx, node); diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 15ce38250fc08..1252d789e85b4 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -26,8 +26,8 @@ void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale; float softcap; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&softcap, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 1, sizeof(float)); softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream); } From fd1c028eedce4d14f400f6c96affe4ce0c00d48e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 28 Jul 2025 17:53:35 +0200 Subject: [PATCH 4/6] pass src tensor instead --- ggml/src/ggml-cuda/ggml-cuda.cu | 11 ++--------- ggml/src/ggml-cuda/softcap.cu | 6 +++--- ggml/src/ggml-cuda/softcap.cuh | 2 +- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f6639fd997654..dbced60a40947 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2848,15 +2848,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - ggml_tensor * src0 = node->src[0]; - float scale = ggml_get_op_params_f32(node, 0); - - i += 2; node = cgraph->nodes[i]; - - ggml_set_op_params_f32(node, 1, scale); - node->src[0] = src0; - - ggml_cuda_op_softcap(*cuda_ctx, node); + i += 2; + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); continue; } } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 1252d789e85b4..cb505d3565cc7 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -15,8 +15,8 @@ static void softcap_f32_cuda(const float * x, float * dst, const float scale, co softcap_f32<<>>(x, dst, scale, softcap, k); } -void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) { + const ggml_tensor * src0 = src->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -26,8 +26,8 @@ void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale; float softcap; + memcpy(&scale, (float *) src->op_params + 0, sizeof(float)); memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&scale, (float *) dst->op_params + 1, sizeof(float)); softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream); } diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh index 2b875bfb0aabe..6d34fb2bee416 100644 --- a/ggml/src/ggml-cuda/softcap.cuh +++ b/ggml/src/ggml-cuda/softcap.cuh @@ -2,4 +2,4 @@ #define CUDA_SOFTCAP_BLOCK_SIZE 256 -void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src); From 4ec0e680bdea78f386cb4145af201bff6f3140eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 28 Jul 2025 22:35:15 +0200 Subject: [PATCH 5/6] check that number of unary ops matches in debug ggml-ci --- ggml/src/ggml-cuda/ggml-cuda.cu | 11 ++++++----- ggml/src/ggml-cuda/softcap.cu | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index dbced60a40947..19c828aa363e8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2768,6 +2768,11 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { #endif static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { +#ifndef NDEBUG + const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); + GGML_ASSERT(unary_ops.size() == num_unary); +#endif + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -2802,16 +2807,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; - const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(scale->type == GGML_TYPE_F32); - if (tanh->src[0] != scale || scale2->src[0] != tanh) { - return false; - } - + // Check for bias if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { return false; } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index cb505d3565cc7..40dfe45d65cf6 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -15,6 +15,7 @@ static void softcap_f32_cuda(const float * x, float * dst, const float scale, co softcap_f32<<>>(x, dst, scale, softcap, k); } +// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) { const ggml_tensor * src0 = src->src[0]; const float * src0_d = (const float *)src0->data; From bacf8392187a5d64b3b91c71f6accd447c6547a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 29 Jul 2025 00:27:09 +0200 Subject: [PATCH 6/6] completely forgot to check the unary op --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 19c828aa363e8..36b353f9a7607 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2807,11 +2807,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; + const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(scale->type == GGML_TYPE_F32); + if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) { + return false; + } + // Check for bias if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { return false;