@@ -1098,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1098
1098
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
1099
1099
*/
1100
1100
template <void unary_op (ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1101
- void ggml_cann_unary_op (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1101
+ void ggml_cann_op_unary (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1102
1102
ggml_tensor* src = dst->src [0 ];
1103
1103
1104
1104
aclTensor* acl_src = ggml_cann_create_tensor (src);
@@ -1109,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1109
1109
}
1110
1110
1111
1111
/* *
1112
- * @brief Applies a unary operation to a ggml tensor using the CANN backend.
1112
+ * @brief Applies a unary operation to a ggml tensor using the CANN backend.
1113
1113
*
1114
- * @details This function performs a unary operation on the input tensor using
1115
- * a user-provided lambda or callable object `unary_op`, which accepts the CANN
1116
- * context and two ACL tensors (source and destination). Internally, this function
1117
- * creates ACL representations of the ggml tensors and invokes the unary operation.
1118
- * The result is stored in the destination tensor `dst`. This utility abstracts the
1119
- * common boilerplate of tensor conversion and cleanup when implementing unary ops.
1114
+ * @details This function applies a unary operation to the input tensor using
1115
+ * a user-provided lambda or callable `unary_op`. The lambda receives the
1116
+ * CANN backend context and two ACL tensors: the source and the destination.
1120
1117
*
1121
- * @param unary_op A callable that performs the unary operation using CANN APIs.
1122
- * @param ctx The CANN context used for operations.
1123
- * @param dst The destination tensor where the result will be stored.
1124
- * The source tensor is retrieved from `dst->src[0]`.
1118
+ * Internally, this function handles the conversion from GGML tensors to ACL tensors,
1119
+ * calls the provided unary op, and manages resource cleanup. The input is assumed
1120
+ * to be `dst->src[0]`, and the result is written to `dst`.
1121
+ *
1122
+ * This utility simplifies writing unary op wrappers by abstracting tensor preparation.
1123
+ *
1124
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1125
+ * @param ctx The CANN context for operation execution.
1126
+ * @param dst The destination ggml_tensor where the result will be stored.
1127
+ * The input tensor is assumed to be `dst->src[0]`.
1128
+ *
1129
+ * @see GGML_CANN_CALL_OP_UNARY
1125
1130
*/
1126
- void ggml_cann_unary_op (
1131
+ void ggml_cann_op_unary (
1127
1132
std::function<void (ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1128
1133
ggml_backend_cann_context& ctx, ggml_tensor* dst);
1129
1134
1130
1135
/* *
1131
- * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
1136
+ * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
1137
+ *
1138
+ * @details This function performs a gated activation such as GEGLU or ReGLU.
1139
+ * It supports two input modes:
1140
+ *
1141
+ * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
1142
+ * These are used directly as the value and gate tensors.
1143
+ *
1144
+ * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
1145
+ * contain a concatenation of value and gate along the first dimension. This tensor
1146
+ * will be split into two equal halves to form the value and gate inputs.
1147
+ *
1148
+ * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
1149
+ * then multiplies the result in-place with the gate tensor:
1150
+ *
1151
+ * @code
1152
+ * dst = unary_op(value) * gate;
1153
+ * @endcode
1154
+ *
1155
+ * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
1156
+ * order of value/gate in the packed input case.
1157
+ *
1158
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1159
+ * It receives (ctx, acl_value_tensor, acl_output_tensor).
1160
+ * @param ctx The CANN context used for execution.
1161
+ * @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
1162
+ *
1163
+ * @see GGML_CANN_CALL_OP_UNARY_GATED
1164
+ */
1165
+ void ggml_cann_op_unary_gated (
1166
+ std::function<void (ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1167
+ ggml_backend_cann_context& ctx, ggml_tensor* dst);
1168
+
1169
+ /* *
1170
+ * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
1171
+ *
1172
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
1173
+ * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
1174
+ * unary ops in the CANN backend.
1175
+ *
1176
+ * Internally, this macro expands to a lambda like:
1177
+ * @code
1178
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1179
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1180
+ * };
1181
+ * @endcode
1182
+ *
1183
+ * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
1184
+ *
1185
+ * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1186
+ *
1187
+ * @see ggml_cann_op_unary
1188
+ * @see GGML_CANN_CALL_ACLNN_OP
1189
+ */
1190
+ #define GGML_CANN_CALL_OP_UNARY (OP_NAME ) \
1191
+ do { \
1192
+ auto lambda = [](ggml_backend_cann_context& ctx, \
1193
+ aclTensor* acl_src, \
1194
+ aclTensor* acl_dst) { \
1195
+ GGML_CANN_CALL_ACLNN_OP (ctx, OP_NAME, acl_src, acl_dst); \
1196
+ }; \
1197
+ ggml_cann_op_unary (lambda, ctx, dst); \
1198
+ } \
1199
+ while (0 )
1200
+
1201
+ /* *
1202
+ * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
1132
1203
*
1133
- * This macro defines an inline lambda wrapping a specific ACL operation name ,
1134
- * and passes it to the templated ggml_cann_unary_op function. It simplifies
1135
- * calling unary ops by hiding the lambda boilerplate .
1204
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression ,
1205
+ * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
1206
+ * executing gated unary ops in the CANN backend .
1136
1207
*
1137
- * Internally, the lambda will call :
1208
+ * Internally, this macro expands to a lambda like :
1138
1209
* @code
1139
- * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1210
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1211
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1212
+ * };
1140
1213
* @endcode
1141
1214
*
1215
+ * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
1216
+ *
1142
1217
* @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1143
1218
*
1144
- * @see ggml_cann_unary_op
1219
+ * @see ggml_cann_op_unary_gated
1145
1220
* @see GGML_CANN_CALL_ACLNN_OP
1146
1221
*/
1147
- #define GGML_CANN_CALL_UNARY_OP (OP_NAME ) \
1222
+ #define GGML_CANN_CALL_OP_UNARY_GATED (OP_NAME ) \
1148
1223
do { \
1149
1224
auto lambda = [](ggml_backend_cann_context& ctx, \
1150
1225
aclTensor* acl_src, \
1151
1226
aclTensor* acl_dst) { \
1152
1227
GGML_CANN_CALL_ACLNN_OP (ctx, OP_NAME, acl_src, acl_dst); \
1153
1228
}; \
1154
- ggml_cann_unary_op (lambda, ctx, dst); \
1229
+ ggml_cann_op_unary_gated (lambda, ctx, dst); \
1155
1230
} \
1156
1231
while (0 )
1232
+
1157
1233
#endif // CANN_ACLNN_OPS
0 commit comments