@@ -119,52 +119,69 @@ def warps_per_cta(layout):
119119@pytest .mark .parametrize ("dtype_str" , ["float32" , "float16" , "int8" ])
120120@pytest .mark .parametrize ("layout" , layouts )
121121@pytest .mark .parametrize ("block_ptr" , [True , False ])
122+ @pytest .mark .parametrize ("transpose" , [True , False ])
122123@pytest .mark .skipif (not is_xpu (), reason = "Block store tests are specific to the XPU backend" )
123- def test_block_store (M , N , dtype_str , layout , block_ptr , device , tmp_path : pathlib .Path ):
124+ def test_block_store (M , N , dtype_str , layout , block_ptr , transpose , device , tmp_path : pathlib .Path ):
124125
125126 warps = warps_per_cta (layout )
126127 num_warps = int (np .prod (warps ))
127128 threads_per_warp = layout .threads_per_warp
129+
128130 threads_per_warp = int (np .prod (threads_per_warp ))
129131
130132 ty = {"float32" : "f32" , "float16" : "f16" , "bfloat16" : "i16" , "int8" : "i8" }[dtype_str ]
131133
132134 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
133135
136+ block_io = "\" column_major\" " if transpose else "\" row_major\" "
137+
134138 if block_ptr :
139+ load_ops = f"""
140+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], { "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]" } , [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
141+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } , boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
142+ """
135143 store_ops = f"""
136- %M_i64 = arith.constant { M } : i64
137- %N_i64 = arith.constant { N } : i64
138- %c1_i64 = arith.constant 1 : i64
139- %c0_i32 = arith.constant 0 : i32
140-
141- %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
142- tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
144+ %dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
145+ tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
143146 """
144147 else :
148+ load_ops = f"""
149+ %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
150+ %src_ptr = tt.addptr %src_base, { "%col_major_off" if transpose else "%row_major_off" } : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
151+ %store_val = tt.load %src_ptr {{ttig.block_io = { block_io } }} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
152+ """
145153 store_ops = f"""
146- %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
147- %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
148- tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
154+ %dst_base = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
155+ %dst_ptr = tt.addptr %dst_base , %row_major_off : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
156+ tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
149157 """
150158
151159 ir = f"""
152160 #layout = { layout }
153161 module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
154162 tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
155163
156- %stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
164+ %M_i64 = arith.constant { M } : i64
165+ %N_i64 = arith.constant { N } : i64
166+ %c1_i64 = arith.constant 1 : i64
167+ %c0_i32 = arith.constant 0 : i32
168+ %stride_N = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
157169 %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
158170 %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
159- %3 = arith.muli %2, %stride : tensor<{ M } x1xi32, #layout>
171+ %row_stride = arith.muli %2, %stride_N : tensor<{ M } x1xi32, #layout>
160172 %4 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
161173 %5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{ N } xi32, #layout>
162- %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
174+ %6 = tt.broadcast %row_stride : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
163175 %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
164- %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
165- %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
166- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
167- %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
176+ %row_major_off = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
177+
178+
179+ %stride_M = arith.constant dense<{ M } > : tensor<1x{ N } xi32, #layout>
180+ %col_stride = arith.muli %5, %stride_M : tensor<1x{ N } xi32, #layout>
181+ %8 = tt.broadcast %2 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
182+ %9 = tt.broadcast %col_stride : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
183+ %col_major_off = arith.addi %8, %9 : tensor<{ M } x{ N } xi32, #layout>
184+ { load_ops }
168185
169186 { store_ops }
170187
@@ -185,8 +202,15 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
185202 temp_file .write_text (ir )
186203 kernel = triton .compile (str (temp_file ))
187204
205+ a = a .permute (1 , 0 ).contiguous ().permute (1 , 0 ) if transpose else a
206+
207+ print ("a:" , a .shape , a .stride ())
208+ print ("x:" , x .shape , x .stride ())
209+
188210 kernel [(1 , 1 , 1 )](a , x )
189211 assert torch .equal (a , x )
190212
191213 if support_block_io :
192214 assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockWrite' in kernel .asm ['llir' ]
215+ if not block_ptr :
216+ assert 'spirv_Subgroup2DBlockLoad' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockRead' in kernel .asm ['llir' ]
0 commit comments