1111// ===----------------------------------------------------------------------===//
1212
1313#include " LowerToMLIRHelpers.h"
14+ #include " mlir/Analysis/DataLayoutAnalysis.h"
1415#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
1516#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
3536#include " mlir/IR/Operation.h"
3637#include " mlir/IR/Region.h"
3738#include " mlir/IR/TypeRange.h"
39+ #include " mlir/IR/Types.h"
3840#include " mlir/IR/Value.h"
3941#include " mlir/IR/ValueRange.h"
42+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
4043#include " mlir/Pass/Pass.h"
4144#include " mlir/Pass/PassManager.h"
4245#include " mlir/Support/LLVM.h"
4851#include " mlir/Transforms/DialectConversion.h"
4952#include " clang/CIR/Dialect/IR/CIRDialect.h"
5053#include " clang/CIR/Dialect/IR/CIRTypes.h"
54+ #include " clang/CIR/Interfaces/CIRLoopOpInterface.h"
5155#include " clang/CIR/LowerToLLVM.h"
5256#include " clang/CIR/LowerToMLIR.h"
5357#include " clang/CIR/LoweringHelpers.h"
5458#include " clang/CIR/Passes.h"
5559#include " llvm/ADT/STLExtras.h"
56- #include " llvm/Support/ErrorHandling.h"
57- #include " clang/CIR/Interfaces/CIRLoopOpInterface.h"
58- #include " clang/CIR/LowerToLLVM.h"
59- #include " clang/CIR/Passes.h"
6060#include " llvm/ADT/Sequence.h"
6161#include " llvm/ADT/SmallVector.h"
6262#include " llvm/ADT/TypeSwitch.h"
6363#include " llvm/IR/Value.h"
64+ #include " llvm/Support/ErrorHandling.h"
6465#include " llvm/Support/TimeProfiler.h"
6566
6667using namespace cir ;
@@ -288,17 +289,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
288289 matchAndRewrite (cir::AllocaOp op, OpAdaptor adaptor,
289290 mlir::ConversionPatternRewriter &rewriter) const override {
290291
291- mlir::Type mlirType =
292- convertTypeForMemory (*getTypeConverter (), adaptor. getAllocaType () );
292+ mlir::Type allocaType = adaptor. getAllocaType ();
293+ mlir::Type mlirType = convertTypeForMemory (*getTypeConverter (), allocaType );
293294
294295 // FIXME: Some types can not be converted yet (e.g. struct)
295296 if (!mlirType)
296297 return mlir::LogicalResult::failure ();
297298
298299 auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
299- if (memreftype && mlir::isa<cir::ArrayType>(adaptor. getAllocaType ())) {
300- // if the type is an array,
301- // we don't need to wrap with memref .
300+ if (memreftype && ( mlir::isa<cir::ArrayType>(allocaType) ||
301+ mlir::isa<cir::RecordType>(allocaType))) {
302+ // Arrays and structs are already memref. No need to wrap another one .
302303 } else {
303304 memreftype = mlir::MemRefType::get ({}, mlirType);
304305 }
@@ -946,8 +947,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946947 } else {
947948 // For scopes with results, use scf.execute_region
948949 SmallVector<mlir::Type> types;
949- if (mlir::failed (
950- getTypeConverter ()-> convertTypes ( scopeOp->getResultTypes (), types)))
950+ if (mlir::failed (getTypeConverter ()-> convertTypes (
951+ scopeOp->getResultTypes (), types)))
951952 return mlir::failure ();
952953 auto exec =
953954 rewriter.create <mlir::scf::ExecuteRegionOp>(scopeOp.getLoc (), types);
@@ -1485,6 +1486,35 @@ class CIRPtrStrideOpLowering
14851486 }
14861487};
14871488
1489+ class CIRGetMemberOpLowering
1490+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
1491+ public:
1492+ CIRGetMemberOpLowering (mlir::TypeConverter &converter, mlir::MLIRContext *ctx,
1493+ const mlir::DataLayout &layout)
1494+ : OpConversionPattern(converter, ctx), layout(layout) {}
1495+ mlir::LogicalResult
1496+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
1497+ mlir::ConversionPatternRewriter &rewriter) const override {
1498+ auto baseAddr = op.getAddr ();
1499+ auto structType =
1500+ mlir::cast<cir::RecordType>(baseAddr.getType ().getPointee ());
1501+
1502+ uint64_t byteOffset = structType.getElementOffset (layout, op.getIndex ());
1503+ auto fieldType = op.getResult ().getType ();
1504+
1505+ auto resultType = mlir::cast<mlir::MemRefType>(
1506+ getTypeConverter ()->convertType (fieldType));
1507+ mlir::Value offsetValue =
1508+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), byteOffset);
1509+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
1510+ op, resultType, adaptor.getAddr (), offsetValue, mlir::ValueRange{});
1511+ return mlir::success ();
1512+ }
1513+
1514+ private:
1515+ const mlir::DataLayout &layout;
1516+ };
1517+
14881518class CIRUnreachableOpLowering
14891519 : public mlir::OpConversionPattern<cir::UnreachableOp> {
14901520public:
@@ -1516,37 +1546,41 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15161546};
15171547
15181548void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1519- mlir::TypeConverter &converter) {
1549+ mlir::TypeConverter &converter,
1550+ mlir::DataLayout layout) {
15201551 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15211552
1522- patterns
1523- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526- CIRFuncOpLowering, CIRBrCondOpLowering,
1527- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529- CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530- CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531- CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533- CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534- CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535- CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536- CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537- CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538- CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539- CIRTrapOpLowering>(converter, patterns.getContext ());
1553+ patterns.add <
1554+ CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1555+ CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1556+ CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1557+ CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1558+ CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1559+ CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1560+ CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1561+ CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, CIRAbsOpLowering,
1562+ CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1563+ CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1564+ CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1565+ CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1566+ CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1567+ CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1568+ CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1569+ CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext ());
1570+
1571+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1572+ layout);
15401573}
15411574
1542- static mlir::TypeConverter prepareTypeConverter () {
1575+ static mlir::TypeConverter prepareTypeConverter (mlir::DataLayout layout ) {
15431576 mlir::TypeConverter converter;
15441577 converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
1545- auto ty = convertTypeForMemory (converter, type.getPointee ());
1578+ auto pointee = type.getPointee ();
1579+ auto ty = convertTypeForMemory (converter, pointee);
15461580 // FIXME: The pointee type might not be converted (e.g. struct)
15471581 if (!ty)
15481582 return nullptr ;
1549- if (isa<cir::ArrayType>(type. getPointee () ))
1583+ if (isa<cir::ArrayType>(pointee) || isa<cir::RecordType>(pointee ))
15501584 return ty;
15511585 return mlir::MemRefType::get ({}, ty);
15521586 });
@@ -1598,6 +1632,13 @@ static mlir::TypeConverter prepareTypeConverter() {
15981632 return nullptr ;
15991633 return mlir::MemRefType::get (shape, elementType);
16001634 });
1635+ converter.addConversion ([&](cir::RecordType type) -> mlir::Type {
1636+ // Reinterpret structs as raw bytes. Don't use tuples as they can't be put
1637+ // in memref.
1638+ auto size = type.getTypeSize (layout, {});
1639+ auto i8 = mlir::IntegerType::get (type.getContext (), /* width=*/ 8 );
1640+ return mlir::MemRefType::get (size.getFixedValue (), i8 );
1641+ });
16011642 converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
16021643 auto ty = converter.convertType (type.getElementType ());
16031644 return mlir::VectorType::get (type.getSize (), ty);
@@ -1609,12 +1650,15 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16091650 mlir::MLIRContext *context = &getContext ();
16101651 mlir::ModuleOp theModule = getOperation ();
16111652
1612- auto converter = prepareTypeConverter ();
1613-
1653+ mlir::DataLayoutAnalysis layoutAnalysis (theModule);
1654+ const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove (theModule);
1655+
1656+ auto converter = prepareTypeConverter (layout);
1657+
16141658 mlir::RewritePatternSet patterns (&getContext ());
16151659
16161660 populateCIRLoopToSCFConversionPatterns (patterns, converter);
1617- populateCIRToMLIRConversionPatterns (patterns, converter);
1661+ populateCIRToMLIRConversionPatterns (patterns, converter, layout );
16181662
16191663 mlir::ConversionTarget target (getContext ());
16201664 target.addLegalOp <mlir::ModuleOp>();
@@ -1628,10 +1672,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281672 // cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291673 // here, then MLIR would think any remaining `cir.continue` indicates a
16301674 // failure, which is not what we want.
1631-
1632- patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331675
1634- if (mlir::failed (mlir::applyPartialConversion (theModule, target,
1676+ patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1677+ CIRYieldOpLowering>(converter, context);
1678+
1679+ if (mlir::failed (mlir::applyPartialConversion (theModule, target,
16351680 std::move (patterns)))) {
16361681 signalPassFailure ();
16371682 }
0 commit comments