From fa3fe8b2489889ab46ba34ab880a7a3f0848bdcf Mon Sep 17 00:00:00 2001 From: Zhang Jianyu Date: Wed, 12 Nov 2025 21:08:33 +0800 Subject: [PATCH] refactor pad_reflect_1d to make the UT case pass --- ggml/src/ggml-sycl/common.hpp | 26 +++++ ggml/src/ggml-sycl/pad_reflect_1d.cpp | 146 +++++++++++++++----------- ggml/src/ggml-sycl/pad_reflect_1d.hpp | 2 + 3 files changed, 115 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 338fa08cda205..637630c1d230a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias, return dpct::pow(base, exph); } +static const sycl::uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + return sycl::uint3(mp, L, d); +} + + +static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x()); + return (hi + n) >> fastdiv_values.y(); +} + + +static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z(); + return sycl::uint2(div_val, mod_val); +} + + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.cpp b/ggml/src/ggml-sycl/pad_reflect_1d.cpp index e56655a98a106..85e993628c65f 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.cpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.cpp @@ -1,72 +1,100 @@ #include "pad_reflect_1d.hpp" -void pad_reflect_1d_f32(const float* src,float* dst, - const int64_t ne0, const int64_t ne02, const int p0, const int p1, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const sycl::nd_item<3> &item_ct1){ - - const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); - const int i1 = item_ct1.get_group(1); - const int g2 = item_ct1.get_group(2); - const int i2 = g2 % ne02; - const int i3 = g2 / ne02; - - if (i0 >= p0 + ne0 + p1) return; - - int t = i0 - p0; - int period = 2 * ne0 -2; - int m = t % period; - m += (m < 0) * period; - int center = ne0 -1; - int srci0 = center - abs(center - m); - - int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; - int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; - dst[offest_dst] = src[offest_src]; +static void pad_reflect_1d_kernel_f32( + const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0, + const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02, + const int64_t ne03, const int64_t nb00, const int64_t nb01, + const int64_t nb02, const int64_t nb03, const int64_t nb0, + const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0, + const int p1, sycl::nd_item<3> item_ct1) { + const int64_t i3 = item_ct1.get_group(0); + const int64_t i2 = item_ct1.get_group(1); + + const sycl::uint2 div_mod_packed = + fast_div_modulo(item_ct1.get_group(2), ne01); + const int64_t tile1 = div_mod_packed.y(); + const int64_t tile0 = div_mod_packed.x(); + const int64_t i1 = tile1; + const int64_t i0 = + item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2); + + if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) { + return; + } + + const char *src0_ptr = + (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *)(src0_ptr + src_idx * nb00); + *(float *)(dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); } -void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx, + ggml_tensor *dst) { - const ggml_tensor * src0 = dst->src[0]; - queue_ptr stream = ctx.stream(); + const ggml_tensor *src0 = dst->src[0]; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int32_t * opts = (const int32_t *) dst->op_params; + const int32_t *opts = (const int32_t *)dst->op_params; const int p0 = opts[0]; const int p1 = opts[1]; - const int64_t ne0 = src0->ne[0]; - - const int64_t ne00 = dst->ne[0]; - const int64_t ne01 = dst->ne[1]; - const int64_t ne02 = dst->ne[2]; - const int64_t ne03 = dst->ne[3]; - - const int64_t nb00 = dst->nb[0]; - const int64_t nb01 = dst->nb[1]; - const int64_t nb02 = dst->nb[2]; - const int64_t nb03 = dst->nb[3]; - const int64_t nb0 = src0->nb[0]; - const int64_t nb1 = src0->nb[1]; - const int64_t nb2 = src0->nb[2]; - const int64_t nb3 = src0->nb[3]; - - int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; - sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); - sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); - - stream->parallel_for( - sycl::nd_range<3>(global, - local), - [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( - (const float *) src0->data, (float *) dst->data, - ne0, ne02, p0, p1, - nb0, nb1, nb2, nb3, - nb00, nb01, nb02, nb03 - , item_ct1); - }); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const sycl::uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + + GGML_ASSERT(ne0 == ne00 + p0 + p1); + + constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE; + const int64_t tiles0 = (ne0 + bx - 1) / bx; + const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, + (unsigned)ne03); + const dpct::dim3 block_dims((unsigned)bx, 1, 1); + + stream->submit([&](sycl::handler &cgh) { + auto src0_data_ct0 = src0->data; + auto dst_data_ct1 = dst->data; + auto src0_nb_ct7 = src0->nb[0]; + auto src0_nb_ct8 = src0->nb[1]; + auto src0_nb_ct9 = src0->nb[2]; + auto src0_nb_ct10 = src0->nb[3]; + auto dst_nb_ct11 = dst->nb[0]; + auto dst_nb_ct12 = dst->nb[1]; + auto dst_nb_ct13 = dst->nb[2]; + auto dst_nb_ct14 = dst->nb[3]; + + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + pad_reflect_1d_kernel_f32( + src0_data_ct0, dst_data_ct1, ne0, ne00, + ne01_packed, ne02, ne03, src0_nb_ct7, + src0_nb_ct8, src0_nb_ct9, src0_nb_ct10, + dst_nb_ct11, dst_nb_ct12, dst_nb_ct13, + dst_nb_ct14, p0, p1, item_ct1); + }); + }); } diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.hpp b/ggml/src/ggml-sycl/pad_reflect_1d.hpp index a24509dea6384..45aaf9a9111f2 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.hpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.hpp @@ -3,6 +3,8 @@ #include "common.hpp" +#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256 + void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst); #endif // GGML_SYCL_PAD_REFLECT_1D_HPP