Skip to content

Commit b034ed6

Browse files
committed
vulkan: adapt integer dot mmv to mmv small m optimization (#15355)
1 parent cf2ef80 commit b034ed6

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
101101
layout (constant_id = 1) const uint NUM_ROWS = 1;
102102
layout (constant_id = 2) const uint NUM_COLS = 1;
103103

104-
#ifdef USE_SUBGROUPS
104+
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
105105
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
106106
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
107107
[[unroll]] for (uint n = 0; n < num_rows; ++n) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44
#extension GL_EXT_integer_dot_product : require
55

6-
#ifdef USE_SUBGROUPS
6+
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
77
#extension GL_KHR_shader_subgroup_basic : require
88
#extension GL_KHR_shader_subgroup_arithmetic : require
99

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ void process_shaders() {
498498
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
499499
if (is_legacy_quant(tname)) {
500500
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
501-
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUPS", "1"}}));
501+
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
502502
}
503503
#endif
504504

0 commit comments

Comments
 (0)