Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
return dpct::pow(base, exph);
}

static const sycl::uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);

uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}

uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
return sycl::uint3(mp, L, d);
}


static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
return (hi + n) >> fastdiv_values.y();
}


static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
return sycl::uint2(div_val, mod_val);
}


#endif // GGML_SYCL_COMMON_HPP
146 changes: 87 additions & 59 deletions ggml/src/ggml-sycl/pad_reflect_1d.cpp
Original file line number Diff line number Diff line change
@@ -1,72 +1,100 @@
#include "pad_reflect_1d.hpp"

void pad_reflect_1d_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){

const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;

if (i0 >= p0 + ne0 + p1) return;

int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);

int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];
static void pad_reflect_1d_kernel_f32(
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
const int64_t ne03, const int64_t nb00, const int64_t nb01,
const int64_t nb02, const int64_t nb03, const int64_t nb0,
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
const int p1, sycl::nd_item<3> item_ct1) {

const int64_t i3 = item_ct1.get_group(0);
const int64_t i2 = item_ct1.get_group(1);

const sycl::uint2 div_mod_packed =
fast_div_modulo(item_ct1.get_group(2), ne01);
const int64_t tile1 = div_mod_packed.y();
const int64_t tile0 = div_mod_packed.x();
const int64_t i1 = tile1;
const int64_t i0 =
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);

if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
return;
}

const char *src0_ptr =
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;

const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;

if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *)(src0_ptr + src_idx * nb00);
*(float *)(dst_ptr + i0 * nb0) = value;

GGML_UNUSED(p1);
}

void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
ggml_tensor *dst) {

const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();
const ggml_tensor *src0 = dst->src[0];
dpct::queue_ptr stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

const int32_t * opts = (const int32_t *) dst->op_params;
const int32_t *opts = (const int32_t *)dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];

const int64_t ne0 = src0->ne[0];

const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];

const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];

int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);

stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const int64_t ne0 = dst->ne[0];

GGML_ASSERT(ne0 == ne00 + p0 + p1);

constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
const int64_t tiles0 = (ne0 + bx - 1) / bx;
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
(unsigned)ne03);
const dpct::dim3 block_dims((unsigned)bx, 1, 1);

stream->submit([&](sycl::handler &cgh) {
auto src0_data_ct0 = src0->data;
auto dst_data_ct1 = dst->data;
auto src0_nb_ct7 = src0->nb[0];
auto src0_nb_ct8 = src0->nb[1];
auto src0_nb_ct9 = src0->nb[2];
auto src0_nb_ct10 = src0->nb[3];
auto dst_nb_ct11 = dst->nb[0];
auto dst_nb_ct12 = dst->nb[1];
auto dst_nb_ct13 = dst->nb[2];
auto dst_nb_ct14 = dst->nb[3];

cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
pad_reflect_1d_kernel_f32(
src0_data_ct0, dst_data_ct1, ne0, ne00,
ne01_packed, ne02, ne03, src0_nb_ct7,
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
dst_nb_ct14, p0, p1, item_ct1);
});
});
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/pad_reflect_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "common.hpp"

#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256

void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);

#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
Loading