Skip to content

Commit c464166

Browse files
committed
[LoadStoreOpToLLVM] Refactor block load lowering of tt.load with tensor pointer.
Signed-off-by: Lu,Chengjun <chengjun.lu@intel.com>
1 parent e1e912b commit c464166

File tree

3 files changed

+663
-631
lines changed

3 files changed

+663
-631
lines changed

python/test/unit/intel/test_block_store.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -119,52 +119,69 @@ def warps_per_cta(layout):
119119
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
120120
@pytest.mark.parametrize("layout", layouts)
121121
@pytest.mark.parametrize("block_ptr", [True, False])
122+
@pytest.mark.parametrize("transpose", [True, False])
122123
@pytest.mark.skipif(not is_xpu(), reason="Block store tests are specific to the XPU backend")
123-
def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathlib.Path):
124+
def test_block_store(M, N, dtype_str, layout, block_ptr, transpose, device, tmp_path: pathlib.Path):
124125

125126
warps = warps_per_cta(layout)
126127
num_warps = int(np.prod(warps))
127128
threads_per_warp = layout.threads_per_warp
129+
128130
threads_per_warp = int(np.prod(threads_per_warp))
129131

130132
ty = {"float32": "f32", "float16": "f16", "bfloat16": "i16", "int8": "i8"}[dtype_str]
131133

132134
support_block_io = torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
133135

136+
block_io = "\"column_major\"" if transpose else "\"row_major\""
137+
134138
if block_ptr:
139+
load_ops = f"""
140+
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {"[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
141+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}, boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
142+
"""
135143
store_ops = f"""
136-
%M_i64 = arith.constant {M} : i64
137-
%N_i64 = arith.constant {N} : i64
138-
%c1_i64 = arith.constant 1 : i64
139-
%c0_i32 = arith.constant 0 : i32
140-
141-
%blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
142-
tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
144+
%dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>
145+
tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{M}x{N}x{ty}, #layout>>
143146
"""
144147
else:
148+
load_ops = f"""
149+
%src_base = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
150+
%src_ptr = tt.addptr %src_base, {"%col_major_off" if transpose else "%row_major_off" } : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
151+
%store_val = tt.load %src_ptr {{ttig.block_io = {block_io}}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
152+
"""
145153
store_ops = f"""
146-
%12 = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
147-
%13 = tt.addptr %12, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
148-
tt.store %13, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
154+
%dst_base = tt.splat %dst : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
155+
%dst_ptr = tt.addptr %dst_base, %row_major_off : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
156+
tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major"}} : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
149157
"""
150158

151159
ir = f"""
152160
#layout = {layout}
153161
module attributes {{{"ttig.support_sg_2d_block," if support_block_io else ""} "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = {threads_per_warp} : i32}} {{
154162
tt.func public @block_store(%src: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
155163
156-
%stride = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
164+
%M_i64 = arith.constant {M} : i64
165+
%N_i64 = arith.constant {N} : i64
166+
%c1_i64 = arith.constant 1 : i64
167+
%c0_i32 = arith.constant 0 : i32
168+
%stride_N = arith.constant dense<{N}> : tensor<{M}x1xi32, #layout>
157169
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>>
158170
%2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{M}x1xi32, #layout>
159-
%3 = arith.muli %2, %stride : tensor<{M}x1xi32, #layout>
171+
%row_stride = arith.muli %2, %stride_N : tensor<{M}x1xi32, #layout>
160172
%4 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>>
161173
%5 = tt.expand_dims %4 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #layout}}>> -> tensor<1x{N}xi32, #layout>
162-
%6 = tt.broadcast %3 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
174+
%6 = tt.broadcast %row_stride : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
163175
%7 = tt.broadcast %5 : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
164-
%8 = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
165-
%9 = tt.splat %src : !tt.ptr<{ty}> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
166-
%10 = tt.addptr %9, %8 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>, tensor<{M}x{N}xi32, #layout>
167-
%store_val = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #layout>
176+
%row_major_off = arith.addi %6, %7 : tensor<{M}x{N}xi32, #layout>
177+
178+
179+
%stride_M = arith.constant dense<{M}> : tensor<1x{N}xi32, #layout>
180+
%col_stride = arith.muli %5, %stride_M : tensor<1x{N}xi32, #layout>
181+
%8 = tt.broadcast %2 : tensor<{M}x1xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
182+
%9 = tt.broadcast %col_stride : tensor<1x{N}xi32, #layout> -> tensor<{M}x{N}xi32, #layout>
183+
%col_major_off = arith.addi %8, %9 : tensor<{M}x{N}xi32, #layout>
184+
{load_ops}
168185
169186
{store_ops}
170187
@@ -185,8 +202,15 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
185202
temp_file.write_text(ir)
186203
kernel = triton.compile(str(temp_file))
187204

205+
a = a.permute(1, 0).contiguous().permute(1, 0) if transpose else a
206+
207+
print("a:", a.shape, a.stride())
208+
print("x:", x.shape, x.stride())
209+
188210
kernel[(1, 1, 1)](a, x)
189211
assert torch.equal(a, x)
190212

191213
if support_block_io:
192214
assert 'spirv_Subgroup2DBlockStoreINTEL' in kernel.asm['llir'] or 'GenISA.LSC2DBlockWrite' in kernel.asm['llir']
215+
if not block_ptr:
216+
assert 'spirv_Subgroup2DBlockLoad' in kernel.asm['llir'] or 'GenISA.LSC2DBlockRead' in kernel.asm['llir']

0 commit comments

Comments
 (0)