Skip to content

Commit a475712

Browse files
committed
Revert "CUDA: fix MMQ nwarps for AMD with warp_size==32 (ggml-org#15014)"
This reverts commit 9c35706.
1 parent f430916 commit a475712

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,25 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
252252
#endif // AMD_MFMA_AVAILABLE
253253

254254
#if defined(GGML_USE_HIP)
255-
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
256-
return amd_mfma_available(cc) ? 8 : 256/warp_size;
255+
static int mmq_get_nwarps_host(const int cc) {
256+
return amd_mfma_available(cc) ? 8 : 4;
257257
}
258258
#else
259-
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
260-
return 256/warp_size;
259+
static int mmq_get_nwarps_host(const int /*cc*/) {
260+
return 8;
261261
}
262262
#endif // (GGML_USE_HIP)
263263

264264
static constexpr __device__ int mmq_get_nwarps_device() {
265+
#if defined(GGML_USE_HIP)
265266
#if defined(AMD_MFMA_AVAILABLE)
266267
return 8;
267268
#else
268-
return 256/ggml_cuda_get_physical_warp_size();
269+
return 4;
269270
#endif // AMD_MFMA_AVAILABLE
271+
#else
272+
return 8;
273+
#endif // defined(GGML_USE_HIP)
270274
}
271275

272276
// ------------------------------------------------------------
@@ -3469,7 +3473,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
34693473
const int cc = ggml_cuda_info().devices[id].cc;
34703474
const int nsm = ggml_cuda_info().devices[id].nsm;
34713475
const int warp_size = ggml_cuda_info().devices[id].warp_size;
3472-
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3476+
const int nwarps = mmq_get_nwarps_host(cc);
34733477
const int mmq_y = get_mmq_y_host(cc);
34743478

34753479
const dim3 block_dims(warp_size, nwarps, 1);
@@ -3556,7 +3560,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
35563560
const int cc = ggml_cuda_info().devices[id].cc;
35573561
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
35583562
const int warp_size = ggml_cuda_info().devices[id].warp_size;
3559-
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3563+
const int nwarps = mmq_get_nwarps_host(cc);
35603564

35613565
const int mmq_x_max = get_mmq_x_max_host(cc);
35623566
const int mmq_y = get_mmq_y_host(cc);

0 commit comments

Comments
 (0)