Skip to content

Commit c6a4f77

Browse files
tommymcmandykaylorxlauko
authored
[CIR] Add get_element operation for computing pointer to array element (#1748)
## Overview Currently, getting the pointer to an element of an array requires a pointer decay and a (possible) pointer stride. A similar pattern for records has been eliminated with the `cir.get_member` operation. This PR provides a similar level of abstraction for arrays with the `get_element` operation. `get_element` replaces the above pattern with a single operation, which takes a pointer to an array and an index, and produces a pointer to the element at that index. There are many places in CIR analysis and lowering where the `ptr_stride(array_to_ptrdecay(x), i)` pattern is handled as a special case. By subsuming the special case pattern with an explicit operation, we make these analyses and lowering more robust. ## Changes Adds the `cir.get_element` operation. Extends CIRGen to emit `cir.get_element` for array subscript expressions. Updated LifetimeCheck to handle `get_element` operation, subsuming special case analysis of `cir.ptr_stride` operation (did not remove the special case). Extends CIR-to-LLVM lowering to lower `cir.get_element` to `llvm.getelementptr` Extends CIR-to-MLIR lowering to lower `cir.get_element` to `memref` operations, matching existing special case `cir.ptr_stride` lowering. ## Additional Notes Currently, 47.6% of `cir.ptr_stride` operations in the llvm-test-suite (SingleSource and MultiSource) can be replaced by `cir.get_element` operations. ### Operator Breakdown (current) name | count | % -- | -- | -- cir.load | 825221 | 22.27% cir.br | 429822 | 11.60% cir.const | 380381 | 10.26% cir.cast | 325646 | 8.79% cir.store | 309586 | 8.35% cir.get_member | 226895 | 6.12% cir.get_global | 186851 | 5.04% cir.ptr_stride | 158094 | 4.27% cir.call | 144522 | 3.90% cir.binop | 141142 | 3.81% cir.alloca | 134346 | 3.63% cir.brcond | 112864 | 3.05% cir.cmp | 83532 | 2.25% ### Operator Breakdown (with `get_element`) name | count | % -- | -- | -- cir.load | 825221 | 22.74% cir.br | 429822 | 11.84% cir.const | 380381 | 10.48% cir.store | 309586 | 8.53% cir.cast | 248645 | 6.85% cir.get_member | 226895 | 6.25% cir.get_global | 186851 | 5.15% cir.call | 144522 | 3.98% cir.binop | 141142 | 3.89% cir.alloca | 134346 | 3.70% cir.brcond | 112864 | 3.11% cir.cmp | 83532 | 2.30% cir.ptr_stride | 81093 | 2.23% cir.get_elem | 77001 | 2.12% --------- Co-authored-by: Andy Kaylor <akaylor@nvidia.com> Co-authored-by: Henrich Lauko <xlauko@mail.muni.cz>
1 parent 754a11a commit c6a4f77

File tree

17 files changed

+339
-109
lines changed

17 files changed

+339
-109
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3132,6 +3132,58 @@ def CIR_GetMethodOp : CIR_Op<"get_method"> {
31323132
let hasVerifier = 1;
31333133
}
31343134

3135+
//===----------------------------------------------------------------------===//
3136+
// GetElementOp
3137+
//===----------------------------------------------------------------------===//
3138+
3139+
def CIR_GetElementOp : CIR_Op<"get_element"> {
3140+
let summary = "Get the address of an array element";
3141+
let description = [{
3142+
The `cir.get_element` operation gets the address of a particular element
3143+
from the `base` array.
3144+
3145+
It expects a pointer to the `base` array and the `index` of the element.
3146+
3147+
Example:
3148+
```mlir
3149+
// Suppose we have a array.
3150+
!s32i = !cir.int<s, 32>
3151+
!arr_ty = !cir.array<!s32i x 4>
3152+
3153+
// Get the address of the element at index 1.
3154+
%elem_1 = cir.get_element %0[1] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>
3155+
3156+
// Get the address of the element at index %i.
3157+
%i = ...
3158+
%elem_i = cir.get_element %0[%i] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>
3159+
```
3160+
}];
3161+
3162+
let arguments = (ins
3163+
Arg<CIR_PtrToArray, "the base address of the array ">:$base,
3164+
Arg<CIR_AnyFundamentalIntType, "the index of the element">:$index
3165+
);
3166+
3167+
let results = (outs CIR_PointerType:$result);
3168+
3169+
let assemblyFormat = [{
3170+
$base `[` $index `]` `:` `(` qualified(type($base)) `,` qualified(type($index)) `)`
3171+
`->` qualified(type($result)) attr-dict
3172+
}];
3173+
3174+
let extraClassDeclaration = [{
3175+
// Get the type of the element.
3176+
mlir::Type getElementType() {
3177+
return getType().getPointee();
3178+
}
3179+
cir::PointerType getBaseType() {
3180+
return mlir::cast<cir::PointerType>(getBase().getType());
3181+
}
3182+
}];
3183+
3184+
let hasVerifier = 1;
3185+
}
3186+
31353187
//===----------------------------------------------------------------------===//
31363188
// VecInsertOp
31373189
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,48 @@ mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc,
2828
return arrayPtr;
2929
}
3030

31-
mlir::Value CIRGenBuilderTy::getArrayElement(mlir::Location arrayLocBegin,
31+
mlir::Value CIRGenBuilderTy::promoteArrayIndex(const clang::TargetInfo &ti,
32+
mlir::Location loc,
33+
mlir::Value index) {
34+
// Get the array index type.
35+
auto arrayIndexWidth = ti.getTypeWidth(clang::TargetInfo::IntType::SignedInt);
36+
mlir::Type arrayIndexType = getSIntNTy(arrayIndexWidth);
37+
38+
// If this is a boolean, zero-extend it to the array index type.
39+
if (auto boolTy = mlir::dyn_cast<cir::BoolType>(index.getType()))
40+
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::bool_to_int,
41+
index);
42+
43+
// If this an integer, ensure that it is at least as width as the array index
44+
// type.
45+
if (auto intTy = mlir::dyn_cast<cir::IntType>(index.getType())) {
46+
if (intTy.getWidth() < arrayIndexWidth)
47+
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::integral,
48+
index);
49+
}
50+
51+
return index;
52+
}
53+
54+
mlir::Value CIRGenBuilderTy::getArrayElement(const clang::TargetInfo &ti,
55+
mlir::Location arrayLocBegin,
3256
mlir::Location arrayLocEnd,
3357
mlir::Value arrayPtr,
3458
mlir::Type eltTy, mlir::Value idx,
3559
bool shouldDecay) {
60+
auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(arrayPtr.getType());
61+
assert(arrayPtrTy && "expected pointer type");
62+
63+
// If the array pointer is not decayed, emit a GetElementOp.
64+
auto arrayTy = mlir::dyn_cast<cir::ArrayType>(arrayPtrTy.getPointee());
65+
if (shouldDecay && arrayTy && arrayTy == eltTy) {
66+
auto eltPtrTy =
67+
getPointerTo(arrayTy.getElementType(), arrayPtrTy.getAddrSpace());
68+
return create<cir::GetElementOp>(arrayLocEnd, eltPtrTy, arrayPtr,
69+
promoteArrayIndex(ti, arrayLocBegin, idx));
70+
}
71+
72+
// If we don't have sufficient type information, emit a PtrStrideOp.
3673
mlir::Value basePtr = arrayPtr;
3774
if (shouldDecay)
3875
basePtr = maybeBuildArrayDecay(arrayLocBegin, arrayPtr, eltTy);

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "clang/AST/Decl.h"
1818
#include "clang/AST/Type.h"
19+
#include "clang/Basic/TargetInfo.h"
1920
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
2021
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
2122
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
@@ -1030,10 +1031,15 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
10301031
return create<cir::GetRuntimeMemberOp>(loc, resultTy, objectPtr, memberPtr);
10311032
}
10321033

1034+
/// Promote a value for use as an array index.
1035+
mlir::Value promoteArrayIndex(const clang::TargetInfo &TargetInfo,
1036+
mlir::Location loc, mlir::Value index);
1037+
10331038
/// Create a cir.ptr_stride operation to get access to an array element.
10341039
/// idx is the index of the element to access, shouldDecay is true if the
10351040
/// result should decay to a pointer to the element type.
1036-
mlir::Value getArrayElement(mlir::Location arrayLocBegin,
1041+
mlir::Value getArrayElement(const clang::TargetInfo &targetInfo,
1042+
mlir::Location arrayLocBegin,
10371043
mlir::Location arrayLocEnd, mlir::Value arrayPtr,
10381044
mlir::Type eltTy, mlir::Value idx,
10391045
bool shouldDecay);

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,8 +1710,8 @@ emitArraySubscriptPtr(CIRGenFunction &CGF, mlir::Location beginLoc,
17101710
// that would enhance tracking this later in CIR?
17111711
if (inbounds)
17121712
assert(!cir::MissingFeatures::emitCheckedInBoundsGEP() && "NYI");
1713-
return CGM.getBuilder().getArrayElement(beginLoc, endLoc, ptr, eltTy, idx,
1714-
shouldDecay);
1713+
return CGM.getBuilder().getArrayElement(CGF.getTarget(), beginLoc, endLoc,
1714+
ptr, eltTy, idx, shouldDecay);
17151715
}
17161716

17171717
static QualType getFixedSizeElementType(const ASTContext &astContext,

clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -954,9 +954,9 @@ void AggExprEmitter::VisitCXXStdInitializerListExpr(
954954
ArrayType->getElementType()) &&
955955
"Expected std::initializer_list second field to be const E *");
956956

957-
auto ArrayEnd =
958-
Builder.getArrayElement(loc, loc, ArrayPtr.getPointer(),
959-
ArrayPtr.getElementType(), Size, false);
957+
auto ArrayEnd = Builder.getArrayElement(
958+
CGF.getTarget(), loc, loc, ArrayPtr.getPointer(),
959+
ArrayPtr.getElementType(), Size, false);
960960
CGF.emitStoreThroughLValue(RValue::get(ArrayEnd), EndOrLength);
961961
}
962962
}

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,6 +3837,18 @@ LogicalResult cir::GetMethodOp::verify() {
38373837
return mlir::success();
38383838
}
38393839

3840+
//===----------------------------------------------------------------------===//
3841+
// GetMemberOp Definitions
3842+
//===----------------------------------------------------------------------===//
3843+
3844+
LogicalResult cir::GetElementOp::verify() {
3845+
auto arrayTy = mlir::cast<cir::ArrayType>(getBaseType().getPointee());
3846+
if (getElementType() != arrayTy.getElementType())
3847+
return emitError() << "element type mismatch";
3848+
3849+
return mlir::success();
3850+
}
3851+
38403852
//===----------------------------------------------------------------------===//
38413853
// InlineAsmOp Definitions
38423854
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,12 @@ void LifetimeCheckPass::updatePointsTo(mlir::Value addr, mlir::Value data,
12541254
return;
12551255
}
12561256

1257+
if (auto getElemOp = mlir::dyn_cast<cir::GetElementOp>(dataSrcOp)) {
1258+
getPmap()[addr].clear();
1259+
getPmap()[addr].insert(State::getLocalValue(getElemOp.getBase()));
1260+
return;
1261+
}
1262+
12571263
// Initializes ptr types out of known lib calls marked with pointer
12581264
// attributes. TODO: find a better way to tag this.
12591265
if (auto callOp = dyn_cast<CallOp>(dataSrcOp)) {
@@ -1945,8 +1951,7 @@ void LifetimeCheckPass::dumpPmap(PMapType &pmap) {
19451951
int entry = 0;
19461952
for (auto &mapEntry : pmap) {
19471953
llvm::errs() << " " << entry << ": " << getVarNameFromValue(mapEntry.first)
1948-
<< " "
1949-
<< "=> ";
1954+
<< " => ";
19501955
printPset(mapEntry.second);
19511956
llvm::errs() << "\n";
19521957
entry++;

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 96 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,51 @@ static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
951951
return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
952952
}
953953

954+
static mlir::Value promoteIndex(mlir::ConversionPatternRewriter &rewriter,
955+
mlir::Value index, uint64_t layoutWidth,
956+
bool isUnsigned) {
957+
auto indexOp = index.getDefiningOp();
958+
if (!indexOp)
959+
return index;
960+
961+
auto indexType = mlir::cast<mlir::IntegerType>(index.getType());
962+
auto width = indexType.getWidth();
963+
if (layoutWidth == width)
964+
return index;
965+
966+
// If the index definition is a unary minus (index = sub 0, x), then we need
967+
// to
968+
bool rewriteSub = false;
969+
auto sub = mlir::dyn_cast<mlir::LLVM::SubOp>(indexOp);
970+
if (sub) {
971+
if (auto lhsConst = dyn_cast<mlir::LLVM::ConstantOp>(
972+
sub.getOperand(0).getDefiningOp())) {
973+
auto lhsConstInt = mlir::dyn_cast<mlir::IntegerAttr>(lhsConst.getValue());
974+
if (lhsConstInt && lhsConstInt.getValue() == 0) {
975+
rewriteSub = true;
976+
index = sub.getOperand(1);
977+
}
978+
}
979+
}
980+
981+
// Handle the cast
982+
auto llvmDstType = mlir::IntegerType::get(rewriter.getContext(), layoutWidth);
983+
index = getLLVMIntCast(rewriter, index, llvmDstType, isUnsigned, width,
984+
layoutWidth);
985+
986+
if (rewriteSub) {
987+
index = rewriter.create<mlir::LLVM::SubOp>(
988+
index.getLoc(),
989+
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(), index.getType(),
990+
0),
991+
index);
992+
// TODO: check if the sub is trivially dead now.
993+
rewriter.eraseOp(sub);
994+
}
995+
996+
return index;
997+
}
998+
954999
mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
9551000
cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
9561001
mlir::ConversionPatternRewriter &rewriter) const {
@@ -964,50 +1009,67 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
9641009
// make it i8 instead.
9651010
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
9661011
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
967-
elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
968-
mlir::IntegerType::Signless);
1012+
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);
9691013

9701014
// Zero-extend, sign-extend or trunc the pointer value.
9711015
auto index = adaptor.getStride();
972-
auto width = mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
9731016
mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType<mlir::ModuleOp>());
974-
auto layoutWidth =
975-
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
976-
auto indexOp = index.getDefiningOp();
977-
if (indexOp && layoutWidth && width != *layoutWidth) {
978-
// If the index comes from a subtraction, make sure the extension happens
979-
// before it. To achieve that, look at unary minus, which already got
980-
// lowered to "sub 0, x".
981-
auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
982-
auto unary = dyn_cast_if_present<cir::UnaryOp>(
983-
ptrStrideOp.getStride().getDefiningOp());
984-
bool rewriteSub =
985-
unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
986-
if (rewriteSub)
987-
index = indexOp->getOperand(1);
988-
989-
// Handle the cast
990-
auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
991-
index = getLLVMIntCast(rewriter, index, llvmDstType,
992-
ptrStrideOp.getStride().getType().isUnsigned(),
993-
width, *layoutWidth);
994-
995-
// Rewrite the sub in front of extensions/trunc
996-
if (rewriteSub) {
997-
index = rewriter.create<mlir::LLVM::SubOp>(
998-
index.getLoc(),
999-
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
1000-
index.getType(), 0),
1001-
index);
1002-
rewriter.eraseOp(sub);
1003-
}
1017+
if (auto layoutWidth =
1018+
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
1019+
bool isUnsigned = false;
1020+
if (auto strideTy =
1021+
mlir::dyn_cast<cir::IntType>(ptrStrideOp.getOperand(1).getType()))
1022+
isUnsigned = strideTy.isUnsigned();
1023+
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
10041024
}
10051025

10061026
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
10071027
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
10081028
return mlir::success();
10091029
}
10101030

1031+
mlir::LogicalResult CIRToLLVMGetElementOpLowering::matchAndRewrite(
1032+
cir::GetElementOp op, OpAdaptor adaptor,
1033+
mlir::ConversionPatternRewriter &rewriter) const {
1034+
1035+
if (auto arrayTy =
1036+
mlir::dyn_cast<cir::ArrayType>(op.getBaseType().getPointee())) {
1037+
auto *tc = getTypeConverter();
1038+
const auto llResultTy = tc->convertType(op.getType());
1039+
auto elementTy = convertTypeForMemory(*tc, dataLayout, op.getElementType());
1040+
auto *ctx = elementTy.getContext();
1041+
1042+
// void and function types doesn't really have a layout to use in GEPs,
1043+
// make it i8 instead.
1044+
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
1045+
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
1046+
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);
1047+
1048+
// Zero-extend, sign-extend or trunc the index value.
1049+
auto index = adaptor.getIndex();
1050+
mlir::DataLayout LLVMLayout(op->getParentOfType<mlir::ModuleOp>());
1051+
if (auto layoutWidth =
1052+
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
1053+
bool isUnsigned = false;
1054+
if (auto strideTy = dyn_cast<cir::IntType>(op.getOperand(1).getType()))
1055+
isUnsigned = strideTy.isUnsigned();
1056+
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
1057+
}
1058+
1059+
// Since the base address is a pointer to an aggregate, the first
1060+
// offset is always zero. The second offset tell us which member it
1061+
// will access.
1062+
const auto llArrayTy = getTypeConverter()->convertType(arrayTy);
1063+
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, index};
1064+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResultTy, llArrayTy,
1065+
adaptor.getBase(), offset);
1066+
1067+
return mlir::success();
1068+
}
1069+
1070+
llvm_unreachable("NYI, GetElementOp lowering to LLVM for non-Array");
1071+
}
1072+
10111073
mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
10121074
cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
10131075
mlir::ConversionPatternRewriter &rewriter) const {
@@ -4388,6 +4450,7 @@ void populateCIRToLLVMConversionPatterns(
43884450
patterns.add<
43894451
// clang-format off
43904452
CIRToLLVMPtrStrideOpLowering,
4453+
CIRToLLVMGetElementOpLowering,
43914454
CIRToLLVMInlineAsmOpLowering
43924455
// clang-format on
43934456
>(converter, patterns.getContext(), dataLayout);

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,22 @@ class CIRToLLVMPtrStrideOpLowering
158158
mlir::ConversionPatternRewriter &) const override;
159159
};
160160

161+
class CIRToLLVMGetElementOpLowering
162+
: public mlir::OpConversionPattern<cir::GetElementOp> {
163+
mlir::DataLayout const &dataLayout;
164+
165+
public:
166+
CIRToLLVMGetElementOpLowering(const mlir::TypeConverter &typeConverter,
167+
mlir::MLIRContext *context,
168+
mlir::DataLayout const &dataLayout)
169+
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
170+
using mlir::OpConversionPattern<cir::GetElementOp>::OpConversionPattern;
171+
172+
mlir::LogicalResult
173+
matchAndRewrite(cir::GetElementOp op, OpAdaptor,
174+
mlir::ConversionPatternRewriter &) const override;
175+
};
176+
161177
class CIRToLLVMBaseClassAddrOpLowering
162178
: public mlir::OpConversionPattern<cir::BaseClassAddrOp> {
163179
public:

0 commit comments

Comments
 (0)