Skip to content

Commit 11dd5a4

Browse files
authored
CANN: Implement GLU ops (#14884)
Implement REGLU, GEGLU, SWIGLU ops according to #14158
1 parent 9b8f3c6 commit 11dd5a4

File tree

4 files changed

+194
-40
lines changed

4 files changed

+194
-40
lines changed

ggml/src/ggml-cann/acl_tensor.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,16 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
7777
for (int i = 0; i < final_dims; i++) {
7878
acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
7979
}
80+
size_t elem_offset = offset / ggml_element_size(tensor);
81+
acl_storage_len += elem_offset;
8082

8183
// Reverse ne and stride.
8284
std::reverse(acl_ne, acl_ne + final_dims);
8385
std::reverse(acl_stride, acl_stride + final_dims);
8486

8587
aclTensor* acl_tensor = aclCreateTensor(
8688
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
87-
offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
89+
elem_offset, format, &acl_storage_len, 1,
8890
tensor->data);
8991

9092
return acl_tensor;

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
9999
}
100100
}
101101

102-
void ggml_cann_unary_op(
102+
void ggml_cann_op_unary(
103103
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
104104
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
105105
ggml_tensor* src = dst->src[0];
@@ -111,6 +111,42 @@ void ggml_cann_unary_op(
111111
ggml_cann_release_resources(ctx, acl_src, acl_dst);
112112
}
113113

114+
void ggml_cann_op_unary_gated(
115+
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
116+
ggml_backend_cann_context& ctx, ggml_tensor* dst) {
117+
ggml_tensor* src0 = dst->src[0];
118+
ggml_tensor* src1 = dst->src[1];
119+
120+
GGML_ASSERT(ggml_is_contiguous_1(src0));
121+
GGML_ASSERT(ggml_is_contiguous_1(dst));
122+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
123+
124+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
125+
aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
126+
if(src1) {
127+
GGML_ASSERT(ggml_is_contiguous_1(src1));
128+
GGML_ASSERT(src0->type == src1->type);
129+
130+
acl_src0 = ggml_cann_create_tensor(src0);
131+
acl_src1 = ggml_cann_create_tensor(src1);
132+
} else {
133+
int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
134+
size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
135+
acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
136+
acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
137+
if (swapped) {
138+
std::swap(acl_src0, acl_src1);
139+
}
140+
}
141+
142+
unary_op(ctx, acl_src0, acl_dst);
143+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
144+
145+
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
146+
if(src1)
147+
ggml_cann_release_resources(ctx, acl_src1);
148+
}
149+
114150
/**
115151
* @brief Repeats elements of a tensor along each dimension according to the
116152
* specified repeat array.

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
10981098
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
10991099
*/
11001100
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1101-
void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1101+
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
11021102
ggml_tensor* src = dst->src[0];
11031103

11041104
aclTensor* acl_src = ggml_cann_create_tensor(src);
@@ -1109,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
11091109
}
11101110

11111111
/**
1112-
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
1112+
* @brief Applies a unary operation to a ggml tensor using the CANN backend.
11131113
*
1114-
* @details This function performs a unary operation on the input tensor using
1115-
* a user-provided lambda or callable object `unary_op`, which accepts the CANN
1116-
* context and two ACL tensors (source and destination). Internally, this function
1117-
* creates ACL representations of the ggml tensors and invokes the unary operation.
1118-
* The result is stored in the destination tensor `dst`. This utility abstracts the
1119-
* common boilerplate of tensor conversion and cleanup when implementing unary ops.
1114+
* @details This function applies a unary operation to the input tensor using
1115+
* a user-provided lambda or callable `unary_op`. The lambda receives the
1116+
* CANN backend context and two ACL tensors: the source and the destination.
11201117
*
1121-
* @param unary_op A callable that performs the unary operation using CANN APIs.
1122-
* @param ctx The CANN context used for operations.
1123-
* @param dst The destination tensor where the result will be stored.
1124-
* The source tensor is retrieved from `dst->src[0]`.
1118+
* Internally, this function handles the conversion from GGML tensors to ACL tensors,
1119+
* calls the provided unary op, and manages resource cleanup. The input is assumed
1120+
* to be `dst->src[0]`, and the result is written to `dst`.
1121+
*
1122+
* This utility simplifies writing unary op wrappers by abstracting tensor preparation.
1123+
*
1124+
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1125+
* @param ctx The CANN context for operation execution.
1126+
* @param dst The destination ggml_tensor where the result will be stored.
1127+
* The input tensor is assumed to be `dst->src[0]`.
1128+
*
1129+
* @see GGML_CANN_CALL_OP_UNARY
11251130
*/
1126-
void ggml_cann_unary_op(
1131+
void ggml_cann_op_unary(
11271132
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
11281133
ggml_backend_cann_context& ctx, ggml_tensor* dst);
11291134

11301135
/**
1131-
* @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
1136+
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
1137+
*
1138+
* @details This function performs a gated activation such as GEGLU or ReGLU.
1139+
* It supports two input modes:
1140+
*
1141+
* 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
1142+
* These are used directly as the value and gate tensors.
1143+
*
1144+
* 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
1145+
* contain a concatenation of value and gate along the first dimension. This tensor
1146+
* will be split into two equal halves to form the value and gate inputs.
1147+
*
1148+
* The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
1149+
* then multiplies the result in-place with the gate tensor:
1150+
*
1151+
* @code
1152+
* dst = unary_op(value) * gate;
1153+
* @endcode
1154+
*
1155+
* The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
1156+
* order of value/gate in the packed input case.
1157+
*
1158+
* @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1159+
* It receives (ctx, acl_value_tensor, acl_output_tensor).
1160+
* @param ctx The CANN context used for execution.
1161+
* @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
1162+
*
1163+
* @see GGML_CANN_CALL_OP_UNARY_GATED
1164+
*/
1165+
void ggml_cann_op_unary_gated(
1166+
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1167+
ggml_backend_cann_context& ctx, ggml_tensor* dst);
1168+
1169+
/**
1170+
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
1171+
*
1172+
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
1173+
* and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
1174+
* unary ops in the CANN backend.
1175+
*
1176+
* Internally, this macro expands to a lambda like:
1177+
* @code
1178+
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1179+
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1180+
* };
1181+
* @endcode
1182+
*
1183+
* This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
1184+
*
1185+
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1186+
*
1187+
* @see ggml_cann_op_unary
1188+
* @see GGML_CANN_CALL_ACLNN_OP
1189+
*/
1190+
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
1191+
do { \
1192+
auto lambda = [](ggml_backend_cann_context& ctx, \
1193+
aclTensor* acl_src, \
1194+
aclTensor* acl_dst) { \
1195+
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
1196+
}; \
1197+
ggml_cann_op_unary(lambda, ctx, dst); \
1198+
} \
1199+
while (0)
1200+
1201+
/**
1202+
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
11321203
*
1133-
* This macro defines an inline lambda wrapping a specific ACL operation name,
1134-
* and passes it to the templated ggml_cann_unary_op function. It simplifies
1135-
* calling unary ops by hiding the lambda boilerplate.
1204+
* This macro wraps the specified ACLNN unary operator name into a lambda expression,
1205+
* and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
1206+
* executing gated unary ops in the CANN backend.
11361207
*
1137-
* Internally, the lambda will call:
1208+
* Internally, this macro expands to a lambda like:
11381209
* @code
1139-
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1210+
* [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1211+
* GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1212+
* };
11401213
* @endcode
11411214
*
1215+
* This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
1216+
*
11421217
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
11431218
*
1144-
* @see ggml_cann_unary_op
1219+
* @see ggml_cann_op_unary_gated
11451220
* @see GGML_CANN_CALL_ACLNN_OP
11461221
*/
1147-
#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \
1222+
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
11481223
do { \
11491224
auto lambda = [](ggml_backend_cann_context& ctx, \
11501225
aclTensor* acl_src, \
11511226
aclTensor* acl_dst) { \
11521227
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
11531228
}; \
1154-
ggml_cann_unary_op(lambda, ctx, dst); \
1229+
ggml_cann_op_unary_gated(lambda, ctx, dst); \
11551230
} \
11561231
while (0)
1232+
11571233
#endif // CANN_ACLNN_OPS

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,48 +1681,50 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
16811681
case GGML_OP_UNARY:
16821682
switch (ggml_get_unary_op(dst)) {
16831683
case GGML_UNARY_OP_ABS:
1684-
GGML_CANN_CALL_UNARY_OP(Abs);
1684+
GGML_CANN_CALL_OP_UNARY(Abs);
16851685
break;
16861686
case GGML_UNARY_OP_NEG:
1687-
GGML_CANN_CALL_UNARY_OP(Neg);
1687+
GGML_CANN_CALL_OP_UNARY(Neg);
16881688
break;
16891689
case GGML_UNARY_OP_GELU:
1690-
GGML_CANN_CALL_UNARY_OP(Gelu);
1690+
case GGML_UNARY_OP_GELU_ERF:
1691+
// aclnnGelu internally uses the erf-based approximation.
1692+
GGML_CANN_CALL_OP_UNARY(Gelu);
16911693
break;
16921694
case GGML_UNARY_OP_SILU:
1693-
GGML_CANN_CALL_UNARY_OP(Silu);
1695+
GGML_CANN_CALL_OP_UNARY(Silu);
16941696
break;
16951697
case GGML_UNARY_OP_GELU_QUICK: {
16961698
auto lambda = [](ggml_backend_cann_context& ctx,
16971699
aclTensor* acl_src,
16981700
aclTensor* acl_dst) {
16991701
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
17001702
};
1701-
ggml_cann_unary_op(lambda, ctx, dst);
1703+
ggml_cann_op_unary(lambda, ctx, dst);
17021704
} break;
17031705
case GGML_UNARY_OP_TANH:
1704-
GGML_CANN_CALL_UNARY_OP(Tanh);
1706+
GGML_CANN_CALL_OP_UNARY(Tanh);
17051707
break;
17061708
case GGML_UNARY_OP_RELU:
1707-
GGML_CANN_CALL_UNARY_OP(Relu);
1709+
GGML_CANN_CALL_OP_UNARY(Relu);
17081710
break;
17091711
case GGML_UNARY_OP_SIGMOID:
1710-
GGML_CANN_CALL_UNARY_OP(Sigmoid);
1712+
GGML_CANN_CALL_OP_UNARY(Sigmoid);
17111713
break;
17121714
case GGML_UNARY_OP_HARDSIGMOID:
1713-
GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
1715+
GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
17141716
break;
17151717
case GGML_UNARY_OP_HARDSWISH:
1716-
GGML_CANN_CALL_UNARY_OP(Hardswish);
1718+
GGML_CANN_CALL_OP_UNARY(Hardswish);
17171719
break;
17181720
case GGML_UNARY_OP_EXP:
1719-
GGML_CANN_CALL_UNARY_OP(Exp);
1721+
GGML_CANN_CALL_OP_UNARY(Exp);
17201722
break;
17211723
case GGML_UNARY_OP_ELU:
17221724
ggml_cann_elu(ctx, dst);
17231725
break;
17241726
case GGML_UNARY_OP_SGN:
1725-
GGML_CANN_CALL_UNARY_OP(Sign);
1727+
GGML_CANN_CALL_OP_UNARY(Sign);
17261728
break;
17271729
case GGML_UNARY_OP_STEP:
17281730
ggml_cann_step(ctx, dst);
@@ -1731,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17311733
return false;
17321734
}
17331735
break;
1736+
case GGML_OP_GLU:
1737+
switch (ggml_get_glu_op(dst)) {
1738+
case GGML_GLU_OP_REGLU:
1739+
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
1740+
break;
1741+
case GGML_GLU_OP_GEGLU:
1742+
case GGML_GLU_OP_GEGLU_ERF:
1743+
// aclnnGelu internally uses the erf-based approximation.
1744+
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
1745+
break;
1746+
case GGML_GLU_OP_SWIGLU:
1747+
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1748+
break;
1749+
case GGML_GLU_OP_GEGLU_QUICK: {
1750+
auto lambda = [](ggml_backend_cann_context& ctx,
1751+
aclTensor* acl_src,
1752+
aclTensor* acl_dst) {
1753+
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1754+
};
1755+
ggml_cann_op_unary_gated(lambda, ctx, dst);
1756+
} break;
1757+
default:
1758+
return false;
1759+
}
1760+
break;
17341761
case GGML_OP_NORM:
17351762
ggml_cann_norm(ctx, dst);
17361763
break;
@@ -1773,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17731800
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
17741801
break;
17751802
case GGML_OP_SQRT:
1776-
GGML_CANN_CALL_UNARY_OP(Sqrt);
1803+
GGML_CANN_CALL_OP_UNARY(Sqrt);
17771804
break;
17781805
case GGML_OP_CLAMP:
17791806
ggml_cann_clamp(ctx, dst);
@@ -1818,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
18181845
ggml_cann_argmax(ctx, dst);
18191846
break;
18201847
case GGML_OP_COS:
1821-
ggml_cann_unary_op<aclnn_cos>(ctx, dst);
1848+
ggml_cann_op_unary<aclnn_cos>(ctx, dst);
18221849
break;
18231850
case GGML_OP_SIN:
1824-
ggml_cann_unary_op<aclnn_sin>(ctx, dst);
1851+
ggml_cann_op_unary<aclnn_sin>(ctx, dst);
18251852
break;
18261853
case GGML_OP_CONV_TRANSPOSE_1D:
18271854
ggml_cann_conv_transpose_1d(ctx, dst);
18281855
break;
18291856
case GGML_OP_LOG:
1830-
GGML_CANN_CALL_UNARY_OP(Log);
1857+
GGML_CANN_CALL_OP_UNARY(Log);
18311858
break;
18321859
case GGML_OP_MEAN:
18331860
ggml_cann_mean(ctx, dst);
@@ -2101,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21012128
case GGML_UNARY_OP_ELU:
21022129
case GGML_UNARY_OP_SGN:
21032130
case GGML_UNARY_OP_STEP:
2131+
case GGML_UNARY_OP_GELU_ERF:
21042132
return true;
21052133
default:
21062134
return false;
21072135
}
2136+
case GGML_OP_GLU:
2137+
switch (ggml_get_glu_op(op)) {
2138+
case GGML_GLU_OP_REGLU:
2139+
case GGML_GLU_OP_GEGLU:
2140+
case GGML_GLU_OP_SWIGLU:
2141+
case GGML_GLU_OP_GEGLU_ERF:
2142+
case GGML_GLU_OP_GEGLU_QUICK:
2143+
return true;
2144+
default:
2145+
return false;
2146+
}
2147+
break;
21082148
case GGML_OP_MUL_MAT: {
21092149
switch (op->src[0]->type) {
21102150
case GGML_TYPE_F16:

0 commit comments

Comments
 (0)