From 59b05dbea99bf5c69f7b6e119a41cdd43de4d2a0 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Sat, 4 Oct 2025 23:05:00 -0700 Subject: [PATCH 1/9] [AMD][DRAFT] revamp range analysis --- .../amd-convert-buffer-ops-small-tensor.mlir | 19 +- .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 197 +++++- test/TritonGPU/amd/amd-range-analysis.mlir | 212 +++--- .../amd/include/Analysis/RangeAnalysis.h | 43 +- .../amd/lib/Analysis/RangeAnalysis.cpp | 651 ++++++++++++++---- .../ConvertToBufferOps.cpp | 205 +----- .../TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp | 3 +- .../lib/Analysis/TestAMDRangeAnalysis.cpp | 3 +- 8 files changed, 896 insertions(+), 437 deletions(-) diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index 2b146c4c51bf..75a1004dfbe3 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // COMMON: buffer_load %arg0[%[[offset]]] %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> - // COMMON: buffer_load %arg1[%[[offset]]] + // Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G. + // COMMON-NOT: buffer_load %arg1[%[[offset]]] %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> // COMMON: %[[data:.*]] = arith.addf %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_store %[[data]], %arg2[%[[offset]]] + // Note: see the explanation above + // COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]] tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> tt.return } @@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset]]] + // Note: the base "scalar_ptr" points to arg0 which is a large-tensor. + // the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128", + // We can prove "offset > 0", but cannot prove byte-offset < 2G. + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON: %[[offset_32_bit:.*]] = arith.trunci %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + // Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024)) + // offset is in [0, i32-max]. + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + // Note: the large tensor is accessed, offset is in the range of [0, smax]. + // without tl.assume the range would be [-128, smax] + // COMMON-NOT: amdgpu.buffer_atomic_rmw %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> tt.return %8 : tensor<1024xf32, #blocked> } diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 7b04877a9454..bb5121af2cbf 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -16,15 +16,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_load %arg0[%[[offset]]] + // Note: large-tensor with elemIdx=pid*256 + arange(0, 256), elemIdx ∈ [0, smax] + // COMMON-NOT: buffer_load %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> - // COMMON: buffer_load %arg1[%[[offset]]] + // COMMON-NOT: buffer_load %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> // COMMON: %[[data:.*]] = arith.addf %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_store %[[data]], %arg2[%[[offset]]] + // Note: large-tensor with elemIdx ∈ [0, smax] + // COMMON-NOT: buffer_store tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> tt.return } @@ -43,6 +45,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> %cmp = arith.cmpi sgt, %arg6, %c0_i32 : i32 llvm.intr.assume %cmp : i1 + %arg6_upper = arith.constant 4194304 : i32 + %cmp2 = arith.cmpi slt, %arg6, %arg6_upper : i32 + llvm.intr.assume %cmp2 : i1 %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 @@ -78,6 +83,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %23 = arith.addi %21, %20 : tensor<256x64xi32, #blocked> %24 = tt.splat %22 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> %25 = tt.addptr %24, %23 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %ofst_upper = arith.constant 1073741823 : i32 + %cmp3 = arith.cmpi slt, %ofst_upper, %ofst_upper : i32 + llvm.intr.assume %cmp3 : i1 // COMMON: %[[splatb:.*]] = tt.splat %arg[[#strideb:]] // COMMON: %[[mulb:.*]] = arith.muli %[[splatb]], %[[#]] @@ -85,12 +93,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[bcast0b:.*]] = tt.broadcast %[[#]] // COMMON: %[[ptrb:.*]] = tt.addptr // COMMON: %[[offsetb:.*]] = arith.addi %[[bcast0b]], %[[bcast1b]] - // COMMON: buffer_store %[[buffer]], %[[ptrb]][%[[offsetb]]] stride = %arg[[#strideb]] + // COMMON-NOT: buffer_store tt.store %25, %12 : tensor<256x64x!tt.ptr, #blocked> tt.return } } + + // ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -113,7 +123,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset]]] + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -165,7 +175,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON: %[[offset_32_bit:.*]] = arith.trunci %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + // Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024)) + // offset is in [0, i32-max]. + // COMMON-NOT: buffer_load %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -265,9 +277,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %14 = tt.addptr %13, %11 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> %15 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> %16 = tt.addptr %15, %offsets : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%[[offsets]]] + // COMMON-NOT: amdgpu.buffer_load %17 = tt.load %16 : tensor<16x!tt.ptr, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg0[%[[range]]] + // COMMON: amdgpu.buffer_store tt.store %14, %17 : tensor<16x!tt.ptr, #blocked> tt.return } @@ -364,11 +376,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> %2 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %3 = tt.addptr %2, %0 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: index is tt.histogram ∈ [0, smax) + // COMMON-NOT: amdgpu.buffer_load %4 = tt.load %3 : tensor<8x!tt.ptr, #blocked> %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %6 = tt.addptr %5, %1 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // Note: index is tt.histogram ∈ [0, smax) + // COMMON: amdgpu.buffer_store tt.store %6, %4 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -391,11 +405,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %8 = arith.addi %6, %7 : tensor<8xi32, #blocked> %9 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // COMMON-NOT: amdgpu.buffer_load %11 = tt.load %10 : tensor<8x!tt.ptr, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %13 = tt.addptr %12, %7 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %13, %11 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -426,11 +440,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %17 = arith.addi %15, %16 : tensor<8xi32, #blocked> %18 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %19 = tt.addptr %18, %17 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: above operations can only prove elmtIdx >= 0 not don't reveal its upper bound. + // COMMON-NOT: amdgpu.buffer_load %20 = tt.load %19 : tensor<8x!tt.ptr, #blocked> %21 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %22 = tt.addptr %21, %16 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %22, %20 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -450,11 +465,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = arith.trunci %4 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> %6 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %7 = tt.addptr %6, %5 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: elemIdx is (int32)(arange(0, 8) + (uint64)(uint32)arg2) + // elemIdx is not necessarilly >=0 + // COMMON-NOT: amdgpu.buffer_load %8 = tt.load %7: tensor<8x!tt.ptr, #blocked> %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %2 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %10, %8 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -490,12 +507,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %4 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: It's not able to prove that the value range of elmtIdx in [0,1G]. + // testing case traverse_if_2nd, traverse_if_2nd_v2 and traverse_if_2nd_v3 + // works better than this case for this purpose. + // COMMON-NOT:amdgpu.buffer_load %7 = tt.load %6: tensor<8x!tt.ptr, #blocked> %8 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %10, %7 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -505,8 +525,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { - // COMMON-LABEL: traverse_if - tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + // COMMON-LABEL: traverse_if_2nd + tt.func @traverse_if_2nd(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON-NOT: amdgpu.buffer_load + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON: amdgpu.buffer_store + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // COMMON-LABEL: traverse_if_2nd_v2 + tt.func @traverse_if_2nd_v2(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c5_i32 = arith.constant 7 : i32 @@ -534,6 +598,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + + // Note: + // elmtIdx = %6 = %4 + %5, value-range(%4) = [0,7], value-range(%5) = [0, umax] + // %5 = max([0,8] + arg3, [8,16) + arg2), to make %6 * sizeof(bf16) <= 2G - 2byte + // arg3 ∈ [0, 1G-1-8-7 = 1073741808), arg2 ∈ [-8, 1G-1-15-8=1073741800] + %cmp1 = arith.cmpi sge, %arg2, %c0_i32 : i32 + llvm.intr.assume %cmp1 : i1 + %cmp2 = arith.cmpi sge, %arg3, %c0_i32 : i32 + llvm.intr.assume %cmp2 : i1 + %arg_up2 = arith.constant 1073741800 : i32 + %arg_up3 = arith.constant 1073741808 : i32 + %cmp3 = arith.cmpi slt, %arg2, %arg_up2 : i32 + %cmp4 = arith.cmpi slt, %arg3, %arg_up3 : i32 + llvm.intr.assume %cmp3 : i1 + llvm.intr.assume %cmp4 : i1 + // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> @@ -547,6 +627,68 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // COMMON-LABEL: traverse_if_2nd_v3 + tt.func @traverse_if_2nd_v3(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + + // Note: + // elmtIdx = %6 = %4 + %5, value-range(%4) = [0,7], value-range(%5) = [0, umax] + // %5 = max([0,8] + arg3, [8,16) + arg2), to make %6 * sizeof(bf16) <= 2G - 2byte + // arg3 ∈ [0, 1G-1-8-7 = 1073741808), arg2 ∈ [-8, 1G-1-15-8=1073741800] + %cmp1 = arith.cmpi sge, %arg2, %c0_i32 : i32 + llvm.intr.assume %cmp1 : i1 + %cmp2 = arith.cmpi sge, %arg3, %c0_i32 : i32 + llvm.intr.assume %cmp2 : i1 + // the only difference between traverse_if_2nd_v3 and traverse_if_2nd_v2 + // is arg_up2. In v3 the upper bound is bumped by 1. + %arg_up2 = arith.constant 1073741801 : i32 + %arg_up3 = arith.constant 1073741808 : i32 + %cmp3 = arith.cmpi slt, %arg2, %arg_up2 : i32 + %cmp4 = arith.cmpi slt, %arg3, %arg_up3 : i32 + llvm.intr.assume %cmp3 : i1 + llvm.intr.assume %cmp4 : i1 + + // COMMON-NOT: amdgpu.buffer_load + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON: amdgpu.buffer_store + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // COMMON-LABEL: atomic_add_bf16 @@ -589,7 +731,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + // Note: the large tensor is accessed, offset is in the range of [0, smax]. + // without tl.assume the range would be [-128, smax] + // COMMON-NOT: amdgpu.buffer_atomic_rmw %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> tt.return %8 : tensor<1024xf32, #blocked> } @@ -629,6 +773,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[bcast0:.*]] = tt.broadcast %[[#]] // COMMON: %[[offset:.*]] = arith.addi %[[bcast0]], %[[bcast1]] + // Note: offset(i.e. elmtIdx) = bcast0 + bcast1 + // = arange(0, 64) + arg6 * arange(0, 256) + // to make elmtIdx * sizeof(f16) ∈ [0, 2G], arg6 must be in [0, 4210752] + %arg6_up = arith.constant 4210752: i32 + %cmp2 = arith.cmpi slt, %arg6, %arg6_up : i32 + llvm.intr.assume %cmp2 : i1 + // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] stride = %arg[[#stride]] into %arg10 %12 = ttg.async_copy_global_to_local %11, %arg10 : tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> @@ -644,10 +795,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = ca into %arg10 %16 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = ca: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> - // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cg into %arg10 + // COMMONx: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cg into %arg10 %17 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cg: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> - // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cv into %arg10 + // COMMONx: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cv into %arg10 %18 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cv: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> tt.return } diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 2c40188e1c44..92498ed94c17 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -97,14 +97,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+2 {{unsigned : [0, 2046] signed : [0, 2046]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %6, %4 : tensor<1024xi64> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -132,7 +132,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 %4 = tt.addptr %3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+2 {{unsigned : [0, 2046] signed : [0, 2046]}} // expected-remark@+1 {{non-neg}} %5 = arith.addi %2, %2 : tensor<1024xi32> %6 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -162,18 +162,18 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+3 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+3 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %5:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %14 = arith.addi %13, %arg4 : tensor<1024xi64> %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -183,10 +183,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %8 = arith.addi %7, %5#1 : tensor<1024xi64> %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -216,15 +216,15 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+3 {{result 1: unsigned : [0, 130048] signed : [0, 130048]}} + // expected-remark@+3 {{result 1: unsigned : [0, 129921] signed : [0, 129921]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %10 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %11 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %12 = arith.addi %11, %arg4 : tensor<1024xi64> %13 = tt.splat %10 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -234,10 +234,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10, %12, %16 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %3#1 : tensor<1024xi64> %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -267,19 +267,19 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+3 {{result 1: unsigned : [0, 15360] signed : [0, 15360]}} + // expected-remark@+3 {{result 1: unsigned : [0, 15345] signed : [0, 15345]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 16}} %3:3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { - // expected-remark@+3 {{result 1: unsigned : [0, 261120] signed : [0, 261120]}} + // expected-remark@+3 {{result 1: unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 256}} %10:3 = scf.for %arg6 = %c0 to %c16 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 262144] signed : [0, 262144]}} + // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} // expected-remark@+1 {{non-neg}} %13 = arith.addi %12, %arg8 : tensor<1024xi64> %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -291,10 +291,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 16384] signed : [0, 16384]}} + // expected-remark@+2 {{unsigned : [0, 16368] signed : [0, 16368]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %3#1 : tensor<1024xi64> %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -331,7 +331,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{inferred total trip count: 16384}} %10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -345,7 +345,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -375,11 +375,11 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+2 {{result 1: unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{result 1: unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{result 1: non-neg}} %3:2 = scf.if %arg2 -> (!tt.ptr, tensor<1024xi64>) { %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> scf.yield %8, %9 : !tt.ptr, tensor<1024xi64> @@ -387,7 +387,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 scf.yield %8, %cst : !tt.ptr, tensor<1024xi64> } - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.trunci %3#1 : tensor<1024xi64> to tensor<1024xi32> %5 = tt.splat %3#0 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -416,7 +416,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> cf.cond_br %arg1, ^bb1(%arg0, %cst : !tt.ptr, tensor<1024xi64>), ^bb2(%3, %4 : !tt.ptr, tensor<1024xi64>) @@ -429,7 +429,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %10 = tt.load %9 : tensor<1024x!tt.ptr> tt.return %10 : tensor<1024xf32> ^bb2(%11: !tt.ptr, %12: tensor<1024xi64>): // pred: ^bb0 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.trunci %12 : tensor<1024xi64> to tensor<1024xi32> %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -483,7 +483,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c256_i32 : i32 %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -492,10 +492,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { %6 = arith.muli %4, %5 : tensor<16x1xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %7 = tt.broadcast %6 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %8 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %9 = tt.broadcast %8 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -535,7 +535,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %7 = arith.muli %4, %6 : tensor<128x1xi32, #blocked> %8 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %10 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -567,14 +567,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = arith.select %arg1, %arg0, %3 : !tt.ptr - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.select %arg1, %cst, %4 : tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -630,24 +630,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { llvm.intr.assume %cmpule_pid : i1 %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %3 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+3 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+3 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %4:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %2, %arg4 = %3, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %11 = arith.trunci %arg4 : tensor<1024xi64> to tensor<1024xi32> %12 = tt.splat %arg3 : !tt.ptr -> tensor<1024x!tt.ptr> %13 = tt.addptr %12, %11 : tensor<1024x!tt.ptr>, tensor<1024xi32> %14 = tt.load %13 : tensor<1024x!tt.ptr> %15 = tt.addptr %arg3, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %16 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %17 = arith.addi %16, %arg4 : tensor<1024xi64> %18 = tt.addptr %15, %0 : !tt.ptr, i32 @@ -655,10 +655,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %18, %17, %19 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} %5 = tt.addptr %4#0, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %6, %4#1 : tensor<1024xi64> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -744,24 +744,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+5 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} - // expected-remark@+4 {{result 3: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+5 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} + // expected-remark@+4 {{result 3: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+3 {{result 1: non-neg}} // expected-remark@+2 {{result 3: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %7:5 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %6, %arg5 = %3, %arg6 = %4, %arg7 = %arg1) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %14 = tt.addptr %arg5, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %16 = arith.addi %15, %arg6 : tensor<1024xi64> %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -771,10 +771,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %14, %16, %arg3, %arg4, %20 : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %9, %7#1 : tensor<1024xi64> %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -804,24 +804,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+5 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} - // expected-remark@+4 {{result 4: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+5 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} + // expected-remark@+4 {{result 4: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+3 {{result 1: non-neg}} // expected-remark@+2 {{result 4: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %7:6 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %arg2, %arg7 = %5, %arg8 = %6, %arg9 = %arg2) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %20 = tt.addptr %arg4, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %21 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %22 = arith.addi %21, %arg5 : tensor<1024xi64> %23 = tt.splat %20 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -829,10 +829,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { %25 = tt.load %24 : tensor<1024x!tt.ptr> %26 = arith.addf %25, %arg6 : tensor<1024xf32> %27 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %28 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %29 = arith.addi %28, %arg8 : tensor<1024xi64> %30 = tt.splat %27 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -842,20 +842,20 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %27, %29, %33, %20, %22, %26 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %9, %7#1 : tensor<1024xi64> %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> %13 = tt.load %12 : tensor<1024x!tt.ptr> %14 = tt.addptr %7#3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %16 = arith.addi %15, %7#4 : tensor<1024xi64> %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -886,14 +886,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+2 {{result 1: unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} // expected-remark@+1 {{inferred total trip count: 1025}} %5:3 = scf.for %arg2 = %c0 to %c128 step %K iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -905,7 +905,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -1235,32 +1235,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr, %arg1: !tt.ptr) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32> - // expected-remark@+2 {{unsigned : [0, 10] signed : [0, 10]}} + // expected-remark@+2 {{unsigned : [0, 9] signed : [0, 9]}} // expected-remark@+1 {{non-neg}} %2 = tt.join %0, %1 : tensor<8xi32> -> tensor<8x2xi32> %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32> - // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} %5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32> - // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} %6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %2, %6 : tensor<8x2xi32> %zeros = arith.constant dense<0> : tensor<8x1xi32> %ones = arith.constant dense<1> : tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %8 = tt.gather %7[%zeros] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %9 = tt.gather %7[%ones] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %8, %9 : tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} // expected-remark@+1 {{non-neg}} %11 = tt.reshape %10 allow_reorder : tensor<8x1xi32> -> tensor<8xi32> tt.return @@ -1316,7 +1316,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %6 = tt.splat %5 : i32 -> tensor<8xi32> %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - // expected-remark@+1 {{unsigned : [0, 2147483655] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} %8 = arith.addi %6, %7 : tensor<8xi32> tt.return } @@ -1332,31 +1332,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32> -> tensor<8x2xi32> %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32> -> tensor<2x8xi32> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %4 = tt.trans %3 {order = array} : tensor<2x8xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %5 = ttg.convert_layout %4 : tensor<8x2xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} + // expected-remark@+2 {{unsigned : [0, 30] signed : [0, 30]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %2 : tensor<8x2xi32> %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32> - // expected-remark@+2 {{unsigned : [2, 10] signed : [2, 10]}} + // expected-remark@+2 {{unsigned : [2, 9] signed : [2, 9]}} // expected-remark@+1 {{non-neg}} %8 = ttg.convert_layout %7 : tensor<8xi32> -> tensor<8xi32> %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> %10 = tt.broadcast %9 : tensor<1x8xi32> -> tensor<2x8xi32> %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32> -> tensor<8x2xi32> %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [7, 15] signed : [7, 15]}} + // expected-remark@+2 {{unsigned : [7, 14] signed : [7, 14]}} // expected-remark@+1 {{non-neg}} %13 = arith.addi %11, %12 : tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} + // expected-remark@+2 {{unsigned : [0, 14] signed : [0, 14]}} // expected-remark@+1 {{non-neg}} %14 = arith.minsi %13, %5 : tensor<8x2xi32> - // expected-remark@+4 {{result 0: unsigned : [2, 10] signed : [2, 10]}} - // expected-remark@+3 {{result 1: unsigned : [2, 10] signed : [2, 10]}} + // expected-remark@+4 {{result 0: unsigned : [2, 9] signed : [2, 9]}} + // expected-remark@+3 {{result 1: unsigned : [2, 9] signed : [2, 9]}} // expected-remark@+2 {{result 0: non-neg}} // expected-remark@+1 {{result 1: non-neg}} %15, %16 = tt.split %11: tensor<8x2xi32> -> tensor<8xi32> @@ -1538,7 +1538,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+5 {{arg 5: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} // expected-remark@+4 {{arg 6: unsigned : [1, 2147483647] signed : [1, 2147483647]}} // expected-remark@+3 {{arg 7: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} - // expected-remark@+2 {{arg 8: unsigned : [1, 2147483647] signed : [1, 1023]}} + // expected-remark@+2 {{arg 8: unsigned : [1, 1023] signed : [1, 1023]}} // expected-remark@+1 {{arg 9: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} tt.func public @buffer_stride(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) { %c1024_i32 = arith.constant 1024 : i32 @@ -1546,7 +1546,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %c32_i32 = arith.constant 32 : i32 %c0_i32 = arith.constant 0 : i32 %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} @@ -1556,29 +1556,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+2 {{unsigned : [1, 2147483647] signed : [1, 2147483647]}} // expected-remark@+1 {{non-neg}} %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 - // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} %5 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %8 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} %9 = arith.addi %8, %5 : tensor<256x64xi32, #blocked> %10 = tt.splat %4 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> %11 = tt.addptr %10, %9 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> %12 = tt.load %11 : tensor<256x64x!tt.ptr, #blocked> %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %15 = tt.expand_dims %13 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %16 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} @@ -1589,21 +1589,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+1 {{result is true}} %cmp2 = arith.cmpi slt, %arg8, %c1024_i32 : i32 llvm.intr.assume %cmp2 : i1 - // expected-remark@+2 {{unsigned : [1, 2147483647] signed : [1, 1023]}} + // expected-remark@+2 {{unsigned : [1, 1023] signed : [1, 1023]}} // expected-remark@+1 {{non-neg}} %17 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} + // expected-remark@+2 {{unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+1 {{non-neg}} %18 = arith.muli %17, %15 : tensor<256x1xi32, #blocked> %19 = tt.addptr %arg2, %c48_i32 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} + // expected-remark@+2 {{unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+1 {{non-neg}} %20 = tt.broadcast %18 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %21 = tt.broadcast %16 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> %22 = tt.addptr %19, %c48_i32 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 261952] signed : [0, 261952]}} + // expected-remark@+2 {{unsigned : [0, 260928] signed : [0, 260928]}} // expected-remark@+1 {{non-neg}} %23 = arith.addi %21, %20 : tensor<256x64xi32, #blocked> %24 = tt.splat %22 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> @@ -1652,7 +1652,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 2097120] signed : [0, 2097120]}} // expected-remark@+1 {{non-neg}} %5 = tt.splat %3 : i32 -> tensor<32xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %4 : tensor<32xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -1660,13 +1660,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-16777216, 16777215]}} %8 = arith.divsi %7, %c128_i32 : i32 %9 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %10 = ttg.convert_layout %6 : tensor<32xi32, #blocked> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %12 = ttg.convert_layout %11 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked2> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -1685,19 +1685,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 2147483392] signed : [0, 2147483392]}} // expected-remark@+1 {{non-neg}} %27 = tt.splat %26 : i32 -> tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %28 = arith.addi %27, %9 : tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %29 = ttg.convert_layout %28 : tensor<128xi32, #blocked> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x128xi32, #blocked4> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %31 = ttg.convert_layout %30 : tensor<1x128xi32, #blocked4> -> tensor<1x128xi32, #blocked3> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %32 = tt.broadcast %31 : tensor<1x128xi32, #blocked3> -> tensor<32x128xi32, #blocked3> %33 = tt.addptr %18, %32 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi32, #blocked3> @@ -1714,7 +1714,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} // expected-remark@+1 {{non-neg}} %20 = tt.splat %2 : i32 -> tensor<32xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} %21 = arith.muli %6, %20 : tensor<32xi32, #blocked> %22 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #blocked> %23 = tt.addptr %22, %21 : tensor<32x!tt.ptr, #blocked>, tensor<32xi32, #blocked> diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index e14f21a89cde..be19640a878e 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -4,6 +4,7 @@ #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Dominance.h" #include "mlir/Interfaces/LoopLikeInterface.h" namespace mlir::triton { @@ -32,15 +33,20 @@ namespace mlir::triton::AMD { /// See visitRegionSuccessors. struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis; + using Base = dataflow::IntegerRangeAnalysis; TritonIntegerRangeAnalysis( DataFlowSolver &solver, - const DenseMap> &assumptions) - : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {} + const DenseMap> &assumptions, + DominanceInfo *dominanceInfo) + : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions), + domInfo(dominanceInfo) {} void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; void initializeFuncOp(triton::FuncOp funcOp); + LogicalResult initialize(Operation *top) override; + LogicalResult visitOperation( Operation *op, ArrayRef operands, @@ -95,7 +101,8 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { /// llvm.intr.assume %assumesltlhs : i1 /// for %K, will produce a final range /// [0, 2147483647] ∩ [-2147483648, 128] = [0, 128] - std::optional maybeGetAssumedRange(Value anchor) const; + std::optional maybeGetAssumedRange(Value anchor, + Block *useBlock) const; int64_t getTotalLoopTripCount(LoopLikeOpInterface loop); @@ -125,6 +132,36 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { /// If one uses collectAssumptions below then `assumptions` will look like /// %K -> {arith.cmpi slt..., arith.cmpi sge}. llvm::DenseMap> assumptions; + + /// The defaultTransferFunc is the default transfer function for this dataflow + /// problem. + /// @param[in] op: the Operation in question + /// @param[in] result: a particular value defined by this op. Note that op + /// may define multiple values. + /// @param[in] srcLattices: lattices of all source operands + /// @param[in] destLattices: lattices all all result values + /// @param[in] incomingRange: the value-range inffered for result + void defaultTransferFunc( + Operation *op, Value result, + ArrayRef srcLattices, + ArrayRef destLattices, + const IntegerValueRange &incomingRange); + +private: + void visitYieldHelper(Operation *yieldOp, Value value); + LogicalResult visitOperationHelper( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices); + + std::optional rectifyInfferableRange( + InferIntRangeInterface interface, + ArrayRef srcLattices, + const IntegerValueRange &range); + + DenseSet signedIntValues; + llvm::SmallMapVector opResultAssumption; + DominanceInfo *domInfo = nullptr; }; std::optional>> diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 69e483b883f2..c4b023b49298 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -4,7 +4,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Iterators.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -22,68 +24,6 @@ using namespace mlir; namespace tt = mlir::triton; -std::optional -triton::AMD::TritonIntegerRangeAnalysis::maybeGetTripCount( - LoopLikeOpInterface loop) { - std::optional lowerBound = loop.getSingleLowerBound(); - std::optional upperBound = loop.getSingleUpperBound(); - std::optional step = loop.getSingleStep(); - std::optional iv = loop.getSingleInductionVar(); - if (!iv) - return {}; - - unsigned int width = ConstantIntRanges::getStorageBitwidth(iv->getType()); - - auto getLoopRangeInfo = [&](std::optional loopBound, - Block *block, - std::optional getUpper = std::nullopt, - std::optional defaultVal = std::nullopt) { - if (loopBound.has_value()) { - if (auto attr = dyn_cast(*loopBound)) { - if (auto bound = dyn_cast_or_null(attr)) - return bound.getValue(); - } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { - const dataflow::IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointBefore(block), value); - if (lattice != nullptr && !lattice->getValue().isUninitialized()) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); - } - } - if (defaultVal) - return *defaultVal; - return getUpper ? APInt::getSignedMaxValue(width) - : APInt::getSignedMinValue(width); - }; - - Block *block = iv->getParentBlock(); - APInt min = getLoopRangeInfo(lowerBound, block, - /*getUpper=*/false); - APInt max = getLoopRangeInfo(upperBound, block, - /*getUpper=*/true); - // We can assume step is 1 if no range information as that gives us the upper - // bound of the number of iterations. - APInt stepValDefault = {width, 1, /*isSigned=*/true}; - APInt stepVal = - getLoopRangeInfo(step, block, /*getUpper=*/{}, stepValDefault); - - if (stepVal.isNegative()) - std::swap(min, max); - // This is necessary to catch a case like this: - // # range = [0 1024] - // K = .... - // # range = [1, 64] - // k = ... - // # range = [0, 16] -> stepVal = range.smin() = 0 - // step = ceildiv(K, k) - if (stepVal.isZero()) - stepVal = stepValDefault; - if (max.sge(min)) - return llvm::divideCeilSigned(max.getSExtValue() - min.getSExtValue(), - stepVal.getSExtValue()); - return {}; -} - namespace { constexpr int64_t kDefaultMaxTripCount = 1024; @@ -98,6 +38,25 @@ void getEnclosingLoops(Operation &op, SmallVector &ops) { } } +tt::FuncOp getEnclosingFunction(Value v) { + tt::FuncOp funcOp = nullptr; + + auto definingOp = v.getDefiningOp(); + if (!definingOp) + if (auto blk = v.getParentBlock()) + definingOp = blk->getParentOp(); + + if (definingOp) { + funcOp = dyn_cast_or_null(definingOp); + if (!funcOp) + funcOp = definingOp->getParentOfType(); + } + assert(funcOp && "No enclosing tt::FuncOp"); + return funcOp; +} + +Block *getFuncEntryBlock(tt::FuncOp func) { return &func.getRegion().front(); } + void inferResultRangesPID(Operation *op, uint64_t max, SetIntRangeFn setResultRange) { assert(op->getNumResults() == 1 && "expected op to have one result"); @@ -126,7 +85,7 @@ void inferResultRanges(tt::MakeRangeOp *op, SetIntRangeFn setResultRange) { /*min*/ {/*numBits*/ bitWidth, /*val*/ op->getStart(), /*isSigned*/ elTy.isSigned()}, /*max*/ - {/*numBits*/ bitWidth, /*val*/ op->getEnd(), + {/*numBits*/ bitWidth, /*val*/ op->getEnd() - 1, /*isSigned*/ elTy.isSigned()}, /*isSigned*/ elTy.isSigned())); } @@ -164,14 +123,50 @@ void inferResultRangesMaxNonNegSigned(Operation *op, } } -std::optional maybeGetAssumedRange(Operation *assumption, - Value anchor) { +// Given an assumption operaiton, try to derive the value range of the value +// 's value range at the somewhere in the block "useBlock". +// Note that +// - The value "anchor" is defined or referenced in the "useBlock" +// - The location of the reference of "anchor" in the "useBlock" does not +// matter because the IR is in SSA form, the value-range of a quantity +// does not change through out the entire block. +// - The assumption should be ignored if it does not dominate the "useBlock". +// +// Consider following cases: +// +// case 1: both s2 and s3 are applicable to s1 because they dominate s1 +// s2: assume y > 5 +// ... +// if cond +// s3: assume z < 3 +// s1: x = y + z +// +// case 2: s2 is applicable to s1 even if s2 stay after s1. +// blk: +// s1: x = y + z +// s2: assume y > 5 +// +// case 3: s2 is not applicable to s1 because the block of else-caluse does not +// domoinate the then-clause block. +// if cond +// s1: x = y + z +// else +// s2: assume y > 5 +// +std::optional +maybeGetAssumedRangeHelper(Operation *assumption, Value anchor, Block *useBlock, + DominanceInfo *domInfo) { + arith::CmpIOp cmpOp = llvm::dyn_cast(assumption); if (!cmpOp) { emitRemark(assumption->getLoc(), "unsupported assumption operation"); return {}; } + Block *anchorBlock = anchor.getParentBlock(); + if (!anchorBlock || !domInfo->dominates(anchorBlock, useBlock)) + return {}; + bool isSigned = true; switch (cmpOp.getPredicate()) { case arith::CmpIPredicate::uge: @@ -248,10 +243,204 @@ std::optional maybeGetAssumedRange(Operation *assumption, return {}; } +std::optional +maybeGetAssumedRange(const SetVector &allAssumptions, Value anchor, + Block *useBlock, DominanceInfo *domInfo) { + + std::optional result; + for (auto assumption : allAssumptions) { + auto tmpResult = + maybeGetAssumedRangeHelper(assumption, anchor, useBlock, domInfo); + if (!tmpResult.has_value()) + continue; + + if (result.has_value()) + result = (*result).intersection(*tmpResult); + else + result = *tmpResult; + } + + if (result) { + const auto &val = *result; + if (val.smin().isNonNegative()) { + // Consider 0 < x && x < 1024. + // When processing x > 0, the value range of x is + // vr1={umin=0, umax=0xf...f, smin=0, smax=0x7...f} + // When processing x < 1024, the value range of x is: + // vr2={umin=0, umax=0xf...f, smin=..., smax=1024} + // and + // vr1 ∩ vr2 = {umin=0, umax=0xf...f, smin=0, smax=1024} + // note that the umax=0xf...f is annoying, need to change to 1024. + return ConstantIntRanges::range(val.smin(), val.smax(), true); + } + } + return result; +} + +// arith dialect in general does not differentiate signed int and unsigned int; +// integer value is signed or unsigned depends on how it's used. +static void collectValueOfSignedInt(Operation *top, DenseSet &valueSet) { + SetVector worklist; + + // Initialize the worklist with some known signed interger values. + top->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case( + [&](auto addPtrOp) { worklist.insert(addPtrOp.getOffset()); }) + .Case([&](auto binop) { + worklist.insert(binop.getResult()); + worklist.insert(binop.getOperand(0)); + worklist.insert(binop.getOperand(1)); + }) + .Case( + [&](auto sExt) { worklist.insert(sExt.getResult()); }) + .Case([&](auto cmpOp) { + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + worklist.insert(cmpOp.getOperand(0)); + worklist.insert(cmpOp.getOperand(1)); + break; + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + worklist.insert(cmpOp.getOperand(0)); + worklist.insert(cmpOp.getOperand(1)); + break; + default: + break; + }; + }); + }); + + valueSet.clear(); + auto addToWorklist = [&](Value v) { + if (!valueSet.count(v)) + worklist.insert(v); + }; + + while (!worklist.empty()) { + auto v = worklist.back(); + worklist.pop_back(); + Operation *op = v.getDefiningOp(); + + // If the result of this op is signed int, then its source operands are + // singed int. + if (op) { + llvm::TypeSwitch(op) + .Case([&](auto binOp) { + addToWorklist(binOp.getOperand(0)); + addToWorklist(binOp.getOperand(1)); + }) + .Case( + [&](auto unary) { addToWorklist(unary.getOperand()); }); + } + + SmallVector results; + if (op) + results = op->getResults(); + else + results.push_back(v); + + for (auto result : results) { + if (valueSet.count(result)) + continue; + + valueSet.insert(result); + + for (mlir::OpOperand &use : result.getUses()) { + llvm::TypeSwitch(use.getOwner()) + .Case( + [&](auto op) { addToWorklist(op.getResult()); }) + .Case( + [&](auto binOp) { addToWorklist(binOp.getResult()); }); + } + } + } + + LLVM_DEBUG({ + DBGS() << "Values considered as signed int (begin)\n"; + OpPrintingFlags flags; + flags.skipRegions(true); + for (auto v : valueSet) { + DBGS() << " - "; + v.print(llvm::dbgs(), flags); + llvm::dbgs() << "\n"; + } + DBGS() << "Values considered as signed int (end)\n"; + }); +} + } // namespace namespace mlir::triton::AMD { +std::optional +TritonIntegerRangeAnalysis::maybeGetTripCount(LoopLikeOpInterface loop) { + std::optional lowerBound = loop.getSingleLowerBound(); + std::optional upperBound = loop.getSingleUpperBound(); + std::optional step = loop.getSingleStep(); + std::optional iv = loop.getSingleInductionVar(); + if (!iv) + return {}; + + unsigned int width = ConstantIntRanges::getStorageBitwidth(iv->getType()); + + auto getLoopRangeInfo = [&](std::optional loopBound, + Block *block, + std::optional getUpper = std::nullopt, + std::optional defaultVal = std::nullopt) { + if (loopBound.has_value()) { + if (auto attr = dyn_cast(*loopBound)) { + if (auto bound = dyn_cast_or_null(attr)) + return bound.getValue(); + } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { + const dataflow::IntegerValueRangeLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); + } + } + if (defaultVal) + return *defaultVal; + return getUpper ? APInt::getSignedMaxValue(width) + : APInt::getSignedMinValue(width); + }; + + Block *block = iv->getParentBlock(); + APInt min = getLoopRangeInfo(lowerBound, block, + /*getUpper=*/false); + APInt max = getLoopRangeInfo(upperBound, block, + /*getUpper=*/true); + // We can assume step is 1 if no range information as that gives us the upper + // bound of the number of iterations. + APInt stepValDefault = {width, 1, /*isSigned=*/true}; + APInt stepVal = + getLoopRangeInfo(step, block, /*getUpper=*/{}, stepValDefault); + + if (stepVal.isNegative()) + std::swap(min, max); + // This is necessary to catch a case like this: + // # range = [0 1024] + // K = .... + // # range = [1, 64] + // k = ... + // # range = [0, 16] -> stepVal = range.smin() = 0 + // step = ceildiv(K, k) + if (stepVal.isZero()) + stepVal = stepValDefault; + if (max.sge(min)) + return llvm::divideCeilSigned(max.getSExtValue() - min.getSExtValue(), + stepVal.getSExtValue()); + return {}; +} + bool isEmptyInitializedRange(ConstantIntRanges rv) { if (!rv.umin().getBitWidth() || !rv.umax().getBitWidth() || !rv.smin().getBitWidth() || !rv.smax().getBitWidth()) @@ -294,26 +483,20 @@ bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp) { return false; } +LogicalResult TritonIntegerRangeAnalysis::initialize(Operation *top) { + signedIntValues.clear(); + collectValueOfSignedInt(top, signedIntValues); + return Base::initialize(top); +} + std::optional -TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor) const { - auto matchingAssumptions = this->assumptions.lookup(anchor); +TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor, + Block *useBlock) const { + const auto &matchingAssumptions = this->assumptions.lookup(anchor); if (matchingAssumptions.empty()) return {}; - unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType()); - assert(bitWidth > 0 && "expected non-zero bitwidth"); - ConstantIntRanges constIntRange = ConstantIntRanges::maxRange(bitWidth); - if (llvm::isa_and_nonnull(anchor.getDefiningOp())) { - constIntRange = ConstantIntRanges::range( - APInt::getZero(bitWidth), - APInt(bitWidth, kDefaultMaxPrograms - 1, true), true); - } - - for (auto assumption : matchingAssumptions) { - if (auto constIntRange_ = ::maybeGetAssumedRange(assumption, anchor)) - constIntRange = constIntRange.intersection(*constIntRange_); - } - return constIntRange; + return ::maybeGetAssumedRange(matchingAssumptions, anchor, useBlock, domInfo); } int64_t @@ -333,8 +516,10 @@ void TritonIntegerRangeAnalysis::setToEntryState( if (!llvm::isa(getElementTypeOrSelf(anchor)) && !llvm::isa(getElementTypeOrSelf(anchor))) return; + + Block *entryBlock = getFuncEntryBlock(getEnclosingFunction(anchor)); IntegerValueRange range = IntegerValueRange::getMaxRange(anchor); - if (auto maybeRange = maybeGetAssumedRange(anchor)) + if (auto maybeRange = maybeGetAssumedRange(anchor, entryBlock)) range = *maybeRange; auto changed = lattice->join(range); LLVM_DEBUG({ @@ -347,52 +532,250 @@ void TritonIntegerRangeAnalysis::setToEntryState( propagateIfChanged(lattice, changed); } +void TritonIntegerRangeAnalysis::defaultTransferFunc( + Operation *op, Value resultVal, + ArrayRef srcLattices, + ArrayRef resultsLattices, + const IntegerValueRange &incomingRange) { + + // step 1: Preparation + // - Get the lattice associated with given particular result value. + // - Make a copy of value-range just inferred, as we need to do some + // change to it before it's joined to the existing lattice. + auto result = dyn_cast(resultVal); + if (!result) + return; + assert(llvm::is_contained(op->getResults(), result)); + + dataflow::IntegerValueRangeLattice *lattice = + resultsLattices[result.getResultNumber()]; + IntegerValueRange incomingRange_ = incomingRange; + + // step 2: Some range value in MLIR lib is too conservative, update the + // value-range before it is jointed to the lattice. + if (auto inferrable = dyn_cast(op)) { + auto res = rectifyInfferableRange(inferrable, srcLattices, incomingRange_); + if (res.has_value()) + incomingRange_ = std::move(*res); + } + + // step 3: If there is assumed value range, the assumed one take precedence. + // TODO: I think this is bit conservative, the better way is: + // final_range = (old_range ∪ incomingRange) ∩ assume_range + if (auto iter = opResultAssumption.find(resultVal); + iter != opResultAssumption.end()) { + const auto &range = iter->second; + if (auto maybeRange = maybeGetAssumedRange(resultVal, op->getBlock())) { + incomingRange_ = + IntegerValueRange(incomingRange.getValue().intersection(range)); + } + } + + // step 4: Update the value range. Note that we are using `join` operation + // which means `union`. Transfer funtion must be monotone! The resolver + // would otherwise fall into infinite loop. + ChangeResult changed = lattice->join(incomingRange_); + LLVM_DEBUG({ + OpPrintingFlags flags; + flags.skipRegions(true); + DBGS() << ((changed == ChangeResult::Change) ? ">Inferred range for: " + : ">Remain unchanged: "); + resultVal.printAsOperand(llvm::dbgs(), flags); + llvm::dbgs() << ", resulting state:" << lattice->getValue() + << ", in value-range: " << incomingRange_ << "\n"; + }); + + // step 5: Add those ops that depends on this op to the worklist. The resolver + // will iterate all items in the worklist until it become empty. + propagateIfChanged(lattice, changed); +} + +std::optional +TritonIntegerRangeAnalysis::rectifyInfferableRange( + InferIntRangeInterface rface, + ArrayRef srcLattices, + const IntegerValueRange &range) { + + auto op = rface.getOperation(); + + // step 1: rule out some operations we cannot handle + if (!llvm::isa(op) || + range.isUninitialized()) { + return std::nullopt; + } + + auto isPos = [](const ConstantIntRanges &range) { + // Return true iff in both unsigned and signed representation, the most + // siganificant bit is always 0. + return range.umax().isNonNegative() && range.smax().isNonNegative() && + range.smin().isNonNegative(); + }; + + // Not appliable to those bin-ops yielding unsigned int. + if (!signedIntValues.count(op->getResult(0))) + return std::nullopt; + + // step 2: Do nothing if the value-range is already a non-negative range. + const ConstantIntRanges &resultRange = range.getValue(); + + if (isPos(resultRange)) + return std::nullopt; + + // step 3: special handling of arith::TruncIOp + if (llvm::isa(op)) { + if (!srcLattices[0] || srcLattices[0]->getValue().isUninitialized()) + return std::nullopt; + + const ConstantIntRanges srcRange = srcLattices[0]->getValue().getValue(); + if (!isPos(srcRange)) + return std::nullopt; + + // assume NSW + APInt umax = APInt::getSignedMaxValue(resultRange.umax().getBitWidth()); + return ConstantIntRanges::fromUnsigned(resultRange.umin(), umax); + } + + // step 4: rule out some messy situations + // If the MSB of umin is "1", bailout + if (!resultRange.umin().isNonNegative()) + return std::nullopt; + + // If the value-ranges of operands are somehow missing, we can do nothing + if (!srcLattices[0] || !srcLattices[1] || + srcLattices[0]->getValue().isUninitialized() || + srcLattices[1]->getValue().isUninitialized()) + return std::nullopt; + + auto opndRange0 = srcLattices[0]->getValue().getValue(); + auto opndRange1 = srcLattices[1]->getValue().getValue(); + + // bail out if one of operands' is not non-negative + if (!isPos(opndRange0) || !isPos(opndRange1)) + return std::nullopt; + + APInt umax(resultRange.umax()); + if (!umax.isNonNegative()) { + // Saturate umax to 0x7f...f + umax = APInt::getSignedMaxValue(umax.getBitWidth()); + } + + return ConstantIntRanges::fromUnsigned(resultRange.umin(), umax); +} + +void TritonIntegerRangeAnalysis::visitYieldHelper(Operation *op, Value value) { + auto yieldOp = dyn_cast(op); + LDBG("visit yieldOp: " << yieldOp); + + dataflow::IntegerValueRangeLattice *srcLattice = getLatticeElement(value); + + for (auto iter : llvm::enumerate(yieldOp->getOperands())) { + if (iter.value() != value) + continue; + + size_t idx = iter.index(); + Operation *parentOp = yieldOp->getParentOp(); + + if (auto ifOp = dyn_cast(parentOp)) { + // Get the corresponding scf.if result and its lattice + mlir::OpResult res = parentOp->getResult(idx); + dataflow::IntegerValueRangeLattice *resLattice = getLatticeElement(res); + auto changed = resLattice->join(*srcLattice); + propagateIfChanged(resLattice, changed); + + LLVM_DEBUG({ + OpPrintingFlags flags; + flags.skipRegions(true); + DBGS() << ((changed == ChangeResult::Change) + ? ">yieldOp bring change: " + : ">yieldOp bring no change:"); + res.printAsOperand(llvm::dbgs(), flags); + llvm::dbgs() << ", resulting value-range: " + << resLattice->getValue().getValue() + << ", in value-range: " + << srcLattice->getValue().getValue() << "\n"; + }); + } + } +} + LogicalResult TritonIntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef resultsLattices) { - LDBG("Inferring ranges for " << *op); - // This callback is almost exactly like the callback in - // IntegerRangeAnalysis::visitOperation except we do not "short-cicruit" the - // analysis by inferring a maximum range for loop results (instead we - // perform a check based on visit counts in visitRegionSuccessors). - auto joinCallback = [&op, &resultsLattices, - this](Value v, const IntegerValueRange &incomingRange) { - auto result = dyn_cast(v); - if (!result) - return; - assert(llvm::is_contained(op->getResults(), result)); - - dataflow::IntegerValueRangeLattice *lattice = - resultsLattices[result.getResultNumber()]; - IntegerValueRange incomingRange_ = incomingRange; - if (auto maybeRange = maybeGetAssumedRange(v)) { - incomingRange_ = - IntegerValueRange(incomingRange.getValue().intersection(*maybeRange)); - } - ChangeResult changed = lattice->join(incomingRange_); + + // step 1: Figure out the implied value-range of result-value. + opResultAssumption.clear(); + for (mlir::OpResult result : op->getResults()) { + auto assumedRange = maybeGetAssumedRange(result, op->getBlock()); + if (assumedRange.has_value()) + opResultAssumption.insert(std::pair(result, *assumedRange)); + } + + // step 2: call helper function inferring the value range. If assumed value- + // range is present, the transfer-function will intersect the assumed value- + // value with the inferred value range. + LogicalResult visitResult = + visitOperationHelper(op, operands, resultsLattices); + + // step 3: If previous step failed to infer value-range, apply assumed + // value-range is present. + for (auto [index, lattice] : llvm::enumerate(resultsLattices)) { + Value result = op->getResult(index); + const auto assumedIter = opResultAssumption.find(result); + if (assumedIter == opResultAssumption.end()) + continue; + + const mlir::IntegerValueRange &vr = lattice->getValue(); + if (!vr.isUninitialized() && !AMD::isEmptyInitializedRange(vr.getValue())) + continue; + + const ConstantIntRanges &assumedVr = assumedIter->second; + IntegerValueRange range(assumedVr); + auto changed = lattice->join(range); + LLVM_DEBUG({ if (changed == ChangeResult::Change) { - DBGS() << "Inferred range for "; - v.printAsOperand(llvm::dbgs(), {}); - llvm::dbgs() << " to " << incomingRange_ << "\n"; + DBGS() << ">Force apply assumed value range. value:"; + result.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << ", range:" << range << "\n"; } }); propagateIfChanged(lattice, changed); - }; + } - // Initialize lattices with assumptions. - for (const auto &resultLattice : resultsLattices) { - if (!resultLattice->getValue().isUninitialized()) - continue; - auto anchor = resultLattice->getAnchor(); - if (auto assumptions = this->assumptions.lookup(anchor); - !assumptions.empty()) { - setToEntryState(resultLattice); - return success(); + // step 4: The dataflow framework does not understand SCF. It skip yieldOp + // as it has no result. To workaround this problem, we visit all yieldOp + // which depends on this operation. + for (int resIdx = 0, resEnd = op->getNumResults(); resIdx < resEnd; + ++resIdx) { + mlir::OpResult res = op->getResult(resIdx); + + for (mlir::OpOperand &use : res.getUses()) { + mlir::Operation *depOp = use.getOwner(); + if (auto yield = dyn_cast(depOp)) + visitYieldHelper(yield, res); } } + return visitResult; +} + +LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices) { + LDBG("Inferring ranges for " << *op); + + // This callback is almost exactly like the callback in + // IntegerRangeAnalysis::visitOperation except we do not "short-cicruit" the + // analysis by inferring a maximum range for loop results (instead we + // perform a check based on visit counts in visitRegionSuccessors). + auto joinCallback = [&op, &operands, &resultsLattices, + this](Value v, const IntegerValueRange &incomingRange) { + this->defaultTransferFunc(op, v, operands, resultsLattices, incomingRange); + }; + // Ops with fixed/constant ranges. if (llvm::isa( op)) { @@ -418,6 +801,11 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( return lattice->getValue(); }); + if (auto sliceOp = dyn_cast(op)) { + joinCallback(sliceOp->getResult(0), argIntValueRanges[0]); + return success(); + } + // Ops with actually changing/variable input/output ranges. if (llvm::isa(op)) { @@ -446,6 +834,10 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( return success(); } + // TODO: It looks like inferResultRangesFromOptional does not handle bunch + // of operations very well: + // - arith.shrui, e.g. arith.shrui %arg3, %c5_i32 + // if (auto inferrable = dyn_cast(op)) { inferrable.inferResultRangesFromOptional(argIntValueRanges, joinCallback); return success(); @@ -456,16 +848,23 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( } void TritonIntegerRangeAnalysis::initializeFuncOp(tt::FuncOp op) { + Block *entryBlock = getFuncEntryBlock(op); for (BlockArgument argument : op.getArguments()) { - if (!this->assumptions.lookup(argument).empty()) { - dataflow::IntegerValueRangeLattice *argLattice = - getLatticeElement(argument); - auto anchor = argLattice->getAnchor(); - IntegerValueRange range = IntegerValueRange::getMaxRange(anchor); - if (auto maybeRange = maybeGetAssumedRange(anchor)) - range = *maybeRange; - (void)argLattice->join(range); - } + if (!this->assumptions.count(argument)) + continue; + + dataflow::IntegerValueRangeLattice *argLattice = + getLatticeElement(argument); + + IntegerValueRange range = IntegerValueRange::getMaxRange(argument); + if (auto maybeRange = maybeGetAssumedRange(argument, entryBlock)) + range = *maybeRange; + + // The lattice must in "bottom" state, The join() operation is to set the + // state to the given "range". + assert(argLattice->getValue().isUninitialized() && + "lattice must be in bottom state"); + (void)argLattice->join(range); } } @@ -474,7 +873,7 @@ void TritonIntegerRangeAnalysis::visitRegionSuccessors( RegionBranchPoint successor, ArrayRef abstractLattices) { LLVM_DEBUG({ - DBGS() << "Inferring ranges for "; + DBGS() << "Visit Region Succesors of "; OpPrintingFlags flags; flags.skipRegions(true); branch.print(llvm::dbgs(), flags); @@ -609,7 +1008,7 @@ struct FoldTrueCmpIOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; FoldTrueCmpIOp(MLIRContext *context, DataFlowSolver *solver) - : OpRewritePattern(context), solver(solver) {}; + : OpRewritePattern(context), solver(solver){}; LogicalResult matchAndRewrite(arith::CmpIOp cmpOp, PatternRewriter &rewriter) const override { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index fb0892c90b1b..2c47964ca8ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -53,182 +53,40 @@ bool isSplatOneConstTensor(const Value v) { return false; } -bool verifyNonSmallerByAssumption( - Value expr, const DenseMap> &assumptions, - const std::function &matchesOther) { - if (!assumptions.contains(expr)) +bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, std::shared_ptr solver) { + Value elemIdx = addPtrOp.getOffset(); + LDBG("Determing element index value range: " << elemIdx); + + // step 1: get the value range of the element index + const auto *lattice = solver->lookupState(elemIdx); + const mlir::IntegerValueRange &vr = lattice->getValue(); + if (vr.isUninitialized() || AMD::isEmptyInitializedRange(vr.getValue())) { + LDBG("cannot get meaningful value range"); return false; - for (Operation *assume : assumptions.at(expr)) { - auto cmpOp = llvm::dyn_cast(assume); - if (!cmpOp) - continue; - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::eq: - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: { - if (cmpOp.getLhs() == expr && matchesOther(cmpOp.getRhs())) { - LDBG(" " << expr << " non-neg by assumption " << cmpOp); - return true; - } - break; - } - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: { - if (cmpOp.getRhs() == expr && matchesOther(cmpOp.getLhs())) { - LDBG(" " << expr << " non-neg by assumption " << cmpOp); - return true; - } - break; - } - default: - break; - } - } - return false; -} - -bool verifyNonSmallerByAssumption( - Value expr, const DenseMap> &assumptions, - Value other) { - return verifyNonSmallerByAssumption( - expr, assumptions, [&](auto otherAssum) { return otherAssum == other; }); -} - -bool verifyNonNegativeExpr( - Value expr, const DenseMap> &assumptions, - std::shared_ptr solver) { - LDBG("Determing if non-negative: " << expr); - - auto nonNegativePred = [&solver](Value v) -> bool { - if (const auto *r = - solver->lookupState(v)) { - if (r->getValue().isUninitialized()) - return false; - if (AMD::isEmptyInitializedRange(r->getValue().getValue())) - return false; - } - return succeeded(dataflow::staticallyNonNegative(*solver, v)); }; - if (nonNegativePred(expr)) - return true; + const auto &smin = vr.getValue().smin(); + const auto &smax = vr.getValue().smax(); - // Recurse if the operation is defined - Operation *op = expr.getDefiningOp(); - if (!op) { - LDBG(" No defining op, assuming possibly negative"); + LDBG("Element idx range: " << smin << " : " << smax); + if (smin.isNegative() || smax.isNegative()) return false; - } - bool nonNegative = - llvm::TypeSwitch(expr.getDefiningOp()) - // Various unary triton ops that don't change the sign of the operand - .Case([&](auto unaryOp) { - return verifyNonNegativeExpr(unaryOp.getOperand(), assumptions, - solver); - }) - .Case([&](auto gatherOp) { - return verifyNonNegativeExpr(gatherOp.getSrc(), assumptions, - solver); - }) - // Joining two non-negative tensors is still non-negative - .Case([&](auto joinOp) { - return verifyNonNegativeExpr(joinOp.getLhs(), assumptions, - solver) && - verifyNonNegativeExpr(joinOp.getRhs(), assumptions, solver); - }) - // Returns a tensor representing histogram: histograms only contain - // buckets of non-negative values. - .Case([&](auto) { return true; }) - .Case([&](auto makeRangeOp) { - // See the warning in TritonOps.td: getStart/getEnd return unsigned, - // so we need to look through get*Attr. - return makeRangeOp.getStartAttr().getInt() >= 0 && - makeRangeOp.getEndAttr().getInt() >= 0; - }) - .Case( - [&](auto constIntOp) { return constIntOp.value() >= 0; }) - .Case([&](arith::ConstantOp constOp) { - Value val = constOp.getResult(); - DenseIntElementsAttr constVal; - if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) - return constVal.getSplatValue().isNonNegative(); - return false; - }) - .Case([&](auto) { - // These are defined as signless, but are actually unsigned - return true; - }) - .Case([&](auto maxOp) { - // max(a,b) >= 0 iff a>=0 || b>=0 - return verifyNonNegativeExpr(maxOp.getLhs(), assumptions, solver) || - verifyNonNegativeExpr(maxOp.getRhs(), assumptions, solver); - }) - .Case([&](auto remsiOp) { - // a % b >= 0 iff a>=0 - return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions, solver); - }) - .Case([&](Operation *unaryOp) { - // a = OP b >= 0 iff b >= 0 - return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions, - solver); - }) - // Casting from arbitrary data does *not* guarantee the offset is in - // range (even if pointer, or the data is non-negative when - // interpreted as the src's type). - .Case( - [&](auto) { return false; }) - .Case( - // These OPs also return unsigned values. - // TODO: We can also sniff whether a Value is unsigned by looking - // for whether or not it's used as an argument to one of - // these OPs. - [&](auto uOp) { return true; }) - .Case( - // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when - // OP != sub - [&](Operation *binOp) { - return verifyNonNegativeExpr(binOp->getOperand(0), assumptions, - solver) && - verifyNonNegativeExpr(binOp->getOperand(1), assumptions, - solver); - }) - // TODO: more scf - .Case([&](auto ifOp) { - auto results = ifOp.getResults(); - auto it = std::find(results.begin(), results.end(), expr); - assert(it != results.end() && "expr should be the result of ifOp"); - auto resultIdx = it - results.begin(); - - // If we're here then we must have both then/else regions - // (each with 1 block) and each region must terminate with an - // `scf.yield` expression. - auto thenYield = cast(ifOp.thenYield()); - auto elseYield = cast(ifOp.elseYield()); - return verifyNonNegativeExpr(thenYield->getOperand(resultIdx), - assumptions, solver) && - verifyNonNegativeExpr(elseYield->getOperand(resultIdx), - assumptions, solver); - }) - .Case([&](auto op) { - // If a user annotates tl.assume(a >= b) then we know a - b >= 0 - return verifyNonSmallerByAssumption(op.getLhs(), assumptions, - op.getRhs()); - }) - .Case([&](auto op) { - return verifyNonNegativeExpr(op->getOperand(0), assumptions, - solver); - }) - .Default([&](Operation *) { - // Conservatively assume that the expression is negative - LDBG(" Unhandled op, cannot assume non-negative"); - return false; - }); - return nonNegative; + // step 2: get element size + Type elemTy = getElementTypeOrSelf(addPtrOp.getType()); + while (auto ptrTy = dyn_cast(elemTy)) + elemTy = ptrTy.getPointeeType(); + + // step 3: check of byte-offset is within 2G + int64_t elemBitSz = elemTy.getIntOrFloatBitWidth(); + int64_t elemMaxIdx = smax.getSExtValue(); + int64_t byteOfst = (elemBitSz * elemMaxIdx + elemBitSz + 7)/8; + int64_t szLimit2GB = (1L << 31) - 1; + + LDBG("element bit sz:" << elemBitSz << ", max byte offset:" << byteOfst << + ((szLimit2GB > byteOfst) ? ", out or range" : ",in range")); + + return byteOfst <= szLimit2GB ; } bool isFuncArgWith32bitPtrRange(mlir::Value value) { @@ -294,7 +152,8 @@ bool canUseBufferOps(Value ptr, LDBG("base-ptr as tt.pointer_range=32 attribute"); return true; } - return verifyNonNegativeExpr(offset, assumptions, std::move(solver)); + + return isByteOffsetSmallerThan2GB(addPtrOp, std::move(solver)); } // Extract stride of the blocked offset of LD/ST ops. @@ -718,8 +577,10 @@ struct TritonAMDGPUConvertToBufferOpsPass DenseMap> assumptions = AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); std::shared_ptr solver = createDataFlowSolver(); + AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load(assumptions, + &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp index 37f411f31403..3531f81f66b2 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp @@ -19,7 +19,8 @@ struct TritonAMDFoldTrueCmpIOpPass ModuleOp mod = getOperation(); std::unique_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load(assumptions, & + getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index df20182df198..c4d2b7e43710 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -31,7 +31,8 @@ struct TestAMDRangeAnalysisPass AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load(assumptions, + &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); From 8f0746dee2a86835141b3d2a17f413b48607db4a Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Sun, 5 Oct 2025 10:05:21 -0700 Subject: [PATCH 2/9] git format --- .../amd/lib/Analysis/RangeAnalysis.cpp | 2 +- .../ConvertToBufferOps.cpp | 19 +++++++++++-------- .../TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp | 4 ++-- .../lib/Analysis/TestAMDRangeAnalysis.cpp | 4 ++-- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index c4b023b49298..fd52b006b01d 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -1008,7 +1008,7 @@ struct FoldTrueCmpIOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; FoldTrueCmpIOp(MLIRContext *context, DataFlowSolver *solver) - : OpRewritePattern(context), solver(solver){}; + : OpRewritePattern(context), solver(solver) {}; LogicalResult matchAndRewrite(arith::CmpIOp cmpOp, PatternRewriter &rewriter) const override { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 2c47964ca8ef..b274eb74e692 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -53,12 +53,14 @@ bool isSplatOneConstTensor(const Value v) { return false; } -bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, std::shared_ptr solver) { +bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, + std::shared_ptr solver) { Value elemIdx = addPtrOp.getOffset(); LDBG("Determing element index value range: " << elemIdx); // step 1: get the value range of the element index - const auto *lattice = solver->lookupState(elemIdx); + const auto *lattice = + solver->lookupState(elemIdx); const mlir::IntegerValueRange &vr = lattice->getValue(); if (vr.isUninitialized() || AMD::isEmptyInitializedRange(vr.getValue())) { LDBG("cannot get meaningful value range"); @@ -80,13 +82,14 @@ bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, std::shared_ptr byteOfst) ? ", out or range" : ",in range")); + LDBG("element bit sz:" << elemBitSz << ", max byte offset:" << byteOfst + << ((szLimit2GB > byteOfst) ? ", out or range" + : ",in range")); - return byteOfst <= szLimit2GB ; + return byteOfst <= szLimit2GB; } bool isFuncArgWith32bitPtrRange(mlir::Value value) { @@ -579,8 +582,8 @@ struct TritonAMDGPUConvertToBufferOpsPass std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions, - &getAnalysis()); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp index 3531f81f66b2..07defebe6cd6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp @@ -19,8 +19,8 @@ struct TritonAMDFoldTrueCmpIOpPass ModuleOp mod = getOperation(); std::unique_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions, & - getAnalysis()); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index c4d2b7e43710..da911bef13f6 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -31,8 +31,8 @@ struct TestAMDRangeAnalysisPass AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions, - &getAnalysis()); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); From 60496cf28f28e356379625e9212d94408e0c7849 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Sun, 5 Oct 2025 10:47:14 -0700 Subject: [PATCH 3/9] fix crash --- .../lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index b274eb74e692..134d8835b968 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -61,9 +61,16 @@ bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, // step 1: get the value range of the element index const auto *lattice = solver->lookupState(elemIdx); + if (!lattice) { + // Note not always able to get lattice, e.g. the offset is obtained from + // tt.load. + LDBG("cannot get lattice associated with the offset"); + return false; + } + const mlir::IntegerValueRange &vr = lattice->getValue(); if (vr.isUninitialized() || AMD::isEmptyInitializedRange(vr.getValue())) { - LDBG("cannot get meaningful value range"); + LDBG("Cannot get value range of the offset"); return false; }; From 8ad6a23c963ccc08ed4d005cc6834e4ed0b278f5 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Mon, 6 Oct 2025 23:34:06 -0700 Subject: [PATCH 4/9] fix potential bugs and add tests --- test/TritonGPU/amd/amd-range-analysis.mlir | 251 ++++++++++++++++++ .../amd/include/Analysis/RangeAnalysis.h | 5 +- .../amd/lib/Analysis/RangeAnalysis.cpp | 96 ++++++- .../ConvertToBufferOps.cpp | 22 +- .../lib/Analysis/TestAMDRangeAnalysis.cpp | 3 +- 5 files changed, 359 insertions(+), 18 deletions(-) diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 92498ed94c17..2904d2709637 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1726,3 +1726,254 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +//def scfif_range1(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 +// else: +// z = y + 4; # to check z in [6, 103] +// z2 = z + 1 # to check z2 in [0, umax]/[smin, smax] +// tl.store(output_ptr + offsets, z2, mask) +// +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range1(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range2(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// tl.assume(x < 20) +// tl.assume(x > 0) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 // check z in [4, 22] +// else: +// z = y + 4; // check z in [6, 103] +// z2 = z + 1 // check z2 in [5, 104] +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range2(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %2 : i1 + %3 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %3 : i1 + %4 = tt.get_program_id x : i32 + %5 = arith.muli %4, %c1024_i32 : i32 + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %7 = tt.splat %5 : i32 -> tensor<1024xi32, #blocked> + %8 = arith.addi %7, %6 : tensor<1024xi32, #blocked> + %9 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %8, %9 : tensor<1024xi32, #blocked> + %11 = arith.cmpi sgt, %x, %y : i32 + %12 = scf.if %11 -> (i32) { + // expected-remark@+1 {{unsigned : [4, 22] signed : [4, 22]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [5, 104] signed : [5, 104]}} + %13 = arith.addi %12, %c1_i32 : i32 + %14 = arith.addi %7, %6 : tensor<1024xi32, #blocked> + %15 = arith.sitofp %13 : i32 to f32 + %16 = tt.splat %15 : f32 -> tensor<1024xf32, #blocked> + %17 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %18 = tt.addptr %17, %14 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %18, %16, %10 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range3(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 +// else: +// tl.assume(x < 20) # should not have impact to the x occurrences in then block! +// tl.assume(x > 0) +// z = y + 4; +// z2 = z + 1 +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range3(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + // expected-remark@+1 {{[0, 4294967295] signed : [-2147483648, 2147483647]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + %17 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %17 : i1 + %18 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %18 : i1 + // expected-remark@+1 {{[6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{[0, 4294967295] signed : [-2147483648, 2147483647]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range4(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 // check the tl.assume is applicable to this statement +// tl.assume(x < 20) +// tl.assume(x > 0) +// else: +// z = y + 4; +// z2 = z + 1 +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range4(%x: i32 loc("x"), %y: i32 loc("y"), %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("output_ptr"), %n_elements: i32 {tt.divisibility = 16 : i32} loc("n_elements")) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + %17 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %17 : i1 + %18 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %18 : i1 + // expected-remark@+1 {{unsigned : [4, 22] signed : [4, 22]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [5, 104] signed : [5, 104]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index be19640a878e..37370bd5a9e0 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -37,9 +37,9 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { TritonIntegerRangeAnalysis( DataFlowSolver &solver, const DenseMap> &assumptions, - DominanceInfo *dominanceInfo) + DominanceInfo *dominanceInfo, bool assumeNoArithOverflow_ = false) : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions), - domInfo(dominanceInfo) {} + domInfo(dominanceInfo), assumeNoArithOverflow(assumeNoArithOverflow_) {} void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; @@ -162,6 +162,7 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { DenseSet signedIntValues; llvm::SmallMapVector opResultAssumption; DominanceInfo *domInfo = nullptr; + bool assumeNoArithOverflow = false; }; std::optional>> diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index fd52b006b01d..b6e5ce53f145 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -15,6 +15,53 @@ #include #include +// Some implementation notes: +// 1: tl.assume statements +// - A value may have multiple assume-operations (assume-ops for short) +// associated with it. At point p, we only take into account those assume-ops +// whose enclosing basic blocks dominate the basic-block where p belong to. +// - See some examples in the comment to maybeGetAssumedRangeHelper(). +// - The assumed value-range is inferred right before an operation is visited, +// but the assume value-range does not directly apply to the lattice. +// - After the operation is visited, if lattices of the results remain in +// "bottom" state (i.e. uninitialized), the assumed value-range will be used +// instead. +// - The transfer-function is supposed to apply intersection between inferred +// value-range and assume value-range. However, the +// IntegerValueRangeLattice::meet() seems to be a silent no-op. For now, +// the transfer-function only use assumed-value range. Intersecting the +// inferred value-range with assume value-range still guarantee monotonicity +// of the transfer function. +// +// 2. SCF. +// - Unfortunately, the data-flow framework does not understand SCF! Running +// this pass together DCE and constant-propagation analysis make things even +// more complicated. +// - The override function visitOperaion() does not visit yieldOp, because +// it has no result and is ignored by the framework. +// - On top of that, while the framework provides a way to visit SCF's +// LHs/incoming-value (via visitRegionSuccessors), it does not provide a hook +// for processing the RHS. +// - To workaround this problem, right after an operation is visited, we check +// if it is used by a yield-op. If so, process the yield-op immediately. +// - Here is the steps about how SCF.if is handled. +// +// x, y = scf.if cond { +// then-clause +// yield a, b +// } else { +// else-clause +// yield c, d +// } +// z = add x, y +// +// o. visitRegionSuccessors(scf.if) is called, which does nothing. Once it +// returns, the base-class return immediately, and hence the +// visitOperation(scf.if) is not called. +// o. The DCE analysis initially mark blocks of then- and else-clause as +// "dead", and they are skipped for now. +// o. +// #undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-range-analysis" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -80,6 +127,7 @@ void inferResultRanges(tt::MakeRangeOp *op, SetIntRangeFn setResultRange) { assert(llvm::isa(resTy.getElementType()) && "expected int type"); IntegerType elTy = llvm::cast(resTy.getElementType()); auto bitWidth = mlir::ConstantIntRanges::getStorageBitwidth(elTy); + // NOTE: make_range(begin, end) yields a half open interval, [begin, end). setResultRange(result, ConstantIntRanges::range( /*min*/ {/*numBits*/ bitWidth, /*val*/ op->getStart(), @@ -163,8 +211,9 @@ maybeGetAssumedRangeHelper(Operation *assumption, Value anchor, Block *useBlock, return {}; } - Block *anchorBlock = anchor.getParentBlock(); - if (!anchorBlock || !domInfo->dominates(anchorBlock, useBlock)) + // The block where tl.assume resides must dominate the block where the value + // is referenced! + if (!useBlock || !domInfo->dominates(cmpOp->getBlock(), useBlock)) return {}; bool isSigned = true; @@ -277,8 +326,17 @@ maybeGetAssumedRange(const SetVector &allAssumptions, Value anchor, return result; } -// arith dialect in general does not differentiate signed int and unsigned int; -// integer value is signed or unsigned depends on how it's used. +// Many operations in arith dialect do not differentiate signed int and unsigned +// int, e.g., arith::AddIOp, arith::MullOp. This function try to extrapolate the +// type (sint or uint) of the Operation from the its UD and DU chains. +// +// TODO: This function seems to be useful for proving a quantity is a +// non-negative. However, it is less so in proving a quantity is smaller than +// specified upper bound. In fact, turning off this feature only sees 5 lines +// difference in amd-range-analysis.mlir. For now, it is turned on only in +// TestAMDRangeAnalysis.cpp. For now, we just keep this code for a while and +// see if it will be useful for some real world applications. +// static void collectValueOfSignedInt(Operation *top, DenseSet &valueSet) { SetVector worklist; @@ -485,7 +543,8 @@ bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp) { LogicalResult TritonIntegerRangeAnalysis::initialize(Operation *top) { signedIntValues.clear(); - collectValueOfSignedInt(top, signedIntValues); + if (assumeNoArithOverflow) + collectValueOfSignedInt(top, signedIntValues); return Base::initialize(top); } @@ -704,7 +763,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( ArrayRef operands, ArrayRef resultsLattices) { - // step 1: Figure out the implied value-range of result-value. + // step 1: Figure out the implied value-range of result and source operands opResultAssumption.clear(); for (mlir::OpResult result : op->getResults()) { auto assumedRange = maybeGetAssumedRange(result, op->getBlock()); @@ -712,11 +771,32 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( opResultAssumption.insert(std::pair(result, *assumedRange)); } - // step 2: call helper function inferring the value range. If assumed value- + llvm::SmallVector + opndValueRanges; + + llvm::SmallVector, 4> + newSrcLattices; + + for (auto [index, opnd] : llvm::enumerate(op->getOperands())) { + auto assumedRange = maybeGetAssumedRange(opnd, op->getBlock()); + if (!assumedRange.has_value()) { + opndValueRanges.push_back(operands[index]); + continue; + } + + auto newLattice = + std::make_unique(opnd); + (void)newLattice->join(IntegerValueRange(*assumedRange)); + opndValueRanges.push_back(newLattice.get()); + newSrcLattices.push_back(std::move(newLattice)); + } + assert(opndValueRanges.size() == operands.size() && "size disagree"); + + // step 3: call helper function inferring the value range. If assumed value- // range is present, the transfer-function will intersect the assumed value- // value with the inferred value range. LogicalResult visitResult = - visitOperationHelper(op, operands, resultsLattices); + visitOperationHelper(op, opndValueRanges, resultsLattices); // step 3: If previous step failed to infer value-range, apply assumed // value-range is present. diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 134d8835b968..0710a6ef65ed 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -56,15 +56,15 @@ bool isSplatOneConstTensor(const Value v) { bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, std::shared_ptr solver) { Value elemIdx = addPtrOp.getOffset(); - LDBG("Determing element index value range: " << elemIdx); + LDBG("Determing value-range of element-index: " << elemIdx); - // step 1: get the value range of the element index + // step 1: Get the value range of the element index const auto *lattice = solver->lookupState(elemIdx); if (!lattice) { - // Note not always able to get lattice, e.g. the offset is obtained from - // tt.load. - LDBG("cannot get lattice associated with the offset"); + // Note that it is not always able to get lattice, e.g. the element-index + // is defined by a tt.load. + LDBG("Cannot get lattice"); return false; } @@ -77,15 +77,23 @@ bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, const auto &smin = vr.getValue().smin(); const auto &smax = vr.getValue().smax(); - LDBG("Element idx range: " << smin << " : " << smax); + LDBG("Element-index value-range: " << smin << " : " << smax); if (smin.isNegative() || smax.isNegative()) return false; - // step 2: get element size + // step 2: Get element type and size. + // e.g. addPtrOp.getType is tensor<64x64x!tt.ptr, then elemTy is + // !tt.ptr, and dereferencing elemTy gets f16. + // TODO: Not sure if we need to keep dereferencing in a loop. Type elemTy = getElementTypeOrSelf(addPtrOp.getType()); while (auto ptrTy = dyn_cast(elemTy)) elemTy = ptrTy.getPointeeType(); + if (!elemTy || !elemTy.isIntOrFloat()) { + LDBG("unknown element type: " << elemTy); + return false; + } + // step 3: check of byte-offset is within 2G int64_t elemBitSz = elemTy.getIntOrFloatBitWidth(); int64_t elemMaxIdx = smax.getSExtValue(); diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index da911bef13f6..105fa1e8cffc 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -32,7 +32,8 @@ struct TestAMDRangeAnalysisPass std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = solver->load( - assumptions, &getAnalysis()); + assumptions, &getAnalysis(), + /*assumeNoArithOverflow=*/true); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); From 917688a5890c8846537178d0acd645dba6552470 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Tue, 7 Oct 2025 10:09:10 -0700 Subject: [PATCH 5/9] add tech note; preparing for code review --- .../amd/lib/Analysis/RangeAnalysis.cpp | 132 ++++++++++++------ 1 file changed, 92 insertions(+), 40 deletions(-) diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index b6e5ce53f145..bda5bbac6cd9 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -15,53 +15,105 @@ #include #include -// Some implementation notes: -// 1: tl.assume statements -// - A value may have multiple assume-operations (assume-ops for short) +// Some notes: +// +// 1. Framework +// 1.1) This pass is based on MLIR's dataflow framework. In hindsight, maybe it +// is ill-fit for what we need. +// 1.2) If I understand correctly, the MLIR's dataflow framework is a +// combination +// of traditional iterative dataflow analysis and Sparse Conditional +// Constant propagation (SCCP). +// 1.3) Iterative dataflow analysis requires transfer function to be monotone. +// However, not all value-ranges keep increasing when the analysis progress. +// Consider the expression x - y, while x and y's value-range may keep +// increasing, the difference between them does not necessarily keep +// increasing as well. +// 1.4) SCCP part is not necessary for this pass. We don't expect many dead +// code at +// the moment this analysis is invoked. The SCCP part only make the anlaysis +// take longer time to converge, and it make more complicated to workaround +// the framework's limitations. +// 1.5) The MLIR dataflow framework does not understand SCF. On top of that it +// provides little interfaces to customize it. So, we have to rely on hack +// to sidestep these limitations. +// 1.6 Maybe just walking the code top-dowm is suffice for range-analysis? +// For loops, figuring out IVs' value-ranges before loops are entered, and +// progress to loop-body, without visiting back-edge for non-SCF loops. +// 1.7 As with SCCP which maintain two worklists, one for control-flow +// dependence, one for data-flow dependence. The framework seems to maintain +// a single unified worklist, with each item being a pair of +// . +// +// 2: tl.assume statements +// 2.1) A value may have multiple assume-operations (assume-ops for short) // associated with it. At point p, we only take into account those assume-ops -// whose enclosing basic blocks dominate the basic-block where p belong to. -// - See some examples in the comment to maybeGetAssumedRangeHelper(). -// - The assumed value-range is inferred right before an operation is visited, -// but the assume value-range does not directly apply to the lattice. -// - After the operation is visited, if lattices of the results remain in -// "bottom" state (i.e. uninitialized), the assumed value-range will be used -// instead. -// - The transfer-function is supposed to apply intersection between inferred -// value-range and assume value-range. However, the -// IntegerValueRangeLattice::meet() seems to be a silent no-op. For now, -// the transfer-function only use assumed-value range. Intersecting the -// inferred value-range with assume value-range still guarantee monotonicity -// of the transfer function. +// whose enclosing basic blocks dominate the basic-block where p belongs to. +// 2.2) See some examples in the comment to maybeGetAssumedRangeHelper(). +// 2.3) The assumed value-range for source and result operands are inferred +// right +// before an operation is visited. +// 2.4) For now, if a value a assumed value-range, we use assumed value-range. +// We should use the intersection of assumed-value-range and inferred-value- +// range. However, it is not always possible: iterative dataflow analysis +// requires that the transfer function must be monotone; in general it's +// dangerous to use both meet() and join() operations. In this pass, +// intersecting inferred value-range with assumed-value-range still guarantee +// its monotonicity. However, the underlying lattice's meet() operation is +// a silent no-op. // -// 2. SCF. -// - Unfortunately, the data-flow framework does not understand SCF! Running -// this pass together DCE and constant-propagation analysis make things even -// more complicated. -// - The override function visitOperaion() does not visit yieldOp, because -// it has no result and is ignored by the framework. -// - On top of that, while the framework provides a way to visit SCF's -// LHs/incoming-value (via visitRegionSuccessors), it does not provide a hook -// for processing the RHS. -// - To workaround this problem, right after an operation is visited, we check -// if it is used by a yield-op. If so, process the yield-op immediately. -// - Here is the steps about how SCF.if is handled. +// 3. SCF. +// 3.1 As mentioned above, MLIR's dataflow framework does not understand SCF. +// 3.2 For example, yield-op will not be visited by subclass's +// visitOperation(). +// That is because the base-class think yield-op has zero result and take +// for granted it has no value to analyze. +// 3.3 The built-in SCCP part makes the visit order somewhat complicated. +// Operations are not visited in forward order. +// 3.4 This is an example explaining how to SCF is processed, and how we +// workaround this problem. // +// op0: cond = ... // x, y = scf.if cond { -// then-clause -// yield a, b +// // then-block +// op1: a = ... +// op2: yield a, b // } else { -// else-clause -// yield c, d +// // else-block +// op3: d = +// op4: yield c, d // } -// z = add x, y -// -// o. visitRegionSuccessors(scf.if) is called, which does nothing. Once it -// returns, the base-class return immediately, and hence the -// visitOperation(scf.if) is not called. -// o. The DCE analysis initially mark blocks of then- and else-clause as -// "dead", and they are skipped for now. -// o. +// op5: z = add x, y // +// step 1: as mentioned in 1.7, multiple analyses comprise the framework with +// an unified worklist. DCE kick in first, when it visit the scf.if, the +// "cond" does not have lattice associated with it. So it initially +// considered both then-block and else-block are dead. +// step 2: after DCE going over all items in the worklist, range-analysis gets +// the chance. op0 is visited, a non-bottom lattice is created for op0's LHS. +// step 3. The baseclass (belong to framework) visits the scf.if +// it calls this class's visitRegionSuccessors(). Basically, +// visitRegionSuccessors() gives subclass a chance to prepare for RHS for +// SCF operations. This class does nothing for scf.if. +// step 3: The base-class returns once sub-class's visitRegionSuccessors() +// returns. Therefor, this class (subclass)'s visitOperand() function is +// *NOT* called with with scf.if. +// step 4: The base-class tries to visit the sub-regions (i.e. then- and else- +// blocks), only finds they are dead (due to step 1) and hence skip them. +// step 5: after step 4, the lattice of x and y are in "bottom" state. +// When op5 is visit, range-analysis find one of source operands is in +// "bottom" state, and do not update z's state. +// ... +// next round starts. +// step 5: DCE found "cond" has non-bottom state associated with it, and mark +// then- and else-block "live" accordingly. +// step 6: Range-analysis get a chance to visit the then- and else-block. +// step 7: when op1 is visited. *HACK KICK IN*. Range-analysis found op1 is +// used by yield-op, it then in turn updates x's state. +// step 8: likewise, then op3's visited, y's state is updated as well. +// step 9: finally, x and y has non-bottom state, when op5 is visited, z's +// state is updated. + #undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-range-analysis" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") From 4fe971a4ceaf88689d0954e1c9083b780e21a9c9 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Thu, 9 Oct 2025 11:15:10 -0700 Subject: [PATCH 6/9] remove yeildop hack --- .../amd/lib/Analysis/RangeAnalysis.cpp | 160 ++++-------------- 1 file changed, 36 insertions(+), 124 deletions(-) diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index bda5bbac6cd9..630e56391c6e 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -21,29 +21,20 @@ // 1.1) This pass is based on MLIR's dataflow framework. In hindsight, maybe it // is ill-fit for what we need. // 1.2) If I understand correctly, the MLIR's dataflow framework is a -// combination -// of traditional iterative dataflow analysis and Sparse Conditional -// Constant propagation (SCCP). +// combination of traditional iterative dataflow analysis and a mighty +// Sparse Conditional Constant propagation (SCCP). // 1.3) Iterative dataflow analysis requires transfer function to be monotone. // However, not all value-ranges keep increasing when the analysis progress. // Consider the expression x - y, while x and y's value-range may keep // increasing, the difference between them does not necessarily keep // increasing as well. -// 1.4) SCCP part is not necessary for this pass. We don't expect many dead -// code at -// the moment this analysis is invoked. The SCCP part only make the anlaysis -// take longer time to converge, and it make more complicated to workaround -// the framework's limitations. -// 1.5) The MLIR dataflow framework does not understand SCF. On top of that it -// provides little interfaces to customize it. So, we have to rely on hack -// to sidestep these limitations. -// 1.6 Maybe just walking the code top-dowm is suffice for range-analysis? +// 1.4) The 1st C in SCCP, i.e. "conditional" part in SCCP part is unnecessary +// for this pass, because we don't expect many dead code at the moment when +// this analysis is invoked. Price for being "conditional" is less about +// compile time but complexity (in terms of debugging and understanding). +// 1.5 Maybe just walking the code top-dowm is suffice for range-analysis: // For loops, figuring out IVs' value-ranges before loops are entered, and // progress to loop-body, without visiting back-edge for non-SCF loops. -// 1.7 As with SCCP which maintain two worklists, one for control-flow -// dependence, one for data-flow dependence. The framework seems to maintain -// a single unified worklist, with each item being a pair of -// . // // 2: tl.assume statements // 2.1) A value may have multiple assume-operations (assume-ops for short) @@ -53,66 +44,16 @@ // 2.3) The assumed value-range for source and result operands are inferred // right // before an operation is visited. -// 2.4) For now, if a value a assumed value-range, we use assumed value-range. -// We should use the intersection of assumed-value-range and inferred-value- -// range. However, it is not always possible: iterative dataflow analysis +// 2.4) For now, if a value has a assumed value-range, we use assumed +// value-range and ignore its inferred value range. It would be nice to +// use the intersection of assumed-value-range and inferred-value-range. +// However, it is not always possible: iterative dataflow analysis // requires that the transfer function must be monotone; in general it's // dangerous to use both meet() and join() operations. In this pass, // intersecting inferred value-range with assumed-value-range still guarantee // its monotonicity. However, the underlying lattice's meet() operation is // a silent no-op. // -// 3. SCF. -// 3.1 As mentioned above, MLIR's dataflow framework does not understand SCF. -// 3.2 For example, yield-op will not be visited by subclass's -// visitOperation(). -// That is because the base-class think yield-op has zero result and take -// for granted it has no value to analyze. -// 3.3 The built-in SCCP part makes the visit order somewhat complicated. -// Operations are not visited in forward order. -// 3.4 This is an example explaining how to SCF is processed, and how we -// workaround this problem. -// -// op0: cond = ... -// x, y = scf.if cond { -// // then-block -// op1: a = ... -// op2: yield a, b -// } else { -// // else-block -// op3: d = -// op4: yield c, d -// } -// op5: z = add x, y -// -// step 1: as mentioned in 1.7, multiple analyses comprise the framework with -// an unified worklist. DCE kick in first, when it visit the scf.if, the -// "cond" does not have lattice associated with it. So it initially -// considered both then-block and else-block are dead. -// step 2: after DCE going over all items in the worklist, range-analysis gets -// the chance. op0 is visited, a non-bottom lattice is created for op0's LHS. -// step 3. The baseclass (belong to framework) visits the scf.if -// it calls this class's visitRegionSuccessors(). Basically, -// visitRegionSuccessors() gives subclass a chance to prepare for RHS for -// SCF operations. This class does nothing for scf.if. -// step 3: The base-class returns once sub-class's visitRegionSuccessors() -// returns. Therefor, this class (subclass)'s visitOperand() function is -// *NOT* called with with scf.if. -// step 4: The base-class tries to visit the sub-regions (i.e. then- and else- -// blocks), only finds they are dead (due to step 1) and hence skip them. -// step 5: after step 4, the lattice of x and y are in "bottom" state. -// When op5 is visit, range-analysis find one of source operands is in -// "bottom" state, and do not update z's state. -// ... -// next round starts. -// step 5: DCE found "cond" has non-bottom state associated with it, and mark -// then- and else-block "live" accordingly. -// step 6: Range-analysis get a chance to visit the then- and else-block. -// step 7: when op1 is visited. *HACK KICK IN*. Range-analysis found op1 is -// used by yield-op, it then in turn updates x's state. -// step 8: likewise, then op3's visited, y's state is updated as well. -// step 9: finally, x and y has non-bottom state, when op5 is visited, z's -// state is updated. #undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-range-analysis" @@ -223,7 +164,7 @@ void inferResultRangesMaxNonNegSigned(Operation *op, } } -// Given an assumption operaiton, try to derive the value range of the value +// Given an assumption operation, try to derive the value range of the value // 's value range at the somewhere in the block "useBlock". // Note that // - The value "anchor" is defined or referenced in the "useBlock" @@ -683,7 +624,7 @@ void TritonIntegerRangeAnalysis::defaultTransferFunc( } // step 4: Update the value range. Note that we are using `join` operation - // which means `union`. Transfer funtion must be monotone! The resolver + // which means `union`. Transfer function must be monotone! The resolver // would otherwise fall into infinite loop. ChangeResult changed = lattice->join(incomingRange_); LLVM_DEBUG({ @@ -718,12 +659,12 @@ TritonIntegerRangeAnalysis::rectifyInfferableRange( auto isPos = [](const ConstantIntRanges &range) { // Return true iff in both unsigned and signed representation, the most - // siganificant bit is always 0. + // significant bit is always 0. return range.umax().isNonNegative() && range.smax().isNonNegative() && range.smin().isNonNegative(); }; - // Not appliable to those bin-ops yielding unsigned int. + // Not applicable to those bin-ops yielding unsigned int. if (!signedIntValues.count(op->getResult(0))) return std::nullopt; @@ -774,42 +715,6 @@ TritonIntegerRangeAnalysis::rectifyInfferableRange( return ConstantIntRanges::fromUnsigned(resultRange.umin(), umax); } -void TritonIntegerRangeAnalysis::visitYieldHelper(Operation *op, Value value) { - auto yieldOp = dyn_cast(op); - LDBG("visit yieldOp: " << yieldOp); - - dataflow::IntegerValueRangeLattice *srcLattice = getLatticeElement(value); - - for (auto iter : llvm::enumerate(yieldOp->getOperands())) { - if (iter.value() != value) - continue; - - size_t idx = iter.index(); - Operation *parentOp = yieldOp->getParentOp(); - - if (auto ifOp = dyn_cast(parentOp)) { - // Get the corresponding scf.if result and its lattice - mlir::OpResult res = parentOp->getResult(idx); - dataflow::IntegerValueRangeLattice *resLattice = getLatticeElement(res); - auto changed = resLattice->join(*srcLattice); - propagateIfChanged(resLattice, changed); - - LLVM_DEBUG({ - OpPrintingFlags flags; - flags.skipRegions(true); - DBGS() << ((changed == ChangeResult::Change) - ? ">yieldOp bring change: " - : ">yieldOp bring no change:"); - res.printAsOperand(llvm::dbgs(), flags); - llvm::dbgs() << ", resulting value-range: " - << resLattice->getValue().getValue() - << ", in value-range: " - << srcLattice->getValue().getValue() << "\n"; - }); - } - } -} - LogicalResult TritonIntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, @@ -876,20 +781,6 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( propagateIfChanged(lattice, changed); } - // step 4: The dataflow framework does not understand SCF. It skip yieldOp - // as it has no result. To workaround this problem, we visit all yieldOp - // which depends on this operation. - for (int resIdx = 0, resEnd = op->getNumResults(); resIdx < resEnd; - ++resIdx) { - mlir::OpResult res = op->getResult(resIdx); - - for (mlir::OpOperand &use : res.getUses()) { - mlir::Operation *depOp = use.getOwner(); - if (auto yield = dyn_cast(depOp)) - visitYieldHelper(yield, res); - } - } - return visitResult; } @@ -1045,6 +936,27 @@ void TritonIntegerRangeAnalysis::visitRegionSuccessors( assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); + // Note: It does not seems to be quite obvious; this loop could update SCF + // operations' LHS. e.g. If the given "branch" argument is scf.if, and the + // scf.if construct looks like following: + // x = scf.if cond + // m = ... // op_m + // yield m + // else + // n = ... // op_n + // yield n + // + // This loop tries to update lattice(x) = join(lattice(m), lattice(n), + // proovided lattice(m) and lattice(n) are initialized. + // + // Note that the state of lattice(m) and lattice(n) was updated in the + // "previous" round. In this "round", the scf.if is vsitied right now, and + // it takes this moment to update its LHS. + // + // Alternatively, when we visit, say op_m, we notice its result is used by + // a yieldOp, get the yieldOp's corresponding receiver, in this case x, and + // update its state accordingly. + // for (Operation *op : predecessors->getKnownPredecessors()) { std::optional operands; if (op == branch) { From aceb15ed67d84c93f788701aa0fe0ba0f3ad3ce5 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Sat, 11 Oct 2025 09:51:28 -0700 Subject: [PATCH 7/9] address code review comment and remove the confusing feature. The confusing feature is to perform value-range analysis assuming arithmetic op has nsw and nuw flags (even they are not present) e.g. pid * block_size will still fit in [0, smax] despite that the pid itself is in [0, smax] --- test/TritonGPU/amd/amd-range-analysis.mlir | 10 +- .../amd/include/Analysis/RangeAnalysis.h | 5 - .../amd/lib/Analysis/RangeAnalysis.cpp | 208 +----------------- .../ConvertToBufferOps.cpp | 4 +- .../lib/Analysis/TestAMDRangeAnalysis.cpp | 3 +- 5 files changed, 17 insertions(+), 213 deletions(-) diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 2904d2709637..5d975534e948 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1316,7 +1316,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %6 = tt.splat %5 : i32 -> tensor<8xi32> %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483654] signed : [-2147483648, 2147483647]}} %8 = arith.addi %6, %7 : tensor<8xi32> tt.return } @@ -1556,10 +1556,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+2 {{unsigned : [1, 2147483647] signed : [1, 2147483647]}} // expected-remark@+1 {{non-neg}} %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 - // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %5 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} @@ -1568,7 +1568,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %8 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %9 = arith.addi %8, %5 : tensor<256x64xi32, #blocked> %10 = tt.splat %4 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> %11 = tt.addptr %10, %9 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> @@ -1714,7 +1714,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}} // expected-remark@+1 {{non-neg}} %20 = tt.splat %2 : i32 -> tensor<32xi32, #blocked> - // expected-remark@+1 {{unsigned : [0, 2147483647] signed : [0, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %21 = arith.muli %6, %20 : tensor<32xi32, #blocked> %22 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr, #blocked> %23 = tt.addptr %22, %21 : tensor<32x!tt.ptr, #blocked>, tensor<32xi32, #blocked> diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index 37370bd5a9e0..df5ff673bac6 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -154,11 +154,6 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { ArrayRef operands, ArrayRef resultsLattices); - std::optional rectifyInfferableRange( - InferIntRangeInterface interface, - ArrayRef srcLattices, - const IntegerValueRange &range); - DenseSet signedIntValues; llvm::SmallMapVector opResultAssumption; DominanceInfo *domInfo = nullptr; diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 630e56391c6e..78362d1fec58 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -42,8 +42,7 @@ // whose enclosing basic blocks dominate the basic-block where p belongs to. // 2.2) See some examples in the comment to maybeGetAssumedRangeHelper(). // 2.3) The assumed value-range for source and result operands are inferred -// right -// before an operation is visited. +// right before an operation is visited. // 2.4) For now, if a value has a assumed value-range, we use assumed // value-range and ignore its inferred value range. It would be nice to // use the intersection of assumed-value-range and inferred-value-range. @@ -87,10 +86,12 @@ tt::FuncOp getEnclosingFunction(Value v) { definingOp = blk->getParentOp(); if (definingOp) { - funcOp = dyn_cast_or_null(definingOp); - if (!funcOp) + if (auto selfIsFunc = dyn_cast(definingOp)) + funcOp = selfIsFunc; + else funcOp = definingOp->getParentOfType(); } + assert(funcOp && "No enclosing tt::FuncOp"); return funcOp; } @@ -319,114 +320,6 @@ maybeGetAssumedRange(const SetVector &allAssumptions, Value anchor, return result; } -// Many operations in arith dialect do not differentiate signed int and unsigned -// int, e.g., arith::AddIOp, arith::MullOp. This function try to extrapolate the -// type (sint or uint) of the Operation from the its UD and DU chains. -// -// TODO: This function seems to be useful for proving a quantity is a -// non-negative. However, it is less so in proving a quantity is smaller than -// specified upper bound. In fact, turning off this feature only sees 5 lines -// difference in amd-range-analysis.mlir. For now, it is turned on only in -// TestAMDRangeAnalysis.cpp. For now, we just keep this code for a while and -// see if it will be useful for some real world applications. -// -static void collectValueOfSignedInt(Operation *top, DenseSet &valueSet) { - SetVector worklist; - - // Initialize the worklist with some known signed interger values. - top->walk([&](Operation *op) { - llvm::TypeSwitch(op) - .Case( - [&](auto addPtrOp) { worklist.insert(addPtrOp.getOffset()); }) - .Case([&](auto binop) { - worklist.insert(binop.getResult()); - worklist.insert(binop.getOperand(0)); - worklist.insert(binop.getOperand(1)); - }) - .Case( - [&](auto sExt) { worklist.insert(sExt.getResult()); }) - .Case([&](auto cmpOp) { - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::sgt: - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: - worklist.insert(cmpOp.getOperand(0)); - worklist.insert(cmpOp.getOperand(1)); - break; - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: - worklist.insert(cmpOp.getOperand(0)); - worklist.insert(cmpOp.getOperand(1)); - break; - default: - break; - }; - }); - }); - - valueSet.clear(); - auto addToWorklist = [&](Value v) { - if (!valueSet.count(v)) - worklist.insert(v); - }; - - while (!worklist.empty()) { - auto v = worklist.back(); - worklist.pop_back(); - Operation *op = v.getDefiningOp(); - - // If the result of this op is signed int, then its source operands are - // singed int. - if (op) { - llvm::TypeSwitch(op) - .Case([&](auto binOp) { - addToWorklist(binOp.getOperand(0)); - addToWorklist(binOp.getOperand(1)); - }) - .Case( - [&](auto unary) { addToWorklist(unary.getOperand()); }); - } - - SmallVector results; - if (op) - results = op->getResults(); - else - results.push_back(v); - - for (auto result : results) { - if (valueSet.count(result)) - continue; - - valueSet.insert(result); - - for (mlir::OpOperand &use : result.getUses()) { - llvm::TypeSwitch(use.getOwner()) - .Case( - [&](auto op) { addToWorklist(op.getResult()); }) - .Case( - [&](auto binOp) { addToWorklist(binOp.getResult()); }); - } - } - } - - LLVM_DEBUG({ - DBGS() << "Values considered as signed int (begin)\n"; - OpPrintingFlags flags; - flags.skipRegions(true); - for (auto v : valueSet) { - DBGS() << " - "; - v.print(llvm::dbgs(), flags); - llvm::dbgs() << "\n"; - } - DBGS() << "Values considered as signed int (end)\n"; - }); -} - } // namespace namespace mlir::triton::AMD { @@ -536,8 +429,6 @@ bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp) { LogicalResult TritonIntegerRangeAnalysis::initialize(Operation *top) { signedIntValues.clear(); - if (assumeNoArithOverflow) - collectValueOfSignedInt(top, signedIntValues); return Base::initialize(top); } @@ -603,15 +494,7 @@ void TritonIntegerRangeAnalysis::defaultTransferFunc( resultsLattices[result.getResultNumber()]; IntegerValueRange incomingRange_ = incomingRange; - // step 2: Some range value in MLIR lib is too conservative, update the - // value-range before it is jointed to the lattice. - if (auto inferrable = dyn_cast(op)) { - auto res = rectifyInfferableRange(inferrable, srcLattices, incomingRange_); - if (res.has_value()) - incomingRange_ = std::move(*res); - } - - // step 3: If there is assumed value range, the assumed one take precedence. + // step 2: If there is assumed value range, the assumed one take precedence. // TODO: I think this is bit conservative, the better way is: // final_range = (old_range ∪ incomingRange) ∩ assume_range if (auto iter = opResultAssumption.find(resultVal); @@ -619,11 +502,11 @@ void TritonIntegerRangeAnalysis::defaultTransferFunc( const auto &range = iter->second; if (auto maybeRange = maybeGetAssumedRange(resultVal, op->getBlock())) { incomingRange_ = - IntegerValueRange(incomingRange.getValue().intersection(range)); + IntegerValueRange(incomingRange_.getValue().intersection(range)); } } - // step 4: Update the value range. Note that we are using `join` operation + // step 3: Update the value range. Note that we are using `join` operation // which means `union`. Transfer function must be monotone! The resolver // would otherwise fall into infinite loop. ChangeResult changed = lattice->join(incomingRange_); @@ -637,84 +520,11 @@ void TritonIntegerRangeAnalysis::defaultTransferFunc( << ", in value-range: " << incomingRange_ << "\n"; }); - // step 5: Add those ops that depends on this op to the worklist. The resolver + // step 4: Add those ops that depends on this op to the worklist. The resolver // will iterate all items in the worklist until it become empty. propagateIfChanged(lattice, changed); } -std::optional -TritonIntegerRangeAnalysis::rectifyInfferableRange( - InferIntRangeInterface rface, - ArrayRef srcLattices, - const IntegerValueRange &range) { - - auto op = rface.getOperation(); - - // step 1: rule out some operations we cannot handle - if (!llvm::isa(op) || - range.isUninitialized()) { - return std::nullopt; - } - - auto isPos = [](const ConstantIntRanges &range) { - // Return true iff in both unsigned and signed representation, the most - // significant bit is always 0. - return range.umax().isNonNegative() && range.smax().isNonNegative() && - range.smin().isNonNegative(); - }; - - // Not applicable to those bin-ops yielding unsigned int. - if (!signedIntValues.count(op->getResult(0))) - return std::nullopt; - - // step 2: Do nothing if the value-range is already a non-negative range. - const ConstantIntRanges &resultRange = range.getValue(); - - if (isPos(resultRange)) - return std::nullopt; - - // step 3: special handling of arith::TruncIOp - if (llvm::isa(op)) { - if (!srcLattices[0] || srcLattices[0]->getValue().isUninitialized()) - return std::nullopt; - - const ConstantIntRanges srcRange = srcLattices[0]->getValue().getValue(); - if (!isPos(srcRange)) - return std::nullopt; - - // assume NSW - APInt umax = APInt::getSignedMaxValue(resultRange.umax().getBitWidth()); - return ConstantIntRanges::fromUnsigned(resultRange.umin(), umax); - } - - // step 4: rule out some messy situations - // If the MSB of umin is "1", bailout - if (!resultRange.umin().isNonNegative()) - return std::nullopt; - - // If the value-ranges of operands are somehow missing, we can do nothing - if (!srcLattices[0] || !srcLattices[1] || - srcLattices[0]->getValue().isUninitialized() || - srcLattices[1]->getValue().isUninitialized()) - return std::nullopt; - - auto opndRange0 = srcLattices[0]->getValue().getValue(); - auto opndRange1 = srcLattices[1]->getValue().getValue(); - - // bail out if one of operands' is not non-negative - if (!isPos(opndRange0) || !isPos(opndRange1)) - return std::nullopt; - - APInt umax(resultRange.umax()); - if (!umax.isNonNegative()) { - // Saturate umax to 0x7f...f - umax = APInt::getSignedMaxValue(umax.getBitWidth()); - } - - return ConstantIntRanges::fromUnsigned(resultRange.umin(), umax); -} - LogicalResult TritonIntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 0710a6ef65ed..f2347e269ac8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -101,8 +101,8 @@ bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, int64_t szLimit2GB = (1L << 31) - 1; LDBG("element bit sz:" << elemBitSz << ", max byte offset:" << byteOfst - << ((szLimit2GB > byteOfst) ? ", out or range" - : ",in range")); + << ((szLimit2GB > byteOfst) ? ", out of range" + : ", in range")); return byteOfst <= szLimit2GB; } diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index 105fa1e8cffc..da911bef13f6 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -32,8 +32,7 @@ struct TestAMDRangeAnalysisPass std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = solver->load( - assumptions, &getAnalysis(), - /*assumeNoArithOverflow=*/true); + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); From cd02ecc52a8bd0a9c322f1f6a4676ddefc802cd5 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Mon, 13 Oct 2025 09:26:27 -0700 Subject: [PATCH 8/9] fix typo and rebase --- third_party/amd/lib/Analysis/RangeAnalysis.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 78362d1fec58..23e6231a7012 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -32,7 +32,7 @@ // for this pass, because we don't expect many dead code at the moment when // this analysis is invoked. Price for being "conditional" is less about // compile time but complexity (in terms of debugging and understanding). -// 1.5 Maybe just walking the code top-dowm is suffice for range-analysis: +// 1.5 Maybe just walking the code top-dowm is sufficient for range-analysis: // For loops, figuring out IVs' value-ranges before loops are entered, and // progress to loop-body, without visiting back-edge for non-SCF loops. // @@ -306,7 +306,7 @@ maybeGetAssumedRange(const SetVector &allAssumptions, Value anchor, if (result) { const auto &val = *result; if (val.smin().isNonNegative()) { - // Consider 0 < x && x < 1024. + // Consider 0 <= x && x <= 1024. // When processing x > 0, the value range of x is // vr1={umin=0, umax=0xf...f, smin=0, smax=0x7...f} // When processing x < 1024, the value range of x is: @@ -559,7 +559,7 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( } assert(opndValueRanges.size() == operands.size() && "size disagree"); - // step 3: call helper function inferring the value range. If assumed value- + // step 2: call helper function inferring the value range. If assumed value- // range is present, the transfer-function will intersect the assumed value- // value with the inferred value range. LogicalResult visitResult = @@ -757,10 +757,10 @@ void TritonIntegerRangeAnalysis::visitRegionSuccessors( // yield n // // This loop tries to update lattice(x) = join(lattice(m), lattice(n), - // proovided lattice(m) and lattice(n) are initialized. + // provided lattice(m) and lattice(n) are initialized. // // Note that the state of lattice(m) and lattice(n) was updated in the - // "previous" round. In this "round", the scf.if is vsitied right now, and + // "previous" round. In this "round", the scf.if is visitied right now, and // it takes this moment to update its LHS. // // Alternatively, when we visit, say op_m, we notice its result is used by From 87dd763d8e14226334d6211c4cebea526a221646 Mon Sep 17 00:00:00 2001 From: Shuxin Yang Date: Mon, 13 Oct 2025 10:50:14 -0700 Subject: [PATCH 9/9] dummy change to trigger cancelled test --- third_party/amd/lib/Analysis/RangeAnalysis.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 23e6231a7012..aef572307ed2 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -52,7 +52,6 @@ // intersecting inferred value-range with assumed-value-range still guarantee // its monotonicity. However, the underlying lattice's meet() operation is // a silent no-op. -// #undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-range-analysis"