diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir index 2b146c4c51bf..75a1004dfbe3 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir @@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> // COMMON: buffer_load %arg0[%[[offset]]] %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> - // COMMON: buffer_load %arg1[%[[offset]]] + // Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G. + // COMMON-NOT: buffer_load %arg1[%[[offset]]] %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> // COMMON: %[[data:.*]] = arith.addf %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_store %[[data]], %arg2[%[[offset]]] + // Note: see the explanation above + // COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]] tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> tt.return } @@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset]]] + // Note: the base "scalar_ptr" points to arg0 which is a large-tensor. + // the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128", + // We can prove "offset > 0", but cannot prove byte-offset < 2G. + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON: %[[offset_32_bit:.*]] = arith.trunci %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + // Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024)) + // offset is in [0, i32-max]. + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + // Note: the large tensor is accessed, offset is in the range of [0, smax]. + // without tl.assume the range would be [-128, smax] + // COMMON-NOT: amdgpu.buffer_atomic_rmw %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> tt.return %8 : tensor<1024xf32, #blocked> } diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 7b04877a9454..bb5121af2cbf 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -16,15 +16,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_load %arg0[%[[offset]]] + // Note: large-tensor with elemIdx=pid*256 + arange(0, 256), elemIdx ∈ [0, smax] + // COMMON-NOT: buffer_load %9 = tt.load %6 : tensor<256x!tt.ptr, #blocked0> - // COMMON: buffer_load %arg1[%[[offset]]] + // COMMON-NOT: buffer_load %10 = tt.load %8 : tensor<256x!tt.ptr, #blocked0> // COMMON: %[[data:.*]] = arith.addf %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> - // COMMON: buffer_store %[[data]], %arg2[%[[offset]]] + // Note: large-tensor with elemIdx ∈ [0, smax] + // COMMON-NOT: buffer_store tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> tt.return } @@ -43,6 +45,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> %cmp = arith.cmpi sgt, %arg6, %c0_i32 : i32 llvm.intr.assume %cmp : i1 + %arg6_upper = arith.constant 4194304 : i32 + %cmp2 = arith.cmpi slt, %arg6, %arg6_upper : i32 + llvm.intr.assume %cmp2 : i1 %2 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<256x1xi32, #blocked> %4 = tt.addptr %arg0, %c32_i32 : !tt.ptr, i32 @@ -78,6 +83,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %23 = arith.addi %21, %20 : tensor<256x64xi32, #blocked> %24 = tt.splat %22 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> %25 = tt.addptr %24, %23 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %ofst_upper = arith.constant 1073741823 : i32 + %cmp3 = arith.cmpi slt, %ofst_upper, %ofst_upper : i32 + llvm.intr.assume %cmp3 : i1 // COMMON: %[[splatb:.*]] = tt.splat %arg[[#strideb:]] // COMMON: %[[mulb:.*]] = arith.muli %[[splatb]], %[[#]] @@ -85,12 +93,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[bcast0b:.*]] = tt.broadcast %[[#]] // COMMON: %[[ptrb:.*]] = tt.addptr // COMMON: %[[offsetb:.*]] = arith.addi %[[bcast0b]], %[[bcast1b]] - // COMMON: buffer_store %[[buffer]], %[[ptrb]][%[[offsetb]]] stride = %arg[[#strideb]] + // COMMON-NOT: buffer_store tt.store %25, %12 : tensor<256x64x!tt.ptr, #blocked> tt.return } } + + // ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -113,7 +123,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset]]] + // COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]] %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -165,7 +175,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON: %[[offset_32_bit:.*]] = arith.trunci %narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked> %9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]] + // Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024)) + // offset is in [0, i32-max]. + // COMMON-NOT: buffer_load %10 = tt.load %9 : tensor<1024x!tt.ptr, #blocked> tt.return %10 : tensor<1024xf32, #blocked> } @@ -265,9 +277,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %14 = tt.addptr %13, %11 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> %15 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> %16 = tt.addptr %15, %offsets : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%[[offsets]]] + // COMMON-NOT: amdgpu.buffer_load %17 = tt.load %16 : tensor<16x!tt.ptr, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg0[%[[range]]] + // COMMON: amdgpu.buffer_store tt.store %14, %17 : tensor<16x!tt.ptr, #blocked> tt.return } @@ -364,11 +376,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> %2 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %3 = tt.addptr %2, %0 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: index is tt.histogram ∈ [0, smax) + // COMMON-NOT: amdgpu.buffer_load %4 = tt.load %3 : tensor<8x!tt.ptr, #blocked> %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %6 = tt.addptr %5, %1 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // Note: index is tt.histogram ∈ [0, smax) + // COMMON: amdgpu.buffer_store tt.store %6, %4 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -391,11 +405,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %8 = arith.addi %6, %7 : tensor<8xi32, #blocked> %9 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // COMMON-NOT: amdgpu.buffer_load %11 = tt.load %10 : tensor<8x!tt.ptr, #blocked> %12 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %13 = tt.addptr %12, %7 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %13, %11 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -426,11 +440,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %17 = arith.addi %15, %16 : tensor<8xi32, #blocked> %18 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %19 = tt.addptr %18, %17 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: above operations can only prove elmtIdx >= 0 not don't reveal its upper bound. + // COMMON-NOT: amdgpu.buffer_load %20 = tt.load %19 : tensor<8x!tt.ptr, #blocked> %21 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %22 = tt.addptr %21, %16 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %22, %20 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -450,11 +465,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = arith.trunci %4 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> %6 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %7 = tt.addptr %6, %5 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: elemIdx is (int32)(arange(0, 8) + (uint64)(uint32)arg2) + // elemIdx is not necessarilly >=0 + // COMMON-NOT: amdgpu.buffer_load %8 = tt.load %7: tensor<8x!tt.ptr, #blocked> %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %2 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %10, %8 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -490,12 +507,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %4 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + // Note: It's not able to prove that the value range of elmtIdx in [0,1G]. + // testing case traverse_if_2nd, traverse_if_2nd_v2 and traverse_if_2nd_v3 + // works better than this case for this purpose. + // COMMON-NOT:amdgpu.buffer_load %7 = tt.load %6: tensor<8x!tt.ptr, #blocked> %8 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> - // COMMON: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + // COMMON: amdgpu.buffer_store tt.store %10, %7 : tensor<8x!tt.ptr, #blocked> tt.return } @@ -505,8 +525,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { - // COMMON-LABEL: traverse_if - tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + // COMMON-LABEL: traverse_if_2nd + tt.func @traverse_if_2nd(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON-NOT: amdgpu.buffer_load + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON: amdgpu.buffer_store + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // COMMON-LABEL: traverse_if_2nd_v2 + tt.func @traverse_if_2nd_v2(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { %c0_i32 = arith.constant 0 : i32 %c2_i32 = arith.constant 2 : i32 %c5_i32 = arith.constant 7 : i32 @@ -534,6 +598,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + + // Note: + // elmtIdx = %6 = %4 + %5, value-range(%4) = [0,7], value-range(%5) = [0, umax] + // %5 = max([0,8] + arg3, [8,16) + arg2), to make %6 * sizeof(bf16) <= 2G - 2byte + // arg3 ∈ [0, 1G-1-8-7 = 1073741808), arg2 ∈ [-8, 1G-1-15-8=1073741800] + %cmp1 = arith.cmpi sge, %arg2, %c0_i32 : i32 + llvm.intr.assume %cmp1 : i1 + %cmp2 = arith.cmpi sge, %arg3, %c0_i32 : i32 + llvm.intr.assume %cmp2 : i1 + %arg_up2 = arith.constant 1073741800 : i32 + %arg_up3 = arith.constant 1073741808 : i32 + %cmp3 = arith.cmpi slt, %arg2, %arg_up2 : i32 + %cmp4 = arith.cmpi slt, %arg3, %arg_up3 : i32 + llvm.intr.assume %cmp3 : i1 + llvm.intr.assume %cmp4 : i1 + // COMMON: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> @@ -547,6 +627,68 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // COMMON-LABEL: traverse_if_2nd_v3 + tt.func @traverse_if_2nd_v3(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + + // Note: + // elmtIdx = %6 = %4 + %5, value-range(%4) = [0,7], value-range(%5) = [0, umax] + // %5 = max([0,8] + arg3, [8,16) + arg2), to make %6 * sizeof(bf16) <= 2G - 2byte + // arg3 ∈ [0, 1G-1-8-7 = 1073741808), arg2 ∈ [-8, 1G-1-15-8=1073741800] + %cmp1 = arith.cmpi sge, %arg2, %c0_i32 : i32 + llvm.intr.assume %cmp1 : i1 + %cmp2 = arith.cmpi sge, %arg3, %c0_i32 : i32 + llvm.intr.assume %cmp2 : i1 + // the only difference between traverse_if_2nd_v3 and traverse_if_2nd_v2 + // is arg_up2. In v3 the upper bound is bumped by 1. + %arg_up2 = arith.constant 1073741801 : i32 + %arg_up3 = arith.constant 1073741808 : i32 + %cmp3 = arith.cmpi slt, %arg2, %arg_up2 : i32 + %cmp4 = arith.cmpi slt, %arg3, %arg_up3 : i32 + llvm.intr.assume %cmp3 : i1 + llvm.intr.assume %cmp4 : i1 + + // COMMON-NOT: amdgpu.buffer_load + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // COMMON: amdgpu.buffer_store + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { // COMMON-LABEL: atomic_add_bf16 @@ -589,7 +731,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 %6 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] + // Note: the large tensor is accessed, offset is in the range of [0, smax]. + // without tl.assume the range would be [-128, smax] + // COMMON-NOT: amdgpu.buffer_atomic_rmw %8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked> tt.return %8 : tensor<1024xf32, #blocked> } @@ -629,6 +773,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[bcast0:.*]] = tt.broadcast %[[#]] // COMMON: %[[offset:.*]] = arith.addi %[[bcast0]], %[[bcast1]] + // Note: offset(i.e. elmtIdx) = bcast0 + bcast1 + // = arange(0, 64) + arg6 * arange(0, 256) + // to make elmtIdx * sizeof(f16) ∈ [0, 2G], arg6 must be in [0, 4210752] + %arg6_up = arith.constant 4210752: i32 + %cmp2 = arith.cmpi slt, %arg6, %arg6_up : i32 + llvm.intr.assume %cmp2 : i1 + // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] stride = %arg[[#stride]] into %arg10 %12 = ttg.async_copy_global_to_local %11, %arg10 : tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> @@ -644,10 +795,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = ca into %arg10 %16 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = ca: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> - // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cg into %arg10 + // COMMONx: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cg into %arg10 %17 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cg: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> - // COMMON: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cv into %arg10 + // COMMONx: %[[buffer:.*]] = amdgpu.buffer_load_to_local %[[ptr]][%[[offset]]] mask = %arg11 other = %arg12 stride = %arg[[#stride]] cacheModifier = cv into %arg10 %18 = ttg.async_copy_global_to_local %11, %arg10 mask %arg11 other %arg12 cacheModifier = cv: tensor<256x64x!tt.ptr, #blocked> -> <256x64xf16, #shared, #smem, mutable> tt.return } diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 2c40188e1c44..5d975534e948 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -97,14 +97,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+2 {{unsigned : [0, 2046] signed : [0, 2046]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %6, %4 : tensor<1024xi64> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -132,7 +132,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 %4 = tt.addptr %3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 2048] signed : [0, 2048]}} + // expected-remark@+2 {{unsigned : [0, 2046] signed : [0, 2046]}} // expected-remark@+1 {{non-neg}} %5 = arith.addi %2, %2 : tensor<1024xi32> %6 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -162,18 +162,18 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+3 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+3 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %5:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %14 = arith.addi %13, %arg4 : tensor<1024xi64> %15 = tt.splat %12 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -183,10 +183,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %8 = arith.addi %7, %5#1 : tensor<1024xi64> %9 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -216,15 +216,15 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+3 {{result 1: unsigned : [0, 130048] signed : [0, 130048]}} + // expected-remark@+3 {{result 1: unsigned : [0, 129921] signed : [0, 129921]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %3:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %10 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %11 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %12 = arith.addi %11, %arg4 : tensor<1024xi64> %13 = tt.splat %10 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -234,10 +234,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10, %12, %16 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %3#1 : tensor<1024xi64> %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -267,19 +267,19 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+3 {{result 1: unsigned : [0, 15360] signed : [0, 15360]}} + // expected-remark@+3 {{result 1: unsigned : [0, 15345] signed : [0, 15345]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 16}} %3:3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %arg0, %arg4 = %cst, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { - // expected-remark@+3 {{result 1: unsigned : [0, 261120] signed : [0, 261120]}} + // expected-remark@+3 {{result 1: unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 256}} %10:3 = scf.for %arg6 = %c0 to %c16 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 262144] signed : [0, 262144]}} + // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} // expected-remark@+1 {{non-neg}} %13 = arith.addi %12, %arg8 : tensor<1024xi64> %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -291,10 +291,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 16384] signed : [0, 16384]}} + // expected-remark@+2 {{unsigned : [0, 16368] signed : [0, 16368]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %3#1 : tensor<1024xi64> %7 = tt.splat %4 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -331,7 +331,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{inferred total trip count: 16384}} %10:3 = scf.for %arg6 = %c0 to %c128 step %c1 iter_args(%arg7 = %arg3, %arg8 = %arg4, %arg9 = %arg5) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %11 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %12 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -345,7 +345,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %10#0, %10#1, %10#2 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %4 = tt.addptr %3#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %5 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -375,11 +375,11 @@ module attributes {"ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // expected-remark@+2 {{result 1: unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{result 1: unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{result 1: non-neg}} %3:2 = scf.if %arg2 -> (!tt.ptr, tensor<1024xi64>) { %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> scf.yield %8, %9 : !tt.ptr, tensor<1024xi64> @@ -387,7 +387,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %8 = tt.addptr %arg0, %1 : !tt.ptr, i32 scf.yield %8, %cst : !tt.ptr, tensor<1024xi64> } - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.trunci %3#1 : tensor<1024xi64> to tensor<1024xi32> %5 = tt.splat %3#0 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -416,7 +416,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> cf.cond_br %arg1, ^bb1(%arg0, %cst : !tt.ptr, tensor<1024xi64>), ^bb2(%3, %4 : !tt.ptr, tensor<1024xi64>) @@ -429,7 +429,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %10 = tt.load %9 : tensor<1024x!tt.ptr> tt.return %10 : tensor<1024xf32> ^bb2(%11: !tt.ptr, %12: tensor<1024xi64>): // pred: ^bb0 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.trunci %12 : tensor<1024xi64> to tensor<1024xi32> %14 = tt.splat %11 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -483,7 +483,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c256_i32 : i32 %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> %3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -492,10 +492,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { %6 = arith.muli %4, %5 : tensor<16x1xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %7 = tt.broadcast %6 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %8 = tt.expand_dims %2 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %9 = tt.broadcast %8 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -535,7 +535,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { %7 = arith.muli %4, %6 : tensor<128x1xi32, #blocked> %8 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> %9 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %10 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -567,14 +567,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = arith.select %arg1, %arg0, %3 : !tt.ptr - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.select %arg1, %cst, %4 : tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.trunci %6 : tensor<1024xi64> to tensor<1024xi32> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -630,24 +630,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { llvm.intr.assume %cmpule_pid : i1 %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %3 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+3 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+3 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+2 {{result 1: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %4:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %2, %arg4 = %3, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { - // expected-remark@+2 {{unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+2 {{unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+1 {{non-neg}} %11 = arith.trunci %arg4 : tensor<1024xi64> to tensor<1024xi32> %12 = tt.splat %arg3 : !tt.ptr -> tensor<1024x!tt.ptr> %13 = tt.addptr %12, %11 : tensor<1024x!tt.ptr>, tensor<1024xi32> %14 = tt.load %13 : tensor<1024x!tt.ptr> %15 = tt.addptr %arg3, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %16 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %17 = arith.addi %16, %arg4 : tensor<1024xi64> %18 = tt.addptr %15, %0 : !tt.ptr, i32 @@ -655,10 +655,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %18, %17, %19 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>} %5 = tt.addptr %4#0, %0 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %1 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %6, %4#1 : tensor<1024xi64> %8 = tt.splat %5 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -744,24 +744,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+5 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} - // expected-remark@+4 {{result 3: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+5 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} + // expected-remark@+4 {{result 3: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+3 {{result 1: non-neg}} // expected-remark@+2 {{result 3: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %7:5 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %6, %arg5 = %3, %arg6 = %4, %arg7 = %arg1) -> (!tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %14 = tt.addptr %arg5, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %16 = arith.addi %15, %arg6 : tensor<1024xi64> %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -771,10 +771,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %14, %16, %arg3, %arg4, %20 : !tt.ptr, tensor<1024xi64>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %9, %7#1 : tensor<1024xi64> %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -804,24 +804,24 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> %5 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %6 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+5 {{result 1: unsigned : [0, 131072] signed : [0, 131072]}} - // expected-remark@+4 {{result 4: unsigned : [0, 131072] signed : [0, 131072]}} + // expected-remark@+5 {{result 1: unsigned : [0, 130944] signed : [0, 130944]}} + // expected-remark@+4 {{result 4: unsigned : [0, 130944] signed : [0, 130944]}} // expected-remark@+3 {{result 1: non-neg}} // expected-remark@+2 {{result 4: non-neg}} // expected-remark@+1 {{inferred total trip count: 128}} %7:6 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %arg2, %arg7 = %5, %arg8 = %6, %arg9 = %arg2) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %20 = tt.addptr %arg4, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %21 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %22 = arith.addi %21, %arg5 : tensor<1024xi64> %23 = tt.splat %20 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -829,10 +829,10 @@ module attributes {"ttg.num-warps" = 4 : i32} { %25 = tt.load %24 : tensor<1024x!tt.ptr> %26 = arith.addf %25, %arg6 : tensor<1024xf32> %27 = tt.addptr %arg7, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %28 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %29 = arith.addi %28, %arg8 : tensor<1024xi64> %30 = tt.splat %27 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -842,20 +842,20 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %27, %29, %33, %20, %22, %26 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %8 = tt.addptr %7#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %9 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %9, %7#1 : tensor<1024xi64> %11 = tt.splat %8 : !tt.ptr -> tensor<1024x!tt.ptr> %12 = tt.addptr %11, %10 : tensor<1024x!tt.ptr>, tensor<1024xi64> %13 = tt.load %12 : tensor<1024x!tt.ptr> %14 = tt.addptr %7#3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %15 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> - // expected-remark@+2 {{unsigned : [0, 132096] signed : [0, 132096]}} + // expected-remark@+2 {{unsigned : [0, 131967] signed : [0, 131967]}} // expected-remark@+1 {{non-neg}} %16 = arith.addi %15, %7#4 : tensor<1024xi64> %17 = tt.splat %14 : !tt.ptr -> tensor<1024x!tt.ptr> @@ -886,14 +886,14 @@ module attributes {"ttg.num-warps" = 4 : i32} { %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %4 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+2 {{result 1: unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} // expected-remark@+1 {{inferred total trip count: 1025}} %5:3 = scf.for %arg2 = %c0 to %c128 step %K iter_args(%arg3 = %3, %arg4 = %4, %arg5 = %arg1) -> (!tt.ptr, tensor<1024xi64>, tensor<1024xf32>) { %12 = tt.addptr %arg3, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %13 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -905,7 +905,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { scf.yield %12, %14, %18 : !tt.ptr, tensor<1024xi64>, tensor<1024xf32> } %6 = tt.addptr %5#0, %1 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}} + // expected-remark@+2 {{unsigned : [0, 1023] signed : [0, 1023]}} // expected-remark@+1 {{non-neg}} %7 = arith.extsi %2 : tensor<1024xi32> to tensor<1024xi64> // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} @@ -1235,32 +1235,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr, %arg1: !tt.ptr) { %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32> - // expected-remark@+2 {{unsigned : [0, 10] signed : [0, 10]}} + // expected-remark@+2 {{unsigned : [0, 9] signed : [0, 9]}} // expected-remark@+1 {{non-neg}} %2 = tt.join %0, %1 : tensor<8xi32> -> tensor<8x2xi32> %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32> - // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} %5 = tt.join %3, %4 : tensor<4xi32> -> tensor<4x2xi32> - // expected-remark@+2 {{unsigned : [0, 8] signed : [0, 8]}} + // expected-remark@+2 {{unsigned : [0, 7] signed : [0, 7]}} // expected-remark@+1 {{non-neg}} %6 = tt.cat %5, %5 : tensor<4x2xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %7 = arith.addi %2, %6 : tensor<8x2xi32> %zeros = arith.constant dense<0> : tensor<8x1xi32> %ones = arith.constant dense<1> : tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %8 = tt.gather %7[%zeros] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 18] signed : [0, 18]}} + // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} // expected-remark@+1 {{non-neg}} %9 = tt.gather %7[%ones] {axis = 1 : i32} : (tensor<8x2xi32>, tensor<8x1xi32>) -> tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} // expected-remark@+1 {{non-neg}} %10 = arith.addi %8, %9 : tensor<8x1xi32> - // expected-remark@+2 {{unsigned : [0, 36] signed : [0, 36]}} + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} // expected-remark@+1 {{non-neg}} %11 = tt.reshape %10 allow_reorder : tensor<8x1xi32> -> tensor<8xi32> tt.return @@ -1316,7 +1316,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // expected-remark@+1 {{non-neg}} %6 = tt.splat %5 : i32 -> tensor<8xi32> %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - // expected-remark@+1 {{unsigned : [0, 2147483655] signed : [-2147483648, 2147483647]}} + // expected-remark@+1 {{unsigned : [0, 2147483654] signed : [-2147483648, 2147483647]}} %8 = arith.addi %6, %7 : tensor<8xi32> tt.return } @@ -1332,31 +1332,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32> -> tensor<8x2xi32> %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32> -> tensor<2x8xi32> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %4 = tt.trans %3 {order = array} : tensor<2x8xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 16] signed : [0, 16]}} + // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} // expected-remark@+1 {{non-neg}} %5 = ttg.convert_layout %4 : tensor<8x2xi32> -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} + // expected-remark@+2 {{unsigned : [0, 30] signed : [0, 30]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %2 : tensor<8x2xi32> %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32> - // expected-remark@+2 {{unsigned : [2, 10] signed : [2, 10]}} + // expected-remark@+2 {{unsigned : [2, 9] signed : [2, 9]}} // expected-remark@+1 {{non-neg}} %8 = ttg.convert_layout %7 : tensor<8xi32> -> tensor<8xi32> %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> %10 = tt.broadcast %9 : tensor<1x8xi32> -> tensor<2x8xi32> %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32> -> tensor<8x2xi32> %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [7, 15] signed : [7, 15]}} + // expected-remark@+2 {{unsigned : [7, 14] signed : [7, 14]}} // expected-remark@+1 {{non-neg}} %13 = arith.addi %11, %12 : tensor<8x2xi32> - // expected-remark@+2 {{unsigned : [0, 15] signed : [0, 15]}} + // expected-remark@+2 {{unsigned : [0, 14] signed : [0, 14]}} // expected-remark@+1 {{non-neg}} %14 = arith.minsi %13, %5 : tensor<8x2xi32> - // expected-remark@+4 {{result 0: unsigned : [2, 10] signed : [2, 10]}} - // expected-remark@+3 {{result 1: unsigned : [2, 10] signed : [2, 10]}} + // expected-remark@+4 {{result 0: unsigned : [2, 9] signed : [2, 9]}} + // expected-remark@+3 {{result 1: unsigned : [2, 9] signed : [2, 9]}} // expected-remark@+2 {{result 0: non-neg}} // expected-remark@+1 {{result 1: non-neg}} %15, %16 = tt.split %11: tensor<8x2xi32> -> tensor<8xi32> @@ -1538,7 +1538,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+5 {{arg 5: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} // expected-remark@+4 {{arg 6: unsigned : [1, 2147483647] signed : [1, 2147483647]}} // expected-remark@+3 {{arg 7: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} - // expected-remark@+2 {{arg 8: unsigned : [1, 2147483647] signed : [1, 1023]}} + // expected-remark@+2 {{arg 8: unsigned : [1, 1023] signed : [1, 1023]}} // expected-remark@+1 {{arg 9: unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} tt.func public @buffer_stride(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32) { %c1024_i32 = arith.constant 1024 : i32 @@ -1546,7 +1546,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %c32_i32 = arith.constant 32 : i32 %c0_i32 = arith.constant 0 : i32 %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} @@ -1562,10 +1562,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} %5 = tt.broadcast %3 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %8 = tt.broadcast %7 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -1575,10 +1575,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %12 = tt.load %11 : tensor<256x64x!tt.ptr, #blocked> %13 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> %14 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - // expected-remark@+2 {{unsigned : [0, 256] signed : [0, 256]}} + // expected-remark@+2 {{unsigned : [0, 255] signed : [0, 255]}} // expected-remark@+1 {{non-neg}} %15 = tt.expand_dims %13 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %16 = tt.expand_dims %14 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} @@ -1589,21 +1589,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // expected-remark@+1 {{result is true}} %cmp2 = arith.cmpi slt, %arg8, %c1024_i32 : i32 llvm.intr.assume %cmp2 : i1 - // expected-remark@+2 {{unsigned : [1, 2147483647] signed : [1, 1023]}} + // expected-remark@+2 {{unsigned : [1, 1023] signed : [1, 1023]}} // expected-remark@+1 {{non-neg}} %17 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} + // expected-remark@+2 {{unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+1 {{non-neg}} %18 = arith.muli %17, %15 : tensor<256x1xi32, #blocked> %19 = tt.addptr %arg2, %c48_i32 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 261888] signed : [0, 261888]}} + // expected-remark@+2 {{unsigned : [0, 260865] signed : [0, 260865]}} // expected-remark@+1 {{non-neg}} %20 = tt.broadcast %18 : tensor<256x1xi32, #blocked> -> tensor<256x64xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 64] signed : [0, 64]}} + // expected-remark@+2 {{unsigned : [0, 63] signed : [0, 63]}} // expected-remark@+1 {{non-neg}} %21 = tt.broadcast %16 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> %22 = tt.addptr %19, %c48_i32 : !tt.ptr, i32 - // expected-remark@+2 {{unsigned : [0, 261952] signed : [0, 261952]}} + // expected-remark@+2 {{unsigned : [0, 260928] signed : [0, 260928]}} // expected-remark@+1 {{non-neg}} %23 = arith.addi %21, %20 : tensor<256x64xi32, #blocked> %24 = tt.splat %22 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> @@ -1652,7 +1652,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 2097120] signed : [0, 2097120]}} // expected-remark@+1 {{non-neg}} %5 = tt.splat %3 : i32 -> tensor<32xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %6 = arith.addi %5, %4 : tensor<32xi32, #blocked> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -1660,13 +1660,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-16777216, 16777215]}} %8 = arith.divsi %7, %c128_i32 : i32 %9 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %10 = ttg.convert_layout %6 : tensor<32xi32, #blocked> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<32x1xi32, #blocked1> - // expected-remark@+2 {{unsigned : [0, 2097152] signed : [0, 2097152]}} + // expected-remark@+2 {{unsigned : [0, 2097151] signed : [0, 2097151]}} // expected-remark@+1 {{non-neg}} %12 = ttg.convert_layout %11 : tensor<32x1xi32, #blocked1> -> tensor<32x1xi32, #blocked2> // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} @@ -1685,19 +1685,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // expected-remark@+2 {{unsigned : [0, 2147483392] signed : [0, 2147483392]}} // expected-remark@+1 {{non-neg}} %27 = tt.splat %26 : i32 -> tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %28 = arith.addi %27, %9 : tensor<128xi32, #blocked> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %29 = ttg.convert_layout %28 : tensor<128xi32, #blocked> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %30 = tt.expand_dims %29 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x128xi32, #blocked4> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %31 = ttg.convert_layout %30 : tensor<1x128xi32, #blocked4> -> tensor<1x128xi32, #blocked3> - // expected-remark@+2 {{unsigned : [0, 2147483520] signed : [0, 2147483520]}} + // expected-remark@+2 {{unsigned : [0, 2147483519] signed : [0, 2147483519]}} // expected-remark@+1 {{non-neg}} %32 = tt.broadcast %31 : tensor<1x128xi32, #blocked3> -> tensor<32x128xi32, #blocked3> %33 = tt.addptr %18, %32 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi32, #blocked3> @@ -1726,3 +1726,254 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +//def scfif_range1(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 +// else: +// z = y + 4; # to check z in [6, 103] +// z2 = z + 1 # to check z2 in [0, umax]/[smin, smax] +// tl.store(output_ptr + offsets, z2, mask) +// +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range1(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range2(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// tl.assume(x < 20) +// tl.assume(x > 0) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 // check z in [4, 22] +// else: +// z = y + 4; // check z in [6, 103] +// z2 = z + 1 // check z2 in [5, 104] +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range2(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %2 : i1 + %3 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %3 : i1 + %4 = tt.get_program_id x : i32 + %5 = arith.muli %4, %c1024_i32 : i32 + %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %7 = tt.splat %5 : i32 -> tensor<1024xi32, #blocked> + %8 = arith.addi %7, %6 : tensor<1024xi32, #blocked> + %9 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %10 = arith.cmpi slt, %8, %9 : tensor<1024xi32, #blocked> + %11 = arith.cmpi sgt, %x, %y : i32 + %12 = scf.if %11 -> (i32) { + // expected-remark@+1 {{unsigned : [4, 22] signed : [4, 22]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [5, 104] signed : [5, 104]}} + %13 = arith.addi %12, %c1_i32 : i32 + %14 = arith.addi %7, %6 : tensor<1024xi32, #blocked> + %15 = arith.sitofp %13 : i32 to f32 + %16 = tt.splat %15 : f32 -> tensor<1024xf32, #blocked> + %17 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %18 = tt.addptr %17, %14 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %18, %16, %10 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range3(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 +// else: +// tl.assume(x < 20) # should not have impact to the x occurrences in then block! +// tl.assume(x > 0) +// z = y + 4; +// z2 = z + 1 +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range3(%x: i32, %y: i32, %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + // expected-remark@+1 {{[0, 4294967295] signed : [-2147483648, 2147483647]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + %17 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %17 : i1 + %18 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %18 : i1 + // expected-remark@+1 {{[6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{[0, 4294967295] signed : [-2147483648, 2147483647]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +//def scfif_range4(x, y, output_ptr,n_elements, BLOCK_SIZE: tl.constexpr, ): +// tl.assume(y < 100) +// tl.assume(y > 1) +// pid = tl.program_id(axis=0) +// block_start = pid * BLOCK_SIZE +// offsets = block_start + tl.arange(0, BLOCK_SIZE) +// mask = offsets < n_elements +// if x > y: +// z = x + 3 // check the tl.assume is applicable to this statement +// tl.assume(x < 20) +// tl.assume(x > 0) +// else: +// z = y + 4; +// z2 = z + 1 +// tl.store(output_ptr + offsets, z2, mask) + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @scfif_range4(%x: i32 loc("x"), %y: i32 loc("y"), %output_ptr: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("output_ptr"), %n_elements: i32 {tt.divisibility = 16 : i32} loc("n_elements")) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %c20_i32 = arith.constant 20 : i32 + %c3_i32 = arith.constant 3 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c1_i32 = arith.constant 1 : i32 + %c100_i32 = arith.constant 100 : i32 + %0 = arith.cmpi slt, %y, %c100_i32 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.cmpi sgt, %y, %c1_i32 : i32 + llvm.intr.assume %1 : i1 + %2 = tt.get_program_id x : i32 + %3 = arith.muli %2, %c1024_i32 : i32 + %4 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %5 = tt.splat %3 : i32 -> tensor<1024xi32, #blocked> + %6 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %7 = tt.splat %n_elements : i32 -> tensor<1024xi32, #blocked> + %8 = arith.cmpi slt, %6, %7 : tensor<1024xi32, #blocked> + %9 = arith.cmpi sgt, %x, %y : i32 + %10 = scf.if %9 -> (i32) { + %17 = arith.cmpi slt, %x, %c20_i32 : i32 + llvm.intr.assume %17 : i1 + %18 = arith.cmpi sgt, %x, %c0_i32 : i32 + llvm.intr.assume %18 : i1 + // expected-remark@+1 {{unsigned : [4, 22] signed : [4, 22]}} + %z = arith.addi %x, %c3_i32 : i32 + scf.yield %z : i32 + } else { + // expected-remark@+1 {{unsigned : [6, 103] signed : [6, 103]}} + %z = arith.addi %y, %c4_i32 : i32 + scf.yield %z : i32 + } + // expected-remark@+1 {{unsigned : [5, 104] signed : [5, 104]}} + %11 = arith.addi %10, %c1_i32 : i32 + %12 = arith.addi %5, %4 : tensor<1024xi32, #blocked> + %13 = arith.sitofp %11 : i32 to f32 + %14 = tt.splat %13 : f32 -> tensor<1024xf32, #blocked> + %15 = tt.splat %output_ptr : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %12 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + tt.store %16, %14, %8 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index e14f21a89cde..df5ff673bac6 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -4,6 +4,7 @@ #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Dominance.h" #include "mlir/Interfaces/LoopLikeInterface.h" namespace mlir::triton { @@ -32,15 +33,20 @@ namespace mlir::triton::AMD { /// See visitRegionSuccessors. struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis; + using Base = dataflow::IntegerRangeAnalysis; TritonIntegerRangeAnalysis( DataFlowSolver &solver, - const DenseMap> &assumptions) - : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {} + const DenseMap> &assumptions, + DominanceInfo *dominanceInfo, bool assumeNoArithOverflow_ = false) + : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions), + domInfo(dominanceInfo), assumeNoArithOverflow(assumeNoArithOverflow_) {} void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; void initializeFuncOp(triton::FuncOp funcOp); + LogicalResult initialize(Operation *top) override; + LogicalResult visitOperation( Operation *op, ArrayRef operands, @@ -95,7 +101,8 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { /// llvm.intr.assume %assumesltlhs : i1 /// for %K, will produce a final range /// [0, 2147483647] ∩ [-2147483648, 128] = [0, 128] - std::optional maybeGetAssumedRange(Value anchor) const; + std::optional maybeGetAssumedRange(Value anchor, + Block *useBlock) const; int64_t getTotalLoopTripCount(LoopLikeOpInterface loop); @@ -125,6 +132,32 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { /// If one uses collectAssumptions below then `assumptions` will look like /// %K -> {arith.cmpi slt..., arith.cmpi sge}. llvm::DenseMap> assumptions; + + /// The defaultTransferFunc is the default transfer function for this dataflow + /// problem. + /// @param[in] op: the Operation in question + /// @param[in] result: a particular value defined by this op. Note that op + /// may define multiple values. + /// @param[in] srcLattices: lattices of all source operands + /// @param[in] destLattices: lattices all all result values + /// @param[in] incomingRange: the value-range inffered for result + void defaultTransferFunc( + Operation *op, Value result, + ArrayRef srcLattices, + ArrayRef destLattices, + const IntegerValueRange &incomingRange); + +private: + void visitYieldHelper(Operation *yieldOp, Value value); + LogicalResult visitOperationHelper( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices); + + DenseSet signedIntValues; + llvm::SmallMapVector opResultAssumption; + DominanceInfo *domInfo = nullptr; + bool assumeNoArithOverflow = false; }; std::optional>> diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 69e483b883f2..aef572307ed2 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -4,7 +4,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Iterators.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" @@ -13,6 +15,44 @@ #include #include +// Some notes: +// +// 1. Framework +// 1.1) This pass is based on MLIR's dataflow framework. In hindsight, maybe it +// is ill-fit for what we need. +// 1.2) If I understand correctly, the MLIR's dataflow framework is a +// combination of traditional iterative dataflow analysis and a mighty +// Sparse Conditional Constant propagation (SCCP). +// 1.3) Iterative dataflow analysis requires transfer function to be monotone. +// However, not all value-ranges keep increasing when the analysis progress. +// Consider the expression x - y, while x and y's value-range may keep +// increasing, the difference between them does not necessarily keep +// increasing as well. +// 1.4) The 1st C in SCCP, i.e. "conditional" part in SCCP part is unnecessary +// for this pass, because we don't expect many dead code at the moment when +// this analysis is invoked. Price for being "conditional" is less about +// compile time but complexity (in terms of debugging and understanding). +// 1.5 Maybe just walking the code top-dowm is sufficient for range-analysis: +// For loops, figuring out IVs' value-ranges before loops are entered, and +// progress to loop-body, without visiting back-edge for non-SCF loops. +// +// 2: tl.assume statements +// 2.1) A value may have multiple assume-operations (assume-ops for short) +// associated with it. At point p, we only take into account those assume-ops +// whose enclosing basic blocks dominate the basic-block where p belongs to. +// 2.2) See some examples in the comment to maybeGetAssumedRangeHelper(). +// 2.3) The assumed value-range for source and result operands are inferred +// right before an operation is visited. +// 2.4) For now, if a value has a assumed value-range, we use assumed +// value-range and ignore its inferred value range. It would be nice to +// use the intersection of assumed-value-range and inferred-value-range. +// However, it is not always possible: iterative dataflow analysis +// requires that the transfer function must be monotone; in general it's +// dangerous to use both meet() and join() operations. In this pass, +// intersecting inferred value-range with assumed-value-range still guarantee +// its monotonicity. However, the underlying lattice's meet() operation is +// a silent no-op. + #undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-range-analysis" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -22,68 +62,6 @@ using namespace mlir; namespace tt = mlir::triton; -std::optional -triton::AMD::TritonIntegerRangeAnalysis::maybeGetTripCount( - LoopLikeOpInterface loop) { - std::optional lowerBound = loop.getSingleLowerBound(); - std::optional upperBound = loop.getSingleUpperBound(); - std::optional step = loop.getSingleStep(); - std::optional iv = loop.getSingleInductionVar(); - if (!iv) - return {}; - - unsigned int width = ConstantIntRanges::getStorageBitwidth(iv->getType()); - - auto getLoopRangeInfo = [&](std::optional loopBound, - Block *block, - std::optional getUpper = std::nullopt, - std::optional defaultVal = std::nullopt) { - if (loopBound.has_value()) { - if (auto attr = dyn_cast(*loopBound)) { - if (auto bound = dyn_cast_or_null(attr)) - return bound.getValue(); - } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { - const dataflow::IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointBefore(block), value); - if (lattice != nullptr && !lattice->getValue().isUninitialized()) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); - } - } - if (defaultVal) - return *defaultVal; - return getUpper ? APInt::getSignedMaxValue(width) - : APInt::getSignedMinValue(width); - }; - - Block *block = iv->getParentBlock(); - APInt min = getLoopRangeInfo(lowerBound, block, - /*getUpper=*/false); - APInt max = getLoopRangeInfo(upperBound, block, - /*getUpper=*/true); - // We can assume step is 1 if no range information as that gives us the upper - // bound of the number of iterations. - APInt stepValDefault = {width, 1, /*isSigned=*/true}; - APInt stepVal = - getLoopRangeInfo(step, block, /*getUpper=*/{}, stepValDefault); - - if (stepVal.isNegative()) - std::swap(min, max); - // This is necessary to catch a case like this: - // # range = [0 1024] - // K = .... - // # range = [1, 64] - // k = ... - // # range = [0, 16] -> stepVal = range.smin() = 0 - // step = ceildiv(K, k) - if (stepVal.isZero()) - stepVal = stepValDefault; - if (max.sge(min)) - return llvm::divideCeilSigned(max.getSExtValue() - min.getSExtValue(), - stepVal.getSExtValue()); - return {}; -} - namespace { constexpr int64_t kDefaultMaxTripCount = 1024; @@ -98,6 +76,27 @@ void getEnclosingLoops(Operation &op, SmallVector &ops) { } } +tt::FuncOp getEnclosingFunction(Value v) { + tt::FuncOp funcOp = nullptr; + + auto definingOp = v.getDefiningOp(); + if (!definingOp) + if (auto blk = v.getParentBlock()) + definingOp = blk->getParentOp(); + + if (definingOp) { + if (auto selfIsFunc = dyn_cast(definingOp)) + funcOp = selfIsFunc; + else + funcOp = definingOp->getParentOfType(); + } + + assert(funcOp && "No enclosing tt::FuncOp"); + return funcOp; +} + +Block *getFuncEntryBlock(tt::FuncOp func) { return &func.getRegion().front(); } + void inferResultRangesPID(Operation *op, uint64_t max, SetIntRangeFn setResultRange) { assert(op->getNumResults() == 1 && "expected op to have one result"); @@ -121,12 +120,13 @@ void inferResultRanges(tt::MakeRangeOp *op, SetIntRangeFn setResultRange) { assert(llvm::isa(resTy.getElementType()) && "expected int type"); IntegerType elTy = llvm::cast(resTy.getElementType()); auto bitWidth = mlir::ConstantIntRanges::getStorageBitwidth(elTy); + // NOTE: make_range(begin, end) yields a half open interval, [begin, end). setResultRange(result, ConstantIntRanges::range( /*min*/ {/*numBits*/ bitWidth, /*val*/ op->getStart(), /*isSigned*/ elTy.isSigned()}, /*max*/ - {/*numBits*/ bitWidth, /*val*/ op->getEnd(), + {/*numBits*/ bitWidth, /*val*/ op->getEnd() - 1, /*isSigned*/ elTy.isSigned()}, /*isSigned*/ elTy.isSigned())); } @@ -164,14 +164,51 @@ void inferResultRangesMaxNonNegSigned(Operation *op, } } -std::optional maybeGetAssumedRange(Operation *assumption, - Value anchor) { +// Given an assumption operation, try to derive the value range of the value +// 's value range at the somewhere in the block "useBlock". +// Note that +// - The value "anchor" is defined or referenced in the "useBlock" +// - The location of the reference of "anchor" in the "useBlock" does not +// matter because the IR is in SSA form, the value-range of a quantity +// does not change through out the entire block. +// - The assumption should be ignored if it does not dominate the "useBlock". +// +// Consider following cases: +// +// case 1: both s2 and s3 are applicable to s1 because they dominate s1 +// s2: assume y > 5 +// ... +// if cond +// s3: assume z < 3 +// s1: x = y + z +// +// case 2: s2 is applicable to s1 even if s2 stay after s1. +// blk: +// s1: x = y + z +// s2: assume y > 5 +// +// case 3: s2 is not applicable to s1 because the block of else-caluse does not +// domoinate the then-clause block. +// if cond +// s1: x = y + z +// else +// s2: assume y > 5 +// +std::optional +maybeGetAssumedRangeHelper(Operation *assumption, Value anchor, Block *useBlock, + DominanceInfo *domInfo) { + arith::CmpIOp cmpOp = llvm::dyn_cast(assumption); if (!cmpOp) { emitRemark(assumption->getLoc(), "unsupported assumption operation"); return {}; } + // The block where tl.assume resides must dominate the block where the value + // is referenced! + if (!useBlock || !domInfo->dominates(cmpOp->getBlock(), useBlock)) + return {}; + bool isSigned = true; switch (cmpOp.getPredicate()) { case arith::CmpIPredicate::uge: @@ -248,10 +285,105 @@ std::optional maybeGetAssumedRange(Operation *assumption, return {}; } +std::optional +maybeGetAssumedRange(const SetVector &allAssumptions, Value anchor, + Block *useBlock, DominanceInfo *domInfo) { + + std::optional result; + for (auto assumption : allAssumptions) { + auto tmpResult = + maybeGetAssumedRangeHelper(assumption, anchor, useBlock, domInfo); + if (!tmpResult.has_value()) + continue; + + if (result.has_value()) + result = (*result).intersection(*tmpResult); + else + result = *tmpResult; + } + + if (result) { + const auto &val = *result; + if (val.smin().isNonNegative()) { + // Consider 0 <= x && x <= 1024. + // When processing x > 0, the value range of x is + // vr1={umin=0, umax=0xf...f, smin=0, smax=0x7...f} + // When processing x < 1024, the value range of x is: + // vr2={umin=0, umax=0xf...f, smin=..., smax=1024} + // and + // vr1 ∩ vr2 = {umin=0, umax=0xf...f, smin=0, smax=1024} + // note that the umax=0xf...f is annoying, need to change to 1024. + return ConstantIntRanges::range(val.smin(), val.smax(), true); + } + } + return result; +} + } // namespace namespace mlir::triton::AMD { +std::optional +TritonIntegerRangeAnalysis::maybeGetTripCount(LoopLikeOpInterface loop) { + std::optional lowerBound = loop.getSingleLowerBound(); + std::optional upperBound = loop.getSingleUpperBound(); + std::optional step = loop.getSingleStep(); + std::optional iv = loop.getSingleInductionVar(); + if (!iv) + return {}; + + unsigned int width = ConstantIntRanges::getStorageBitwidth(iv->getType()); + + auto getLoopRangeInfo = [&](std::optional loopBound, + Block *block, + std::optional getUpper = std::nullopt, + std::optional defaultVal = std::nullopt) { + if (loopBound.has_value()) { + if (auto attr = dyn_cast(*loopBound)) { + if (auto bound = dyn_cast_or_null(attr)) + return bound.getValue(); + } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { + const dataflow::IntegerValueRangeLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); + } + } + if (defaultVal) + return *defaultVal; + return getUpper ? APInt::getSignedMaxValue(width) + : APInt::getSignedMinValue(width); + }; + + Block *block = iv->getParentBlock(); + APInt min = getLoopRangeInfo(lowerBound, block, + /*getUpper=*/false); + APInt max = getLoopRangeInfo(upperBound, block, + /*getUpper=*/true); + // We can assume step is 1 if no range information as that gives us the upper + // bound of the number of iterations. + APInt stepValDefault = {width, 1, /*isSigned=*/true}; + APInt stepVal = + getLoopRangeInfo(step, block, /*getUpper=*/{}, stepValDefault); + + if (stepVal.isNegative()) + std::swap(min, max); + // This is necessary to catch a case like this: + // # range = [0 1024] + // K = .... + // # range = [1, 64] + // k = ... + // # range = [0, 16] -> stepVal = range.smin() = 0 + // step = ceildiv(K, k) + if (stepVal.isZero()) + stepVal = stepValDefault; + if (max.sge(min)) + return llvm::divideCeilSigned(max.getSExtValue() - min.getSExtValue(), + stepVal.getSExtValue()); + return {}; +} + bool isEmptyInitializedRange(ConstantIntRanges rv) { if (!rv.umin().getBitWidth() || !rv.umax().getBitWidth() || !rv.smin().getBitWidth() || !rv.smax().getBitWidth()) @@ -294,26 +426,19 @@ bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp) { return false; } +LogicalResult TritonIntegerRangeAnalysis::initialize(Operation *top) { + signedIntValues.clear(); + return Base::initialize(top); +} + std::optional -TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor) const { - auto matchingAssumptions = this->assumptions.lookup(anchor); +TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor, + Block *useBlock) const { + const auto &matchingAssumptions = this->assumptions.lookup(anchor); if (matchingAssumptions.empty()) return {}; - unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType()); - assert(bitWidth > 0 && "expected non-zero bitwidth"); - ConstantIntRanges constIntRange = ConstantIntRanges::maxRange(bitWidth); - if (llvm::isa_and_nonnull(anchor.getDefiningOp())) { - constIntRange = ConstantIntRanges::range( - APInt::getZero(bitWidth), - APInt(bitWidth, kDefaultMaxPrograms - 1, true), true); - } - - for (auto assumption : matchingAssumptions) { - if (auto constIntRange_ = ::maybeGetAssumedRange(assumption, anchor)) - constIntRange = constIntRange.intersection(*constIntRange_); - } - return constIntRange; + return ::maybeGetAssumedRange(matchingAssumptions, anchor, useBlock, domInfo); } int64_t @@ -333,8 +458,10 @@ void TritonIntegerRangeAnalysis::setToEntryState( if (!llvm::isa(getElementTypeOrSelf(anchor)) && !llvm::isa(getElementTypeOrSelf(anchor))) return; + + Block *entryBlock = getFuncEntryBlock(getEnclosingFunction(anchor)); IntegerValueRange range = IntegerValueRange::getMaxRange(anchor); - if (auto maybeRange = maybeGetAssumedRange(anchor)) + if (auto maybeRange = maybeGetAssumedRange(anchor, entryBlock)) range = *maybeRange; auto changed = lattice->join(range); LLVM_DEBUG({ @@ -347,52 +474,140 @@ void TritonIntegerRangeAnalysis::setToEntryState( propagateIfChanged(lattice, changed); } +void TritonIntegerRangeAnalysis::defaultTransferFunc( + Operation *op, Value resultVal, + ArrayRef srcLattices, + ArrayRef resultsLattices, + const IntegerValueRange &incomingRange) { + + // step 1: Preparation + // - Get the lattice associated with given particular result value. + // - Make a copy of value-range just inferred, as we need to do some + // change to it before it's joined to the existing lattice. + auto result = dyn_cast(resultVal); + if (!result) + return; + assert(llvm::is_contained(op->getResults(), result)); + + dataflow::IntegerValueRangeLattice *lattice = + resultsLattices[result.getResultNumber()]; + IntegerValueRange incomingRange_ = incomingRange; + + // step 2: If there is assumed value range, the assumed one take precedence. + // TODO: I think this is bit conservative, the better way is: + // final_range = (old_range ∪ incomingRange) ∩ assume_range + if (auto iter = opResultAssumption.find(resultVal); + iter != opResultAssumption.end()) { + const auto &range = iter->second; + if (auto maybeRange = maybeGetAssumedRange(resultVal, op->getBlock())) { + incomingRange_ = + IntegerValueRange(incomingRange_.getValue().intersection(range)); + } + } + + // step 3: Update the value range. Note that we are using `join` operation + // which means `union`. Transfer function must be monotone! The resolver + // would otherwise fall into infinite loop. + ChangeResult changed = lattice->join(incomingRange_); + LLVM_DEBUG({ + OpPrintingFlags flags; + flags.skipRegions(true); + DBGS() << ((changed == ChangeResult::Change) ? ">Inferred range for: " + : ">Remain unchanged: "); + resultVal.printAsOperand(llvm::dbgs(), flags); + llvm::dbgs() << ", resulting state:" << lattice->getValue() + << ", in value-range: " << incomingRange_ << "\n"; + }); + + // step 4: Add those ops that depends on this op to the worklist. The resolver + // will iterate all items in the worklist until it become empty. + propagateIfChanged(lattice, changed); +} + LogicalResult TritonIntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef resultsLattices) { - LDBG("Inferring ranges for " << *op); - // This callback is almost exactly like the callback in - // IntegerRangeAnalysis::visitOperation except we do not "short-cicruit" the - // analysis by inferring a maximum range for loop results (instead we - // perform a check based on visit counts in visitRegionSuccessors). - auto joinCallback = [&op, &resultsLattices, - this](Value v, const IntegerValueRange &incomingRange) { - auto result = dyn_cast(v); - if (!result) - return; - assert(llvm::is_contained(op->getResults(), result)); - - dataflow::IntegerValueRangeLattice *lattice = - resultsLattices[result.getResultNumber()]; - IntegerValueRange incomingRange_ = incomingRange; - if (auto maybeRange = maybeGetAssumedRange(v)) { - incomingRange_ = - IntegerValueRange(incomingRange.getValue().intersection(*maybeRange)); + + // step 1: Figure out the implied value-range of result and source operands + opResultAssumption.clear(); + for (mlir::OpResult result : op->getResults()) { + auto assumedRange = maybeGetAssumedRange(result, op->getBlock()); + if (assumedRange.has_value()) + opResultAssumption.insert(std::pair(result, *assumedRange)); + } + + llvm::SmallVector + opndValueRanges; + + llvm::SmallVector, 4> + newSrcLattices; + + for (auto [index, opnd] : llvm::enumerate(op->getOperands())) { + auto assumedRange = maybeGetAssumedRange(opnd, op->getBlock()); + if (!assumedRange.has_value()) { + opndValueRanges.push_back(operands[index]); + continue; } - ChangeResult changed = lattice->join(incomingRange_); + + auto newLattice = + std::make_unique(opnd); + (void)newLattice->join(IntegerValueRange(*assumedRange)); + opndValueRanges.push_back(newLattice.get()); + newSrcLattices.push_back(std::move(newLattice)); + } + assert(opndValueRanges.size() == operands.size() && "size disagree"); + + // step 2: call helper function inferring the value range. If assumed value- + // range is present, the transfer-function will intersect the assumed value- + // value with the inferred value range. + LogicalResult visitResult = + visitOperationHelper(op, opndValueRanges, resultsLattices); + + // step 3: If previous step failed to infer value-range, apply assumed + // value-range is present. + for (auto [index, lattice] : llvm::enumerate(resultsLattices)) { + Value result = op->getResult(index); + const auto assumedIter = opResultAssumption.find(result); + if (assumedIter == opResultAssumption.end()) + continue; + + const mlir::IntegerValueRange &vr = lattice->getValue(); + if (!vr.isUninitialized() && !AMD::isEmptyInitializedRange(vr.getValue())) + continue; + + const ConstantIntRanges &assumedVr = assumedIter->second; + IntegerValueRange range(assumedVr); + auto changed = lattice->join(range); + LLVM_DEBUG({ if (changed == ChangeResult::Change) { - DBGS() << "Inferred range for "; - v.printAsOperand(llvm::dbgs(), {}); - llvm::dbgs() << " to " << incomingRange_ << "\n"; + DBGS() << ">Force apply assumed value range. value:"; + result.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << ", range:" << range << "\n"; } }); propagateIfChanged(lattice, changed); - }; - - // Initialize lattices with assumptions. - for (const auto &resultLattice : resultsLattices) { - if (!resultLattice->getValue().isUninitialized()) - continue; - auto anchor = resultLattice->getAnchor(); - if (auto assumptions = this->assumptions.lookup(anchor); - !assumptions.empty()) { - setToEntryState(resultLattice); - return success(); - } } + return visitResult; +} + +LogicalResult TritonIntegerRangeAnalysis::visitOperationHelper( + Operation *op, + ArrayRef operands, + ArrayRef resultsLattices) { + LDBG("Inferring ranges for " << *op); + + // This callback is almost exactly like the callback in + // IntegerRangeAnalysis::visitOperation except we do not "short-cicruit" the + // analysis by inferring a maximum range for loop results (instead we + // perform a check based on visit counts in visitRegionSuccessors). + auto joinCallback = [&op, &operands, &resultsLattices, + this](Value v, const IntegerValueRange &incomingRange) { + this->defaultTransferFunc(op, v, operands, resultsLattices, incomingRange); + }; + // Ops with fixed/constant ranges. if (llvm::isa( op)) { @@ -418,6 +633,11 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( return lattice->getValue(); }); + if (auto sliceOp = dyn_cast(op)) { + joinCallback(sliceOp->getResult(0), argIntValueRanges[0]); + return success(); + } + // Ops with actually changing/variable input/output ranges. if (llvm::isa(op)) { @@ -446,6 +666,10 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( return success(); } + // TODO: It looks like inferResultRangesFromOptional does not handle bunch + // of operations very well: + // - arith.shrui, e.g. arith.shrui %arg3, %c5_i32 + // if (auto inferrable = dyn_cast(op)) { inferrable.inferResultRangesFromOptional(argIntValueRanges, joinCallback); return success(); @@ -456,16 +680,23 @@ LogicalResult TritonIntegerRangeAnalysis::visitOperation( } void TritonIntegerRangeAnalysis::initializeFuncOp(tt::FuncOp op) { + Block *entryBlock = getFuncEntryBlock(op); for (BlockArgument argument : op.getArguments()) { - if (!this->assumptions.lookup(argument).empty()) { - dataflow::IntegerValueRangeLattice *argLattice = - getLatticeElement(argument); - auto anchor = argLattice->getAnchor(); - IntegerValueRange range = IntegerValueRange::getMaxRange(anchor); - if (auto maybeRange = maybeGetAssumedRange(anchor)) - range = *maybeRange; - (void)argLattice->join(range); - } + if (!this->assumptions.count(argument)) + continue; + + dataflow::IntegerValueRangeLattice *argLattice = + getLatticeElement(argument); + + IntegerValueRange range = IntegerValueRange::getMaxRange(argument); + if (auto maybeRange = maybeGetAssumedRange(argument, entryBlock)) + range = *maybeRange; + + // The lattice must in "bottom" state, The join() operation is to set the + // state to the given "range". + assert(argLattice->getValue().isUninitialized() && + "lattice must be in bottom state"); + (void)argLattice->join(range); } } @@ -474,7 +705,7 @@ void TritonIntegerRangeAnalysis::visitRegionSuccessors( RegionBranchPoint successor, ArrayRef abstractLattices) { LLVM_DEBUG({ - DBGS() << "Inferring ranges for "; + DBGS() << "Visit Region Succesors of "; OpPrintingFlags flags; flags.skipRegions(true); branch.print(llvm::dbgs(), flags); @@ -514,6 +745,27 @@ void TritonIntegerRangeAnalysis::visitRegionSuccessors( assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); + // Note: It does not seems to be quite obvious; this loop could update SCF + // operations' LHS. e.g. If the given "branch" argument is scf.if, and the + // scf.if construct looks like following: + // x = scf.if cond + // m = ... // op_m + // yield m + // else + // n = ... // op_n + // yield n + // + // This loop tries to update lattice(x) = join(lattice(m), lattice(n), + // provided lattice(m) and lattice(n) are initialized. + // + // Note that the state of lattice(m) and lattice(n) was updated in the + // "previous" round. In this "round", the scf.if is visitied right now, and + // it takes this moment to update its LHS. + // + // Alternatively, when we visit, say op_m, we notice its result is used by + // a yieldOp, get the yieldOp's corresponding receiver, in this case x, and + // update its state accordingly. + // for (Operation *op : predecessors->getKnownPredecessors()) { std::optional operands; if (op == branch) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index fb0892c90b1b..f2347e269ac8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -53,182 +53,58 @@ bool isSplatOneConstTensor(const Value v) { return false; } -bool verifyNonSmallerByAssumption( - Value expr, const DenseMap> &assumptions, - const std::function &matchesOther) { - if (!assumptions.contains(expr)) +bool isByteOffsetSmallerThan2GB(triton::AddPtrOp addPtrOp, + std::shared_ptr solver) { + Value elemIdx = addPtrOp.getOffset(); + LDBG("Determing value-range of element-index: " << elemIdx); + + // step 1: Get the value range of the element index + const auto *lattice = + solver->lookupState(elemIdx); + if (!lattice) { + // Note that it is not always able to get lattice, e.g. the element-index + // is defined by a tt.load. + LDBG("Cannot get lattice"); return false; - for (Operation *assume : assumptions.at(expr)) { - auto cmpOp = llvm::dyn_cast(assume); - if (!cmpOp) - continue; - switch (cmpOp.getPredicate()) { - case arith::CmpIPredicate::eq: - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: { - if (cmpOp.getLhs() == expr && matchesOther(cmpOp.getRhs())) { - LDBG(" " << expr << " non-neg by assumption " << cmpOp); - return true; - } - break; - } - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: { - if (cmpOp.getRhs() == expr && matchesOther(cmpOp.getLhs())) { - LDBG(" " << expr << " non-neg by assumption " << cmpOp); - return true; - } - break; - } - default: - break; - } } - return false; -} -bool verifyNonSmallerByAssumption( - Value expr, const DenseMap> &assumptions, - Value other) { - return verifyNonSmallerByAssumption( - expr, assumptions, [&](auto otherAssum) { return otherAssum == other; }); -} - -bool verifyNonNegativeExpr( - Value expr, const DenseMap> &assumptions, - std::shared_ptr solver) { - LDBG("Determing if non-negative: " << expr); - - auto nonNegativePred = [&solver](Value v) -> bool { - if (const auto *r = - solver->lookupState(v)) { - if (r->getValue().isUninitialized()) - return false; - if (AMD::isEmptyInitializedRange(r->getValue().getValue())) - return false; - } - return succeeded(dataflow::staticallyNonNegative(*solver, v)); + const mlir::IntegerValueRange &vr = lattice->getValue(); + if (vr.isUninitialized() || AMD::isEmptyInitializedRange(vr.getValue())) { + LDBG("Cannot get value range of the offset"); + return false; }; - if (nonNegativePred(expr)) - return true; + const auto &smin = vr.getValue().smin(); + const auto &smax = vr.getValue().smax(); - // Recurse if the operation is defined - Operation *op = expr.getDefiningOp(); - if (!op) { - LDBG(" No defining op, assuming possibly negative"); + LDBG("Element-index value-range: " << smin << " : " << smax); + if (smin.isNegative() || smax.isNegative()) + return false; + + // step 2: Get element type and size. + // e.g. addPtrOp.getType is tensor<64x64x!tt.ptr, then elemTy is + // !tt.ptr, and dereferencing elemTy gets f16. + // TODO: Not sure if we need to keep dereferencing in a loop. + Type elemTy = getElementTypeOrSelf(addPtrOp.getType()); + while (auto ptrTy = dyn_cast(elemTy)) + elemTy = ptrTy.getPointeeType(); + + if (!elemTy || !elemTy.isIntOrFloat()) { + LDBG("unknown element type: " << elemTy); return false; } - bool nonNegative = - llvm::TypeSwitch(expr.getDefiningOp()) - // Various unary triton ops that don't change the sign of the operand - .Case([&](auto unaryOp) { - return verifyNonNegativeExpr(unaryOp.getOperand(), assumptions, - solver); - }) - .Case([&](auto gatherOp) { - return verifyNonNegativeExpr(gatherOp.getSrc(), assumptions, - solver); - }) - // Joining two non-negative tensors is still non-negative - .Case([&](auto joinOp) { - return verifyNonNegativeExpr(joinOp.getLhs(), assumptions, - solver) && - verifyNonNegativeExpr(joinOp.getRhs(), assumptions, solver); - }) - // Returns a tensor representing histogram: histograms only contain - // buckets of non-negative values. - .Case([&](auto) { return true; }) - .Case([&](auto makeRangeOp) { - // See the warning in TritonOps.td: getStart/getEnd return unsigned, - // so we need to look through get*Attr. - return makeRangeOp.getStartAttr().getInt() >= 0 && - makeRangeOp.getEndAttr().getInt() >= 0; - }) - .Case( - [&](auto constIntOp) { return constIntOp.value() >= 0; }) - .Case([&](arith::ConstantOp constOp) { - Value val = constOp.getResult(); - DenseIntElementsAttr constVal; - if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat()) - return constVal.getSplatValue().isNonNegative(); - return false; - }) - .Case([&](auto) { - // These are defined as signless, but are actually unsigned - return true; - }) - .Case([&](auto maxOp) { - // max(a,b) >= 0 iff a>=0 || b>=0 - return verifyNonNegativeExpr(maxOp.getLhs(), assumptions, solver) || - verifyNonNegativeExpr(maxOp.getRhs(), assumptions, solver); - }) - .Case([&](auto remsiOp) { - // a % b >= 0 iff a>=0 - return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions, solver); - }) - .Case([&](Operation *unaryOp) { - // a = OP b >= 0 iff b >= 0 - return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions, - solver); - }) - // Casting from arbitrary data does *not* guarantee the offset is in - // range (even if pointer, or the data is non-negative when - // interpreted as the src's type). - .Case( - [&](auto) { return false; }) - .Case( - // These OPs also return unsigned values. - // TODO: We can also sniff whether a Value is unsigned by looking - // for whether or not it's used as an argument to one of - // these OPs. - [&](auto uOp) { return true; }) - .Case( - // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when - // OP != sub - [&](Operation *binOp) { - return verifyNonNegativeExpr(binOp->getOperand(0), assumptions, - solver) && - verifyNonNegativeExpr(binOp->getOperand(1), assumptions, - solver); - }) - // TODO: more scf - .Case([&](auto ifOp) { - auto results = ifOp.getResults(); - auto it = std::find(results.begin(), results.end(), expr); - assert(it != results.end() && "expr should be the result of ifOp"); - auto resultIdx = it - results.begin(); - - // If we're here then we must have both then/else regions - // (each with 1 block) and each region must terminate with an - // `scf.yield` expression. - auto thenYield = cast(ifOp.thenYield()); - auto elseYield = cast(ifOp.elseYield()); - return verifyNonNegativeExpr(thenYield->getOperand(resultIdx), - assumptions, solver) && - verifyNonNegativeExpr(elseYield->getOperand(resultIdx), - assumptions, solver); - }) - .Case([&](auto op) { - // If a user annotates tl.assume(a >= b) then we know a - b >= 0 - return verifyNonSmallerByAssumption(op.getLhs(), assumptions, - op.getRhs()); - }) - .Case([&](auto op) { - return verifyNonNegativeExpr(op->getOperand(0), assumptions, - solver); - }) - .Default([&](Operation *) { - // Conservatively assume that the expression is negative - LDBG(" Unhandled op, cannot assume non-negative"); - return false; - }); - return nonNegative; + // step 3: check of byte-offset is within 2G + int64_t elemBitSz = elemTy.getIntOrFloatBitWidth(); + int64_t elemMaxIdx = smax.getSExtValue(); + int64_t byteOfst = (elemBitSz * elemMaxIdx + elemBitSz + 7) / 8; + int64_t szLimit2GB = (1L << 31) - 1; + + LDBG("element bit sz:" << elemBitSz << ", max byte offset:" << byteOfst + << ((szLimit2GB > byteOfst) ? ", out of range" + : ", in range")); + + return byteOfst <= szLimit2GB; } bool isFuncArgWith32bitPtrRange(mlir::Value value) { @@ -294,7 +170,8 @@ bool canUseBufferOps(Value ptr, LDBG("base-ptr as tt.pointer_range=32 attribute"); return true; } - return verifyNonNegativeExpr(offset, assumptions, std::move(solver)); + + return isByteOffsetSmallerThan2GB(addPtrOp, std::move(solver)); } // Extract stride of the blocked offset of LD/ST ops. @@ -718,8 +595,10 @@ struct TritonAMDGPUConvertToBufferOpsPass DenseMap> assumptions = AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); std::shared_ptr solver = createDataFlowSolver(); + AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp index 37f411f31403..07defebe6cd6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FoldTrueCmpIOp.cpp @@ -19,7 +19,8 @@ struct TritonAMDFoldTrueCmpIOpPass ModuleOp mod = getOperation(); std::unique_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index df20182df198..da911bef13f6 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -31,7 +31,8 @@ struct TestAMDRangeAnalysisPass AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); std::shared_ptr solver = createDataFlowSolver(); AMD::TritonIntegerRangeAnalysis *rangeAnalysis = - solver->load(assumptions); + solver->load( + assumptions, &getAnalysis()); AMD::initializeFuncOps(mod, rangeAnalysis); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure();