@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528528 for (auto continueOp : continues) {
529529 bool nested = false ;
530530 // When there is another loop between this WhileOp and the ContinueOp,
531- // we shouldn't change that loop instead.
531+ // we should change that loop instead.
532532 for (mlir::Operation *parent = continueOp->getParentOp ();
533533 parent != whileOp; parent = parent->getParentOp ()) {
534534 if (isa<WhileOp>(parent)) {
@@ -570,6 +570,81 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570570 }
571571 }
572572
573+ void rewriteBreak (mlir::scf::WhileOp whileOp,
574+ mlir::ConversionPatternRewriter &rewriter) const {
575+ // Collect all BreakOp inside this while.
576+ llvm::SmallVector<cir::BreakOp> breaks;
577+ whileOp->walk ([&](mlir::Operation *op) {
578+ if (auto breakOp = dyn_cast<BreakOp>(op))
579+ breaks.push_back (breakOp);
580+ });
581+
582+ if (breaks.empty ())
583+ return ;
584+
585+ for (auto breakOp : breaks) {
586+ bool nested = false ;
587+ // When there is another loop between this WhileOp and the BreakOp,
588+ // we should change that loop instead.
589+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
590+ parent = parent->getParentOp ()) {
591+ if (isa<WhileOp>(parent)) {
592+ nested = true ;
593+ break ;
594+ }
595+ }
596+ if (nested)
597+ continue ;
598+
599+ // Similar to the case of ContinueOp, when there is an `IfOp`,
600+ // we need to take special care.
601+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
602+ parent = parent->getParentOp ()) {
603+ if (auto ifOp = dyn_cast<cir::IfOp>(parent))
604+ llvm_unreachable (" NYI" );
605+ }
606+
607+ // Operations after this BreakOp has to be removed.
608+ for (mlir::Operation *runner = breakOp->getNextNode (); runner;) {
609+ mlir::Operation *next = runner->getNextNode ();
610+ runner->erase ();
611+ runner = next;
612+ }
613+
614+ // Blocks after this BreakOp also has to be removed.
615+ for (mlir::Block *block = breakOp->getBlock ()->getNextNode (); block;) {
616+ mlir::Block *next = block->getNextNode ();
617+ block->erase ();
618+ block = next;
619+ }
620+
621+ // We know this BreakOp isn't nested in any IfOp.
622+ // Therefore, the loop is executed only once.
623+ // We pull everything out of the loop.
624+
625+ auto &beforeOps = whileOp.getBeforeBody ()->getOperations ();
626+ for (mlir::Operation *op = &*beforeOps.begin (); op;) {
627+ if (isa<ConditionOp>(op))
628+ break ;
629+ auto *next = op->getNextNode ();
630+ op->moveBefore (whileOp);
631+ op = next;
632+ }
633+
634+ auto &afterOps = whileOp.getAfterBody ()->getOperations ();
635+ for (mlir::Operation *op = &*afterOps.begin (); op;) {
636+ if (isa<YieldOp>(op))
637+ break ;
638+ auto *next = op->getNextNode ();
639+ op->moveBefore (whileOp);
640+ op = next;
641+ }
642+
643+ // The loop itself should now be removed.
644+ rewriter.eraseOp (whileOp);
645+ }
646+ }
647+
573648public:
574649 using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575650
@@ -579,6 +654,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579654 SCFWhileLoop loop (op, adaptor, &rewriter);
580655 auto whileOp = loop.transferToSCFWhileOp ();
581656 rewriteContinue (whileOp, rewriter);
657+ rewriteBreak (whileOp, rewriter);
582658 rewriter.eraseOp (op);
583659 return mlir::success ();
584660 }
0 commit comments