Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1660,3 +1660,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_diag_mask_inf(
ggml_metal_library_t lib,
const struct ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_DIAG_MASK_INF);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));

char base[256];
char name[256];

snprintf(base, 256, "kernel_diag_mask_inf_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
return res;
}

1 change: 1 addition & 0 deletions src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_diag_mask_inf (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
Expand Down
2 changes: 2 additions & 0 deletions src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
case GGML_OP_DIAG_MASK_INF:
return true;
default:
return false;
}
Expand Down
10 changes: 10 additions & 0 deletions src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -887,4 +887,14 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;

typedef struct {
int32_t ne00; // nc
int32_t ne01; // nr (rows_per_channel)
int32_t nrows; // ggml_nrows(src0)
int32_t n_past;
uint64_t nb0; // src0->nb[0]
uint64_t nb1; // src0->nb[1]
uint64_t nb2; // src0->nb[2]
} ggml_metal_kargs_diag_mask_inf;

#endif // GGML_METAL_IMPL
47 changes: 47 additions & 0 deletions src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
case GGML_OP_DIAG_MASK_INF:
{
n_fuse = ggml_metal_op_diag_mask_inf(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
Expand Down Expand Up @@ -3842,3 +3846,46 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {

return 1;
}

int ggml_metal_op_diag_mask_inf(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

ggml_tensor * src0 = op->src[0];

GGML_TENSOR_LOCALS(int32_t, ne0, src0, ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, src0, nb);

const int32_t n_past = ggml_get_op_params_i32(op, 0);

const int32_t nc = ne00;
const int32_t nr = ne01;
const int32_t nrows = ggml_nrows(src0);

ggml_metal_kargs_diag_mask_inf args = {
.ne00 = nc,
.ne01 = nr,
.nrows = nrows,
.n_past = n_past,
.nb0 = nb00,
.nb1 = nb01,
.nb2 = nb02,
};

ggml_metal_pipeline_t pipeline =
ggml_metal_library_get_pipeline_diag_mask_inf(lib, op);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);

const int32_t rows = nrows;

ggml_metal_encoder_dispatch_threadgroups(enc,
rows, 1, 1,
1, 1, 1);

return 1;
}
1 change: 1 addition & 0 deletions src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_diag_mask_inf (ggml_metal_op_t ctx, int idx);

#ifdef __cplusplus
}
Expand Down
37 changes: 37 additions & 0 deletions src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -9672,3 +9672,40 @@ kernel void kernel_opt_step_sgd_f32(

x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}

kernel void kernel_diag_mask_inf_f32(
constant ggml_metal_kargs_diag_mask_inf & args [[ buffer(0) ]],
device const float * src [[ buffer(1) ]],
device float * dst [[ buffer(2) ]],
uint row [[ thread_position_in_grid ]]) {

const int nc = args.ne00; // ncols_x
const int nr = args.ne01; // rows_per_channel
const int nrows = args.nrows; // nrows_x
const int n_past = args.n_past;

if (row >= nrows) {
return;
}

const int j = row % nr;

const uint64_t nb0 = args.nb0;
const uint64_t nb1 = args.nb1;
const uint64_t nb2 = args.nb2;

const int k = row / nr;
const size_t base = k*nb2 + j*nb1;

for (int col = 0; col < nc; ++col) {
const size_t off = base + col*nb0;

float v = *((device const float *)((device char *)src + off));

if (col >= n_past && col > n_past + j) {
v = -1e9f;
}

*((device float *)((device char *)dst + off)) = v;
}
}