@@ -133,27 +133,35 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
133133 support_block_io = torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
134134
135135 if block_ptr :
136+ load_ops = f"""
137+ %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
138+ %store_val = tt.load %src_ptr {{boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
139+ """
136140 store_ops = f"""
137- %M_i64 = arith.constant { M } : i64
138- %N_i64 = arith.constant { N } : i64
139- %c1_i64 = arith.constant 1 : i64
140- %c0_i32 = arith.constant 0 : i32
141-
142- %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
143- tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
141+ %dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M } x{ N } x{ ty } , #layout>>
142+ tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M } x{ N } x{ ty } , #layout>>
144143 """
145144 else :
145+ load_ops = f"""
146+ %src_base = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
147+ %src_ptr = tt.addptr %src_base, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
148+ %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
149+ """
146150 store_ops = f"""
147- %12 = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
148- %13 = tt.addptr %12 , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
149- tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
151+ %dst_base = tt.splat %dst : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
152+ %dst_ptr = tt.addptr %dst_base , %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
153+ tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
150154 """
151155
152156 ir = f"""
153157 #layout = { layout }
154158 module attributes {{{ "ttig.support_sg_2d_block," if support_block_io else "" } "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps } : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp } : i32}} {{
155159 tt.func public @block_store(%src: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty } > {{tt.divisibility = 16 : i32}}) {{
156160
161+ %M_i64 = arith.constant { M } : i64
162+ %N_i64 = arith.constant { N } : i64
163+ %c1_i64 = arith.constant 1 : i64
164+ %c0_i32 = arith.constant 0 : i32
157165 %stride = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #layout>
158166 %1 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
159167 %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M } x1xi32, #layout>
@@ -163,9 +171,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
163171 %6 = tt.broadcast %3 : tensor<{ M } x1xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
164172 %7 = tt.broadcast %5 : tensor<1x{ N } xi32, #layout> -> tensor<{ M } x{ N } xi32, #layout>
165173 %8 = arith.addi %6, %7 : tensor<{ M } x{ N } xi32, #layout>
166- %9 = tt.splat %src : !tt.ptr<{ ty } > -> tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
167- %10 = tt.addptr %9, %8 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>, tensor<{ M } x{ N } xi32, #layout>
168- %store_val = tt.load %10 : tensor<{ M } x{ N } x!tt.ptr<{ ty } >, #layout>
174+ { load_ops }
169175
170176 { store_ops }
171177
@@ -191,3 +197,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
191197
192198 if support_block_io :
193199 assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockWrite' in kernel .asm ['llir' ]
200+ if not block_ptr :
201+ assert 'spirv_Subgroup2DBlockLoad' in kernel .asm ['llir' ] or 'GenISA.LSC2DBlockRead' in kernel .asm ['llir' ]
0 commit comments