Skip to content

Commit 4184272

Browse files
committed
[SCEV] Check if predicate is known false for predicated AddRecs.
Similarly to #131538, we can also try and check if a predicate is known to wrap given the backedge taken count. For now, this just checks directly when we try to create predicated AddRecs. This both helps to avoid spending compile-time on optimizations where we know the predicate is false, and can also help to allow additional vectorization (e.g. by deciding to scalarize memory accesses when otherwise we would try to create a predicated AddRec with a predicate that's always false). The initial version is quite restricted, but can be extended in follow-ups to cover more cases.
1 parent 1a3e857 commit 4184272

File tree

2 files changed

+111
-20
lines changed

2 files changed

+111
-20
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,9 +2300,16 @@ CollectAddOperandsWithScales(SmallDenseMap<const SCEV *, APInt, 16> &M,
23002300
return Interesting;
23012301
}
23022302

2303-
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2304-
const SCEV *LHS, const SCEV *RHS,
2305-
const Instruction *CtxI) {
2303+
namespace {
2304+
enum class OverflowCheckTy { WillOverflow, WillNotOverflow };
2305+
}
2306+
2307+
// Return true if (LHS BinOp RHS) is guaranteed to overflow (if \p Check is
2308+
// WillOverflow) or to not overflow (if \p Check is WillNotOverflow).
2309+
static bool checkOverflow(OverflowCheckTy Check, ScalarEvolution *SE,
2310+
Instruction::BinaryOps BinOp, bool Signed,
2311+
const SCEV *LHS, const SCEV *RHS,
2312+
const Instruction *CtxI) {
23062313
const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
23072314
SCEV::NoWrapFlags, unsigned);
23082315
switch (BinOp) {
@@ -2328,12 +2335,12 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23282335
auto *WideTy =
23292336
IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
23302337

2331-
const SCEV *A = (this->*Extension)(
2332-
(this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333-
const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334-
const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335-
const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336-
if (A == B)
2338+
const SCEV *A = (SE->*Extension)(
2339+
(SE->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2340+
const SCEV *LHSB = (SE->*Extension)(LHS, WideTy, 0);
2341+
const SCEV *RHSB = (SE->*Extension)(RHS, WideTy, 0);
2342+
const SCEV *B = (SE->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2343+
if (Check == OverflowCheckTy::WillNotOverflow && A == B)
23372344
return true;
23382345
// Can we use context to prove the fact we need?
23392346
if (!CtxI)
@@ -2361,21 +2368,31 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23612368
}
23622369

23632370
ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
2371+
if (Check == OverflowCheckTy::WillOverflow)
2372+
Pred = CmpInst::getInversePredicate(Pred);
2373+
23642374
if (OverflowDown) {
23652375
// To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
23662376
APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
23672377
: APInt::getMinValue(NumBits);
23682378
APInt Limit = Min + Magnitude;
2369-
return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2379+
return SE->isKnownPredicateAt(Pred, SE->getConstant(Limit), LHS, CtxI);
23702380
} else {
23712381
// To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
23722382
APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
23732383
: APInt::getMaxValue(NumBits);
23742384
APInt Limit = Max - Magnitude;
2375-
return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2385+
return SE->isKnownPredicateAt(Pred, LHS, SE->getConstant(Limit), CtxI);
23762386
}
23772387
}
23782388

2389+
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2390+
const SCEV *LHS, const SCEV *RHS,
2391+
const Instruction *CtxI) {
2392+
return checkOverflow(OverflowCheckTy::WillNotOverflow, this, BinOp, Signed,
2393+
LHS, RHS, CtxI);
2394+
}
2395+
23792396
std::optional<SCEV::NoWrapFlags>
23802397
ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
23812398
const OverflowingBinaryOperator *OBO) {
@@ -14930,6 +14947,35 @@ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
1493014947
if (!AddRec)
1493114948
return nullptr;
1493214949

14950+
// Check if any of the transformed predicates is known to be false. In that
14951+
// case, it doesn't make sense to convert to an predicated AddRec, as the
14952+
// versioned loop will never execute.
14953+
for (const SCEVPredicate *Pred : TransformPreds) {
14954+
auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
14955+
if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
14956+
continue;
14957+
14958+
const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
14959+
const SCEV *ExitCount = getBackedgeTakenCount(AddRec->getLoop());
14960+
if (isa<SCEVCouldNotCompute>(ExitCount))
14961+
continue;
14962+
14963+
const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
14964+
if (!Step->isOne())
14965+
continue;
14966+
14967+
ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
14968+
if (checkOverflow(OverflowCheckTy::WillOverflow, this, Instruction::Add,
14969+
/*CheckSigned=*/true, AddRecToCheck->getStart(),
14970+
ExitCount, nullptr)) {
14971+
return nullptr;
14972+
}
14973+
const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
14974+
if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart())) {
14975+
return nullptr;
14976+
}
14977+
}
14978+
1493314979
// Since the transformation was successful, we can now transfer the SCEV
1493414980
// predicates.
1493514981
Preds.append(TransformPreds.begin(), TransformPreds.end());

llvm/test/Transforms/LoopVectorize/first-order-recurrence-dead-instructions.ll

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,61 @@ define i8 @recurrence_phi_with_same_incoming_values_after_simplifications(i8 %fo
66
; CHECK-LABEL: define i8 @recurrence_phi_with_same_incoming_values_after_simplifications(
77
; CHECK-SAME: i8 [[FOR_START:%.*]], ptr [[DST:%.*]]) {
88
; CHECK-NEXT: [[ENTRY:.*]]:
9+
; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
10+
; CHECK: [[VECTOR_PH]]:
11+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i8> poison, i8 [[FOR_START]], i64 0
12+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i8> [[BROADCAST_SPLATINSERT]], <4 x i8> poison, <4 x i32> zeroinitializer
13+
; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <4 x i8> [[BROADCAST_SPLAT]], <4 x i8> [[BROADCAST_SPLAT]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
14+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
15+
; CHECK: [[VECTOR_BODY]]:
16+
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
17+
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = add i32 1, [[INDEX]]
18+
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[OFFSET_IDX]], 0
19+
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[OFFSET_IDX]], 1
20+
; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[OFFSET_IDX]], 2
21+
; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[OFFSET_IDX]], 3
22+
; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[OFFSET_IDX]], 4
23+
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[OFFSET_IDX]], 5
24+
; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[OFFSET_IDX]], 6
25+
; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[OFFSET_IDX]], 7
26+
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP1]]
27+
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP2]]
28+
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP3]]
29+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP4]]
30+
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP5]]
31+
; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP6]]
32+
; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP7]]
33+
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP8]]
34+
; CHECK-NEXT: [[TMP17:%.*]] = extractelement <4 x i8> [[TMP0]], i32 0
35+
; CHECK-NEXT: store i8 [[TMP17]], ptr [[TMP9]], align 1
36+
; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x i8> [[TMP0]], i32 1
37+
; CHECK-NEXT: store i8 [[TMP18]], ptr [[TMP10]], align 1
38+
; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x i8> [[TMP0]], i32 2
39+
; CHECK-NEXT: store i8 [[TMP19]], ptr [[TMP11]], align 1
40+
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i8> [[TMP0]], i32 3
41+
; CHECK-NEXT: store i8 [[TMP20]], ptr [[TMP12]], align 1
42+
; CHECK-NEXT: store i8 [[TMP17]], ptr [[TMP13]], align 1
43+
; CHECK-NEXT: store i8 [[TMP18]], ptr [[TMP14]], align 1
44+
; CHECK-NEXT: store i8 [[TMP19]], ptr [[TMP15]], align 1
45+
; CHECK-NEXT: store i8 [[TMP20]], ptr [[TMP16]], align 1
46+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
47+
; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT]], -8
48+
; CHECK-NEXT: br i1 [[TMP21]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
49+
; CHECK: [[MIDDLE_BLOCK]]:
50+
; CHECK-NEXT: br label %[[SCALAR_PH]]
51+
; CHECK: [[SCALAR_PH]]:
52+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ -7, %[[MIDDLE_BLOCK]] ], [ 1, %[[ENTRY]] ]
53+
; CHECK-NEXT: [[SCALAR_RECUR_INIT:%.*]] = phi i8 [ [[FOR_START]], %[[MIDDLE_BLOCK]] ], [ [[FOR_START]], %[[ENTRY]] ]
954
; CHECK-NEXT: br label %[[LOOP:.*]]
1055
; CHECK: [[LOOP]]:
11-
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
12-
; CHECK-NEXT: [[FOR:%.*]] = phi i8 [ [[FOR_START]], %[[ENTRY]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ]
56+
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
57+
; CHECK-NEXT: [[FOR:%.*]] = phi i8 [ [[SCALAR_RECUR_INIT]], %[[SCALAR_PH]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ]
1358
; CHECK-NEXT: [[FOR_NEXT]] = and i8 [[FOR_START]], -1
1459
; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 1
1560
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[IV]]
1661
; CHECK-NEXT: store i8 [[FOR]], ptr [[GEP_DST]], align 1
1762
; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], 0
18-
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]]
63+
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
1964
; CHECK: [[EXIT]]:
2065
; CHECK-NEXT: [[FOR_NEXT_LCSSA:%.*]] = phi i8 [ [[FOR_NEXT]], %[[LOOP]] ]
2166
; CHECK-NEXT: ret i8 [[FOR_NEXT_LCSSA]]
@@ -61,7 +106,7 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
61106
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
62107
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
63108
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 16
64-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
109+
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
65110
; CHECK: [[MIDDLE_BLOCK]]:
66111
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP2]], i32 2
67112
; CHECK-NEXT: br label %[[FOR_END:.*]]
@@ -82,7 +127,7 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
82127
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[B3]] to i32
83128
; CHECK-NEXT: [[A_GEP:%.*]] = getelementptr i32, ptr [[A_PTR]], i16 [[IV]]
84129
; CHECK-NEXT: store i32 0, ptr [[A_GEP]], align 4
85-
; CHECK-NEXT: br i1 [[VEC_DEAD]], label %[[FOR_END]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
130+
; CHECK-NEXT: br i1 [[VEC_DEAD]], label %[[FOR_END]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]]
86131
; CHECK: [[FOR_END]]:
87132
; CHECK-NEXT: [[FOR_LCSSA:%.*]] = phi i32 [ [[FOR]], %[[LOOP]] ], [ [[VECTOR_RECUR_EXTRACT_FOR_PHI]], %[[MIDDLE_BLOCK]] ]
88133
; CHECK-NEXT: ret i32 [[FOR_LCSSA]]
@@ -142,7 +187,7 @@ define void @sink_dead_inst(ptr %a) {
142187
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
143188
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
144189
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], 40
145-
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
190+
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
146191
; CHECK: [[MIDDLE_BLOCK]]:
147192
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
148193
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT1:%.*]] = extractelement <4 x i32> [[TMP2]], i32 3
@@ -163,7 +208,7 @@ define void @sink_dead_inst(ptr %a) {
163208
; CHECK-NEXT: [[REC_1_PREV]] = add i16 [[IV_NEXT]], 5
164209
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[A]], i16 [[IV]]
165210
; CHECK-NEXT: store i16 [[USE_REC_1]], ptr [[GEP]], align 2
166-
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP5:![0-9]+]]
211+
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP7:![0-9]+]]
167212
; CHECK: [[FOR_END]]:
168213
; CHECK-NEXT: ret void
169214
;
@@ -205,7 +250,7 @@ define void @unused_recurrence(ptr %a) {
205250
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
206251
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
207252
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[INDEX_NEXT]], 1024
208-
; CHECK-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
253+
; CHECK-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
209254
; CHECK: [[MIDDLE_BLOCK]]:
210255
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP1]], i32 3
211256
; CHECK-NEXT: br label %[[SCALAR_PH]]
@@ -220,7 +265,7 @@ define void @unused_recurrence(ptr %a) {
220265
; CHECK-NEXT: [[IV_NEXT]] = add i16 [[IV]], 1
221266
; CHECK-NEXT: [[REC_1_PREV]] = add i16 [[IV_NEXT]], 5
222267
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[IV]], 1000
223-
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP7:![0-9]+]]
268+
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP9:![0-9]+]]
224269
; CHECK: [[FOR_END]]:
225270
; CHECK-NEXT: ret void
226271
;

0 commit comments

Comments
 (0)