Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions test/TritonGPU/amd/amd-convert-buffer-ops-small-tensor.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// COMMON: buffer_load %arg0[%[[offset]]]
%9 = tt.load %6 : tensor<256x!tt.ptr<f32>, #blocked0>
// COMMON: buffer_load %arg1[%[[offset]]]
// Note: offset = pid * 256 + arange(0, 256); byte-ofst="offset * sizeof(i32)" may not fall into range of 2G.
// COMMON-NOT: buffer_load %arg1[%[[offset]]]
%10 = tt.load %8 : tensor<256x!tt.ptr<f32>, #blocked0>
// COMMON: %[[data:.*]] = arith.addf
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// COMMON: buffer_store %[[data]], %arg2[%[[offset]]]
// Note: see the explanation above
// COMMON-NOT: buffer_store %[[data]], %arg2[%[[offset]]]
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
}
Expand Down Expand Up @@ -70,7 +72,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// COMMON: buffer_load %[[scalar_ptr]][%[[offset]]]
// Note: the base "scalar_ptr" points to arg0 which is a large-tensor.
// the offset="%sub + arange(0,1024)" where "%sub=pid*1024-128",
// We can prove "offset > 0", but cannot prove byte-offset < 2G.
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset]]]
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %10 : tensor<1024xf32, #blocked>
}
Expand Down Expand Up @@ -122,7 +127,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// COMMON: %[[offset_32_bit:.*]] = arith.trunci
%narrow4 = arith.trunci %4 : tensor<1024xi64, #blocked> to tensor <1024xi32, #blocked>
%9 = tt.addptr %8, %narrow4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// COMMON: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
// Note: base is arg0 which is large-tensor, the offset=int(long(pid*1024) * long(arange(0, 1024))
// offset is in [0, i32-max].
// COMMON-NOT: buffer_load %[[scalar_ptr]][%[[offset_32_bit]]]
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %10 : tensor<1024xf32, #blocked>
}
Expand Down Expand Up @@ -555,7 +562,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// COMMON: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
// Note: the large tensor is accessed, offset is in the range of [0, smax].
// without tl.assume the range would be [-128, smax]
// COMMON-NOT: amdgpu.buffer_atomic_rmw
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
tt.return %8 : tensor<1024xf32, #blocked>
}
Expand Down
197 changes: 174 additions & 23 deletions test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Large diffs are not rendered by default.

455 changes: 353 additions & 102 deletions test/TritonGPU/amd/amd-range-analysis.mlir

Large diffs are not rendered by default.

39 changes: 36 additions & 3 deletions third_party/amd/include/Analysis/RangeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/LoopLikeInterface.h"

namespace mlir::triton {
Expand Down Expand Up @@ -32,15 +33,20 @@ namespace mlir::triton::AMD {
/// See visitRegionSuccessors.
struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis;
using Base = dataflow::IntegerRangeAnalysis;
TritonIntegerRangeAnalysis(
DataFlowSolver &solver,
const DenseMap<Value, SetVector<Operation *>> &assumptions)
: dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {}
const DenseMap<Value, SetVector<Operation *>> &assumptions,
DominanceInfo *dominanceInfo, bool assumeNoArithOverflow_ = false)
: dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions),
domInfo(dominanceInfo), assumeNoArithOverflow(assumeNoArithOverflow_) {}

void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override;

void initializeFuncOp(triton::FuncOp funcOp);

LogicalResult initialize(Operation *top) override;

LogicalResult visitOperation(
Operation *op,
ArrayRef<const dataflow::IntegerValueRangeLattice *> operands,
Expand Down Expand Up @@ -95,7 +101,8 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
/// llvm.intr.assume %assumesltlhs : i1
/// for %K, will produce a final range
/// [0, 2147483647] ∩ [-2147483648, 128] = [0, 128]
std::optional<ConstantIntRanges> maybeGetAssumedRange(Value anchor) const;
std::optional<ConstantIntRanges> maybeGetAssumedRange(Value anchor,
Block *useBlock) const;

int64_t getTotalLoopTripCount(LoopLikeOpInterface loop);

Expand Down Expand Up @@ -125,6 +132,32 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
/// If one uses collectAssumptions below then `assumptions` will look like
/// %K -> {arith.cmpi slt..., arith.cmpi sge}.
llvm::DenseMap<Value, SetVector<Operation *>> assumptions;

/// The defaultTransferFunc is the default transfer function for this dataflow
/// problem.
/// @param[in] op: the Operation in question
/// @param[in] result: a particular value defined by this op. Note that op
/// may define multiple values.
/// @param[in] srcLattices: lattices of all source operands
/// @param[in] destLattices: lattices all all result values
/// @param[in] incomingRange: the value-range inffered for result
void defaultTransferFunc(
Operation *op, Value result,
ArrayRef<const dataflow::IntegerValueRangeLattice *> srcLattices,
ArrayRef<dataflow::IntegerValueRangeLattice *> destLattices,
const IntegerValueRange &incomingRange);

private:
void visitYieldHelper(Operation *yieldOp, Value value);
LogicalResult visitOperationHelper(
Operation *op,
ArrayRef<const dataflow::IntegerValueRangeLattice *> operands,
ArrayRef<dataflow::IntegerValueRangeLattice *> resultsLattices);

DenseSet<Value> signedIntValues;
llvm::SmallMapVector<Value, ConstantIntRanges, 2> opResultAssumption;
DominanceInfo *domInfo = nullptr;
bool assumeNoArithOverflow = false;
};

std::optional<SmallVector<std::optional<ConstantIntRanges>>>
Expand Down
Loading