1111// ===----------------------------------------------------------------------===//
1212
1313#include  " LowerToMLIRHelpers.h" 
14+ #include  " mlir/Analysis/DataLayoutAnalysis.h" 
1415#include  " mlir/Conversion/AffineToStandard/AffineToStandard.h" 
1516#include  " mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 
1617#include  " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 
3536#include  " mlir/IR/Operation.h" 
3637#include  " mlir/IR/Region.h" 
3738#include  " mlir/IR/TypeRange.h" 
39+ #include  " mlir/IR/Types.h" 
3840#include  " mlir/IR/Value.h" 
3941#include  " mlir/IR/ValueRange.h" 
42+ #include  " mlir/Interfaces/DataLayoutInterfaces.h" 
4043#include  " mlir/Pass/Pass.h" 
4144#include  " mlir/Pass/PassManager.h" 
4245#include  " mlir/Support/LLVM.h" 
4851#include  " mlir/Transforms/DialectConversion.h" 
4952#include  " clang/CIR/Dialect/IR/CIRDialect.h" 
5053#include  " clang/CIR/Dialect/IR/CIRTypes.h" 
54+ #include  " clang/CIR/Interfaces/CIRLoopOpInterface.h" 
5155#include  " clang/CIR/LowerToLLVM.h" 
5256#include  " clang/CIR/LowerToMLIR.h" 
5357#include  " clang/CIR/LoweringHelpers.h" 
5458#include  " clang/CIR/Passes.h" 
5559#include  " llvm/ADT/STLExtras.h" 
56- #include  " llvm/Support/ErrorHandling.h" 
57- #include  " clang/CIR/Interfaces/CIRLoopOpInterface.h" 
58- #include  " clang/CIR/LowerToLLVM.h" 
59- #include  " clang/CIR/Passes.h" 
6060#include  " llvm/ADT/Sequence.h" 
6161#include  " llvm/ADT/SmallVector.h" 
6262#include  " llvm/ADT/TypeSwitch.h" 
6363#include  " llvm/IR/Value.h" 
64+ #include  " llvm/Support/ErrorHandling.h" 
6465#include  " llvm/Support/TimeProfiler.h" 
6566
6667using  namespace  cir ; 
@@ -288,17 +289,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
288289  matchAndRewrite (cir::AllocaOp op, OpAdaptor adaptor,
289290                  mlir::ConversionPatternRewriter &rewriter) const  override  {
290291
291-     mlir::Type mlirType = 
292-          convertTypeForMemory (*getTypeConverter (), adaptor. getAllocaType () );
292+     mlir::Type allocaType = adaptor. getAllocaType (); 
293+     mlir::Type mlirType =  convertTypeForMemory (*getTypeConverter (), allocaType );
293294
294295    //  FIXME: Some types can not be converted yet (e.g. struct)
295296    if  (!mlirType)
296297      return  mlir::LogicalResult::failure ();
297298
298299    auto  memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
299-     if  (memreftype && mlir::isa<cir::ArrayType>(adaptor. getAllocaType ())) { 
300-       //  if the type is an array, 
301-       //  we don't  need to wrap with memref .
300+     if  (memreftype && ( mlir::isa<cir::ArrayType>(allocaType) || 
301+                        mlir::isa<cir::RecordType>(allocaType))) { 
302+       //  Arrays and structs are already memref. No  need to wrap another one .
302303    } else  {
303304      memreftype = mlir::MemRefType::get ({}, mlirType);
304305    }
@@ -946,8 +947,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946947    } else  {
947948      //  For scopes with results, use scf.execute_region
948949      SmallVector<mlir::Type> types;
949-       if  (mlir::failed (
950-               getTypeConverter ()-> convertTypes ( scopeOp->getResultTypes (), types)))
950+       if  (mlir::failed (getTypeConverter ()-> convertTypes ( 
951+               scopeOp->getResultTypes (), types)))
951952        return  mlir::failure ();
952953      auto  exec =
953954          rewriter.create <mlir::scf::ExecuteRegionOp>(scopeOp.getLoc (), types);
@@ -1485,6 +1486,51 @@ class CIRPtrStrideOpLowering
14851486  }
14861487};
14871488
1489+ class  CIRGetMemberOpLowering 
1490+     : public mlir::OpConversionPattern<cir::GetMemberOp> {
1491+ public: 
1492+   CIRGetMemberOpLowering (mlir::TypeConverter &converter, mlir::MLIRContext *ctx,
1493+                          const  mlir::DataLayout &layout)
1494+       : OpConversionPattern(converter, ctx), layout(layout) {}
1495+ 
1496+   mlir::LogicalResult
1497+   matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
1498+                   mlir::ConversionPatternRewriter &rewriter) const  override  {
1499+ 
1500+     cir::PointerType ptrType = op.getAddr ().getType ();
1501+ 
1502+     auto  structType = mlir::dyn_cast<cir::RecordType>(ptrType.getPointee ());
1503+     if  (!structType) {
1504+       return  rewriter.notifyMatchFailure (
1505+           op, " expected RecordType as pointee of GetMemberOp base"  );
1506+     }
1507+ 
1508+     uint64_t  byteOffset = structType.getElementOffset (layout, op.getIndex ());
1509+     cir::PointerType fieldType = op.getResult ().getType ();
1510+     auto  convertedType = getTypeConverter ()->convertType (fieldType);
1511+     if  (!convertedType) {
1512+       return  rewriter.notifyMatchFailure (op, " failed to convert field type"  );
1513+     }
1514+ 
1515+     auto  resultType = mlir::dyn_cast<mlir::MemRefType>(convertedType);
1516+     if  (!resultType) {
1517+       return  rewriter.notifyMatchFailure (
1518+           op, " expected MemRefType after type conversion"  );
1519+     }
1520+ 
1521+     mlir::Value offsetValue =
1522+         rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), byteOffset);
1523+ 
1524+     rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
1525+         op, resultType, adaptor.getAddr (), offsetValue, mlir::ValueRange{});
1526+ 
1527+     return  mlir::success ();
1528+   }
1529+ 
1530+ private: 
1531+   const  mlir::DataLayout &layout;
1532+ };
1533+ 
14881534class  CIRUnreachableOpLowering 
14891535    : public mlir::OpConversionPattern<cir::UnreachableOp> {
14901536public: 
@@ -1516,37 +1562,41 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15161562};
15171563
15181564void  populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1519-                                          mlir::TypeConverter &converter) {
1565+                                          mlir::TypeConverter &converter,
1566+                                          mlir::DataLayout layout) {
15201567  patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15211568
1522-   patterns
1523-       .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524-            CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525-            CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526-            CIRFuncOpLowering, CIRBrCondOpLowering,
1527-            CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528-            CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529-            CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530-            CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531-            CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532-            CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533-            CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534-            CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535-            CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536-            CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537-            CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538-            CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539-            CIRTrapOpLowering>(converter, patterns.getContext ());
1569+   patterns.add <
1570+       CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1571+       CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1572+       CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1573+       CIRFuncOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1574+       CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1575+       CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1576+       CIRGetElementOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1577+       CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering, CIRAbsOpLowering,
1578+       CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1579+       CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1580+       CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1581+       CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1582+       CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1583+       CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1584+       CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1585+       CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext ());
1586+ 
1587+   patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1588+                                        layout);
15401589}
15411590
1542- static  mlir::TypeConverter prepareTypeConverter () {
1591+ static  mlir::TypeConverter prepareTypeConverter (mlir::DataLayout layout ) {
15431592  mlir::TypeConverter converter;
15441593  converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
1545-     auto  ty = convertTypeForMemory (converter, type.getPointee ());
1594+     auto  pointee = type.getPointee ();
1595+     auto  ty = convertTypeForMemory (converter, pointee);
15461596    //  FIXME: The pointee type might not be converted (e.g. struct)
15471597    if  (!ty)
15481598      return  nullptr ;
1549-     if  (isa<cir::ArrayType>(type. getPointee () ))
1599+     if  (isa<cir::ArrayType>(pointee) || isa<cir::RecordType>(pointee ))
15501600      return  ty;
15511601    return  mlir::MemRefType::get ({}, ty);
15521602  });
@@ -1598,6 +1648,13 @@ static mlir::TypeConverter prepareTypeConverter() {
15981648      return  nullptr ;
15991649    return  mlir::MemRefType::get (shape, elementType);
16001650  });
1651+   converter.addConversion ([&](cir::RecordType type) -> mlir::Type {
1652+     //  Reinterpret structs as raw bytes. Don't use tuples as they can't be put
1653+     //  in memref.
1654+     auto  size = type.getTypeSize (layout, {});
1655+     auto  i8  = mlir::IntegerType::get (type.getContext (), /* width=*/ 8 );
1656+     return  mlir::MemRefType::get (size.getFixedValue (), i8 );
1657+   });
16011658  converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
16021659    auto  ty = converter.convertType (type.getElementType ());
16031660    return  mlir::VectorType::get (type.getSize (), ty);
@@ -1609,12 +1666,15 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16091666  mlir::MLIRContext *context = &getContext ();
16101667  mlir::ModuleOp theModule = getOperation ();
16111668
1612-   auto  converter = prepareTypeConverter ();
1613-   
1669+   mlir::DataLayoutAnalysis layoutAnalysis (theModule);
1670+   const  mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove (theModule);
1671+ 
1672+   auto  converter = prepareTypeConverter (layout);
1673+ 
16141674  mlir::RewritePatternSet patterns (&getContext ());
16151675
16161676  populateCIRLoopToSCFConversionPatterns (patterns, converter);
1617-   populateCIRToMLIRConversionPatterns (patterns, converter);
1677+   populateCIRToMLIRConversionPatterns (patterns, converter, layout );
16181678
16191679  mlir::ConversionTarget target (getContext ());
16201680  target.addLegalOp <mlir::ModuleOp>();
@@ -1628,10 +1688,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281688  //  cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291689  //  here, then MLIR would think any remaining `cir.continue` indicates a
16301690  //  failure, which is not what we want.
1631-   
1632-   patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331691
1634-   if  (mlir::failed (mlir::applyPartialConversion (theModule, target, 
1692+   patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1693+                CIRYieldOpLowering>(converter, context);
1694+ 
1695+   if  (mlir::failed (mlir::applyPartialConversion (theModule, target,
16351696                                                std::move (patterns)))) {
16361697    signalPassFailure ();
16371698  }
0 commit comments