@@ -252,21 +252,25 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
252
252
#endif // AMD_MFMA_AVAILABLE
253
253
254
254
#if defined(GGML_USE_HIP)
255
- static int mmq_get_nwarps_host (const int cc, const int warp_size ) {
256
- return amd_mfma_available (cc) ? 8 : 256 /warp_size ;
255
+ static int mmq_get_nwarps_host (const int cc) {
256
+ return amd_mfma_available (cc) ? 8 : 4 ;
257
257
}
258
258
#else
259
- static int mmq_get_nwarps_host (const int /* cc*/ , const int warp_size ) {
260
- return 256 /warp_size ;
259
+ static int mmq_get_nwarps_host (const int /* cc*/ ) {
260
+ return 8 ;
261
261
}
262
262
#endif // (GGML_USE_HIP)
263
263
264
264
static constexpr __device__ int mmq_get_nwarps_device () {
265
+ #if defined(GGML_USE_HIP)
265
266
#if defined(AMD_MFMA_AVAILABLE)
266
267
return 8 ;
267
268
#else
268
- return 256 / ggml_cuda_get_physical_warp_size () ;
269
+ return 4 ;
269
270
#endif // AMD_MFMA_AVAILABLE
271
+ #else
272
+ return 8 ;
273
+ #endif // defined(GGML_USE_HIP)
270
274
}
271
275
272
276
// ------------------------------------------------------------
@@ -3469,7 +3473,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3469
3473
const int cc = ggml_cuda_info ().devices [id].cc ;
3470
3474
const int nsm = ggml_cuda_info ().devices [id].nsm ;
3471
3475
const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3472
- const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3476
+ const int nwarps = mmq_get_nwarps_host (cc);
3473
3477
const int mmq_y = get_mmq_y_host (cc);
3474
3478
3475
3479
const dim3 block_dims (warp_size, nwarps, 1 );
@@ -3556,7 +3560,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
3556
3560
const int cc = ggml_cuda_info ().devices [id].cc ;
3557
3561
const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
3558
3562
const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3559
- const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3563
+ const int nwarps = mmq_get_nwarps_host (cc);
3560
3564
3561
3565
const int mmq_x_max = get_mmq_x_max_host (cc);
3562
3566
const int mmq_y = get_mmq_y_host (cc);
0 commit comments