Skip to content

Commit 62f3e26

Browse files
[CIR][ThroughMLIR] Lower structs and GetMemberOp.
1 parent a725efb commit 62f3e26

File tree

3 files changed

+117
-45
lines changed

3 files changed

+117
-45
lines changed

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,20 +569,22 @@ uint64_t RecordType::getElementOffset(const ::mlir::DataLayout &dataLayout,
569569
assert(idx < getNumElements());
570570
auto members = getMembers();
571571

572-
unsigned offset = 0;
572+
unsigned offset = 0, recordSize = 0;
573573

574-
for (unsigned i = 0, e = idx; i != e; ++i) {
574+
for (unsigned i = 0, e = idx; i != e + 1; ++i) {
575575
auto ty = members[i];
576576

577577
// This matches LLVM since it uses the ABI instead of preferred alignment.
578578
const llvm::Align tyAlign =
579579
llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));
580580

581581
// Add padding if necessary to align the data element properly.
582-
offset = llvm::alignTo(offset, tyAlign);
582+
recordSize = llvm::alignTo(recordSize, tyAlign);
583+
if (i == idx)
584+
offset = recordSize;
583585

584586
// Consume space for this data item
585-
offset += dataLayout.getTypeSize(ty);
587+
recordSize += dataLayout.getTypeSize(ty);
586588
}
587589

588590
// Account for padding, if necessary, for the alignment of the field whose
@@ -781,8 +783,8 @@ LongDoubleType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
781783
uint64_t
782784
LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
783785
mlir::DataLayoutEntryListRef params) const {
784-
return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying()).getABIAlignment(
785-
dataLayout, params);
786+
return mlir::cast<mlir::DataLayoutTypeInterface>(getUnderlying())
787+
.getABIAlignment(dataLayout, params);
786788
}
787789

788790
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "LowerToMLIRHelpers.h"
14+
#include "mlir/Analysis/DataLayoutAnalysis.h"
1415
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1516
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -35,8 +36,10 @@
3536
#include "mlir/IR/Operation.h"
3637
#include "mlir/IR/Region.h"
3738
#include "mlir/IR/TypeRange.h"
39+
#include "mlir/IR/Types.h"
3840
#include "mlir/IR/Value.h"
3941
#include "mlir/IR/ValueRange.h"
42+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
4043
#include "mlir/Pass/Pass.h"
4144
#include "mlir/Pass/PassManager.h"
4245
#include "mlir/Support/LLVM.h"
@@ -48,19 +51,17 @@
4851
#include "mlir/Transforms/DialectConversion.h"
4952
#include "clang/CIR/Dialect/IR/CIRDialect.h"
5053
#include "clang/CIR/Dialect/IR/CIRTypes.h"
54+
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
5155
#include "clang/CIR/LowerToLLVM.h"
5256
#include "clang/CIR/LowerToMLIR.h"
5357
#include "clang/CIR/LoweringHelpers.h"
5458
#include "clang/CIR/Passes.h"
5559
#include "llvm/ADT/STLExtras.h"
56-
#include "llvm/Support/ErrorHandling.h"
57-
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
58-
#include "clang/CIR/LowerToLLVM.h"
59-
#include "clang/CIR/Passes.h"
6060
#include "llvm/ADT/Sequence.h"
6161
#include "llvm/ADT/SmallVector.h"
6262
#include "llvm/ADT/TypeSwitch.h"
6363
#include "llvm/IR/Value.h"
64+
#include "llvm/Support/ErrorHandling.h"
6465
#include "llvm/Support/TimeProfiler.h"
6566

6667
using namespace cir;
@@ -288,17 +289,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
288289
matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor,
289290
mlir::ConversionPatternRewriter &rewriter) const override {
290291

291-
mlir::Type mlirType =
292-
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
292+
mlir::Type allocaType = adaptor.getAllocaType();
293+
mlir::Type mlirType = convertTypeForMemory(*getTypeConverter(), allocaType);
293294

294295
// FIXME: Some types can not be converted yet (e.g. struct)
295296
if (!mlirType)
296297
return mlir::LogicalResult::failure();
297298

298299
auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
299-
if (memreftype && mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
300-
// if the type is an array,
301-
// we don't need to wrap with memref.
300+
if (memreftype && (mlir::isa<cir::ArrayType>(allocaType) ||
301+
mlir::isa<cir::RecordType>(allocaType))) {
302+
// Arrays and structs are already memref. No need to wrap another one.
302303
} else {
303304
memreftype = mlir::MemRefType::get({}, mlirType);
304305
}
@@ -946,8 +947,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946947
} else {
947948
// For scopes with results, use scf.execute_region
948949
SmallVector<mlir::Type> types;
949-
if (mlir::failed(
950-
getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types)))
950+
if (mlir::failed(getTypeConverter()->convertTypes(
951+
scopeOp->getResultTypes(), types)))
951952
return mlir::failure();
952953
auto exec =
953954
rewriter.create<mlir::scf::ExecuteRegionOp>(scopeOp.getLoc(), types);
@@ -1485,6 +1486,35 @@ class CIRPtrStrideOpLowering
14851486
}
14861487
};
14871488

1489+
class CIRGetMemberOpLowering
1490+
: public mlir::OpConversionPattern<cir::GetMemberOp> {
1491+
public:
1492+
CIRGetMemberOpLowering(mlir::TypeConverter &converter, mlir::MLIRContext *ctx,
1493+
const mlir::DataLayout &layout)
1494+
: OpConversionPattern(converter, ctx), layout(layout) {}
1495+
mlir::LogicalResult
1496+
matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor,
1497+
mlir::ConversionPatternRewriter &rewriter) const override {
1498+
auto baseAddr = op.getAddr();
1499+
auto structType =
1500+
mlir::cast<cir::RecordType>(baseAddr.getType().getPointee());
1501+
1502+
uint64_t byteOffset = structType.getElementOffset(layout, op.getIndex());
1503+
auto fieldType = op.getResult().getType();
1504+
1505+
auto resultType = mlir::cast<mlir::MemRefType>(
1506+
getTypeConverter()->convertType(fieldType));
1507+
mlir::Value offsetValue =
1508+
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), byteOffset);
1509+
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
1510+
op, resultType, adaptor.getAddr(), offsetValue, mlir::ValueRange{});
1511+
return mlir::success();
1512+
}
1513+
1514+
private:
1515+
const mlir::DataLayout &layout;
1516+
};
1517+
14881518
class CIRUnreachableOpLowering
14891519
: public mlir::OpConversionPattern<cir::UnreachableOp> {
14901520
public:
@@ -1516,37 +1546,41 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15161546
};
15171547

15181548
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1519-
mlir::TypeConverter &converter) {
1549+
mlir::TypeConverter &converter,
1550+
mlir::DataLayout layout) {
15201551
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
15211552

1522-
patterns
1523-
.add<CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524-
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525-
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526-
CIRFuncOpLowering, CIRBrCondOpLowering,
1527-
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529-
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530-
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531-
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532-
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533-
CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534-
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535-
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536-
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537-
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538-
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539-
CIRTrapOpLowering>(converter, patterns.getContext());
1553+
patterns.add<
1554+
CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1555+
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1556+
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1557+
CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1558+
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1559+
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1560+
CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1561+
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, CIRAbsOpLowering,
1562+
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1563+
CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1564+
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1565+
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1566+
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1567+
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1568+
CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1569+
CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext());
1570+
1571+
patterns.add<CIRGetMemberOpLowering>(converter, patterns.getContext(),
1572+
layout);
15401573
}
15411574

1542-
static mlir::TypeConverter prepareTypeConverter() {
1575+
static mlir::TypeConverter prepareTypeConverter(mlir::DataLayout layout) {
15431576
mlir::TypeConverter converter;
15441577
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
1545-
auto ty = convertTypeForMemory(converter, type.getPointee());
1578+
auto pointee = type.getPointee();
1579+
auto ty = convertTypeForMemory(converter, pointee);
15461580
// FIXME: The pointee type might not be converted (e.g. struct)
15471581
if (!ty)
15481582
return nullptr;
1549-
if (isa<cir::ArrayType>(type.getPointee()))
1583+
if (isa<cir::ArrayType>(pointee) || isa<cir::RecordType>(pointee))
15501584
return ty;
15511585
return mlir::MemRefType::get({}, ty);
15521586
});
@@ -1598,6 +1632,13 @@ static mlir::TypeConverter prepareTypeConverter() {
15981632
return nullptr;
15991633
return mlir::MemRefType::get(shape, elementType);
16001634
});
1635+
converter.addConversion([&](cir::RecordType type) -> mlir::Type {
1636+
// Reinterpret structs as raw bytes. Don't use tuples as they can't be put
1637+
// in memref.
1638+
auto size = type.getTypeSize(layout, {});
1639+
auto i8 = mlir::IntegerType::get(type.getContext(), /*width=*/8);
1640+
return mlir::MemRefType::get(size.getFixedValue(), i8);
1641+
});
16011642
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
16021643
auto ty = converter.convertType(type.getElementType());
16031644
return mlir::VectorType::get(type.getSize(), ty);
@@ -1609,12 +1650,15 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16091650
mlir::MLIRContext *context = &getContext();
16101651
mlir::ModuleOp theModule = getOperation();
16111652

1612-
auto converter = prepareTypeConverter();
1613-
1653+
mlir::DataLayoutAnalysis layoutAnalysis(theModule);
1654+
const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove(theModule);
1655+
1656+
auto converter = prepareTypeConverter(layout);
1657+
16141658
mlir::RewritePatternSet patterns(&getContext());
16151659

16161660
populateCIRLoopToSCFConversionPatterns(patterns, converter);
1617-
populateCIRToMLIRConversionPatterns(patterns, converter);
1661+
populateCIRToMLIRConversionPatterns(patterns, converter, layout);
16181662

16191663
mlir::ConversionTarget target(getContext());
16201664
target.addLegalOp<mlir::ModuleOp>();
@@ -1628,10 +1672,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281672
// cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291673
// here, then MLIR would think any remaining `cir.continue` indicates a
16301674
// failure, which is not what we want.
1631-
1632-
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331675

1634-
if (mlir::failed(mlir::applyPartialConversion(theModule, target,
1676+
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1677+
CIRYieldOpLowering>(converter, context);
1678+
1679+
if (mlir::failed(mlir::applyPartialConversion(theModule, target,
16351680
std::move(patterns)))) {
16361681
signalPassFailure();
16371682
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: cir-opt %s -cir-to-mlir -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
!u8i = !cir.int<u, 8>
6+
!u32i = !cir.int<u, 32>
7+
!ty_S = !cir.record<struct "S" {!u8i, !s32i}>
8+
9+
module {
10+
cir.func @test() {
11+
%1 = cir.alloca !ty_S, !cir.ptr<!ty_S>, ["x"] {alignment = 4 : i64}
12+
%3 = cir.get_member %1[0] {name = "c"} : !cir.ptr<!ty_S> -> !cir.ptr<!u8i>
13+
%5 = cir.get_member %1[1] {name = "i"} : !cir.ptr<!ty_S> -> !cir.ptr<!s32i>
14+
cir.return
15+
}
16+
17+
// CHECK: func.func @test() {
18+
// CHECK: %[[alloca:[a-z0-9]+]] = memref.alloca() {alignment = 4 : i64} : memref<8xi8>
19+
// CHECK: %[[zero:[a-z0-9]+]] = arith.constant 0 : index
20+
// CHECK: memref.view %[[alloca]][%[[zero]]][] : memref<8xi8> to memref<i8>
21+
// CHECK: %[[four:[a-z0-9]+]] = arith.constant 4 : index
22+
// CHECK: %view_0 = memref.view %[[alloca]][%[[four]]][] : memref<8xi8> to memref<i32>
23+
// CHECK: return
24+
// CHECK: }
25+
}

0 commit comments

Comments
 (0)