1515from .. import exc
1616from .._compat import get_tensor_descriptor_fn_name
1717from .ast_extension import expr_from_string
18+ from .ast_extension import statement_from_string
1819from .compile_environment import CompileEnvironment
1920from .device_function import DeviceFunction
2021from .host_function import HostFunction
@@ -353,7 +354,6 @@ def codegen_load(
353354 )
354355 assert extra_mask is None
355356 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
356-
357357 # Load from tensor descriptor with permuted offsets
358358 load_expr = expr_from_string (
359359 f"{ indexing .tensor_descriptor (state )} .load({ indexing .offsets_str_permuted (state )} )"
@@ -383,10 +383,24 @@ def codegen_store(
383383 )
384384 assert extra_mask is None
385385 indexing = BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386+ store_value = indexing .reshape_store (state , value )
387+
388+ config = DeviceFunction .current ().config
389+ epilogue_subtiles = state .config .epilogue_subtiling
390+ if torch .cuda .get_device_capability () >= (9 , 0 ) and (
391+ idx := state .device_function .device_store_index
392+ ) < len (epilogue_subtiles ):
393+ subtile_split = epilogue_subtiles [idx ]
394+ state .device_function .device_store_index += 1
395+
396+ subtile_codegen = self ._codegen_epilogue_subtile_store (
397+ state , fake_tensor , indexing , store_value , subtile_split , config
398+ )
399+ if subtile_codegen is not None :
400+ return subtile_codegen
386401
387402 # Apply permutation to the value being stored if needed
388403 desc_arg = indexing .tensor_descriptor_arg (state )
389- store_value = indexing .reshape_store (state , value )
390404
391405 if desc_arg .permutation is not None :
392406 # Apply permutation to the value
@@ -400,6 +414,95 @@ def codegen_store(
400414 value = store_value ,
401415 )
402416
417+ def _codegen_epilogue_subtile_store (
418+ self ,
419+ state : CodegenState ,
420+ fake_tensor : torch .Tensor ,
421+ indexing : BlockedSubscriptIndexing ,
422+ store_value : ast .AST ,
423+ subtile_split : int ,
424+ config : Config ,
425+ ) -> ast .AST | None :
426+ # Currently support 2D tiles without permutations
427+ if (
428+ len (indexing .block_shape ) != 2
429+ or len (indexing .offsets ) != 2
430+ or subtile_split == 0
431+ ):
432+ return None
433+
434+ env = CompileEnvironment .current ()
435+ block_m , block_n = indexing .block_shape
436+ try :
437+ block_n_hint = env .size_hint (block_n )
438+ block_idx = env .get_block_id (block_n )
439+ block_size = env .block_sizes [block_idx ].from_config (config )
440+ except Exception :
441+ return None
442+
443+ if block_n_hint % 2 != 0 or block_size <= 16 :
444+ return None
445+
446+ device_fn = state .device_function
447+ codegen = state .codegen
448+
449+ block_m_str = device_fn .literal_expr (block_m )
450+ block_n_str = device_fn .literal_expr (block_n )
451+ indexing .block_shape [1 ] //= subtile_split
452+
453+ # TODO(PaulZhang12): Support more epilogue subtile configs besides 2
454+ block_n_half_str = f"({ block_n_str } // { subtile_split } )"
455+
456+ # Lift the store value into a temporary variable for reuse
457+ acc_var = codegen .lift (store_value , prefix = "acc" )
458+
459+ reshape_expr = expr_from_string (
460+ "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)" ,
461+ acc = acc_var ,
462+ dim_m = expr_from_string (block_m_str ),
463+ dim_half = expr_from_string (block_n_half_str ),
464+ )
465+ reshape_var = codegen .lift (reshape_expr , prefix = "acc" )
466+
467+ acc0_name = codegen .tmpvar (prefix = "acc" )
468+ acc1_name = codegen .tmpvar (prefix = "acc" )
469+ codegen .add_statement (
470+ statement_from_string (
471+ f"{ acc0_name } , { acc1_name } = tl.split({{acc}})" ,
472+ acc = reshape_var ,
473+ )
474+ )
475+ acc0 = expr_from_string (acc0_name )
476+ acc1 = expr_from_string (acc1_name )
477+
478+ desc_name = indexing .tensor_descriptor (state )
479+ offset0 = expr_from_string (indexing .offsets [0 ])
480+ offset1 = expr_from_string (indexing .offsets [1 ])
481+
482+ # First subtile store
483+ codegen .add_statement (
484+ statement_from_string (
485+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
486+ off0 = offset0 ,
487+ off1 = offset1 ,
488+ value = acc0 ,
489+ )
490+ )
491+
492+ offset1_shifted = expr_from_string (
493+ "({offset} + {half})" ,
494+ offset = expr_from_string (indexing .offsets [1 ]),
495+ half = expr_from_string (block_n_half_str ),
496+ )
497+
498+ # Emit second subtile store as the expression returned to the caller
499+ return expr_from_string (
500+ f"{ desc_name } .store([{{off0}}, {{off1}}], {{value}})" ,
501+ off0 = offset0 ,
502+ off1 = offset1_shifted ,
503+ value = acc1 ,
504+ )
505+
403506
404507class StackIndexingStrategy :
405508 """
0 commit comments