Skip to content

cuda : add softcap fusion #14907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "ggml-cuda/quantize.cuh"
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softcap.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
Expand Down Expand Up @@ -2766,7 +2767,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
}
#endif

static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary);
#endif

if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
Expand Down Expand Up @@ -2794,9 +2800,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}

return true;
}

return true;
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
const ggml_tensor *scale = cgraph->nodes[node_idx];
const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];

GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(scale->type == GGML_TYPE_F32);

if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
return false;
}

// Check for bias
if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
return false;
}

return true;
}

return false;
}

static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
Expand All @@ -2817,10 +2846,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}

static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
}
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
Expand Down
34 changes: 34 additions & 0 deletions ggml/src/ggml-cuda/softcap.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "softcap.cuh"

static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = tanhf(scale * x[i]) * softcap;
}

static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
}

// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
const ggml_tensor * src0 = src->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

float scale;
float softcap;
memcpy(&scale, (float *) src->op_params + 0, sizeof(float));
memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));

softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/softcap.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_SOFTCAP_BLOCK_SIZE 256

void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);
36 changes: 36 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,41 @@ struct test_scale : public test_case {
}
};

// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
struct test_softcap : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float softcap;

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "SOFTCAP";
}

bool run_whole_graph() override { return true; }

std::string vars() override {
return VARS_TO_STR3(type, ne, softcap);
}

test_softcap(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
float softcap = 30.0f)
: type(type), ne(ne), softcap(softcap) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());

ggml_set_param(a);
ggml_set_name(a, "a");

ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_SILU_BACK
struct test_silu_back : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -5390,6 +5425,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
test_cases.emplace_back(new test_silu_back());

for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
Expand Down