Skip to content
Merged
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3132,6 +3132,58 @@ def CIR_GetMethodOp : CIR_Op<"get_method"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// GetElementOp
//===----------------------------------------------------------------------===//

def CIR_GetElementOp : CIR_Op<"get_element"> {
let summary = "Get the address of an array element";
let description = [{
The `cir.get_element` operation gets the address of a particular element
from the `base` array.

It expects a pointer to the `base` array and the `index` of the element.

Example:
```mlir
// Suppose we have a array.
!s32i = !cir.int<s, 32>
!arr_ty = !cir.array<!s32i x 4>

// Get the address of the element at index 1.
%elem_1 = cir.get_element %0[1] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>

// Get the address of the element at index %i.
%i = ...
%elem_i = cir.get_element %0[%i] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>
```
}];

let arguments = (ins
Arg<CIR_PtrToArray, "the base address of the array ">:$base,
Arg<CIR_AnyFundamentalIntType, "the index of the element">:$index
);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
$base `[` $index `]` `:` `(` qualified(type($base)) `,` qualified(type($index)) `)`
`->` qualified(type($result)) attr-dict
}];

let extraClassDeclaration = [{
// Get the type of the element.
mlir::Type getElementType() {
return getType().getPointee();
}
cir::PointerType getBaseType() {
return mlir::cast<cir::PointerType>(getBase().getType());
}
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecInsertOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 38 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,48 @@ mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc,
return arrayPtr;
}

mlir::Value CIRGenBuilderTy::getArrayElement(mlir::Location arrayLocBegin,
mlir::Value CIRGenBuilderTy::promoteArrayIndex(const clang::TargetInfo &ti,
mlir::Location loc,
mlir::Value index) {
// Get the array index type.
auto arrayIndexWidth = ti.getTypeWidth(clang::TargetInfo::IntType::SignedInt);
mlir::Type arrayIndexType = getSIntNTy(arrayIndexWidth);

// If this is a boolean, zero-extend it to the array index type.
if (auto boolTy = mlir::dyn_cast<cir::BoolType>(index.getType()))
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::bool_to_int,
index);

// If this an integer, ensure that it is at least as width as the array index
// type.
if (auto intTy = mlir::dyn_cast<cir::IntType>(index.getType())) {
if (intTy.getWidth() < arrayIndexWidth)
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::integral,
index);
}

return index;
}

mlir::Value CIRGenBuilderTy::getArrayElement(const clang::TargetInfo &ti,
mlir::Location arrayLocBegin,
mlir::Location arrayLocEnd,
mlir::Value arrayPtr,
mlir::Type eltTy, mlir::Value idx,
bool shouldDecay) {
auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(arrayPtr.getType());
assert(arrayPtrTy && "expected pointer type");

// If the array pointer is not decayed, emit a GetElementOp.
auto arrayTy = mlir::dyn_cast<cir::ArrayType>(arrayPtrTy.getPointee());
if (shouldDecay && arrayTy && arrayTy == eltTy) {
auto eltPtrTy =
getPointerTo(arrayTy.getElementType(), arrayPtrTy.getAddrSpace());
return create<cir::GetElementOp>(arrayLocEnd, eltPtrTy, arrayPtr,
promoteArrayIndex(ti, arrayLocBegin, idx));
}

// If we don't have sufficient type information, emit a PtrStrideOp.
mlir::Value basePtr = arrayPtr;
if (shouldDecay)
basePtr = maybeBuildArrayDecay(arrayLocBegin, arrayPtr, eltTy);
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "clang/AST/Decl.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
Expand Down Expand Up @@ -1030,10 +1031,15 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return create<cir::GetRuntimeMemberOp>(loc, resultTy, objectPtr, memberPtr);
}

/// Promote a value for use as an array index.
mlir::Value promoteArrayIndex(const clang::TargetInfo &TargetInfo,
mlir::Location loc, mlir::Value index);

/// Create a cir.ptr_stride operation to get access to an array element.
/// idx is the index of the element to access, shouldDecay is true if the
/// result should decay to a pointer to the element type.
mlir::Value getArrayElement(mlir::Location arrayLocBegin,
mlir::Value getArrayElement(const clang::TargetInfo &targetInfo,
mlir::Location arrayLocBegin,
mlir::Location arrayLocEnd, mlir::Value arrayPtr,
mlir::Type eltTy, mlir::Value idx,
bool shouldDecay);
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1710,8 +1710,8 @@ emitArraySubscriptPtr(CIRGenFunction &CGF, mlir::Location beginLoc,
// that would enhance tracking this later in CIR?
if (inbounds)
assert(!cir::MissingFeatures::emitCheckedInBoundsGEP() && "NYI");
return CGM.getBuilder().getArrayElement(beginLoc, endLoc, ptr, eltTy, idx,
shouldDecay);
return CGM.getBuilder().getArrayElement(CGF.getTarget(), beginLoc, endLoc,
ptr, eltTy, idx, shouldDecay);
}

static QualType getFixedSizeElementType(const ASTContext &astContext,
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,9 @@ void AggExprEmitter::VisitCXXStdInitializerListExpr(
ArrayType->getElementType()) &&
"Expected std::initializer_list second field to be const E *");

auto ArrayEnd =
Builder.getArrayElement(loc, loc, ArrayPtr.getPointer(),
ArrayPtr.getElementType(), Size, false);
auto ArrayEnd = Builder.getArrayElement(
CGF.getTarget(), loc, loc, ArrayPtr.getPointer(),
ArrayPtr.getElementType(), Size, false);
CGF.emitStoreThroughLValue(RValue::get(ArrayEnd), EndOrLength);
}
}
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,18 @@ LogicalResult cir::GetMethodOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// GetMemberOp Definitions
//===----------------------------------------------------------------------===//

LogicalResult cir::GetElementOp::verify() {
auto arrayTy = mlir::cast<cir::ArrayType>(getBaseType().getPointee());
if (getElementType() != arrayTy.getElementType())
return emitError() << "element type mismatch";

return mlir::success();
}

//===----------------------------------------------------------------------===//
// InlineAsmOp Definitions
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 7 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,12 @@ void LifetimeCheckPass::updatePointsTo(mlir::Value addr, mlir::Value data,
return;
}

if (auto getElemOp = mlir::dyn_cast<cir::GetElementOp>(dataSrcOp)) {
getPmap()[addr].clear();
getPmap()[addr].insert(State::getLocalValue(getElemOp.getBase()));
return;
}

// Initializes ptr types out of known lib calls marked with pointer
// attributes. TODO: find a better way to tag this.
if (auto callOp = dyn_cast<CallOp>(dataSrcOp)) {
Expand Down Expand Up @@ -1945,8 +1951,7 @@ void LifetimeCheckPass::dumpPmap(PMapType &pmap) {
int entry = 0;
for (auto &mapEntry : pmap) {
llvm::errs() << " " << entry << ": " << getVarNameFromValue(mapEntry.first)
<< " "
<< "=> ";
<< " => ";
printPset(mapEntry.second);
llvm::errs() << "\n";
entry++;
Expand Down
129 changes: 96 additions & 33 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,51 @@ static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
}

static mlir::Value promoteIndex(mlir::ConversionPatternRewriter &rewriter,
mlir::Value index, uint64_t layoutWidth,
bool isUnsigned) {
auto indexOp = index.getDefiningOp();
if (!indexOp)
return index;

auto indexType = mlir::cast<mlir::IntegerType>(index.getType());
auto width = indexType.getWidth();
if (layoutWidth == width)
return index;

// If the index definition is a unary minus (index = sub 0, x), then we need
// to
bool rewriteSub = false;
auto sub = mlir::dyn_cast<mlir::LLVM::SubOp>(indexOp);
if (sub) {
if (auto lhsConst = dyn_cast<mlir::LLVM::ConstantOp>(
sub.getOperand(0).getDefiningOp())) {
auto lhsConstInt = mlir::dyn_cast<mlir::IntegerAttr>(lhsConst.getValue());
if (lhsConstInt && lhsConstInt.getValue() == 0) {
rewriteSub = true;
index = sub.getOperand(1);
}
}
}

// Handle the cast
auto llvmDstType = mlir::IntegerType::get(rewriter.getContext(), layoutWidth);
index = getLLVMIntCast(rewriter, index, llvmDstType, isUnsigned, width,
layoutWidth);

if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(), index.getType(),
0),
index);
// TODO: check if the sub is trivially dead now.
rewriter.eraseOp(sub);
}

return index;
}

mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -964,50 +1009,67 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
// make it i8 instead.
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
mlir::IntegerType::Signless);
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);

// Zero-extend, sign-extend or trunc the pointer value.
auto index = adaptor.getStride();
auto width = mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType<mlir::ModuleOp>());
auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
auto indexOp = index.getDefiningOp();
if (indexOp && layoutWidth && width != *layoutWidth) {
// If the index comes from a subtraction, make sure the extension happens
// before it. To achieve that, look at unary minus, which already got
// lowered to "sub 0, x".
auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
auto unary = dyn_cast_if_present<cir::UnaryOp>(
ptrStrideOp.getStride().getDefiningOp());
bool rewriteSub =
unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
if (rewriteSub)
index = indexOp->getOperand(1);

// Handle the cast
auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
index = getLLVMIntCast(rewriter, index, llvmDstType,
ptrStrideOp.getStride().getType().isUnsigned(),
width, *layoutWidth);

// Rewrite the sub in front of extensions/trunc
if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
index.getType(), 0),
index);
rewriter.eraseOp(sub);
}
if (auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
bool isUnsigned = false;
if (auto strideTy =
mlir::dyn_cast<cir::IntType>(ptrStrideOp.getOperand(1).getType()))
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMGetElementOpLowering::matchAndRewrite(
cir::GetElementOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {

if (auto arrayTy =
mlir::dyn_cast<cir::ArrayType>(op.getBaseType().getPointee())) {
auto *tc = getTypeConverter();
const auto llResultTy = tc->convertType(op.getType());
auto elementTy = convertTypeForMemory(*tc, dataLayout, op.getElementType());
auto *ctx = elementTy.getContext();

// void and function types doesn't really have a layout to use in GEPs,
// make it i8 instead.
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);

// Zero-extend, sign-extend or trunc the index value.
auto index = adaptor.getIndex();
mlir::DataLayout LLVMLayout(op->getParentOfType<mlir::ModuleOp>());
if (auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
bool isUnsigned = false;
if (auto strideTy = dyn_cast<cir::IntType>(op.getOperand(1).getType()))
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

// Since the base address is a pointer to an aggregate, the first
// offset is always zero. The second offset tell us which member it
// will access.
const auto llArrayTy = getTypeConverter()->convertType(arrayTy);
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, index};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResultTy, llArrayTy,
adaptor.getBase(), offset);

return mlir::success();
}

llvm_unreachable("NYI, GetElementOp lowering to LLVM for non-Array");
}

mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -4388,6 +4450,7 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<
// clang-format off
CIRToLLVMPtrStrideOpLowering,
CIRToLLVMGetElementOpLowering,
CIRToLLVMInlineAsmOpLowering
// clang-format on
>(converter, patterns.getContext(), dataLayout);
Expand Down
16 changes: 16 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,22 @@ class CIRToLLVMPtrStrideOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMGetElementOpLowering
: public mlir::OpConversionPattern<cir::GetElementOp> {
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMGetElementOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
using mlir::OpConversionPattern<cir::GetElementOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::GetElementOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBaseClassAddrOpLowering
: public mlir::OpConversionPattern<cir::BaseClassAddrOp> {
public:
Expand Down
Loading
Loading