Skip to content

Commit 5c3445c

Browse files
committed
[SCEV] Check if predicate is known false for predicated AddRecs.
Similarly to #131538, we can also try and check if a predicate is known to wrap given the backedge taken count. For now, this just checks directly when we try to create predicated AddRecs. This both helps to avoid spending compile-time on optimizations where we know the predicate is false, and can also help to allow additional vectorization (e.g. by deciding to scalarize memory accesses when otherwise we would try to create a predicated AddRec with a predicate that's always false). The initial version is quite restricted, but can be extended in follow-ups to cover more cases.
1 parent 1a3e857 commit 5c3445c

File tree

2 files changed

+114
-20
lines changed

2 files changed

+114
-20
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,9 +2300,15 @@ CollectAddOperandsWithScales(SmallDenseMap<const SCEV *, APInt, 16> &M,
23002300
return Interesting;
23012301
}
23022302

2303-
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2304-
const SCEV *LHS, const SCEV *RHS,
2305-
const Instruction *CtxI) {
2303+
namespace {
2304+
enum class OverflowCheckTy { WillOverflow, WillNotOverflow };
2305+
}
2306+
2307+
// Return true if (LHS BinOp RHS) is guaranteed to overflow (
2308+
static bool checkOverflow(OverflowCheckTy Check, ScalarEvolution *SE,
2309+
Instruction::BinaryOps BinOp, bool Signed,
2310+
const SCEV *LHS, const SCEV *RHS,
2311+
const Instruction *CtxI) {
23062312
const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
23072313
SCEV::NoWrapFlags, unsigned);
23082314
switch (BinOp) {
@@ -2328,12 +2334,12 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23282334
auto *WideTy =
23292335
IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
23302336

2331-
const SCEV *A = (this->*Extension)(
2332-
(this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333-
const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334-
const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335-
const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336-
if (A == B)
2337+
const SCEV *A = (SE->*Extension)(
2338+
(SE->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2339+
const SCEV *LHSB = (SE->*Extension)(LHS, WideTy, 0);
2340+
const SCEV *RHSB = (SE->*Extension)(RHS, WideTy, 0);
2341+
const SCEV *B = (SE->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2342+
if (Check == OverflowCheckTy::WillNotOverflow && A == B)
23372343
return true;
23382344
// Can we use context to prove the fact we need?
23392345
if (!CtxI)
@@ -2361,21 +2367,31 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23612367
}
23622368

23632369
ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
2370+
if (Check == OverflowCheckTy::WillOverflow)
2371+
Pred = CmpInst::getInversePredicate(Pred);
2372+
23642373
if (OverflowDown) {
23652374
// To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
23662375
APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
23672376
: APInt::getMinValue(NumBits);
23682377
APInt Limit = Min + Magnitude;
2369-
return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2378+
return SE->isKnownPredicateAt(Pred, SE->getConstant(Limit), LHS, CtxI);
23702379
} else {
23712380
// To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
23722381
APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
23732382
: APInt::getMaxValue(NumBits);
23742383
APInt Limit = Max - Magnitude;
2375-
return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2384+
return SE->isKnownPredicateAt(Pred, LHS, SE->getConstant(Limit), CtxI);
23762385
}
23772386
}
23782387

2388+
bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2389+
const SCEV *LHS, const SCEV *RHS,
2390+
const Instruction *CtxI) {
2391+
return checkOverflow(OverflowCheckTy::WillNotOverflow, this, BinOp, Signed,
2392+
LHS, RHS, CtxI);
2393+
}
2394+
23792395
std::optional<SCEV::NoWrapFlags>
23802396
ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
23812397
const OverflowingBinaryOperator *OBO) {
@@ -14930,6 +14946,39 @@ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
1493014946
if (!AddRec)
1493114947
return nullptr;
1493214948

14949+
// Check if any of the transformed predicates is known to be false. In that
14950+
// case, it doesn't make sense to convert to an predicated AddRec, as the
14951+
// versioned loop will never execute.
14952+
for (const SCEVPredicate *Pred : TransformPreds) {
14953+
auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
14954+
if (!WrapPred ||
14955+
(WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW &&
14956+
WrapPred->getFlags() != SCEVWrapPredicate::IncrementNUSW))
14957+
continue;
14958+
14959+
const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
14960+
const SCEV *ExitCount = getBackedgeTakenCount(AddRec->getLoop());
14961+
if (isa<SCEVCouldNotCompute>(ExitCount))
14962+
continue;
14963+
14964+
const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
14965+
if (!Step->isOne())
14966+
continue;
14967+
14968+
bool CheckSigned = WrapPred->getFlags() == SCEVWrapPredicate::IncrementNSSW;
14969+
ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
14970+
if (checkOverflow(OverflowCheckTy::WillOverflow, this, Instruction::Add,
14971+
CheckSigned, AddRecToCheck->getStart(), ExitCount,
14972+
nullptr)) {
14973+
return nullptr;
14974+
}
14975+
const SCEV *S = getAddExpr(AddRecToCheck->getStart(), ExitCount);
14976+
if (isKnownPredicate(CheckSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT, S,
14977+
AddRecToCheck->getStart())) {
14978+
return nullptr;
14979+
}
14980+
}
14981+
1493314982
// Since the transformation was successful, we can now transfer the SCEV
1493414983
// predicates.
1493514984
Preds.append(TransformPreds.begin(), TransformPreds.end());

llvm/test/Transforms/LoopVectorize/first-order-recurrence-dead-instructions.ll

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,61 @@ define i8 @recurrence_phi_with_same_incoming_values_after_simplifications(i8 %fo
66
; CHECK-LABEL: define i8 @recurrence_phi_with_same_incoming_values_after_simplifications(
77
; CHECK-SAME: i8 [[FOR_START:%.*]], ptr [[DST:%.*]]) {
88
; CHECK-NEXT: [[ENTRY:.*]]:
9+
; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
10+
; CHECK: [[VECTOR_PH]]:
11+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i8> poison, i8 [[FOR_START]], i64 0
12+
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i8> [[BROADCAST_SPLATINSERT]], <4 x i8> poison, <4 x i32> zeroinitializer
13+
; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <4 x i8> [[BROADCAST_SPLAT]], <4 x i8> [[BROADCAST_SPLAT]], <4 x i32> <i32 3, i32 4, i32 5, i32 6>
14+
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
15+
; CHECK: [[VECTOR_BODY]]:
16+
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
17+
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = add i32 1, [[INDEX]]
18+
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[OFFSET_IDX]], 0
19+
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[OFFSET_IDX]], 1
20+
; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[OFFSET_IDX]], 2
21+
; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[OFFSET_IDX]], 3
22+
; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[OFFSET_IDX]], 4
23+
; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[OFFSET_IDX]], 5
24+
; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[OFFSET_IDX]], 6
25+
; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[OFFSET_IDX]], 7
26+
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP1]]
27+
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP2]]
28+
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP3]]
29+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP4]]
30+
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP5]]
31+
; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP6]]
32+
; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP7]]
33+
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[TMP8]]
34+
; CHECK-NEXT: [[TMP17:%.*]] = extractelement <4 x i8> [[TMP0]], i32 0
35+
; CHECK-NEXT: store i8 [[TMP17]], ptr [[TMP9]], align 1
36+
; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x i8> [[TMP0]], i32 1
37+
; CHECK-NEXT: store i8 [[TMP18]], ptr [[TMP10]], align 1
38+
; CHECK-NEXT: [[TMP19:%.*]] = extractelement <4 x i8> [[TMP0]], i32 2
39+
; CHECK-NEXT: store i8 [[TMP19]], ptr [[TMP11]], align 1
40+
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i8> [[TMP0]], i32 3
41+
; CHECK-NEXT: store i8 [[TMP20]], ptr [[TMP12]], align 1
42+
; CHECK-NEXT: store i8 [[TMP17]], ptr [[TMP13]], align 1
43+
; CHECK-NEXT: store i8 [[TMP18]], ptr [[TMP14]], align 1
44+
; CHECK-NEXT: store i8 [[TMP19]], ptr [[TMP15]], align 1
45+
; CHECK-NEXT: store i8 [[TMP20]], ptr [[TMP16]], align 1
46+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
47+
; CHECK-NEXT: [[TMP21:%.*]] = icmp eq i32 [[INDEX_NEXT]], -8
48+
; CHECK-NEXT: br i1 [[TMP21]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
49+
; CHECK: [[MIDDLE_BLOCK]]:
50+
; CHECK-NEXT: br label %[[SCALAR_PH]]
51+
; CHECK: [[SCALAR_PH]]:
52+
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i32 [ -7, %[[MIDDLE_BLOCK]] ], [ 1, %[[ENTRY]] ]
53+
; CHECK-NEXT: [[SCALAR_RECUR_INIT:%.*]] = phi i8 [ [[FOR_START]], %[[MIDDLE_BLOCK]] ], [ [[FOR_START]], %[[ENTRY]] ]
954
; CHECK-NEXT: br label %[[LOOP:.*]]
1055
; CHECK: [[LOOP]]:
11-
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
12-
; CHECK-NEXT: [[FOR:%.*]] = phi i8 [ [[FOR_START]], %[[ENTRY]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ]
56+
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
57+
; CHECK-NEXT: [[FOR:%.*]] = phi i8 [ [[SCALAR_RECUR_INIT]], %[[SCALAR_PH]] ], [ [[FOR_NEXT:%.*]], %[[LOOP]] ]
1358
; CHECK-NEXT: [[FOR_NEXT]] = and i8 [[FOR_START]], -1
1459
; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 1
1560
; CHECK-NEXT: [[GEP_DST:%.*]] = getelementptr inbounds i8, ptr [[DST]], i32 [[IV]]
1661
; CHECK-NEXT: store i8 [[FOR]], ptr [[GEP_DST]], align 1
1762
; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], 0
18-
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]]
63+
; CHECK-NEXT: br i1 [[EC]], label %[[EXIT:.*]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
1964
; CHECK: [[EXIT]]:
2065
; CHECK-NEXT: [[FOR_NEXT_LCSSA:%.*]] = phi i8 [ [[FOR_NEXT]], %[[LOOP]] ]
2166
; CHECK-NEXT: ret i8 [[FOR_NEXT_LCSSA]]
@@ -61,7 +106,7 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
61106
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
62107
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
63108
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 16
64-
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
109+
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
65110
; CHECK: [[MIDDLE_BLOCK]]:
66111
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x i32> [[TMP2]], i32 2
67112
; CHECK-NEXT: br label %[[FOR_END:.*]]
@@ -82,7 +127,7 @@ define i32 @sink_after_dead_inst(ptr %A.ptr) {
82127
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[B3]] to i32
83128
; CHECK-NEXT: [[A_GEP:%.*]] = getelementptr i32, ptr [[A_PTR]], i16 [[IV]]
84129
; CHECK-NEXT: store i32 0, ptr [[A_GEP]], align 4
85-
; CHECK-NEXT: br i1 [[VEC_DEAD]], label %[[FOR_END]], label %[[LOOP]], !llvm.loop [[LOOP3:![0-9]+]]
130+
; CHECK-NEXT: br i1 [[VEC_DEAD]], label %[[FOR_END]], label %[[LOOP]], !llvm.loop [[LOOP5:![0-9]+]]
86131
; CHECK: [[FOR_END]]:
87132
; CHECK-NEXT: [[FOR_LCSSA:%.*]] = phi i32 [ [[FOR]], %[[LOOP]] ], [ [[VECTOR_RECUR_EXTRACT_FOR_PHI]], %[[MIDDLE_BLOCK]] ]
88133
; CHECK-NEXT: ret i32 [[FOR_LCSSA]]
@@ -142,7 +187,7 @@ define void @sink_dead_inst(ptr %a) {
142187
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
143188
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
144189
; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i32 [[INDEX_NEXT]], 40
145-
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
190+
; CHECK-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
146191
; CHECK: [[MIDDLE_BLOCK]]:
147192
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP4]], i32 3
148193
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT1:%.*]] = extractelement <4 x i32> [[TMP2]], i32 3
@@ -163,7 +208,7 @@ define void @sink_dead_inst(ptr %a) {
163208
; CHECK-NEXT: [[REC_1_PREV]] = add i16 [[IV_NEXT]], 5
164209
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[A]], i16 [[IV]]
165210
; CHECK-NEXT: store i16 [[USE_REC_1]], ptr [[GEP]], align 2
166-
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP5:![0-9]+]]
211+
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP7:![0-9]+]]
167212
; CHECK: [[FOR_END]]:
168213
; CHECK-NEXT: ret void
169214
;
@@ -205,7 +250,7 @@ define void @unused_recurrence(ptr %a) {
205250
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
206251
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i16> [[STEP_ADD]], splat (i16 4)
207252
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[INDEX_NEXT]], 1024
208-
; CHECK-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
253+
; CHECK-NEXT: br i1 [[TMP2]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
209254
; CHECK: [[MIDDLE_BLOCK]]:
210255
; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x i16> [[TMP1]], i32 3
211256
; CHECK-NEXT: br label %[[SCALAR_PH]]
@@ -220,7 +265,7 @@ define void @unused_recurrence(ptr %a) {
220265
; CHECK-NEXT: [[IV_NEXT]] = add i16 [[IV]], 1
221266
; CHECK-NEXT: [[REC_1_PREV]] = add i16 [[IV_NEXT]], 5
222267
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[IV]], 1000
223-
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP7:![0-9]+]]
268+
; CHECK-NEXT: br i1 [[CMP]], label %[[FOR_END:.*]], label %[[FOR_COND]], !llvm.loop [[LOOP9:![0-9]+]]
224269
; CHECK: [[FOR_END]]:
225270
; CHECK-NEXT: ret void
226271
;

0 commit comments

Comments
 (0)